From ef6d6e96be91366234e907d6ae25c9bf855e48f0 Mon Sep 17 00:00:00 2001 From: Tomaz Muraus Date: Wed, 17 Feb 2021 22:32:39 +0100 Subject: [PATCH 01/25] Add new black-check and black-format Make targets which verify code comforms to black formatting rules and run it on CI. --- Makefile | 59 ++++++++++++++++++++++++++++++++++++++++++- pyproject.toml | 13 ++++++++++ test-requirements.txt | 1 + 3 files changed, 72 insertions(+), 1 deletion(-) create mode 100644 pyproject.toml diff --git a/Makefile b/Makefile index 65ca6204ae0..abf89f450b4 100644 --- a/Makefile +++ b/Makefile @@ -326,6 +326,63 @@ schemasgen: requirements .schemasgen . $(VIRTUALENV_DIR)/bin/activate; pylint -j $(PYLINT_CONCURRENCY) -E --rcfile=./lint-configs/python/.pylintrc --load-plugins=pylint_plugins.api_models tools/*.py || exit 1; . $(VIRTUALENV_DIR)/bin/activate; pylint -j $(PYLINT_CONCURRENCY) -E --rcfile=./lint-configs/python/.pylintrc pylint_plugins/*.py || exit 1; +# Black task which checks if the code comforts to black code style +.PHONY: black-check +black: requirements .black-check + +.PHONY: .black-check +.black: + @echo + @echo "================== black-check ====================" + @echo + # st2 components + @for component in $(COMPONENTS); do\ + echo "==========================================================="; \ + echo "Running black on" $$component; \ + echo "==========================================================="; \ + . $(VIRTUALENV_DIR)/bin/activate ; black --check --config pyproject.toml $$component/ || exit 1; \ + done + # runner modules and packages + @for component in $(COMPONENTS_RUNNERS); do\ + echo "==========================================================="; \ + echo "Running black on" $$component; \ + echo "==========================================================="; \ + . $(VIRTUALENV_DIR)/bin/activate ; black --check --config pyproject.toml $$component/ || exit 1; \ + done + # Python pack management actions + . $(VIRTUALENV_DIR)/bin/activate; black --check --config pyproject.toml contrib/* || exit 1; + . $(VIRTUALENV_DIR)/bin/activate; black --check --config pyproject.toml scripts/*.py || exit 1; + . $(VIRTUALENV_DIR)/bin/activate; black --check --config pyproject.toml tools/*.py || exit 1; + . $(VIRTUALENV_DIR)/bin/activate; black --check --config pyproject.toml pylint_plugins/*.py || exit 1; + +# Black task which reformats the code using black +.PHONY: black-format +black: requirements .black-format + +.PHONY: .black-format +.black-format: + @echo + @echo "================== black ====================" + @echo + # st2 components + @for component in $(COMPONENTS); do\ + echo "==========================================================="; \ + echo "Running black on" $$component; \ + echo "==========================================================="; \ + . $(VIRTUALENV_DIR)/bin/activate ; black --config pyproject.toml $$component/ || exit 1; \ + done + # runner modules and packages + @for component in $(COMPONENTS_RUNNERS); do\ + echo "==========================================================="; \ + echo "Running black on" $$component; \ + echo "==========================================================="; \ + . $(VIRTUALENV_DIR)/bin/activate ; black --config pyproject.toml $$component/ || exit 1; \ + done + . $(VIRTUALENV_DIR)/bin/activate; black --config pyproject.toml contrib/ || exit 1; + . $(VIRTUALENV_DIR)/bin/activate; black --config pyproject.toml scripts/*.py || exit 1; + . $(VIRTUALENV_DIR)/bin/activate; black --config pyproject.toml tools/*.py || exit 1; + . $(VIRTUALENV_DIR)/bin/activate; black --config pyproject.toml pylint_plugins/*.py || exit 1; + .PHONY: lint-api-spec lint-api-spec: requirements .lint-api-spec @@ -979,7 +1036,7 @@ debs: ci: ci-checks ci-unit ci-integration ci-packs-tests .PHONY: ci-checks -ci-checks: .generated-files-check .pylint .flake8 check-requirements check-sdist-requirements .st2client-dependencies-check .st2common-circular-dependencies-check circle-lint-api-spec .rst-check .st2client-install-check check-python-packages +ci-checks: .generated-files-check .black-check .pylint .flake8 check-requirements check-sdist-requirements .st2client-dependencies-check .st2common-circular-dependencies-check circle-lint-api-spec .rst-check .st2client-install-check check-python-packages .PHONY: .rst-check .rst-check: diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000000..4d034829943 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,13 @@ +[tool.black] +max-line-length = 100 +target_version = ['py36'] +include = '\.pyi?$' +exclude = ''' +( + /( + | \.git + | \.virtualenv + | __pycache__ + )/ +) +''' diff --git a/test-requirements.txt b/test-requirements.txt index 6ca0e9608df..b1909e45351 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -5,6 +5,7 @@ st2flake8==0.1.0 astroid==2.4.2 pylint==2.6.0 pylint-plugin-utils>=0.4 +black==20.8b1 bandit==1.5.1 ipython<6.0.0 isort>=4.2.5 From 8496bb2407b969f0937431992172b98b545f6756 Mon Sep 17 00:00:00 2001 From: StackStorm CodeFormat Date: Wed, 17 Feb 2021 22:34:26 +0100 Subject: [PATCH 02/25] Reformat all the code using black tool. --- .../actions/format_execution_result.py | 43 +- contrib/chatops/actions/match.py | 17 +- contrib/chatops/actions/match_and_execute.py | 27 +- contrib/chatops/tests/test_format_result.py | 28 +- contrib/core/actions/generate_uuid.py | 8 +- contrib/core/actions/inject_trigger.py | 13 +- contrib/core/actions/pause.py | 4 +- .../core/tests/test_action_inject_trigger.py | 26 +- contrib/core/tests/test_action_sendmail.py | 265 +-- contrib/core/tests/test_action_uuid.py | 6 +- contrib/examples/actions/noop.py | 4 +- contrib/examples/actions/print_config.py | 4 +- .../actions/print_to_stdout_and_stderr.py | 6 +- .../actions/python-mock-core-remote.py | 13 +- .../examples/actions/python-mock-create-vm.py | 9 +- .../actions/pythonactions/fibonacci.py | 5 +- ...loop_increase_index_and_check_condition.py | 6 +- .../forloop_parse_github_repos.py | 4 +- .../examples/actions/pythonactions/isprime.py | 13 +- .../pythonactions/json_string_to_object.py | 1 - .../actions/pythonactions/object_return.py | 3 +- .../pythonactions/print_python_environment.py | 11 +- .../pythonactions/print_python_version.py | 5 +- .../pythonactions/yaml_string_to_object.py | 1 - .../ubuntu_pkg_info/lib/datatransformer.py | 8 +- .../ubuntu_pkg_info/ubuntu_pkg_info.py | 11 +- contrib/examples/sensors/echo_flask_app.py | 21 +- contrib/examples/sensors/fibonacci_sensor.py | 15 +- contrib/hello_st2/sensors/sensor1.py | 10 +- contrib/linux/actions/checks/check_loadavg.py | 22 +- .../linux/actions/checks/check_processes.py | 29 +- contrib/linux/actions/dig.py | 28 +- contrib/linux/actions/service.py | 23 +- contrib/linux/actions/wait_for_ssh.py | 40 +- contrib/linux/sensors/file_watch_sensor.py | 24 +- contrib/linux/tests/test_action_dig.py | 18 +- contrib/packs/actions/get_config.py | 4 +- contrib/packs/actions/pack_mgmt/delete.py | 19 +- contrib/packs/actions/pack_mgmt/download.py | 93 +- .../packs/actions/pack_mgmt/get_installed.py | 33 +- .../pack_mgmt/get_pack_dependencies.py | 51 +- .../actions/pack_mgmt/get_pack_warnings.py | 6 +- contrib/packs/actions/pack_mgmt/register.py | 77 +- contrib/packs/actions/pack_mgmt/search.py | 48 +- .../actions/pack_mgmt/setup_virtualenv.py | 71 +- .../packs/actions/pack_mgmt/show_remote.py | 5 +- contrib/packs/actions/pack_mgmt/unload.py | 73 +- .../pack_mgmt/virtualenv_setup_prerun.py | 2 +- contrib/packs/tests/test_action_aliases.py | 60 +- contrib/packs/tests/test_action_download.py | 535 +++--- contrib/packs/tests/test_action_unload.py | 19 +- .../packs/tests/test_get_pack_dependencies.py | 103 +- contrib/packs/tests/test_get_pack_warnings.py | 45 +- .../tests/test_virtualenv_setup_prerun.py | 23 +- .../action_chain_runner/__init__.py | 2 +- .../action_chain_runner.py | 570 +++--- .../runners/action_chain_runner/dist_utils.py | 65 +- contrib/runners/action_chain_runner/setup.py | 30 +- .../tests/unit/test_actionchain.py | 835 +++++---- .../tests/unit/test_actionchain_cancel.py | 164 +- .../unit/test_actionchain_notifications.py | 56 +- .../unit/test_actionchain_params_rendering.py | 100 +- .../unit/test_actionchain_pause_resume.py | 596 +++--- .../announcement_runner/__init__.py | 2 +- .../announcement_runner.py | 35 +- .../runners/announcement_runner/dist_utils.py | 65 +- contrib/runners/announcement_runner/setup.py | 28 +- .../tests/unit/test_announcementrunner.py | 64 +- contrib/runners/http_runner/dist_utils.py | 65 +- .../http_runner/http_runner/__init__.py | 2 +- .../http_runner/http_runner/http_runner.py | 199 +- contrib/runners/http_runner/setup.py | 28 +- .../tests/unit/test_http_runner.py | 337 ++-- contrib/runners/inquirer_runner/dist_utils.py | 65 +- .../inquirer_runner/__init__.py | 2 +- .../inquirer_runner/inquirer_runner.py | 47 +- contrib/runners/inquirer_runner/setup.py | 28 +- .../tests/unit/test_inquirer_runner.py | 96 +- contrib/runners/local_runner/dist_utils.py | 65 +- .../local_runner/local_runner/__init__.py | 2 +- .../runners/local_runner/local_runner/base.py | 169 +- .../local_shell_command_runner.py | 36 +- .../local_runner/local_shell_script_runner.py | 42 +- contrib/runners/local_runner/setup.py | 32 +- .../tests/integration/test_localrunner.py | 453 ++--- contrib/runners/noop_runner/dist_utils.py | 65 +- .../noop_runner/noop_runner/__init__.py | 2 +- .../noop_runner/noop_runner/noop_runner.py | 25 +- contrib/runners/noop_runner/setup.py | 26 +- .../noop_runner/tests/unit/test_nooprunner.py | 14 +- contrib/runners/orquesta_runner/dist_utils.py | 65 +- .../orquesta_functions/runtime.py | 44 +- .../orquesta_functions/st2kv.py | 20 +- .../orquesta_runner/__init__.py | 2 +- .../orquesta_runner/orquesta_runner.py | 132 +- contrib/runners/orquesta_runner/setup.py | 90 +- .../test_wiring_functions_st2kv.py | 63 +- .../orquesta_runner/tests/unit/base.py | 12 +- .../orquesta_runner/tests/unit/test_basic.py | 339 ++-- .../orquesta_runner/tests/unit/test_cancel.py | 182 +- .../tests/unit/test_context.py | 285 +-- .../tests/unit/test_data_flow.py | 92 +- .../orquesta_runner/tests/unit/test_delay.py | 114 +- .../tests/unit/test_error_handling.py | 779 ++++---- .../tests/unit/test_functions_common.py | 214 +-- .../tests/unit/test_functions_st2kv.py | 118 +- .../tests/unit/test_functions_task.py | 208 ++- .../tests/unit/test_inquiries.py | 365 ++-- .../orquesta_runner/tests/unit/test_notify.py | 228 +-- .../tests/unit/test_output_schema.py | 112 +- .../tests/unit/test_pause_and_resume.py | 660 ++++--- .../tests/unit/test_policies.py | 107 +- .../orquesta_runner/tests/unit/test_rerun.py | 398 ++-- .../tests/unit/test_with_items.py | 293 +-- contrib/runners/python_runner/dist_utils.py | 65 +- .../python_runner/python_runner/__init__.py | 2 +- .../python_runner/python_action_wrapper.py | 196 +- .../python_runner/python_runner.py | 226 ++- contrib/runners/python_runner/setup.py | 26 +- .../test_python_action_process_wrapper.py | 91 +- .../integration/test_pythonrunner_behavior.py | 43 +- .../tests/unit/test_output_schema.py | 35 +- .../tests/unit/test_pythonrunner.py | 841 +++++---- contrib/runners/remote_runner/dist_utils.py | 65 +- .../remote_runner/remote_runner/__init__.py | 2 +- .../remote_runner/remote_command_runner.py | 68 +- .../remote_runner/remote_script_runner.py | 154 +- contrib/runners/remote_runner/setup.py | 32 +- contrib/runners/winrm_runner/dist_utils.py | 65 +- contrib/runners/winrm_runner/setup.py | 34 +- .../tests/unit/test_winrm_base.py | 1180 ++++++------ .../tests/unit/test_winrm_command_runner.py | 15 +- .../unit/test_winrm_ps_command_runner.py | 15 +- .../tests/unit/test_winrm_ps_script_runner.py | 29 +- .../winrm_runner/winrm_runner/__init__.py | 2 +- .../winrm_runner/winrm_runner/winrm_base.py | 218 ++- .../winrm_runner/winrm_command_runner.py | 18 +- .../winrm_runner/winrm_ps_command_runner.py | 18 +- .../winrm_runner/winrm_ps_script_runner.py | 20 +- lint-configs/python/.flake8 | 5 +- pylint_plugins/api_models.py | 40 +- pylint_plugins/db_models.py | 13 +- scripts/dist_utils.py | 65 +- scripts/dist_utils_old.py | 47 +- scripts/fixate-requirements.py | 122 +- st2actions/dist_utils.py | 65 +- st2actions/setup.py | 28 +- st2actions/st2actions/__init__.py | 2 +- st2actions/st2actions/cmd/actionrunner.py | 32 +- st2actions/st2actions/cmd/scheduler.py | 41 +- st2actions/st2actions/cmd/st2notifier.py | 27 +- st2actions/st2actions/cmd/workflow_engine.py | 23 +- st2actions/st2actions/config.py | 7 +- st2actions/st2actions/container/base.py | 230 ++- st2actions/st2actions/notifier/config.py | 15 +- st2actions/st2actions/notifier/notifier.py | 220 ++- st2actions/st2actions/policies/concurrency.py | 59 +- .../policies/concurrency_by_attr.py | 84 +- st2actions/st2actions/policies/retry.py | 118 +- st2actions/st2actions/runners/pythonrunner.py | 4 +- st2actions/st2actions/scheduler/config.py | 57 +- st2actions/st2actions/scheduler/entrypoint.py | 37 +- st2actions/st2actions/scheduler/handler.py | 199 +- st2actions/st2actions/worker.py | 160 +- st2actions/st2actions/workflows/config.py | 15 +- st2actions/st2actions/workflows/workflows.py | 62 +- st2actions/tests/unit/policies/test_base.py | 62 +- .../tests/unit/policies/test_concurrency.py | 240 ++- .../unit/policies/test_concurrency_by_attr.py | 249 ++- .../tests/unit/policies/test_retry_policy.py | 120 +- .../tests/unit/test_action_runner_worker.py | 13 +- .../tests/unit/test_actions_registrar.py | 178 +- st2actions/tests/unit/test_async_runner.py | 16 +- .../tests/unit/test_execution_cancellation.py | 190 +- st2actions/tests/unit/test_executions.py | 125 +- st2actions/tests/unit/test_notifier.py | 429 +++-- st2actions/tests/unit/test_parallel_ssh.py | 404 +++-- .../test_paramiko_remote_script_runner.py | 293 +-- st2actions/tests/unit/test_paramiko_ssh.py | 975 ++++++---- .../tests/unit/test_paramiko_ssh_runner.py | 212 ++- st2actions/tests/unit/test_policies.py | 98 +- .../tests/unit/test_polling_async_runner.py | 16 +- st2actions/tests/unit/test_queue_consumers.py | 65 +- st2actions/tests/unit/test_remote_runners.py | 25 +- .../tests/unit/test_runner_container.py | 278 +-- st2actions/tests/unit/test_scheduler.py | 113 +- .../tests/unit/test_scheduler_entrypoint.py | 41 +- st2actions/tests/unit/test_scheduler_retry.py | 103 +- st2actions/tests/unit/test_worker.py | 79 +- st2actions/tests/unit/test_workflow_engine.py | 140 +- st2api/dist_utils.py | 65 +- st2api/setup.py | 22 +- st2api/st2api/__init__.py | 2 +- st2api/st2api/app.py | 53 +- st2api/st2api/cmd/__init__.py | 2 +- st2api/st2api/cmd/api.py | 35 +- st2api/st2api/config.py | 49 +- st2api/st2api/controllers/base.py | 10 +- .../controllers/controller_transforms.py | 8 +- st2api/st2api/controllers/resource.py | 473 +++-- st2api/st2api/controllers/root.py | 16 +- st2api/st2api/controllers/v1/action_views.py | 155 +- st2api/st2api/controllers/v1/actionalias.py | 202 ++- .../st2api/controllers/v1/actionexecutions.py | 694 ++++--- st2api/st2api/controllers/v1/actions.py | 212 ++- .../st2api/controllers/v1/aliasexecution.py | 198 +- st2api/st2api/controllers/v1/auth.py | 116 +- .../st2api/controllers/v1/execution_views.py | 46 +- st2api/st2api/controllers/v1/inquiries.py | 90 +- st2api/st2api/controllers/v1/keyvalue.py | 222 ++- .../controllers/v1/pack_config_schemas.py | 24 +- st2api/st2api/controllers/v1/pack_configs.py | 61 +- st2api/st2api/controllers/v1/pack_views.py | 98 +- st2api/st2api/controllers/v1/packs.py | 252 +-- st2api/st2api/controllers/v1/policies.py | 215 ++- st2api/st2api/controllers/v1/rbac.py | 90 +- .../controllers/v1/rule_enforcement_views.py | 93 +- .../controllers/v1/rule_enforcements.py | 69 +- st2api/st2api/controllers/v1/rule_views.py | 122 +- st2api/st2api/controllers/v1/rules.py | 198 +- st2api/st2api/controllers/v1/ruletypes.py | 25 +- st2api/st2api/controllers/v1/runnertypes.py | 85 +- st2api/st2api/controllers/v1/sensors.py | 74 +- .../st2api/controllers/v1/service_registry.py | 30 +- st2api/st2api/controllers/v1/timers.py | 52 +- st2api/st2api/controllers/v1/traces.py | 60 +- st2api/st2api/controllers/v1/triggers.py | 374 ++-- st2api/st2api/controllers/v1/user.py | 28 +- st2api/st2api/controllers/v1/webhooks.py | 99 +- .../controllers/v1/workflow_inspection.py | 11 +- st2api/st2api/validation.py | 25 +- st2api/st2api/wsgi.py | 8 +- .../integration/test_gunicorn_configs.py | 24 +- st2api/tests/unit/controllers/test_root.py | 10 +- .../unit/controllers/v1/test_action_alias.py | 199 +- .../unit/controllers/v1/test_action_views.py | 290 +-- .../tests/unit/controllers/v1/test_actions.py | 689 +++---- .../controllers/v1/test_alias_execution.py | 374 ++-- st2api/tests/unit/controllers/v1/test_auth.py | 186 +- .../unit/controllers/v1/test_auth_api_keys.py | 271 +-- st2api/tests/unit/controllers/v1/test_base.py | 90 +- .../unit/controllers/v1/test_executions.py | 1470 ++++++++------- .../controllers/v1/test_executions_auth.py | 232 +-- .../v1/test_executions_descendants.py | 65 +- .../controllers/v1/test_executions_filters.py | 330 ++-- .../unit/controllers/v1/test_inquiries.py | 235 ++- st2api/tests/unit/controllers/v1/test_kvps.py | 639 +++---- .../controllers/v1/test_pack_config_schema.py | 39 +- .../unit/controllers/v1/test_pack_configs.py | 92 +- .../tests/unit/controllers/v1/test_packs.py | 697 +++---- .../unit/controllers/v1/test_packs_views.py | 94 +- .../unit/controllers/v1/test_policies.py | 161 +- .../v1/test_rule_enforcement_views.py | 126 +- .../controllers/v1/test_rule_enforcements.py | 70 +- .../unit/controllers/v1/test_rule_views.py | 65 +- .../tests/unit/controllers/v1/test_rules.py | 442 +++-- .../unit/controllers/v1/test_ruletypes.py | 24 +- .../unit/controllers/v1/test_runnertypes.py | 67 +- .../unit/controllers/v1/test_sensortypes.py | 103 +- .../controllers/v1/test_service_registry.py | 56 +- .../tests/unit/controllers/v1/test_timers.py | 61 +- .../tests/unit/controllers/v1/test_traces.py | 180 +- .../controllers/v1/test_triggerinstances.py | 195 +- .../unit/controllers/v1/test_triggers.py | 99 +- .../unit/controllers/v1/test_triggertypes.py | 67 +- .../unit/controllers/v1/test_webhooks.py | 441 +++-- .../v1/test_workflow_inspection.py | 77 +- st2api/tests/unit/test_validation_utils.py | 48 +- st2auth/dist_utils.py | 65 +- st2auth/setup.py | 28 +- st2auth/st2auth/__init__.py | 2 +- st2auth/st2auth/app.py | 46 +- st2auth/st2auth/backends/__init__.py | 21 +- st2auth/st2auth/backends/base.py | 12 +- st2auth/st2auth/backends/constants.py | 10 +- st2auth/st2auth/cmd/api.py | 48 +- st2auth/st2auth/config.py | 85 +- st2auth/st2auth/controllers/v1/auth.py | 42 +- st2auth/st2auth/controllers/v1/sso.py | 43 +- st2auth/st2auth/handlers.py | 182 +- st2auth/st2auth/sso/__init__.py | 24 +- st2auth/st2auth/sso/base.py | 8 +- st2auth/st2auth/sso/noop.py | 6 +- st2auth/st2auth/validation.py | 20 +- st2auth/st2auth/wsgi.py | 8 +- st2auth/tests/base.py | 1 - st2auth/tests/unit/controllers/v1/test_sso.py | 104 +- .../tests/unit/controllers/v1/test_token.py | 206 ++- st2auth/tests/unit/test_auth_backends.py | 2 +- st2auth/tests/unit/test_handlers.py | 166 +- st2auth/tests/unit/test_validation_utils.py | 47 +- st2client/dist_utils.py | 65 +- st2client/setup.py | 64 +- st2client/st2client/__init__.py | 2 +- st2client/st2client/base.py | 228 ++- st2client/st2client/client.py | 310 ++-- st2client/st2client/commands/__init__.py | 27 +- st2client/st2client/commands/action.py | 1604 ++++++++++------- st2client/st2client/commands/action_alias.py | 164 +- st2client/st2client/commands/auth.py | 417 +++-- st2client/st2client/commands/inquiry.py | 134 +- st2client/st2client/commands/keyvalue.py | 410 +++-- st2client/st2client/commands/pack.py | 508 ++++-- st2client/st2client/commands/policy.py | 84 +- st2client/st2client/commands/rbac.py | 172 +- st2client/st2client/commands/resource.py | 480 +++-- st2client/st2client/commands/rule.py | 166 +- .../st2client/commands/rule_enforcement.py | 194 +- st2client/st2client/commands/sensor.py | 66 +- .../st2client/commands/service_registry.py | 67 +- st2client/st2client/commands/timer.py | 35 +- st2client/st2client/commands/trace.py | 349 ++-- st2client/st2client/commands/trigger.py | 90 +- .../st2client/commands/triggerinstance.py | 194 +- st2client/st2client/commands/webhook.py | 40 +- st2client/st2client/commands/workflow.py | 45 +- st2client/st2client/config.py | 5 +- st2client/st2client/config_parser.py | 130 +- st2client/st2client/exceptions/base.py | 5 +- st2client/st2client/formatters/__init__.py | 4 +- st2client/st2client/formatters/doc.py | 25 +- st2client/st2client/formatters/execution.py | 74 +- st2client/st2client/formatters/table.py | 118 +- st2client/st2client/models/__init__.py | 24 +- st2client/st2client/models/action.py | 36 +- st2client/st2client/models/action_alias.py | 29 +- st2client/st2client/models/aliasexecution.py | 25 +- st2client/st2client/models/auth.py | 16 +- st2client/st2client/models/config.py | 16 +- st2client/st2client/models/core.py | 318 ++-- st2client/st2client/models/inquiry.py | 17 +- st2client/st2client/models/keyvalue.py | 10 +- st2client/st2client/models/pack.py | 10 +- st2client/st2client/models/policy.py | 14 +- st2client/st2client/models/rbac.py | 29 +- st2client/st2client/models/reactor.py | 56 +- .../st2client/models/service_registry.py | 35 +- st2client/st2client/models/timer.py | 8 +- st2client/st2client/models/trace.py | 10 +- st2client/st2client/models/webhook.py | 10 +- st2client/st2client/shell.py | 348 ++-- st2client/st2client/utils/color.py | 62 +- st2client/st2client/utils/date.py | 11 +- st2client/st2client/utils/httpclient.py | 50 +- st2client/st2client/utils/interactive.py | 218 +-- st2client/st2client/utils/jsutil.py | 13 +- st2client/st2client/utils/logging.py | 6 +- st2client/st2client/utils/misc.py | 4 +- st2client/st2client/utils/schema.py | 28 +- st2client/st2client/utils/strutil.py | 12 +- st2client/st2client/utils/terminal.py | 33 +- st2client/st2client/utils/types.py | 15 +- st2client/tests/base.py | 36 +- st2client/tests/fixtures/loader.py | 15 +- st2client/tests/unit/test_action.py | 716 ++++---- st2client/tests/unit/test_action_alias.py | 28 +- st2client/tests/unit/test_app.py | 18 +- st2client/tests/unit/test_auth.py | 472 ++--- st2client/tests/unit/test_client.py | 133 +- st2client/tests/unit/test_client_actions.py | 55 +- .../tests/unit/test_client_executions.py | 236 ++- .../tests/unit/test_command_actionrun.py | 207 ++- st2client/tests/unit/test_commands.py | 354 ++-- st2client/tests/unit/test_config_parser.py | 114 +- .../tests/unit/test_execution_tail_command.py | 437 ++--- st2client/tests/unit/test_formatters.py | 287 +-- st2client/tests/unit/test_inquiry.py | 274 +-- st2client/tests/unit/test_interactive.py | 372 ++-- st2client/tests/unit/test_keyvalue.py | 326 ++-- st2client/tests/unit/test_models.py | 275 ++- st2client/tests/unit/test_shell.py | 659 ++++--- st2client/tests/unit/test_ssl.py | 99 +- st2client/tests/unit/test_trace_commands.py | 237 ++- st2client/tests/unit/test_util_date.py | 24 +- st2client/tests/unit/test_util_json.py | 151 +- st2client/tests/unit/test_util_misc.py | 30 +- st2client/tests/unit/test_util_strutil.py | 8 +- st2client/tests/unit/test_util_terminal.py | 28 +- st2client/tests/unit/test_workflow.py | 87 +- ...grate-datastore-to-include-scope-secret.py | 26 +- .../v2.1/st2-migrate-datastore-scopes.py | 26 +- .../v3.1/st2-cleanup-policy-delayed.py | 10 +- st2common/bin/paramiko_ssh_evenlets_tester.py | 72 +- st2common/dist_utils.py | 65 +- st2common/setup.py | 62 +- st2common/st2common/__init__.py | 2 +- .../st2common/bootstrap/actionsregistrar.py | 120 +- .../st2common/bootstrap/aliasesregistrar.py | 103 +- st2common/st2common/bootstrap/base.py | 79 +- .../st2common/bootstrap/configsregistrar.py | 74 +- .../st2common/bootstrap/policiesregistrar.py | 91 +- .../st2common/bootstrap/rulesregistrar.py | 112 +- .../st2common/bootstrap/ruletypesregistrar.py | 37 +- .../st2common/bootstrap/runnersregistrar.py | 38 +- .../st2common/bootstrap/sensorsregistrar.py | 103 +- .../st2common/bootstrap/triggersregistrar.py | 90 +- st2common/st2common/callback/base.py | 3 +- st2common/st2common/cmd/download_pack.py | 50 +- st2common/st2common/cmd/generate_api_spec.py | 8 +- st2common/st2common/cmd/install_pack.py | 60 +- st2common/st2common/cmd/purge_executions.py | 43 +- .../st2common/cmd/purge_trigger_instances.py | 19 +- .../st2common/cmd/setup_pack_virtualenv.py | 48 +- st2common/st2common/cmd/validate_api_spec.py | 41 +- st2common/st2common/cmd/validate_config.py | 38 +- st2common/st2common/config.py | 797 ++++---- st2common/st2common/constants/action.py | 132 +- st2common/st2common/constants/api.py | 8 +- st2common/st2common/constants/auth.py | 30 +- .../st2common/constants/error_messages.py | 23 +- st2common/st2common/constants/exit_codes.py | 8 +- .../st2common/constants/garbage_collection.py | 8 +- st2common/st2common/constants/keyvalue.py | 57 +- st2common/st2common/constants/logging.py | 6 +- st2common/st2common/constants/meta.py | 9 +- st2common/st2common/constants/pack.py | 69 +- st2common/st2common/constants/policy.py | 9 +- .../st2common/constants/rule_enforcement.py | 13 +- st2common/st2common/constants/rules.py | 10 +- st2common/st2common/constants/runners.py | 54 +- st2common/st2common/constants/scheduler.py | 9 +- st2common/st2common/constants/secrets.py | 19 +- st2common/st2common/constants/sensors.py | 8 +- st2common/st2common/constants/system.py | 17 +- st2common/st2common/constants/timer.py | 9 +- st2common/st2common/constants/trace.py | 6 +- st2common/st2common/constants/triggers.py | 479 ++--- st2common/st2common/constants/types.py | 50 +- st2common/st2common/content/bootstrap.py | 220 +-- st2common/st2common/content/loader.py | 69 +- st2common/st2common/content/utils.py | 115 +- st2common/st2common/content/validators.py | 15 +- st2common/st2common/database_setup.py | 40 +- st2common/st2common/exceptions/__init__.py | 26 +- st2common/st2common/exceptions/action.py | 6 +- st2common/st2common/exceptions/actionalias.py | 4 +- st2common/st2common/exceptions/api.py | 3 +- st2common/st2common/exceptions/auth.py | 26 +- st2common/st2common/exceptions/connection.py | 3 + st2common/st2common/exceptions/db.py | 7 +- st2common/st2common/exceptions/inquiry.py | 17 +- st2common/st2common/exceptions/keyvalue.py | 6 +- st2common/st2common/exceptions/rbac.py | 53 +- st2common/st2common/exceptions/ssh.py | 4 +- st2common/st2common/exceptions/workflow.py | 38 +- .../st2common/expressions/functions/data.py | 32 +- .../expressions/functions/datastore.py | 12 +- .../st2common/expressions/functions/path.py | 5 +- .../st2common/expressions/functions/regex.py | 7 +- .../st2common/expressions/functions/time.py | 29 +- .../expressions/functions/version.py | 14 +- st2common/st2common/fields.py | 15 +- .../garbage_collection/executions.py | 147 +- .../st2common/garbage_collection/inquiries.py | 25 +- .../garbage_collection/trigger_instances.py | 34 +- st2common/st2common/log.py | 93 +- st2common/st2common/logging/filters.py | 15 +- st2common/st2common/logging/formatters.py | 71 +- st2common/st2common/logging/handlers.py | 33 +- st2common/st2common/logging/misc.py | 46 +- st2common/st2common/metrics/base.py | 32 +- .../st2common/metrics/drivers/echo_driver.py | 16 +- .../st2common/metrics/drivers/noop_driver.py | 4 +- .../metrics/drivers/statsd_driver.py | 51 +- st2common/st2common/metrics/utils.py | 9 +- st2common/st2common/middleware/cors.py | 45 +- .../st2common/middleware/error_handling.py | 34 +- .../st2common/middleware/instrumentation.py | 47 +- st2common/st2common/middleware/logging.py | 52 +- st2common/st2common/middleware/streaming.py | 8 +- st2common/st2common/models/api/action.py | 505 +++--- .../st2common/models/api/actionrunner.py | 13 +- st2common/st2common/models/api/auth.py | 134 +- st2common/st2common/models/api/base.py | 31 +- st2common/st2common/models/api/execution.py | 132 +- st2common/st2common/models/api/inquiry.py | 131 +- st2common/st2common/models/api/keyvalue.py | 188 +- .../st2common/models/api/notification.py | 81 +- st2common/st2common/models/api/pack.py | 355 ++-- st2common/st2common/models/api/policy.py | 153 +- st2common/st2common/models/api/rbac.py | 377 ++-- st2common/st2common/models/api/rule.py | 258 ++- .../st2common/models/api/rule_enforcement.py | 113 +- st2common/st2common/models/api/sensor.py | 61 +- st2common/st2common/models/api/tag.py | 12 +- st2common/st2common/models/api/trace.py | 188 +- st2common/st2common/models/api/trigger.py | 196 +- st2common/st2common/models/api/webhook.py | 15 +- st2common/st2common/models/base.py | 4 +- st2common/st2common/models/db/__init__.py | 386 ++-- st2common/st2common/models/db/action.py | 80 +- st2common/st2common/models/db/actionalias.py | 65 +- st2common/st2common/models/db/auth.py | 55 +- st2common/st2common/models/db/execution.py | 121 +- .../st2common/models/db/execution_queue.py | 48 +- .../st2common/models/db/executionstate.py | 23 +- st2common/st2common/models/db/keyvalue.py | 20 +- st2common/st2common/models/db/liveaction.py | 67 +- st2common/st2common/models/db/marker.py | 14 +- st2common/st2common/models/db/notification.py | 30 +- st2common/st2common/models/db/pack.py | 36 +- st2common/st2common/models/db/policy.py | 93 +- st2common/st2common/models/db/rbac.py | 52 +- st2common/st2common/models/db/reactor.py | 19 +- st2common/st2common/models/db/rule.py | 78 +- .../st2common/models/db/rule_enforcement.py | 56 +- st2common/st2common/models/db/runner.py | 38 +- st2common/st2common/models/db/sensor.py | 33 +- st2common/st2common/models/db/stormbase.py | 90 +- st2common/st2common/models/db/timer.py | 4 +- st2common/st2common/models/db/trace.py | 70 +- st2common/st2common/models/db/trigger.py | 65 +- st2common/st2common/models/db/webhook.py | 4 +- st2common/st2common/models/db/workflow.py | 40 +- st2common/st2common/models/system/action.py | 387 ++-- .../st2common/models/system/actionchain.py | 93 +- st2common/st2common/models/system/common.py | 26 +- st2common/st2common/models/system/keyvalue.py | 6 +- .../models/system/paramiko_command_action.py | 26 +- .../models/system/paramiko_script_action.py | 52 +- .../models/utils/action_alias_utils.py | 137 +- .../models/utils/action_param_utils.py | 65 +- st2common/st2common/models/utils/profiling.py | 75 +- .../models/utils/sensor_type_utils.py | 100 +- st2common/st2common/operators.py | 162 +- st2common/st2common/persistence/action.py | 12 +- st2common/st2common/persistence/auth.py | 34 +- st2common/st2common/persistence/base.py | 104 +- st2common/st2common/persistence/cleanup.py | 57 +- st2common/st2common/persistence/db_init.py | 58 +- st2common/st2common/persistence/execution.py | 4 +- .../st2common/persistence/execution_queue.py | 4 +- .../st2common/persistence/executionstate.py | 8 +- st2common/st2common/persistence/keyvalue.py | 70 +- st2common/st2common/persistence/liveaction.py | 4 +- st2common/st2common/persistence/marker.py | 4 +- st2common/st2common/persistence/pack.py | 6 +- st2common/st2common/persistence/policy.py | 12 +- st2common/st2common/persistence/rbac.py | 7 +- st2common/st2common/persistence/reactor.py | 10 +- st2common/st2common/persistence/rule.py | 2 +- st2common/st2common/persistence/runner.py | 2 +- st2common/st2common/persistence/sensor.py | 4 +- st2common/st2common/persistence/trace.py | 10 +- st2common/st2common/persistence/trigger.py | 20 +- st2common/st2common/persistence/workflow.py | 5 +- st2common/st2common/policies/__init__.py | 5 +- st2common/st2common/policies/base.py | 13 +- st2common/st2common/policies/concurrency.py | 15 +- st2common/st2common/rbac/backends/__init__.py | 12 +- st2common/st2common/rbac/backends/base.py | 20 +- st2common/st2common/rbac/backends/noop.py | 16 +- st2common/st2common/rbac/migrations.py | 8 +- st2common/st2common/rbac/types.py | 598 +++--- st2common/st2common/router.py | 481 +++-- st2common/st2common/runners/__init__.py | 9 +- st2common/st2common/runners/base.py | 226 +-- st2common/st2common/runners/base_action.py | 19 +- st2common/st2common/runners/parallel_ssh.py | 229 ++- st2common/st2common/runners/paramiko_ssh.py | 326 ++-- .../st2common/runners/paramiko_ssh_runner.py | 158 +- st2common/st2common/runners/utils.py | 75 +- st2common/st2common/script_setup.py | 22 +- st2common/st2common/service_setup.py | 85 +- st2common/st2common/services/access.py | 32 +- st2common/st2common/services/action.py | 213 ++- st2common/st2common/services/config.py | 6 +- st2common/st2common/services/coordination.py | 57 +- st2common/st2common/services/datastore.py | 61 +- st2common/st2common/services/executions.py | 141 +- st2common/st2common/services/inquiry.py | 36 +- st2common/st2common/services/keyvalues.py | 69 +- st2common/st2common/services/packs.py | 98 +- st2common/st2common/services/policies.py | 43 +- st2common/st2common/services/queries.py | 8 +- st2common/st2common/services/rules.py | 15 +- .../st2common/services/sensor_watcher.py | 57 +- st2common/st2common/services/trace.py | 161 +- .../st2common/services/trigger_dispatcher.py | 54 +- st2common/st2common/services/triggers.py | 293 +-- .../st2common/services/triggerwatcher.py | 73 +- st2common/st2common/services/workflows.py | 748 ++++---- st2common/st2common/signal_handlers.py | 2 +- st2common/st2common/stream/listener.py | 124 +- st2common/st2common/transport/__init__.py | 16 +- .../transport/actionexecutionstate.py | 12 +- st2common/st2common/transport/announcement.py | 30 +- st2common/st2common/transport/bootstrap.py | 7 +- .../st2common/transport/bootstrap_utils.py | 93 +- .../transport/connection_retry_wrapper.py | 41 +- st2common/st2common/transport/consumers.py | 76 +- st2common/st2common/transport/execution.py | 35 +- st2common/st2common/transport/liveaction.py | 16 +- st2common/st2common/transport/publishers.py | 57 +- st2common/st2common/transport/queues.py | 109 +- st2common/st2common/transport/reactor.py | 32 +- st2common/st2common/transport/utils.py | 68 +- st2common/st2common/transport/workflow.py | 26 +- st2common/st2common/triggers.py | 68 +- st2common/st2common/util/action_db.py | 216 ++- .../st2common/util/actionalias_helpstring.py | 24 +- .../st2common/util/actionalias_matching.py | 113 +- st2common/st2common/util/api.py | 8 +- st2common/st2common/util/argument_parser.py | 40 +- st2common/st2common/util/auth.py | 33 +- st2common/st2common/util/casts.py | 12 +- st2common/st2common/util/compat.py | 13 +- st2common/st2common/util/concurrency.py | 102 +- st2common/st2common/util/config_loader.py | 83 +- st2common/st2common/util/config_parser.py | 14 +- st2common/st2common/util/crypto.py | 156 +- st2common/st2common/util/date.py | 11 +- st2common/st2common/util/debugging.py | 6 +- st2common/st2common/util/deprecation.py | 8 +- st2common/st2common/util/driver_loader.py | 17 +- st2common/st2common/util/enum.py | 11 +- st2common/st2common/util/file_system.py | 11 +- st2common/st2common/util/green/shell.py | 103 +- st2common/st2common/util/greenpooldispatch.py | 39 +- st2common/st2common/util/gunicorn_workers.py | 6 +- st2common/st2common/util/hash.py | 6 +- st2common/st2common/util/http.py | 27 +- st2common/st2common/util/ip_utils.py | 30 +- st2common/st2common/util/isotime.py | 30 +- st2common/st2common/util/jinja.py | 93 +- st2common/st2common/util/jsonify.py | 26 +- st2common/st2common/util/keyvalue.py | 43 +- st2common/st2common/util/loader.py | 71 +- st2common/st2common/util/misc.py | 48 +- st2common/st2common/util/mongoescape.py | 19 +- st2common/st2common/util/monkey_patch.py | 24 +- st2common/st2common/util/output_schema.py | 33 +- st2common/st2common/util/pack.py | 128 +- st2common/st2common/util/pack_management.py | 267 +-- st2common/st2common/util/param.py | 193 +- st2common/st2common/util/payload.py | 5 +- st2common/st2common/util/queues.py | 8 +- st2common/st2common/util/reference.py | 25 +- st2common/st2common/util/sandboxing.py | 86 +- st2common/st2common/util/schema/__init__.py | 327 ++-- st2common/st2common/util/secrets.py | 38 +- st2common/st2common/util/service.py | 4 +- st2common/st2common/util/shell.py | 44 +- st2common/st2common/util/spec_loader.py | 39 +- st2common/st2common/util/system_info.py | 14 +- st2common/st2common/util/templating.py | 12 +- st2common/st2common/util/types.py | 15 +- st2common/st2common/util/uid.py | 8 +- st2common/st2common/util/ujson.py | 4 +- st2common/st2common/util/url.py | 6 +- st2common/st2common/util/versioning.py | 17 +- st2common/st2common/util/virtualenvs.py | 188 +- st2common/st2common/util/wsgi.py | 8 +- st2common/st2common/validators/api/action.py | 75 +- st2common/st2common/validators/api/misc.py | 8 +- st2common/st2common/validators/api/reactor.py | 93 +- .../st2common/validators/workflow/base.py | 1 - .../tests/fixtures/mock_runner/mock_runner.py | 11 +- st2common/tests/fixtures/version_file.py | 2 +- .../integration/test_rabbitmq_ssl_listener.py | 183 +- .../test_register_content_script.py | 112 +- .../test_service_setup_log_level_filtering.py | 83 +- st2common/tests/unit/base.py | 39 +- st2common/tests/unit/services/test_access.py | 26 +- st2common/tests/unit/services/test_action.py | 384 ++-- .../tests/unit/services/test_keyvalue.py | 23 +- st2common/tests/unit/services/test_policy.py | 51 +- .../unit/services/test_synchronization.py | 8 +- st2common/tests/unit/services/test_trace.py | 635 ++++--- .../test_trace_injection_action_services.py | 46 +- .../tests/unit/services/test_workflow.py | 263 +-- .../services/test_workflow_cancellation.py | 52 +- .../test_workflow_identify_orphans.py | 189 +- .../unit/services/test_workflow_rerun.py | 214 ++- .../services/test_workflow_service_retries.py | 205 ++- .../tests/unit/test_action_alias_utils.py | 267 +-- .../tests/unit/test_action_api_validator.py | 108 +- st2common/tests/unit/test_action_db_utils.py | 506 +++--- .../tests/unit/test_action_param_utils.py | 106 +- .../tests/unit/test_action_system_models.py | 81 +- .../tests/unit/test_actionchain_schema.py | 44 +- st2common/tests/unit/test_aliasesregistrar.py | 14 +- .../tests/unit/test_api_model_validation.py | 245 +-- st2common/tests/unit/test_casts.py | 16 +- st2common/tests/unit/test_config_loader.py | 503 +++--- st2common/tests/unit/test_config_parser.py | 14 +- .../tests/unit/test_configs_registrar.py | 186 +- .../unit/test_connection_retry_wrapper.py | 13 +- st2common/tests/unit/test_content_loader.py | 67 +- st2common/tests/unit/test_content_utils.py | 277 +-- st2common/tests/unit/test_crypto_utils.py | 166 +- st2common/tests/unit/test_datastore.py | 103 +- st2common/tests/unit/test_date_utils.py | 30 +- st2common/tests/unit/test_db.py | 682 ++++--- st2common/tests/unit/test_db_action_state.py | 6 +- st2common/tests/unit/test_db_auth.py | 35 +- st2common/tests/unit/test_db_base.py | 59 +- st2common/tests/unit/test_db_execution.py | 144 +- st2common/tests/unit/test_db_fields.py | 10 +- st2common/tests/unit/test_db_liveaction.py | 89 +- st2common/tests/unit/test_db_marker.py | 11 +- st2common/tests/unit/test_db_model_uids.py | 78 +- st2common/tests/unit/test_db_pack.py | 24 +- st2common/tests/unit/test_db_policy.py | 203 ++- st2common/tests/unit/test_db_rbac.py | 46 +- .../tests/unit/test_db_rule_enforcement.py | 85 +- st2common/tests/unit/test_db_task.py | 55 +- st2common/tests/unit/test_db_trace.py | 149 +- st2common/tests/unit/test_db_uid_mixin.py | 41 +- st2common/tests/unit/test_db_workflow.py | 35 +- st2common/tests/unit/test_dist_utils.py | 94 +- .../tests/unit/test_exceptions_workflow.py | 5 +- st2common/tests/unit/test_executions.py | 269 +-- st2common/tests/unit/test_executions_util.py | 254 +-- .../tests/unit/test_greenpooldispatch.py | 17 +- st2common/tests/unit/test_hash.py | 5 +- st2common/tests/unit/test_ip_utils.py | 43 +- st2common/tests/unit/test_isotime_utils.py | 102 +- .../unit/test_jinja_render_crypto_filters.py | 103 +- .../unit/test_jinja_render_data_filters.py | 59 +- .../test_jinja_render_json_escape_filters.py | 41 +- ...est_jinja_render_jsonpath_query_filters.py | 59 +- .../unit/test_jinja_render_path_filters.py | 29 +- .../unit/test_jinja_render_regex_filters.py | 47 +- .../unit/test_jinja_render_time_filters.py | 16 +- .../unit/test_jinja_render_version_filters.py | 97 +- st2common/tests/unit/test_json_schema.py | 478 ++--- st2common/tests/unit/test_jsonify.py | 19 +- st2common/tests/unit/test_keyvalue_lookup.py | 151 +- .../tests/unit/test_keyvalue_system_model.py | 22 +- st2common/tests/unit/test_logger.py | 382 ++-- st2common/tests/unit/test_logging.py | 22 +- .../tests/unit/test_logging_middleware.py | 56 +- st2common/tests/unit/test_metrics.py | 214 ++- st2common/tests/unit/test_misc_utils.py | 118 +- .../tests/unit/test_model_utils_profiling.py | 22 +- st2common/tests/unit/test_mongoescape.py | 90 +- .../tests/unit/test_notification_helper.py | 133 +- st2common/tests/unit/test_operators.py | 1329 +++++++------- ...st_pack_action_alias_unit_testing_utils.py | 152 +- st2common/tests/unit/test_pack_management.py | 24 +- st2common/tests/unit/test_param_utils.py | 1100 ++++++----- .../test_paramiko_command_action_model.py | 107 +- .../unit/test_paramiko_script_action_model.py | 166 +- st2common/tests/unit/test_persistence.py | 112 +- .../unit/test_persistence_change_revision.py | 17 +- st2common/tests/unit/test_plugin_loader.py | 65 +- st2common/tests/unit/test_policies.py | 46 +- .../tests/unit/test_policies_registrar.py | 109 +- st2common/tests/unit/test_purge_executions.py | 287 +-- .../unit/test_purge_trigger_instances.py | 43 +- st2common/tests/unit/test_queue_consumer.py | 31 +- st2common/tests/unit/test_queue_utils.py | 53 +- st2common/tests/unit/test_rbac_types.py | 400 ++-- st2common/tests/unit/test_reference.py | 33 +- .../unit/test_register_internal_trigger.py | 5 +- .../tests/unit/test_resource_reference.py | 71 +- .../tests/unit/test_resource_registrar.py | 182 +- st2common/tests/unit/test_runners_base.py | 11 +- st2common/tests/unit/test_runners_utils.py | 28 +- .../tests/unit/test_sensor_type_utils.py | 60 +- st2common/tests/unit/test_sensor_watcher.py | 37 +- st2common/tests/unit/test_service_setup.py | 118 +- .../unit/test_shell_action_system_model.py | 494 ++--- st2common/tests/unit/test_state_publisher.py | 17 +- st2common/tests/unit/test_stream_generator.py | 21 +- st2common/tests/unit/test_system_info.py | 5 +- st2common/tests/unit/test_tags.py | 54 +- .../tests/unit/test_time_jinja_filters.py | 22 +- st2common/tests/unit/test_transport.py | 66 +- st2common/tests/unit/test_trigger_services.py | 247 +-- .../tests/unit/test_triggers_registrar.py | 22 +- .../tests/unit/test_unit_testing_mocks.py | 71 +- .../unit/test_util_actionalias_helpstrings.py | 161 +- .../unit/test_util_actionalias_matching.py | 142 +- st2common/tests/unit/test_util_api.py | 41 +- st2common/tests/unit/test_util_compat.py | 12 +- st2common/tests/unit/test_util_db.py | 81 +- st2common/tests/unit/test_util_file_system.py | 30 +- st2common/tests/unit/test_util_http.py | 20 +- st2common/tests/unit/test_util_jinja.py | 106 +- st2common/tests/unit/test_util_keyvalue.py | 108 +- .../tests/unit/test_util_output_schema.py | 60 +- st2common/tests/unit/test_util_pack.py | 44 +- st2common/tests/unit/test_util_payload.py | 28 +- st2common/tests/unit/test_util_sandboxing.py | 213 ++- st2common/tests/unit/test_util_secrets.py | 1140 +++++------- st2common/tests/unit/test_util_shell.py | 32 +- st2common/tests/unit/test_util_templating.py | 36 +- st2common/tests/unit/test_util_types.py | 4 +- st2common/tests/unit/test_util_url.py | 16 +- st2common/tests/unit/test_versioning_utils.py | 54 +- st2common/tests/unit/test_virtualenvs.py | 341 ++-- st2exporter/dist_utils.py | 65 +- st2exporter/setup.py | 22 +- .../st2exporter/cmd/st2exporter_starter.py | 20 +- st2exporter/st2exporter/config.py | 23 +- st2exporter/st2exporter/exporter/dumper.py | 67 +- .../st2exporter/exporter/file_writer.py | 12 +- .../st2exporter/exporter/json_converter.py | 7 +- st2exporter/st2exporter/worker.py | 52 +- .../integration/test_dumper_integration.py | 63 +- .../tests/integration/test_export_worker.py | 55 +- st2exporter/tests/unit/test_dumper.py | 133 +- st2exporter/tests/unit/test_json_converter.py | 29 +- st2reactor/dist_utils.py | 65 +- st2reactor/setup.py | 32 +- st2reactor/st2reactor/__init__.py | 2 +- st2reactor/st2reactor/cmd/garbagecollector.py | 31 +- st2reactor/st2reactor/cmd/rule_tester.py | 40 +- st2reactor/st2reactor/cmd/rulesengine.py | 28 +- st2reactor/st2reactor/cmd/sensormanager.py | 42 +- st2reactor/st2reactor/cmd/timersengine.py | 29 +- st2reactor/st2reactor/cmd/trigger_re_fire.py | 58 +- .../st2reactor/container/hash_partitioner.py | 31 +- st2reactor/st2reactor/container/manager.py | 90 +- .../container/partitioner_lookup.py | 35 +- .../st2reactor/container/partitioners.py | 49 +- .../st2reactor/container/process_container.py | 193 +- .../st2reactor/container/sensor_wrapper.py | 237 ++- st2reactor/st2reactor/container/utils.py | 10 +- .../st2reactor/garbage_collector/base.py | 143 +- .../st2reactor/garbage_collector/config.py | 65 +- st2reactor/st2reactor/rules/config.py | 15 +- st2reactor/st2reactor/rules/enforcer.py | 146 +- st2reactor/st2reactor/rules/engine.py | 42 +- st2reactor/st2reactor/rules/filter.py | 157 +- st2reactor/st2reactor/rules/matcher.py | 57 +- st2reactor/st2reactor/rules/tester.py | 103 +- st2reactor/st2reactor/rules/worker.py | 46 +- st2reactor/st2reactor/sensor/base.py | 9 +- st2reactor/st2reactor/sensor/config.py | 63 +- st2reactor/st2reactor/timer/base.py | 111 +- st2reactor/st2reactor/timer/config.py | 7 +- .../integration/test_garbage_collector.py | 218 ++- .../tests/integration/test_rules_engine.py | 56 +- .../integration/test_sensor_container.py | 90 +- .../tests/integration/test_sensor_watcher.py | 27 +- st2reactor/tests/unit/test_container_utils.py | 63 +- st2reactor/tests/unit/test_enforce.py | 535 +++--- st2reactor/tests/unit/test_filter.py | 325 ++-- .../tests/unit/test_garbage_collector.py | 58 +- .../tests/unit/test_hash_partitioner.py | 53 +- st2reactor/tests/unit/test_partitioners.py | 82 +- .../tests/unit/test_process_container.py | 170 +- st2reactor/tests/unit/test_rule_engine.py | 208 +-- st2reactor/tests/unit/test_rule_matcher.py | 310 ++-- .../unit/test_sensor_and_rule_registration.py | 81 +- st2reactor/tests/unit/test_sensor_service.py | 203 ++- st2reactor/tests/unit/test_sensor_wrapper.py | 172 +- st2reactor/tests/unit/test_tester.py | 90 +- st2reactor/tests/unit/test_timer.py | 26 +- st2stream/dist_utils.py | 65 +- st2stream/setup.py | 22 +- st2stream/st2stream/__init__.py | 2 +- st2stream/st2stream/app.py | 47 +- st2stream/st2stream/cmd/__init__.py | 2 +- st2stream/st2stream/cmd/api.py | 49 +- st2stream/st2stream/config.py | 27 +- .../st2stream/controllers/v1/executions.py | 78 +- st2stream/st2stream/controllers/v1/root.py | 4 +- st2stream/st2stream/controllers/v1/stream.py | 70 +- st2stream/st2stream/signal_handlers.py | 4 +- st2stream/st2stream/wsgi.py | 7 +- st2stream/tests/unit/controllers/v1/base.py | 4 +- .../tests/unit/controllers/v1/test_stream.py | 202 +-- .../v1/test_stream_execution_output.py | 138 +- st2tests/dist_utils.py | 65 +- st2tests/integration/orquesta/base.py | 55 +- .../integration/orquesta/test_performance.py | 23 +- st2tests/integration/orquesta/test_wiring.py | 112 +- .../orquesta/test_wiring_cancel.py | 56 +- .../orquesta/test_wiring_data_flow.py | 42 +- .../integration/orquesta/test_wiring_delay.py | 25 +- .../orquesta/test_wiring_error_handling.py | 349 ++-- .../orquesta/test_wiring_functions.py | 211 +-- .../orquesta/test_wiring_functions_st2kv.py | 150 +- .../orquesta/test_wiring_functions_task.py | 81 +- .../orquesta/test_wiring_inquiry.py | 57 +- .../orquesta/test_wiring_pause_and_resume.py | 122 +- .../integration/orquesta/test_wiring_rerun.py | 90 +- .../orquesta/test_wiring_task_retry.py | 22 +- .../orquesta/test_wiring_with_items.py | 89 +- st2tests/setup.py | 20 +- st2tests/st2tests/__init__.py | 12 +- st2tests/st2tests/action_aliases.py | 45 +- st2tests/st2tests/actions.py | 14 +- st2tests/st2tests/api.py | 173 +- st2tests/st2tests/base.py | 277 +-- st2tests/st2tests/config.py | 408 +++-- .../fixtures/history_views/__init__.py | 4 +- .../localrunner_pack/actions/text_gen.py | 8 +- .../actions/render_config_context.py | 1 - .../dummy_pack_9/actions/invalid_syntax.py | 4 +- .../fixtures/packs/executions/__init__.py | 8 +- .../test_async_runner/test_async_runner.py | 16 +- .../test_polling_async_runner.py | 16 +- .../actions/get_library_path.py | 4 +- st2tests/st2tests/fixturesloader.py | 234 ++- st2tests/st2tests/http.py | 1 - st2tests/st2tests/mocks/action.py | 13 +- st2tests/st2tests/mocks/auth.py | 18 +- st2tests/st2tests/mocks/datastore.py | 29 +- st2tests/st2tests/mocks/execution.py | 6 +- st2tests/st2tests/mocks/liveaction.py | 7 +- .../st2tests/mocks/runners/async_runner.py | 16 +- .../mocks/runners/polling_async_runner.py | 16 +- st2tests/st2tests/mocks/runners/runner.py | 30 +- st2tests/st2tests/mocks/sensor.py | 24 +- st2tests/st2tests/mocks/workflow.py | 5 +- st2tests/st2tests/pack_resource.py | 10 +- st2tests/st2tests/policies/concurrency.py | 12 +- st2tests/st2tests/policies/mock_exception.py | 3 +- .../packs/pythonactions/actions/echoer.py | 2 +- .../pythonactions/actions/non_simple_type.py | 8 +- .../packs/pythonactions/actions/pascal_row.py | 44 +- .../actions/print_config_item_doesnt_exist.py | 4 +- .../actions/print_to_stdout_and_stderr.py | 4 +- .../pythonactions/actions/python_paths.py | 4 +- .../packs/pythonactions/actions/test.py | 2 +- st2tests/st2tests/sensors.py | 20 +- .../checks/actions/checks/check_loadavg.py | 28 +- tools/config_gen.py | 130 +- tools/diff-db-disk.py | 200 +- tools/direct_queue_publisher.py | 25 +- tools/enumerate-runners.py | 9 +- tools/json2yaml.py | 39 +- tools/list_group_members.py | 26 +- tools/log_watcher.py | 90 +- tools/migrate_messaging_setup.py | 14 +- tools/migrate_rules_to_include_pack.py | 49 +- .../migrate_triggers_to_include_ref_count.py | 7 +- tools/queue_consumer.py | 39 +- tools/queue_producer.py | 20 +- tools/st2-analyze-links.py | 83 +- tools/st2-inject-trigger-instances.py | 90 +- tools/visualize_action_chain.py | 101 +- 937 files changed, 54139 insertions(+), 42097 deletions(-) diff --git a/contrib/chatops/actions/format_execution_result.py b/contrib/chatops/actions/format_execution_result.py index 8790ae4ae73..d6830df004d 100755 --- a/contrib/chatops/actions/format_execution_result.py +++ b/contrib/chatops/actions/format_execution_result.py @@ -23,51 +23,50 @@ class FormatResultAction(Action): def __init__(self, config=None, action_service=None): - super(FormatResultAction, self).__init__(config=config, action_service=action_service) - api_url = os.environ.get('ST2_ACTION_API_URL', None) - token = os.environ.get('ST2_ACTION_AUTH_TOKEN', None) + super(FormatResultAction, self).__init__( + config=config, action_service=action_service + ) + api_url = os.environ.get("ST2_ACTION_API_URL", None) + token = os.environ.get("ST2_ACTION_AUTH_TOKEN", None) self.client = Client(api_url=api_url, token=token) self.jinja = jinja_utils.get_jinja_environment(allow_undefined=True) - self.jinja.tests['in'] = lambda item, list: item in list + self.jinja.tests["in"] = lambda item, list: item in list path = os.path.dirname(os.path.realpath(__file__)) - with open(os.path.join(path, 'templates/default.j2'), 'r') as f: + with open(os.path.join(path, "templates/default.j2"), "r") as f: self.default_template = f.read() def run(self, execution_id): execution = self._get_execution(execution_id) - context = { - 'six': six, - 'execution': execution - } + context = {"six": six, "execution": execution} template = self.default_template result = {"enabled": True} - alias_id = execution['context'].get('action_alias_ref', {}).get('id', None) + alias_id = execution["context"].get("action_alias_ref", {}).get("id", None) if alias_id: - alias = self.client.managers['ActionAlias'].get_by_id(alias_id) + alias = self.client.managers["ActionAlias"].get_by_id(alias_id) - context.update({ - 'alias': alias - }) + context.update({"alias": alias}) - result_params = getattr(alias, 'result', None) + result_params = getattr(alias, "result", None) if result_params: - if not result_params.get('enabled', True): + if not result_params.get("enabled", True): result["enabled"] = False else: - if 'format' in alias.result: - template = alias.result['format'] - if 'extra' in alias.result: - result['extra'] = jinja_utils.render_values(alias.result['extra'], context) + if "format" in alias.result: + template = alias.result["format"] + if "extra" in alias.result: + result["extra"] = jinja_utils.render_values( + alias.result["extra"], context + ) - result['message'] = self.jinja.from_string(template).render(context) + result["message"] = self.jinja.from_string(template).render(context) return result def _get_execution(self, execution_id): if not execution_id: - raise ValueError('Invalid execution_id provided.') + raise ValueError("Invalid execution_id provided.") execution = self.client.liveactions.get_by_id(id=execution_id) if not execution: return None diff --git a/contrib/chatops/actions/match.py b/contrib/chatops/actions/match.py index 46dac1ff648..7ee2154b425 100644 --- a/contrib/chatops/actions/match.py +++ b/contrib/chatops/actions/match.py @@ -23,23 +23,16 @@ class MatchAction(Action): def __init__(self, config=None): super(MatchAction, self).__init__(config=config) - api_url = os.environ.get('ST2_ACTION_API_URL', None) - token = os.environ.get('ST2_ACTION_AUTH_TOKEN', None) + api_url = os.environ.get("ST2_ACTION_API_URL", None) + token = os.environ.get("ST2_ACTION_AUTH_TOKEN", None) self.client = Client(api_url=api_url, token=token) def run(self, text): alias_match = ActionAliasMatch() alias_match.command = text - matches = self.client.managers['ActionAlias'].match(alias_match) - return { - 'alias': _format_match(matches[0]), - 'representation': matches[1] - } + matches = self.client.managers["ActionAlias"].match(alias_match) + return {"alias": _format_match(matches[0]), "representation": matches[1]} def _format_match(match): - return { - 'name': match.name, - 'pack': match.pack, - 'action_ref': match.action_ref - } + return {"name": match.name, "pack": match.pack, "action_ref": match.action_ref} diff --git a/contrib/chatops/actions/match_and_execute.py b/contrib/chatops/actions/match_and_execute.py index 11388e599b6..5e90080f031 100644 --- a/contrib/chatops/actions/match_and_execute.py +++ b/contrib/chatops/actions/match_and_execute.py @@ -19,25 +19,26 @@ from st2common.runners.base_action import Action from st2client.models.action_alias import ActionAliasMatch from st2client.models.aliasexecution import ActionAliasExecution -from st2client.commands.action import (LIVEACTION_STATUS_REQUESTED, - LIVEACTION_STATUS_SCHEDULED, - LIVEACTION_STATUS_RUNNING, - LIVEACTION_STATUS_CANCELING) +from st2client.commands.action import ( + LIVEACTION_STATUS_REQUESTED, + LIVEACTION_STATUS_SCHEDULED, + LIVEACTION_STATUS_RUNNING, + LIVEACTION_STATUS_CANCELING, +) from st2client.client import Client class ExecuteActionAliasAction(Action): def __init__(self, config=None): super(ExecuteActionAliasAction, self).__init__(config=config) - api_url = os.environ.get('ST2_ACTION_API_URL', None) - token = os.environ.get('ST2_ACTION_AUTH_TOKEN', None) + api_url = os.environ.get("ST2_ACTION_API_URL", None) + token = os.environ.get("ST2_ACTION_AUTH_TOKEN", None) self.client = Client(api_url=api_url, token=token) def run(self, text, source_channel=None, user=None): alias_match = ActionAliasMatch() alias_match.command = text - alias, representation = self.client.managers['ActionAlias'].match( - alias_match) + alias, representation = self.client.managers["ActionAlias"].match(alias_match) execution = ActionAliasExecution() execution.name = alias.name @@ -48,20 +49,20 @@ def run(self, text, source_channel=None, user=None): execution.notification_route = None execution.user = user - action_exec_mgr = self.client.managers['ActionAliasExecution'] + action_exec_mgr = self.client.managers["ActionAliasExecution"] execution = action_exec_mgr.create(execution) - self._wait_execution_to_finish(execution.execution['id']) - return execution.execution['id'] + self._wait_execution_to_finish(execution.execution["id"]) + return execution.execution["id"] def _wait_execution_to_finish(self, execution_id): pending_statuses = [ LIVEACTION_STATUS_REQUESTED, LIVEACTION_STATUS_SCHEDULED, LIVEACTION_STATUS_RUNNING, - LIVEACTION_STATUS_CANCELING + LIVEACTION_STATUS_CANCELING, ] - action_exec_mgr = self.client.managers['LiveAction'] + action_exec_mgr = self.client.managers["LiveAction"] execution = action_exec_mgr.get_by_id(execution_id) while execution.status in pending_statuses: time.sleep(1) diff --git a/contrib/chatops/tests/test_format_result.py b/contrib/chatops/tests/test_format_result.py index e700af74548..05114cb3611 100644 --- a/contrib/chatops/tests/test_format_result.py +++ b/contrib/chatops/tests/test_format_result.py @@ -20,9 +20,7 @@ from format_execution_result import FormatResultAction -__all__ = [ - 'FormatResultActionTestCase' -] +__all__ = ["FormatResultActionTestCase"] class FormatResultActionTestCase(BaseActionTestCase): @@ -30,47 +28,45 @@ class FormatResultActionTestCase(BaseActionTestCase): def test_rendering_works_remote_shell_cmd(self): remote_shell_cmd_execution_model = json.loads( - self.get_fixture_content('remote_cmd_execution.json') + self.get_fixture_content("remote_cmd_execution.json") ) action = self.get_action_instance() action._get_execution = mock.MagicMock( return_value=remote_shell_cmd_execution_model ) - result = action.run(execution_id='57967f9355fc8c19a96d9e4f') + result = action.run(execution_id="57967f9355fc8c19a96d9e4f") self.assertTrue(result) - self.assertIn('web_url', result['message']) - self.assertIn('Took 2s to complete', result['message']) + self.assertIn("web_url", result["message"]) + self.assertIn("Took 2s to complete", result["message"]) def test_rendering_local_shell_cmd(self): local_shell_cmd_execution_model = json.loads( - self.get_fixture_content('local_cmd_execution.json') + self.get_fixture_content("local_cmd_execution.json") ) action = self.get_action_instance() action._get_execution = mock.MagicMock( return_value=local_shell_cmd_execution_model ) - self.assertTrue(action.run(execution_id='5799522f55fc8c2d33ac03e0')) + self.assertTrue(action.run(execution_id="5799522f55fc8c2d33ac03e0")) def test_rendering_http_request(self): http_execution_model = json.loads( - self.get_fixture_content('http_execution.json') + self.get_fixture_content("http_execution.json") ) action = self.get_action_instance() - action._get_execution = mock.MagicMock( - return_value=http_execution_model - ) - self.assertTrue(action.run(execution_id='579955f055fc8c2d33ac03e3')) + action._get_execution = mock.MagicMock(return_value=http_execution_model) + self.assertTrue(action.run(execution_id="579955f055fc8c2d33ac03e3")) def test_rendering_python_action(self): python_action_execution_model = json.loads( - self.get_fixture_content('python_action_execution.json') + self.get_fixture_content("python_action_execution.json") ) action = self.get_action_instance() action._get_execution = mock.MagicMock( return_value=python_action_execution_model ) - self.assertTrue(action.run(execution_id='5799572a55fc8c2d33ac03ec')) + self.assertTrue(action.run(execution_id="5799572a55fc8c2d33ac03ec")) diff --git a/contrib/core/actions/generate_uuid.py b/contrib/core/actions/generate_uuid.py index 972b7cb5523..88d8125549d 100644 --- a/contrib/core/actions/generate_uuid.py +++ b/contrib/core/actions/generate_uuid.py @@ -18,16 +18,14 @@ from st2common.runners.base_action import Action -__all__ = [ - 'GenerateUUID' -] +__all__ = ["GenerateUUID"] class GenerateUUID(Action): def run(self, uuid_type): - if uuid_type == 'uuid1': + if uuid_type == "uuid1": return str(uuid.uuid1()) - elif uuid_type == 'uuid4': + elif uuid_type == "uuid4": return str(uuid.uuid4()) else: raise ValueError("Unknown uuid_type. Only uuid1 and uuid4 are supported") diff --git a/contrib/core/actions/inject_trigger.py b/contrib/core/actions/inject_trigger.py index 706e2165db1..a6b2e683170 100644 --- a/contrib/core/actions/inject_trigger.py +++ b/contrib/core/actions/inject_trigger.py @@ -17,9 +17,7 @@ from st2common.runners.base_action import Action -__all__ = [ - 'InjectTriggerAction' -] +__all__ = ["InjectTriggerAction"] class InjectTriggerAction(Action): @@ -34,8 +32,11 @@ def run(self, trigger, payload=None, trace_tag=None): # results in a TriggerInstanceDB database object creation or not. The object is created # inside rulesengine service and could fail due to the user providing an invalid trigger # reference or similar. - self.logger.debug('Injecting trigger "%s" with payload="%s"' % (trigger, str(payload))) - result = client.webhooks.post_generic_webhook(trigger=trigger, payload=payload, - trace_tag=trace_tag) + self.logger.debug( + 'Injecting trigger "%s" with payload="%s"' % (trigger, str(payload)) + ) + result = client.webhooks.post_generic_webhook( + trigger=trigger, payload=payload, trace_tag=trace_tag + ) return result diff --git a/contrib/core/actions/pause.py b/contrib/core/actions/pause.py index 99b9ed9e9bf..7ef8b4eccbf 100755 --- a/contrib/core/actions/pause.py +++ b/contrib/core/actions/pause.py @@ -19,9 +19,7 @@ from st2common.runners.base_action import Action -__all__ = [ - 'PauseAction' -] +__all__ = ["PauseAction"] class PauseAction(Action): diff --git a/contrib/core/tests/test_action_inject_trigger.py b/contrib/core/tests/test_action_inject_trigger.py index 4e0c3b1a291..7c8e44ac986 100644 --- a/contrib/core/tests/test_action_inject_trigger.py +++ b/contrib/core/tests/test_action_inject_trigger.py @@ -27,50 +27,46 @@ class InjectTriggerActionTestCase(BaseActionTestCase): action_cls = InjectTriggerAction - @mock.patch('st2common.services.datastore.BaseDatastoreService.get_api_client') + @mock.patch("st2common.services.datastore.BaseDatastoreService.get_api_client") def test_inject_trigger_only_trigger_no_payload(self, mock_get_api_client): mock_api_client = mock.Mock() mock_get_api_client.return_value = mock_api_client action = self.get_action_instance() - action.run(trigger='dummy_pack.trigger1') + action.run(trigger="dummy_pack.trigger1") mock_api_client.webhooks.post_generic_webhook.assert_called_with( - trigger='dummy_pack.trigger1', - payload={}, - trace_tag=None + trigger="dummy_pack.trigger1", payload={}, trace_tag=None ) mock_api_client.webhooks.post_generic_webhook.reset() - @mock.patch('st2common.services.datastore.BaseDatastoreService.get_api_client') + @mock.patch("st2common.services.datastore.BaseDatastoreService.get_api_client") def test_inject_trigger_trigger_and_payload(self, mock_get_api_client): mock_api_client = mock.Mock() mock_get_api_client.return_value = mock_api_client action = self.get_action_instance() - action.run(trigger='dummy_pack.trigger2', payload={'foo': 'bar'}) + action.run(trigger="dummy_pack.trigger2", payload={"foo": "bar"}) mock_api_client.webhooks.post_generic_webhook.assert_called_with( - trigger='dummy_pack.trigger2', - payload={'foo': 'bar'}, - trace_tag=None + trigger="dummy_pack.trigger2", payload={"foo": "bar"}, trace_tag=None ) mock_api_client.webhooks.post_generic_webhook.reset() - @mock.patch('st2common.services.datastore.BaseDatastoreService.get_api_client') + @mock.patch("st2common.services.datastore.BaseDatastoreService.get_api_client") def test_inject_trigger_trigger_payload_trace_tag(self, mock_get_api_client): mock_api_client = mock.Mock() mock_get_api_client.return_value = mock_api_client action = self.get_action_instance() - action.run(trigger='dummy_pack.trigger3', payload={'foo': 'bar'}, trace_tag='Tag1') + action.run( + trigger="dummy_pack.trigger3", payload={"foo": "bar"}, trace_tag="Tag1" + ) mock_api_client.webhooks.post_generic_webhook.assert_called_with( - trigger='dummy_pack.trigger3', - payload={'foo': 'bar'}, - trace_tag='Tag1' + trigger="dummy_pack.trigger3", payload={"foo": "bar"}, trace_tag="Tag1" ) diff --git a/contrib/core/tests/test_action_sendmail.py b/contrib/core/tests/test_action_sendmail.py index 241fd35d68c..b821ca5f12f 100644 --- a/contrib/core/tests/test_action_sendmail.py +++ b/contrib/core/tests/test_action_sendmail.py @@ -33,12 +33,10 @@ from local_runner.local_shell_script_runner import LocalShellScriptRunner -__all__ = [ - 'SendmailActionTestCase' -] +__all__ = ["SendmailActionTestCase"] MOCK_EXECUTION = mock.Mock() -MOCK_EXECUTION.id = '598dbf0c0640fd54bffc688b' +MOCK_EXECUTION.id = "598dbf0c0640fd54bffc688b" HOSTNAME = socket.gethostname() @@ -47,134 +45,151 @@ class SendmailActionTestCase(RunnerTestCase, CleanDbTestCase, CleanFilesTestCase NOTE: Those tests rely on stanley user being available on the system and having passwordless sudo access. """ + fixtures_loader = FixturesLoader() def test_sendmail_default_text_html_content_type(self): action_parameters = { - 'sendmail_binary': 'cat', - - 'from': 'from.user@example.tld1', - 'to': 'to.user@example.tld2', - 'subject': 'this is subject 1', - 'send_empty_body': False, - 'content_type': 'text/html', - 'body': 'Hello there html.', - 'attachments': '' + "sendmail_binary": "cat", + "from": "from.user@example.tld1", + "to": "to.user@example.tld2", + "subject": "this is subject 1", + "send_empty_body": False, + "content_type": "text/html", + "body": "Hello there html.", + "attachments": "", } - expected_body = ('Hello there html.\n' - '

\n' - 'This message was generated by StackStorm action ' - 'send_mail running on %s' % (HOSTNAME)) + expected_body = ( + "Hello there html.\n" + "

\n" + "This message was generated by StackStorm action " + "send_mail running on %s" % (HOSTNAME) + ) - status, _, email_data, message = self._run_action(action_parameters=action_parameters) + status, _, email_data, message = self._run_action( + action_parameters=action_parameters + ) self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) # Verify subject contains utf-8 charset and is base64 encoded - self.assertIn('SUBJECT: =?UTF-8?B?', email_data) + self.assertIn("SUBJECT: =?UTF-8?B?", email_data) - self.assertEqual(message.to[0][1], action_parameters['to']) - self.assertEqual(message.from_[0][1], action_parameters['from']) - self.assertEqual(message.subject, action_parameters['subject']) + self.assertEqual(message.to[0][1], action_parameters["to"]) + self.assertEqual(message.from_[0][1], action_parameters["from"]) + self.assertEqual(message.subject, action_parameters["subject"]) self.assertEqual(message.body, expected_body) - self.assertEqual(message.content_type, 'text/html; charset=UTF-8') + self.assertEqual(message.content_type, "text/html; charset=UTF-8") def test_sendmail_text_plain_content_type(self): action_parameters = { - 'sendmail_binary': 'cat', - - 'from': 'from.user@example.tld1', - 'to': 'to.user@example.tld2', - 'subject': 'this is subject 2', - 'send_empty_body': False, - 'content_type': 'text/plain', - 'body': 'Hello there plain.', - 'attachments': '' + "sendmail_binary": "cat", + "from": "from.user@example.tld1", + "to": "to.user@example.tld2", + "subject": "this is subject 2", + "send_empty_body": False, + "content_type": "text/plain", + "body": "Hello there plain.", + "attachments": "", } - expected_body = ('Hello there plain.\n\n' - 'This message was generated by StackStorm action ' - 'send_mail running on %s' % (HOSTNAME)) + expected_body = ( + "Hello there plain.\n\n" + "This message was generated by StackStorm action " + "send_mail running on %s" % (HOSTNAME) + ) - status, _, email_data, message = self._run_action(action_parameters=action_parameters) + status, _, email_data, message = self._run_action( + action_parameters=action_parameters + ) self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) # Verify subject contains utf-8 charset and is base64 encoded - self.assertIn('SUBJECT: =?UTF-8?B?', email_data) + self.assertIn("SUBJECT: =?UTF-8?B?", email_data) - self.assertEqual(message.to[0][1], action_parameters['to']) - self.assertEqual(message.from_[0][1], action_parameters['from']) - self.assertEqual(message.subject, action_parameters['subject']) + self.assertEqual(message.to[0][1], action_parameters["to"]) + self.assertEqual(message.from_[0][1], action_parameters["from"]) + self.assertEqual(message.subject, action_parameters["subject"]) self.assertEqual(message.body, expected_body) - self.assertEqual(message.content_type, 'text/plain; charset=UTF-8') + self.assertEqual(message.content_type, "text/plain; charset=UTF-8") def test_sendmail_utf8_subject_and_body(self): # 1. tex/html action_parameters = { - 'sendmail_binary': 'cat', - - 'from': 'from.user@example.tld1', - 'to': 'to.user@example.tld2', - 'subject': u'Å unicode subject 😃😃', - 'send_empty_body': False, - 'content_type': 'text/html', - 'body': u'Hello there 😃😃.', - 'attachments': '' + "sendmail_binary": "cat", + "from": "from.user@example.tld1", + "to": "to.user@example.tld2", + "subject": "Å unicode subject 😃😃", + "send_empty_body": False, + "content_type": "text/html", + "body": "Hello there 😃😃.", + "attachments": "", } if six.PY2: - expected_body = (u'Hello there 😃😃.\n' - u'

\n' - u'This message was generated by StackStorm action ' - u'send_mail running on %s' % (HOSTNAME)) + expected_body = ( + "Hello there 😃😃.\n" + "

\n" + "This message was generated by StackStorm action " + "send_mail running on %s" % (HOSTNAME) + ) else: - expected_body = (u'Hello there \\U0001f603\\U0001f603.\n' - u'

\n' - u'This message was generated by StackStorm action ' - u'send_mail running on %s' % (HOSTNAME)) - - status, _, email_data, message = self._run_action(action_parameters=action_parameters) + expected_body = ( + "Hello there \\U0001f603\\U0001f603.\n" + "

\n" + "This message was generated by StackStorm action " + "send_mail running on %s" % (HOSTNAME) + ) + + status, _, email_data, message = self._run_action( + action_parameters=action_parameters + ) self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) # Verify subject contains utf-8 charset and is base64 encoded - self.assertIn('SUBJECT: =?UTF-8?B?', email_data) + self.assertIn("SUBJECT: =?UTF-8?B?", email_data) - self.assertEqual(message.to[0][1], action_parameters['to']) - self.assertEqual(message.from_[0][1], action_parameters['from']) - self.assertEqual(message.subject, action_parameters['subject']) + self.assertEqual(message.to[0][1], action_parameters["to"]) + self.assertEqual(message.from_[0][1], action_parameters["from"]) + self.assertEqual(message.subject, action_parameters["subject"]) self.assertEqual(message.body, expected_body) - self.assertEqual(message.content_type, 'text/html; charset=UTF-8') + self.assertEqual(message.content_type, "text/html; charset=UTF-8") # 2. text/plain action_parameters = { - 'sendmail_binary': 'cat', - - 'from': 'from.user@example.tld1', - 'to': 'to.user@example.tld2', - 'subject': u'Å unicode subject 😃😃', - 'send_empty_body': False, - 'content_type': 'text/plain', - 'body': u'Hello there 😃😃.', - 'attachments': '' + "sendmail_binary": "cat", + "from": "from.user@example.tld1", + "to": "to.user@example.tld2", + "subject": "Å unicode subject 😃😃", + "send_empty_body": False, + "content_type": "text/plain", + "body": "Hello there 😃😃.", + "attachments": "", } if six.PY2: - expected_body = (u'Hello there 😃😃.\n\n' - u'This message was generated by StackStorm action ' - u'send_mail running on %s' % (HOSTNAME)) + expected_body = ( + "Hello there 😃😃.\n\n" + "This message was generated by StackStorm action " + "send_mail running on %s" % (HOSTNAME) + ) else: - expected_body = (u'Hello there \\U0001f603\\U0001f603.\n\n' - u'This message was generated by StackStorm action ' - u'send_mail running on %s' % (HOSTNAME)) - - status, _, email_data, message = self._run_action(action_parameters=action_parameters) + expected_body = ( + "Hello there \\U0001f603\\U0001f603.\n\n" + "This message was generated by StackStorm action " + "send_mail running on %s" % (HOSTNAME) + ) + + status, _, email_data, message = self._run_action( + action_parameters=action_parameters + ) self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(message.to[0][1], action_parameters['to']) - self.assertEqual(message.from_[0][1], action_parameters['from']) - self.assertEqual(message.subject, action_parameters['subject']) + self.assertEqual(message.to[0][1], action_parameters["to"]) + self.assertEqual(message.from_[0][1], action_parameters["from"]) + self.assertEqual(message.subject, action_parameters["subject"]) self.assertEqual(message.body, expected_body) - self.assertEqual(message.content_type, 'text/plain; charset=UTF-8') + self.assertEqual(message.content_type, "text/plain; charset=UTF-8") def test_sendmail_with_attachments(self): _, path_1 = tempfile.mkstemp() @@ -185,48 +200,52 @@ def test_sendmail_with_attachments(self): self.to_delete_files.append(path_1) self.to_delete_files.append(path_2) - with open(path_1, 'w') as fp: - fp.write('content 1') + with open(path_1, "w") as fp: + fp.write("content 1") - with open(path_2, 'w') as fp: - fp.write('content 2') + with open(path_2, "w") as fp: + fp.write("content 2") action_parameters = { - 'sendmail_binary': 'cat', - - 'from': 'from.user@example.tld1', - 'to': 'to.user@example.tld2', - 'subject': 'this is email with attachments', - 'send_empty_body': False, - 'content_type': 'text/plain', - 'body': 'Hello there plain.', - 'attachments': '%s,%s' % (path_1, path_2) + "sendmail_binary": "cat", + "from": "from.user@example.tld1", + "to": "to.user@example.tld2", + "subject": "this is email with attachments", + "send_empty_body": False, + "content_type": "text/plain", + "body": "Hello there plain.", + "attachments": "%s,%s" % (path_1, path_2), } - expected_body = ('Hello there plain.\n\n' - 'This message was generated by StackStorm action ' - 'send_mail running on %s' % (HOSTNAME)) + expected_body = ( + "Hello there plain.\n\n" + "This message was generated by StackStorm action " + "send_mail running on %s" % (HOSTNAME) + ) - status, _, email_data, message = self._run_action(action_parameters=action_parameters) + status, _, email_data, message = self._run_action( + action_parameters=action_parameters + ) self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) # Verify subject contains utf-8 charset and is base64 encoded - self.assertIn('SUBJECT: =?UTF-8?B?', email_data) + self.assertIn("SUBJECT: =?UTF-8?B?", email_data) - self.assertEqual(message.to[0][1], action_parameters['to']) - self.assertEqual(message.from_[0][1], action_parameters['from']) - self.assertEqual(message.subject, action_parameters['subject']) + self.assertEqual(message.to[0][1], action_parameters["to"]) + self.assertEqual(message.from_[0][1], action_parameters["from"]) + self.assertEqual(message.subject, action_parameters["subject"]) self.assertEqual(message.body, expected_body) - self.assertEqual(message.content_type, - 'multipart/mixed; boundary="ZZ_/afg6432dfgkl.94531q"') + self.assertEqual( + message.content_type, 'multipart/mixed; boundary="ZZ_/afg6432dfgkl.94531q"' + ) # There should be 3 message parts - 2 for attachments, one for body - self.assertEqual(email_data.count('--ZZ_/afg6432dfgkl.94531q'), 3) + self.assertEqual(email_data.count("--ZZ_/afg6432dfgkl.94531q"), 3) # There should be 2 attachments - self.assertEqual(email_data.count('Content-Transfer-Encoding: base64'), 2) - self.assertIn(base64.b64encode(b'content 1').decode('utf-8'), email_data) - self.assertIn(base64.b64encode(b'content 2').decode('utf-8'), email_data) + self.assertEqual(email_data.count("Content-Transfer-Encoding: base64"), 2) + self.assertIn(base64.b64encode(b"content 1").decode("utf-8"), email_data) + self.assertIn(base64.b64encode(b"content 2").decode("utf-8"), email_data) def _run_action(self, action_parameters): """ @@ -234,10 +253,12 @@ def _run_action(self, action_parameters): parse the output email data. """ models = self.fixtures_loader.load_models( - fixtures_pack='packs/core', fixtures_dict={'actions': ['sendmail.yaml']}) - action_db = models['actions']['sendmail.yaml'] + fixtures_pack="packs/core", fixtures_dict={"actions": ["sendmail.yaml"]} + ) + action_db = models["actions"]["sendmail.yaml"] entry_point = self.fixtures_loader.get_fixture_file_path_abs( - 'packs/core', 'actions', 'send_mail/send_mail') + "packs/core", "actions", "send_mail/send_mail" + ) runner = self._get_runner(action_db, entry_point=entry_point) runner.pre_run() @@ -246,13 +267,13 @@ def _run_action(self, action_parameters): # Remove footer added by the action which is not part of raw email data and parse # the message - if 'stdout' in result: - email_data = result['stdout'] - email_data = email_data.split('\n')[:-2] - email_data = '\n'.join(email_data) + if "stdout" in result: + email_data = result["stdout"] + email_data = email_data.split("\n")[:-2] + email_data = "\n".join(email_data) if six.PY2 and isinstance(email_data, six.text_type): - email_data = email_data.encode('utf-8') + email_data = email_data.encode("utf-8") message = mailparser.parse_from_string(email_data) else: @@ -273,5 +294,5 @@ def _get_runner(self, action_db, entry_point): runner.callback = dict() runner.libs_dir_path = None runner.auth_token = mock.Mock() - runner.auth_token.token = 'mock-token' + runner.auth_token.token = "mock-token" return runner diff --git a/contrib/core/tests/test_action_uuid.py b/contrib/core/tests/test_action_uuid.py index 4e946f30622..e13cfd4c18a 100644 --- a/contrib/core/tests/test_action_uuid.py +++ b/contrib/core/tests/test_action_uuid.py @@ -28,13 +28,13 @@ def test_run(self): action = self.get_action_instance() # accepts uuid1 as a type - result = action.run(uuid_type='uuid1') + result = action.run(uuid_type="uuid1") self.assertTrue(result) # accepts uuid4 as a type - result = action.run(uuid_type='uuid4') + result = action.run(uuid_type="uuid4") self.assertTrue(result) # fails on incorrect type with self.assertRaises(ValueError): - result = action.run(uuid_type='foobar') + result = action.run(uuid_type="foobar") diff --git a/contrib/examples/actions/noop.py b/contrib/examples/actions/noop.py index 0283499ce15..bbdf5e67e6a 100644 --- a/contrib/examples/actions/noop.py +++ b/contrib/examples/actions/noop.py @@ -5,6 +5,6 @@ class PrintParametersAction(Action): def run(self, **parameters): - print('=========') + print("=========") pprint(parameters) - print('=========') + print("=========") diff --git a/contrib/examples/actions/print_config.py b/contrib/examples/actions/print_config.py index 68bdf1e2d6e..15b3103b618 100644 --- a/contrib/examples/actions/print_config.py +++ b/contrib/examples/actions/print_config.py @@ -5,6 +5,6 @@ class PrintConfigAction(Action): def run(self): - print('=========') + print("=========") pprint(self.config) - print('=========') + print("=========") diff --git a/contrib/examples/actions/print_to_stdout_and_stderr.py b/contrib/examples/actions/print_to_stdout_and_stderr.py index da31dc14b46..124a32a67cc 100644 --- a/contrib/examples/actions/print_to_stdout_and_stderr.py +++ b/contrib/examples/actions/print_to_stdout_and_stderr.py @@ -23,12 +23,12 @@ class PrintToStdoutAndStderrAction(Action): def run(self, count=100, sleep_delay=0.5): for i in range(0, count): if i % 2 == 0: - text = 'stderr' + text = "stderr" stream = sys.stderr else: - text = 'stdout' + text = "stdout" stream = sys.stdout - stream.write('%s -> Line: %s\n' % (text, (i + 1))) + stream.write("%s -> Line: %s\n" % (text, (i + 1))) stream.flush() time.sleep(sleep_delay) diff --git a/contrib/examples/actions/python-mock-core-remote.py b/contrib/examples/actions/python-mock-core-remote.py index cd4d44500ea..52c13d804ee 100644 --- a/contrib/examples/actions/python-mock-core-remote.py +++ b/contrib/examples/actions/python-mock-core-remote.py @@ -2,7 +2,6 @@ class MockCoreRemoteAction(Action): - def run(self, cmd, hosts, hosts_dict): if hosts_dict: return hosts_dict @@ -10,14 +9,14 @@ def run(self, cmd, hosts, hosts_dict): if not hosts: return None - host_list = hosts.split(',') + host_list = hosts.split(",") results = {} for h in hosts: results[h] = { - 'failed': False, - 'return_code': 0, - 'stderr': '', - 'succeeded': True, - 'stdout': cmd, + "failed": False, + "return_code": 0, + "stderr": "", + "succeeded": True, + "stdout": cmd, } return results diff --git a/contrib/examples/actions/python-mock-create-vm.py b/contrib/examples/actions/python-mock-create-vm.py index 60a88b79675..62fdaa36c15 100644 --- a/contrib/examples/actions/python-mock-create-vm.py +++ b/contrib/examples/actions/python-mock-create-vm.py @@ -5,17 +5,12 @@ class MockCreateVMAction(Action): - def run(self, cpu_cores, memory_mb, vm_name, ip): eventlet.sleep(5) data = { - 'vm_id': 'vm' + str(random.randint(0, 10000)), - ip: { - 'cpu_cores': cpu_cores, - 'memory_mb': memory_mb, - 'vm_name': vm_name - } + "vm_id": "vm" + str(random.randint(0, 10000)), + ip: {"cpu_cores": cpu_cores, "memory_mb": memory_mb, "vm_name": vm_name}, } return data diff --git a/contrib/examples/actions/pythonactions/fibonacci.py b/contrib/examples/actions/pythonactions/fibonacci.py index afab612161d..bd9a479f353 100755 --- a/contrib/examples/actions/pythonactions/fibonacci.py +++ b/contrib/examples/actions/pythonactions/fibonacci.py @@ -12,12 +12,13 @@ def fib(n): return n return fib(n - 2) + fib(n - 1) -if __name__ == '__main__': + +if __name__ == "__main__": try: startNumber = int(float(sys.argv[1])) endNumber = int(float(sys.argv[2])) results = map(str, map(fib, list(range(startNumber, endNumber)))) - results = ' '.join(results) + results = " ".join(results) print(results) except Exception as e: traceback.print_exc(file=sys.stderr) diff --git a/contrib/examples/actions/pythonactions/forloop_increase_index_and_check_condition.py b/contrib/examples/actions/pythonactions/forloop_increase_index_and_check_condition.py index 989467570c4..8cb3c42f4b2 100644 --- a/contrib/examples/actions/pythonactions/forloop_increase_index_and_check_condition.py +++ b/contrib/examples/actions/pythonactions/forloop_increase_index_and_check_condition.py @@ -3,13 +3,13 @@ class IncreaseIndexAndCheckCondition(Action): def run(self, index, pagesize, input): - if pagesize and pagesize != '': + if pagesize and pagesize != "": if len(input) < int(pagesize): return (False, "Breaking out of the loop") else: pagesize = 0 - if not index or index == '': + if not index or index == "": index = 1 - return(True, int(index) + 1) + return (True, int(index) + 1) diff --git a/contrib/examples/actions/pythonactions/forloop_parse_github_repos.py b/contrib/examples/actions/pythonactions/forloop_parse_github_repos.py index a2cdfd1063d..dbefc1b07ec 100644 --- a/contrib/examples/actions/pythonactions/forloop_parse_github_repos.py +++ b/contrib/examples/actions/pythonactions/forloop_parse_github_repos.py @@ -6,12 +6,12 @@ class ParseGithubRepos(Action): def run(self, content): try: - soup = BeautifulSoup(content, 'html.parser') + soup = BeautifulSoup(content, "html.parser") repo_list = soup.find_all("h3") output = {} for each_item in repo_list: - repo_half_url = each_item.find("a")['href'] + repo_half_url = each_item.find("a")["href"] repo_name = repo_half_url.split("/")[-1] repo_url = "https://github.com" + repo_half_url output[repo_name] = repo_url diff --git a/contrib/examples/actions/pythonactions/isprime.py b/contrib/examples/actions/pythonactions/isprime.py index 911594a01ef..e55d202922a 100644 --- a/contrib/examples/actions/pythonactions/isprime.py +++ b/contrib/examples/actions/pythonactions/isprime.py @@ -6,18 +6,19 @@ class PrimeCheckerAction(Action): def run(self, value=0): - self.logger.debug('PYTHONPATH: %s', get_environ('PYTHONPATH')) - self.logger.debug('value=%s' % (value)) + self.logger.debug("PYTHONPATH: %s", get_environ("PYTHONPATH")) + self.logger.debug("value=%s" % (value)) if math.floor(value) != value: - raise ValueError('%s should be an integer.' % value) + raise ValueError("%s should be an integer." % value) if value < 2: return False - for test in range(2, int(math.floor(math.sqrt(value)))+1): + for test in range(2, int(math.floor(math.sqrt(value))) + 1): if value % test == 0: return False return True -if __name__ == '__main__': + +if __name__ == "__main__": checker = PrimeCheckerAction() for i in range(0, 10): - print('%s : %s' % (i, checker.run(value=1))) + print("%s : %s" % (i, checker.run(value=1))) diff --git a/contrib/examples/actions/pythonactions/json_string_to_object.py b/contrib/examples/actions/pythonactions/json_string_to_object.py index 1072c4554e2..e3c492d7a2e 100644 --- a/contrib/examples/actions/pythonactions/json_string_to_object.py +++ b/contrib/examples/actions/pythonactions/json_string_to_object.py @@ -4,6 +4,5 @@ class JsonStringToObject(Action): - def run(self, json_str): return json.loads(json_str) diff --git a/contrib/examples/actions/pythonactions/object_return.py b/contrib/examples/actions/pythonactions/object_return.py index ecaaf57391b..f8a008b73da 100644 --- a/contrib/examples/actions/pythonactions/object_return.py +++ b/contrib/examples/actions/pythonactions/object_return.py @@ -2,6 +2,5 @@ class ObjectReturnAction(Action): - def run(self): - return {'a': 'b', 'c': {'d': 'e', 'f': 1, 'g': True}} + return {"a": "b", "c": {"d": "e", "f": 1, "g": True}} diff --git a/contrib/examples/actions/pythonactions/print_python_environment.py b/contrib/examples/actions/pythonactions/print_python_environment.py index 9c070cc1c07..dd92bfc2028 100644 --- a/contrib/examples/actions/pythonactions/print_python_environment.py +++ b/contrib/examples/actions/pythonactions/print_python_environment.py @@ -6,10 +6,9 @@ class PrintPythonEnvironmentAction(Action): - def run(self): - print('Using Python executable: %s' % (sys.executable)) - print('Using Python version: %s' % (sys.version)) - print('Platform: %s' % (platform.platform())) - print('PYTHONPATH: %s' % (os.environ.get('PYTHONPATH'))) - print('sys.path: %s' % (sys.path)) + print("Using Python executable: %s" % (sys.executable)) + print("Using Python version: %s" % (sys.version)) + print("Platform: %s" % (platform.platform())) + print("PYTHONPATH: %s" % (os.environ.get("PYTHONPATH"))) + print("sys.path: %s" % (sys.path)) diff --git a/contrib/examples/actions/pythonactions/print_python_version.py b/contrib/examples/actions/pythonactions/print_python_version.py index 0ae2a27b184..201c68dd5f0 100644 --- a/contrib/examples/actions/pythonactions/print_python_version.py +++ b/contrib/examples/actions/pythonactions/print_python_version.py @@ -4,7 +4,6 @@ class PrintPythonVersionAction(Action): - def run(self): - print('Using Python executable: %s' % (sys.executable)) - print('Using Python version: %s' % (sys.version)) + print("Using Python executable: %s" % (sys.executable)) + print("Using Python version: %s" % (sys.version)) diff --git a/contrib/examples/actions/pythonactions/yaml_string_to_object.py b/contrib/examples/actions/pythonactions/yaml_string_to_object.py index 297451cdad6..aa888ce4088 100644 --- a/contrib/examples/actions/pythonactions/yaml_string_to_object.py +++ b/contrib/examples/actions/pythonactions/yaml_string_to_object.py @@ -4,6 +4,5 @@ class YamlStringToObject(Action): - def run(self, yaml_str): return yaml.safe_load(yaml_str) diff --git a/contrib/examples/actions/ubuntu_pkg_info/lib/datatransformer.py b/contrib/examples/actions/ubuntu_pkg_info/lib/datatransformer.py index c2c0198a422..14f19582fd9 100644 --- a/contrib/examples/actions/ubuntu_pkg_info/lib/datatransformer.py +++ b/contrib/examples/actions/ubuntu_pkg_info/lib/datatransformer.py @@ -5,11 +5,11 @@ def to_json(out, err, code): payload = {} if err: - payload['err'] = err - payload['exit_code'] = code + payload["err"] = err + payload["exit_code"] = code return json.dumps(payload) - payload['pkg_info'] = out - payload['exit_code'] = code + payload["pkg_info"] = out + payload["exit_code"] = code return json.dumps(payload) diff --git a/contrib/examples/actions/ubuntu_pkg_info/ubuntu_pkg_info.py b/contrib/examples/actions/ubuntu_pkg_info/ubuntu_pkg_info.py index d8213f43427..ec5e5f7ace8 100755 --- a/contrib/examples/actions/ubuntu_pkg_info/ubuntu_pkg_info.py +++ b/contrib/examples/actions/ubuntu_pkg_info/ubuntu_pkg_info.py @@ -7,17 +7,20 @@ def main(args): - command_list = shlex.split('apt-cache policy ' + ' '.join(args[1:])) - process = subprocess.Popen(command_list, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + command_list = shlex.split("apt-cache policy " + " ".join(args[1:])) + process = subprocess.Popen( + command_list, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) command_stdout, command_stderr = process.communicate() command_exitcode = process.returncode try: payload = transformer.to_json(command_stdout, command_stderr, command_exitcode) except Exception as e: - sys.stderr.write('JSON conversion failed. %s' % six.text_type(e)) + sys.stderr.write("JSON conversion failed. %s" % six.text_type(e)) sys.exit(1) sys.stdout.write(payload) -if __name__ == '__main__': + +if __name__ == "__main__": main(sys.argv) diff --git a/contrib/examples/sensors/echo_flask_app.py b/contrib/examples/sensors/echo_flask_app.py index c4025ae441e..5177df23068 100644 --- a/contrib/examples/sensors/echo_flask_app.py +++ b/contrib/examples/sensors/echo_flask_app.py @@ -6,13 +6,12 @@ class EchoFlaskSensor(Sensor): def __init__(self, sensor_service, config): super(EchoFlaskSensor, self).__init__( - sensor_service=sensor_service, - config=config + sensor_service=sensor_service, config=config ) - self._host = '127.0.0.1' + self._host = "127.0.0.1" self._port = 5000 - self._path = '/echo' + self._path = "/echo" self._log = self._sensor_service.get_logger(__name__) self._app = Flask(__name__) @@ -21,15 +20,19 @@ def setup(self): pass def run(self): - @self._app.route(self._path, methods=['POST']) + @self._app.route(self._path, methods=["POST"]) def echo(): payload = request.get_json(force=True) - self._sensor_service.dispatch(trigger="examples.echoflasksensor", - payload=payload) + self._sensor_service.dispatch( + trigger="examples.echoflasksensor", payload=payload + ) return request.data - self._log.info('Listening for payload on http://{}:{}{}'.format( - self._host, self._port, self._path)) + self._log.info( + "Listening for payload on http://{}:{}{}".format( + self._host, self._port, self._path + ) + ) self._app.run(host=self._host, port=self._port, threaded=False) def cleanup(self): diff --git a/contrib/examples/sensors/fibonacci_sensor.py b/contrib/examples/sensors/fibonacci_sensor.py index 266e81aba3f..2df956335bb 100644 --- a/contrib/examples/sensors/fibonacci_sensor.py +++ b/contrib/examples/sensors/fibonacci_sensor.py @@ -4,12 +4,9 @@ class FibonacciSensor(PollingSensor): - def __init__(self, sensor_service, config, poll_interval=20): super(FibonacciSensor, self).__init__( - sensor_service=sensor_service, - config=config, - poll_interval=poll_interval + sensor_service=sensor_service, config=config, poll_interval=poll_interval ) self.a = None self.b = None @@ -26,19 +23,21 @@ def setup(self): def poll(self): # Reset a and b if there are large enough to avoid integer overflow problems if self.a > 10000 or self.b > 10000: - self.logger.debug('Reseting values to avoid integer overflow issues') + self.logger.debug("Reseting values to avoid integer overflow issues") self.a = 0 self.b = 1 self.count = 2 - fib = (self.a + self.b) - self.logger.debug('Count: %d, a: %d, b: %d, fib: %s', self.count, self.a, self.b, fib) + fib = self.a + self.b + self.logger.debug( + "Count: %d, a: %d, b: %d, fib: %s", self.count, self.a, self.b, fib + ) payload = { "count": self.count, "fibonacci": fib, - "pythonpath": os.environ.get("PYTHONPATH", None) + "pythonpath": os.environ.get("PYTHONPATH", None), } self.sensor_service.dispatch(trigger="examples.fibonacci", payload=payload) diff --git a/contrib/hello_st2/sensors/sensor1.py b/contrib/hello_st2/sensors/sensor1.py index 501de54a981..a4914cdf8b7 100644 --- a/contrib/hello_st2/sensors/sensor1.py +++ b/contrib/hello_st2/sensors/sensor1.py @@ -14,11 +14,11 @@ def setup(self): def run(self): while not self._stop: - self._logger.debug('HelloSensor dispatching trigger...') - count = self.sensor_service.get_value('hello_st2.count') or 0 - payload = {'greeting': 'Yo, StackStorm!', 'count': int(count) + 1} - self.sensor_service.dispatch(trigger='hello_st2.event1', payload=payload) - self.sensor_service.set_value('hello_st2.count', payload['count']) + self._logger.debug("HelloSensor dispatching trigger...") + count = self.sensor_service.get_value("hello_st2.count") or 0 + payload = {"greeting": "Yo, StackStorm!", "count": int(count) + 1} + self.sensor_service.dispatch(trigger="hello_st2.event1", payload=payload) + self.sensor_service.set_value("hello_st2.count", payload["count"]) eventlet.sleep(60) def cleanup(self): diff --git a/contrib/linux/actions/checks/check_loadavg.py b/contrib/linux/actions/checks/check_loadavg.py index fb7d3938ccb..04036924e81 100755 --- a/contrib/linux/actions/checks/check_loadavg.py +++ b/contrib/linux/actions/checks/check_loadavg.py @@ -29,7 +29,7 @@ output = {} try: - fh = open(loadAvgFile, 'r') + fh = open(loadAvgFile, "r") load = fh.readline().split()[0:3] except: print("Error opening %s" % loadAvgFile) @@ -38,7 +38,7 @@ fh.close() try: - fh = open(cpuInfoFile, 'r') + fh = open(cpuInfoFile, "r") for line in fh: if "processor" in line: cpus += 1 @@ -48,16 +48,16 @@ finally: fh.close() -output['1'] = str(float(load[0]) / cpus) -output['5'] = str(float(load[1]) / cpus) -output['15'] = str(float(load[2]) / cpus) +output["1"] = str(float(load[0]) / cpus) +output["5"] = str(float(load[1]) / cpus) +output["15"] = str(float(load[2]) / cpus) -if time == '1' or time == 'one': - print(output['1']) -elif time == '5' or time == 'five': - print(output['5']) -elif time == '15' or time == 'fifteen': - print(output['15']) +if time == "1" or time == "one": + print(output["1"]) +elif time == "5" or time == "five": + print(output["5"]) +elif time == "15" or time == "fifteen": + print(output["15"]) else: print(json.dumps(output)) diff --git a/contrib/linux/actions/checks/check_processes.py b/contrib/linux/actions/checks/check_processes.py index b1ff1af0aec..d2a7db195f8 100755 --- a/contrib/linux/actions/checks/check_processes.py +++ b/contrib/linux/actions/checks/check_processes.py @@ -41,8 +41,11 @@ def setup(self, debug=False, pidlist=False): if debug is True: print("Debug is on") - self.allProcs = [procs for procs in os.listdir(self.procDir) if procs.isdigit() and - int(procs) != int(self.myPid)] + self.allProcs = [ + procs + for procs in os.listdir(self.procDir) + if procs.isdigit() and int(procs) != int(self.myPid) + ] def process(self, criteria): for p in self.allProcs: @@ -58,37 +61,37 @@ def process(self, criteria): cmdfh.close() fh.close() - if criteria == 'state': + if criteria == "state": if pInfo[2] == self.state: self.interestingProcs.append(pInfo) - elif criteria == 'name': + elif criteria == "name": if re.search(self.name, pInfo[1]): self.interestingProcs.append(pInfo) - elif criteria == 'pid': + elif criteria == "pid": if pInfo[0] == self.pid: self.interestingProcs.append(pInfo) def byState(self, state): self.state = state - self.process(criteria='state') + self.process(criteria="state") self.show() def byPid(self, pid): self.pid = pid - self.process(criteria='pid') + self.process(criteria="pid") self.show() def byName(self, name): self.name = name - self.process(criteria='name') + self.process(criteria="name") self.show() def run(self, foo, criteria): - if foo == 'state': + if foo == "state": self.byState(criteria) - elif foo == 'name': + elif foo == "name": self.byName(criteria) - elif foo == 'pid': + elif foo == "pid": self.byPid(criteria) def show(self): @@ -99,13 +102,13 @@ def show(self): prettyOut[proc[0]] = proc[1] if self.pidlist is True: - pidlist = ' '.join(prettyOut.keys()) + pidlist = " ".join(prettyOut.keys()) sys.stderr.write(pidlist) print(json.dumps(prettyOut)) -if __name__ == '__main__': +if __name__ == "__main__": if "pidlist" in sys.argv: pidlist = True else: diff --git a/contrib/linux/actions/dig.py b/contrib/linux/actions/dig.py index 9a3b58a5cd8..7eb8518a2a3 100644 --- a/contrib/linux/actions/dig.py +++ b/contrib/linux/actions/dig.py @@ -25,29 +25,28 @@ class DigAction(Action): - def run(self, rand, count, nameserver, hostname, queryopts): opt_list = [] output = [] - cmd_args = ['dig'] + cmd_args = ["dig"] if nameserver: - nameserver = '@' + nameserver + nameserver = "@" + nameserver cmd_args.append(nameserver) - if isinstance(queryopts, str) and ',' in queryopts: - opt_list = queryopts.split(',') + if isinstance(queryopts, str) and "," in queryopts: + opt_list = queryopts.split(",") else: opt_list.append(queryopts) - cmd_args.extend(['+' + option for option in opt_list]) + cmd_args.extend(["+" + option for option in opt_list]) cmd_args.append(hostname) try: - raw_result = subprocess.Popen(cmd_args, - stderr=subprocess.PIPE, - stdout=subprocess.PIPE).communicate()[0] + raw_result = subprocess.Popen( + cmd_args, stderr=subprocess.PIPE, stdout=subprocess.PIPE + ).communicate()[0] if sys.version_info >= (3,): # This function might call getpreferred encoding unless we pass @@ -57,16 +56,19 @@ def run(self, rand, count, nameserver, hostname, queryopts): else: result_list_str = str(raw_result) - result_list = list(filter(None, result_list_str.split('\n'))) + result_list = list(filter(None, result_list_str.split("\n"))) # NOTE: Python3 supports the FileNotFoundError, the errono.ENOENT is for py2 compat # for Python3: # except FileNotFoundError as e: except OSError as e: if e.errno == errno.ENOENT: - return False, "Can't find dig installed in the path (usually /usr/bin/dig). If " \ - "dig isn't installed, you can install it with 'sudo yum install " \ - "bind-utils' or 'sudo apt install dnsutils'" + return ( + False, + "Can't find dig installed in the path (usually /usr/bin/dig). If " + "dig isn't installed, you can install it with 'sudo yum install " + "bind-utils' or 'sudo apt install dnsutils'", + ) else: raise e diff --git a/contrib/linux/actions/service.py b/contrib/linux/actions/service.py index 3961438431c..335e5038f6e 100644 --- a/contrib/linux/actions/service.py +++ b/contrib/linux/actions/service.py @@ -26,20 +26,23 @@ distro = platform.linux_distribution()[0] if len(sys.argv) < 3: - raise ValueError('Usage: service.py ') + raise ValueError("Usage: service.py ") -args = {'act': quote_unix(sys.argv[1]), 'service': quote_unix(sys.argv[2])} +args = {"act": quote_unix(sys.argv[1]), "service": quote_unix(sys.argv[2])} -if re.search(distro, 'Ubuntu'): - if os.path.isfile("/etc/init/%s.conf" % args['service']): - cmd_args = ['service', args['service'], args['act']] - elif os.path.isfile("/etc/init.d/%s" % args['service']): - cmd_args = ['/etc/init.d/%s' % (args['service']), args['act']] +if re.search(distro, "Ubuntu"): + if os.path.isfile("/etc/init/%s.conf" % args["service"]): + cmd_args = ["service", args["service"], args["act"]] + elif os.path.isfile("/etc/init.d/%s" % args["service"]): + cmd_args = ["/etc/init.d/%s" % (args["service"]), args["act"]] else: print("Unknown service") sys.exit(2) -elif re.search(distro, 'Redhat') or re.search(distro, 'Fedora') or \ - re.search(distro, 'CentOS Linux'): - cmd_args = ['systemctl', args['act'], args['service']] +elif ( + re.search(distro, "Redhat") + or re.search(distro, "Fedora") + or re.search(distro, "CentOS Linux") +): + cmd_args = ["systemctl", args["act"], args["service"]] subprocess.call(cmd_args, shell=False) diff --git a/contrib/linux/actions/wait_for_ssh.py b/contrib/linux/actions/wait_for_ssh.py index 4ad4a66050f..c29e91ba031 100644 --- a/contrib/linux/actions/wait_for_ssh.py +++ b/contrib/linux/actions/wait_for_ssh.py @@ -25,29 +25,47 @@ class BaseAction(Action): - def run(self, hostname, port, username, password=None, keyfile=None, ssh_timeout=5, - sleep_delay=20, retries=10): + def run( + self, + hostname, + port, + username, + password=None, + keyfile=None, + ssh_timeout=5, + sleep_delay=20, + retries=10, + ): # Note: If neither password nor key file is provided, we try to use system user # key file if not password and not keyfile: keyfile = cfg.CONF.system_user.ssh_key_file - self.logger.info('Neither "password" nor "keyfile" parameter provided, ' - 'defaulting to using "%s" key file' % (keyfile)) + self.logger.info( + 'Neither "password" nor "keyfile" parameter provided, ' + 'defaulting to using "%s" key file' % (keyfile) + ) - client = ParamikoSSHClient(hostname=hostname, port=port, username=username, - password=password, key_files=keyfile, - timeout=ssh_timeout) + client = ParamikoSSHClient( + hostname=hostname, + port=port, + username=username, + password=password, + key_files=keyfile, + timeout=ssh_timeout, + ) for index in range(retries): attempt = index + 1 try: - self.logger.debug('SSH connection attempt: %s' % (attempt)) + self.logger.debug("SSH connection attempt: %s" % (attempt)) client.connect() return True except Exception as e: - self.logger.info('Attempt %s failed (%s), sleeping for %s seconds...' % - (attempt, six.text_type(e), sleep_delay)) + self.logger.info( + "Attempt %s failed (%s), sleeping for %s seconds..." + % (attempt, six.text_type(e), sleep_delay) + ) time.sleep(sleep_delay) - raise Exception('Exceeded max retries (%s)' % (retries)) + raise Exception("Exceeded max retries (%s)" % (retries)) diff --git a/contrib/linux/sensors/file_watch_sensor.py b/contrib/linux/sensors/file_watch_sensor.py index 2597d63926b..52e29431167 100644 --- a/contrib/linux/sensors/file_watch_sensor.py +++ b/contrib/linux/sensors/file_watch_sensor.py @@ -24,8 +24,9 @@ class FileWatchSensor(Sensor): def __init__(self, sensor_service, config=None): - super(FileWatchSensor, self).__init__(sensor_service=sensor_service, - config=config) + super(FileWatchSensor, self).__init__( + sensor_service=sensor_service, config=config + ) self._trigger = None self._logger = self._sensor_service.get_logger(__name__) self._tail = None @@ -48,16 +49,16 @@ def cleanup(self): pass def add_trigger(self, trigger): - file_path = trigger['parameters'].get('file_path', None) + file_path = trigger["parameters"].get("file_path", None) if not file_path: self._logger.error('Received trigger type without "file_path" field.') return - self._trigger = trigger.get('ref', None) + self._trigger = trigger.get("ref", None) if not self._trigger: - raise Exception('Trigger %s did not contain a ref.' % trigger) + raise Exception("Trigger %s did not contain a ref." % trigger) # Wait a bit to avoid initialization race in logshipper library eventlet.sleep(1.0) @@ -69,7 +70,7 @@ def update_trigger(self, trigger): pass def remove_trigger(self, trigger): - file_path = trigger['parameters'].get('file_path', None) + file_path = trigger["parameters"].get("file_path", None) if not file_path: self._logger.error('Received trigger type without "file_path" field.') @@ -83,10 +84,11 @@ def remove_trigger(self, trigger): def _handle_line(self, file_path, line): trigger = self._trigger payload = { - 'file_path': file_path, - 'file_name': os.path.basename(file_path), - 'line': line + "file_path": file_path, + "file_name": os.path.basename(file_path), + "line": line, } - self._logger.debug('Sending payload %s for trigger %s to sensor_service.', - payload, trigger) + self._logger.debug( + "Sending payload %s for trigger %s to sensor_service.", payload, trigger + ) self.sensor_service.dispatch(trigger=trigger, payload=payload) diff --git a/contrib/linux/tests/test_action_dig.py b/contrib/linux/tests/test_action_dig.py index 4f363521d93..008cf16e76a 100644 --- a/contrib/linux/tests/test_action_dig.py +++ b/contrib/linux/tests/test_action_dig.py @@ -27,15 +27,18 @@ def test_run_with_empty_hostname(self): action = self.get_action_instance() # Use the defaults from dig.yaml - result = action.run(rand=False, count=0, nameserver=None, hostname='', queryopts='short') + result = action.run( + rand=False, count=0, nameserver=None, hostname="", queryopts="short" + ) self.assertIsInstance(result, list) self.assertEqual(len(result), 0) def test_run_with_empty_queryopts(self): action = self.get_action_instance() - results = action.run(rand=False, count=0, nameserver=None, hostname='google.com', - queryopts='') + results = action.run( + rand=False, count=0, nameserver=None, hostname="google.com", queryopts="" + ) self.assertIsInstance(results, list) for result in results: @@ -45,8 +48,13 @@ def test_run_with_empty_queryopts(self): def test_run(self): action = self.get_action_instance() - results = action.run(rand=False, count=0, nameserver=None, hostname='google.com', - queryopts='short') + results = action.run( + rand=False, + count=0, + nameserver=None, + hostname="google.com", + queryopts="short", + ) self.assertIsInstance(results, list) self.assertGreater(len(results), 0) diff --git a/contrib/packs/actions/get_config.py b/contrib/packs/actions/get_config.py index 505ef683c43..07e4654cef9 100755 --- a/contrib/packs/actions/get_config.py +++ b/contrib/packs/actions/get_config.py @@ -22,8 +22,8 @@ class RenderTemplateAction(Action): def run(self): result = { - 'pack_group': utils.get_pack_group(), - 'pack_path': utils.get_system_packs_base_path() + "pack_group": utils.get_pack_group(), + "pack_path": utils.get_system_packs_base_path(), } return result diff --git a/contrib/packs/actions/pack_mgmt/delete.py b/contrib/packs/actions/pack_mgmt/delete.py index 93bcc46044b..ca0436e834b 100644 --- a/contrib/packs/actions/pack_mgmt/delete.py +++ b/contrib/packs/actions/pack_mgmt/delete.py @@ -27,15 +27,18 @@ class UninstallPackAction(Action): def __init__(self, config=None, action_service=None): - super(UninstallPackAction, self).__init__(config=config, action_service=action_service) - self._base_virtualenvs_path = os.path.join(cfg.CONF.system.base_path, - 'virtualenvs/') + super(UninstallPackAction, self).__init__( + config=config, action_service=action_service + ) + self._base_virtualenvs_path = os.path.join( + cfg.CONF.system.base_path, "virtualenvs/" + ) def run(self, packs, abs_repo_base, delete_env=True): intersection = BLOCKED_PACKS & frozenset(packs) if len(intersection) > 0: - names = ', '.join(list(intersection)) - raise ValueError('Uninstall includes an uninstallable pack - %s.' % (names)) + names = ", ".join(list(intersection)) + raise ValueError("Uninstall includes an uninstallable pack - %s." % (names)) # 1. Delete pack content for fp in os.listdir(abs_repo_base): @@ -51,6 +54,8 @@ def run(self, packs, abs_repo_base, delete_env=True): virtualenv_path = os.path.join(self._base_virtualenvs_path, pack_name) if os.path.isdir(virtualenv_path): - self.logger.debug('Deleting virtualenv "%s" for pack "%s"' % - (virtualenv_path, pack_name)) + self.logger.debug( + 'Deleting virtualenv "%s" for pack "%s"' + % (virtualenv_path, pack_name) + ) shutil.rmtree(virtualenv_path) diff --git a/contrib/packs/actions/pack_mgmt/download.py b/contrib/packs/actions/pack_mgmt/download.py index b4d888630b4..cc0f7cd8fb0 100644 --- a/contrib/packs/actions/pack_mgmt/download.py +++ b/contrib/packs/actions/pack_mgmt/download.py @@ -21,68 +21,85 @@ from st2common.runners.base_action import Action from st2common.util.pack_management import download_pack -__all__ = [ - 'DownloadGitRepoAction' -] +__all__ = ["DownloadGitRepoAction"] class DownloadGitRepoAction(Action): def __init__(self, config=None, action_service=None): - super(DownloadGitRepoAction, self).__init__(config=config, action_service=action_service) + super(DownloadGitRepoAction, self).__init__( + config=config, action_service=action_service + ) - self.https_proxy = os.environ.get('https_proxy', self.config.get('https_proxy', None)) - self.http_proxy = os.environ.get('http_proxy', self.config.get('http_proxy', None)) + self.https_proxy = os.environ.get( + "https_proxy", self.config.get("https_proxy", None) + ) + self.http_proxy = os.environ.get( + "http_proxy", self.config.get("http_proxy", None) + ) self.proxy_ca_bundle_path = os.environ.get( - 'proxy_ca_bundle_path', - self.config.get('proxy_ca_bundle_path', None) + "proxy_ca_bundle_path", self.config.get("proxy_ca_bundle_path", None) ) - self.no_proxy = os.environ.get('no_proxy', self.config.get('no_proxy', None)) + self.no_proxy = os.environ.get("no_proxy", self.config.get("no_proxy", None)) self.proxy_config = None if self.http_proxy or self.https_proxy: - self.logger.debug('Using proxy %s', - self.http_proxy if self.http_proxy else self.https_proxy) + self.logger.debug( + "Using proxy %s", + self.http_proxy if self.http_proxy else self.https_proxy, + ) self.proxy_config = { - 'https_proxy': self.https_proxy, - 'http_proxy': self.http_proxy, - 'proxy_ca_bundle_path': self.proxy_ca_bundle_path, - 'no_proxy': self.no_proxy + "https_proxy": self.https_proxy, + "http_proxy": self.http_proxy, + "proxy_ca_bundle_path": self.proxy_ca_bundle_path, + "no_proxy": self.no_proxy, } # This is needed for git binary to work with a proxy - if self.https_proxy and not os.environ.get('https_proxy', None): - os.environ['https_proxy'] = self.https_proxy + if self.https_proxy and not os.environ.get("https_proxy", None): + os.environ["https_proxy"] = self.https_proxy - if self.http_proxy and not os.environ.get('http_proxy', None): - os.environ['http_proxy'] = self.http_proxy + if self.http_proxy and not os.environ.get("http_proxy", None): + os.environ["http_proxy"] = self.http_proxy - if self.no_proxy and not os.environ.get('no_proxy', None): - os.environ['no_proxy'] = self.no_proxy + if self.no_proxy and not os.environ.get("no_proxy", None): + os.environ["no_proxy"] = self.no_proxy - if self.proxy_ca_bundle_path and not os.environ.get('proxy_ca_bundle_path', None): - os.environ['no_proxy'] = self.no_proxy + if self.proxy_ca_bundle_path and not os.environ.get( + "proxy_ca_bundle_path", None + ): + os.environ["no_proxy"] = self.no_proxy - def run(self, packs, abs_repo_base, verifyssl=True, force=False, - dependency_list=None): + def run( + self, packs, abs_repo_base, verifyssl=True, force=False, dependency_list=None + ): result = {} pack_url = None if dependency_list: for pack_dependency in dependency_list: - pack_result = download_pack(pack=pack_dependency, abs_repo_base=abs_repo_base, - verify_ssl=verifyssl, force=force, - proxy_config=self.proxy_config, force_permissions=True, - logger=self.logger) + pack_result = download_pack( + pack=pack_dependency, + abs_repo_base=abs_repo_base, + verify_ssl=verifyssl, + force=force, + proxy_config=self.proxy_config, + force_permissions=True, + logger=self.logger, + ) pack_url, pack_ref, pack_result = pack_result result[pack_ref] = pack_result else: for pack in packs: - pack_result = download_pack(pack=pack, abs_repo_base=abs_repo_base, - verify_ssl=verifyssl, force=force, - proxy_config=self.proxy_config, - force_permissions=True, - logger=self.logger) + pack_result = download_pack( + pack=pack, + abs_repo_base=abs_repo_base, + verify_ssl=verifyssl, + force=force, + proxy_config=self.proxy_config, + force_permissions=True, + logger=self.logger, + ) pack_url, pack_ref, pack_result = pack_result result[pack_ref] = pack_result @@ -99,14 +116,16 @@ def _validate_result(result, repo_url): if not atleast_one_success: message_list = [] - message_list.append('The pack has not been downloaded from "%s".\n' % (repo_url)) - message_list.append('Errors:') + message_list.append( + 'The pack has not been downloaded from "%s".\n' % (repo_url) + ) + message_list.append("Errors:") for pack, value in result.items(): success, error = value message_list.append(error) - message = '\n'.join(message_list) + message = "\n".join(message_list) raise Exception(message) return sanitized_result diff --git a/contrib/packs/actions/pack_mgmt/get_installed.py b/contrib/packs/actions/pack_mgmt/get_installed.py index eaa88b6319e..36f2504b85e 100644 --- a/contrib/packs/actions/pack_mgmt/get_installed.py +++ b/contrib/packs/actions/pack_mgmt/get_installed.py @@ -28,6 +28,7 @@ class GetInstalled(Action): """"Get information about installed pack.""" + def run(self, pack): """ :param pack: Installed Pack Name to get info about @@ -47,46 +48,42 @@ def run(self, pack): # Pack doesn't exist, finish execution normally with empty metadata if not os.path.isdir(pack_path): - return { - 'pack': None, - 'git_status': None - } + return {"pack": None, "git_status": None} if not metadata_file: - error = ('Pack "%s" doesn\'t contain pack.yaml file.' % (pack)) + error = 'Pack "%s" doesn\'t contain pack.yaml file.' % (pack) raise Exception(error) try: details = self._parse_yaml_file(metadata_file) except Exception as e: - error = ('Pack "%s" doesn\'t contain a valid pack.yaml file: %s' % (pack, - six.text_type(e))) + error = 'Pack "%s" doesn\'t contain a valid pack.yaml file: %s' % ( + pack, + six.text_type(e), + ) raise Exception(error) try: repo = Repo(pack_path) git_status = "Status:\n%s\n\nRemotes:\n%s" % ( - repo.git.status().split('\n')[0], - "\n".join([remote.url for remote in repo.remotes]) + repo.git.status().split("\n")[0], + "\n".join([remote.url for remote in repo.remotes]), ) ahead_behind = repo.git.rev_list( - '--left-right', '--count', 'HEAD...origin/master' + "--left-right", "--count", "HEAD...origin/master" ).split() # Dear god. - if ahead_behind != [u'0', u'0']: + if ahead_behind != ["0", "0"]: git_status += "\n\n" - git_status += "%s commits ahead " if ahead_behind[0] != u'0' else "" - git_status += "and " if u'0' not in ahead_behind else "" - git_status += "%s commits behind " if ahead_behind[1] != u'0' else "" + git_status += "%s commits ahead " if ahead_behind[0] != "0" else "" + git_status += "and " if "0" not in ahead_behind else "" + git_status += "%s commits behind " if ahead_behind[1] != "0" else "" git_status += "origin/master." except InvalidGitRepositoryError: git_status = None - return { - 'pack': details, - 'git_status': git_status - } + return {"pack": details, "git_status": git_status} def _parse_yaml_file(self, file_path): with open(file_path) as data_file: diff --git a/contrib/packs/actions/pack_mgmt/get_pack_dependencies.py b/contrib/packs/actions/pack_mgmt/get_pack_dependencies.py index 60ab2c9503a..b9168526a22 100644 --- a/contrib/packs/actions/pack_mgmt/get_pack_dependencies.py +++ b/contrib/packs/actions/pack_mgmt/get_pack_dependencies.py @@ -40,7 +40,7 @@ def run(self, packs_status, nested): return result for pack, status in six.iteritems(packs_status): - if 'success' not in status.lower(): + if "success" not in status.lower(): continue dependency_packs = get_dependency_list(pack) @@ -50,40 +50,51 @@ def run(self, packs_status, nested): for dep_pack in dependency_packs: name_or_url, pack_version = self.get_name_and_version(dep_pack) - if len(name_or_url.split('/')) == 1: + if len(name_or_url.split("/")) == 1: pack_name = name_or_url else: name_or_git = name_or_url.split("/")[-1] - pack_name = name_or_git if '.git' not in name_or_git else \ - name_or_git.split('.')[0] + pack_name = ( + name_or_git + if ".git" not in name_or_git + else name_or_git.split(".")[0] + ) # Check existing pack by pack name existing_pack_version = get_pack_version(pack_name) # Try one more time to get existing pack version by name if 'stackstorm-' is in # pack name - if not existing_pack_version and 'stackstorm-' in pack_name.lower(): - existing_pack_version = get_pack_version(pack_name.split('stackstorm-')[-1]) + if not existing_pack_version and "stackstorm-" in pack_name.lower(): + existing_pack_version = get_pack_version( + pack_name.split("stackstorm-")[-1] + ) if existing_pack_version: - if existing_pack_version and not existing_pack_version.startswith('v'): - existing_pack_version = 'v' + existing_pack_version - if pack_version and not pack_version.startswith('v'): - pack_version = 'v' + pack_version - if pack_version and existing_pack_version != pack_version \ - and dep_pack not in conflict_list: + if existing_pack_version and not existing_pack_version.startswith( + "v" + ): + existing_pack_version = "v" + existing_pack_version + if pack_version and not pack_version.startswith("v"): + pack_version = "v" + pack_version + if ( + pack_version + and existing_pack_version != pack_version + and dep_pack not in conflict_list + ): conflict_list.append(dep_pack) else: - conflict = self.check_dependency_list_for_conflict(name_or_url, pack_version, - dependency_list) + conflict = self.check_dependency_list_for_conflict( + name_or_url, pack_version, dependency_list + ) if conflict: conflict_list.append(dep_pack) elif dep_pack not in dependency_list: dependency_list.append(dep_pack) - result['dependency_list'] = dependency_list - result['conflict_list'] = conflict_list - result['nested'] = nested - 1 + result["dependency_list"] = dependency_list + result["conflict_list"] = conflict_list + result["nested"] = nested - 1 return result @@ -112,7 +123,7 @@ def get_pack_version(pack=None): pack_path = get_pack_base_path(pack) try: pack_metadata = get_pack_metadata(pack_dir=pack_path) - result = pack_metadata.get('version', None) + result = pack_metadata.get("version", None) except Exception: result = None finally: @@ -124,9 +135,9 @@ def get_dependency_list(pack=None): try: pack_metadata = get_pack_metadata(pack_dir=pack_path) - result = pack_metadata.get('dependencies', None) + result = pack_metadata.get("dependencies", None) except Exception: - print('Could not open pack.yaml at location %s' % pack_path) + print("Could not open pack.yaml at location %s" % pack_path) result = None finally: return result diff --git a/contrib/packs/actions/pack_mgmt/get_pack_warnings.py b/contrib/packs/actions/pack_mgmt/get_pack_warnings.py index 445a5df0c2d..e8f42dcbb66 100755 --- a/contrib/packs/actions/pack_mgmt/get_pack_warnings.py +++ b/contrib/packs/actions/pack_mgmt/get_pack_warnings.py @@ -34,7 +34,7 @@ def run(self, packs_status): return result for pack, status in six.iteritems(packs_status): - if 'success' not in status.lower(): + if "success" not in status.lower(): continue warning = get_warnings(pack) @@ -42,7 +42,7 @@ def run(self, packs_status): if warning: warning_list.append(warning) - result['warning_list'] = warning_list + result["warning_list"] = warning_list return result @@ -54,6 +54,6 @@ def get_warnings(pack=None): pack_metadata = get_pack_metadata(pack_dir=pack_path) result = get_pack_warnings(pack_metadata) except Exception: - print('Could not open pack.yaml at location %s' % pack_path) + print("Could not open pack.yaml at location %s" % pack_path) finally: return result diff --git a/contrib/packs/actions/pack_mgmt/register.py b/contrib/packs/actions/pack_mgmt/register.py index 220962f0f43..1587333d5b2 100644 --- a/contrib/packs/actions/pack_mgmt/register.py +++ b/contrib/packs/actions/pack_mgmt/register.py @@ -19,21 +19,19 @@ from st2client.models.keyvalue import KeyValuePair # pylint: disable=no-name-in-module from st2common.runners.base_action import Action -__all__ = [ - 'St2RegisterAction' -] +__all__ = ["St2RegisterAction"] COMPATIBILITY_TRANSFORMATIONS = { - 'runners': 'runner', - 'triggers': 'trigger', - 'sensors': 'sensor', - 'actions': 'action', - 'rules': 'rule', - 'rule_types': 'rule_type', - 'aliases': 'alias', - 'policiy_types': 'policy_type', - 'policies': 'policy', - 'configs': 'config', + "runners": "runner", + "triggers": "trigger", + "sensors": "sensor", + "actions": "action", + "rules": "rule", + "rule_types": "rule_type", + "aliases": "alias", + "policiy_types": "policy_type", + "policies": "policy", + "configs": "config", } @@ -63,23 +61,23 @@ def __init__(self, config): def run(self, register, packs=None): types = [] - for type in register.split(','): + for type in register.split(","): if type in COMPATIBILITY_TRANSFORMATIONS: types.append(COMPATIBILITY_TRANSFORMATIONS[type]) else: types.append(type) - method_kwargs = { - 'types': types - } + method_kwargs = {"types": types} packs.reverse() if packs: - method_kwargs['packs'] = packs + method_kwargs["packs"] = packs - result = self._run_client_method(method=self.client.packs.register, - method_kwargs=method_kwargs, - format_func=format_result) + result = self._run_client_method( + method=self.client.packs.register, + method_kwargs=method_kwargs, + format_func=format_result, + ) # TODO: make sure to return proper model return result @@ -90,42 +88,48 @@ def _get_client(self): client_kwargs = {} if cacert: - client_kwargs['cacert'] = cacert + client_kwargs["cacert"] = cacert - return self._client(base_url=base_url, api_url=api_url, - auth_url=auth_url, token=token, - **client_kwargs) + return self._client( + base_url=base_url, + api_url=api_url, + auth_url=auth_url, + token=token, + **client_kwargs, + ) def _get_st2_urls(self): # First try to use base_url from config. - base_url = self.config.get('base_url', None) - api_url = self.config.get('api_url', None) - auth_url = self.config.get('auth_url', None) + base_url = self.config.get("base_url", None) + api_url = self.config.get("api_url", None) + auth_url = self.config.get("auth_url", None) # not found look up from env vars. Assuming the pack is # configuered to work with current StackStorm instance. if not base_url: - api_url = os.environ.get('ST2_ACTION_API_URL', None) - auth_url = os.environ.get('ST2_ACTION_AUTH_URL', None) + api_url = os.environ.get("ST2_ACTION_API_URL", None) + auth_url = os.environ.get("ST2_ACTION_AUTH_URL", None) return base_url, api_url, auth_url def _get_auth_token(self): # First try to use auth_token from config. - token = self.config.get('auth_token', None) + token = self.config.get("auth_token", None) # not found look up from env vars. Assuming the pack is # configuered to work with current StackStorm instance. if not token: - token = os.environ.get('ST2_ACTION_AUTH_TOKEN', None) + token = os.environ.get("ST2_ACTION_AUTH_TOKEN", None) return token def _get_cacert(self): - cacert = self.config.get('cacert', None) + cacert = self.config.get("cacert", None) return cacert - def _run_client_method(self, method, method_kwargs, format_func, format_kwargs=None): + def _run_client_method( + self, method, method_kwargs, format_func, format_kwargs=None + ): """ Run the provided client method and format the result. @@ -144,8 +148,9 @@ def _run_client_method(self, method, method_kwargs, format_func, format_kwargs=N # This is a work around since the default values can only be strings method_kwargs = filter_none_values(method_kwargs) method_name = method.__name__ - self.logger.debug('Calling client method "%s" with kwargs "%s"' % (method_name, - method_kwargs)) + self.logger.debug( + 'Calling client method "%s" with kwargs "%s"' % (method_name, method_kwargs) + ) result = method(**method_kwargs) result = format_func(result, **format_kwargs or {}) diff --git a/contrib/packs/actions/pack_mgmt/search.py b/contrib/packs/actions/pack_mgmt/search.py index b7cb07f7fc6..dd732c1b29b 100644 --- a/contrib/packs/actions/pack_mgmt/search.py +++ b/contrib/packs/actions/pack_mgmt/search.py @@ -22,43 +22,51 @@ class PackSearch(Action): def __init__(self, config=None, action_service=None): super(PackSearch, self).__init__(config=config, action_service=action_service) - self.https_proxy = os.environ.get('https_proxy', self.config.get('https_proxy', None)) - self.http_proxy = os.environ.get('http_proxy', self.config.get('http_proxy', None)) + self.https_proxy = os.environ.get( + "https_proxy", self.config.get("https_proxy", None) + ) + self.http_proxy = os.environ.get( + "http_proxy", self.config.get("http_proxy", None) + ) self.proxy_ca_bundle_path = os.environ.get( - 'proxy_ca_bundle_path', - self.config.get('proxy_ca_bundle_path', None) + "proxy_ca_bundle_path", self.config.get("proxy_ca_bundle_path", None) ) - self.no_proxy = os.environ.get('no_proxy', self.config.get('no_proxy', None)) + self.no_proxy = os.environ.get("no_proxy", self.config.get("no_proxy", None)) self.proxy_config = None if self.http_proxy or self.https_proxy: - self.logger.debug('Using proxy %s', - self.http_proxy if self.http_proxy else self.https_proxy) + self.logger.debug( + "Using proxy %s", + self.http_proxy if self.http_proxy else self.https_proxy, + ) self.proxy_config = { - 'https_proxy': self.https_proxy, - 'http_proxy': self.http_proxy, - 'proxy_ca_bundle_path': self.proxy_ca_bundle_path, - 'no_proxy': self.no_proxy + "https_proxy": self.https_proxy, + "http_proxy": self.http_proxy, + "proxy_ca_bundle_path": self.proxy_ca_bundle_path, + "no_proxy": self.no_proxy, } - if self.https_proxy and not os.environ.get('https_proxy', None): - os.environ['https_proxy'] = self.https_proxy + if self.https_proxy and not os.environ.get("https_proxy", None): + os.environ["https_proxy"] = self.https_proxy - if self.http_proxy and not os.environ.get('http_proxy', None): - os.environ['http_proxy'] = self.http_proxy + if self.http_proxy and not os.environ.get("http_proxy", None): + os.environ["http_proxy"] = self.http_proxy - if self.no_proxy and not os.environ.get('no_proxy', None): - os.environ['no_proxy'] = self.no_proxy + if self.no_proxy and not os.environ.get("no_proxy", None): + os.environ["no_proxy"] = self.no_proxy - if self.proxy_ca_bundle_path and not os.environ.get('proxy_ca_bundle_path', None): - os.environ['no_proxy'] = self.no_proxy + if self.proxy_ca_bundle_path and not os.environ.get( + "proxy_ca_bundle_path", None + ): + os.environ["no_proxy"] = self.no_proxy """"Search for packs in StackStorm Exchange and other directories.""" + def run(self, query): """ :param query: A word or a phrase to search for :type query: ``str`` """ - self.logger.debug('Proxy config: %s', self.proxy_config) + self.logger.debug("Proxy config: %s", self.proxy_config) return search_pack_index(query, proxy_config=self.proxy_config) diff --git a/contrib/packs/actions/pack_mgmt/setup_virtualenv.py b/contrib/packs/actions/pack_mgmt/setup_virtualenv.py index 23f8a75ef71..bf7a32ed7ea 100644 --- a/contrib/packs/actions/pack_mgmt/setup_virtualenv.py +++ b/contrib/packs/actions/pack_mgmt/setup_virtualenv.py @@ -18,9 +18,7 @@ from st2common.runners.base_action import Action from st2common.util.virtualenvs import setup_pack_virtualenv -__all__ = [ - 'SetupVirtualEnvironmentAction' -] +__all__ = ["SetupVirtualEnvironmentAction"] class SetupVirtualEnvironmentAction(Action): @@ -37,42 +35,50 @@ class SetupVirtualEnvironmentAction(Action): creation of the virtual environment and performs an update of the current dependencies as well as an installation of new dependencies """ + def __init__(self, config=None, action_service=None): super(SetupVirtualEnvironmentAction, self).__init__( - config=config, - action_service=action_service) + config=config, action_service=action_service + ) - self.https_proxy = os.environ.get('https_proxy', self.config.get('https_proxy', None)) - self.http_proxy = os.environ.get('http_proxy', self.config.get('http_proxy', None)) + self.https_proxy = os.environ.get( + "https_proxy", self.config.get("https_proxy", None) + ) + self.http_proxy = os.environ.get( + "http_proxy", self.config.get("http_proxy", None) + ) self.proxy_ca_bundle_path = os.environ.get( - 'proxy_ca_bundle_path', - self.config.get('proxy_ca_bundle_path', None) + "proxy_ca_bundle_path", self.config.get("proxy_ca_bundle_path", None) ) - self.no_proxy = os.environ.get('no_proxy', self.config.get('no_proxy', None)) + self.no_proxy = os.environ.get("no_proxy", self.config.get("no_proxy", None)) self.proxy_config = None if self.http_proxy or self.https_proxy: - self.logger.debug('Using proxy %s', - self.http_proxy if self.http_proxy else self.https_proxy) + self.logger.debug( + "Using proxy %s", + self.http_proxy if self.http_proxy else self.https_proxy, + ) self.proxy_config = { - 'https_proxy': self.https_proxy, - 'http_proxy': self.http_proxy, - 'proxy_ca_bundle_path': self.proxy_ca_bundle_path, - 'no_proxy': self.no_proxy + "https_proxy": self.https_proxy, + "http_proxy": self.http_proxy, + "proxy_ca_bundle_path": self.proxy_ca_bundle_path, + "no_proxy": self.no_proxy, } - if self.https_proxy and not os.environ.get('https_proxy', None): - os.environ['https_proxy'] = self.https_proxy + if self.https_proxy and not os.environ.get("https_proxy", None): + os.environ["https_proxy"] = self.https_proxy - if self.http_proxy and not os.environ.get('http_proxy', None): - os.environ['http_proxy'] = self.http_proxy + if self.http_proxy and not os.environ.get("http_proxy", None): + os.environ["http_proxy"] = self.http_proxy - if self.no_proxy and not os.environ.get('no_proxy', None): - os.environ['no_proxy'] = self.no_proxy + if self.no_proxy and not os.environ.get("no_proxy", None): + os.environ["no_proxy"] = self.no_proxy - if self.proxy_ca_bundle_path and not os.environ.get('proxy_ca_bundle_path', None): - os.environ['no_proxy'] = self.no_proxy + if self.proxy_ca_bundle_path and not os.environ.get( + "proxy_ca_bundle_path", None + ): + os.environ["no_proxy"] = self.no_proxy def run(self, packs, update=False, no_download=True): """ @@ -84,10 +90,15 @@ def run(self, packs, update=False, no_download=True): """ for pack_name in packs: - setup_pack_virtualenv(pack_name=pack_name, update=update, logger=self.logger, - proxy_config=self.proxy_config, - no_download=no_download) - - message = ('Successfully set up virtualenv for the following packs: %s' % - (', '.join(packs))) + setup_pack_virtualenv( + pack_name=pack_name, + update=update, + logger=self.logger, + proxy_config=self.proxy_config, + no_download=no_download, + ) + + message = "Successfully set up virtualenv for the following packs: %s" % ( + ", ".join(packs) + ) return message diff --git a/contrib/packs/actions/pack_mgmt/show_remote.py b/contrib/packs/actions/pack_mgmt/show_remote.py index ba5bff8141e..6b2f655594d 100644 --- a/contrib/packs/actions/pack_mgmt/show_remote.py +++ b/contrib/packs/actions/pack_mgmt/show_remote.py @@ -19,11 +19,10 @@ class ShowRemote(Action): """Get detailed information about an available pack from the StackStorm Exchange index""" + def run(self, pack): """ :param pack: Pack Name to get info about :type pack: ``str`` """ - return { - 'pack': get_pack_from_index(pack) - } + return {"pack": get_pack_from_index(pack)} diff --git a/contrib/packs/actions/pack_mgmt/unload.py b/contrib/packs/actions/pack_mgmt/unload.py index c72cdf9ce10..46caf9cc7ac 100644 --- a/contrib/packs/actions/pack_mgmt/unload.py +++ b/contrib/packs/actions/pack_mgmt/unload.py @@ -36,31 +36,48 @@ class UnregisterPackAction(BaseAction): def __init__(self, config=None, action_service=None): - super(UnregisterPackAction, self).__init__(config=config, action_service=action_service) + super(UnregisterPackAction, self).__init__( + config=config, action_service=action_service + ) self.initialize() def initialize(self): # 1. Setup db connection - username = cfg.CONF.database.username if hasattr(cfg.CONF.database, 'username') else None - password = cfg.CONF.database.password if hasattr(cfg.CONF.database, 'password') else None - db_setup(cfg.CONF.database.db_name, cfg.CONF.database.host, cfg.CONF.database.port, - username=username, password=password, - ssl=cfg.CONF.database.ssl, - ssl_keyfile=cfg.CONF.database.ssl_keyfile, - ssl_certfile=cfg.CONF.database.ssl_certfile, - ssl_cert_reqs=cfg.CONF.database.ssl_cert_reqs, - ssl_ca_certs=cfg.CONF.database.ssl_ca_certs, - authentication_mechanism=cfg.CONF.database.authentication_mechanism, - ssl_match_hostname=cfg.CONF.database.ssl_match_hostname) + username = ( + cfg.CONF.database.username + if hasattr(cfg.CONF.database, "username") + else None + ) + password = ( + cfg.CONF.database.password + if hasattr(cfg.CONF.database, "password") + else None + ) + db_setup( + cfg.CONF.database.db_name, + cfg.CONF.database.host, + cfg.CONF.database.port, + username=username, + password=password, + ssl=cfg.CONF.database.ssl, + ssl_keyfile=cfg.CONF.database.ssl_keyfile, + ssl_certfile=cfg.CONF.database.ssl_certfile, + ssl_cert_reqs=cfg.CONF.database.ssl_cert_reqs, + ssl_ca_certs=cfg.CONF.database.ssl_ca_certs, + authentication_mechanism=cfg.CONF.database.authentication_mechanism, + ssl_match_hostname=cfg.CONF.database.ssl_match_hostname, + ) def run(self, packs): intersection = BLOCKED_PACKS & frozenset(packs) if len(intersection) > 0: - names = ', '.join(list(intersection)) - raise ValueError('Unregister includes an unregisterable pack - %s.' % (names)) + names = ", ".join(list(intersection)) + raise ValueError( + "Unregister includes an unregisterable pack - %s." % (names) + ) for pack in packs: - self.logger.debug('Removing pack %s.', pack) + self.logger.debug("Removing pack %s.", pack) self._unregister_sensors(pack=pack) self._unregister_trigger_types(pack=pack) self._unregister_triggers(pack=pack) @@ -69,21 +86,27 @@ def run(self, packs): self._unregister_aliases(pack=pack) self._unregister_policies(pack=pack) self._unregister_pack(pack=pack) - self.logger.info('Removed pack %s.', pack) + self.logger.info("Removed pack %s.", pack) def _unregister_sensors(self, pack): return self._delete_pack_db_objects(pack=pack, access_cls=SensorType) def _unregister_trigger_types(self, pack): - deleted_trigger_types_dbs = self._delete_pack_db_objects(pack=pack, access_cls=TriggerType) + deleted_trigger_types_dbs = self._delete_pack_db_objects( + pack=pack, access_cls=TriggerType + ) # 2. Check if deleted trigger is used by any other rules outside this pack for trigger_type_db in deleted_trigger_types_dbs: - rule_dbs = Rule.query(trigger=trigger_type_db.ref, pack__ne=trigger_type_db.pack) + rule_dbs = Rule.query( + trigger=trigger_type_db.ref, pack__ne=trigger_type_db.pack + ) for rule_db in rule_dbs: - self.logger.warning('Rule "%s" references deleted trigger "%s"' % - (rule_db.name, trigger_type_db.ref)) + self.logger.warning( + 'Rule "%s" references deleted trigger "%s"' + % (rule_db.name, trigger_type_db.ref) + ) return deleted_trigger_types_dbs @@ -136,25 +159,25 @@ def _delete_pack_db_object(self, pack): pack_db = None if not pack_db: - self.logger.exception('Pack DB object not found') + self.logger.exception("Pack DB object not found") return try: Pack.delete(pack_db) except: - self.logger.exception('Failed to remove DB object %s.', pack_db) + self.logger.exception("Failed to remove DB object %s.", pack_db) def _delete_config_schema_db_object(self, pack): try: config_schema_db = ConfigSchema.get_by_pack(value=pack) except StackStormDBObjectNotFoundError: - self.logger.exception('ConfigSchemaDB object not found') + self.logger.exception("ConfigSchemaDB object not found") return try: ConfigSchema.delete(config_schema_db) except: - self.logger.exception('Failed to remove DB object %s.', config_schema_db) + self.logger.exception("Failed to remove DB object %s.", config_schema_db) def _delete_pack_db_objects(self, pack, access_cls): db_objs = access_cls.get_all(pack=pack) @@ -166,6 +189,6 @@ def _delete_pack_db_objects(self, pack, access_cls): access_cls.delete(db_obj) deleted_objs.append(db_obj) except: - self.logger.exception('Failed to remove DB object %s.', db_obj) + self.logger.exception("Failed to remove DB object %s.", db_obj) return deleted_objs diff --git a/contrib/packs/actions/pack_mgmt/virtualenv_setup_prerun.py b/contrib/packs/actions/pack_mgmt/virtualenv_setup_prerun.py index aedc993f6bb..abde082ed3b 100644 --- a/contrib/packs/actions/pack_mgmt/virtualenv_setup_prerun.py +++ b/contrib/packs/actions/pack_mgmt/virtualenv_setup_prerun.py @@ -32,7 +32,7 @@ def run(self, packs_status, packs_list=None): packs = [] for pack_name, status in six.iteritems(packs_status): - if 'success' in status.lower(): + if "success" in status.lower(): packs.append(pack_name) packs_list.extend(packs) diff --git a/contrib/packs/tests/test_action_aliases.py b/contrib/packs/tests/test_action_aliases.py index 858a1677518..ecfebe8b687 100644 --- a/contrib/packs/tests/test_action_aliases.py +++ b/contrib/packs/tests/test_action_aliases.py @@ -19,73 +19,65 @@ class PackGet(BaseActionAliasTestCase): action_alias_name = "pack_get" def test_alias_pack_get(self): - format_string = self.action_alias_db.formats[0]['representation'][0] + format_string = self.action_alias_db.formats[0]["representation"][0] format_strings = self.action_alias_db.get_format_strings() command = "pack get st2" - expected_parameters = { - 'pack': "st2" - } + expected_parameters = {"pack": "st2"} - self.assertExtractedParametersMatch(format_string=format_string, - command=command, - parameters=expected_parameters) + self.assertExtractedParametersMatch( + format_string=format_string, command=command, parameters=expected_parameters + ) self.assertCommandMatchesExactlyOneFormatString( - format_strings=format_strings, - command=command) + format_strings=format_strings, command=command + ) class PackInstall(BaseActionAliasTestCase): action_alias_name = "pack_install" def test_alias_pack_install(self): - format_string = self.action_alias_db.formats[0]['representation'][0] + format_string = self.action_alias_db.formats[0]["representation"][0] command = "pack install st2" - expected_parameters = { - 'packs': "st2" - } + expected_parameters = {"packs": "st2"} - self.assertExtractedParametersMatch(format_string=format_string, - command=command, - parameters=expected_parameters) + self.assertExtractedParametersMatch( + format_string=format_string, command=command, parameters=expected_parameters + ) class PackSearch(BaseActionAliasTestCase): action_alias_name = "pack_search" def test_alias_pack_search(self): - format_string = self.action_alias_db.formats[0]['representation'][0] + format_string = self.action_alias_db.formats[0]["representation"][0] format_strings = self.action_alias_db.get_format_strings() command = "pack search st2" - expected_parameters = { - 'query': "st2" - } + expected_parameters = {"query": "st2"} - self.assertExtractedParametersMatch(format_string=format_string, - command=command, - parameters=expected_parameters) + self.assertExtractedParametersMatch( + format_string=format_string, command=command, parameters=expected_parameters + ) self.assertCommandMatchesExactlyOneFormatString( - format_strings=format_strings, - command=command) + format_strings=format_strings, command=command + ) class PackShow(BaseActionAliasTestCase): action_alias_name = "pack_show" def test_alias_pack_show(self): - format_string = self.action_alias_db.formats[0]['representation'][0] + format_string = self.action_alias_db.formats[0]["representation"][0] format_strings = self.action_alias_db.get_format_strings() command = "pack show st2" - expected_parameters = { - 'pack': "st2" - } + expected_parameters = {"pack": "st2"} - self.assertExtractedParametersMatch(format_string=format_string, - command=command, - parameters=expected_parameters) + self.assertExtractedParametersMatch( + format_string=format_string, command=command, parameters=expected_parameters + ) self.assertCommandMatchesExactlyOneFormatString( - format_strings=format_strings, - command=command) + format_strings=format_strings, command=command + ) diff --git a/contrib/packs/tests/test_action_download.py b/contrib/packs/tests/test_action_download.py index 3eeda008868..c29e95fccc2 100644 --- a/contrib/packs/tests/test_action_download.py +++ b/contrib/packs/tests/test_action_download.py @@ -22,6 +22,7 @@ import hashlib from st2common.util.monkey_patch import use_select_poll_workaround + use_select_poll_workaround() from lockfile import LockFile @@ -46,7 +47,7 @@ "author": "st2-dev", "keywords": ["some", "search", "another", "terms"], "email": "info@stackstorm.com", - "description": "st2 pack to test package management pipeline" + "description": "st2 pack to test package management pipeline", }, "test2": { "version": "0.5.0", @@ -55,7 +56,7 @@ "author": "stanley", "keywords": ["some", "special", "terms"], "email": "info@stackstorm.com", - "description": "another st2 pack to test package management pipeline" + "description": "another st2 pack to test package management pipeline", }, "test3": { "version": "0.5.0", @@ -65,16 +66,17 @@ "author": "stanley", "keywords": ["some", "special", "terms"], "email": "info@stackstorm.com", - "description": "another st2 pack to test package management pipeline" + "description": "another st2 pack to test package management pipeline", }, "test4": { "version": "0.5.0", "name": "test4", "repo_url": "https://github.com/StackStorm-Exchange/stackstorm-test4", "author": "stanley", - "keywords": ["some", "special", "terms"], "email": "info@stackstorm.com", - "description": "another st2 pack to test package management pipeline" - } + "keywords": ["some", "special", "terms"], + "email": "info@stackstorm.com", + "description": "another st2 pack to test package management pipeline", + }, } @@ -85,7 +87,7 @@ def mock_is_dir_func(path): """ Mock function which returns True if path ends with .git """ - if path.endswith('.git'): + if path.endswith(".git"): return True return original_is_dir_func(path) @@ -95,9 +97,9 @@ def mock_get_gitref(repo, ref): Mock get_gitref function which return mocked object if ref passed is PACK_INDEX['test']['version'] """ - if PACK_INDEX['test']['version'] in ref: - if ref[0] == 'v': - return mock.MagicMock(hexsha=PACK_INDEX['test']['version']) + if PACK_INDEX["test"]["version"] in ref: + if ref[0] == "v": + return mock.MagicMock(hexsha=PACK_INDEX["test"]["version"]) else: return None elif ref: @@ -106,21 +108,24 @@ def mock_get_gitref(repo, ref): return None -@mock.patch.object(pack_service, 'fetch_pack_index', mock.MagicMock(return_value=(PACK_INDEX, {}))) +@mock.patch.object( + pack_service, "fetch_pack_index", mock.MagicMock(return_value=(PACK_INDEX, {})) +) class DownloadGitRepoActionTestCase(BaseActionTestCase): action_cls = DownloadGitRepoAction def setUp(self): super(DownloadGitRepoActionTestCase, self).setUp() - clone_from = mock.patch.object(Repo, 'clone_from') + clone_from = mock.patch.object(Repo, "clone_from") self.addCleanup(clone_from.stop) self.clone_from = clone_from.start() self.expand_user_path = tempfile.mkdtemp() - expand_user = mock.patch.object(os.path, 'expanduser', - mock.MagicMock(return_value=self.expand_user_path)) + expand_user = mock.patch.object( + os.path, "expanduser", mock.MagicMock(return_value=self.expand_user_path) + ) self.addCleanup(expand_user.stop) self.expand_user = expand_user.start() @@ -132,8 +137,10 @@ def setUp(self): def side_effect(url, to_path, **kwargs): # Since we have no way to pass pack name here, we would have to derive it from repo url - fixture_name = url.split('/')[-1] - fixture_path = os.path.join(self._get_base_pack_path(), 'tests/fixtures', fixture_name) + fixture_name = url.split("/")[-1] + fixture_path = os.path.join( + self._get_base_pack_path(), "tests/fixtures", fixture_name + ) shutil.copytree(fixture_path, to_path) return self.repo_instance @@ -145,13 +152,15 @@ def tearDown(self): def test_run_pack_download(self): action = self.get_action_instance() - result = action.run(packs=['test'], abs_repo_base=self.repo_base) - temp_dir = hashlib.md5(PACK_INDEX['test']['repo_url'].encode()).hexdigest() + result = action.run(packs=["test"], abs_repo_base=self.repo_base) + temp_dir = hashlib.md5(PACK_INDEX["test"]["repo_url"].encode()).hexdigest() - self.assertEqual(result, {'test': 'Success.'}) - self.clone_from.assert_called_once_with(PACK_INDEX['test']['repo_url'], - os.path.join(os.path.expanduser('~'), temp_dir)) - self.assertTrue(os.path.isfile(os.path.join(self.repo_base, 'test/pack.yaml'))) + self.assertEqual(result, {"test": "Success."}) + self.clone_from.assert_called_once_with( + PACK_INDEX["test"]["repo_url"], + os.path.join(os.path.expanduser("~"), temp_dir), + ) + self.assertTrue(os.path.isfile(os.path.join(self.repo_base, "test/pack.yaml"))) self.repo_instance.git.checkout.assert_called() self.repo_instance.git.branch.assert_called() @@ -159,65 +168,81 @@ def test_run_pack_download(self): def test_run_pack_download_dependencies(self): action = self.get_action_instance() - result = action.run(packs=['test'], dependency_list=['test2', 'test4'], - abs_repo_base=self.repo_base) + result = action.run( + packs=["test"], + dependency_list=["test2", "test4"], + abs_repo_base=self.repo_base, + ) temp_dirs = [ - hashlib.md5(PACK_INDEX['test2']['repo_url'].encode()).hexdigest(), - hashlib.md5(PACK_INDEX['test4']['repo_url'].encode()).hexdigest() + hashlib.md5(PACK_INDEX["test2"]["repo_url"].encode()).hexdigest(), + hashlib.md5(PACK_INDEX["test4"]["repo_url"].encode()).hexdigest(), ] - self.assertEqual(result, {'test2': 'Success.', 'test4': 'Success.'}) - self.clone_from.assert_any_call(PACK_INDEX['test2']['repo_url'], - os.path.join(os.path.expanduser('~'), temp_dirs[0])) - self.clone_from.assert_any_call(PACK_INDEX['test4']['repo_url'], - os.path.join(os.path.expanduser('~'), temp_dirs[1])) + self.assertEqual(result, {"test2": "Success.", "test4": "Success."}) + self.clone_from.assert_any_call( + PACK_INDEX["test2"]["repo_url"], + os.path.join(os.path.expanduser("~"), temp_dirs[0]), + ) + self.clone_from.assert_any_call( + PACK_INDEX["test4"]["repo_url"], + os.path.join(os.path.expanduser("~"), temp_dirs[1]), + ) self.assertEqual(self.clone_from.call_count, 2) - self.assertTrue(os.path.isfile(os.path.join(self.repo_base, 'test2/pack.yaml'))) - self.assertTrue(os.path.isfile(os.path.join(self.repo_base, 'test4/pack.yaml'))) + self.assertTrue(os.path.isfile(os.path.join(self.repo_base, "test2/pack.yaml"))) + self.assertTrue(os.path.isfile(os.path.join(self.repo_base, "test4/pack.yaml"))) def test_run_pack_download_existing_pack(self): action = self.get_action_instance() - action.run(packs=['test'], abs_repo_base=self.repo_base) - self.assertTrue(os.path.isfile(os.path.join(self.repo_base, 'test/pack.yaml'))) + action.run(packs=["test"], abs_repo_base=self.repo_base) + self.assertTrue(os.path.isfile(os.path.join(self.repo_base, "test/pack.yaml"))) - result = action.run(packs=['test'], abs_repo_base=self.repo_base) + result = action.run(packs=["test"], abs_repo_base=self.repo_base) - self.assertEqual(result, {'test': 'Success.'}) + self.assertEqual(result, {"test": "Success."}) def test_run_pack_download_multiple_packs(self): action = self.get_action_instance() - result = action.run(packs=['test', 'test2'], abs_repo_base=self.repo_base) + result = action.run(packs=["test", "test2"], abs_repo_base=self.repo_base) temp_dirs = [ - hashlib.md5(PACK_INDEX['test']['repo_url'].encode()).hexdigest(), - hashlib.md5(PACK_INDEX['test2']['repo_url'].encode()).hexdigest() + hashlib.md5(PACK_INDEX["test"]["repo_url"].encode()).hexdigest(), + hashlib.md5(PACK_INDEX["test2"]["repo_url"].encode()).hexdigest(), ] - self.assertEqual(result, {'test': 'Success.', 'test2': 'Success.'}) - self.clone_from.assert_any_call(PACK_INDEX['test']['repo_url'], - os.path.join(os.path.expanduser('~'), temp_dirs[0])) - self.clone_from.assert_any_call(PACK_INDEX['test2']['repo_url'], - os.path.join(os.path.expanduser('~'), temp_dirs[1])) + self.assertEqual(result, {"test": "Success.", "test2": "Success."}) + self.clone_from.assert_any_call( + PACK_INDEX["test"]["repo_url"], + os.path.join(os.path.expanduser("~"), temp_dirs[0]), + ) + self.clone_from.assert_any_call( + PACK_INDEX["test2"]["repo_url"], + os.path.join(os.path.expanduser("~"), temp_dirs[1]), + ) self.assertEqual(self.clone_from.call_count, 2) - self.assertTrue(os.path.isfile(os.path.join(self.repo_base, 'test/pack.yaml'))) - self.assertTrue(os.path.isfile(os.path.join(self.repo_base, 'test2/pack.yaml'))) + self.assertTrue(os.path.isfile(os.path.join(self.repo_base, "test/pack.yaml"))) + self.assertTrue(os.path.isfile(os.path.join(self.repo_base, "test2/pack.yaml"))) - @mock.patch.object(Repo, 'clone_from') + @mock.patch.object(Repo, "clone_from") def test_run_pack_download_error(self, clone_from): - clone_from.side_effect = Exception('Something went terribly wrong during the clone') + clone_from.side_effect = Exception( + "Something went terribly wrong during the clone" + ) action = self.get_action_instance() - self.assertRaises(Exception, action.run, packs=['test'], abs_repo_base=self.repo_base) + self.assertRaises( + Exception, action.run, packs=["test"], abs_repo_base=self.repo_base + ) def test_run_pack_download_no_tag(self): self.repo_instance.commit.side_effect = BadName action = self.get_action_instance() - self.assertRaises(ValueError, action.run, packs=['test=1.2.3'], - abs_repo_base=self.repo_base) + self.assertRaises( + ValueError, action.run, packs=["test=1.2.3"], abs_repo_base=self.repo_base + ) def test_run_pack_lock_is_already_acquired(self): action = self.get_action_instance() - temp_dir = hashlib.md5(PACK_INDEX['test']['repo_url'].encode()).hexdigest() + temp_dir = hashlib.md5(PACK_INDEX["test"]["repo_url"].encode()).hexdigest() original_acquire = LockFile.acquire @@ -227,15 +252,20 @@ def mock_acquire(self, timeout=None): LockFile.acquire = mock_acquire try: - lock_file = LockFile('/tmp/%s' % (temp_dir)) + lock_file = LockFile("/tmp/%s" % (temp_dir)) # Acquire a lock (file) so acquire inside download will fail - with open(lock_file.lock_file, 'w') as fp: - fp.write('') - - expected_msg = 'Timeout waiting to acquire lock for' - self.assertRaisesRegexp(LockTimeout, expected_msg, action.run, packs=['test'], - abs_repo_base=self.repo_base) + with open(lock_file.lock_file, "w") as fp: + fp.write("") + + expected_msg = "Timeout waiting to acquire lock for" + self.assertRaisesRegexp( + LockTimeout, + expected_msg, + action.run, + packs=["test"], + abs_repo_base=self.repo_base, + ) finally: os.unlink(lock_file.lock_file) LockFile.acquire = original_acquire @@ -243,7 +273,7 @@ def mock_acquire(self, timeout=None): def test_run_pack_lock_is_already_acquired_force_flag(self): # Lock is already acquired but force is true so it should be deleted and released action = self.get_action_instance() - temp_dir = hashlib.md5(PACK_INDEX['test']['repo_url'].encode()).hexdigest() + temp_dir = hashlib.md5(PACK_INDEX["test"]["repo_url"].encode()).hexdigest() original_acquire = LockFile.acquire @@ -253,194 +283,266 @@ def mock_acquire(self, timeout=None): LockFile.acquire = mock_acquire try: - lock_file = LockFile('/tmp/%s' % (temp_dir)) + lock_file = LockFile("/tmp/%s" % (temp_dir)) # Acquire a lock (file) so acquire inside download will fail - with open(lock_file.lock_file, 'w') as fp: - fp.write('') + with open(lock_file.lock_file, "w") as fp: + fp.write("") - result = action.run(packs=['test'], abs_repo_base=self.repo_base, force=True) + result = action.run( + packs=["test"], abs_repo_base=self.repo_base, force=True + ) finally: LockFile.acquire = original_acquire - self.assertEqual(result, {'test': 'Success.'}) + self.assertEqual(result, {"test": "Success."}) def test_run_pack_download_v_tag(self): def side_effect(ref): - if ref[0] != 'v': + if ref[0] != "v": raise BadName() - return mock.MagicMock(hexsha='abcdef') + return mock.MagicMock(hexsha="abcdef") self.repo_instance.commit.side_effect = side_effect self.repo_instance.git = mock.MagicMock( - branch=(lambda *args: 'master'), - checkout=(lambda *args: True) + branch=(lambda *args: "master"), checkout=(lambda *args: True) ) action = self.get_action_instance() - result = action.run(packs=['test=1.2.3'], abs_repo_base=self.repo_base) + result = action.run(packs=["test=1.2.3"], abs_repo_base=self.repo_base) - self.assertEqual(result, {'test': 'Success.'}) + self.assertEqual(result, {"test": "Success."}) - @mock.patch.object(st2common.util.pack_management, 'get_valid_versions_for_repo', - mock.Mock(return_value=['1.0.0', '2.0.0'])) + @mock.patch.object( + st2common.util.pack_management, + "get_valid_versions_for_repo", + mock.Mock(return_value=["1.0.0", "2.0.0"]), + ) def test_run_pack_download_invalid_version(self): self.repo_instance.commit.side_effect = lambda ref: None action = self.get_action_instance() - expected_msg = ('is not a valid version, hash, tag or branch.*?' - 'Available versions are: 1.0.0, 2.0.0.') - self.assertRaisesRegexp(ValueError, expected_msg, action.run, - packs=['test=2.2.3'], abs_repo_base=self.repo_base) + expected_msg = ( + "is not a valid version, hash, tag or branch.*?" + "Available versions are: 1.0.0, 2.0.0." + ) + self.assertRaisesRegexp( + ValueError, + expected_msg, + action.run, + packs=["test=2.2.3"], + abs_repo_base=self.repo_base, + ) def test_download_pack_stackstorm_version_identifier_check(self): action = self.get_action_instance() # Version is satisfied - st2common.util.pack_management.CURRENT_STACKSTORM_VERSION = '2.0.0' + st2common.util.pack_management.CURRENT_STACKSTORM_VERSION = "2.0.0" - result = action.run(packs=['test3'], abs_repo_base=self.repo_base) - self.assertEqual(result['test3'], 'Success.') + result = action.run(packs=["test3"], abs_repo_base=self.repo_base) + self.assertEqual(result["test3"], "Success.") # Pack requires a version which is not satisfied by current StackStorm version - st2common.util.pack_management.CURRENT_STACKSTORM_VERSION = '2.2.0' - expected_msg = ('Pack "test3" requires StackStorm ">=1.6.0, <2.2.0", but ' - 'current version is "2.2.0"') - self.assertRaisesRegexp(ValueError, expected_msg, action.run, packs=['test3'], - abs_repo_base=self.repo_base) - - st2common.util.pack_management.CURRENT_STACKSTORM_VERSION = '2.3.0' - expected_msg = ('Pack "test3" requires StackStorm ">=1.6.0, <2.2.0", but ' - 'current version is "2.3.0"') - self.assertRaisesRegexp(ValueError, expected_msg, action.run, packs=['test3'], - abs_repo_base=self.repo_base) - - st2common.util.pack_management.CURRENT_STACKSTORM_VERSION = '1.5.9' - expected_msg = ('Pack "test3" requires StackStorm ">=1.6.0, <2.2.0", but ' - 'current version is "1.5.9"') - self.assertRaisesRegexp(ValueError, expected_msg, action.run, packs=['test3'], - abs_repo_base=self.repo_base) - - st2common.util.pack_management.CURRENT_STACKSTORM_VERSION = '1.5.0' - expected_msg = ('Pack "test3" requires StackStorm ">=1.6.0, <2.2.0", but ' - 'current version is "1.5.0"') - self.assertRaisesRegexp(ValueError, expected_msg, action.run, packs=['test3'], - abs_repo_base=self.repo_base) + st2common.util.pack_management.CURRENT_STACKSTORM_VERSION = "2.2.0" + expected_msg = ( + 'Pack "test3" requires StackStorm ">=1.6.0, <2.2.0", but ' + 'current version is "2.2.0"' + ) + self.assertRaisesRegexp( + ValueError, + expected_msg, + action.run, + packs=["test3"], + abs_repo_base=self.repo_base, + ) + + st2common.util.pack_management.CURRENT_STACKSTORM_VERSION = "2.3.0" + expected_msg = ( + 'Pack "test3" requires StackStorm ">=1.6.0, <2.2.0", but ' + 'current version is "2.3.0"' + ) + self.assertRaisesRegexp( + ValueError, + expected_msg, + action.run, + packs=["test3"], + abs_repo_base=self.repo_base, + ) + + st2common.util.pack_management.CURRENT_STACKSTORM_VERSION = "1.5.9" + expected_msg = ( + 'Pack "test3" requires StackStorm ">=1.6.0, <2.2.0", but ' + 'current version is "1.5.9"' + ) + self.assertRaisesRegexp( + ValueError, + expected_msg, + action.run, + packs=["test3"], + abs_repo_base=self.repo_base, + ) + + st2common.util.pack_management.CURRENT_STACKSTORM_VERSION = "1.5.0" + expected_msg = ( + 'Pack "test3" requires StackStorm ">=1.6.0, <2.2.0", but ' + 'current version is "1.5.0"' + ) + self.assertRaisesRegexp( + ValueError, + expected_msg, + action.run, + packs=["test3"], + abs_repo_base=self.repo_base, + ) # Version is not met, but force=true parameter is provided - st2common.util.pack_management.CURRENT_STACKSTORM_VERSION = '1.5.0' - result = action.run(packs=['test3'], abs_repo_base=self.repo_base, force=True) - self.assertEqual(result['test3'], 'Success.') + st2common.util.pack_management.CURRENT_STACKSTORM_VERSION = "1.5.0" + result = action.run(packs=["test3"], abs_repo_base=self.repo_base, force=True) + self.assertEqual(result["test3"], "Success.") def test_download_pack_python_version_check(self): action = self.get_action_instance() # No python_versions attribute specified in the metadata file - with mock.patch('st2common.util.pack_management.get_pack_metadata') as \ - mock_get_pack_metadata: + with mock.patch( + "st2common.util.pack_management.get_pack_metadata" + ) as mock_get_pack_metadata: mock_get_pack_metadata.return_value = { - 'name': 'test3', - 'stackstorm_version': '', - 'python_versions': [] + "name": "test3", + "stackstorm_version": "", + "python_versions": [], } st2common.util.pack_management.six.PY2 = True st2common.util.pack_management.six.PY3 = False - st2common.util.pack_management.CURRENT_PYTHON_VERSION = '2.7.11' + st2common.util.pack_management.CURRENT_PYTHON_VERSION = "2.7.11" - result = action.run(packs=['test3'], abs_repo_base=self.repo_base, force=False) - self.assertEqual(result['test3'], 'Success.') + result = action.run( + packs=["test3"], abs_repo_base=self.repo_base, force=False + ) + self.assertEqual(result["test3"], "Success.") # Pack works with Python 2.x installation is running 2.7 - with mock.patch('st2common.util.pack_management.get_pack_metadata') as \ - mock_get_pack_metadata: + with mock.patch( + "st2common.util.pack_management.get_pack_metadata" + ) as mock_get_pack_metadata: mock_get_pack_metadata.return_value = { - 'name': 'test3', - 'stackstorm_version': '', - 'python_versions': ['2'] + "name": "test3", + "stackstorm_version": "", + "python_versions": ["2"], } st2common.util.pack_management.six.PY2 = True st2common.util.pack_management.six.PY3 = False - st2common.util.pack_management.CURRENT_PYTHON_VERSION = '2.7.5' + st2common.util.pack_management.CURRENT_PYTHON_VERSION = "2.7.5" - result = action.run(packs=['test3'], abs_repo_base=self.repo_base, force=False) - self.assertEqual(result['test3'], 'Success.') + result = action.run( + packs=["test3"], abs_repo_base=self.repo_base, force=False + ) + self.assertEqual(result["test3"], "Success.") - st2common.util.pack_management.CURRENT_PYTHON_VERSION = '2.7.12' + st2common.util.pack_management.CURRENT_PYTHON_VERSION = "2.7.12" - result = action.run(packs=['test3'], abs_repo_base=self.repo_base, force=False) - self.assertEqual(result['test3'], 'Success.') + result = action.run( + packs=["test3"], abs_repo_base=self.repo_base, force=False + ) + self.assertEqual(result["test3"], "Success.") # Pack works with Python 2.x installation is running 3.5 - with mock.patch('st2common.util.pack_management.get_pack_metadata') as \ - mock_get_pack_metadata: + with mock.patch( + "st2common.util.pack_management.get_pack_metadata" + ) as mock_get_pack_metadata: mock_get_pack_metadata.return_value = { - 'name': 'test3', - 'stackstorm_version': '', - 'python_versions': ['2'] + "name": "test3", + "stackstorm_version": "", + "python_versions": ["2"], } st2common.util.pack_management.six.PY2 = False st2common.util.pack_management.six.PY3 = True - st2common.util.pack_management.CURRENT_PYTHON_VERSION = '3.5.2' + st2common.util.pack_management.CURRENT_PYTHON_VERSION = "3.5.2" - expected_msg = (r'Pack "test3" requires Python 2.x, but current Python version is ' - '"3.5.2"') - self.assertRaisesRegexp(ValueError, expected_msg, action.run, - packs=['test3'], abs_repo_base=self.repo_base, force=False) + expected_msg = ( + r'Pack "test3" requires Python 2.x, but current Python version is ' + '"3.5.2"' + ) + self.assertRaisesRegexp( + ValueError, + expected_msg, + action.run, + packs=["test3"], + abs_repo_base=self.repo_base, + force=False, + ) # Pack works with Python 3.x installation is running 2.7 - with mock.patch('st2common.util.pack_management.get_pack_metadata') as \ - mock_get_pack_metadata: + with mock.patch( + "st2common.util.pack_management.get_pack_metadata" + ) as mock_get_pack_metadata: mock_get_pack_metadata.return_value = { - 'name': 'test3', - 'stackstorm_version': '', - 'python_versions': ['3'] + "name": "test3", + "stackstorm_version": "", + "python_versions": ["3"], } st2common.util.pack_management.six.PY2 = True st2common.util.pack_management.six.PY3 = False - st2common.util.pack_management.CURRENT_PYTHON_VERSION = '2.7.2' + st2common.util.pack_management.CURRENT_PYTHON_VERSION = "2.7.2" - expected_msg = (r'Pack "test3" requires Python 3.x, but current Python version is ' - '"2.7.2"') - self.assertRaisesRegexp(ValueError, expected_msg, action.run, - packs=['test3'], abs_repo_base=self.repo_base, force=False) + expected_msg = ( + r'Pack "test3" requires Python 3.x, but current Python version is ' + '"2.7.2"' + ) + self.assertRaisesRegexp( + ValueError, + expected_msg, + action.run, + packs=["test3"], + abs_repo_base=self.repo_base, + force=False, + ) # Pack works with Python 2.x and 3.x installation is running 2.7 and 3.6.1 - with mock.patch('st2common.util.pack_management.get_pack_metadata') as \ - mock_get_pack_metadata: + with mock.patch( + "st2common.util.pack_management.get_pack_metadata" + ) as mock_get_pack_metadata: mock_get_pack_metadata.return_value = { - 'name': 'test3', - 'stackstorm_version': '', - 'python_versions': ['2', '3'] + "name": "test3", + "stackstorm_version": "", + "python_versions": ["2", "3"], } st2common.util.pack_management.six.PY2 = True st2common.util.pack_management.six.PY3 = False - st2common.util.pack_management.CURRENT_PYTHON_VERSION = '2.7.5' + st2common.util.pack_management.CURRENT_PYTHON_VERSION = "2.7.5" - result = action.run(packs=['test3'], abs_repo_base=self.repo_base, force=False) - self.assertEqual(result['test3'], 'Success.') + result = action.run( + packs=["test3"], abs_repo_base=self.repo_base, force=False + ) + self.assertEqual(result["test3"], "Success.") st2common.util.pack_management.six.PY2 = False st2common.util.pack_management.six.PY3 = True - st2common.util.pack_management.CURRENT_PYTHON_VERSION = '3.6.1' + st2common.util.pack_management.CURRENT_PYTHON_VERSION = "3.6.1" - result = action.run(packs=['test3'], abs_repo_base=self.repo_base, force=False) - self.assertEqual(result['test3'], 'Success.') + result = action.run( + packs=["test3"], abs_repo_base=self.repo_base, force=False + ) + self.assertEqual(result["test3"], "Success.") def test_resolve_urls(self): - url = eval_repo_url( - "https://github.com/StackStorm-Exchange/stackstorm-test") + url = eval_repo_url("https://github.com/StackStorm-Exchange/stackstorm-test") self.assertEqual(url, "https://github.com/StackStorm-Exchange/stackstorm-test") url = eval_repo_url( - "https://github.com/StackStorm-Exchange/stackstorm-test.git") - self.assertEqual(url, "https://github.com/StackStorm-Exchange/stackstorm-test.git") + "https://github.com/StackStorm-Exchange/stackstorm-test.git" + ) + self.assertEqual( + url, "https://github.com/StackStorm-Exchange/stackstorm-test.git" + ) url = eval_repo_url("StackStorm-Exchange/stackstorm-test") self.assertEqual(url, "https://github.com/StackStorm-Exchange/stackstorm-test") @@ -460,11 +562,11 @@ def test_resolve_urls(self): url = eval_repo_url("file://localhost/home/vagrant/stackstorm-test") self.assertEqual(url, "file://localhost/home/vagrant/stackstorm-test") - url = eval_repo_url('ssh:///AutomationStackStorm') - self.assertEqual(url, 'ssh:///AutomationStackStorm') + url = eval_repo_url("ssh:///AutomationStackStorm") + self.assertEqual(url, "ssh:///AutomationStackStorm") - url = eval_repo_url('ssh://joe@local/AutomationStackStorm') - self.assertEqual(url, 'ssh://joe@local/AutomationStackStorm') + url = eval_repo_url("ssh://joe@local/AutomationStackStorm") + self.assertEqual(url, "ssh://joe@local/AutomationStackStorm") def test_run_pack_download_edge_cases(self): """ @@ -479,36 +581,35 @@ def test_run_pack_download_edge_cases(self): """ def side_effect(ref): - if ref[0] != 'v': + if ref[0] != "v": raise BadName() - return mock.MagicMock(hexsha='abcdeF') + return mock.MagicMock(hexsha="abcdeF") self.repo_instance.commit.side_effect = side_effect edge_cases = [ - ('master', '1.2.3'), - ('master', 'some-branch'), - ('master', 'default-branch'), - ('master', None), - ('default-branch', '1.2.3'), - ('default-branch', 'some-branch'), - ('default-branch', 'default-branch'), - ('default-branch', None) + ("master", "1.2.3"), + ("master", "some-branch"), + ("master", "default-branch"), + ("master", None), + ("default-branch", "1.2.3"), + ("default-branch", "some-branch"), + ("default-branch", "default-branch"), + ("default-branch", None), ] for default_branch, ref in edge_cases: self.repo_instance.git = mock.MagicMock( - branch=(lambda *args: default_branch), - checkout=(lambda *args: True) + branch=(lambda *args: default_branch), checkout=(lambda *args: True) ) # Set default branch self.repo_instance.active_branch.name = default_branch - self.repo_instance.active_branch.object = 'aBcdef' - self.repo_instance.head.commit = 'aBcdef' + self.repo_instance.active_branch.object = "aBcdef" + self.repo_instance.head.commit = "aBcdef" # Fake gitref object - gitref = mock.MagicMock(hexsha='abcDef') + gitref = mock.MagicMock(hexsha="abcDef") # Fool _get_gitref into working when its ref == our ref def fake_commit(arg_ref): @@ -516,30 +617,34 @@ def fake_commit(arg_ref): return gitref else: raise BadName() + self.repo_instance.commit = fake_commit self.repo_instance.active_branch.object = gitref action = self.get_action_instance() if ref: - packs = ['test=%s' % (ref)] + packs = ["test=%s" % (ref)] else: - packs = ['test'] + packs = ["test"] result = action.run(packs=packs, abs_repo_base=self.repo_base) - self.assertEqual(result, {'test': 'Success.'}) + self.assertEqual(result, {"test": "Success."}) - @mock.patch('os.path.isdir', mock_is_dir_func) + @mock.patch("os.path.isdir", mock_is_dir_func) def test_run_pack_dowload_local_git_repo_detached_head_state(self): action = self.get_action_instance() - type(self.repo_instance).active_branch = \ - mock.PropertyMock(side_effect=TypeError('detached head')) + type(self.repo_instance).active_branch = mock.PropertyMock( + side_effect=TypeError("detached head") + ) - pack_path = os.path.join(BASE_DIR, 'fixtures/stackstorm-test') + pack_path = os.path.join(BASE_DIR, "fixtures/stackstorm-test") - result = action.run(packs=['file://%s' % (pack_path)], abs_repo_base=self.repo_base) - self.assertEqual(result, {'test': 'Success.'}) + result = action.run( + packs=["file://%s" % (pack_path)], abs_repo_base=self.repo_base + ) + self.assertEqual(result, {"test": "Success."}) # Verify function has bailed out early self.repo_instance.git.checkout.assert_not_called() @@ -551,41 +656,55 @@ def test_run_pack_download_local_directory(self): # 1. Local directory doesn't exist expected_msg = r'Local pack directory ".*" doesn\'t exist' - self.assertRaisesRegexp(ValueError, expected_msg, action.run, - packs=['file://doesnt_exist'], abs_repo_base=self.repo_base) + self.assertRaisesRegexp( + ValueError, + expected_msg, + action.run, + packs=["file://doesnt_exist"], + abs_repo_base=self.repo_base, + ) # 2. Local pack which is not a git repository - pack_path = os.path.join(BASE_DIR, 'fixtures/stackstorm-test4') + pack_path = os.path.join(BASE_DIR, "fixtures/stackstorm-test4") - result = action.run(packs=['file://%s' % (pack_path)], abs_repo_base=self.repo_base) - self.assertEqual(result, {'test4': 'Success.'}) + result = action.run( + packs=["file://%s" % (pack_path)], abs_repo_base=self.repo_base + ) + self.assertEqual(result, {"test4": "Success."}) # Verify pack contents have been copied over - destination_path = os.path.join(self.repo_base, 'test4') + destination_path = os.path.join(self.repo_base, "test4") self.assertTrue(os.path.exists(destination_path)) - self.assertTrue(os.path.exists(os.path.join(destination_path, 'pack.yaml'))) + self.assertTrue(os.path.exists(os.path.join(destination_path, "pack.yaml"))) - @mock.patch('st2common.util.pack_management.get_gitref', mock_get_gitref) + @mock.patch("st2common.util.pack_management.get_gitref", mock_get_gitref) def test_run_pack_download_with_tag(self): action = self.get_action_instance() - result = action.run(packs=['test'], abs_repo_base=self.repo_base) - temp_dir = hashlib.md5(PACK_INDEX['test']['repo_url'].encode()).hexdigest() + result = action.run(packs=["test"], abs_repo_base=self.repo_base) + temp_dir = hashlib.md5(PACK_INDEX["test"]["repo_url"].encode()).hexdigest() - self.assertEqual(result, {'test': 'Success.'}) - self.clone_from.assert_called_once_with(PACK_INDEX['test']['repo_url'], - os.path.join(os.path.expanduser('~'), temp_dir)) - self.assertTrue(os.path.isfile(os.path.join(self.repo_base, 'test/pack.yaml'))) + self.assertEqual(result, {"test": "Success."}) + self.clone_from.assert_called_once_with( + PACK_INDEX["test"]["repo_url"], + os.path.join(os.path.expanduser("~"), temp_dir), + ) + self.assertTrue(os.path.isfile(os.path.join(self.repo_base, "test/pack.yaml"))) # Check repo.git.checkout is called three times self.assertEqual(self.repo_instance.git.checkout.call_count, 3) # Check repo.git.checkout called with latest tag or branch - self.assertEqual(PACK_INDEX['test']['version'], - self.repo_instance.git.checkout.call_args_list[1][0][0]) + self.assertEqual( + PACK_INDEX["test"]["version"], + self.repo_instance.git.checkout.call_args_list[1][0][0], + ) # Check repo.git.checkout called with head - self.assertEqual(self.repo_instance.head.reference, - self.repo_instance.git.checkout.call_args_list[2][0][0]) + self.assertEqual( + self.repo_instance.head.reference, + self.repo_instance.git.checkout.call_args_list[2][0][0], + ) self.repo_instance.git.branch.assert_called_with( - '-f', self.repo_instance.head.reference, PACK_INDEX['test']['version']) + "-f", self.repo_instance.head.reference, PACK_INDEX["test"]["version"] + ) diff --git a/contrib/packs/tests/test_action_unload.py b/contrib/packs/tests/test_action_unload.py index 5e642483d4c..fc07ff87c3c 100644 --- a/contrib/packs/tests/test_action_unload.py +++ b/contrib/packs/tests/test_action_unload.py @@ -20,6 +20,7 @@ from oslo_config import cfg from st2common.util.monkey_patch import use_select_poll_workaround + use_select_poll_workaround() from st2common.content.bootstrap import register_content @@ -39,11 +40,11 @@ from pack_mgmt.unload import UnregisterPackAction -__all__ = [ - 'UnloadActionTestCase' -] +__all__ = ["UnloadActionTestCase"] -PACK_PATH_1 = os.path.join(fixturesloader.get_fixtures_packs_base_path(), 'dummy_pack_1') +PACK_PATH_1 = os.path.join( + fixturesloader.get_fixtures_packs_base_path(), "dummy_pack_1" +) class UnloadActionTestCase(BaseActionTestCase, CleanDbTestCase): @@ -64,13 +65,15 @@ def setUp(self): # Register the pack with all the content # TODO: Don't use pack cache - cfg.CONF.set_override(name='all', override=True, group='register') - cfg.CONF.set_override(name='pack', override=PACK_PATH_1, group='register') - cfg.CONF.set_override(name='no_fail_on_failure', override=True, group='register') + cfg.CONF.set_override(name="all", override=True, group="register") + cfg.CONF.set_override(name="pack", override=PACK_PATH_1, group="register") + cfg.CONF.set_override( + name="no_fail_on_failure", override=True, group="register" + ) register_content() def test_run(self): - pack = 'dummy_pack_1' + pack = "dummy_pack_1" # Verify all the resources are there pack_dbs = Pack.query(ref=pack) diff --git a/contrib/packs/tests/test_get_pack_dependencies.py b/contrib/packs/tests/test_get_pack_dependencies.py index e047d7fca48..a90f9406386 100644 --- a/contrib/packs/tests/test_get_pack_dependencies.py +++ b/contrib/packs/tests/test_get_pack_dependencies.py @@ -21,21 +21,20 @@ from pack_mgmt.get_pack_dependencies import GetPackDependencies -UNINSTALLED_PACK = 'uninstalled_pack' +UNINSTALLED_PACK = "uninstalled_pack" UNINSTALLED_PACKS = [ UNINSTALLED_PACK, - 'https://github.com/StackStorm-Exchange/stackstorm-pack1', - 'https://github.com/StackStorm-Exchange/stackstorm-pack2.git', - 'https://github.com/StackStorm-Exchange/stackstorm-pack3.git=v2.1.1', - 'StackStorm-Exchange/stackstorm-pack4', - 'git://StackStorm-Exchange/stackstorm-pack5=v2.1.1', - 'git://StackStorm-Exchange/stackstorm-pack6.git', - 'git@github.com:foo/pack7.git' - 'git@github.com:foo/pack8.git=v3.2.1', - 'file:///home/vagrant/stackstorm-pack9', - 'file://localhost/home/vagrant/stackstorm-pack10', - 'ssh:///AutomationStackStorm11', - 'ssh://joe@local/AutomationStackStorm12' + "https://github.com/StackStorm-Exchange/stackstorm-pack1", + "https://github.com/StackStorm-Exchange/stackstorm-pack2.git", + "https://github.com/StackStorm-Exchange/stackstorm-pack3.git=v2.1.1", + "StackStorm-Exchange/stackstorm-pack4", + "git://StackStorm-Exchange/stackstorm-pack5=v2.1.1", + "git://StackStorm-Exchange/stackstorm-pack6.git", + "git@github.com:foo/pack7.git" "git@github.com:foo/pack8.git=v3.2.1", + "file:///home/vagrant/stackstorm-pack9", + "file://localhost/home/vagrant/stackstorm-pack10", + "ssh:///AutomationStackStorm11", + "ssh://joe@local/AutomationStackStorm12", ] DOWNLOADED_OR_INSTALLED_PACK_METAdATA = { @@ -58,7 +57,7 @@ "keywords": ["some", "special", "terms"], "email": "info@stackstorm.com", "description": "another st2 pack to test package management pipeline", - "dependencies": ['uninstalled_pack', 'no_dependencies'] + "dependencies": ["uninstalled_pack", "no_dependencies"], }, # List of uninstalled dependency packs. "test3": { @@ -70,7 +69,7 @@ "keywords": ["some", "special", "terms"], "email": "info@stackstorm.com", "description": "another st2 pack to test package management pipeline", - "dependencies": UNINSTALLED_PACKS + "dependencies": UNINSTALLED_PACKS, }, # One conflict pack with existing pack. "test4": { @@ -82,9 +81,7 @@ "keywords": ["some", "special", "terms"], "email": "info@stackstorm.com", "description": "another st2 pack to test package management pipeline", - "dependencies": [ - "test2=v0.4.0" - ] + "dependencies": ["test2=v0.4.0"], }, # One uninstalled conflict pack. "test5": { @@ -93,9 +90,10 @@ "name": "test4", "repo_url": "https://github.com/StackStorm-Exchange/stackstorm-test4", "author": "stanley", - "keywords": ["some", "special", "terms"], "email": "info@stackstorm.com", + "keywords": ["some", "special", "terms"], + "email": "info@stackstorm.com", "description": "another st2 pack to test package management pipeline", - "dependencies": ["uninstalled_pack=v0.4.0"] + "dependencies": ["uninstalled_pack=v0.4.0"], }, # One dependency pack without version. It is not checked against conflict. "test6": { @@ -104,10 +102,11 @@ "name": "test4", "repo_url": "https://github.com/StackStorm-Exchange/stackstorm-test4", "author": "stanley", - "keywords": ["some", "special", "terms"], "email": "info@stackstorm.com", + "keywords": ["some", "special", "terms"], + "email": "info@stackstorm.com", "description": "another st2 pack to test package management pipeline", - "dependencies": ["test2"] - } + "dependencies": ["test2"], + }, } @@ -119,7 +118,7 @@ def mock_get_dependency_list(pack): if pack in DOWNLOADED_OR_INSTALLED_PACK_METAdATA: metadata = DOWNLOADED_OR_INSTALLED_PACK_METAdATA[pack] - dependencies = metadata.get('dependencies', None) + dependencies = metadata.get("dependencies", None) return dependencies @@ -132,13 +131,15 @@ def mock_get_pack_version(pack): if pack in DOWNLOADED_OR_INSTALLED_PACK_METAdATA: metadata = DOWNLOADED_OR_INSTALLED_PACK_METAdATA[pack] - version = metadata.get('version', None) + version = metadata.get("version", None) return version -@mock.patch('pack_mgmt.get_pack_dependencies.get_dependency_list', mock_get_dependency_list) -@mock.patch('pack_mgmt.get_pack_dependencies.get_pack_version', mock_get_pack_version) +@mock.patch( + "pack_mgmt.get_pack_dependencies.get_dependency_list", mock_get_dependency_list +) +@mock.patch("pack_mgmt.get_pack_dependencies.get_pack_version", mock_get_pack_version) class GetPackDependenciesTestCase(BaseActionTestCase): action_cls = GetPackDependencies @@ -167,9 +168,9 @@ def test_run_get_pack_dependencies_with_failed_packs_status(self): nested = 2 result = action.run(packs_status=packs_status, nested=nested) - self.assertEqual(result['dependency_list'], []) - self.assertEqual(result['conflict_list'], []) - self.assertEqual(result['nested'], nested - 1) + self.assertEqual(result["dependency_list"], []) + self.assertEqual(result["conflict_list"], []) + self.assertEqual(result["nested"], nested - 1) def test_run_get_pack_dependencies_with_failed_and_succeeded_packs_status(self): action = self.get_action_instance() @@ -177,9 +178,9 @@ def test_run_get_pack_dependencies_with_failed_and_succeeded_packs_status(self): nested = 2 result = action.run(packs_status=packs_status, nested=nested) - self.assertEqual(result['dependency_list'], [UNINSTALLED_PACK]) - self.assertEqual(result['conflict_list'], []) - self.assertEqual(result['nested'], nested - 1) + self.assertEqual(result["dependency_list"], [UNINSTALLED_PACK]) + self.assertEqual(result["conflict_list"], []) + self.assertEqual(result["nested"], nested - 1) def test_run_get_pack_dependencies_with_no_dependency(self): action = self.get_action_instance() @@ -187,9 +188,9 @@ def test_run_get_pack_dependencies_with_no_dependency(self): nested = 3 result = action.run(packs_status=packs_status, nested=nested) - self.assertEqual(result['dependency_list'], []) - self.assertEqual(result['conflict_list'], []) - self.assertEqual(result['nested'], nested - 1) + self.assertEqual(result["dependency_list"], []) + self.assertEqual(result["conflict_list"], []) + self.assertEqual(result["nested"], nested - 1) def test_run_get_pack_dependencies_with_dependency(self): action = self.get_action_instance() @@ -197,9 +198,9 @@ def test_run_get_pack_dependencies_with_dependency(self): nested = 1 result = action.run(packs_status=packs_status, nested=nested) - self.assertEqual(result['dependency_list'], [UNINSTALLED_PACK]) - self.assertEqual(result['conflict_list'], []) - self.assertEqual(result['nested'], nested - 1) + self.assertEqual(result["dependency_list"], [UNINSTALLED_PACK]) + self.assertEqual(result["conflict_list"], []) + self.assertEqual(result["nested"], nested - 1) def test_run_get_pack_dependencies_with_dependencies(self): action = self.get_action_instance() @@ -207,9 +208,9 @@ def test_run_get_pack_dependencies_with_dependencies(self): nested = 1 result = action.run(packs_status=packs_status, nested=nested) - self.assertEqual(result['dependency_list'], UNINSTALLED_PACKS) - self.assertEqual(result['conflict_list'], []) - self.assertEqual(result['nested'], nested - 1) + self.assertEqual(result["dependency_list"], UNINSTALLED_PACKS) + self.assertEqual(result["conflict_list"], []) + self.assertEqual(result["nested"], nested - 1) def test_run_get_pack_dependencies_with_existing_pack_conflict(self): action = self.get_action_instance() @@ -217,9 +218,9 @@ def test_run_get_pack_dependencies_with_existing_pack_conflict(self): nested = 1 result = action.run(packs_status=packs_status, nested=nested) - self.assertEqual(result['dependency_list'], [UNINSTALLED_PACK]) - self.assertEqual(result['conflict_list'], ['test2=v0.4.0']) - self.assertEqual(result['nested'], nested - 1) + self.assertEqual(result["dependency_list"], [UNINSTALLED_PACK]) + self.assertEqual(result["conflict_list"], ["test2=v0.4.0"]) + self.assertEqual(result["nested"], nested - 1) def test_run_get_pack_dependencies_with_dependency_conflict(self): action = self.get_action_instance() @@ -227,9 +228,9 @@ def test_run_get_pack_dependencies_with_dependency_conflict(self): nested = 1 result = action.run(packs_status=packs_status, nested=nested) - self.assertEqual(result['dependency_list'], ['uninstalled_pack']) - self.assertEqual(result['conflict_list'], ['uninstalled_pack=v0.4.0']) - self.assertEqual(result['nested'], nested - 1) + self.assertEqual(result["dependency_list"], ["uninstalled_pack"]) + self.assertEqual(result["conflict_list"], ["uninstalled_pack=v0.4.0"]) + self.assertEqual(result["nested"], nested - 1) def test_run_get_pack_dependencies_with_no_version(self): action = self.get_action_instance() @@ -237,6 +238,6 @@ def test_run_get_pack_dependencies_with_no_version(self): nested = 1 result = action.run(packs_status=packs_status, nested=nested) - self.assertEqual(result['dependency_list'], [UNINSTALLED_PACK]) - self.assertEqual(result['conflict_list'], []) - self.assertEqual(result['nested'], nested - 1) + self.assertEqual(result["dependency_list"], [UNINSTALLED_PACK]) + self.assertEqual(result["conflict_list"], []) + self.assertEqual(result["nested"], nested - 1) diff --git a/contrib/packs/tests/test_get_pack_warnings.py b/contrib/packs/tests/test_get_pack_warnings.py index 49e2d920a88..3eac7ba3562 100644 --- a/contrib/packs/tests/test_get_pack_warnings.py +++ b/contrib/packs/tests/test_get_pack_warnings.py @@ -29,7 +29,7 @@ "keywords": ["some", "search", "another", "terms"], "email": "info@stackstorm.com", "description": "st2 pack to test package management pipeline", - "python_versions": ["2","3"], + "python_versions": ["2", "3"], }, # Python 3 "py3": { @@ -72,10 +72,11 @@ "keywords": ["some", "special", "terms"], "email": "info@stackstorm.com", "description": "another st2 pack to test package management pipeline", - "python_versions": ["2"] - } + "python_versions": ["2"], + }, } + def mock_get_pack_basepath(pack): """ Mock get_pack_basepath function which just returns pack n ame @@ -94,8 +95,8 @@ def mock_get_pack_metadata(pack_dir): return metadata -@mock.patch('pack_mgmt.get_pack_warnings.get_pack_base_path', mock_get_pack_basepath) -@mock.patch('pack_mgmt.get_pack_warnings.get_pack_metadata', mock_get_pack_metadata) +@mock.patch("pack_mgmt.get_pack_warnings.get_pack_base_path", mock_get_pack_basepath) +@mock.patch("pack_mgmt.get_pack_warnings.get_pack_metadata", mock_get_pack_metadata) class GetPackWarningsTestCase(BaseActionTestCase): action_cls = GetPackWarnings @@ -107,15 +108,15 @@ def test_run_get_pack_warnings_py3_pack(self): packs_status = {"py3": "Success."} result = action.run(packs_status=packs_status) - self.assertEqual(result['warning_list'], []) + self.assertEqual(result["warning_list"], []) def test_run_get_pack_warnings_py2_pack(self): action = self.get_action_instance() packs_status = {"py2": "Success."} result = action.run(packs_status=packs_status) - self.assertEqual(len(result['warning_list']), 1) - warning = result['warning_list'][0] + self.assertEqual(len(result["warning_list"]), 1) + warning = result["warning_list"][0] self.assertTrue("DEPRECATION WARNING" in warning) self.assertTrue("Pack py2 only supports Python 2" in warning) @@ -124,28 +125,32 @@ def test_run_get_pack_warnings_py23_pack(self): packs_status = {"py23": "Success."} result = action.run(packs_status=packs_status) - self.assertEqual(result['warning_list'], []) + self.assertEqual(result["warning_list"], []) def test_run_get_pack_warnings_pynone_pack(self): action = self.get_action_instance() packs_status = {"pynone": "Success."} result = action.run(packs_status=packs_status) - self.assertEqual(result['warning_list'], []) + self.assertEqual(result["warning_list"], []) def test_run_get_pack_warnings_multiple_pack(self): action = self.get_action_instance() - packs_status = {"py2": "Success.", - "py23": "Success.", - "py22": "Success."} + packs_status = {"py2": "Success.", "py23": "Success.", "py22": "Success."} result = action.run(packs_status=packs_status) - self.assertEqual(len(result['warning_list']), 2) - warning0 = result['warning_list'][0] - warning1 = result['warning_list'][1] + self.assertEqual(len(result["warning_list"]), 2) + warning0 = result["warning_list"][0] + warning1 = result["warning_list"][1] self.assertTrue("DEPRECATION WARNING" in warning0) self.assertTrue("DEPRECATION WARNING" in warning1) - self.assertTrue(("Pack py2 only supports Python 2" in warning0 and - "Pack py22 only supports Python 2" in warning1) or - ("Pack py22 only supports Python 2" in warning0 and - "Pack py2 only supports Python 2" in warning1)) + self.assertTrue( + ( + "Pack py2 only supports Python 2" in warning0 + and "Pack py22 only supports Python 2" in warning1 + ) + or ( + "Pack py22 only supports Python 2" in warning0 + and "Pack py2 only supports Python 2" in warning1 + ) + ) diff --git a/contrib/packs/tests/test_virtualenv_setup_prerun.py b/contrib/packs/tests/test_virtualenv_setup_prerun.py index 63b27410f6c..0097ecd8feb 100644 --- a/contrib/packs/tests/test_virtualenv_setup_prerun.py +++ b/contrib/packs/tests/test_virtualenv_setup_prerun.py @@ -28,21 +28,26 @@ def setUp(self): def test_run_with_pack_list(self): action = self.get_action_instance() - result = action.run(packs_status={'test1': 'Success.', 'test2': 'Success.'}, - packs_list=['test3', 'test4']) + result = action.run( + packs_status={"test1": "Success.", "test2": "Success."}, + packs_list=["test3", "test4"], + ) - self.assertEqual(result, ['test3', 'test4', 'test1', 'test2']) + self.assertEqual(result, ["test3", "test4", "test1", "test2"]) def test_run_with_none_pack_list(self): action = self.get_action_instance() - result = action.run(packs_status={'test1': 'Success.', 'test2': 'Success.'}, - packs_list=None) + result = action.run( + packs_status={"test1": "Success.", "test2": "Success."}, packs_list=None + ) - self.assertEqual(result, ['test1', 'test2']) + self.assertEqual(result, ["test1", "test2"]) def test_run_with_failed_status(self): action = self.get_action_instance() - result = action.run(packs_status={'test1': 'Failed.', 'test2': 'Success.'}, - packs_list=['test3', 'test4']) + result = action.run( + packs_status={"test1": "Failed.", "test2": "Success."}, + packs_list=["test3", "test4"], + ) - self.assertEqual(result, ['test3', 'test4', 'test2']) + self.assertEqual(result, ["test3", "test4", "test2"]) diff --git a/contrib/runners/action_chain_runner/action_chain_runner/__init__.py b/contrib/runners/action_chain_runner/action_chain_runner/__init__.py index bbe290db9a7..e6d3f15e0bd 100644 --- a/contrib/runners/action_chain_runner/action_chain_runner/__init__.py +++ b/contrib/runners/action_chain_runner/action_chain_runner/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = "3.4dev" diff --git a/contrib/runners/action_chain_runner/action_chain_runner/action_chain_runner.py b/contrib/runners/action_chain_runner/action_chain_runner/action_chain_runner.py index 39cb873136d..e71c12d0046 100644 --- a/contrib/runners/action_chain_runner/action_chain_runner/action_chain_runner.py +++ b/contrib/runners/action_chain_runner/action_chain_runner/action_chain_runner.py @@ -50,26 +50,16 @@ from st2common.util.config_loader import get_config from st2common.util.ujson import fast_deepcopy -__all__ = [ - 'ActionChainRunner', - 'ChainHolder', - - 'get_runner', - 'get_metadata' -] +__all__ = ["ActionChainRunner", "ChainHolder", "get_runner", "get_metadata"] LOG = logging.getLogger(__name__) -RESULTS_KEY = '__results' -JINJA_START_MARKERS = [ - '{{', - '{%' -] -PUBLISHED_VARS_KEY = 'published' +RESULTS_KEY = "__results" +JINJA_START_MARKERS = ["{{", "{%"] +PUBLISHED_VARS_KEY = "published" class ChainHolder(object): - def __init__(self, chainspec, chainname): self.actionchain = actionchain.ActionChain(**chainspec) self.chainname = chainname @@ -78,17 +68,21 @@ def __init__(self, chainspec, chainname): default = self._get_default(self.actionchain) self.actionchain.default = default - LOG.debug('Using %s as default for %s.', self.actionchain.default, self.chainname) + LOG.debug( + "Using %s as default for %s.", self.actionchain.default, self.chainname + ) if not self.actionchain.default: - raise Exception('Failed to find default node in %s.' % (self.chainname)) + raise Exception("Failed to find default node in %s." % (self.chainname)) self.vars = {} def init_vars(self, action_parameters, action_context=None): if self.actionchain.vars: - self.vars = self._get_rendered_vars(self.actionchain.vars, - action_parameters=action_parameters, - action_context=action_context) + self.vars = self._get_rendered_vars( + self.actionchain.vars, + action_parameters=action_parameters, + action_context=action_context, + ) def restore_vars(self, ctx_vars): self.vars.update(fast_deepcopy(ctx_vars)) @@ -107,28 +101,37 @@ def validate(self): on_failure_node_name = node.on_failure # Check "on-success" path - valid_name = self._is_valid_node_name(all_node_names=all_nodes, - node_name=on_success_node_name) + valid_name = self._is_valid_node_name( + all_node_names=all_nodes, node_name=on_success_node_name + ) if not valid_name: - msg = ('Unable to find node with name "%s" referenced in "on-success" in ' - 'task "%s".' % (on_success_node_name, node.name)) + msg = ( + 'Unable to find node with name "%s" referenced in "on-success" in ' + 'task "%s".' % (on_success_node_name, node.name) + ) raise ValueError(msg) # Check "on-failure" path - valid_name = self._is_valid_node_name(all_node_names=all_nodes, - node_name=on_failure_node_name) + valid_name = self._is_valid_node_name( + all_node_names=all_nodes, node_name=on_failure_node_name + ) if not valid_name: - msg = ('Unable to find node with name "%s" referenced in "on-failure" in ' - 'task "%s".' % (on_failure_node_name, node.name)) + msg = ( + 'Unable to find node with name "%s" referenced in "on-failure" in ' + 'task "%s".' % (on_failure_node_name, node.name) + ) raise ValueError(msg) # check if node specified in default is valid. if self.actionchain.default: - valid_name = self._is_valid_node_name(all_node_names=all_nodes, - node_name=self.actionchain.default) + valid_name = self._is_valid_node_name( + all_node_names=all_nodes, node_name=self.actionchain.default + ) if not valid_name: - msg = ('Unable to find node with name "%s" referenced in "default".' % - self.actionchain.default) + msg = ( + 'Unable to find node with name "%s" referenced in "default".' + % self.actionchain.default + ) raise ValueError(msg) return True @@ -147,8 +150,12 @@ def _get_default(action_chain): # 2. There are no fragments in the chain. all_nodes = ChainHolder._get_all_nodes(action_chain=action_chain) node_names = set(all_nodes) - on_success_nodes = ChainHolder._get_all_on_success_nodes(action_chain=action_chain) - on_failure_nodes = ChainHolder._get_all_on_failure_nodes(action_chain=action_chain) + on_success_nodes = ChainHolder._get_all_on_success_nodes( + action_chain=action_chain + ) + on_failure_nodes = ChainHolder._get_all_on_failure_nodes( + action_chain=action_chain + ) referenced_nodes = on_success_nodes | on_failure_nodes possible_default_nodes = node_names - referenced_nodes if possible_default_nodes: @@ -210,19 +217,25 @@ def _get_rendered_vars(vars, action_parameters, action_context): return {} action_context = action_context or {} - user = action_context.get('user', cfg.CONF.system_user.user) + user = action_context.get("user", cfg.CONF.system_user.user) context = {} - context.update({ - kv_constants.DATASTORE_PARENT_SCOPE: { - kv_constants.SYSTEM_SCOPE: kv_service.KeyValueLookup( - scope=kv_constants.FULL_SYSTEM_SCOPE), - kv_constants.USER_SCOPE: kv_service.UserKeyValueLookup( - scope=kv_constants.FULL_USER_SCOPE, user=user) + context.update( + { + kv_constants.DATASTORE_PARENT_SCOPE: { + kv_constants.SYSTEM_SCOPE: kv_service.KeyValueLookup( + scope=kv_constants.FULL_SYSTEM_SCOPE + ), + kv_constants.USER_SCOPE: kv_service.UserKeyValueLookup( + scope=kv_constants.FULL_USER_SCOPE, user=user + ), + } } - }) + ) context.update(action_parameters) - LOG.info('Rendering action chain vars. Mapping = %s; Context = %s', vars, context) + LOG.info( + "Rendering action chain vars. Mapping = %s; Context = %s", vars, context + ) return jinja_utils.render_values(mapping=vars, context=context) def get_node(self, node_name=None, raise_on_failure=False): @@ -233,22 +246,22 @@ def get_node(self, node_name=None, raise_on_failure=False): return node if raise_on_failure: raise runner_exc.ActionRunnerException( - 'Unable to find node with name "%s".' % (node_name)) + 'Unable to find node with name "%s".' % (node_name) + ) return None - def get_next_node(self, curr_node_name=None, condition='on-success'): + def get_next_node(self, curr_node_name=None, condition="on-success"): if not curr_node_name: return self.get_node(self.actionchain.default) current_node = self.get_node(curr_node_name) - if condition == 'on-success': + if condition == "on-success": return self.get_node(current_node.on_success, raise_on_failure=True) - elif condition == 'on-failure': + elif condition == "on-failure": return self.get_node(current_node.on_failure, raise_on_failure=True) - raise runner_exc.ActionRunnerException('Unknown condition %s.' % condition) + raise runner_exc.ActionRunnerException("Unknown condition %s." % condition) class ActionChainRunner(ActionRunner): - def __init__(self, runner_id): super(ActionChainRunner, self).__init__(runner_id=runner_id) self.chain_holder = None @@ -261,16 +274,20 @@ def pre_run(self): super(ActionChainRunner, self).pre_run() chainspec_file = self.entry_point - LOG.debug('Reading action chain from %s for action %s.', chainspec_file, - self.action) + LOG.debug( + "Reading action chain from %s for action %s.", chainspec_file, self.action + ) try: - chainspec = self._meta_loader.load(file_path=chainspec_file, - expected_type=dict) + chainspec = self._meta_loader.load( + file_path=chainspec_file, expected_type=dict + ) except Exception as e: - message = ('Failed to parse action chain definition from "%s": %s' % - (chainspec_file, six.text_type(e))) - LOG.exception('Failed to load action chain definition.') + message = 'Failed to parse action chain definition from "%s": %s' % ( + chainspec_file, + six.text_type(e), + ) + LOG.exception("Failed to load action chain definition.") raise runner_exc.ActionRunnerPreRunError(message) try: @@ -279,20 +296,22 @@ def pre_run(self): # preserve the whole nasty jsonschema message as that is better to get to the # root cause message = six.text_type(e) - LOG.exception('Failed to instantiate ActionChain.') + LOG.exception("Failed to instantiate ActionChain.") raise runner_exc.ActionRunnerPreRunError(message) except Exception as e: message = six.text_type(e) - LOG.exception('Failed to instantiate ActionChain.') + LOG.exception("Failed to instantiate ActionChain.") raise runner_exc.ActionRunnerPreRunError(message) # Runner attributes are set lazily. So these steps # should happen outside the constructor. - if getattr(self, 'liveaction', None): - self._chain_notify = getattr(self.liveaction, 'notify', None) + if getattr(self, "liveaction", None): + self._chain_notify = getattr(self.liveaction, "notify", None) if self.runner_parameters: - self._skip_notify_tasks = self.runner_parameters.get('skip_notify', []) - self._display_published = self.runner_parameters.get('display_published', True) + self._skip_notify_tasks = self.runner_parameters.get("skip_notify", []) + self._display_published = self.runner_parameters.get( + "display_published", True + ) # Perform some pre-run chain validation try: @@ -308,34 +327,38 @@ def cancel(self): # Identify the list of action executions that are workflows and cascade pause. for child_exec_id in self.execution.children: child_exec = ActionExecution.get(id=child_exec_id, raise_exception=True) - if (child_exec.runner['name'] in action_constants.WORKFLOW_RUNNER_TYPES and - child_exec.status in action_constants.LIVEACTION_CANCELABLE_STATES): + if ( + child_exec.runner["name"] in action_constants.WORKFLOW_RUNNER_TYPES + and child_exec.status in action_constants.LIVEACTION_CANCELABLE_STATES + ): action_service.request_cancellation( - LiveAction.get(id=child_exec.liveaction['id']), - self.context.get('user', None) + LiveAction.get(id=child_exec.liveaction["id"]), + self.context.get("user", None), ) return ( action_constants.LIVEACTION_STATUS_CANCELING, self.liveaction.result, - self.liveaction.context + self.liveaction.context, ) def pause(self): # Identify the list of action executions that are workflows and cascade pause. for child_exec_id in self.execution.children: child_exec = ActionExecution.get(id=child_exec_id, raise_exception=True) - if (child_exec.runner['name'] in action_constants.WORKFLOW_RUNNER_TYPES and - child_exec.status == action_constants.LIVEACTION_STATUS_RUNNING): + if ( + child_exec.runner["name"] in action_constants.WORKFLOW_RUNNER_TYPES + and child_exec.status == action_constants.LIVEACTION_STATUS_RUNNING + ): action_service.request_pause( - LiveAction.get(id=child_exec.liveaction['id']), - self.context.get('user', None) + LiveAction.get(id=child_exec.liveaction["id"]), + self.context.get("user", None), ) return ( action_constants.LIVEACTION_STATUS_PAUSING, self.liveaction.result, - self.liveaction.context + self.liveaction.context, ) def resume(self): @@ -344,7 +367,7 @@ def resume(self): self.runner_type.runner_parameters, self.action.parameters, self.liveaction.parameters, - self.liveaction.context + self.liveaction.context, ) # Assign runner parameters needed for pre-run. @@ -357,9 +380,7 @@ def resume(self): # Change the status of the liveaction from resuming to running. self.liveaction = action_service.update_status( - self.liveaction, - action_constants.LIVEACTION_STATUS_RUNNING, - publish=False + self.liveaction, action_constants.LIVEACTION_STATUS_RUNNING, publish=False ) # Run the action chain. @@ -370,13 +391,15 @@ def _run_chain(self, action_parameters, resuming=False): chain_status = action_constants.LIVEACTION_STATUS_FAILED # Result holds the final result that the chain store in the database. - result = {'tasks': []} + result = {"tasks": []} # Save published variables into the result if specified. if self._display_published: result[PUBLISHED_VARS_KEY] = {} - context_result = {} # Holds result which is used for the template context purposes + context_result = ( + {} + ) # Holds result which is used for the template context purposes top_level_error = None # Stores a reference to a top level error action_node = None last_task = None @@ -384,11 +407,12 @@ def _run_chain(self, action_parameters, resuming=False): try: # Initialize vars with the action parameters. # This allows action parameers to be referenced from vars. - self.chain_holder.init_vars(action_parameters=action_parameters, - action_context=self.context) + self.chain_holder.init_vars( + action_parameters=action_parameters, action_context=self.context + ) except Exception as e: chain_status = action_constants.LIVEACTION_STATUS_FAILED - m = 'Failed initializing ``vars`` in chain.' + m = "Failed initializing ``vars`` in chain." LOG.exception(m) top_level_error = self._format_error(e, m) result.update(top_level_error) @@ -397,28 +421,32 @@ def _run_chain(self, action_parameters, resuming=False): # Restore state on resuming an existing chain execution. if resuming: # Restore vars is any from the liveaction. - ctx_vars = self.liveaction.context.pop('vars', {}) + ctx_vars = self.liveaction.context.pop("vars", {}) self.chain_holder.restore_vars(ctx_vars) # Restore result if any from the liveaction. - if self.liveaction and hasattr(self.liveaction, 'result') and self.liveaction.result: + if ( + self.liveaction + and hasattr(self.liveaction, "result") + and self.liveaction.result + ): result = self.liveaction.result # Initialize or rebuild existing context_result from liveaction # which holds the result used for resolving context in Jinja template. - for task in result.get('tasks', []): - context_result[task['name']] = task['result'] + for task in result.get("tasks", []): + context_result[task["name"]] = task["result"] # Restore or initialize the top_level_error # that stores a reference to a top level error. - if 'error' in result or 'traceback' in result: + if "error" in result or "traceback" in result: top_level_error = { - 'error': result.get('error'), - 'traceback': result.get('traceback') + "error": result.get("error"), + "traceback": result.get("traceback"), } # If there are no executed tasks in the chain, then get the first node. - if len(result['tasks']) <= 0: + if len(result["tasks"]) <= 0: try: action_node = self.chain_holder.get_next_node() except Exception as e: @@ -433,21 +461,24 @@ def _run_chain(self, action_parameters, resuming=False): # Otherwise, figure out the last task executed and # its state to determine where to begin executing. else: - last_task = result['tasks'][-1] - action_node = self.chain_holder.get_node(last_task['name']) - liveaction = action_db_util.get_liveaction_by_id(last_task['liveaction_id']) + last_task = result["tasks"][-1] + action_node = self.chain_holder.get_node(last_task["name"]) + liveaction = action_db_util.get_liveaction_by_id(last_task["liveaction_id"]) # If the liveaction of the last task has changed, update the result entry. - if liveaction.status != last_task['state']: + if liveaction.status != last_task["state"]: updated_task_result = self._get_updated_action_exec_result( - action_node, liveaction, last_task) - del result['tasks'][-1] - result['tasks'].append(updated_task_result) + action_node, liveaction, last_task + ) + del result["tasks"][-1] + result["tasks"].append(updated_task_result) # Also need to update context_result so the updated result # is available to Jinja expressions - updated_task_name = updated_task_result['name'] - context_result[updated_task_name]['result'] = updated_task_result['result'] + updated_task_name = updated_task_result["name"] + context_result[updated_task_name]["result"] = updated_task_result[ + "result" + ] # If the last task was canceled, then canceled the chain altogether. if liveaction.status == action_constants.LIVEACTION_STATUS_CANCELED: @@ -463,42 +494,52 @@ def _run_chain(self, action_parameters, resuming=False): if liveaction.status == action_constants.LIVEACTION_STATUS_SUCCEEDED: chain_status = action_constants.LIVEACTION_STATUS_SUCCEEDED action_node = self.chain_holder.get_next_node( - last_task['name'], condition='on-success') + last_task["name"], condition="on-success" + ) # If the last task failed, then get the next on-failure action node. if liveaction.status in action_constants.LIVEACTION_FAILED_STATES: chain_status = action_constants.LIVEACTION_STATUS_FAILED action_node = self.chain_holder.get_next_node( - last_task['name'], condition='on-failure') + last_task["name"], condition="on-failure" + ) # Setup parent context. - parent_context = { - 'execution_id': self.execution_id - } + parent_context = {"execution_id": self.execution_id} - if getattr(self.liveaction, 'context', None): + if getattr(self.liveaction, "context", None): parent_context.update(self.liveaction.context) # Run the action chain until there are no more tasks. while action_node: error = None liveaction = None - last_task = result['tasks'][-1] if len(result['tasks']) > 0 else None + last_task = result["tasks"][-1] if len(result["tasks"]) > 0 else None created_at = date_utils.get_datetime_utc_now() try: # If last task was paused, then fetch the liveaction and resume it first. - if last_task and last_task['state'] == action_constants.LIVEACTION_STATUS_PAUSED: - liveaction = action_db_util.get_liveaction_by_id(last_task['liveaction_id']) - del result['tasks'][-1] + if ( + last_task + and last_task["state"] == action_constants.LIVEACTION_STATUS_PAUSED + ): + liveaction = action_db_util.get_liveaction_by_id( + last_task["liveaction_id"] + ) + del result["tasks"][-1] else: liveaction = self._get_next_action( - action_node=action_node, parent_context=parent_context, - action_params=action_parameters, context_result=context_result) + action_node=action_node, + parent_context=parent_context, + action_params=action_parameters, + context_result=context_result, + ) except action_exc.InvalidActionReferencedException as e: chain_status = action_constants.LIVEACTION_STATUS_FAILED - m = ('Failed to run task "%s". Action with reference "%s" doesn\'t exist.' % - (action_node.name, action_node.ref)) + m = ( + 'Failed to run task "%s". Action with reference "%s" doesn\'t exist.' + % (action_node.name, action_node.ref) + ) LOG.exception(m) top_level_error = self._format_error(e, m) break @@ -506,24 +547,41 @@ def _run_chain(self, action_parameters, resuming=False): # Rendering parameters failed before we even got to running this action, # abort and fail the whole action chain chain_status = action_constants.LIVEACTION_STATUS_FAILED - m = 'Failed to run task "%s". Parameter rendering failed.' % action_node.name + m = ( + 'Failed to run task "%s". Parameter rendering failed.' + % action_node.name + ) LOG.exception(m) top_level_error = self._format_error(e, m) break except db_exc.StackStormDBObjectNotFoundError as e: chain_status = action_constants.LIVEACTION_STATUS_FAILED - m = 'Failed to resume task "%s". Unable to find liveaction.' % action_node.name + m = ( + 'Failed to resume task "%s". Unable to find liveaction.' + % action_node.name + ) LOG.exception(m) top_level_error = self._format_error(e, m) break try: # If last task was paused, then fetch the liveaction and resume it first. - if last_task and last_task['state'] == action_constants.LIVEACTION_STATUS_PAUSED: - LOG.info('Resume task %s for chain %s.', action_node.name, self.liveaction.id) + if ( + last_task + and last_task["state"] == action_constants.LIVEACTION_STATUS_PAUSED + ): + LOG.info( + "Resume task %s for chain %s.", + action_node.name, + self.liveaction.id, + ) liveaction = self._resume_action(liveaction) else: - LOG.info('Run task %s for chain %s.', action_node.name, self.liveaction.id) + LOG.info( + "Run task %s for chain %s.", + action_node.name, + self.liveaction.id, + ) liveaction = self._run_action(liveaction) except Exception as e: # Save the traceback and error message @@ -537,9 +595,12 @@ def _run_chain(self, action_parameters, resuming=False): # Render and publish variables rendered_publish_vars = ActionChainRunner._render_publish_vars( - action_node=action_node, action_parameters=action_parameters, - execution_result=liveaction.result, previous_execution_results=context_result, - chain_vars=self.chain_holder.vars) + action_node=action_node, + action_parameters=action_parameters, + execution_result=liveaction.result, + previous_execution_results=context_result, + chain_vars=self.chain_holder.vars, + ) if rendered_publish_vars: self.chain_holder.vars.update(rendered_publish_vars) @@ -550,49 +611,68 @@ def _run_chain(self, action_parameters, resuming=False): updated_at = date_utils.get_datetime_utc_now() task_result = self._format_action_exec_result( - action_node, - liveaction, - created_at, - updated_at, - error=error + action_node, liveaction, created_at, updated_at, error=error ) - result['tasks'].append(task_result) + result["tasks"].append(task_result) try: if not liveaction: chain_status = action_constants.LIVEACTION_STATUS_FAILED action_node = self.chain_holder.get_next_node( - action_node.name, condition='on-failure') - elif liveaction.status == action_constants.LIVEACTION_STATUS_TIMED_OUT: + action_node.name, condition="on-failure" + ) + elif ( + liveaction.status + == action_constants.LIVEACTION_STATUS_TIMED_OUT + ): chain_status = action_constants.LIVEACTION_STATUS_TIMED_OUT action_node = self.chain_holder.get_next_node( - action_node.name, condition='on-failure') - elif liveaction.status == action_constants.LIVEACTION_STATUS_CANCELED: - LOG.info('Chain execution (%s) canceled because task "%s" is canceled.', - self.liveaction_id, action_node.name) + action_node.name, condition="on-failure" + ) + elif ( + liveaction.status == action_constants.LIVEACTION_STATUS_CANCELED + ): + LOG.info( + 'Chain execution (%s) canceled because task "%s" is canceled.', + self.liveaction_id, + action_node.name, + ) chain_status = action_constants.LIVEACTION_STATUS_CANCELED action_node = None elif liveaction.status == action_constants.LIVEACTION_STATUS_PAUSED: - LOG.info('Chain execution (%s) paused because task "%s" is paused.', - self.liveaction_id, action_node.name) + LOG.info( + 'Chain execution (%s) paused because task "%s" is paused.', + self.liveaction_id, + action_node.name, + ) chain_status = action_constants.LIVEACTION_STATUS_PAUSED self._save_vars() action_node = None - elif liveaction.status == action_constants.LIVEACTION_STATUS_PENDING: - LOG.info('Chain execution (%s) paused because task "%s" is pending.', - self.liveaction_id, action_node.name) + elif ( + liveaction.status == action_constants.LIVEACTION_STATUS_PENDING + ): + LOG.info( + 'Chain execution (%s) paused because task "%s" is pending.', + self.liveaction_id, + action_node.name, + ) chain_status = action_constants.LIVEACTION_STATUS_PAUSED self._save_vars() action_node = None elif liveaction.status in action_constants.LIVEACTION_FAILED_STATES: chain_status = action_constants.LIVEACTION_STATUS_FAILED action_node = self.chain_holder.get_next_node( - action_node.name, condition='on-failure') - elif liveaction.status == action_constants.LIVEACTION_STATUS_SUCCEEDED: + action_node.name, condition="on-failure" + ) + elif ( + liveaction.status + == action_constants.LIVEACTION_STATUS_SUCCEEDED + ): chain_status = action_constants.LIVEACTION_STATUS_SUCCEEDED action_node = self.chain_holder.get_next_node( - action_node.name, condition='on-success') + action_node.name, condition="on-success" + ) else: action_node = None except Exception as e: @@ -604,12 +684,12 @@ def _run_chain(self, action_parameters, resuming=False): break if action_service.is_action_canceled_or_canceling(self.liveaction.id): - LOG.info('Chain execution (%s) canceled by user.', self.liveaction.id) + LOG.info("Chain execution (%s) canceled by user.", self.liveaction.id) chain_status = action_constants.LIVEACTION_STATUS_CANCELED return (chain_status, result, None) if action_service.is_action_paused_or_pausing(self.liveaction.id): - LOG.info('Chain execution (%s) paused by user.', self.liveaction.id) + LOG.info("Chain execution (%s) paused by user.", self.liveaction.id) chain_status = action_constants.LIVEACTION_STATUS_PAUSED self._save_vars() return (chain_status, result, self.liveaction.context) @@ -621,17 +701,22 @@ def _run_chain(self, action_parameters, resuming=False): def _format_error(self, e, msg): return { - 'error': '%s. %s' % (msg, six.text_type(e)), - 'traceback': traceback.format_exc(10) + "error": "%s. %s" % (msg, six.text_type(e)), + "traceback": traceback.format_exc(10), } def _save_vars(self): # Save the context vars in the liveaction context. - self.liveaction.context['vars'] = self.chain_holder.vars + self.liveaction.context["vars"] = self.chain_holder.vars @staticmethod - def _render_publish_vars(action_node, action_parameters, execution_result, - previous_execution_results, chain_vars): + def _render_publish_vars( + action_node, + action_parameters, + execution_result, + previous_execution_results, + chain_vars, + ): """ If no output is specified on the action_node the output is the entire execution_result. If any output is specified then only those variables are published as output of an @@ -649,36 +734,48 @@ def _render_publish_vars(action_node, action_parameters, execution_result, context.update(chain_vars) context.update({RESULTS_KEY: previous_execution_results}) - context.update({ - kv_constants.SYSTEM_SCOPE: kv_service.KeyValueLookup( - scope=kv_constants.SYSTEM_SCOPE) - }) - - context.update({ - kv_constants.DATASTORE_PARENT_SCOPE: { + context.update( + { kv_constants.SYSTEM_SCOPE: kv_service.KeyValueLookup( - scope=kv_constants.FULL_SYSTEM_SCOPE) + scope=kv_constants.SYSTEM_SCOPE + ) } - }) + ) + + context.update( + { + kv_constants.DATASTORE_PARENT_SCOPE: { + kv_constants.SYSTEM_SCOPE: kv_service.KeyValueLookup( + scope=kv_constants.FULL_SYSTEM_SCOPE + ) + } + } + ) try: - rendered_result = jinja_utils.render_values(mapping=action_node.publish, - context=context) + rendered_result = jinja_utils.render_values( + mapping=action_node.publish, context=context + ) except Exception as e: - key = getattr(e, 'key', None) - value = getattr(e, 'value', None) - msg = ('Failed rendering value for publish parameter "%s" in task "%s" ' - '(template string=%s): %s' % (key, action_node.name, value, six.text_type(e))) + key = getattr(e, "key", None) + value = getattr(e, "value", None) + msg = ( + 'Failed rendering value for publish parameter "%s" in task "%s" ' + "(template string=%s): %s" + % (key, action_node.name, value, six.text_type(e)) + ) raise action_exc.ParameterRenderingFailedException(msg) return rendered_result @staticmethod - def _resolve_params(action_node, original_parameters, results, chain_vars, chain_context): + def _resolve_params( + action_node, original_parameters, results, chain_vars, chain_context + ): # setup context with original parameters and the intermediate results. - chain_parent = chain_context.get('parent', {}) - pack = chain_parent.get('pack') - user = chain_parent.get('user') + chain_parent = chain_context.get("parent", {}) + pack = chain_parent.get("pack") + user = chain_parent.get("user") config = get_config(pack, user) @@ -688,34 +785,47 @@ def _resolve_params(action_node, original_parameters, results, chain_vars, chain context.update(chain_vars) context.update({RESULTS_KEY: results}) - context.update({ - kv_constants.SYSTEM_SCOPE: kv_service.KeyValueLookup( - scope=kv_constants.SYSTEM_SCOPE) - }) - - context.update({ - kv_constants.DATASTORE_PARENT_SCOPE: { + context.update( + { kv_constants.SYSTEM_SCOPE: kv_service.KeyValueLookup( - scope=kv_constants.FULL_SYSTEM_SCOPE) + scope=kv_constants.SYSTEM_SCOPE + ) } - }) + ) + + context.update( + { + kv_constants.DATASTORE_PARENT_SCOPE: { + kv_constants.SYSTEM_SCOPE: kv_service.KeyValueLookup( + scope=kv_constants.FULL_SYSTEM_SCOPE + ) + } + } + ) context.update({action_constants.ACTION_CONTEXT_KV_PREFIX: chain_context}) context.update({pack_constants.PACK_CONFIG_CONTEXT_KV_PREFIX: config}) try: - rendered_params = jinja_utils.render_values(mapping=action_node.get_parameters(), - context=context) + rendered_params = jinja_utils.render_values( + mapping=action_node.get_parameters(), context=context + ) except Exception as e: LOG.exception('Jinja rendering for parameter "%s" failed.' % (e.key)) - key = getattr(e, 'key', None) - value = getattr(e, 'value', None) - msg = ('Failed rendering value for action parameter "%s" in task "%s" ' - '(template string=%s): %s') % (key, action_node.name, value, six.text_type(e)) + key = getattr(e, "key", None) + value = getattr(e, "value", None) + msg = ( + 'Failed rendering value for action parameter "%s" in task "%s" ' + "(template string=%s): %s" + ) % (key, action_node.name, value, six.text_type(e)) raise action_exc.ParameterRenderingFailedException(msg) - LOG.debug('Rendered params: %s: Type: %s', rendered_params, type(rendered_params)) + LOG.debug( + "Rendered params: %s: Type: %s", rendered_params, type(rendered_params) + ) return rendered_params - def _get_next_action(self, action_node, parent_context, action_params, context_result): + def _get_next_action( + self, action_node, parent_context, action_params, context_result + ): # Verify that the referenced action exists # TODO: We do another lookup in cast_param, refactor to reduce number of lookups task_name = action_node.name @@ -723,18 +833,25 @@ def _get_next_action(self, action_node, parent_context, action_params, context_r action_db = action_db_util.get_action_by_ref(ref=action_ref) if not action_db: - error = 'Task :: %s - Action with ref %s not registered.' % (task_name, action_ref) + error = "Task :: %s - Action with ref %s not registered." % ( + task_name, + action_ref, + ) raise action_exc.InvalidActionReferencedException(error) resolved_params = ActionChainRunner._resolve_params( - action_node=action_node, original_parameters=action_params, - results=context_result, chain_vars=self.chain_holder.vars, - chain_context={'parent': parent_context}) + action_node=action_node, + original_parameters=action_params, + results=context_result, + chain_vars=self.chain_holder.vars, + chain_context={"parent": parent_context}, + ) liveaction = self._build_liveaction_object( action_node=action_node, resolved_params=resolved_params, - parent_context=parent_context) + parent_context=parent_context, + ) return liveaction @@ -747,13 +864,16 @@ def _run_action(self, liveaction, wait_for_completion=True, sleep_delay=1.0): liveaction, _ = action_service.request(liveaction) except Exception as e: liveaction.status = action_constants.LIVEACTION_STATUS_FAILED - LOG.exception('Failed to schedule liveaction.') + LOG.exception("Failed to schedule liveaction.") raise e - while (wait_for_completion and liveaction.status not in ( - action_constants.LIVEACTION_COMPLETED_STATES + - [action_constants.LIVEACTION_STATUS_PAUSED, - action_constants.LIVEACTION_STATUS_PENDING])): + while wait_for_completion and liveaction.status not in ( + action_constants.LIVEACTION_COMPLETED_STATES + + [ + action_constants.LIVEACTION_STATUS_PAUSED, + action_constants.LIVEACTION_STATUS_PENDING, + ] + ): eventlet.sleep(sleep_delay) liveaction = action_db_util.get_liveaction_by_id(liveaction.id) @@ -765,16 +885,17 @@ def _resume_action(self, liveaction, wait_for_completion=True, sleep_delay=1.0): :type sleep_delay: ``float`` """ try: - user = self.context.get('user', None) + user = self.context.get("user", None) liveaction, _ = action_service.request_resume(liveaction, user) except Exception as e: liveaction.status = action_constants.LIVEACTION_STATUS_FAILED - LOG.exception('Failed to schedule liveaction.') + LOG.exception("Failed to schedule liveaction.") raise e - while (wait_for_completion and liveaction.status not in ( - action_constants.LIVEACTION_COMPLETED_STATES + - [action_constants.LIVEACTION_STATUS_PAUSED])): + while wait_for_completion and liveaction.status not in ( + action_constants.LIVEACTION_COMPLETED_STATES + + [action_constants.LIVEACTION_STATUS_PAUSED] + ): eventlet.sleep(sleep_delay) liveaction = action_db_util.get_liveaction_by_id(liveaction.id) @@ -787,14 +908,12 @@ def _build_liveaction_object(self, action_node, resolved_params, parent_context) notify = self._get_notify(action_node) if notify: liveaction.notify = notify - LOG.debug('%s: Task notify set to: %s', action_node.name, liveaction.notify) + LOG.debug("%s: Task notify set to: %s", action_node.name, liveaction.notify) - liveaction.context = { - 'parent': parent_context, - 'chain': vars(action_node) - } - liveaction.parameters = action_param_utils.cast_params(action_ref=action_node.ref, - params=resolved_params) + liveaction.context = {"parent": parent_context, "chain": vars(action_node)} + liveaction.parameters = action_param_utils.cast_params( + action_ref=action_node.ref, params=resolved_params + ) return liveaction def _get_notify(self, action_node): @@ -807,18 +926,23 @@ def _get_notify(self, action_node): return None - def _get_updated_action_exec_result(self, action_node, liveaction, prev_task_result): + def _get_updated_action_exec_result( + self, action_node, liveaction, prev_task_result + ): if liveaction.status in action_constants.LIVEACTION_COMPLETED_STATES: - created_at = isotime.parse(prev_task_result['created_at']) + created_at = isotime.parse(prev_task_result["created_at"]) updated_at = liveaction.end_timestamp else: - created_at = isotime.parse(prev_task_result['created_at']) - updated_at = isotime.parse(prev_task_result['updated_at']) + created_at = isotime.parse(prev_task_result["created_at"]) + updated_at = isotime.parse(prev_task_result["updated_at"]) - return self._format_action_exec_result(action_node, liveaction, created_at, updated_at) + return self._format_action_exec_result( + action_node, liveaction, created_at, updated_at + ) - def _format_action_exec_result(self, action_node, liveaction_db, created_at, updated_at, - error=None): + def _format_action_exec_result( + self, action_node, liveaction_db, created_at, updated_at, error=None + ): """ Format ActionExecution result so it can be used in the final action result output. @@ -833,24 +957,24 @@ def _format_action_exec_result(self, action_node, liveaction_db, created_at, upd if liveaction_db: execution_db = ActionExecution.get(liveaction__id=str(liveaction_db.id)) - result['id'] = action_node.name - result['name'] = action_node.name - result['execution_id'] = str(execution_db.id) if execution_db else None - result['liveaction_id'] = str(liveaction_db.id) if liveaction_db else None - result['workflow'] = None + result["id"] = action_node.name + result["name"] = action_node.name + result["execution_id"] = str(execution_db.id) if execution_db else None + result["liveaction_id"] = str(liveaction_db.id) if liveaction_db else None + result["workflow"] = None - result['created_at'] = isotime.format(dt=created_at) - result['updated_at'] = isotime.format(dt=updated_at) + result["created_at"] = isotime.format(dt=created_at) + result["updated_at"] = isotime.format(dt=updated_at) if error or not liveaction_db: - result['state'] = action_constants.LIVEACTION_STATUS_FAILED + result["state"] = action_constants.LIVEACTION_STATUS_FAILED else: - result['state'] = liveaction_db.status + result["state"] = liveaction_db.status if error: - result['result'] = error + result["result"] = error else: - result['result'] = liveaction_db.result + result["result"] = liveaction_db.result return result @@ -860,4 +984,4 @@ def get_runner(): def get_metadata(): - return get_runner_metadata('action_chain_runner')[0] + return get_runner_metadata("action_chain_runner")[0] diff --git a/contrib/runners/action_chain_runner/dist_utils.py b/contrib/runners/action_chain_runner/dist_utils.py index a6f62c8cc2a..2f2043cf29f 100644 --- a/contrib/runners/action_chain_runner/dist_utils.py +++ b/contrib/runners/action_chain_runner/dist_utils.py @@ -43,17 +43,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" __all__ = [ - 'check_pip_is_installed', - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_is_installed", + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] @@ -64,15 +64,15 @@ def check_pip_is_installed(): try: import pip # NOQA except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) return True -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ @@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'): import pip if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) return True @@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path): reqs = [] def _get_link(line): - vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+'] + vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"] for vcs_prefix in vcs_prefixes: - if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)): - req_name = re.findall('.*#egg=(.+)([&|@]).*$', line) + if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)): + req_name = re.findall(".*#egg=(.+)([&|@]).*$", line) if not req_name: - req_name = re.findall('.*#egg=(.+?)$', line) + req_name = re.findall(".*#egg=(.+?)$", line) else: req_name = req_name[0] if not req_name: - raise ValueError('Line "%s" is missing "#egg="' % (line)) + raise ValueError( + 'Line "%s" is missing "#egg="' % (line) + ) - link = line.replace('-e ', '').strip() + link = line.replace("-e ", "").strip() return link, req_name[0] return None, None - with open(requirements_file_path, 'r') as fp: + with open(requirements_file_path, "r") as fp: for line in fp.readlines(): line = line.strip() - if line.startswith('#') or not line: + if line.startswith("#") or not line: continue link, req_name = _get_link(line=line) @@ -131,8 +135,8 @@ def _get_link(line): else: req_name = line - if ';' in req_name: - req_name = req_name.split(';')[0].strip() + if ";" in req_name: + req_name = req_name.split(";")[0].strip() reqs.append(req_name) @@ -146,7 +150,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -155,14 +159,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/contrib/runners/action_chain_runner/setup.py b/contrib/runners/action_chain_runner/setup.py index 6c2043505c6..7c96e1e1d1b 100644 --- a/contrib/runners/action_chain_runner/setup.py +++ b/contrib/runners/action_chain_runner/setup.py @@ -26,31 +26,33 @@ from action_chain_runner import __version__ BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt') +REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt") install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE) apply_vagrant_workaround() setup( - name='stackstorm-runner-action-chain', + name="stackstorm-runner-action-chain", version=__version__, - description=('Action-Chain workflow action runner for StackStorm event-driven ' - 'automation platform'), - author='StackStorm', - author_email='info@stackstorm.com', - license='Apache License (2.0)', - url='https://stackstorm.com/', + description=( + "Action-Chain workflow action runner for StackStorm event-driven " + "automation platform" + ), + author="StackStorm", + author_email="info@stackstorm.com", + license="Apache License (2.0)", + url="https://stackstorm.com/", install_requires=install_reqs, dependency_links=dep_links, - test_suite='tests', + test_suite="tests", zip_safe=False, include_package_data=True, - packages=find_packages(exclude=['setuptools', 'tests']), - package_data={'action_chain_runner': ['runner.yaml']}, + packages=find_packages(exclude=["setuptools", "tests"]), + package_data={"action_chain_runner": ["runner.yaml"]}, scripts=[], entry_points={ - 'st2common.runners.runner': [ - 'action-chain = action_chain_runner.action_chain_runner', + "st2common.runners.runner": [ + "action-chain = action_chain_runner.action_chain_runner", ], - } + }, ) diff --git a/contrib/runners/action_chain_runner/tests/unit/test_actionchain.py b/contrib/runners/action_chain_runner/tests/unit/test_actionchain.py index 32bb5c92499..9daed4fa908 100644 --- a/contrib/runners/action_chain_runner/tests/unit/test_actionchain.py +++ b/contrib/runners/action_chain_runner/tests/unit/test_actionchain.py @@ -39,99 +39,135 @@ class DummyActionExecution(object): - def __init__(self, status=LIVEACTION_STATUS_SUCCEEDED, result=''): + def __init__(self, status=LIVEACTION_STATUS_SUCCEEDED, result=""): self.id = None self.status = status self.result = result -FIXTURES_PACK = 'generic' +FIXTURES_PACK = "generic" TEST_MODELS = { - 'actions': ['a1.yaml', 'a2.yaml', 'action_4_action_context_param.yaml'], - 'runners': ['testrunner1.yaml'] + "actions": ["a1.yaml", "a2.yaml", "action_4_action_context_param.yaml"], + "runners": ["testrunner1.yaml"], } -MODELS = FixturesLoader().load_models(fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_MODELS) -ACTION_1 = MODELS['actions']['a1.yaml'] -ACTION_2 = MODELS['actions']['a2.yaml'] -ACTION_3 = MODELS['actions']['action_4_action_context_param.yaml'] -RUNNER = MODELS['runners']['testrunner1.yaml'] +MODELS = FixturesLoader().load_models( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS +) +ACTION_1 = MODELS["actions"]["a1.yaml"] +ACTION_2 = MODELS["actions"]["a2.yaml"] +ACTION_3 = MODELS["actions"]["action_4_action_context_param.yaml"] +RUNNER = MODELS["runners"]["testrunner1.yaml"] CHAIN_1_PATH = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain1.yaml') + FIXTURES_PACK, "actionchains", "chain1.yaml" +) CHAIN_2_PATH = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain2.yaml') + FIXTURES_PACK, "actionchains", "chain2.yaml" +) CHAIN_ACTION_CALL_NO_PARAMS_PATH = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_action_call_no_params.yaml') + FIXTURES_PACK, "actionchains", "chain_action_call_no_params.yaml" +) CHAIN_NO_DEFAULT = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'no_default_chain.yaml') + FIXTURES_PACK, "actionchains", "no_default_chain.yaml" +) CHAIN_NO_DEFAULT_2 = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'no_default_chain_2.yaml') + FIXTURES_PACK, "actionchains", "no_default_chain_2.yaml" +) CHAIN_BAD_DEFAULT = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'bad_default_chain.yaml') -CHAIN_BROKEN_ON_SUCCESS_PATH_STATIC_TASK_NAME = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_broken_on_success_path_static_task_name.yaml') -CHAIN_BROKEN_ON_FAILURE_PATH_STATIC_TASK_NAME = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_broken_on_failure_path_static_task_name.yaml') + FIXTURES_PACK, "actionchains", "bad_default_chain.yaml" +) +CHAIN_BROKEN_ON_SUCCESS_PATH_STATIC_TASK_NAME = ( + FixturesLoader().get_fixture_file_path_abs( + FIXTURES_PACK, + "actionchains", + "chain_broken_on_success_path_static_task_name.yaml", + ) +) +CHAIN_BROKEN_ON_FAILURE_PATH_STATIC_TASK_NAME = ( + FixturesLoader().get_fixture_file_path_abs( + FIXTURES_PACK, + "actionchains", + "chain_broken_on_failure_path_static_task_name.yaml", + ) +) CHAIN_FIRST_TASK_RENDER_FAIL_PATH = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_first_task_parameter_render_fail.yaml') + FIXTURES_PACK, "actionchains", "chain_first_task_parameter_render_fail.yaml" +) CHAIN_SECOND_TASK_RENDER_FAIL_PATH = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_second_task_parameter_render_fail.yaml') + FIXTURES_PACK, "actionchains", "chain_second_task_parameter_render_fail.yaml" +) CHAIN_LIST_TEMP_PATH = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_list_template.yaml') + FIXTURES_PACK, "actionchains", "chain_list_template.yaml" +) CHAIN_DICT_TEMP_PATH = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_dict_template.yaml') + FIXTURES_PACK, "actionchains", "chain_dict_template.yaml" +) CHAIN_DEP_INPUT = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_dependent_input.yaml') + FIXTURES_PACK, "actionchains", "chain_dependent_input.yaml" +) CHAIN_DEP_RESULTS_INPUT = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_dep_result_input.yaml') + FIXTURES_PACK, "actionchains", "chain_dep_result_input.yaml" +) MALFORMED_CHAIN_PATH = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'malformedchain.yaml') + FIXTURES_PACK, "actionchains", "malformedchain.yaml" +) CHAIN_TYPED_PARAMS = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_typed_params.yaml') + FIXTURES_PACK, "actionchains", "chain_typed_params.yaml" +) CHAIN_SYSTEM_PARAMS = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_typed_system_params.yaml') + FIXTURES_PACK, "actionchains", "chain_typed_system_params.yaml" +) CHAIN_WITH_ACTIONPARAM_VARS = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_with_actionparam_vars.yaml') + FIXTURES_PACK, "actionchains", "chain_with_actionparam_vars.yaml" +) CHAIN_WITH_SYSTEM_VARS = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_with_system_vars.yaml') + FIXTURES_PACK, "actionchains", "chain_with_system_vars.yaml" +) CHAIN_WITH_PUBLISH = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_with_publish.yaml') + FIXTURES_PACK, "actionchains", "chain_with_publish.yaml" +) CHAIN_WITH_PUBLISH_2 = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_with_publish_2.yaml') + FIXTURES_PACK, "actionchains", "chain_with_publish_2.yaml" +) CHAIN_WITH_PUBLISH_PARAM_RENDERING_FAILURE = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_publish_params_rendering_failure.yaml') + FIXTURES_PACK, "actionchains", "chain_publish_params_rendering_failure.yaml" +) CHAIN_WITH_INVALID_ACTION = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_with_invalid_action.yaml') -CHAIN_ACTION_PARAMS_AND_PARAMETERS_ATTRIBUTE = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_action_params_and_parameters.yaml') + FIXTURES_PACK, "actionchains", "chain_with_invalid_action.yaml" +) +CHAIN_ACTION_PARAMS_AND_PARAMETERS_ATTRIBUTE = ( + FixturesLoader().get_fixture_file_path_abs( + FIXTURES_PACK, "actionchains", "chain_action_params_and_parameters.yaml" + ) +) CHAIN_ACTION_PARAMS_ATTRIBUTE = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_action_params_attribute.yaml') + FIXTURES_PACK, "actionchains", "chain_action_params_attribute.yaml" +) CHAIN_ACTION_PARAMETERS_ATTRIBUTE = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_action_parameters_attribute.yaml') + FIXTURES_PACK, "actionchains", "chain_action_parameters_attribute.yaml" +) CHAIN_ACTION_INVALID_PARAMETER_TYPE = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_invalid_parameter_type_passed_to_action.yaml') + FIXTURES_PACK, "actionchains", "chain_invalid_parameter_type_passed_to_action.yaml" +) -CHAIN_NOTIFY_API = {'notify': {'on-complete': {'message': 'foo happened.'}}} +CHAIN_NOTIFY_API = {"notify": {"on-complete": {"message": "foo happened."}}} CHAIN_NOTIFY_DB = NotificationsHelper.to_model(CHAIN_NOTIFY_API) @mock.patch.object( - action_db_util, - 'get_runnertype_by_name', - mock.MagicMock(return_value=RUNNER)) + action_db_util, "get_runnertype_by_name", mock.MagicMock(return_value=RUNNER) +) @mock.patch.object( action_service, - 'is_action_canceled_or_canceling', - mock.MagicMock(return_value=False)) + "is_action_canceled_or_canceling", + mock.MagicMock(return_value=False), +) @mock.patch.object( - action_service, - 'is_action_paused_or_pausing', - mock.MagicMock(return_value=False)) + action_service, "is_action_paused_or_pausing", mock.MagicMock(return_value=False) +) class TestActionChainRunner(ExecutionDbTestCase): - def test_runner_creation(self): runner = acr.get_runner() self.assertTrue(runner) @@ -143,18 +179,23 @@ def test_malformed_chain(self): chain_runner.entry_point = MALFORMED_CHAIN_PATH chain_runner.action = ACTION_1 chain_runner.pre_run() - self.assertTrue(False, 'Expected pre_run to fail.') + self.assertTrue(False, "Expected pre_run to fail.") except runnerexceptions.ActionRunnerPreRunError: self.assertTrue(True) - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_success_path(self, request): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_1_PATH chain_runner.action = ACTION_1 - action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_1.name, pack=ACTION_1.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.liveaction.notify = CHAIN_NOTIFY_DB chain_runner.pre_run() @@ -163,9 +204,12 @@ def test_chain_runner_success_path(self, request): # based on the chain the callcount is known to be 3. Not great but works. self.assertEqual(request.call_count, 3) - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_chain_second_task_times_out(self, request): # Second task in the chain times out so the action chain status should be timeout chain_runner = acr.get_runner() @@ -177,13 +221,15 @@ def test_chain_runner_chain_second_task_times_out(self, request): def mock_run_action(*args, **kwargs): original_live_action = args[0] liveaction = original_run_action(*args, **kwargs) - if original_live_action.action == 'wolfpack.a2': + if original_live_action.action == "wolfpack.a2": # Mock a timeout for second task liveaction.status = LIVEACTION_STATUS_TIMED_OUT return liveaction chain_runner._run_action = mock_run_action - action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_1.name, pack=ACTION_1.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.pre_run() status, _, _ = chain_runner.run({}) @@ -193,9 +239,12 @@ def mock_run_action(*args, **kwargs): # based on the chain the callcount is known to be 3. Not great but works. self.assertEqual(request.call_count, 3) - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_task_is_canceled_while_running(self, request): # Second task in the action is CANCELED, make sure runner doesn't get stuck in an infinite # loop @@ -207,7 +256,7 @@ def test_chain_runner_task_is_canceled_while_running(self, request): def mock_run_action(*args, **kwargs): original_live_action = args[0] - if original_live_action.action == 'wolfpack.a2': + if original_live_action.action == "wolfpack.a2": status = LIVEACTION_STATUS_CANCELED else: status = LIVEACTION_STATUS_SUCCEEDED @@ -216,7 +265,9 @@ def mock_run_action(*args, **kwargs): return liveaction chain_runner._run_action = mock_run_action - action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_1.name, pack=ACTION_1.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.pre_run() status, _, _ = chain_runner.run({}) @@ -227,16 +278,21 @@ def mock_run_action(*args, **kwargs): # canceled self.assertEqual(request.call_count, 2) - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_success_task_action_call_with_no_params(self, request): # Make sure that the runner doesn't explode if task definition contains # no "params" section chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_ACTION_CALL_NO_PARAMS_PATH chain_runner.action = ACTION_1 - action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_1.name, pack=ACTION_1.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.liveaction.notify = CHAIN_NOTIFY_DB chain_runner.pre_run() @@ -245,14 +301,19 @@ def test_chain_runner_success_task_action_call_with_no_params(self, request): # based on the chain the callcount is known to be 3. Not great but works. self.assertEqual(request.call_count, 3) - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_no_default(self, request): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_NO_DEFAULT chain_runner.action = ACTION_1 - action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_1.name, pack=ACTION_1.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.pre_run() chain_runner.run({}) @@ -264,9 +325,12 @@ def test_chain_runner_no_default(self, request): # based on the chain the callcount is known to be 3. Not great but works. self.assertEqual(request.call_count, 3) - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_no_default_multiple_options(self, request): # subtle difference is that when there are multiple possible default nodes # the order per chain definition may not be preseved. This is really a @@ -274,7 +338,9 @@ def test_chain_runner_no_default_multiple_options(self, request): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_NO_DEFAULT_2 chain_runner.action = ACTION_1 - action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_1.name, pack=ACTION_1.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.pre_run() chain_runner.run({}) @@ -286,29 +352,44 @@ def test_chain_runner_no_default_multiple_options(self, request): # based on the chain the callcount is known to be 2. self.assertEqual(request.call_count, 2) - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_bad_default(self, request): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_BAD_DEFAULT chain_runner.action = ACTION_1 - expected_msg = 'Unable to find node with name "bad_default" referenced in "default".' - self.assertRaisesRegexp(runnerexceptions.ActionRunnerPreRunError, - expected_msg, chain_runner.pre_run) - - @mock.patch('eventlet.sleep', mock.MagicMock()) - @mock.patch.object(action_db_util, 'get_liveaction_by_id', mock.MagicMock( - return_value=DummyActionExecution())) - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', - return_value=(DummyActionExecution(status=LIVEACTION_STATUS_RUNNING), None)) + expected_msg = ( + 'Unable to find node with name "bad_default" referenced in "default".' + ) + self.assertRaisesRegexp( + runnerexceptions.ActionRunnerPreRunError, expected_msg, chain_runner.pre_run + ) + + @mock.patch("eventlet.sleep", mock.MagicMock()) + @mock.patch.object( + action_db_util, + "get_liveaction_by_id", + mock.MagicMock(return_value=DummyActionExecution()), + ) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, + "request", + return_value=(DummyActionExecution(status=LIVEACTION_STATUS_RUNNING), None), + ) def test_chain_runner_success_path_with_wait(self, request): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_1_PATH chain_runner.action = ACTION_1 - action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_1.name, pack=ACTION_1.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.pre_run() chain_runner.run({}) @@ -316,15 +397,21 @@ def test_chain_runner_success_path_with_wait(self, request): # based on the chain the callcount is known to be 3. Not great but works. self.assertEqual(request.call_count, 3) - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', - return_value=(DummyActionExecution(status=LIVEACTION_STATUS_FAILED), None)) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, + "request", + return_value=(DummyActionExecution(status=LIVEACTION_STATUS_FAILED), None), + ) def test_chain_runner_failure_path(self, request): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_1_PATH chain_runner.action = ACTION_1 - action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_1.name, pack=ACTION_1.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.pre_run() status, _, _ = chain_runner.run({}) @@ -333,42 +420,57 @@ def test_chain_runner_failure_path(self, request): # based on the chain the callcount is known to be 2. Not great but works. self.assertEqual(request.call_count, 2) - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', - return_value=(DummyActionExecution(), None)) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_broken_on_success_path_static_task_name(self, request): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_BROKEN_ON_SUCCESS_PATH_STATIC_TASK_NAME chain_runner.action = ACTION_1 - expected_msg = ('Unable to find node with name "c5" referenced in "on-success" ' - 'in task "c2"') - self.assertRaisesRegexp(runnerexceptions.ActionRunnerPreRunError, - expected_msg, chain_runner.pre_run) - - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', - return_value=(DummyActionExecution(), None)) + expected_msg = ( + 'Unable to find node with name "c5" referenced in "on-success" ' + 'in task "c2"' + ) + self.assertRaisesRegexp( + runnerexceptions.ActionRunnerPreRunError, expected_msg, chain_runner.pre_run + ) + + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_broken_on_failure_path_static_task_name(self, request): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_BROKEN_ON_FAILURE_PATH_STATIC_TASK_NAME chain_runner.action = ACTION_1 - expected_msg = ('Unable to find node with name "c6" referenced in "on-failure" ' - 'in task "c2"') - self.assertRaisesRegexp(runnerexceptions.ActionRunnerPreRunError, - expected_msg, chain_runner.pre_run) - - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', side_effect=RuntimeError('Test Failure.')) + expected_msg = ( + 'Unable to find node with name "c6" referenced in "on-failure" ' + 'in task "c2"' + ) + self.assertRaisesRegexp( + runnerexceptions.ActionRunnerPreRunError, expected_msg, chain_runner.pre_run + ) + + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, "request", side_effect=RuntimeError("Test Failure.") + ) def test_chain_runner_action_exception(self, request): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_1_PATH chain_runner.action = ACTION_1 - action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_1.name, pack=ACTION_1.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.pre_run() status, results, _ = chain_runner.run({}) @@ -379,102 +481,131 @@ def test_chain_runner_action_exception(self, request): self.assertEqual(request.call_count, 2) error_count = 0 - for task_result in results['tasks']: - if task_result['result'].get('error', None): + for task_result in results["tasks"]: + if task_result["result"].get("error", None): error_count += 1 self.assertEqual(error_count, 2) - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_str_param_temp(self, request): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_FIRST_TASK_RENDER_FAIL_PATH chain_runner.action = ACTION_1 - action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_1.name, pack=ACTION_1.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.pre_run() - chain_runner.run({'s1': 1, 's2': 2, 's3': 3, 's4': 4}) + chain_runner.run({"s1": 1, "s2": 2, "s3": 3, "s4": 4}) self.assertNotEqual(chain_runner.chain_holder.actionchain, None) mock_args, _ = request.call_args self.assertEqual(mock_args[0].parameters, {"p1": "1"}) - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_list_param_temp(self, request): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_LIST_TEMP_PATH chain_runner.action = ACTION_1 - action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_1.name, pack=ACTION_1.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.pre_run() - chain_runner.run({'s1': 1, 's2': 2, 's3': 3, 's4': 4}) + chain_runner.run({"s1": 1, "s2": 2, "s3": 3, "s4": 4}) self.assertNotEqual(chain_runner.chain_holder.actionchain, None) mock_args, _ = request.call_args self.assertEqual(mock_args[0].parameters, {"p1": "[2, 3, 4]"}) - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_dict_param_temp(self, request): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_DICT_TEMP_PATH chain_runner.action = ACTION_1 - action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_1.name, pack=ACTION_1.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.pre_run() - chain_runner.run({'s1': 1, 's2': 2, 's3': 3, 's4': 4}) + chain_runner.run({"s1": 1, "s2": 2, "s3": 3, "s4": 4}) self.assertNotEqual(chain_runner.chain_holder.actionchain, None) expected_value = {"p1": {"p1.3": "[3, 4]", "p1.2": "2", "p1.1": "1"}} mock_args, _ = request.call_args self.assertEqual(mock_args[0].parameters, expected_value) - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', - return_value=(DummyActionExecution(result={'o1': '1'}), None)) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, + "request", + return_value=(DummyActionExecution(result={"o1": "1"}), None), + ) def test_chain_runner_dependent_param_temp(self, request): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_DEP_INPUT chain_runner.action = ACTION_1 - action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_1.name, pack=ACTION_1.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.pre_run() - chain_runner.run({'s1': 1, 's2': 2, 's3': 3, 's4': 4}) + chain_runner.run({"s1": 1, "s2": 2, "s3": 3, "s4": 4}) self.assertNotEqual(chain_runner.chain_holder.actionchain, None) - expected_values = [{u'p1': u'1'}, - {u'p1': u'1'}, - {u'p2': u'1', u'p3': u'1', u'p1': u'1'}] + expected_values = [{"p1": "1"}, {"p1": "1"}, {"p2": "1", "p3": "1", "p1": "1"}] # Each of the call_args must be one of for call_args in request.call_args_list: self.assertIn(call_args[0][0].parameters, expected_values) expected_values.remove(call_args[0][0].parameters) - self.assertEqual(len(expected_values), 0, 'Not all expected values received.') - - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', - return_value=(DummyActionExecution(result={'o1': '1'}), None)) + self.assertEqual(len(expected_values), 0, "Not all expected values received.") + + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, + "request", + return_value=(DummyActionExecution(result={"o1": "1"}), None), + ) def test_chain_runner_dependent_results_param(self, request): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_DEP_RESULTS_INPUT chain_runner.action = ACTION_1 - action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_1.name, pack=ACTION_1.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.pre_run() - chain_runner.run({'s1': 1}) + chain_runner.run({"s1": 1}) self.assertNotEqual(chain_runner.chain_holder.actionchain, None) if six.PY2: - expected_values = [{u'p1': u'1'}, - {u'p1': u'1'}, - {u'out': u"{'c2': {'o1': '1'}, 'c1': {'o1': '1'}}"}] + expected_values = [ + {"p1": "1"}, + {"p1": "1"}, + {"out": "{'c2': {'o1': '1'}, 'c1': {'o1': '1'}}"}, + ] else: - expected_values = [{'p1': '1'}, - {'p1': '1'}, - {'out': "{'c1': {'o1': '1'}, 'c2': {'o1': '1'}}"}] + expected_values = [ + {"p1": "1"}, + {"p1": "1"}, + {"out": "{'c1': {'o1': '1'}, 'c2': {'o1': '1'}}"}, + ] # Each of the call_args must be one of self.assertEqual(request.call_count, 3) @@ -482,104 +613,137 @@ def test_chain_runner_dependent_results_param(self, request): self.assertIn(call_args[0][0].parameters, expected_values) expected_values.remove(call_args[0][0].parameters) - self.assertEqual(len(expected_values), 0, 'Not all expected values received.') + self.assertEqual(len(expected_values), 0, "Not all expected values received.") - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(RunnerType, 'get_by_name', - mock.MagicMock(return_value=RUNNER)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object(RunnerType, "get_by_name", mock.MagicMock(return_value=RUNNER)) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_missing_param_temp(self, request): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_FIRST_TASK_RENDER_FAIL_PATH chain_runner.action = ACTION_1 - action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_1.name, pack=ACTION_1.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.pre_run() chain_runner.run({}) - self.assertEqual(request.call_count, 0, 'No call expected.') - - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + self.assertEqual(request.call_count, 0, "No call expected.") + + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_failure_during_param_rendering_single_task(self, request): # Parameter rendering should result in a top level error which aborts # the whole chain chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_FIRST_TASK_RENDER_FAIL_PATH chain_runner.action = ACTION_1 - action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_1.name, pack=ACTION_1.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.pre_run() status, result, _ = chain_runner.run({}) # No tasks ran because rendering of parameters for the first task failed self.assertEqual(status, LIVEACTION_STATUS_FAILED) - self.assertEqual(result['tasks'], []) - self.assertIn('error', result) - self.assertIn('traceback', result) - self.assertIn('Failed to run task "c1". Parameter rendering failed', result['error']) - self.assertIn('Traceback', result['traceback']) - - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + self.assertEqual(result["tasks"], []) + self.assertIn("error", result) + self.assertIn("traceback", result) + self.assertIn( + 'Failed to run task "c1". Parameter rendering failed', result["error"] + ) + self.assertIn("Traceback", result["traceback"]) + + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_failure_during_param_rendering_multiple_tasks(self, request): # Parameter rendering should result in a top level error which aborts # the whole chain chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_SECOND_TASK_RENDER_FAIL_PATH chain_runner.action = ACTION_1 - action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_1.name, pack=ACTION_1.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.pre_run() status, result, _ = chain_runner.run({}) # Verify that only first task has ran self.assertEqual(status, LIVEACTION_STATUS_FAILED) - self.assertEqual(len(result['tasks']), 1) - self.assertEqual(result['tasks'][0]['name'], 'c1') - - expected_error = ('Failed rendering value for action parameter "p1" in ' - 'task "c2" (template string={{s1}}):') - - self.assertIn('error', result) - self.assertIn('traceback', result) - self.assertIn('Failed to run task "c2". Parameter rendering failed', result['error']) - self.assertIn(expected_error, result['error']) - self.assertIn('Traceback', result['traceback']) - - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_2)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + self.assertEqual(len(result["tasks"]), 1) + self.assertEqual(result["tasks"][0]["name"], "c1") + + expected_error = ( + 'Failed rendering value for action parameter "p1" in ' + 'task "c2" (template string={{s1}}):' + ) + + self.assertIn("error", result) + self.assertIn("traceback", result) + self.assertIn( + 'Failed to run task "c2". Parameter rendering failed', result["error"] + ) + self.assertIn(expected_error, result["error"]) + self.assertIn("Traceback", result["traceback"]) + + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_2) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_typed_params(self, request): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_TYPED_PARAMS chain_runner.action = ACTION_2 - action_ref = ResourceReference.to_string_reference(name=ACTION_2.name, pack=ACTION_2.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_2.name, pack=ACTION_2.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.pre_run() - chain_runner.run({'s1': 1, 's2': 'two', 's3': 3.14}) + chain_runner.run({"s1": 1, "s2": "two", "s3": 3.14}) self.assertNotEqual(chain_runner.chain_holder.actionchain, None) - expected_value = {'booltype': True, - 'inttype': 1, - 'numbertype': 3.14, - 'strtype': 'two', - 'arrtype': ['1', 'two'], - 'objtype': {'s2': 'two', - 'k1': '1'}} + expected_value = { + "booltype": True, + "inttype": 1, + "numbertype": 3.14, + "strtype": "two", + "arrtype": ["1", "two"], + "objtype": {"s2": "two", "k1": "1"}, + } mock_args, _ = request.call_args self.assertEqual(mock_args[0].parameters, expected_value) - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_2)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_2) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_typed_system_params(self, request): - action_ref = ResourceReference.to_string_reference(name=ACTION_2.name, pack=ACTION_2.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_2.name, pack=ACTION_2.pack + ) kvps = [] try: - kvps.append(KeyValuePair.add_or_update(KeyValuePairDB(name='a', value='1'))) - kvps.append(KeyValuePair.add_or_update(KeyValuePairDB(name='a.b.c', value='two'))) + kvps.append(KeyValuePair.add_or_update(KeyValuePairDB(name="a", value="1"))) + kvps.append( + KeyValuePair.add_or_update(KeyValuePairDB(name="a.b.c", value="two")) + ) chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_SYSTEM_PARAMS chain_runner.action = ACTION_2 @@ -587,22 +751,28 @@ def test_chain_runner_typed_system_params(self, request): chain_runner.pre_run() chain_runner.run({}) self.assertNotEqual(chain_runner.chain_holder.actionchain, None) - expected_value = {'inttype': 1, - 'strtype': 'two'} + expected_value = {"inttype": 1, "strtype": "two"} mock_args, _ = request.call_args self.assertEqual(mock_args[0].parameters, expected_value) finally: for kvp in kvps: KeyValuePair.delete(kvp) - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_2)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_2) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_vars_system_params(self, request): - action_ref = ResourceReference.to_string_reference(name=ACTION_2.name, pack=ACTION_2.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_2.name, pack=ACTION_2.pack + ) kvps = [] try: - kvps.append(KeyValuePair.add_or_update(KeyValuePairDB(name='a', value='two'))) + kvps.append( + KeyValuePair.add_or_update(KeyValuePairDB(name="a", value="two")) + ) chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_WITH_SYSTEM_VARS chain_runner.action = ACTION_2 @@ -610,72 +780,88 @@ def test_chain_runner_vars_system_params(self, request): chain_runner.pre_run() chain_runner.run({}) self.assertNotEqual(chain_runner.chain_holder.actionchain, None) - expected_value = {'inttype': 1, - 'strtype': 'two', - 'booltype': True} + expected_value = {"inttype": 1, "strtype": "two", "booltype": True} mock_args, _ = request.call_args self.assertEqual(mock_args[0].parameters, expected_value) finally: for kvp in kvps: KeyValuePair.delete(kvp) - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_2)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_2) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_vars_action_params(self, request): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_WITH_ACTIONPARAM_VARS chain_runner.action = ACTION_2 - action_ref = ResourceReference.to_string_reference(name=ACTION_2.name, pack=ACTION_2.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_2.name, pack=ACTION_2.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.pre_run() - chain_runner.run({'input_a': 'two'}) + chain_runner.run({"input_a": "two"}) self.assertNotEqual(chain_runner.chain_holder.actionchain, None) - expected_value = {'inttype': 1, - 'strtype': 'two', - 'booltype': True} + expected_value = {"inttype": 1, "strtype": "two", "booltype": True} mock_args, _ = request.call_args self.assertEqual(mock_args[0].parameters, expected_value) - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_2)) - @mock.patch.object(action_service, 'request', - return_value=(DummyActionExecution(result={'raw_out': 'published'}), None)) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_2) + ) + @mock.patch.object( + action_service, + "request", + return_value=(DummyActionExecution(result={"raw_out": "published"}), None), + ) def test_chain_runner_publish(self, request): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_WITH_PUBLISH chain_runner.action = ACTION_2 - action_ref = ResourceReference.to_string_reference(name=ACTION_2.name, pack=ACTION_2.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_2.name, pack=ACTION_2.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) - chain_runner.runner_parameters = {'display_published': True} + chain_runner.runner_parameters = {"display_published": True} chain_runner.pre_run() - action_parameters = {'action_param_1': 'test value 1'} + action_parameters = {"action_param_1": "test value 1"} _, result, _ = chain_runner.run(action_parameters=action_parameters) # We also assert that the action parameters are available in the # "publish" scope self.assertNotEqual(chain_runner.chain_holder.actionchain, None) - expected_value = {'inttype': 1, - 'strtype': 'published', - 'booltype': True, - 'published_action_param': action_parameters['action_param_1']} + expected_value = { + "inttype": 1, + "strtype": "published", + "booltype": True, + "published_action_param": action_parameters["action_param_1"], + } mock_args, _ = request.call_args self.assertEqual(mock_args[0].parameters, expected_value) # Assert that the variables are correctly published - self.assertEqual(result['published'], - {'published_action_param': u'test value 1', 'o1': u'published'}) - - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + self.assertEqual( + result["published"], + {"published_action_param": "test value 1", "o1": "published"}, + ) + + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_publish_param_rendering_failure(self, request): # Parameter rendering should result in a top level error which aborts # the whole chain chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_WITH_PUBLISH_PARAM_RENDERING_FAILURE chain_runner.action = ACTION_1 - action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_1.name, pack=ACTION_1.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.pre_run() @@ -685,16 +871,21 @@ def test_chain_runner_publish_param_rendering_failure(self, request): # TODO: Should we treat this as task error? Right now it bubbles all # the way up and it's not really consistent with action param # rendering failure - expected_error = ('Failed rendering value for publish parameter "p1" in ' - 'task "c2" (template string={{ not_defined }}):') + expected_error = ( + 'Failed rendering value for publish parameter "p1" in ' + 'task "c2" (template string={{ not_defined }}):' + ) self.assertIn(expected_error, six.text_type(e)) pass else: - self.fail('Exception was not thrown') - - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_2)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + self.fail("Exception was not thrown") + + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_2) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_task_passes_invalid_parameter_type_to_action(self, mock_request): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_ACTION_INVALID_PARAMETER_TYPE @@ -702,48 +893,72 @@ def test_chain_task_passes_invalid_parameter_type_to_action(self, mock_request): chain_runner.pre_run() action_parameters = {} - expected_msg = (r'Failed to cast value "stringnotanarray" \(type: str\) for parameter ' - r'"arrtype" of type "array"') - self.assertRaisesRegexp(ValueError, expected_msg, chain_runner.run, - action_parameters=action_parameters) - - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=None)) - @mock.patch.object(action_service, 'request', - return_value=(DummyActionExecution(result={'raw_out': 'published'}), None)) + expected_msg = ( + r'Failed to cast value "stringnotanarray" \(type: str\) for parameter ' + r'"arrtype" of type "array"' + ) + self.assertRaisesRegexp( + ValueError, + expected_msg, + chain_runner.run, + action_parameters=action_parameters, + ) + + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=None) + ) + @mock.patch.object( + action_service, + "request", + return_value=(DummyActionExecution(result={"raw_out": "published"}), None), + ) def test_action_chain_runner_referenced_action_doesnt_exist(self, mock_request): # Action referenced by a task doesn't exist, should result in a top level error chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_WITH_INVALID_ACTION chain_runner.action = ACTION_2 - action_ref = ResourceReference.to_string_reference(name=ACTION_2.name, pack=ACTION_2.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_2.name, pack=ACTION_2.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.pre_run() action_parameters = {} status, output, _ = chain_runner.run(action_parameters=action_parameters) - expected_error = ('Failed to run task "c1". Action with reference "wolfpack.a2" ' - 'doesn\'t exist.') + expected_error = ( + 'Failed to run task "c1". Action with reference "wolfpack.a2" ' + "doesn't exist." + ) self.assertEqual(status, LIVEACTION_STATUS_FAILED) - self.assertIn(expected_error, output['error']) - self.assertIn('Traceback', output['traceback']) + self.assertIn(expected_error, output["error"]) + self.assertIn("Traceback", output["traceback"]) - def test_exception_is_thrown_if_both_params_and_parameters_attributes_are_provided(self): + def test_exception_is_thrown_if_both_params_and_parameters_attributes_are_provided( + self, + ): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_ACTION_PARAMS_AND_PARAMETERS_ATTRIBUTE chain_runner.action = ACTION_2 - expected_msg = ('Either "params" or "parameters" attribute needs to be provided, but ' - 'not both') - self.assertRaisesRegexp(runnerexceptions.ActionRunnerPreRunError, expected_msg, - chain_runner.pre_run) - - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_2)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + expected_msg = ( + 'Either "params" or "parameters" attribute needs to be provided, but ' + "not both" + ) + self.assertRaisesRegexp( + runnerexceptions.ActionRunnerPreRunError, expected_msg, chain_runner.pre_run + ) + + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_2) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_params_and_parameters_attributes_both_work(self, _): - action_ref = ResourceReference.to_string_reference(name=ACTION_2.name, pack=ACTION_2.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_2.name, pack=ACTION_2.pack + ) # "params" attribute used chain_runner = acr.get_runner() @@ -756,10 +971,12 @@ def test_params_and_parameters_attributes_both_work(self, _): def mock_build_liveaction_object(action_node, resolved_params, parent_context): # Verify parameters are correctly passed to the action - self.assertEqual(resolved_params, {'pparams': 'v1'}) - original_build_liveaction_object(action_node=action_node, - resolved_params=resolved_params, - parent_context=parent_context) + self.assertEqual(resolved_params, {"pparams": "v1"}) + original_build_liveaction_object( + action_node=action_node, + resolved_params=resolved_params, + parent_context=parent_context, + ) chain_runner._build_liveaction_object = mock_build_liveaction_object @@ -776,10 +993,12 @@ def mock_build_liveaction_object(action_node, resolved_params, parent_context): def mock_build_liveaction_object(action_node, resolved_params, parent_context): # Verify parameters are correctly passed to the action - self.assertEqual(resolved_params, {'pparameters': 'v1'}) - original_build_liveaction_object(action_node=action_node, - resolved_params=resolved_params, - parent_context=parent_context) + self.assertEqual(resolved_params, {"pparameters": "v1"}) + original_build_liveaction_object( + action_node=action_node, + resolved_params=resolved_params, + parent_context=parent_context, + ) chain_runner._build_liveaction_object = mock_build_liveaction_object @@ -787,21 +1006,27 @@ def mock_build_liveaction_object(action_node, resolved_params, parent_context): status, output, _ = chain_runner.run(action_parameters=action_parameters) self.assertEqual(status, LIVEACTION_STATUS_SUCCEEDED) - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_2)) - @mock.patch.object(action_service, 'request', - return_value=(DummyActionExecution(result={'raw_out': 'published'}), None)) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_2) + ) + @mock.patch.object( + action_service, + "request", + return_value=(DummyActionExecution(result={"raw_out": "published"}), None), + ) def test_display_published_is_true_by_default(self, _): - action_ref = ResourceReference.to_string_reference(name=ACTION_2.name, pack=ACTION_2.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_2.name, pack=ACTION_2.pack + ) expected_published_values = { - 't1_publish_param_1': 'foo1', - 't1_publish_param_2': 'foo2', - 't1_publish_param_3': 'foo3', - 't2_publish_param_1': 'foo4', - 't2_publish_param_2': 'foo5', - 't2_publish_param_3': 'foo6', - 'publish_last_wins': 'bar_last', + "t1_publish_param_1": "foo1", + "t1_publish_param_2": "foo2", + "t1_publish_param_3": "foo3", + "t2_publish_param_1": "foo4", + "t2_publish_param_2": "foo5", + "t2_publish_param_3": "foo6", + "publish_last_wins": "bar_last", } # 1. display_published is True by default @@ -816,35 +1041,35 @@ def test_display_published_is_true_by_default(self, _): _, result, _ = chain_runner.run(action_parameters=action_parameters) # Assert that the variables are correctly published - self.assertEqual(result['published'], expected_published_values) + self.assertEqual(result["published"], expected_published_values) # 2. display_published is True by default so end result should be the same chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_WITH_PUBLISH_2 chain_runner.action = ACTION_2 chain_runner.liveaction = LiveActionDB(action=action_ref) - chain_runner.runner_parameters = {'display_published': True} + chain_runner.runner_parameters = {"display_published": True} chain_runner.pre_run() action_parameters = {} _, result, _ = chain_runner.run(action_parameters=action_parameters) # Assert that the variables are correctly published - self.assertEqual(result['published'], expected_published_values) + self.assertEqual(result["published"], expected_published_values) # 3. display_published is disabled chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_WITH_PUBLISH_2 chain_runner.action = ACTION_2 chain_runner.liveaction = LiveActionDB(action=action_ref) - chain_runner.runner_parameters = {'display_published': False} + chain_runner.runner_parameters = {"display_published": False} chain_runner.pre_run() action_parameters = {} _, result, _ = chain_runner.run(action_parameters=action_parameters) - self.assertNotIn('published', result) - self.assertEqual(result.get('published', {}), {}) + self.assertNotIn("published", result) + self.assertEqual(result.get("published", {}), {}) @classmethod def tearDownClass(cls): diff --git a/contrib/runners/action_chain_runner/tests/unit/test_actionchain_cancel.py b/contrib/runners/action_chain_runner/tests/unit/test_actionchain_cancel.py index 7bba3606d88..dca88cf8037 100644 --- a/contrib/runners/action_chain_runner/tests/unit/test_actionchain_cancel.py +++ b/contrib/runners/action_chain_runner/tests/unit/test_actionchain_cancel.py @@ -20,6 +20,7 @@ import tempfile from st2tests import config as test_config + test_config.parse_args() from st2common.bootstrap import actionsregistrar @@ -40,39 +41,25 @@ TEST_FIXTURES = { - 'chains': [ - 'test_cancel.yaml', - 'test_cancel_with_subworkflow.yaml' - ], - 'actions': [ - 'test_cancel.yaml', - 'test_cancel_with_subworkflow.yaml' - ] + "chains": ["test_cancel.yaml", "test_cancel_with_subworkflow.yaml"], + "actions": ["test_cancel.yaml", "test_cancel_with_subworkflow.yaml"], } -TEST_PACK = 'action_chain_tests' -TEST_PACK_PATH = fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "action_chain_tests" +TEST_PACK_PATH = fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK -PACKS = [ - TEST_PACK_PATH, - fixturesloader.get_fixtures_packs_base_path() + '/core' -] +PACKS = [TEST_PACK_PATH, fixturesloader.get_fixtures_packs_base_path() + "/core"] -USERNAME = 'stanley' +USERNAME = "stanley" -@mock.patch.object( - CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) -@mock.patch.object( - CUDPublisher, - 'publish_create', - mock.MagicMock(return_value=None)) +@mock.patch.object(CUDPublisher, "publish_update", mock.MagicMock(return_value=None)) +@mock.patch.object(CUDPublisher, "publish_create", mock.MagicMock(return_value=None)) @mock.patch.object( LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=MockLiveActionPublisherNonBlocking.publish_state)) + "publish_state", + mock.MagicMock(side_effect=MockLiveActionPublisherNonBlocking.publish_state), +) class ActionChainRunnerPauseResumeTest(ExecutionDbTestCase): temp_file_path = None @@ -86,8 +73,7 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: @@ -98,7 +84,7 @@ def setUp(self): # Create temporary directory used by the tests _, self.temp_file_path = tempfile.mkstemp() - os.chmod(self.temp_file_path, 0o755) # nosec + os.chmod(self.temp_file_path, 0o755) # nosec def tearDown(self): if self.temp_file_path and os.path.exists(self.temp_file_path): @@ -110,7 +96,7 @@ def _wait_for_children(self, execution, interval=0.1, retries=100): # Wait until the execution has children. for i in range(0, retries): execution = ActionExecution.get_by_id(str(execution.id)) - if len(getattr(execution, 'children', [])) <= 0: + if len(getattr(execution, "children", [])) <= 0: eventlet.sleep(interval) continue @@ -123,34 +109,42 @@ def test_chain_cancel(self): path = self.temp_file_path self.assertTrue(os.path.exists(path)) - action = TEST_PACK + '.' + 'test_cancel' - params = {'tempfile': path, 'message': 'foobar'} + action = TEST_PACK + "." + "test_cancel" + params = {"tempfile": path, "message": "foobar"} liveaction = LiveActionDB(action=action, parameters=params) liveaction, execution = action_service.request(liveaction) liveaction = LiveAction.get_by_id(str(liveaction.id)) # Wait until the liveaction is running. - liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_RUNNING) + liveaction = self._wait_on_status( + liveaction, action_constants.LIVEACTION_STATUS_RUNNING + ) # Request action chain to cancel. - liveaction, execution = action_service.request_cancellation(liveaction, USERNAME) + liveaction, execution = action_service.request_cancellation( + liveaction, USERNAME + ) # Wait until the liveaction is canceling. - liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_CANCELING) + liveaction = self._wait_on_status( + liveaction, action_constants.LIVEACTION_STATUS_CANCELING + ) # Delete the temporary file that the action chain is waiting on. os.remove(path) self.assertFalse(os.path.exists(path)) # Wait until the liveaction is canceled. - liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_CANCELED) + liveaction = self._wait_on_status( + liveaction, action_constants.LIVEACTION_STATUS_CANCELED + ) # Wait for non-blocking threads to complete. Ensure runner is not running. MockLiveActionPublisherNonBlocking.wait_all() # Check liveaction result. - self.assertIn('tasks', liveaction.result) - self.assertEqual(len(liveaction.result['tasks']), 1) + self.assertIn("tasks", liveaction.result) + self.assertEqual(len(liveaction.result["tasks"]), 1) def test_chain_cancel_cascade_to_subworkflow(self): # A temp file is created during test setup. Ensure the temp file exists. @@ -159,14 +153,16 @@ def test_chain_cancel_cascade_to_subworkflow(self): path = self.temp_file_path self.assertTrue(os.path.exists(path)) - action = TEST_PACK + '.' + 'test_cancel_with_subworkflow' - params = {'tempfile': path, 'message': 'foobar'} + action = TEST_PACK + "." + "test_cancel_with_subworkflow" + params = {"tempfile": path, "message": "foobar"} liveaction = LiveActionDB(action=action, parameters=params) liveaction, execution = action_service.request(liveaction) liveaction = LiveAction.get_by_id(str(liveaction.id)) # Wait until the liveaction is running. - liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_RUNNING) + liveaction = self._wait_on_status( + liveaction, action_constants.LIVEACTION_STATUS_RUNNING + ) # Wait for subworkflow to register. execution = self._wait_for_children(execution) @@ -174,44 +170,58 @@ def test_chain_cancel_cascade_to_subworkflow(self): # Wait until the subworkflow is running. task1_exec = ActionExecution.get_by_id(execution.children[0]) - task1_live = LiveAction.get_by_id(task1_exec.liveaction['id']) - task1_live = self._wait_on_status(task1_live, action_constants.LIVEACTION_STATUS_RUNNING) + task1_live = LiveAction.get_by_id(task1_exec.liveaction["id"]) + task1_live = self._wait_on_status( + task1_live, action_constants.LIVEACTION_STATUS_RUNNING + ) # Request action chain to cancel. - liveaction, execution = action_service.request_cancellation(liveaction, USERNAME) + liveaction, execution = action_service.request_cancellation( + liveaction, USERNAME + ) # Wait until the liveaction is canceling. - liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_CANCELING) + liveaction = self._wait_on_status( + liveaction, action_constants.LIVEACTION_STATUS_CANCELING + ) self.assertEqual(len(execution.children), 1) # Wait until the subworkflow is canceling. task1_exec = ActionExecution.get_by_id(execution.children[0]) - task1_live = LiveAction.get_by_id(task1_exec.liveaction['id']) - task1_live = self._wait_on_status(task1_live, action_constants.LIVEACTION_STATUS_CANCELING) + task1_live = LiveAction.get_by_id(task1_exec.liveaction["id"]) + task1_live = self._wait_on_status( + task1_live, action_constants.LIVEACTION_STATUS_CANCELING + ) # Delete the temporary file that the action chain is waiting on. os.remove(path) self.assertFalse(os.path.exists(path)) # Wait until the liveaction is canceled. - liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_CANCELED) + liveaction = self._wait_on_status( + liveaction, action_constants.LIVEACTION_STATUS_CANCELED + ) self.assertEqual(len(execution.children), 1) # Wait until the subworkflow is canceled. task1_exec = ActionExecution.get_by_id(execution.children[0]) - task1_live = LiveAction.get_by_id(task1_exec.liveaction['id']) - task1_live = self._wait_on_status(task1_live, action_constants.LIVEACTION_STATUS_CANCELED) + task1_live = LiveAction.get_by_id(task1_exec.liveaction["id"]) + task1_live = self._wait_on_status( + task1_live, action_constants.LIVEACTION_STATUS_CANCELED + ) # Wait for non-blocking threads to complete. Ensure runner is not running. MockLiveActionPublisherNonBlocking.wait_all() # Check liveaction result. - self.assertIn('tasks', liveaction.result) - self.assertEqual(len(liveaction.result['tasks']), 1) + self.assertIn("tasks", liveaction.result) + self.assertEqual(len(liveaction.result["tasks"]), 1) - subworkflow = liveaction.result['tasks'][0] - self.assertEqual(len(subworkflow['result']['tasks']), 1) - self.assertEqual(subworkflow['state'], action_constants.LIVEACTION_STATUS_CANCELED) + subworkflow = liveaction.result["tasks"][0] + self.assertEqual(len(subworkflow["result"]["tasks"]), 1) + self.assertEqual( + subworkflow["state"], action_constants.LIVEACTION_STATUS_CANCELED + ) def test_chain_cancel_cascade_to_parent_workflow(self): # A temp file is created during test setup. Ensure the temp file exists. @@ -220,14 +230,16 @@ def test_chain_cancel_cascade_to_parent_workflow(self): path = self.temp_file_path self.assertTrue(os.path.exists(path)) - action = TEST_PACK + '.' + 'test_cancel_with_subworkflow' - params = {'tempfile': path, 'message': 'foobar'} + action = TEST_PACK + "." + "test_cancel_with_subworkflow" + params = {"tempfile": path, "message": "foobar"} liveaction = LiveActionDB(action=action, parameters=params) liveaction, execution = action_service.request(liveaction) liveaction = LiveAction.get_by_id(str(liveaction.id)) # Wait until the liveaction is running. - liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_RUNNING) + liveaction = self._wait_on_status( + liveaction, action_constants.LIVEACTION_STATUS_RUNNING + ) # Wait for subworkflow to register. execution = self._wait_for_children(execution) @@ -235,16 +247,22 @@ def test_chain_cancel_cascade_to_parent_workflow(self): # Wait until the subworkflow is running. task1_exec = ActionExecution.get_by_id(execution.children[0]) - task1_live = LiveAction.get_by_id(task1_exec.liveaction['id']) - task1_live = self._wait_on_status(task1_live, action_constants.LIVEACTION_STATUS_RUNNING) + task1_live = LiveAction.get_by_id(task1_exec.liveaction["id"]) + task1_live = self._wait_on_status( + task1_live, action_constants.LIVEACTION_STATUS_RUNNING + ) # Request subworkflow to cancel. - task1_live, task1_exec = action_service.request_cancellation(task1_live, USERNAME) + task1_live, task1_exec = action_service.request_cancellation( + task1_live, USERNAME + ) # Wait until the subworkflow is canceling. task1_exec = ActionExecution.get_by_id(execution.children[0]) - task1_live = LiveAction.get_by_id(task1_exec.liveaction['id']) - task1_live = self._wait_on_status(task1_live, action_constants.LIVEACTION_STATUS_CANCELING) + task1_live = LiveAction.get_by_id(task1_exec.liveaction["id"]) + task1_live = self._wait_on_status( + task1_live, action_constants.LIVEACTION_STATUS_CANCELING + ) # Delete the temporary file that the action chain is waiting on. os.remove(path) @@ -252,20 +270,26 @@ def test_chain_cancel_cascade_to_parent_workflow(self): # Wait until the subworkflow is canceled. task1_exec = ActionExecution.get_by_id(execution.children[0]) - task1_live = LiveAction.get_by_id(task1_exec.liveaction['id']) - task1_live = self._wait_on_status(task1_live, action_constants.LIVEACTION_STATUS_CANCELED) + task1_live = LiveAction.get_by_id(task1_exec.liveaction["id"]) + task1_live = self._wait_on_status( + task1_live, action_constants.LIVEACTION_STATUS_CANCELED + ) # Wait until the parent liveaction is canceled. - liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_CANCELED) + liveaction = self._wait_on_status( + liveaction, action_constants.LIVEACTION_STATUS_CANCELED + ) self.assertEqual(len(execution.children), 1) # Wait for non-blocking threads to complete. Ensure runner is not running. MockLiveActionPublisherNonBlocking.wait_all() # Check liveaction result. - self.assertIn('tasks', liveaction.result) - self.assertEqual(len(liveaction.result['tasks']), 1) + self.assertIn("tasks", liveaction.result) + self.assertEqual(len(liveaction.result["tasks"]), 1) - subworkflow = liveaction.result['tasks'][0] - self.assertEqual(len(subworkflow['result']['tasks']), 1) - self.assertEqual(subworkflow['state'], action_constants.LIVEACTION_STATUS_CANCELED) + subworkflow = liveaction.result["tasks"][0] + self.assertEqual(len(subworkflow["result"]["tasks"]), 1) + self.assertEqual( + subworkflow["state"], action_constants.LIVEACTION_STATUS_CANCELED + ) diff --git a/contrib/runners/action_chain_runner/tests/unit/test_actionchain_notifications.py b/contrib/runners/action_chain_runner/tests/unit/test_actionchain_notifications.py index 193d6064a1f..7997869b130 100644 --- a/contrib/runners/action_chain_runner/tests/unit/test_actionchain_notifications.py +++ b/contrib/runners/action_chain_runner/tests/unit/test_actionchain_notifications.py @@ -27,51 +27,53 @@ class DummyActionExecution(object): - def __init__(self, status=LIVEACTION_STATUS_SUCCEEDED, result=''): + def __init__(self, status=LIVEACTION_STATUS_SUCCEEDED, result=""): self.id = None self.status = status self.result = result -FIXTURES_PACK = 'generic' +FIXTURES_PACK = "generic" -TEST_MODELS = { - 'actions': ['a1.yaml', 'a2.yaml'], - 'runners': ['testrunner1.yaml'] -} +TEST_MODELS = {"actions": ["a1.yaml", "a2.yaml"], "runners": ["testrunner1.yaml"]} -MODELS = FixturesLoader().load_models(fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_MODELS) -ACTION_1 = MODELS['actions']['a1.yaml'] -ACTION_2 = MODELS['actions']['a2.yaml'] -RUNNER = MODELS['runners']['testrunner1.yaml'] +MODELS = FixturesLoader().load_models( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS +) +ACTION_1 = MODELS["actions"]["a1.yaml"] +ACTION_2 = MODELS["actions"]["a2.yaml"] +RUNNER = MODELS["runners"]["testrunner1.yaml"] CHAIN_1_PATH = FixturesLoader().get_fixture_file_path_abs( - FIXTURES_PACK, 'actionchains', 'chain_with_notifications.yaml') + FIXTURES_PACK, "actionchains", "chain_with_notifications.yaml" +) @mock.patch.object( - action_db_util, - 'get_runnertype_by_name', - mock.MagicMock(return_value=RUNNER)) + action_db_util, "get_runnertype_by_name", mock.MagicMock(return_value=RUNNER) +) @mock.patch.object( action_service, - 'is_action_canceled_or_canceling', - mock.MagicMock(return_value=False)) + "is_action_canceled_or_canceling", + mock.MagicMock(return_value=False), +) @mock.patch.object( - action_service, - 'is_action_paused_or_pausing', - mock.MagicMock(return_value=False)) + action_service, "is_action_paused_or_pausing", mock.MagicMock(return_value=False) +) class TestActionChainNotifications(ExecutionDbTestCase): - - @mock.patch.object(action_db_util, 'get_action_by_ref', - mock.MagicMock(return_value=ACTION_1)) - @mock.patch.object(action_service, 'request', return_value=(DummyActionExecution(), None)) + @mock.patch.object( + action_db_util, "get_action_by_ref", mock.MagicMock(return_value=ACTION_1) + ) + @mock.patch.object( + action_service, "request", return_value=(DummyActionExecution(), None) + ) def test_chain_runner_success_path(self, request): chain_runner = acr.get_runner() chain_runner.entry_point = CHAIN_1_PATH chain_runner.action = ACTION_1 - action_ref = ResourceReference.to_string_reference(name=ACTION_1.name, pack=ACTION_1.pack) + action_ref = ResourceReference.to_string_reference( + name=ACTION_1.name, pack=ACTION_1.pack + ) chain_runner.liveaction = LiveActionDB(action=action_ref) chain_runner.pre_run() chain_runner.run({}) @@ -79,8 +81,8 @@ def test_chain_runner_success_path(self, request): self.assertEqual(request.call_count, 2) first_call_args = request.call_args_list[0][0] liveaction_db = first_call_args[0] - self.assertTrue(liveaction_db.notify, 'Notify property expected.') + self.assertTrue(liveaction_db.notify, "Notify property expected.") second_call_args = request.call_args_list[1][0] liveaction_db = second_call_args[0] - self.assertFalse(liveaction_db.notify, 'Notify property not expected.') + self.assertFalse(liveaction_db.notify, "Notify property not expected.") diff --git a/contrib/runners/action_chain_runner/tests/unit/test_actionchain_params_rendering.py b/contrib/runners/action_chain_runner/tests/unit/test_actionchain_params_rendering.py index 6fa4c6b4560..d6278ca61ae 100644 --- a/contrib/runners/action_chain_runner/tests/unit/test_actionchain_params_rendering.py +++ b/contrib/runners/action_chain_runner/tests/unit/test_actionchain_params_rendering.py @@ -25,96 +25,96 @@ class ActionChainRunnerResolveParamsTests(unittest2.TestCase): - def test_render_params_action_context(self): runner = acr.get_runner() chain_context = { - 'parent': { - 'execution_id': 'some_awesome_exec_id', - 'user': 'dad' - }, - 'user': 'son', - 'k1': 'v1' + "parent": {"execution_id": "some_awesome_exec_id", "user": "dad"}, + "user": "son", + "k1": "v1", } task_params = { - 'exec_id': {'default': '{{action_context.parent.execution_id}}'}, - 'k2': {}, - 'foo': {'default': 1} + "exec_id": {"default": "{{action_context.parent.execution_id}}"}, + "k2": {}, + "foo": {"default": 1}, } - action_node = Node(name='test_action_context_params', ref='core.local', params=task_params) + action_node = Node( + name="test_action_context_params", ref="core.local", params=task_params + ) rendered_params = runner._resolve_params(action_node, {}, {}, {}, chain_context) - self.assertEqual(rendered_params['exec_id']['default'], 'some_awesome_exec_id') + self.assertEqual(rendered_params["exec_id"]["default"], "some_awesome_exec_id") def test_render_params_action_context_non_existent_member(self): runner = acr.get_runner() chain_context = { - 'parent': { - 'execution_id': 'some_awesome_exec_id', - 'user': 'dad' - }, - 'user': 'son', - 'k1': 'v1' + "parent": {"execution_id": "some_awesome_exec_id", "user": "dad"}, + "user": "son", + "k1": "v1", } task_params = { - 'exec_id': {'default': '{{action_context.parent.yo_gimme_tha_key}}'}, - 'k2': {}, - 'foo': {'default': 1} + "exec_id": {"default": "{{action_context.parent.yo_gimme_tha_key}}"}, + "k2": {}, + "foo": {"default": 1}, } - action_node = Node(name='test_action_context_params', ref='core.local', params=task_params) + action_node = Node( + name="test_action_context_params", ref="core.local", params=task_params + ) try: runner._resolve_params(action_node, {}, {}, {}, chain_context) - self.fail('Should have thrown an instance of %s' % ParameterRenderingFailedException) + self.fail( + "Should have thrown an instance of %s" + % ParameterRenderingFailedException + ) except ParameterRenderingFailedException: pass def test_render_params_with_config(self): - with mock.patch('st2common.util.config_loader.ContentPackConfigLoader') as config_loader: + with mock.patch( + "st2common.util.config_loader.ContentPackConfigLoader" + ) as config_loader: config_loader().get_config.return_value = { - 'amazing_config_value_fo_lyfe': 'no' + "amazing_config_value_fo_lyfe": "no" } runner = acr.get_runner() chain_context = { - 'parent': { - 'execution_id': 'some_awesome_exec_id', - 'user': 'dad', - 'pack': 'mom' + "parent": { + "execution_id": "some_awesome_exec_id", + "user": "dad", + "pack": "mom", }, - 'user': 'son', + "user": "son", } task_params = { - 'config_val': '{{config_context.amazing_config_value_fo_lyfe}}' + "config_val": "{{config_context.amazing_config_value_fo_lyfe}}" } action_node = Node( - name='test_action_context_params', - ref='core.local', - params=task_params + name="test_action_context_params", ref="core.local", params=task_params + ) + rendered_params = runner._resolve_params( + action_node, {}, {}, {}, chain_context ) - rendered_params = runner._resolve_params(action_node, {}, {}, {}, chain_context) - self.assertEqual(rendered_params['config_val'], 'no') + self.assertEqual(rendered_params["config_val"], "no") def test_init_params_vars_with_unicode_value(self): chain_spec = { - 'vars': { - 'unicode_var': u'٩(̾●̮̮̃̾•̃̾)۶ ٩(̾●̮̮̃̾•̃̾)۶ ćšž', - 'unicode_var_param': u'{{ param }}' + "vars": { + "unicode_var": "٩(̾●̮̮̃̾•̃̾)۶ ٩(̾●̮̮̃̾•̃̾)۶ ćšž", + "unicode_var_param": "{{ param }}", }, - 'chain': [ + "chain": [ { - 'name': 'c1', - 'ref': 'core.local', - 'parameters': { - 'cmd': 'echo {{ unicode_var }}' - } + "name": "c1", + "ref": "core.local", + "parameters": {"cmd": "echo {{ unicode_var }}"}, } - ] + ], } - chain_holder = acr.ChainHolder(chainspec=chain_spec, chainname='foo') - chain_holder.init_vars(action_parameters={'param': u'٩(̾●̮̮̃̾•̃̾)۶'}) + chain_holder = acr.ChainHolder(chainspec=chain_spec, chainname="foo") + chain_holder.init_vars(action_parameters={"param": "٩(̾●̮̮̃̾•̃̾)۶"}) expected = { - 'unicode_var': u'٩(̾●̮̮̃̾•̃̾)۶ ٩(̾●̮̮̃̾•̃̾)۶ ćšž', - 'unicode_var_param': u'٩(̾●̮̮̃̾•̃̾)۶' + "unicode_var": "٩(̾●̮̮̃̾•̃̾)۶ ٩(̾●̮̮̃̾•̃̾)۶ ćšž", + "unicode_var_param": "٩(̾●̮̮̃̾•̃̾)۶", } self.assertEqual(chain_holder.vars, expected) diff --git a/contrib/runners/action_chain_runner/tests/unit/test_actionchain_pause_resume.py b/contrib/runners/action_chain_runner/tests/unit/test_actionchain_pause_resume.py index c093c2061c1..46f948d73ab 100644 --- a/contrib/runners/action_chain_runner/tests/unit/test_actionchain_pause_resume.py +++ b/contrib/runners/action_chain_runner/tests/unit/test_actionchain_pause_resume.py @@ -20,6 +20,7 @@ import tempfile from st2tests import config as test_config + test_config.parse_args() from st2common.bootstrap import actionsregistrar @@ -42,53 +43,45 @@ TEST_FIXTURES = { - 'chains': [ - 'test_pause_resume.yaml', - 'test_pause_resume_context_result', - 'test_pause_resume_with_published_vars.yaml', - 'test_pause_resume_with_error.yaml', - 'test_pause_resume_with_subworkflow.yaml', - 'test_pause_resume_with_context_access.yaml', - 'test_pause_resume_with_init_vars.yaml', - 'test_pause_resume_with_no_more_task.yaml', - 'test_pause_resume_last_task_failed_with_no_next_task.yaml' + "chains": [ + "test_pause_resume.yaml", + "test_pause_resume_context_result", + "test_pause_resume_with_published_vars.yaml", + "test_pause_resume_with_error.yaml", + "test_pause_resume_with_subworkflow.yaml", + "test_pause_resume_with_context_access.yaml", + "test_pause_resume_with_init_vars.yaml", + "test_pause_resume_with_no_more_task.yaml", + "test_pause_resume_last_task_failed_with_no_next_task.yaml", + ], + "actions": [ + "test_pause_resume.yaml", + "test_pause_resume_context_result", + "test_pause_resume_with_published_vars.yaml", + "test_pause_resume_with_error.yaml", + "test_pause_resume_with_subworkflow.yaml", + "test_pause_resume_with_context_access.yaml", + "test_pause_resume_with_init_vars.yaml", + "test_pause_resume_with_no_more_task.yaml", + "test_pause_resume_last_task_failed_with_no_next_task.yaml", ], - 'actions': [ - 'test_pause_resume.yaml', - 'test_pause_resume_context_result', - 'test_pause_resume_with_published_vars.yaml', - 'test_pause_resume_with_error.yaml', - 'test_pause_resume_with_subworkflow.yaml', - 'test_pause_resume_with_context_access.yaml', - 'test_pause_resume_with_init_vars.yaml', - 'test_pause_resume_with_no_more_task.yaml', - 'test_pause_resume_last_task_failed_with_no_next_task.yaml' - ] } -TEST_PACK = 'action_chain_tests' -TEST_PACK_PATH = fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "action_chain_tests" +TEST_PACK_PATH = fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK -PACKS = [ - TEST_PACK_PATH, - fixturesloader.get_fixtures_packs_base_path() + '/core' -] +PACKS = [TEST_PACK_PATH, fixturesloader.get_fixtures_packs_base_path() + "/core"] -USERNAME = 'stanley' +USERNAME = "stanley" -@mock.patch.object( - CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) -@mock.patch.object( - CUDPublisher, - 'publish_create', - mock.MagicMock(return_value=None)) +@mock.patch.object(CUDPublisher, "publish_update", mock.MagicMock(return_value=None)) +@mock.patch.object(CUDPublisher, "publish_create", mock.MagicMock(return_value=None)) @mock.patch.object( LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=MockLiveActionPublisherNonBlocking.publish_state)) + "publish_state", + mock.MagicMock(side_effect=MockLiveActionPublisherNonBlocking.publish_state), +) class ActionChainRunnerPauseResumeTest(ExecutionDbTestCase): temp_file_path = None @@ -102,8 +95,7 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: @@ -114,7 +106,7 @@ def setUp(self): # Create temporary directory used by the tests _, self.temp_file_path = tempfile.mkstemp() - os.chmod(self.temp_file_path, 0o755) # nosec + os.chmod(self.temp_file_path, 0o755) # nosec def tearDown(self): if self.temp_file_path and os.path.exists(self.temp_file_path): @@ -138,7 +130,7 @@ def _wait_for_children(self, execution, interval=0.1, retries=100): # Wait until the execution has children. for i in range(0, retries): execution = ActionExecution.get_by_id(str(execution.id)) - if len(getattr(execution, 'children', [])) <= 0: + if len(getattr(execution, "children", [])) <= 0: eventlet.sleep(interval) continue @@ -151,32 +143,42 @@ def test_chain_pause_resume(self): path = self.temp_file_path self.assertTrue(os.path.exists(path)) - action = TEST_PACK + '.' + 'test_pause_resume' - params = {'tempfile': path, 'message': 'foobar'} + action = TEST_PACK + "." + "test_pause_resume" + params = {"tempfile": path, "message": "foobar"} liveaction = LiveActionDB(action=action, parameters=params) liveaction, execution = action_service.request(liveaction) liveaction = LiveAction.get_by_id(str(liveaction.id)) # Wait until the liveaction is running. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_RUNNING) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_RUNNING + ) self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_RUNNING) # Request action chain to pause. liveaction, execution = action_service.request_pause(liveaction, USERNAME) # Wait until the liveaction is pausing. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSING) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSING + ) extra_info = str(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info + ) # Delete the temporary file that the action chain is waiting on. os.remove(path) self.assertFalse(os.path.exists(path)) # Wait until the liveaction is paused. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSED + ) extra_info = str(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info + ) # Wait for non-blocking threads to complete. Ensure runner is not running. MockLiveActionPublisherNonBlocking.wait_all() @@ -185,15 +187,19 @@ def test_chain_pause_resume(self): liveaction, execution = action_service.request_resume(liveaction, USERNAME) # Wait until the liveaction is completed. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) # Wait for non-blocking threads to complete. MockLiveActionPublisherNonBlocking.wait_all() # Check liveaction result. - self.assertIn('tasks', liveaction.result) - self.assertEqual(len(liveaction.result['tasks']), 2) + self.assertIn("tasks", liveaction.result) + self.assertEqual(len(liveaction.result["tasks"]), 2) def test_chain_pause_resume_with_published_vars(self): # A temp file is created during test setup. Ensure the temp file exists. @@ -202,32 +208,42 @@ def test_chain_pause_resume_with_published_vars(self): path = self.temp_file_path self.assertTrue(os.path.exists(path)) - action = TEST_PACK + '.' + 'test_pause_resume_with_published_vars' - params = {'tempfile': path, 'message': 'foobar'} + action = TEST_PACK + "." + "test_pause_resume_with_published_vars" + params = {"tempfile": path, "message": "foobar"} liveaction = LiveActionDB(action=action, parameters=params) liveaction, execution = action_service.request(liveaction) liveaction = LiveAction.get_by_id(str(liveaction.id)) # Wait until the liveaction is running. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_RUNNING) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_RUNNING + ) self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_RUNNING) # Request action chain to pause. liveaction, execution = action_service.request_pause(liveaction, USERNAME) # Wait until the liveaction is pausing. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSING) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSING + ) extra_info = str(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info + ) # Delete the temporary file that the action chain is waiting on. os.remove(path) self.assertFalse(os.path.exists(path)) # Wait until the liveaction is paused. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSED + ) extra_info = str(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info + ) # Wait for non-blocking threads to complete. Ensure runner is not running. MockLiveActionPublisherNonBlocking.wait_all() @@ -236,17 +252,23 @@ def test_chain_pause_resume_with_published_vars(self): liveaction, execution = action_service.request_resume(liveaction, USERNAME) # Wait until the liveaction is completed. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) # Wait for non-blocking threads to complete. MockLiveActionPublisherNonBlocking.wait_all() # Check liveaction result. - self.assertIn('tasks', liveaction.result) - self.assertEqual(len(liveaction.result['tasks']), 2) - self.assertIn('published', liveaction.result) - self.assertDictEqual({'var1': 'foobar', 'var2': 'fubar'}, liveaction.result['published']) + self.assertIn("tasks", liveaction.result) + self.assertEqual(len(liveaction.result["tasks"]), 2) + self.assertIn("published", liveaction.result) + self.assertDictEqual( + {"var1": "foobar", "var2": "fubar"}, liveaction.result["published"] + ) def test_chain_pause_resume_with_published_vars_display_false(self): # A temp file is created during test setup. Ensure the temp file exists. @@ -255,32 +277,42 @@ def test_chain_pause_resume_with_published_vars_display_false(self): path = self.temp_file_path self.assertTrue(os.path.exists(path)) - action = TEST_PACK + '.' + 'test_pause_resume_with_published_vars' - params = {'tempfile': path, 'message': 'foobar', 'display_published': False} + action = TEST_PACK + "." + "test_pause_resume_with_published_vars" + params = {"tempfile": path, "message": "foobar", "display_published": False} liveaction = LiveActionDB(action=action, parameters=params) liveaction, execution = action_service.request(liveaction) liveaction = LiveAction.get_by_id(str(liveaction.id)) # Wait until the liveaction is running. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_RUNNING) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_RUNNING + ) self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_RUNNING) # Request action chain to pause. liveaction, execution = action_service.request_pause(liveaction, USERNAME) # Wait until the liveaction is pausing. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSING) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSING + ) extra_info = str(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info + ) # Delete the temporary file that the action chain is waiting on. os.remove(path) self.assertFalse(os.path.exists(path)) # Wait until the liveaction is paused. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSED + ) extra_info = str(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info + ) # Wait for non-blocking threads to complete. Ensure runner is not running. MockLiveActionPublisherNonBlocking.wait_all() @@ -289,16 +321,20 @@ def test_chain_pause_resume_with_published_vars_display_false(self): liveaction, execution = action_service.request_resume(liveaction, USERNAME) # Wait until the liveaction is completed. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) # Wait for non-blocking threads to complete. MockLiveActionPublisherNonBlocking.wait_all() # Check liveaction result. - self.assertIn('tasks', liveaction.result) - self.assertEqual(len(liveaction.result['tasks']), 2) - self.assertNotIn('published', liveaction.result) + self.assertIn("tasks", liveaction.result) + self.assertEqual(len(liveaction.result["tasks"]), 2) + self.assertNotIn("published", liveaction.result) def test_chain_pause_resume_with_error(self): # A temp file is created during test setup. Ensure the temp file exists. @@ -307,32 +343,42 @@ def test_chain_pause_resume_with_error(self): path = self.temp_file_path self.assertTrue(os.path.exists(path)) - action = TEST_PACK + '.' + 'test_pause_resume_with_error' - params = {'tempfile': path, 'message': 'foobar'} + action = TEST_PACK + "." + "test_pause_resume_with_error" + params = {"tempfile": path, "message": "foobar"} liveaction = LiveActionDB(action=action, parameters=params) liveaction, execution = action_service.request(liveaction) liveaction = LiveAction.get_by_id(str(liveaction.id)) # Wait until the liveaction is running. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_RUNNING) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_RUNNING + ) self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_RUNNING) # Request action chain to pause. liveaction, execution = action_service.request_pause(liveaction, USERNAME) # Wait until the liveaction is pausing. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSING) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSING + ) extra_info = str(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info + ) # Delete the temporary file that the action chain is waiting on. os.remove(path) self.assertFalse(os.path.exists(path)) # Wait until the liveaction is paused. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSED + ) extra_info = str(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info + ) # Wait for non-blocking threads to complete. Ensure runner is not running. MockLiveActionPublisherNonBlocking.wait_all() @@ -341,19 +387,23 @@ def test_chain_pause_resume_with_error(self): liveaction, execution = action_service.request_resume(liveaction, USERNAME) # Wait until the liveaction is completed. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) # Wait for non-blocking threads to complete. MockLiveActionPublisherNonBlocking.wait_all() # Check liveaction result. - self.assertIn('tasks', liveaction.result) - self.assertEqual(len(liveaction.result['tasks']), 2) - self.assertTrue(liveaction.result['tasks'][0]['result']['failed']) - self.assertEqual(1, liveaction.result['tasks'][0]['result']['return_code']) - self.assertTrue(liveaction.result['tasks'][1]['result']['succeeded']) - self.assertEqual(0, liveaction.result['tasks'][1]['result']['return_code']) + self.assertIn("tasks", liveaction.result) + self.assertEqual(len(liveaction.result["tasks"]), 2) + self.assertTrue(liveaction.result["tasks"][0]["result"]["failed"]) + self.assertEqual(1, liveaction.result["tasks"][0]["result"]["return_code"]) + self.assertTrue(liveaction.result["tasks"][1]["result"]["succeeded"]) + self.assertEqual(0, liveaction.result["tasks"][1]["result"]["return_code"]) def test_chain_pause_resume_cascade_to_subworkflow(self): # A temp file is created during test setup. Ensure the temp file exists. @@ -362,14 +412,16 @@ def test_chain_pause_resume_cascade_to_subworkflow(self): path = self.temp_file_path self.assertTrue(os.path.exists(path)) - action = TEST_PACK + '.' + 'test_pause_resume_with_subworkflow' - params = {'tempfile': path, 'message': 'foobar'} + action = TEST_PACK + "." + "test_pause_resume_with_subworkflow" + params = {"tempfile": path, "message": "foobar"} liveaction = LiveActionDB(action=action, parameters=params) liveaction, execution = action_service.request(liveaction) liveaction = LiveAction.get_by_id(str(liveaction.id)) # Wait until the liveaction is running. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_RUNNING) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_RUNNING + ) self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_RUNNING) # Wait for subworkflow to register. @@ -378,71 +430,97 @@ def test_chain_pause_resume_cascade_to_subworkflow(self): # Wait until the subworkflow is running. task1_exec = ActionExecution.get_by_id(execution.children[0]) - task1_live = LiveAction.get_by_id(task1_exec.liveaction['id']) - task1_live = self._wait_for_status(task1_live, action_constants.LIVEACTION_STATUS_RUNNING) + task1_live = LiveAction.get_by_id(task1_exec.liveaction["id"]) + task1_live = self._wait_for_status( + task1_live, action_constants.LIVEACTION_STATUS_RUNNING + ) self.assertEqual(task1_live.status, action_constants.LIVEACTION_STATUS_RUNNING) # Request action chain to pause. liveaction, execution = action_service.request_pause(liveaction, USERNAME) # Wait until the liveaction is pausing. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSING) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSING + ) extra_info = str(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info + ) self.assertEqual(len(execution.children), 1) # Wait until the subworkflow is pausing. task1_exec = ActionExecution.get_by_id(execution.children[0]) - task1_live = LiveAction.get_by_id(task1_exec.liveaction['id']) - task1_live = self._wait_for_status(task1_live, action_constants.LIVEACTION_STATUS_PAUSING) + task1_live = LiveAction.get_by_id(task1_exec.liveaction["id"]) + task1_live = self._wait_for_status( + task1_live, action_constants.LIVEACTION_STATUS_PAUSING + ) extra_info = str(task1_live) - self.assertEqual(task1_live.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info) + self.assertEqual( + task1_live.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info + ) # Delete the temporary file that the action chain is waiting on. os.remove(path) self.assertFalse(os.path.exists(path)) # Wait until the liveaction is paused. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSED + ) extra_info = str(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info + ) self.assertEqual(len(execution.children), 1) # Wait until the subworkflow is paused. task1_exec = ActionExecution.get_by_id(execution.children[0]) - task1_live = LiveAction.get_by_id(task1_exec.liveaction['id']) - task1_live = self._wait_for_status(task1_live, action_constants.LIVEACTION_STATUS_PAUSED) + task1_live = LiveAction.get_by_id(task1_exec.liveaction["id"]) + task1_live = self._wait_for_status( + task1_live, action_constants.LIVEACTION_STATUS_PAUSED + ) extra_info = str(task1_live) - self.assertEqual(task1_live.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info) + self.assertEqual( + task1_live.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info + ) # Wait for non-blocking threads to complete. Ensure runner is not running. MockLiveActionPublisherNonBlocking.wait_all() # Check liveaction result. - self.assertIn('tasks', liveaction.result) - self.assertEqual(len(liveaction.result['tasks']), 1) + self.assertIn("tasks", liveaction.result) + self.assertEqual(len(liveaction.result["tasks"]), 1) - subworkflow = liveaction.result['tasks'][0] - self.assertEqual(len(subworkflow['result']['tasks']), 1) - self.assertEqual(subworkflow['state'], action_constants.LIVEACTION_STATUS_PAUSED) + subworkflow = liveaction.result["tasks"][0] + self.assertEqual(len(subworkflow["result"]["tasks"]), 1) + self.assertEqual( + subworkflow["state"], action_constants.LIVEACTION_STATUS_PAUSED + ) # Request action chain to resume. liveaction, execution = action_service.request_resume(liveaction, USERNAME) # Wait until the liveaction is completed. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) # Wait for non-blocking threads to complete. MockLiveActionPublisherNonBlocking.wait_all() # Check liveaction result. - self.assertIn('tasks', liveaction.result) - self.assertEqual(len(liveaction.result['tasks']), 2) + self.assertIn("tasks", liveaction.result) + self.assertEqual(len(liveaction.result["tasks"]), 2) - subworkflow = liveaction.result['tasks'][0] - self.assertEqual(len(subworkflow['result']['tasks']), 2) - self.assertEqual(subworkflow['state'], action_constants.LIVEACTION_STATUS_SUCCEEDED) + subworkflow = liveaction.result["tasks"][0] + self.assertEqual(len(subworkflow["result"]["tasks"]), 2) + self.assertEqual( + subworkflow["state"], action_constants.LIVEACTION_STATUS_SUCCEEDED + ) def test_chain_pause_resume_cascade_to_parent_workflow(self): # A temp file is created during test setup. Ensure the temp file exists. @@ -451,14 +529,16 @@ def test_chain_pause_resume_cascade_to_parent_workflow(self): path = self.temp_file_path self.assertTrue(os.path.exists(path)) - action = TEST_PACK + '.' + 'test_pause_resume_with_subworkflow' - params = {'tempfile': path, 'message': 'foobar'} + action = TEST_PACK + "." + "test_pause_resume_with_subworkflow" + params = {"tempfile": path, "message": "foobar"} liveaction = LiveActionDB(action=action, parameters=params) liveaction, execution = action_service.request(liveaction) liveaction = LiveAction.get_by_id(str(liveaction.id)) # Wait until the liveaction is running. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_RUNNING) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_RUNNING + ) self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_RUNNING) # Wait for subworkflow to register. @@ -467,8 +547,10 @@ def test_chain_pause_resume_cascade_to_parent_workflow(self): # Wait until the subworkflow is running. task1_exec = ActionExecution.get_by_id(execution.children[0]) - task1_live = LiveAction.get_by_id(task1_exec.liveaction['id']) - task1_live = self._wait_for_status(task1_live, action_constants.LIVEACTION_STATUS_RUNNING) + task1_live = LiveAction.get_by_id(task1_exec.liveaction["id"]) + task1_live = self._wait_for_status( + task1_live, action_constants.LIVEACTION_STATUS_RUNNING + ) self.assertEqual(task1_live.status, action_constants.LIVEACTION_STATUS_RUNNING) # Request subworkflow to pause. @@ -476,10 +558,14 @@ def test_chain_pause_resume_cascade_to_parent_workflow(self): # Wait until the subworkflow is pausing. task1_exec = ActionExecution.get_by_id(execution.children[0]) - task1_live = LiveAction.get_by_id(task1_exec.liveaction['id']) - task1_live = self._wait_for_status(task1_live, action_constants.LIVEACTION_STATUS_PAUSING) + task1_live = LiveAction.get_by_id(task1_exec.liveaction["id"]) + task1_live = self._wait_for_status( + task1_live, action_constants.LIVEACTION_STATUS_PAUSING + ) extra_info = str(task1_live) - self.assertEqual(task1_live.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info) + self.assertEqual( + task1_live.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info + ) # Delete the temporary file that the action chain is waiting on. os.remove(path) @@ -487,39 +573,55 @@ def test_chain_pause_resume_cascade_to_parent_workflow(self): # Wait until the subworkflow is paused. task1_exec = ActionExecution.get_by_id(execution.children[0]) - task1_live = LiveAction.get_by_id(task1_exec.liveaction['id']) - task1_live = self._wait_for_status(task1_live, action_constants.LIVEACTION_STATUS_PAUSED) + task1_live = LiveAction.get_by_id(task1_exec.liveaction["id"]) + task1_live = self._wait_for_status( + task1_live, action_constants.LIVEACTION_STATUS_PAUSED + ) extra_info = str(task1_live) - self.assertEqual(task1_live.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info) + self.assertEqual( + task1_live.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info + ) # Wait until the parent liveaction is paused. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSED + ) extra_info = str(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info + ) self.assertEqual(len(execution.children), 1) # Wait for non-blocking threads to complete. Ensure runner is not running. MockLiveActionPublisherNonBlocking.wait_all() # Check liveaction result. - self.assertIn('tasks', liveaction.result) - self.assertEqual(len(liveaction.result['tasks']), 1) + self.assertIn("tasks", liveaction.result) + self.assertEqual(len(liveaction.result["tasks"]), 1) - subworkflow = liveaction.result['tasks'][0] - self.assertEqual(len(subworkflow['result']['tasks']), 1) - self.assertEqual(subworkflow['state'], action_constants.LIVEACTION_STATUS_PAUSED) + subworkflow = liveaction.result["tasks"][0] + self.assertEqual(len(subworkflow["result"]["tasks"]), 1) + self.assertEqual( + subworkflow["state"], action_constants.LIVEACTION_STATUS_PAUSED + ) # Request subworkflow to resume. task1_live, task1_exec = action_service.request_resume(task1_live, USERNAME) # Wait until the subworkflow is paused. task1_exec = ActionExecution.get_by_id(execution.children[0]) - task1_live = LiveAction.get_by_id(task1_exec.liveaction['id']) - task1_live = self._wait_for_status(task1_live, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(task1_live.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + task1_live = LiveAction.get_by_id(task1_exec.liveaction["id"]) + task1_live = self._wait_for_status( + task1_live, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) + self.assertEqual( + task1_live.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) # The parent workflow will stay paused. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSED + ) self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED) # Wait for non-blocking threads to complete. @@ -527,30 +629,38 @@ def test_chain_pause_resume_cascade_to_parent_workflow(self): # Check liveaction result of the parent, which should stay the same # because only the subworkflow was resumed. - self.assertIn('tasks', liveaction.result) - self.assertEqual(len(liveaction.result['tasks']), 1) + self.assertIn("tasks", liveaction.result) + self.assertEqual(len(liveaction.result["tasks"]), 1) - subworkflow = liveaction.result['tasks'][0] - self.assertEqual(len(subworkflow['result']['tasks']), 1) - self.assertEqual(subworkflow['state'], action_constants.LIVEACTION_STATUS_PAUSED) + subworkflow = liveaction.result["tasks"][0] + self.assertEqual(len(subworkflow["result"]["tasks"]), 1) + self.assertEqual( + subworkflow["state"], action_constants.LIVEACTION_STATUS_PAUSED + ) # Request parent workflow to resume. liveaction, execution = action_service.request_resume(liveaction, USERNAME) # Wait until the liveaction is completed. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) # Wait for non-blocking threads to complete. MockLiveActionPublisherNonBlocking.wait_all() # Check liveaction result. - self.assertIn('tasks', liveaction.result) - self.assertEqual(len(liveaction.result['tasks']), 2) + self.assertIn("tasks", liveaction.result) + self.assertEqual(len(liveaction.result["tasks"]), 2) - subworkflow = liveaction.result['tasks'][0] - self.assertEqual(len(subworkflow['result']['tasks']), 2) - self.assertEqual(subworkflow['state'], action_constants.LIVEACTION_STATUS_SUCCEEDED) + subworkflow = liveaction.result["tasks"][0] + self.assertEqual(len(subworkflow["result"]["tasks"]), 2) + self.assertEqual( + subworkflow["state"], action_constants.LIVEACTION_STATUS_SUCCEEDED + ) def test_chain_pause_resume_with_context_access(self): # A temp file is created during test setup. Ensure the temp file exists. @@ -559,32 +669,42 @@ def test_chain_pause_resume_with_context_access(self): path = self.temp_file_path self.assertTrue(os.path.exists(path)) - action = TEST_PACK + '.' + 'test_pause_resume_with_context_access' - params = {'tempfile': path, 'message': 'foobar'} + action = TEST_PACK + "." + "test_pause_resume_with_context_access" + params = {"tempfile": path, "message": "foobar"} liveaction = LiveActionDB(action=action, parameters=params) liveaction, execution = action_service.request(liveaction) liveaction = LiveAction.get_by_id(str(liveaction.id)) # Wait until the liveaction is running. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_RUNNING) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_RUNNING + ) self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_RUNNING) # Request action chain to pause. liveaction, execution = action_service.request_pause(liveaction, USERNAME) # Wait until the liveaction is pausing. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSING) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSING + ) extra_info = str(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info + ) # Delete the temporary file that the action chain is waiting on. os.remove(path) self.assertFalse(os.path.exists(path)) # Wait until the liveaction is paused. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSED + ) extra_info = str(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info + ) # Wait for non-blocking threads to complete. Ensure runner is not running. MockLiveActionPublisherNonBlocking.wait_all() @@ -593,16 +713,20 @@ def test_chain_pause_resume_with_context_access(self): liveaction, execution = action_service.request_resume(liveaction, USERNAME) # Wait until the liveaction is completed. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) # Wait for non-blocking threads to complete. MockLiveActionPublisherNonBlocking.wait_all() # Check liveaction result. - self.assertIn('tasks', liveaction.result) - self.assertEqual(len(liveaction.result['tasks']), 3) - self.assertEqual(liveaction.result['tasks'][2]['result']['stdout'], 'foobar') + self.assertIn("tasks", liveaction.result) + self.assertEqual(len(liveaction.result["tasks"]), 3) + self.assertEqual(liveaction.result["tasks"][2]["result"]["stdout"], "foobar") def test_chain_pause_resume_with_init_vars(self): # A temp file is created during test setup. Ensure the temp file exists. @@ -611,32 +735,42 @@ def test_chain_pause_resume_with_init_vars(self): path = self.temp_file_path self.assertTrue(os.path.exists(path)) - action = TEST_PACK + '.' + 'test_pause_resume_with_init_vars' - params = {'tempfile': path, 'message': 'foobar'} + action = TEST_PACK + "." + "test_pause_resume_with_init_vars" + params = {"tempfile": path, "message": "foobar"} liveaction = LiveActionDB(action=action, parameters=params) liveaction, execution = action_service.request(liveaction) liveaction = LiveAction.get_by_id(str(liveaction.id)) # Wait until the liveaction is running. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_RUNNING) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_RUNNING + ) self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_RUNNING) # Request action chain to pause. liveaction, execution = action_service.request_pause(liveaction, USERNAME) # Wait until the liveaction is pausing. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSING) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSING + ) extra_info = str(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info + ) # Delete the temporary file that the action chain is waiting on. os.remove(path) self.assertFalse(os.path.exists(path)) # Wait until the liveaction is paused. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSED + ) extra_info = str(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info + ) # Wait for non-blocking threads to complete. Ensure runner is not running. MockLiveActionPublisherNonBlocking.wait_all() @@ -645,16 +779,20 @@ def test_chain_pause_resume_with_init_vars(self): liveaction, execution = action_service.request_resume(liveaction, USERNAME) # Wait until the liveaction is completed. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) # Wait for non-blocking threads to complete. MockLiveActionPublisherNonBlocking.wait_all() # Check liveaction result. - self.assertIn('tasks', liveaction.result) - self.assertEqual(len(liveaction.result['tasks']), 2) - self.assertEqual(liveaction.result['tasks'][1]['result']['stdout'], 'FOOBAR') + self.assertIn("tasks", liveaction.result) + self.assertEqual(len(liveaction.result["tasks"]), 2) + self.assertEqual(liveaction.result["tasks"][1]["result"]["stdout"], "FOOBAR") def test_chain_pause_resume_with_no_more_task(self): # A temp file is created during test setup. Ensure the temp file exists. @@ -663,32 +801,42 @@ def test_chain_pause_resume_with_no_more_task(self): path = self.temp_file_path self.assertTrue(os.path.exists(path)) - action = TEST_PACK + '.' + 'test_pause_resume_with_no_more_task' - params = {'tempfile': path, 'message': 'foobar'} + action = TEST_PACK + "." + "test_pause_resume_with_no_more_task" + params = {"tempfile": path, "message": "foobar"} liveaction = LiveActionDB(action=action, parameters=params) liveaction, execution = action_service.request(liveaction) liveaction = LiveAction.get_by_id(str(liveaction.id)) # Wait until the liveaction is running. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_RUNNING) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_RUNNING + ) self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_RUNNING) # Request action chain to pause. liveaction, execution = action_service.request_pause(liveaction, USERNAME) # Wait until the liveaction is pausing. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSING) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSING + ) extra_info = str(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info + ) # Delete the temporary file that the action chain is waiting on. os.remove(path) self.assertFalse(os.path.exists(path)) # Wait until the liveaction is paused. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSED + ) extra_info = str(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info + ) # Wait for non-blocking threads to complete. Ensure runner is not running. MockLiveActionPublisherNonBlocking.wait_all() @@ -697,15 +845,19 @@ def test_chain_pause_resume_with_no_more_task(self): liveaction, execution = action_service.request_resume(liveaction, USERNAME) # Wait until the liveaction is completed. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) # Wait for non-blocking threads to complete. MockLiveActionPublisherNonBlocking.wait_all() # Check liveaction result. - self.assertIn('tasks', liveaction.result) - self.assertEqual(len(liveaction.result['tasks']), 1) + self.assertIn("tasks", liveaction.result) + self.assertEqual(len(liveaction.result["tasks"]), 1) def test_chain_pause_resume_last_task_failed_with_no_next_task(self): # A temp file is created during test setup. Ensure the temp file exists. @@ -714,32 +866,44 @@ def test_chain_pause_resume_last_task_failed_with_no_next_task(self): path = self.temp_file_path self.assertTrue(os.path.exists(path)) - action = TEST_PACK + '.' + 'test_pause_resume_last_task_failed_with_no_next_task' - params = {'tempfile': path, 'message': 'foobar'} + action = ( + TEST_PACK + "." + "test_pause_resume_last_task_failed_with_no_next_task" + ) + params = {"tempfile": path, "message": "foobar"} liveaction = LiveActionDB(action=action, parameters=params) liveaction, execution = action_service.request(liveaction) liveaction = LiveAction.get_by_id(str(liveaction.id)) # Wait until the liveaction is running. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_RUNNING) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_RUNNING + ) self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_RUNNING) # Request action chain to pause. liveaction, execution = action_service.request_pause(liveaction, USERNAME) # Wait until the liveaction is pausing. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSING) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSING + ) extra_info = str(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_PAUSING, extra_info + ) # Delete the temporary file that the action chain is waiting on. os.remove(path) self.assertFalse(os.path.exists(path)) # Wait until the liveaction is paused. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSED + ) extra_info = str(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info + ) # Wait for non-blocking threads to complete. Ensure runner is not running. MockLiveActionPublisherNonBlocking.wait_all() @@ -748,62 +912,70 @@ def test_chain_pause_resume_last_task_failed_with_no_next_task(self): liveaction, execution = action_service.request_resume(liveaction, USERNAME) # Wait until the liveaction is completed. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_FAILED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_FAILED + ) self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_FAILED) # Wait for non-blocking threads to complete. MockLiveActionPublisherNonBlocking.wait_all() # Check liveaction result. - self.assertIn('tasks', liveaction.result) - self.assertEqual(len(liveaction.result['tasks']), 1) + self.assertIn("tasks", liveaction.result) + self.assertEqual(len(liveaction.result["tasks"]), 1) self.assertEqual( - liveaction.result['tasks'][0]['state'], - action_constants.LIVEACTION_STATUS_FAILED + liveaction.result["tasks"][0]["state"], + action_constants.LIVEACTION_STATUS_FAILED, ) def test_chain_pause_resume_status_change(self): # Tests context_result is updated when last task's status changes between pause and resume - action = TEST_PACK + '.' + 'test_pause_resume_context_result' + action = TEST_PACK + "." + "test_pause_resume_context_result" liveaction = LiveActionDB(action=action) liveaction, execution = action_service.request(liveaction) liveaction = LiveAction.get_by_id(str(liveaction.id)) # Wait until the liveaction is paused. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSED + ) extra_info = str(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_PAUSED, extra_info + ) # Wait for non-blocking threads to complete. Ensure runner is not running. MockLiveActionPublisherNonBlocking.wait_all() - last_task_liveaction_id = liveaction.result['tasks'][-1]['liveaction_id'] + last_task_liveaction_id = liveaction.result["tasks"][-1]["liveaction_id"] action_utils.update_liveaction_status( status=action_constants.LIVEACTION_STATUS_SUCCEEDED, end_timestamp=date_utils.get_datetime_utc_now(), - result={'foo': 'bar'}, - liveaction_id=last_task_liveaction_id + result={"foo": "bar"}, + liveaction_id=last_task_liveaction_id, ) # Request action chain to resume. liveaction, execution = action_service.request_resume(liveaction, USERNAME) # Wait until the liveaction is completed. - liveaction = self._wait_for_status(liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED) + liveaction = self._wait_for_status( + liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) self.assertEqual( liveaction.status, action_constants.LIVEACTION_STATUS_SUCCEEDED, - str(liveaction) + str(liveaction), ) # Wait for non-blocking threads to complete. MockLiveActionPublisherNonBlocking.wait_all() # Check liveaction result. - self.assertIn('tasks', liveaction.result) - self.assertEqual(len(liveaction.result['tasks']), 2) - self.assertEqual(liveaction.result['tasks'][0]['result']['foo'], 'bar') + self.assertIn("tasks", liveaction.result) + self.assertEqual(len(liveaction.result["tasks"]), 2) + self.assertEqual(liveaction.result["tasks"][0]["result"]["foo"], "bar") diff --git a/contrib/runners/announcement_runner/announcement_runner/__init__.py b/contrib/runners/announcement_runner/announcement_runner/__init__.py index bbe290db9a7..e6d3f15e0bd 100644 --- a/contrib/runners/announcement_runner/announcement_runner/__init__.py +++ b/contrib/runners/announcement_runner/announcement_runner/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = "3.4dev" diff --git a/contrib/runners/announcement_runner/announcement_runner/announcement_runner.py b/contrib/runners/announcement_runner/announcement_runner/announcement_runner.py index 6d219f2819c..4782544c3cd 100644 --- a/contrib/runners/announcement_runner/announcement_runner/announcement_runner.py +++ b/contrib/runners/announcement_runner/announcement_runner/announcement_runner.py @@ -24,12 +24,7 @@ from st2common.models.api.trace import TraceContext from st2common.transport.announcement import AnnouncementDispatcher -__all__ = [ - 'AnnouncementRunner', - - 'get_runner', - 'get_metadata' -] +__all__ = ["AnnouncementRunner", "get_runner", "get_metadata"] LOG = logging.getLogger(__name__) @@ -42,28 +37,28 @@ def __init__(self, runner_id): def pre_run(self): super(AnnouncementRunner, self).pre_run() - LOG.debug('Entering AnnouncementRunner.pre_run() for liveaction_id="%s"', - self.liveaction_id) + LOG.debug( + 'Entering AnnouncementRunner.pre_run() for liveaction_id="%s"', + self.liveaction_id, + ) - if not self.runner_parameters.get('experimental'): - message = ('Experimental flag is missing for action %s' % self.action.ref) - LOG.exception('Experimental runner is called without experimental flag.') + if not self.runner_parameters.get("experimental"): + message = "Experimental flag is missing for action %s" % self.action.ref + LOG.exception("Experimental runner is called without experimental flag.") raise runnerexceptions.ActionRunnerPreRunError(message) - self._route = self.runner_parameters.get('route') + self._route = self.runner_parameters.get("route") def run(self, action_parameters): - trace_context = self.liveaction.context.get('trace_context', None) + trace_context = self.liveaction.context.get("trace_context", None) if trace_context: trace_context = TraceContext(**trace_context) - self._dispatcher.dispatch(self._route, - payload=action_parameters, - trace_context=trace_context) + self._dispatcher.dispatch( + self._route, payload=action_parameters, trace_context=trace_context + ) - result = { - "output": action_parameters - } + result = {"output": action_parameters} result.update(action_parameters) return (LIVEACTION_STATUS_SUCCEEDED, result, None) @@ -74,4 +69,4 @@ def get_runner(): def get_metadata(): - return get_runner_metadata('announcement_runner')[0] + return get_runner_metadata("announcement_runner")[0] diff --git a/contrib/runners/announcement_runner/dist_utils.py b/contrib/runners/announcement_runner/dist_utils.py index a6f62c8cc2a..2f2043cf29f 100644 --- a/contrib/runners/announcement_runner/dist_utils.py +++ b/contrib/runners/announcement_runner/dist_utils.py @@ -43,17 +43,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" __all__ = [ - 'check_pip_is_installed', - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_is_installed", + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] @@ -64,15 +64,15 @@ def check_pip_is_installed(): try: import pip # NOQA except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) return True -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ @@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'): import pip if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) return True @@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path): reqs = [] def _get_link(line): - vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+'] + vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"] for vcs_prefix in vcs_prefixes: - if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)): - req_name = re.findall('.*#egg=(.+)([&|@]).*$', line) + if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)): + req_name = re.findall(".*#egg=(.+)([&|@]).*$", line) if not req_name: - req_name = re.findall('.*#egg=(.+?)$', line) + req_name = re.findall(".*#egg=(.+?)$", line) else: req_name = req_name[0] if not req_name: - raise ValueError('Line "%s" is missing "#egg="' % (line)) + raise ValueError( + 'Line "%s" is missing "#egg="' % (line) + ) - link = line.replace('-e ', '').strip() + link = line.replace("-e ", "").strip() return link, req_name[0] return None, None - with open(requirements_file_path, 'r') as fp: + with open(requirements_file_path, "r") as fp: for line in fp.readlines(): line = line.strip() - if line.startswith('#') or not line: + if line.startswith("#") or not line: continue link, req_name = _get_link(line=line) @@ -131,8 +135,8 @@ def _get_link(line): else: req_name = line - if ';' in req_name: - req_name = req_name.split(';')[0].strip() + if ";" in req_name: + req_name = req_name.split(";")[0].strip() reqs.append(req_name) @@ -146,7 +150,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -155,14 +159,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/contrib/runners/announcement_runner/setup.py b/contrib/runners/announcement_runner/setup.py index efd60b14afe..a72469ffea4 100644 --- a/contrib/runners/announcement_runner/setup.py +++ b/contrib/runners/announcement_runner/setup.py @@ -26,30 +26,32 @@ from announcement_runner import __version__ BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt') +REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt") install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE) apply_vagrant_workaround() setup( - name='stackstorm-runner-announcement', + name="stackstorm-runner-announcement", version=__version__, - description=('Announcement action runner for StackStorm event-driven automation platform'), - author='StackStorm', - author_email='info@stackstorm.com', - license='Apache License (2.0)', - url='https://stackstorm.com/', + description=( + "Announcement action runner for StackStorm event-driven automation platform" + ), + author="StackStorm", + author_email="info@stackstorm.com", + license="Apache License (2.0)", + url="https://stackstorm.com/", install_requires=install_reqs, dependency_links=dep_links, - test_suite='tests', + test_suite="tests", zip_safe=False, include_package_data=True, - packages=find_packages(exclude=['setuptools', 'tests']), - package_data={'announcement_runner': ['runner.yaml']}, + packages=find_packages(exclude=["setuptools", "tests"]), + package_data={"announcement_runner": ["runner.yaml"]}, scripts=[], entry_points={ - 'st2common.runners.runner': [ - 'announcement = announcement_runner.announcement_runner', + "st2common.runners.runner": [ + "announcement = announcement_runner.announcement_runner", ], - } + }, ) diff --git a/contrib/runners/announcement_runner/tests/unit/test_announcementrunner.py b/contrib/runners/announcement_runner/tests/unit/test_announcementrunner.py index cc9c5410151..9ad56a21159 100644 --- a/contrib/runners/announcement_runner/tests/unit/test_announcementrunner.py +++ b/contrib/runners/announcement_runner/tests/unit/test_announcementrunner.py @@ -26,69 +26,63 @@ mock_dispatcher = mock.Mock() -@mock.patch('st2common.transport.announcement.AnnouncementDispatcher.dispatch') +@mock.patch("st2common.transport.announcement.AnnouncementDispatcher.dispatch") class AnnouncementRunnerTestCase(RunnerTestCase): - @classmethod def setUpClass(cls): tests_config.parse_args() def test_runner_creation(self, dispatch): runner = announcement_runner.get_runner() - self.assertIsNotNone(runner, 'Creation failed. No instance.') - self.assertEqual(type(runner), announcement_runner.AnnouncementRunner, - 'Creation failed. No instance.') + self.assertIsNotNone(runner, "Creation failed. No instance.") + self.assertEqual( + type(runner), + announcement_runner.AnnouncementRunner, + "Creation failed. No instance.", + ) self.assertEqual(runner._dispatcher.dispatch, dispatch) def test_announcement(self, dispatch): runner = announcement_runner.get_runner() - runner.runner_parameters = { - 'experimental': True, - 'route': 'general' - } + runner.runner_parameters = {"experimental": True, "route": "general"} runner.liveaction = mock.Mock(context={}) runner.pre_run() - (status, result, _) = runner.run({'test': 'passed'}) + (status, result, _) = runner.run({"test": "passed"}) self.assertEqual(status, LIVEACTION_STATUS_SUCCEEDED) self.assertIsNotNone(result) - self.assertEqual(result['test'], 'passed') - dispatch.assert_called_once_with('general', payload={'test': 'passed'}, - trace_context=None) + self.assertEqual(result["test"], "passed") + dispatch.assert_called_once_with( + "general", payload={"test": "passed"}, trace_context=None + ) def test_announcement_no_experimental(self, dispatch): runner = announcement_runner.get_runner() - runner.action = mock.Mock(ref='some.thing') - runner.runner_parameters = { - 'route': 'general' - } + runner.action = mock.Mock(ref="some.thing") + runner.runner_parameters = {"route": "general"} runner.liveaction = mock.Mock(context={}) - expected_msg = 'Experimental flag is missing for action some.thing' + expected_msg = "Experimental flag is missing for action some.thing" self.assertRaisesRegexp(Exception, expected_msg, runner.pre_run) - @mock.patch('st2common.models.api.trace.TraceContext.__new__') + @mock.patch("st2common.models.api.trace.TraceContext.__new__") def test_announcement_with_trace(self, context, dispatch): runner = announcement_runner.get_runner() - runner.runner_parameters = { - 'experimental': True, - 'route': 'general' - } - runner.liveaction = mock.Mock(context={ - 'trace_context': { - 'id_': 'a', - 'trace_tag': 'b' - } - }) + runner.runner_parameters = {"experimental": True, "route": "general"} + runner.liveaction = mock.Mock( + context={"trace_context": {"id_": "a", "trace_tag": "b"}} + ) runner.pre_run() - (status, result, _) = runner.run({'test': 'passed'}) + (status, result, _) = runner.run({"test": "passed"}) self.assertEqual(status, LIVEACTION_STATUS_SUCCEEDED) self.assertIsNotNone(result) - self.assertEqual(result['test'], 'passed') - context.assert_called_once_with(TraceContext, - **runner.liveaction.context['trace_context']) - dispatch.assert_called_once_with('general', payload={'test': 'passed'}, - trace_context=context.return_value) + self.assertEqual(result["test"], "passed") + context.assert_called_once_with( + TraceContext, **runner.liveaction.context["trace_context"] + ) + dispatch.assert_called_once_with( + "general", payload={"test": "passed"}, trace_context=context.return_value + ) diff --git a/contrib/runners/http_runner/dist_utils.py b/contrib/runners/http_runner/dist_utils.py index a6f62c8cc2a..2f2043cf29f 100644 --- a/contrib/runners/http_runner/dist_utils.py +++ b/contrib/runners/http_runner/dist_utils.py @@ -43,17 +43,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" __all__ = [ - 'check_pip_is_installed', - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_is_installed", + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] @@ -64,15 +64,15 @@ def check_pip_is_installed(): try: import pip # NOQA except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) return True -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ @@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'): import pip if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) return True @@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path): reqs = [] def _get_link(line): - vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+'] + vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"] for vcs_prefix in vcs_prefixes: - if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)): - req_name = re.findall('.*#egg=(.+)([&|@]).*$', line) + if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)): + req_name = re.findall(".*#egg=(.+)([&|@]).*$", line) if not req_name: - req_name = re.findall('.*#egg=(.+?)$', line) + req_name = re.findall(".*#egg=(.+?)$", line) else: req_name = req_name[0] if not req_name: - raise ValueError('Line "%s" is missing "#egg="' % (line)) + raise ValueError( + 'Line "%s" is missing "#egg="' % (line) + ) - link = line.replace('-e ', '').strip() + link = line.replace("-e ", "").strip() return link, req_name[0] return None, None - with open(requirements_file_path, 'r') as fp: + with open(requirements_file_path, "r") as fp: for line in fp.readlines(): line = line.strip() - if line.startswith('#') or not line: + if line.startswith("#") or not line: continue link, req_name = _get_link(line=line) @@ -131,8 +135,8 @@ def _get_link(line): else: req_name = line - if ';' in req_name: - req_name = req_name.split(';')[0].strip() + if ";" in req_name: + req_name = req_name.split(";")[0].strip() reqs.append(req_name) @@ -146,7 +150,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -155,14 +159,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/contrib/runners/http_runner/http_runner/__init__.py b/contrib/runners/http_runner/http_runner/__init__.py index bbe290db9a7..e6d3f15e0bd 100644 --- a/contrib/runners/http_runner/http_runner/__init__.py +++ b/contrib/runners/http_runner/http_runner/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = "3.4dev" diff --git a/contrib/runners/http_runner/http_runner/http_runner.py b/contrib/runners/http_runner/http_runner/http_runner.py index b2ff115fc64..6a02c809b9d 100644 --- a/contrib/runners/http_runner/http_runner/http_runner.py +++ b/contrib/runners/http_runner/http_runner/http_runner.py @@ -35,45 +35,36 @@ import six from six.moves import range -__all__ = [ - 'HttpRunner', - - 'HTTPClient', - - 'get_runner', - 'get_metadata' -] +__all__ = ["HttpRunner", "HTTPClient", "get_runner", "get_metadata"] LOG = logging.getLogger(__name__) SUCCESS_STATUS_CODES = [code for code in range(200, 207)] # Lookup constants for runner params -RUNNER_ON_BEHALF_USER = 'user' -RUNNER_URL = 'url' -RUNNER_HEADERS = 'headers' # Debatable whether this should be action params. -RUNNER_COOKIES = 'cookies' -RUNNER_ALLOW_REDIRECTS = 'allow_redirects' -RUNNER_HTTP_PROXY = 'http_proxy' -RUNNER_HTTPS_PROXY = 'https_proxy' -RUNNER_VERIFY_SSL_CERT = 'verify_ssl_cert' -RUNNER_USERNAME = 'username' -RUNNER_PASSWORD = 'password' -RUNNER_URL_HOSTS_BLACKLIST = 'url_hosts_blacklist' -RUNNER_URL_HOSTS_WHITELIST = 'url_hosts_whitelist' +RUNNER_ON_BEHALF_USER = "user" +RUNNER_URL = "url" +RUNNER_HEADERS = "headers" # Debatable whether this should be action params. +RUNNER_COOKIES = "cookies" +RUNNER_ALLOW_REDIRECTS = "allow_redirects" +RUNNER_HTTP_PROXY = "http_proxy" +RUNNER_HTTPS_PROXY = "https_proxy" +RUNNER_VERIFY_SSL_CERT = "verify_ssl_cert" +RUNNER_USERNAME = "username" +RUNNER_PASSWORD = "password" +RUNNER_URL_HOSTS_BLACKLIST = "url_hosts_blacklist" +RUNNER_URL_HOSTS_WHITELIST = "url_hosts_whitelist" # Lookup constants for action params -ACTION_AUTH = 'auth' -ACTION_BODY = 'body' -ACTION_TIMEOUT = 'timeout' -ACTION_METHOD = 'method' -ACTION_QUERY_PARAMS = 'params' -FILE_NAME = 'file_name' -FILE_CONTENT = 'file_content' -FILE_CONTENT_TYPE = 'file_content_type' +ACTION_AUTH = "auth" +ACTION_BODY = "body" +ACTION_TIMEOUT = "timeout" +ACTION_METHOD = "method" +ACTION_QUERY_PARAMS = "params" +FILE_NAME = "file_name" +FILE_CONTENT = "file_content" +FILE_CONTENT_TYPE = "file_content_type" -RESPONSE_BODY_PARSE_FUNCTIONS = { - 'application/json': json.loads -} +RESPONSE_BODY_PARSE_FUNCTIONS = {"application/json": json.loads} class HttpRunner(ActionRunner): @@ -85,37 +76,48 @@ def __init__(self, runner_id): def pre_run(self): super(HttpRunner, self).pre_run() - LOG.debug('Entering HttpRunner.pre_run() for liveaction_id="%s"', self.liveaction_id) - self._on_behalf_user = self.runner_parameters.get(RUNNER_ON_BEHALF_USER, - self._on_behalf_user) + LOG.debug( + 'Entering HttpRunner.pre_run() for liveaction_id="%s"', self.liveaction_id + ) + self._on_behalf_user = self.runner_parameters.get( + RUNNER_ON_BEHALF_USER, self._on_behalf_user + ) self._url = self.runner_parameters.get(RUNNER_URL, None) self._headers = self.runner_parameters.get(RUNNER_HEADERS, {}) self._cookies = self.runner_parameters.get(RUNNER_COOKIES, None) - self._allow_redirects = self.runner_parameters.get(RUNNER_ALLOW_REDIRECTS, False) + self._allow_redirects = self.runner_parameters.get( + RUNNER_ALLOW_REDIRECTS, False + ) self._username = self.runner_parameters.get(RUNNER_USERNAME, None) self._password = self.runner_parameters.get(RUNNER_PASSWORD, None) self._http_proxy = self.runner_parameters.get(RUNNER_HTTP_PROXY, None) self._https_proxy = self.runner_parameters.get(RUNNER_HTTPS_PROXY, None) self._verify_ssl_cert = self.runner_parameters.get(RUNNER_VERIFY_SSL_CERT, None) - self._url_hosts_blacklist = self.runner_parameters.get(RUNNER_URL_HOSTS_BLACKLIST, []) - self._url_hosts_whitelist = self.runner_parameters.get(RUNNER_URL_HOSTS_WHITELIST, []) + self._url_hosts_blacklist = self.runner_parameters.get( + RUNNER_URL_HOSTS_BLACKLIST, [] + ) + self._url_hosts_whitelist = self.runner_parameters.get( + RUNNER_URL_HOSTS_WHITELIST, [] + ) def run(self, action_parameters): client = self._get_http_client(action_parameters) if self._url_hosts_blacklist and self._url_hosts_whitelist: - msg = ('"url_hosts_blacklist" and "url_hosts_whitelist" parameters are mutually ' - 'exclusive. Only one should be provided.') + msg = ( + '"url_hosts_blacklist" and "url_hosts_whitelist" parameters are mutually ' + "exclusive. Only one should be provided." + ) raise ValueError(msg) try: result = client.run() except requests.exceptions.Timeout as e: - result = {'error': six.text_type(e)} + result = {"error": six.text_type(e)} status = LIVEACTION_STATUS_TIMED_OUT else: - status = HttpRunner._get_result_status(result.get('status_code', None)) + status = HttpRunner._get_result_status(result.get("status_code", None)) return (status, result, None) @@ -132,8 +134,8 @@ def _get_http_client(self, action_parameters): # Include our user agent and action name so requests can be tracked back headers = copy.deepcopy(self._headers) if self._headers else {} - headers['User-Agent'] = 'st2/v%s' % (st2_version) - headers['X-Stanley-Action'] = self.action_name + headers["User-Agent"] = "st2/v%s" % (st2_version) + headers["X-Stanley-Action"] = self.action_name if file_name and file_content: files = {} @@ -141,7 +143,7 @@ def _get_http_client(self, action_parameters): if file_content_type: value = (file_content, file_content_type) else: - value = (file_content) + value = file_content files[file_name] = value else: @@ -150,43 +152,72 @@ def _get_http_client(self, action_parameters): proxies = {} if self._http_proxy: - proxies['http'] = self._http_proxy + proxies["http"] = self._http_proxy if self._https_proxy: - proxies['https'] = self._https_proxy - - return HTTPClient(url=self._url, method=method, body=body, params=params, - headers=headers, cookies=self._cookies, auth=auth, - timeout=timeout, allow_redirects=self._allow_redirects, - proxies=proxies, files=files, verify=self._verify_ssl_cert, - username=self._username, password=self._password, - url_hosts_blacklist=self._url_hosts_blacklist, - url_hosts_whitelist=self._url_hosts_whitelist) + proxies["https"] = self._https_proxy + + return HTTPClient( + url=self._url, + method=method, + body=body, + params=params, + headers=headers, + cookies=self._cookies, + auth=auth, + timeout=timeout, + allow_redirects=self._allow_redirects, + proxies=proxies, + files=files, + verify=self._verify_ssl_cert, + username=self._username, + password=self._password, + url_hosts_blacklist=self._url_hosts_blacklist, + url_hosts_whitelist=self._url_hosts_whitelist, + ) @staticmethod def _get_result_status(status_code): - return LIVEACTION_STATUS_SUCCEEDED if status_code in SUCCESS_STATUS_CODES \ + return ( + LIVEACTION_STATUS_SUCCEEDED + if status_code in SUCCESS_STATUS_CODES else LIVEACTION_STATUS_FAILED + ) class HTTPClient(object): - def __init__(self, url=None, method=None, body='', params=None, headers=None, cookies=None, - auth=None, timeout=60, allow_redirects=False, proxies=None, - files=None, verify=False, username=None, password=None, - url_hosts_blacklist=None, url_hosts_whitelist=None): + def __init__( + self, + url=None, + method=None, + body="", + params=None, + headers=None, + cookies=None, + auth=None, + timeout=60, + allow_redirects=False, + proxies=None, + files=None, + verify=False, + username=None, + password=None, + url_hosts_blacklist=None, + url_hosts_whitelist=None, + ): if url is None: - raise Exception('URL must be specified.') + raise Exception("URL must be specified.") if method is None: if files or body: - method = 'POST' + method = "POST" else: - method = 'GET' + method = "GET" headers = headers or {} normalized_headers = self._normalize_headers(headers=headers) - if body and 'content-length' not in normalized_headers: - headers['Content-Length'] = str(len(body)) + if body and "content-length" not in normalized_headers: + headers["Content-Length"] = str(len(body)) self.url = url self.method = method @@ -207,8 +238,10 @@ def __init__(self, url=None, method=None, body='', params=None, headers=None, co self.url_hosts_whitelist = url_hosts_whitelist or [] if self.url_hosts_blacklist and self.url_hosts_whitelist: - msg = ('"url_hosts_blacklist" and "url_hosts_whitelist" parameters are mutually ' - 'exclusive. Only one should be provided.') + msg = ( + '"url_hosts_blacklist" and "url_hosts_whitelist" parameters are mutually ' + "exclusive. Only one should be provided." + ) raise ValueError(msg) def run(self): @@ -235,7 +268,7 @@ def run(self): try: data = json.dumps(data) except ValueError: - msg = 'Request body (%s) can\'t be parsed as JSON' % (data) + msg = "Request body (%s) can't be parsed as JSON" % (data) raise ValueError(msg) else: data = self.body @@ -245,7 +278,7 @@ def run(self): # Ensure data is bytes since that what request expects if isinstance(data, six.text_type): - data = data.encode('utf-8') + data = data.encode("utf-8") resp = requests.request( self.method, @@ -259,19 +292,19 @@ def run(self): allow_redirects=self.allow_redirects, proxies=self.proxies, files=self.files, - verify=self.verify + verify=self.verify, ) headers = dict(resp.headers) body, parsed = self._parse_response_body(headers=headers, body=resp.text) - results['status_code'] = resp.status_code - results['body'] = body - results['parsed'] = parsed # flag which indicates if body has been parsed - results['headers'] = headers + results["status_code"] = resp.status_code + results["body"] = body + results["parsed"] = parsed # flag which indicates if body has been parsed + results["headers"] = headers return results except Exception as e: - LOG.exception('Exception making request to remote URL: %s, %s', self.url, e) + LOG.exception("Exception making request to remote URL: %s, %s", self.url, e) raise finally: if resp: @@ -285,27 +318,27 @@ def _parse_response_body(self, headers, body): :return: (parsed body, flag which indicates if body has been parsed) :rtype: (``object``, ``bool``) """ - body = body or '' + body = body or "" headers = self._normalize_headers(headers=headers) - content_type = headers.get('content-type', None) + content_type = headers.get("content-type", None) parsed = False if not content_type: return (body, parsed) # The header can also contain charset which we simply discard - content_type = content_type.split(';')[0] + content_type = content_type.split(";")[0] parse_func = RESPONSE_BODY_PARSE_FUNCTIONS.get(content_type, None) if not parse_func: return (body, parsed) - LOG.debug('Parsing body with content type: %s', content_type) + LOG.debug("Parsing body with content type: %s", content_type) try: body = parse_func(body) except Exception: - LOG.exception('Failed to parse body') + LOG.exception("Failed to parse body") else: parsed = True @@ -323,7 +356,7 @@ def _normalize_headers(self, headers): def _is_json_content(self): normalized = self._normalize_headers(self.headers) - return normalized.get('content-type', None) == 'application/json' + return normalized.get("content-type", None) == "application/json" def _cast_object(self, value): if isinstance(value, str) or isinstance(value, six.text_type): @@ -370,10 +403,10 @@ def _get_host_from_url(self, url): parsed = urlparse.urlparse(url) # Remove port and [] - host = parsed.netloc.replace('[', '').replace(']', '') + host = parsed.netloc.replace("[", "").replace("]", "") if parsed.port is not None: - host = host.replace(':%s' % (parsed.port), '') + host = host.replace(":%s" % (parsed.port), "") return host @@ -383,4 +416,4 @@ def get_runner(): def get_metadata(): - return get_runner_metadata('http_runner')[0] + return get_runner_metadata("http_runner")[0] diff --git a/contrib/runners/http_runner/setup.py b/contrib/runners/http_runner/setup.py index 2b962da5990..2a5c9e217bc 100644 --- a/contrib/runners/http_runner/setup.py +++ b/contrib/runners/http_runner/setup.py @@ -26,30 +26,32 @@ from http_runner import __version__ BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt') +REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt") install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE) apply_vagrant_workaround() setup( - name='stackstorm-runner-http', + name="stackstorm-runner-http", version=__version__, - description=('HTTP(s) action runner for StackStorm event-driven automation platform'), - author='StackStorm', - author_email='info@stackstorm.com', - license='Apache License (2.0)', - url='https://stackstorm.com/', + description=( + "HTTP(s) action runner for StackStorm event-driven automation platform" + ), + author="StackStorm", + author_email="info@stackstorm.com", + license="Apache License (2.0)", + url="https://stackstorm.com/", install_requires=install_reqs, dependency_links=dep_links, - test_suite='tests', + test_suite="tests", zip_safe=False, include_package_data=True, - packages=find_packages(exclude=['setuptools', 'tests']), - package_data={'http_runner': ['runner.yaml']}, + packages=find_packages(exclude=["setuptools", "tests"]), + package_data={"http_runner": ["runner.yaml"]}, scripts=[], entry_points={ - 'st2common.runners.runner': [ - 'http-request = http_runner.http_runner', + "st2common.runners.runner": [ + "http-request = http_runner.http_runner", ], - } + }, ) diff --git a/contrib/runners/http_runner/tests/unit/test_http_runner.py b/contrib/runners/http_runner/tests/unit/test_http_runner.py index be64f6d4207..9d2d99a7c1c 100644 --- a/contrib/runners/http_runner/tests/unit/test_http_runner.py +++ b/contrib/runners/http_runner/tests/unit/test_http_runner.py @@ -28,16 +28,13 @@ import st2tests.config as tests_config -__all__ = [ - 'HTTPClientTestCase', - 'HTTPRunnerTestCase' -] +__all__ = ["HTTPClientTestCase", "HTTPRunnerTestCase"] if six.PY2: - EXPECTED_DATA = '' + EXPECTED_DATA = "" else: - EXPECTED_DATA = b'' + EXPECTED_DATA = b"" class MockResult(object): @@ -49,70 +46,70 @@ class HTTPClientTestCase(unittest2.TestCase): def setUpClass(cls): tests_config.parse_args() - @mock.patch('http_runner.http_runner.requests') + @mock.patch("http_runner.http_runner.requests") def test_parse_response_body(self, mock_requests): - client = HTTPClient(url='http://127.0.0.1') + client = HTTPClient(url="http://127.0.0.1") mock_result = MockResult() # Unknown content type, body should be returned raw - mock_result.text = 'foo bar ponies' - mock_result.headers = {'Content-Type': 'text/html'} + mock_result.text = "foo bar ponies" + mock_result.headers = {"Content-Type": "text/html"} mock_result.status_code = 200 mock_requests.request.return_value = mock_result result = client.run() - self.assertEqual(result['body'], mock_result.text) - self.assertEqual(result['status_code'], mock_result.status_code) - self.assertEqual(result['headers'], mock_result.headers) + self.assertEqual(result["body"], mock_result.text) + self.assertEqual(result["status_code"], mock_result.status_code) + self.assertEqual(result["headers"], mock_result.headers) # Unknown content type, JSON body mock_result.text = '{"test1": "val1"}' - mock_result.headers = {'Content-Type': 'text/html'} + mock_result.headers = {"Content-Type": "text/html"} mock_requests.request.return_value = mock_result result = client.run() - self.assertEqual(result['body'], mock_result.text) + self.assertEqual(result["body"], mock_result.text) # JSON content-type and JSON body mock_result.text = '{"test1": "val1"}' - mock_result.headers = {'Content-Type': 'application/json'} + mock_result.headers = {"Content-Type": "application/json"} mock_requests.request.return_value = mock_result result = client.run() - self.assertIsInstance(result['body'], dict) - self.assertEqual(result['body'], {'test1': 'val1'}) + self.assertIsInstance(result["body"], dict) + self.assertEqual(result["body"], {"test1": "val1"}) # JSON content-type with charset and JSON body mock_result.text = '{"test1": "val1"}' - mock_result.headers = {'Content-Type': 'application/json; charset=UTF-8'} + mock_result.headers = {"Content-Type": "application/json; charset=UTF-8"} mock_requests.request.return_value = mock_result result = client.run() - self.assertIsInstance(result['body'], dict) - self.assertEqual(result['body'], {'test1': 'val1'}) + self.assertIsInstance(result["body"], dict) + self.assertEqual(result["body"], {"test1": "val1"}) # JSON content-type and invalid json body - mock_result.text = 'not json' - mock_result.headers = {'Content-Type': 'application/json'} + mock_result.text = "not json" + mock_result.headers = {"Content-Type": "application/json"} mock_requests.request.return_value = mock_result result = client.run() - self.assertNotIsInstance(result['body'], dict) - self.assertEqual(result['body'], mock_result.text) + self.assertNotIsInstance(result["body"], dict) + self.assertEqual(result["body"], mock_result.text) - @mock.patch('http_runner.http_runner.requests') + @mock.patch("http_runner.http_runner.requests") def test_https_verify(self, mock_requests): - url = 'https://127.0.0.1:8888' + url = "https://127.0.0.1:8888" client = HTTPClient(url=url, verify=True) mock_result = MockResult() - mock_result.text = 'foo bar ponies' - mock_result.headers = {'Content-Type': 'text/html'} + mock_result.text = "foo bar ponies" + mock_result.headers = {"Content-Type": "text/html"} mock_result.status_code = 200 mock_requests.request.return_value = mock_result @@ -121,23 +118,33 @@ def test_https_verify(self, mock_requests): self.assertTrue(client.verify) if six.PY2: - data = '' + data = "" else: - data = b'' + data = b"" mock_requests.request.assert_called_with( - 'GET', url, allow_redirects=False, auth=None, cookies=None, - data=data, files=None, headers={}, params=None, proxies=None, - timeout=60, verify=True) - - @mock.patch('http_runner.http_runner.requests') + "GET", + url, + allow_redirects=False, + auth=None, + cookies=None, + data=data, + files=None, + headers={}, + params=None, + proxies=None, + timeout=60, + verify=True, + ) + + @mock.patch("http_runner.http_runner.requests") def test_https_verify_false(self, mock_requests): - url = 'https://127.0.0.1:8888' + url = "https://127.0.0.1:8888" client = HTTPClient(url=url) mock_result = MockResult() - mock_result.text = 'foo bar ponies' - mock_result.headers = {'Content-Type': 'text/html'} + mock_result.text = "foo bar ponies" + mock_result.headers = {"Content-Type": "text/html"} mock_result.status_code = 200 mock_requests.request.return_value = mock_result @@ -146,182 +153,202 @@ def test_https_verify_false(self, mock_requests): self.assertFalse(client.verify) mock_requests.request.assert_called_with( - 'GET', url, allow_redirects=False, auth=None, cookies=None, - data=EXPECTED_DATA, files=None, headers={}, params=None, proxies=None, - timeout=60, verify=False) - - @mock.patch('http_runner.http_runner.requests') + "GET", + url, + allow_redirects=False, + auth=None, + cookies=None, + data=EXPECTED_DATA, + files=None, + headers={}, + params=None, + proxies=None, + timeout=60, + verify=False, + ) + + @mock.patch("http_runner.http_runner.requests") def test_https_auth_basic(self, mock_requests): - url = 'https://127.0.0.1:8888' - username = 'misspiggy' - password = 'kermit' + url = "https://127.0.0.1:8888" + username = "misspiggy" + password = "kermit" client = HTTPClient(url=url, username=username, password=password) mock_result = MockResult() - mock_result.text = 'muppet show' - mock_result.headers = {'Authorization': 'bWlzc3BpZ2d5Omtlcm1pdA=='} + mock_result.text = "muppet show" + mock_result.headers = {"Authorization": "bWlzc3BpZ2d5Omtlcm1pdA=="} mock_result.status_code = 200 mock_requests.request.return_value = mock_result result = client.run() - self.assertEqual(result['headers'], mock_result.headers) + self.assertEqual(result["headers"], mock_result.headers) mock_requests.request.assert_called_once_with( - 'GET', url, allow_redirects=False, auth=client.auth, cookies=None, - data=EXPECTED_DATA, files=None, headers={}, params=None, proxies=None, - timeout=60, verify=False) - - @mock.patch('http_runner.http_runner.requests') + "GET", + url, + allow_redirects=False, + auth=client.auth, + cookies=None, + data=EXPECTED_DATA, + files=None, + headers={}, + params=None, + proxies=None, + timeout=60, + verify=False, + ) + + @mock.patch("http_runner.http_runner.requests") def test_http_unicode_body_data(self, mock_requests): - url = 'http://127.0.0.1:8888' - method = 'POST' + url = "http://127.0.0.1:8888" + method = "POST" mock_result = MockResult() # 1. String data headers = {} - body = 'žžžžž' - client = HTTPClient(url=url, method=method, headers=headers, body=body, timeout=0.1) + body = "žžžžž" + client = HTTPClient( + url=url, method=method, headers=headers, body=body, timeout=0.1 + ) mock_result.text = '{"foo": "bar"}' - mock_result.headers = {'Content-Type': 'application/json'} + mock_result.headers = {"Content-Type": "application/json"} mock_result.status_code = 200 mock_requests.request.return_value = mock_result result = client.run() - self.assertEqual(result['status_code'], 200) + self.assertEqual(result["status_code"], 200) call_kwargs = mock_requests.request.call_args_list[0][1] - expected_data = u'žžžžž'.encode('utf-8') - self.assertEqual(call_kwargs['data'], expected_data) + expected_data = "žžžžž".encode("utf-8") + self.assertEqual(call_kwargs["data"], expected_data) # 1. Object / JSON data - body = { - 'foo': u'ažž' - } - headers = { - 'Content-Type': 'application/json; charset=utf-8' - } - client = HTTPClient(url=url, method=method, headers=headers, body=body, timeout=0.1) + body = {"foo": "ažž"} + headers = {"Content-Type": "application/json; charset=utf-8"} + client = HTTPClient( + url=url, method=method, headers=headers, body=body, timeout=0.1 + ) mock_result.text = '{"foo": "bar"}' - mock_result.headers = {'Content-Type': 'application/json'} + mock_result.headers = {"Content-Type": "application/json"} mock_result.status_code = 200 mock_requests.request.return_value = mock_result result = client.run() - self.assertEqual(result['status_code'], 200) + self.assertEqual(result["status_code"], 200) call_kwargs = mock_requests.request.call_args_list[1][1] if six.PY2: - expected_data = { - 'foo': u'a\u017e\u017e' - } + expected_data = {"foo": "a\u017e\u017e"} else: expected_data = body - self.assertEqual(call_kwargs['data'], expected_data) + self.assertEqual(call_kwargs["data"], expected_data) - @mock.patch('http_runner.http_runner.requests') + @mock.patch("http_runner.http_runner.requests") def test_blacklisted_url_url_hosts_blacklist_runner_parameter(self, mock_requests): # Black list is empty self.assertEqual(mock_requests.request.call_count, 0) - url = 'http://www.example.com' - client = HTTPClient(url=url, method='GET') + url = "http://www.example.com" + client = HTTPClient(url=url, method="GET") client.run() self.assertEqual(mock_requests.request.call_count, 1) # Blacklist is set url_hosts_blacklist = [ - 'example.com', - '127.0.0.1', - '::1', - '2001:0db8:85a3:0000:0000:8a2e:0370:7334' + "example.com", + "127.0.0.1", + "::1", + "2001:0db8:85a3:0000:0000:8a2e:0370:7334", ] # Blacklisted urls urls = [ - 'https://example.com', - 'http://example.com', - 'http://example.com:81', - 'http://example.com:80', - 'http://example.com:9000', - 'http://[::1]:80/', - 'http://[::1]', - 'http://[::1]:9000', - 'http://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]', - 'https://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8000' + "https://example.com", + "http://example.com", + "http://example.com:81", + "http://example.com:80", + "http://example.com:9000", + "http://[::1]:80/", + "http://[::1]", + "http://[::1]:9000", + "http://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]", + "https://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8000", ] for url in urls: expected_msg = r'URL "%s" is blacklisted' % (re.escape(url)) - client = HTTPClient(url=url, method='GET', url_hosts_blacklist=url_hosts_blacklist) + client = HTTPClient( + url=url, method="GET", url_hosts_blacklist=url_hosts_blacklist + ) self.assertRaisesRegexp(ValueError, expected_msg, client.run) # Non blacklisted URLs - urls = [ - 'https://example2.com', - 'http://example3.com', - 'http://example4.com:81' - ] + urls = ["https://example2.com", "http://example3.com", "http://example4.com:81"] for url in urls: mock_requests.request.reset_mock() self.assertEqual(mock_requests.request.call_count, 0) - client = HTTPClient(url=url, method='GET', url_hosts_blacklist=url_hosts_blacklist) + client = HTTPClient( + url=url, method="GET", url_hosts_blacklist=url_hosts_blacklist + ) client.run() self.assertEqual(mock_requests.request.call_count, 1) - @mock.patch('http_runner.http_runner.requests') + @mock.patch("http_runner.http_runner.requests") def test_whitelisted_url_url_hosts_whitelist_runner_parameter(self, mock_requests): # Whitelist is empty self.assertEqual(mock_requests.request.call_count, 0) - url = 'http://www.example.com' - client = HTTPClient(url=url, method='GET') + url = "http://www.example.com" + client = HTTPClient(url=url, method="GET") client.run() self.assertEqual(mock_requests.request.call_count, 1) # Whitelist is set url_hosts_whitelist = [ - 'example.com', - '127.0.0.1', - '::1', - '2001:0db8:85a3:0000:0000:8a2e:0370:7334' + "example.com", + "127.0.0.1", + "::1", + "2001:0db8:85a3:0000:0000:8a2e:0370:7334", ] # Non whitelisted urls urls = [ - 'https://www.google.com', - 'https://www.example2.com', - 'http://127.0.0.2' + "https://www.google.com", + "https://www.example2.com", + "http://127.0.0.2", ] for url in urls: expected_msg = r'URL "%s" is not whitelisted' % (re.escape(url)) - client = HTTPClient(url=url, method='GET', url_hosts_whitelist=url_hosts_whitelist) + client = HTTPClient( + url=url, method="GET", url_hosts_whitelist=url_hosts_whitelist + ) self.assertRaisesRegexp(ValueError, expected_msg, client.run) # Whitelisted URLS urls = [ - 'https://example.com', - 'http://example.com', - 'http://example.com:81', - 'http://example.com:80', - 'http://example.com:9000', - 'http://[::1]:80/', - 'http://[::1]', - 'http://[::1]:9000', - 'http://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]', - 'https://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8000' + "https://example.com", + "http://example.com", + "http://example.com:81", + "http://example.com:80", + "http://example.com:9000", + "http://[::1]:80/", + "http://[::1]", + "http://[::1]:9000", + "http://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]", + "https://[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8000", ] for url in urls: @@ -329,57 +356,71 @@ def test_whitelisted_url_url_hosts_whitelist_runner_parameter(self, mock_request self.assertEqual(mock_requests.request.call_count, 0) - client = HTTPClient(url=url, method='GET', url_hosts_whitelist=url_hosts_whitelist) + client = HTTPClient( + url=url, method="GET", url_hosts_whitelist=url_hosts_whitelist + ) client.run() self.assertEqual(mock_requests.request.call_count, 1) - def test_url_host_blacklist_and_url_host_blacklist_params_are_mutually_exclusive(self): - url = 'http://www.example.com' - - expected_msg = (r'"url_hosts_blacklist" and "url_hosts_whitelist" parameters are mutually ' - 'exclusive.') - self.assertRaisesRegexp(ValueError, expected_msg, HTTPClient, url=url, method='GET', - url_hosts_blacklist=[url], url_hosts_whitelist=[url]) + def test_url_host_blacklist_and_url_host_blacklist_params_are_mutually_exclusive( + self, + ): + url = "http://www.example.com" + + expected_msg = ( + r'"url_hosts_blacklist" and "url_hosts_whitelist" parameters are mutually ' + "exclusive." + ) + self.assertRaisesRegexp( + ValueError, + expected_msg, + HTTPClient, + url=url, + method="GET", + url_hosts_blacklist=[url], + url_hosts_whitelist=[url], + ) class HTTPRunnerTestCase(unittest2.TestCase): - @mock.patch('http_runner.http_runner.requests') + @mock.patch("http_runner.http_runner.requests") def test_get_success(self, mock_requests): mock_result = MockResult() # Unknown content type, body should be returned raw - mock_result.text = 'foo bar ponies' - mock_result.headers = {'Content-Type': 'text/html'} + mock_result.text = "foo bar ponies" + mock_result.headers = {"Content-Type": "text/html"} mock_result.status_code = 200 mock_requests.request.return_value = mock_result - runner_parameters = { - 'url': 'http://www.example.com', - 'method': 'GET' - } - runner = HttpRunner('id') + runner_parameters = {"url": "http://www.example.com", "method": "GET"} + runner = HttpRunner("id") runner.runner_parameters = runner_parameters runner.pre_run() status, result, _ = runner.run({}) self.assertEqual(status, LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(result['body'], 'foo bar ponies') - self.assertEqual(result['status_code'], 200) - self.assertEqual(result['parsed'], False) + self.assertEqual(result["body"], "foo bar ponies") + self.assertEqual(result["status_code"], 200) + self.assertEqual(result["parsed"], False) - def test_url_host_blacklist_and_url_host_blacklist_params_are_mutually_exclusive(self): + def test_url_host_blacklist_and_url_host_blacklist_params_are_mutually_exclusive( + self, + ): runner_parameters = { - 'url': 'http://www.example.com', - 'method': 'GET', - 'url_hosts_blacklist': ['http://127.0.0.1'], - 'url_hosts_whitelist': ['http://127.0.0.1'], + "url": "http://www.example.com", + "method": "GET", + "url_hosts_blacklist": ["http://127.0.0.1"], + "url_hosts_whitelist": ["http://127.0.0.1"], } - runner = HttpRunner('id') + runner = HttpRunner("id") runner.runner_parameters = runner_parameters runner.pre_run() - expected_msg = (r'"url_hosts_blacklist" and "url_hosts_whitelist" parameters are mutually ' - 'exclusive.') + expected_msg = ( + r'"url_hosts_blacklist" and "url_hosts_whitelist" parameters are mutually ' + "exclusive." + ) self.assertRaisesRegexp(ValueError, expected_msg, runner.run, {}) diff --git a/contrib/runners/inquirer_runner/dist_utils.py b/contrib/runners/inquirer_runner/dist_utils.py index a6f62c8cc2a..2f2043cf29f 100644 --- a/contrib/runners/inquirer_runner/dist_utils.py +++ b/contrib/runners/inquirer_runner/dist_utils.py @@ -43,17 +43,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" __all__ = [ - 'check_pip_is_installed', - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_is_installed", + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] @@ -64,15 +64,15 @@ def check_pip_is_installed(): try: import pip # NOQA except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) return True -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ @@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'): import pip if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) return True @@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path): reqs = [] def _get_link(line): - vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+'] + vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"] for vcs_prefix in vcs_prefixes: - if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)): - req_name = re.findall('.*#egg=(.+)([&|@]).*$', line) + if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)): + req_name = re.findall(".*#egg=(.+)([&|@]).*$", line) if not req_name: - req_name = re.findall('.*#egg=(.+?)$', line) + req_name = re.findall(".*#egg=(.+?)$", line) else: req_name = req_name[0] if not req_name: - raise ValueError('Line "%s" is missing "#egg="' % (line)) + raise ValueError( + 'Line "%s" is missing "#egg="' % (line) + ) - link = line.replace('-e ', '').strip() + link = line.replace("-e ", "").strip() return link, req_name[0] return None, None - with open(requirements_file_path, 'r') as fp: + with open(requirements_file_path, "r") as fp: for line in fp.readlines(): line = line.strip() - if line.startswith('#') or not line: + if line.startswith("#") or not line: continue link, req_name = _get_link(line=line) @@ -131,8 +135,8 @@ def _get_link(line): else: req_name = line - if ';' in req_name: - req_name = req_name.split(';')[0].strip() + if ";" in req_name: + req_name = req_name.split(";")[0].strip() reqs.append(req_name) @@ -146,7 +150,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -155,14 +159,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/contrib/runners/inquirer_runner/inquirer_runner/__init__.py b/contrib/runners/inquirer_runner/inquirer_runner/__init__.py index bbe290db9a7..e6d3f15e0bd 100644 --- a/contrib/runners/inquirer_runner/inquirer_runner/__init__.py +++ b/contrib/runners/inquirer_runner/inquirer_runner/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = "3.4dev" diff --git a/contrib/runners/inquirer_runner/inquirer_runner/inquirer_runner.py b/contrib/runners/inquirer_runner/inquirer_runner/inquirer_runner.py index 6a5757bc44a..af0f0c6f342 100644 --- a/contrib/runners/inquirer_runner/inquirer_runner/inquirer_runner.py +++ b/contrib/runners/inquirer_runner/inquirer_runner/inquirer_runner.py @@ -29,20 +29,16 @@ from st2common.util import action_db as action_utils -__all__ = [ - 'Inquirer', - 'get_runner', - 'get_metadata' -] +__all__ = ["Inquirer", "get_runner", "get_metadata"] LOG = logging.getLogger(__name__) # constants to lookup in runner_parameters. -RUNNER_SCHEMA = 'schema' -RUNNER_ROLES = 'roles' -RUNNER_USERS = 'users' -RUNNER_ROUTE = 'route' -RUNNER_TTL = 'ttl' +RUNNER_SCHEMA = "schema" +RUNNER_ROLES = "roles" +RUNNER_USERS = "users" +RUNNER_ROUTE = "route" +RUNNER_TTL = "ttl" DEFAULT_SCHEMA = { "title": "response_data", @@ -51,15 +47,14 @@ "continue": { "type": "boolean", "description": "Would you like to continue the workflow?", - "required": True + "required": True, } - } + }, } class Inquirer(runners.ActionRunner): - """This runner implements the ability to ask for more input during a workflow - """ + """This runner implements the ability to ask for more input during a workflow""" def __init__(self, runner_id): super(Inquirer, self).__init__(runner_id=runner_id) @@ -83,14 +78,11 @@ def run(self, action_parameters): # Assemble and dispatch trigger trigger_ref = sys_db_models.ResourceReference.to_string_reference( - pack=trigger_constants.INQUIRY_TRIGGER['pack'], - name=trigger_constants.INQUIRY_TRIGGER['name'] + pack=trigger_constants.INQUIRY_TRIGGER["pack"], + name=trigger_constants.INQUIRY_TRIGGER["name"], ) - trigger_payload = { - "id": str(exc.id), - "route": self.route - } + trigger_payload = {"id": str(exc.id), "route": self.route} self.trigger_dispatcher.dispatch(trigger_ref, trigger_payload) @@ -99,7 +91,7 @@ def run(self, action_parameters): "roles": self.roles_param, "users": self.users_param, "route": self.route, - "ttl": self.ttl + "ttl": self.ttl, } return (action_constants.LIVEACTION_STATUS_PENDING, result, None) @@ -110,9 +102,10 @@ def post_run(self, status, result): # is made in the run method, but because the liveaction hasn't update to pending status # yet, there is a race condition where the pause request is mishandled. if status == action_constants.LIVEACTION_STATUS_PENDING: - pause_parent = ( - self.liveaction.context.get("parent") and - not workflow_service.is_action_execution_under_workflow_context(self.liveaction) + pause_parent = self.liveaction.context.get( + "parent" + ) and not workflow_service.is_action_execution_under_workflow_context( + self.liveaction ) # For action execution under Action Chain workflows, request the entire @@ -122,7 +115,9 @@ def post_run(self, status, result): # to pause the workflow. if pause_parent: root_liveaction = action_service.get_root_liveaction(self.liveaction) - action_service.request_pause(root_liveaction, self.context.get('user', None)) + action_service.request_pause( + root_liveaction, self.context.get("user", None) + ) # Invoke post run of parent for common post run related work. super(Inquirer, self).post_run(status, result) @@ -133,4 +128,4 @@ def get_runner(): def get_metadata(): - return runners.get_metadata('inquirer_runner')[0] + return runners.get_metadata("inquirer_runner")[0] diff --git a/contrib/runners/inquirer_runner/setup.py b/contrib/runners/inquirer_runner/setup.py index 9be54704f93..44d4a4d7f7c 100644 --- a/contrib/runners/inquirer_runner/setup.py +++ b/contrib/runners/inquirer_runner/setup.py @@ -26,30 +26,32 @@ from inquirer_runner import __version__ BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt') +REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt") install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE) apply_vagrant_workaround() setup( - name='stackstorm-runner-inquirer', + name="stackstorm-runner-inquirer", version=__version__, - description=('Inquirer action runner for StackStorm event-driven automation platform'), - author='StackStorm', - author_email='info@stackstorm.com', - license='Apache License (2.0)', - url='https://stackstorm.com/', + description=( + "Inquirer action runner for StackStorm event-driven automation platform" + ), + author="StackStorm", + author_email="info@stackstorm.com", + license="Apache License (2.0)", + url="https://stackstorm.com/", install_requires=install_reqs, dependency_links=dep_links, - test_suite='tests', + test_suite="tests", zip_safe=False, include_package_data=True, - packages=find_packages(exclude=['setuptools', 'tests']), - package_data={'inquirer_runner': ['runner.yaml']}, + packages=find_packages(exclude=["setuptools", "tests"]), + package_data={"inquirer_runner": ["runner.yaml"]}, scripts=[], entry_points={ - 'st2common.runners.runner': [ - 'inquirer = inquirer_runner.inquirer_runner', + "st2common.runners.runner": [ + "inquirer = inquirer_runner.inquirer_runner", ], - } + }, ) diff --git a/contrib/runners/inquirer_runner/tests/unit/test_inquirer_runner.py b/contrib/runners/inquirer_runner/tests/unit/test_inquirer_runner.py index da9c70b78a8..caa47bc53a3 100644 --- a/contrib/runners/inquirer_runner/tests/unit/test_inquirer_runner.py +++ b/contrib/runners/inquirer_runner/tests/unit/test_inquirer_runner.py @@ -28,7 +28,7 @@ mock_exc_get = mock.Mock() -mock_exc_get.id = 'abcdef' +mock_exc_get.id = "abcdef" mock_inquiry_liveaction_db = mock.Mock() mock_inquiry_liveaction_db.result = {"response": {}} @@ -37,7 +37,7 @@ mock_action_utils.return_value = mock_inquiry_liveaction_db test_parent = mock.Mock() -test_parent.id = '1234567890' +test_parent.id = "1234567890" mock_get_root = mock.Mock() mock_get_root.return_value = test_parent @@ -45,38 +45,19 @@ mock_trigger_dispatcher = mock.Mock() mock_request_pause = mock.Mock() -test_user = 'st2admin' +test_user = "st2admin" -runner_params = { - "users": [], - "roles": [], - "route": "developers", - "schema": {} -} +runner_params = {"users": [], "roles": [], "route": "developers", "schema": {}} +@mock.patch.object(reactor_transport, "TriggerDispatcher", mock_trigger_dispatcher) +@mock.patch.object(action_utils, "get_liveaction_by_id", mock_action_utils) +@mock.patch.object(action_service, "request_pause", mock_request_pause) +@mock.patch.object(action_service, "get_root_liveaction", mock_get_root) @mock.patch.object( - reactor_transport, - 'TriggerDispatcher', - mock_trigger_dispatcher) -@mock.patch.object( - action_utils, - 'get_liveaction_by_id', - mock_action_utils) -@mock.patch.object( - action_service, - 'request_pause', - mock_request_pause) -@mock.patch.object( - action_service, - 'get_root_liveaction', - mock_get_root) -@mock.patch.object( - ex_db_access.ActionExecution, - 'get', - mock.MagicMock(return_value=mock_exc_get)) + ex_db_access.ActionExecution, "get", mock.MagicMock(return_value=mock_exc_get) +) class InquiryTestCase(st2tests.RunnerTestCase): - def tearDown(self): mock_trigger_dispatcher.reset_mock() mock_action_utils.reset_mock() @@ -85,17 +66,19 @@ def tearDown(self): def test_runner_creation(self): runner = inquirer_runner.get_runner() - self.assertIsNotNone(runner, 'Creation failed. No instance.') - self.assertEqual(type(runner), inquirer_runner.Inquirer, 'Creation failed. No instance.') + self.assertIsNotNone(runner, "Creation failed. No instance.") + self.assertEqual( + type(runner), inquirer_runner.Inquirer, "Creation failed. No instance." + ) def test_simple_inquiry(self): runner = inquirer_runner.get_runner() - runner.context = {'user': test_user} + runner.context = {"user": test_user} runner.action = self._get_mock_action_obj() runner.runner_parameters = runner_params runner.pre_run() - mock_inquiry_liveaction_db.context = {'parent': test_parent.id} + mock_inquiry_liveaction_db.context = {"parent": test_parent.id} runner.liveaction = mock_inquiry_liveaction_db (status, output, _) = runner.run({}) @@ -104,20 +87,16 @@ def test_simple_inquiry(self): self.assertEqual( output, { - 'users': [], - 'roles': [], - 'route': "developers", - 'schema': {}, - 'ttl': 1440 - } + "users": [], + "roles": [], + "route": "developers", + "schema": {}, + "ttl": 1440, + }, ) mock_trigger_dispatcher.return_value.dispatch.assert_called_once_with( - 'core.st2.generic.inquiry', - { - 'id': mock_exc_get.id, - 'route': "developers" - } + "core.st2.generic.inquiry", {"id": mock_exc_get.id, "route": "developers"} ) runner.post_run(action_constants.LIVEACTION_STATUS_PENDING, {}) @@ -125,37 +104,28 @@ def test_simple_inquiry(self): mock_request_pause.assert_called_once_with(test_parent, test_user) def test_inquiry_no_parent(self): - """Should behave like a regular execution, but without requesting a pause - """ + """Should behave like a regular execution, but without requesting a pause""" runner = inquirer_runner.get_runner() - runner.context = { - 'user': 'st2admin' - } + runner.context = {"user": "st2admin"} runner.action = self._get_mock_action_obj() runner.runner_parameters = runner_params runner.pre_run() - mock_inquiry_liveaction_db.context = { - "parent": None - } + mock_inquiry_liveaction_db.context = {"parent": None} (status, output, _) = runner.run({}) self.assertEqual(status, action_constants.LIVEACTION_STATUS_PENDING) self.assertEqual( output, { - 'users': [], - 'roles': [], - 'route': "developers", - 'schema': {}, - 'ttl': 1440 - } + "users": [], + "roles": [], + "route": "developers", + "schema": {}, + "ttl": 1440, + }, ) mock_trigger_dispatcher.return_value.dispatch.assert_called_once_with( - 'core.st2.generic.inquiry', - { - 'id': mock_exc_get.id, - 'route': "developers" - } + "core.st2.generic.inquiry", {"id": mock_exc_get.id, "route": "developers"} ) mock_request_pause.assert_not_called() diff --git a/contrib/runners/local_runner/dist_utils.py b/contrib/runners/local_runner/dist_utils.py index a6f62c8cc2a..2f2043cf29f 100644 --- a/contrib/runners/local_runner/dist_utils.py +++ b/contrib/runners/local_runner/dist_utils.py @@ -43,17 +43,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" __all__ = [ - 'check_pip_is_installed', - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_is_installed", + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] @@ -64,15 +64,15 @@ def check_pip_is_installed(): try: import pip # NOQA except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) return True -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ @@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'): import pip if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) return True @@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path): reqs = [] def _get_link(line): - vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+'] + vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"] for vcs_prefix in vcs_prefixes: - if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)): - req_name = re.findall('.*#egg=(.+)([&|@]).*$', line) + if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)): + req_name = re.findall(".*#egg=(.+)([&|@]).*$", line) if not req_name: - req_name = re.findall('.*#egg=(.+?)$', line) + req_name = re.findall(".*#egg=(.+?)$", line) else: req_name = req_name[0] if not req_name: - raise ValueError('Line "%s" is missing "#egg="' % (line)) + raise ValueError( + 'Line "%s" is missing "#egg="' % (line) + ) - link = line.replace('-e ', '').strip() + link = line.replace("-e ", "").strip() return link, req_name[0] return None, None - with open(requirements_file_path, 'r') as fp: + with open(requirements_file_path, "r") as fp: for line in fp.readlines(): line = line.strip() - if line.startswith('#') or not line: + if line.startswith("#") or not line: continue link, req_name = _get_link(line=line) @@ -131,8 +135,8 @@ def _get_link(line): else: req_name = line - if ';' in req_name: - req_name = req_name.split(';')[0].strip() + if ";" in req_name: + req_name = req_name.split(";")[0].strip() reqs.append(req_name) @@ -146,7 +150,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -155,14 +159,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/contrib/runners/local_runner/local_runner/__init__.py b/contrib/runners/local_runner/local_runner/__init__.py index bbe290db9a7..e6d3f15e0bd 100644 --- a/contrib/runners/local_runner/local_runner/__init__.py +++ b/contrib/runners/local_runner/local_runner/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = "3.4dev" diff --git a/contrib/runners/local_runner/local_runner/base.py b/contrib/runners/local_runner/local_runner/base.py index 5bcf20137f4..4fda9c18669 100644 --- a/contrib/runners/local_runner/local_runner/base.py +++ b/contrib/runners/local_runner/local_runner/base.py @@ -39,32 +39,36 @@ from st2common.services.action import store_execution_output_data from st2common.runners.utils import make_read_and_store_stream_func -__all__ = [ - 'BaseLocalShellRunner', - - 'RUNNER_COMMAND' -] +__all__ = ["BaseLocalShellRunner", "RUNNER_COMMAND"] LOG = logging.getLogger(__name__) -DEFAULT_KWARG_OP = '--' +DEFAULT_KWARG_OP = "--" LOGGED_USER_USERNAME = pwd.getpwuid(os.getuid())[0] # constants to lookup in runner_parameters. -RUNNER_SUDO = 'sudo' -RUNNER_SUDO_PASSWORD = 'sudo_password' -RUNNER_ON_BEHALF_USER = 'user' -RUNNER_COMMAND = 'cmd' -RUNNER_CWD = 'cwd' -RUNNER_ENV = 'env' -RUNNER_KWARG_OP = 'kwarg_op' -RUNNER_TIMEOUT = 'timeout' +RUNNER_SUDO = "sudo" +RUNNER_SUDO_PASSWORD = "sudo_password" +RUNNER_ON_BEHALF_USER = "user" +RUNNER_COMMAND = "cmd" +RUNNER_CWD = "cwd" +RUNNER_ENV = "env" +RUNNER_KWARG_OP = "kwarg_op" +RUNNER_TIMEOUT = "timeout" PROC_EXIT_CODE_TO_LIVEACTION_STATUS_MAP = { - str(exit_code_constants.SUCCESS_EXIT_CODE): action_constants.LIVEACTION_STATUS_SUCCEEDED, - str(exit_code_constants.FAILURE_EXIT_CODE): action_constants.LIVEACTION_STATUS_FAILED, - str(-1 * exit_code_constants.SIGKILL_EXIT_CODE): action_constants.LIVEACTION_STATUS_TIMED_OUT, - str(-1 * exit_code_constants.SIGTERM_EXIT_CODE): action_constants.LIVEACTION_STATUS_ABANDONED + str( + exit_code_constants.SUCCESS_EXIT_CODE + ): action_constants.LIVEACTION_STATUS_SUCCEEDED, + str( + exit_code_constants.FAILURE_EXIT_CODE + ): action_constants.LIVEACTION_STATUS_FAILED, + str( + -1 * exit_code_constants.SIGKILL_EXIT_CODE + ): action_constants.LIVEACTION_STATUS_TIMED_OUT, + str( + -1 * exit_code_constants.SIGTERM_EXIT_CODE + ): action_constants.LIVEACTION_STATUS_ABANDONED, } @@ -77,7 +81,8 @@ class BaseLocalShellRunner(ActionRunner, ShellRunnerMixin): Note: The user under which the action runner service is running (stanley user by default) needs to have pasworless sudo access set up. """ - KEYS_TO_TRANSFORM = ['stdout', 'stderr'] + + KEYS_TO_TRANSFORM = ["stdout", "stderr"] def __init__(self, runner_id): super(BaseLocalShellRunner, self).__init__(runner_id=runner_id) @@ -87,14 +92,17 @@ def pre_run(self): self._sudo = self.runner_parameters.get(RUNNER_SUDO, False) self._sudo_password = self.runner_parameters.get(RUNNER_SUDO_PASSWORD, None) - self._on_behalf_user = self.context.get(RUNNER_ON_BEHALF_USER, LOGGED_USER_USERNAME) + self._on_behalf_user = self.context.get( + RUNNER_ON_BEHALF_USER, LOGGED_USER_USERNAME + ) self._user = cfg.CONF.system_user.user self._cwd = self.runner_parameters.get(RUNNER_CWD, None) self._env = self.runner_parameters.get(RUNNER_ENV, {}) self._env = self._env or {} self._kwarg_op = self.runner_parameters.get(RUNNER_KWARG_OP, DEFAULT_KWARG_OP) self._timeout = self.runner_parameters.get( - RUNNER_TIMEOUT, runner_constants.LOCAL_RUNNER_DEFAULT_ACTION_TIMEOUT) + RUNNER_TIMEOUT, runner_constants.LOCAL_RUNNER_DEFAULT_ACTION_TIMEOUT + ) def _run(self, action): env_vars = self._env @@ -110,8 +118,11 @@ def _run(self, action): # For consistency with the old Fabric based runner, make sure the file is executable if script_action: script_local_path_abs = self.entry_point - args = 'chmod +x %s ; %s' % (script_local_path_abs, args) - sanitized_args = 'chmod +x %s ; %s' % (script_local_path_abs, sanitized_args) + args = "chmod +x %s ; %s" % (script_local_path_abs, args) + sanitized_args = "chmod +x %s ; %s" % ( + script_local_path_abs, + sanitized_args, + ) env = os.environ.copy() @@ -122,22 +133,38 @@ def _run(self, action): st2_env_vars = self._get_common_action_env_variables() env.update(st2_env_vars) - LOG.info('Executing action via LocalRunner: %s', self.runner_id) - LOG.info('[Action info] name: %s, Id: %s, command: %s, user: %s, sudo: %s' % - (action.name, action.action_exec_id, sanitized_args, action.user, action.sudo)) + LOG.info("Executing action via LocalRunner: %s", self.runner_id) + LOG.info( + "[Action info] name: %s, Id: %s, command: %s, user: %s, sudo: %s" + % ( + action.name, + action.action_exec_id, + sanitized_args, + action.user, + action.sudo, + ) + ) stdout = StringIO() stderr = StringIO() - store_execution_stdout_line = functools.partial(store_execution_output_data, - output_type='stdout') - store_execution_stderr_line = functools.partial(store_execution_output_data, - output_type='stderr') + store_execution_stdout_line = functools.partial( + store_execution_output_data, output_type="stdout" + ) + store_execution_stderr_line = functools.partial( + store_execution_output_data, output_type="stderr" + ) - read_and_store_stdout = make_read_and_store_stream_func(execution_db=self.execution, - action_db=self.action, store_data_func=store_execution_stdout_line) - read_and_store_stderr = make_read_and_store_stream_func(execution_db=self.execution, - action_db=self.action, store_data_func=store_execution_stderr_line) + read_and_store_stdout = make_read_and_store_stream_func( + execution_db=self.execution, + action_db=self.action, + store_data_func=store_execution_stdout_line, + ) + read_and_store_stderr = make_read_and_store_stream_func( + execution_db=self.execution, + action_db=self.action, + store_data_func=store_execution_stderr_line, + ) subprocess = concurrency.get_subprocess_module() @@ -145,9 +172,10 @@ def _run(self, action): # Note: We don't need to explicitly escape the argument because we pass command as a list # to subprocess.Popen and all the arguments are escaped by the function. if self._sudo_password: - LOG.debug('Supplying sudo password via stdin') - echo_process = concurrency.subprocess_popen(['echo', self._sudo_password + '\n'], - stdout=subprocess.PIPE) + LOG.debug("Supplying sudo password via stdin") + echo_process = concurrency.subprocess_popen( + ["echo", self._sudo_password + "\n"], stdout=subprocess.PIPE + ) stdin = echo_process.stdout else: stdin = None @@ -161,57 +189,64 @@ def _run(self, action): # Ideally os.killpg should have done the trick but for some reason that failed. # Note: pkill will set the returncode to 143 so we don't need to explicitly set # it to some non-zero value. - exit_code, stdout, stderr, timed_out = shell.run_command(cmd=args, - stdin=stdin, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - shell=True, - cwd=self._cwd, - env=env, - timeout=self._timeout, - preexec_func=os.setsid, - kill_func=kill_process, - read_stdout_func=read_and_store_stdout, - read_stderr_func=read_and_store_stderr, - read_stdout_buffer=stdout, - read_stderr_buffer=stderr) + exit_code, stdout, stderr, timed_out = shell.run_command( + cmd=args, + stdin=stdin, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=True, + cwd=self._cwd, + env=env, + timeout=self._timeout, + preexec_func=os.setsid, + kill_func=kill_process, + read_stdout_func=read_and_store_stdout, + read_stderr_func=read_and_store_stderr, + read_stdout_buffer=stdout, + read_stderr_buffer=stderr, + ) error = None if timed_out: - error = 'Action failed to complete in %s seconds' % (self._timeout) + error = "Action failed to complete in %s seconds" % (self._timeout) exit_code = -1 * exit_code_constants.SIGKILL_EXIT_CODE # Detect if user provided an invalid sudo password or sudo is not configured for that user if self._sudo_password: - if re.search(r'sudo: \d+ incorrect password attempts', stderr): - match = re.search(r'\[sudo\] password for (.+?)\:', stderr) + if re.search(r"sudo: \d+ incorrect password attempts", stderr): + match = re.search(r"\[sudo\] password for (.+?)\:", stderr) if match: username = match.groups()[0] else: - username = 'unknown' + username = "unknown" - error = ('Invalid sudo password provided or sudo is not configured for this user ' - '(%s)' % (username)) + error = ( + "Invalid sudo password provided or sudo is not configured for this user " + "(%s)" % (username) + ) exit_code = -1 - succeeded = (exit_code == exit_code_constants.SUCCESS_EXIT_CODE) + succeeded = exit_code == exit_code_constants.SUCCESS_EXIT_CODE result = { - 'failed': not succeeded, - 'succeeded': succeeded, - 'return_code': exit_code, - 'stdout': strip_shell_chars(stdout), - 'stderr': strip_shell_chars(stderr) + "failed": not succeeded, + "succeeded": succeeded, + "return_code": exit_code, + "stdout": strip_shell_chars(stdout), + "stderr": strip_shell_chars(stderr), } if error: - result['error'] = error + result["error"] = error status = PROC_EXIT_CODE_TO_LIVEACTION_STATUS_MAP.get( - str(exit_code), - action_constants.LIVEACTION_STATUS_FAILED + str(exit_code), action_constants.LIVEACTION_STATUS_FAILED ) - return (status, jsonify.json_loads(result, BaseLocalShellRunner.KEYS_TO_TRANSFORM), None) + return ( + status, + jsonify.json_loads(result, BaseLocalShellRunner.KEYS_TO_TRANSFORM), + None, + ) diff --git a/contrib/runners/local_runner/local_runner/local_shell_command_runner.py b/contrib/runners/local_runner/local_runner/local_shell_command_runner.py index 4ae61f32259..cbf603de27a 100644 --- a/contrib/runners/local_runner/local_runner/local_shell_command_runner.py +++ b/contrib/runners/local_runner/local_runner/local_shell_command_runner.py @@ -23,28 +23,25 @@ from local_runner.base import BaseLocalShellRunner from local_runner.base import RUNNER_COMMAND -__all__ = [ - 'LocalShellCommandRunner', - - 'get_runner', - 'get_metadata' -] +__all__ = ["LocalShellCommandRunner", "get_runner", "get_metadata"] class LocalShellCommandRunner(BaseLocalShellRunner): def run(self, action_parameters): if self.entry_point: - raise ValueError('entry_point is only valid for local-shell-script runner') + raise ValueError("entry_point is only valid for local-shell-script runner") command = self.runner_parameters.get(RUNNER_COMMAND, None) - action = ShellCommandAction(name=self.action_name, - action_exec_id=str(self.liveaction_id), - command=command, - user=self._user, - env_vars=self._env, - sudo=self._sudo, - timeout=self._timeout, - sudo_password=self._sudo_password) + action = ShellCommandAction( + name=self.action_name, + action_exec_id=str(self.liveaction_id), + command=command, + user=self._user, + env_vars=self._env, + sudo=self._sudo, + timeout=self._timeout, + sudo_password=self._sudo_password, + ) return self._run(action=action) @@ -54,7 +51,10 @@ def get_runner(): def get_metadata(): - metadata = get_runner_metadata('local_runner') - metadata = [runner for runner in metadata if - runner['runner_module'] == __name__.split('.')[-1]][0] + metadata = get_runner_metadata("local_runner") + metadata = [ + runner + for runner in metadata + if runner["runner_module"] == __name__.split(".")[-1] + ][0] return metadata diff --git a/contrib/runners/local_runner/local_runner/local_shell_script_runner.py b/contrib/runners/local_runner/local_runner/local_shell_script_runner.py index 24a0fe6ddbe..257e457ca11 100644 --- a/contrib/runners/local_runner/local_runner/local_shell_script_runner.py +++ b/contrib/runners/local_runner/local_runner/local_shell_script_runner.py @@ -23,34 +23,31 @@ from local_runner.base import BaseLocalShellRunner -__all__ = [ - 'LocalShellScriptRunner', - - 'get_runner', - 'get_metadata' -] +__all__ = ["LocalShellScriptRunner", "get_runner", "get_metadata"] class LocalShellScriptRunner(BaseLocalShellRunner, GitWorktreeActionRunner): def run(self, action_parameters): if not self.entry_point: - raise ValueError('Missing entry_point action metadata attribute') + raise ValueError("Missing entry_point action metadata attribute") script_local_path_abs = self.entry_point positional_args, named_args = self._get_script_args(action_parameters) named_args = self._transform_named_args(named_args) - action = ShellScriptAction(name=self.action_name, - action_exec_id=str(self.liveaction_id), - script_local_path_abs=script_local_path_abs, - named_args=named_args, - positional_args=positional_args, - user=self._user, - env_vars=self._env, - sudo=self._sudo, - timeout=self._timeout, - cwd=self._cwd, - sudo_password=self._sudo_password) + action = ShellScriptAction( + name=self.action_name, + action_exec_id=str(self.liveaction_id), + script_local_path_abs=script_local_path_abs, + named_args=named_args, + positional_args=positional_args, + user=self._user, + env_vars=self._env, + sudo=self._sudo, + timeout=self._timeout, + cwd=self._cwd, + sudo_password=self._sudo_password, + ) return self._run(action=action) @@ -60,7 +57,10 @@ def get_runner(): def get_metadata(): - metadata = get_runner_metadata('local_runner') - metadata = [runner for runner in metadata if - runner['runner_module'] == __name__.split('.')[-1]][0] + metadata = get_runner_metadata("local_runner") + metadata = [ + runner + for runner in metadata + if runner["runner_module"] == __name__.split(".")[-1] + ][0] return metadata diff --git a/contrib/runners/local_runner/setup.py b/contrib/runners/local_runner/setup.py index feb1cb65541..063314ab742 100644 --- a/contrib/runners/local_runner/setup.py +++ b/contrib/runners/local_runner/setup.py @@ -26,32 +26,34 @@ from local_runner import __version__ BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt') +REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt") install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE) apply_vagrant_workaround() setup( - name='stackstorm-runner-local', + name="stackstorm-runner-local", version=__version__, - description=('Local Shell Command and Script action runner for StackStorm event-driven ' - 'automation platform'), - author='StackStorm', - author_email='info@stackstorm.com', - license='Apache License (2.0)', - url='https://stackstorm.com/', + description=( + "Local Shell Command and Script action runner for StackStorm event-driven " + "automation platform" + ), + author="StackStorm", + author_email="info@stackstorm.com", + license="Apache License (2.0)", + url="https://stackstorm.com/", install_requires=install_reqs, dependency_links=dep_links, - test_suite='tests', + test_suite="tests", zip_safe=False, include_package_data=True, - packages=find_packages(exclude=['setuptools', 'tests']), - package_data={'local_runner': ['runner.yaml']}, + packages=find_packages(exclude=["setuptools", "tests"]), + package_data={"local_runner": ["runner.yaml"]}, scripts=[], entry_points={ - 'st2common.runners.runner': [ - 'local-shell-cmd = local_runner.local_shell_command_runner', - 'local-shell-script = local_runner.local_shell_script_runner', + "st2common.runners.runner": [ + "local-shell-cmd = local_runner.local_shell_command_runner", + "local-shell-script = local_runner.local_shell_script_runner", ], - } + }, ) diff --git a/contrib/runners/local_runner/tests/integration/test_localrunner.py b/contrib/runners/local_runner/tests/integration/test_localrunner.py index 0e5a2f3efcf..05c241f46bc 100644 --- a/contrib/runners/local_runner/tests/integration/test_localrunner.py +++ b/contrib/runners/local_runner/tests/integration/test_localrunner.py @@ -22,6 +22,7 @@ import st2tests.config as tests_config from six.moves import range + tests_config.parse_args() from st2common.constants import action as action_constants @@ -40,13 +41,10 @@ from local_runner.local_shell_command_runner import LocalShellCommandRunner from local_runner.local_shell_script_runner import LocalShellScriptRunner -__all__ = [ - 'LocalShellCommandRunnerTestCase', - 'LocalShellScriptRunnerTestCase' -] +__all__ = ["LocalShellCommandRunnerTestCase", "LocalShellScriptRunnerTestCase"] MOCK_EXECUTION = mock.Mock() -MOCK_EXECUTION.id = '598dbf0c0640fd54bffc688b' +MOCK_EXECUTION.id = "598dbf0c0640fd54bffc688b" class LocalShellCommandRunnerTestCase(RunnerTestCase, CleanDbTestCase): @@ -56,108 +54,115 @@ def setUp(self): super(LocalShellCommandRunnerTestCase, self).setUp() # False is a default behavior so end result should be the same - cfg.CONF.set_override(name='stream_output', group='actionrunner', override=False) + cfg.CONF.set_override( + name="stream_output", group="actionrunner", override=False + ) def test_shell_command_action_basic(self): models = self.fixtures_loader.load_models( - fixtures_pack='generic', fixtures_dict={'actions': ['local.yaml']}) - action_db = models['actions']['local.yaml'] + fixtures_pack="generic", fixtures_dict={"actions": ["local.yaml"]} + ) + action_db = models["actions"]["local.yaml"] - runner = self._get_runner(action_db, cmd='echo 10') + runner = self._get_runner(action_db, cmd="echo 10") runner.pre_run() status, result, _ = runner.run({}) runner.post_run(status, result) self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(result['stdout'], 10) + self.assertEqual(result["stdout"], 10) # End result should be the same when streaming is enabled - cfg.CONF.set_override(name='stream_output', group='actionrunner', override=True) + cfg.CONF.set_override(name="stream_output", group="actionrunner", override=True) # Verify initial state output_dbs = ActionExecutionOutput.get_all() self.assertEqual(len(output_dbs), 0) - runner = self._get_runner(action_db, cmd='echo 10') + runner = self._get_runner(action_db, cmd="echo 10") runner.pre_run() status, result, _ = runner.run({}) runner.post_run(status, result) self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(result['stdout'], 10) + self.assertEqual(result["stdout"], 10) output_dbs = ActionExecutionOutput.get_all() self.assertEqual(len(output_dbs), 1) - self.assertEqual(output_dbs[0].output_type, 'stdout') - self.assertEqual(output_dbs[0].data, '10\n') + self.assertEqual(output_dbs[0].output_type, "stdout") + self.assertEqual(output_dbs[0].data, "10\n") def test_timeout(self): models = self.fixtures_loader.load_models( - fixtures_pack='generic', fixtures_dict={'actions': ['local.yaml']}) - action_db = models['actions']['local.yaml'] + fixtures_pack="generic", fixtures_dict={"actions": ["local.yaml"]} + ) + action_db = models["actions"]["local.yaml"] # smaller timeout == faster tests. - runner = self._get_runner(action_db, cmd='sleep 10', timeout=0.01) + runner = self._get_runner(action_db, cmd="sleep 10", timeout=0.01) runner.pre_run() status, result, _ = runner.run({}) runner.post_run(status, result) self.assertEqual(status, action_constants.LIVEACTION_STATUS_TIMED_OUT) @mock.patch.object( - shell, 'run_command', - mock.MagicMock(return_value=(-15, '', '', False))) + shell, "run_command", mock.MagicMock(return_value=(-15, "", "", False)) + ) def test_shutdown(self): models = self.fixtures_loader.load_models( - fixtures_pack='generic', fixtures_dict={'actions': ['local.yaml']}) - action_db = models['actions']['local.yaml'] - runner = self._get_runner(action_db, cmd='sleep 0.1') + fixtures_pack="generic", fixtures_dict={"actions": ["local.yaml"]} + ) + action_db = models["actions"]["local.yaml"] + runner = self._get_runner(action_db, cmd="sleep 0.1") runner.pre_run() status, result, _ = runner.run({}) self.assertEqual(status, action_constants.LIVEACTION_STATUS_ABANDONED) def test_common_st2_env_vars_are_available_to_the_action(self): models = self.fixtures_loader.load_models( - fixtures_pack='generic', fixtures_dict={'actions': ['local.yaml']}) - action_db = models['actions']['local.yaml'] + fixtures_pack="generic", fixtures_dict={"actions": ["local.yaml"]} + ) + action_db = models["actions"]["local.yaml"] - runner = self._get_runner(action_db, cmd='echo $ST2_ACTION_API_URL') + runner = self._get_runner(action_db, cmd="echo $ST2_ACTION_API_URL") runner.pre_run() status, result, _ = runner.run({}) runner.post_run(status, result) self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(result['stdout'].strip(), get_full_public_api_url()) + self.assertEqual(result["stdout"].strip(), get_full_public_api_url()) - runner = self._get_runner(action_db, cmd='echo $ST2_ACTION_AUTH_TOKEN') + runner = self._get_runner(action_db, cmd="echo $ST2_ACTION_AUTH_TOKEN") runner.pre_run() status, result, _ = runner.run({}) runner.post_run(status, result) self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(result['stdout'].strip(), 'mock-token') + self.assertEqual(result["stdout"].strip(), "mock-token") def test_sudo_and_env_variable_preservation(self): # Verify that the environment environment are correctly preserved when running as a # root / non-system user # Note: This test will fail if SETENV option is not present in the sudoers file models = self.fixtures_loader.load_models( - fixtures_pack='generic', fixtures_dict={'actions': ['local.yaml']}) - action_db = models['actions']['local.yaml'] + fixtures_pack="generic", fixtures_dict={"actions": ["local.yaml"]} + ) + action_db = models["actions"]["local.yaml"] - cmd = 'echo `whoami` ; echo ${VAR1}' - env = {'VAR1': 'poniesponies'} + cmd = "echo `whoami` ; echo ${VAR1}" + env = {"VAR1": "poniesponies"} runner = self._get_runner(action_db, cmd=cmd, sudo=True, env=env) runner.pre_run() status, result, _ = runner.run({}) runner.post_run(status, result) self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(result['stdout'].strip(), 'root\nponiesponies') + self.assertEqual(result["stdout"].strip(), "root\nponiesponies") - @mock.patch('st2common.util.concurrency.subprocess_popen') - @mock.patch('st2common.util.concurrency.spawn') + @mock.patch("st2common.util.concurrency.subprocess_popen") + @mock.patch("st2common.util.concurrency.spawn") def test_action_stdout_and_stderr_is_stored_in_the_db(self, mock_spawn, mock_popen): # Feature is enabled - cfg.CONF.set_override(name='stream_output', group='actionrunner', override=True) + cfg.CONF.set_override(name="stream_output", group="actionrunner", override=True) # Note: We need to mock spawn function so we can test everything in single event loop # iteration @@ -165,78 +170,75 @@ def test_action_stdout_and_stderr_is_stored_in_the_db(self, mock_spawn, mock_pop # No output to stdout and no result (implicit None) mock_stdout = [ - 'stdout line 1\n', - 'stdout line 2\n', - ] - mock_stderr = [ - 'stderr line 1\n', - 'stderr line 2\n', - 'stderr line 3\n' + "stdout line 1\n", + "stdout line 2\n", ] + mock_stderr = ["stderr line 1\n", "stderr line 2\n", "stderr line 3\n"] mock_process = mock.Mock() mock_process.returncode = 0 mock_popen.return_value = mock_process mock_process.stdout.closed = False mock_process.stderr.closed = False - mock_process.stdout.readline = make_mock_stream_readline(mock_process.stdout, mock_stdout, - stop_counter=2) - mock_process.stderr.readline = make_mock_stream_readline(mock_process.stderr, mock_stderr, - stop_counter=3) + mock_process.stdout.readline = make_mock_stream_readline( + mock_process.stdout, mock_stdout, stop_counter=2 + ) + mock_process.stderr.readline = make_mock_stream_readline( + mock_process.stderr, mock_stderr, stop_counter=3 + ) models = self.fixtures_loader.load_models( - fixtures_pack='generic', fixtures_dict={'actions': ['local.yaml']}) - action_db = models['actions']['local.yaml'] + fixtures_pack="generic", fixtures_dict={"actions": ["local.yaml"]} + ) + action_db = models["actions"]["local.yaml"] - runner = self._get_runner(action_db, cmd='echo $ST2_ACTION_API_URL') + runner = self._get_runner(action_db, cmd="echo $ST2_ACTION_API_URL") runner.pre_run() status, result, _ = runner.run({}) runner.post_run(status, result) self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(result['stdout'], 'stdout line 1\nstdout line 2') - self.assertEqual(result['stderr'], 'stderr line 1\nstderr line 2\nstderr line 3') - self.assertEqual(result['return_code'], 0) + self.assertEqual(result["stdout"], "stdout line 1\nstdout line 2") + self.assertEqual( + result["stderr"], "stderr line 1\nstderr line 2\nstderr line 3" + ) + self.assertEqual(result["return_code"], 0) # Verify stdout and stderr lines have been correctly stored in the db - output_dbs = ActionExecutionOutput.query(output_type='stdout') + output_dbs = ActionExecutionOutput.query(output_type="stdout") self.assertEqual(len(output_dbs), 2) self.assertEqual(output_dbs[0].data, mock_stdout[0]) self.assertEqual(output_dbs[1].data, mock_stdout[1]) - output_dbs = ActionExecutionOutput.query(output_type='stderr') + output_dbs = ActionExecutionOutput.query(output_type="stderr") self.assertEqual(len(output_dbs), 3) self.assertEqual(output_dbs[0].data, mock_stderr[0]) self.assertEqual(output_dbs[1].data, mock_stderr[1]) self.assertEqual(output_dbs[2].data, mock_stderr[2]) - @mock.patch('st2common.util.concurrency.subprocess_popen') - @mock.patch('st2common.util.concurrency.spawn') - def test_action_stdout_and_stderr_is_stored_in_the_db_short_running_action(self, mock_spawn, - mock_popen): + @mock.patch("st2common.util.concurrency.subprocess_popen") + @mock.patch("st2common.util.concurrency.spawn") + def test_action_stdout_and_stderr_is_stored_in_the_db_short_running_action( + self, mock_spawn, mock_popen + ): # Verify that we correctly retrieve all the output and wait for stdout and stderr reading # threads for short running actions. models = self.fixtures_loader.load_models( - fixtures_pack='generic', fixtures_dict={'actions': ['local.yaml']}) - action_db = models['actions']['local.yaml'] + fixtures_pack="generic", fixtures_dict={"actions": ["local.yaml"]} + ) + action_db = models["actions"]["local.yaml"] # Feature is enabled - cfg.CONF.set_override(name='stream_output', group='actionrunner', override=True) + cfg.CONF.set_override(name="stream_output", group="actionrunner", override=True) # Note: We need to mock spawn function so we can test everything in single event loop # iteration mock_spawn.side_effect = blocking_eventlet_spawn # No output to stdout and no result (implicit None) - mock_stdout = [ - 'stdout line 1\n', - 'stdout line 2\n' - ] - mock_stderr = [ - 'stderr line 1\n', - 'stderr line 2\n' - ] + mock_stdout = ["stdout line 1\n", "stdout line 2\n"] + mock_stderr = ["stderr line 1\n", "stderr line 2\n"] # We add a sleep to simulate action process exiting before we finish reading data from mock_process = mock.Mock() @@ -244,11 +246,12 @@ def test_action_stdout_and_stderr_is_stored_in_the_db_short_running_action(self, mock_popen.return_value = mock_process mock_process.stdout.closed = False mock_process.stderr.closed = False - mock_process.stdout.readline = make_mock_stream_readline(mock_process.stdout, mock_stdout, - stop_counter=2, - sleep_delay=1) - mock_process.stderr.readline = make_mock_stream_readline(mock_process.stderr, mock_stderr, - stop_counter=2) + mock_process.stdout.readline = make_mock_stream_readline( + mock_process.stdout, mock_stdout, stop_counter=2, sleep_delay=1 + ) + mock_process.stderr.readline = make_mock_stream_readline( + mock_process.stderr, mock_stderr, stop_counter=2 + ) for index in range(1, 4): mock_process.stdout.closed = False @@ -263,12 +266,12 @@ def test_action_stdout_and_stderr_is_stored_in_the_db_short_running_action(self, self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(result['stdout'], 'stdout line 1\nstdout line 2') - self.assertEqual(result['stderr'], 'stderr line 1\nstderr line 2') - self.assertEqual(result['return_code'], 0) + self.assertEqual(result["stdout"], "stdout line 1\nstdout line 2") + self.assertEqual(result["stderr"], "stderr line 1\nstderr line 2") + self.assertEqual(result["return_code"], 0) # Verify stdout and stderr lines have been correctly stored in the db - output_dbs = ActionExecutionOutput.query(output_type='stdout') + output_dbs = ActionExecutionOutput.query(output_type="stdout") if index == 1: db_index_1 = 0 @@ -287,7 +290,7 @@ def test_action_stdout_and_stderr_is_stored_in_the_db_short_running_action(self, self.assertEqual(output_dbs[db_index_1].data, mock_stdout[0]) self.assertEqual(output_dbs[db_index_2].data, mock_stdout[1]) - output_dbs = ActionExecutionOutput.query(output_type='stderr') + output_dbs = ActionExecutionOutput.query(output_type="stderr") self.assertEqual(len(output_dbs), (index * 2)) self.assertEqual(output_dbs[db_index_1].data, mock_stderr[0]) self.assertEqual(output_dbs[db_index_2].data, mock_stderr[1]) @@ -295,16 +298,13 @@ def test_action_stdout_and_stderr_is_stored_in_the_db_short_running_action(self, def test_shell_command_sudo_password_is_passed_to_sudo_binary(self): # Verify that sudo password is correctly passed to sudo binary via stdin models = self.fixtures_loader.load_models( - fixtures_pack='generic', fixtures_dict={'actions': ['local.yaml']}) - action_db = models['actions']['local.yaml'] + fixtures_pack="generic", fixtures_dict={"actions": ["local.yaml"]} + ) + action_db = models["actions"]["local.yaml"] - sudo_passwords = [ - 'pass 1', - 'sudopass', - '$sudo p@ss 2' - ] + sudo_passwords = ["pass 1", "sudopass", "$sudo p@ss 2"] - cmd = ('{ read sudopass; echo $sudopass; }') + cmd = "{ read sudopass; echo $sudopass; }" # without sudo for sudo_password in sudo_passwords: @@ -314,9 +314,8 @@ def test_shell_command_sudo_password_is_passed_to_sudo_binary(self): status, result, _ = runner.run({}) runner.post_run(status, result) - self.assertEqual(status, - action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(result['stdout'], sudo_password) + self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + self.assertEqual(result["stdout"], sudo_password) # with sudo for sudo_password in sudo_passwords: @@ -327,12 +326,13 @@ def test_shell_command_sudo_password_is_passed_to_sudo_binary(self): status, result, _ = runner.run({}) runner.post_run(status, result) - self.assertEqual(status, - action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(result['stdout'], sudo_password) + self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + self.assertEqual(result["stdout"], sudo_password) # Verify new process which provides password via stdin to the command is created - with mock.patch('st2common.util.concurrency.subprocess_popen') as mock_subproc_popen: + with mock.patch( + "st2common.util.concurrency.subprocess_popen" + ) as mock_subproc_popen: index = 0 for sudo_password in sudo_passwords: runner = self._get_runner(action_db, cmd=cmd) @@ -349,58 +349,67 @@ def test_shell_command_sudo_password_is_passed_to_sudo_binary(self): index += 1 - self.assertEqual(call_args[0][0], ['echo', '%s\n' % (sudo_password)]) + self.assertEqual(call_args[0][0], ["echo", "%s\n" % (sudo_password)]) self.assertEqual(index, len(sudo_passwords)) def test_shell_command_invalid_stdout_password(self): # Simulate message printed to stderr by sudo when invalid sudo password is provided models = self.fixtures_loader.load_models( - fixtures_pack='generic', fixtures_dict={'actions': ['local.yaml']}) - action_db = models['actions']['local.yaml'] - - cmd = ('echo "[sudo] password for bar: Sorry, try again.\n[sudo] password for bar:' - ' Sorry, try again.\n[sudo] password for bar: \nsudo: 2 incorrect password ' - 'attempts" 1>&2; exit 1') + fixtures_pack="generic", fixtures_dict={"actions": ["local.yaml"]} + ) + action_db = models["actions"]["local.yaml"] + + cmd = ( + 'echo "[sudo] password for bar: Sorry, try again.\n[sudo] password for bar:' + " Sorry, try again.\n[sudo] password for bar: \nsudo: 2 incorrect password " + 'attempts" 1>&2; exit 1' + ) runner = self._get_runner(action_db, cmd=cmd) runner.pre_run() - runner._sudo_password = 'pass' + runner._sudo_password = "pass" status, result, _ = runner.run({}) runner.post_run(status, result) - expected_error = ('Invalid sudo password provided or sudo is not configured for this ' - 'user (bar)') + expected_error = ( + "Invalid sudo password provided or sudo is not configured for this " + "user (bar)" + ) self.assertEqual(status, action_constants.LIVEACTION_STATUS_FAILED) - self.assertEqual(result['error'], expected_error) - self.assertEqual(result['stdout'], '') + self.assertEqual(result["error"], expected_error) + self.assertEqual(result["stdout"], "") @staticmethod - def _get_runner(action_db, - entry_point=None, - cmd=None, - on_behalf_user=None, - user=None, - kwarg_op=local_runner.DEFAULT_KWARG_OP, - timeout=LOCAL_RUNNER_DEFAULT_ACTION_TIMEOUT, - sudo=False, - env=None): + def _get_runner( + action_db, + entry_point=None, + cmd=None, + on_behalf_user=None, + user=None, + kwarg_op=local_runner.DEFAULT_KWARG_OP, + timeout=LOCAL_RUNNER_DEFAULT_ACTION_TIMEOUT, + sudo=False, + env=None, + ): runner = LocalShellCommandRunner(uuid.uuid4().hex) runner.execution = MOCK_EXECUTION runner.action = action_db runner.action_name = action_db.name runner.liveaction_id = uuid.uuid4().hex runner.entry_point = entry_point - runner.runner_parameters = {local_runner.RUNNER_COMMAND: cmd, - local_runner.RUNNER_SUDO: sudo, - local_runner.RUNNER_ENV: env, - local_runner.RUNNER_ON_BEHALF_USER: user, - local_runner.RUNNER_KWARG_OP: kwarg_op, - local_runner.RUNNER_TIMEOUT: timeout} + runner.runner_parameters = { + local_runner.RUNNER_COMMAND: cmd, + local_runner.RUNNER_SUDO: sudo, + local_runner.RUNNER_ENV: env, + local_runner.RUNNER_ON_BEHALF_USER: user, + local_runner.RUNNER_KWARG_OP: kwarg_op, + local_runner.RUNNER_TIMEOUT: timeout, + } runner.context = dict() runner.callback = dict() runner.libs_dir_path = None runner.auth_token = mock.Mock() - runner.auth_token.token = 'mock-token' + runner.auth_token.token = "mock-token" return runner @@ -411,22 +420,27 @@ def setUp(self): super(LocalShellScriptRunnerTestCase, self).setUp() # False is a default behavior so end result should be the same - cfg.CONF.set_override(name='stream_output', group='actionrunner', override=False) + cfg.CONF.set_override( + name="stream_output", group="actionrunner", override=False + ) def test_script_with_parameters_parameter_serialization(self): models = self.fixtures_loader.load_models( - fixtures_pack='generic', fixtures_dict={'actions': ['local_script_with_params.yaml']}) - action_db = models['actions']['local_script_with_params.yaml'] - entry_point = os.path.join(get_fixtures_base_path(), - 'generic/actions/local_script_with_params.sh') + fixtures_pack="generic", + fixtures_dict={"actions": ["local_script_with_params.yaml"]}, + ) + action_db = models["actions"]["local_script_with_params.yaml"] + entry_point = os.path.join( + get_fixtures_base_path(), "generic/actions/local_script_with_params.sh" + ) action_parameters = { - 'param_string': 'test string', - 'param_integer': 1, - 'param_float': 2.55, - 'param_boolean': True, - 'param_list': ['a', 'b', 'c'], - 'param_object': {'foo': 'bar'} + "param_string": "test string", + "param_integer": 1, + "param_float": 2.55, + "param_boolean": True, + "param_list": ["a", "b", "c"], + "param_object": {"foo": "bar"}, } runner = self._get_runner(action_db=action_db, entry_point=entry_point) @@ -434,20 +448,20 @@ def test_script_with_parameters_parameter_serialization(self): status, result, _ = runner.run(action_parameters=action_parameters) runner.post_run(status, result) self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertIn('PARAM_STRING=test string', result['stdout']) - self.assertIn('PARAM_INTEGER=1', result['stdout']) - self.assertIn('PARAM_FLOAT=2.55', result['stdout']) - self.assertIn('PARAM_BOOLEAN=1', result['stdout']) - self.assertIn('PARAM_LIST=a,b,c', result['stdout']) - self.assertIn('PARAM_OBJECT={"foo": "bar"}', result['stdout']) + self.assertIn("PARAM_STRING=test string", result["stdout"]) + self.assertIn("PARAM_INTEGER=1", result["stdout"]) + self.assertIn("PARAM_FLOAT=2.55", result["stdout"]) + self.assertIn("PARAM_BOOLEAN=1", result["stdout"]) + self.assertIn("PARAM_LIST=a,b,c", result["stdout"]) + self.assertIn('PARAM_OBJECT={"foo": "bar"}', result["stdout"]) action_parameters = { - 'param_string': 'test string', - 'param_integer': 1, - 'param_float': 2.55, - 'param_boolean': False, - 'param_list': ['a', 'b', 'c'], - 'param_object': {'foo': 'bar'} + "param_string": "test string", + "param_integer": 1, + "param_float": 2.55, + "param_boolean": False, + "param_list": ["a", "b", "c"], + "param_object": {"foo": "bar"}, } runner = self._get_runner(action_db=action_db, entry_point=entry_point) @@ -456,12 +470,12 @@ def test_script_with_parameters_parameter_serialization(self): runner.post_run(status, result) self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertIn('PARAM_BOOLEAN=0', result['stdout']) + self.assertIn("PARAM_BOOLEAN=0", result["stdout"]) action_parameters = { - 'param_string': '', - 'param_integer': None, - 'param_float': None, + "param_string": "", + "param_integer": None, + "param_float": None, } runner = self._get_runner(action_db=action_db, entry_point=entry_point) @@ -470,24 +484,24 @@ def test_script_with_parameters_parameter_serialization(self): runner.post_run(status, result) self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertIn('PARAM_STRING=\n', result['stdout']) - self.assertIn('PARAM_INTEGER=\n', result['stdout']) - self.assertIn('PARAM_FLOAT=\n', result['stdout']) + self.assertIn("PARAM_STRING=\n", result["stdout"]) + self.assertIn("PARAM_INTEGER=\n", result["stdout"]) + self.assertIn("PARAM_FLOAT=\n", result["stdout"]) # End result should be the same when streaming is enabled - cfg.CONF.set_override(name='stream_output', group='actionrunner', override=True) + cfg.CONF.set_override(name="stream_output", group="actionrunner", override=True) # Verify initial state output_dbs = ActionExecutionOutput.get_all() self.assertEqual(len(output_dbs), 0) action_parameters = { - 'param_string': 'test string', - 'param_integer': 1, - 'param_float': 2.55, - 'param_boolean': True, - 'param_list': ['a', 'b', 'c'], - 'param_object': {'foo': 'bar'} + "param_string": "test string", + "param_integer": 1, + "param_float": 2.55, + "param_boolean": True, + "param_list": ["a", "b", "c"], + "param_object": {"foo": "bar"}, } runner = self._get_runner(action_db=action_db, entry_point=entry_point) @@ -496,26 +510,26 @@ def test_script_with_parameters_parameter_serialization(self): runner.post_run(status, result) self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertIn('PARAM_STRING=test string', result['stdout']) - self.assertIn('PARAM_INTEGER=1', result['stdout']) - self.assertIn('PARAM_FLOAT=2.55', result['stdout']) - self.assertIn('PARAM_BOOLEAN=1', result['stdout']) - self.assertIn('PARAM_LIST=a,b,c', result['stdout']) - self.assertIn('PARAM_OBJECT={"foo": "bar"}', result['stdout']) - - output_dbs = ActionExecutionOutput.query(output_type='stdout') + self.assertIn("PARAM_STRING=test string", result["stdout"]) + self.assertIn("PARAM_INTEGER=1", result["stdout"]) + self.assertIn("PARAM_FLOAT=2.55", result["stdout"]) + self.assertIn("PARAM_BOOLEAN=1", result["stdout"]) + self.assertIn("PARAM_LIST=a,b,c", result["stdout"]) + self.assertIn('PARAM_OBJECT={"foo": "bar"}', result["stdout"]) + + output_dbs = ActionExecutionOutput.query(output_type="stdout") self.assertEqual(len(output_dbs), 6) - self.assertEqual(output_dbs[0].data, 'PARAM_STRING=test string\n') + self.assertEqual(output_dbs[0].data, "PARAM_STRING=test string\n") self.assertEqual(output_dbs[5].data, 'PARAM_OBJECT={"foo": "bar"}\n') - output_dbs = ActionExecutionOutput.query(output_type='stderr') + output_dbs = ActionExecutionOutput.query(output_type="stderr") self.assertEqual(len(output_dbs), 0) - @mock.patch('st2common.util.concurrency.subprocess_popen') - @mock.patch('st2common.util.concurrency.spawn') + @mock.patch("st2common.util.concurrency.subprocess_popen") + @mock.patch("st2common.util.concurrency.spawn") def test_action_stdout_and_stderr_is_stored_in_the_db(self, mock_spawn, mock_popen): # Feature is enabled - cfg.CONF.set_override(name='stream_output', group='actionrunner', override=True) + cfg.CONF.set_override(name="stream_output", group="actionrunner", override=True) # Note: We need to mock spawn function so we can test everything in single event loop # iteration @@ -523,40 +537,41 @@ def test_action_stdout_and_stderr_is_stored_in_the_db(self, mock_spawn, mock_pop # No output to stdout and no result (implicit None) mock_stdout = [ - 'stdout line 1\n', - 'stdout line 2\n', - 'stdout line 3\n', - 'stdout line 4\n' - ] - mock_stderr = [ - 'stderr line 1\n', - 'stderr line 2\n', - 'stderr line 3\n' + "stdout line 1\n", + "stdout line 2\n", + "stdout line 3\n", + "stdout line 4\n", ] + mock_stderr = ["stderr line 1\n", "stderr line 2\n", "stderr line 3\n"] mock_process = mock.Mock() mock_process.returncode = 0 mock_popen.return_value = mock_process mock_process.stdout.closed = False mock_process.stderr.closed = False - mock_process.stdout.readline = make_mock_stream_readline(mock_process.stdout, mock_stdout, - stop_counter=4) - mock_process.stderr.readline = make_mock_stream_readline(mock_process.stderr, mock_stderr, - stop_counter=3) + mock_process.stdout.readline = make_mock_stream_readline( + mock_process.stdout, mock_stdout, stop_counter=4 + ) + mock_process.stderr.readline = make_mock_stream_readline( + mock_process.stderr, mock_stderr, stop_counter=3 + ) models = self.fixtures_loader.load_models( - fixtures_pack='generic', fixtures_dict={'actions': ['local_script_with_params.yaml']}) - action_db = models['actions']['local_script_with_params.yaml'] - entry_point = os.path.join(get_fixtures_base_path(), - 'generic/actions/local_script_with_params.sh') + fixtures_pack="generic", + fixtures_dict={"actions": ["local_script_with_params.yaml"]}, + ) + action_db = models["actions"]["local_script_with_params.yaml"] + entry_point = os.path.join( + get_fixtures_base_path(), "generic/actions/local_script_with_params.sh" + ) action_parameters = { - 'param_string': 'test string', - 'param_integer': 1, - 'param_float': 2.55, - 'param_boolean': True, - 'param_list': ['a', 'b', 'c'], - 'param_object': {'foo': 'bar'} + "param_string": "test string", + "param_integer": 1, + "param_float": 2.55, + "param_boolean": True, + "param_list": ["a", "b", "c"], + "param_object": {"foo": "bar"}, } runner = self._get_runner(action_db=action_db, entry_point=entry_point) @@ -564,20 +579,24 @@ def test_action_stdout_and_stderr_is_stored_in_the_db(self, mock_spawn, mock_pop status, result, _ = runner.run(action_parameters=action_parameters) runner.post_run(status, result) - self.assertEqual(result['stdout'], - 'stdout line 1\nstdout line 2\nstdout line 3\nstdout line 4') - self.assertEqual(result['stderr'], 'stderr line 1\nstderr line 2\nstderr line 3') - self.assertEqual(result['return_code'], 0) + self.assertEqual( + result["stdout"], + "stdout line 1\nstdout line 2\nstdout line 3\nstdout line 4", + ) + self.assertEqual( + result["stderr"], "stderr line 1\nstderr line 2\nstderr line 3" + ) + self.assertEqual(result["return_code"], 0) # Verify stdout and stderr lines have been correctly stored in the db - output_dbs = ActionExecutionOutput.query(output_type='stdout') + output_dbs = ActionExecutionOutput.query(output_type="stdout") self.assertEqual(len(output_dbs), 4) self.assertEqual(output_dbs[0].data, mock_stdout[0]) self.assertEqual(output_dbs[1].data, mock_stdout[1]) self.assertEqual(output_dbs[2].data, mock_stdout[2]) self.assertEqual(output_dbs[3].data, mock_stdout[3]) - output_dbs = ActionExecutionOutput.query(output_type='stderr') + output_dbs = ActionExecutionOutput.query(output_type="stderr") self.assertEqual(len(output_dbs), 3) self.assertEqual(output_dbs[0].data, mock_stderr[0]) self.assertEqual(output_dbs[1].data, mock_stderr[1]) @@ -585,30 +604,36 @@ def test_action_stdout_and_stderr_is_stored_in_the_db(self, mock_spawn, mock_pop def test_shell_script_action(self): models = self.fixtures_loader.load_models( - fixtures_pack='localrunner_pack', fixtures_dict={'actions': ['text_gen.yml']}) - action_db = models['actions']['text_gen.yml'] + fixtures_pack="localrunner_pack", + fixtures_dict={"actions": ["text_gen.yml"]}, + ) + action_db = models["actions"]["text_gen.yml"] entry_point = self.fixtures_loader.get_fixture_file_path_abs( - 'localrunner_pack', 'actions', 'text_gen.py') + "localrunner_pack", "actions", "text_gen.py" + ) runner = self._get_runner(action_db, entry_point=entry_point) runner.pre_run() - status, result, _ = runner.run({'chars': 1000}) + status, result, _ = runner.run({"chars": 1000}) runner.post_run(status, result) self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(len(result['stdout']), 1000) + self.assertEqual(len(result["stdout"]), 1000) def test_large_stdout(self): models = self.fixtures_loader.load_models( - fixtures_pack='localrunner_pack', fixtures_dict={'actions': ['text_gen.yml']}) - action_db = models['actions']['text_gen.yml'] + fixtures_pack="localrunner_pack", + fixtures_dict={"actions": ["text_gen.yml"]}, + ) + action_db = models["actions"]["text_gen.yml"] entry_point = self.fixtures_loader.get_fixture_file_path_abs( - 'localrunner_pack', 'actions', 'text_gen.py') + "localrunner_pack", "actions", "text_gen.py" + ) runner = self._get_runner(action_db, entry_point=entry_point) runner.pre_run() char_count = 10 ** 6 # Note 10^7 succeeds but ends up being slow. - status, result, _ = runner.run({'chars': char_count}) + status, result, _ = runner.run({"chars": char_count}) runner.post_run(status, result) self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(len(result['stdout']), char_count) + self.assertEqual(len(result["stdout"]), char_count) def _get_runner(self, action_db, entry_point): runner = LocalShellScriptRunner(uuid.uuid4().hex) @@ -622,5 +647,5 @@ def _get_runner(self, action_db, entry_point): runner.callback = dict() runner.libs_dir_path = None runner.auth_token = mock.Mock() - runner.auth_token.token = 'mock-token' + runner.auth_token.token = "mock-token" return runner diff --git a/contrib/runners/noop_runner/dist_utils.py b/contrib/runners/noop_runner/dist_utils.py index a6f62c8cc2a..2f2043cf29f 100644 --- a/contrib/runners/noop_runner/dist_utils.py +++ b/contrib/runners/noop_runner/dist_utils.py @@ -43,17 +43,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" __all__ = [ - 'check_pip_is_installed', - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_is_installed", + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] @@ -64,15 +64,15 @@ def check_pip_is_installed(): try: import pip # NOQA except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) return True -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ @@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'): import pip if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) return True @@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path): reqs = [] def _get_link(line): - vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+'] + vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"] for vcs_prefix in vcs_prefixes: - if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)): - req_name = re.findall('.*#egg=(.+)([&|@]).*$', line) + if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)): + req_name = re.findall(".*#egg=(.+)([&|@]).*$", line) if not req_name: - req_name = re.findall('.*#egg=(.+?)$', line) + req_name = re.findall(".*#egg=(.+?)$", line) else: req_name = req_name[0] if not req_name: - raise ValueError('Line "%s" is missing "#egg="' % (line)) + raise ValueError( + 'Line "%s" is missing "#egg="' % (line) + ) - link = line.replace('-e ', '').strip() + link = line.replace("-e ", "").strip() return link, req_name[0] return None, None - with open(requirements_file_path, 'r') as fp: + with open(requirements_file_path, "r") as fp: for line in fp.readlines(): line = line.strip() - if line.startswith('#') or not line: + if line.startswith("#") or not line: continue link, req_name = _get_link(line=line) @@ -131,8 +135,8 @@ def _get_link(line): else: req_name = line - if ';' in req_name: - req_name = req_name.split(';')[0].strip() + if ";" in req_name: + req_name = req_name.split(";")[0].strip() reqs.append(req_name) @@ -146,7 +150,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -155,14 +159,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/contrib/runners/noop_runner/noop_runner/__init__.py b/contrib/runners/noop_runner/noop_runner/__init__.py index bbe290db9a7..e6d3f15e0bd 100644 --- a/contrib/runners/noop_runner/noop_runner/__init__.py +++ b/contrib/runners/noop_runner/noop_runner/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = "3.4dev" diff --git a/contrib/runners/noop_runner/noop_runner/noop_runner.py b/contrib/runners/noop_runner/noop_runner/noop_runner.py index 0eb745218a8..b4dda10fd5a 100644 --- a/contrib/runners/noop_runner/noop_runner/noop_runner.py +++ b/contrib/runners/noop_runner/noop_runner/noop_runner.py @@ -22,12 +22,7 @@ from st2common.constants.action import LIVEACTION_STATUS_SUCCEEDED import st2common.util.jsonify as jsonify -__all__ = [ - 'NoopRunner', - - 'get_runner', - 'get_metadata' -] +__all__ = ["NoopRunner", "get_runner", "get_metadata"] LOG = logging.getLogger(__name__) @@ -36,7 +31,8 @@ class NoopRunner(ActionRunner): """ Runner which does absolutely nothing. No-op action. """ - KEYS_TO_TRANSFORM = ['stdout', 'stderr'] + + KEYS_TO_TRANSFORM = ["stdout", "stderr"] def __init__(self, runner_id): super(NoopRunner, self).__init__(runner_id=runner_id) @@ -46,14 +42,15 @@ def pre_run(self): def run(self, action_parameters): - LOG.info('Executing action via NoopRunner: %s', self.runner_id) - LOG.info('[Action info] name: %s, Id: %s', - self.action_name, str(self.execution_id)) + LOG.info("Executing action via NoopRunner: %s", self.runner_id) + LOG.info( + "[Action info] name: %s, Id: %s", self.action_name, str(self.execution_id) + ) result = { - 'failed': False, - 'succeeded': True, - 'return_code': 0, + "failed": False, + "succeeded": True, + "return_code": 0, } status = LIVEACTION_STATUS_SUCCEEDED @@ -65,4 +62,4 @@ def get_runner(): def get_metadata(): - return get_runner_metadata('noop_runner')[0] + return get_runner_metadata("noop_runner")[0] diff --git a/contrib/runners/noop_runner/setup.py b/contrib/runners/noop_runner/setup.py index 30b00bd68b9..94b518c55f6 100644 --- a/contrib/runners/noop_runner/setup.py +++ b/contrib/runners/noop_runner/setup.py @@ -26,30 +26,30 @@ from noop_runner import __version__ BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt') +REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt") install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE) apply_vagrant_workaround() setup( - name='stackstorm-runner-noop', + name="stackstorm-runner-noop", version=__version__, - description=('No-Op action runner for StackStorm event-driven automation platform'), - author='StackStorm', - author_email='info@stackstorm.com', - license='Apache License (2.0)', - url='https://stackstorm.com/', + description=("No-Op action runner for StackStorm event-driven automation platform"), + author="StackStorm", + author_email="info@stackstorm.com", + license="Apache License (2.0)", + url="https://stackstorm.com/", install_requires=install_reqs, dependency_links=dep_links, - test_suite='tests', + test_suite="tests", zip_safe=False, include_package_data=True, - packages=find_packages(exclude=['setuptools', 'tests']), - package_data={'noop_runner': ['runner.yaml']}, + packages=find_packages(exclude=["setuptools", "tests"]), + package_data={"noop_runner": ["runner.yaml"]}, scripts=[], entry_points={ - 'st2common.runners.runner': [ - 'noop = noop_runner.noop_runner', + "st2common.runners.runner": [ + "noop = noop_runner.noop_runner", ], - } + }, ) diff --git a/contrib/runners/noop_runner/tests/unit/test_nooprunner.py b/contrib/runners/noop_runner/tests/unit/test_nooprunner.py index 6783404ffbb..98c66c33cd4 100644 --- a/contrib/runners/noop_runner/tests/unit/test_nooprunner.py +++ b/contrib/runners/noop_runner/tests/unit/test_nooprunner.py @@ -19,6 +19,7 @@ import mock import st2tests.config as tests_config + tests_config.parse_args() from unittest2 import TestCase @@ -33,16 +34,17 @@ class TestNoopRunner(TestCase): def test_noop_command_executes(self): models = TestNoopRunner.fixtures_loader.load_models( - fixtures_pack='generic', fixtures_dict={'actions': ['noop.yaml']}) + fixtures_pack="generic", fixtures_dict={"actions": ["noop.yaml"]} + ) - action_db = models['actions']['noop.yaml'] + action_db = models["actions"]["noop.yaml"] runner = TestNoopRunner._get_runner(action_db) status, result, _ = runner.run({}) self.assertEqual(status, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(result['failed'], False) - self.assertEqual(result['succeeded'], True) - self.assertEqual(result['return_code'], 0) + self.assertEqual(result["failed"], False) + self.assertEqual(result["succeeded"], True) + self.assertEqual(result["return_code"], 0) @staticmethod def _get_runner(action_db): @@ -55,5 +57,5 @@ def _get_runner(action_db): runner.callback = dict() runner.libs_dir_path = None runner.auth_token = mock.Mock() - runner.auth_token.token = 'mock-token' + runner.auth_token.token = "mock-token" return runner diff --git a/contrib/runners/orquesta_runner/dist_utils.py b/contrib/runners/orquesta_runner/dist_utils.py index a6f62c8cc2a..2f2043cf29f 100644 --- a/contrib/runners/orquesta_runner/dist_utils.py +++ b/contrib/runners/orquesta_runner/dist_utils.py @@ -43,17 +43,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" __all__ = [ - 'check_pip_is_installed', - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_is_installed", + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] @@ -64,15 +64,15 @@ def check_pip_is_installed(): try: import pip # NOQA except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) return True -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ @@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'): import pip if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) return True @@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path): reqs = [] def _get_link(line): - vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+'] + vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"] for vcs_prefix in vcs_prefixes: - if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)): - req_name = re.findall('.*#egg=(.+)([&|@]).*$', line) + if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)): + req_name = re.findall(".*#egg=(.+)([&|@]).*$", line) if not req_name: - req_name = re.findall('.*#egg=(.+?)$', line) + req_name = re.findall(".*#egg=(.+?)$", line) else: req_name = req_name[0] if not req_name: - raise ValueError('Line "%s" is missing "#egg="' % (line)) + raise ValueError( + 'Line "%s" is missing "#egg="' % (line) + ) - link = line.replace('-e ', '').strip() + link = line.replace("-e ", "").strip() return link, req_name[0] return None, None - with open(requirements_file_path, 'r') as fp: + with open(requirements_file_path, "r") as fp: for line in fp.readlines(): line = line.strip() - if line.startswith('#') or not line: + if line.startswith("#") or not line: continue link, req_name = _get_link(line=line) @@ -131,8 +135,8 @@ def _get_link(line): else: req_name = line - if ';' in req_name: - req_name = req_name.split(';')[0].strip() + if ";" in req_name: + req_name = req_name.split(";")[0].strip() reqs.append(req_name) @@ -146,7 +150,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -155,14 +159,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/contrib/runners/orquesta_runner/orquesta_functions/runtime.py b/contrib/runners/orquesta_runner/orquesta_functions/runtime.py index f5986392d7e..e71dcafd402 100644 --- a/contrib/runners/orquesta_runner/orquesta_functions/runtime.py +++ b/contrib/runners/orquesta_runner/orquesta_functions/runtime.py @@ -33,15 +33,15 @@ def format_task_result(instances): instance = instances[-1] return { - 'task_execution_id': str(instance.id), - 'workflow_execution_id': instance.workflow_execution, - 'task_name': instance.task_id, - 'task_id': instance.task_id, - 'route': instance.task_route, - 'result': instance.result, - 'status': instance.status, - 'start_timestamp': str(instance.start_timestamp), - 'end_timestamp': str(instance.end_timestamp) + "task_execution_id": str(instance.id), + "workflow_execution_id": instance.workflow_execution, + "task_name": instance.task_id, + "task_id": instance.task_id, + "route": instance.task_route, + "result": instance.result, + "status": instance.status, + "start_timestamp": str(instance.start_timestamp), + "end_timestamp": str(instance.end_timestamp), } @@ -54,17 +54,17 @@ def task(context, task_id=None, route=None): current_task = {} if task_id is None: - task_id = current_task['id'] + task_id = current_task["id"] if route is None: - route = current_task.get('route', 0) + route = current_task.get("route", 0) try: - workflow_state = context['__state'] or {} + workflow_state = context["__state"] or {} except KeyError: workflow_state = {} - task_state_pointers = workflow_state.get('tasks') or {} + task_state_pointers = workflow_state.get("tasks") or {} task_state_entry_uid = constants.TASK_STATE_ROUTE_FORMAT % (task_id, str(route)) task_state_entry_idx = task_state_pointers.get(task_state_entry_uid) @@ -72,9 +72,11 @@ def task(context, task_id=None, route=None): # use an earlier route before the split to find the specific task. if task_state_entry_idx is None: if route > 0: - current_route_details = workflow_state['routes'][route] + current_route_details = workflow_state["routes"][route] # Reverse the list because we want to start with the next longest route. - for idx, prev_route_details in enumerate(reversed(workflow_state['routes'][:route])): + for idx, prev_route_details in enumerate( + reversed(workflow_state["routes"][:route]) + ): if len(set(prev_route_details) - set(current_route_details)) == 0: # The index is from a reversed list so need to calculate # the index of the item in the list before the reverse. @@ -83,17 +85,15 @@ def task(context, task_id=None, route=None): else: # Otherwise, get the task flow entry and use the # task id and route to query the database. - task_state_seqs = workflow_state.get('sequence') or [] + task_state_seqs = workflow_state.get("sequence") or [] task_state_entry = task_state_seqs[task_state_entry_idx] - route = task_state_entry['route'] - st2_ctx = context['__vars']['st2'] - workflow_execution_id = st2_ctx['workflow_execution_id'] + route = task_state_entry["route"] + st2_ctx = context["__vars"]["st2"] + workflow_execution_id = st2_ctx["workflow_execution_id"] # Query the database by the workflow execution ID, task ID, and task route. instances = wf_db_access.TaskExecution.query( - workflow_execution=workflow_execution_id, - task_id=task_id, - task_route=route + workflow_execution=workflow_execution_id, task_id=task_id, task_route=route ) if not instances: diff --git a/contrib/runners/orquesta_runner/orquesta_functions/st2kv.py b/contrib/runners/orquesta_runner/orquesta_functions/st2kv.py index 35cae92cd74..ed23507a1b4 100644 --- a/contrib/runners/orquesta_runner/orquesta_functions/st2kv.py +++ b/contrib/runners/orquesta_runner/orquesta_functions/st2kv.py @@ -29,26 +29,28 @@ def st2kv_(context, key, **kwargs): if not isinstance(key, six.string_types): - raise TypeError('Given key is not typeof string.') + raise TypeError("Given key is not typeof string.") - decrypt = kwargs.get('decrypt', False) + decrypt = kwargs.get("decrypt", False) if not isinstance(decrypt, bool): - raise TypeError('Decrypt parameter is not typeof bool.') + raise TypeError("Decrypt parameter is not typeof bool.") try: - username = context['__vars']['st2']['user'] + username = context["__vars"]["st2"]["user"] except KeyError: - raise KeyError('Could not get user from context.') + raise KeyError("Could not get user from context.") try: user_db = auth_db_access.User.get(username) except Exception as e: - raise Exception('Failed to retrieve User object for user "%s", "%s"' % - (username, six.text_type(e))) + raise Exception( + 'Failed to retrieve User object for user "%s", "%s"' + % (username, six.text_type(e)) + ) - has_default = 'default' in kwargs - default_value = kwargs.get('default') + has_default = "default" in kwargs + default_value = kwargs.get("default") try: return kvp_util.get_key(key=key, user_db=user_db, decrypt=decrypt) diff --git a/contrib/runners/orquesta_runner/orquesta_runner/__init__.py b/contrib/runners/orquesta_runner/orquesta_runner/__init__.py index bbe290db9a7..e6d3f15e0bd 100644 --- a/contrib/runners/orquesta_runner/orquesta_runner/__init__.py +++ b/contrib/runners/orquesta_runner/orquesta_runner/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = "3.4dev" diff --git a/contrib/runners/orquesta_runner/orquesta_runner/orquesta_runner.py b/contrib/runners/orquesta_runner/orquesta_runner/orquesta_runner.py index b59642609e4..62f2492ae4a 100644 --- a/contrib/runners/orquesta_runner/orquesta_runner/orquesta_runner.py +++ b/contrib/runners/orquesta_runner/orquesta_runner/orquesta_runner.py @@ -37,71 +37,72 @@ from st2common.util import api as api_util from st2common.util import ujson -__all__ = [ - 'OrquestaRunner', - 'get_runner', - 'get_metadata' -] +__all__ = ["OrquestaRunner", "get_runner", "get_metadata"] LOG = logging.getLogger(__name__) class OrquestaRunner(runners.AsyncActionRunner): - @staticmethod def get_workflow_definition(entry_point): - with open(entry_point, 'r') as def_file: + with open(entry_point, "r") as def_file: return def_file.read() def _get_notify_config(self): return ( - notify_api_models.NotificationsHelper.from_model(notify_model=self.liveaction.notify) + notify_api_models.NotificationsHelper.from_model( + notify_model=self.liveaction.notify + ) if self.liveaction.notify else None ) def _construct_context(self, wf_ex): ctx = ujson.fast_deepcopy(self.context) - ctx['workflow_execution'] = str(wf_ex.id) + ctx["workflow_execution"] = str(wf_ex.id) return ctx def _construct_st2_context(self): st2_ctx = { - 'st2': { - 'action_execution_id': str(self.execution.id), - 'api_url': api_util.get_full_public_api_url(), - 'user': self.execution.context.get('user', cfg.CONF.system_user.user), - 'pack': self.execution.context.get('pack', None), - 'action': self.execution.action.get('ref', None), - 'runner': self.execution.action.get('runner_type', None) + "st2": { + "action_execution_id": str(self.execution.id), + "api_url": api_util.get_full_public_api_url(), + "user": self.execution.context.get("user", cfg.CONF.system_user.user), + "pack": self.execution.context.get("pack", None), + "action": self.execution.action.get("ref", None), + "runner": self.execution.action.get("runner_type", None), } } - if self.execution.context.get('api_user'): - st2_ctx['st2']['api_user'] = self.execution.context.get('api_user') + if self.execution.context.get("api_user"): + st2_ctx["st2"]["api_user"] = self.execution.context.get("api_user") - if self.execution.context.get('source_channel'): - st2_ctx['st2']['source_channel'] = self.execution.context.get('source_channel') + if self.execution.context.get("source_channel"): + st2_ctx["st2"]["source_channel"] = self.execution.context.get( + "source_channel" + ) if self.execution.context: - st2_ctx['parent'] = self.execution.context + st2_ctx["parent"] = self.execution.context return st2_ctx def _handle_workflow_return_value(self, wf_ex_db): if wf_ex_db.status in wf_statuses.COMPLETED_STATUSES: status = wf_ex_db.status - result = {'output': wf_ex_db.output or None} + result = {"output": wf_ex_db.output or None} if wf_ex_db.status in wf_statuses.ABENDED_STATUSES: - result['errors'] = wf_ex_db.errors + result["errors"] = wf_ex_db.errors for wf_ex_error in wf_ex_db.errors: - msg = 'Workflow execution completed with errors.' - wf_svc.update_progress(wf_ex_db, '%s %s' % (msg, str(wf_ex_error)), log=False) - LOG.error('[%s] %s', str(self.execution.id), msg, extra=wf_ex_error) + msg = "Workflow execution completed with errors." + wf_svc.update_progress( + wf_ex_db, "%s %s" % (msg, str(wf_ex_error)), log=False + ) + LOG.error("[%s] %s", str(self.execution.id), msg, extra=wf_ex_error) return (status, result, self.context) @@ -115,8 +116,8 @@ def _handle_workflow_return_value(self, wf_ex_db): def run(self, action_parameters): # If there is an action execution reference for rerun and there is task specified, # then rerun the existing workflow execution. - rerun_options = self.context.get('re-run', {}) - rerun_task_options = rerun_options.get('tasks', []) + rerun_options = self.context.get("re-run", {}) + rerun_task_options = rerun_options.get("tasks", []) if self.rerun_ex_ref and rerun_task_options: return self.rerun_workflow(self.rerun_ex_ref, options=rerun_options) @@ -131,14 +132,16 @@ def start_workflow(self, action_parameters): # Request workflow execution. st2_ctx = self._construct_st2_context() notify_cfg = self._get_notify_config() - wf_ex_db = wf_svc.request(wf_def, self.execution, st2_ctx, notify_cfg=notify_cfg) + wf_ex_db = wf_svc.request( + wf_def, self.execution, st2_ctx, notify_cfg=notify_cfg + ) except wf_exc.WorkflowInspectionError as e: status = ac_const.LIVEACTION_STATUS_FAILED - result = {'errors': e.args[1], 'output': None} + result = {"errors": e.args[1], "output": None} return (status, result, self.context) except Exception as e: status = ac_const.LIVEACTION_STATUS_FAILED - result = {'errors': [{'message': six.text_type(e)}], 'output': None} + result = {"errors": [{"message": six.text_type(e)}], "output": None} return (status, result, self.context) return self._handle_workflow_return_value(wf_ex_db) @@ -146,13 +149,13 @@ def start_workflow(self, action_parameters): def rerun_workflow(self, ac_ex_ref, options=None): try: # Request rerun of workflow execution. - wf_ex_id = ac_ex_ref.context.get('workflow_execution') + wf_ex_id = ac_ex_ref.context.get("workflow_execution") st2_ctx = self._construct_st2_context() - st2_ctx['workflow_execution_id'] = wf_ex_id + st2_ctx["workflow_execution_id"] = wf_ex_id wf_ex_db = wf_svc.request_rerun(self.execution, st2_ctx, options=options) except Exception as e: status = ac_const.LIVEACTION_STATUS_FAILED - result = {'errors': [{'message': six.text_type(e)}], 'output': None} + result = {"errors": [{"message": six.text_type(e)}], "output": None} return (status, result, self.context) return self._handle_workflow_return_value(wf_ex_db) @@ -160,8 +163,8 @@ def rerun_workflow(self, ac_ex_ref, options=None): @staticmethod def task_pauseable(ac_ex): wf_ex_pauseable = ( - ac_ex.runner['name'] in ac_const.WORKFLOW_RUNNER_TYPES and - ac_ex.status == ac_const.LIVEACTION_STATUS_RUNNING + ac_ex.runner["name"] in ac_const.WORKFLOW_RUNNER_TYPES + and ac_ex.status == ac_const.LIVEACTION_STATUS_RUNNING ) return wf_ex_pauseable @@ -175,26 +178,24 @@ def pause(self): child_ex = ex_db_access.ActionExecution.get(id=child_ex_id) if self.task_pauseable(child_ex): ac_svc.request_pause( - lv_db_access.LiveAction.get(id=child_ex.liveaction['id']), - self.context.get('user', None) + lv_db_access.LiveAction.get(id=child_ex.liveaction["id"]), + self.context.get("user", None), ) - if wf_ex_db.status == wf_statuses.PAUSING or ac_svc.is_children_active(self.liveaction.id): + if wf_ex_db.status == wf_statuses.PAUSING or ac_svc.is_children_active( + self.liveaction.id + ): status = ac_const.LIVEACTION_STATUS_PAUSING else: status = ac_const.LIVEACTION_STATUS_PAUSED - return ( - status, - self.liveaction.result, - self.liveaction.context - ) + return (status, self.liveaction.result, self.liveaction.context) @staticmethod def task_resumeable(ac_ex): wf_ex_resumeable = ( - ac_ex.runner['name'] in ac_const.WORKFLOW_RUNNER_TYPES and - ac_ex.status == ac_const.LIVEACTION_STATUS_PAUSED + ac_ex.runner["name"] in ac_const.WORKFLOW_RUNNER_TYPES + and ac_ex.status == ac_const.LIVEACTION_STATUS_PAUSED ) return wf_ex_resumeable @@ -208,26 +209,26 @@ def resume(self): child_ex = ex_db_access.ActionExecution.get(id=child_ex_id) if self.task_resumeable(child_ex): ac_svc.request_resume( - lv_db_access.LiveAction.get(id=child_ex.liveaction['id']), - self.context.get('user', None) + lv_db_access.LiveAction.get(id=child_ex.liveaction["id"]), + self.context.get("user", None), ) return ( wf_ex_db.status if wf_ex_db else ac_const.LIVEACTION_STATUS_RUNNING, self.liveaction.result, - self.liveaction.context + self.liveaction.context, ) @staticmethod def task_cancelable(ac_ex): wf_ex_cancelable = ( - ac_ex.runner['name'] in ac_const.WORKFLOW_RUNNER_TYPES and - ac_ex.status in ac_const.LIVEACTION_CANCELABLE_STATES + ac_ex.runner["name"] in ac_const.WORKFLOW_RUNNER_TYPES + and ac_ex.status in ac_const.LIVEACTION_CANCELABLE_STATES ) ac_ex_cancelable = ( - ac_ex.runner['name'] not in ac_const.WORKFLOW_RUNNER_TYPES and - ac_ex.status in ac_const.LIVEACTION_DELAYED_STATES + ac_ex.runner["name"] not in ac_const.WORKFLOW_RUNNER_TYPES + and ac_ex.status in ac_const.LIVEACTION_DELAYED_STATES ) return wf_ex_cancelable or ac_ex_cancelable @@ -242,8 +243,10 @@ def cancel(self): # If workflow execution is not found because the action execution is cancelled # before the workflow execution is created or if the workflow execution is # already completed, then ignore the exception and proceed with cancellation. - except (wf_svc_exc.WorkflowExecutionNotFoundException, - wf_svc_exc.WorkflowExecutionIsCompletedException): + except ( + wf_svc_exc.WorkflowExecutionNotFoundException, + wf_svc_exc.WorkflowExecutionIsCompletedException, + ): pass # If there is an unknown exception, then log the error. Continue with the # cancelation sequence below to cancel children and determine final status. @@ -253,19 +256,22 @@ def cancel(self): # execution will be in an unknown state. except Exception: _, ex, tb = sys.exc_info() - msg = 'Error encountered when canceling workflow execution.' - LOG.exception('[%s] %s', str(self.execution.id), msg) - msg = 'Error encountered when canceling workflow execution. %s' + msg = "Error encountered when canceling workflow execution." + LOG.exception("[%s] %s", str(self.execution.id), msg) + msg = "Error encountered when canceling workflow execution. %s" wf_svc.update_progress(wf_ex_db, msg % str(ex), log=False) - result = {'error': msg % str(ex), 'traceback': ''.join(traceback.format_tb(tb, 20))} + result = { + "error": msg % str(ex), + "traceback": "".join(traceback.format_tb(tb, 20)), + } # Request cancellation of tasks that are workflows and still running. for child_ex_id in self.execution.children: child_ex = ex_db_access.ActionExecution.get(id=child_ex_id) if self.task_cancelable(child_ex): ac_svc.request_cancellation( - lv_db_access.LiveAction.get(id=child_ex.liveaction['id']), - self.context.get('user', None) + lv_db_access.LiveAction.get(id=child_ex.liveaction["id"]), + self.context.get("user", None), ) status = ( @@ -277,7 +283,7 @@ def cancel(self): return ( status, result if result else self.liveaction.result, - self.liveaction.context + self.liveaction.context, ) @@ -286,4 +292,4 @@ def get_runner(): def get_metadata(): - return runners.get_metadata('orquesta_runner')[0] + return runners.get_metadata("orquesta_runner")[0] diff --git a/contrib/runners/orquesta_runner/setup.py b/contrib/runners/orquesta_runner/setup.py index 5dac5ed34e9..859a8b60506 100644 --- a/contrib/runners/orquesta_runner/setup.py +++ b/contrib/runners/orquesta_runner/setup.py @@ -26,62 +26,64 @@ from orquesta_runner import __version__ BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt') +REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt") install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE) apply_vagrant_workaround() setup( - name='stackstorm-runner-orquesta', + name="stackstorm-runner-orquesta", version=__version__, - description='Orquesta workflow runner for StackStorm event-driven automation platform', - author='StackStorm', - author_email='info@stackstorm.com', - license='Apache License (2.0)', - url='https://stackstorm.com/', + description="Orquesta workflow runner for StackStorm event-driven automation platform", + author="StackStorm", + author_email="info@stackstorm.com", + license="Apache License (2.0)", + url="https://stackstorm.com/", install_requires=install_reqs, dependency_links=dep_links, - test_suite='tests', + test_suite="tests", zip_safe=False, include_package_data=True, - packages=find_packages(exclude=['setuptools', 'tests']), - package_data={'orquesta_runner': ['runner.yaml']}, + packages=find_packages(exclude=["setuptools", "tests"]), + package_data={"orquesta_runner": ["runner.yaml"]}, scripts=[], entry_points={ - 'st2common.runners.runner': [ - 'orquesta = orquesta_runner.orquesta_runner', + "st2common.runners.runner": [ + "orquesta = orquesta_runner.orquesta_runner", ], - 'orquesta.expressions.functions': [ - 'st2kv = orquesta_functions.st2kv:st2kv_', - 'task = orquesta_functions.runtime:task', - 'basename = st2common.expressions.functions.path:basename', - 'dirname = st2common.expressions.functions.path:dirname', - 'from_json_string = st2common.expressions.functions.data:from_json_string', - 'from_yaml_string = st2common.expressions.functions.data:from_yaml_string', - 'json_dump = st2common.expressions.functions.data:to_json_string', - 'json_parse = st2common.expressions.functions.data:from_json_string', - 'json_escape = st2common.expressions.functions.data:json_escape', - 'jsonpath_query = st2common.expressions.functions.data:jsonpath_query', - 'regex_match = st2common.expressions.functions.regex:regex_match', - 'regex_replace = st2common.expressions.functions.regex:regex_replace', - 'regex_search = st2common.expressions.functions.regex:regex_search', - 'regex_substring = st2common.expressions.functions.regex:regex_substring', - ('to_human_time_from_seconds = ' - 'st2common.expressions.functions.time:to_human_time_from_seconds'), - 'to_json_string = st2common.expressions.functions.data:to_json_string', - 'to_yaml_string = st2common.expressions.functions.data:to_yaml_string', - 'use_none = st2common.expressions.functions.data:use_none', - 'version_compare = st2common.expressions.functions.version:version_compare', - 'version_more_than = st2common.expressions.functions.version:version_more_than', - 'version_less_than = st2common.expressions.functions.version:version_less_than', - 'version_equal = st2common.expressions.functions.version:version_equal', - 'version_match = st2common.expressions.functions.version:version_match', - 'version_bump_major = st2common.expressions.functions.version:version_bump_major', - 'version_bump_minor = st2common.expressions.functions.version:version_bump_minor', - 'version_bump_patch = st2common.expressions.functions.version:version_bump_patch', - 'version_strip_patch = st2common.expressions.functions.version:version_strip_patch', - 'yaml_dump = st2common.expressions.functions.data:to_yaml_string', - 'yaml_parse = st2common.expressions.functions.data:from_yaml_string' + "orquesta.expressions.functions": [ + "st2kv = orquesta_functions.st2kv:st2kv_", + "task = orquesta_functions.runtime:task", + "basename = st2common.expressions.functions.path:basename", + "dirname = st2common.expressions.functions.path:dirname", + "from_json_string = st2common.expressions.functions.data:from_json_string", + "from_yaml_string = st2common.expressions.functions.data:from_yaml_string", + "json_dump = st2common.expressions.functions.data:to_json_string", + "json_parse = st2common.expressions.functions.data:from_json_string", + "json_escape = st2common.expressions.functions.data:json_escape", + "jsonpath_query = st2common.expressions.functions.data:jsonpath_query", + "regex_match = st2common.expressions.functions.regex:regex_match", + "regex_replace = st2common.expressions.functions.regex:regex_replace", + "regex_search = st2common.expressions.functions.regex:regex_search", + "regex_substring = st2common.expressions.functions.regex:regex_substring", + ( + "to_human_time_from_seconds = " + "st2common.expressions.functions.time:to_human_time_from_seconds" + ), + "to_json_string = st2common.expressions.functions.data:to_json_string", + "to_yaml_string = st2common.expressions.functions.data:to_yaml_string", + "use_none = st2common.expressions.functions.data:use_none", + "version_compare = st2common.expressions.functions.version:version_compare", + "version_more_than = st2common.expressions.functions.version:version_more_than", + "version_less_than = st2common.expressions.functions.version:version_less_than", + "version_equal = st2common.expressions.functions.version:version_equal", + "version_match = st2common.expressions.functions.version:version_match", + "version_bump_major = st2common.expressions.functions.version:version_bump_major", + "version_bump_minor = st2common.expressions.functions.version:version_bump_minor", + "version_bump_patch = st2common.expressions.functions.version:version_bump_patch", + "version_strip_patch = st2common.expressions.functions.version:version_strip_patch", + "yaml_dump = st2common.expressions.functions.data:to_yaml_string", + "yaml_parse = st2common.expressions.functions.data:from_yaml_string", ], - } + }, ) diff --git a/contrib/runners/orquesta_runner/tests/integration/test_wiring_functions_st2kv.py b/contrib/runners/orquesta_runner/tests/integration/test_wiring_functions_st2kv.py index 8734bce072c..0e273f6e83b 100644 --- a/contrib/runners/orquesta_runner/tests/integration/test_wiring_functions_st2kv.py +++ b/contrib/runners/orquesta_runner/tests/integration/test_wiring_functions_st2kv.py @@ -21,78 +21,67 @@ class DatastoreFunctionTest(base.TestWorkflowExecution): @classmethod - def set_kvp(cls, name, value, scope='system', secret=False): + def set_kvp(cls, name, value, scope="system", secret=False): kvp = models.KeyValuePair( - id=name, - name=name, - value=value, - scope=scope, - secret=secret + id=name, name=name, value=value, scope=scope, secret=secret ) cls.st2client.keys.update(kvp) @classmethod - def del_kvp(cls, name, scope='system'): - kvp = models.KeyValuePair( - id=name, - name=name, - scope=scope - ) + def del_kvp(cls, name, scope="system"): + kvp = models.KeyValuePair(id=name, name=name, scope=scope) cls.st2client.keys.delete(kvp) def test_st2kv_system_scope(self): - key = 'lakshmi' - value = 'kanahansnasnasdlsajks' + key = "lakshmi" + value = "kanahansnasnasdlsajks" self.set_kvp(key, value) - wf_name = 'examples.orquesta-st2kv' - wf_input = {'key_name': 'system.%s' % key} + wf_name = "examples.orquesta-st2kv" + wf_input = {"key_name": "system.%s" % key} execution = self._execute_workflow(wf_name, wf_input) output = self._wait_for_completion(execution) self.assertEqual(output.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) - self.assertIn('output', output.result) - self.assertIn('value', output.result['output']) - self.assertEqual(value, output.result['output']['value']) + self.assertIn("output", output.result) + self.assertIn("value", output.result["output"]) + self.assertEqual(value, output.result["output"]["value"]) self.del_kvp(key) def test_st2kv_user_scope(self): - key = 'winson' - value = 'SoDiamondEng' + key = "winson" + value = "SoDiamondEng" - self.set_kvp(key, value, 'user') - wf_name = 'examples.orquesta-st2kv' - wf_input = {'key_name': key} + self.set_kvp(key, value, "user") + wf_name = "examples.orquesta-st2kv" + wf_input = {"key_name": key} execution = self._execute_workflow(wf_name, wf_input) output = self._wait_for_completion(execution) self.assertEqual(output.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) - self.assertIn('output', output.result) - self.assertIn('value', output.result['output']) - self.assertEqual(value, output.result['output']['value']) + self.assertIn("output", output.result) + self.assertIn("value", output.result["output"]) + self.assertEqual(value, output.result["output"]["value"]) # self.del_kvp(key) def test_st2kv_decrypt(self): - key = 'kami' - value = 'eggplant' + key = "kami" + value = "eggplant" self.set_kvp(key, value, secret=True) - wf_name = 'examples.orquesta-st2kv' - wf_input = { - 'key_name': 'system.%s' % key, - 'decrypt': True - } + wf_name = "examples.orquesta-st2kv" + wf_input = {"key_name": "system.%s" % key, "decrypt": True} execution = self._execute_workflow(wf_name, wf_input) output = self._wait_for_completion(execution) self.assertEqual(output.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) - self.assertIn('output', output.result) - self.assertIn('value', output.result['output']) - self.assertEqual(value, output.result['output']['value']) + self.assertIn("output", output.result) + self.assertIn("value", output.result["output"]) + self.assertEqual(value, output.result["output"]["value"]) self.del_kvp(key) diff --git a/contrib/runners/orquesta_runner/tests/unit/base.py b/contrib/runners/orquesta_runner/tests/unit/base.py index dbd28957210..d3e518fab7a 100644 --- a/contrib/runners/orquesta_runner/tests/unit/base.py +++ b/contrib/runners/orquesta_runner/tests/unit/base.py @@ -19,13 +19,13 @@ def get_wf_fixture_meta_data(fixture_pack_path, wf_meta_file_name): - wf_meta_file_path = fixture_pack_path + '/actions/' + wf_meta_file_name + wf_meta_file_path = fixture_pack_path + "/actions/" + wf_meta_file_name wf_meta_content = loader.load_meta_file(wf_meta_file_path) - wf_name = wf_meta_content['pack'] + '.' + wf_meta_content['name'] + wf_name = wf_meta_content["pack"] + "." + wf_meta_content["name"] return { - 'file_name': wf_meta_file_name, - 'file_path': wf_meta_file_path, - 'content': wf_meta_content, - 'name': wf_name + "file_name": wf_meta_file_name, + "file_path": wf_meta_file_path, + "content": wf_meta_content, + "name": wf_name, } diff --git a/contrib/runners/orquesta_runner/tests/unit/test_basic.py b/contrib/runners/orquesta_runner/tests/unit/test_basic.py index 7fc2255ed27..5f5c60a0129 100644 --- a/contrib/runners/orquesta_runner/tests/unit/test_basic.py +++ b/contrib/runners/orquesta_runner/tests/unit/test_basic.py @@ -26,6 +26,7 @@ # XXX: actionsensor import depends on config being setup. import st2tests.config as tests_config + tests_config.parse_args() from tests.unit import base @@ -51,37 +52,45 @@ from st2tests.mocks import workflow as mock_wf_ex_xport -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) PACKS = [ TEST_PACK_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] @mock.patch.object( - publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) + publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None) +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create)) + "publish_create", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create + ), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state)) + "publish_state", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state + ), +) class OrquestaRunnerTest(st2tests.ExecutionDbTestCase): - @classmethod def setUpClass(cls): super(OrquestaRunnerTest, cls).setUpClass() @@ -91,8 +100,7 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: @@ -103,14 +111,15 @@ def get_runner_class(cls, runner_name): return runners.get_runner(runner_name, runner_name).__class__ @mock.patch.object( - runners_utils, - 'invoke_post_run', - mock.MagicMock(return_value=None)) + runners_utils, "invoke_post_run", mock.MagicMock(return_value=None) + ) def test_run_workflow(self): - username = 'stanley' - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - wf_input = {'who': 'Thanos'} - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + username = "stanley" + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + wf_input = {"who": "Thanos"} + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) # The main action execution for this workflow is not under the context of another workflow. @@ -120,9 +129,13 @@ def test_run_workflow(self): lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertTrue(lv_ac_db.action_is_workflow) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) - wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id)) + wf_ex_dbs = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + ) wf_ex_db = wf_ex_dbs[0] # Check required attributes. @@ -134,26 +147,24 @@ def test_run_workflow(self): # Check context in the workflow execution. expected_wf_ex_ctx = { - 'st2': { - 'workflow_execution_id': str(wf_ex_db.id), - 'action_execution_id': str(ac_ex_db.id), - 'api_url': 'http://127.0.0.1/v1', - 'user': username, - 'pack': 'orquesta_tests', - 'action': 'orquesta_tests.sequential', - 'runner': 'orquesta' + "st2": { + "workflow_execution_id": str(wf_ex_db.id), + "action_execution_id": str(ac_ex_db.id), + "api_url": "http://127.0.0.1/v1", + "user": username, + "pack": "orquesta_tests", + "action": "orquesta_tests.sequential", + "runner": "orquesta", }, - 'parent': { - 'pack': 'orquesta_tests' - } + "parent": {"pack": "orquesta_tests"}, } self.assertDictEqual(wf_ex_db.context, expected_wf_ex_ctx) # Check context in the liveaction. expected_lv_ac_ctx = { - 'workflow_execution': str(wf_ex_db.id), - 'pack': 'orquesta_tests' + "workflow_execution": str(wf_ex_db.id), + "pack": "orquesta_tests", } self.assertDictEqual(lv_ac_db.context, expected_lv_ac_ctx) @@ -161,24 +172,26 @@ def test_run_workflow(self): # Check graph. self.assertIsNotNone(wf_ex_db.graph) self.assertIsInstance(wf_ex_db.graph, dict) - self.assertIn('nodes', wf_ex_db.graph) - self.assertIn('adjacency', wf_ex_db.graph) + self.assertIn("nodes", wf_ex_db.graph) + self.assertIn("adjacency", wf_ex_db.graph) # Check task states. self.assertIsNotNone(wf_ex_db.state) self.assertIsInstance(wf_ex_db.state, dict) - self.assertIn('tasks', wf_ex_db.state) - self.assertIn('sequence', wf_ex_db.state) + self.assertIn("tasks", wf_ex_db.state) + self.assertIn("sequence", wf_ex_db.state) # Check input. self.assertDictEqual(wf_ex_db.input, wf_input) # Assert task1 is already completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) - self.assertEqual(tk1_lv_ac_db.context.get('user'), username) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) + self.assertEqual(tk1_lv_ac_db.context.get("user"), username) self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) self.assertTrue(wf_svc.is_action_execution_under_workflow_context(tk1_ac_ex_db)) @@ -192,11 +205,13 @@ def test_run_workflow(self): self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Assert task2 is already completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task2'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task2"} tk2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk2_ex_db.id))[0] - tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction['id']) - self.assertEqual(tk2_lv_ac_db.context.get('user'), username) + tk2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk2_ex_db.id) + )[0] + tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction["id"]) + self.assertEqual(tk2_lv_ac_db.context.get("user"), username) self.assertEqual(tk2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) self.assertTrue(wf_svc.is_action_execution_under_workflow_context(tk2_ac_ex_db)) @@ -210,11 +225,13 @@ def test_run_workflow(self): self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Assert task3 is already completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task3'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task3"} tk3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk3_ex_db.id))[0] - tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction['id']) - self.assertEqual(tk3_lv_ac_db.context.get('user'), username) + tk3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk3_ex_db.id) + )[0] + tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction["id"]) + self.assertEqual(tk3_lv_ac_db.context.get("user"), username) self.assertEqual(tk3_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) self.assertTrue(wf_svc.is_action_execution_under_workflow_context(tk3_ac_ex_db)) @@ -234,48 +251,60 @@ def test_run_workflow(self): self.assertEqual(runners_utils.invoke_post_run.call_count, 1) # Check workflow output. - expected_output = {'msg': '%s, All your base are belong to us!' % wf_input['who']} + expected_output = { + "msg": "%s, All your base are belong to us!" % wf_input["who"] + } self.assertDictEqual(wf_ex_db.output, expected_output) # Check liveaction and action execution result. - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} self.assertDictEqual(lv_ac_db.result, expected_result) self.assertDictEqual(ac_ex_db.result, expected_result) def test_run_workflow_with_unicode_input(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - wf_input = {'who': '薩諾斯'} - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + wf_input = {"who": "薩諾斯"} + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] # Process task1. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) wf_svc.handle_action_execution_completion(tk1_ac_ex_db) tk1_ex_db = wf_db_access.TaskExecution.get_by_id(tk1_ex_db.id) self.assertEqual(tk1_ex_db.status, wf_statuses.SUCCEEDED) # Process task2. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task2'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task2"} tk2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk2_ex_db.id))[0] - tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction['id']) + tk2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk2_ex_db.id) + )[0] + tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction["id"]) self.assertEqual(tk2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) wf_svc.handle_action_execution_completion(tk2_ac_ex_db) tk2_ex_db = wf_db_access.TaskExecution.get_by_id(tk2_ex_db.id) self.assertEqual(tk2_ex_db.status, wf_statuses.SUCCEEDED) # Process task3. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task3'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task3"} tk3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk3_ex_db.id))[0] - tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction['id']) + tk3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk3_ex_db.id) + )[0] + tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction["id"]) self.assertEqual(tk3_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) wf_svc.handle_action_execution_completion(tk3_ac_ex_db) tk3_ex_db = wf_db_access.TaskExecution.get_by_id(tk3_ex_db.id) @@ -290,33 +319,41 @@ def test_run_workflow_with_unicode_input(self): self.assertEqual(ac_ex_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) # Check workflow output. - wf_input_val = wf_input['who'].decode('utf-8') if six.PY2 else wf_input['who'] - expected_output = {'msg': '%s, All your base are belong to us!' % wf_input_val} + wf_input_val = wf_input["who"].decode("utf-8") if six.PY2 else wf_input["who"] + expected_output = {"msg": "%s, All your base are belong to us!" % wf_input_val} self.assertDictEqual(wf_ex_db.output, expected_output) # Check liveaction and action execution result. - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} self.assertDictEqual(lv_ac_db.result, expected_result) self.assertDictEqual(ac_ex_db.result, expected_result) def test_run_workflow_action_config_context(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'config-context.yaml') + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "config-context.yaml") wf_input = {} - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Assert task1 is already completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) self.assertTrue(wf_svc.is_action_execution_under_workflow_context(tk1_ac_ex_db)) @@ -332,59 +369,77 @@ def test_run_workflow_action_config_context(self): self.assertEqual(ac_ex_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) # Verify config_context works - self.assertEqual(wf_ex_db.output, {'msg': 'value of config key a'}) + self.assertEqual(wf_ex_db.output, {"msg": "value of config key a"}) def test_run_workflow_with_action_less_tasks(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'action-less-tasks.yaml') - wf_input = {'name': 'Thanos'} - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "action-less-tasks.yaml" + ) + wf_input = {"name": "Thanos"} + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Assert task1 is already completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id)) + tk1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + ) self.assertEqual(len(tk1_ac_ex_dbs), 0) self.assertEqual(tk1_ex_db.status, wf_statuses.SUCCEEDED) # Assert task2 is already completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task2'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task2"} tk2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk2_ex_db.id))[0] - tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction['id']) + tk2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk2_ex_db.id) + )[0] + tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction["id"]) self.assertEqual(tk2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) # Manually handle action execution completion. wf_svc.handle_action_execution_completion(tk2_ac_ex_db) # Assert task3 is already completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task3'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task3"} tk3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk3_ex_db.id))[0] - tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction['id']) + tk3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk3_ex_db.id) + )[0] + tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction["id"]) self.assertEqual(tk3_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) # Manually handle action execution completion. wf_svc.handle_action_execution_completion(tk3_ac_ex_db) # Assert task4 is already completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task4'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task4"} tk4_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk4_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(tk4_ex_db.id)) + tk4_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(tk4_ex_db.id) + ) self.assertEqual(len(tk4_ac_ex_dbs), 0) self.assertEqual(tk4_ex_db.status, wf_statuses.SUCCEEDED) # Assert task5 is already completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task5'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task5"} tk5_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk5_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk5_ex_db.id))[0] - tk5_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk5_ac_ex_db.liveaction['id']) + tk5_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk5_ex_db.id) + )[0] + tk5_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk5_ac_ex_db.liveaction["id"]) self.assertEqual(tk5_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) # Manually handle action execution completion. @@ -399,65 +454,95 @@ def test_run_workflow_with_action_less_tasks(self): self.assertEqual(ac_ex_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) # Check workflow output. - expected_output = {'greeting': '%s, All your base are belong to us!' % wf_input['name']} - expected_output['greeting'] = expected_output['greeting'].upper() + expected_output = { + "greeting": "%s, All your base are belong to us!" % wf_input["name"] + } + expected_output["greeting"] = expected_output["greeting"].upper() self.assertDictEqual(wf_ex_db.output, expected_output) # Check liveaction and action execution result. - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} self.assertDictEqual(lv_ac_db.result, expected_result) self.assertDictEqual(ac_ex_db.result, expected_result) @mock.patch.object( - pc_svc, 'apply_post_run_policies', - mock.MagicMock(return_value=None)) + pc_svc, "apply_post_run_policies", mock.MagicMock(return_value=None) + ) def test_handle_action_execution_completion(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'subworkflow.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "subworkflow.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Identify the records for the main workflow. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id)) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(tk_ex_dbs), 1) # Identify the records for the tasks. - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[0].id))[0] - t1_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t1_ac_ex_db.id))[0] + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_dbs[0].id) + )[0] + t1_wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(t1_ac_ex_db.id) + )[0] self.assertEqual(t1_ac_ex_db.status, ac_const.LIVEACTION_STATUS_RUNNING) self.assertEqual(t1_wf_ex_db.status, wf_statuses.RUNNING) # Manually notify action execution completion for the tasks. # Assert policies are not applied in the notifier. - t1_t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[0] - t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t1_ex_db.id))[0] + t1_t1_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t1_wf_ex_db.id) + )[0] + t1_t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t1_ex_db.id) + )[0] notifier.get_notifier().process(t1_t1_ac_ex_db) self.assertFalse(pc_svc.apply_post_run_policies.called) - t1_tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id)) + t1_tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(t1_wf_ex_db.id) + ) self.assertEqual(len(t1_tk_ex_dbs), 1) workflows.get_engine().process(t1_t1_ac_ex_db) self.assertTrue(pc_svc.apply_post_run_policies.called) pc_svc.apply_post_run_policies.reset_mock() - t1_t2_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[1] - t1_t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t2_ex_db.id))[0] + t1_t2_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t1_wf_ex_db.id) + )[1] + t1_t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t2_ex_db.id) + )[0] notifier.get_notifier().process(t1_t2_ac_ex_db) self.assertFalse(pc_svc.apply_post_run_policies.called) - t1_tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id)) + t1_tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(t1_wf_ex_db.id) + ) self.assertEqual(len(t1_tk_ex_dbs), 2) workflows.get_engine().process(t1_t2_ac_ex_db) self.assertTrue(pc_svc.apply_post_run_policies.called) pc_svc.apply_post_run_policies.reset_mock() - t1_t3_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[2] - t1_t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t3_ex_db.id))[0] + t1_t3_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t1_wf_ex_db.id) + )[2] + t1_t3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t3_ex_db.id) + )[0] notifier.get_notifier().process(t1_t3_ac_ex_db) self.assertFalse(pc_svc.apply_post_run_policies.called) - t1_tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id)) + t1_tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(t1_wf_ex_db.id) + ) self.assertEqual(len(t1_tk_ex_dbs), 3) workflows.get_engine().process(t1_t3_ac_ex_db) self.assertTrue(pc_svc.apply_post_run_policies.called) @@ -466,19 +551,25 @@ def test_handle_action_execution_completion(self): t1_ac_ex_db = ex_db_access.ActionExecution.get_by_id(t1_ac_ex_db.id) notifier.get_notifier().process(t1_ac_ex_db) self.assertFalse(pc_svc.apply_post_run_policies.called) - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id)) + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(tk_ex_dbs), 1) workflows.get_engine().process(t1_ac_ex_db) self.assertTrue(pc_svc.apply_post_run_policies.called) pc_svc.apply_post_run_policies.reset_mock() - t2_ex_db_qry = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task2'} + t2_ex_db_qry = {"workflow_execution": str(wf_ex_db.id), "task_id": "task2"} t2_ex_db = wf_db_access.TaskExecution.query(**t2_ex_db_qry)[0] - t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_ex_db.id))[0] + t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_ex_db.id) + )[0] self.assertEqual(t2_ac_ex_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) notifier.get_notifier().process(t2_ac_ex_db) self.assertFalse(pc_svc.apply_post_run_policies.called) - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id)) + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(tk_ex_dbs), 2) workflows.get_engine().process(t2_ac_ex_db) self.assertTrue(pc_svc.apply_post_run_policies.called) diff --git a/contrib/runners/orquesta_runner/tests/unit/test_cancel.py b/contrib/runners/orquesta_runner/tests/unit/test_cancel.py index 145bd1f3b44..b49fd0f77bf 100644 --- a/contrib/runners/orquesta_runner/tests/unit/test_cancel.py +++ b/contrib/runners/orquesta_runner/tests/unit/test_cancel.py @@ -24,6 +24,7 @@ # XXX: actionsensor import depends on config being setup. import st2tests.config as tests_config + tests_config.parse_args() from tests.unit import base @@ -45,37 +46,45 @@ from st2tests.mocks import workflow as mock_wf_ex_xport -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) PACKS = [ TEST_PACK_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] @mock.patch.object( - publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) + publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None) +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create)) + "publish_create", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create + ), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state)) + "publish_state", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state + ), +) class OrquestaRunnerCancelTest(st2tests.ExecutionDbTestCase): - @classmethod def setUpClass(cls): super(OrquestaRunnerCancelTest, cls).setUpClass() @@ -85,8 +94,7 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: @@ -96,15 +104,15 @@ def setUpClass(cls): def get_runner_class(cls, runner_name): return runners.get_runner(runner_name, runner_name).__class__ - @mock.patch.object( - ac_svc, 'is_children_active', - mock.MagicMock(return_value=True)) + @mock.patch.object(ac_svc, "is_children_active", mock.MagicMock(return_value=True)) def test_cancel(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) requester = cfg.CONF.system_user.user lv_ac_db, ac_ex_db = ac_svc.request_cancellation(lv_ac_db, requester) @@ -112,23 +120,33 @@ def test_cancel(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_CANCELING) def test_cancel_workflow_cascade_down_to_subworkflow(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'subworkflow.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "subworkflow.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Identify the records for the subworkflow. - wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id)) + wf_ex_dbs = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + ) self.assertEqual(len(wf_ex_dbs), 1) - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_dbs[0].id)) + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_dbs[0].id) + ) self.assertEqual(len(tk_ex_dbs), 1) - tk_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[0].id)) + tk_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_dbs[0].id) + ) self.assertEqual(len(tk_ac_ex_dbs), 1) - tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_dbs[0].liveaction['id']) + tk_lv_ac_db = lv_db_access.LiveAction.get_by_id( + tk_ac_ex_dbs[0].liveaction["id"] + ) self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Cancel the main workflow. @@ -145,23 +163,33 @@ def test_cancel_workflow_cascade_down_to_subworkflow(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_CANCELED) def test_cancel_subworkflow_cascade_up_to_workflow(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'subworkflow.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "subworkflow.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Identify the records for the subworkflow. - wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id)) + wf_ex_dbs = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + ) self.assertEqual(len(wf_ex_dbs), 1) - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_dbs[0].id)) + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_dbs[0].id) + ) self.assertEqual(len(tk_ex_dbs), 1) - tk_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[0].id)) + tk_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_dbs[0].id) + ) self.assertEqual(len(tk_ac_ex_dbs), 1) - tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_dbs[0].liveaction['id']) + tk_lv_ac_db = lv_db_access.LiveAction.get_by_id( + tk_ac_ex_dbs[0].liveaction["id"] + ) self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Cancel the subworkflow. @@ -183,34 +211,50 @@ def test_cancel_subworkflow_cascade_up_to_workflow(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_CANCELED) def test_cancel_subworkflow_cascade_up_to_workflow_with_other_subworkflows(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'subworkflows.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "subworkflows.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Identify the records for the subworkflow. - wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id)) + wf_ex_dbs = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + ) self.assertEqual(len(wf_ex_dbs), 1) - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_dbs[0].id)) + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_dbs[0].id) + ) self.assertEqual(len(tk_ex_dbs), 2) - tk1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[0].id)) + tk1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_dbs[0].id) + ) self.assertEqual(len(tk1_ac_ex_dbs), 1) - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_dbs[0].liveaction['id']) + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id( + tk1_ac_ex_dbs[0].liveaction["id"] + ) self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) - tk2_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[1].id)) + tk2_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_dbs[1].id) + ) self.assertEqual(len(tk2_ac_ex_dbs), 1) - tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_dbs[0].liveaction['id']) + tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id( + tk2_ac_ex_dbs[0].liveaction["id"] + ) self.assertEqual(tk2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Cancel the subworkflow which should cascade up to the root. requester = cfg.CONF.system_user.user - tk1_lv_ac_db, tk1_ac_ex_db = ac_svc.request_cancellation(tk1_lv_ac_db, requester) + tk1_lv_ac_db, tk1_ac_ex_db = ac_svc.request_cancellation( + tk1_lv_ac_db, requester + ) self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_CANCELING) # Assert the main workflow is canceling. @@ -239,15 +283,21 @@ def test_cancel_subworkflow_cascade_up_to_workflow_with_other_subworkflows(self) self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_CANCELED) def test_cancel_before_wf_ex_db_created(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Delete the workfow execution to mock issue where the record has not been created yet. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - wf_db_access.WorkflowExecution.delete(wf_ex_db, publish=False, dispatch_trigger=False) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + wf_db_access.WorkflowExecution.delete( + wf_ex_db, publish=False, dispatch_trigger=False + ) # Cancel the action execution. requester = cfg.CONF.system_user.user @@ -256,15 +306,19 @@ def test_cancel_before_wf_ex_db_created(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_CANCELED) def test_cancel_after_wf_ex_db_completed(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Delete the workfow execution to mock issue where the workflow is already completed # but the liveaction and action execution have not had time to be updated. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] wf_ex_db.status = wf_ex_statuses.SUCCEEDED wf_ex_db = wf_db_access.WorkflowExecution.update(wf_ex_db, publish=False) @@ -275,14 +329,16 @@ def test_cancel_after_wf_ex_db_completed(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_CANCELED) @mock.patch.object( - wf_svc, 'request_cancellation', - mock.MagicMock(side_effect=Exception('foobar'))) + wf_svc, "request_cancellation", mock.MagicMock(side_effect=Exception("foobar")) + ) def test_cancel_unexpected_exception(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Cancel the action execution. requester = cfg.CONF.system_user.user @@ -297,4 +353,6 @@ def test_cancel_unexpected_exception(self): # to raise an exception and the records will be stuck in a canceling # status and user is unable to easily clean up. self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_CANCELED) - self.assertIn('Error encountered when canceling', lv_ac_db.result.get('error', '')) + self.assertIn( + "Error encountered when canceling", lv_ac_db.result.get("error", "") + ) diff --git a/contrib/runners/orquesta_runner/tests/unit/test_context.py b/contrib/runners/orquesta_runner/tests/unit/test_context.py index 373f512e87d..bce5a508733 100644 --- a/contrib/runners/orquesta_runner/tests/unit/test_context.py +++ b/contrib/runners/orquesta_runner/tests/unit/test_context.py @@ -24,6 +24,7 @@ # XXX: actionsensor import depends on config being setup. import st2tests.config as tests_config + tests_config.parse_args() from tests.unit import base @@ -43,37 +44,45 @@ from st2tests.mocks import liveaction as mock_lv_ac_xport from st2tests.mocks import workflow as mock_wf_ex_xport -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) PACKS = [ TEST_PACK_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] @mock.patch.object( - publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) + publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None) +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create)) + "publish_create", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create + ), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state)) + "publish_state", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state + ), +) class OrquestaContextTest(st2tests.ExecutionDbTestCase): - @classmethod def setUpClass(cls): super(OrquestaContextTest, cls).setUpClass() @@ -83,24 +92,31 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: actions_registrar.register_from_pack(pack) def test_runtime_context(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'runtime-context.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "runtime-context.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Identify the records for the workflow. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0] - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + t1_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + )[0] + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + )[0] # Complete the worklfow. wf_svc.handle_action_execution_completion(t1_ac_ex_db) @@ -113,59 +129,75 @@ def test_runtime_context(self): # Check result. expected_st2_ctx = { - 'action_execution_id': str(ac_ex_db.id), - 'api_url': 'http://127.0.0.1/v1', - 'user': 'stanley', - 'pack': 'orquesta_tests', - 'action': 'orquesta_tests.runtime-context', - 'runner': 'orquesta' + "action_execution_id": str(ac_ex_db.id), + "api_url": "http://127.0.0.1/v1", + "user": "stanley", + "pack": "orquesta_tests", + "action": "orquesta_tests.runtime-context", + "runner": "orquesta", } expected_st2_ctx_with_wf_ex_id = copy.deepcopy(expected_st2_ctx) - expected_st2_ctx_with_wf_ex_id['workflow_execution_id'] = str(wf_ex_db.id) + expected_st2_ctx_with_wf_ex_id["workflow_execution_id"] = str(wf_ex_db.id) expected_output = { - 'st2_ctx_at_input': expected_st2_ctx, - 'st2_ctx_at_vars': expected_st2_ctx, - 'st2_ctx_at_publish': expected_st2_ctx_with_wf_ex_id, - 'st2_ctx_at_output': expected_st2_ctx_with_wf_ex_id + "st2_ctx_at_input": expected_st2_ctx, + "st2_ctx_at_vars": expected_st2_ctx, + "st2_ctx_at_publish": expected_st2_ctx_with_wf_ex_id, + "st2_ctx_at_output": expected_st2_ctx_with_wf_ex_id, } - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} self.assertDictEqual(lv_ac_db.result, expected_result) def test_action_context_sys_user(self): - wf_name = 'subworkflow-default-value-from-action-context' - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + '.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_name = "subworkflow-default-value-from-action-context" + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + ".yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Identify the records for the main workflow. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0] - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0] - t1_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t1_ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + t1_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + )[0] + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + )[0] + t1_wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(t1_ac_ex_db.id) + )[0] self.assertEqual(t1_ex_db.status, wf_statuses.RUNNING) self.assertEqual(t1_ac_ex_db.status, ac_const.LIVEACTION_STATUS_RUNNING) self.assertEqual(t1_wf_ex_db.status, wf_statuses.RUNNING) # Complete subworkflow under task1. - query_filters = {'workflow_execution': str(t1_wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(t1_wf_ex_db.id), "task_id": "task1"} t1_t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t1_ex_db.id))[0] + t1_t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t1_ex_db.id) + )[0] wf_svc.handle_action_execution_completion(t1_t1_ac_ex_db) - query_filters = {'workflow_execution': str(t1_wf_ex_db.id), 'task_id': 'task2'} + query_filters = {"workflow_execution": str(t1_wf_ex_db.id), "task_id": "task2"} t1_t2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t2_ex_db.id))[0] + t1_t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t2_ex_db.id) + )[0] wf_svc.handle_action_execution_completion(t1_t2_ac_ex_db) - query_filters = {'workflow_execution': str(t1_wf_ex_db.id), 'task_id': 'task3'} + query_filters = {"workflow_execution": str(t1_wf_ex_db.id), "task_id": "task3"} t1_t3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t3_ex_db.id))[0] + t1_t3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t3_ex_db.id) + )[0] wf_svc.handle_action_execution_completion(t1_t3_ac_ex_db) t1_wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(str(t1_wf_ex_db.id)) @@ -184,44 +216,60 @@ def test_action_context_sys_user(self): # Check result. expected_result = { - 'output': { - 'msg': 'stanley, All your base are belong to us!' - } + "output": {"msg": "stanley, All your base are belong to us!"} } self.assertDictEqual(lv_ac_db.result, expected_result) def test_action_context_api_user(self): - wf_name = 'subworkflow-default-value-from-action-context' - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + '.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], context={'api_user': 'Thanos'}) + wf_name = "subworkflow-default-value-from-action-context" + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + ".yaml") + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], context={"api_user": "Thanos"} + ) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Identify the records for the main workflow. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0] - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0] - t1_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t1_ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + t1_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + )[0] + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + )[0] + t1_wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(t1_ac_ex_db.id) + )[0] self.assertEqual(t1_ex_db.status, wf_statuses.RUNNING) self.assertEqual(t1_ac_ex_db.status, ac_const.LIVEACTION_STATUS_RUNNING) self.assertEqual(t1_wf_ex_db.status, wf_statuses.RUNNING) # Complete subworkflow under task1. - query_filters = {'workflow_execution': str(t1_wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(t1_wf_ex_db.id), "task_id": "task1"} t1_t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t1_ex_db.id))[0] + t1_t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t1_ex_db.id) + )[0] wf_svc.handle_action_execution_completion(t1_t1_ac_ex_db) - query_filters = {'workflow_execution': str(t1_wf_ex_db.id), 'task_id': 'task2'} + query_filters = {"workflow_execution": str(t1_wf_ex_db.id), "task_id": "task2"} t1_t2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t2_ex_db.id))[0] + t1_t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t2_ex_db.id) + )[0] wf_svc.handle_action_execution_completion(t1_t2_ac_ex_db) - query_filters = {'workflow_execution': str(t1_wf_ex_db.id), 'task_id': 'task3'} + query_filters = {"workflow_execution": str(t1_wf_ex_db.id), "task_id": "task3"} t1_t3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t3_ex_db.id))[0] + t1_t3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t3_ex_db.id) + )[0] wf_svc.handle_action_execution_completion(t1_t3_ac_ex_db) t1_wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(str(t1_wf_ex_db.id)) @@ -239,45 +287,57 @@ def test_action_context_api_user(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) # Check result. - expected_result = { - 'output': { - 'msg': 'Thanos, All your base are belong to us!' - } - } + expected_result = {"output": {"msg": "Thanos, All your base are belong to us!"}} self.assertDictEqual(lv_ac_db.result, expected_result) def test_action_context_no_channel(self): - wf_name = 'subworkflow-source-channel-from-action-context' - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + '.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_name = "subworkflow-source-channel-from-action-context" + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + ".yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Identify the records for the main workflow. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0] - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0] - t1_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t1_ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + t1_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + )[0] + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + )[0] + t1_wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(t1_ac_ex_db.id) + )[0] self.assertEqual(t1_ex_db.status, wf_statuses.RUNNING) self.assertEqual(t1_ac_ex_db.status, ac_const.LIVEACTION_STATUS_RUNNING) self.assertEqual(t1_wf_ex_db.status, wf_statuses.RUNNING) # Complete subworkflow under task1. - query_filters = {'workflow_execution': str(t1_wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(t1_wf_ex_db.id), "task_id": "task1"} t1_t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t1_ex_db.id))[0] + t1_t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t1_ex_db.id) + )[0] wf_svc.handle_action_execution_completion(t1_t1_ac_ex_db) - query_filters = {'workflow_execution': str(t1_wf_ex_db.id), 'task_id': 'task2'} + query_filters = {"workflow_execution": str(t1_wf_ex_db.id), "task_id": "task2"} t1_t2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t2_ex_db.id))[0] + t1_t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t2_ex_db.id) + )[0] wf_svc.handle_action_execution_completion(t1_t2_ac_ex_db) - query_filters = {'workflow_execution': str(t1_wf_ex_db.id), 'task_id': 'task3'} + query_filters = {"workflow_execution": str(t1_wf_ex_db.id), "task_id": "task3"} t1_t3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t3_ex_db.id))[0] + t1_t3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t3_ex_db.id) + )[0] wf_svc.handle_action_execution_completion(t1_t3_ac_ex_db) t1_wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(str(t1_wf_ex_db.id)) @@ -296,45 +356,60 @@ def test_action_context_no_channel(self): # Check result. expected_result = { - 'output': { - 'msg': 'no_channel, All your base are belong to us!' - } + "output": {"msg": "no_channel, All your base are belong to us!"} } self.assertDictEqual(lv_ac_db.result, expected_result) def test_action_context_source_channel(self): - wf_name = 'subworkflow-source-channel-from-action-context' - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + '.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], - context={'source_channel': 'general'}) + wf_name = "subworkflow-source-channel-from-action-context" + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + ".yaml") + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], context={"source_channel": "general"} + ) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Identify the records for the main workflow. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0] - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0] - t1_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t1_ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + t1_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + )[0] + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + )[0] + t1_wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(t1_ac_ex_db.id) + )[0] self.assertEqual(t1_ex_db.status, wf_statuses.RUNNING) self.assertEqual(t1_ac_ex_db.status, ac_const.LIVEACTION_STATUS_RUNNING) self.assertEqual(t1_wf_ex_db.status, wf_statuses.RUNNING) # Complete subworkflow under task1. - query_filters = {'workflow_execution': str(t1_wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(t1_wf_ex_db.id), "task_id": "task1"} t1_t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t1_ex_db.id))[0] + t1_t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t1_ex_db.id) + )[0] wf_svc.handle_action_execution_completion(t1_t1_ac_ex_db) - query_filters = {'workflow_execution': str(t1_wf_ex_db.id), 'task_id': 'task2'} + query_filters = {"workflow_execution": str(t1_wf_ex_db.id), "task_id": "task2"} t1_t2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t2_ex_db.id))[0] + t1_t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t2_ex_db.id) + )[0] wf_svc.handle_action_execution_completion(t1_t2_ac_ex_db) - query_filters = {'workflow_execution': str(t1_wf_ex_db.id), 'task_id': 'task3'} + query_filters = {"workflow_execution": str(t1_wf_ex_db.id), "task_id": "task3"} t1_t3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t3_ex_db.id))[0] + t1_t3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t3_ex_db.id) + )[0] wf_svc.handle_action_execution_completion(t1_t3_ac_ex_db) t1_wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(str(t1_wf_ex_db.id)) @@ -353,9 +428,7 @@ def test_action_context_source_channel(self): # Check result. expected_result = { - 'output': { - 'msg': 'general, All your base are belong to us!' - } + "output": {"msg": "general, All your base are belong to us!"} } self.assertDictEqual(lv_ac_db.result, expected_result) diff --git a/contrib/runners/orquesta_runner/tests/unit/test_data_flow.py b/contrib/runners/orquesta_runner/tests/unit/test_data_flow.py index 00d26f01551..d1c0c249aba 100644 --- a/contrib/runners/orquesta_runner/tests/unit/test_data_flow.py +++ b/contrib/runners/orquesta_runner/tests/unit/test_data_flow.py @@ -26,6 +26,7 @@ # XXX: actionsensor import depends on config being setup. import st2tests.config as tests_config + tests_config.parse_args() from tests.unit import base @@ -47,37 +48,45 @@ from st2tests.mocks import workflow as mock_wf_ex_xport -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) PACKS = [ TEST_PACK_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] @mock.patch.object( - publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) + publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None) +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create)) + "publish_create", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create + ), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state)) + "publish_state", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state + ), +) class OrquestaRunnerTest(st2tests.ExecutionDbTestCase): - @classmethod def setUpClass(cls): super(OrquestaRunnerTest, cls).setUpClass() @@ -87,8 +96,7 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: @@ -99,22 +107,30 @@ def get_runner_class(cls, runner_name): return runners.get_runner(runner_name, runner_name).__class__ def assert_data_flow(self, data): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'data-flow.yaml') - wf_input = {'a1': data} - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "data-flow.yaml") + wf_input = {"a1": data} + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Assert task1 is already completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) # Manually handle action execution completion. @@ -127,10 +143,12 @@ def assert_data_flow(self, data): self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Assert task2 is already completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task2'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task2"} tk2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk2_ex_db.id))[0] - tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction['id']) + tk2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk2_ex_db.id) + )[0] + tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction["id"]) self.assertEqual(tk2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) # Manually handle action execution completion. @@ -143,10 +161,12 @@ def assert_data_flow(self, data): self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Assert task3 is already completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task3'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task3"} tk3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk3_ex_db.id))[0] - tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction['id']) + tk3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk3_ex_db.id) + )[0] + tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction["id"]) self.assertEqual(tk3_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) # Manually handle action execution completion. @@ -164,20 +184,20 @@ def assert_data_flow(self, data): # Check workflow output. expected_output = { - 'a5': wf_input['a1'] if six.PY3 else wf_input['a1'].decode('utf-8'), - 'b5': wf_input['a1'] if six.PY3 else wf_input['a1'].decode('utf-8') + "a5": wf_input["a1"] if six.PY3 else wf_input["a1"].decode("utf-8"), + "b5": wf_input["a1"] if six.PY3 else wf_input["a1"].decode("utf-8"), } self.assertDictEqual(wf_ex_db.output, expected_output) # Check liveaction and action execution result. - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} self.assertDictEqual(lv_ac_db.result, expected_result) self.assertDictEqual(ac_ex_db.result, expected_result) def test_string(self): - self.assert_data_flow('xyz') + self.assert_data_flow("xyz") def test_unicode_string(self): - self.assert_data_flow('床前明月光 疑是地上霜 舉頭望明月 低頭思故鄉') + self.assert_data_flow("床前明月光 疑是地上霜 舉頭望明月 低頭思故鄉") diff --git a/contrib/runners/orquesta_runner/tests/unit/test_delay.py b/contrib/runners/orquesta_runner/tests/unit/test_delay.py index 66834f9952c..d2535c8f03e 100644 --- a/contrib/runners/orquesta_runner/tests/unit/test_delay.py +++ b/contrib/runners/orquesta_runner/tests/unit/test_delay.py @@ -23,6 +23,7 @@ # XXX: actionsensor import depends on config being setup. import st2tests.config as tests_config + tests_config.parse_args() from tests.unit import base @@ -43,37 +44,45 @@ from st2tests.mocks import workflow as mock_wf_ex_xport -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) PACKS = [ TEST_PACK_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] @mock.patch.object( - publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) + publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None) +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create)) + "publish_create", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create + ), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state)) + "publish_state", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state + ), +) class OrquestaRunnerDelayTest(st2tests.ExecutionDbTestCase): - @classmethod def setUpClass(cls): super(OrquestaRunnerDelayTest, cls).setUpClass() @@ -83,8 +92,7 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: @@ -94,17 +102,25 @@ def test_delay(self): expected_delay_sec = 1 expected_delay_msec = expected_delay_sec * 1000 - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'delay.yaml') - wf_input = {'delay': expected_delay_sec} - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "delay.yaml") + wf_input = {"delay": expected_delay_sec} + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) - lv_ac_db = self._wait_on_status(lv_ac_db, action_constants.LIVEACTION_STATUS_RUNNING) + lv_ac_db = self._wait_on_status( + lv_ac_db, action_constants.LIVEACTION_STATUS_RUNNING + ) # Identify records for the main workflow. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0] + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + )[0] t1_lv_ac_db = lv_db_access.LiveAction.query(task_execution=str(t1_ex_db.id))[0] # Assert delay value is rendered and assigned. @@ -116,20 +132,28 @@ def test_delay_for_with_items(self): expected_delay_sec = 1 expected_delay_msec = expected_delay_sec * 1000 - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'with-items-delay.yaml') - wf_input = {'delay': expected_delay_sec} - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "with-items-delay.yaml") + wf_input = {"delay": expected_delay_sec} + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert action execution is running. - lv_ac_db = self._wait_on_status(lv_ac_db, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + lv_ac_db = self._wait_on_status( + lv_ac_db, action_constants.LIVEACTION_STATUS_RUNNING + ) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) # Process the with items task. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) t1_lv_ac_dbs = lv_db_access.LiveAction.query(task_execution=str(t1_ex_db.id)) # Assert delay value is rendered and assigned. @@ -166,20 +190,30 @@ def test_delay_for_with_items_concurrency(self): expected_delay_sec = 1 expected_delay_msec = expected_delay_sec * 1000 - wf_input = {'concurrency': concurrency, 'delay': expected_delay_sec} - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'with-items-concurrency-delay.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_input = {"concurrency": concurrency, "delay": expected_delay_sec} + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "with-items-concurrency-delay.yaml" + ) + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert action execution is running. - lv_ac_db = self._wait_on_status(lv_ac_db, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + lv_ac_db = self._wait_on_status( + lv_ac_db, action_constants.LIVEACTION_STATUS_RUNNING + ) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) # Process the first set of action executions from with items concurrency. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) t1_lv_ac_dbs = lv_db_access.LiveAction.query(task_execution=str(t1_ex_db.id)) # Assert the number of concurrent items is correct. @@ -211,7 +245,9 @@ def test_delay_for_with_items_concurrency(self): self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Process the second set of action executions from with items concurrency. - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) t1_lv_ac_dbs = lv_db_access.LiveAction.query(task_execution=str(t1_ex_db.id)) # Assert delay value is rendered and assigned only to the first set of action executions. diff --git a/contrib/runners/orquesta_runner/tests/unit/test_error_handling.py b/contrib/runners/orquesta_runner/tests/unit/test_error_handling.py index 6f140040caa..d06d3359937 100644 --- a/contrib/runners/orquesta_runner/tests/unit/test_error_handling.py +++ b/contrib/runners/orquesta_runner/tests/unit/test_error_handling.py @@ -24,6 +24,7 @@ # XXX: actionsensor import depends on config being setup. import st2tests.config as tests_config + tests_config.parse_args() from tests.unit import base @@ -48,41 +49,50 @@ from st2common.models.db.execution_queue import ActionExecutionSchedulingQueueItemDB -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) PACKS = [ TEST_PACK_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] @mock.patch.object( - publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) + publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None) +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create)) + "publish_create", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create + ), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state)) + "publish_state", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state + ), +) class OrquestaErrorHandlingTest(st2tests.WorkflowTestCase): ensure_indexes = True ensure_indexes_models = [ WorkflowExecutionDB, TaskExecutionDB, - ActionExecutionSchedulingQueueItemDB + ActionExecutionSchedulingQueueItemDB, ] @classmethod @@ -94,8 +104,7 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: @@ -104,78 +113,86 @@ def setUpClass(cls): def test_fail_inspection(self): expected_errors = [ { - 'type': 'content', - 'message': 'The action "std.noop" is not registered in the database.', - 'schema_path': r'properties.tasks.patternProperties.^\w+$.properties.action', - 'spec_path': 'tasks.task3.action' + "type": "content", + "message": 'The action "std.noop" is not registered in the database.', + "schema_path": r"properties.tasks.patternProperties.^\w+$.properties.action", + "spec_path": "tasks.task3.action", }, { - 'type': 'context', - 'language': 'yaql', - 'expression': '<% ctx().foobar %>', - 'message': 'Variable "foobar" is referenced before assignment.', - 'schema_path': r'properties.tasks.patternProperties.^\w+$.properties.input', - 'spec_path': 'tasks.task1.input', + "type": "context", + "language": "yaql", + "expression": "<% ctx().foobar %>", + "message": 'Variable "foobar" is referenced before assignment.', + "schema_path": r"properties.tasks.patternProperties.^\w+$.properties.input", + "spec_path": "tasks.task1.input", }, { - 'type': 'expression', - 'language': 'yaql', - 'expression': '<% <% succeeded() %>', - 'message': ( - 'Parse error: unexpected \'<\' at ' - 'position 0 of expression \'<% succeeded()\'' + "type": "expression", + "language": "yaql", + "expression": "<% <% succeeded() %>", + "message": ( + "Parse error: unexpected '<' at " + "position 0 of expression '<% succeeded()'" ), - 'schema_path': ( - r'properties.tasks.patternProperties.^\w+$.' - 'properties.next.items.properties.when' + "schema_path": ( + r"properties.tasks.patternProperties.^\w+$." + "properties.next.items.properties.when" ), - 'spec_path': 'tasks.task2.next[0].when' + "spec_path": "tasks.task2.next[0].when", }, { - 'type': 'syntax', - 'message': ( - '[{\'cmd\': \'echo <% ctx().macro %>\'}] is ' - 'not valid under any of the given schemas' + "type": "syntax", + "message": ( + "[{'cmd': 'echo <% ctx().macro %>'}] is " + "not valid under any of the given schemas" ), - 'schema_path': r'properties.tasks.patternProperties.^\w+$.properties.input.oneOf', - 'spec_path': 'tasks.task2.input' - } + "schema_path": r"properties.tasks.patternProperties.^\w+$.properties.input.oneOf", + "spec_path": "tasks.task2.input", + }, ] - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'fail-inspection.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "fail-inspection.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED) - self.assertIn('errors', lv_ac_db.result) - self.assertListEqual(lv_ac_db.result['errors'], expected_errors) + self.assertIn("errors", lv_ac_db.result) + self.assertListEqual(lv_ac_db.result["errors"], expected_errors) def test_fail_input_rendering(self): expected_errors = [ { - 'type': 'error', - 'message': ( - 'YaqlEvaluationException: Unable to evaluate expression ' - '\'<% abs(4).value %>\'. NoFunctionRegisteredException: ' + "type": "error", + "message": ( + "YaqlEvaluationException: Unable to evaluate expression " + "'<% abs(4).value %>'. NoFunctionRegisteredException: " 'Unknown function "#property#value"' - ) + ), } ] - expected_result = {'output': None, 'errors': expected_errors} + expected_result = {"output": None, "errors": expected_errors} - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'fail-input-rendering.yaml') + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "fail-input-rendering.yaml" + ) - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) # Assert action execution for task is not started and workflow failed. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id)) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(tk_ex_dbs), 0) self.assertEqual(wf_ex_db.status, wf_statuses.FAILED) - self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors) + self.assertListEqual( + self.sort_workflow_errors(wf_ex_db.errors), expected_errors + ) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED) @@ -188,28 +205,36 @@ def test_fail_input_rendering(self): def test_fail_vars_rendering(self): expected_errors = [ { - 'type': 'error', - 'message': ( - 'YaqlEvaluationException: Unable to evaluate expression ' - '\'<% abs(4).value %>\'. NoFunctionRegisteredException: ' + "type": "error", + "message": ( + "YaqlEvaluationException: Unable to evaluate expression " + "'<% abs(4).value %>'. NoFunctionRegisteredException: " 'Unknown function "#property#value"' - ) + ), } ] - expected_result = {'output': None, 'errors': expected_errors} + expected_result = {"output": None, "errors": expected_errors} - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'fail-vars-rendering.yaml') + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "fail-vars-rendering.yaml" + ) - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) # Assert action execution for task is not started and workflow failed. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id)) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(tk_ex_dbs), 0) self.assertEqual(wf_ex_db.status, wf_statuses.FAILED) - self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors) + self.assertListEqual( + self.sort_workflow_errors(wf_ex_db.errors), expected_errors + ) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED) @@ -222,30 +247,38 @@ def test_fail_vars_rendering(self): def test_fail_start_task_action(self): expected_errors = [ { - 'type': 'error', - 'message': ( - 'YaqlEvaluationException: Unable to evaluate expression ' - '\'<% ctx().func.value %>\'. NoFunctionRegisteredException: ' + "type": "error", + "message": ( + "YaqlEvaluationException: Unable to evaluate expression " + "'<% ctx().func.value %>'. NoFunctionRegisteredException: " 'Unknown function "#property#value"' ), - 'task_id': 'task1', - 'route': 0 + "task_id": "task1", + "route": 0, } ] - expected_result = {'output': None, 'errors': expected_errors} + expected_result = {"output": None, "errors": expected_errors} - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'fail-start-task-action.yaml') + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "fail-start-task-action.yaml" + ) - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) # Assert action execution for task is not started and workflow failed. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id)) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(tk_ex_dbs), 0) self.assertEqual(wf_ex_db.status, wf_statuses.FAILED) - self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors) + self.assertListEqual( + self.sort_workflow_errors(wf_ex_db.errors), expected_errors + ) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED) @@ -258,31 +291,37 @@ def test_fail_start_task_action(self): def test_fail_start_task_input_expr_eval(self): expected_errors = [ { - 'type': 'error', - 'message': ( - 'YaqlEvaluationException: Unable to evaluate expression ' - '\'<% ctx().msg1.value %>\'. NoFunctionRegisteredException: ' + "type": "error", + "message": ( + "YaqlEvaluationException: Unable to evaluate expression " + "'<% ctx().msg1.value %>'. NoFunctionRegisteredException: " 'Unknown function "#property#value"' ), - 'task_id': 'task1', - 'route': 0 + "task_id": "task1", + "route": 0, } ] - expected_result = {'output': None, 'errors': expected_errors} + expected_result = {"output": None, "errors": expected_errors} - wf_file = 'fail-start-task-input-expr-eval.yaml' + wf_file = "fail-start-task-input-expr-eval.yaml" wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_file) - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) # Assert action execution for task is not started and workflow failed. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id)) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(tk_ex_dbs), 0) self.assertEqual(wf_ex_db.status, wf_statuses.FAILED) - self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors) + self.assertListEqual( + self.sort_workflow_errors(wf_ex_db.errors), expected_errors + ) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED) @@ -294,37 +333,40 @@ def test_fail_start_task_input_expr_eval(self): def test_fail_start_task_input_value_type(self): if six.PY3: - msg = 'Value "{\'x\': \'foobar\'}" must either be a string or None. Got "dict".' + msg = "Value \"{'x': 'foobar'}\" must either be a string or None. Got \"dict\"." else: - msg = 'Value "{u\'x\': u\'foobar\'}" must either be a string or None. Got "dict".' + msg = "Value \"{u'x': u'foobar'}\" must either be a string or None. Got \"dict\"." - msg = 'ValueError: ' + msg + msg = "ValueError: " + msg expected_errors = [ - { - 'type': 'error', - 'message': msg, - 'task_id': 'task1', - 'route': 0 - } + {"type": "error", "message": msg, "task_id": "task1", "route": 0} ] - expected_result = {'output': None, 'errors': expected_errors} + expected_result = {"output": None, "errors": expected_errors} - wf_file = 'fail-start-task-input-value-type.yaml' + wf_file = "fail-start-task-input-value-type.yaml" wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_file) - wf_input = {'var1': {'x': 'foobar'}} - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_input = {"var1": {"x": "foobar"}} + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) # Assert workflow and task executions failed. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, wf_statuses.FAILED) - self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors) + self.assertListEqual( + self.sort_workflow_errors(wf_ex_db.errors), expected_errors + ) - tk_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0] + tk_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + )[0] self.assertEqual(tk_ex_db.status, wf_statuses.FAILED) - self.assertDictEqual(tk_ex_db.result, {'errors': expected_errors}) + self.assertDictEqual(tk_ex_db.result, {"errors": expected_errors}) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED) @@ -337,29 +379,35 @@ def test_fail_start_task_input_value_type(self): def test_fail_next_task_action(self): expected_errors = [ { - 'type': 'error', - 'message': ( - 'YaqlEvaluationException: Unable to evaluate expression ' - '\'<% ctx().func.value %>\'. NoFunctionRegisteredException: ' + "type": "error", + "message": ( + "YaqlEvaluationException: Unable to evaluate expression " + "'<% ctx().func.value %>'. NoFunctionRegisteredException: " 'Unknown function "#property#value"' ), - 'task_id': 'task2', - 'route': 0 + "task_id": "task2", + "route": 0, } ] - expected_result = {'output': None, 'errors': expected_errors} + expected_result = {"output": None, "errors": expected_errors} - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'fail-task-action.yaml') + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "fail-task-action.yaml") - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) # Assert task1 is already completed. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - tk_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0] - tk_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_db.id))[0] - tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction['id']) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + tk_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + )[0] + tk_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_db.id) + )[0] + tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction["id"]) self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) # Manually handle action execution completion for task1 which has an error in publish. @@ -370,7 +418,9 @@ def test_fail_next_task_action(self): self.assertEqual(tk_ex_db.status, wf_statuses.SUCCEEDED) wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) self.assertEqual(wf_ex_db.status, wf_statuses.FAILED) - self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors) + self.assertListEqual( + self.sort_workflow_errors(wf_ex_db.errors), expected_errors + ) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED) @@ -383,29 +433,37 @@ def test_fail_next_task_action(self): def test_fail_next_task_input_expr_eval(self): expected_errors = [ { - 'type': 'error', - 'message': ( - 'YaqlEvaluationException: Unable to evaluate expression ' - '\'<% ctx().msg2.value %>\'. NoFunctionRegisteredException: ' + "type": "error", + "message": ( + "YaqlEvaluationException: Unable to evaluate expression " + "'<% ctx().msg2.value %>'. NoFunctionRegisteredException: " 'Unknown function "#property#value"' ), - 'task_id': 'task2', - 'route': 0 + "task_id": "task2", + "route": 0, } ] - expected_result = {'output': None, 'errors': expected_errors} + expected_result = {"output": None, "errors": expected_errors} - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'fail-task-input-expr-eval.yaml') + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "fail-task-input-expr-eval.yaml" + ) - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) # Assert task1 is already completed. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - tk_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0] - tk_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_db.id))[0] - tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction['id']) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + tk_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + )[0] + tk_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_db.id) + )[0] + tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction["id"]) self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) # Manually handle action execution completion for task1 which has an error in publish. @@ -416,7 +474,9 @@ def test_fail_next_task_input_expr_eval(self): self.assertEqual(tk_ex_db.status, wf_statuses.SUCCEEDED) wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) self.assertEqual(wf_ex_db.status, wf_statuses.FAILED) - self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors) + self.assertListEqual( + self.sort_workflow_errors(wf_ex_db.errors), expected_errors + ) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED) @@ -428,34 +488,37 @@ def test_fail_next_task_input_expr_eval(self): def test_fail_next_task_input_value_type(self): if six.PY3: - msg = 'Value "{\'x\': \'foobar\'}" must either be a string or None. Got "dict".' + msg = "Value \"{'x': 'foobar'}\" must either be a string or None. Got \"dict\"." else: - msg = 'Value "{u\'x\': u\'foobar\'}" must either be a string or None. Got "dict".' + msg = "Value \"{u'x': u'foobar'}\" must either be a string or None. Got \"dict\"." - msg = 'ValueError: ' + msg + msg = "ValueError: " + msg expected_errors = [ - { - 'type': 'error', - 'message': msg, - 'task_id': 'task2', - 'route': 0 - } + {"type": "error", "message": msg, "task_id": "task2", "route": 0} ] - expected_result = {'output': None, 'errors': expected_errors} + expected_result = {"output": None, "errors": expected_errors} - wf_file = 'fail-task-input-value-type.yaml' + wf_file = "fail-task-input-value-type.yaml" wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_file) - wf_input = {'var1': {'x': 'foobar'}} - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_input = {"var1": {"x": "foobar"}} + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) # Assert task1 is already completed and workflow execution is still running. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - tk1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + tk1_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + )[0] + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) @@ -465,11 +528,13 @@ def test_fail_next_task_input_value_type(self): # Assert workflow execution and task2 execution failed. wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(str(wf_ex_db.id)) self.assertEqual(wf_ex_db.status, wf_statuses.FAILED) - self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors) + self.assertListEqual( + self.sort_workflow_errors(wf_ex_db.errors), expected_errors + ) - tk2_ex_db = wf_db_access.TaskExecution.query(task_id='task2')[0] + tk2_ex_db = wf_db_access.TaskExecution.query(task_id="task2")[0] self.assertEqual(tk2_ex_db.status, wf_statuses.FAILED) - self.assertDictEqual(tk2_ex_db.result, {'errors': expected_errors}) + self.assertDictEqual(tk2_ex_db.result, {"errors": expected_errors}) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED) @@ -482,37 +547,47 @@ def test_fail_next_task_input_value_type(self): def test_fail_task_execution(self): expected_errors = [ { - 'type': 'error', - 'message': 'Execution failed. See result for details.', - 'task_id': 'task1', - 'result': { - 'stdout': '', - 'stderr': 'boom!', - 'return_code': 1, - 'failed': True, - 'succeeded': False - } + "type": "error", + "message": "Execution failed. See result for details.", + "task_id": "task1", + "result": { + "stdout": "", + "stderr": "boom!", + "return_code": 1, + "failed": True, + "succeeded": False, + }, } ] - expected_result = {'output': None, 'errors': expected_errors} + expected_result = {"output": None, "errors": expected_errors} - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'fail-task-execution.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "fail-task-execution.yaml" + ) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) # Process task1. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - tk1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + tk1_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + )[0] + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED) wf_svc.handle_action_execution_completion(tk1_ac_ex_db) # Assert workflow state and result. wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(str(wf_ex_db.id)) self.assertEqual(wf_ex_db.status, wf_statuses.FAILED) - self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors) + self.assertListEqual( + self.sort_workflow_errors(wf_ex_db.errors), expected_errors + ) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED) @@ -525,28 +600,36 @@ def test_fail_task_execution(self): def test_fail_task_transition(self): expected_errors = [ { - 'type': 'error', - 'message': ( + "type": "error", + "message": ( "YaqlEvaluationException: Unable to resolve key 'foobar' in expression " "'<% succeeded() and result().foobar %>' from context." ), - 'task_transition_id': 'task2__t0', - 'task_id': 'task1', - 'route': 0 + "task_transition_id": "task2__t0", + "task_id": "task1", + "route": 0, } ] - expected_result = {'output': None, 'errors': expected_errors} + expected_result = {"output": None, "errors": expected_errors} - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'fail-task-transition.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "fail-task-transition.yaml" + ) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) # Assert task1 is already completed. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - tk_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0] - tk_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_db.id))[0] - tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction['id']) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + tk_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + )[0] + tk_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_db.id) + )[0] + tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction["id"]) self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) # Manually handle action execution completion for task1 which has an error in publish. @@ -557,7 +640,9 @@ def test_fail_task_transition(self): self.assertEqual(tk_ex_db.status, wf_statuses.SUCCEEDED) wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) self.assertEqual(wf_ex_db.status, wf_statuses.FAILED) - self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors) + self.assertListEqual( + self.sort_workflow_errors(wf_ex_db.errors), expected_errors + ) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED) @@ -570,29 +655,37 @@ def test_fail_task_transition(self): def test_fail_task_publish(self): expected_errors = [ { - 'type': 'error', - 'message': ( - 'YaqlEvaluationException: Unable to evaluate expression ' - '\'<% foobar() %>\'. NoFunctionRegisteredException: ' + "type": "error", + "message": ( + "YaqlEvaluationException: Unable to evaluate expression " + "'<% foobar() %>'. NoFunctionRegisteredException: " 'Unknown function "foobar"' ), - 'task_transition_id': 'task2__t0', - 'task_id': 'task1', - 'route': 0 + "task_transition_id": "task2__t0", + "task_id": "task1", + "route": 0, } ] - expected_result = {'output': None, 'errors': expected_errors} + expected_result = {"output": None, "errors": expected_errors} - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'fail-task-publish.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "fail-task-publish.yaml" + ) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) # Assert task1 is already completed. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - tk_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0] - tk_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_db.id))[0] - tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction['id']) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + tk_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + )[0] + tk_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_db.id) + )[0] + tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction["id"]) self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) # Manually handle action execution completion for task1 which has an error in publish. @@ -603,7 +696,9 @@ def test_fail_task_publish(self): self.assertEqual(tk_ex_db.status, wf_statuses.SUCCEEDED) wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) self.assertEqual(wf_ex_db.status, wf_statuses.FAILED) - self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors) + self.assertListEqual( + self.sort_workflow_errors(wf_ex_db.errors), expected_errors + ) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED) @@ -616,26 +711,34 @@ def test_fail_task_publish(self): def test_fail_output_rendering(self): expected_errors = [ { - 'type': 'error', - 'message': ( - 'YaqlEvaluationException: Unable to evaluate expression ' - '\'<% abs(4).value %>\'. NoFunctionRegisteredException: ' + "type": "error", + "message": ( + "YaqlEvaluationException: Unable to evaluate expression " + "'<% abs(4).value %>'. NoFunctionRegisteredException: " 'Unknown function "#property#value"' - ) + ), } ] - expected_result = {'output': None, 'errors': expected_errors} + expected_result = {"output": None, "errors": expected_errors} - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'fail-output-rendering.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "fail-output-rendering.yaml" + ) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) # Assert task1 is already completed. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - tk_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0] - tk_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_db.id))[0] - tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction['id']) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + tk_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + )[0] + tk_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_db.id) + )[0] + tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction["id"]) self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) # Manually handle action execution completion for task1 which has an error in publish. @@ -646,7 +749,9 @@ def test_fail_output_rendering(self): self.assertEqual(tk_ex_db.status, wf_statuses.SUCCEEDED) wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) self.assertEqual(wf_ex_db.status, wf_statuses.FAILED) - self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors) + self.assertListEqual( + self.sort_workflow_errors(wf_ex_db.errors), expected_errors + ) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED) @@ -657,50 +762,51 @@ def test_fail_output_rendering(self): self.assertDictEqual(ac_ex_db.result, expected_result) def test_output_on_error(self): - expected_output = { - 'progress': 25 - } + expected_output = {"progress": 25} expected_errors = [ { - 'type': 'error', - 'task_id': 'task2', - 'message': 'Execution failed. See result for details.', - 'result': { - 'failed': True, - 'return_code': 1, - 'stderr': '', - 'stdout': '', - 'succeeded': False - } + "type": "error", + "task_id": "task2", + "message": "Execution failed. See result for details.", + "result": { + "failed": True, + "return_code": 1, + "stderr": "", + "stdout": "", + "succeeded": False, + }, } ] - expected_result = { - 'errors': expected_errors, - 'output': expected_output - } + expected_result = {"errors": expected_errors, "output": expected_output} - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'output-on-error.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "output-on-error.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] # Assert task1 is already completed and workflow execution is still running. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) wf_svc.handle_action_execution_completion(tk1_ac_ex_db) wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Assert task2 is already completed and workflow execution has failed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task2'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task2"} tk2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk2_ex_db.id))[0] - tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction['id']) + tk2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk2_ex_db.id) + )[0] + tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction["id"]) self.assertEqual(tk2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED) wf_svc.handle_action_execution_completion(tk2_ac_ex_db) @@ -718,26 +824,32 @@ def test_output_on_error(self): self.assertDictEqual(ac_ex_db.result, expected_result) def test_fail_manually(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'fail-manually.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "fail-manually.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] # Assert task1 and workflow execution failed due to fail in the task transition. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED) wf_svc.handle_action_execution_completion(tk1_ac_ex_db) wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) self.assertEqual(wf_ex_db.status, wf_statuses.FAILED) # Assert log task is scheduled even though the workflow execution failed manually. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'log'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "log"} tk2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk2_ex_db.id))[0] - tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction['id']) + tk2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk2_ex_db.id) + )[0] + tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction["id"]) self.assertEqual(tk2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) wf_svc.handle_action_execution_completion(tk2_ac_ex_db) wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) @@ -746,38 +858,44 @@ def test_fail_manually(self): # Check errors and output. expected_errors = [ { - 'task_id': 'fail', - 'type': 'error', - 'message': 'Execution failed. See result for details.' + "task_id": "fail", + "type": "error", + "message": "Execution failed. See result for details.", }, { - 'task_id': 'task1', - 'type': 'error', - 'message': 'Execution failed. See result for details.', - 'result': { - 'failed': True, - 'return_code': 1, - 'stderr': '', - 'stdout': '', - 'succeeded': False - } - } + "task_id": "task1", + "type": "error", + "message": "Execution failed. See result for details.", + "result": { + "failed": True, + "return_code": 1, + "stderr": "", + "stdout": "", + "succeeded": False, + }, + }, ] - self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors) + self.assertListEqual( + self.sort_workflow_errors(wf_ex_db.errors), expected_errors + ) def test_fail_manually_with_recovery_failure(self): - wf_file = 'fail-manually-with-recovery-failure.yaml' + wf_file = "fail-manually-with-recovery-failure.yaml" wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_file) - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] # Assert task1 and workflow execution failed due to fail in the task transition. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED) wf_svc.handle_action_execution_completion(tk1_ac_ex_db) wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) @@ -785,10 +903,12 @@ def test_fail_manually_with_recovery_failure(self): # Assert recover task is scheduled even though the workflow execution failed manually. # The recover task in the workflow is setup to fail. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'recover'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "recover"} tk2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk2_ex_db.id))[0] - tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction['id']) + tk2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk2_ex_db.id) + )[0] + tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction["id"]) self.assertEqual(tk2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED) wf_svc.handle_action_execution_completion(tk2_ac_ex_db) wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) @@ -797,61 +917,70 @@ def test_fail_manually_with_recovery_failure(self): # Check errors and output. expected_errors = [ { - 'task_id': 'fail', - 'type': 'error', - 'message': 'Execution failed. See result for details.' + "task_id": "fail", + "type": "error", + "message": "Execution failed. See result for details.", }, { - 'task_id': 'recover', - 'type': 'error', - 'message': 'Execution failed. See result for details.', - 'result': { - 'failed': True, - 'return_code': 1, - 'stderr': '', - 'stdout': '', - 'succeeded': False - } + "task_id": "recover", + "type": "error", + "message": "Execution failed. See result for details.", + "result": { + "failed": True, + "return_code": 1, + "stderr": "", + "stdout": "", + "succeeded": False, + }, }, { - 'task_id': 'task1', - 'type': 'error', - 'message': 'Execution failed. See result for details.', - 'result': { - 'failed': True, - 'return_code': 1, - 'stderr': '', - 'stdout': '', - 'succeeded': False - } - } + "task_id": "task1", + "type": "error", + "message": "Execution failed. See result for details.", + "result": { + "failed": True, + "return_code": 1, + "stderr": "", + "stdout": "", + "succeeded": False, + }, + }, ] - self.assertListEqual(self.sort_workflow_errors(wf_ex_db.errors), expected_errors) + self.assertListEqual( + self.sort_workflow_errors(wf_ex_db.errors), expected_errors + ) @mock.patch.object( - runners_utils, - 'invoke_post_run', - mock.MagicMock(return_value=None)) + runners_utils, "invoke_post_run", mock.MagicMock(return_value=None) + ) def test_include_result_to_error_log(self): - username = 'stanley' - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - wf_input = {'who': 'Thanos'} - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + username = "stanley" + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + wf_input = {"who": "Thanos"} + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) - wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id)) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) + wf_ex_dbs = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + ) wf_ex_db = wf_ex_dbs[0] # Assert task1 is already completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) - self.assertEqual(tk1_lv_ac_db.context.get('user'), username) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) + self.assertEqual(tk1_lv_ac_db.context.get("user"), username) self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) # Manually override and fail the action execution and write some result. @@ -862,11 +991,13 @@ def test_include_result_to_error_log(self): tk1_lv_ac_db, ac_const.LIVEACTION_STATUS_FAILED, result=result, - publish=False + publish=False, ) - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_FAILED) self.assertDictEqual(tk1_lv_ac_db.result, result) @@ -882,14 +1013,10 @@ def test_include_result_to_error_log(self): # Assert result is included in the error log. expected_errors = [ { - 'message': 'Execution failed. See result for details.', - 'type': 'error', - 'task_id': 'task1', - 'result': { - '127.0.0.1': { - 'hostname': 'foobar' - } - } + "message": "Execution failed. See result for details.", + "type": "error", + "task_id": "task1", + "result": {"127.0.0.1": {"hostname": "foobar"}}, } ] diff --git a/contrib/runners/orquesta_runner/tests/unit/test_functions_common.py b/contrib/runners/orquesta_runner/tests/unit/test_functions_common.py index d8c416f13a5..faa92bd03ab 100644 --- a/contrib/runners/orquesta_runner/tests/unit/test_functions_common.py +++ b/contrib/runners/orquesta_runner/tests/unit/test_functions_common.py @@ -23,6 +23,7 @@ # XXX: actionsensor import depends on config being setup. import st2tests.config as tests_config + tests_config.parse_args() from tests.unit import base @@ -44,37 +45,45 @@ from st2tests.mocks import workflow as mock_wf_ex_xport -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) PACKS = [ TEST_PACK_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] @mock.patch.object( - publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) + publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None) +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create)) + "publish_create", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create + ), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state)) + "publish_state", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state + ), +) class OrquestaFunctionTest(st2tests.ExecutionDbTestCase): - @classmethod def setUpClass(cls): super(OrquestaFunctionTest, cls).setUpClass() @@ -84,30 +93,35 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: actions_registrar.register_from_pack(pack) def _execute_workflow(self, wf_name, expected_output): - wf_file = wf_name + '.yaml' + wf_file = wf_name + ".yaml" wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_file) - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Assert task1 is already completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) self.assertTrue(wf_svc.is_action_execution_under_workflow_context(tk1_ac_ex_db)) @@ -123,149 +137,139 @@ def _execute_workflow(self, wf_name, expected_output): self.assertEqual(ac_ex_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) # Check workflow output, liveaction result, and action execution result. - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} self.assertDictEqual(wf_ex_db.output, expected_output) self.assertDictEqual(lv_ac_db.result, expected_result) self.assertDictEqual(ac_ex_db.result, expected_result) def test_data_functions_in_yaql(self): - wf_name = 'yaql-data-functions' + wf_name = "yaql-data-functions" expected_output = { - 'data_json_str_1': '{"foo": {"bar": "foobar"}}', - 'data_json_str_2': '{"foo": {"bar": "foobar"}}', - 'data_json_str_3': '{"foo": {"bar": "foobar"}}', - 'data_json_obj_1': {'foo': {'bar': 'foobar'}}, - 'data_json_obj_2': {'foo': {'bar': 'foobar'}}, - 'data_json_obj_3': {'foo': {'bar': 'foobar'}}, - 'data_json_obj_4': {'foo': {'bar': 'foobar'}}, - 'data_yaml_str_1': 'foo:\n bar: foobar\n', - 'data_yaml_str_2': 'foo:\n bar: foobar\n', - 'data_query_1': ['foobar'], - 'data_none_str': data_funcs.NONE_MAGIC_VALUE, - 'data_str': 'foobar' + "data_json_str_1": '{"foo": {"bar": "foobar"}}', + "data_json_str_2": '{"foo": {"bar": "foobar"}}', + "data_json_str_3": '{"foo": {"bar": "foobar"}}', + "data_json_obj_1": {"foo": {"bar": "foobar"}}, + "data_json_obj_2": {"foo": {"bar": "foobar"}}, + "data_json_obj_3": {"foo": {"bar": "foobar"}}, + "data_json_obj_4": {"foo": {"bar": "foobar"}}, + "data_yaml_str_1": "foo:\n bar: foobar\n", + "data_yaml_str_2": "foo:\n bar: foobar\n", + "data_query_1": ["foobar"], + "data_none_str": data_funcs.NONE_MAGIC_VALUE, + "data_str": "foobar", } self._execute_workflow(wf_name, expected_output) def test_data_functions_in_jinja(self): - wf_name = 'jinja-data-functions' + wf_name = "jinja-data-functions" expected_output = { - 'data_json_str_1': '{"foo": {"bar": "foobar"}}', - 'data_json_str_2': '{"foo": {"bar": "foobar"}}', - 'data_json_str_3': '{"foo": {"bar": "foobar"}}', - 'data_json_obj_1': {'foo': {'bar': 'foobar'}}, - 'data_json_obj_2': {'foo': {'bar': 'foobar'}}, - 'data_json_obj_3': {'foo': {'bar': 'foobar'}}, - 'data_json_obj_4': {'foo': {'bar': 'foobar'}}, - 'data_yaml_str_1': 'foo:\n bar: foobar\n', - 'data_yaml_str_2': 'foo:\n bar: foobar\n', - 'data_query_1': ['foobar'], - 'data_pipe_str_1': '{"foo": {"bar": "foobar"}}', - 'data_none_str': data_funcs.NONE_MAGIC_VALUE, - 'data_str': 'foobar', - 'data_list_str': '- a: 1\n b: 2\n- x: 3\n y: 4\n' + "data_json_str_1": '{"foo": {"bar": "foobar"}}', + "data_json_str_2": '{"foo": {"bar": "foobar"}}', + "data_json_str_3": '{"foo": {"bar": "foobar"}}', + "data_json_obj_1": {"foo": {"bar": "foobar"}}, + "data_json_obj_2": {"foo": {"bar": "foobar"}}, + "data_json_obj_3": {"foo": {"bar": "foobar"}}, + "data_json_obj_4": {"foo": {"bar": "foobar"}}, + "data_yaml_str_1": "foo:\n bar: foobar\n", + "data_yaml_str_2": "foo:\n bar: foobar\n", + "data_query_1": ["foobar"], + "data_pipe_str_1": '{"foo": {"bar": "foobar"}}', + "data_none_str": data_funcs.NONE_MAGIC_VALUE, + "data_str": "foobar", + "data_list_str": "- a: 1\n b: 2\n- x: 3\n y: 4\n", } self._execute_workflow(wf_name, expected_output) def test_path_functions_in_yaql(self): - wf_name = 'yaql-path-functions' + wf_name = "yaql-path-functions" - expected_output = { - 'basename': 'file.txt', - 'dirname': '/path/to/some' - } + expected_output = {"basename": "file.txt", "dirname": "/path/to/some"} self._execute_workflow(wf_name, expected_output) def test_path_functions_in_jinja(self): - wf_name = 'jinja-path-functions' + wf_name = "jinja-path-functions" - expected_output = { - 'basename': 'file.txt', - 'dirname': '/path/to/some' - } + expected_output = {"basename": "file.txt", "dirname": "/path/to/some"} self._execute_workflow(wf_name, expected_output) def test_regex_functions_in_yaql(self): - wf_name = 'yaql-regex-functions' + wf_name = "yaql-regex-functions" expected_output = { - 'match': True, - 'replace': 'wxyz', - 'search': True, - 'substring': '668 Infinite Dr' + "match": True, + "replace": "wxyz", + "search": True, + "substring": "668 Infinite Dr", } self._execute_workflow(wf_name, expected_output) def test_regex_functions_in_jinja(self): - wf_name = 'jinja-regex-functions' + wf_name = "jinja-regex-functions" expected_output = { - 'match': True, - 'replace': 'wxyz', - 'search': True, - 'substring': '668 Infinite Dr' + "match": True, + "replace": "wxyz", + "search": True, + "substring": "668 Infinite Dr", } self._execute_workflow(wf_name, expected_output) def test_time_functions_in_yaql(self): - wf_name = 'yaql-time-functions' + wf_name = "yaql-time-functions" - expected_output = { - 'time': '3h25m45s' - } + expected_output = {"time": "3h25m45s"} self._execute_workflow(wf_name, expected_output) def test_time_functions_in_jinja(self): - wf_name = 'jinja-time-functions' + wf_name = "jinja-time-functions" - expected_output = { - 'time': '3h25m45s' - } + expected_output = {"time": "3h25m45s"} self._execute_workflow(wf_name, expected_output) def test_version_functions_in_yaql(self): - wf_name = 'yaql-version-functions' + wf_name = "yaql-version-functions" expected_output = { - 'compare_equal': 0, - 'compare_more_than': -1, - 'compare_less_than': 1, - 'equal': True, - 'more_than': False, - 'less_than': False, - 'match': True, - 'bump_major': '1.0.0', - 'bump_minor': '0.11.0', - 'bump_patch': '0.10.1', - 'strip_patch': '0.10' + "compare_equal": 0, + "compare_more_than": -1, + "compare_less_than": 1, + "equal": True, + "more_than": False, + "less_than": False, + "match": True, + "bump_major": "1.0.0", + "bump_minor": "0.11.0", + "bump_patch": "0.10.1", + "strip_patch": "0.10", } self._execute_workflow(wf_name, expected_output) def test_version_functions_in_jinja(self): - wf_name = 'jinja-version-functions' + wf_name = "jinja-version-functions" expected_output = { - 'compare_equal': 0, - 'compare_more_than': -1, - 'compare_less_than': 1, - 'equal': True, - 'more_than': False, - 'less_than': False, - 'match': True, - 'bump_major': '1.0.0', - 'bump_minor': '0.11.0', - 'bump_patch': '0.10.1', - 'strip_patch': '0.10' + "compare_equal": 0, + "compare_more_than": -1, + "compare_less_than": 1, + "equal": True, + "more_than": False, + "less_than": False, + "match": True, + "bump_major": "1.0.0", + "bump_minor": "0.11.0", + "bump_patch": "0.10.1", + "strip_patch": "0.10", } self._execute_workflow(wf_name, expected_output) diff --git a/contrib/runners/orquesta_runner/tests/unit/test_functions_st2kv.py b/contrib/runners/orquesta_runner/tests/unit/test_functions_st2kv.py index 846afa19f0d..3004857bee2 100644 --- a/contrib/runners/orquesta_runner/tests/unit/test_functions_st2kv.py +++ b/contrib/runners/orquesta_runner/tests/unit/test_functions_st2kv.py @@ -23,6 +23,7 @@ # XXX: actionsensor import depends on config being setup. import st2tests.config as tests_config + tests_config.parse_args() from orquesta_functions import st2kv @@ -37,14 +38,13 @@ from st2common.util import keyvalue as kvp_util -MOCK_CTX = {'__vars': {'st2': {'user': 'stanley'}}} -MOCK_CTX_NO_USER = {'__vars': {'st2': {}}} +MOCK_CTX = {"__vars": {"st2": {"user": "stanley"}}} +MOCK_CTX_NO_USER = {"__vars": {"st2": {}}} class DatastoreFunctionTest(unittest2.TestCase): - def test_missing_user_context(self): - self.assertRaises(KeyError, st2kv.st2kv_, MOCK_CTX_NO_USER, 'foo') + self.assertRaises(KeyError, st2kv.st2kv_, MOCK_CTX_NO_USER, "foo") def test_invalid_input(self): self.assertRaises(TypeError, st2kv.st2kv_, None, 123) @@ -55,35 +55,29 @@ def test_invalid_input(self): class UserScopeDatastoreFunctionTest(st2tests.ExecutionDbTestCase): - @classmethod def setUpClass(cls): super(UserScopeDatastoreFunctionTest, cls).setUpClass() - user = auth_db.UserDB(name='stanley') + user = auth_db.UserDB(name="stanley") user.save() scope = kvp_const.FULL_USER_SCOPE cls.kvps = {} # Plain keys - keys = { - 'stanley:foo': 'bar', - 'stanley:foo_empty': '', - 'stanley:foo_null': None - } + keys = {"stanley:foo": "bar", "stanley:foo_empty": "", "stanley:foo_null": None} for k, v in six.iteritems(keys): instance = kvp_db.KeyValuePairDB(name=k, value=v, scope=scope) cls.kvps[k] = kvp_db_access.KeyValuePair.add_or_update(instance) # Secret key - keys = { - 'stanley:fu': 'bar', - 'stanley:fu_empty': '' - } + keys = {"stanley:fu": "bar", "stanley:fu_empty": ""} for k, v in six.iteritems(keys): value = crypto.symmetric_encrypt(kvp_api.KeyValuePairAPI.crypto_key, v) - instance = kvp_db.KeyValuePairDB(name=k, value=value, scope=scope, secret=True) + instance = kvp_db.KeyValuePairDB( + name=k, value=value, scope=scope, secret=True + ) cls.kvps[k] = kvp_db_access.KeyValuePair.add_or_update(instance) @classmethod @@ -94,9 +88,9 @@ def tearDownClass(cls): super(UserScopeDatastoreFunctionTest, cls).tearDownClass() def test_key_exists(self): - self.assertEqual(st2kv.st2kv_(MOCK_CTX, 'foo'), 'bar') - self.assertEqual(st2kv.st2kv_(MOCK_CTX, 'foo_empty'), '') - self.assertIsNone(st2kv.st2kv_(MOCK_CTX, 'foo_null')) + self.assertEqual(st2kv.st2kv_(MOCK_CTX, "foo"), "bar") + self.assertEqual(st2kv.st2kv_(MOCK_CTX, "foo_empty"), "") + self.assertIsNone(st2kv.st2kv_(MOCK_CTX, "foo_null")) def test_key_does_not_exist(self): self.assertRaisesRegexp( @@ -104,65 +98,61 @@ def test_key_does_not_exist(self): 'The key ".*" does not exist in the StackStorm datastore.', st2kv.st2kv_, MOCK_CTX, - 'foobar' + "foobar", ) def test_key_does_not_exist_but_return_default(self): - self.assertEqual(st2kv.st2kv_(MOCK_CTX, 'foobar', default='foosball'), 'foosball') - self.assertEqual(st2kv.st2kv_(MOCK_CTX, 'foobar', default=''), '') - self.assertIsNone(st2kv.st2kv_(MOCK_CTX, 'foobar', default=None)) + self.assertEqual( + st2kv.st2kv_(MOCK_CTX, "foobar", default="foosball"), "foosball" + ) + self.assertEqual(st2kv.st2kv_(MOCK_CTX, "foobar", default=""), "") + self.assertIsNone(st2kv.st2kv_(MOCK_CTX, "foobar", default=None)) def test_key_decrypt(self): - self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, 'fu'), 'bar') - self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, 'fu', decrypt=False), 'bar') - self.assertEqual(st2kv.st2kv_(MOCK_CTX, 'fu', decrypt=True), 'bar') - self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, 'fu_empty'), '') - self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, 'fu_empty', decrypt=False), '') - self.assertEqual(st2kv.st2kv_(MOCK_CTX, 'fu_empty', decrypt=True), '') + self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, "fu"), "bar") + self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, "fu", decrypt=False), "bar") + self.assertEqual(st2kv.st2kv_(MOCK_CTX, "fu", decrypt=True), "bar") + self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, "fu_empty"), "") + self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, "fu_empty", decrypt=False), "") + self.assertEqual(st2kv.st2kv_(MOCK_CTX, "fu_empty", decrypt=True), "") @mock.patch.object( - kvp_util, 'get_key', - mock.MagicMock(side_effect=Exception('Mock failure.'))) + kvp_util, "get_key", mock.MagicMock(side_effect=Exception("Mock failure.")) + ) def test_get_key_exception(self): self.assertRaisesRegexp( exc.ExpressionEvaluationException, - 'Mock failure.', + "Mock failure.", st2kv.st2kv_, MOCK_CTX, - 'foo' + "foo", ) class SystemScopeDatastoreFunctionTest(st2tests.ExecutionDbTestCase): - @classmethod def setUpClass(cls): super(SystemScopeDatastoreFunctionTest, cls).setUpClass() - user = auth_db.UserDB(name='stanley') + user = auth_db.UserDB(name="stanley") user.save() scope = kvp_const.FULL_SYSTEM_SCOPE cls.kvps = {} # Plain key - keys = { - 'foo': 'bar', - 'foo_empty': '', - 'foo_null': None - } + keys = {"foo": "bar", "foo_empty": "", "foo_null": None} for k, v in six.iteritems(keys): instance = kvp_db.KeyValuePairDB(name=k, value=v, scope=scope) cls.kvps[k] = kvp_db_access.KeyValuePair.add_or_update(instance) # Secret key - keys = { - 'fu': 'bar', - 'fu_empty': '' - } + keys = {"fu": "bar", "fu_empty": ""} for k, v in six.iteritems(keys): value = crypto.symmetric_encrypt(kvp_api.KeyValuePairAPI.crypto_key, v) - instance = kvp_db.KeyValuePairDB(name=k, value=value, scope=scope, secret=True) + instance = kvp_db.KeyValuePairDB( + name=k, value=value, scope=scope, secret=True + ) cls.kvps[k] = kvp_db_access.KeyValuePair.add_or_update(instance) @classmethod @@ -173,9 +163,9 @@ def tearDownClass(cls): super(SystemScopeDatastoreFunctionTest, cls).tearDownClass() def test_key_exists(self): - self.assertEqual(st2kv.st2kv_(MOCK_CTX, 'system.foo'), 'bar') - self.assertEqual(st2kv.st2kv_(MOCK_CTX, 'system.foo_empty'), '') - self.assertIsNone(st2kv.st2kv_(MOCK_CTX, 'system.foo_null')) + self.assertEqual(st2kv.st2kv_(MOCK_CTX, "system.foo"), "bar") + self.assertEqual(st2kv.st2kv_(MOCK_CTX, "system.foo_empty"), "") + self.assertIsNone(st2kv.st2kv_(MOCK_CTX, "system.foo_null")) def test_key_does_not_exist(self): self.assertRaisesRegexp( @@ -183,30 +173,34 @@ def test_key_does_not_exist(self): 'The key ".*" does not exist in the StackStorm datastore.', st2kv.st2kv_, MOCK_CTX, - 'foo' + "foo", ) def test_key_does_not_exist_but_return_default(self): - self.assertEqual(st2kv.st2kv_(MOCK_CTX, 'system.foobar', default='foosball'), 'foosball') - self.assertEqual(st2kv.st2kv_(MOCK_CTX, 'system.foobar', default=''), '') - self.assertIsNone(st2kv.st2kv_(MOCK_CTX, 'system.foobar', default=None)) + self.assertEqual( + st2kv.st2kv_(MOCK_CTX, "system.foobar", default="foosball"), "foosball" + ) + self.assertEqual(st2kv.st2kv_(MOCK_CTX, "system.foobar", default=""), "") + self.assertIsNone(st2kv.st2kv_(MOCK_CTX, "system.foobar", default=None)) def test_key_decrypt(self): - self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, 'system.fu'), 'bar') - self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, 'system.fu', decrypt=False), 'bar') - self.assertEqual(st2kv.st2kv_(MOCK_CTX, 'system.fu', decrypt=True), 'bar') - self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, 'system.fu_empty'), '') - self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, 'system.fu_empty', decrypt=False), '') - self.assertEqual(st2kv.st2kv_(MOCK_CTX, 'system.fu_empty', decrypt=True), '') + self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, "system.fu"), "bar") + self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, "system.fu", decrypt=False), "bar") + self.assertEqual(st2kv.st2kv_(MOCK_CTX, "system.fu", decrypt=True), "bar") + self.assertNotEqual(st2kv.st2kv_(MOCK_CTX, "system.fu_empty"), "") + self.assertNotEqual( + st2kv.st2kv_(MOCK_CTX, "system.fu_empty", decrypt=False), "" + ) + self.assertEqual(st2kv.st2kv_(MOCK_CTX, "system.fu_empty", decrypt=True), "") @mock.patch.object( - kvp_util, 'get_key', - mock.MagicMock(side_effect=Exception('Mock failure.'))) + kvp_util, "get_key", mock.MagicMock(side_effect=Exception("Mock failure.")) + ) def test_get_key_exception(self): self.assertRaisesRegexp( exc.ExpressionEvaluationException, - 'Mock failure.', + "Mock failure.", st2kv.st2kv_, MOCK_CTX, - 'system.foo' + "system.foo", ) diff --git a/contrib/runners/orquesta_runner/tests/unit/test_functions_task.py b/contrib/runners/orquesta_runner/tests/unit/test_functions_task.py index 146e7ee39e9..46ffb861e31 100644 --- a/contrib/runners/orquesta_runner/tests/unit/test_functions_task.py +++ b/contrib/runners/orquesta_runner/tests/unit/test_functions_task.py @@ -23,6 +23,7 @@ # XXX: actionsensor import depends on config being setup. import st2tests.config as tests_config + tests_config.parse_args() from tests.unit import base @@ -43,37 +44,45 @@ from st2tests.mocks import workflow as mock_wf_ex_xport -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) PACKS = [ TEST_PACK_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] @mock.patch.object( - publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) + publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None) +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create)) + "publish_create", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create + ), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state)) + "publish_state", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state + ), +) class OrquestaFunctionTest(st2tests.ExecutionDbTestCase): - @classmethod def setUpClass(cls): super(OrquestaFunctionTest, cls).setUpClass() @@ -83,42 +92,57 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: actions_registrar.register_from_pack(pack) - def _execute_workflow(self, wf_name, expected_task_sequence, expected_output, - expected_status=wf_statuses.SUCCEEDED, expected_errors=None): - wf_file = wf_name + '.yaml' + def _execute_workflow( + self, + wf_name, + expected_task_sequence, + expected_output, + expected_status=wf_statuses.SUCCEEDED, + expected_errors=None, + ): + wf_file = wf_name + ".yaml" wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_file) - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, ac_const.LIVEACTION_STATUS_RUNNING) for task_id, route in expected_task_sequence: tk_ex_dbs = wf_db_access.TaskExecution.query( - workflow_execution=str(wf_ex_db.id), - task_id=task_id, - task_route=route + workflow_execution=str(wf_ex_db.id), task_id=task_id, task_route=route ) if len(tk_ex_dbs) <= 0: break - tk_ex_db = sorted(tk_ex_dbs, key=lambda x: x.start_timestamp)[len(tk_ex_dbs) - 1] - tk_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_db.id))[0] - tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction['id']) + tk_ex_db = sorted(tk_ex_dbs, key=lambda x: x.start_timestamp)[ + len(tk_ex_dbs) - 1 + ] + tk_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_db.id) + )[0] + tk_lv_ac_db = lv_db_access.LiveAction.get_by_id( + tk_ac_ex_db.liveaction["id"] + ) self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) - self.assertTrue(wf_svc.is_action_execution_under_workflow_context(tk_ac_ex_db)) + self.assertTrue( + wf_svc.is_action_execution_under_workflow_context(tk_ac_ex_db) + ) wf_svc.handle_action_execution_completion(tk_ac_ex_db) @@ -131,10 +155,10 @@ def _execute_workflow(self, wf_name, expected_task_sequence, expected_output, self.assertEqual(ac_ex_db.status, expected_status) # Check workflow output, liveaction result, and action execution result. - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} if expected_errors is not None: - expected_result['errors'] = expected_errors + expected_result["errors"] = expected_errors if expected_output is not None: self.assertDictEqual(wf_ex_db.output, expected_output) @@ -143,83 +167,81 @@ def _execute_workflow(self, wf_name, expected_task_sequence, expected_output, self.assertDictEqual(ac_ex_db.result, expected_result) def test_task_functions_in_yaql(self): - wf_name = 'yaql-task-functions' + wf_name = "yaql-task-functions" expected_task_sequence = [ - ('task1', 0), - ('task3', 0), - ('task6', 0), - ('task7', 0), - ('task2', 0), - ('task4', 0), - ('task8', 1), - ('task8', 2), - ('task4', 0), - ('task9', 1), - ('task9', 2), - ('task5', 0) + ("task1", 0), + ("task3", 0), + ("task6", 0), + ("task7", 0), + ("task2", 0), + ("task4", 0), + ("task8", 1), + ("task8", 2), + ("task4", 0), + ("task9", 1), + ("task9", 2), + ("task5", 0), ] expected_output = { - 'last_task4_result': 'False', - 'task9__1__parent': 'task8__1', - 'task9__2__parent': 'task8__2', - 'that_task_by_name': 'task1', - 'this_task_by_name': 'task1', - 'this_task_no_arg': 'task1' + "last_task4_result": "False", + "task9__1__parent": "task8__1", + "task9__2__parent": "task8__2", + "that_task_by_name": "task1", + "this_task_by_name": "task1", + "this_task_no_arg": "task1", } self._execute_workflow(wf_name, expected_task_sequence, expected_output) def test_task_functions_in_jinja(self): - wf_name = 'jinja-task-functions' + wf_name = "jinja-task-functions" expected_task_sequence = [ - ('task1', 0), - ('task3', 0), - ('task6', 0), - ('task7', 0), - ('task2', 0), - ('task4', 0), - ('task8', 1), - ('task8', 2), - ('task4', 0), - ('task9', 1), - ('task9', 2), - ('task5', 0) + ("task1", 0), + ("task3", 0), + ("task6", 0), + ("task7", 0), + ("task2", 0), + ("task4", 0), + ("task8", 1), + ("task8", 2), + ("task4", 0), + ("task9", 1), + ("task9", 2), + ("task5", 0), ] expected_output = { - 'last_task4_result': 'False', - 'task9__1__parent': 'task8__1', - 'task9__2__parent': 'task8__2', - 'that_task_by_name': 'task1', - 'this_task_by_name': 'task1', - 'this_task_no_arg': 'task1' + "last_task4_result": "False", + "task9__1__parent": "task8__1", + "task9__2__parent": "task8__2", + "that_task_by_name": "task1", + "this_task_by_name": "task1", + "this_task_no_arg": "task1", } self._execute_workflow(wf_name, expected_task_sequence, expected_output) def test_task_nonexistent_in_yaql(self): - wf_name = 'yaql-task-nonexistent' + wf_name = "yaql-task-nonexistent" - expected_task_sequence = [ - ('task1', 0) - ] + expected_task_sequence = [("task1", 0)] expected_output = None expected_errors = [ { - 'type': 'error', - 'message': ( - 'YaqlEvaluationException: Unable to evaluate expression ' - '\'<% task("task0") %>\'. ExpressionEvaluationException: ' + "type": "error", + "message": ( + "YaqlEvaluationException: Unable to evaluate expression " + "'<% task(\"task0\") %>'. ExpressionEvaluationException: " 'Unable to find task execution for "task0".' ), - 'task_transition_id': 'continue__t0', - 'task_id': 'task1', - 'route': 0 + "task_transition_id": "continue__t0", + "task_id": "task1", + "route": 0, } ] @@ -228,29 +250,27 @@ def test_task_nonexistent_in_yaql(self): expected_task_sequence, expected_output, expected_status=ac_const.LIVEACTION_STATUS_FAILED, - expected_errors=expected_errors + expected_errors=expected_errors, ) def test_task_nonexistent_in_jinja(self): - wf_name = 'jinja-task-nonexistent' + wf_name = "jinja-task-nonexistent" - expected_task_sequence = [ - ('task1', 0) - ] + expected_task_sequence = [("task1", 0)] expected_output = None expected_errors = [ { - 'type': 'error', - 'message': ( - 'JinjaEvaluationException: Unable to evaluate expression ' - '\'{{ task("task0") }}\'. ExpressionEvaluationException: ' + "type": "error", + "message": ( + "JinjaEvaluationException: Unable to evaluate expression " + "'{{ task(\"task0\") }}'. ExpressionEvaluationException: " 'Unable to find task execution for "task0".' ), - 'task_transition_id': 'continue__t0', - 'task_id': 'task1', - 'route': 0 + "task_transition_id": "continue__t0", + "task_id": "task1", + "route": 0, } ] @@ -259,5 +279,5 @@ def test_task_nonexistent_in_jinja(self): expected_task_sequence, expected_output, expected_status=ac_const.LIVEACTION_STATUS_FAILED, - expected_errors=expected_errors + expected_errors=expected_errors, ) diff --git a/contrib/runners/orquesta_runner/tests/unit/test_inquiries.py b/contrib/runners/orquesta_runner/tests/unit/test_inquiries.py index 3e84d7bce80..8dfdf24a844 100644 --- a/contrib/runners/orquesta_runner/tests/unit/test_inquiries.py +++ b/contrib/runners/orquesta_runner/tests/unit/test_inquiries.py @@ -23,6 +23,7 @@ # XXX: actionsensor import depends on config being setup. import st2tests.config as tests_config + tests_config.parse_args() from tests.unit import base @@ -45,37 +46,45 @@ from st2tests.mocks import workflow as mock_wf_ex_xport -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) PACKS = [ TEST_PACK_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] @mock.patch.object( - publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) + publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None) +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create)) + "publish_create", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create + ), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state)) + "publish_state", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state + ), +) class OrquestaRunnerTest(st2tests.ExecutionDbTestCase): - @classmethod def setUpClass(cls): super(OrquestaRunnerTest, cls).setUpClass() @@ -85,30 +94,35 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: actions_registrar.register_from_pack(pack) def test_inquiry(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'ask-approval.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "ask-approval.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) # Assert start task is already completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'start'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "start"} t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0] - t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction['id']) - self.assertEqual(t1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + )[0] + t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction["id"]) + self.assertEqual( + t1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflows.get_engine().process(t1_ac_ex_db) t1_ex_db = wf_db_access.TaskExecution.get_by_id(t1_ex_db.id) self.assertEqual(t1_ex_db.status, wf_statuses.SUCCEEDED) @@ -118,10 +132,15 @@ def test_inquiry(self): self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Assert get approval task is already pending. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'get_approval'} + query_filters = { + "workflow_execution": str(wf_ex_db.id), + "task_id": "get_approval", + } t2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_ex_db.id))[0] - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) + t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_ex_db.id) + )[0] + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) self.assertEqual(t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_PENDING) workflows.get_engine().process(t2_ac_ex_db) t2_ex_db = wf_db_access.TaskExecution.get_by_id(t2_ex_db.id) @@ -133,12 +152,16 @@ def test_inquiry(self): # Respond to the inquiry and check status. inquiry_api = inqy_api_models.InquiryAPI.from_model(t2_ac_ex_db) - inquiry_response = {'approved': True} + inquiry_response = {"approved": True} inquiry_service.respond(inquiry_api, inquiry_response) t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t2_lv_ac_db.id)) - self.assertEqual(t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + self.assertEqual( + t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) t2_ac_ex_db = ex_db_access.ActionExecution.get_by_id(str(t2_ac_ex_db.id)) - self.assertEqual(t2_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + self.assertEqual( + t2_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflows.get_engine().process(t2_ac_ex_db) t2_ex_db = wf_db_access.TaskExecution.get_by_id(str(t2_ex_db.id)) self.assertEqual(t2_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) @@ -148,11 +171,15 @@ def test_inquiry(self): self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Assert the final task is completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'finish'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "finish"} t3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t3_ex_db.id))[0] - t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(t3_ac_ex_db.liveaction['id']) - self.assertEqual(t3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + t3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t3_ex_db.id) + )[0] + t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(t3_ac_ex_db.liveaction["id"]) + self.assertEqual( + t3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflows.get_engine().process(t3_ac_ex_db) t3_ex_db = wf_db_access.TaskExecution.get_by_id(t3_ex_db.id) self.assertEqual(t3_ex_db.status, wf_statuses.SUCCEEDED) @@ -162,22 +189,30 @@ def test_inquiry(self): self.assertEqual(wf_ex_db.status, wf_statuses.SUCCEEDED) def test_consecutive_inquiries(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'ask-consecutive-approvals.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "ask-consecutive-approvals.yaml" + ) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) # Assert start task is already completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'start'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "start"} t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0] - t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction['id']) - self.assertEqual(t1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + )[0] + t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction["id"]) + self.assertEqual( + t1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflows.get_engine().process(t1_ac_ex_db) t1_ex_db = wf_db_access.TaskExecution.get_by_id(t1_ex_db.id) self.assertEqual(t1_ex_db.status, wf_statuses.SUCCEEDED) @@ -187,10 +222,15 @@ def test_consecutive_inquiries(self): self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Assert get approval task is already pending. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'get_approval'} + query_filters = { + "workflow_execution": str(wf_ex_db.id), + "task_id": "get_approval", + } t2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_ex_db.id))[0] - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) + t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_ex_db.id) + )[0] + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) self.assertEqual(t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_PENDING) workflows.get_engine().process(t2_ac_ex_db) t2_ex_db = wf_db_access.TaskExecution.get_by_id(t2_ex_db.id) @@ -202,12 +242,16 @@ def test_consecutive_inquiries(self): # Respond to the inquiry and check status. inquiry_api = inqy_api_models.InquiryAPI.from_model(t2_ac_ex_db) - inquiry_response = {'approved': True} + inquiry_response = {"approved": True} inquiry_service.respond(inquiry_api, inquiry_response) t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t2_lv_ac_db.id)) - self.assertEqual(t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + self.assertEqual( + t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) t2_ac_ex_db = ex_db_access.ActionExecution.get_by_id(str(t2_ac_ex_db.id)) - self.assertEqual(t2_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + self.assertEqual( + t2_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflows.get_engine().process(t2_ac_ex_db) t2_ex_db = wf_db_access.TaskExecution.get_by_id(str(t2_ex_db.id)) self.assertEqual(t2_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) @@ -217,10 +261,15 @@ def test_consecutive_inquiries(self): self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Assert the final task is completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'get_confirmation'} + query_filters = { + "workflow_execution": str(wf_ex_db.id), + "task_id": "get_confirmation", + } t3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t3_ex_db.id))[0] - t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(t3_ac_ex_db.liveaction['id']) + t3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t3_ex_db.id) + )[0] + t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(t3_ac_ex_db.liveaction["id"]) self.assertEqual(t3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_PENDING) workflows.get_engine().process(t3_ac_ex_db) t3_ex_db = wf_db_access.TaskExecution.get_by_id(t3_ex_db.id) @@ -232,12 +281,16 @@ def test_consecutive_inquiries(self): # Respond to the inquiry and check status. inquiry_api = inqy_api_models.InquiryAPI.from_model(t3_ac_ex_db) - inquiry_response = {'approved': True} + inquiry_response = {"approved": True} inquiry_service.respond(inquiry_api, inquiry_response) t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t3_lv_ac_db.id)) - self.assertEqual(t3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + self.assertEqual( + t3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) t3_ac_ex_db = ex_db_access.ActionExecution.get_by_id(str(t3_ac_ex_db.id)) - self.assertEqual(t3_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + self.assertEqual( + t3_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflows.get_engine().process(t3_ac_ex_db) t3_ex_db = wf_db_access.TaskExecution.get_by_id(str(t3_ex_db.id)) self.assertEqual(t3_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) @@ -247,11 +300,15 @@ def test_consecutive_inquiries(self): self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Assert the final task is completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'finish'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "finish"} t4_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t4_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t4_ex_db.id))[0] - t4_lv_ac_db = lv_db_access.LiveAction.get_by_id(t4_ac_ex_db.liveaction['id']) - self.assertEqual(t4_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + t4_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t4_ex_db.id) + )[0] + t4_lv_ac_db = lv_db_access.LiveAction.get_by_id(t4_ac_ex_db.liveaction["id"]) + self.assertEqual( + t4_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflows.get_engine().process(t4_ac_ex_db) t4_ex_db = wf_db_access.TaskExecution.get_by_id(t4_ex_db.id) self.assertEqual(t4_ex_db.status, wf_statuses.SUCCEEDED) @@ -261,22 +318,30 @@ def test_consecutive_inquiries(self): self.assertEqual(wf_ex_db.status, wf_statuses.SUCCEEDED) def test_parallel_inquiries(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'ask-parallel-approvals.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "ask-parallel-approvals.yaml" + ) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) # Assert start task is already completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'start'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "start"} t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0] - t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction['id']) - self.assertEqual(t1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + )[0] + t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction["id"]) + self.assertEqual( + t1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflows.get_engine().process(t1_ac_ex_db) t1_ex_db = wf_db_access.TaskExecution.get_by_id(t1_ex_db.id) self.assertEqual(t1_ex_db.status, wf_statuses.SUCCEEDED) @@ -286,10 +351,12 @@ def test_parallel_inquiries(self): self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Assert get approval task is already pending. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'ask_jack'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "ask_jack"} t2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_ex_db.id))[0] - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) + t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_ex_db.id) + )[0] + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) self.assertEqual(t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_PENDING) workflows.get_engine().process(t2_ac_ex_db) t2_ex_db = wf_db_access.TaskExecution.get_by_id(t2_ex_db.id) @@ -300,10 +367,12 @@ def test_parallel_inquiries(self): self.assertEqual(wf_ex_db.status, wf_statuses.PAUSING) # Assert get approval task is already pending. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'ask_jill'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "ask_jill"} t3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t3_ex_db.id))[0] - t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(t3_ac_ex_db.liveaction['id']) + t3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t3_ex_db.id) + )[0] + t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(t3_ac_ex_db.liveaction["id"]) self.assertEqual(t3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_PENDING) workflows.get_engine().process(t3_ac_ex_db) t3_ex_db = wf_db_access.TaskExecution.get_by_id(t3_ex_db.id) @@ -315,12 +384,16 @@ def test_parallel_inquiries(self): # Respond to the inquiry and check status. inquiry_api = inqy_api_models.InquiryAPI.from_model(t2_ac_ex_db) - inquiry_response = {'approved': True} + inquiry_response = {"approved": True} inquiry_service.respond(inquiry_api, inquiry_response) t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t2_lv_ac_db.id)) - self.assertEqual(t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + self.assertEqual( + t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) t2_ac_ex_db = ex_db_access.ActionExecution.get_by_id(str(t2_ac_ex_db.id)) - self.assertEqual(t2_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + self.assertEqual( + t2_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflows.get_engine().process(t2_ac_ex_db) t2_ex_db = wf_db_access.TaskExecution.get_by_id(str(t2_ex_db.id)) self.assertEqual(t2_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) @@ -332,12 +405,16 @@ def test_parallel_inquiries(self): # Respond to the inquiry and check status. inquiry_api = inqy_api_models.InquiryAPI.from_model(t3_ac_ex_db) - inquiry_response = {'approved': True} + inquiry_response = {"approved": True} inquiry_service.respond(inquiry_api, inquiry_response) t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t3_lv_ac_db.id)) - self.assertEqual(t3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + self.assertEqual( + t3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) t3_ac_ex_db = ex_db_access.ActionExecution.get_by_id(str(t3_ac_ex_db.id)) - self.assertEqual(t3_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + self.assertEqual( + t3_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflows.get_engine().process(t3_ac_ex_db) t3_ex_db = wf_db_access.TaskExecution.get_by_id(str(t3_ex_db.id)) self.assertEqual(t3_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) @@ -347,11 +424,15 @@ def test_parallel_inquiries(self): self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Assert the final task is completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'finish'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "finish"} t4_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t4_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t4_ex_db.id))[0] - t4_lv_ac_db = lv_db_access.LiveAction.get_by_id(t4_ac_ex_db.liveaction['id']) - self.assertEqual(t4_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + t4_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t4_ex_db.id) + )[0] + t4_lv_ac_db = lv_db_access.LiveAction.get_by_id(t4_ac_ex_db.liveaction["id"]) + self.assertEqual( + t4_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflows.get_engine().process(t4_ac_ex_db) t4_ex_db = wf_db_access.TaskExecution.get_by_id(t4_ex_db.id) self.assertEqual(t4_ex_db.status, wf_statuses.SUCCEEDED) @@ -361,22 +442,30 @@ def test_parallel_inquiries(self): self.assertEqual(wf_ex_db.status, wf_statuses.SUCCEEDED) def test_nested_inquiry(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'ask-nested-approval.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "ask-nested-approval.yaml" + ) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) # Assert start task is already completed. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'start'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "start"} t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0] - t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction['id']) - self.assertEqual(t1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + )[0] + t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction["id"]) + self.assertEqual( + t1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflows.get_engine().process(t1_ac_ex_db) t1_ex_db = wf_db_access.TaskExecution.get_by_id(t1_ex_db.id) self.assertEqual(t1_ex_db.status, wf_statuses.SUCCEEDED) @@ -386,23 +475,36 @@ def test_nested_inquiry(self): self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Assert the subworkflow is already started. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'get_approval'} + query_filters = { + "workflow_execution": str(wf_ex_db.id), + "task_id": "get_approval", + } t2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_ex_db.id))[0] - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) + t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_ex_db.id) + )[0] + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) self.assertEqual(t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) workflows.get_engine().process(t2_ac_ex_db) t2_ex_db = wf_db_access.TaskExecution.get_by_id(t2_ex_db.id) self.assertEqual(t2_ex_db.status, wf_statuses.RUNNING) - t2_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t2_ac_ex_db.id))[0] + t2_wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(t2_ac_ex_db.id) + )[0] self.assertEqual(t2_wf_ex_db.status, wf_statuses.RUNNING) # Process task1 of subworkflow. - query_filters = {'workflow_execution': str(t2_wf_ex_db.id), 'task_id': 'start'} + query_filters = {"workflow_execution": str(t2_wf_ex_db.id), "task_id": "start"} t2_t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t2_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t1_ex_db.id))[0] - t2_t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_t1_ac_ex_db.liveaction['id']) - self.assertEqual(t2_t1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + t2_t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_t1_ex_db.id) + )[0] + t2_t1_lv_ac_db = lv_db_access.LiveAction.get_by_id( + t2_t1_ac_ex_db.liveaction["id"] + ) + self.assertEqual( + t2_t1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflows.get_engine().process(t2_t1_ac_ex_db) t2_t1_ex_db = wf_db_access.TaskExecution.get_by_id(t2_t1_ex_db.id) self.assertEqual(t2_t1_ex_db.status, wf_statuses.SUCCEEDED) @@ -410,11 +512,20 @@ def test_nested_inquiry(self): self.assertEqual(t2_wf_ex_db.status, wf_statuses.RUNNING) # Process inquiry task of subworkflow and assert the subworkflow is paused. - query_filters = {'workflow_execution': str(t2_wf_ex_db.id), 'task_id': 'get_approval'} + query_filters = { + "workflow_execution": str(t2_wf_ex_db.id), + "task_id": "get_approval", + } t2_t2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t2_t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t2_ex_db.id))[0] - t2_t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_t2_ac_ex_db.liveaction['id']) - self.assertEqual(t2_t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_PENDING) + t2_t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_t2_ex_db.id) + )[0] + t2_t2_lv_ac_db = lv_db_access.LiveAction.get_by_id( + t2_t2_ac_ex_db.liveaction["id"] + ) + self.assertEqual( + t2_t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_PENDING + ) workflows.get_engine().process(t2_t2_ac_ex_db) t2_t2_ex_db = wf_db_access.TaskExecution.get_by_id(t2_t2_ex_db.id) self.assertEqual(t2_t2_ex_db.status, wf_statuses.PENDING) @@ -422,8 +533,10 @@ def test_nested_inquiry(self): self.assertEqual(t2_wf_ex_db.status, wf_statuses.PAUSED) # Process the corresponding task in parent workflow and assert the task is paused. - t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_ex_db.id))[0] - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) + t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_ex_db.id) + )[0] + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) self.assertEqual(t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_PAUSED) workflows.get_engine().process(t2_ac_ex_db) t2_ex_db = wf_db_access.TaskExecution.get_by_id(t2_ex_db.id) @@ -435,34 +548,50 @@ def test_nested_inquiry(self): # Respond to the inquiry and check status. inquiry_api = inqy_api_models.InquiryAPI.from_model(t2_t2_ac_ex_db) - inquiry_response = {'approved': True} + inquiry_response = {"approved": True} inquiry_service.respond(inquiry_api, inquiry_response) t2_t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t2_t2_lv_ac_db.id)) - self.assertEqual(t2_t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + self.assertEqual( + t2_t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) t2_t2_ac_ex_db = ex_db_access.ActionExecution.get_by_id(str(t2_t2_ac_ex_db.id)) - self.assertEqual(t2_t2_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + self.assertEqual( + t2_t2_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflows.get_engine().process(t2_t2_ac_ex_db) t2_t2_ex_db = wf_db_access.TaskExecution.get_by_id(str(t2_t2_ex_db.id)) - self.assertEqual(t2_t2_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + self.assertEqual( + t2_t2_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) # Assert the main workflow is running again. wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Complete the rest of the subworkflow - query_filters = {'workflow_execution': str(t2_wf_ex_db.id), 'task_id': 'finish'} + query_filters = {"workflow_execution": str(t2_wf_ex_db.id), "task_id": "finish"} t2_t3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t2_t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t3_ex_db.id))[0] - t2_t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_t3_ac_ex_db.liveaction['id']) - self.assertEqual(t2_t3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + t2_t3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_t3_ex_db.id) + )[0] + t2_t3_lv_ac_db = lv_db_access.LiveAction.get_by_id( + t2_t3_ac_ex_db.liveaction["id"] + ) + self.assertEqual( + t2_t3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflows.get_engine().process(t2_t3_ac_ex_db) t2_t3_ex_db = wf_db_access.TaskExecution.get_by_id(t2_t3_ex_db.id) self.assertEqual(t2_t3_ex_db.status, wf_statuses.SUCCEEDED) t2_wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(str(t2_wf_ex_db.id)) self.assertEqual(t2_wf_ex_db.status, wf_statuses.SUCCEEDED) - t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_ex_db.id))[0] - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) - self.assertEqual(t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_ex_db.id) + )[0] + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) + self.assertEqual( + t2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflows.get_engine().process(t2_ac_ex_db) t2_ex_db = wf_db_access.TaskExecution.get_by_id(t2_ex_db.id) self.assertEqual(t2_ex_db.status, wf_statuses.SUCCEEDED) @@ -470,11 +599,15 @@ def test_nested_inquiry(self): self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Complete the rest of the main workflow - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'finish'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "finish"} t3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t3_ex_db.id))[0] - t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(t3_ac_ex_db.liveaction['id']) - self.assertEqual(t3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + t3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t3_ex_db.id) + )[0] + t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(t3_ac_ex_db.liveaction["id"]) + self.assertEqual( + t3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflows.get_engine().process(t3_ac_ex_db) t3_ex_db = wf_db_access.TaskExecution.get_by_id(t3_ex_db.id) self.assertEqual(t3_ex_db.status, wf_statuses.SUCCEEDED) diff --git a/contrib/runners/orquesta_runner/tests/unit/test_notify.py b/contrib/runners/orquesta_runner/tests/unit/test_notify.py index dc8131f1003..6ca125d8559 100644 --- a/contrib/runners/orquesta_runner/tests/unit/test_notify.py +++ b/contrib/runners/orquesta_runner/tests/unit/test_notify.py @@ -25,6 +25,7 @@ # XXX: actionsensor import depends on config being setup. import st2tests.config as tests_config + tests_config.parse_args() from tests.unit import base @@ -47,57 +48,60 @@ from st2tests.mocks import liveaction as mock_lv_ac_xport from st2tests.mocks import workflow as mock_wf_ex_xport -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) PACKS = [ TEST_PACK_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] MOCK_NOTIFY = { - 'on-complete': { - 'data': { - 'source_channel': 'baloney', - 'user': 'lakstorm' - }, - 'routes': [ - 'hubot' - ] + "on-complete": { + "data": {"source_channel": "baloney", "user": "lakstorm"}, + "routes": ["hubot"], } } @mock.patch.object( - notifier.Notifier, - '_post_notify_triggers', - mock.MagicMock(return_value=None)) + notifier.Notifier, "_post_notify_triggers", mock.MagicMock(return_value=None) +) @mock.patch.object( - notifier.Notifier, - '_post_generic_trigger', - mock.MagicMock(return_value=None)) + notifier.Notifier, "_post_generic_trigger", mock.MagicMock(return_value=None) +) @mock.patch.object( publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(side_effect=mock_ac_ex_xport.MockExecutionPublisher.publish_update)) + "publish_update", + mock.MagicMock(side_effect=mock_ac_ex_xport.MockExecutionPublisher.publish_update), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create)) + "publish_create", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create + ), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state)) + "publish_state", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state + ), +) class OrquestaNotifyTest(st2tests.ExecutionDbTestCase): - @classmethod def setUpClass(cls): super(OrquestaNotifyTest, cls).setUpClass() @@ -107,177 +111,181 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: actions_registrar.register_from_pack(pack) def test_no_notify(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) # Check that notify is setup correctly in the db record. self.assertDictEqual(wf_ex_db.notify, {}) def test_no_notify_task_list(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db.notify = notify_api_models.NotificationsHelper.to_model(MOCK_NOTIFY) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) # Check that notify is setup correctly in the db record. - expected_notify = { - 'config': MOCK_NOTIFY, - 'tasks': [] - } + expected_notify = {"config": MOCK_NOTIFY, "tasks": []} self.assertDictEqual(wf_ex_db.notify, expected_notify) def test_custom_notify_task_list(self): - wf_input = {'notify': ['task1']} - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_input = {"notify": ["task1"]} + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db.notify = notify_api_models.NotificationsHelper.to_model(MOCK_NOTIFY) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) # Check that notify is setup correctly in the db record. - expected_notify = { - 'config': MOCK_NOTIFY, - 'tasks': wf_input['notify'] - } + expected_notify = {"config": MOCK_NOTIFY, "tasks": wf_input["notify"]} self.assertDictEqual(wf_ex_db.notify, expected_notify) def test_default_notify_task_list(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'notify.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "notify.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db.notify = notify_api_models.NotificationsHelper.to_model(MOCK_NOTIFY) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) # Check that notify is setup correctly in the db record. - expected_notify = { - 'config': MOCK_NOTIFY, - 'tasks': ['task1', 'task2', 'task3'] - } + expected_notify = {"config": MOCK_NOTIFY, "tasks": ["task1", "task2", "task3"]} self.assertDictEqual(wf_ex_db.notify, expected_notify) def test_notify_task_list_bad_item_value(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db.notify = notify_api_models.NotificationsHelper.to_model(MOCK_NOTIFY) expected_schema_failure_test_cases = [ - 'task1', # Notify must be type of list. - [123], # Item has to be type of string. - [''], # String value cannot be empty. - [' '], # String value cannot be just spaces. - [' '], # String value cannot be just tabs. - ['init task'], # String value cannot have space. - ['init-task'], # String value cannot have dash. - ['task1', 'task1'] # String values have to be unique. + "task1", # Notify must be type of list. + [123], # Item has to be type of string. + [""], # String value cannot be empty. + [" "], # String value cannot be just spaces. + [" "], # String value cannot be just tabs. + ["init task"], # String value cannot have space. + ["init-task"], # String value cannot have dash. + ["task1", "task1"], # String values have to be unique. ] for notify_tasks in expected_schema_failure_test_cases: - lv_ac_db.parameters = {'notify': notify_tasks} + lv_ac_db.parameters = {"notify": notify_tasks} try: self.assertRaises( - jsonschema.ValidationError, - action_service.request, - lv_ac_db + jsonschema.ValidationError, action_service.request, lv_ac_db ) except Exception as e: - raise AssertionError('%s: %s' % (six.text_type(e), notify_tasks)) + raise AssertionError("%s: %s" % (six.text_type(e), notify_tasks)) def test_notify_task_list_nonexistent_task(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db.notify = notify_api_models.NotificationsHelper.to_model(MOCK_NOTIFY) - lv_ac_db.parameters = {'notify': ['init_task']} + lv_ac_db.parameters = {"notify": ["init_task"]} lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) expected_result = { - 'output': None, - 'errors': [ + "output": None, + "errors": [ { - 'message': ( - 'The following tasks in the notify parameter do not ' - 'exist in the workflow definition: init_task.' + "message": ( + "The following tasks in the notify parameter do not " + "exist in the workflow definition: init_task." ) } - ] + ], } self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_FAILED) self.assertDictEqual(lv_ac_db.result, expected_result) def test_notify_task_list_item_value(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db.notify = notify_api_models.NotificationsHelper.to_model(MOCK_NOTIFY) - expected_schema_success_test_cases = [ - [], - ['task1'], - ['task1', 'task2'] - ] + expected_schema_success_test_cases = [[], ["task1"], ["task1", "task2"]] for notify_tasks in expected_schema_success_test_cases: - lv_ac_db.parameters = {'notify': notify_tasks} + lv_ac_db.parameters = {"notify": notify_tasks} lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) + self.assertEqual( + lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING + ) def test_cascade_notify_to_tasks(self): - wf_input = {'notify': ['task2']} - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_input = {"notify": ["task2"]} + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db.notify = notify_api_models.NotificationsHelper.to_model(MOCK_NOTIFY) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) # Assert task1 notify is not set. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) self.assertIsNone(tk1_lv_ac_db.notify) - self.assertEqual(tk1_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + self.assertEqual( + tk1_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) self.assertFalse(notifier.Notifier._post_notify_triggers.called) notifier.Notifier._post_notify_triggers.reset_mock() @@ -289,13 +297,19 @@ def test_cascade_notify_to_tasks(self): self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Assert task2 notify is set. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task2'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task2"} tk2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk2_ex_db.id))[0] - tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction['id']) - notify = notify_api_models.NotificationsHelper.from_model(notify_model=tk2_lv_ac_db.notify) + tk2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk2_ex_db.id) + )[0] + tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction["id"]) + notify = notify_api_models.NotificationsHelper.from_model( + notify_model=tk2_lv_ac_db.notify + ) self.assertEqual(notify, MOCK_NOTIFY) - self.assertEqual(tk2_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + self.assertEqual( + tk2_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) self.assertTrue(notifier.Notifier._post_notify_triggers.called) notifier.Notifier._post_notify_triggers.reset_mock() @@ -307,12 +321,16 @@ def test_cascade_notify_to_tasks(self): self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Assert task3 notify is not set. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task3'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task3"} tk3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk3_ex_db.id))[0] - tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction['id']) + tk3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk3_ex_db.id) + )[0] + tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction["id"]) self.assertIsNone(tk3_lv_ac_db.notify) - self.assertEqual(tk3_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + self.assertEqual( + tk3_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) self.assertFalse(notifier.Notifier._post_notify_triggers.called) notifier.Notifier._post_notify_triggers.reset_mock() diff --git a/contrib/runners/orquesta_runner/tests/unit/test_output_schema.py b/contrib/runners/orquesta_runner/tests/unit/test_output_schema.py index 5bae5bab27b..f23084b5278 100644 --- a/contrib/runners/orquesta_runner/tests/unit/test_output_schema.py +++ b/contrib/runners/orquesta_runner/tests/unit/test_output_schema.py @@ -22,6 +22,7 @@ # XXX: actionsensor import depends on config being setup. import st2tests.config as tests_config + tests_config.parse_args() from tests.unit import base @@ -45,12 +46,14 @@ BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) PACKS = [ TEST_PACK_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] FAIL_SCHEMA = { @@ -61,25 +64,32 @@ @mock.patch.object( - publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) + publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None) +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create)) + "publish_create", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create + ), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state)) + "publish_state", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state + ), +) class OrquestaRunnerTest(RunnerTestCase, st2tests.ExecutionDbTestCase): @classmethod def setUpClass(cls): @@ -90,8 +100,7 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: @@ -102,28 +111,40 @@ def get_runner_class(cls, runner_name): return runners.get_runner(runner_name, runner_name).__class__ def test_adherence_to_output_schema(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential_with_schema.yaml') - wf_input = {'who': 'Thanos'} - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "sequential_with_schema.yaml" + ) + wf_input = {"who": "Thanos"} + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id)) + wf_ex_dbs = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + ) wf_ex_db = wf_ex_dbs[0] - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] wf_svc.handle_action_execution_completion(tk1_ac_ex_db) tk1_ex_db = wf_db_access.TaskExecution.get_by_id(tk1_ex_db.id) wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task2'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task2"} tk2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk2_ex_db.id))[0] + tk2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk2_ex_db.id) + )[0] wf_svc.handle_action_execution_completion(tk2_ac_ex_db) tk2_ex_db = wf_db_access.TaskExecution.get_by_id(tk2_ex_db.id) wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task3'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task3"} tk3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk3_ex_db.id))[0] + tk3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk3_ex_db.id) + )[0] wf_svc.handle_action_execution_completion(tk3_ac_ex_db) wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) @@ -134,30 +155,39 @@ def test_adherence_to_output_schema(self): def test_fail_incorrect_output_schema(self): wf_meta = base.get_wf_fixture_meta_data( - TEST_PACK_PATH, - 'sequential_with_broken_schema.yaml' + TEST_PACK_PATH, "sequential_with_broken_schema.yaml" + ) + wf_input = {"who": "Thanos"} + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input ) - wf_input = {'who': 'Thanos'} - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id)) + wf_ex_dbs = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + ) wf_ex_db = wf_ex_dbs[0] - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] wf_svc.handle_action_execution_completion(tk1_ac_ex_db) tk1_ex_db = wf_db_access.TaskExecution.get_by_id(tk1_ex_db.id) wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task2'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task2"} tk2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk2_ex_db.id))[0] + tk2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk2_ex_db.id) + )[0] wf_svc.handle_action_execution_completion(tk2_ac_ex_db) tk2_ex_db = wf_db_access.TaskExecution.get_by_id(tk2_ex_db.id) wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task3'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task3"} tk3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk3_ex_db.id))[0] + tk3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk3_ex_db.id) + )[0] wf_svc.handle_action_execution_completion(tk3_ac_ex_db) wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) @@ -167,9 +197,9 @@ def test_fail_incorrect_output_schema(self): self.assertEqual(ac_ex_db.status, ac_const.LIVEACTION_STATUS_FAILED) expected_result = { - 'error': "Additional properties are not allowed", - 'message': 'Error validating output. See error output for more details.' + "error": "Additional properties are not allowed", + "message": "Error validating output. See error output for more details.", } - self.assertIn(expected_result['error'], ac_ex_db.result['error']) - self.assertEqual(expected_result['message'], ac_ex_db.result['message']) + self.assertIn(expected_result["error"], ac_ex_db.result["error"]) + self.assertEqual(expected_result["message"], ac_ex_db.result["message"]) diff --git a/contrib/runners/orquesta_runner/tests/unit/test_pause_and_resume.py b/contrib/runners/orquesta_runner/tests/unit/test_pause_and_resume.py index c2021b379e3..6ade3900295 100644 --- a/contrib/runners/orquesta_runner/tests/unit/test_pause_and_resume.py +++ b/contrib/runners/orquesta_runner/tests/unit/test_pause_and_resume.py @@ -24,6 +24,7 @@ # XXX: actionsensor import depends on config being setup. import st2tests.config as tests_config + tests_config.parse_args() from tests.unit import base @@ -46,37 +47,45 @@ from st2tests.mocks import workflow as mock_wf_ex_xport -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) PACKS = [ TEST_PACK_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] @mock.patch.object( - publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) + publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None) +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create)) + "publish_create", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create + ), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state)) + "publish_state", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state + ), +) class OrquestaRunnerPauseResumeTest(st2tests.ExecutionDbTestCase): - @classmethod def setUpClass(cls): super(OrquestaRunnerPauseResumeTest, cls).setUpClass() @@ -86,8 +95,7 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: @@ -97,56 +105,68 @@ def setUpClass(cls): def get_runner_class(cls, runner_name): return runners.get_runner(runner_name, runner_name).__class__ - @mock.patch.object( - ac_svc, 'is_children_active', - mock.MagicMock(return_value=False)) + @mock.patch.object(ac_svc, "is_children_active", mock.MagicMock(return_value=False)) def test_pause(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) lv_ac_db, ac_ex_db = ac_svc.request_pause(lv_ac_db, cfg.CONF.system_user.user) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSING) - @mock.patch.object( - ac_svc, 'is_children_active', - mock.MagicMock(return_value=True)) + @mock.patch.object(ac_svc, "is_children_active", mock.MagicMock(return_value=True)) def test_pause_with_active_children(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) lv_ac_db, ac_ex_db = ac_svc.request_pause(lv_ac_db, cfg.CONF.system_user.user) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSING) def test_pause_subworkflow_not_cascade_up_to_workflow(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'subworkflow.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "subworkflow.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Identify the records for the subworkflow. - wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id)) + wf_ex_dbs = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + ) self.assertEqual(len(wf_ex_dbs), 1) - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_dbs[0].id)) + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_dbs[0].id) + ) self.assertEqual(len(tk_ex_dbs), 1) - tk_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[0].id)) + tk_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_dbs[0].id) + ) self.assertEqual(len(tk_ac_ex_dbs), 1) - tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_dbs[0].liveaction['id']) + tk_lv_ac_db = lv_db_access.LiveAction.get_by_id( + tk_ac_ex_dbs[0].liveaction["id"] + ) self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Pause the subworkflow. - tk_lv_ac_db, tk_ac_ex_db = ac_svc.request_pause(tk_lv_ac_db, cfg.CONF.system_user.user) + tk_lv_ac_db, tk_ac_ex_db = ac_svc.request_pause( + tk_lv_ac_db, cfg.CONF.system_user.user + ) self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSING) # Assert the main workflow is still running. @@ -154,38 +174,52 @@ def test_pause_subworkflow_not_cascade_up_to_workflow(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) def test_pause_workflow_cascade_down_to_subworkflow(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'subworkflow.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "subworkflow.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Identify the records for the main workflow. - wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id)) + wf_ex_dbs = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + ) self.assertEqual(len(wf_ex_dbs), 1) wf_ex_db = wf_ex_dbs[0] - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id)) + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(tk_ex_dbs), 1) tk_ex_db = tk_ex_dbs[0] - tk_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_db.id)) + tk_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_db.id) + ) self.assertEqual(len(tk_ac_ex_dbs), 1) tk_ac_ex_db = tk_ac_ex_dbs[0] - tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction['id']) + tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction["id"]) self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Identify the records for the subworkflow. - sub_wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(tk_ac_ex_db.id)) + sub_wf_ex_dbs = wf_db_access.WorkflowExecution.query( + action_execution=str(tk_ac_ex_db.id) + ) self.assertEqual(len(sub_wf_ex_dbs), 1) sub_wf_ex_db = sub_wf_ex_dbs[0] - sub_tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(sub_wf_ex_db.id)) + sub_tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(sub_wf_ex_db.id) + ) self.assertEqual(len(sub_tk_ex_dbs), 1) sub_tk_ex_db = sub_tk_ex_dbs[0] - sub_tk_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(sub_tk_ex_db.id)) + sub_tk_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(sub_tk_ex_db.id) + ) self.assertEqual(len(sub_tk_ac_ex_dbs), 1) # Pause the main workflow and assert it is pausing because subworkflow is still running. @@ -213,32 +247,48 @@ def test_pause_workflow_cascade_down_to_subworkflow(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSED) def test_pause_subworkflow_while_another_subworkflow_running(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'subworkflows.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "subworkflows.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Identify the records for the main workflow. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id)) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(tk_ex_dbs), 2) # Identify the records for the subworkflows. - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[0].id))[0] - t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction['id']) - t1_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t1_ac_ex_db.id))[0] + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_dbs[0].id) + )[0] + t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction["id"]) + t1_wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(t1_ac_ex_db.id) + )[0] self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) self.assertEqual(t1_wf_ex_db.status, wf_statuses.RUNNING) - t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[1].id))[0] - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) - t2_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t2_ac_ex_db.id))[0] + t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_dbs[1].id) + )[0] + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) + t2_wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(t2_ac_ex_db.id) + )[0] self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) self.assertEqual(t2_wf_ex_db.status, wf_statuses.RUNNING) # Pause the subworkflow. - t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_pause(t1_lv_ac_db, cfg.CONF.system_user.user) + t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_pause( + t1_lv_ac_db, cfg.CONF.system_user.user + ) self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSING) # Assert the main workflow is still running. @@ -246,12 +296,16 @@ def test_pause_subworkflow_while_another_subworkflow_running(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Assert the other subworkflow is still running. - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Manually notify action execution completion for the task in the subworkflow. - t1_t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[0] - t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t1_ex_db.id))[0] + t1_t1_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t1_wf_ex_db.id) + )[0] + t1_t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t1_ex_db.id) + )[0] workflows.get_engine().process(t1_t1_ac_ex_db) # Assert the subworkflow is paused and manually notify the paused of the @@ -267,18 +321,30 @@ def test_pause_subworkflow_while_another_subworkflow_running(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Assert the other subworkflow is still running. - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Manually notify action execution completion for the tasks in the other subworkflow. - t2_t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t2_wf_ex_db.id))[0] - t2_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t1_ex_db.id))[0] + t2_t1_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t2_wf_ex_db.id) + )[0] + t2_t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_t1_ex_db.id) + )[0] workflows.get_engine().process(t2_t1_ac_ex_db) - t2_t2_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t2_wf_ex_db.id))[1] - t2_t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t2_ex_db.id))[0] + t2_t2_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t2_wf_ex_db.id) + )[1] + t2_t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_t2_ex_db.id) + )[0] workflows.get_engine().process(t2_t2_ac_ex_db) - t2_t3_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t2_wf_ex_db.id))[2] - t2_t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t3_ex_db.id))[0] + t2_t3_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t2_wf_ex_db.id) + )[2] + t2_t3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_t3_ex_db.id) + )[0] workflows.get_engine().process(t2_t3_ac_ex_db) t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t2_lv_ac_db.id)) self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) @@ -293,32 +359,48 @@ def test_pause_subworkflow_while_another_subworkflow_running(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSED) def test_pause_subworkflow_while_another_subworkflow_completed(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'subworkflows.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "subworkflows.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Identify the records for the main workflow. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id)) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(tk_ex_dbs), 2) # Identify the records for the subworkflows. - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[0].id))[0] - t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction['id']) - t1_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t1_ac_ex_db.id))[0] + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_dbs[0].id) + )[0] + t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction["id"]) + t1_wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(t1_ac_ex_db.id) + )[0] self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) self.assertEqual(t1_wf_ex_db.status, wf_statuses.RUNNING) - t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[1].id))[0] - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) - t2_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t2_ac_ex_db.id))[0] + t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_dbs[1].id) + )[0] + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) + t2_wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(t2_ac_ex_db.id) + )[0] self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) self.assertEqual(t2_wf_ex_db.status, wf_statuses.RUNNING) # Pause the subworkflow. - t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_pause(t1_lv_ac_db, cfg.CONF.system_user.user) + t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_pause( + t1_lv_ac_db, cfg.CONF.system_user.user + ) self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSING) # Assert the main workflow is still running. @@ -326,18 +408,30 @@ def test_pause_subworkflow_while_another_subworkflow_completed(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Assert the other subworkflow is still running. - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Manually notify action execution completion for the tasks in the other subworkflow. - t2_t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t2_wf_ex_db.id))[0] - t2_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t1_ex_db.id))[0] + t2_t1_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t2_wf_ex_db.id) + )[0] + t2_t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_t1_ex_db.id) + )[0] workflows.get_engine().process(t2_t1_ac_ex_db) - t2_t2_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t2_wf_ex_db.id))[1] - t2_t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t2_ex_db.id))[0] + t2_t2_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t2_wf_ex_db.id) + )[1] + t2_t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_t2_ex_db.id) + )[0] workflows.get_engine().process(t2_t2_ac_ex_db) - t2_t3_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t2_wf_ex_db.id))[2] - t2_t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t3_ex_db.id))[0] + t2_t3_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t2_wf_ex_db.id) + )[2] + t2_t3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_t3_ex_db.id) + )[0] workflows.get_engine().process(t2_t3_ac_ex_db) t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t2_lv_ac_db.id)) self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) @@ -352,12 +446,16 @@ def test_pause_subworkflow_while_another_subworkflow_completed(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Assert the target subworkflow is still pausing. - t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction['id']) + t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction["id"]) self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSING) # Manually notify action execution completion for the task in the subworkflow. - t1_t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[0] - t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t1_ex_db.id))[0] + t1_t1_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t1_wf_ex_db.id) + )[0] + t1_t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t1_ex_db.id) + )[0] workflows.get_engine().process(t1_t1_ac_ex_db) # Assert the subworkflow is paused and manually notify the paused of the @@ -372,15 +470,15 @@ def test_pause_subworkflow_while_another_subworkflow_completed(self): lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSED) - @mock.patch.object( - ac_svc, 'is_children_active', - mock.MagicMock(return_value=False)) + @mock.patch.object(ac_svc, "is_children_active", mock.MagicMock(return_value=False)) def test_resume(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Pause the workflow. lv_ac_db, ac_ex_db = ac_svc.request_pause(lv_ac_db, cfg.CONF.system_user.user) @@ -388,63 +486,93 @@ def test_resume(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSING) # Identify the records for the running task(s) and manually complete it. - wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id)) - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_dbs[0].id)) + wf_ex_dbs = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + ) + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_dbs[0].id) + ) self.assertEqual(len(tk_ex_dbs), 1) - tk_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[0].id)) - tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_dbs[0].liveaction['id']) + tk_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_dbs[0].id) + ) + tk_lv_ac_db = lv_db_access.LiveAction.get_by_id( + tk_ac_ex_dbs[0].liveaction["id"] + ) self.assertEqual(tk_ac_ex_dbs[0].status, ac_const.LIVEACTION_STATUS_SUCCEEDED) self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) wf_svc.handle_action_execution_completion(tk_ac_ex_dbs[0]) # Ensure the workflow is paused. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSED, lv_ac_db.result) - wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id)) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSED, lv_ac_db.result + ) + wf_ex_dbs = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + ) self.assertEqual(wf_ex_dbs[0].status, wf_statuses.PAUSED) # Resume the workflow. lv_ac_db, ac_ex_db = ac_svc.request_resume(lv_ac_db, cfg.CONF.system_user.user) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) - wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id)) + wf_ex_dbs = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + ) self.assertEqual(wf_ex_dbs[0].status, wf_statuses.RUNNING) - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_dbs[0].id)) + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_dbs[0].id) + ) self.assertEqual(len(tk_ex_dbs), 2) def test_resume_cascade_to_subworkflow(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'subworkflow.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "subworkflow.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Identify the records for the main workflow. - wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id)) + wf_ex_dbs = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + ) self.assertEqual(len(wf_ex_dbs), 1) wf_ex_db = wf_ex_dbs[0] - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id)) + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(tk_ex_dbs), 1) tk_ex_db = tk_ex_dbs[0] - tk_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_db.id)) + tk_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_db.id) + ) self.assertEqual(len(tk_ac_ex_dbs), 1) tk_ac_ex_db = tk_ac_ex_dbs[0] - tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction['id']) + tk_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk_ac_ex_db.liveaction["id"]) self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Identify the records for the subworkflow. - sub_wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(tk_ac_ex_db.id)) + sub_wf_ex_dbs = wf_db_access.WorkflowExecution.query( + action_execution=str(tk_ac_ex_db.id) + ) self.assertEqual(len(sub_wf_ex_dbs), 1) sub_wf_ex_db = sub_wf_ex_dbs[0] - sub_tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(sub_wf_ex_db.id)) + sub_tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(sub_wf_ex_db.id) + ) self.assertEqual(len(sub_tk_ex_dbs), 1) sub_tk_ex_db = sub_tk_ex_dbs[0] - sub_tk_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(sub_tk_ex_db.id)) + sub_tk_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(sub_tk_ex_db.id) + ) self.assertEqual(len(sub_tk_ac_ex_dbs), 1) # Pause the main workflow and assert it is pausing because subworkflow is still running. @@ -481,32 +609,48 @@ def test_resume_cascade_to_subworkflow(self): self.assertEqual(tk_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) def test_resume_from_each_subworkflow_when_parent_is_paused(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'subworkflows.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "subworkflows.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Identify the records for the main workflow. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id)) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(tk_ex_dbs), 2) # Identify the records for the subworkflows. - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[0].id))[0] - t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction['id']) - t1_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t1_ac_ex_db.id))[0] + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_dbs[0].id) + )[0] + t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction["id"]) + t1_wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(t1_ac_ex_db.id) + )[0] self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) self.assertEqual(t1_wf_ex_db.status, wf_statuses.RUNNING) - t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[1].id))[0] - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) - t2_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t2_ac_ex_db.id))[0] + t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_dbs[1].id) + )[0] + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) + t2_wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(t2_ac_ex_db.id) + )[0] self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) self.assertEqual(t2_wf_ex_db.status, wf_statuses.RUNNING) # Pause one of the subworkflows. - t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_pause(t1_lv_ac_db, cfg.CONF.system_user.user) + t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_pause( + t1_lv_ac_db, cfg.CONF.system_user.user + ) self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSING) # Assert the main workflow is still running. @@ -514,12 +658,16 @@ def test_resume_from_each_subworkflow_when_parent_is_paused(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Assert the other subworkflow is still running. - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Manually notify action execution completion for the task in the subworkflow. - t1_t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[0] - t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t1_ex_db.id))[0] + t1_t1_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t1_wf_ex_db.id) + )[0] + t1_t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t1_ex_db.id) + )[0] workflows.get_engine().process(t1_t1_ac_ex_db) # Assert the subworkflow is paused and manually notify the paused of the @@ -535,11 +683,13 @@ def test_resume_from_each_subworkflow_when_parent_is_paused(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Assert the other subworkflow is still running. - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Pause the other subworkflow. - t2_lv_ac_db, t2_ac_ex_db = ac_svc.request_pause(t2_lv_ac_db, cfg.CONF.system_user.user) + t2_lv_ac_db, t2_ac_ex_db = ac_svc.request_pause( + t2_lv_ac_db, cfg.CONF.system_user.user + ) self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSING) # Assert the main workflow is still running. @@ -547,8 +697,12 @@ def test_resume_from_each_subworkflow_when_parent_is_paused(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Manually notify action execution completion for the task in the subworkflow. - t2_t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t2_wf_ex_db.id))[0] - t2_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t1_ex_db.id))[0] + t2_t1_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t2_wf_ex_db.id) + )[0] + t2_t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_t1_ex_db.id) + )[0] workflows.get_engine().process(t2_t1_ac_ex_db) # Assert the subworkflow is paused and manually notify the paused of the @@ -564,7 +718,9 @@ def test_resume_from_each_subworkflow_when_parent_is_paused(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSED) # Resume the subworkflow and assert it is running. - t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_resume(t1_lv_ac_db, cfg.CONF.system_user.user) + t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_resume( + t1_lv_ac_db, cfg.CONF.system_user.user + ) t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t1_lv_ac_db.id)) self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) @@ -573,11 +729,19 @@ def test_resume_from_each_subworkflow_when_parent_is_paused(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Manually notify action execution completion for the tasks in the subworkflow. - t1_t2_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[1] - t1_t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t2_ex_db.id))[0] + t1_t2_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t1_wf_ex_db.id) + )[1] + t1_t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t2_ex_db.id) + )[0] workflows.get_engine().process(t1_t2_ac_ex_db) - t1_t3_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[2] - t1_t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t3_ex_db.id))[0] + t1_t3_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t1_wf_ex_db.id) + )[2] + t1_t3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t3_ex_db.id) + )[0] workflows.get_engine().process(t1_t3_ac_ex_db) t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t1_lv_ac_db.id)) self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) @@ -592,32 +756,48 @@ def test_resume_from_each_subworkflow_when_parent_is_paused(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSED) def test_resume_from_subworkflow_when_parent_is_paused(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'subworkflows.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "subworkflows.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Identify the records for the main workflow. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id)) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(tk_ex_dbs), 2) # Identify the records for the subworkflows. - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[0].id))[0] - t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction['id']) - t1_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t1_ac_ex_db.id))[0] + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_dbs[0].id) + )[0] + t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction["id"]) + t1_wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(t1_ac_ex_db.id) + )[0] self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) self.assertEqual(t1_wf_ex_db.status, wf_statuses.RUNNING) - t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[1].id))[0] - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) - t2_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t2_ac_ex_db.id))[0] + t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_dbs[1].id) + )[0] + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) + t2_wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(t2_ac_ex_db.id) + )[0] self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) self.assertEqual(t2_wf_ex_db.status, wf_statuses.RUNNING) # Pause the subworkflow. - t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_pause(t1_lv_ac_db, cfg.CONF.system_user.user) + t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_pause( + t1_lv_ac_db, cfg.CONF.system_user.user + ) self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSING) # Assert the main workflow is still running. @@ -625,12 +805,16 @@ def test_resume_from_subworkflow_when_parent_is_paused(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Assert the other subworkflow is still running. - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Manually notify action execution completion for the task in the subworkflow. - t1_t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[0] - t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t1_ex_db.id))[0] + t1_t1_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t1_wf_ex_db.id) + )[0] + t1_t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t1_ex_db.id) + )[0] workflows.get_engine().process(t1_t1_ac_ex_db) # Assert the subworkflow is paused and manually notify the paused of the @@ -646,18 +830,30 @@ def test_resume_from_subworkflow_when_parent_is_paused(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Assert the other subworkflow is still running. - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Manually notify action execution completion for the tasks in the other subworkflow. - t2_t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t2_wf_ex_db.id))[0] - t2_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t1_ex_db.id))[0] + t2_t1_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t2_wf_ex_db.id) + )[0] + t2_t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_t1_ex_db.id) + )[0] workflows.get_engine().process(t2_t1_ac_ex_db) - t2_t2_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t2_wf_ex_db.id))[1] - t2_t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t2_ex_db.id))[0] + t2_t2_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t2_wf_ex_db.id) + )[1] + t2_t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_t2_ex_db.id) + )[0] workflows.get_engine().process(t2_t2_ac_ex_db) - t2_t3_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t2_wf_ex_db.id))[2] - t2_t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t3_ex_db.id))[0] + t2_t3_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t2_wf_ex_db.id) + )[2] + t2_t3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_t3_ex_db.id) + )[0] workflows.get_engine().process(t2_t3_ac_ex_db) t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t2_lv_ac_db.id)) self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) @@ -672,7 +868,9 @@ def test_resume_from_subworkflow_when_parent_is_paused(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSED) # Resume the subworkflow and assert it is running. - t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_resume(t1_lv_ac_db, cfg.CONF.system_user.user) + t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_resume( + t1_lv_ac_db, cfg.CONF.system_user.user + ) t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t1_lv_ac_db.id)) self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) @@ -681,11 +879,19 @@ def test_resume_from_subworkflow_when_parent_is_paused(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Manually notify action execution completion for the tasks in the subworkflow. - t1_t2_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[1] - t1_t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t2_ex_db.id))[0] + t1_t2_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t1_wf_ex_db.id) + )[1] + t1_t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t2_ex_db.id) + )[0] workflows.get_engine().process(t1_t2_ac_ex_db) - t1_t3_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[2] - t1_t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t3_ex_db.id))[0] + t1_t3_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t1_wf_ex_db.id) + )[2] + t1_t3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t3_ex_db.id) + )[0] workflows.get_engine().process(t1_t3_ac_ex_db) t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t1_lv_ac_db.id)) self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) @@ -696,12 +902,16 @@ def test_resume_from_subworkflow_when_parent_is_paused(self): workflows.get_engine().process(t1_ac_ex_db) # Assert task3 has started and completed. - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id)) + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(tk_ex_dbs), 3) - t3_ex_db_qry = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task3'} + t3_ex_db_qry = {"workflow_execution": str(wf_ex_db.id), "task_id": "task3"} t3_ex_db = wf_db_access.TaskExecution.query(**t3_ex_db_qry)[0] - t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t3_ex_db.id))[0] - t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(t3_ac_ex_db.liveaction['id']) + t3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t3_ex_db.id) + )[0] + t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(t3_ac_ex_db.liveaction["id"]) self.assertEqual(t3_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) wf_svc.handle_action_execution_completion(t3_ac_ex_db) @@ -710,32 +920,48 @@ def test_resume_from_subworkflow_when_parent_is_paused(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) def test_resume_from_subworkflow_when_parent_is_running(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'subworkflows.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "subworkflows.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Identify the records for the main workflow. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id)) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(tk_ex_dbs), 2) # Identify the records for the subworkflows. - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[0].id))[0] - t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction['id']) - t1_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t1_ac_ex_db.id))[0] + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_dbs[0].id) + )[0] + t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(t1_ac_ex_db.liveaction["id"]) + t1_wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(t1_ac_ex_db.id) + )[0] self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) self.assertEqual(t1_wf_ex_db.status, wf_statuses.RUNNING) - t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[1].id))[0] - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) - t2_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t2_ac_ex_db.id))[0] + t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_dbs[1].id) + )[0] + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) + t2_wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(t2_ac_ex_db.id) + )[0] self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) self.assertEqual(t2_wf_ex_db.status, wf_statuses.RUNNING) # Pause the subworkflow. - t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_pause(t1_lv_ac_db, cfg.CONF.system_user.user) + t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_pause( + t1_lv_ac_db, cfg.CONF.system_user.user + ) self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_PAUSING) # Assert the main workflow is still running. @@ -743,12 +969,16 @@ def test_resume_from_subworkflow_when_parent_is_running(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Assert the other subworkflow is still running. - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Manually notify action execution completion for the task in the subworkflow. - t1_t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[0] - t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t1_ex_db.id))[0] + t1_t1_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t1_wf_ex_db.id) + )[0] + t1_t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t1_ex_db.id) + )[0] workflows.get_engine().process(t1_t1_ac_ex_db) # Assert the subworkflow is paused and manually notify the paused of the @@ -764,11 +994,13 @@ def test_resume_from_subworkflow_when_parent_is_running(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Assert the other subworkflow is still running. - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Resume the subworkflow and assert it is running. - t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_resume(t1_lv_ac_db, cfg.CONF.system_user.user) + t1_lv_ac_db, t1_ac_ex_db = ac_svc.request_resume( + t1_lv_ac_db, cfg.CONF.system_user.user + ) t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t1_lv_ac_db.id)) self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) @@ -777,15 +1009,23 @@ def test_resume_from_subworkflow_when_parent_is_running(self): self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Assert the other subworkflow is still running. - t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction['id']) + t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(t2_ac_ex_db.liveaction["id"]) self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING) # Manually notify action execution completion for the tasks in the subworkflow. - t1_t2_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[1] - t1_t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t2_ex_db.id))[0] + t1_t2_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t1_wf_ex_db.id) + )[1] + t1_t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t2_ex_db.id) + )[0] workflows.get_engine().process(t1_t2_ac_ex_db) - t1_t3_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[2] - t1_t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t3_ex_db.id))[0] + t1_t3_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t1_wf_ex_db.id) + )[2] + t1_t3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t3_ex_db.id) + )[0] workflows.get_engine().process(t1_t3_ac_ex_db) t1_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t1_lv_ac_db.id)) self.assertEqual(t1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) @@ -796,14 +1036,26 @@ def test_resume_from_subworkflow_when_parent_is_running(self): workflows.get_engine().process(t1_ac_ex_db) # Manually notify action execution completion for the tasks in the other subworkflow. - t2_t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t2_wf_ex_db.id))[0] - t2_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t1_ex_db.id))[0] + t2_t1_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t2_wf_ex_db.id) + )[0] + t2_t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_t1_ex_db.id) + )[0] workflows.get_engine().process(t2_t1_ac_ex_db) - t2_t2_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t2_wf_ex_db.id))[1] - t2_t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t2_ex_db.id))[0] + t2_t2_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t2_wf_ex_db.id) + )[1] + t2_t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_t2_ex_db.id) + )[0] workflows.get_engine().process(t2_t2_ac_ex_db) - t2_t3_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t2_wf_ex_db.id))[2] - t2_t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_t3_ex_db.id))[0] + t2_t3_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t2_wf_ex_db.id) + )[2] + t2_t3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_t3_ex_db.id) + )[0] workflows.get_engine().process(t2_t3_ac_ex_db) t2_lv_ac_db = lv_db_access.LiveAction.get_by_id(str(t2_lv_ac_db.id)) self.assertEqual(t2_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) @@ -814,12 +1066,16 @@ def test_resume_from_subworkflow_when_parent_is_running(self): workflows.get_engine().process(t2_ac_ex_db) # Assert task3 has started and completed. - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id)) + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(tk_ex_dbs), 3) - t3_ex_db_qry = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task3'} + t3_ex_db_qry = {"workflow_execution": str(wf_ex_db.id), "task_id": "task3"} t3_ex_db = wf_db_access.TaskExecution.query(**t3_ex_db_qry)[0] - t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t3_ex_db.id))[0] - t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(t3_ac_ex_db.liveaction['id']) + t3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t3_ex_db.id) + )[0] + t3_lv_ac_db = lv_db_access.LiveAction.get_by_id(t3_ac_ex_db.liveaction["id"]) self.assertEqual(t3_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) wf_svc.handle_action_execution_completion(t3_ac_ex_db) diff --git a/contrib/runners/orquesta_runner/tests/unit/test_policies.py b/contrib/runners/orquesta_runner/tests/unit/test_policies.py index 2595609f63b..81ab6392627 100644 --- a/contrib/runners/orquesta_runner/tests/unit/test_policies.py +++ b/contrib/runners/orquesta_runner/tests/unit/test_policies.py @@ -23,6 +23,7 @@ # XXX: actionsensor import depends on config being setup. import st2tests.config as tests_config + tests_config.parse_args() from tests.unit import base @@ -46,37 +47,45 @@ from st2tests.mocks import workflow as mock_wf_ex_xport -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) PACKS = [ TEST_PACK_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] @mock.patch.object( - publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) + publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None) +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create)) + "publish_create", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create + ), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state)) + "publish_state", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state + ), +) class OrquestaRunnerTest(st2tests.ExecutionDbTestCase): - @classmethod def setUpClass(cls): super(OrquestaRunnerTest, cls).setUpClass() @@ -86,7 +95,7 @@ def setUpClass(cls): policiesregistrar.register_policy_types(st2common) # Register test pack(s). - registrar_options = {'use_pack_cache': False, 'fail_on_failure': True} + registrar_options = {"use_pack_cache": False, "fail_on_failure": True} actions_registrar = actionsregistrar.ActionsRegistrar(**registrar_options) policies_registrar = policiesregistrar.PolicyRegistrar(**registrar_options) @@ -106,27 +115,37 @@ def tearDown(self): ac_ex_db.delete() def test_retry_policy_applied_on_workflow_failure(self): - wf_name = 'sequential' - wf_ac_ref = TEST_PACK + '.' + wf_name - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + '.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_name = "sequential" + wf_ac_ref = TEST_PACK + "." + wf_name + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + ".yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Ensure there is only one execution recorded. self.assertEqual(len(lv_db_access.LiveAction.query(action=wf_ac_ref)), 1) # Identify the records for the workflow and task. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + t1_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + )[0] t1_lv_ac_db = lv_db_access.LiveAction.query(task_execution=str(t1_ex_db.id))[0] - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0] + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + )[0] # Manually set the status to fail. ac_svc.update_status(t1_lv_ac_db, ac_const.LIVEACTION_STATUS_FAILED) t1_lv_ac_db = lv_db_access.LiveAction.query(task_execution=str(t1_ex_db.id))[0] - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0] + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + )[0] self.assertEqual(t1_ac_ex_db.status, ac_const.LIVEACTION_STATUS_FAILED) notifier.get_notifier().process(t1_ac_ex_db) workflows.get_engine().process(t1_ac_ex_db) @@ -140,32 +159,48 @@ def test_retry_policy_applied_on_workflow_failure(self): self.assertEqual(len(lv_db_access.LiveAction.query(action=wf_ac_ref)), 2) def test_no_retry_policy_applied_on_task_failure(self): - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'subworkflow.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "subworkflow.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) - self.assertEqual(lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result) + self.assertEqual( + lv_ac_db.status, ac_const.LIVEACTION_STATUS_RUNNING, lv_ac_db.result + ) # Identify the records for the main workflow. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id)) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(tk_ex_dbs), 1) # Identify the records for the tasks. - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[0].id))[0] - t1_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t1_ac_ex_db.id))[0] + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_dbs[0].id) + )[0] + t1_wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(t1_ac_ex_db.id) + )[0] self.assertEqual(t1_ac_ex_db.status, ac_const.LIVEACTION_STATUS_RUNNING) self.assertEqual(t1_wf_ex_db.status, wf_statuses.RUNNING) # Ensure there is only one execution for the task. - tk_ac_ref = TEST_PACK + '.' + 'sequential' + tk_ac_ref = TEST_PACK + "." + "sequential" self.assertEqual(len(lv_db_access.LiveAction.query(action=tk_ac_ref)), 1) # Fail the subtask of the subworkflow. - t1_t1_ex_db = wf_db_access.TaskExecution.query(workflow_execution=str(t1_wf_ex_db.id))[0] - t1_t1_lv_ac_db = lv_db_access.LiveAction.query(task_execution=str(t1_t1_ex_db.id))[0] + t1_t1_ex_db = wf_db_access.TaskExecution.query( + workflow_execution=str(t1_wf_ex_db.id) + )[0] + t1_t1_lv_ac_db = lv_db_access.LiveAction.query( + task_execution=str(t1_t1_ex_db.id) + )[0] ac_svc.update_status(t1_t1_lv_ac_db, ac_const.LIVEACTION_STATUS_FAILED) - t1_t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_t1_ex_db.id))[0] + t1_t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_t1_ex_db.id) + )[0] self.assertEqual(t1_t1_ac_ex_db.status, ac_const.LIVEACTION_STATUS_FAILED) notifier.get_notifier().process(t1_t1_ac_ex_db) workflows.get_engine().process(t1_t1_ac_ex_db) diff --git a/contrib/runners/orquesta_runner/tests/unit/test_rerun.py b/contrib/runners/orquesta_runner/tests/unit/test_rerun.py index 59f2f94d08c..191f3a06814 100644 --- a/contrib/runners/orquesta_runner/tests/unit/test_rerun.py +++ b/contrib/runners/orquesta_runner/tests/unit/test_rerun.py @@ -20,6 +20,7 @@ import st2tests import st2tests.config as tests_config + tests_config.parse_args() from local_runner import local_shell_command_runner @@ -41,41 +42,57 @@ from st2tests.mocks import workflow as mock_wf_ex_xport -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) PACKS = [ TEST_PACK_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] RUNNER_RESULT_FAILED = (action_constants.LIVEACTION_STATUS_FAILED, {}, {}) -RUNNER_RESULT_RUNNING = (action_constants.LIVEACTION_STATUS_RUNNING, {'stdout': '...'}, {}) -RUNNER_RESULT_SUCCEEDED = (action_constants.LIVEACTION_STATUS_SUCCEEDED, {'stdout': 'foobar'}, {}) +RUNNER_RESULT_RUNNING = ( + action_constants.LIVEACTION_STATUS_RUNNING, + {"stdout": "..."}, + {}, +) +RUNNER_RESULT_SUCCEEDED = ( + action_constants.LIVEACTION_STATUS_SUCCEEDED, + {"stdout": "foobar"}, + {}, +) @mock.patch.object( - publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) + publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None) +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create)) + "publish_create", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create + ), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state)) + "publish_state", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state + ), +) class OrquestRunnerTest(st2tests.WorkflowTestCase): - @classmethod def setUpClass(cls): super(OrquestRunnerTest, cls).setUpClass() @@ -85,28 +102,35 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: actions_registrar.register_from_pack(pack) @mock.patch.object( - local_shell_command_runner.LocalShellCommandRunner, 'run', - mock.MagicMock(side_effect=[RUNNER_RESULT_FAILED, RUNNER_RESULT_SUCCEEDED])) + local_shell_command_runner.LocalShellCommandRunner, + "run", + mock.MagicMock(side_effect=[RUNNER_RESULT_FAILED, RUNNER_RESULT_SUCCEEDED]), + ) def test_rerun_workflow(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - wf_input = {'who': 'Thanos'} - lv_ac_db1 = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + wf_input = {"who": "Thanos"} + lv_ac_db1 = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db1, ac_ex_db1 = action_service.request(lv_ac_db1) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db1.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db1.id) + )[0] # Process task1. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) self.assertEqual(tk1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_FAILED) workflow_service.handle_action_execution_completion(tk1_ac_ex_db) tk1_ex_db = wf_db_access.TaskExecution.get_by_id(tk1_ex_db.id) @@ -121,18 +145,15 @@ def test_rerun_workflow(self): self.assertEqual(ac_ex_db1.status, action_constants.LIVEACTION_STATUS_FAILED) # Rerun the execution. - context = { - 're-run': { - 'ref': str(ac_ex_db1.id), - 'tasks': ['task1'] - } - } - - lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name'], context=context) + context = {"re-run": {"ref": str(ac_ex_db1.id), "tasks": ["task1"]}} + + lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"], context=context) lv_ac_db2, ac_ex_db2 = action_service.request(lv_ac_db2) # Assert the workflow reran ok and is running. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db2.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db2.id) + )[0] self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) lv_ac_db2 = lv_db_access.LiveAction.get_by_id(str(lv_ac_db2.id)) self.assertEqual(lv_ac_db2.status, action_constants.LIVEACTION_STATUS_RUNNING) @@ -140,33 +161,45 @@ def test_rerun_workflow(self): self.assertEqual(ac_ex_db2.status, action_constants.LIVEACTION_STATUS_RUNNING) # Process task1 and make sure it succeeds. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_dbs = wf_db_access.TaskExecution.query(**query_filters) self.assertEqual(len(tk1_ex_dbs), 2) tk1_ex_dbs = sorted(tk1_ex_dbs, key=lambda x: x.start_timestamp) tk1_ex_db = tk1_ex_dbs[-1] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) - self.assertEqual(tk1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) + self.assertEqual( + tk1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflow_service.handle_action_execution_completion(tk1_ac_ex_db) tk1_ex_db = wf_db_access.TaskExecution.get_by_id(tk1_ex_db.id) self.assertEqual(tk1_ex_db.status, wf_statuses.SUCCEEDED) @mock.patch.object( - local_shell_command_runner.LocalShellCommandRunner, 'run', - mock.MagicMock(side_effect=[RUNNER_RESULT_FAILED])) + local_shell_command_runner.LocalShellCommandRunner, + "run", + mock.MagicMock(side_effect=[RUNNER_RESULT_FAILED]), + ) def test_rerun_with_missing_workflow_execution_id(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - wf_input = {'who': 'Thanos'} - lv_ac_db1 = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + wf_input = {"who": "Thanos"} + lv_ac_db1 = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db1, ac_ex_db1 = action_service.request(lv_ac_db1) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db1.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db1.id) + )[0] # Process task1. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) self.assertEqual(tk1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_FAILED) workflow_service.handle_action_execution_completion(tk1_ac_ex_db) tk1_ex_db = wf_db_access.TaskExecution.get_by_id(tk1_ex_db.id) @@ -184,49 +217,52 @@ def test_rerun_with_missing_workflow_execution_id(self): wf_db_access.WorkflowExecution.delete(wf_ex_db, publish=False) # Manually delete the workflow_execution_id from context of the action execution. - lv_ac_db1.context.pop('workflow_execution') + lv_ac_db1.context.pop("workflow_execution") lv_ac_db1 = lv_db_access.LiveAction.add_or_update(lv_ac_db1, publish=False) ac_ex_db1 = execution_service.update_execution(lv_ac_db1, publish=False) # Rerun the execution. - context = { - 're-run': { - 'ref': str(ac_ex_db1.id), - 'tasks': ['task1'] - } - } - - lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name'], context=context) + context = {"re-run": {"ref": str(ac_ex_db1.id), "tasks": ["task1"]}} + + lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"], context=context) lv_ac_db2, ac_ex_db2 = action_service.request(lv_ac_db2) expected_error = ( - 'Unable to rerun workflow execution because ' - 'workflow_execution_id is not provided.' + "Unable to rerun workflow execution because " + "workflow_execution_id is not provided." ) # Assert the workflow rerrun fails. lv_ac_db2 = lv_db_access.LiveAction.get_by_id(str(lv_ac_db2.id)) self.assertEqual(lv_ac_db2.status, action_constants.LIVEACTION_STATUS_FAILED) - self.assertEqual(expected_error, lv_ac_db2.result['errors'][0]['message']) + self.assertEqual(expected_error, lv_ac_db2.result["errors"][0]["message"]) ac_ex_db2 = ex_db_access.ActionExecution.get_by_id(str(ac_ex_db2.id)) self.assertEqual(ac_ex_db2.status, action_constants.LIVEACTION_STATUS_FAILED) - self.assertEqual(expected_error, ac_ex_db2.result['errors'][0]['message']) + self.assertEqual(expected_error, ac_ex_db2.result["errors"][0]["message"]) @mock.patch.object( - local_shell_command_runner.LocalShellCommandRunner, 'run', - mock.MagicMock(side_effect=[RUNNER_RESULT_FAILED])) + local_shell_command_runner.LocalShellCommandRunner, + "run", + mock.MagicMock(side_effect=[RUNNER_RESULT_FAILED]), + ) def test_rerun_with_invalid_workflow_execution(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - wf_input = {'who': 'Thanos'} - lv_ac_db1 = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + wf_input = {"who": "Thanos"} + lv_ac_db1 = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db1, ac_ex_db1 = action_service.request(lv_ac_db1) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db1.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db1.id) + )[0] # Process task1. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) self.assertEqual(tk1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_FAILED) workflow_service.handle_action_execution_completion(tk1_ac_ex_db) tk1_ex_db = wf_db_access.TaskExecution.get_by_id(tk1_ex_db.id) @@ -244,45 +280,50 @@ def test_rerun_with_invalid_workflow_execution(self): wf_db_access.WorkflowExecution.delete(wf_ex_db, publish=False) # Rerun the execution. - context = { - 're-run': { - 'ref': str(ac_ex_db1.id), - 'tasks': ['task1'] - } - } - - lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name'], context=context) + context = {"re-run": {"ref": str(ac_ex_db1.id), "tasks": ["task1"]}} + + lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"], context=context) lv_ac_db2, ac_ex_db2 = action_service.request(lv_ac_db2) expected_error = ( 'Unable to rerun workflow execution "%s" because ' - 'it does not exist.' % str(wf_ex_db.id) + "it does not exist." % str(wf_ex_db.id) ) # Assert the workflow rerrun fails. lv_ac_db2 = lv_db_access.LiveAction.get_by_id(str(lv_ac_db2.id)) self.assertEqual(lv_ac_db2.status, action_constants.LIVEACTION_STATUS_FAILED) - self.assertEqual(expected_error, lv_ac_db2.result['errors'][0]['message']) + self.assertEqual(expected_error, lv_ac_db2.result["errors"][0]["message"]) ac_ex_db2 = ex_db_access.ActionExecution.get_by_id(str(ac_ex_db2.id)) self.assertEqual(ac_ex_db2.status, action_constants.LIVEACTION_STATUS_FAILED) - self.assertEqual(expected_error, ac_ex_db2.result['errors'][0]['message']) + self.assertEqual(expected_error, ac_ex_db2.result["errors"][0]["message"]) @mock.patch.object( - local_shell_command_runner.LocalShellCommandRunner, 'run', - mock.MagicMock(side_effect=[RUNNER_RESULT_RUNNING])) + local_shell_command_runner.LocalShellCommandRunner, + "run", + mock.MagicMock(side_effect=[RUNNER_RESULT_RUNNING]), + ) def test_rerun_workflow_still_running(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - wf_input = {'who': 'Thanos'} - lv_ac_db1 = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + wf_input = {"who": "Thanos"} + lv_ac_db1 = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db1, ac_ex_db1 = action_service.request(lv_ac_db1) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db1.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db1.id) + )[0] # Process task1. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) - self.assertEqual(tk1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) + self.assertEqual( + tk1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING + ) # Assert workflow is still running. wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) @@ -293,47 +334,52 @@ def test_rerun_workflow_still_running(self): self.assertEqual(ac_ex_db1.status, action_constants.LIVEACTION_STATUS_RUNNING) # Rerun the execution. - context = { - 're-run': { - 'ref': str(ac_ex_db1.id), - 'tasks': ['task1'] - } - } - - lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name'], context=context) + context = {"re-run": {"ref": str(ac_ex_db1.id), "tasks": ["task1"]}} + + lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"], context=context) lv_ac_db2, ac_ex_db2 = action_service.request(lv_ac_db2) expected_error = ( 'Unable to rerun workflow execution "%s" because ' - 'it is not in a completed state.' % str(wf_ex_db.id) + "it is not in a completed state." % str(wf_ex_db.id) ) # Assert the workflow rerrun fails. lv_ac_db2 = lv_db_access.LiveAction.get_by_id(str(lv_ac_db2.id)) self.assertEqual(lv_ac_db2.status, action_constants.LIVEACTION_STATUS_FAILED) - self.assertEqual(expected_error, lv_ac_db2.result['errors'][0]['message']) + self.assertEqual(expected_error, lv_ac_db2.result["errors"][0]["message"]) ac_ex_db2 = ex_db_access.ActionExecution.get_by_id(str(ac_ex_db2.id)) self.assertEqual(ac_ex_db2.status, action_constants.LIVEACTION_STATUS_FAILED) - self.assertEqual(expected_error, ac_ex_db2.result['errors'][0]['message']) + self.assertEqual(expected_error, ac_ex_db2.result["errors"][0]["message"]) @mock.patch.object( - workflow_service, 'request_rerun', - mock.MagicMock(side_effect=Exception('Unexpected.'))) + workflow_service, + "request_rerun", + mock.MagicMock(side_effect=Exception("Unexpected.")), + ) @mock.patch.object( - local_shell_command_runner.LocalShellCommandRunner, 'run', - mock.MagicMock(side_effect=[RUNNER_RESULT_FAILED])) + local_shell_command_runner.LocalShellCommandRunner, + "run", + mock.MagicMock(side_effect=[RUNNER_RESULT_FAILED]), + ) def test_rerun_with_unexpected_error(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - wf_input = {'who': 'Thanos'} - lv_ac_db1 = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + wf_input = {"who": "Thanos"} + lv_ac_db1 = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db1, ac_ex_db1 = action_service.request(lv_ac_db1) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db1.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db1.id) + )[0] # Process task1. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) self.assertEqual(tk1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_FAILED) workflow_service.handle_action_execution_completion(tk1_ac_ex_db) tk1_ex_db = wf_db_access.TaskExecution.get_by_id(tk1_ex_db.id) @@ -351,62 +397,75 @@ def test_rerun_with_unexpected_error(self): wf_db_access.WorkflowExecution.delete(wf_ex_db, publish=False) # Rerun the execution. - context = { - 're-run': { - 'ref': str(ac_ex_db1.id), - 'tasks': ['task1'] - } - } - - lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name'], context=context) + context = {"re-run": {"ref": str(ac_ex_db1.id), "tasks": ["task1"]}} + + lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"], context=context) lv_ac_db2, ac_ex_db2 = action_service.request(lv_ac_db2) - expected_error = 'Unexpected.' + expected_error = "Unexpected." # Assert the workflow rerrun fails. lv_ac_db2 = lv_db_access.LiveAction.get_by_id(str(lv_ac_db2.id)) self.assertEqual(lv_ac_db2.status, action_constants.LIVEACTION_STATUS_FAILED) - self.assertEqual(expected_error, lv_ac_db2.result['errors'][0]['message']) + self.assertEqual(expected_error, lv_ac_db2.result["errors"][0]["message"]) ac_ex_db2 = ex_db_access.ActionExecution.get_by_id(str(ac_ex_db2.id)) self.assertEqual(ac_ex_db2.status, action_constants.LIVEACTION_STATUS_FAILED) - self.assertEqual(expected_error, ac_ex_db2.result['errors'][0]['message']) + self.assertEqual(expected_error, ac_ex_db2.result["errors"][0]["message"]) @mock.patch.object( - local_shell_command_runner.LocalShellCommandRunner, 'run', - mock.MagicMock(return_value=RUNNER_RESULT_SUCCEEDED)) + local_shell_command_runner.LocalShellCommandRunner, + "run", + mock.MagicMock(return_value=RUNNER_RESULT_SUCCEEDED), + ) def test_rerun_workflow_already_succeeded(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - wf_input = {'who': 'Thanos'} - lv_ac_db1 = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + wf_input = {"who": "Thanos"} + lv_ac_db1 = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db1, ac_ex_db1 = action_service.request(lv_ac_db1) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db1.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db1.id) + )[0] # Process task1. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) - self.assertEqual(tk1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) + self.assertEqual( + tk1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflow_service.handle_action_execution_completion(tk1_ac_ex_db) tk1_ex_db = wf_db_access.TaskExecution.get_by_id(tk1_ex_db.id) self.assertEqual(tk1_ex_db.status, wf_statuses.SUCCEEDED) # Process task2. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task2'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task2"} tk2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk2_ex_db.id))[0] - tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction['id']) - self.assertEqual(tk2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + tk2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk2_ex_db.id) + )[0] + tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction["id"]) + self.assertEqual( + tk2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflow_service.handle_action_execution_completion(tk2_ac_ex_db) tk2_ex_db = wf_db_access.TaskExecution.get_by_id(tk2_ex_db.id) self.assertEqual(tk2_ex_db.status, wf_statuses.SUCCEEDED) # Process task3. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task3'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task3"} tk3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk3_ex_db.id))[0] - tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction['id']) - self.assertEqual(tk3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + tk3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk3_ex_db.id) + )[0] + tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction["id"]) + self.assertEqual( + tk3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflow_service.handle_action_execution_completion(tk3_ac_ex_db) tk3_ex_db = wf_db_access.TaskExecution.get_by_id(tk3_ex_db.id) self.assertEqual(tk3_ex_db.status, wf_statuses.SUCCEEDED) @@ -420,18 +479,15 @@ def test_rerun_workflow_already_succeeded(self): self.assertEqual(ac_ex_db1.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) # Rerun the execution. - context = { - 're-run': { - 'ref': str(ac_ex_db1.id), - 'tasks': ['task1'] - } - } - - lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name'], context=context) + context = {"re-run": {"ref": str(ac_ex_db1.id), "tasks": ["task1"]}} + + lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"], context=context) lv_ac_db2, ac_ex_db2 = action_service.request(lv_ac_db2) # Assert the workflow reran ok and is running. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db2.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db2.id) + )[0] self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) lv_ac_db2 = lv_db_access.LiveAction.get_by_id(str(lv_ac_db2.id)) self.assertEqual(lv_ac_db2.status, action_constants.LIVEACTION_STATUS_RUNNING) @@ -439,40 +495,52 @@ def test_rerun_workflow_already_succeeded(self): self.assertEqual(ac_ex_db2.status, action_constants.LIVEACTION_STATUS_RUNNING) # Assert there are two task1 and the last entry succeeded. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_dbs = wf_db_access.TaskExecution.query(**query_filters) self.assertEqual(len(tk1_ex_dbs), 2) tk1_ex_dbs = sorted(tk1_ex_dbs, key=lambda x: x.start_timestamp) tk1_ex_db = tk1_ex_dbs[-1] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) - self.assertEqual(tk1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) + self.assertEqual( + tk1_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflow_service.handle_action_execution_completion(tk1_ac_ex_db) tk1_ex_db = wf_db_access.TaskExecution.get_by_id(tk1_ex_db.id) self.assertEqual(tk1_ex_db.status, wf_statuses.SUCCEEDED) # Assert there are two task2 and the last entry succeeded. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task2'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task2"} tk2_ex_dbs = wf_db_access.TaskExecution.query(**query_filters) self.assertEqual(len(tk2_ex_dbs), 2) tk2_ex_dbs = sorted(tk2_ex_dbs, key=lambda x: x.start_timestamp) tk2_ex_db = tk2_ex_dbs[-1] - tk2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk2_ex_db.id))[0] - tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction['id']) - self.assertEqual(tk2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + tk2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk2_ex_db.id) + )[0] + tk2_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk2_ac_ex_db.liveaction["id"]) + self.assertEqual( + tk2_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflow_service.handle_action_execution_completion(tk2_ac_ex_db) tk2_ex_db = wf_db_access.TaskExecution.get_by_id(tk2_ex_db.id) self.assertEqual(tk2_ex_db.status, wf_statuses.SUCCEEDED) # Assert there are two task3 and the last entry succeeded. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task3'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task3"} tk3_ex_dbs = wf_db_access.TaskExecution.query(**query_filters) self.assertEqual(len(tk3_ex_dbs), 2) tk3_ex_dbs = sorted(tk3_ex_dbs, key=lambda x: x.start_timestamp) tk3_ex_db = tk3_ex_dbs[-1] - tk3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk3_ex_db.id))[0] - tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction['id']) - self.assertEqual(tk3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + tk3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk3_ex_db.id) + )[0] + tk3_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk3_ac_ex_db.liveaction["id"]) + self.assertEqual( + tk3_lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) workflow_service.handle_action_execution_completion(tk3_ac_ex_db) tk3_ex_db = wf_db_access.TaskExecution.get_by_id(tk3_ex_db.id) self.assertEqual(tk3_ex_db.status, wf_statuses.SUCCEEDED) diff --git a/contrib/runners/orquesta_runner/tests/unit/test_with_items.py b/contrib/runners/orquesta_runner/tests/unit/test_with_items.py index cc0846d733b..66768745862 100644 --- a/contrib/runners/orquesta_runner/tests/unit/test_with_items.py +++ b/contrib/runners/orquesta_runner/tests/unit/test_with_items.py @@ -25,6 +25,7 @@ # XXX: actionsensor import depends on config being setup. import st2tests.config as tests_config + tests_config.parse_args() from tests.unit import base @@ -48,37 +49,45 @@ from st2tests.mocks import workflow as mock_wf_ex_xport -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) PACKS = [ TEST_PACK_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] @mock.patch.object( - publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) + publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None) +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create)) + "publish_create", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create + ), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state)) + "publish_state", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state + ), +) class OrquestaWithItemsTest(st2tests.ExecutionDbTestCase): - @classmethod def setUpClass(cls): super(OrquestaWithItemsTest, cls).setUpClass() @@ -88,8 +97,7 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: @@ -101,35 +109,34 @@ def get_runner_class(cls, runner_name): def set_execution_status(self, lv_ac_db_id, status): lv_ac_db = action_utils.update_liveaction_status( - status=status, - liveaction_id=lv_ac_db_id, - publish=False + status=status, liveaction_id=lv_ac_db_id, publish=False ) - ac_ex_db = execution_service.update_execution( - lv_ac_db, - publish=False - ) + ac_ex_db = execution_service.update_execution(lv_ac_db, publish=False) return lv_ac_db, ac_ex_db def test_with_items(self): num_items = 3 - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'with-items.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "with-items.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) # Process the with items task. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) self.assertEqual(len(t1_ac_ex_dbs), num_items) @@ -155,20 +162,26 @@ def test_with_items(self): def test_with_items_failure(self): num_items = 10 - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'with-items-failure.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "with-items-failure.yaml" + ) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) # Process the with items task. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) self.assertEqual(len(t1_ac_ex_dbs), num_items) @@ -195,52 +208,68 @@ def test_with_items_failure(self): def test_with_items_empty_list(self): items = [] num_items = len(items) - wf_input = {'members': items} + wf_input = {"members": items} - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'with-items.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, "with-items.yaml") + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Wait for the liveaction to complete. - lv_ac_db = self._wait_on_status(lv_ac_db, action_constants.LIVEACTION_STATUS_SUCCEEDED) + lv_ac_db = self._wait_on_status( + lv_ac_db, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) # Retrieve records from database. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) # Ensure there is no action executions for the task and the task is already completed. self.assertEqual(len(t1_ac_ex_dbs), num_items) self.assertEqual(t1_ex_db.status, wf_statuses.SUCCEEDED) - self.assertDictEqual(t1_ex_db.result, {'items': []}) + self.assertDictEqual(t1_ex_db.result, {"items": []}) # Assert the main workflow is completed. wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) self.assertEqual(wf_ex_db.status, wf_statuses.SUCCEEDED) lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertDictEqual(lv_ac_db.result, {'output': {'items': []}}) + self.assertDictEqual(lv_ac_db.result, {"output": {"items": []}}) def test_with_items_concurrency(self): num_items = 3 concurrency = 2 - wf_input = {'concurrency': concurrency} - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'with-items-concurrency.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_input = {"concurrency": concurrency} + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "with-items-concurrency.yaml" + ) + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) # Process the first set of action executions from with items concurrency. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) self.assertEqual(len(t1_ac_ex_dbs), concurrency) @@ -261,7 +290,9 @@ def test_with_items_concurrency(self): self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Process the second set of action executions from with items concurrency. - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) self.assertEqual(len(t1_ac_ex_dbs), num_items) @@ -287,30 +318,37 @@ def test_with_items_concurrency(self): def test_with_items_cancellation(self): num_items = 3 - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'with-items-concurrency.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "with-items-concurrency.yaml" + ) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert the workflow execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) self.assertEqual(t1_ex_db.status, wf_statuses.RUNNING) self.assertEqual(len(t1_ac_ex_dbs), num_items) # Reset the action executions to running status. for ac_ex in t1_ac_ex_dbs: self.set_execution_status( - ac_ex.liveaction['id'], - action_constants.LIVEACTION_STATUS_RUNNING + ac_ex.liveaction["id"], action_constants.LIVEACTION_STATUS_RUNNING ) - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) status = [ ac_ex.status == action_constants.LIVEACTION_STATUS_RUNNING @@ -328,11 +366,12 @@ def test_with_items_cancellation(self): # Manually succeed the action executions and process completion. for ac_ex in t1_ac_ex_dbs: self.set_execution_status( - ac_ex.liveaction['id'], - action_constants.LIVEACTION_STATUS_SUCCEEDED + ac_ex.liveaction["id"], action_constants.LIVEACTION_STATUS_SUCCEEDED ) - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) status = [ ac_ex.status == action_constants.LIVEACTION_STATUS_SUCCEEDED @@ -353,31 +392,40 @@ def test_with_items_cancellation(self): def test_with_items_concurrency_cancellation(self): concurrency = 2 - wf_input = {'concurrency': concurrency} - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'with-items-concurrency.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_input = {"concurrency": concurrency} + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "with-items-concurrency.yaml" + ) + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert the workflow execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) self.assertEqual(t1_ex_db.status, wf_statuses.RUNNING) self.assertEqual(len(t1_ac_ex_dbs), concurrency) # Reset the action executions to running status. for ac_ex in t1_ac_ex_dbs: self.set_execution_status( - ac_ex.liveaction['id'], - action_constants.LIVEACTION_STATUS_RUNNING + ac_ex.liveaction["id"], action_constants.LIVEACTION_STATUS_RUNNING ) - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) status = [ ac_ex.status == action_constants.LIVEACTION_STATUS_RUNNING @@ -395,11 +443,12 @@ def test_with_items_concurrency_cancellation(self): # Manually succeed the action executions and process completion. for ac_ex in t1_ac_ex_dbs: self.set_execution_status( - ac_ex.liveaction['id'], - action_constants.LIVEACTION_STATUS_SUCCEEDED + ac_ex.liveaction["id"], action_constants.LIVEACTION_STATUS_SUCCEEDED ) - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) status = [ ac_ex.status == action_constants.LIVEACTION_STATUS_SUCCEEDED @@ -420,30 +469,37 @@ def test_with_items_concurrency_cancellation(self): def test_with_items_pause_and_resume(self): num_items = 3 - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'with-items-concurrency.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "with-items-concurrency.yaml" + ) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert the workflow execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) self.assertEqual(t1_ex_db.status, wf_statuses.RUNNING) self.assertEqual(len(t1_ac_ex_dbs), num_items) # Reset the action executions to running status. for ac_ex in t1_ac_ex_dbs: self.set_execution_status( - ac_ex.liveaction['id'], - action_constants.LIVEACTION_STATUS_RUNNING + ac_ex.liveaction["id"], action_constants.LIVEACTION_STATUS_RUNNING ) - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) status = [ ac_ex.status == action_constants.LIVEACTION_STATUS_RUNNING @@ -461,11 +517,12 @@ def test_with_items_pause_and_resume(self): # Manually succeed the action executions and process completion. for ac_ex in t1_ac_ex_dbs: self.set_execution_status( - ac_ex.liveaction['id'], - action_constants.LIVEACTION_STATUS_SUCCEEDED + ac_ex.liveaction["id"], action_constants.LIVEACTION_STATUS_SUCCEEDED ) - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) status = [ ac_ex.status == action_constants.LIVEACTION_STATUS_SUCCEEDED @@ -498,31 +555,40 @@ def test_with_items_concurrency_pause_and_resume(self): num_items = 3 concurrency = 2 - wf_input = {'concurrency': concurrency} - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'with-items-concurrency.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_input = {"concurrency": concurrency} + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "with-items-concurrency.yaml" + ) + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert the workflow execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) self.assertEqual(t1_ex_db.status, wf_statuses.RUNNING) self.assertEqual(len(t1_ac_ex_dbs), concurrency) # Reset the action executions to running status. for ac_ex in t1_ac_ex_dbs: self.set_execution_status( - ac_ex.liveaction['id'], - action_constants.LIVEACTION_STATUS_RUNNING + ac_ex.liveaction["id"], action_constants.LIVEACTION_STATUS_RUNNING ) - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) status = [ ac_ex.status == action_constants.LIVEACTION_STATUS_RUNNING @@ -540,11 +606,12 @@ def test_with_items_concurrency_pause_and_resume(self): # Manually succeed the action executions and process completion. for ac_ex in t1_ac_ex_dbs: self.set_execution_status( - ac_ex.liveaction['id'], - action_constants.LIVEACTION_STATUS_SUCCEEDED + ac_ex.liveaction["id"], action_constants.LIVEACTION_STATUS_SUCCEEDED ) - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) status = [ ac_ex.status == action_constants.LIVEACTION_STATUS_SUCCEEDED @@ -572,7 +639,9 @@ def test_with_items_concurrency_pause_and_resume(self): self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) # Check new set of action execution is scheduled. - t1_ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id)) + t1_ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + ) self.assertEqual(len(t1_ac_ex_dbs), num_items) # Manually process the last action execution. @@ -585,20 +654,34 @@ def test_with_items_concurrency_pause_and_resume(self): self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) def test_subworkflow_with_items_empty_list(self): - wf_input = {'members': []} - wf_meta = base.get_wf_fixture_meta_data(TEST_PACK_PATH, 'with-items-empty-parent.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters=wf_input) + wf_input = {"members": []} + wf_meta = base.get_wf_fixture_meta_data( + TEST_PACK_PATH, "with-items-empty-parent.yaml" + ) + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters=wf_input + ) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Identify the records for the main workflow. - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] - tk_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id)) + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] + tk_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(tk_ex_dbs), 1) # Identify the records for the tasks. - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk_ex_dbs[0].id))[0] - t1_wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(t1_ac_ex_db.id))[0] - self.assertEqual(t1_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk_ex_dbs[0].id) + )[0] + t1_wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(t1_ac_ex_db.id) + )[0] + self.assertEqual( + t1_ac_ex_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) self.assertEqual(t1_wf_ex_db.status, wf_statuses.SUCCEEDED) # Manually processing completion of the subworkflow in task1. diff --git a/contrib/runners/python_runner/dist_utils.py b/contrib/runners/python_runner/dist_utils.py index a6f62c8cc2a..2f2043cf29f 100644 --- a/contrib/runners/python_runner/dist_utils.py +++ b/contrib/runners/python_runner/dist_utils.py @@ -43,17 +43,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" __all__ = [ - 'check_pip_is_installed', - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_is_installed", + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] @@ -64,15 +64,15 @@ def check_pip_is_installed(): try: import pip # NOQA except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) return True -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ @@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'): import pip if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) return True @@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path): reqs = [] def _get_link(line): - vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+'] + vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"] for vcs_prefix in vcs_prefixes: - if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)): - req_name = re.findall('.*#egg=(.+)([&|@]).*$', line) + if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)): + req_name = re.findall(".*#egg=(.+)([&|@]).*$", line) if not req_name: - req_name = re.findall('.*#egg=(.+?)$', line) + req_name = re.findall(".*#egg=(.+?)$", line) else: req_name = req_name[0] if not req_name: - raise ValueError('Line "%s" is missing "#egg="' % (line)) + raise ValueError( + 'Line "%s" is missing "#egg="' % (line) + ) - link = line.replace('-e ', '').strip() + link = line.replace("-e ", "").strip() return link, req_name[0] return None, None - with open(requirements_file_path, 'r') as fp: + with open(requirements_file_path, "r") as fp: for line in fp.readlines(): line = line.strip() - if line.startswith('#') or not line: + if line.startswith("#") or not line: continue link, req_name = _get_link(line=line) @@ -131,8 +135,8 @@ def _get_link(line): else: req_name = line - if ';' in req_name: - req_name = req_name.split(';')[0].strip() + if ";" in req_name: + req_name = req_name.split(";")[0].strip() reqs.append(req_name) @@ -146,7 +150,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -155,14 +159,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/contrib/runners/python_runner/python_runner/__init__.py b/contrib/runners/python_runner/python_runner/__init__.py index bbe290db9a7..e6d3f15e0bd 100644 --- a/contrib/runners/python_runner/python_runner/__init__.py +++ b/contrib/runners/python_runner/python_runner/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = "3.4dev" diff --git a/contrib/runners/python_runner/python_runner/python_action_wrapper.py b/contrib/runners/python_runner/python_runner/python_action_wrapper.py index 119f6bdf847..b9ae0757b3a 100644 --- a/contrib/runners/python_runner/python_runner/python_action_wrapper.py +++ b/contrib/runners/python_runner/python_runner/python_action_wrapper.py @@ -18,7 +18,8 @@ # Ignore CryptographyDeprecationWarning warnings which appear on older versions of Python 2.7 import warnings from cryptography.utils import CryptographyDeprecationWarning -warnings.filterwarnings('ignore', category=CryptographyDeprecationWarning) + +warnings.filterwarnings("ignore", category=CryptographyDeprecationWarning) import os import sys @@ -33,8 +34,8 @@ # lives gets added to sys.path and we don't want that. # Note: We need to use just the suffix, because full path is different depending if the process # is ran in virtualenv or not -RUNNERS_PATH_SUFFIX = 'st2common/runners' -if __name__ == '__main__': +RUNNERS_PATH_SUFFIX = "st2common/runners" +if __name__ == "__main__": script_path = sys.path[0] if RUNNERS_PATH_SUFFIX in script_path: sys.path.pop(0) @@ -61,10 +62,7 @@ from st2common.constants.runners import PYTHON_RUNNER_INVALID_ACTION_STATUS_EXIT_CODE from st2common.constants.runners import PYTHON_RUNNER_DEFAULT_LOG_LEVEL -__all__ = [ - 'PythonActionWrapper', - 'ActionService' -] +__all__ = ["PythonActionWrapper", "ActionService"] LOG = logging.getLogger(__name__) @@ -104,15 +102,18 @@ def datastore_service(self): # duration of the action lifetime action_name = self._action_wrapper._class_name log_level = self._action_wrapper._log_level - logger = get_logger_for_python_runner_action(action_name=action_name, - log_level=log_level) + logger = get_logger_for_python_runner_action( + action_name=action_name, log_level=log_level + ) pack_name = self._action_wrapper._pack class_name = self._action_wrapper._class_name - auth_token = os.environ.get('ST2_ACTION_AUTH_TOKEN', None) - self._datastore_service = ActionDatastoreService(logger=logger, - pack_name=pack_name, - class_name=class_name, - auth_token=auth_token) + auth_token = os.environ.get("ST2_ACTION_AUTH_TOKEN", None) + self._datastore_service = ActionDatastoreService( + logger=logger, + pack_name=pack_name, + class_name=class_name, + auth_token=auth_token, + ) return self._datastore_service ################################## @@ -130,20 +131,32 @@ def list_values(self, local=True, prefix=None): return self.datastore_service.list_values(local=local, prefix=prefix) def get_value(self, name, local=True, scope=SYSTEM_SCOPE, decrypt=False): - return self.datastore_service.get_value(name=name, local=local, scope=scope, - decrypt=decrypt) + return self.datastore_service.get_value( + name=name, local=local, scope=scope, decrypt=decrypt + ) - def set_value(self, name, value, ttl=None, local=True, scope=SYSTEM_SCOPE, encrypt=False): - return self.datastore_service.set_value(name=name, value=value, ttl=ttl, local=local, - scope=scope, encrypt=encrypt) + def set_value( + self, name, value, ttl=None, local=True, scope=SYSTEM_SCOPE, encrypt=False + ): + return self.datastore_service.set_value( + name=name, value=value, ttl=ttl, local=local, scope=scope, encrypt=encrypt + ) def delete_value(self, name, local=True, scope=SYSTEM_SCOPE): return self.datastore_service.delete_value(name=name, local=local, scope=scope) class PythonActionWrapper(object): - def __init__(self, pack, file_path, config=None, parameters=None, user=None, parent_args=None, - log_level=PYTHON_RUNNER_DEFAULT_LOG_LEVEL): + def __init__( + self, + pack, + file_path, + config=None, + parameters=None, + user=None, + parent_args=None, + log_level=PYTHON_RUNNER_DEFAULT_LOG_LEVEL, + ): """ :param pack: Name of the pack this action belongs to. :type pack: ``str`` @@ -173,19 +186,22 @@ def __init__(self, pack, file_path, config=None, parameters=None, user=None, par self._log_level = log_level self._class_name = None - self._logger = logging.getLogger('PythonActionWrapper') + self._logger = logging.getLogger("PythonActionWrapper") try: st2common_config.parse_args(args=self._parent_args) except Exception as e: - LOG.debug('Failed to parse config using parent args (parent_args=%s): %s' % - (str(self._parent_args), six.text_type(e))) + LOG.debug( + "Failed to parse config using parent args (parent_args=%s): %s" + % (str(self._parent_args), six.text_type(e)) + ) # Note: We can only set a default user value if one is not provided after parsing the # config if not self._user: # Note: We use late import to avoid performance overhead from oslo_config import cfg + self._user = cfg.CONF.system_user.user def run(self): @@ -201,26 +217,25 @@ def run(self): action_status = None action_result = output - action_output = { - 'result': action_result, - 'status': None - } + action_output = {"result": action_result, "status": None} if action_status is not None and not isinstance(action_status, bool): - sys.stderr.write('Status returned from the action run() method must either be ' - 'True or False, got: %s\n' % (action_status)) + sys.stderr.write( + "Status returned from the action run() method must either be " + "True or False, got: %s\n" % (action_status) + ) sys.stderr.write(INVALID_STATUS_ERROR_MESSAGE) sys.exit(PYTHON_RUNNER_INVALID_ACTION_STATUS_EXIT_CODE) if action_status is not None and isinstance(action_status, bool): - action_output['status'] = action_status + action_output["status"] = action_status # Special case if result object is not JSON serializable - aka user wanted to return a # non-simple type (e.g. class instance or other non-JSON serializable type) try: - json.dumps(action_output['result']) + json.dumps(action_output["result"]) except TypeError: - action_output['result'] = str(action_output['result']) + action_output["result"] = str(action_output["result"]) try: print_output = json.dumps(action_output) @@ -229,7 +244,7 @@ def run(self): # Print output to stdout so the parent can capture it sys.stdout.write(ACTION_OUTPUT_RESULT_DELIMITER) - sys.stdout.write(print_output + '\n') + sys.stdout.write(print_output + "\n") sys.stdout.write(ACTION_OUTPUT_RESULT_DELIMITER) sys.stdout.flush() @@ -238,17 +253,22 @@ def _get_action_instance(self): actions_cls = action_loader.register_plugin(Action, self._file_path) except Exception as e: tb_msg = traceback.format_exc() - msg = ('Failed to load action class from file "%s" (action file most likely doesn\'t ' - 'exist or contains invalid syntax): %s' % (self._file_path, six.text_type(e))) - msg += '\n\n' + tb_msg + msg = ( + 'Failed to load action class from file "%s" (action file most likely doesn\'t ' + "exist or contains invalid syntax): %s" + % (self._file_path, six.text_type(e)) + ) + msg += "\n\n" + tb_msg exc_cls = type(e) raise exc_cls(msg) action_cls = actions_cls[0] if actions_cls and len(actions_cls) > 0 else None if not action_cls: - raise Exception('File "%s" has no action class or the file doesn\'t exist.' % - (self._file_path)) + raise Exception( + 'File "%s" has no action class or the file doesn\'t exist.' + % (self._file_path) + ) # Retrieve name of the action class # Note - we need to either use cls.__name_ or inspect.getmro(cls)[0].__name__ to @@ -256,31 +276,45 @@ def _get_action_instance(self): self._class_name = action_cls.__name__ action_service = ActionService(action_wrapper=self) - action_instance = get_action_class_instance(action_cls=action_cls, - config=self._config, - action_service=action_service) + action_instance = get_action_class_instance( + action_cls=action_cls, config=self._config, action_service=action_service + ) return action_instance -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Python action runner process wrapper') - parser.add_argument('--pack', required=True, - help='Name of the pack this action belongs to') - parser.add_argument('--file-path', required=True, - help='Path to the action module') - parser.add_argument('--config', required=False, - help='Pack config serialized as JSON') - parser.add_argument('--parameters', required=False, - help='Serialized action parameters') - parser.add_argument('--stdin-parameters', required=False, action='store_true', - help='Serialized action parameters via stdin') - parser.add_argument('--user', required=False, - help='User who triggered the action execution') - parser.add_argument('--parent-args', required=False, - help='Command line arguments passed to the parent process serialized as ' - ' JSON') - parser.add_argument('--log-level', required=False, default=PYTHON_RUNNER_DEFAULT_LOG_LEVEL, - help='Log level for actions') +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Python action runner process wrapper") + parser.add_argument( + "--pack", required=True, help="Name of the pack this action belongs to" + ) + parser.add_argument("--file-path", required=True, help="Path to the action module") + parser.add_argument( + "--config", required=False, help="Pack config serialized as JSON" + ) + parser.add_argument( + "--parameters", required=False, help="Serialized action parameters" + ) + parser.add_argument( + "--stdin-parameters", + required=False, + action="store_true", + help="Serialized action parameters via stdin", + ) + parser.add_argument( + "--user", required=False, help="User who triggered the action execution" + ) + parser.add_argument( + "--parent-args", + required=False, + help="Command line arguments passed to the parent process serialized as " + " JSON", + ) + parser.add_argument( + "--log-level", + required=False, + default=PYTHON_RUNNER_DEFAULT_LOG_LEVEL, + help="Log level for actions", + ) args = parser.parse_args() config = json.loads(args.config) if args.config else {} @@ -289,46 +323,54 @@ def _get_action_instance(self): log_level = args.log_level if not isinstance(config, dict): - raise ValueError('Pack config needs to be a dictionary') + raise ValueError("Pack config needs to be a dictionary") parameters = {} if args.parameters: - LOG.debug('Getting parameters from argument') + LOG.debug("Getting parameters from argument") args_parameters = args.parameters args_parameters = json.loads(args_parameters) if args_parameters else {} parameters.update(args_parameters) if args.stdin_parameters: - LOG.debug('Getting parameters from stdin') + LOG.debug("Getting parameters from stdin") i, _, _ = select.select([sys.stdin], [], [], READ_STDIN_INPUT_TIMEOUT) if not i: - raise ValueError(('No input received and timed out while waiting for ' - 'parameters from stdin')) + raise ValueError( + ( + "No input received and timed out while waiting for " + "parameters from stdin" + ) + ) stdin_data = sys.stdin.readline().strip() try: stdin_parameters = json.loads(stdin_data) - stdin_parameters = stdin_parameters.get('parameters', {}) + stdin_parameters = stdin_parameters.get("parameters", {}) except Exception as e: - msg = ('Failed to parse parameters from stdin. Expected a JSON object with ' - '"parameters" attribute: %s' % (six.text_type(e))) + msg = ( + "Failed to parse parameters from stdin. Expected a JSON object with " + '"parameters" attribute: %s' % (six.text_type(e)) + ) raise ValueError(msg) parameters.update(stdin_parameters) - LOG.debug('Received parameters: %s', parameters) + LOG.debug("Received parameters: %s", parameters) assert isinstance(parent_args, list) - obj = PythonActionWrapper(pack=args.pack, - file_path=args.file_path, - config=config, - parameters=parameters, - user=user, - parent_args=parent_args, - log_level=log_level) + obj = PythonActionWrapper( + pack=args.pack, + file_path=args.file_path, + config=config, + parameters=parameters, + user=user, + parent_args=parent_args, + log_level=log_level, + ) obj.run() diff --git a/contrib/runners/python_runner/python_runner/python_runner.py b/contrib/runners/python_runner/python_runner/python_runner.py index fd412c890e2..b11668e000a 100644 --- a/contrib/runners/python_runner/python_runner/python_runner.py +++ b/contrib/runners/python_runner/python_runner/python_runner.py @@ -58,34 +58,39 @@ from python_runner import python_action_wrapper __all__ = [ - 'PythonRunner', - - 'get_runner', - 'get_metadata', + "PythonRunner", + "get_runner", + "get_metadata", ] LOG = logging.getLogger(__name__) # constants to lookup in runner_parameters. -RUNNER_ENV = 'env' -RUNNER_TIMEOUT = 'timeout' -RUNNER_LOG_LEVEL = 'log_level' +RUNNER_ENV = "env" +RUNNER_TIMEOUT = "timeout" +RUNNER_LOG_LEVEL = "log_level" # Environment variables which can't be specified by the user BLACKLISTED_ENV_VARS = [ # We don't allow user to override PYTHONPATH since this would break things - 'pythonpath' + "pythonpath" ] BASE_DIR = os.path.dirname(os.path.abspath(python_action_wrapper.__file__)) -WRAPPER_SCRIPT_NAME = 'python_action_wrapper.py' +WRAPPER_SCRIPT_NAME = "python_action_wrapper.py" WRAPPER_SCRIPT_PATH = os.path.join(BASE_DIR, WRAPPER_SCRIPT_NAME) class PythonRunner(GitWorktreeActionRunner): - - def __init__(self, runner_id, config=None, timeout=PYTHON_RUNNER_DEFAULT_ACTION_TIMEOUT, - log_level=None, sandbox=True, use_parent_args=True): + def __init__( + self, + runner_id, + config=None, + timeout=PYTHON_RUNNER_DEFAULT_ACTION_TIMEOUT, + log_level=None, + sandbox=True, + use_parent_args=True, + ): """ :param timeout: Action execution timeout in seconds. @@ -123,36 +128,42 @@ def pre_run(self): self._log_level = cfg.CONF.actionrunner.python_runner_log_level def run(self, action_parameters): - LOG.debug('Running pythonrunner.') - LOG.debug('Getting pack name.') + LOG.debug("Running pythonrunner.") + LOG.debug("Getting pack name.") pack = self.get_pack_ref() - LOG.debug('Getting user.') + LOG.debug("Getting user.") user = self.get_user() - LOG.debug('Serializing parameters.') - serialized_parameters = json.dumps(action_parameters if action_parameters else {}) - LOG.debug('Getting virtualenv_path.') + LOG.debug("Serializing parameters.") + serialized_parameters = json.dumps( + action_parameters if action_parameters else {} + ) + LOG.debug("Getting virtualenv_path.") virtualenv_path = get_sandbox_virtualenv_path(pack=pack) - LOG.debug('Getting python path.') + LOG.debug("Getting python path.") if self._sandbox: python_path = get_sandbox_python_binary_path(pack=pack) else: python_path = sys.executable - LOG.debug('Checking virtualenv path.') + LOG.debug("Checking virtualenv path.") if virtualenv_path and not os.path.isdir(virtualenv_path): - format_values = {'pack': pack, 'virtualenv_path': virtualenv_path} + format_values = {"pack": pack, "virtualenv_path": virtualenv_path} msg = PACK_VIRTUALENV_DOESNT_EXIST % format_values - LOG.error('virtualenv_path set but not a directory: %s', msg) + LOG.error("virtualenv_path set but not a directory: %s", msg) raise Exception(msg) - LOG.debug('Checking entry_point.') + LOG.debug("Checking entry_point.") if not self.entry_point: - LOG.error('Action "%s" is missing entry_point attribute' % (self.action.name)) - raise Exception('Action "%s" is missing entry_point attribute' % (self.action.name)) + LOG.error( + 'Action "%s" is missing entry_point attribute' % (self.action.name) + ) + raise Exception( + 'Action "%s" is missing entry_point attribute' % (self.action.name) + ) # Note: We pass config as command line args so the actual wrapper process is standalone # and doesn't need access to db - LOG.debug('Setting args.') + LOG.debug("Setting args.") if self._use_parent_args: parent_args = json.dumps(sys.argv[1:]) @@ -161,12 +172,12 @@ def run(self, action_parameters): args = [ python_path, - '-u', # unbuffered mode so streaming mode works as expected + "-u", # unbuffered mode so streaming mode works as expected WRAPPER_SCRIPT_PATH, - '--pack=%s' % (pack), - '--file-path=%s' % (self.entry_point), - '--user=%s' % (user), - '--parent-args=%s' % (parent_args), + "--pack=%s" % (pack), + "--file-path=%s" % (self.entry_point), + "--user=%s" % (user), + "--parent-args=%s" % (parent_args), ] subprocess = concurrency.get_subprocess_module() @@ -178,35 +189,36 @@ def run(self, action_parameters): stdin_params = None if len(serialized_parameters) >= MAX_PARAM_LENGTH: stdin = subprocess.PIPE - LOG.debug('Parameters are too big...changing to stdin') + LOG.debug("Parameters are too big...changing to stdin") stdin_params = '{"parameters": %s}\n' % (serialized_parameters) - args.append('--stdin-parameters') + args.append("--stdin-parameters") else: - LOG.debug('Parameters are just right...adding them to arguments') - args.append('--parameters=%s' % (serialized_parameters)) + LOG.debug("Parameters are just right...adding them to arguments") + args.append("--parameters=%s" % (serialized_parameters)) if self._config: - args.append('--config=%s' % (json.dumps(self._config))) + args.append("--config=%s" % (json.dumps(self._config))) if self._log_level != PYTHON_RUNNER_DEFAULT_LOG_LEVEL: # We only pass --log-level parameter if non default log level value is specified - args.append('--log-level=%s' % (self._log_level)) + args.append("--log-level=%s" % (self._log_level)) # We need to ensure all the st2 dependencies are also available to the subprocess - LOG.debug('Setting env.') + LOG.debug("Setting env.") env = os.environ.copy() - env['PATH'] = get_sandbox_path(virtualenv_path=virtualenv_path) + env["PATH"] = get_sandbox_path(virtualenv_path=virtualenv_path) sandbox_python_path = get_sandbox_python_path_for_python_action( - pack=pack, - inherit_from_parent=True, - inherit_parent_virtualenv=True) + pack=pack, inherit_from_parent=True, inherit_parent_virtualenv=True + ) if self._enable_common_pack_libs: try: pack_common_libs_path = self._get_pack_common_libs_path(pack_ref=pack) except Exception as e: - LOG.debug('Failed to retrieve pack common lib path: %s' % (six.text_type(e))) + LOG.debug( + "Failed to retrieve pack common lib path: %s" % (six.text_type(e)) + ) # There is no MongoDB connection available in Lambda and pack common lib # functionality is not also mandatory for Lambda so we simply ignore those errors. # Note: We should eventually refactor this code to make runner standalone and not @@ -217,13 +229,13 @@ def run(self, action_parameters): pack_common_libs_path = None # Remove leading : (if any) - if sandbox_python_path.startswith(':'): + if sandbox_python_path.startswith(":"): sandbox_python_path = sandbox_python_path[1:] if self._enable_common_pack_libs and pack_common_libs_path: - sandbox_python_path = pack_common_libs_path + ':' + sandbox_python_path + sandbox_python_path = pack_common_libs_path + ":" + sandbox_python_path - env['PYTHONPATH'] = sandbox_python_path + env["PYTHONPATH"] = sandbox_python_path # Include user provided environment variables (if any) user_env_vars = self._get_env_vars() @@ -238,40 +250,53 @@ def run(self, action_parameters): stdout = StringIO() stderr = StringIO() - store_execution_stdout_line = functools.partial(store_execution_output_data, - output_type='stdout') - store_execution_stderr_line = functools.partial(store_execution_output_data, - output_type='stderr') - - read_and_store_stdout = make_read_and_store_stream_func(execution_db=self.execution, - action_db=self.action, store_data_func=store_execution_stdout_line) - read_and_store_stderr = make_read_and_store_stream_func(execution_db=self.execution, - action_db=self.action, store_data_func=store_execution_stderr_line) + store_execution_stdout_line = functools.partial( + store_execution_output_data, output_type="stdout" + ) + store_execution_stderr_line = functools.partial( + store_execution_output_data, output_type="stderr" + ) + + read_and_store_stdout = make_read_and_store_stream_func( + execution_db=self.execution, + action_db=self.action, + store_data_func=store_execution_stdout_line, + ) + read_and_store_stderr = make_read_and_store_stream_func( + execution_db=self.execution, + action_db=self.action, + store_data_func=store_execution_stderr_line, + ) command_string = list2cmdline(args) if stdin_params: - command_string = 'echo %s | %s' % (quote_unix(stdin_params), command_string) + command_string = "echo %s | %s" % (quote_unix(stdin_params), command_string) bufsize = cfg.CONF.actionrunner.stream_output_buffer_size - LOG.debug('Running command (bufsize=%s): PATH=%s PYTHONPATH=%s %s' % (bufsize, env['PATH'], - env['PYTHONPATH'], - command_string)) - exit_code, stdout, stderr, timed_out = run_command(cmd=args, - stdin=stdin, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - shell=False, - env=env, - timeout=self._timeout, - read_stdout_func=read_and_store_stdout, - read_stderr_func=read_and_store_stderr, - read_stdout_buffer=stdout, - read_stderr_buffer=stderr, - stdin_value=stdin_params, - bufsize=bufsize) - LOG.debug('Returning values: %s, %s, %s, %s', exit_code, stdout, stderr, timed_out) - LOG.debug('Returning.') + LOG.debug( + "Running command (bufsize=%s): PATH=%s PYTHONPATH=%s %s" + % (bufsize, env["PATH"], env["PYTHONPATH"], command_string) + ) + exit_code, stdout, stderr, timed_out = run_command( + cmd=args, + stdin=stdin, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=False, + env=env, + timeout=self._timeout, + read_stdout_func=read_and_store_stdout, + read_stderr_func=read_and_store_stderr, + read_stdout_buffer=stdout, + read_stderr_buffer=stderr, + stdin_value=stdin_params, + bufsize=bufsize, + ) + LOG.debug( + "Returning values: %s, %s, %s, %s", exit_code, stdout, stderr, timed_out + ) + LOG.debug("Returning.") return self._get_output_values(exit_code, stdout, stderr, timed_out) def _get_pack_common_libs_path(self, pack_ref): @@ -280,7 +305,9 @@ def _get_pack_common_libs_path(self, pack_ref): (if used). """ worktree_path = self.git_worktree_path - pack_common_libs_path = get_pack_common_libs_path_for_pack_ref(pack_ref=pack_ref) + pack_common_libs_path = get_pack_common_libs_path_for_pack_ref( + pack_ref=pack_ref + ) if not worktree_path: return pack_common_libs_path @@ -288,18 +315,20 @@ def _get_pack_common_libs_path(self, pack_ref): # Modify the path so it uses git worktree directory pack_base_path = get_pack_base_path(pack_name=pack_ref) - new_pack_common_libs_path = pack_common_libs_path.replace(pack_base_path, '') + new_pack_common_libs_path = pack_common_libs_path.replace(pack_base_path, "") # Remove leading slash (if any) - if new_pack_common_libs_path.startswith('/'): + if new_pack_common_libs_path.startswith("/"): new_pack_common_libs_path = new_pack_common_libs_path[1:] - new_pack_common_libs_path = os.path.join(worktree_path, new_pack_common_libs_path) + new_pack_common_libs_path = os.path.join( + worktree_path, new_pack_common_libs_path + ) # Check to prevent directory traversal common_prefix = os.path.commonprefix([worktree_path, new_pack_common_libs_path]) if common_prefix != worktree_path: - raise ValueError('pack libs path is not located inside the pack directory') + raise ValueError("pack libs path is not located inside the pack directory") return new_pack_common_libs_path @@ -312,7 +341,7 @@ def _get_output_values(self, exit_code, stdout, stderr, timed_out): :rtype: ``tuple`` """ if timed_out: - error = 'Action failed to complete in %s seconds' % (self._timeout) + error = "Action failed to complete in %s seconds" % (self._timeout) else: error = None @@ -335,16 +364,18 @@ def _get_output_values(self, exit_code, stdout, stderr, timed_out): action_result = json.loads(action_result) except Exception as e: # Failed to de-serialize the result, probably it contains non-simple type or similar - LOG.warning('Failed to de-serialize result "%s": %s' % (str(action_result), - six.text_type(e))) + LOG.warning( + 'Failed to de-serialize result "%s": %s' + % (str(action_result), six.text_type(e)) + ) if action_result: if isinstance(action_result, dict): - result = action_result.get('result', None) - status = action_result.get('status', None) + result = action_result.get("result", None) + status = action_result.get("status", None) else: # Failed to de-serialize action result aka result is a string - match = re.search("'result': (.*?)$", action_result or '') + match = re.search("'result': (.*?)$", action_result or "") if match: action_result = match.groups()[0] @@ -352,21 +383,22 @@ def _get_output_values(self, exit_code, stdout, stderr, timed_out): result = action_result status = None else: - result = 'None' + result = "None" status = None output = { - 'stdout': stdout, - 'stderr': stderr, - 'exit_code': exit_code, - 'result': result + "stdout": stdout, + "stderr": stderr, + "exit_code": exit_code, + "result": result, } if error: - output['error'] = error + output["error"] = error - status = self._get_final_status(action_status=status, timed_out=timed_out, - exit_code=exit_code) + status = self._get_final_status( + action_status=status, timed_out=timed_out, exit_code=exit_code + ) return (status, output, None) def _get_final_status(self, action_status, timed_out, exit_code): @@ -415,8 +447,10 @@ def _get_env_vars(self): to_delete.append(key) for key in to_delete: - LOG.debug('User specified environment variable "%s" which is being ignored...' % - (key)) + LOG.debug( + 'User specified environment variable "%s" which is being ignored...' + % (key) + ) del env_vars[key] return env_vars @@ -441,4 +475,4 @@ def get_runner(config=None): def get_metadata(): - return get_runner_metadata('python_runner')[0] + return get_runner_metadata("python_runner")[0] diff --git a/contrib/runners/python_runner/setup.py b/contrib/runners/python_runner/setup.py index c1a5d6c20a2..04e55a31c03 100644 --- a/contrib/runners/python_runner/setup.py +++ b/contrib/runners/python_runner/setup.py @@ -26,30 +26,30 @@ from python_runner import __version__ BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt') +REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt") install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE) apply_vagrant_workaround() setup( - name='stackstorm-runner-python', + name="stackstorm-runner-python", version=__version__, - description='Python action runner for StackStorm event-driven automation platform', - author='StackStorm', - author_email='info@stackstorm.com', - license='Apache License (2.0)', - url='https://stackstorm.com/', + description="Python action runner for StackStorm event-driven automation platform", + author="StackStorm", + author_email="info@stackstorm.com", + license="Apache License (2.0)", + url="https://stackstorm.com/", install_requires=install_reqs, dependency_links=dep_links, - test_suite='tests', + test_suite="tests", zip_safe=False, include_package_data=True, - packages=find_packages(exclude=['setuptools', 'tests']), - package_data={'python_runner': ['runner.yaml']}, + packages=find_packages(exclude=["setuptools", "tests"]), + package_data={"python_runner": ["runner.yaml"]}, scripts=[], entry_points={ - 'st2common.runners.runner': [ - 'python-script = python_runner.python_runner', + "st2common.runners.runner": [ + "python-script = python_runner.python_runner", ], - } + }, ) diff --git a/contrib/runners/python_runner/tests/integration/test_python_action_process_wrapper.py b/contrib/runners/python_runner/tests/integration/test_python_action_process_wrapper.py index 27e42ecc5ab..e1d39361a2f 100644 --- a/contrib/runners/python_runner/tests/integration/test_python_action_process_wrapper.py +++ b/contrib/runners/python_runner/tests/integration/test_python_action_process_wrapper.py @@ -42,49 +42,53 @@ from st2common.util.shell import run_command from six.moves import range -__all__ = [ - 'PythonRunnerActionWrapperProcessTestCase' -] +__all__ = ["PythonRunnerActionWrapperProcessTestCase"] # Maximum limit for the process wrapper script execution time (in seconds) WRAPPER_PROCESS_RUN_TIME_UPPER_LIMIT = 0.31 -ASSERTION_ERROR_MESSAGE = (""" +ASSERTION_ERROR_MESSAGE = """ Python wrapper process script took more than %s seconds to execute (%s). This most likely means that a direct or in-direct import of a module which takes a long time to load has been added (e.g. jsonschema, pecan, kombu, etc). Please review recently changed and added code for potential slow import issues and refactor / re-organize code if possible. -""".strip()) +""".strip() BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -WRAPPER_SCRIPT_PATH = os.path.join(BASE_DIR, - '../../../python_runner/python_runner/python_action_wrapper.py') +WRAPPER_SCRIPT_PATH = os.path.join( + BASE_DIR, "../../../python_runner/python_runner/python_action_wrapper.py" +) WRAPPER_SCRIPT_PATH = os.path.abspath(WRAPPER_SCRIPT_PATH) -TIME_BINARY_PATH = find_executable('time') +TIME_BINARY_PATH = find_executable("time") TIME_BINARY_AVAILABLE = TIME_BINARY_PATH is not None -@unittest2.skipIf(not TIME_BINARY_PATH, 'time binary not available') +@unittest2.skipIf(not TIME_BINARY_PATH, "time binary not available") class PythonRunnerActionWrapperProcessTestCase(unittest2.TestCase): def test_process_wrapper_exits_in_reasonable_timeframe(self): # 1. Verify wrapper script path is correct and file exists self.assertTrue(os.path.isfile(WRAPPER_SCRIPT_PATH)) # 2. First run it without time to verify path is valid - command_string = 'python %s --file-path=foo.py' % (WRAPPER_SCRIPT_PATH) + command_string = "python %s --file-path=foo.py" % (WRAPPER_SCRIPT_PATH) _, _, stderr = run_command(command_string, shell=True) - self.assertIn('usage: python_action_wrapper.py', stderr) + self.assertIn("usage: python_action_wrapper.py", stderr) - expected_msg_1 = 'python_action_wrapper.py: error: argument --pack is required' - expected_msg_2 = ('python_action_wrapper.py: error: the following arguments are ' - 'required: --pack') + expected_msg_1 = "python_action_wrapper.py: error: argument --pack is required" + expected_msg_2 = ( + "python_action_wrapper.py: error: the following arguments are " + "required: --pack" + ) self.assertTrue(expected_msg_1 in stderr or expected_msg_2 in stderr) # 3. Now time it - command_string = '%s -f "%%e" python %s' % (TIME_BINARY_PATH, WRAPPER_SCRIPT_PATH) + command_string = '%s -f "%%e" python %s' % ( + TIME_BINARY_PATH, + WRAPPER_SCRIPT_PATH, + ) # Do multiple runs and average it run_times = [] @@ -92,14 +96,18 @@ def test_process_wrapper_exits_in_reasonable_timeframe(self): count = 8 for i in range(0, count): _, _, stderr = run_command(command_string, shell=True) - stderr = stderr.strip().split('\n')[-1] + stderr = stderr.strip().split("\n")[-1] run_time_seconds = float(stderr) run_times.append(run_time_seconds) - avg_run_time_seconds = (sum(run_times) / count) - assertion_msg = ASSERTION_ERROR_MESSAGE % (WRAPPER_PROCESS_RUN_TIME_UPPER_LIMIT, - avg_run_time_seconds) - self.assertTrue(avg_run_time_seconds <= WRAPPER_PROCESS_RUN_TIME_UPPER_LIMIT, assertion_msg) + avg_run_time_seconds = sum(run_times) / count + assertion_msg = ASSERTION_ERROR_MESSAGE % ( + WRAPPER_PROCESS_RUN_TIME_UPPER_LIMIT, + avg_run_time_seconds, + ) + self.assertTrue( + avg_run_time_seconds <= WRAPPER_PROCESS_RUN_TIME_UPPER_LIMIT, assertion_msg + ) def test_config_with_a_lot_of_items_and_a_lot_of_parameters_work_fine(self): # Test case which verifies that actions with large configs and a lot of parameters work @@ -107,48 +115,55 @@ def test_config_with_a_lot_of_items_and_a_lot_of_parameters_work_fine(self): # upper limit on the size. config = {} for index in range(0, 50): - config['key_%s' % (index)] = 'value value foo %s' % (index) + config["key_%s" % (index)] = "value value foo %s" % (index) config = json.dumps(config) parameters = {} for index in range(0, 30): - parameters['param_foo_%s' % (index)] = 'some param value %s' % (index) + parameters["param_foo_%s" % (index)] = "some param value %s" % (index) parameters = json.dumps(parameters) - file_path = os.path.join(BASE_DIR, '../../../../examples/actions/noop.py') + file_path = os.path.join(BASE_DIR, "../../../../examples/actions/noop.py") - command_string = ('python %s --pack=dummy --file-path=%s --config=\'%s\' ' - '--parameters=\'%s\'' % - (WRAPPER_SCRIPT_PATH, file_path, config, parameters)) + command_string = ( + "python %s --pack=dummy --file-path=%s --config='%s' " + "--parameters='%s'" % (WRAPPER_SCRIPT_PATH, file_path, config, parameters) + ) exit_code, stdout, stderr = run_command(command_string, shell=True) self.assertEqual(exit_code, 0) self.assertIn('"status"', stdout) def test_stdin_params_timeout_no_stdin_data_provided(self): config = {} - file_path = os.path.join(BASE_DIR, '../../../../examples/actions/noop.py') + file_path = os.path.join(BASE_DIR, "../../../../examples/actions/noop.py") # try running in a sub-shell to ensure that the stdin is empty - command_string = ('python %s --pack=dummy --file-path=%s --config=\'%s\' ' - '--stdin-parameters' % - (WRAPPER_SCRIPT_PATH, file_path, config)) + command_string = ( + "python %s --pack=dummy --file-path=%s --config='%s' " + "--stdin-parameters" % (WRAPPER_SCRIPT_PATH, file_path, config) + ) exit_code, stdout, stderr = run_command(command_string, shell=True) - expected_msg = ('ValueError: No input received and timed out while waiting for parameters ' - 'from stdin') + expected_msg = ( + "ValueError: No input received and timed out while waiting for parameters " + "from stdin" + ) self.assertEqual(exit_code, 1) self.assertIn(expected_msg, stderr) def test_stdin_params_invalid_format_friendly_error(self): config = {} - file_path = os.path.join(BASE_DIR, '../../../contrib/examples/actions/noop.py') + file_path = os.path.join(BASE_DIR, "../../../contrib/examples/actions/noop.py") # Not a valid JSON string - command_string = ('echo "invalid" | python %s --pack=dummy --file-path=%s --config=\'%s\' ' - '--stdin-parameters' % - (WRAPPER_SCRIPT_PATH, file_path, config)) + command_string = ( + "echo \"invalid\" | python %s --pack=dummy --file-path=%s --config='%s' " + "--stdin-parameters" % (WRAPPER_SCRIPT_PATH, file_path, config) + ) exit_code, stdout, stderr = run_command(command_string, shell=True) - expected_msg = ('ValueError: Failed to parse parameters from stdin. Expected a JSON ' - 'object with "parameters" attribute') + expected_msg = ( + "ValueError: Failed to parse parameters from stdin. Expected a JSON " + 'object with "parameters" attribute' + ) self.assertEqual(exit_code, 1) self.assertIn(expected_msg, stderr) diff --git a/contrib/runners/python_runner/tests/integration/test_pythonrunner_behavior.py b/contrib/runners/python_runner/tests/integration/test_pythonrunner_behavior.py index 328a4a0fc0a..a6d300be23e 100644 --- a/contrib/runners/python_runner/tests/integration/test_pythonrunner_behavior.py +++ b/contrib/runners/python_runner/tests/integration/test_pythonrunner_behavior.py @@ -30,13 +30,12 @@ from st2tests.base import CleanDbTestCase from st2tests.fixturesloader import get_fixtures_base_path -__all__ = [ - 'PythonRunnerBehaviorTestCase' -] +__all__ = ["PythonRunnerBehaviorTestCase"] BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -WRAPPER_SCRIPT_PATH = os.path.join(BASE_DIR, - '../../../python_runner/python_runner/python_action_wrapper.py') +WRAPPER_SCRIPT_PATH = os.path.join( + BASE_DIR, "../../../python_runner/python_runner/python_action_wrapper.py" +) WRAPPER_SCRIPT_PATH = os.path.abspath(WRAPPER_SCRIPT_PATH) @@ -46,24 +45,24 @@ def setUp(self): config.parse_args() dir_path = tempfile.mkdtemp() - cfg.CONF.set_override(name='base_path', override=dir_path, group='system') + cfg.CONF.set_override(name="base_path", override=dir_path, group="system") self.base_path = dir_path - self.virtualenvs_path = os.path.join(self.base_path, 'virtualenvs/') + self.virtualenvs_path = os.path.join(self.base_path, "virtualenvs/") # Make sure dir is deleted on tearDown self.to_delete_directories.append(self.base_path) def test_priority_of_loading_library_after_setup_pack_virtualenv(self): - ''' + """ This test checks priority of loading library, whether the library which is specified in the 'requirements.txt' of pack is loaded when a same name module is also specified in the 'requirements.txt' of st2, at a subprocess in ActionRunner. To test above, this uses 'get_library_path.py' action in 'test_library_dependencies' pack. This action returns file-path of imported module which is specified by 'module' parameter. - ''' - pack_name = 'test_library_dependencies' + """ + pack_name = "test_library_dependencies" # Before calling action, this sets up virtualenv for test pack. This pack has # requirements.txt wihch only writes 'six' module. @@ -72,20 +71,25 @@ def test_priority_of_loading_library_after_setup_pack_virtualenv(self): # This test suite expects that loaded six module is located under the virtualenv library, # because 'six' is written in the requirements.txt of 'test_library_dependencies' pack. - (_, output, _) = self._run_action(pack_name, 'get_library_path.py', {'module': 'six'}) - self.assertEqual(output['result'].find(self.virtualenvs_path), 0) + (_, output, _) = self._run_action( + pack_name, "get_library_path.py", {"module": "six"} + ) + self.assertEqual(output["result"].find(self.virtualenvs_path), 0) # Conversely, this expects that 'mock' module file-path is not under sandbox library, # but the parent process's library path, because that is not under the pack's virtualenv. - (_, output, _) = self._run_action(pack_name, 'get_library_path.py', {'module': 'mock'}) - self.assertEqual(output['result'].find(self.virtualenvs_path), -1) + (_, output, _) = self._run_action( + pack_name, "get_library_path.py", {"module": "mock"} + ) + self.assertEqual(output["result"].find(self.virtualenvs_path), -1) # While a module which is in the pack's virtualenv library is specified at 'module' # parameter of the action, this test suite expects that file-path under the parent's # library is returned when 'sandbox' parameter of PythonRunner is False. - (_, output, _) = self._run_action(pack_name, 'get_library_path.py', {'module': 'six'}, - {'_sandbox': False}) - self.assertEqual(output['result'].find(self.virtualenvs_path), -1) + (_, output, _) = self._run_action( + pack_name, "get_library_path.py", {"module": "six"}, {"_sandbox": False} + ) + self.assertEqual(output["result"].find(self.virtualenvs_path), -1) def _run_action(self, pack, action, params, runner_params={}): action_db = mock.Mock() @@ -99,7 +103,8 @@ def _run_action(self, pack, action, params, runner_params={}): for key, value in runner_params.items(): setattr(runner, key, value) - runner.entry_point = os.path.join(get_fixtures_base_path(), - 'packs/%s/actions/%s' % (pack, action)) + runner.entry_point = os.path.join( + get_fixtures_base_path(), "packs/%s/actions/%s" % (pack, action) + ) runner.pre_run() return runner.run(params) diff --git a/contrib/runners/python_runner/tests/unit/test_output_schema.py b/contrib/runners/python_runner/tests/unit/test_output_schema.py index 218ba669a60..218a8f0732b 100644 --- a/contrib/runners/python_runner/tests/unit/test_output_schema.py +++ b/contrib/runners/python_runner/tests/unit/test_output_schema.py @@ -33,15 +33,16 @@ BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -PASCAL_ROW_ACTION_PATH = os.path.join(tests_base.get_resources_path(), 'packs', - 'pythonactions/actions/pascal_row.py') +PASCAL_ROW_ACTION_PATH = os.path.join( + tests_base.get_resources_path(), "packs", "pythonactions/actions/pascal_row.py" +) MOCK_SYS = mock.Mock() MOCK_SYS.argv = [] MOCK_SYS.executable = sys.executable MOCK_EXECUTION = mock.Mock() -MOCK_EXECUTION.id = '598dbf0c0640fd54bffc688b' +MOCK_EXECUTION.id = "598dbf0c0640fd54bffc688b" FAIL_SCHEMA = { "notvalid": { @@ -50,7 +51,7 @@ } -@mock.patch('python_runner.python_runner.sys', MOCK_SYS) +@mock.patch("python_runner.python_runner.sys", MOCK_SYS) class PythonRunnerTestCase(RunnerTestCase, CleanDbTestCase): register_packs = True register_pack_configs = True @@ -61,29 +62,23 @@ def setUpClass(cls): assert_submodules_are_checked_out() def test_adherence_to_output_schema(self): - config = self.loader(os.path.join(BASE_DIR, '../../runner.yaml')) + config = self.loader(os.path.join(BASE_DIR, "../../runner.yaml")) runner = self._get_mock_runner_obj() runner.entry_point = PASCAL_ROW_ACTION_PATH runner.pre_run() - (status, output, _) = runner.run({'row_index': 5}) - output_schema._validate_runner( - config[0]['output_schema'], - output - ) + (status, output, _) = runner.run({"row_index": 5}) + output_schema._validate_runner(config[0]["output_schema"], output) self.assertEqual(status, LIVEACTION_STATUS_SUCCEEDED) self.assertIsNotNone(output) - self.assertEqual(output['result'], [1, 5, 10, 10, 5, 1]) + self.assertEqual(output["result"], [1, 5, 10, 10, 5, 1]) def test_fail_incorrect_output_schema(self): runner = self._get_mock_runner_obj() runner.entry_point = PASCAL_ROW_ACTION_PATH runner.pre_run() - (status, output, _) = runner.run({'row_index': 5}) + (status, output, _) = runner.run({"row_index": 5}) with self.assertRaises(jsonschema.ValidationError): - output_schema._validate_runner( - FAIL_SCHEMA, - output - ) + output_schema._validate_runner(FAIL_SCHEMA, output) def _get_mock_runner_obj(self, pack=None, sandbox=None): runner = python_runner.get_runner() @@ -106,10 +101,8 @@ def _get_mock_action_obj(self): Pack gets set to the system pack so the action doesn't require a separate virtualenv. """ action = mock.Mock() - action.ref = 'dummy.action' + action.ref = "dummy.action" action.pack = SYSTEM_PACK_NAME - action.entry_point = 'foo.py' - action.runner_type = { - 'name': 'python-script' - } + action.entry_point = "foo.py" + action.runner_type = {"name": "python-script"} return action diff --git a/contrib/runners/python_runner/tests/unit/test_pythonrunner.py b/contrib/runners/python_runner/tests/unit/test_pythonrunner.py index 8d55f8262d8..940d087af64 100644 --- a/contrib/runners/python_runner/tests/unit/test_pythonrunner.py +++ b/contrib/runners/python_runner/tests/unit/test_pythonrunner.py @@ -29,7 +29,10 @@ from st2common.runners.utils import get_action_class_instance from st2common.services import config as config_service from st2common.constants.action import ACTION_OUTPUT_RESULT_DELIMITER -from st2common.constants.action import LIVEACTION_STATUS_SUCCEEDED, LIVEACTION_STATUS_FAILED +from st2common.constants.action import ( + LIVEACTION_STATUS_SUCCEEDED, + LIVEACTION_STATUS_FAILED, +) from st2common.constants.action import LIVEACTION_STATUS_TIMED_OUT from st2common.constants.action import MAX_PARAM_LENGTH from st2common.constants.pack import SYSTEM_PACK_NAME @@ -43,29 +46,49 @@ import st2tests.base as tests_base -PASCAL_ROW_ACTION_PATH = os.path.join(tests_base.get_resources_path(), 'packs', - 'pythonactions/actions/pascal_row.py') -ECHOER_ACTION_PATH = os.path.join(tests_base.get_resources_path(), 'packs', - 'pythonactions/actions/echoer.py') -TEST_ACTION_PATH = os.path.join(tests_base.get_resources_path(), 'packs', - 'pythonactions/actions/test.py') -PATHS_ACTION_PATH = os.path.join(tests_base.get_resources_path(), 'packs', - 'pythonactions/actions/python_paths.py') -ACTION_1_PATH = os.path.join(tests_base.get_fixtures_path(), - 'packs/dummy_pack_9/actions/list_repos_doesnt_exist.py') -ACTION_2_PATH = os.path.join(tests_base.get_fixtures_path(), - 'packs/dummy_pack_9/actions/invalid_syntax.py') -NON_SIMPLE_TYPE_ACTION = os.path.join(tests_base.get_resources_path(), 'packs', - 'pythonactions/actions/non_simple_type.py') -PRINT_VERSION_ACTION = os.path.join(tests_base.get_fixtures_path(), 'packs', - 'test_content_version/actions/print_version.py') -PRINT_VERSION_LOCAL_MODULE_ACTION = os.path.join(tests_base.get_fixtures_path(), 'packs', - 'test_content_version/actions/print_version_local_import.py') - -PRINT_CONFIG_ITEM_ACTION = os.path.join(tests_base.get_resources_path(), 'packs', - 'pythonactions/actions/print_config_item_doesnt_exist.py') -PRINT_TO_STDOUT_STDERR_ACTION = os.path.join(tests_base.get_resources_path(), 'packs', - 'pythonactions/actions/print_to_stdout_and_stderr.py') +PASCAL_ROW_ACTION_PATH = os.path.join( + tests_base.get_resources_path(), "packs", "pythonactions/actions/pascal_row.py" +) +ECHOER_ACTION_PATH = os.path.join( + tests_base.get_resources_path(), "packs", "pythonactions/actions/echoer.py" +) +TEST_ACTION_PATH = os.path.join( + tests_base.get_resources_path(), "packs", "pythonactions/actions/test.py" +) +PATHS_ACTION_PATH = os.path.join( + tests_base.get_resources_path(), "packs", "pythonactions/actions/python_paths.py" +) +ACTION_1_PATH = os.path.join( + tests_base.get_fixtures_path(), + "packs/dummy_pack_9/actions/list_repos_doesnt_exist.py", +) +ACTION_2_PATH = os.path.join( + tests_base.get_fixtures_path(), "packs/dummy_pack_9/actions/invalid_syntax.py" +) +NON_SIMPLE_TYPE_ACTION = os.path.join( + tests_base.get_resources_path(), "packs", "pythonactions/actions/non_simple_type.py" +) +PRINT_VERSION_ACTION = os.path.join( + tests_base.get_fixtures_path(), + "packs", + "test_content_version/actions/print_version.py", +) +PRINT_VERSION_LOCAL_MODULE_ACTION = os.path.join( + tests_base.get_fixtures_path(), + "packs", + "test_content_version/actions/print_version_local_import.py", +) + +PRINT_CONFIG_ITEM_ACTION = os.path.join( + tests_base.get_resources_path(), + "packs", + "pythonactions/actions/print_config_item_doesnt_exist.py", +) +PRINT_TO_STDOUT_STDERR_ACTION = os.path.join( + tests_base.get_resources_path(), + "packs", + "pythonactions/actions/print_to_stdout_and_stderr.py", +) # Note: runner inherits parent args which doesn't work with tests since test pass additional @@ -75,10 +98,10 @@ mock_sys.executable = sys.executable MOCK_EXECUTION = mock.Mock() -MOCK_EXECUTION.id = '598dbf0c0640fd54bffc688b' +MOCK_EXECUTION.id = "598dbf0c0640fd54bffc688b" -@mock.patch('python_runner.python_runner.sys', mock_sys) +@mock.patch("python_runner.python_runner.sys", mock_sys) class PythonRunnerTestCase(RunnerTestCase, CleanDbTestCase): register_packs = True register_pack_configs = True @@ -90,8 +113,10 @@ def setUpClass(cls): def test_runner_creation(self): runner = python_runner.get_runner() - self.assertIsNotNone(runner, 'Creation failed. No instance.') - self.assertEqual(type(runner), python_runner.PythonRunner, 'Creation failed. No instance.') + self.assertIsNotNone(runner, "Creation failed. No instance.") + self.assertEqual( + type(runner), python_runner.PythonRunner, "Creation failed. No instance." + ) def test_action_returns_non_serializable_result(self): # Actions returns non-simple type which can't be serialized, verify result is simple str() @@ -105,33 +130,37 @@ def test_action_returns_non_serializable_result(self): self.assertIsNotNone(output) if six.PY2: - expected_result_re = (r"\[{'a': '1'}, {'h': 3, 'c': 2}, {'e': " - r"}\]") + expected_result_re = ( + r"\[{'a': '1'}, {'h': 3, 'c': 2}, {'e': " + r"}\]" + ) else: - expected_result_re = (r"\[{'a': '1'}, {'c': 2, 'h': 3}, {'e': " - r"}\]") + expected_result_re = ( + r"\[{'a': '1'}, {'c': 2, 'h': 3}, {'e': " + r"}\]" + ) - match = re.match(expected_result_re, output['result']) + match = re.match(expected_result_re, output["result"]) self.assertTrue(match) def test_simple_action_with_result_no_status(self): runner = self._get_mock_runner_obj() runner.entry_point = PASCAL_ROW_ACTION_PATH runner.pre_run() - (status, output, _) = runner.run({'row_index': 5}) + (status, output, _) = runner.run({"row_index": 5}) self.assertEqual(status, LIVEACTION_STATUS_SUCCEEDED) self.assertIsNotNone(output) - self.assertEqual(output['result'], [1, 5, 10, 10, 5, 1]) + self.assertEqual(output["result"], [1, 5, 10, 10, 5, 1]) def test_simple_action_with_result_as_None_no_status(self): runner = self._get_mock_runner_obj() runner.entry_point = PASCAL_ROW_ACTION_PATH runner.pre_run() - (status, output, _) = runner.run({'row_index': 'b'}) + (status, output, _) = runner.run({"row_index": "b"}) self.assertEqual(status, LIVEACTION_STATUS_SUCCEEDED) self.assertIsNotNone(output) - self.assertEqual(output['exit_code'], 0) - self.assertEqual(output['result'], None) + self.assertEqual(output["exit_code"], 0) + self.assertEqual(output["result"], None) def test_simple_action_timeout(self): timeout = 0 @@ -139,30 +168,30 @@ def test_simple_action_timeout(self): runner.runner_parameters = {python_runner.RUNNER_TIMEOUT: timeout} runner.entry_point = PASCAL_ROW_ACTION_PATH runner.pre_run() - (status, output, _) = runner.run({'row_index': 4}) + (status, output, _) = runner.run({"row_index": 4}) self.assertEqual(status, LIVEACTION_STATUS_TIMED_OUT) self.assertIsNotNone(output) - self.assertEqual(output['result'], 'None') - self.assertEqual(output['error'], 'Action failed to complete in 0 seconds') - self.assertEqual(output['exit_code'], -9) + self.assertEqual(output["result"], "None") + self.assertEqual(output["error"], "Action failed to complete in 0 seconds") + self.assertEqual(output["exit_code"], -9) def test_simple_action_with_status_succeeded(self): runner = self._get_mock_runner_obj() runner.entry_point = PASCAL_ROW_ACTION_PATH runner.pre_run() - (status, output, _) = runner.run({'row_index': 4}) + (status, output, _) = runner.run({"row_index": 4}) self.assertEqual(status, LIVEACTION_STATUS_SUCCEEDED) self.assertIsNotNone(output) - self.assertEqual(output['result'], [1, 4, 6, 4, 1]) + self.assertEqual(output["result"], [1, 4, 6, 4, 1]) def test_simple_action_with_status_failed(self): runner = self._get_mock_runner_obj() runner.entry_point = PASCAL_ROW_ACTION_PATH runner.pre_run() - (status, output, _) = runner.run({'row_index': 'a'}) + (status, output, _) = runner.run({"row_index": "a"}) self.assertEqual(status, LIVEACTION_STATUS_FAILED) self.assertIsNotNone(output) - self.assertEqual(output['result'], "This is suppose to fail don't worry!!") + self.assertEqual(output["result"], "This is suppose to fail don't worry!!") def test_simple_action_with_status_complex_type_returned_for_result(self): # Result containing a complex type shouldn't break the returning a tuple with status @@ -170,78 +199,79 @@ def test_simple_action_with_status_complex_type_returned_for_result(self): runner = self._get_mock_runner_obj() runner.entry_point = PASCAL_ROW_ACTION_PATH runner.pre_run() - (status, output, _) = runner.run({'row_index': 'complex_type'}) + (status, output, _) = runner.run({"row_index": "complex_type"}) self.assertEqual(status, LIVEACTION_STATUS_FAILED) self.assertIsNotNone(output) - self.assertIn('.*" % - runner.git_worktree_path) - self.assertRegexpMatches(output['stdout'].strip(), expected_stdout) + expected_stdout = ( + ".*" + % runner.git_worktree_path + ) + self.assertRegexpMatches(output["stdout"].strip(), expected_stdout) - @mock.patch('st2common.runners.base.run_command') + @mock.patch("st2common.runners.base.run_command") def test_content_version_old_git_version(self, mock_run_command): - mock_stdout = '' - mock_stderr = ''' + mock_stdout = "" + mock_stderr = """ git: 'worktree' is not a git command. See 'git --help'. -''' +""" mock_stderr = six.text_type(mock_stderr) mock_run_command.return_value = 1, mock_stdout, mock_stderr, False runner = self._get_mock_runner_obj() runner.entry_point = PASCAL_ROW_ACTION_PATH - runner.runner_parameters = {'content_version': 'v0.10.0'} + runner.runner_parameters = {"content_version": "v0.10.0"} - expected_msg = (r'Failed to create git worktree for pack "core": Installed git version ' - 'doesn\'t support git worktree command. To be able to utilize this ' - 'functionality you need to use git >= 2.5.0.') + expected_msg = ( + r'Failed to create git worktree for pack "core": Installed git version ' + "doesn't support git worktree command. To be able to utilize this " + "functionality you need to use git >= 2.5.0." + ) self.assertRaisesRegexp(ValueError, expected_msg, runner.pre_run) - @mock.patch('st2common.runners.base.run_command') + @mock.patch("st2common.runners.base.run_command") def test_content_version_pack_repo_not_git_repository(self, mock_run_command): - mock_stdout = '' - mock_stderr = ''' + mock_stdout = "" + mock_stderr = """ fatal: Not a git repository (or any parent up to mount point /home) Stopping at filesystem boundary (GIT_DISCOVERY_ACROSS_FILESYSTEM not set). -''' +""" mock_stderr = six.text_type(mock_stderr) mock_run_command.return_value = 1, mock_stdout, mock_stderr, False runner = self._get_mock_runner_obj() runner.entry_point = PASCAL_ROW_ACTION_PATH - runner.runner_parameters = {'content_version': 'v0.10.0'} - - expected_msg = (r'Failed to create git worktree for pack "core": Pack directory ' - '".*" is not a ' - 'git repository. To utilize this functionality, pack directory needs to ' - 'be a git repository.') + runner.runner_parameters = {"content_version": "v0.10.0"} + + expected_msg = ( + r'Failed to create git worktree for pack "core": Pack directory ' + '".*" is not a ' + "git repository. To utilize this functionality, pack directory needs to " + "be a git repository." + ) self.assertRaisesRegexp(ValueError, expected_msg, runner.pre_run) - @mock.patch('st2common.runners.base.run_command') + @mock.patch("st2common.runners.base.run_command") def test_content_version_invalid_git_revision(self, mock_run_command): - mock_stdout = '' - mock_stderr = ''' + mock_stdout = "" + mock_stderr = """ fatal: invalid reference: vinvalid -''' +""" mock_stderr = six.text_type(mock_stderr) mock_run_command.return_value = 1, mock_stdout, mock_stderr, False runner = self._get_mock_runner_obj() runner.entry_point = PASCAL_ROW_ACTION_PATH - runner.runner_parameters = {'content_version': 'vinvalid'} + runner.runner_parameters = {"content_version": "vinvalid"} - expected_msg = (r'Failed to create git worktree for pack "core": Invalid content_version ' - '"vinvalid" provided. Make sure that git repository is up ' - 'to date and contains that revision.') + expected_msg = ( + r'Failed to create git worktree for pack "core": Invalid content_version ' + '"vinvalid" provided. Make sure that git repository is up ' + "to date and contains that revision." + ) self.assertRaisesRegexp(ValueError, expected_msg, runner.pre_run) def test_missing_config_item_user_friendly_error(self): @@ -953,10 +1051,12 @@ def test_missing_config_item_user_friendly_error(self): self.assertEqual(status, LIVEACTION_STATUS_FAILED) self.assertIsNotNone(output) - self.assertIn('{}', output['stdout']) - self.assertIn('default_value', output['stdout']) - self.assertIn('Config for pack "core" is missing key "key"', output['stderr']) - self.assertIn('make sure you run "st2ctl reload --register-configs"', output['stderr']) + self.assertIn("{}", output["stdout"]) + self.assertIn("default_value", output["stdout"]) + self.assertIn('Config for pack "core" is missing key "key"', output["stderr"]) + self.assertIn( + 'make sure you run "st2ctl reload --register-configs"', output["stderr"] + ) def _get_mock_runner_obj(self, pack=None, sandbox=None): runner = python_runner.get_runner() @@ -972,22 +1072,25 @@ def _get_mock_runner_obj(self, pack=None, sandbox=None): return runner - @mock.patch('st2actions.container.base.ActionExecution.get', mock.Mock()) + @mock.patch("st2actions.container.base.ActionExecution.get", mock.Mock()) def _get_mock_runner_obj_from_container(self, pack, user, sandbox=None): container = RunnerContainer() runnertype_db = mock.Mock() - runnertype_db.name = 'python-script' - runnertype_db.runner_package = 'python_runner' - runnertype_db.runner_module = 'python_runner' + runnertype_db.name = "python-script" + runnertype_db.runner_package = "python_runner" + runnertype_db.runner_module = "python_runner" action_db = mock.Mock() action_db.pack = pack - action_db.entry_point = 'foo.py' + action_db.entry_point = "foo.py" liveaction_db = mock.Mock() - liveaction_db.id = '123' - liveaction_db.context = {'user': user} - runner = container._get_runner(runner_type_db=runnertype_db, action_db=action_db, - liveaction_db=liveaction_db) + liveaction_db.id = "123" + liveaction_db.context = {"user": user} + runner = container._get_runner( + runner_type_db=runnertype_db, + action_db=action_db, + liveaction_db=liveaction_db, + ) runner.execution = MOCK_EXECUTION runner.action = action_db runner.runner_parameters = {} @@ -1004,10 +1107,8 @@ def _get_mock_action_obj(self): Pack gets set to the system pack so the action doesn't require a separate virtualenv. """ action = mock.Mock() - action.ref = 'dummy.action' + action.ref = "dummy.action" action.pack = SYSTEM_PACK_NAME - action.entry_point = 'foo.py' - action.runner_type = { - 'name': 'python-script' - } + action.entry_point = "foo.py" + action.runner_type = {"name": "python-script"} return action diff --git a/contrib/runners/remote_runner/dist_utils.py b/contrib/runners/remote_runner/dist_utils.py index a6f62c8cc2a..2f2043cf29f 100644 --- a/contrib/runners/remote_runner/dist_utils.py +++ b/contrib/runners/remote_runner/dist_utils.py @@ -43,17 +43,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" __all__ = [ - 'check_pip_is_installed', - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_is_installed", + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] @@ -64,15 +64,15 @@ def check_pip_is_installed(): try: import pip # NOQA except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) return True -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ @@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'): import pip if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) return True @@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path): reqs = [] def _get_link(line): - vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+'] + vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"] for vcs_prefix in vcs_prefixes: - if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)): - req_name = re.findall('.*#egg=(.+)([&|@]).*$', line) + if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)): + req_name = re.findall(".*#egg=(.+)([&|@]).*$", line) if not req_name: - req_name = re.findall('.*#egg=(.+?)$', line) + req_name = re.findall(".*#egg=(.+?)$", line) else: req_name = req_name[0] if not req_name: - raise ValueError('Line "%s" is missing "#egg="' % (line)) + raise ValueError( + 'Line "%s" is missing "#egg="' % (line) + ) - link = line.replace('-e ', '').strip() + link = line.replace("-e ", "").strip() return link, req_name[0] return None, None - with open(requirements_file_path, 'r') as fp: + with open(requirements_file_path, "r") as fp: for line in fp.readlines(): line = line.strip() - if line.startswith('#') or not line: + if line.startswith("#") or not line: continue link, req_name = _get_link(line=line) @@ -131,8 +135,8 @@ def _get_link(line): else: req_name = line - if ';' in req_name: - req_name = req_name.split(';')[0].strip() + if ";" in req_name: + req_name = req_name.split(";")[0].strip() reqs.append(req_name) @@ -146,7 +150,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -155,14 +159,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/contrib/runners/remote_runner/remote_runner/__init__.py b/contrib/runners/remote_runner/remote_runner/__init__.py index bbe290db9a7..e6d3f15e0bd 100644 --- a/contrib/runners/remote_runner/remote_runner/__init__.py +++ b/contrib/runners/remote_runner/remote_runner/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = "3.4dev" diff --git a/contrib/runners/remote_runner/remote_runner/remote_command_runner.py b/contrib/runners/remote_runner/remote_runner/remote_command_runner.py index 09382d91257..60880a0431f 100644 --- a/contrib/runners/remote_runner/remote_runner/remote_command_runner.py +++ b/contrib/runners/remote_runner/remote_runner/remote_command_runner.py @@ -24,12 +24,7 @@ from st2common.runners.base import get_metadata as get_runner_metadata from st2common.models.system.paramiko_command_action import ParamikoRemoteCommandAction -__all__ = [ - 'ParamikoRemoteCommandRunner', - - 'get_runner', - 'get_metadata' -] +__all__ = ["ParamikoRemoteCommandRunner", "get_runner", "get_metadata"] LOG = logging.getLogger(__name__) @@ -38,42 +33,52 @@ class ParamikoRemoteCommandRunner(BaseParallelSSHRunner): def run(self, action_parameters): remote_action = self._get_remote_action(action_parameters) - LOG.debug('Executing remote command action.', extra={'_action_params': remote_action}) + LOG.debug( + "Executing remote command action.", extra={"_action_params": remote_action} + ) result = self._run(remote_action) - LOG.debug('Executed remote_action.', extra={'_result': result}) - status = self._get_result_status(result, cfg.CONF.ssh_runner.allow_partial_failure) + LOG.debug("Executed remote_action.", extra={"_result": result}) + status = self._get_result_status( + result, cfg.CONF.ssh_runner.allow_partial_failure + ) return (status, result, None) def _run(self, remote_action): command = remote_action.get_full_command_string() - return self._parallel_ssh_client.run(command, timeout=remote_action.get_timeout()) + return self._parallel_ssh_client.run( + command, timeout=remote_action.get_timeout() + ) def _get_remote_action(self, action_paramaters): # remote script actions with entry_point don't make sense, user probably wanted to use # "remote-shell-script" action if self.entry_point: - msg = ('Action "%s" specified "entry_point" attribute. Perhaps wanted to use ' - '"remote-shell-script" runner?' % (self.action_name)) + msg = ( + 'Action "%s" specified "entry_point" attribute. Perhaps wanted to use ' + '"remote-shell-script" runner?' % (self.action_name) + ) raise Exception(msg) command = self.runner_parameters.get(RUNNER_COMMAND, None) env_vars = self._get_env_vars() - return ParamikoRemoteCommandAction(self.action_name, - str(self.liveaction_id), - command, - env_vars=env_vars, - on_behalf_user=self._on_behalf_user, - user=self._username, - password=self._password, - private_key=self._private_key, - passphrase=self._passphrase, - hosts=self._hosts, - parallel=self._parallel, - sudo=self._sudo, - sudo_password=self._sudo_password, - timeout=self._timeout, - cwd=self._cwd) + return ParamikoRemoteCommandAction( + self.action_name, + str(self.liveaction_id), + command, + env_vars=env_vars, + on_behalf_user=self._on_behalf_user, + user=self._username, + password=self._password, + private_key=self._private_key, + passphrase=self._passphrase, + hosts=self._hosts, + parallel=self._parallel, + sudo=self._sudo, + sudo_password=self._sudo_password, + timeout=self._timeout, + cwd=self._cwd, + ) def get_runner(): @@ -81,7 +86,10 @@ def get_runner(): def get_metadata(): - metadata = get_runner_metadata('remote_runner') - metadata = [runner for runner in metadata if - runner['runner_module'] == __name__.split('.')[-1]][0] + metadata = get_runner_metadata("remote_runner") + metadata = [ + runner + for runner in metadata + if runner["runner_module"] == __name__.split(".")[-1] + ][0] return metadata diff --git a/contrib/runners/remote_runner/remote_runner/remote_script_runner.py b/contrib/runners/remote_runner/remote_runner/remote_script_runner.py index 292f391850c..e71e8f63146 100644 --- a/contrib/runners/remote_runner/remote_runner/remote_script_runner.py +++ b/contrib/runners/remote_runner/remote_runner/remote_script_runner.py @@ -27,12 +27,7 @@ from st2common.runners.base import get_metadata as get_runner_metadata from st2common.models.system.paramiko_script_action import ParamikoRemoteScriptAction -__all__ = [ - 'ParamikoRemoteScriptRunner', - - 'get_runner', - 'get_metadata' -] +__all__ = ["ParamikoRemoteScriptRunner", "get_runner", "get_metadata"] LOG = logging.getLogger(__name__) @@ -41,10 +36,12 @@ class ParamikoRemoteScriptRunner(BaseParallelSSHRunner): def run(self, action_parameters): remote_action = self._get_remote_action(action_parameters) - LOG.debug('Executing remote action.', extra={'_action_params': remote_action}) + LOG.debug("Executing remote action.", extra={"_action_params": remote_action}) result = self._run(remote_action) - LOG.debug('Executed remote action.', extra={'_result': result}) - status = self._get_result_status(result, cfg.CONF.ssh_runner.allow_partial_failure) + LOG.debug("Executed remote action.", extra={"_result": result}) + status = self._get_result_status( + result, cfg.CONF.ssh_runner.allow_partial_failure + ) return (status, result, None) @@ -54,109 +51,133 @@ def _run(self, remote_action): except: # If for whatever reason there is a top level exception, # we just bail here. - error = 'Failed copying content to remote boxes.' + error = "Failed copying content to remote boxes." LOG.exception(error) _, ex, tb = sys.exc_info() - copy_results = self._generate_error_results(' '.join([error, str(ex)]), tb) + copy_results = self._generate_error_results(" ".join([error, str(ex)]), tb) return copy_results try: exec_results = self._run_script_on_remote_host(remote_action) try: remote_dir = remote_action.get_remote_base_dir() - LOG.debug('Deleting remote execution dir.', extra={'_remote_dir': remote_dir}) - delete_results = self._parallel_ssh_client.delete_dir(path=remote_dir, - force=True) - LOG.debug('Deleted remote execution dir.', extra={'_result': delete_results}) + LOG.debug( + "Deleting remote execution dir.", extra={"_remote_dir": remote_dir} + ) + delete_results = self._parallel_ssh_client.delete_dir( + path=remote_dir, force=True + ) + LOG.debug( + "Deleted remote execution dir.", extra={"_result": delete_results} + ) except: - LOG.exception('Failed deleting remote dir.', extra={'_remote_dir': remote_dir}) + LOG.exception( + "Failed deleting remote dir.", extra={"_remote_dir": remote_dir} + ) finally: return exec_results except: - error = 'Failed executing script on remote boxes.' - LOG.exception(error, extra={'_action_params': remote_action}) + error = "Failed executing script on remote boxes." + LOG.exception(error, extra={"_action_params": remote_action}) _, ex, tb = sys.exc_info() - exec_results = self._generate_error_results(' '.join([error, str(ex)]), tb) + exec_results = self._generate_error_results(" ".join([error, str(ex)]), tb) return exec_results def _copy_artifacts(self, remote_action): # First create remote execution directory. remote_dir = remote_action.get_remote_base_dir() - LOG.debug('Creating remote execution dir.', extra={'_path': remote_dir}) - mkdir_result = self._parallel_ssh_client.mkdir(path=remote_action.get_remote_base_dir()) + LOG.debug("Creating remote execution dir.", extra={"_path": remote_dir}) + mkdir_result = self._parallel_ssh_client.mkdir( + path=remote_action.get_remote_base_dir() + ) # Copy the script to remote dir in remote host. local_script_abs_path = remote_action.get_local_script_abs_path() remote_script_abs_path = remote_action.get_remote_script_abs_path() file_mode = 0o744 - extra = {'_local_script': local_script_abs_path, '_remote_script': remote_script_abs_path, - 'mode': file_mode} - LOG.debug('Copying local script to remote box.', extra=extra) - put_result_1 = self._parallel_ssh_client.put(local_path=local_script_abs_path, - remote_path=remote_script_abs_path, - mirror_local_mode=False, mode=file_mode) + extra = { + "_local_script": local_script_abs_path, + "_remote_script": remote_script_abs_path, + "mode": file_mode, + } + LOG.debug("Copying local script to remote box.", extra=extra) + put_result_1 = self._parallel_ssh_client.put( + local_path=local_script_abs_path, + remote_path=remote_script_abs_path, + mirror_local_mode=False, + mode=file_mode, + ) # If `lib` exist for the script, copy that to remote host. local_libs_path = remote_action.get_local_libs_path_abs() if os.path.exists(local_libs_path): - extra = {'_local_libs': local_libs_path, '_remote_path': remote_dir} - LOG.debug('Copying libs to remote host.', extra=extra) - put_result_2 = self._parallel_ssh_client.put(local_path=local_libs_path, - remote_path=remote_dir, - mirror_local_mode=True) + extra = {"_local_libs": local_libs_path, "_remote_path": remote_dir} + LOG.debug("Copying libs to remote host.", extra=extra) + put_result_2 = self._parallel_ssh_client.put( + local_path=local_libs_path, + remote_path=remote_dir, + mirror_local_mode=True, + ) result = mkdir_result or put_result_1 or put_result_2 return result def _run_script_on_remote_host(self, remote_action): command = remote_action.get_full_command_string() - LOG.info('Command to run: %s', command) - results = self._parallel_ssh_client.run(command, timeout=remote_action.get_timeout()) - LOG.debug('Results from script: %s', results) + LOG.info("Command to run: %s", command) + results = self._parallel_ssh_client.run( + command, timeout=remote_action.get_timeout() + ) + LOG.debug("Results from script: %s", results) return results def _get_remote_action(self, action_parameters): # remote script actions without entry_point don't make sense, user probably wanted to use # "remote-shell-cmd" action if not self.entry_point: - msg = ('Action "%s" is missing "entry_point" attribute. Perhaps wanted to use ' - '"remote-shell-script" runner?' % (self.action_name)) + msg = ( + 'Action "%s" is missing "entry_point" attribute. Perhaps wanted to use ' + '"remote-shell-script" runner?' % (self.action_name) + ) raise Exception(msg) script_local_path_abs = self.entry_point pos_args, named_args = self._get_script_args(action_parameters) named_args = self._transform_named_args(named_args) env_vars = self._get_env_vars() - remote_dir = self.runner_parameters.get(RUNNER_REMOTE_DIR, - cfg.CONF.ssh_runner.remote_dir) + remote_dir = self.runner_parameters.get( + RUNNER_REMOTE_DIR, cfg.CONF.ssh_runner.remote_dir + ) remote_dir = os.path.join(remote_dir, self.liveaction_id) - return ParamikoRemoteScriptAction(self.action_name, - str(self.liveaction_id), - script_local_path_abs, - self.libs_dir_path, - named_args=named_args, - positional_args=pos_args, - env_vars=env_vars, - on_behalf_user=self._on_behalf_user, - user=self._username, - password=self._password, - private_key=self._private_key, - remote_dir=remote_dir, - hosts=self._hosts, - parallel=self._parallel, - sudo=self._sudo, - sudo_password=self._sudo_password, - timeout=self._timeout, - cwd=self._cwd) + return ParamikoRemoteScriptAction( + self.action_name, + str(self.liveaction_id), + script_local_path_abs, + self.libs_dir_path, + named_args=named_args, + positional_args=pos_args, + env_vars=env_vars, + on_behalf_user=self._on_behalf_user, + user=self._username, + password=self._password, + private_key=self._private_key, + remote_dir=remote_dir, + hosts=self._hosts, + parallel=self._parallel, + sudo=self._sudo, + sudo_password=self._sudo_password, + timeout=self._timeout, + cwd=self._cwd, + ) @staticmethod def _generate_error_results(error, tb): error_dict = { - 'error': error, - 'traceback': ''.join(traceback.format_tb(tb, 20)) if tb else '', - 'failed': True, - 'succeeded': False, - 'return_code': 255 + "error": error, + "traceback": "".join(traceback.format_tb(tb, 20)) if tb else "", + "failed": True, + "succeeded": False, + "return_code": 255, } return error_dict @@ -166,7 +187,10 @@ def get_runner(): def get_metadata(): - metadata = get_runner_metadata('remote_runner') - metadata = [runner for runner in metadata if - runner['runner_module'] == __name__.split('.')[-1]][0] + metadata = get_runner_metadata("remote_runner") + metadata = [ + runner + for runner in metadata + if runner["runner_module"] == __name__.split(".")[-1] + ][0] return metadata diff --git a/contrib/runners/remote_runner/setup.py b/contrib/runners/remote_runner/setup.py index cdd61b68b14..3e83437aff1 100644 --- a/contrib/runners/remote_runner/setup.py +++ b/contrib/runners/remote_runner/setup.py @@ -26,32 +26,34 @@ from remote_runner import __version__ BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt') +REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt") install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE) apply_vagrant_workaround() setup( - name='stackstorm-runner-remote', + name="stackstorm-runner-remote", version=__version__, - description=('Remote SSH shell command and script action runner for StackStorm event-driven ' - 'automation platform'), - author='StackStorm', - author_email='info@stackstorm.com', - license='Apache License (2.0)', - url='https://stackstorm.com/', + description=( + "Remote SSH shell command and script action runner for StackStorm event-driven " + "automation platform" + ), + author="StackStorm", + author_email="info@stackstorm.com", + license="Apache License (2.0)", + url="https://stackstorm.com/", install_requires=install_reqs, dependency_links=dep_links, - test_suite='tests', + test_suite="tests", zip_safe=False, include_package_data=True, - packages=find_packages(exclude=['setuptools', 'tests']), - package_data={'remote_runner': ['runner.yaml']}, + packages=find_packages(exclude=["setuptools", "tests"]), + package_data={"remote_runner": ["runner.yaml"]}, scripts=[], entry_points={ - 'st2common.runners.runner': [ - 'remote-shell-cmd = remote_runner.remote_command_runner', - 'remote-shell-script = remote_runner.remote_script_runner', + "st2common.runners.runner": [ + "remote-shell-cmd = remote_runner.remote_command_runner", + "remote-shell-script = remote_runner.remote_script_runner", ], - } + }, ) diff --git a/contrib/runners/winrm_runner/dist_utils.py b/contrib/runners/winrm_runner/dist_utils.py index a6f62c8cc2a..2f2043cf29f 100644 --- a/contrib/runners/winrm_runner/dist_utils.py +++ b/contrib/runners/winrm_runner/dist_utils.py @@ -43,17 +43,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" __all__ = [ - 'check_pip_is_installed', - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_is_installed", + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] @@ -64,15 +64,15 @@ def check_pip_is_installed(): try: import pip # NOQA except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) return True -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ @@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'): import pip if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) return True @@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path): reqs = [] def _get_link(line): - vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+'] + vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"] for vcs_prefix in vcs_prefixes: - if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)): - req_name = re.findall('.*#egg=(.+)([&|@]).*$', line) + if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)): + req_name = re.findall(".*#egg=(.+)([&|@]).*$", line) if not req_name: - req_name = re.findall('.*#egg=(.+?)$', line) + req_name = re.findall(".*#egg=(.+?)$", line) else: req_name = req_name[0] if not req_name: - raise ValueError('Line "%s" is missing "#egg="' % (line)) + raise ValueError( + 'Line "%s" is missing "#egg="' % (line) + ) - link = line.replace('-e ', '').strip() + link = line.replace("-e ", "").strip() return link, req_name[0] return None, None - with open(requirements_file_path, 'r') as fp: + with open(requirements_file_path, "r") as fp: for line in fp.readlines(): line = line.strip() - if line.startswith('#') or not line: + if line.startswith("#") or not line: continue link, req_name = _get_link(line=line) @@ -131,8 +135,8 @@ def _get_link(line): else: req_name = line - if ';' in req_name: - req_name = req_name.split(';')[0].strip() + if ";" in req_name: + req_name = req_name.split(";")[0].strip() reqs.append(req_name) @@ -146,7 +150,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -155,14 +159,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/contrib/runners/winrm_runner/setup.py b/contrib/runners/winrm_runner/setup.py index f3f014277b7..53d7b952e1b 100644 --- a/contrib/runners/winrm_runner/setup.py +++ b/contrib/runners/winrm_runner/setup.py @@ -26,33 +26,35 @@ from winrm_runner import __version__ BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt') +REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt") install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE) apply_vagrant_workaround() setup( - name='stackstorm-runner-winrm', + name="stackstorm-runner-winrm", version=__version__, - description=('WinRM shell command and PowerShell script action runner for' - ' the StackStorm event-driven automation platform'), - author='StackStorm', - author_email='info@stackstorm.com', - license='Apache License (2.0)', - url='https://stackstorm.com/', + description=( + "WinRM shell command and PowerShell script action runner for" + " the StackStorm event-driven automation platform" + ), + author="StackStorm", + author_email="info@stackstorm.com", + license="Apache License (2.0)", + url="https://stackstorm.com/", install_requires=install_reqs, dependency_links=dep_links, - test_suite='tests', + test_suite="tests", zip_safe=False, include_package_data=True, - packages=find_packages(exclude=['setuptools', 'tests']), - package_data={'winrm_runner': ['runner.yaml']}, + packages=find_packages(exclude=["setuptools", "tests"]), + package_data={"winrm_runner": ["runner.yaml"]}, scripts=[], entry_points={ - 'st2common.runners.runner': [ - 'winrm-cmd = winrm_runner.winrm_command_runner', - 'winrm-ps-cmd = winrm_runner.winrm_ps_command_runner', - 'winrm-ps-script = winrm_runner.winrm_ps_script_runner', + "st2common.runners.runner": [ + "winrm-cmd = winrm_runner.winrm_command_runner", + "winrm-ps-cmd = winrm_runner.winrm_ps_command_runner", + "winrm-ps-script = winrm_runner.winrm_ps_script_runner", ], - } + }, ) diff --git a/contrib/runners/winrm_runner/tests/unit/test_winrm_base.py b/contrib/runners/winrm_runner/tests/unit/test_winrm_base.py index 0803b3e25af..1ff9f2ce1dc 100644 --- a/contrib/runners/winrm_runner/tests/unit/test_winrm_base.py +++ b/contrib/runners/winrm_runner/tests/unit/test_winrm_base.py @@ -32,157 +32,170 @@ class WinRmBaseTestCase(RunnerTestCase): - def setUp(self): super(WinRmBaseTestCase, self).setUpClass() self._runner = winrm_ps_command_runner.get_runner() def _init_runner(self): - runner_parameters = {'host': 'host@domain.tld', - 'username': 'user@domain.tld', - 'password': 'xyz987'} + runner_parameters = { + "host": "host@domain.tld", + "username": "user@domain.tld", + "password": "xyz987", + } self._runner.runner_parameters = runner_parameters self._runner.pre_run() def test_win_rm_runner_timout_error(self): - error = WinRmRunnerTimoutError('test_response') + error = WinRmRunnerTimoutError("test_response") self.assertIsInstance(error, Exception) - self.assertEqual(error.response, 'test_response') + self.assertEqual(error.response, "test_response") with self.assertRaises(WinRmRunnerTimoutError): - raise WinRmRunnerTimoutError('test raising') + raise WinRmRunnerTimoutError("test raising") def test_init(self): - runner = winrm_ps_command_runner.WinRmPsCommandRunner('abcdef') + runner = winrm_ps_command_runner.WinRmPsCommandRunner("abcdef") self.assertIsInstance(runner, WinRmBaseRunner) self.assertIsInstance(runner, ActionRunner) self.assertEqual(runner.runner_id, "abcdef") - @mock.patch('winrm_runner.winrm_base.ActionRunner.pre_run') + @mock.patch("winrm_runner.winrm_base.ActionRunner.pre_run") def test_pre_run(self, mock_pre_run): - runner_parameters = {'host': 'host@domain.tld', - 'username': 'user@domain.tld', - 'password': 'abc123', - 'timeout': 99, - 'port': 1234, - 'scheme': 'http', - 'transport': 'ntlm', - 'verify_ssl_cert': False, - 'cwd': 'C:\\Test', - 'env': {'TEST_VAR': 'TEST_VALUE'}, - 'kwarg_op': '/'} + runner_parameters = { + "host": "host@domain.tld", + "username": "user@domain.tld", + "password": "abc123", + "timeout": 99, + "port": 1234, + "scheme": "http", + "transport": "ntlm", + "verify_ssl_cert": False, + "cwd": "C:\\Test", + "env": {"TEST_VAR": "TEST_VALUE"}, + "kwarg_op": "/", + } self._runner.runner_parameters = runner_parameters self._runner.pre_run() mock_pre_run.assert_called_with() self.assertEqual(self._runner._session, None) - self.assertEqual(self._runner._host, 'host@domain.tld') - self.assertEqual(self._runner._username, 'user@domain.tld') - self.assertEqual(self._runner._password, 'abc123') + self.assertEqual(self._runner._host, "host@domain.tld") + self.assertEqual(self._runner._username, "user@domain.tld") + self.assertEqual(self._runner._password, "abc123") self.assertEqual(self._runner._timeout, 99) self.assertEqual(self._runner._read_timeout, 100) self.assertEqual(self._runner._port, 1234) - self.assertEqual(self._runner._scheme, 'http') - self.assertEqual(self._runner._transport, 'ntlm') - self.assertEqual(self._runner._winrm_url, 'http://host@domain.tld:1234/wsman') + self.assertEqual(self._runner._scheme, "http") + self.assertEqual(self._runner._transport, "ntlm") + self.assertEqual(self._runner._winrm_url, "http://host@domain.tld:1234/wsman") self.assertEqual(self._runner._verify_ssl, False) - self.assertEqual(self._runner._server_cert_validation, 'ignore') - self.assertEqual(self._runner._cwd, 'C:\\Test') - self.assertEqual(self._runner._env, {'TEST_VAR': 'TEST_VALUE'}) - self.assertEqual(self._runner._kwarg_op, '/') + self.assertEqual(self._runner._server_cert_validation, "ignore") + self.assertEqual(self._runner._cwd, "C:\\Test") + self.assertEqual(self._runner._env, {"TEST_VAR": "TEST_VALUE"}) + self.assertEqual(self._runner._kwarg_op, "/") - @mock.patch('winrm_runner.winrm_base.ActionRunner.pre_run') + @mock.patch("winrm_runner.winrm_base.ActionRunner.pre_run") def test_pre_run_defaults(self, mock_pre_run): - runner_parameters = {'host': 'host@domain.tld', - 'username': 'user@domain.tld', - 'password': 'abc123'} + runner_parameters = { + "host": "host@domain.tld", + "username": "user@domain.tld", + "password": "abc123", + } self._runner.runner_parameters = runner_parameters self._runner.pre_run() mock_pre_run.assert_called_with() - self.assertEqual(self._runner._host, 'host@domain.tld') - self.assertEqual(self._runner._username, 'user@domain.tld') - self.assertEqual(self._runner._password, 'abc123') + self.assertEqual(self._runner._host, "host@domain.tld") + self.assertEqual(self._runner._username, "user@domain.tld") + self.assertEqual(self._runner._password, "abc123") self.assertEqual(self._runner._timeout, 60) self.assertEqual(self._runner._read_timeout, 61) self.assertEqual(self._runner._port, 5986) - self.assertEqual(self._runner._scheme, 'https') - self.assertEqual(self._runner._transport, 'ntlm') - self.assertEqual(self._runner._winrm_url, 'https://host@domain.tld:5986/wsman') + self.assertEqual(self._runner._scheme, "https") + self.assertEqual(self._runner._transport, "ntlm") + self.assertEqual(self._runner._winrm_url, "https://host@domain.tld:5986/wsman") self.assertEqual(self._runner._verify_ssl, True) - self.assertEqual(self._runner._server_cert_validation, 'validate') + self.assertEqual(self._runner._server_cert_validation, "validate") self.assertEqual(self._runner._cwd, None) self.assertEqual(self._runner._env, {}) - self.assertEqual(self._runner._kwarg_op, '-') + self.assertEqual(self._runner._kwarg_op, "-") - @mock.patch('winrm_runner.winrm_base.ActionRunner.pre_run') + @mock.patch("winrm_runner.winrm_base.ActionRunner.pre_run") def test_pre_run_5985_force_http(self, mock_pre_run): - runner_parameters = {'host': 'host@domain.tld', - 'username': 'user@domain.tld', - 'password': 'abc123', - 'port': 5985, - 'scheme': 'https'} + runner_parameters = { + "host": "host@domain.tld", + "username": "user@domain.tld", + "password": "abc123", + "port": 5985, + "scheme": "https", + } self._runner.runner_parameters = runner_parameters self._runner.pre_run() mock_pre_run.assert_called_with() - self.assertEqual(self._runner._host, 'host@domain.tld') - self.assertEqual(self._runner._username, 'user@domain.tld') - self.assertEqual(self._runner._password, 'abc123') + self.assertEqual(self._runner._host, "host@domain.tld") + self.assertEqual(self._runner._username, "user@domain.tld") + self.assertEqual(self._runner._password, "abc123") self.assertEqual(self._runner._timeout, 60) self.assertEqual(self._runner._read_timeout, 61) # ensure port is still 5985 self.assertEqual(self._runner._port, 5985) # ensure scheme is set back to http - self.assertEqual(self._runner._scheme, 'http') - self.assertEqual(self._runner._transport, 'ntlm') - self.assertEqual(self._runner._winrm_url, 'http://host@domain.tld:5985/wsman') + self.assertEqual(self._runner._scheme, "http") + self.assertEqual(self._runner._transport, "ntlm") + self.assertEqual(self._runner._winrm_url, "http://host@domain.tld:5985/wsman") self.assertEqual(self._runner._verify_ssl, True) - self.assertEqual(self._runner._server_cert_validation, 'validate') + self.assertEqual(self._runner._server_cert_validation, "validate") self.assertEqual(self._runner._cwd, None) self.assertEqual(self._runner._env, {}) - self.assertEqual(self._runner._kwarg_op, '-') + self.assertEqual(self._runner._kwarg_op, "-") - @mock.patch('winrm_runner.winrm_base.ActionRunner.pre_run') + @mock.patch("winrm_runner.winrm_base.ActionRunner.pre_run") def test_pre_run_none_env(self, mock_pre_run): - runner_parameters = {'host': 'host@domain.tld', - 'username': 'user@domain.tld', - 'password': 'abc123', - 'env': None} + runner_parameters = { + "host": "host@domain.tld", + "username": "user@domain.tld", + "password": "abc123", + "env": None, + } self._runner.runner_parameters = runner_parameters self._runner.pre_run() mock_pre_run.assert_called_with() # ensure that env is set to {} even though we passed in None self.assertEqual(self._runner._env, {}) - @mock.patch('winrm_runner.winrm_base.ActionRunner.pre_run') + @mock.patch("winrm_runner.winrm_base.ActionRunner.pre_run") def test_pre_run_ssl_verify_true(self, mock_pre_run): - runner_parameters = {'host': 'host@domain.tld', - 'username': 'user@domain.tld', - 'password': 'abc123', - 'verify_ssl_cert': True} + runner_parameters = { + "host": "host@domain.tld", + "username": "user@domain.tld", + "password": "abc123", + "verify_ssl_cert": True, + } self._runner.runner_parameters = runner_parameters self._runner.pre_run() mock_pre_run.assert_called_with() self.assertEqual(self._runner._verify_ssl, True) - self.assertEqual(self._runner._server_cert_validation, 'validate') + self.assertEqual(self._runner._server_cert_validation, "validate") - @mock.patch('winrm_runner.winrm_base.ActionRunner.pre_run') + @mock.patch("winrm_runner.winrm_base.ActionRunner.pre_run") def test_pre_run_ssl_verify_false(self, mock_pre_run): - runner_parameters = {'host': 'host@domain.tld', - 'username': 'user@domain.tld', - 'password': 'abc123', - 'verify_ssl_cert': False} + runner_parameters = { + "host": "host@domain.tld", + "username": "user@domain.tld", + "password": "abc123", + "verify_ssl_cert": False, + } self._runner.runner_parameters = runner_parameters self._runner.pre_run() mock_pre_run.assert_called_with() self.assertEqual(self._runner._verify_ssl, False) - self.assertEqual(self._runner._server_cert_validation, 'ignore') + self.assertEqual(self._runner._server_cert_validation, "ignore") - @mock.patch('winrm_runner.winrm_base.Session') + @mock.patch("winrm_runner.winrm_base.Session") def test_get_session(self, mock_session): self._runner._session = None - self._runner._winrm_url = 'https://host@domain.tld:5986/wsman' - self._runner._username = 'user@domain.tld' - self._runner._password = 'abc123' - self._runner._transport = 'ntlm' - self._runner._server_cert_validation = 'validate' + self._runner._winrm_url = "https://host@domain.tld:5986/wsman" + self._runner._username = "user@domain.tld" + self._runner._password = "abc123" + self._runner._transport = "ntlm" + self._runner._server_cert_validation = "validate" self._runner._timeout = 60 self._runner._read_timeout = 61 mock_session.return_value = "session" @@ -190,12 +203,14 @@ def test_get_session(self, mock_session): result = self._runner._get_session() self.assertEqual(result, "session") self.assertEqual(result, self._runner._session) - mock_session.assert_called_with('https://host@domain.tld:5986/wsman', - auth=('user@domain.tld', 'abc123'), - transport='ntlm', - server_cert_validation='validate', - operation_timeout_sec=60, - read_timeout_sec=61) + mock_session.assert_called_with( + "https://host@domain.tld:5986/wsman", + auth=("user@domain.tld", "abc123"), + transport="ntlm", + server_cert_validation="validate", + operation_timeout_sec=60, + read_timeout_sec=61, + ) # ensure calling _get_session again doesn't create a new one, it reuses the existing old_session = self._runner._session @@ -206,18 +221,18 @@ def test_winrm_get_command_output(self): self._runner._timeout = 0 mock_protocol = mock.MagicMock() mock_protocol._raw_get_command_output.side_effect = [ - (b'output1', b'error1', 123, False), - (b'output2', b'error2', 456, False), - (b'output3', b'error3', 789, True) + (b"output1", b"error1", 123, False), + (b"output2", b"error2", 456, False), + (b"output3", b"error3", 789, True), ] result = self._runner._winrm_get_command_output(mock_protocol, 567, 890) - self.assertEqual(result, (b'output1output2output3', b'error1error2error3', 789)) + self.assertEqual(result, (b"output1output2output3", b"error1error2error3", 789)) mock_protocol._raw_get_command_output.assert_has_calls = [ mock.call(567, 890), mock.call(567, 890), - mock.call(567, 890) + mock.call(567, 890), ] def test_winrm_get_command_output_timeout(self): @@ -227,7 +242,7 @@ def test_winrm_get_command_output_timeout(self): def sleep_for_timeout(*args, **kwargs): time.sleep(0.2) - return (b'output1', b'error1', 123, False) + return (b"output1", b"error1", 123, False) mock_protocol._raw_get_command_output.side_effect = sleep_for_timeout @@ -235,9 +250,11 @@ def sleep_for_timeout(*args, **kwargs): self._runner._winrm_get_command_output(mock_protocol, 567, 890) timeout_exception = cm.exception - self.assertEqual(timeout_exception.response.std_out, b'output1') - self.assertEqual(timeout_exception.response.std_err, b'error1') - self.assertEqual(timeout_exception.response.status_code, WINRM_TIMEOUT_EXIT_CODE) + self.assertEqual(timeout_exception.response.std_out, b"output1") + self.assertEqual(timeout_exception.response.std_err, b"error1") + self.assertEqual( + timeout_exception.response.status_code, WINRM_TIMEOUT_EXIT_CODE + ) mock_protocol._raw_get_command_output.assert_called_with(567, 890) def test_winrm_get_command_output_operation_timeout(self): @@ -255,292 +272,354 @@ def sleep_for_timeout_then_raise(*args, **kwargs): self._runner._winrm_get_command_output(mock_protocol, 567, 890) timeout_exception = cm.exception - self.assertEqual(timeout_exception.response.std_out, b'') - self.assertEqual(timeout_exception.response.std_err, b'') - self.assertEqual(timeout_exception.response.status_code, WINRM_TIMEOUT_EXIT_CODE) + self.assertEqual(timeout_exception.response.std_out, b"") + self.assertEqual(timeout_exception.response.std_err, b"") + self.assertEqual( + timeout_exception.response.status_code, WINRM_TIMEOUT_EXIT_CODE + ) mock_protocol._raw_get_command_output.assert_called_with(567, 890) def test_winrm_run_cmd(self): mock_protocol = mock.MagicMock() mock_protocol.open_shell.return_value = 123 mock_protocol.run_command.return_value = 456 - mock_protocol._raw_get_command_output.return_value = (b'output', b'error', 9, True) + mock_protocol._raw_get_command_output.return_value = ( + b"output", + b"error", + 9, + True, + ) mock_session = mock.MagicMock(protocol=mock_protocol) self._init_runner() - result = self._runner._winrm_run_cmd(mock_session, "fake-command", - args=['arg1', 'arg2'], - env={'PATH': 'C:\\st2\\bin'}, - cwd='C:\\st2') - expected_response = Response((b'output', b'error', 9)) + result = self._runner._winrm_run_cmd( + mock_session, + "fake-command", + args=["arg1", "arg2"], + env={"PATH": "C:\\st2\\bin"}, + cwd="C:\\st2", + ) + expected_response = Response((b"output", b"error", 9)) expected_response.timeout = False self.assertEqual(result.__dict__, expected_response.__dict__) - mock_protocol.open_shell.assert_called_with(env_vars={'PATH': 'C:\\st2\\bin'}, - working_directory='C:\\st2') - mock_protocol.run_command.assert_called_with(123, 'fake-command', ['arg1', 'arg2']) + mock_protocol.open_shell.assert_called_with( + env_vars={"PATH": "C:\\st2\\bin"}, working_directory="C:\\st2" + ) + mock_protocol.run_command.assert_called_with( + 123, "fake-command", ["arg1", "arg2"] + ) mock_protocol._raw_get_command_output.assert_called_with(123, 456) mock_protocol.cleanup_command.assert_called_with(123, 456) mock_protocol.close_shell.assert_called_with(123) - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._winrm_get_command_output') + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._winrm_get_command_output") def test_winrm_run_cmd_timeout(self, mock_get_command_output): mock_protocol = mock.MagicMock() mock_protocol.open_shell.return_value = 123 mock_protocol.run_command.return_value = 456 mock_session = mock.MagicMock(protocol=mock_protocol) - mock_get_command_output.side_effect = WinRmRunnerTimoutError(Response(('', '', 5))) + mock_get_command_output.side_effect = WinRmRunnerTimoutError( + Response(("", "", 5)) + ) self._init_runner() - result = self._runner._winrm_run_cmd(mock_session, "fake-command", - args=['arg1', 'arg2'], - env={'PATH': 'C:\\st2\\bin'}, - cwd='C:\\st2') - expected_response = Response(('', '', 5)) + result = self._runner._winrm_run_cmd( + mock_session, + "fake-command", + args=["arg1", "arg2"], + env={"PATH": "C:\\st2\\bin"}, + cwd="C:\\st2", + ) + expected_response = Response(("", "", 5)) expected_response.timeout = True self.assertEqual(result.__dict__, expected_response.__dict__) - mock_protocol.open_shell.assert_called_with(env_vars={'PATH': 'C:\\st2\\bin'}, - working_directory='C:\\st2') - mock_protocol.run_command.assert_called_with(123, 'fake-command', ['arg1', 'arg2']) + mock_protocol.open_shell.assert_called_with( + env_vars={"PATH": "C:\\st2\\bin"}, working_directory="C:\\st2" + ) + mock_protocol.run_command.assert_called_with( + 123, "fake-command", ["arg1", "arg2"] + ) mock_protocol.cleanup_command.assert_called_with(123, 456) mock_protocol.close_shell.assert_called_with(123) def test_winrm_encode(self): - result = self._runner._winrm_encode('hello world') + result = self._runner._winrm_encode("hello world") # result translated into UTF-16 little-endian - self.assertEqual(result, 'aABlAGwAbABvACAAdwBvAHIAbABkAA==') + self.assertEqual(result, "aABlAGwAbABvACAAdwBvAHIAbABkAA==") def test_winrm_ps_cmd(self): - result = self._runner._winrm_ps_cmd('abc123==') - self.assertEqual(result, 'powershell -encodedcommand abc123==') + result = self._runner._winrm_ps_cmd("abc123==") + self.assertEqual(result, "powershell -encodedcommand abc123==") - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._winrm_run_cmd') + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._winrm_run_cmd") def test_winrm_run_ps(self, mock_run_cmd): - mock_run_cmd.return_value = Response(('output', '', 3)) + mock_run_cmd.return_value = Response(("output", "", 3)) script = "Get-ADUser stanley" - result = self._runner._winrm_run_ps("session", script, - env={'PATH': 'C:\\st2\\bin'}, - cwd='C:\\st2') + result = self._runner._winrm_run_ps( + "session", script, env={"PATH": "C:\\st2\\bin"}, cwd="C:\\st2" + ) - self.assertEqual(result.__dict__, - Response(('output', '', 3)).__dict__) - expected_ps = ('powershell -encodedcommand ' + - b64encode("Get-ADUser stanley".encode('utf_16_le')).decode('ascii')) - mock_run_cmd.assert_called_with("session", - expected_ps, - env={'PATH': 'C:\\st2\\bin'}, - cwd='C:\\st2') + self.assertEqual(result.__dict__, Response(("output", "", 3)).__dict__) + expected_ps = "powershell -encodedcommand " + b64encode( + "Get-ADUser stanley".encode("utf_16_le") + ).decode("ascii") + mock_run_cmd.assert_called_with( + "session", expected_ps, env={"PATH": "C:\\st2\\bin"}, cwd="C:\\st2" + ) - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._winrm_run_cmd') + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._winrm_run_cmd") def test_winrm_run_ps_clean_stderr(self, mock_run_cmd): - mock_run_cmd.return_value = Response(('output', 'error', 3)) + mock_run_cmd.return_value = Response(("output", "error", 3)) mock_session = mock.MagicMock() - mock_session._clean_error_msg.return_value = 'e' + mock_session._clean_error_msg.return_value = "e" script = "Get-ADUser stanley" - result = self._runner._winrm_run_ps(mock_session, script, - env={'PATH': 'C:\\st2\\bin'}, - cwd='C:\\st2') + result = self._runner._winrm_run_ps( + mock_session, script, env={"PATH": "C:\\st2\\bin"}, cwd="C:\\st2" + ) - self.assertEqual(result.__dict__, - Response(('output', 'e', 3)).__dict__) - expected_ps = ('powershell -encodedcommand ' + - b64encode("Get-ADUser stanley".encode('utf_16_le')).decode('ascii')) - mock_run_cmd.assert_called_with(mock_session, - expected_ps, - env={'PATH': 'C:\\st2\\bin'}, - cwd='C:\\st2') - mock_session._clean_error_msg.assert_called_with('error') + self.assertEqual(result.__dict__, Response(("output", "e", 3)).__dict__) + expected_ps = "powershell -encodedcommand " + b64encode( + "Get-ADUser stanley".encode("utf_16_le") + ).decode("ascii") + mock_run_cmd.assert_called_with( + mock_session, expected_ps, env={"PATH": "C:\\st2\\bin"}, cwd="C:\\st2" + ) + mock_session._clean_error_msg.assert_called_with("error") def test_translate_response_success(self): - response = Response(('output1', 'error1', 0)) + response = Response(("output1", "error1", 0)) response.timeout = False result = self._runner._translate_response(response) - self.assertEqual(result, ('succeeded', - {'failed': False, - 'succeeded': True, - 'return_code': 0, - 'stdout': 'output1', - 'stderr': 'error1'}, - None)) + self.assertEqual( + result, + ( + "succeeded", + { + "failed": False, + "succeeded": True, + "return_code": 0, + "stdout": "output1", + "stderr": "error1", + }, + None, + ), + ) def test_translate_response_failure(self): - response = Response(('output1', 'error1', 123)) + response = Response(("output1", "error1", 123)) response.timeout = False result = self._runner._translate_response(response) - self.assertEqual(result, ('failed', - {'failed': True, - 'succeeded': False, - 'return_code': 123, - 'stdout': 'output1', - 'stderr': 'error1'}, - None)) + self.assertEqual( + result, + ( + "failed", + { + "failed": True, + "succeeded": False, + "return_code": 123, + "stdout": "output1", + "stderr": "error1", + }, + None, + ), + ) def test_translate_response_timeout(self): - response = Response(('output1', 'error1', 123)) + response = Response(("output1", "error1", 123)) response.timeout = True result = self._runner._translate_response(response) - self.assertEqual(result, ('timeout', - {'failed': True, - 'succeeded': False, - 'return_code': -1, - 'stdout': 'output1', - 'stderr': 'error1'}, - None)) - - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._run_ps_or_raise') + self.assertEqual( + result, + ( + "timeout", + { + "failed": True, + "succeeded": False, + "return_code": -1, + "stdout": "output1", + "stderr": "error1", + }, + None, + ), + ) + + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._run_ps_or_raise") def test_make_tmp_dir(self, mock_run_ps_or_raise): - mock_run_ps_or_raise.return_value = {'stdout': ' expected \n'} + mock_run_ps_or_raise.return_value = {"stdout": " expected \n"} - result = self._runner._make_tmp_dir('C:\\Windows\\Temp') - self.assertEqual(result, 'expected') - mock_run_ps_or_raise.assert_called_with('''$parent = C:\\Windows\\Temp + result = self._runner._make_tmp_dir("C:\\Windows\\Temp") + self.assertEqual(result, "expected") + mock_run_ps_or_raise.assert_called_with( + """$parent = C:\\Windows\\Temp $name = [System.IO.Path]::GetRandomFileName() $path = Join-Path $parent $name New-Item -ItemType Directory -Path $path | Out-Null -$path''', - ("Unable to make temporary directory for" - " powershell script")) +$path""", + ("Unable to make temporary directory for" " powershell script"), + ) - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._run_ps_or_raise') + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._run_ps_or_raise") def test_rm_dir(self, mock_run_ps_or_raise): - self._runner._rm_dir('C:\\Windows\\Temp\\testtmpdir') + self._runner._rm_dir("C:\\Windows\\Temp\\testtmpdir") mock_run_ps_or_raise.assert_called_with( 'Remove-Item -Force -Recurse -Path "C:\\Windows\\Temp\\testtmpdir"', - "Unable to remove temporary directory for powershell script") + "Unable to remove temporary directory for powershell script", + ) - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._upload_chunk') - @mock.patch('winrm_runner.winrm_base.open') - @mock.patch('os.path.exists') + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._upload_chunk") + @mock.patch("winrm_runner.winrm_base.open") + @mock.patch("os.path.exists") def test_upload_chunk_file(self, mock_os_path_exists, mock_open, mock_upload_chunk): mock_os_path_exists.return_value = True mock_src_file = mock.MagicMock() mock_src_file.read.return_value = "test data" mock_open.return_value.__enter__.return_value = mock_src_file - self._runner._upload('/opt/data/test.ps1', 'C:\\Windows\\Temp\\test.ps1') - mock_os_path_exists.assert_called_with('/opt/data/test.ps1') - mock_open.assert_called_with('/opt/data/test.ps1', 'r') + self._runner._upload("/opt/data/test.ps1", "C:\\Windows\\Temp\\test.ps1") + mock_os_path_exists.assert_called_with("/opt/data/test.ps1") + mock_open.assert_called_with("/opt/data/test.ps1", "r") mock_src_file.read.assert_called_with() - mock_upload_chunk.assert_has_calls([ - mock.call('C:\\Windows\\Temp\\test.ps1', 'test data') - ]) + mock_upload_chunk.assert_has_calls( + [mock.call("C:\\Windows\\Temp\\test.ps1", "test data")] + ) - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._upload_chunk') - @mock.patch('os.path.exists') + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._upload_chunk") + @mock.patch("os.path.exists") def test_upload_chunk_data(self, mock_os_path_exists, mock_upload_chunk): mock_os_path_exists.return_value = False - self._runner._upload('test data', 'C:\\Windows\\Temp\\test.ps1') - mock_os_path_exists.assert_called_with('test data') - mock_upload_chunk.assert_has_calls([ - mock.call('C:\\Windows\\Temp\\test.ps1', 'test data') - ]) + self._runner._upload("test data", "C:\\Windows\\Temp\\test.ps1") + mock_os_path_exists.assert_called_with("test data") + mock_upload_chunk.assert_has_calls( + [mock.call("C:\\Windows\\Temp\\test.ps1", "test data")] + ) - @mock.patch('winrm_runner.winrm_base.WINRM_UPLOAD_CHUNK_SIZE_BYTES', 2) - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._upload_chunk') - @mock.patch('os.path.exists') + @mock.patch("winrm_runner.winrm_base.WINRM_UPLOAD_CHUNK_SIZE_BYTES", 2) + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._upload_chunk") + @mock.patch("os.path.exists") def test_upload_chunk_multiple_chunks(self, mock_os_path_exists, mock_upload_chunk): mock_os_path_exists.return_value = False - self._runner._upload('test data', 'C:\\Windows\\Temp\\test.ps1') - mock_os_path_exists.assert_called_with('test data') - mock_upload_chunk.assert_has_calls([ - mock.call('C:\\Windows\\Temp\\test.ps1', 'te'), - mock.call('C:\\Windows\\Temp\\test.ps1', 'st'), - mock.call('C:\\Windows\\Temp\\test.ps1', ' d'), - mock.call('C:\\Windows\\Temp\\test.ps1', 'at'), - mock.call('C:\\Windows\\Temp\\test.ps1', 'a'), - ]) - - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._run_ps_or_raise') + self._runner._upload("test data", "C:\\Windows\\Temp\\test.ps1") + mock_os_path_exists.assert_called_with("test data") + mock_upload_chunk.assert_has_calls( + [ + mock.call("C:\\Windows\\Temp\\test.ps1", "te"), + mock.call("C:\\Windows\\Temp\\test.ps1", "st"), + mock.call("C:\\Windows\\Temp\\test.ps1", " d"), + mock.call("C:\\Windows\\Temp\\test.ps1", "at"), + mock.call("C:\\Windows\\Temp\\test.ps1", "a"), + ] + ) + + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._run_ps_or_raise") def test_upload_chunk(self, mock_run_ps_or_raise): - self._runner._upload_chunk('C:\\Windows\\Temp\\testtmp.ps1', 'hello world') + self._runner._upload_chunk("C:\\Windows\\Temp\\testtmp.ps1", "hello world") mock_run_ps_or_raise.assert_called_with( - '''$filePath = "C:\\Windows\\Temp\\testtmp.ps1" + """$filePath = "C:\\Windows\\Temp\\testtmp.ps1" $s = @" aGVsbG8gd29ybGQ= "@ $data = [System.Convert]::FromBase64String($s) Add-Content -value $data -encoding byte -path $filePath -''', - "Failed to upload chunk of powershell script") +""", + "Failed to upload chunk of powershell script", + ) - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._rm_dir') - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._upload') - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._make_tmp_dir') + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._rm_dir") + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._upload") + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._make_tmp_dir") def test_tmp_script(self, mock_make_tmp_dir, mock_upload, mock_rm_dir): - mock_make_tmp_dir.return_value = 'C:\\Windows\\Temp\\abc123' - - with self._runner._tmp_script('C:\\Windows\\Temp', 'Get-ChildItem') as tmp: - self.assertEqual(tmp, 'C:\\Windows\\Temp\\abc123\\script.ps1') - mock_make_tmp_dir.assert_called_with('C:\\Windows\\Temp') - mock_upload.assert_called_with('Get-ChildItem', - 'C:\\Windows\\Temp\\abc123\\script.ps1') - mock_rm_dir.assert_called_with('C:\\Windows\\Temp\\abc123') - - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._rm_dir') - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._upload') - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._make_tmp_dir') - def test_tmp_script_cleans_up_when_raises(self, mock_make_tmp_dir, mock_upload, - mock_rm_dir): - mock_make_tmp_dir.return_value = 'C:\\Windows\\Temp\\abc123' + mock_make_tmp_dir.return_value = "C:\\Windows\\Temp\\abc123" + + with self._runner._tmp_script("C:\\Windows\\Temp", "Get-ChildItem") as tmp: + self.assertEqual(tmp, "C:\\Windows\\Temp\\abc123\\script.ps1") + mock_make_tmp_dir.assert_called_with("C:\\Windows\\Temp") + mock_upload.assert_called_with( + "Get-ChildItem", "C:\\Windows\\Temp\\abc123\\script.ps1" + ) + mock_rm_dir.assert_called_with("C:\\Windows\\Temp\\abc123") + + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._rm_dir") + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._upload") + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._make_tmp_dir") + def test_tmp_script_cleans_up_when_raises( + self, mock_make_tmp_dir, mock_upload, mock_rm_dir + ): + mock_make_tmp_dir.return_value = "C:\\Windows\\Temp\\abc123" mock_upload.side_effect = RuntimeError with self.assertRaises(RuntimeError): - with self._runner._tmp_script('C:\\Windows\\Temp', 'Get-ChildItem') as tmp: + with self._runner._tmp_script("C:\\Windows\\Temp", "Get-ChildItem") as tmp: self.assertEqual(tmp, "can never get here") - mock_make_tmp_dir.assert_called_with('C:\\Windows\\Temp') - mock_upload.assert_called_with('Get-ChildItem', - 'C:\\Windows\\Temp\\abc123\\script.ps1') - mock_rm_dir.assert_called_with('C:\\Windows\\Temp\\abc123') + mock_make_tmp_dir.assert_called_with("C:\\Windows\\Temp") + mock_upload.assert_called_with( + "Get-ChildItem", "C:\\Windows\\Temp\\abc123\\script.ps1" + ) + mock_rm_dir.assert_called_with("C:\\Windows\\Temp\\abc123") - @mock.patch('winrm.Protocol') + @mock.patch("winrm.Protocol") def test_run_cmd(self, mock_protocol_init): mock_protocol = mock.MagicMock() mock_protocol._raw_get_command_output.side_effect = [ - (b'output1', b'error1', 0, False), - (b'output2', b'error2', 0, False), - (b'output3', b'error3', 0, True) + (b"output1", b"error1", 0, False), + (b"output2", b"error2", 0, False), + (b"output3", b"error3", 0, True), ] mock_protocol_init.return_value = mock_protocol self._init_runner() result = self._runner.run_cmd("ipconfig /all") - self.assertEqual(result, ('succeeded', - {'failed': False, - 'succeeded': True, - 'return_code': 0, - 'stdout': 'output1output2output3', - 'stderr': 'error1error2error3'}, - None)) - - @mock.patch('winrm.Protocol') + self.assertEqual( + result, + ( + "succeeded", + { + "failed": False, + "succeeded": True, + "return_code": 0, + "stdout": "output1output2output3", + "stderr": "error1error2error3", + }, + None, + ), + ) + + @mock.patch("winrm.Protocol") def test_run_cmd_failed(self, mock_protocol_init): mock_protocol = mock.MagicMock() mock_protocol._raw_get_command_output.side_effect = [ - (b'output1', b'error1', 0, False), - (b'output2', b'error2', 0, False), - (b'output3', b'error3', 1, True) + (b"output1", b"error1", 0, False), + (b"output2", b"error2", 0, False), + (b"output3", b"error3", 1, True), ] mock_protocol_init.return_value = mock_protocol self._init_runner() result = self._runner.run_cmd("ipconfig /all") - self.assertEqual(result, ('failed', - {'failed': True, - 'succeeded': False, - 'return_code': 1, - 'stdout': 'output1output2output3', - 'stderr': 'error1error2error3'}, - None)) - - @mock.patch('winrm.Protocol') + self.assertEqual( + result, + ( + "failed", + { + "failed": True, + "succeeded": False, + "return_code": 1, + "stdout": "output1output2output3", + "stderr": "error1error2error3", + }, + None, + ), + ) + + @mock.patch("winrm.Protocol") def test_run_cmd_timeout(self, mock_protocol_init): mock_protocol = mock.MagicMock() self._init_runner() @@ -548,61 +627,82 @@ def test_run_cmd_timeout(self, mock_protocol_init): def sleep_for_timeout_then_raise(*args, **kwargs): time.sleep(0.2) - return (b'output1', b'error1', 123, False) + return (b"output1", b"error1", 123, False) mock_protocol._raw_get_command_output.side_effect = sleep_for_timeout_then_raise mock_protocol_init.return_value = mock_protocol result = self._runner.run_cmd("ipconfig /all") - self.assertEqual(result, ('timeout', - {'failed': True, - 'succeeded': False, - 'return_code': -1, - 'stdout': 'output1', - 'stderr': 'error1'}, - None)) - - @mock.patch('winrm.Protocol') + self.assertEqual( + result, + ( + "timeout", + { + "failed": True, + "succeeded": False, + "return_code": -1, + "stdout": "output1", + "stderr": "error1", + }, + None, + ), + ) + + @mock.patch("winrm.Protocol") def test_run_ps(self, mock_protocol_init): mock_protocol = mock.MagicMock() mock_protocol._raw_get_command_output.side_effect = [ - (b'output1', b'error1', 0, False), - (b'output2', b'error2', 0, False), - (b'output3', b'error3', 0, True) + (b"output1", b"error1", 0, False), + (b"output2", b"error2", 0, False), + (b"output3", b"error3", 0, True), ] mock_protocol_init.return_value = mock_protocol self._init_runner() result = self._runner.run_ps("Get-Location") - self.assertEqual(result, ('succeeded', - {'failed': False, - 'succeeded': True, - 'return_code': 0, - 'stdout': 'output1output2output3', - 'stderr': 'error1error2error3'}, - None)) - - @mock.patch('winrm.Protocol') + self.assertEqual( + result, + ( + "succeeded", + { + "failed": False, + "succeeded": True, + "return_code": 0, + "stdout": "output1output2output3", + "stderr": "error1error2error3", + }, + None, + ), + ) + + @mock.patch("winrm.Protocol") def test_run_ps_failed(self, mock_protocol_init): mock_protocol = mock.MagicMock() mock_protocol._raw_get_command_output.side_effect = [ - (b'output1', b'error1', 0, False), - (b'output2', b'error2', 0, False), - (b'output3', b'error3', 1, True) + (b"output1", b"error1", 0, False), + (b"output2", b"error2", 0, False), + (b"output3", b"error3", 1, True), ] mock_protocol_init.return_value = mock_protocol self._init_runner() result = self._runner.run_ps("Get-Location") - self.assertEqual(result, ('failed', - {'failed': True, - 'succeeded': False, - 'return_code': 1, - 'stdout': 'output1output2output3', - 'stderr': 'error1error2error3'}, - None)) - - @mock.patch('winrm.Protocol') + self.assertEqual( + result, + ( + "failed", + { + "failed": True, + "succeeded": False, + "return_code": 1, + "stdout": "output1output2output3", + "stderr": "error1error2error3", + }, + None, + ), + ) + + @mock.patch("winrm.Protocol") def test_run_ps_timeout(self, mock_protocol_init): mock_protocol = mock.MagicMock() self._init_runner() @@ -610,91 +710,113 @@ def test_run_ps_timeout(self, mock_protocol_init): def sleep_for_timeout_then_raise(*args, **kwargs): time.sleep(0.2) - return (b'output1', b'error1', 123, False) + return (b"output1", b"error1", 123, False) mock_protocol._raw_get_command_output.side_effect = sleep_for_timeout_then_raise mock_protocol_init.return_value = mock_protocol result = self._runner.run_ps("Get-Location") - self.assertEqual(result, ('timeout', - {'failed': True, - 'succeeded': False, - 'return_code': -1, - 'stdout': 'output1', - 'stderr': 'error1'}, - None)) - - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._run_ps') - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._winrm_encode') + self.assertEqual( + result, + ( + "timeout", + { + "failed": True, + "succeeded": False, + "return_code": -1, + "stdout": "output1", + "stderr": "error1", + }, + None, + ), + ) + + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._run_ps") + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._winrm_encode") def test_run_ps_params(self, mock_winrm_encode, mock_run_ps): - mock_winrm_encode.return_value = 'xyz123==' + mock_winrm_encode.return_value = "xyz123==" mock_run_ps.return_value = "expected" self._init_runner() - result = self._runner.run_ps("Get-Location", '-param1 value1 arg1') + result = self._runner.run_ps("Get-Location", "-param1 value1 arg1") self.assertEqual(result, "expected") - mock_winrm_encode.assert_called_with('& {Get-Location} -param1 value1 arg1') - mock_run_ps.assert_called_with('xyz123==', is_b64=True) - - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._winrm_ps_cmd') - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._run_ps_script') - def test_run_ps_large_command_convert_to_script(self, mock_run_ps_script, - mock_winrm_ps_cmd): + mock_winrm_encode.assert_called_with("& {Get-Location} -param1 value1 arg1") + mock_run_ps.assert_called_with("xyz123==", is_b64=True) + + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._winrm_ps_cmd") + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._run_ps_script") + def test_run_ps_large_command_convert_to_script( + self, mock_run_ps_script, mock_winrm_ps_cmd + ): mock_run_ps_script.return_value = "expected" # max length of a command in powershelll - script = 'powershell -encodedcommand ' - script += '#' * (WINRM_MAX_CMD_LENGTH + 1 - len(script)) + script = "powershell -encodedcommand " + script += "#" * (WINRM_MAX_CMD_LENGTH + 1 - len(script)) mock_winrm_ps_cmd.return_value = script self._init_runner() - result = self._runner.run_ps('$PSVersionTable') + result = self._runner.run_ps("$PSVersionTable") self.assertEqual(result, "expected") - mock_run_ps_script.assert_called_with('$PSVersionTable', None) + mock_run_ps_script.assert_called_with("$PSVersionTable", None) - @mock.patch('winrm.Protocol') + @mock.patch("winrm.Protocol") def test__run_ps(self, mock_protocol_init): mock_protocol = mock.MagicMock() mock_protocol._raw_get_command_output.side_effect = [ - (b'output1', b'error1', 0, False), - (b'output2', b'error2', 0, False), - (b'output3', b'error3', 0, True) + (b"output1", b"error1", 0, False), + (b"output2", b"error2", 0, False), + (b"output3", b"error3", 0, True), ] mock_protocol_init.return_value = mock_protocol self._init_runner() result = self._runner._run_ps("Get-Location") - self.assertEqual(result, ('succeeded', - {'failed': False, - 'succeeded': True, - 'return_code': 0, - 'stdout': 'output1output2output3', - 'stderr': 'error1error2error3'}, - None)) - - @mock.patch('winrm.Protocol') + self.assertEqual( + result, + ( + "succeeded", + { + "failed": False, + "succeeded": True, + "return_code": 0, + "stdout": "output1output2output3", + "stderr": "error1error2error3", + }, + None, + ), + ) + + @mock.patch("winrm.Protocol") def test__run_ps_failed(self, mock_protocol_init): mock_protocol = mock.MagicMock() mock_protocol._raw_get_command_output.side_effect = [ - (b'output1', b'error1', 0, False), - (b'output2', b'error2', 0, False), - (b'output3', b'error3', 1, True) + (b"output1", b"error1", 0, False), + (b"output2", b"error2", 0, False), + (b"output3", b"error3", 1, True), ] mock_protocol_init.return_value = mock_protocol self._init_runner() result = self._runner._run_ps("Get-Location") - self.assertEqual(result, ('failed', - {'failed': True, - 'succeeded': False, - 'return_code': 1, - 'stdout': 'output1output2output3', - 'stderr': 'error1error2error3'}, - None)) - - @mock.patch('winrm.Protocol') + self.assertEqual( + result, + ( + "failed", + { + "failed": True, + "succeeded": False, + "return_code": 1, + "stdout": "output1output2output3", + "stderr": "error1error2error3", + }, + None, + ), + ) + + @mock.patch("winrm.Protocol") def test__run_ps_timeout(self, mock_protocol_init): mock_protocol = mock.MagicMock() self._init_runner() @@ -702,238 +824,236 @@ def test__run_ps_timeout(self, mock_protocol_init): def sleep_for_timeout_then_raise(*args, **kwargs): time.sleep(0.2) - return (b'output1', b'error1', 123, False) + return (b"output1", b"error1", 123, False) mock_protocol._raw_get_command_output.side_effect = sleep_for_timeout_then_raise mock_protocol_init.return_value = mock_protocol result = self._runner._run_ps("Get-Location") - self.assertEqual(result, ('timeout', - {'failed': True, - 'succeeded': False, - 'return_code': -1, - 'stdout': 'output1', - 'stderr': 'error1'}, - None)) - - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._winrm_run_ps') + self.assertEqual( + result, + ( + "timeout", + { + "failed": True, + "succeeded": False, + "return_code": -1, + "stdout": "output1", + "stderr": "error1", + }, + None, + ), + ) + + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._winrm_run_ps") def test__run_ps_b64_default(self, mock_winrm_run_ps): - mock_winrm_run_ps.return_value = mock.MagicMock(status_code=0, - timeout=False, - std_out='output1', - std_err='error1') + mock_winrm_run_ps.return_value = mock.MagicMock( + status_code=0, timeout=False, std_out="output1", std_err="error1" + ) self._init_runner() result = self._runner._run_ps("$PSVersionTable") - self.assertEqual(result, ('succeeded', - {'failed': False, - 'succeeded': True, - 'return_code': 0, - 'stdout': 'output1', - 'stderr': 'error1'}, - None)) - mock_winrm_run_ps.assert_called_with(self._runner._session, - '$PSVersionTable', - env={}, - cwd=None, - is_b64=False) - - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._winrm_run_ps') + self.assertEqual( + result, + ( + "succeeded", + { + "failed": False, + "succeeded": True, + "return_code": 0, + "stdout": "output1", + "stderr": "error1", + }, + None, + ), + ) + mock_winrm_run_ps.assert_called_with( + self._runner._session, "$PSVersionTable", env={}, cwd=None, is_b64=False + ) + + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._winrm_run_ps") def test__run_ps_b64_true(self, mock_winrm_run_ps): - mock_winrm_run_ps.return_value = mock.MagicMock(status_code=0, - timeout=False, - std_out='output1', - std_err='error1') + mock_winrm_run_ps.return_value = mock.MagicMock( + status_code=0, timeout=False, std_out="output1", std_err="error1" + ) self._init_runner() result = self._runner._run_ps("xyz123", is_b64=True) - self.assertEqual(result, ('succeeded', - {'failed': False, - 'succeeded': True, - 'return_code': 0, - 'stdout': 'output1', - 'stderr': 'error1'}, - None)) - mock_winrm_run_ps.assert_called_with(self._runner._session, - 'xyz123', - env={}, - cwd=None, - is_b64=True) - - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._run_ps') - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._tmp_script') + self.assertEqual( + result, + ( + "succeeded", + { + "failed": False, + "succeeded": True, + "return_code": 0, + "stdout": "output1", + "stderr": "error1", + }, + None, + ), + ) + mock_winrm_run_ps.assert_called_with( + self._runner._session, "xyz123", env={}, cwd=None, is_b64=True + ) + + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._run_ps") + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._tmp_script") def test__run_ps_script(self, mock_tmp_script, mock_run_ps): - mock_tmp_script.return_value.__enter__.return_value = 'C:\\tmpscript.ps1' - mock_run_ps.return_value = 'expected' + mock_tmp_script.return_value.__enter__.return_value = "C:\\tmpscript.ps1" + mock_run_ps.return_value = "expected" self._init_runner() result = self._runner._run_ps_script("$PSVersionTable") - self.assertEqual(result, 'expected') - mock_tmp_script.assert_called_with('[System.IO.Path]::GetTempPath()', - '$PSVersionTable') - mock_run_ps.assert_called_with('& {C:\\tmpscript.ps1}') + self.assertEqual(result, "expected") + mock_tmp_script.assert_called_with( + "[System.IO.Path]::GetTempPath()", "$PSVersionTable" + ) + mock_run_ps.assert_called_with("& {C:\\tmpscript.ps1}") - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._run_ps') - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._tmp_script') + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._run_ps") + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._tmp_script") def test__run_ps_script_with_params(self, mock_tmp_script, mock_run_ps): - mock_tmp_script.return_value.__enter__.return_value = 'C:\\tmpscript.ps1' - mock_run_ps.return_value = 'expected' + mock_tmp_script.return_value.__enter__.return_value = "C:\\tmpscript.ps1" + mock_run_ps.return_value = "expected" self._init_runner() - result = self._runner._run_ps_script("Get-ChildItem", '-param1 value1 arg1') - self.assertEqual(result, 'expected') - mock_tmp_script.assert_called_with('[System.IO.Path]::GetTempPath()', - 'Get-ChildItem') - mock_run_ps.assert_called_with('& {C:\\tmpscript.ps1} -param1 value1 arg1') + result = self._runner._run_ps_script("Get-ChildItem", "-param1 value1 arg1") + self.assertEqual(result, "expected") + mock_tmp_script.assert_called_with( + "[System.IO.Path]::GetTempPath()", "Get-ChildItem" + ) + mock_run_ps.assert_called_with("& {C:\\tmpscript.ps1} -param1 value1 arg1") - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._run_ps') + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._run_ps") def test__run_ps_or_raise(self, mock_run_ps): - mock_run_ps.return_value = ('success', - { - 'failed': False, - 'succeeded': True, - 'return_code': 0, - 'stdout': 'output', - 'stderr': 'error', - }, - None) + mock_run_ps.return_value = ( + "success", + { + "failed": False, + "succeeded": True, + "return_code": 0, + "stdout": "output", + "stderr": "error", + }, + None, + ) self._init_runner() - result = self._runner._run_ps_or_raise('Get-ChildItem', 'my error message') - self.assertEqual(result, { - 'failed': False, - 'succeeded': True, - 'return_code': 0, - 'stdout': 'output', - 'stderr': 'error', - }) - - @mock.patch('winrm_runner.winrm_base.WinRmBaseRunner._run_ps') + result = self._runner._run_ps_or_raise("Get-ChildItem", "my error message") + self.assertEqual( + result, + { + "failed": False, + "succeeded": True, + "return_code": 0, + "stdout": "output", + "stderr": "error", + }, + ) + + @mock.patch("winrm_runner.winrm_base.WinRmBaseRunner._run_ps") def test__run_ps_or_raise_raises_on_failure(self, mock_run_ps): - mock_run_ps.return_value = ('success', - { - 'failed': True, - 'succeeded': False, - 'return_code': 1, - 'stdout': 'output', - 'stderr': 'error', - }, - None) + mock_run_ps.return_value = ( + "success", + { + "failed": True, + "succeeded": False, + "return_code": 1, + "stdout": "output", + "stderr": "error", + }, + None, + ) self._init_runner() with self.assertRaises(RuntimeError): - self._runner._run_ps_or_raise('Get-ChildItem', 'my error message') + self._runner._run_ps_or_raise("Get-ChildItem", "my error message") def test_multireplace(self): - multireplace_map = {'a': 'x', - 'c': 'y', - 'aaa': 'z'} - result = self._runner._multireplace('aaaccaa', multireplace_map) - self.assertEqual(result, 'zyyxx') + multireplace_map = {"a": "x", "c": "y", "aaa": "z"} + result = self._runner._multireplace("aaaccaa", multireplace_map) + self.assertEqual(result, "zyyxx") def test_multireplace_powershell(self): - param_str = ( - '\n' - '\r' - '\t' - '\a' - '\b' - '\f' - '\v' - '"' - '\'' - '`' - '\0' - '$' - ) + param_str = "\n" "\r" "\t" "\a" "\b" "\f" "\v" '"' "'" "`" "\0" "$" result = self._runner._multireplace(param_str, PS_ESCAPE_SEQUENCES) - self.assertEqual(result, ( - '`n' - '`r' - '`t' - '`a' - '`b' - '`f' - '`v' - '`"' - '`\'' - '``' - '`0' - '`$' - )) + self.assertEqual( + result, ("`n" "`r" "`t" "`a" "`b" "`f" "`v" '`"' "`'" "``" "`0" "`$") + ) def test_param_to_ps_none(self): # test None/null param = None result = self._runner._param_to_ps(param) - self.assertEqual(result, '$null') + self.assertEqual(result, "$null") def test_param_to_ps_string(self): # test ascii - param_str = 'StackStorm 1234' + param_str = "StackStorm 1234" result = self._runner._param_to_ps(param_str) self.assertEqual(result, '"StackStorm 1234"') # test escaped - param_str = '\n\r\t' + param_str = "\n\r\t" result = self._runner._param_to_ps(param_str) self.assertEqual(result, '"`n`r`t"') def test_param_to_ps_bool(self): # test True result = self._runner._param_to_ps(True) - self.assertEqual(result, '$true') + self.assertEqual(result, "$true") # test False result = self._runner._param_to_ps(False) - self.assertEqual(result, '$false') + self.assertEqual(result, "$false") def test_param_to_ps_integer(self): result = self._runner._param_to_ps(9876) - self.assertEqual(result, '9876') + self.assertEqual(result, "9876") result = self._runner._param_to_ps(-765) - self.assertEqual(result, '-765') + self.assertEqual(result, "-765") def test_param_to_ps_float(self): result = self._runner._param_to_ps(98.76) - self.assertEqual(result, '98.76') + self.assertEqual(result, "98.76") result = self._runner._param_to_ps(-76.5) - self.assertEqual(result, '-76.5') + self.assertEqual(result, "-76.5") def test_param_to_ps_list(self): - input_list = ['StackStorm Test String', - '`\0$', - True, - 99] + input_list = ["StackStorm Test String", "`\0$", True, 99] result = self._runner._param_to_ps(input_list) self.assertEqual(result, '@("StackStorm Test String", "```0`$", $true, 99)') def test_param_to_ps_list_nested(self): - input_list = [['a'], ['b'], [['c']]] + input_list = [["a"], ["b"], [["c"]]] result = self._runner._param_to_ps(input_list) self.assertEqual(result, '@(@("a"), @("b"), @(@("c")))') def test_param_to_ps_dict(self): input_list = collections.OrderedDict( - [('str key', 'Value String'), - ('esc str\n', '\b\f\v"'), - (False, True), - (11, 99), - (18.3, 12.34)]) + [ + ("str key", "Value String"), + ("esc str\n", '\b\f\v"'), + (False, True), + (11, 99), + (18.3, 12.34), + ] + ) result = self._runner._param_to_ps(input_list) expected_str = ( '@{"str key" = "Value String"; ' - '"esc str`n" = "`b`f`v`\""; ' - '$false = $true; ' - '11 = 99; ' - '18.3 = 12.34}' + '"esc str`n" = "`b`f`v`""; ' + "$false = $true; " + "11 = 99; " + "18.3 = 12.34}" ) self.assertEqual(result, expected_str) def test_param_to_ps_dict_nexted(self): input_list = collections.OrderedDict( - [('a', {'deep_a': 'value'}), - ('b', {'deep_b': {'deep_deep_b': 'value'}})]) + [("a", {"deep_a": "value"}), ("b", {"deep_b": {"deep_deep_b": "value"}})] + ) result = self._runner._param_to_ps(input_list) expected_str = ( '@{"a" = @{"deep_a" = "value"}; ' @@ -945,21 +1065,22 @@ def test_param_to_ps_deep_nested_dict_outer(self): #### # dict as outer container input_dict = collections.OrderedDict( - [('a', [{'deep_a': 'value'}, - {'deep_b': ['a', 'b', 'c']}])]) + [("a", [{"deep_a": "value"}, {"deep_b": ["a", "b", "c"]}])] + ) result = self._runner._param_to_ps(input_dict) expected_str = ( - '@{"a" = @(@{"deep_a" = "value"}, ' - '@{"deep_b" = @("a", "b", "c")})}' + '@{"a" = @(@{"deep_a" = "value"}, ' '@{"deep_b" = @("a", "b", "c")})}' ) self.assertEqual(result, expected_str) def test_param_to_ps_deep_nested_list_outer(self): #### # list as outer container - input_list = [{'deep_a': 'value'}, - {'deep_b': ['a', 'b', 'c']}, - {'deep_c': [{'x': 'y'}]}] + input_list = [ + {"deep_a": "value"}, + {"deep_b": ["a", "b", "c"]}, + {"deep_c": [{"x": "y"}]}, + ] result = self._runner._param_to_ps(input_list) expected_str = ( '@(@{"deep_a" = "value"}, ' @@ -969,45 +1090,48 @@ def test_param_to_ps_deep_nested_list_outer(self): self.assertEqual(result, expected_str) def test_transform_params_to_ps(self): - positional_args = [1, 'a', '\n'] + positional_args = [1, "a", "\n"] named_args = collections.OrderedDict( - [('a', 'value1'), - ('b', True), - ('c', ['x', 'y']), - ('d', {'z': 'w'})] + [("a", "value1"), ("b", True), ("c", ["x", "y"]), ("d", {"z": "w"})] ) - result_pos, result_named = self._runner._transform_params_to_ps(positional_args, - named_args) - self.assertEqual(result_pos, ['1', '"a"', '"`n"']) - self.assertEqual(result_named, collections.OrderedDict([ - ('a', '"value1"'), - ('b', '$true'), - ('c', '@("x", "y")'), - ('d', '@{"z" = "w"}')])) + result_pos, result_named = self._runner._transform_params_to_ps( + positional_args, named_args + ) + self.assertEqual(result_pos, ["1", '"a"', '"`n"']) + self.assertEqual( + result_named, + collections.OrderedDict( + [ + ("a", '"value1"'), + ("b", "$true"), + ("c", '@("x", "y")'), + ("d", '@{"z" = "w"}'), + ] + ), + ) def test_transform_params_to_ps_none(self): positional_args = None named_args = None - result_pos, result_named = self._runner._transform_params_to_ps(positional_args, - named_args) + result_pos, result_named = self._runner._transform_params_to_ps( + positional_args, named_args + ) self.assertEqual(result_pos, None) self.assertEqual(result_named, None) def test_create_ps_params_string(self): - positional_args = [1, 'a', '\n'] + positional_args = [1, "a", "\n"] named_args = collections.OrderedDict( - [('-a', 'value1'), - ('-b', True), - ('-c', ['x', 'y']), - ('-d', {'z': 'w'})] + [("-a", "value1"), ("-b", True), ("-c", ["x", "y"]), ("-d", {"z": "w"})] ) result = self._runner.create_ps_params_string(positional_args, named_args) - self.assertEqual(result, - '-a "value1" -b $true -c @("x", "y") -d @{"z" = "w"} 1 "a" "`n"') + self.assertEqual( + result, '-a "value1" -b $true -c @("x", "y") -d @{"z" = "w"} 1 "a" "`n"' + ) def test_create_ps_params_string_none(self): positional_args = None diff --git a/contrib/runners/winrm_runner/tests/unit/test_winrm_command_runner.py b/contrib/runners/winrm_runner/tests/unit/test_winrm_command_runner.py index 9ff36a1b47c..78365a333b6 100644 --- a/contrib/runners/winrm_runner/tests/unit/test_winrm_command_runner.py +++ b/contrib/runners/winrm_runner/tests/unit/test_winrm_command_runner.py @@ -23,23 +23,22 @@ class WinRmCommandRunnerTestCase(RunnerTestCase): - def setUp(self): super(WinRmCommandRunnerTestCase, self).setUpClass() self._runner = winrm_command_runner.get_runner() def test_init(self): - runner = winrm_command_runner.WinRmCommandRunner('abcdef') + runner = winrm_command_runner.WinRmCommandRunner("abcdef") self.assertIsInstance(runner, WinRmBaseRunner) self.assertIsInstance(runner, ActionRunner) - self.assertEqual(runner.runner_id, 'abcdef') + self.assertEqual(runner.runner_id, "abcdef") - @mock.patch('winrm_runner.winrm_command_runner.WinRmCommandRunner.run_cmd') + @mock.patch("winrm_runner.winrm_command_runner.WinRmCommandRunner.run_cmd") def test_run(self, mock_run_cmd): - mock_run_cmd.return_value = 'expected' + mock_run_cmd.return_value = "expected" - self._runner.runner_parameters = {'cmd': 'ipconfig /all'} + self._runner.runner_parameters = {"cmd": "ipconfig /all"} result = self._runner.run({}) - self.assertEqual(result, 'expected') - mock_run_cmd.assert_called_with('ipconfig /all') + self.assertEqual(result, "expected") + mock_run_cmd.assert_called_with("ipconfig /all") diff --git a/contrib/runners/winrm_runner/tests/unit/test_winrm_ps_command_runner.py b/contrib/runners/winrm_runner/tests/unit/test_winrm_ps_command_runner.py index d6bae23e2cf..90d9e95abd6 100644 --- a/contrib/runners/winrm_runner/tests/unit/test_winrm_ps_command_runner.py +++ b/contrib/runners/winrm_runner/tests/unit/test_winrm_ps_command_runner.py @@ -23,23 +23,22 @@ class WinRmPsCommandRunnerTestCase(RunnerTestCase): - def setUp(self): super(WinRmPsCommandRunnerTestCase, self).setUpClass() self._runner = winrm_ps_command_runner.get_runner() def test_init(self): - runner = winrm_ps_command_runner.WinRmPsCommandRunner('abcdef') + runner = winrm_ps_command_runner.WinRmPsCommandRunner("abcdef") self.assertIsInstance(runner, WinRmBaseRunner) self.assertIsInstance(runner, ActionRunner) - self.assertEqual(runner.runner_id, 'abcdef') + self.assertEqual(runner.runner_id, "abcdef") - @mock.patch('winrm_runner.winrm_ps_command_runner.WinRmPsCommandRunner.run_ps') + @mock.patch("winrm_runner.winrm_ps_command_runner.WinRmPsCommandRunner.run_ps") def test_run(self, mock_run_ps): - mock_run_ps.return_value = 'expected' + mock_run_ps.return_value = "expected" - self._runner.runner_parameters = {'cmd': 'Get-ADUser stanley'} + self._runner.runner_parameters = {"cmd": "Get-ADUser stanley"} result = self._runner.run({}) - self.assertEqual(result, 'expected') - mock_run_ps.assert_called_with('Get-ADUser stanley') + self.assertEqual(result, "expected") + mock_run_ps.assert_called_with("Get-ADUser stanley") diff --git a/contrib/runners/winrm_runner/tests/unit/test_winrm_ps_script_runner.py b/contrib/runners/winrm_runner/tests/unit/test_winrm_ps_script_runner.py index b3c1e140349..c1414c25e7e 100644 --- a/contrib/runners/winrm_runner/tests/unit/test_winrm_ps_script_runner.py +++ b/contrib/runners/winrm_runner/tests/unit/test_winrm_ps_script_runner.py @@ -22,39 +22,41 @@ from winrm_runner import winrm_ps_script_runner from winrm_runner.winrm_base import WinRmBaseRunner -FIXTURES_PATH = os.path.join(os.path.dirname(__file__), 'fixtures') +FIXTURES_PATH = os.path.join(os.path.dirname(__file__), "fixtures") POWERSHELL_SCRIPT_PATH = os.path.join(FIXTURES_PATH, "TestScript.ps1") class WinRmPsScriptRunnerTestCase(RunnerTestCase): - def setUp(self): super(WinRmPsScriptRunnerTestCase, self).setUpClass() self._runner = winrm_ps_script_runner.get_runner() def test_init(self): - runner = winrm_ps_script_runner.WinRmPsScriptRunner('abcdef') + runner = winrm_ps_script_runner.WinRmPsScriptRunner("abcdef") self.assertIsInstance(runner, WinRmBaseRunner) self.assertIsInstance(runner, ActionRunner) - self.assertEqual(runner.runner_id, 'abcdef') + self.assertEqual(runner.runner_id, "abcdef") - @mock.patch('winrm_runner.winrm_ps_script_runner.WinRmPsScriptRunner._get_script_args') - @mock.patch('winrm_runner.winrm_ps_script_runner.WinRmPsScriptRunner.run_ps') + @mock.patch( + "winrm_runner.winrm_ps_script_runner.WinRmPsScriptRunner._get_script_args" + ) + @mock.patch("winrm_runner.winrm_ps_script_runner.WinRmPsScriptRunner.run_ps") def test_run(self, mock_run_ps, mock_get_script_args): - mock_run_ps.return_value = 'expected' - pos_args = [1, 'abc'] + mock_run_ps.return_value = "expected" + pos_args = [1, "abc"] named_args = {"d": {"test": ["\r", True, 3]}} mock_get_script_args.return_value = (pos_args, named_args) self._runner.entry_point = POWERSHELL_SCRIPT_PATH self._runner.runner_parameters = {} - self._runner._kwarg_op = '-' + self._runner._kwarg_op = "-" result = self._runner.run({}) - self.assertEqual(result, 'expected') - mock_run_ps.assert_called_with('''[CmdletBinding()] + self.assertEqual(result, "expected") + mock_run_ps.assert_called_with( + """[CmdletBinding()] Param( [bool]$p_bool, [int]$p_integer, @@ -77,5 +79,6 @@ def test_run(self, mock_run_ps, mock_get_script_args): Write-Output "p_obj = $($p_obj | ConvertTo-Json -Compress)" Write-Output "p_pos0 = $p_pos0" Write-Output "p_pos1 = $p_pos1" -''', - '-d @{"test" = @("`r", $true, 3)} 1 "abc"') +""", + '-d @{"test" = @("`r", $true, 3)} 1 "abc"', + ) diff --git a/contrib/runners/winrm_runner/winrm_runner/__init__.py b/contrib/runners/winrm_runner/winrm_runner/__init__.py index bbe290db9a7..e6d3f15e0bd 100644 --- a/contrib/runners/winrm_runner/winrm_runner/__init__.py +++ b/contrib/runners/winrm_runner/winrm_runner/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = "3.4dev" diff --git a/contrib/runners/winrm_runner/winrm_runner/winrm_base.py b/contrib/runners/winrm_runner/winrm_runner/winrm_base.py index fb26e49db60..9bebbedc7bc 100644 --- a/contrib/runners/winrm_runner/winrm_runner/winrm_base.py +++ b/contrib/runners/winrm_runner/winrm_runner/winrm_base.py @@ -32,7 +32,7 @@ from winrm.exceptions import WinRMOperationTimeoutError __all__ = [ - 'WinRmBaseRunner', + "WinRmBaseRunner", ] LOG = logging.getLogger(__name__) @@ -49,7 +49,7 @@ RUNNER_USERNAME = "username" RUNNER_VERIFY_SSL = "verify_ssl_cert" -WINRM_DEFAULT_TMP_DIR_PS = '[System.IO.Path]::GetTempPath()' +WINRM_DEFAULT_TMP_DIR_PS = "[System.IO.Path]::GetTempPath()" # maximum cmdline length for systems >= Windows XP # https://support.microsoft.com/en-us/help/830473/command-prompt-cmd-exe-command-line-string-limitation WINRM_MAX_CMD_LENGTH = 8191 @@ -76,28 +76,28 @@ # Compiled list from the following sources: # https://ss64.com/ps/syntax-esc.html # https://www.techotopia.com/index.php/Windows_PowerShell_1.0_String_Quoting_and_Escape_Sequences#PowerShell_Special_Escape_Sequences -PS_ESCAPE_SEQUENCES = {'\n': '`n', - '\r': '`r', - '\t': '`t', - '\a': '`a', - '\b': '`b', - '\f': '`f', - '\v': '`v', - '"': '`"', - '\'': '`\'', - '`': '``', - '\0': '`0', - '$': '`$'} +PS_ESCAPE_SEQUENCES = { + "\n": "`n", + "\r": "`r", + "\t": "`t", + "\a": "`a", + "\b": "`b", + "\f": "`f", + "\v": "`v", + '"': '`"', + "'": "`'", + "`": "``", + "\0": "`0", + "$": "`$", +} class WinRmRunnerTimoutError(Exception): - def __init__(self, response): self.response = response class WinRmBaseRunner(ActionRunner): - def pre_run(self): super(WinRmBaseRunner, self).pre_run() @@ -107,12 +107,16 @@ def pre_run(self): self._username = self.runner_parameters[RUNNER_USERNAME] self._password = self.runner_parameters[RUNNER_PASSWORD] self._timeout = self.runner_parameters.get(RUNNER_TIMEOUT, DEFAULT_TIMEOUT) - self._read_timeout = self._timeout + 1 # read_timeout must be > operation_timeout + self._read_timeout = ( + self._timeout + 1 + ) # read_timeout must be > operation_timeout # default to https port 5986 over ntlm self._port = self.runner_parameters.get(RUNNER_PORT, DEFAULT_PORT) self._scheme = self.runner_parameters.get(RUNNER_SCHEME, DEFAULT_SCHEME) - self._transport = self.runner_parameters.get(RUNNER_TRANSPORT, DEFAULT_TRANSPORT) + self._transport = self.runner_parameters.get( + RUNNER_TRANSPORT, DEFAULT_TRANSPORT + ) # if connecting to the HTTP port then we must use "http" as the scheme # in the URL @@ -120,10 +124,14 @@ def pre_run(self): self._scheme = "http" # construct the URL for connecting to WinRM on the host - self._winrm_url = "{}://{}:{}/wsman".format(self._scheme, self._host, self._port) + self._winrm_url = "{}://{}:{}/wsman".format( + self._scheme, self._host, self._port + ) # default to verifying SSL certs - self._verify_ssl = self.runner_parameters.get(RUNNER_VERIFY_SSL, DEFAULT_VERIFY_SSL) + self._verify_ssl = self.runner_parameters.get( + RUNNER_VERIFY_SSL, DEFAULT_VERIFY_SSL + ) self._server_cert_validation = "validate" if self._verify_ssl else "ignore" # additional parameters @@ -136,12 +144,14 @@ def _get_session(self): # cache session (only create if it doesn't exist yet) if not self._session: LOG.debug("Connecting via WinRM to url: {}".format(self._winrm_url)) - self._session = Session(self._winrm_url, - auth=(self._username, self._password), - transport=self._transport, - server_cert_validation=self._server_cert_validation, - operation_timeout_sec=self._timeout, - read_timeout_sec=self._read_timeout) + self._session = Session( + self._winrm_url, + auth=(self._username, self._password), + transport=self._transport, + server_cert_validation=self._server_cert_validation, + operation_timeout_sec=self._timeout, + read_timeout_sec=self._read_timeout, + ) return self._session def _winrm_get_command_output(self, protocol, shell_id, command_id): @@ -154,37 +164,46 @@ def _winrm_get_command_output(self, protocol, shell_id, command_id): while not command_done: # check if we need to timeout (StackStorm custom) current_time = time.time() - elapsed_time = (current_time - start_time) + elapsed_time = current_time - start_time if self._timeout and (elapsed_time > self._timeout): - raise WinRmRunnerTimoutError(Response((b''.join(stdout_buffer), - b''.join(stderr_buffer), - WINRM_TIMEOUT_EXIT_CODE))) + raise WinRmRunnerTimoutError( + Response( + ( + b"".join(stdout_buffer), + b"".join(stderr_buffer), + WINRM_TIMEOUT_EXIT_CODE, + ) + ) + ) # end stackstorm custom try: - stdout, stderr, return_code, command_done = \ - protocol._raw_get_command_output(shell_id, command_id) + ( + stdout, + stderr, + return_code, + command_done, + ) = protocol._raw_get_command_output(shell_id, command_id) stdout_buffer.append(stdout) stderr_buffer.append(stderr) except WinRMOperationTimeoutError: # this is an expected error when waiting for a long-running process, # just silently retry pass - return b''.join(stdout_buffer), b''.join(stderr_buffer), return_code + return b"".join(stdout_buffer), b"".join(stderr_buffer), return_code def _winrm_run_cmd(self, session, command, args=(), env=None, cwd=None): # NOTE: this is copied from pywinrm because it doesn't support # passing env and working_directory from the Session.run_cmd. # It also doesn't support timeouts. All of these things have been # added - shell_id = session.protocol.open_shell(env_vars=env, - working_directory=cwd) + shell_id = session.protocol.open_shell(env_vars=env, working_directory=cwd) command_id = session.protocol.run_command(shell_id, command, args) # try/catch is for custom timeout handing (StackStorm custom) try: - rs = Response(self._winrm_get_command_output(session.protocol, - shell_id, - command_id)) + rs = Response( + self._winrm_get_command_output(session.protocol, shell_id, command_id) + ) rs.timeout = False except WinRmRunnerTimoutError as e: rs = e.response @@ -195,37 +214,34 @@ def _winrm_run_cmd(self, session, command, args=(), env=None, cwd=None): return rs def _winrm_encode(self, script): - return b64encode(script.encode('utf_16_le')).decode('ascii') + return b64encode(script.encode("utf_16_le")).decode("ascii") def _winrm_ps_cmd(self, encoded_ps): - return 'powershell -encodedcommand {0}'.format(encoded_ps) + return "powershell -encodedcommand {0}".format(encoded_ps) def _winrm_run_ps(self, session, script, env=None, cwd=None, is_b64=False): # NOTE: this is copied from pywinrm because it doesn't support # passing env and working_directory from the Session.run_ps # encode the script in UTF only if it isn't passed in encoded - LOG.debug('_winrm_run_ps() - script size = {}'.format(len(script))) + LOG.debug("_winrm_run_ps() - script size = {}".format(len(script))) encoded_ps = script if is_b64 else self._winrm_encode(script) ps_cmd = self._winrm_ps_cmd(encoded_ps) - LOG.debug('_winrm_run_ps() - ps cmd size = {}'.format(len(ps_cmd))) - rs = self._winrm_run_cmd(session, - ps_cmd, - env=env, - cwd=cwd) + LOG.debug("_winrm_run_ps() - ps cmd size = {}".format(len(ps_cmd))) + rs = self._winrm_run_cmd(session, ps_cmd, env=env, cwd=cwd) if len(rs.std_err): # if there was an error message, clean it it up and make it human # readable if isinstance(rs.std_err, bytes): # decode bytes into utf-8 because of a bug in pywinrm # real fix is here: https://github.com/diyan/pywinrm/pull/222/files - rs.std_err = rs.std_err.decode('utf-8') + rs.std_err = rs.std_err.decode("utf-8") rs.std_err = session._clean_error_msg(rs.std_err) return rs def _translate_response(self, response): # check exit status for errors - succeeded = (response.status_code == exit_code_constants.SUCCESS_EXIT_CODE) + succeeded = response.status_code == exit_code_constants.SUCCESS_EXIT_CODE status = action_constants.LIVEACTION_STATUS_SUCCEEDED status_code = response.status_code if response.timeout: @@ -236,39 +252,46 @@ def _translate_response(self, response): # create result result = { - 'failed': not succeeded, - 'succeeded': succeeded, - 'return_code': status_code, - 'stdout': response.std_out, - 'stderr': response.std_err + "failed": not succeeded, + "succeeded": succeeded, + "return_code": status_code, + "stdout": response.std_out, + "stderr": response.std_err, } # Ensure stdout and stderr is always a string - if isinstance(result['stdout'], six.binary_type): - result['stdout'] = result['stdout'].decode('utf-8') + if isinstance(result["stdout"], six.binary_type): + result["stdout"] = result["stdout"].decode("utf-8") - if isinstance(result['stderr'], six.binary_type): - result['stderr'] = result['stderr'].decode('utf-8') + if isinstance(result["stderr"], six.binary_type): + result["stderr"] = result["stderr"].decode("utf-8") # automatically convert result stdout/stderr from JSON strings to # objects so they can be used natively return (status, jsonify.json_loads(result, RESULT_KEYS_TO_TRANSFORM), None) def _make_tmp_dir(self, parent): - LOG.debug("Creating temporary directory for WinRM script in parent: {}".format(parent)) + LOG.debug( + "Creating temporary directory for WinRM script in parent: {}".format(parent) + ) ps = """$parent = {parent} $name = [System.IO.Path]::GetRandomFileName() $path = Join-Path $parent $name New-Item -ItemType Directory -Path $path | Out-Null -$path""".format(parent=parent) - result = self._run_ps_or_raise(ps, ("Unable to make temporary directory for" - " powershell script")) +$path""".format( + parent=parent + ) + result = self._run_ps_or_raise( + ps, ("Unable to make temporary directory for" " powershell script") + ) # strip to remove trailing newline and whitespace (if any) - return result['stdout'].strip() + return result["stdout"].strip() def _rm_dir(self, directory): ps = 'Remove-Item -Force -Recurse -Path "{}"'.format(directory) - self._run_ps_or_raise(ps, "Unable to remove temporary directory for powershell script") + self._run_ps_or_raise( + ps, "Unable to remove temporary directory for powershell script" + ) def _upload(self, src_path_or_data, dst_path): src_data = None @@ -276,7 +299,7 @@ def _upload(self, src_path_or_data, dst_path): # if this is a path, then read the data from the path if os.path.exists(src_path_or_data): LOG.debug("WinRM uploading local file: {}".format(src_path_or_data)) - with open(src_path_or_data, 'r') as src_file: + with open(src_path_or_data, "r") as src_file: src_data = src_file.read() else: LOG.debug("WinRM uploading data from a string") @@ -285,14 +308,19 @@ def _upload(self, src_path_or_data, dst_path): # upload the data in chunks such that each chunk doesn't exceed the # max command size of the windows command line for i in range(0, len(src_data), WINRM_UPLOAD_CHUNK_SIZE_BYTES): - LOG.debug("WinRM uploading data bytes: {}-{}". - format(i, (i + WINRM_UPLOAD_CHUNK_SIZE_BYTES))) - self._upload_chunk(dst_path, src_data[i:(i + WINRM_UPLOAD_CHUNK_SIZE_BYTES)]) + LOG.debug( + "WinRM uploading data bytes: {}-{}".format( + i, (i + WINRM_UPLOAD_CHUNK_SIZE_BYTES) + ) + ) + self._upload_chunk( + dst_path, src_data[i : (i + WINRM_UPLOAD_CHUNK_SIZE_BYTES)] + ) def _upload_chunk(self, dst_path, src_data): # adapted from https://github.com/diyan/pywinrm/issues/18 if not isinstance(src_data, six.binary_type): - src_data = src_data.encode('utf-8') + src_data = src_data.encode("utf-8") ps = """$filePath = "{dst_path}" $s = @" @@ -300,10 +328,11 @@ def _upload_chunk(self, dst_path, src_data): "@ $data = [System.Convert]::FromBase64String($s) Add-Content -value $data -encoding byte -path $filePath -""".format(dst_path=dst_path, - b64_data=base64.b64encode(src_data).decode('utf-8')) +""".format( + dst_path=dst_path, b64_data=base64.b64encode(src_data).decode("utf-8") + ) - LOG.debug('WinRM uploading chunk, size = {}'.format(len(ps))) + LOG.debug("WinRM uploading chunk, size = {}".format(len(ps))) self._run_ps_or_raise(ps, "Failed to upload chunk of powershell script") @contextmanager @@ -335,7 +364,7 @@ def run_cmd(self, cmd): def run_ps(self, script, params=None): # temporary directory for the powershell script if params: - powershell = '& {%s} %s' % (script, params) + powershell = "& {%s} %s" % (script, params) else: powershell = script encoded_ps = self._winrm_encode(powershell) @@ -346,9 +375,12 @@ def run_ps(self, script, params=None): # else we need to upload the script to a temporary file and execute it, # then remove the temporary file if len(ps_cmd) <= WINRM_MAX_CMD_LENGTH: - LOG.info(("WinRM powershell command size {} is > {}, the max size of a" - " powershell command. Converting to a script execution.") - .format(WINRM_MAX_CMD_LENGTH, len(ps_cmd))) + LOG.info( + ( + "WinRM powershell command size {} is > {}, the max size of a" + " powershell command. Converting to a script execution." + ).format(WINRM_MAX_CMD_LENGTH, len(ps_cmd)) + ) return self._run_ps(encoded_ps, is_b64=True) else: return self._run_ps_script(script, params) @@ -360,8 +392,9 @@ def _run_ps(self, powershell, is_b64=False): # connect session = self._get_session() # execute - response = self._winrm_run_ps(session, powershell, env=self._env, cwd=self._cwd, - is_b64=is_b64) + response = self._winrm_run_ps( + session, powershell, env=self._env, cwd=self._cwd, is_b64=is_b64 + ) # create triplet from WinRM response return self._translate_response(response) @@ -383,12 +416,12 @@ def _run_ps_or_raise(self, ps, error_msg): response = self._run_ps(ps) # response is a tuple: (status, result, None) result = response[1] - if result['failed']: - raise RuntimeError(("{}:\n" - "stdout = {}\n\n" - "stderr = {}").format(error_msg, - result['stdout'], - result['stderr'])) + if result["failed"]: + raise RuntimeError( + ("{}:\n" "stdout = {}\n\n" "stderr = {}").format( + error_msg, result["stdout"], result["stderr"] + ) + ) return result def _multireplace(self, string, replacements): @@ -407,7 +440,7 @@ def _multireplace(self, string, replacements): substrs = sorted(replacements, key=len, reverse=True) # Create a big OR regex that matches any of the substrings to replace - regexp = re.compile('|'.join([re.escape(s) for s in substrs])) + regexp = re.compile("|".join([re.escape(s) for s in substrs])) # For each match, look up the new string in the replacements return regexp.sub(lambda match: replacements[match.group(0)], string) @@ -426,8 +459,12 @@ def _param_to_ps(self, param): ps_str += ")" elif isinstance(param, dict): ps_str = "@{" - ps_str += "; ".join([(self._param_to_ps(k) + ' = ' + self._param_to_ps(v)) - for k, v in six.iteritems(param)]) + ps_str += "; ".join( + [ + (self._param_to_ps(k) + " = " + self._param_to_ps(v)) + for k, v in six.iteritems(param) + ] + ) ps_str += "}" else: ps_str = str(param) @@ -446,12 +483,15 @@ def _transform_params_to_ps(self, positional_args, named_args): def create_ps_params_string(self, positional_args, named_args): # convert the script parameters into powershell strings - positional_args, named_args = self._transform_params_to_ps(positional_args, - named_args) + positional_args, named_args = self._transform_params_to_ps( + positional_args, named_args + ) # concatenate them into a long string ps_params_str = "" if named_args: - ps_params_str += " " .join([(k + " " + v) for k, v in six.iteritems(named_args)]) + ps_params_str += " ".join( + [(k + " " + v) for k, v in six.iteritems(named_args)] + ) ps_params_str += " " if positional_args: ps_params_str += " ".join(positional_args) diff --git a/contrib/runners/winrm_runner/winrm_runner/winrm_command_runner.py b/contrib/runners/winrm_runner/winrm_runner/winrm_command_runner.py index d09e5ce7d66..1239f3efd54 100644 --- a/contrib/runners/winrm_runner/winrm_runner/winrm_command_runner.py +++ b/contrib/runners/winrm_runner/winrm_runner/winrm_command_runner.py @@ -20,19 +20,14 @@ from st2common.runners.base import get_metadata as get_runner_metadata from winrm_runner.winrm_base import WinRmBaseRunner -__all__ = [ - 'WinRmCommandRunner', - 'get_runner', - 'get_metadata' -] +__all__ = ["WinRmCommandRunner", "get_runner", "get_metadata"] LOG = logging.getLogger(__name__) -RUNNER_COMMAND = 'cmd' +RUNNER_COMMAND = "cmd" class WinRmCommandRunner(WinRmBaseRunner): - def run(self, action_parameters): cmd_command = self.runner_parameters[RUNNER_COMMAND] @@ -45,7 +40,10 @@ def get_runner(): def get_metadata(): - metadata = get_runner_metadata('winrm_runner') - metadata = [runner for runner in metadata if - runner['runner_module'] == __name__.split('.')[-1]][0] + metadata = get_runner_metadata("winrm_runner") + metadata = [ + runner + for runner in metadata + if runner["runner_module"] == __name__.split(".")[-1] + ][0] return metadata diff --git a/contrib/runners/winrm_runner/winrm_runner/winrm_ps_command_runner.py b/contrib/runners/winrm_runner/winrm_runner/winrm_ps_command_runner.py index f49db2b09e2..e6d0a37e2f8 100644 --- a/contrib/runners/winrm_runner/winrm_runner/winrm_ps_command_runner.py +++ b/contrib/runners/winrm_runner/winrm_runner/winrm_ps_command_runner.py @@ -20,19 +20,14 @@ from st2common.runners.base import get_metadata as get_runner_metadata from winrm_runner.winrm_base import WinRmBaseRunner -__all__ = [ - 'WinRmPsCommandRunner', - 'get_runner', - 'get_metadata' -] +__all__ = ["WinRmPsCommandRunner", "get_runner", "get_metadata"] LOG = logging.getLogger(__name__) -RUNNER_COMMAND = 'cmd' +RUNNER_COMMAND = "cmd" class WinRmPsCommandRunner(WinRmBaseRunner): - def run(self, action_parameters): powershell_command = self.runner_parameters[RUNNER_COMMAND] @@ -45,7 +40,10 @@ def get_runner(): def get_metadata(): - metadata = get_runner_metadata('winrm_runner') - metadata = [runner for runner in metadata if - runner['runner_module'] == __name__.split('.')[-1]][0] + metadata = get_runner_metadata("winrm_runner") + metadata = [ + runner + for runner in metadata + if runner["runner_module"] == __name__.split(".")[-1] + ][0] return metadata diff --git a/contrib/runners/winrm_runner/winrm_runner/winrm_ps_script_runner.py b/contrib/runners/winrm_runner/winrm_runner/winrm_ps_script_runner.py index 9f156bd8c9f..ff162b7aee3 100644 --- a/contrib/runners/winrm_runner/winrm_runner/winrm_ps_script_runner.py +++ b/contrib/runners/winrm_runner/winrm_runner/winrm_ps_script_runner.py @@ -21,23 +21,18 @@ from st2common.runners.base import get_metadata as get_runner_metadata from winrm_runner.winrm_base import WinRmBaseRunner -__all__ = [ - 'WinRmPsScriptRunner', - 'get_runner', - 'get_metadata' -] +__all__ = ["WinRmPsScriptRunner", "get_runner", "get_metadata"] LOG = logging.getLogger(__name__) class WinRmPsScriptRunner(WinRmBaseRunner, ShellRunnerMixin): - def run(self, action_parameters): if not self.entry_point: - raise ValueError('Missing entry_point action metadata attribute') + raise ValueError("Missing entry_point action metadata attribute") # read in the script contents from the local file - with open(self.entry_point, 'r') as script_file: + with open(self.entry_point, "r") as script_file: ps_script = script_file.read() # extract script parameters specified in the action metadata file @@ -57,7 +52,10 @@ def get_runner(): def get_metadata(): - metadata = get_runner_metadata('winrm_runner') - metadata = [runner for runner in metadata if - runner['runner_module'] == __name__.split('.')[-1]][0] + metadata = get_runner_metadata("winrm_runner") + metadata = [ + runner + for runner in metadata + if runner["runner_module"] == __name__.split(".")[-1] + ][0] return metadata diff --git a/lint-configs/python/.flake8 b/lint-configs/python/.flake8 index f3cc01b3197..4edeebe1621 100644 --- a/lint-configs/python/.flake8 +++ b/lint-configs/python/.flake8 @@ -2,7 +2,10 @@ max-line-length = 100 # L102 - apache license header enable-extensions = L101,L102 -ignore = E128,E402,E722,W504 +# We ignore some rules which conflict with black +# E203 - whitespace before ':' - in direct conflict with black rule +# W503 line break before binary operator - in direct conflict with black rule +ignore = E128,E402,E722,W504,E203,W503 exclude=*.egg/*,build,dist # Configuration for flake8-copyright extension diff --git a/pylint_plugins/api_models.py b/pylint_plugins/api_models.py index 398a664d403..4e14095f714 100644 --- a/pylint_plugins/api_models.py +++ b/pylint_plugins/api_models.py @@ -29,9 +29,7 @@ from astroid import scoped_nodes # A list of class names for which we want to skip the checks -CLASS_NAME_BLACKLIST = [ - 'ExecutionSpecificationAPI' -] +CLASS_NAME_BLACKLIST = ["ExecutionSpecificationAPI"] def register(linter): @@ -42,11 +40,11 @@ def transform(cls): if cls.name in CLASS_NAME_BLACKLIST: return - if cls.name.endswith('API') or 'schema' in cls.locals: + if cls.name.endswith("API") or "schema" in cls.locals: # This is a class which defines attributes in "schema" variable using json schema. # Those attributes are then assigned during run time inside the constructor fqdn = cls.qname() - module_name, class_name = fqdn.rsplit('.', 1) + module_name, class_name = fqdn.rsplit(".", 1) module = __import__(module_name, fromlist=[class_name]) actual_cls = getattr(module, class_name) @@ -57,29 +55,31 @@ def transform(cls): # Not a class we are interested in return - properties = schema.get('properties', {}) + properties = schema.get("properties", {}) for property_name, property_data in six.iteritems(properties): - property_name = property_name.replace('-', '_') # Note: We do the same in Python code - property_type = property_data.get('type', None) + property_name = property_name.replace( + "-", "_" + ) # Note: We do the same in Python code + property_type = property_data.get("type", None) if isinstance(property_type, (list, tuple)): # Hack for attributes with multiple types (e.g. string, null) property_type = property_type[0] - if property_type == 'object': + if property_type == "object": node = nodes.Dict() - elif property_type == 'array': + elif property_type == "array": node = nodes.List() - elif property_type == 'integer': - node = scoped_nodes.builtin_lookup('int')[1][0] - elif property_type == 'number': - node = scoped_nodes.builtin_lookup('float')[1][0] - elif property_type == 'string': - node = scoped_nodes.builtin_lookup('str')[1][0] - elif property_type == 'boolean': - node = scoped_nodes.builtin_lookup('bool')[1][0] - elif property_type == 'null': - node = scoped_nodes.builtin_lookup('None')[1][0] + elif property_type == "integer": + node = scoped_nodes.builtin_lookup("int")[1][0] + elif property_type == "number": + node = scoped_nodes.builtin_lookup("float")[1][0] + elif property_type == "string": + node = scoped_nodes.builtin_lookup("str")[1][0] + elif property_type == "boolean": + node = scoped_nodes.builtin_lookup("bool")[1][0] + elif property_type == "null": + node = scoped_nodes.builtin_lookup("None")[1][0] else: # Unknown type node = astroid.ClassDef(property_name, None) diff --git a/pylint_plugins/db_models.py b/pylint_plugins/db_models.py index 241e9ea5829..da9251462e4 100644 --- a/pylint_plugins/db_models.py +++ b/pylint_plugins/db_models.py @@ -23,8 +23,7 @@ from astroid import nodes # A list of class names for which we want to skip the checks -CLASS_NAME_BLACKLIST = [ -] +CLASS_NAME_BLACKLIST = [] def register(linter): @@ -35,14 +34,14 @@ def transform(cls): if cls.name in CLASS_NAME_BLACKLIST: return - if cls.name == 'StormFoundationDB': + if cls.name == "StormFoundationDB": # _fields get added automagically by mongoengine - if '_fields' not in cls.locals: - cls.locals['_fields'] = [nodes.Dict()] + if "_fields" not in cls.locals: + cls.locals["_fields"] = [nodes.Dict()] - if cls.name.endswith('DB'): + if cls.name.endswith("DB"): # mongoengine explicitly declared "id" field on each class so we teach pylint about that - property_name = 'id' + property_name = "id" node = astroid.ClassDef(property_name, None) cls.locals[property_name] = [node] diff --git a/scripts/dist_utils.py b/scripts/dist_utils.py index ba73f554c6e..c0af527b6bf 100644 --- a/scripts/dist_utils.py +++ b/scripts/dist_utils.py @@ -47,17 +47,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" __all__ = [ - 'check_pip_is_installed', - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_is_installed", + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] @@ -68,15 +68,15 @@ def check_pip_is_installed(): try: import pip # NOQA except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) return True -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ @@ -85,10 +85,12 @@ def check_pip_version(min_version='6.0.0'): import pip if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) return True @@ -102,30 +104,32 @@ def fetch_requirements(requirements_file_path): reqs = [] def _get_link(line): - vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+'] + vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"] for vcs_prefix in vcs_prefixes: - if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)): - req_name = re.findall('.*#egg=(.+)([&|@]).*$', line) + if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)): + req_name = re.findall(".*#egg=(.+)([&|@]).*$", line) if not req_name: - req_name = re.findall('.*#egg=(.+?)$', line) + req_name = re.findall(".*#egg=(.+?)$", line) else: req_name = req_name[0] if not req_name: - raise ValueError('Line "%s" is missing "#egg="' % (line)) + raise ValueError( + 'Line "%s" is missing "#egg="' % (line) + ) - link = line.replace('-e ', '').strip() + link = line.replace("-e ", "").strip() return link, req_name[0] return None, None - with open(requirements_file_path, 'r') as fp: + with open(requirements_file_path, "r") as fp: for line in fp.readlines(): line = line.strip() - if line.startswith('#') or not line: + if line.startswith("#") or not line: continue link, req_name = _get_link(line=line) @@ -135,8 +139,8 @@ def _get_link(line): else: req_name = line - if ';' in req_name: - req_name = req_name.split(';')[0].strip() + if ";" in req_name: + req_name = req_name.split(";")[0].strip() reqs.append(req_name) @@ -150,7 +154,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -159,14 +163,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/scripts/dist_utils_old.py b/scripts/dist_utils_old.py index 5dfadb1bef5..da38f6edbf4 100644 --- a/scripts/dist_utils_old.py +++ b/scripts/dist_utils_old.py @@ -35,17 +35,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" try: import pip from pip import __version__ as pip_version except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) try: @@ -57,28 +57,30 @@ try: from pip._internal.req.req_file import parse_requirements except ImportError as e: - print('Failed to import parse_requirements from pip: %s' % (text_type(e))) - print('Using pip: %s' % (str(pip_version))) + print("Failed to import parse_requirements from pip: %s" % (text_type(e))) + print("Using pip: %s" % (str(pip_version))) sys.exit(1) __all__ = [ - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) @@ -90,7 +92,7 @@ def fetch_requirements(requirements_file_path): reqs = [] for req in parse_requirements(requirements_file_path, session=False): # Note: req.url was used before 9.0.0 and req.link is used in all the recent versions - link = getattr(req, 'link', getattr(req, 'url', None)) + link = getattr(req, "link", getattr(req, "url", None)) if link: links.append(str(link)) reqs.append(str(req.req)) @@ -104,7 +106,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -113,14 +115,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/scripts/fixate-requirements.py b/scripts/fixate-requirements.py index dd5c8d25053..4277c986f83 100755 --- a/scripts/fixate-requirements.py +++ b/scripts/fixate-requirements.py @@ -43,18 +43,18 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 OSCWD = os.path.abspath(os.curdir) -GET_PIP = ' curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = " curl https://bootstrap.pypa.io/get-pip.py | python" try: import pip from pip import __version__ as pip_version except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) try: @@ -66,24 +66,43 @@ try: from pip._internal.req.req_file import parse_requirements except ImportError as e: - print('Failed to import parse_requirements from pip: %s' % (text_type(e))) - print('Using pip: %s' % (str(pip_version))) + print("Failed to import parse_requirements from pip: %s" % (text_type(e))) + print("Using pip: %s" % (str(pip_version))) sys.exit(1) def parse_args(): - parser = argparse.ArgumentParser(description='Tool for requirements.txt generation.') - parser.add_argument('-s', '--source-requirements', nargs='+', - required=True, - help='Specify paths to requirements file(s). ' - 'In case several requirements files are given their content is merged.') - parser.add_argument('-f', '--fixed-requirements', required=True, - help='Specify path to fixed-requirements.txt file.') - parser.add_argument('-o', '--output-file', default='requirements.txt', - help='Specify path to the resulting requirements file.') - parser.add_argument('--skip', default=None, - help=('Comma delimited list of requirements to not ' - 'include in the generated file.')) + parser = argparse.ArgumentParser( + description="Tool for requirements.txt generation." + ) + parser.add_argument( + "-s", + "--source-requirements", + nargs="+", + required=True, + help="Specify paths to requirements file(s). " + "In case several requirements files are given their content is merged.", + ) + parser.add_argument( + "-f", + "--fixed-requirements", + required=True, + help="Specify path to fixed-requirements.txt file.", + ) + parser.add_argument( + "-o", + "--output-file", + default="requirements.txt", + help="Specify path to the resulting requirements file.", + ) + parser.add_argument( + "--skip", + default=None, + help=( + "Comma delimited list of requirements to not " + "include in the generated file." + ), + ) if len(sys.argv) < 2: parser.print_help() sys.exit(1) @@ -91,9 +110,11 @@ def parse_args(): def check_pip_version(): - if StrictVersion(pip.__version__) < StrictVersion('6.1.0'): - print("Upgrade pip, your version `{0}' " - "is outdated:\n".format(pip.__version__), GET_PIP) + if StrictVersion(pip.__version__) < StrictVersion("6.1.0"): + print( + "Upgrade pip, your version `{0}' " "is outdated:\n".format(pip.__version__), + GET_PIP, + ) sys.exit(1) @@ -129,13 +150,14 @@ def merge_source_requirements(sources): elif req.link: merged_requirements.append(req) else: - raise RuntimeError('Unexpected requirement {0}'.format(req)) + raise RuntimeError("Unexpected requirement {0}".format(req)) return merged_requirements -def write_requirements(sources=None, fixed_requirements=None, output_file=None, - skip=None): +def write_requirements( + sources=None, fixed_requirements=None, output_file=None, skip=None +): """ Write resulting requirements taking versions from the fixed_requirements. """ @@ -153,7 +175,9 @@ def write_requirements(sources=None, fixed_requirements=None, output_file=None, continue if project_name in fixedreq_hash: - raise ValueError('Duplicate definition for dependency "%s"' % (project_name)) + raise ValueError( + 'Duplicate definition for dependency "%s"' % (project_name) + ) fixedreq_hash[project_name] = req @@ -169,7 +193,7 @@ def write_requirements(sources=None, fixed_requirements=None, output_file=None, rline = str(req.link) if req.editable: - rline = '-e %s' % (rline) + rline = "-e %s" % (rline) elif req.req: project = req.name req_obj = fixedreq_hash.get(project, req) @@ -184,30 +208,40 @@ def write_requirements(sources=None, fixed_requirements=None, output_file=None, # Sort the lines to guarantee a stable order lines_to_write = sorted(lines_to_write) - data = '\n'.join(lines_to_write) + '\n' - with open(output_file, 'w') as fp: - fp.write('# Don\'t edit this file. It\'s generated automatically!\n') - fp.write('# If you want to update global dependencies, modify fixed-requirements.txt\n') - fp.write('# and then run \'make requirements\' to update requirements.txt for all\n') - fp.write('# components.\n') - fp.write('# If you want to update depdencies for a single component, modify the\n') - fp.write('# in-requirements.txt for that component and then run \'make requirements\' to\n') - fp.write('# update the component requirements.txt\n') + data = "\n".join(lines_to_write) + "\n" + with open(output_file, "w") as fp: + fp.write("# Don't edit this file. It's generated automatically!\n") + fp.write( + "# If you want to update global dependencies, modify fixed-requirements.txt\n" + ) + fp.write( + "# and then run 'make requirements' to update requirements.txt for all\n" + ) + fp.write("# components.\n") + fp.write( + "# If you want to update depdencies for a single component, modify the\n" + ) + fp.write( + "# in-requirements.txt for that component and then run 'make requirements' to\n" + ) + fp.write("# update the component requirements.txt\n") fp.write(data) - print('Requirements written to: {0}'.format(output_file)) + print("Requirements written to: {0}".format(output_file)) -if __name__ == '__main__': +if __name__ == "__main__": check_pip_version() args = parse_args() - if args['skip']: - skip = args['skip'].split(',') + if args["skip"]: + skip = args["skip"].split(",") else: skip = None - write_requirements(sources=args['source_requirements'], - fixed_requirements=args['fixed_requirements'], - output_file=args['output_file'], - skip=skip) + write_requirements( + sources=args["source_requirements"], + fixed_requirements=args["fixed_requirements"], + output_file=args["output_file"], + skip=skip, + ) diff --git a/st2actions/dist_utils.py b/st2actions/dist_utils.py index a6f62c8cc2a..2f2043cf29f 100644 --- a/st2actions/dist_utils.py +++ b/st2actions/dist_utils.py @@ -43,17 +43,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" __all__ = [ - 'check_pip_is_installed', - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_is_installed", + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] @@ -64,15 +64,15 @@ def check_pip_is_installed(): try: import pip # NOQA except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) return True -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ @@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'): import pip if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) return True @@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path): reqs = [] def _get_link(line): - vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+'] + vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"] for vcs_prefix in vcs_prefixes: - if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)): - req_name = re.findall('.*#egg=(.+)([&|@]).*$', line) + if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)): + req_name = re.findall(".*#egg=(.+)([&|@]).*$", line) if not req_name: - req_name = re.findall('.*#egg=(.+?)$', line) + req_name = re.findall(".*#egg=(.+?)$", line) else: req_name = req_name[0] if not req_name: - raise ValueError('Line "%s" is missing "#egg="' % (line)) + raise ValueError( + 'Line "%s" is missing "#egg="' % (line) + ) - link = line.replace('-e ', '').strip() + link = line.replace("-e ", "").strip() return link, req_name[0] return None, None - with open(requirements_file_path, 'r') as fp: + with open(requirements_file_path, "r") as fp: for line in fp.readlines(): line = line.strip() - if line.startswith('#') or not line: + if line.startswith("#") or not line: continue link, req_name = _get_link(line=line) @@ -131,8 +135,8 @@ def _get_link(line): else: req_name = line - if ';' in req_name: - req_name = req_name.split(';')[0].strip() + if ";" in req_name: + req_name = req_name.split(";")[0].strip() reqs.append(req_name) @@ -146,7 +150,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -155,14 +159,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/st2actions/setup.py b/st2actions/setup.py index 6fcb2cde924..a4e8c127901 100644 --- a/st2actions/setup.py +++ b/st2actions/setup.py @@ -23,9 +23,9 @@ from dist_utils import apply_vagrant_workaround from st2actions import __version__ -ST2_COMPONENT = 'st2actions' +ST2_COMPONENT = "st2actions" BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt') +REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt") install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE) @@ -33,21 +33,23 @@ setup( name=ST2_COMPONENT, version=__version__, - description='{} StackStorm event-driven automation platform component'.format(ST2_COMPONENT), - author='StackStorm', - author_email='info@stackstorm.com', - license='Apache License (2.0)', - url='https://stackstorm.com/', + description="{} StackStorm event-driven automation platform component".format( + ST2_COMPONENT + ), + author="StackStorm", + author_email="info@stackstorm.com", + license="Apache License (2.0)", + url="https://stackstorm.com/", install_requires=install_reqs, dependency_links=dep_links, test_suite=ST2_COMPONENT, zip_safe=False, include_package_data=True, - packages=find_packages(exclude=['setuptools', 'tests']), + packages=find_packages(exclude=["setuptools", "tests"]), scripts=[ - 'bin/st2actionrunner', - 'bin/st2notifier', - 'bin/st2workflowengine', - 'bin/st2scheduler', - ] + "bin/st2actionrunner", + "bin/st2notifier", + "bin/st2workflowengine", + "bin/st2scheduler", + ], ) diff --git a/st2actions/st2actions/__init__.py b/st2actions/st2actions/__init__.py index bbe290db9a7..e6d3f15e0bd 100644 --- a/st2actions/st2actions/__init__.py +++ b/st2actions/st2actions/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = "3.4dev" diff --git a/st2actions/st2actions/cmd/actionrunner.py b/st2actions/st2actions/cmd/actionrunner.py index 457bf45e033..6aa339115a8 100644 --- a/st2actions/st2actions/cmd/actionrunner.py +++ b/st2actions/st2actions/cmd/actionrunner.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from st2common.util.monkey_patch import monkey_patch + monkey_patch() import os @@ -30,15 +31,12 @@ from st2common.service_setup import setup as common_setup from st2common.service_setup import teardown as common_teardown -__all__ = [ - 'main' -] +__all__ = ["main"] LOG = logging.getLogger(__name__) def _setup_sigterm_handler(): - def sigterm_handler(signum=None, frame=None): # This will cause SystemExit to be throw and allow for component cleanup. sys.exit(0) @@ -49,18 +47,22 @@ def sigterm_handler(signum=None, frame=None): def _setup(): - capabilities = { - 'name': 'actionrunner', - 'type': 'passive' - } - common_setup(service='actionrunner', config=config, setup_db=True, register_mq_exchanges=True, - register_signal_handlers=True, service_registry=True, capabilities=capabilities) + capabilities = {"name": "actionrunner", "type": "passive"} + common_setup( + service="actionrunner", + config=config, + setup_db=True, + register_mq_exchanges=True, + register_signal_handlers=True, + service_registry=True, + capabilities=capabilities, + ) _setup_sigterm_handler() def _run_worker(): - LOG.info('(PID=%s) Worker started.', os.getpid()) + LOG.info("(PID=%s) Worker started.", os.getpid()) action_worker = worker.get_worker() @@ -68,20 +70,20 @@ def _run_worker(): action_worker.start() action_worker.wait() except (KeyboardInterrupt, SystemExit): - LOG.info('(PID=%s) Worker stopped.', os.getpid()) + LOG.info("(PID=%s) Worker stopped.", os.getpid()) errors = False try: action_worker.shutdown() except: - LOG.exception('Unable to shutdown worker.') + LOG.exception("Unable to shutdown worker.") errors = True if errors: return 1 except: - LOG.exception('(PID=%s) Worker unexpectedly stopped.', os.getpid()) + LOG.exception("(PID=%s) Worker unexpectedly stopped.", os.getpid()) return 1 return 0 @@ -98,7 +100,7 @@ def main(): except SystemExit as exit_code: sys.exit(exit_code) except: - LOG.exception('(PID=%s) Worker quit due to exception.', os.getpid()) + LOG.exception("(PID=%s) Worker quit due to exception.", os.getpid()) return 1 finally: _teardown() diff --git a/st2actions/st2actions/cmd/scheduler.py b/st2actions/st2actions/cmd/scheduler.py index b3c972b6543..df6dd768db8 100644 --- a/st2actions/st2actions/cmd/scheduler.py +++ b/st2actions/st2actions/cmd/scheduler.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from st2common.util.monkey_patch import monkey_patch + monkey_patch() import os @@ -28,9 +29,7 @@ from st2common.service_setup import teardown as common_teardown from st2common.service_setup import setup as common_setup -__all__ = [ - 'main' -] +__all__ = ["main"] LOG = logging.getLogger(__name__) @@ -46,23 +45,27 @@ def sigterm_handler(signum=None, frame=None): def _setup(): - capabilities = { - 'name': 'scheduler', - 'type': 'passive' - } - common_setup(service='scheduler', config=config, setup_db=True, register_mq_exchanges=True, - register_signal_handlers=True, service_registry=True, capabilities=capabilities) + capabilities = {"name": "scheduler", "type": "passive"} + common_setup( + service="scheduler", + config=config, + setup_db=True, + register_mq_exchanges=True, + register_signal_handlers=True, + service_registry=True, + capabilities=capabilities, + ) _setup_sigterm_handler() def _run_scheduler(): - LOG.info('(PID=%s) Scheduler started.', os.getpid()) + LOG.info("(PID=%s) Scheduler started.", os.getpid()) # Lazy load these so that decorator metrics are in place from st2actions.scheduler import ( handler as scheduler_handler, - entrypoint as scheduler_entrypoint + entrypoint as scheduler_entrypoint, ) handler = scheduler_handler.get_handler() @@ -73,14 +76,18 @@ def _run_scheduler(): try: handler._cleanup_policy_delayed() except Exception: - LOG.exception('(PID=%s) Scheduler unable to perform migration cleanup.', os.getpid()) + LOG.exception( + "(PID=%s) Scheduler unable to perform migration cleanup.", os.getpid() + ) # TODO: Remove this try block for _fix_missing_action_execution_id in v3.2. # This is a temporary fix to auto-populate action_execution_id. try: handler._fix_missing_action_execution_id() except Exception: - LOG.exception('(PID=%s) Scheduler unable to populate action_execution_id.', os.getpid()) + LOG.exception( + "(PID=%s) Scheduler unable to populate action_execution_id.", os.getpid() + ) try: handler.start() @@ -89,7 +96,7 @@ def _run_scheduler(): # Wait on handler first since entrypoint is more durable. handler.wait() or entrypoint.wait() except (KeyboardInterrupt, SystemExit): - LOG.info('(PID=%s) Scheduler stopped.', os.getpid()) + LOG.info("(PID=%s) Scheduler stopped.", os.getpid()) errors = False @@ -97,13 +104,13 @@ def _run_scheduler(): handler.shutdown() entrypoint.shutdown() except: - LOG.exception('Unable to shutdown scheduler.') + LOG.exception("Unable to shutdown scheduler.") errors = True if errors: return 1 except: - LOG.exception('(PID=%s) Scheduler unexpectedly stopped.', os.getpid()) + LOG.exception("(PID=%s) Scheduler unexpectedly stopped.", os.getpid()) try: handler.shutdown() @@ -127,7 +134,7 @@ def main(): except SystemExit as exit_code: sys.exit(exit_code) except: - LOG.exception('(PID=%s) Scheduler quit due to exception.', os.getpid()) + LOG.exception("(PID=%s) Scheduler quit due to exception.", os.getpid()) return 1 finally: _teardown() diff --git a/st2actions/st2actions/cmd/st2notifier.py b/st2actions/st2actions/cmd/st2notifier.py index fdf74f5bf1a..7f1ccc72224 100644 --- a/st2actions/st2actions/cmd/st2notifier.py +++ b/st2actions/st2actions/cmd/st2notifier.py @@ -16,6 +16,7 @@ from __future__ import absolute_import from st2common.util.monkey_patch import monkey_patch + monkey_patch() import os @@ -27,29 +28,31 @@ from st2actions.notifier import config from st2actions.notifier import notifier -__all__ = [ - 'main' -] +__all__ = ["main"] LOG = logging.getLogger(__name__) def _setup(): - capabilities = { - 'name': 'notifier', - 'type': 'passive' - } - common_setup(service='notifier', config=config, setup_db=True, register_mq_exchanges=True, - register_signal_handlers=True, service_registry=True, capabilities=capabilities) + capabilities = {"name": "notifier", "type": "passive"} + common_setup( + service="notifier", + config=config, + setup_db=True, + register_mq_exchanges=True, + register_signal_handlers=True, + service_registry=True, + capabilities=capabilities, + ) def _run_worker(): - LOG.info('(PID=%s) Actions notifier started.', os.getpid()) + LOG.info("(PID=%s) Actions notifier started.", os.getpid()) actions_notifier = notifier.get_notifier() try: actions_notifier.start(wait=True) except (KeyboardInterrupt, SystemExit): - LOG.info('(PID=%s) Actions notifier stopped.', os.getpid()) + LOG.info("(PID=%s) Actions notifier stopped.", os.getpid()) actions_notifier.shutdown() return 0 @@ -65,7 +68,7 @@ def main(): except SystemExit as exit_code: sys.exit(exit_code) except: - LOG.exception('(PID=%s) Results tracker quit due to exception.', os.getpid()) + LOG.exception("(PID=%s) Results tracker quit due to exception.", os.getpid()) return 1 finally: _teardown() diff --git a/st2actions/st2actions/cmd/workflow_engine.py b/st2actions/st2actions/cmd/workflow_engine.py index 361d6ce9e14..f51296b4b0b 100644 --- a/st2actions/st2actions/cmd/workflow_engine.py +++ b/st2actions/st2actions/cmd/workflow_engine.py @@ -19,6 +19,7 @@ from __future__ import absolute_import from st2common.util.monkey_patch import monkey_patch + monkey_patch() import os @@ -32,15 +33,12 @@ from st2common.service_setup import setup as common_setup from st2common.service_setup import teardown as common_teardown -__all__ = [ - 'main' -] +__all__ = ["main"] LOG = logging.getLogger(__name__) def setup_sigterm_handler(): - def sigterm_handler(signum=None, frame=None): # This will cause SystemExit to be throw and allow for component cleanup. sys.exit(0) @@ -51,35 +49,32 @@ def sigterm_handler(signum=None, frame=None): def setup(): - capabilities = { - 'name': 'workflowengine', - 'type': 'passive' - } + capabilities = {"name": "workflowengine", "type": "passive"} common_setup( - service='workflow_engine', + service="workflow_engine", config=config, setup_db=True, register_mq_exchanges=True, register_signal_handlers=True, service_registry=True, - capabilities=capabilities + capabilities=capabilities, ) setup_sigterm_handler() def run_server(): - LOG.info('(PID=%s) Workflow engine started.', os.getpid()) + LOG.info("(PID=%s) Workflow engine started.", os.getpid()) engine = workflows.get_engine() try: engine.start(wait=True) except (KeyboardInterrupt, SystemExit): - LOG.info('(PID=%s) Workflow engine stopped.', os.getpid()) + LOG.info("(PID=%s) Workflow engine stopped.", os.getpid()) engine.shutdown() except: - LOG.exception('(PID=%s) Workflow engine unexpectedly stopped.', os.getpid()) + LOG.exception("(PID=%s) Workflow engine unexpectedly stopped.", os.getpid()) return 1 return 0 @@ -97,7 +92,7 @@ def main(): sys.exit(exit_code) except Exception: traceback.print_exc() - LOG.exception('(PID=%s) Workflow engine quit due to exception.', os.getpid()) + LOG.exception("(PID=%s) Workflow engine quit due to exception.", os.getpid()) return 1 finally: teardown() diff --git a/st2actions/st2actions/config.py b/st2actions/st2actions/config.py index b4e83a5306d..14dc2c4f58a 100644 --- a/st2actions/st2actions/config.py +++ b/st2actions/st2actions/config.py @@ -28,8 +28,11 @@ def parse_args(args=None): - CONF(args=args, version=VERSION_STRING, - default_config_files=[DEFAULT_CONFIG_FILE_PATH]) + CONF( + args=args, + version=VERSION_STRING, + default_config_files=[DEFAULT_CONFIG_FILE_PATH], + ) def register_opts(): diff --git a/st2actions/st2actions/container/base.py b/st2actions/st2actions/container/base.py index a350a3dd696..7f2b50f0c77 100644 --- a/st2actions/st2actions/container/base.py +++ b/st2actions/st2actions/container/base.py @@ -30,8 +30,8 @@ from st2common.models.system.action import ResolvedActionParameters from st2common.persistence.execution import ActionExecution from st2common.services import access, executions, queries -from st2common.util.action_db import (get_action_by_ref, get_runnertype_by_name) -from st2common.util.action_db import (update_liveaction_status, get_liveaction_by_id) +from st2common.util.action_db import get_action_by_ref, get_runnertype_by_name +from st2common.util.action_db import update_liveaction_status, get_liveaction_by_id from st2common.util import param as param_utils from st2common.util.config_loader import ContentPackConfigLoader from st2common.metrics.base import CounterWithTimer @@ -42,30 +42,28 @@ LOG = logging.getLogger(__name__) -__all__ = [ - 'RunnerContainer', - 'get_runner_container' -] +__all__ = ["RunnerContainer", "get_runner_container"] class RunnerContainer(object): - def dispatch(self, liveaction_db): action_db = get_action_by_ref(liveaction_db.action) if not action_db: - raise Exception('Action %s not found in DB.' % (liveaction_db.action)) + raise Exception("Action %s not found in DB." % (liveaction_db.action)) - liveaction_db.context['pack'] = action_db.pack + liveaction_db.context["pack"] = action_db.pack - runner_type_db = get_runnertype_by_name(action_db.runner_type['name']) + runner_type_db = get_runnertype_by_name(action_db.runner_type["name"]) - extra = {'liveaction_db': liveaction_db, 'runner_type_db': runner_type_db} - LOG.info('Dispatching Action to a runner', extra=extra) + extra = {"liveaction_db": liveaction_db, "runner_type_db": runner_type_db} + LOG.info("Dispatching Action to a runner", extra=extra) # Get runner instance. runner = self._get_runner(runner_type_db, action_db, liveaction_db) - LOG.debug('Runner instance for RunnerType "%s" is: %s', runner_type_db.name, runner) + LOG.debug( + 'Runner instance for RunnerType "%s" is: %s', runner_type_db.name, runner + ) # Process the request. funcs = { @@ -74,12 +72,12 @@ def dispatch(self, liveaction_db): action_constants.LIVEACTION_STATUS_RUNNING: self._do_run, action_constants.LIVEACTION_STATUS_CANCELING: self._do_cancel, action_constants.LIVEACTION_STATUS_PAUSING: self._do_pause, - action_constants.LIVEACTION_STATUS_RESUMING: self._do_resume + action_constants.LIVEACTION_STATUS_RESUMING: self._do_resume, } if liveaction_db.status not in funcs: raise actionrunner.ActionRunnerDispatchError( - 'Action runner is unable to dispatch the liveaction because it is ' + "Action runner is unable to dispatch the liveaction because it is " 'in an unsupported status of "%s".' % liveaction_db.status ) @@ -94,7 +92,8 @@ def _do_run(self, runner): runner.auth_token = self._create_auth_token( context=runner.context, action_db=runner.action, - liveaction_db=runner.liveaction) + liveaction_db=runner.liveaction, + ) try: # Finalized parameters are resolved and then rendered. This process could @@ -104,13 +103,14 @@ def _do_run(self, runner): runner.runner_type.runner_parameters, runner.action.parameters, runner.liveaction.parameters, - runner.liveaction.context) + runner.liveaction.context, + ) runner.runner_parameters = runner_params except ParamException as e: raise actionrunner.ActionRunnerException(six.text_type(e)) - LOG.debug('Performing pre-run for runner: %s', runner.runner_id) + LOG.debug("Performing pre-run for runner: %s", runner.runner_id) runner.pre_run() # Mask secret parameters in the log context @@ -118,90 +118,117 @@ def _do_run(self, runner): action_db=runner.action, runner_type_db=runner.runner_type, runner_parameters=runner_params, - action_parameters=action_params) + action_parameters=action_params, + ) - extra = {'runner': runner, 'parameters': resolved_action_params} - LOG.debug('Performing run for runner: %s' % (runner.runner_id), extra=extra) + extra = {"runner": runner, "parameters": resolved_action_params} + LOG.debug("Performing run for runner: %s" % (runner.runner_id), extra=extra) - with CounterWithTimer(key='action.executions'): - with CounterWithTimer(key='action.%s.executions' % (runner.action.ref)): + with CounterWithTimer(key="action.executions"): + with CounterWithTimer(key="action.%s.executions" % (runner.action.ref)): (status, result, context) = runner.run(action_params) result = jsonify.try_loads(result) action_completed = status in action_constants.LIVEACTION_COMPLETED_STATES - if (isinstance(runner, PollingAsyncActionRunner) and - runner.is_polling_enabled() and not action_completed): + if ( + isinstance(runner, PollingAsyncActionRunner) + and runner.is_polling_enabled() + and not action_completed + ): queries.setup_query(runner.liveaction.id, runner.runner_type, context) except: - LOG.exception('Failed to run action.') + LOG.exception("Failed to run action.") _, ex, tb = sys.exc_info() # mark execution as failed. status = action_constants.LIVEACTION_STATUS_FAILED # include the error message and traceback to try and provide some hints. - result = {'error': str(ex), 'traceback': ''.join(traceback.format_tb(tb, 20))} + result = { + "error": str(ex), + "traceback": "".join(traceback.format_tb(tb, 20)), + } context = None finally: # Log action completion - extra = {'result': result, 'status': status} + extra = {"result": result, "status": status} LOG.debug('Action "%s" completed.' % (runner.action.name), extra=extra) # Update the final status of liveaction and corresponding action execution. - runner.liveaction = self._update_status(runner.liveaction.id, status, result, context) + runner.liveaction = self._update_status( + runner.liveaction.id, status, result, context + ) # Always clean-up the auth_token # This method should be called in the finally block to ensure post_run is not impacted. self._clean_up_auth_token(runner=runner, status=status) - LOG.debug('Performing post_run for runner: %s', runner.runner_id) + LOG.debug("Performing post_run for runner: %s", runner.runner_id) runner.post_run(status=status, result=result) - LOG.debug('Runner do_run result', extra={'result': runner.liveaction.result}) - LOG.audit('Liveaction completed', extra={'liveaction_db': runner.liveaction}) + LOG.debug("Runner do_run result", extra={"result": runner.liveaction.result}) + LOG.audit("Liveaction completed", extra={"liveaction_db": runner.liveaction}) return runner.liveaction def _do_cancel(self, runner): try: - extra = {'runner': runner} - LOG.debug('Performing cancel for runner: %s', (runner.runner_id), extra=extra) + extra = {"runner": runner} + LOG.debug( + "Performing cancel for runner: %s", (runner.runner_id), extra=extra + ) (status, result, context) = runner.cancel() # Update the final status of liveaction and corresponding action execution. # The status is updated here because we want to keep the workflow running # as is if the cancel operation failed. - runner.liveaction = self._update_status(runner.liveaction.id, status, result, context) + runner.liveaction = self._update_status( + runner.liveaction.id, status, result, context + ) except: _, ex, tb = sys.exc_info() # include the error message and traceback to try and provide some hints. - result = {'error': str(ex), 'traceback': ''.join(traceback.format_tb(tb, 20))} - LOG.exception('Failed to cancel action %s.' % (runner.liveaction.id), extra=result) + result = { + "error": str(ex), + "traceback": "".join(traceback.format_tb(tb, 20)), + } + LOG.exception( + "Failed to cancel action %s." % (runner.liveaction.id), extra=result + ) finally: # Always clean-up the auth_token # This method should be called in the finally block to ensure post_run is not impacted. self._clean_up_auth_token(runner=runner, status=runner.liveaction.status) - LOG.debug('Performing post_run for runner: %s', runner.runner_id) - result = {'error': 'Execution canceled by user.'} + LOG.debug("Performing post_run for runner: %s", runner.runner_id) + result = {"error": "Execution canceled by user."} runner.post_run(status=runner.liveaction.status, result=result) return runner.liveaction def _do_pause(self, runner): try: - extra = {'runner': runner} - LOG.debug('Performing pause for runner: %s', (runner.runner_id), extra=extra) + extra = {"runner": runner} + LOG.debug( + "Performing pause for runner: %s", (runner.runner_id), extra=extra + ) (status, result, context) = runner.pause() except: _, ex, tb = sys.exc_info() # include the error message and traceback to try and provide some hints. status = action_constants.LIVEACTION_STATUS_FAILED - result = {'error': str(ex), 'traceback': ''.join(traceback.format_tb(tb, 20))} + result = { + "error": str(ex), + "traceback": "".join(traceback.format_tb(tb, 20)), + } context = runner.liveaction.context - LOG.exception('Failed to pause action %s.' % (runner.liveaction.id), extra=result) + LOG.exception( + "Failed to pause action %s." % (runner.liveaction.id), extra=result + ) finally: # Update the final status of liveaction and corresponding action execution. - runner.liveaction = self._update_status(runner.liveaction.id, status, result, context) + runner.liveaction = self._update_status( + runner.liveaction.id, status, result, context + ) # Always clean-up the auth_token self._clean_up_auth_token(runner=runner, status=runner.liveaction.status) @@ -210,35 +237,47 @@ def _do_pause(self, runner): def _do_resume(self, runner): try: - extra = {'runner': runner} - LOG.debug('Performing resume for runner: %s', (runner.runner_id), extra=extra) + extra = {"runner": runner} + LOG.debug( + "Performing resume for runner: %s", (runner.runner_id), extra=extra + ) (status, result, context) = runner.resume() result = jsonify.try_loads(result) action_completed = status in action_constants.LIVEACTION_COMPLETED_STATES - if (isinstance(runner, PollingAsyncActionRunner) and - runner.is_polling_enabled() and not action_completed): + if ( + isinstance(runner, PollingAsyncActionRunner) + and runner.is_polling_enabled() + and not action_completed + ): queries.setup_query(runner.liveaction.id, runner.runner_type, context) except: _, ex, tb = sys.exc_info() # include the error message and traceback to try and provide some hints. status = action_constants.LIVEACTION_STATUS_FAILED - result = {'error': str(ex), 'traceback': ''.join(traceback.format_tb(tb, 20))} + result = { + "error": str(ex), + "traceback": "".join(traceback.format_tb(tb, 20)), + } context = runner.liveaction.context - LOG.exception('Failed to resume action %s.' % (runner.liveaction.id), extra=result) + LOG.exception( + "Failed to resume action %s." % (runner.liveaction.id), extra=result + ) finally: # Update the final status of liveaction and corresponding action execution. - runner.liveaction = self._update_status(runner.liveaction.id, status, result, context) + runner.liveaction = self._update_status( + runner.liveaction.id, status, result, context + ) # Always clean-up the auth_token # This method should be called in the finally block to ensure post_run is not impacted. self._clean_up_auth_token(runner=runner, status=runner.liveaction.status) - LOG.debug('Performing post_run for runner: %s', runner.runner_id) + LOG.debug("Performing post_run for runner: %s", runner.runner_id) runner.post_run(status=status, result=result) - LOG.debug('Runner do_run result', extra={'result': runner.liveaction.result}) - LOG.audit('Liveaction completed', extra={'liveaction_db': runner.liveaction}) + LOG.debug("Runner do_run result", extra={"result": runner.liveaction.result}) + LOG.audit("Liveaction completed", extra={"liveaction_db": runner.liveaction}) return runner.liveaction @@ -260,7 +299,7 @@ def _clean_up_auth_token(self, runner, status): try: self._delete_auth_token(runner.auth_token) except: - LOG.exception('Unable to clean-up auth_token.') + LOG.exception("Unable to clean-up auth_token.") return True @@ -273,8 +312,8 @@ def _update_live_action_db(self, liveaction_id, status, result, context): liveaction_db = get_liveaction_by_id(liveaction_id) state_changed = ( - liveaction_db.status != status and - liveaction_db.status not in action_constants.LIVEACTION_COMPLETED_STATES + liveaction_db.status != status + and liveaction_db.status not in action_constants.LIVEACTION_COMPLETED_STATES ) if status in action_constants.LIVEACTION_COMPLETED_STATES: @@ -287,64 +326,69 @@ def _update_live_action_db(self, liveaction_id, status, result, context): result=result, context=context, end_timestamp=end_timestamp, - liveaction_db=liveaction_db + liveaction_db=liveaction_db, ) return (liveaction_db, state_changed) def _update_status(self, liveaction_id, status, result, context): try: - LOG.debug('Setting status: %s for liveaction: %s', status, liveaction_id) + LOG.debug("Setting status: %s for liveaction: %s", status, liveaction_id) liveaction_db, state_changed = self._update_live_action_db( - liveaction_id, status, result, context) + liveaction_id, status, result, context + ) except Exception as e: LOG.exception( - 'Cannot update liveaction ' - '(id: %s, status: %s, result: %s).' % ( - liveaction_id, status, result) + "Cannot update liveaction " + "(id: %s, status: %s, result: %s)." % (liveaction_id, status, result) ) raise e try: executions.update_execution(liveaction_db, publish=state_changed) - extra = {'liveaction_db': liveaction_db} - LOG.debug('Updated liveaction after run', extra=extra) + extra = {"liveaction_db": liveaction_db} + LOG.debug("Updated liveaction after run", extra=extra) except Exception as e: LOG.exception( - 'Cannot update action execution for liveaction ' - '(id: %s, status: %s, result: %s).' % ( - liveaction_id, status, result) + "Cannot update action execution for liveaction " + "(id: %s, status: %s, result: %s)." % (liveaction_id, status, result) ) raise e return liveaction_db def _get_entry_point_abs_path(self, pack, entry_point): - return content_utils.get_entry_point_abs_path(pack=pack, entry_point=entry_point) + return content_utils.get_entry_point_abs_path( + pack=pack, entry_point=entry_point + ) def _get_action_libs_abs_path(self, pack, entry_point): - return content_utils.get_action_libs_abs_path(pack=pack, entry_point=entry_point) + return content_utils.get_action_libs_abs_path( + pack=pack, entry_point=entry_point + ) def _get_rerun_reference(self, context): - execution_id = context.get('re-run', {}).get('ref') + execution_id = context.get("re-run", {}).get("ref") return ActionExecution.get_by_id(execution_id) if execution_id else None def _get_runner(self, runner_type_db, action_db, liveaction_db): - resolved_entry_point = self._get_entry_point_abs_path(action_db.pack, action_db.entry_point) - context = getattr(liveaction_db, 'context', dict()) - user = context.get('user', cfg.CONF.system_user.user) + resolved_entry_point = self._get_entry_point_abs_path( + action_db.pack, action_db.entry_point + ) + context = getattr(liveaction_db, "context", dict()) + user = context.get("user", cfg.CONF.system_user.user) config = None # Note: Right now configs are only supported by the Python runner actions - if (runner_type_db.name == 'python-script' or - runner_type_db.runner_module == 'python_runner'): - LOG.debug('Loading config from pack for python runner.') + if ( + runner_type_db.name == "python-script" + or runner_type_db.runner_module == "python_runner" + ): + LOG.debug("Loading config from pack for python runner.") config_loader = ContentPackConfigLoader(pack_name=action_db.pack, user=user) config = config_loader.get_config() - runner = get_runner( - name=runner_type_db.name, - config=config) + runner = get_runner(name=runner_type_db.name, config=config) # TODO: Pass those arguments to the constructor instead of late # assignment, late assignment is awful @@ -357,13 +401,16 @@ def _get_runner(self, runner_type_db, action_db, liveaction_db): runner.execution_id = str(runner.execution.id) runner.entry_point = resolved_entry_point runner.context = context - runner.callback = getattr(liveaction_db, 'callback', dict()) - runner.libs_dir_path = self._get_action_libs_abs_path(action_db.pack, - action_db.entry_point) + runner.callback = getattr(liveaction_db, "callback", dict()) + runner.libs_dir_path = self._get_action_libs_abs_path( + action_db.pack, action_db.entry_point + ) # For re-run, get the ActionExecutionDB in which the re-run is based on. - rerun_ref_id = runner.context.get('re-run', {}).get('ref') - runner.rerun_ex_ref = ActionExecution.get(id=rerun_ref_id) if rerun_ref_id else None + rerun_ref_id = runner.context.get("re-run", {}).get("ref") + runner.rerun_ex_ref = ( + ActionExecution.get(id=rerun_ref_id) if rerun_ref_id else None + ) return runner @@ -371,19 +418,20 @@ def _create_auth_token(self, context, action_db, liveaction_db): if not context: return None - user = context.get('user', None) + user = context.get("user", None) if not user: return None metadata = { - 'service': 'actions_container', - 'action_name': action_db.name, - 'live_action_id': str(liveaction_db.id) - + "service": "actions_container", + "action_name": action_db.name, + "live_action_id": str(liveaction_db.id), } ttl = cfg.CONF.auth.service_token_ttl - token_db = access.create_token(username=user, ttl=ttl, metadata=metadata, service=True) + token_db = access.create_token( + username=user, ttl=ttl, metadata=metadata, service=True + ) return token_db def _delete_auth_token(self, auth_token): diff --git a/st2actions/st2actions/notifier/config.py b/st2actions/st2actions/notifier/config.py index 6c0162f3103..0322179bbc9 100644 --- a/st2actions/st2actions/notifier/config.py +++ b/st2actions/st2actions/notifier/config.py @@ -27,8 +27,11 @@ def parse_args(args=None): - cfg.CONF(args=args, version=VERSION_STRING, - default_config_files=[DEFAULT_CONFIG_FILE_PATH]) + cfg.CONF( + args=args, + version=VERSION_STRING, + default_config_files=[DEFAULT_CONFIG_FILE_PATH], + ) def register_opts(): @@ -47,11 +50,13 @@ def _register_common_opts(): def _register_notifier_opts(): notifier_opts = [ cfg.StrOpt( - 'logging', default='/etc/st2/logging.notifier.conf', - help='Location of the logging configuration file.') + "logging", + default="/etc/st2/logging.notifier.conf", + help="Location of the logging configuration file.", + ) ] - CONF.register_opts(notifier_opts, group='notifier') + CONF.register_opts(notifier_opts, group="notifier") register_opts() diff --git a/st2actions/st2actions/notifier/notifier.py b/st2actions/st2actions/notifier/notifier.py index 37db830e520..ea1a5377331 100644 --- a/st2actions/st2actions/notifier/notifier.py +++ b/st2actions/st2actions/notifier/notifier.py @@ -42,22 +42,23 @@ from st2common.constants.action import ACTION_CONTEXT_KV_PREFIX from st2common.constants.action import ACTION_PARAMETERS_KV_PREFIX from st2common.constants.action import ACTION_RESULTS_KV_PREFIX -from st2common.constants.keyvalue import FULL_SYSTEM_SCOPE, SYSTEM_SCOPE, DATASTORE_PARENT_SCOPE +from st2common.constants.keyvalue import ( + FULL_SYSTEM_SCOPE, + SYSTEM_SCOPE, + DATASTORE_PARENT_SCOPE, +) from st2common.services.keyvalues import KeyValueLookup from st2common.transport.queues import NOTIFIER_ACTIONUPDATE_WORK_QUEUE from st2common.metrics.base import CounterWithTimer from st2common.metrics.base import Timer -__all__ = [ - 'Notifier', - 'get_notifier' -] +__all__ = ["Notifier", "get_notifier"] LOG = logging.getLogger(__name__) # XXX: Fix this nasty positional dependency. -ACTION_TRIGGER_TYPE = INTERNAL_TRIGGER_TYPES['action'][0] -NOTIFY_TRIGGER_TYPE = INTERNAL_TRIGGER_TYPES['action'][1] +ACTION_TRIGGER_TYPE = INTERNAL_TRIGGER_TYPES["action"][0] +NOTIFY_TRIGGER_TYPE = INTERNAL_TRIGGER_TYPES["action"][1] class Notifier(consumers.MessageHandler): @@ -69,35 +70,40 @@ def __init__(self, connection, queues, trigger_dispatcher=None): trigger_dispatcher = TriggerDispatcher(LOG) self._trigger_dispatcher = trigger_dispatcher self._notify_trigger = ResourceReference.to_string_reference( - pack=NOTIFY_TRIGGER_TYPE['pack'], - name=NOTIFY_TRIGGER_TYPE['name']) + pack=NOTIFY_TRIGGER_TYPE["pack"], name=NOTIFY_TRIGGER_TYPE["name"] + ) self._action_trigger = ResourceReference.to_string_reference( - pack=ACTION_TRIGGER_TYPE['pack'], - name=ACTION_TRIGGER_TYPE['name']) + pack=ACTION_TRIGGER_TYPE["pack"], name=ACTION_TRIGGER_TYPE["name"] + ) - @CounterWithTimer(key='notifier.action.executions') + @CounterWithTimer(key="notifier.action.executions") def process(self, execution_db): execution_id = str(execution_db.id) - extra = {'execution': execution_db} + extra = {"execution": execution_db} LOG.debug('Processing action execution "%s".', execution_id, extra=extra) # Get the corresponding liveaction record. - liveaction_db = LiveAction.get_by_id(execution_db.liveaction['id']) + liveaction_db = LiveAction.get_by_id(execution_db.liveaction["id"]) if execution_db.status in LIVEACTION_COMPLETED_STATES: # If the action execution is executed under an orquesta workflow, policies for the # action execution will be applied by the workflow engine. A policy may affect the # final state of the action execution thereby impacting the state of the workflow. - if not workflow_service.is_action_execution_under_workflow_context(execution_db): - with CounterWithTimer(key='notifier.apply_post_run_policies'): + if not workflow_service.is_action_execution_under_workflow_context( + execution_db + ): + with CounterWithTimer(key="notifier.apply_post_run_policies"): policy_service.apply_post_run_policies(liveaction_db) if liveaction_db.notify: - with CounterWithTimer(key='notifier.notify_trigger.post'): - self._post_notify_triggers(liveaction_db=liveaction_db, - execution_db=execution_db) + with CounterWithTimer(key="notifier.notify_trigger.post"): + self._post_notify_triggers( + liveaction_db=liveaction_db, execution_db=execution_db + ) - self._post_generic_trigger(liveaction_db=liveaction_db, execution_db=execution_db) + self._post_generic_trigger( + liveaction_db=liveaction_db, execution_db=execution_db + ) def _get_execution_for_liveaction(self, liveaction): execution = ActionExecution.get(liveaction__id=str(liveaction.id)) @@ -108,39 +114,52 @@ def _get_execution_for_liveaction(self, liveaction): return execution def _post_notify_triggers(self, liveaction_db=None, execution_db=None): - notify = getattr(liveaction_db, 'notify', None) + notify = getattr(liveaction_db, "notify", None) if not notify: return if notify.on_complete: self._post_notify_subsection_triggers( - liveaction_db=liveaction_db, execution_db=execution_db, + liveaction_db=liveaction_db, + execution_db=execution_db, notify_subsection=notify.on_complete, - default_message_suffix='completed.') + default_message_suffix="completed.", + ) if liveaction_db.status == LIVEACTION_STATUS_SUCCEEDED and notify.on_success: self._post_notify_subsection_triggers( - liveaction_db=liveaction_db, execution_db=execution_db, + liveaction_db=liveaction_db, + execution_db=execution_db, notify_subsection=notify.on_success, - default_message_suffix='succeeded.') + default_message_suffix="succeeded.", + ) if liveaction_db.status in LIVEACTION_FAILED_STATES and notify.on_failure: self._post_notify_subsection_triggers( - liveaction_db=liveaction_db, execution_db=execution_db, + liveaction_db=liveaction_db, + execution_db=execution_db, notify_subsection=notify.on_failure, - default_message_suffix='failed.') + default_message_suffix="failed.", + ) - def _post_notify_subsection_triggers(self, liveaction_db=None, execution_db=None, - notify_subsection=None, - default_message_suffix=None): - routes = (getattr(notify_subsection, 'routes') or - getattr(notify_subsection, 'channels', [])) or [] + def _post_notify_subsection_triggers( + self, + liveaction_db=None, + execution_db=None, + notify_subsection=None, + default_message_suffix=None, + ): + routes = ( + getattr(notify_subsection, "routes") + or getattr(notify_subsection, "channels", []) + ) or [] execution_id = str(execution_db.id) if routes and len(routes) >= 1: payload = {} message = notify_subsection.message or ( - 'Action ' + liveaction_db.action + ' ' + default_message_suffix) + "Action " + liveaction_db.action + " " + default_message_suffix + ) data = notify_subsection.data or {} jinja_context = self._build_jinja_context( @@ -148,17 +167,18 @@ def _post_notify_subsection_triggers(self, liveaction_db=None, execution_db=None ) try: - with Timer(key='notifier.transform_message'): - message = self._transform_message(message=message, - context=jinja_context) + with Timer(key="notifier.transform_message"): + message = self._transform_message( + message=message, context=jinja_context + ) except: - LOG.exception('Failed (Jinja) transforming `message`.') + LOG.exception("Failed (Jinja) transforming `message`.") try: - with Timer(key='notifier.transform_data'): + with Timer(key="notifier.transform_data"): data = self._transform_data(data=data, context=jinja_context) except: - LOG.exception('Failed (Jinja) transforming `data`.') + LOG.exception("Failed (Jinja) transforming `data`.") # At this point convert result to a string. This restricts the rulesengines # ability to introspect the result. On the other handle atleast a json usable @@ -166,69 +186,82 @@ def _post_notify_subsection_triggers(self, liveaction_db=None, execution_db=None # to a string representation it uses str(...) which make it impossible to # parse the result as json any longer. # TODO: Use to_serializable_dict - data['result'] = json.dumps(liveaction_db.result) + data["result"] = json.dumps(liveaction_db.result) - payload['message'] = message - payload['data'] = data - payload['execution_id'] = execution_id - payload['status'] = liveaction_db.status - payload['start_timestamp'] = isotime.format(liveaction_db.start_timestamp) + payload["message"] = message + payload["data"] = data + payload["execution_id"] = execution_id + payload["status"] = liveaction_db.status + payload["start_timestamp"] = isotime.format(liveaction_db.start_timestamp) try: - payload['end_timestamp'] = isotime.format(liveaction_db.end_timestamp) + payload["end_timestamp"] = isotime.format(liveaction_db.end_timestamp) except AttributeError: # This can be raised if liveaction.end_timestamp is None, which is caused # when policy cancels a request due to concurrency # In this case, use datetime.now() instead - payload['end_timestamp'] = isotime.format(datetime.utcnow()) + payload["end_timestamp"] = isotime.format(datetime.utcnow()) - payload['action_ref'] = liveaction_db.action - payload['runner_ref'] = self._get_runner_ref(liveaction_db.action) + payload["action_ref"] = liveaction_db.action + payload["runner_ref"] = self._get_runner_ref(liveaction_db.action) trace_context = self._get_trace_context(execution_id=execution_id) failed_routes = [] for route in routes: try: - payload['route'] = route + payload["route"] = route # Deprecated. Only for backward compatibility reasons. - payload['channel'] = route - LOG.debug('POSTing %s for %s. Payload - %s.', NOTIFY_TRIGGER_TYPE['name'], - liveaction_db.id, payload) - - with CounterWithTimer(key='notifier.notify_trigger.dispatch'): - self._trigger_dispatcher.dispatch(self._notify_trigger, payload=payload, - trace_context=trace_context) + payload["channel"] = route + LOG.debug( + "POSTing %s for %s. Payload - %s.", + NOTIFY_TRIGGER_TYPE["name"], + liveaction_db.id, + payload, + ) + + with CounterWithTimer(key="notifier.notify_trigger.dispatch"): + self._trigger_dispatcher.dispatch( + self._notify_trigger, + payload=payload, + trace_context=trace_context, + ) except: failed_routes.append(route) if len(failed_routes) > 0: - raise Exception('Failed notifications to routes: %s' % ', '.join(failed_routes)) + raise Exception( + "Failed notifications to routes: %s" % ", ".join(failed_routes) + ) def _build_jinja_context(self, liveaction_db, execution_db): context = {} - context.update({ - DATASTORE_PARENT_SCOPE: { - SYSTEM_SCOPE: KeyValueLookup(scope=FULL_SYSTEM_SCOPE) + context.update( + { + DATASTORE_PARENT_SCOPE: { + SYSTEM_SCOPE: KeyValueLookup(scope=FULL_SYSTEM_SCOPE) + } } - }) + ) context.update({ACTION_PARAMETERS_KV_PREFIX: liveaction_db.parameters}) context.update({ACTION_CONTEXT_KV_PREFIX: liveaction_db.context}) context.update({ACTION_RESULTS_KV_PREFIX: execution_db.result}) return context def _transform_message(self, message, context=None): - mapping = {'message': message} + mapping = {"message": message} context = context or {} - return (jinja_utils.render_values(mapping=mapping, context=context)).get('message', - message) + return (jinja_utils.render_values(mapping=mapping, context=context)).get( + "message", message + ) def _transform_data(self, data, context=None): return jinja_utils.render_values(mapping=data, context=context) def _get_trace_context(self, execution_id): trace_db = trace_service.get_trace_db_by_action_execution( - action_execution_id=execution_id) + action_execution_id=execution_id + ) if trace_db: return TraceContext(id_=str(trace_db.id), trace_tag=trace_db.trace_tag) # If no trace_context is found then do not create a new one here. If necessary @@ -237,38 +270,48 @@ def _get_trace_context(self, execution_id): def _post_generic_trigger(self, liveaction_db=None, execution_db=None): if not cfg.CONF.action_sensor.enable: - LOG.debug('Action trigger is disabled, skipping trigger dispatch...') + LOG.debug("Action trigger is disabled, skipping trigger dispatch...") return execution_id = str(execution_db.id) - extra = {'execution': execution_db} + extra = {"execution": execution_db} target_statuses = cfg.CONF.action_sensor.emit_when if execution_db.status not in target_statuses: msg = 'Skip action execution "%s" because state "%s" is not in %s' - LOG.debug(msg % (execution_id, execution_db.status, target_statuses), extra=extra) + LOG.debug( + msg % (execution_id, execution_db.status, target_statuses), extra=extra + ) return - with CounterWithTimer(key='notifier.generic_trigger.post'): - payload = {'execution_id': execution_id, - 'status': liveaction_db.status, - 'start_timestamp': str(liveaction_db.start_timestamp), - # deprecate 'action_name' at some point and switch to 'action_ref' - 'action_name': liveaction_db.action, - 'action_ref': liveaction_db.action, - 'runner_ref': self._get_runner_ref(liveaction_db.action), - 'parameters': liveaction_db.get_masked_parameters(), - 'result': liveaction_db.result} + with CounterWithTimer(key="notifier.generic_trigger.post"): + payload = { + "execution_id": execution_id, + "status": liveaction_db.status, + "start_timestamp": str(liveaction_db.start_timestamp), + # deprecate 'action_name' at some point and switch to 'action_ref' + "action_name": liveaction_db.action, + "action_ref": liveaction_db.action, + "runner_ref": self._get_runner_ref(liveaction_db.action), + "parameters": liveaction_db.get_masked_parameters(), + "result": liveaction_db.result, + } # Use execution_id to extract trace rather than liveaction. execution_id # will look-up an exact TraceDB while liveaction depending on context # may not end up going to the DB. trace_context = self._get_trace_context(execution_id=execution_id) - LOG.debug('POSTing %s for %s. Payload - %s. TraceContext - %s', - ACTION_TRIGGER_TYPE['name'], liveaction_db.id, payload, trace_context) + LOG.debug( + "POSTing %s for %s. Payload - %s. TraceContext - %s", + ACTION_TRIGGER_TYPE["name"], + liveaction_db.id, + payload, + trace_context, + ) - with CounterWithTimer(key='notifier.generic_trigger.dispatch'): - self._trigger_dispatcher.dispatch(self._action_trigger, payload=payload, - trace_context=trace_context) + with CounterWithTimer(key="notifier.generic_trigger.dispatch"): + self._trigger_dispatcher.dispatch( + self._action_trigger, payload=payload, trace_context=trace_context + ) def _get_runner_ref(self, action_ref): """ @@ -277,10 +320,13 @@ def _get_runner_ref(self, action_ref): :rtype: ``str`` """ action = Action.get_by_ref(action_ref) - return action['runner_type']['name'] + return action["runner_type"]["name"] def get_notifier(): with transport_utils.get_connection() as conn: - return Notifier(conn, [NOTIFIER_ACTIONUPDATE_WORK_QUEUE], - trigger_dispatcher=TriggerDispatcher(LOG)) + return Notifier( + conn, + [NOTIFIER_ACTIONUPDATE_WORK_QUEUE], + trigger_dispatcher=TriggerDispatcher(LOG), + ) diff --git a/st2actions/st2actions/policies/concurrency.py b/st2actions/st2actions/policies/concurrency.py index 4f98b093c73..cf47ed0b691 100644 --- a/st2actions/st2actions/policies/concurrency.py +++ b/st2actions/st2actions/policies/concurrency.py @@ -22,53 +22,64 @@ from st2common.services import action as action_service -__all__ = [ - 'ConcurrencyApplicator' -] +__all__ = ["ConcurrencyApplicator"] LOG = logging.getLogger(__name__) class ConcurrencyApplicator(BaseConcurrencyApplicator): - - def __init__(self, policy_ref, policy_type, threshold=0, action='delay'): - super(ConcurrencyApplicator, self).__init__(policy_ref=policy_ref, policy_type=policy_type, - threshold=threshold, - action=action) + def __init__(self, policy_ref, policy_type, threshold=0, action="delay"): + super(ConcurrencyApplicator, self).__init__( + policy_ref=policy_ref, + policy_type=policy_type, + threshold=threshold, + action=action, + ) def _get_lock_uid(self, target): - values = {'policy_type': self._policy_type, 'action': target.action} + values = {"policy_type": self._policy_type, "action": target.action} return self._get_lock_name(values=values) def _apply_before(self, target): # Get the count of scheduled instances of the action. scheduled = action_access.LiveAction.count( - action=target.action, status=action_constants.LIVEACTION_STATUS_SCHEDULED) + action=target.action, status=action_constants.LIVEACTION_STATUS_SCHEDULED + ) # Get the count of running instances of the action. running = action_access.LiveAction.count( - action=target.action, status=action_constants.LIVEACTION_STATUS_RUNNING) + action=target.action, status=action_constants.LIVEACTION_STATUS_RUNNING + ) count = scheduled + running # Mark the execution as scheduled if threshold is not reached or delayed otherwise. if count < self.threshold: - LOG.debug('There are %s instances of %s in scheduled or running status. ' - 'Threshold of %s is not reached. Action execution will be scheduled.', - count, target.action, self._policy_ref) + LOG.debug( + "There are %s instances of %s in scheduled or running status. " + "Threshold of %s is not reached. Action execution will be scheduled.", + count, + target.action, + self._policy_ref, + ) status = action_constants.LIVEACTION_STATUS_REQUESTED else: - action = 'delayed' if self.policy_action == 'delay' else 'canceled' - LOG.debug('There are %s instances of %s in scheduled or running status. ' - 'Threshold of %s is reached. Action execution will be %s.', - count, target.action, self._policy_ref, action) + action = "delayed" if self.policy_action == "delay" else "canceled" + LOG.debug( + "There are %s instances of %s in scheduled or running status. " + "Threshold of %s is reached. Action execution will be %s.", + count, + target.action, + self._policy_ref, + action, + ) status = self._get_status_for_policy_action(action=self.policy_action) # Update the status in the database. Publish status for cancellation so the # appropriate runner can cancel the execution. Other statuses are not published # because they will be picked up by the worker(s) to be processed again, # leading to duplicate action executions. - publish = (status == action_constants.LIVEACTION_STATUS_CANCELING) + publish = status == action_constants.LIVEACTION_STATUS_CANCELING target = action_service.update_status(target, status, publish=publish) return target @@ -78,13 +89,17 @@ def apply_before(self, target): valid_states = [ action_constants.LIVEACTION_STATUS_REQUESTED, - action_constants.LIVEACTION_STATUS_DELAYED + action_constants.LIVEACTION_STATUS_DELAYED, ] # Exit if target not in valid state. if target.status not in valid_states: - LOG.debug('The live action is not in a valid state therefore the policy ' - '"%s" cannot be applied. %s', self._policy_ref, target) + LOG.debug( + "The live action is not in a valid state therefore the policy " + '"%s" cannot be applied. %s', + self._policy_ref, + target, + ) return target target = self._apply_before(target) diff --git a/st2actions/st2actions/policies/concurrency_by_attr.py b/st2actions/st2actions/policies/concurrency_by_attr.py index 7c9ee1dabc4..ea3f9cd4218 100644 --- a/st2actions/st2actions/policies/concurrency_by_attr.py +++ b/st2actions/st2actions/policies/concurrency_by_attr.py @@ -25,38 +25,41 @@ from st2common.policies.concurrency import BaseConcurrencyApplicator from st2common.services import coordination -__all__ = [ - 'ConcurrencyByAttributeApplicator' -] +__all__ = ["ConcurrencyByAttributeApplicator"] LOG = logging.getLogger(__name__) class ConcurrencyByAttributeApplicator(BaseConcurrencyApplicator): - - def __init__(self, policy_ref, policy_type, threshold=0, action='delay', attributes=None): - super(ConcurrencyByAttributeApplicator, self).__init__(policy_ref=policy_ref, - policy_type=policy_type, - threshold=threshold, - action=action) + def __init__( + self, policy_ref, policy_type, threshold=0, action="delay", attributes=None + ): + super(ConcurrencyByAttributeApplicator, self).__init__( + policy_ref=policy_ref, + policy_type=policy_type, + threshold=threshold, + action=action, + ) self.attributes = attributes or [] def _get_lock_uid(self, target): meta = { - 'policy_type': self._policy_type, - 'action': target.action, - 'attributes': self.attributes + "policy_type": self._policy_type, + "action": target.action, + "attributes": self.attributes, } return json.dumps(meta) def _get_filters(self, target): - filters = {('parameters__%s' % k): v - for k, v in six.iteritems(target.parameters) - if k in self.attributes} + filters = { + ("parameters__%s" % k): v + for k, v in six.iteritems(target.parameters) + if k in self.attributes + } - filters['action'] = target.action - filters['status'] = None + filters["action"] = target.action + filters["status"] = None return filters @@ -65,54 +68,71 @@ def _apply_before(self, target): filters = self._get_filters(target) # Get the count of scheduled instances of the action. - filters['status'] = action_constants.LIVEACTION_STATUS_SCHEDULED + filters["status"] = action_constants.LIVEACTION_STATUS_SCHEDULED scheduled = action_access.LiveAction.count(**filters) # Get the count of running instances of the action. - filters['status'] = action_constants.LIVEACTION_STATUS_RUNNING + filters["status"] = action_constants.LIVEACTION_STATUS_RUNNING running = action_access.LiveAction.count(**filters) count = scheduled + running # Mark the execution as scheduled if threshold is not reached or delayed otherwise. if count < self.threshold: - LOG.debug('There are %s instances of %s in scheduled or running status. ' - 'Threshold of %s is not reached. Action execution will be scheduled.', - count, target.action, self._policy_ref) + LOG.debug( + "There are %s instances of %s in scheduled or running status. " + "Threshold of %s is not reached. Action execution will be scheduled.", + count, + target.action, + self._policy_ref, + ) status = action_constants.LIVEACTION_STATUS_REQUESTED else: - action = 'delayed' if self.policy_action == 'delay' else 'canceled' - LOG.debug('There are %s instances of %s in scheduled or running status. ' - 'Threshold of %s is reached. Action execution will be %s.', - count, target.action, self._policy_ref, action) + action = "delayed" if self.policy_action == "delay" else "canceled" + LOG.debug( + "There are %s instances of %s in scheduled or running status. " + "Threshold of %s is reached. Action execution will be %s.", + count, + target.action, + self._policy_ref, + action, + ) status = self._get_status_for_policy_action(action=self.policy_action) # Update the status in the database. Publish status for cancellation so the # appropriate runner can cancel the execution. Other statuses are not published # because they will be picked up by the worker(s) to be processed again, # leading to duplicate action executions. - publish = (status == action_constants.LIVEACTION_STATUS_CANCELING) + publish = status == action_constants.LIVEACTION_STATUS_CANCELING target = action_service.update_status(target, status, publish=publish) return target def apply_before(self, target): - target = super(ConcurrencyByAttributeApplicator, self).apply_before(target=target) + target = super(ConcurrencyByAttributeApplicator, self).apply_before( + target=target + ) valid_states = [ action_constants.LIVEACTION_STATUS_REQUESTED, - action_constants.LIVEACTION_STATUS_DELAYED + action_constants.LIVEACTION_STATUS_DELAYED, ] # Exit if target not in valid state. if target.status not in valid_states: - LOG.debug('The live action is not schedulable therefore the policy ' - '"%s" cannot be applied. %s', self._policy_ref, target) + LOG.debug( + "The live action is not schedulable therefore the policy " + '"%s" cannot be applied. %s', + self._policy_ref, + target, + ) return target # Warn users that the coordination service is not configured. if not coordination.configured(): - LOG.warn('Coordination service is not configured. Policy enforcement is best effort.') + LOG.warn( + "Coordination service is not configured. Policy enforcement is best effort." + ) target = self._apply_before(target) diff --git a/st2actions/st2actions/policies/retry.py b/st2actions/st2actions/policies/retry.py index 85775d4f13f..abbbd70453c 100644 --- a/st2actions/st2actions/policies/retry.py +++ b/st2actions/st2actions/policies/retry.py @@ -27,22 +27,16 @@ from st2common.util.enum import Enum from st2common.policies.base import ResourcePolicyApplicator -__all__ = [ - 'RetryOnPolicy', - 'ExecutionRetryPolicyApplicator' -] +__all__ = ["RetryOnPolicy", "ExecutionRetryPolicyApplicator"] LOG = logging.getLogger(__name__) -VALID_RETRY_STATUSES = [ - LIVEACTION_STATUS_FAILED, - LIVEACTION_STATUS_TIMED_OUT -] +VALID_RETRY_STATUSES = [LIVEACTION_STATUS_FAILED, LIVEACTION_STATUS_TIMED_OUT] class RetryOnPolicy(Enum): - FAILURE = 'failure' # Retry on execution failure - TIMEOUT = 'timeout' # Retry on execution timeout + FAILURE = "failure" # Retry on execution failure + TIMEOUT = "timeout" # Retry on execution timeout class ExecutionRetryPolicyApplicator(ResourcePolicyApplicator): @@ -57,8 +51,9 @@ def __init__(self, policy_ref, policy_type, retry_on, max_retry_count=2, delay=0 :param delay: How long to wait before retrying an execution. :type delay: ``float`` """ - super(ExecutionRetryPolicyApplicator, self).__init__(policy_ref=policy_ref, - policy_type=policy_type) + super(ExecutionRetryPolicyApplicator, self).__init__( + policy_ref=policy_ref, policy_type=policy_type + ) self.retry_on = retry_on self.max_retry_count = max_retry_count @@ -71,27 +66,33 @@ def apply_after(self, target): if self._is_live_action_part_of_workflow_action(live_action_db): LOG.warning( - 'Retry cannot be applied to this liveaction because it is executed under a ' - 'workflow. Use workflow specific retry functionality where applicable. %s', - live_action_db + "Retry cannot be applied to this liveaction because it is executed under a " + "workflow. Use workflow specific retry functionality where applicable. %s", + live_action_db, ) return target retry_count = self._get_live_action_retry_count(live_action_db=live_action_db) - extra = {'live_action_db': live_action_db, 'policy_ref': self._policy_ref, - 'retry_on': self.retry_on, 'max_retry_count': self.max_retry_count, - 'current_retry_count': retry_count} + extra = { + "live_action_db": live_action_db, + "policy_ref": self._policy_ref, + "retry_on": self.retry_on, + "max_retry_count": self.max_retry_count, + "current_retry_count": retry_count, + } if live_action_db.status not in VALID_RETRY_STATUSES: # Currently we only support retrying on failed action - LOG.debug('Liveaction not in a valid retry state, not checking retry policy', - extra=extra) + LOG.debug( + "Liveaction not in a valid retry state, not checking retry policy", + extra=extra, + ) return target if (retry_count + 1) > self.max_retry_count: - LOG.info('Maximum retry count has been reached, not retrying', extra=extra) + LOG.info("Maximum retry count has been reached, not retrying", extra=extra) return target has_failed = live_action_db.status == LIVEACTION_STATUS_FAILED @@ -100,34 +101,50 @@ def apply_after(self, target): # TODO: This is not crash and restart safe, switch to using "DELAYED" # status if self.delay > 0: - re_run_live_action = functools.partial(eventlet.spawn_after, self.delay, - self._re_run_live_action, - live_action_db=live_action_db) + re_run_live_action = functools.partial( + eventlet.spawn_after, + self.delay, + self._re_run_live_action, + live_action_db=live_action_db, + ) else: # Even if delay is 0, use a small delay (0.1 seconds) to prevent busy wait - re_run_live_action = functools.partial(eventlet.spawn_after, 0.1, - self._re_run_live_action, - live_action_db=live_action_db) + re_run_live_action = functools.partial( + eventlet.spawn_after, + 0.1, + self._re_run_live_action, + live_action_db=live_action_db, + ) - re_run_live_action = functools.partial(self._re_run_live_action, - live_action_db=live_action_db) + re_run_live_action = functools.partial( + self._re_run_live_action, live_action_db=live_action_db + ) if has_failed and self.retry_on == RetryOnPolicy.FAILURE: - extra['failure'] = True - LOG.info('Policy matched (failure), retrying action execution in %s seconds...' % - (self.delay), extra=extra) + extra["failure"] = True + LOG.info( + "Policy matched (failure), retrying action execution in %s seconds..." + % (self.delay), + extra=extra, + ) re_run_live_action() return target if has_timed_out and self.retry_on == RetryOnPolicy.TIMEOUT: - extra['timeout'] = True - LOG.info('Policy matched (timeout), retrying action execution in %s seconds...' % - (self.delay), extra=extra) + extra["timeout"] = True + LOG.info( + "Policy matched (timeout), retrying action execution in %s seconds..." + % (self.delay), + extra=extra, + ) re_run_live_action() return target - LOG.info('Invalid status "%s" for live action "%s", wont retry' % - (live_action_db.status, str(live_action_db.id)), extra=extra) + LOG.info( + 'Invalid status "%s" for live action "%s", wont retry' + % (live_action_db.status, str(live_action_db.id)), + extra=extra, + ) return target @@ -137,9 +154,9 @@ def _is_live_action_part_of_workflow_action(self, live_action_db): :rtype: ``dict`` """ - context = getattr(live_action_db, 'context', {}) - parent = context.get('parent', {}) - is_wf_action = (parent is not None and parent != {}) + context = getattr(live_action_db, "context", {}) + parent = context.get("parent", {}) + is_wf_action = parent is not None and parent != {} return is_wf_action @@ -151,8 +168,8 @@ def _get_live_action_retry_count(self, live_action_db): """ # TODO: Ideally we would store retry_count in zookeeper or similar and use locking so we # can run multiple instances of st2notififer - context = getattr(live_action_db, 'context', {}) - retry_count = context.get('policies', {}).get('retry', {}).get('retry_count', 0) + context = getattr(live_action_db, "context", {}) + retry_count = context.get("policies", {}).get("retry", {}).get("retry_count", 0) return retry_count @@ -160,17 +177,18 @@ def _re_run_live_action(self, live_action_db): retry_count = self._get_live_action_retry_count(live_action_db=live_action_db) # Add additional policy specific info to the context - context = getattr(live_action_db, 'context', {}) + context = getattr(live_action_db, "context", {}) new_context = copy.deepcopy(context) - new_context['policies'] = {} - new_context['policies']['retry'] = { - 'applied_policy': self._policy_ref, - 'retry_count': (retry_count + 1), - 'retried_liveaction_id': str(live_action_db.id) + new_context["policies"] = {} + new_context["policies"]["retry"] = { + "applied_policy": self._policy_ref, + "retry_count": (retry_count + 1), + "retried_liveaction_id": str(live_action_db.id), } action_ref = live_action_db.action parameters = live_action_db.parameters - new_live_action_db = LiveActionDB(action=action_ref, parameters=parameters, - context=new_context) + new_live_action_db = LiveActionDB( + action=action_ref, parameters=parameters, context=new_context + ) _, action_execution_db = action_services.request(new_live_action_db) return action_execution_db diff --git a/st2actions/st2actions/runners/pythonrunner.py b/st2actions/st2actions/runners/pythonrunner.py index 215edd83c86..33a3f3ec39e 100644 --- a/st2actions/st2actions/runners/pythonrunner.py +++ b/st2actions/st2actions/runners/pythonrunner.py @@ -16,6 +16,4 @@ from __future__ import absolute_import from st2common.runners.base_action import Action -__all__ = [ - 'Action' -] +__all__ = ["Action"] diff --git a/st2actions/st2actions/scheduler/config.py b/st2actions/st2actions/scheduler/config.py index a991403a9bc..8df6c3ff3e0 100644 --- a/st2actions/st2actions/scheduler/config.py +++ b/st2actions/st2actions/scheduler/config.py @@ -27,8 +27,11 @@ def parse_args(args=None): - cfg.CONF(args=args, version=sys_constants.VERSION_STRING, - default_config_files=[DEFAULT_CONFIG_FILE_PATH]) + cfg.CONF( + args=args, + version=sys_constants.VERSION_STRING, + default_config_files=[DEFAULT_CONFIG_FILE_PATH], + ) def register_opts(): @@ -47,36 +50,48 @@ def _register_common_opts(): def _register_service_opts(): scheduler_opts = [ cfg.StrOpt( - 'logging', - default='/etc/st2/logging.scheduler.conf', - help='Location of the logging configuration file.' + "logging", + default="/etc/st2/logging.scheduler.conf", + help="Location of the logging configuration file.", ), cfg.FloatOpt( - 'execution_scheduling_timeout_threshold_min', default=1, - help='How long GC to search back in minutes for orphaned scheduled actions'), + "execution_scheduling_timeout_threshold_min", + default=1, + help="How long GC to search back in minutes for orphaned scheduled actions", + ), cfg.IntOpt( - 'pool_size', default=10, - help='The size of the pool used by the scheduler for scheduling executions.'), + "pool_size", + default=10, + help="The size of the pool used by the scheduler for scheduling executions.", + ), cfg.FloatOpt( - 'sleep_interval', default=0.10, - help='How long (in seconds) to sleep between each action scheduler main loop run ' - 'interval.'), + "sleep_interval", + default=0.10, + help="How long (in seconds) to sleep between each action scheduler main loop run " + "interval.", + ), cfg.FloatOpt( - 'gc_interval', default=10, - help='How often (in seconds) to look for zombie execution requests before rescheduling ' - 'them.'), + "gc_interval", + default=10, + help="How often (in seconds) to look for zombie execution requests before rescheduling " + "them.", + ), cfg.IntOpt( - 'retry_max_attempt', default=10, - help='The maximum number of attempts that the scheduler retries on error.'), + "retry_max_attempt", + default=10, + help="The maximum number of attempts that the scheduler retries on error.", + ), cfg.IntOpt( - 'retry_wait_msec', default=3000, - help='The number of milliseconds to wait in between retries.') + "retry_wait_msec", + default=3000, + help="The number of milliseconds to wait in between retries.", + ), ] - cfg.CONF.register_opts(scheduler_opts, group='scheduler') + cfg.CONF.register_opts(scheduler_opts, group="scheduler") try: register_opts() except cfg.DuplicateOptError: - LOG.exception('The scheduler configuration options are already parsed and loaded.') + LOG.exception("The scheduler configuration options are already parsed and loaded.") diff --git a/st2actions/st2actions/scheduler/entrypoint.py b/st2actions/st2actions/scheduler/entrypoint.py index ee8a76f2d1a..14d816ded39 100644 --- a/st2actions/st2actions/scheduler/entrypoint.py +++ b/st2actions/st2actions/scheduler/entrypoint.py @@ -29,10 +29,7 @@ from st2common.persistence.execution_queue import ActionExecutionSchedulingQueue from st2common.models.db.execution_queue import ActionExecutionSchedulingQueueItemDB -__all__ = [ - 'SchedulerEntrypoint', - 'get_scheduler_entrypoint' -] +__all__ = ["SchedulerEntrypoint", "get_scheduler_entrypoint"] LOG = logging.getLogger(__name__) @@ -43,6 +40,7 @@ class SchedulerEntrypoint(consumers.MessageHandler): SchedulerEntrypoint subscribes to the Action scheduler request queue and places new Live Actions into the scheduling queue collection for scheduling on action runners. """ + message_type = LiveActionDB def process(self, request): @@ -53,18 +51,25 @@ def process(self, request): :type request: ``st2common.models.db.liveaction.LiveActionDB`` """ if request.status != action_constants.LIVEACTION_STATUS_REQUESTED: - LOG.info('%s is ignoring %s (id=%s) with "%s" status.', - self.__class__.__name__, type(request), request.id, request.status) + LOG.info( + '%s is ignoring %s (id=%s) with "%s" status.', + self.__class__.__name__, + type(request), + request.id, + request.status, + ) return try: liveaction_db = action_utils.get_liveaction_by_id(str(request.id)) except StackStormDBObjectNotFoundError: - LOG.exception('Failed to find liveaction %s in the database.', str(request.id)) + LOG.exception( + "Failed to find liveaction %s in the database.", str(request.id) + ) raise query = { - 'liveaction_id': str(liveaction_db.id), + "liveaction_id": str(liveaction_db.id), } queued_requests = ActionExecutionSchedulingQueue.query(**query) @@ -75,17 +80,16 @@ def process(self, request): if liveaction_db.delay and liveaction_db.delay > 0: liveaction_db = action_service.update_status( - liveaction_db, - action_constants.LIVEACTION_STATUS_DELAYED, - publish=False + liveaction_db, action_constants.LIVEACTION_STATUS_DELAYED, publish=False ) execution_queue_item_db = self._create_execution_queue_item_db_from_liveaction( - liveaction_db, - delay=liveaction_db.delay + liveaction_db, delay=liveaction_db.delay ) - ActionExecutionSchedulingQueue.add_or_update(execution_queue_item_db, publish=False) + ActionExecutionSchedulingQueue.add_or_update( + execution_queue_item_db, publish=False + ) return execution_queue_item_db @@ -99,9 +103,8 @@ def _create_execution_queue_item_db_from_liveaction(self, liveaction, delay=None execution_queue_item_db.action_execution_id = str(execution.id) execution_queue_item_db.liveaction_id = str(liveaction.id) execution_queue_item_db.original_start_timestamp = liveaction.start_timestamp - execution_queue_item_db.scheduled_start_timestamp = date.append_milliseconds_to_time( - liveaction.start_timestamp, - delay or 0 + execution_queue_item_db.scheduled_start_timestamp = ( + date.append_milliseconds_to_time(liveaction.start_timestamp, delay or 0) ) execution_queue_item_db.delay = delay diff --git a/st2actions/st2actions/scheduler/handler.py b/st2actions/st2actions/scheduler/handler.py index 76d54066a96..e39871db3eb 100644 --- a/st2actions/st2actions/scheduler/handler.py +++ b/st2actions/st2actions/scheduler/handler.py @@ -37,10 +37,7 @@ from st2common.metrics import base as metrics from st2common.exceptions import db as db_exc -__all__ = [ - 'ActionExecutionSchedulingQueueHandler', - 'get_handler' -] +__all__ = ["ActionExecutionSchedulingQueueHandler", "get_handler"] LOG = logging.getLogger(__name__) @@ -61,14 +58,15 @@ def __init__(self): # fast (< 5 seconds). If an item is still being marked as processing it likely indicates # that the scheduler process which was processing that item crashed or similar so we need # to mark it as "handling=False" so some other scheduler process can pick it up. - self._execution_scheduling_timeout_threshold_ms = \ + self._execution_scheduling_timeout_threshold_ms = ( cfg.CONF.scheduler.execution_scheduling_timeout_threshold_min * 60 * 1000 + ) self._coordinator = coordination_service.get_coordinator(start_heart=True) self._main_thread = None self._cleanup_thread = None def run(self): - LOG.debug('Starting scheduler handler...') + LOG.debug("Starting scheduler handler...") while not self._shutdown: eventlet.greenthread.sleep(cfg.CONF.scheduler.sleep_interval) @@ -77,7 +75,8 @@ def run(self): @retrying.retry( retry_on_exception=service_utils.retry_on_exceptions, stop_max_attempt_number=cfg.CONF.scheduler.retry_max_attempt, - wait_fixed=cfg.CONF.scheduler.retry_wait_msec) + wait_fixed=cfg.CONF.scheduler.retry_wait_msec, + ) def process(self): execution_queue_item_db = self._get_next_execution() @@ -85,7 +84,7 @@ def process(self): self._pool.spawn(self._handle_execution, execution_queue_item_db) def cleanup(self): - LOG.debug('Starting scheduler garbage collection...') + LOG.debug("Starting scheduler garbage collection...") while not self._shutdown: eventlet.greenthread.sleep(cfg.CONF.scheduler.gc_interval) @@ -99,11 +98,11 @@ def _reset_handling_flag(self): False so other scheduler can pick it up. """ query = { - 'scheduled_start_timestamp__lte': date.append_milliseconds_to_time( + "scheduled_start_timestamp__lte": date.append_milliseconds_to_time( date.get_datetime_utc_now(), - -self._execution_scheduling_timeout_threshold_ms + -self._execution_scheduling_timeout_threshold_ms, ), - 'handling': True + "handling": True, } execution_queue_item_dbs = ActionExecutionSchedulingQueue.query(**query) or [] @@ -112,17 +111,19 @@ def _reset_handling_flag(self): execution_queue_item_db.handling = False try: - ActionExecutionSchedulingQueue.add_or_update(execution_queue_item_db, publish=False) + ActionExecutionSchedulingQueue.add_or_update( + execution_queue_item_db, publish=False + ) LOG.info( '[%s] Removing lock for orphaned execution queue item "%s".', execution_queue_item_db.action_execution_id, - str(execution_queue_item_db.id) + str(execution_queue_item_db.id), ) except db_exc.StackStormDBObjectWriteConflictError: LOG.info( '[%s] Execution queue item "%s" updated during garbage collection.', execution_queue_item_db.action_execution_id, - str(execution_queue_item_db.id) + str(execution_queue_item_db.id), ) # TODO: Remove this function for fixing missing action_execution_id in v3.2. @@ -132,7 +133,9 @@ def _fix_missing_action_execution_id(self): """ Auto-populate the action_execution_id in ActionExecutionSchedulingQueue if empty. """ - for entry in ActionExecutionSchedulingQueue.query(action_execution_id__in=['', None]): + for entry in ActionExecutionSchedulingQueue.query( + action_execution_id__in=["", None] + ): execution_db = ActionExecution.get(liveaction__id=entry.liveaction_id) if not execution_db: @@ -152,23 +155,27 @@ def _cleanup_policy_delayed(self): moved back into requested status. """ - policy_delayed_liveaction_dbs = LiveAction.query(status='policy-delayed') or [] + policy_delayed_liveaction_dbs = LiveAction.query(status="policy-delayed") or [] for liveaction_db in policy_delayed_liveaction_dbs: - ex_que_qry = {'liveaction_id': str(liveaction_db.id), 'handling': False} - execution_queue_item_dbs = ActionExecutionSchedulingQueue.query(**ex_que_qry) or [] + ex_que_qry = {"liveaction_id": str(liveaction_db.id), "handling": False} + execution_queue_item_dbs = ( + ActionExecutionSchedulingQueue.query(**ex_que_qry) or [] + ) for execution_queue_item_db in execution_queue_item_dbs: # Mark the entry in the scheduling queue for handling. try: execution_queue_item_db.handling = True - execution_queue_item_db = ActionExecutionSchedulingQueue.add_or_update( - execution_queue_item_db, publish=False) + execution_queue_item_db = ( + ActionExecutionSchedulingQueue.add_or_update( + execution_queue_item_db, publish=False + ) + ) except db_exc.StackStormDBObjectWriteConflictError: - msg = ( - '[%s] Item "%s" is currently being processed by another scheduler.' % - (execution_queue_item_db.action_execution_id, - str(execution_queue_item_db.id)) + msg = '[%s] Item "%s" is currently being processed by another scheduler.' % ( + execution_queue_item_db.action_execution_id, + str(execution_queue_item_db.id), ) LOG.error(msg) raise Exception(msg) @@ -177,7 +184,7 @@ def _cleanup_policy_delayed(self): LOG.info( '[%s] Removing policy-delayed entry "%s" from the scheduling queue.', execution_queue_item_db.action_execution_id, - str(execution_queue_item_db.id) + str(execution_queue_item_db.id), ) ActionExecutionSchedulingQueue.delete(execution_queue_item_db) @@ -186,18 +193,20 @@ def _cleanup_policy_delayed(self): LOG.info( '[%s] Removing policy-delayed entry "%s" from the scheduling queue.', execution_queue_item_db.action_execution_id, - str(execution_queue_item_db.id) + str(execution_queue_item_db.id), ) liveaction_db = action_service.update_status( - liveaction_db, action_constants.LIVEACTION_STATUS_REQUESTED) + liveaction_db, action_constants.LIVEACTION_STATUS_REQUESTED + ) execution_service.update_execution(liveaction_db) @retrying.retry( retry_on_exception=service_utils.retry_on_exceptions, stop_max_attempt_number=cfg.CONF.scheduler.retry_max_attempt, - wait_fixed=cfg.CONF.scheduler.retry_wait_msec) + wait_fixed=cfg.CONF.scheduler.retry_wait_msec, + ) def _handle_garbage_collection(self): self._reset_handling_flag() @@ -212,13 +221,10 @@ def _get_next_execution(self): due to a policy. """ query = { - 'scheduled_start_timestamp__lte': date.get_datetime_utc_now(), - 'handling': False, - 'limit': 1, - 'order_by': [ - '+scheduled_start_timestamp', - '+original_start_timestamp' - ] + "scheduled_start_timestamp__lte": date.get_datetime_utc_now(), + "handling": False, + "limit": 1, + "order_by": ["+scheduled_start_timestamp", "+original_start_timestamp"], } execution_queue_item_db = ActionExecutionSchedulingQueue.query(**query).first() @@ -229,45 +235,52 @@ def _get_next_execution(self): # Mark that this scheduler process is currently handling (processing) that request # NOTE: This operation is atomic (CAS) msg = '[%s] Retrieved item "%s" from scheduling queue.' - LOG.info(msg, execution_queue_item_db.action_execution_id, execution_queue_item_db.id) + LOG.info( + msg, execution_queue_item_db.action_execution_id, execution_queue_item_db.id + ) execution_queue_item_db.handling = True try: - ActionExecutionSchedulingQueue.add_or_update(execution_queue_item_db, publish=False) + ActionExecutionSchedulingQueue.add_or_update( + execution_queue_item_db, publish=False + ) return execution_queue_item_db except db_exc.StackStormDBObjectWriteConflictError: LOG.info( '[%s] Item "%s" is already handled by another scheduler.', execution_queue_item_db.action_execution_id, - str(execution_queue_item_db.id) + str(execution_queue_item_db.id), ) return None - @metrics.CounterWithTimer(key='scheduler.handle_execution') + @metrics.CounterWithTimer(key="scheduler.handle_execution") def _handle_execution(self, execution_queue_item_db): action_execution_id = str(execution_queue_item_db.action_execution_id) liveaction_id = str(execution_queue_item_db.liveaction_id) queue_item_id = str(execution_queue_item_db.id) - extra = {'queue_item_id': queue_item_id} + extra = {"queue_item_id": queue_item_id} LOG.info( '[%s] Scheduling Liveaction "%s".', - action_execution_id, liveaction_id, extra=extra + action_execution_id, + liveaction_id, + extra=extra, ) try: liveaction_db = action_utils.get_liveaction_by_id(liveaction_id) except StackStormDBObjectNotFoundError: msg = '[%s] Failed to find liveaction "%s" in the database (queue_item_id=%s).' - LOG.exception(msg, action_execution_id, liveaction_id, queue_item_id, extra=extra) + LOG.exception( + msg, action_execution_id, liveaction_id, queue_item_id, extra=extra + ) ActionExecutionSchedulingQueue.delete(execution_queue_item_db) raise # Identify if the action has policies that require locking. action_has_policies_require_lock = policy_service.has_policies( - liveaction_db, - policy_types=policy_constants.POLICY_TYPES_REQUIRING_LOCK + liveaction_db, policy_types=policy_constants.POLICY_TYPES_REQUIRING_LOCK ) # Acquire a distributed lock if the referenced action has specific policies attached. @@ -275,9 +288,9 @@ def _handle_execution(self, execution_queue_item_db): # Warn users that the coordination service is not configured. if not coordination_service.configured(): LOG.warn( - '[%s] Coordination backend is not configured. ' - 'Policy enforcement is best effort.', - action_execution_id + "[%s] Coordination backend is not configured. " + "Policy enforcement is best effort.", + action_execution_id, ) # Acquire a distributed lock before querying the database to make sure that only one @@ -304,11 +317,14 @@ def _regulate_and_schedule(self, liveaction_db, execution_queue_item_db): action_execution_id = str(execution_queue_item_db.action_execution_id) liveaction_id = str(execution_queue_item_db.liveaction_id) queue_item_id = str(execution_queue_item_db.id) - extra = {'queue_item_id': queue_item_id} + extra = {"queue_item_id": queue_item_id} LOG.info( '[%s] Liveaction "%s" has status "%s" before applying policies.', - action_execution_id, liveaction_id, liveaction_db.status, extra=extra + action_execution_id, + liveaction_id, + liveaction_db.status, + extra=extra, ) # Apply policies defined for the action. @@ -316,13 +332,18 @@ def _regulate_and_schedule(self, liveaction_db, execution_queue_item_db): LOG.info( '[%s] Liveaction "%s" has status "%s" after applying policies.', - action_execution_id, liveaction_id, liveaction_db.status, extra=extra + action_execution_id, + liveaction_id, + liveaction_db.status, + extra=extra, ) if liveaction_db.status == action_constants.LIVEACTION_STATUS_DELAYED: LOG.info( '[%s] Liveaction "%s" is delayed and scheduling queue is updated.', - action_execution_id, liveaction_id, extra=extra + action_execution_id, + liveaction_id, + extra=extra, ) liveaction_db = action_service.update_status( @@ -330,23 +351,30 @@ def _regulate_and_schedule(self, liveaction_db, execution_queue_item_db): ) execution_queue_item_db.handling = False - execution_queue_item_db.scheduled_start_timestamp = date.append_milliseconds_to_time( - date.get_datetime_utc_now(), - POLICY_DELAYED_EXECUTION_RESCHEDULE_TIME_MS + execution_queue_item_db.scheduled_start_timestamp = ( + date.append_milliseconds_to_time( + date.get_datetime_utc_now(), + POLICY_DELAYED_EXECUTION_RESCHEDULE_TIME_MS, + ) ) try: - ActionExecutionSchedulingQueue.add_or_update(execution_queue_item_db, publish=False) + ActionExecutionSchedulingQueue.add_or_update( + execution_queue_item_db, publish=False + ) except db_exc.StackStormDBObjectWriteConflictError: LOG.warning( - '[%s] Database write conflict on updating scheduling queue.', - action_execution_id, extra=extra + "[%s] Database write conflict on updating scheduling queue.", + action_execution_id, + extra=extra, ) return - if (liveaction_db.status in action_constants.LIVEACTION_COMPLETED_STATES or - liveaction_db.status in action_constants.LIVEACTION_CANCEL_STATES): + if ( + liveaction_db.status in action_constants.LIVEACTION_COMPLETED_STATES + or liveaction_db.status in action_constants.LIVEACTION_CANCEL_STATES + ): ActionExecutionSchedulingQueue.delete(execution_queue_item_db) return @@ -356,33 +384,41 @@ def _delay(self, liveaction_db, execution_queue_item_db): action_execution_id = str(execution_queue_item_db.action_execution_id) liveaction_id = str(execution_queue_item_db.liveaction_id) queue_item_id = str(execution_queue_item_db.id) - extra = {'queue_item_id': queue_item_id} + extra = {"queue_item_id": queue_item_id} LOG.info( '[%s] Liveaction "%s" is delayed and scheduling queue is updated.', - action_execution_id, liveaction_id, extra=extra + action_execution_id, + liveaction_id, + extra=extra, ) liveaction_db = action_service.update_status( liveaction_db, action_constants.LIVEACTION_STATUS_DELAYED, publish=False ) - execution_queue_item_db.scheduled_start_timestamp = date.append_milliseconds_to_time( - date.get_datetime_utc_now(), - POLICY_DELAYED_EXECUTION_RESCHEDULE_TIME_MS + execution_queue_item_db.scheduled_start_timestamp = ( + date.append_milliseconds_to_time( + date.get_datetime_utc_now(), POLICY_DELAYED_EXECUTION_RESCHEDULE_TIME_MS + ) ) try: execution_queue_item_db.handling = False - ActionExecutionSchedulingQueue.add_or_update(execution_queue_item_db, publish=False) + ActionExecutionSchedulingQueue.add_or_update( + execution_queue_item_db, publish=False + ) except db_exc.StackStormDBObjectWriteConflictError: LOG.warning( - '[%s] Database write conflict on updating scheduling queue.', - action_execution_id, extra=extra + "[%s] Database write conflict on updating scheduling queue.", + action_execution_id, + extra=extra, ) def _schedule(self, liveaction_db, execution_queue_item_db): - if self._is_execution_queue_item_runnable(liveaction_db, execution_queue_item_db): + if self._is_execution_queue_item_runnable( + liveaction_db, execution_queue_item_db + ): self._update_to_scheduled(liveaction_db, execution_queue_item_db) @staticmethod @@ -396,7 +432,7 @@ def _is_execution_queue_item_runnable(liveaction_db, execution_queue_item_db): valid_status = [ action_constants.LIVEACTION_STATUS_REQUESTED, action_constants.LIVEACTION_STATUS_SCHEDULED, - action_constants.LIVEACTION_STATUS_DELAYED + action_constants.LIVEACTION_STATUS_DELAYED, ] if liveaction_db.status in valid_status: @@ -405,11 +441,14 @@ def _is_execution_queue_item_runnable(liveaction_db, execution_queue_item_db): action_execution_id = str(execution_queue_item_db.action_execution_id) liveaction_id = str(execution_queue_item_db.liveaction_id) queue_item_id = str(execution_queue_item_db.id) - extra = {'queue_item_id': queue_item_id} + extra = {"queue_item_id": queue_item_id} LOG.info( '[%s] Ignoring Liveaction "%s" with status "%s" after policies are applied.', - action_execution_id, liveaction_id, liveaction_db.status, extra=extra + action_execution_id, + liveaction_id, + liveaction_db.status, + extra=extra, ) ActionExecutionSchedulingQueue.delete(execution_queue_item_db) @@ -421,18 +460,26 @@ def _update_to_scheduled(liveaction_db, execution_queue_item_db): action_execution_id = str(execution_queue_item_db.action_execution_id) liveaction_id = str(execution_queue_item_db.liveaction_id) queue_item_id = str(execution_queue_item_db.id) - extra = {'queue_item_id': queue_item_id} + extra = {"queue_item_id": queue_item_id} # Update liveaction status to "scheduled". LOG.info( '[%s] Liveaction "%s" with status "%s" is updated to status "scheduled."', - action_execution_id, liveaction_id, liveaction_db.status, extra=extra + action_execution_id, + liveaction_id, + liveaction_db.status, + extra=extra, ) - if liveaction_db.status in [action_constants.LIVEACTION_STATUS_REQUESTED, - action_constants.LIVEACTION_STATUS_DELAYED]: + if liveaction_db.status in [ + action_constants.LIVEACTION_STATUS_REQUESTED, + action_constants.LIVEACTION_STATUS_DELAYED, + ]: liveaction_db = action_service.update_status( - liveaction_db, action_constants.LIVEACTION_STATUS_SCHEDULED, publish=False) + liveaction_db, + action_constants.LIVEACTION_STATUS_SCHEDULED, + publish=False, + ) # Publish the "scheduled" status here manually. Otherwise, there could be a # race condition with the update of the action_execution_db if the execution diff --git a/st2actions/st2actions/worker.py b/st2actions/st2actions/worker.py index 3147ce1aae3..1741d607249 100644 --- a/st2actions/st2actions/worker.py +++ b/st2actions/st2actions/worker.py @@ -34,10 +34,7 @@ from st2common.transport import queues -__all__ = [ - 'ActionExecutionDispatcher', - 'get_worker' -] +__all__ = ["ActionExecutionDispatcher", "get_worker"] LOG = logging.getLogger(__name__) @@ -46,14 +43,14 @@ queues.ACTIONRUNNER_WORK_QUEUE, queues.ACTIONRUNNER_CANCEL_QUEUE, queues.ACTIONRUNNER_PAUSE_QUEUE, - queues.ACTIONRUNNER_RESUME_QUEUE + queues.ACTIONRUNNER_RESUME_QUEUE, ] ACTIONRUNNER_DISPATCHABLE_STATES = [ action_constants.LIVEACTION_STATUS_SCHEDULED, action_constants.LIVEACTION_STATUS_CANCELING, action_constants.LIVEACTION_STATUS_PAUSING, - action_constants.LIVEACTION_STATUS_RESUMING + action_constants.LIVEACTION_STATUS_RESUMING, ] @@ -83,41 +80,54 @@ def process(self, liveaction): """ if liveaction.status == action_constants.LIVEACTION_STATUS_CANCELED: - LOG.info('%s is not executing %s (id=%s) with "%s" status.', - self.__class__.__name__, type(liveaction), liveaction.id, liveaction.status) + LOG.info( + '%s is not executing %s (id=%s) with "%s" status.', + self.__class__.__name__, + type(liveaction), + liveaction.id, + liveaction.status, + ) if not liveaction.result: updated_liveaction = action_utils.update_liveaction_status( status=liveaction.status, - result={'message': 'Action execution canceled by user.'}, - liveaction_id=liveaction.id) + result={"message": "Action execution canceled by user."}, + liveaction_id=liveaction.id, + ) executions.update_execution(updated_liveaction) return if liveaction.status not in ACTIONRUNNER_DISPATCHABLE_STATES: - LOG.info('%s is not dispatching %s (id=%s) with "%s" status.', - self.__class__.__name__, type(liveaction), liveaction.id, liveaction.status) + LOG.info( + '%s is not dispatching %s (id=%s) with "%s" status.', + self.__class__.__name__, + type(liveaction), + liveaction.id, + liveaction.status, + ) return try: liveaction_db = action_utils.get_liveaction_by_id(liveaction.id) except StackStormDBObjectNotFoundError: - LOG.exception('Failed to find liveaction %s in the database.', liveaction.id) + LOG.exception( + "Failed to find liveaction %s in the database.", liveaction.id + ) raise if liveaction.status != liveaction_db.status: LOG.warning( - 'The status of liveaction %s has changed from %s to %s ' - 'while in the queue waiting for processing.', + "The status of liveaction %s has changed from %s to %s " + "while in the queue waiting for processing.", liveaction.id, liveaction.status, - liveaction_db.status + liveaction_db.status, ) dispatchers = { action_constants.LIVEACTION_STATUS_SCHEDULED: self._run_action, action_constants.LIVEACTION_STATUS_CANCELING: self._cancel_action, action_constants.LIVEACTION_STATUS_PAUSING: self._pause_action, - action_constants.LIVEACTION_STATUS_RESUMING: self._resume_action + action_constants.LIVEACTION_STATUS_RESUMING: self._resume_action, } return dispatchers[liveaction.status](liveaction) @@ -130,7 +140,7 @@ def shutdown(self): try: executions.abandon_execution_if_incomplete(liveaction_id=liveaction_id) except: - LOG.exception('Failed to abandon liveaction %s.', liveaction_id) + LOG.exception("Failed to abandon liveaction %s.", liveaction_id) def _run_action(self, liveaction_db): # stamp liveaction with process_info @@ -140,35 +150,49 @@ def _run_action(self, liveaction_db): liveaction_db = action_utils.update_liveaction_status( status=action_constants.LIVEACTION_STATUS_RUNNING, runner_info=runner_info, - liveaction_id=liveaction_db.id) + liveaction_id=liveaction_db.id, + ) self._running_liveactions.add(liveaction_db.id) action_execution_db = executions.update_execution(liveaction_db) # Launch action - extra = {'action_execution_db': action_execution_db, 'liveaction_db': liveaction_db} - LOG.audit('Launching action execution.', extra=extra) + extra = { + "action_execution_db": action_execution_db, + "liveaction_db": liveaction_db, + } + LOG.audit("Launching action execution.", extra=extra) # the extra field will not be shown in non-audit logs so temporarily log at info. - LOG.info('Dispatched {~}action_execution: %s / {~}live_action: %s with "%s" status.', - action_execution_db.id, liveaction_db.id, liveaction_db.status) - - extra = {'liveaction_db': liveaction_db} + LOG.info( + 'Dispatched {~}action_execution: %s / {~}live_action: %s with "%s" status.', + action_execution_db.id, + liveaction_db.id, + liveaction_db.status, + ) + + extra = {"liveaction_db": liveaction_db} try: result = self.container.dispatch(liveaction_db) - LOG.debug('Runner dispatch produced result: %s', result) + LOG.debug("Runner dispatch produced result: %s", result) if not result and not liveaction_db.action_is_workflow: - raise ActionRunnerException('Failed to execute action.') + raise ActionRunnerException("Failed to execute action.") except: _, ex, tb = sys.exc_info() - extra['error'] = str(ex) - LOG.info('Action "%s" failed: %s' % (liveaction_db.action, str(ex)), extra=extra) + extra["error"] = str(ex) + LOG.info( + 'Action "%s" failed: %s' % (liveaction_db.action, str(ex)), extra=extra + ) liveaction_db = action_utils.update_liveaction_status( status=action_constants.LIVEACTION_STATUS_FAILED, liveaction_id=liveaction_db.id, - result={'error': str(ex), 'traceback': ''.join(traceback.format_tb(tb, 20))}) + result={ + "error": str(ex), + "traceback": "".join(traceback.format_tb(tb, 20)), + }, + ) executions.update_execution(liveaction_db) raise finally: @@ -182,66 +206,98 @@ def _run_action(self, liveaction_db): def _cancel_action(self, liveaction_db): action_execution_db = ActionExecution.get(liveaction__id=str(liveaction_db.id)) - extra = {'action_execution_db': action_execution_db, 'liveaction_db': liveaction_db} - LOG.audit('Canceling action execution.', extra=extra) + extra = { + "action_execution_db": action_execution_db, + "liveaction_db": liveaction_db, + } + LOG.audit("Canceling action execution.", extra=extra) # the extra field will not be shown in non-audit logs so temporarily log at info. - LOG.info('Dispatched {~}action_execution: %s / {~}live_action: %s with "%s" status.', - action_execution_db.id, liveaction_db.id, liveaction_db.status) + LOG.info( + 'Dispatched {~}action_execution: %s / {~}live_action: %s with "%s" status.', + action_execution_db.id, + liveaction_db.id, + liveaction_db.status, + ) try: result = self.container.dispatch(liveaction_db) - LOG.debug('Runner dispatch produced result: %s', result) + LOG.debug("Runner dispatch produced result: %s", result) except: _, ex, tb = sys.exc_info() - extra['error'] = str(ex) - LOG.info('Failed to cancel action execution %s.' % (liveaction_db.id), extra=extra) + extra["error"] = str(ex) + LOG.info( + "Failed to cancel action execution %s." % (liveaction_db.id), + extra=extra, + ) raise return result def _pause_action(self, liveaction_db): action_execution_db = ActionExecution.get(liveaction__id=str(liveaction_db.id)) - extra = {'action_execution_db': action_execution_db, 'liveaction_db': liveaction_db} - LOG.audit('Pausing action execution.', extra=extra) + extra = { + "action_execution_db": action_execution_db, + "liveaction_db": liveaction_db, + } + LOG.audit("Pausing action execution.", extra=extra) # the extra field will not be shown in non-audit logs so temporarily log at info. - LOG.info('Dispatched {~}action_execution: %s / {~}live_action: %s with "%s" status.', - action_execution_db.id, liveaction_db.id, liveaction_db.status) + LOG.info( + 'Dispatched {~}action_execution: %s / {~}live_action: %s with "%s" status.', + action_execution_db.id, + liveaction_db.id, + liveaction_db.status, + ) try: result = self.container.dispatch(liveaction_db) - LOG.debug('Runner dispatch produced result: %s', result) + LOG.debug("Runner dispatch produced result: %s", result) except: _, ex, tb = sys.exc_info() - extra['error'] = str(ex) - LOG.info('Failed to pause action execution %s.' % (liveaction_db.id), extra=extra) + extra["error"] = str(ex) + LOG.info( + "Failed to pause action execution %s." % (liveaction_db.id), extra=extra + ) raise return result def _resume_action(self, liveaction_db): action_execution_db = ActionExecution.get(liveaction__id=str(liveaction_db.id)) - extra = {'action_execution_db': action_execution_db, 'liveaction_db': liveaction_db} - LOG.audit('Resuming action execution.', extra=extra) + extra = { + "action_execution_db": action_execution_db, + "liveaction_db": liveaction_db, + } + LOG.audit("Resuming action execution.", extra=extra) # the extra field will not be shown in non-audit logs so temporarily log at info. - LOG.info('Dispatched {~}action_execution: %s / {~}live_action: %s with "%s" status.', - action_execution_db.id, liveaction_db.id, liveaction_db.status) + LOG.info( + 'Dispatched {~}action_execution: %s / {~}live_action: %s with "%s" status.', + action_execution_db.id, + liveaction_db.id, + liveaction_db.status, + ) try: result = self.container.dispatch(liveaction_db) - LOG.debug('Runner dispatch produced result: %s', result) + LOG.debug("Runner dispatch produced result: %s", result) except: _, ex, tb = sys.exc_info() - extra['error'] = str(ex) - LOG.info('Failed to resume action execution %s.' % (liveaction_db.id), extra=extra) + extra["error"] = str(ex) + LOG.info( + "Failed to resume action execution %s." % (liveaction_db.id), + extra=extra, + ) raise # Cascade the resume upstream if action execution is child of an orquesta workflow. # The action service request_resume function is not used here because we do not want # other peer subworkflows to be resumed. - if 'orquesta' in action_execution_db.context and 'parent' in action_execution_db.context: + if ( + "orquesta" in action_execution_db.context + and "parent" in action_execution_db.context + ): wf_svc.handle_action_execution_resume(action_execution_db) return result diff --git a/st2actions/st2actions/workflows/config.py b/st2actions/st2actions/workflows/config.py index 0d2556f67aa..6854323ddd5 100644 --- a/st2actions/st2actions/workflows/config.py +++ b/st2actions/st2actions/workflows/config.py @@ -23,8 +23,11 @@ def parse_args(args=None): - cfg.CONF(args=args, version=sys_constants.VERSION_STRING, - default_config_files=[DEFAULT_CONFIG_FILE_PATH]) + cfg.CONF( + args=args, + version=sys_constants.VERSION_STRING, + default_config_files=[DEFAULT_CONFIG_FILE_PATH], + ) def register_opts(): @@ -43,13 +46,13 @@ def _register_common_opts(): def _register_service_opts(): wf_engine_opts = [ cfg.StrOpt( - 'logging', - default='/etc/st2/logging.workflowengine.conf', - help='Location of the logging configuration file.' + "logging", + default="/etc/st2/logging.workflowengine.conf", + help="Location of the logging configuration file.", ) ] - cfg.CONF.register_opts(wf_engine_opts, group='workflow_engine') + cfg.CONF.register_opts(wf_engine_opts, group="workflow_engine") register_opts() diff --git a/st2actions/st2actions/workflows/workflows.py b/st2actions/st2actions/workflows/workflows.py index 0351998025e..2151c7d4407 100644 --- a/st2actions/st2actions/workflows/workflows.py +++ b/st2actions/st2actions/workflows/workflows.py @@ -37,17 +37,16 @@ WORKFLOW_EXECUTION_QUEUES = [ queues.WORKFLOW_EXECUTION_WORK_QUEUE, queues.WORKFLOW_EXECUTION_RESUME_QUEUE, - queues.WORKFLOW_ACTION_EXECUTION_UPDATE_QUEUE + queues.WORKFLOW_ACTION_EXECUTION_UPDATE_QUEUE, ] class WorkflowExecutionHandler(consumers.VariableMessageHandler): - def __init__(self, connection, queues): super(WorkflowExecutionHandler, self).__init__(connection, queues) def handle_workflow_execution_with_instrumentation(wf_ex_db): - with metrics.CounterWithTimer(key='orquesta.workflow.executions'): + with metrics.CounterWithTimer(key="orquesta.workflow.executions"): return self.handle_workflow_execution(wf_ex_db=wf_ex_db) def handle_action_execution_with_instrumentation(ac_ex_db): @@ -55,27 +54,27 @@ def handle_action_execution_with_instrumentation(ac_ex_db): if not wf_svc.is_action_execution_under_workflow_context(ac_ex_db): return - with metrics.CounterWithTimer(key='orquesta.action.executions'): + with metrics.CounterWithTimer(key="orquesta.action.executions"): return self.handle_action_execution(ac_ex_db=ac_ex_db) self.message_types = { wf_db_models.WorkflowExecutionDB: handle_workflow_execution_with_instrumentation, - ex_db_models.ActionExecutionDB: handle_action_execution_with_instrumentation + ex_db_models.ActionExecutionDB: handle_action_execution_with_instrumentation, } def get_queue_consumer(self, connection, queues): # We want to use a special ActionsQueueConsumer which uses 2 dispatcher pools return consumers.VariableMessageQueueConsumer( - connection=connection, - queues=queues, - handler=self + connection=connection, queues=queues, handler=self ) def process(self, message): handler_function = self.message_types.get(type(message), None) if not handler_function: - msg = 'Handler function for message type "%s" is not defined.' % type(message) + msg = 'Handler function for message type "%s" is not defined.' % type( + message + ) raise ValueError(msg) try: @@ -90,43 +89,45 @@ def process(self, message): def fail_workflow_execution(self, message, exception): # Prepare attributes based on message type. if isinstance(message, wf_db_models.WorkflowExecutionDB): - msg_type = 'workflow' + msg_type = "workflow" wf_ex_db = message wf_ex_id = str(wf_ex_db.id) task = None else: - msg_type = 'task' + msg_type = "task" ac_ex_db = message - wf_ex_id = ac_ex_db.context['orquesta']['workflow_execution_id'] - task_ex_id = ac_ex_db.context['orquesta']['task_execution_id'] + wf_ex_id = ac_ex_db.context["orquesta"]["workflow_execution_id"] + task_ex_id = ac_ex_db.context["orquesta"]["task_execution_id"] wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_id) task_ex_db = wf_db_access.TaskExecution.get_by_id(task_ex_id) - task = {'id': task_ex_db.task_id, 'route': task_ex_db.task_route} + task = {"id": task_ex_db.task_id, "route": task_ex_db.task_route} # Log the error. - msg = 'Unknown error while processing %s execution. %s: %s' + msg = "Unknown error while processing %s execution. %s: %s" wf_svc.update_progress( wf_ex_db, msg % (msg_type, exception.__class__.__name__, str(exception)), - severity='error' + severity="error", ) # Fail the task execution so it's marked correctly in the # conductor state to allow for task rerun if needed. if isinstance(message, ex_db_models.ActionExecutionDB): msg = 'Unknown error while processing %s execution. Failing task execution "%s".' - wf_svc.update_progress(wf_ex_db, msg % (msg_type, task_ex_id), severity='error') + wf_svc.update_progress( + wf_ex_db, msg % (msg_type, task_ex_id), severity="error" + ) wf_svc.update_task_execution(task_ex_id, ac_const.LIVEACTION_STATUS_FAILED) wf_svc.update_task_state(task_ex_id, ac_const.LIVEACTION_STATUS_FAILED) # Fail the workflow execution. msg = 'Unknown error while processing %s execution. Failing workflow execution "%s".' - wf_svc.update_progress(wf_ex_db, msg % (msg_type, wf_ex_id), severity='error') + wf_svc.update_progress(wf_ex_db, msg % (msg_type, wf_ex_id), severity="error") wf_svc.fail_workflow_execution(wf_ex_id, exception, task=task) def handle_workflow_execution(self, wf_ex_db): # Request the next set of tasks to execute. - wf_svc.update_progress(wf_ex_db, 'Processing request for workflow execution.') + wf_svc.update_progress(wf_ex_db, "Processing request for workflow execution.") wf_svc.request_next_tasks(wf_ex_db) def handle_action_execution(self, ac_ex_db): @@ -135,16 +136,17 @@ def handle_action_execution(self, ac_ex_db): return # Get related record identifiers. - wf_ex_id = ac_ex_db.context['orquesta']['workflow_execution_id'] - task_ex_id = ac_ex_db.context['orquesta']['task_execution_id'] + wf_ex_id = ac_ex_db.context["orquesta"]["workflow_execution_id"] + task_ex_id = ac_ex_db.context["orquesta"]["task_execution_id"] # Get execution records for logging purposes. wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_id) task_ex_db = wf_db_access.TaskExecution.get_by_id(task_ex_id) - msg = ( - 'Action execution "%s" for task "%s" is updated and in "%s" state.' % - (str(ac_ex_db.id), task_ex_db.task_id, ac_ex_db.status) + msg = 'Action execution "%s" for task "%s" is updated and in "%s" state.' % ( + str(ac_ex_db.id), + task_ex_db.task_id, + ac_ex_db.status, ) wf_svc.update_progress(wf_ex_db, msg) @@ -152,9 +154,13 @@ def handle_action_execution(self, ac_ex_db): if task_ex_db.status in statuses.COMPLETED_STATUSES: msg = ( 'Action execution "%s" for task "%s", route "%s", is not processed ' - 'because task execution "%s" is already in completed state "%s".' % ( - str(ac_ex_db.id), task_ex_db.task_id, str(task_ex_db.task_route), - str(task_ex_db.id), task_ex_db.status + 'because task execution "%s" is already in completed state "%s".' + % ( + str(ac_ex_db.id), + task_ex_db.task_id, + str(task_ex_db.task_route), + str(task_ex_db.id), + task_ex_db.status, ) ) wf_svc.update_progress(wf_ex_db, msg) @@ -175,7 +181,7 @@ def handle_action_execution(self, ac_ex_db): return # Apply post run policies. - lv_ac_db = lv_db_access.LiveAction.get_by_id(ac_ex_db.liveaction['id']) + lv_ac_db = lv_db_access.LiveAction.get_by_id(ac_ex_db.liveaction["id"]) pc_svc.apply_post_run_policies(lv_ac_db) # Process completion of the action execution. diff --git a/st2actions/tests/unit/policies/test_base.py b/st2actions/tests/unit/policies/test_base.py index fcf3aef40d4..2e5003d89c6 100644 --- a/st2actions/tests/unit/policies/test_base.py +++ b/st2actions/tests/unit/policies/test_base.py @@ -17,6 +17,7 @@ import mock from st2tests import config as test_config + test_config.parse_args() import st2common @@ -32,28 +33,21 @@ from st2tests.fixturesloader import FixturesLoader -__all__ = [ - 'SchedulerPoliciesTestCase', - 'NotifierPoliciesTestCase' -] +__all__ = ["SchedulerPoliciesTestCase", "NotifierPoliciesTestCase"] -PACK = 'generic' +PACK = "generic" TEST_FIXTURES_1 = { - 'actions': [ - 'action1.yaml' + "actions": ["action1.yaml"], + "policies": [ + "policy_4.yaml", ], - 'policies': [ - 'policy_4.yaml', - ] } TEST_FIXTURES_2 = { - 'actions': [ - 'action1.yaml' + "actions": ["action1.yaml"], + "policies": [ + "policy_1.yaml", ], - 'policies': [ - 'policy_1.yaml', - ] } @@ -73,15 +67,14 @@ def setUp(self): register_policy_types(st2common) loader = FixturesLoader() - models = loader.save_fixtures_to_db(fixtures_pack=PACK, - fixtures_dict=TEST_FIXTURES_2) + models = loader.save_fixtures_to_db( + fixtures_pack=PACK, fixtures_dict=TEST_FIXTURES_2 + ) # Policy with "post_run" application - self.policy_db = models['policies']['policy_1.yaml'] + self.policy_db = models["policies"]["policy_1.yaml"] - @mock.patch.object( - policies, 'get_driver', - mock.MagicMock(return_value=None)) + @mock.patch.object(policies, "get_driver", mock.MagicMock(return_value=None)) def test_disabled_policy_not_applied_on_pre_run(self): ########## # First test a scenario where policy is enabled @@ -91,7 +84,9 @@ def test_disabled_policy_not_applied_on_pre_run(self): # Post run hasn't been called yet, call count should be 0 self.assertEqual(policies.get_driver.call_count, 0) - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'}) + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "foo"} + ) live_action_db, execution_db = action_service.request(liveaction) policy_service.apply_pre_run_policies(live_action_db) @@ -108,7 +103,9 @@ def test_disabled_policy_not_applied_on_pre_run(self): self.assertEqual(policies.get_driver.call_count, 0) - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'}) + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "foo"} + ) live_action_db, execution_db = action_service.request(liveaction) policy_service.apply_pre_run_policies(live_action_db) @@ -133,15 +130,14 @@ def setUp(self): register_policy_types(st2common) loader = FixturesLoader() - models = loader.save_fixtures_to_db(fixtures_pack=PACK, - fixtures_dict=TEST_FIXTURES_1) + models = loader.save_fixtures_to_db( + fixtures_pack=PACK, fixtures_dict=TEST_FIXTURES_1 + ) # Policy with "post_run" application - self.policy_db = models['policies']['policy_4.yaml'] + self.policy_db = models["policies"]["policy_4.yaml"] - @mock.patch.object( - policies, 'get_driver', - mock.MagicMock(return_value=None)) + @mock.patch.object(policies, "get_driver", mock.MagicMock(return_value=None)) def test_disabled_policy_not_applied_on_post_run(self): ########## # First test a scenario where policy is enabled @@ -151,7 +147,9 @@ def test_disabled_policy_not_applied_on_post_run(self): # Post run hasn't been called yet, call count should be 0 self.assertEqual(policies.get_driver.call_count, 0) - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'}) + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "foo"} + ) live_action_db, execution_db = action_service.request(liveaction) policy_service.apply_post_run_policies(live_action_db) @@ -168,7 +166,9 @@ def test_disabled_policy_not_applied_on_post_run(self): self.assertEqual(policies.get_driver.call_count, 0) - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'}) + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "foo"} + ) live_action_db, execution_db = action_service.request(liveaction) policy_service.apply_post_run_policies(live_action_db) diff --git a/st2actions/tests/unit/policies/test_concurrency.py b/st2actions/tests/unit/policies/test_concurrency.py index 670c38d8397..f22a0303cde 100644 --- a/st2actions/tests/unit/policies/test_concurrency.py +++ b/st2actions/tests/unit/policies/test_concurrency.py @@ -42,40 +42,40 @@ from st2tests.mocks.runners import runner -__all__ = [ - 'ConcurrencyPolicyTestCase' -] +__all__ = ["ConcurrencyPolicyTestCase"] -PACK = 'generic' +PACK = "generic" TEST_FIXTURES = { - 'actions': [ - 'action1.yaml', - 'action2.yaml' - ], - 'policies': [ - 'policy_1.yaml', - 'policy_5.yaml' - ] + "actions": ["action1.yaml", "action2.yaml"], + "policies": ["policy_1.yaml", "policy_5.yaml"], } -NON_EMPTY_RESULT = 'non-empty' -MOCK_RUN_RETURN_VALUE = (action_constants.LIVEACTION_STATUS_RUNNING, NON_EMPTY_RESULT, None) +NON_EMPTY_RESULT = "non-empty" +MOCK_RUN_RETURN_VALUE = ( + action_constants.LIVEACTION_STATUS_RUNNING, + NON_EMPTY_RESULT, + None, +) SCHEDULED_STATES = [ action_constants.LIVEACTION_STATUS_SCHEDULED, action_constants.LIVEACTION_STATUS_RUNNING, - action_constants.LIVEACTION_STATUS_SUCCEEDED + action_constants.LIVEACTION_STATUS_SUCCEEDED, ] -@mock.patch('st2common.runners.base.get_runner', mock.Mock(return_value=runner.get_runner())) -@mock.patch('st2actions.container.base.get_runner', mock.Mock(return_value=runner.get_runner())) -@mock.patch.object( - CUDPublisher, 'publish_update', - mock.MagicMock(side_effect=MockExecutionPublisher.publish_update)) +@mock.patch( + "st2common.runners.base.get_runner", mock.Mock(return_value=runner.get_runner()) +) +@mock.patch( + "st2actions.container.base.get_runner", mock.Mock(return_value=runner.get_runner()) +) @mock.patch.object( - CUDPublisher, 'publish_create', - mock.MagicMock(return_value=None)) + CUDPublisher, + "publish_update", + mock.MagicMock(side_effect=MockExecutionPublisher.publish_update), +) +@mock.patch.object(CUDPublisher, "publish_create", mock.MagicMock(return_value=None)) class ConcurrencyPolicyTestCase(EventletTestCase, ExecutionDbTestCase): @classmethod def setUpClass(cls): @@ -93,8 +93,7 @@ def setUpClass(cls): register_policy_types(st2common) loader = FixturesLoader() - loader.save_fixtures_to_db(fixtures_pack=PACK, - fixtures_dict=TEST_FIXTURES) + loader.save_fixtures_to_db(fixtures_pack=PACK, fixtures_dict=TEST_FIXTURES) @classmethod def tearDownClass(cls): @@ -106,10 +105,15 @@ def tearDownClass(cls): # NOTE: This monkey patch needs to happen again here because during tests for some reason this # method gets unpatched (test doing reload() or similar) - @mock.patch('st2actions.container.base.get_runner', mock.Mock(return_value=runner.get_runner())) + @mock.patch( + "st2actions.container.base.get_runner", + mock.Mock(return_value=runner.get_runner()), + ) def tearDown(self): for liveaction in LiveAction.get_all(): - action_service.update_status(liveaction, action_constants.LIVEACTION_STATUS_CANCELED) + action_service.update_status( + liveaction, action_constants.LIVEACTION_STATUS_CANCELED + ) @staticmethod def _process_scheduling_queue(): @@ -117,64 +121,82 @@ def _process_scheduling_queue(): scheduling_queue.get_handler()._handle_execution(queued_req) @mock.patch.object( - runner.MockActionRunner, 'run', - mock.MagicMock(return_value=MOCK_RUN_RETURN_VALUE)) + runner.MockActionRunner, + "run", + mock.MagicMock(return_value=MOCK_RUN_RETURN_VALUE), + ) @mock.patch.object( - LiveActionPublisher, 'publish_state', - mock.MagicMock(side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state)) + LiveActionPublisher, + "publish_state", + mock.MagicMock( + side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state + ), + ) def test_over_threshold_delay_executions(self): # Ensure the concurrency policy is accurate. - policy_db = Policy.get_by_ref('wolfpack.action-1.concurrency') - self.assertGreater(policy_db.parameters['threshold'], 0) + policy_db = Policy.get_by_ref("wolfpack.action-1.concurrency") + self.assertGreater(policy_db.parameters["threshold"], 0) # Launch action executions until the expected threshold is reached. - for i in range(0, policy_db.parameters['threshold']): - parameters = {'actionstr': 'foo-' + str(i)} - liveaction = LiveActionDB(action='wolfpack.action-1', parameters=parameters) + for i in range(0, policy_db.parameters["threshold"]): + parameters = {"actionstr": "foo-" + str(i)} + liveaction = LiveActionDB(action="wolfpack.action-1", parameters=parameters) action_service.request(liveaction) # Run the scheduler to schedule action executions. self._process_scheduling_queue() # Check the number of action executions in scheduled state. - scheduled = [item for item in LiveAction.get_all() if item.status in SCHEDULED_STATES] - self.assertEqual(len(scheduled), policy_db.parameters['threshold']) + scheduled = [ + item for item in LiveAction.get_all() if item.status in SCHEDULED_STATES + ] + self.assertEqual(len(scheduled), policy_db.parameters["threshold"]) # Assert the correct number of published states and action executions. This is to avoid # duplicate executions caused by accidental publishing of state in the concurrency policies. # num_state_changes = len(scheduled) * len(['requested', 'scheduled', 'running']) expected_num_exec = len(scheduled) expected_num_pubs = expected_num_exec * 3 - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count) # Execution is expected to be delayed since concurrency threshold is reached. - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo-last'}) + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "foo-last"} + ) liveaction, _ = action_service.request(liveaction) expected_num_pubs += 1 # Tally requested state. - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) # Run the scheduler to schedule action executions. self._process_scheduling_queue() # Since states are being processed async, wait for the liveaction to go into delayed state. - liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_DELAYED) + liveaction = self._wait_on_status( + liveaction, action_constants.LIVEACTION_STATUS_DELAYED + ) expected_num_exec += 0 # This request will not be scheduled for execution. expected_num_pubs += 0 # The delayed status change should not be published. - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count) # Mark one of the scheduled/running execution as completed. action_service.update_status( - scheduled[0], - action_constants.LIVEACTION_STATUS_SUCCEEDED, - publish=True + scheduled[0], action_constants.LIVEACTION_STATUS_SUCCEEDED, publish=True ) expected_num_pubs += 1 # Tally succeeded state. - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) # Run the scheduler to schedule action executions. self._process_scheduling_queue() @@ -185,52 +207,74 @@ def test_over_threshold_delay_executions(self): # Since states are being processed async, wait for the liveaction to be scheduled. liveaction = self._wait_on_statuses(liveaction, SCHEDULED_STATES) - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count) # Check the status changes. execution = ActionExecution.get(liveaction__id=str(liveaction.id)) - expected_status_changes = ['requested', 'delayed', 'requested', 'scheduled', 'running'] - actual_status_changes = [entry['status'] for entry in execution.log] + expected_status_changes = [ + "requested", + "delayed", + "requested", + "scheduled", + "running", + ] + actual_status_changes = [entry["status"] for entry in execution.log] self.assertListEqual(actual_status_changes, expected_status_changes) @mock.patch.object( - runner.MockActionRunner, 'run', - mock.MagicMock(return_value=MOCK_RUN_RETURN_VALUE)) + runner.MockActionRunner, + "run", + mock.MagicMock(return_value=MOCK_RUN_RETURN_VALUE), + ) @mock.patch.object( - LiveActionPublisher, 'publish_state', - mock.MagicMock(side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state)) + LiveActionPublisher, + "publish_state", + mock.MagicMock( + side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state + ), + ) def test_over_threshold_cancel_executions(self): - policy_db = Policy.get_by_ref('wolfpack.action-2.concurrency.cancel') - self.assertEqual(policy_db.parameters['action'], 'cancel') - self.assertGreater(policy_db.parameters['threshold'], 0) + policy_db = Policy.get_by_ref("wolfpack.action-2.concurrency.cancel") + self.assertEqual(policy_db.parameters["action"], "cancel") + self.assertGreater(policy_db.parameters["threshold"], 0) # Launch action executions until the expected threshold is reached. - for i in range(0, policy_db.parameters['threshold']): - parameters = {'actionstr': 'foo-' + str(i)} - liveaction = LiveActionDB(action='wolfpack.action-2', parameters=parameters) + for i in range(0, policy_db.parameters["threshold"]): + parameters = {"actionstr": "foo-" + str(i)} + liveaction = LiveActionDB(action="wolfpack.action-2", parameters=parameters) action_service.request(liveaction) # Run the scheduler to schedule action executions. self._process_scheduling_queue() # Check the number of action executions in scheduled state. - scheduled = [item for item in LiveAction.get_all() if item.status in SCHEDULED_STATES] - self.assertEqual(len(scheduled), policy_db.parameters['threshold']) + scheduled = [ + item for item in LiveAction.get_all() if item.status in SCHEDULED_STATES + ] + self.assertEqual(len(scheduled), policy_db.parameters["threshold"]) # duplicate executions caused by accidental publishing of state in the concurrency policies. # num_state_changes = len(scheduled) * len(['requested', 'scheduled', 'running']) expected_num_exec = len(scheduled) expected_num_pubs = expected_num_exec * 3 - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count) # Execution is expected to be canceled since concurrency threshold is reached. - liveaction = LiveActionDB(action='wolfpack.action-2', parameters={'actionstr': 'foo'}) + liveaction = LiveActionDB( + action="wolfpack.action-2", parameters={"actionstr": "foo"} + ) liveaction, _ = action_service.request(liveaction) expected_num_pubs += 1 # Tally requested state. - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) # Run the scheduler to schedule action executions. self._process_scheduling_queue() @@ -240,67 +284,91 @@ def test_over_threshold_cancel_executions(self): LiveActionPublisher.publish_state.assert_has_calls(calls) expected_num_pubs += 2 # Tally canceling and canceled state changes. expected_num_exec += 0 # This request will not be scheduled for execution. - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count) # Assert the action is canceled. liveaction = LiveAction.get_by_id(str(liveaction.id)) self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_CANCELED) - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count) @mock.patch.object( - runner.MockActionRunner, 'run', - mock.MagicMock(return_value=MOCK_RUN_RETURN_VALUE)) + runner.MockActionRunner, + "run", + mock.MagicMock(return_value=MOCK_RUN_RETURN_VALUE), + ) @mock.patch.object( - LiveActionPublisher, 'publish_state', - mock.MagicMock(side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state)) + LiveActionPublisher, + "publish_state", + mock.MagicMock( + side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state + ), + ) def test_on_cancellation(self): - policy_db = Policy.get_by_ref('wolfpack.action-1.concurrency') - self.assertGreater(policy_db.parameters['threshold'], 0) + policy_db = Policy.get_by_ref("wolfpack.action-1.concurrency") + self.assertGreater(policy_db.parameters["threshold"], 0) # Launch action executions until the expected threshold is reached. - for i in range(0, policy_db.parameters['threshold']): - parameters = {'actionstr': 'foo-' + str(i)} - liveaction = LiveActionDB(action='wolfpack.action-1', parameters=parameters) + for i in range(0, policy_db.parameters["threshold"]): + parameters = {"actionstr": "foo-" + str(i)} + liveaction = LiveActionDB(action="wolfpack.action-1", parameters=parameters) action_service.request(liveaction) # Run the scheduler to schedule action executions. self._process_scheduling_queue() # Check the number of action executions in scheduled state. - scheduled = [item for item in LiveAction.get_all() if item.status in SCHEDULED_STATES] - self.assertEqual(len(scheduled), policy_db.parameters['threshold']) + scheduled = [ + item for item in LiveAction.get_all() if item.status in SCHEDULED_STATES + ] + self.assertEqual(len(scheduled), policy_db.parameters["threshold"]) # duplicate executions caused by accidental publishing of state in the concurrency policies. # num_state_changes = len(scheduled) * len(['requested', 'scheduled', 'running']) expected_num_exec = len(scheduled) expected_num_pubs = expected_num_exec * 3 - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count) # Execution is expected to be delayed since concurrency threshold is reached. - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'}) + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "foo"} + ) liveaction, _ = action_service.request(liveaction) expected_num_pubs += 1 # Tally requested state. - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) # Run the scheduler to schedule action executions. self._process_scheduling_queue() # Since states are being processed async, wait for the liveaction to go into delayed state. - liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_DELAYED) + liveaction = self._wait_on_status( + liveaction, action_constants.LIVEACTION_STATUS_DELAYED + ) expected_num_exec += 0 # This request will not be scheduled for execution. expected_num_pubs += 0 # The delayed status change should not be published. - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count) # Cancel execution. - action_service.request_cancellation(scheduled[0], 'stanley') + action_service.request_cancellation(scheduled[0], "stanley") expected_num_pubs += 2 # Tally the canceling and canceled states. - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) # Run the scheduler to schedule action executions. self._process_scheduling_queue() @@ -312,5 +380,7 @@ def test_on_cancellation(self): # Execution is expected to be rescheduled. liveaction = LiveAction.get_by_id(str(liveaction.id)) self.assertIn(liveaction.status, SCHEDULED_STATES) - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count) diff --git a/st2actions/tests/unit/policies/test_concurrency_by_attr.py b/st2actions/tests/unit/policies/test_concurrency_by_attr.py index b576e3a669d..98cfc3a4dc3 100644 --- a/st2actions/tests/unit/policies/test_concurrency_by_attr.py +++ b/st2actions/tests/unit/policies/test_concurrency_by_attr.py @@ -39,42 +39,41 @@ from st2tests.mocks.runners import runner from six.moves import range -__all__ = [ - 'ConcurrencyByAttributePolicyTestCase' -] +__all__ = ["ConcurrencyByAttributePolicyTestCase"] -PACK = 'generic' +PACK = "generic" TEST_FIXTURES = { - 'actions': [ - 'action1.yaml', - 'action2.yaml' - ], - 'policies': [ - 'policy_3.yaml', - 'policy_7.yaml' - ] + "actions": ["action1.yaml", "action2.yaml"], + "policies": ["policy_3.yaml", "policy_7.yaml"], } -NON_EMPTY_RESULT = 'non-empty' -MOCK_RUN_RETURN_VALUE = (action_constants.LIVEACTION_STATUS_RUNNING, NON_EMPTY_RESULT, None) +NON_EMPTY_RESULT = "non-empty" +MOCK_RUN_RETURN_VALUE = ( + action_constants.LIVEACTION_STATUS_RUNNING, + NON_EMPTY_RESULT, + None, +) SCHEDULED_STATES = [ action_constants.LIVEACTION_STATUS_SCHEDULED, action_constants.LIVEACTION_STATUS_RUNNING, - action_constants.LIVEACTION_STATUS_SUCCEEDED + action_constants.LIVEACTION_STATUS_SUCCEEDED, ] -@mock.patch('st2common.runners.base.get_runner', mock.Mock(return_value=runner.get_runner())) -@mock.patch('st2actions.container.base.get_runner', mock.Mock(return_value=runner.get_runner())) +@mock.patch( + "st2common.runners.base.get_runner", mock.Mock(return_value=runner.get_runner()) +) +@mock.patch( + "st2actions.container.base.get_runner", mock.Mock(return_value=runner.get_runner()) +) @mock.patch.object( - CUDPublisher, 'publish_update', - mock.MagicMock(side_effect=MockExecutionPublisher.publish_update)) -@mock.patch.object( - CUDPublisher, 'publish_create', - mock.MagicMock(return_value=None)) + CUDPublisher, + "publish_update", + mock.MagicMock(side_effect=MockExecutionPublisher.publish_update), +) +@mock.patch.object(CUDPublisher, "publish_create", mock.MagicMock(return_value=None)) class ConcurrencyByAttributePolicyTestCase(EventletTestCase, ExecutionDbTestCase): - @classmethod def setUpClass(cls): EventletTestCase.setUpClass() @@ -91,8 +90,7 @@ def setUpClass(cls): register_policy_types(st2common) loader = FixturesLoader() - loader.save_fixtures_to_db(fixtures_pack=PACK, - fixtures_dict=TEST_FIXTURES) + loader.save_fixtures_to_db(fixtures_pack=PACK, fixtures_dict=TEST_FIXTURES) @classmethod def tearDownClass(cls): @@ -104,10 +102,15 @@ def tearDownClass(cls): # NOTE: This monkey patch needs to happen again here because during tests for some reason this # method gets unpatched (test doing reload() or similar) - @mock.patch('st2actions.container.base.get_runner', mock.Mock(return_value=runner.get_runner())) + @mock.patch( + "st2actions.container.base.get_runner", + mock.Mock(return_value=runner.get_runner()), + ) def tearDown(self): for liveaction in LiveAction.get_all(): - action_service.update_status(liveaction, action_constants.LIVEACTION_STATUS_CANCELED) + action_service.update_status( + liveaction, action_constants.LIVEACTION_STATUS_CANCELED + ) @staticmethod def _process_scheduling_queue(): @@ -115,58 +118,80 @@ def _process_scheduling_queue(): scheduling_queue.get_handler()._handle_execution(queued_req) @mock.patch.object( - runner.MockActionRunner, 'run', - mock.MagicMock(return_value=MOCK_RUN_RETURN_VALUE)) + runner.MockActionRunner, + "run", + mock.MagicMock(return_value=MOCK_RUN_RETURN_VALUE), + ) @mock.patch.object( - LiveActionPublisher, 'publish_state', - mock.MagicMock(side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state)) + LiveActionPublisher, + "publish_state", + mock.MagicMock( + side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state + ), + ) def test_over_threshold_delay_executions(self): - policy_db = Policy.get_by_ref('wolfpack.action-1.concurrency.attr') - self.assertGreater(policy_db.parameters['threshold'], 0) - self.assertIn('actionstr', policy_db.parameters['attributes']) + policy_db = Policy.get_by_ref("wolfpack.action-1.concurrency.attr") + self.assertGreater(policy_db.parameters["threshold"], 0) + self.assertIn("actionstr", policy_db.parameters["attributes"]) # Launch action executions until the expected threshold is reached. - for i in range(0, policy_db.parameters['threshold']): - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'}) + for i in range(0, policy_db.parameters["threshold"]): + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "foo"} + ) action_service.request(liveaction) # Run the scheduler to schedule action executions. self._process_scheduling_queue() # Check the number of action executions in scheduled state. - scheduled = [item for item in LiveAction.get_all() if item.status in SCHEDULED_STATES] - self.assertEqual(len(scheduled), policy_db.parameters['threshold']) + scheduled = [ + item for item in LiveAction.get_all() if item.status in SCHEDULED_STATES + ] + self.assertEqual(len(scheduled), policy_db.parameters["threshold"]) # Assert the correct number of published states and action executions. This is to avoid # duplicate executions caused by accidental publishing of state in the concurrency policies. # num_state_changes = len(scheduled) * len(['requested', 'scheduled', 'running']) expected_num_exec = len(scheduled) expected_num_pubs = expected_num_exec * 3 - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count) # Execution is expected to be delayed since concurrency threshold is reached. - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'}) + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "foo"} + ) liveaction, _ = action_service.request(liveaction) expected_num_pubs += 1 # Tally requested state. - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) # Run the scheduler to schedule action executions. self._process_scheduling_queue() # Since states are being processed asynchronously, wait for the # liveaction to go into delayed state. - liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_DELAYED) + liveaction = self._wait_on_status( + liveaction, action_constants.LIVEACTION_STATUS_DELAYED + ) expected_num_exec += 0 # This request will not be scheduled for execution. expected_num_pubs += 0 # The delayed status change should not be published. - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count) # Execution is expected to be scheduled since concurrency threshold is not reached. # The execution with actionstr "fu" is over the threshold but actionstr "bar" is not. - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'bar'}) + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "bar"} + ) liveaction, _ = action_service.request(liveaction) # Run the scheduler to schedule action executions. @@ -177,18 +202,20 @@ def test_over_threshold_delay_executions(self): liveaction = self._wait_on_statuses(liveaction, SCHEDULED_STATES) expected_num_exec += 1 # This request is expected to be executed. expected_num_pubs += 3 # Tally requested, scheduled, and running state. - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count) # Mark one of the execution as completed. action_service.update_status( - scheduled[0], - action_constants.LIVEACTION_STATUS_SUCCEEDED, - publish=True + scheduled[0], action_constants.LIVEACTION_STATUS_SUCCEEDED, publish=True ) expected_num_pubs += 1 # Tally succeeded state. - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) # Run the scheduler to schedule action executions. self._process_scheduling_queue() @@ -197,47 +224,65 @@ def test_over_threshold_delay_executions(self): liveaction = self._wait_on_statuses(liveaction, SCHEDULED_STATES) expected_num_exec += 1 # The delayed request is expected to be executed. expected_num_pubs += 2 # Tally scheduled and running state. - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count) @mock.patch.object( - runner.MockActionRunner, 'run', - mock.MagicMock(return_value=MOCK_RUN_RETURN_VALUE)) + runner.MockActionRunner, + "run", + mock.MagicMock(return_value=MOCK_RUN_RETURN_VALUE), + ) @mock.patch.object( - LiveActionPublisher, 'publish_state', - mock.MagicMock(side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state)) + LiveActionPublisher, + "publish_state", + mock.MagicMock( + side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state + ), + ) def test_over_threshold_cancel_executions(self): - policy_db = Policy.get_by_ref('wolfpack.action-2.concurrency.attr.cancel') - self.assertEqual(policy_db.parameters['action'], 'cancel') - self.assertGreater(policy_db.parameters['threshold'], 0) - self.assertIn('actionstr', policy_db.parameters['attributes']) + policy_db = Policy.get_by_ref("wolfpack.action-2.concurrency.attr.cancel") + self.assertEqual(policy_db.parameters["action"], "cancel") + self.assertGreater(policy_db.parameters["threshold"], 0) + self.assertIn("actionstr", policy_db.parameters["attributes"]) # Launch action executions until the expected threshold is reached. - for i in range(0, policy_db.parameters['threshold']): - liveaction = LiveActionDB(action='wolfpack.action-2', parameters={'actionstr': 'foo'}) + for i in range(0, policy_db.parameters["threshold"]): + liveaction = LiveActionDB( + action="wolfpack.action-2", parameters={"actionstr": "foo"} + ) action_service.request(liveaction) # Run the scheduler to schedule action executions. self._process_scheduling_queue() # Check the number of action executions in scheduled state. - scheduled = [item for item in LiveAction.get_all() if item.status in SCHEDULED_STATES] - self.assertEqual(len(scheduled), policy_db.parameters['threshold']) + scheduled = [ + item for item in LiveAction.get_all() if item.status in SCHEDULED_STATES + ] + self.assertEqual(len(scheduled), policy_db.parameters["threshold"]) # Assert the correct number of published states and action executions. This is to avoid # duplicate executions caused by accidental publishing of state in the concurrency policies. # num_state_changes = len(scheduled) * len(['requested', 'scheduled', 'running']) expected_num_exec = len(scheduled) expected_num_pubs = expected_num_exec * 3 - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count) # Execution is expected to be delayed since concurrency threshold is reached. - liveaction = LiveActionDB(action='wolfpack.action-2', parameters={'actionstr': 'foo'}) + liveaction = LiveActionDB( + action="wolfpack.action-2", parameters={"actionstr": "foo"} + ) liveaction, _ = action_service.request(liveaction) expected_num_pubs += 1 # Tally requested state. - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) # Run the scheduler to schedule action executions. self._process_scheduling_queue() @@ -247,7 +292,9 @@ def test_over_threshold_cancel_executions(self): LiveActionPublisher.publish_state.assert_has_calls(calls) expected_num_pubs += 2 # Tally canceling and canceled state changes. expected_num_exec += 0 # This request will not be scheduled for execution. - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count) # Assert the action is canceled. @@ -255,58 +302,80 @@ def test_over_threshold_cancel_executions(self): self.assertEqual(canceled.status, action_constants.LIVEACTION_STATUS_CANCELED) @mock.patch.object( - runner.MockActionRunner, 'run', - mock.MagicMock(return_value=MOCK_RUN_RETURN_VALUE)) + runner.MockActionRunner, + "run", + mock.MagicMock(return_value=MOCK_RUN_RETURN_VALUE), + ) @mock.patch.object( - LiveActionPublisher, 'publish_state', - mock.MagicMock(side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state)) + LiveActionPublisher, + "publish_state", + mock.MagicMock( + side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state + ), + ) def test_on_cancellation(self): - policy_db = Policy.get_by_ref('wolfpack.action-1.concurrency.attr') - self.assertGreater(policy_db.parameters['threshold'], 0) - self.assertIn('actionstr', policy_db.parameters['attributes']) + policy_db = Policy.get_by_ref("wolfpack.action-1.concurrency.attr") + self.assertGreater(policy_db.parameters["threshold"], 0) + self.assertIn("actionstr", policy_db.parameters["attributes"]) # Launch action executions until the expected threshold is reached. - for i in range(0, policy_db.parameters['threshold']): - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'}) + for i in range(0, policy_db.parameters["threshold"]): + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "foo"} + ) action_service.request(liveaction) # Run the scheduler to schedule action executions. self._process_scheduling_queue() # Check the number of action executions in scheduled state. - scheduled = [item for item in LiveAction.get_all() if item.status in SCHEDULED_STATES] - self.assertEqual(len(scheduled), policy_db.parameters['threshold']) + scheduled = [ + item for item in LiveAction.get_all() if item.status in SCHEDULED_STATES + ] + self.assertEqual(len(scheduled), policy_db.parameters["threshold"]) # duplicate executions caused by accidental publishing of state in the concurrency policies. # num_state_changes = len(scheduled) * len(['requested', 'scheduled', 'running']) expected_num_exec = len(scheduled) expected_num_pubs = expected_num_exec * 3 - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count) # Execution is expected to be delayed since concurrency threshold is reached. - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'}) + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "foo"} + ) liveaction, _ = action_service.request(liveaction) expected_num_pubs += 1 # Tally requested state. - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) # Run the scheduler to schedule action executions. self._process_scheduling_queue() # Since states are being processed asynchronously, wait for the # liveaction to go into delayed state. - liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_DELAYED) + liveaction = self._wait_on_status( + liveaction, action_constants.LIVEACTION_STATUS_DELAYED + ) delayed = liveaction expected_num_exec += 0 # This request will not be scheduled for execution. expected_num_pubs += 0 # The delayed status change should not be published. - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count) # Execution is expected to be scheduled since concurrency threshold is not reached. # The execution with actionstr "fu" is over the threshold but actionstr "bar" is not. - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'bar'}) + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "bar"} + ) liveaction, _ = action_service.request(liveaction) # Run the scheduler to schedule action executions. @@ -317,13 +386,17 @@ def test_on_cancellation(self): liveaction = self._wait_on_statuses(liveaction, SCHEDULED_STATES) expected_num_exec += 1 # This request is expected to be executed. expected_num_pubs += 3 # Tally requested, scheduled, and running states. - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count) # Cancel execution. - action_service.request_cancellation(scheduled[0], 'stanley') + action_service.request_cancellation(scheduled[0], "stanley") expected_num_pubs += 2 # Tally the canceling and canceled states. - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) # Run the scheduler to schedule action executions. self._process_scheduling_queue() @@ -331,7 +404,9 @@ def test_on_cancellation(self): # Once capacity freed up, the delayed execution is published as requested again. expected_num_exec += 1 # The delayed request is expected to be executed. expected_num_pubs += 2 # Tally scheduled and running state. - self.assertEqual(expected_num_pubs, LiveActionPublisher.publish_state.call_count) + self.assertEqual( + expected_num_pubs, LiveActionPublisher.publish_state.call_count + ) self.assertEqual(expected_num_exec, runner.MockActionRunner.run.call_count) # Since states are being processed asynchronously, wait for the diff --git a/st2actions/tests/unit/policies/test_retry_policy.py b/st2actions/tests/unit/policies/test_retry_policy.py index 6b6f0f0cc48..21371c6a028 100644 --- a/st2actions/tests/unit/policies/test_retry_policy.py +++ b/st2actions/tests/unit/policies/test_retry_policy.py @@ -35,19 +35,10 @@ from st2tests.base import CleanDbTestCase from st2tests.fixturesloader import FixturesLoader -__all__ = [ - 'RetryPolicyTestCase' -] +__all__ = ["RetryPolicyTestCase"] -PACK = 'generic' -TEST_FIXTURES = { - 'actions': [ - 'action1.yaml' - ], - 'policies': [ - 'policy_4.yaml' - ] -} +PACK = "generic" +TEST_FIXTURES = {"actions": ["action1.yaml"], "policies": ["policy_4.yaml"]} class RetryPolicyTestCase(CleanDbTestCase): @@ -66,18 +57,21 @@ def setUp(self): register_policy_types(st2actions) loader = FixturesLoader() - models = loader.save_fixtures_to_db(fixtures_pack=PACK, - fixtures_dict=TEST_FIXTURES) + models = loader.save_fixtures_to_db( + fixtures_pack=PACK, fixtures_dict=TEST_FIXTURES + ) # Instantiate policy applicator we will use in the tests - policy_db = models['policies']['policy_4.yaml'] - retry_on = policy_db.parameters['retry_on'] - max_retry_count = policy_db.parameters['max_retry_count'] - self.policy = ExecutionRetryPolicyApplicator(policy_ref='test_policy', - policy_type='action.retry', - retry_on=retry_on, - max_retry_count=max_retry_count, - delay=0) + policy_db = models["policies"]["policy_4.yaml"] + retry_on = policy_db.parameters["retry_on"] + max_retry_count = policy_db.parameters["max_retry_count"] + self.policy = ExecutionRetryPolicyApplicator( + policy_ref="test_policy", + policy_type="action.retry", + retry_on=retry_on, + max_retry_count=max_retry_count, + delay=0, + ) def test_retry_on_timeout_no_retry_since_no_timeout_reached(self): # Verify initial state @@ -85,7 +79,9 @@ def test_retry_on_timeout_no_retry_since_no_timeout_reached(self): self.assertSequenceEqual(ActionExecution.get_all(), []) # Start a mock action which succeeds - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'}) + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "foo"} + ) live_action_db, execution_db = action_service.request(liveaction) live_action_db.status = LIVEACTION_STATUS_SUCCEEDED @@ -110,7 +106,9 @@ def test_retry_on_timeout_first_retry_is_successful(self): self.assertSequenceEqual(ActionExecution.get_all(), []) # Start a mock action which times out - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'}) + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "foo"} + ) live_action_db, execution_db = action_service.request(liveaction) live_action_db.status = LIVEACTION_STATUS_TIMED_OUT @@ -130,14 +128,16 @@ def test_retry_on_timeout_first_retry_is_successful(self): self.assertEqual(action_execution_dbs[1].status, LIVEACTION_STATUS_REQUESTED) # Verify retried execution contains policy related context - original_liveaction_id = action_execution_dbs[0].liveaction['id'] + original_liveaction_id = action_execution_dbs[0].liveaction["id"] context = action_execution_dbs[1].context - self.assertIn('policies', context) - self.assertEqual(context['policies']['retry']['retry_count'], 1) - self.assertEqual(context['policies']['retry']['applied_policy'], 'test_policy') - self.assertEqual(context['policies']['retry']['retried_liveaction_id'], - original_liveaction_id) + self.assertIn("policies", context) + self.assertEqual(context["policies"]["retry"]["retry_count"], 1) + self.assertEqual(context["policies"]["retry"]["applied_policy"], "test_policy") + self.assertEqual( + context["policies"]["retry"]["retried_liveaction_id"], + original_liveaction_id, + ) # Simulate success of second action so no it shouldn't be retried anymore live_action_db = live_action_dbs[1] @@ -161,7 +161,9 @@ def test_retry_on_timeout_policy_is_retried_twice(self): self.assertSequenceEqual(ActionExecution.get_all(), []) # Start a mock action which times out - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'}) + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "foo"} + ) live_action_db, execution_db = action_service.request(liveaction) live_action_db.status = LIVEACTION_STATUS_TIMED_OUT @@ -181,14 +183,16 @@ def test_retry_on_timeout_policy_is_retried_twice(self): self.assertEqual(action_execution_dbs[1].status, LIVEACTION_STATUS_REQUESTED) # Verify retried execution contains policy related context - original_liveaction_id = action_execution_dbs[0].liveaction['id'] + original_liveaction_id = action_execution_dbs[0].liveaction["id"] context = action_execution_dbs[1].context - self.assertIn('policies', context) - self.assertEqual(context['policies']['retry']['retry_count'], 1) - self.assertEqual(context['policies']['retry']['applied_policy'], 'test_policy') - self.assertEqual(context['policies']['retry']['retried_liveaction_id'], - original_liveaction_id) + self.assertIn("policies", context) + self.assertEqual(context["policies"]["retry"]["retry_count"], 1) + self.assertEqual(context["policies"]["retry"]["applied_policy"], "test_policy") + self.assertEqual( + context["policies"]["retry"]["retried_liveaction_id"], + original_liveaction_id, + ) # Simulate timeout of second action which should cause another retry live_action_db = live_action_dbs[1] @@ -212,14 +216,16 @@ def test_retry_on_timeout_policy_is_retried_twice(self): self.assertEqual(action_execution_dbs[2].status, LIVEACTION_STATUS_REQUESTED) # Verify retried execution contains policy related context - original_liveaction_id = action_execution_dbs[1].liveaction['id'] + original_liveaction_id = action_execution_dbs[1].liveaction["id"] context = action_execution_dbs[2].context - self.assertIn('policies', context) - self.assertEqual(context['policies']['retry']['retry_count'], 2) - self.assertEqual(context['policies']['retry']['applied_policy'], 'test_policy') - self.assertEqual(context['policies']['retry']['retried_liveaction_id'], - original_liveaction_id) + self.assertIn("policies", context) + self.assertEqual(context["policies"]["retry"]["retry_count"], 2) + self.assertEqual(context["policies"]["retry"]["applied_policy"], "test_policy") + self.assertEqual( + context["policies"]["retry"]["retried_liveaction_id"], + original_liveaction_id, + ) def test_retry_on_timeout_max_retries_reached(self): # Verify initial state @@ -227,12 +233,14 @@ def test_retry_on_timeout_max_retries_reached(self): self.assertSequenceEqual(ActionExecution.get_all(), []) # Start a mock action which times out - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'}) + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "foo"} + ) live_action_db, execution_db = action_service.request(liveaction) live_action_db.status = LIVEACTION_STATUS_TIMED_OUT - live_action_db.context['policies'] = {} - live_action_db.context['policies']['retry'] = {'retry_count': 2} + live_action_db.context["policies"] = {} + live_action_db.context["policies"]["retry"] = {"retry_count": 2} execution_db.status = LIVEACTION_STATUS_TIMED_OUT LiveAction.add_or_update(live_action_db) ActionExecution.add_or_update(execution_db) @@ -248,8 +256,10 @@ def test_retry_on_timeout_max_retries_reached(self): self.assertEqual(action_execution_dbs[0].status, LIVEACTION_STATUS_TIMED_OUT) @mock.patch.object( - trace_service, 'get_trace_db_by_live_action', - mock.MagicMock(return_value=(None, None))) + trace_service, + "get_trace_db_by_live_action", + mock.MagicMock(return_value=(None, None)), + ) def test_no_retry_on_workflow_task(self): # Verify initial state self.assertSequenceEqual(LiveAction.get_all(), []) @@ -257,9 +267,9 @@ def test_no_retry_on_workflow_task(self): # Start a mock action which times out live_action_db = LiveActionDB( - action='wolfpack.action-1', - parameters={'actionstr': 'foo'}, - context={'parent': {'execution_id': 'abcde'}} + action="wolfpack.action-1", + parameters={"actionstr": "foo"}, + context={"parent": {"execution_id": "abcde"}}, ) live_action_db, execution_db = action_service.request(live_action_db) @@ -268,7 +278,7 @@ def test_no_retry_on_workflow_task(self): # Expire the workflow instance. live_action_db.status = LIVEACTION_STATUS_TIMED_OUT - live_action_db.context['policies'] = {} + live_action_db.context["policies"] = {} execution_db.status = LIVEACTION_STATUS_TIMED_OUT LiveAction.add_or_update(live_action_db) ActionExecution.add_or_update(execution_db) @@ -297,10 +307,12 @@ def test_no_retry_on_non_applicable_statuses(self): LIVEACTION_STATUS_CANCELED, ] - action_ref = 'wolfpack.action-1' + action_ref = "wolfpack.action-1" for status in non_retry_statuses: - liveaction = LiveActionDB(action=action_ref, parameters={'actionstr': 'foo'}) + liveaction = LiveActionDB( + action=action_ref, parameters={"actionstr": "foo"} + ) live_action_db, execution_db = action_service.request(liveaction) live_action_db.status = status diff --git a/st2actions/tests/unit/test_action_runner_worker.py b/st2actions/tests/unit/test_action_runner_worker.py index 4f2494a4316..1d0c7bbbd07 100644 --- a/st2actions/tests/unit/test_action_runner_worker.py +++ b/st2actions/tests/unit/test_action_runner_worker.py @@ -21,11 +21,10 @@ from st2common.models.db.liveaction import LiveActionDB from st2tests import config as test_config + test_config.parse_args() -__all__ = [ - 'ActionsQueueConsumerTestCase' -] +__all__ = ["ActionsQueueConsumerTestCase"] class ActionsQueueConsumerTestCase(TestCase): @@ -38,7 +37,9 @@ def test_process_right_dispatcher_is_used(self): consumer._workflows_dispatcher = Mock() consumer._actions_dispatcher = Mock() - body = LiveActionDB(status='scheduled', action='core.local', action_is_workflow=False) + body = LiveActionDB( + status="scheduled", action="core.local", action_is_workflow=False + ) message = Mock() consumer.process(body=body, message=message) @@ -49,7 +50,9 @@ def test_process_right_dispatcher_is_used(self): consumer._workflows_dispatcher = Mock() consumer._actions_dispatcher = Mock() - body = LiveActionDB(status='scheduled', action='core.local', action_is_workflow=True) + body = LiveActionDB( + status="scheduled", action="core.local", action_is_workflow=True + ) message = Mock() consumer.process(body=body, message=message) diff --git a/st2actions/tests/unit/test_actions_registrar.py b/st2actions/tests/unit/test_actions_registrar.py index c4d2771268d..cc9da332995 100644 --- a/st2actions/tests/unit/test_actions_registrar.py +++ b/st2actions/tests/unit/test_actions_registrar.py @@ -31,18 +31,24 @@ import st2tests.fixturesloader as fixtures_loader from st2tests.fixturesloader import get_fixtures_base_path -MOCK_RUNNER_TYPE_DB = RunnerTypeDB(name='run-local', runner_module='st2.runners.local') +MOCK_RUNNER_TYPE_DB = RunnerTypeDB(name="run-local", runner_module="st2.runners.local") # NOTE: We need to perform this patching because test fixtures are located outside of the packs # base paths directory. This will never happen outside the context of test fixtures. -@mock.patch('st2common.content.utils.get_pack_base_path', - mock.Mock(return_value=os.path.join(get_fixtures_base_path(), 'generic'))) +@mock.patch( + "st2common.content.utils.get_pack_base_path", + mock.Mock(return_value=os.path.join(get_fixtures_base_path(), "generic")), +) class ActionsRegistrarTest(tests_base.DbTestCase): - - @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock(return_value=True)) - @mock.patch.object(action_validator, 'get_runner_model', - mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB)) + @mock.patch.object( + action_validator, "_is_valid_pack", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + action_validator, + "get_runner_model", + mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB), + ) def test_register_all_actions(self): try: packs_base_path = fixtures_loader.get_fixtures_base_path() @@ -50,111 +56,157 @@ def test_register_all_actions(self): actions_registrar.register_actions(packs_base_paths=[packs_base_path]) except Exception as e: print(six.text_type(e)) - self.fail('All actions must be registered without exceptions.') + self.fail("All actions must be registered without exceptions.") else: all_actions_in_db = Action.get_all() self.assertTrue(len(all_actions_in_db) > 0) # Assert metadata_file field is populated - expected_path = 'actions/action-with-no-parameters.yaml' + expected_path = "actions/action-with-no-parameters.yaml" self.assertEqual(all_actions_in_db[0].metadata_file, expected_path) def test_register_actions_from_bad_pack(self): packs_base_path = tests_base.get_fixtures_path() try: actions_registrar.register_actions(packs_base_paths=[packs_base_path]) - self.fail('Should have thrown.') + self.fail("Should have thrown.") except: pass - @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock(return_value=True)) - @mock.patch.object(action_validator, 'get_runner_model', - mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB)) + @mock.patch.object( + action_validator, "_is_valid_pack", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + action_validator, + "get_runner_model", + mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB), + ) def test_pack_name_missing(self): registrar = actions_registrar.ActionsRegistrar() loader = fixtures_loader.FixturesLoader() action_file = loader.get_fixture_file_path_abs( - 'generic', 'actions', 'action_3_pack_missing.yaml') - registrar._register_action('dummy', action_file) + "generic", "actions", "action_3_pack_missing.yaml" + ) + registrar._register_action("dummy", action_file) action_name = None - with open(action_file, 'r') as fd: + with open(action_file, "r") as fd: content = yaml.safe_load(fd) - action_name = str(content['name']) + action_name = str(content["name"]) action_db = Action.get_by_name(action_name) - expected_msg = 'Content pack must be set to dummy' - self.assertEqual(action_db.pack, 'dummy', expected_msg) + expected_msg = "Content pack must be set to dummy" + self.assertEqual(action_db.pack, "dummy", expected_msg) Action.delete(action_db) - @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock(return_value=True)) - @mock.patch.object(action_validator, 'get_runner_model', - mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB)) + @mock.patch.object( + action_validator, "_is_valid_pack", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + action_validator, + "get_runner_model", + mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB), + ) def test_register_action_with_no_params(self): registrar = actions_registrar.ActionsRegistrar() loader = fixtures_loader.FixturesLoader() action_file = loader.get_fixture_file_path_abs( - 'generic', 'actions', 'action-with-no-parameters.yaml') - - self.assertEqual(registrar._register_action('dummy', action_file), None) - - @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock(return_value=True)) - @mock.patch.object(action_validator, 'get_runner_model', - mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB)) + "generic", "actions", "action-with-no-parameters.yaml" + ) + + self.assertEqual(registrar._register_action("dummy", action_file), None) + + @mock.patch.object( + action_validator, "_is_valid_pack", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + action_validator, + "get_runner_model", + mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB), + ) def test_register_action_invalid_parameter_type_attribute(self): registrar = actions_registrar.ActionsRegistrar() loader = fixtures_loader.FixturesLoader() action_file = loader.get_fixture_file_path_abs( - 'generic', 'actions', 'action_invalid_param_type.yaml') - - expected_msg = '\'list\' is not valid under any of the given schema' - self.assertRaisesRegexp(jsonschema.ValidationError, expected_msg, - registrar._register_action, - 'dummy', action_file) - - @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock(return_value=True)) - @mock.patch.object(action_validator, 'get_runner_model', - mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB)) + "generic", "actions", "action_invalid_param_type.yaml" + ) + + expected_msg = "'list' is not valid under any of the given schema" + self.assertRaisesRegexp( + jsonschema.ValidationError, + expected_msg, + registrar._register_action, + "dummy", + action_file, + ) + + @mock.patch.object( + action_validator, "_is_valid_pack", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + action_validator, + "get_runner_model", + mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB), + ) def test_register_action_invalid_parameter_name(self): registrar = actions_registrar.ActionsRegistrar() loader = fixtures_loader.FixturesLoader() action_file = loader.get_fixture_file_path_abs( - 'generic', 'actions', 'action_invalid_parameter_name.yaml') - - expected_msg = ('Parameter name "action-name" is invalid. Valid characters for ' - 'parameter name are') - self.assertRaisesRegexp(jsonschema.ValidationError, expected_msg, - registrar._register_action, - 'generic', action_file) - - @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock(return_value=True)) - @mock.patch.object(action_validator, 'get_runner_model', - mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB)) + "generic", "actions", "action_invalid_parameter_name.yaml" + ) + + expected_msg = ( + 'Parameter name "action-name" is invalid. Valid characters for ' + "parameter name are" + ) + self.assertRaisesRegexp( + jsonschema.ValidationError, + expected_msg, + registrar._register_action, + "generic", + action_file, + ) + + @mock.patch.object( + action_validator, "_is_valid_pack", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + action_validator, + "get_runner_model", + mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB), + ) def test_invalid_params_schema(self): registrar = actions_registrar.ActionsRegistrar() loader = fixtures_loader.FixturesLoader() action_file = loader.get_fixture_file_path_abs( - 'generic', 'actions', 'action-invalid-schema-params.yaml') + "generic", "actions", "action-invalid-schema-params.yaml" + ) try: - registrar._register_action('generic', action_file) - self.fail('Invalid action schema. Should have failed.') + registrar._register_action("generic", action_file) + self.fail("Invalid action schema. Should have failed.") except jsonschema.ValidationError: pass - @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock(return_value=True)) - @mock.patch.object(action_validator, 'get_runner_model', - mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB)) + @mock.patch.object( + action_validator, "_is_valid_pack", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + action_validator, + "get_runner_model", + mock.MagicMock(return_value=MOCK_RUNNER_TYPE_DB), + ) def test_action_update(self): registrar = actions_registrar.ActionsRegistrar() loader = fixtures_loader.FixturesLoader() action_file = loader.get_fixture_file_path_abs( - 'generic', 'actions', 'action1.yaml') - registrar._register_action('wolfpack', action_file) + "generic", "actions", "action1.yaml" + ) + registrar._register_action("wolfpack", action_file) # try registering again. this should not throw errors. - registrar._register_action('wolfpack', action_file) + registrar._register_action("wolfpack", action_file) action_name = None - with open(action_file, 'r') as fd: + with open(action_file, "r") as fd: content = yaml.safe_load(fd) - action_name = str(content['name']) + action_name = str(content["name"]) action_db = Action.get_by_name(action_name) - expected_msg = 'Content pack must be set to wolfpack' - self.assertEqual(action_db.pack, 'wolfpack', expected_msg) + expected_msg = "Content pack must be set to wolfpack" + self.assertEqual(action_db.pack, "wolfpack", expected_msg) Action.delete(action_db) diff --git a/st2actions/tests/unit/test_async_runner.py b/st2actions/tests/unit/test_async_runner.py index 04092029038..31258fae4ec 100644 --- a/st2actions/tests/unit/test_async_runner.py +++ b/st2actions/tests/unit/test_async_runner.py @@ -14,15 +14,16 @@ # limitations under the License. from __future__ import absolute_import + try: import simplejson as json except: import json from st2common.runners.base import AsyncActionRunner -from st2common.constants.action import (LIVEACTION_STATUS_RUNNING) +from st2common.constants.action import LIVEACTION_STATUS_RUNNING -RAISE_PROPERTY = 'raise' +RAISE_PROPERTY = "raise" def get_runner(): @@ -31,7 +32,7 @@ def get_runner(): class AsyncTestRunner(AsyncActionRunner): def __init__(self): - super(AsyncTestRunner, self).__init__(runner_id='1') + super(AsyncTestRunner, self).__init__(runner_id="1") self.pre_run_called = False self.run_called = False self.post_run_called = False @@ -43,14 +44,11 @@ def run(self, action_params): self.run_called = True result = {} if self.runner_parameters.get(RAISE_PROPERTY, False): - raise Exception('Raise required.') + raise Exception("Raise required.") else: - result = { - 'ran': True, - 'action_params': action_params - } + result = {"ran": True, "action_params": action_params} - return (LIVEACTION_STATUS_RUNNING, json.dumps(result), {'id': 'foo'}) + return (LIVEACTION_STATUS_RUNNING, json.dumps(result), {"id": "foo"}) def post_run(self, status, result): self.post_run_called = True diff --git a/st2actions/tests/unit/test_execution_cancellation.py b/st2actions/tests/unit/test_execution_cancellation.py index 6a130e2fe76..e6c51159ef1 100644 --- a/st2actions/tests/unit/test_execution_cancellation.py +++ b/st2actions/tests/unit/test_execution_cancellation.py @@ -22,6 +22,7 @@ # XXX: actionsensor import depends on config being setup. import st2tests.config as tests_config + tests_config.parse_args() from st2common.constants import action as action_constants @@ -42,35 +43,32 @@ from st2tests.mocks.liveaction import MockLiveActionPublisherNonBlocking from st2tests.mocks.runners import runner -__all__ = [ - 'ExecutionCancellationTestCase' -] +__all__ = ["ExecutionCancellationTestCase"] -TEST_FIXTURES = { - 'actions': [ - 'action1.yaml' - ] -} +TEST_FIXTURES = {"actions": ["action1.yaml"]} -PACK = 'generic' +PACK = "generic" LOADER = FixturesLoader() FIXTURES = LOADER.load_fixtures(fixtures_pack=PACK, fixtures_dict=TEST_FIXTURES) -@mock.patch('st2common.runners.base.get_runner', mock.Mock(return_value=runner.get_runner())) -@mock.patch('st2actions.container.base.get_runner', mock.Mock(return_value=runner.get_runner())) -@mock.patch.object( - CUDPublisher, 'publish_update', - mock.MagicMock(side_effect=MockExecutionPublisher.publish_update)) +@mock.patch( + "st2common.runners.base.get_runner", mock.Mock(return_value=runner.get_runner()) +) +@mock.patch( + "st2actions.container.base.get_runner", mock.Mock(return_value=runner.get_runner()) +) @mock.patch.object( - CUDPublisher, 'publish_create', - mock.MagicMock(return_value=None)) + CUDPublisher, + "publish_update", + mock.MagicMock(side_effect=MockExecutionPublisher.publish_update), +) +@mock.patch.object(CUDPublisher, "publish_create", mock.MagicMock(return_value=None)) class ExecutionCancellationTestCase(ExecutionDbTestCase): - @classmethod def setUpClass(cls): super(ExecutionCancellationTestCase, cls).setUpClass() - for _, fixture in six.iteritems(FIXTURES['actions']): + for _, fixture in six.iteritems(FIXTURES["actions"]): instance = ActionAPI(**fixture) Action.add_or_update(ActionAPI.to_model(instance)) @@ -80,62 +78,84 @@ def tearDown(self): # Ensure all liveactions are canceled at end of each test. for liveaction in LiveAction.get_all(): action_service.update_status( - liveaction, action_constants.LIVEACTION_STATUS_CANCELED) + liveaction, action_constants.LIVEACTION_STATUS_CANCELED + ) @classmethod def get_runner_class(cls, runner_name): return runners.get_runner(runner_name).__class__ @mock.patch.object( - LiveActionPublisher, 'publish_state', - mock.MagicMock(side_effect=MockLiveActionPublisherNonBlocking.publish_state)) - @mock.patch('st2common.runners.base.get_runner', mock.Mock(return_value=runner.get_runner())) - @mock.patch('st2actions.container.base.get_runner', mock.Mock(return_value=runner.get_runner())) + LiveActionPublisher, + "publish_state", + mock.MagicMock(side_effect=MockLiveActionPublisherNonBlocking.publish_state), + ) + @mock.patch( + "st2common.runners.base.get_runner", mock.Mock(return_value=runner.get_runner()) + ) + @mock.patch( + "st2actions.container.base.get_runner", + mock.Mock(return_value=runner.get_runner()), + ) def test_basic_cancel(self): - runner_run_result = (action_constants.LIVEACTION_STATUS_RUNNING, 'foobar', None) + runner_run_result = (action_constants.LIVEACTION_STATUS_RUNNING, "foobar", None) mock_runner_run = mock.Mock(return_value=runner_run_result) - with mock.patch.object(runner.MockActionRunner, 'run', mock_runner_run): - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'}) + with mock.patch.object(runner.MockActionRunner, "run", mock_runner_run): + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "foo"} + ) liveaction, _ = action_service.request(liveaction) liveaction = self._wait_on_status( - liveaction, - action_constants.LIVEACTION_STATUS_RUNNING + liveaction, action_constants.LIVEACTION_STATUS_RUNNING ) # Cancel execution. action_service.request_cancellation(liveaction, cfg.CONF.system_user.user) liveaction = self._wait_on_status( - liveaction, - action_constants.LIVEACTION_STATUS_CANCELED + liveaction, action_constants.LIVEACTION_STATUS_CANCELED ) @mock.patch.object( - CUDPublisher, 'publish_create', - mock.MagicMock(side_effect=MockLiveActionPublisher.publish_create)) + CUDPublisher, + "publish_create", + mock.MagicMock(side_effect=MockLiveActionPublisher.publish_create), + ) @mock.patch.object( - CUDPublisher, 'publish_update', - mock.MagicMock(side_effect=MockExecutionPublisher.publish_update)) + CUDPublisher, + "publish_update", + mock.MagicMock(side_effect=MockExecutionPublisher.publish_update), + ) @mock.patch.object( - LiveActionPublisher, 'publish_state', - mock.MagicMock(side_effect=MockLiveActionPublisher.publish_state)) + LiveActionPublisher, + "publish_state", + mock.MagicMock(side_effect=MockLiveActionPublisher.publish_state), + ) @mock.patch.object( - runners.ActionRunner, 'cancel', - mock.MagicMock(side_effect=Exception('Mock cancellation failure.'))) - @mock.patch('st2common.runners.base.get_runner', mock.Mock(return_value=runner.get_runner())) - @mock.patch('st2actions.container.base.get_runner', mock.Mock(return_value=runner.get_runner())) + runners.ActionRunner, + "cancel", + mock.MagicMock(side_effect=Exception("Mock cancellation failure.")), + ) + @mock.patch( + "st2common.runners.base.get_runner", mock.Mock(return_value=runner.get_runner()) + ) + @mock.patch( + "st2actions.container.base.get_runner", + mock.Mock(return_value=runner.get_runner()), + ) def test_failed_cancel(self): - runner_run_result = (action_constants.LIVEACTION_STATUS_RUNNING, 'foobar', None) + runner_run_result = (action_constants.LIVEACTION_STATUS_RUNNING, "foobar", None) mock_runner_run = mock.Mock(return_value=runner_run_result) - with mock.patch.object(runner.MockActionRunner, 'run', mock_runner_run): - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'}) + with mock.patch.object(runner.MockActionRunner, "run", mock_runner_run): + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "foo"} + ) liveaction, _ = action_service.request(liveaction) liveaction = self._wait_on_status( - liveaction, - action_constants.LIVEACTION_STATUS_RUNNING + liveaction, action_constants.LIVEACTION_STATUS_RUNNING ) # Cancel execution. @@ -144,22 +164,28 @@ def test_failed_cancel(self): # Cancellation failed and execution state remains "canceling". runners.ActionRunner.cancel.assert_called_once_with() liveaction = LiveAction.get_by_id(str(liveaction.id)) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_CANCELING) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_CANCELING + ) @mock.patch.object( - CUDPublisher, 'publish_create', - mock.MagicMock(return_value=None)) + CUDPublisher, "publish_create", mock.MagicMock(return_value=None) + ) @mock.patch.object( - LiveActionPublisher, 'publish_state', - mock.MagicMock(return_value=None)) + LiveActionPublisher, "publish_state", mock.MagicMock(return_value=None) + ) @mock.patch.object( - runners.ActionRunner, 'cancel', - mock.MagicMock(return_value=None)) + runners.ActionRunner, "cancel", mock.MagicMock(return_value=None) + ) def test_noop_cancel(self): - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'}) + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "foo"} + ) liveaction, _ = action_service.request(liveaction) liveaction = LiveAction.get_by_id(str(liveaction.id)) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_REQUESTED) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_REQUESTED + ) # Cancel execution. action_service.request_cancellation(liveaction, cfg.CONF.system_user.user) @@ -171,22 +197,28 @@ def test_noop_cancel(self): self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_CANCELED) @mock.patch.object( - CUDPublisher, 'publish_create', - mock.MagicMock(return_value=None)) + CUDPublisher, "publish_create", mock.MagicMock(return_value=None) + ) @mock.patch.object( - LiveActionPublisher, 'publish_state', - mock.MagicMock(return_value=None)) + LiveActionPublisher, "publish_state", mock.MagicMock(return_value=None) + ) @mock.patch.object( - runners.ActionRunner, 'cancel', - mock.MagicMock(return_value=None)) + runners.ActionRunner, "cancel", mock.MagicMock(return_value=None) + ) def test_cancel_delayed_execution(self): - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'}) + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "foo"} + ) liveaction, _ = action_service.request(liveaction) liveaction = LiveAction.get_by_id(str(liveaction.id)) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_REQUESTED) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_REQUESTED + ) # Manually update the liveaction from requested to delayed to mock concurrency policy. - action_service.update_status(liveaction, action_constants.LIVEACTION_STATUS_DELAYED) + action_service.update_status( + liveaction, action_constants.LIVEACTION_STATUS_DELAYED + ) liveaction = LiveAction.get_by_id(str(liveaction.id)) self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_DELAYED) @@ -200,27 +232,33 @@ def test_cancel_delayed_execution(self): self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_CANCELED) @mock.patch.object( - CUDPublisher, 'publish_create', - mock.MagicMock(return_value=None)) + CUDPublisher, "publish_create", mock.MagicMock(return_value=None) + ) @mock.patch.object( - LiveActionPublisher, 'publish_state', - mock.MagicMock(return_value=None)) + LiveActionPublisher, "publish_state", mock.MagicMock(return_value=None) + ) @mock.patch.object( - trace_service, 'get_trace_db_by_live_action', - mock.MagicMock(return_value=(None, None))) + trace_service, + "get_trace_db_by_live_action", + mock.MagicMock(return_value=(None, None)), + ) def test_cancel_delayed_execution_with_parent(self): liveaction = LiveActionDB( - action='wolfpack.action-1', - parameters={'actionstr': 'foo'}, - context={'parent': {'execution_id': uuid.uuid4().hex}} + action="wolfpack.action-1", + parameters={"actionstr": "foo"}, + context={"parent": {"execution_id": uuid.uuid4().hex}}, ) liveaction, _ = action_service.request(liveaction) liveaction = LiveAction.get_by_id(str(liveaction.id)) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_REQUESTED) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_REQUESTED + ) # Manually update the liveaction from requested to delayed to mock concurrency policy. - action_service.update_status(liveaction, action_constants.LIVEACTION_STATUS_DELAYED) + action_service.update_status( + liveaction, action_constants.LIVEACTION_STATUS_DELAYED + ) liveaction = LiveAction.get_by_id(str(liveaction.id)) self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_DELAYED) @@ -230,4 +268,6 @@ def test_cancel_delayed_execution_with_parent(self): # Cancel is only called when liveaction is still in running state. # Otherwise, the cancellation is only a state change. liveaction = LiveAction.get_by_id(str(liveaction.id)) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_CANCELING) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_CANCELING + ) diff --git a/st2actions/tests/unit/test_executions.py b/st2actions/tests/unit/test_executions.py index f143631e422..64bde6b654d 100644 --- a/st2actions/tests/unit/test_executions.py +++ b/st2actions/tests/unit/test_executions.py @@ -20,6 +20,7 @@ # XXX: actionsensor import depends on config being setup. import st2tests.config as tests_config + tests_config.parse_args() import st2common.bootstrap.runnersregistrar as runners_registrar @@ -53,47 +54,57 @@ @mock.patch.object( - LocalShellCommandRunner, 'run', - mock.MagicMock(return_value=(action_constants.LIVEACTION_STATUS_FAILED, 'Non-empty', None))) + LocalShellCommandRunner, + "run", + mock.MagicMock( + return_value=(action_constants.LIVEACTION_STATUS_FAILED, "Non-empty", None) + ), +) @mock.patch.object( - CUDPublisher, 'publish_create', - mock.MagicMock(side_effect=MockLiveActionPublisher.publish_create)) + CUDPublisher, + "publish_create", + mock.MagicMock(side_effect=MockLiveActionPublisher.publish_create), +) @mock.patch.object( - LiveActionPublisher, 'publish_state', - mock.MagicMock(side_effect=MockLiveActionPublisher.publish_state)) + LiveActionPublisher, + "publish_state", + mock.MagicMock(side_effect=MockLiveActionPublisher.publish_state), +) class TestActionExecutionHistoryWorker(ExecutionDbTestCase): - @classmethod def setUpClass(cls): super(TestActionExecutionHistoryWorker, cls).setUpClass() runners_registrar.register_runners() - action_local = ActionAPI(**copy.deepcopy(fixture.ARTIFACTS['actions']['local'])) + action_local = ActionAPI(**copy.deepcopy(fixture.ARTIFACTS["actions"]["local"])) Action.add_or_update(ActionAPI.to_model(action_local)) - action_chain = ActionAPI(**copy.deepcopy(fixture.ARTIFACTS['actions']['chain'])) - action_chain.entry_point = fixture.PATH + '/chain.yaml' + action_chain = ActionAPI(**copy.deepcopy(fixture.ARTIFACTS["actions"]["chain"])) + action_chain.entry_point = fixture.PATH + "/chain.yaml" Action.add_or_update(ActionAPI.to_model(action_chain)) def tearDown(self): - MOCK_FAIL_EXECUTION_CREATE = False # noqa + MOCK_FAIL_EXECUTION_CREATE = False # noqa super(TestActionExecutionHistoryWorker, self).tearDown() def test_basic_execution(self): - liveaction = LiveActionDB(action='executions.local', parameters={'cmd': 'uname -a'}) + liveaction = LiveActionDB( + action="executions.local", parameters={"cmd": "uname -a"} + ) liveaction, _ = action_service.request(liveaction) - liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_FAILED) + liveaction = self._wait_on_status( + liveaction, action_constants.LIVEACTION_STATUS_FAILED + ) execution = self._get_action_execution( - liveaction__id=str(liveaction.id), - raise_exception=True + liveaction__id=str(liveaction.id), raise_exception=True ) self.assertDictEqual(execution.trigger, {}) self.assertDictEqual(execution.trigger_type, {}) self.assertDictEqual(execution.trigger_instance, {}) self.assertDictEqual(execution.rule, {}) - action = action_utils.get_action_by_ref('executions.local') + action = action_utils.get_action_by_ref("executions.local") self.assertDictEqual(execution.action, vars(ActionAPI.from_model(action))) - runner = RunnerType.get_by_name(action.runner_type['name']) + runner = RunnerType.get_by_name(action.runner_type["name"]) self.assertDictEqual(execution.runner, vars(RunnerTypeAPI.from_model(runner))) liveaction = LiveAction.get_by_id(str(liveaction.id)) self.assertEqual(execution.start_timestamp, liveaction.start_timestamp) @@ -101,26 +112,27 @@ def test_basic_execution(self): self.assertEqual(execution.result, liveaction.result) self.assertEqual(execution.status, liveaction.status) self.assertEqual(execution.context, liveaction.context) - self.assertEqual(execution.liveaction['callback'], liveaction.callback) - self.assertEqual(execution.liveaction['action'], liveaction.action) + self.assertEqual(execution.liveaction["callback"], liveaction.callback) + self.assertEqual(execution.liveaction["action"], liveaction.action) def test_basic_execution_history_create_failed(self): - MOCK_FAIL_EXECUTION_CREATE = True # noqa + MOCK_FAIL_EXECUTION_CREATE = True # noqa self.test_basic_execution() def test_chained_executions(self): - liveaction = LiveActionDB(action='executions.chain') + liveaction = LiveActionDB(action="executions.chain") liveaction, _ = action_service.request(liveaction) - liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_FAILED) + liveaction = self._wait_on_status( + liveaction, action_constants.LIVEACTION_STATUS_FAILED + ) execution = self._get_action_execution( - liveaction__id=str(liveaction.id), - raise_exception=True + liveaction__id=str(liveaction.id), raise_exception=True ) - action = action_utils.get_action_by_ref('executions.chain') + action = action_utils.get_action_by_ref("executions.chain") self.assertDictEqual(execution.action, vars(ActionAPI.from_model(action))) - runner = RunnerType.get_by_name(action.runner_type['name']) + runner = RunnerType.get_by_name(action.runner_type["name"]) self.assertDictEqual(execution.runner, vars(RunnerTypeAPI.from_model(runner))) liveaction = LiveAction.get_by_id(str(liveaction.id)) self.assertEqual(execution.start_timestamp, liveaction.start_timestamp) @@ -128,56 +140,69 @@ def test_chained_executions(self): self.assertEqual(execution.result, liveaction.result) self.assertEqual(execution.status, liveaction.status) self.assertEqual(execution.context, liveaction.context) - self.assertEqual(execution.liveaction['callback'], liveaction.callback) - self.assertEqual(execution.liveaction['action'], liveaction.action) + self.assertEqual(execution.liveaction["callback"], liveaction.callback) + self.assertEqual(execution.liveaction["action"], liveaction.action) self.assertGreater(len(execution.children), 0) for child in execution.children: record = ActionExecution.get(id=child, raise_exception=True) self.assertEqual(record.parent, str(execution.id)) - self.assertEqual(record.action['name'], 'local') - self.assertEqual(record.runner['name'], 'local-shell-cmd') + self.assertEqual(record.action["name"], "local") + self.assertEqual(record.runner["name"], "local-shell-cmd") def test_triggered_execution(self): docs = { - 'trigger_type': copy.deepcopy(fixture.ARTIFACTS['trigger_type']), - 'trigger': copy.deepcopy(fixture.ARTIFACTS['trigger']), - 'rule': copy.deepcopy(fixture.ARTIFACTS['rule']), - 'trigger_instance': copy.deepcopy(fixture.ARTIFACTS['trigger_instance'])} + "trigger_type": copy.deepcopy(fixture.ARTIFACTS["trigger_type"]), + "trigger": copy.deepcopy(fixture.ARTIFACTS["trigger"]), + "rule": copy.deepcopy(fixture.ARTIFACTS["rule"]), + "trigger_instance": copy.deepcopy(fixture.ARTIFACTS["trigger_instance"]), + } # Trigger an action execution. trigger_type = TriggerType.add_or_update( - TriggerTypeAPI.to_model(TriggerTypeAPI(**docs['trigger_type']))) - trigger = Trigger.add_or_update(TriggerAPI.to_model(TriggerAPI(**docs['trigger']))) - rule = RuleAPI.to_model(RuleAPI(**docs['rule'])) + TriggerTypeAPI.to_model(TriggerTypeAPI(**docs["trigger_type"])) + ) + trigger = Trigger.add_or_update( + TriggerAPI.to_model(TriggerAPI(**docs["trigger"])) + ) + rule = RuleAPI.to_model(RuleAPI(**docs["rule"])) rule.trigger = reference.get_str_resource_ref_from_model(trigger) rule = Rule.add_or_update(rule) trigger_instance = TriggerInstance.add_or_update( - TriggerInstanceAPI.to_model(TriggerInstanceAPI(**docs['trigger_instance']))) + TriggerInstanceAPI.to_model(TriggerInstanceAPI(**docs["trigger_instance"])) + ) trace_service.add_or_update_given_trace_context( - trace_context={'trace_tag': 'test_triggered_execution_trace'}, - trigger_instances=[str(trigger_instance.id)]) + trace_context={"trace_tag": "test_triggered_execution_trace"}, + trigger_instances=[str(trigger_instance.id)], + ) enforcer = RuleEnforcer(trigger_instance, rule) enforcer.enforce() # Wait for the action execution to complete and then confirm outcome. - liveaction = LiveAction.get(context__trigger_instance__id=str(trigger_instance.id)) + liveaction = LiveAction.get( + context__trigger_instance__id=str(trigger_instance.id) + ) self.assertIsNotNone(liveaction) - liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_FAILED) + liveaction = self._wait_on_status( + liveaction, action_constants.LIVEACTION_STATUS_FAILED + ) execution = self._get_action_execution( - liveaction__id=str(liveaction.id), - raise_exception=True + liveaction__id=str(liveaction.id), raise_exception=True ) self.assertDictEqual(execution.trigger, vars(TriggerAPI.from_model(trigger))) - self.assertDictEqual(execution.trigger_type, vars(TriggerTypeAPI.from_model(trigger_type))) - self.assertDictEqual(execution.trigger_instance, - vars(TriggerInstanceAPI.from_model(trigger_instance))) + self.assertDictEqual( + execution.trigger_type, vars(TriggerTypeAPI.from_model(trigger_type)) + ) + self.assertDictEqual( + execution.trigger_instance, + vars(TriggerInstanceAPI.from_model(trigger_instance)), + ) self.assertDictEqual(execution.rule, vars(RuleAPI.from_model(rule))) action = action_utils.get_action_by_ref(liveaction.action) self.assertDictEqual(execution.action, vars(ActionAPI.from_model(action))) - runner = RunnerType.get_by_name(action.runner_type['name']) + runner = RunnerType.get_by_name(action.runner_type["name"]) self.assertDictEqual(execution.runner, vars(RunnerTypeAPI.from_model(runner))) liveaction = LiveAction.get_by_id(str(liveaction.id)) self.assertEqual(execution.start_timestamp, liveaction.start_timestamp) @@ -185,8 +210,8 @@ def test_triggered_execution(self): self.assertEqual(execution.result, liveaction.result) self.assertEqual(execution.status, liveaction.status) self.assertEqual(execution.context, liveaction.context) - self.assertEqual(execution.liveaction['callback'], liveaction.callback) - self.assertEqual(execution.liveaction['action'], liveaction.action) + self.assertEqual(execution.liveaction["callback"], liveaction.callback) + self.assertEqual(execution.liveaction["action"], liveaction.action) def _get_action_execution(self, **kwargs): return ActionExecution.get(**kwargs) diff --git a/st2actions/tests/unit/test_notifier.py b/st2actions/tests/unit/test_notifier.py index fa1af31ca80..b648d7fad30 100644 --- a/st2actions/tests/unit/test_notifier.py +++ b/st2actions/tests/unit/test_notifier.py @@ -20,6 +20,7 @@ import mock import st2tests.config as tests_config + tests_config.parse_args() from st2actions.notifier.notifier import Notifier @@ -41,77 +42,96 @@ from st2common.util import isotime from st2tests.base import CleanDbTestCase -ACTION_TRIGGER_TYPE = INTERNAL_TRIGGER_TYPES['action'][0] -NOTIFY_TRIGGER_TYPE = INTERNAL_TRIGGER_TYPES['action'][1] -MOCK_EXECUTION = ActionExecutionDB(id=bson.ObjectId(), result={'stdout': 'stuff happens'}) +ACTION_TRIGGER_TYPE = INTERNAL_TRIGGER_TYPES["action"][0] +NOTIFY_TRIGGER_TYPE = INTERNAL_TRIGGER_TYPES["action"][1] +MOCK_EXECUTION = ActionExecutionDB( + id=bson.ObjectId(), result={"stdout": "stuff happens"} +) class NotifierTestCase(CleanDbTestCase): - class MockDispatcher(object): def __init__(self, tester): self.tester = tester self.notify_trigger = ResourceReference.to_string_reference( - pack=NOTIFY_TRIGGER_TYPE['pack'], - name=NOTIFY_TRIGGER_TYPE['name']) + pack=NOTIFY_TRIGGER_TYPE["pack"], name=NOTIFY_TRIGGER_TYPE["name"] + ) self.action_trigger = ResourceReference.to_string_reference( - pack=ACTION_TRIGGER_TYPE['pack'], - name=ACTION_TRIGGER_TYPE['name']) + pack=ACTION_TRIGGER_TYPE["pack"], name=ACTION_TRIGGER_TYPE["name"] + ) def dispatch(self, *args, **kwargs): try: self.tester.assertEqual(len(args), 1) - self.tester.assertTrue('payload' in kwargs) - payload = kwargs['payload'] + self.tester.assertTrue("payload" in kwargs) + payload = kwargs["payload"] if args[0] == self.notify_trigger: - self.tester.assertEqual(payload['status'], 'succeeded') - self.tester.assertTrue('execution_id' in payload) - self.tester.assertEqual(payload['execution_id'], str(MOCK_EXECUTION.id)) - self.tester.assertTrue('start_timestamp' in payload) - self.tester.assertTrue('end_timestamp' in payload) - self.tester.assertEqual('core.local', payload['action_ref']) - self.tester.assertEqual('Action succeeded.', payload['message']) - self.tester.assertTrue('data' in payload) - self.tester.assertTrue('local-shell-cmd', payload['runner_ref']) + self.tester.assertEqual(payload["status"], "succeeded") + self.tester.assertTrue("execution_id" in payload) + self.tester.assertEqual( + payload["execution_id"], str(MOCK_EXECUTION.id) + ) + self.tester.assertTrue("start_timestamp" in payload) + self.tester.assertTrue("end_timestamp" in payload) + self.tester.assertEqual("core.local", payload["action_ref"]) + self.tester.assertEqual("Action succeeded.", payload["message"]) + self.tester.assertTrue("data" in payload) + self.tester.assertTrue("local-shell-cmd", payload["runner_ref"]) if args[0] == self.action_trigger: - self.tester.assertEqual(payload['status'], 'succeeded') - self.tester.assertTrue('execution_id' in payload) - self.tester.assertEqual(payload['execution_id'], str(MOCK_EXECUTION.id)) - self.tester.assertTrue('start_timestamp' in payload) - self.tester.assertEqual('core.local', payload['action_name']) - self.tester.assertEqual('core.local', payload['action_ref']) - self.tester.assertTrue('result' in payload) - self.tester.assertTrue('parameters' in payload) - self.tester.assertTrue('local-shell-cmd', payload['runner_ref']) + self.tester.assertEqual(payload["status"], "succeeded") + self.tester.assertTrue("execution_id" in payload) + self.tester.assertEqual( + payload["execution_id"], str(MOCK_EXECUTION.id) + ) + self.tester.assertTrue("start_timestamp" in payload) + self.tester.assertEqual("core.local", payload["action_name"]) + self.tester.assertEqual("core.local", payload["action_ref"]) + self.tester.assertTrue("result" in payload) + self.tester.assertTrue("parameters" in payload) + self.tester.assertTrue("local-shell-cmd", payload["runner_ref"]) except Exception: - self.tester.fail('Test failed') - - @mock.patch('st2common.util.action_db.get_action_by_ref', mock.MagicMock( - return_value=ActionDB(pack='core', name='local', runner_type={'name': 'local-shell-cmd'}, - parameters={}))) - @mock.patch('st2common.util.action_db.get_runnertype_by_name', mock.MagicMock( - return_value=RunnerTypeDB(name='foo', runner_parameters={}))) - @mock.patch.object(Action, 'get_by_ref', mock.MagicMock( - return_value={'runner_type': {'name': 'local-shell-cmd'}})) - @mock.patch.object(Policy, 'query', mock.MagicMock( - return_value=[])) - @mock.patch.object(Notifier, '_get_trace_context', mock.MagicMock(return_value={})) + self.tester.fail("Test failed") + + @mock.patch( + "st2common.util.action_db.get_action_by_ref", + mock.MagicMock( + return_value=ActionDB( + pack="core", + name="local", + runner_type={"name": "local-shell-cmd"}, + parameters={}, + ) + ), + ) + @mock.patch( + "st2common.util.action_db.get_runnertype_by_name", + mock.MagicMock(return_value=RunnerTypeDB(name="foo", runner_parameters={})), + ) + @mock.patch.object( + Action, + "get_by_ref", + mock.MagicMock(return_value={"runner_type": {"name": "local-shell-cmd"}}), + ) + @mock.patch.object(Policy, "query", mock.MagicMock(return_value=[])) + @mock.patch.object(Notifier, "_get_trace_context", mock.MagicMock(return_value={})) def test_notify_triggers(self): - liveaction_db = LiveActionDB(action='core.local') + liveaction_db = LiveActionDB(action="core.local") liveaction_db.id = bson.ObjectId() - liveaction_db.description = '' - liveaction_db.status = 'succeeded' + liveaction_db.description = "" + liveaction_db.status = "succeeded" liveaction_db.parameters = {} - on_success = NotificationSubSchema(message='Action succeeded.') - on_failure = NotificationSubSchema(message='Action failed.') - liveaction_db.notify = NotificationSchema(on_success=on_success, - on_failure=on_failure) + on_success = NotificationSubSchema(message="Action succeeded.") + on_failure = NotificationSubSchema(message="Action failed.") + liveaction_db.notify = NotificationSchema( + on_success=on_success, on_failure=on_failure + ) liveaction_db.start_timestamp = date_utils.get_datetime_utc_now() - liveaction_db.end_timestamp = \ - (liveaction_db.start_timestamp + datetime.timedelta(seconds=50)) + liveaction_db.end_timestamp = ( + liveaction_db.start_timestamp + datetime.timedelta(seconds=50) + ) LiveAction.add_or_update(liveaction_db) execution = MOCK_EXECUTION @@ -122,26 +142,39 @@ def test_notify_triggers(self): notifier = Notifier(connection=None, queues=[], trigger_dispatcher=dispatcher) notifier.process(execution) - @mock.patch('st2common.util.action_db.get_action_by_ref', mock.MagicMock( - return_value=ActionDB(pack='core', name='local', runner_type={'name': 'local-shell-cmd'}, - parameters={}))) - @mock.patch('st2common.util.action_db.get_runnertype_by_name', mock.MagicMock( - return_value=RunnerTypeDB(name='foo', runner_parameters={}))) - @mock.patch.object(Action, 'get_by_ref', mock.MagicMock( - return_value={'runner_type': {'name': 'local-shell-cmd'}})) - @mock.patch.object(Policy, 'query', mock.MagicMock( - return_value=[])) - @mock.patch.object(Notifier, '_get_trace_context', mock.MagicMock(return_value={})) + @mock.patch( + "st2common.util.action_db.get_action_by_ref", + mock.MagicMock( + return_value=ActionDB( + pack="core", + name="local", + runner_type={"name": "local-shell-cmd"}, + parameters={}, + ) + ), + ) + @mock.patch( + "st2common.util.action_db.get_runnertype_by_name", + mock.MagicMock(return_value=RunnerTypeDB(name="foo", runner_parameters={})), + ) + @mock.patch.object( + Action, + "get_by_ref", + mock.MagicMock(return_value={"runner_type": {"name": "local-shell-cmd"}}), + ) + @mock.patch.object(Policy, "query", mock.MagicMock(return_value=[])) + @mock.patch.object(Notifier, "_get_trace_context", mock.MagicMock(return_value={})) def test_notify_triggers_end_timestamp_none(self): - liveaction_db = LiveActionDB(action='core.local') + liveaction_db = LiveActionDB(action="core.local") liveaction_db.id = bson.ObjectId() - liveaction_db.description = '' - liveaction_db.status = 'succeeded' + liveaction_db.description = "" + liveaction_db.status = "succeeded" liveaction_db.parameters = {} - on_success = NotificationSubSchema(message='Action succeeded.') - on_failure = NotificationSubSchema(message='Action failed.') - liveaction_db.notify = NotificationSchema(on_success=on_success, - on_failure=on_failure) + on_success = NotificationSubSchema(message="Action succeeded.") + on_failure = NotificationSubSchema(message="Action failed.") + liveaction_db.notify = NotificationSchema( + on_success=on_success, on_failure=on_failure + ) liveaction_db.start_timestamp = date_utils.get_datetime_utc_now() # This tests for end_timestamp being set to None, which can happen when a policy cancels @@ -159,30 +192,48 @@ def test_notify_triggers_end_timestamp_none(self): notifier = Notifier(connection=None, queues=[], trigger_dispatcher=dispatcher) notifier.process(execution) - @mock.patch('st2common.util.action_db.get_action_by_ref', mock.MagicMock( - return_value=ActionDB(pack='core', name='local', runner_type={'name': 'local-shell-cmd'}))) - @mock.patch('st2common.util.action_db.get_runnertype_by_name', mock.MagicMock( - return_value=RunnerTypeDB(name='foo', runner_parameters={'runner_foo': 'foo'}))) - @mock.patch.object(Action, 'get_by_ref', mock.MagicMock( - return_value={'runner_type': {'name': 'local-shell-cmd'}})) - @mock.patch.object(Policy, 'query', mock.MagicMock( - return_value=[])) - @mock.patch.object(Notifier, '_post_generic_trigger', mock.MagicMock( - return_value=True)) - @mock.patch.object(Notifier, '_get_trace_context', mock.MagicMock(return_value={})) - @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch') + @mock.patch( + "st2common.util.action_db.get_action_by_ref", + mock.MagicMock( + return_value=ActionDB( + pack="core", name="local", runner_type={"name": "local-shell-cmd"} + ) + ), + ) + @mock.patch( + "st2common.util.action_db.get_runnertype_by_name", + mock.MagicMock( + return_value=RunnerTypeDB( + name="foo", runner_parameters={"runner_foo": "foo"} + ) + ), + ) + @mock.patch.object( + Action, + "get_by_ref", + mock.MagicMock(return_value={"runner_type": {"name": "local-shell-cmd"}}), + ) + @mock.patch.object(Policy, "query", mock.MagicMock(return_value=[])) + @mock.patch.object( + Notifier, "_post_generic_trigger", mock.MagicMock(return_value=True) + ) + @mock.patch.object(Notifier, "_get_trace_context", mock.MagicMock(return_value={})) + @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch") def test_notify_triggers_jinja_patterns(self, dispatch): - liveaction_db = LiveActionDB(action='core.local') + liveaction_db = LiveActionDB(action="core.local") liveaction_db.id = bson.ObjectId() - liveaction_db.description = '' - liveaction_db.status = 'succeeded' - liveaction_db.parameters = {'cmd': 'mamma mia', 'runner_foo': 'foo'} - on_success = NotificationSubSchema(message='Command {{action_parameters.cmd}} succeeded.', - data={'stdout': '{{action_results.stdout}}'}) + liveaction_db.description = "" + liveaction_db.status = "succeeded" + liveaction_db.parameters = {"cmd": "mamma mia", "runner_foo": "foo"} + on_success = NotificationSubSchema( + message="Command {{action_parameters.cmd}} succeeded.", + data={"stdout": "{{action_results.stdout}}"}, + ) liveaction_db.notify = NotificationSchema(on_success=on_success) liveaction_db.start_timestamp = date_utils.get_datetime_utc_now() - liveaction_db.end_timestamp = \ - (liveaction_db.start_timestamp + datetime.timedelta(seconds=50)) + liveaction_db.end_timestamp = ( + liveaction_db.start_timestamp + datetime.timedelta(seconds=50) + ) LiveAction.add_or_update(liveaction_db) @@ -192,26 +243,31 @@ def test_notify_triggers_jinja_patterns(self, dispatch): notifier = Notifier(connection=None, queues=[]) notifier.process(execution) - exp = {'status': 'succeeded', - 'start_timestamp': isotime.format(liveaction_db.start_timestamp), - 'route': 'notify.default', 'runner_ref': 'local-shell-cmd', - 'channel': 'notify.default', 'message': u'Command mamma mia succeeded.', - 'data': {'result': '{}', 'stdout': 'stuff happens'}, - 'action_ref': u'core.local', - 'execution_id': str(MOCK_EXECUTION.id), - 'end_timestamp': isotime.format(liveaction_db.end_timestamp)} - dispatch.assert_called_once_with('core.st2.generic.notifytrigger', payload=exp, - trace_context={}) + exp = { + "status": "succeeded", + "start_timestamp": isotime.format(liveaction_db.start_timestamp), + "route": "notify.default", + "runner_ref": "local-shell-cmd", + "channel": "notify.default", + "message": "Command mamma mia succeeded.", + "data": {"result": "{}", "stdout": "stuff happens"}, + "action_ref": "core.local", + "execution_id": str(MOCK_EXECUTION.id), + "end_timestamp": isotime.format(liveaction_db.end_timestamp), + } + dispatch.assert_called_once_with( + "core.st2.generic.notifytrigger", payload=exp, trace_context={} + ) notifier.process(execution) - @mock.patch.object(Notifier, '_get_runner_ref', mock.MagicMock( - return_value='local-shell-cmd')) - @mock.patch.object(Notifier, '_get_trace_context', mock.MagicMock( - return_value={})) - @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch') + @mock.patch.object( + Notifier, "_get_runner_ref", mock.MagicMock(return_value="local-shell-cmd") + ) + @mock.patch.object(Notifier, "_get_trace_context", mock.MagicMock(return_value={})) + @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch") def test_post_generic_trigger_emit_when_default_value_is_used(self, dispatch): for status in LIVEACTION_STATUSES: - liveaction_db = LiveActionDB(action='core.local') + liveaction_db = LiveActionDB(action="core.local") liveaction_db.status = status execution = MOCK_EXECUTION execution.liveaction = vars(LiveActionAPI.from_model(liveaction_db)) @@ -221,28 +277,34 @@ def test_post_generic_trigger_emit_when_default_value_is_used(self, dispatch): notifier._post_generic_trigger(liveaction_db, execution) if status in LIVEACTION_COMPLETED_STATES: - exp = {'status': status, - 'start_timestamp': str(liveaction_db.start_timestamp), - 'result': {}, 'parameters': {}, - 'action_ref': u'core.local', - 'runner_ref': 'local-shell-cmd', - 'execution_id': str(MOCK_EXECUTION.id), - 'action_name': u'core.local'} - dispatch.assert_called_with('core.st2.generic.actiontrigger', - payload=exp, trace_context={}) + exp = { + "status": status, + "start_timestamp": str(liveaction_db.start_timestamp), + "result": {}, + "parameters": {}, + "action_ref": "core.local", + "runner_ref": "local-shell-cmd", + "execution_id": str(MOCK_EXECUTION.id), + "action_name": "core.local", + } + dispatch.assert_called_with( + "core.st2.generic.actiontrigger", payload=exp, trace_context={} + ) self.assertEqual(dispatch.call_count, len(LIVEACTION_COMPLETED_STATES)) - @mock.patch('oslo_config.cfg.CONF.action_sensor', mock.MagicMock( - emit_when=['scheduled', 'pending', 'abandoned'])) - @mock.patch.object(Notifier, '_get_runner_ref', mock.MagicMock( - return_value='local-shell-cmd')) - @mock.patch.object(Notifier, '_get_trace_context', mock.MagicMock( - return_value={})) - @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch') + @mock.patch( + "oslo_config.cfg.CONF.action_sensor", + mock.MagicMock(emit_when=["scheduled", "pending", "abandoned"]), + ) + @mock.patch.object( + Notifier, "_get_runner_ref", mock.MagicMock(return_value="local-shell-cmd") + ) + @mock.patch.object(Notifier, "_get_trace_context", mock.MagicMock(return_value={})) + @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch") def test_post_generic_trigger_with_emit_condition(self, dispatch): for status in LIVEACTION_STATUSES: - liveaction_db = LiveActionDB(action='core.local') + liveaction_db = LiveActionDB(action="core.local") liveaction_db.status = status execution = MOCK_EXECUTION execution.liveaction = vars(LiveActionAPI.from_model(liveaction_db)) @@ -251,36 +313,45 @@ def test_post_generic_trigger_with_emit_condition(self, dispatch): notifier = Notifier(connection=None, queues=[]) notifier._post_generic_trigger(liveaction_db, execution) - if status in ['scheduled', 'pending', 'abandoned']: - exp = {'status': status, - 'start_timestamp': str(liveaction_db.start_timestamp), - 'result': {}, 'parameters': {}, - 'action_ref': u'core.local', - 'runner_ref': 'local-shell-cmd', - 'execution_id': str(MOCK_EXECUTION.id), - 'action_name': u'core.local'} - dispatch.assert_called_with('core.st2.generic.actiontrigger', - payload=exp, trace_context={}) + if status in ["scheduled", "pending", "abandoned"]: + exp = { + "status": status, + "start_timestamp": str(liveaction_db.start_timestamp), + "result": {}, + "parameters": {}, + "action_ref": "core.local", + "runner_ref": "local-shell-cmd", + "execution_id": str(MOCK_EXECUTION.id), + "action_name": "core.local", + } + dispatch.assert_called_with( + "core.st2.generic.actiontrigger", payload=exp, trace_context={} + ) self.assertEqual(dispatch.call_count, 3) - @mock.patch('oslo_config.cfg.CONF.action_sensor.enable', mock.MagicMock( - return_value=True)) - @mock.patch.object(Notifier, '_get_runner_ref', mock.MagicMock( - return_value='local-shell-cmd')) - @mock.patch.object(Notifier, '_get_trace_context', mock.MagicMock( - return_value={})) - @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch') - @mock.patch('st2actions.notifier.notifier.LiveAction') - @mock.patch('st2actions.notifier.notifier.policy_service.apply_post_run_policies', mock.Mock()) - def test_process_post_generic_notify_trigger_on_completed_state_default(self, - mock_LiveAction, mock_dispatch): + @mock.patch( + "oslo_config.cfg.CONF.action_sensor.enable", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + Notifier, "_get_runner_ref", mock.MagicMock(return_value="local-shell-cmd") + ) + @mock.patch.object(Notifier, "_get_trace_context", mock.MagicMock(return_value={})) + @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch") + @mock.patch("st2actions.notifier.notifier.LiveAction") + @mock.patch( + "st2actions.notifier.notifier.policy_service.apply_post_run_policies", + mock.Mock(), + ) + def test_process_post_generic_notify_trigger_on_completed_state_default( + self, mock_LiveAction, mock_dispatch + ): # Verify that generic action trigger is posted on all completed states when action sensor # is enabled for status in LIVEACTION_STATUSES: notifier = Notifier(connection=None, queues=[]) - liveaction_db = LiveActionDB(id=bson.ObjectId(), action='core.local') + liveaction_db = LiveActionDB(id=bson.ObjectId(), action="core.local") liveaction_db.status = status execution = MOCK_EXECUTION execution.liveaction = vars(LiveActionAPI.from_model(liveaction_db)) @@ -292,35 +363,45 @@ def test_process_post_generic_notify_trigger_on_completed_state_default(self, notifier.process(execution) if status in LIVEACTION_COMPLETED_STATES: - exp = {'status': status, - 'start_timestamp': str(liveaction_db.start_timestamp), - 'result': {}, 'parameters': {}, - 'action_ref': u'core.local', - 'runner_ref': 'local-shell-cmd', - 'execution_id': str(MOCK_EXECUTION.id), - 'action_name': u'core.local'} - mock_dispatch.assert_called_with('core.st2.generic.actiontrigger', - payload=exp, trace_context={}) + exp = { + "status": status, + "start_timestamp": str(liveaction_db.start_timestamp), + "result": {}, + "parameters": {}, + "action_ref": "core.local", + "runner_ref": "local-shell-cmd", + "execution_id": str(MOCK_EXECUTION.id), + "action_name": "core.local", + } + mock_dispatch.assert_called_with( + "core.st2.generic.actiontrigger", payload=exp, trace_context={} + ) self.assertEqual(mock_dispatch.call_count, len(LIVEACTION_COMPLETED_STATES)) - @mock.patch('oslo_config.cfg.CONF.action_sensor', mock.MagicMock( - enable=True, emit_when=['scheduled', 'pending', 'abandoned'])) - @mock.patch.object(Notifier, '_get_runner_ref', mock.MagicMock( - return_value='local-shell-cmd')) - @mock.patch.object(Notifier, '_get_trace_context', mock.MagicMock( - return_value={})) - @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch') - @mock.patch('st2actions.notifier.notifier.LiveAction') - @mock.patch('st2actions.notifier.notifier.policy_service.apply_post_run_policies', mock.Mock()) - def test_process_post_generic_notify_trigger_on_custom_emit_when_states(self, - mock_LiveAction, mock_dispatch): + @mock.patch( + "oslo_config.cfg.CONF.action_sensor", + mock.MagicMock(enable=True, emit_when=["scheduled", "pending", "abandoned"]), + ) + @mock.patch.object( + Notifier, "_get_runner_ref", mock.MagicMock(return_value="local-shell-cmd") + ) + @mock.patch.object(Notifier, "_get_trace_context", mock.MagicMock(return_value={})) + @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch") + @mock.patch("st2actions.notifier.notifier.LiveAction") + @mock.patch( + "st2actions.notifier.notifier.policy_service.apply_post_run_policies", + mock.Mock(), + ) + def test_process_post_generic_notify_trigger_on_custom_emit_when_states( + self, mock_LiveAction, mock_dispatch + ): # Verify that generic action trigger is posted on all completed states when action sensor # is enabled for status in LIVEACTION_STATUSES: notifier = Notifier(connection=None, queues=[]) - liveaction_db = LiveActionDB(id=bson.ObjectId(), action='core.local') + liveaction_db = LiveActionDB(id=bson.ObjectId(), action="core.local") liveaction_db.status = status execution = MOCK_EXECUTION execution.liveaction = vars(LiveActionAPI.from_model(liveaction_db)) @@ -331,15 +412,19 @@ def test_process_post_generic_notify_trigger_on_custom_emit_when_states(self, notifier = Notifier(connection=None, queues=[]) notifier.process(execution) - if status in ['scheduled', 'pending', 'abandoned']: - exp = {'status': status, - 'start_timestamp': str(liveaction_db.start_timestamp), - 'result': {}, 'parameters': {}, - 'action_ref': u'core.local', - 'runner_ref': 'local-shell-cmd', - 'execution_id': str(MOCK_EXECUTION.id), - 'action_name': u'core.local'} - mock_dispatch.assert_called_with('core.st2.generic.actiontrigger', - payload=exp, trace_context={}) + if status in ["scheduled", "pending", "abandoned"]: + exp = { + "status": status, + "start_timestamp": str(liveaction_db.start_timestamp), + "result": {}, + "parameters": {}, + "action_ref": "core.local", + "runner_ref": "local-shell-cmd", + "execution_id": str(MOCK_EXECUTION.id), + "action_name": "core.local", + } + mock_dispatch.assert_called_with( + "core.st2.generic.actiontrigger", payload=exp, trace_context={} + ) self.assertEqual(mock_dispatch.call_count, 3) diff --git a/st2actions/tests/unit/test_parallel_ssh.py b/st2actions/tests/unit/test_parallel_ssh.py index bf8c1df87bd..67052a53e09 100644 --- a/st2actions/tests/unit/test_parallel_ssh.py +++ b/st2actions/tests/unit/test_parallel_ssh.py @@ -17,13 +17,14 @@ import json import os -from mock import (patch, Mock, MagicMock) +from mock import patch, Mock, MagicMock import unittest2 from st2common.runners.parallel_ssh import ParallelSSHClient from st2common.runners.paramiko_ssh import ParamikoSSHClient from st2common.runners.paramiko_ssh import SSHCommandTimeoutError import st2tests.config as tests_config + tests_config.parse_args() MOCK_STDERR_SUDO_PASSWORD_ERROR = """ @@ -35,251 +36,294 @@ class ParallelSSHTests(unittest2.TestCase): - - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_connect_with_password(self): - hosts = ['localhost', '127.0.0.1'] - client = ParallelSSHClient(hosts=hosts, - user='ubuntu', - password='ubuntu', - connect=False) + hosts = ["localhost", "127.0.0.1"] + client = ParallelSSHClient( + hosts=hosts, user="ubuntu", password="ubuntu", connect=False + ) client.connect() expected_conn = { - 'allow_agent': False, - 'look_for_keys': False, - 'password': 'ubuntu', - 'username': 'ubuntu', - 'timeout': 60, - 'port': 22 + "allow_agent": False, + "look_for_keys": False, + "password": "ubuntu", + "username": "ubuntu", + "timeout": 60, + "port": 22, } for host in hosts: - expected_conn['hostname'] = host - client._hosts_client[host].client.connect.assert_called_once_with(**expected_conn) + expected_conn["hostname"] = host + client._hosts_client[host].client.connect.assert_called_once_with( + **expected_conn + ) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_connect_with_random_ports(self): - hosts = ['localhost:22', '127.0.0.1:55', 'st2build001'] - client = ParallelSSHClient(hosts=hosts, - user='ubuntu', - password='ubuntu', - connect=False) + hosts = ["localhost:22", "127.0.0.1:55", "st2build001"] + client = ParallelSSHClient( + hosts=hosts, user="ubuntu", password="ubuntu", connect=False + ) client.connect() expected_conn = { - 'allow_agent': False, - 'look_for_keys': False, - 'password': 'ubuntu', - 'username': 'ubuntu', - 'timeout': 60, - 'port': 22 + "allow_agent": False, + "look_for_keys": False, + "password": "ubuntu", + "username": "ubuntu", + "timeout": 60, + "port": 22, } for host in hosts: hostname, port = client._get_host_port_info(host) - expected_conn['hostname'] = hostname - expected_conn['port'] = port - client._hosts_client[hostname].client.connect.assert_called_once_with(**expected_conn) + expected_conn["hostname"] = hostname + expected_conn["port"] = port + client._hosts_client[hostname].client.connect.assert_called_once_with( + **expected_conn + ) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_connect_with_key(self): - hosts = ['localhost', '127.0.0.1', 'st2build001'] - client = ParallelSSHClient(hosts=hosts, - user='ubuntu', - pkey_file='~/.ssh/id_rsa', - connect=False) + hosts = ["localhost", "127.0.0.1", "st2build001"] + client = ParallelSSHClient( + hosts=hosts, user="ubuntu", pkey_file="~/.ssh/id_rsa", connect=False + ) client.connect() expected_conn = { - 'allow_agent': False, - 'look_for_keys': False, - 'key_filename': '~/.ssh/id_rsa', - 'username': 'ubuntu', - 'timeout': 60, - 'port': 22 + "allow_agent": False, + "look_for_keys": False, + "key_filename": "~/.ssh/id_rsa", + "username": "ubuntu", + "timeout": 60, + "port": 22, } for host in hosts: hostname, port = client._get_host_port_info(host) - expected_conn['hostname'] = hostname - expected_conn['port'] = port - client._hosts_client[hostname].client.connect.assert_called_once_with(**expected_conn) + expected_conn["hostname"] = hostname + expected_conn["port"] = port + client._hosts_client[hostname].client.connect.assert_called_once_with( + **expected_conn + ) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_connect_with_bastion(self): - hosts = ['localhost', '127.0.0.1'] - client = ParallelSSHClient(hosts=hosts, - user='ubuntu', - pkey_file='~/.ssh/id_rsa', - bastion_host='testing_bastion_host', - connect=False) + hosts = ["localhost", "127.0.0.1"] + client = ParallelSSHClient( + hosts=hosts, + user="ubuntu", + pkey_file="~/.ssh/id_rsa", + bastion_host="testing_bastion_host", + connect=False, + ) client.connect() for host in hosts: hostname, _ = client._get_host_port_info(host) - self.assertEqual(client._hosts_client[hostname].bastion_host, 'testing_bastion_host') + self.assertEqual( + client._hosts_client[hostname].bastion_host, "testing_bastion_host" + ) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, 'run', MagicMock(return_value=('/home/ubuntu', '', 0))) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, "run", MagicMock(return_value=("/home/ubuntu", "", 0)) + ) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_run_command(self): - hosts = ['localhost', '127.0.0.1', 'st2build001'] - client = ParallelSSHClient(hosts=hosts, - user='ubuntu', - pkey_file='~/.ssh/id_rsa', - connect=True) - client.run('pwd', timeout=60) - expected_kwargs = { - 'timeout': 60, - 'call_line_handler_func': True - } + hosts = ["localhost", "127.0.0.1", "st2build001"] + client = ParallelSSHClient( + hosts=hosts, user="ubuntu", pkey_file="~/.ssh/id_rsa", connect=True + ) + client.run("pwd", timeout=60) + expected_kwargs = {"timeout": 60, "call_line_handler_func": True} for host in hosts: hostname, _ = client._get_host_port_info(host) - client._hosts_client[hostname].run.assert_called_with('pwd', **expected_kwargs) + client._hosts_client[hostname].run.assert_called_with( + "pwd", **expected_kwargs + ) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_run_command_timeout(self): # Make sure stdout and stderr is included on timeout - hosts = ['localhost', '127.0.0.1', 'st2build001'] - client = ParallelSSHClient(hosts=hosts, - user='ubuntu', - pkey_file='~/.ssh/id_rsa', - connect=True) - mock_run = Mock(side_effect=SSHCommandTimeoutError(cmd='pwd', timeout=10, - stdout='a', - stderr='b', - ssh_connect_timeout=30)) + hosts = ["localhost", "127.0.0.1", "st2build001"] + client = ParallelSSHClient( + hosts=hosts, user="ubuntu", pkey_file="~/.ssh/id_rsa", connect=True + ) + mock_run = Mock( + side_effect=SSHCommandTimeoutError( + cmd="pwd", timeout=10, stdout="a", stderr="b", ssh_connect_timeout=30 + ) + ) for host in hosts: hostname, _ = client._get_host_port_info(host) host_client = client._hosts_client[host] host_client.run = mock_run - results = client.run('pwd') + results = client.run("pwd") for host in hosts: result = results[host] - self.assertEqual(result['failed'], True) - self.assertEqual(result['stdout'], 'a') - self.assertEqual(result['stderr'], 'b') - self.assertEqual(result['return_code'], -9) + self.assertEqual(result["failed"], True) + self.assertEqual(result["stdout"], "a") + self.assertEqual(result["stderr"], "b") + self.assertEqual(result["return_code"], -9) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, 'put', MagicMock(return_value={})) - @patch.object(os.path, 'exists', MagicMock(return_value=True)) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object(ParamikoSSHClient, "put", MagicMock(return_value={})) + @patch.object(os.path, "exists", MagicMock(return_value=True)) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_put(self): - hosts = ['localhost', '127.0.0.1', 'st2build001'] - client = ParallelSSHClient(hosts=hosts, - user='ubuntu', - pkey_file='~/.ssh/id_rsa', - connect=True) - client.put('/local/stuff', '/remote/stuff', mode=0o744) - expected_kwargs = { - 'mode': 0o744, - 'mirror_local_mode': False - } + hosts = ["localhost", "127.0.0.1", "st2build001"] + client = ParallelSSHClient( + hosts=hosts, user="ubuntu", pkey_file="~/.ssh/id_rsa", connect=True + ) + client.put("/local/stuff", "/remote/stuff", mode=0o744) + expected_kwargs = {"mode": 0o744, "mirror_local_mode": False} for host in hosts: hostname, _ = client._get_host_port_info(host) - client._hosts_client[hostname].put.assert_called_with('/local/stuff', '/remote/stuff', - **expected_kwargs) + client._hosts_client[hostname].put.assert_called_with( + "/local/stuff", "/remote/stuff", **expected_kwargs + ) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, 'delete_file', MagicMock(return_value={})) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object(ParamikoSSHClient, "delete_file", MagicMock(return_value={})) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_delete_file(self): - hosts = ['localhost', '127.0.0.1', 'st2build001'] - client = ParallelSSHClient(hosts=hosts, - user='ubuntu', - pkey_file='~/.ssh/id_rsa', - connect=True) - client.delete_file('/remote/stuff') + hosts = ["localhost", "127.0.0.1", "st2build001"] + client = ParallelSSHClient( + hosts=hosts, user="ubuntu", pkey_file="~/.ssh/id_rsa", connect=True + ) + client.delete_file("/remote/stuff") for host in hosts: hostname, _ = client._get_host_port_info(host) - client._hosts_client[hostname].delete_file.assert_called_with('/remote/stuff') + client._hosts_client[hostname].delete_file.assert_called_with( + "/remote/stuff" + ) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, 'delete_dir', MagicMock(return_value={})) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object(ParamikoSSHClient, "delete_dir", MagicMock(return_value={})) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_delete_dir(self): - hosts = ['localhost', '127.0.0.1', 'st2build001'] - client = ParallelSSHClient(hosts=hosts, - user='ubuntu', - pkey_file='~/.ssh/id_rsa', - connect=True) - client.delete_dir('/remote/stuff/', force=True) - expected_kwargs = { - 'force': True, - 'timeout': None - } + hosts = ["localhost", "127.0.0.1", "st2build001"] + client = ParallelSSHClient( + hosts=hosts, user="ubuntu", pkey_file="~/.ssh/id_rsa", connect=True + ) + client.delete_dir("/remote/stuff/", force=True) + expected_kwargs = {"force": True, "timeout": None} for host in hosts: hostname, _ = client._get_host_port_info(host) - client._hosts_client[hostname].delete_dir.assert_called_with('/remote/stuff/', - **expected_kwargs) + client._hosts_client[hostname].delete_dir.assert_called_with( + "/remote/stuff/", **expected_kwargs + ) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_host_port_info(self): - client = ParallelSSHClient(hosts=['dummy'], - user='ubuntu', - pkey_file='~/.ssh/id_rsa', - connect=True) + client = ParallelSSHClient( + hosts=["dummy"], user="ubuntu", pkey_file="~/.ssh/id_rsa", connect=True + ) # No port case. Port should be 22. - host_str = '1.2.3.4' + host_str = "1.2.3.4" host, port = client._get_host_port_info(host_str) self.assertEqual(host, host_str) self.assertEqual(port, 22) # IPv6 with square brackets with port specified. - host_str = '[fec2::10]:55' + host_str = "[fec2::10]:55" host, port = client._get_host_port_info(host_str) - self.assertEqual(host, 'fec2::10') + self.assertEqual(host, "fec2::10") self.assertEqual(port, 55) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, 'run', MagicMock( - return_value=(json.dumps({'foo': 'bar'}), '', 0)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "run", + MagicMock(return_value=(json.dumps({"foo": "bar"}), "", 0)), + ) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), ) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) def test_run_command_json_output_transformed_to_object(self): - hosts = ['127.0.0.1'] - client = ParallelSSHClient(hosts=hosts, - user='ubuntu', - pkey_file='~/.ssh/id_rsa', - connect=True) - results = client.run('stuff', timeout=60) - self.assertIn('127.0.0.1', results) - self.assertDictEqual(results['127.0.0.1']['stdout'], {'foo': 'bar'}) + hosts = ["127.0.0.1"] + client = ParallelSSHClient( + hosts=hosts, user="ubuntu", pkey_file="~/.ssh/id_rsa", connect=True + ) + results = client.run("stuff", timeout=60) + self.assertIn("127.0.0.1", results) + self.assertDictEqual(results["127.0.0.1"]["stdout"], {"foo": "bar"}) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, 'run', MagicMock( - return_value=('', MOCK_STDERR_SUDO_PASSWORD_ERROR, 0)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "run", + MagicMock(return_value=("", MOCK_STDERR_SUDO_PASSWORD_ERROR, 0)), + ) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), ) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) def test_run_sudo_password_user_friendly_error(self): - hosts = ['127.0.0.1'] - client = ParallelSSHClient(hosts=hosts, - user='ubuntu', - pkey_file='~/.ssh/id_rsa', - connect=True, - sudo_password=True) - results = client.run('stuff', timeout=60) + hosts = ["127.0.0.1"] + client = ParallelSSHClient( + hosts=hosts, + user="ubuntu", + pkey_file="~/.ssh/id_rsa", + connect=True, + sudo_password=True, + ) + results = client.run("stuff", timeout=60) - expected_error = ('Failed executing command "stuff" on host "127.0.0.1" ' - 'Invalid sudo password provided or sudo is not configured for ' - 'this user (bar)') + expected_error = ( + 'Failed executing command "stuff" on host "127.0.0.1" ' + "Invalid sudo password provided or sudo is not configured for " + "this user (bar)" + ) - self.assertIn('127.0.0.1', results) - self.assertEqual(results['127.0.0.1']['succeeded'], False) - self.assertEqual(results['127.0.0.1']['failed'], True) - self.assertIn(expected_error, results['127.0.0.1']['error']) + self.assertIn("127.0.0.1", results) + self.assertEqual(results["127.0.0.1"]["succeeded"], False) + self.assertEqual(results["127.0.0.1"]["failed"], True) + self.assertIn(expected_error, results["127.0.0.1"]["error"]) diff --git a/st2actions/tests/unit/test_paramiko_remote_script_runner.py b/st2actions/tests/unit/test_paramiko_remote_script_runner.py index 1246f1cbe2d..1bf67a95036 100644 --- a/st2actions/tests/unit/test_paramiko_remote_script_runner.py +++ b/st2actions/tests/unit/test_paramiko_remote_script_runner.py @@ -21,6 +21,7 @@ # XXX: There is an import dependency. Config needs to setup # before importing remote_script_runner classes. import st2tests.config as tests_config + tests_config.parse_args() from st2common.util import jsonify @@ -35,234 +36,254 @@ from st2tests.fixturesloader import FixturesLoader -__all__ = [ - 'ParamikoScriptRunnerTestCase' -] +__all__ = ["ParamikoScriptRunnerTestCase"] -FIXTURES_PACK = 'generic' -TEST_MODELS = { - 'actions': ['a1.yaml'] -} +FIXTURES_PACK = "generic" +TEST_MODELS = {"actions": ["a1.yaml"]} -MODELS = FixturesLoader().load_models(fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_MODELS) -ACTION_1 = MODELS['actions']['a1.yaml'] +MODELS = FixturesLoader().load_models( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS +) +ACTION_1 = MODELS["actions"]["a1.yaml"] class ParamikoScriptRunnerTestCase(unittest2.TestCase): - @patch('st2common.runners.parallel_ssh.ParallelSSHClient', Mock) - @patch.object(jsonify, 'json_loads', MagicMock(return_value={})) - @patch.object(ParallelSSHClient, 'run', MagicMock(return_value={})) - @patch.object(ParallelSSHClient, 'connect', MagicMock(return_value={})) + @patch("st2common.runners.parallel_ssh.ParallelSSHClient", Mock) + @patch.object(jsonify, "json_loads", MagicMock(return_value={})) + @patch.object(ParallelSSHClient, "run", MagicMock(return_value={})) + @patch.object(ParallelSSHClient, "connect", MagicMock(return_value={})) def test_cwd_used_correctly(self): remote_action = ParamikoRemoteScriptAction( - 'foo-script', bson.ObjectId(), - script_local_path_abs='/home/stanley/shiz_storm.py', + "foo-script", + bson.ObjectId(), + script_local_path_abs="/home/stanley/shiz_storm.py", script_local_libs_path_abs=None, - named_args={}, positional_args=['blank space'], env_vars={}, - on_behalf_user='svetlana', user='stanley', - private_key='---SOME RSA KEY---', - remote_dir='/tmp', hosts=['127.0.0.1'], cwd='/test/cwd/' + named_args={}, + positional_args=["blank space"], + env_vars={}, + on_behalf_user="svetlana", + user="stanley", + private_key="---SOME RSA KEY---", + remote_dir="/tmp", + hosts=["127.0.0.1"], + cwd="/test/cwd/", + ) + paramiko_runner = ParamikoRemoteScriptRunner("runner_1") + paramiko_runner._parallel_ssh_client = ParallelSSHClient( + ["127.0.0.1"], "stanley" ) - paramiko_runner = ParamikoRemoteScriptRunner('runner_1') - paramiko_runner._parallel_ssh_client = ParallelSSHClient(['127.0.0.1'], 'stanley') paramiko_runner._run_script_on_remote_host(remote_action) exp_cmd = "cd /test/cwd/ && /tmp/shiz_storm.py 'blank space'" - ParallelSSHClient.run.assert_called_with(exp_cmd, - timeout=None) + ParallelSSHClient.run.assert_called_with(exp_cmd, timeout=None) def test_username_invalid_private_key(self): - paramiko_runner = ParamikoRemoteScriptRunner('runner_1') + paramiko_runner = ParamikoRemoteScriptRunner("runner_1") paramiko_runner.runner_parameters = { - 'username': 'test_user', - 'hosts': '127.0.0.1', - 'private_key': 'invalid private key', + "username": "test_user", + "hosts": "127.0.0.1", + "private_key": "invalid private key", } paramiko_runner.context = {} self.assertRaises(NoHostsConnectedToException, paramiko_runner.pre_run) - @patch('st2common.runners.parallel_ssh.ParallelSSHClient', Mock) - @patch.object(ParallelSSHClient, 'run', MagicMock(return_value={})) - @patch.object(ParallelSSHClient, 'connect', MagicMock(return_value={})) + @patch("st2common.runners.parallel_ssh.ParallelSSHClient", Mock) + @patch.object(ParallelSSHClient, "run", MagicMock(return_value={})) + @patch.object(ParallelSSHClient, "connect", MagicMock(return_value={})) def test_top_level_error_is_correctly_reported(self): # Verify that a top-level error doesn't cause an exception to be thrown. # In a top-level error case, result dict doesn't contain entry per host - paramiko_runner = ParamikoRemoteScriptRunner('runner_1') + paramiko_runner = ParamikoRemoteScriptRunner("runner_1") paramiko_runner.runner_parameters = { - 'username': 'test_user', - 'hosts': '127.0.0.1' + "username": "test_user", + "hosts": "127.0.0.1", } paramiko_runner.action = ACTION_1 - paramiko_runner.liveaction_id = 'foo' - paramiko_runner.entry_point = 'foo' + paramiko_runner.liveaction_id = "foo" + paramiko_runner.entry_point = "foo" paramiko_runner.context = {} - paramiko_runner._cwd = '/tmp' - paramiko_runner._copy_artifacts = Mock(side_effect=Exception('fail!')) + paramiko_runner._cwd = "/tmp" + paramiko_runner._copy_artifacts = Mock(side_effect=Exception("fail!")) status, result, _ = paramiko_runner.run(action_parameters={}) self.assertEqual(status, LIVEACTION_STATUS_FAILED) - self.assertEqual(result['failed'], True) - self.assertEqual(result['succeeded'], False) - self.assertIn('Failed copying content to remote boxes', result['error']) + self.assertEqual(result["failed"], True) + self.assertEqual(result["succeeded"], False) + self.assertIn("Failed copying content to remote boxes", result["error"]) def test_command_construction_correct_default_parameter_values_are_used(self): runner_parameters = {} action_db_parameters = { - 'project': { - 'type': 'string', - 'default': 'st2', - 'position': 0, - }, - 'version': { - 'type': 'string', - 'position': 1, - 'required': True + "project": { + "type": "string", + "default": "st2", + "position": 0, }, - 'fork': { - 'type': 'string', - 'position': 2, - 'default': 'StackStorm', + "version": {"type": "string", "position": 1, "required": True}, + "fork": { + "type": "string", + "position": 2, + "default": "StackStorm", }, - 'branch': { - 'type': 'string', - 'position': 3, - 'default': 'master', + "branch": { + "type": "string", + "position": 3, + "default": "master", }, - 'update_changelog': { - 'type': 'boolean', - 'position': 4, - 'default': False + "update_changelog": {"type": "boolean", "position": 4, "default": False}, + "local_repo": { + "type": "string", + "position": 5, }, - 'local_repo': { - 'type': 'string', - 'position': 5, - } } context = {} - action_db = ActionDB(pack='dummy', name='action') + action_db = ActionDB(pack="dummy", name="action") - runner = ParamikoRemoteScriptRunner('id') + runner = ParamikoRemoteScriptRunner("id") runner.runner_parameters = {} runner.action = action_db # 1. All default values used live_action_db_parameters = { - 'project': 'st2flow', - 'version': '3.0.0', - 'fork': 'StackStorm', - 'local_repo': '/tmp/repo' + "project": "st2flow", + "version": "3.0.0", + "fork": "StackStorm", + "local_repo": "/tmp/repo", } - runner_params, action_params = param_utils.render_final_params(runner_parameters, - action_db_parameters, - live_action_db_parameters, - context) + runner_params, action_params = param_utils.render_final_params( + runner_parameters, action_db_parameters, live_action_db_parameters, context + ) - self.assertDictEqual(action_params, { - 'project': 'st2flow', - 'version': '3.0.0', - 'fork': 'StackStorm', - 'branch': 'master', # default value used - 'update_changelog': False, # default value used - 'local_repo': '/tmp/repo' - }) + self.assertDictEqual( + action_params, + { + "project": "st2flow", + "version": "3.0.0", + "fork": "StackStorm", + "branch": "master", # default value used + "update_changelog": False, # default value used + "local_repo": "/tmp/repo", + }, + ) action_db.parameters = action_db_parameters positional_args, named_args = runner._get_script_args(action_params) named_args = runner._transform_named_args(named_args) remote_action = ParamikoRemoteScriptAction( - 'foo-script', 'id', - script_local_path_abs='/tmp/script.sh', + "foo-script", + "id", + script_local_path_abs="/tmp/script.sh", script_local_libs_path_abs=None, - named_args=named_args, positional_args=positional_args, env_vars={}, - on_behalf_user='svetlana', user='stanley', - remote_dir='/tmp', hosts=['127.0.0.1'], cwd='/test/cwd/' + named_args=named_args, + positional_args=positional_args, + env_vars={}, + on_behalf_user="svetlana", + user="stanley", + remote_dir="/tmp", + hosts=["127.0.0.1"], + cwd="/test/cwd/", ) command_string = remote_action.get_full_command_string() - expected = 'cd /test/cwd/ && /tmp/script.sh st2flow 3.0.0 StackStorm master 0 /tmp/repo' + expected = "cd /test/cwd/ && /tmp/script.sh st2flow 3.0.0 StackStorm master 0 /tmp/repo" self.assertEqual(command_string, expected) # 2. Some default values used live_action_db_parameters = { - 'project': 'st2web', - 'version': '3.1.0', - 'fork': 'StackStorm1', - 'update_changelog': True, - 'local_repo': '/tmp/repob' + "project": "st2web", + "version": "3.1.0", + "fork": "StackStorm1", + "update_changelog": True, + "local_repo": "/tmp/repob", } - runner_params, action_params = param_utils.render_final_params(runner_parameters, - action_db_parameters, - live_action_db_parameters, - context) + runner_params, action_params = param_utils.render_final_params( + runner_parameters, action_db_parameters, live_action_db_parameters, context + ) - self.assertDictEqual(action_params, { - 'project': 'st2web', - 'version': '3.1.0', - 'fork': 'StackStorm1', - 'branch': 'master', # default value used - 'update_changelog': True, # default value used - 'local_repo': '/tmp/repob' - }) + self.assertDictEqual( + action_params, + { + "project": "st2web", + "version": "3.1.0", + "fork": "StackStorm1", + "branch": "master", # default value used + "update_changelog": True, # default value used + "local_repo": "/tmp/repob", + }, + ) action_db.parameters = action_db_parameters positional_args, named_args = runner._get_script_args(action_params) named_args = runner._transform_named_args(named_args) remote_action = ParamikoRemoteScriptAction( - 'foo-script', 'id', - script_local_path_abs='/tmp/script.sh', + "foo-script", + "id", + script_local_path_abs="/tmp/script.sh", script_local_libs_path_abs=None, - named_args=named_args, positional_args=positional_args, env_vars={}, - on_behalf_user='svetlana', user='stanley', - remote_dir='/tmp', hosts=['127.0.0.1'], cwd='/test/cwd/' + named_args=named_args, + positional_args=positional_args, + env_vars={}, + on_behalf_user="svetlana", + user="stanley", + remote_dir="/tmp", + hosts=["127.0.0.1"], + cwd="/test/cwd/", ) command_string = remote_action.get_full_command_string() - expected = 'cd /test/cwd/ && /tmp/script.sh st2web 3.1.0 StackStorm1 master 1 /tmp/repob' + expected = "cd /test/cwd/ && /tmp/script.sh st2web 3.1.0 StackStorm1 master 1 /tmp/repob" self.assertEqual(command_string, expected) # 3. None is specified for a boolean parameter, should use a default live_action_db_parameters = { - 'project': 'st2rbac', - 'version': '3.2.0', - 'fork': 'StackStorm2', - 'update_changelog': None, - 'local_repo': '/tmp/repoc' + "project": "st2rbac", + "version": "3.2.0", + "fork": "StackStorm2", + "update_changelog": None, + "local_repo": "/tmp/repoc", } - runner_params, action_params = param_utils.render_final_params(runner_parameters, - action_db_parameters, - live_action_db_parameters, - context) + runner_params, action_params = param_utils.render_final_params( + runner_parameters, action_db_parameters, live_action_db_parameters, context + ) - self.assertDictEqual(action_params, { - 'project': 'st2rbac', - 'version': '3.2.0', - 'fork': 'StackStorm2', - 'branch': 'master', # default value used - 'update_changelog': False, # default value used - 'local_repo': '/tmp/repoc' - }) + self.assertDictEqual( + action_params, + { + "project": "st2rbac", + "version": "3.2.0", + "fork": "StackStorm2", + "branch": "master", # default value used + "update_changelog": False, # default value used + "local_repo": "/tmp/repoc", + }, + ) action_db.parameters = action_db_parameters positional_args, named_args = runner._get_script_args(action_params) named_args = runner._transform_named_args(named_args) remote_action = ParamikoRemoteScriptAction( - 'foo-script', 'id', - script_local_path_abs='/tmp/script.sh', + "foo-script", + "id", + script_local_path_abs="/tmp/script.sh", script_local_libs_path_abs=None, - named_args=named_args, positional_args=positional_args, env_vars={}, - on_behalf_user='svetlana', user='stanley', - remote_dir='/tmp', hosts=['127.0.0.1'], cwd='/test/cwd/' + named_args=named_args, + positional_args=positional_args, + env_vars={}, + on_behalf_user="svetlana", + user="stanley", + remote_dir="/tmp", + hosts=["127.0.0.1"], + cwd="/test/cwd/", ) command_string = remote_action.get_full_command_string() - expected = 'cd /test/cwd/ && /tmp/script.sh st2rbac 3.2.0 StackStorm2 master 0 /tmp/repoc' + expected = "cd /test/cwd/ && /tmp/script.sh st2rbac 3.2.0 StackStorm2 master 0 /tmp/repoc" self.assertEqual(command_string, expected) diff --git a/st2actions/tests/unit/test_paramiko_ssh.py b/st2actions/tests/unit/test_paramiko_ssh.py index 7335f11a7eb..eadc4a477a8 100644 --- a/st2actions/tests/unit/test_paramiko_ssh.py +++ b/st2actions/tests/unit/test_paramiko_ssh.py @@ -28,363 +28,456 @@ from st2common.runners.paramiko_ssh import ParamikoSSHClient from st2tests.fixturesloader import get_resources_base_path import st2tests.config as tests_config + tests_config.parse_args() -__all__ = [ - 'ParamikoSSHClientTestCase' -] +__all__ = ["ParamikoSSHClientTestCase"] class ParamikoSSHClientTestCase(unittest2.TestCase): - - @patch('paramiko.SSHClient', Mock) + @patch("paramiko.SSHClient", Mock) def setUp(self): """ Creates the object patching the actual connection. """ - cfg.CONF.set_override(name='ssh_key_file', override=None, group='system_user') - cfg.CONF.set_override(name='use_ssh_config', override=False, group='ssh_runner') - cfg.CONF.set_override(name='ssh_connect_timeout', override=30, group='ssh_runner') - - conn_params = {'hostname': 'dummy.host.org', - 'port': 8822, - 'username': 'ubuntu', - 'key_files': '~/.ssh/ubuntu_ssh', - 'timeout': 30} + cfg.CONF.set_override(name="ssh_key_file", override=None, group="system_user") + cfg.CONF.set_override(name="use_ssh_config", override=False, group="ssh_runner") + cfg.CONF.set_override( + name="ssh_connect_timeout", override=30, group="ssh_runner" + ) + + conn_params = { + "hostname": "dummy.host.org", + "port": 8822, + "username": "ubuntu", + "key_files": "~/.ssh/ubuntu_ssh", + "timeout": 30, + } self.ssh_cli = ParamikoSSHClient(**conn_params) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) - @patch('paramiko.ProxyCommand') + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) + @patch("paramiko.ProxyCommand") def test_set_proxycommand(self, mock_ProxyCommand): """ Loads proxy commands from ssh config file """ - ssh_config_file_path = os.path.join(get_resources_base_path(), - 'ssh', 'dummy_ssh_config') - cfg.CONF.set_override(name='ssh_config_file_path', - override=ssh_config_file_path, - group='ssh_runner') - cfg.CONF.set_override(name='use_ssh_config', override=True, - group='ssh_runner') - - conn_params = {'hostname': 'dummy.host.org', 'username': 'ubuntu', 'password': 'foo'} + ssh_config_file_path = os.path.join( + get_resources_base_path(), "ssh", "dummy_ssh_config" + ) + cfg.CONF.set_override( + name="ssh_config_file_path", + override=ssh_config_file_path, + group="ssh_runner", + ) + cfg.CONF.set_override(name="use_ssh_config", override=True, group="ssh_runner") + + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "password": "foo", + } mock = ParamikoSSHClient(**conn_params) mock.connect() - mock_ProxyCommand.assert_called_once_with('ssh -q -W dummy.host.org:22 dummy_bastion') - - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) - @patch('paramiko.ProxyCommand') + mock_ProxyCommand.assert_called_once_with( + "ssh -q -W dummy.host.org:22 dummy_bastion" + ) + + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) + @patch("paramiko.ProxyCommand") def test_fail_set_proxycommand(self, mock_ProxyCommand): """ Loads proxy commands from ssh config file """ - ssh_config_file_path = os.path.join(get_resources_base_path(), - 'ssh', 'dummy_ssh_config_fail') - cfg.CONF.set_override(name='ssh_config_file_path', - override=ssh_config_file_path, - group='ssh_runner') - cfg.CONF.set_override(name='use_ssh_config', - override=True, group='ssh_runner') - - conn_params = {'hostname': 'dummy.host.org'} + ssh_config_file_path = os.path.join( + get_resources_base_path(), "ssh", "dummy_ssh_config_fail" + ) + cfg.CONF.set_override( + name="ssh_config_file_path", + override=ssh_config_file_path, + group="ssh_runner", + ) + cfg.CONF.set_override(name="use_ssh_config", override=True, group="ssh_runner") + + conn_params = {"hostname": "dummy.host.org"} mock = ParamikoSSHClient(**conn_params) self.assertRaises(Exception, mock.connect) mock_ProxyCommand.assert_not_called() - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_create_with_password(self): - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'password': 'ubuntu'} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "password": "ubuntu", + } mock = ParamikoSSHClient(**conn_params) mock.connect() - expected_conn = {'username': 'ubuntu', - 'password': 'ubuntu', - 'allow_agent': False, - 'hostname': 'dummy.host.org', - 'look_for_keys': False, - 'timeout': 30, - 'port': 22} + expected_conn = { + "username": "ubuntu", + "password": "ubuntu", + "allow_agent": False, + "hostname": "dummy.host.org", + "look_for_keys": False, + "timeout": 30, + "port": 22, + } mock.client.connect.assert_called_once_with(**expected_conn) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_deprecated_key_argument(self): - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'key_files': 'id_rsa'} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "key_files": "id_rsa", + } mock = ParamikoSSHClient(**conn_params) mock.connect() - expected_conn = {'username': 'ubuntu', - 'allow_agent': False, - 'hostname': 'dummy.host.org', - 'look_for_keys': False, - 'key_filename': 'id_rsa', - 'timeout': 30, - 'port': 22} + expected_conn = { + "username": "ubuntu", + "allow_agent": False, + "hostname": "dummy.host.org", + "look_for_keys": False, + "key_filename": "id_rsa", + "timeout": 30, + "port": 22, + } mock.client.connect.assert_called_once_with(**expected_conn) def test_key_files_and_key_material_arguments_are_mutual_exclusive(self): - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'key_files': 'id_rsa', - 'key_material': 'key'} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "key_files": "id_rsa", + "key_material": "key", + } - expected_msg = ('key_files and key_material arguments are mutually exclusive. ' - 'Supply only one.') + expected_msg = ( + "key_files and key_material arguments are mutually exclusive. " + "Supply only one." + ) client = ParamikoSSHClient(**conn_params) - self.assertRaisesRegexp(ValueError, expected_msg, - client.connect) + self.assertRaisesRegexp(ValueError, expected_msg, client.connect) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_key_material_argument(self): - path = os.path.join(get_resources_base_path(), - 'ssh', 'dummy_rsa') + path = os.path.join(get_resources_base_path(), "ssh", "dummy_rsa") - with open(path, 'r') as fp: + with open(path, "r") as fp: private_key = fp.read() - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'key_material': private_key} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "key_material": private_key, + } mock = ParamikoSSHClient(**conn_params) mock.connect() pkey = paramiko.RSAKey.from_private_key(StringIO(private_key)) - expected_conn = {'username': 'ubuntu', - 'allow_agent': False, - 'hostname': 'dummy.host.org', - 'look_for_keys': False, - 'pkey': pkey, - 'timeout': 30, - 'port': 22} + expected_conn = { + "username": "ubuntu", + "allow_agent": False, + "hostname": "dummy.host.org", + "look_for_keys": False, + "pkey": pkey, + "timeout": 30, + "port": 22, + } mock.client.connect.assert_called_once_with(**expected_conn) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_key_material_argument_invalid_key(self): - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'key_material': 'id_rsa'} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "key_material": "id_rsa", + } mock = ParamikoSSHClient(**conn_params) - expected_msg = 'Invalid or unsupported key type' - self.assertRaisesRegexp(paramiko.ssh_exception.SSHException, - expected_msg, mock.connect) + expected_msg = "Invalid or unsupported key type" + self.assertRaisesRegexp( + paramiko.ssh_exception.SSHException, expected_msg, mock.connect + ) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=True)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, "_is_key_file_needs_passphrase", MagicMock(return_value=True) + ) def test_passphrase_no_key_provided(self): - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'passphrase': 'testphrase'} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "passphrase": "testphrase", + } - expected_msg = 'passphrase should accompany private key material' + expected_msg = "passphrase should accompany private key material" client = ParamikoSSHClient(**conn_params) self.assertRaisesRegexp(ValueError, expected_msg, client.connect) - @patch('paramiko.SSHClient', Mock) + @patch("paramiko.SSHClient", Mock) def test_passphrase_not_provided_for_encrypted_key_file(self): - path = os.path.join(get_resources_base_path(), - 'ssh', 'dummy_rsa_passphrase') - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'key_files': path} + path = os.path.join(get_resources_base_path(), "ssh", "dummy_rsa_passphrase") + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "key_files": path, + } mock = ParamikoSSHClient(**conn_params) - self.assertRaises(paramiko.ssh_exception.PasswordRequiredException, mock.connect) - - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=True)) + self.assertRaises( + paramiko.ssh_exception.PasswordRequiredException, mock.connect + ) + + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, "_is_key_file_needs_passphrase", MagicMock(return_value=True) + ) def test_key_with_passphrase_success(self): - path = os.path.join(get_resources_base_path(), - 'ssh', 'dummy_rsa_passphrase') + path = os.path.join(get_resources_base_path(), "ssh", "dummy_rsa_passphrase") - with open(path, 'r') as fp: + with open(path, "r") as fp: private_key = fp.read() # Key material provided - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'key_material': private_key, - 'passphrase': 'testphrase'} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "key_material": private_key, + "passphrase": "testphrase", + } mock = ParamikoSSHClient(**conn_params) mock.connect() - pkey = paramiko.RSAKey.from_private_key(StringIO(private_key), 'testphrase') - expected_conn = {'username': 'ubuntu', - 'allow_agent': False, - 'hostname': 'dummy.host.org', - 'look_for_keys': False, - 'pkey': pkey, - 'timeout': 30, - 'port': 22} + pkey = paramiko.RSAKey.from_private_key(StringIO(private_key), "testphrase") + expected_conn = { + "username": "ubuntu", + "allow_agent": False, + "hostname": "dummy.host.org", + "look_for_keys": False, + "pkey": pkey, + "timeout": 30, + "port": 22, + } mock.client.connect.assert_called_once_with(**expected_conn) # Path to private key file provided - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'key_files': path, - 'passphrase': 'testphrase'} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "key_files": path, + "passphrase": "testphrase", + } mock = ParamikoSSHClient(**conn_params) mock.connect() - expected_conn = {'username': 'ubuntu', - 'allow_agent': False, - 'hostname': 'dummy.host.org', - 'look_for_keys': False, - 'key_filename': path, - 'password': 'testphrase', - 'timeout': 30, - 'port': 22} + expected_conn = { + "username": "ubuntu", + "allow_agent": False, + "hostname": "dummy.host.org", + "look_for_keys": False, + "key_filename": path, + "password": "testphrase", + "timeout": 30, + "port": 22, + } mock.client.connect.assert_called_once_with(**expected_conn) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=True)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, "_is_key_file_needs_passphrase", MagicMock(return_value=True) + ) def test_passphrase_and_no_key(self): - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'passphrase': 'testphrase'} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "passphrase": "testphrase", + } - expected_msg = 'passphrase should accompany private key material' + expected_msg = "passphrase should accompany private key material" client = ParamikoSSHClient(**conn_params) - self.assertRaisesRegexp(ValueError, expected_msg, - client.connect) + self.assertRaisesRegexp(ValueError, expected_msg, client.connect) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=True)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, "_is_key_file_needs_passphrase", MagicMock(return_value=True) + ) def test_incorrect_passphrase(self): - path = os.path.join(get_resources_base_path(), - 'ssh', 'dummy_rsa_passphrase') + path = os.path.join(get_resources_base_path(), "ssh", "dummy_rsa_passphrase") - with open(path, 'r') as fp: + with open(path, "r") as fp: private_key = fp.read() - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'key_material': private_key, - 'passphrase': 'incorrect'} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "key_material": private_key, + "passphrase": "incorrect", + } mock = ParamikoSSHClient(**conn_params) - expected_msg = 'Invalid passphrase or invalid/unsupported key type' - self.assertRaisesRegexp(paramiko.ssh_exception.SSHException, - expected_msg, mock.connect) - - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + expected_msg = "Invalid passphrase or invalid/unsupported key type" + self.assertRaisesRegexp( + paramiko.ssh_exception.SSHException, expected_msg, mock.connect + ) + + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_key_material_contains_path_not_contents(self): - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu'} - key_materials = [ - '~/.ssh/id_rsa', - '/tmp/id_rsa', - 'C:\\id_rsa' - ] + conn_params = {"hostname": "dummy.host.org", "username": "ubuntu"} + key_materials = ["~/.ssh/id_rsa", "/tmp/id_rsa", "C:\\id_rsa"] - expected_msg = ('"private_key" parameter needs to contain private key data / content and ' - 'not a path') + expected_msg = ( + '"private_key" parameter needs to contain private key data / content and ' + "not a path" + ) for key_material in key_materials: conn_params = conn_params.copy() - conn_params['key_material'] = key_material + conn_params["key_material"] = key_material mock = ParamikoSSHClient(**conn_params) - self.assertRaisesRegexp(paramiko.ssh_exception.SSHException, - expected_msg, mock.connect) + self.assertRaisesRegexp( + paramiko.ssh_exception.SSHException, expected_msg, mock.connect + ) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_create_with_key(self): - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'key_files': 'id_rsa'} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "key_files": "id_rsa", + } mock = ParamikoSSHClient(**conn_params) mock.connect() - expected_conn = {'username': 'ubuntu', - 'allow_agent': False, - 'hostname': 'dummy.host.org', - 'look_for_keys': False, - 'key_filename': 'id_rsa', - 'timeout': 30, - 'port': 22} + expected_conn = { + "username": "ubuntu", + "allow_agent": False, + "hostname": "dummy.host.org", + "look_for_keys": False, + "key_filename": "id_rsa", + "timeout": 30, + "port": 22, + } mock.client.connect.assert_called_once_with(**expected_conn) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_create_with_key_via_bastion(self): - conn_params = {'hostname': 'dummy.host.org', - 'bastion_host': 'bastion.host.org', - 'username': 'ubuntu', - 'key_files': 'id_rsa'} + conn_params = { + "hostname": "dummy.host.org", + "bastion_host": "bastion.host.org", + "username": "ubuntu", + "key_files": "id_rsa", + } mock = ParamikoSSHClient(**conn_params) mock.connect() - expected_bastion_conn = {'username': 'ubuntu', - 'allow_agent': False, - 'hostname': 'bastion.host.org', - 'look_for_keys': False, - 'key_filename': 'id_rsa', - 'timeout': 30, - 'port': 22} + expected_bastion_conn = { + "username": "ubuntu", + "allow_agent": False, + "hostname": "bastion.host.org", + "look_for_keys": False, + "key_filename": "id_rsa", + "timeout": 30, + "port": 22, + } mock.bastion_client.connect.assert_called_once_with(**expected_bastion_conn) - expected_conn = {'username': 'ubuntu', - 'allow_agent': False, - 'hostname': 'dummy.host.org', - 'look_for_keys': False, - 'key_filename': 'id_rsa', - 'timeout': 30, - 'port': 22, - 'sock': mock.bastion_socket} + expected_conn = { + "username": "ubuntu", + "allow_agent": False, + "hostname": "dummy.host.org", + "look_for_keys": False, + "key_filename": "id_rsa", + "timeout": 30, + "port": 22, + "sock": mock.bastion_socket, + } mock.client.connect.assert_called_once_with(**expected_conn) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_create_with_password_and_key(self): - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'password': 'ubuntu', - 'key_files': 'id_rsa'} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "password": "ubuntu", + "key_files": "id_rsa", + } mock = ParamikoSSHClient(**conn_params) mock.connect() - expected_conn = {'username': 'ubuntu', - 'password': 'ubuntu', - 'allow_agent': False, - 'hostname': 'dummy.host.org', - 'look_for_keys': False, - 'key_filename': 'id_rsa', - 'timeout': 30, - 'port': 22} + expected_conn = { + "username": "ubuntu", + "password": "ubuntu", + "allow_agent": False, + "hostname": "dummy.host.org", + "look_for_keys": False, + "key_filename": "id_rsa", + "timeout": 30, + "port": 22, + } mock.client.connect.assert_called_once_with(**expected_conn) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_create_without_credentials(self): """ Initialize object with no credentials. @@ -394,44 +487,54 @@ def test_create_without_credentials(self): the final parameters at the last moment when we explicitly try to connect, all the credentials should be set to None. """ - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu'} + conn_params = {"hostname": "dummy.host.org", "username": "ubuntu"} mock = ParamikoSSHClient(**conn_params) self.assertEqual(mock.password, None) self.assertEqual(mock.key_material, None) self.assertEqual(mock.key_files, None) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_create_without_credentials_use_default_key(self): # No credentials are provided by default stanley ssh key exists so it should use that - cfg.CONF.set_override(name='ssh_key_file', override='stanley_rsa', group='system_user') + cfg.CONF.set_override( + name="ssh_key_file", override="stanley_rsa", group="system_user" + ) - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu'} + conn_params = {"hostname": "dummy.host.org", "username": "ubuntu"} mock = ParamikoSSHClient(**conn_params) mock.connect() - expected_conn = {'username': 'ubuntu', - 'hostname': 'dummy.host.org', - 'key_filename': 'stanley_rsa', - 'allow_agent': False, - 'look_for_keys': False, - 'timeout': 30, - 'port': 22} + expected_conn = { + "username": "ubuntu", + "hostname": "dummy.host.org", + "key_filename": "stanley_rsa", + "allow_agent": False, + "look_for_keys": False, + "timeout": 30, + "port": 22, + } mock.client.connect.assert_called_once_with(**expected_conn) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_consume_stdout', - MagicMock(return_value=StringIO(''))) - @patch.object(ParamikoSSHClient, '_consume_stderr', - MagicMock(return_value=StringIO(''))) - @patch.object(os.path, 'exists', MagicMock(return_value=True)) - @patch.object(os, 'stat', MagicMock(return_value=None)) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, "_consume_stdout", MagicMock(return_value=StringIO("")) + ) + @patch.object( + ParamikoSSHClient, "_consume_stderr", MagicMock(return_value=StringIO("")) + ) + @patch.object(os.path, "exists", MagicMock(return_value=True)) + @patch.object(os, "stat", MagicMock(return_value=None)) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_basic_usage_absolute_path(self): """ Basic execution. @@ -443,13 +546,15 @@ def test_basic_usage_absolute_path(self): # Connect behavior mock.connect() mock_cli = mock.client # The actual mocked object: SSHClient - expected_conn = {'username': 'ubuntu', - 'key_filename': '~/.ssh/ubuntu_ssh', - 'allow_agent': False, - 'hostname': 'dummy.host.org', - 'look_for_keys': False, - 'timeout': 28, - 'port': 8822} + expected_conn = { + "username": "ubuntu", + "key_filename": "~/.ssh/ubuntu_ssh", + "allow_agent": False, + "hostname": "dummy.host.org", + "look_for_keys": False, + "timeout": 28, + "port": 8822, + } mock_cli.connect.assert_called_once_with(**expected_conn) mock.put(sd, sd, mirror_local_mode=False) @@ -458,21 +563,23 @@ def test_basic_usage_absolute_path(self): mock.run(sd) # Make assertions over 'run' method - mock_cli.get_transport().open_session().exec_command \ - .assert_called_once_with(sd) + mock_cli.get_transport().open_session().exec_command.assert_called_once_with(sd) mock.close() - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_delete_script(self): """ Provide a basic test with 'delete' action. """ mock = self.ssh_cli # script to execute - sd = '/root/random_script.sh' + sd = "/root/random_script.sh" mock.connect() @@ -482,91 +589,110 @@ def test_delete_script(self): mock.close() - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) - @patch.object(ParamikoSSHClient, 'exists', return_value=False) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) + @patch.object(ParamikoSSHClient, "exists", return_value=False) def test_put_dir(self, *args): mock = self.ssh_cli mock.connect() - local_dir = os.path.join(get_resources_base_path(), 'packs') - mock.put_dir(local_path=local_dir, remote_path='/tmp') + local_dir = os.path.join(get_resources_base_path(), "packs") + mock.put_dir(local_path=local_dir, remote_path="/tmp") mock_cli = mock.client # The actual mocked object: SSHClient # Assert that expected dirs are created on remote box. - calls = [call('/tmp/packs/pythonactions'), call('/tmp/packs/pythonactions/actions')] + calls = [ + call("/tmp/packs/pythonactions"), + call("/tmp/packs/pythonactions/actions"), + ] mock_cli.open_sftp().mkdir.assert_has_calls(calls, any_order=True) # Assert that expected files are copied to remote box. - local_file = os.path.join(get_resources_base_path(), - 'packs/pythonactions/actions/pascal_row.py') - remote_file = os.path.join('/tmp', 'packs/pythonactions/actions/pascal_row.py') + local_file = os.path.join( + get_resources_base_path(), "packs/pythonactions/actions/pascal_row.py" + ) + remote_file = os.path.join("/tmp", "packs/pythonactions/actions/pascal_row.py") calls = [call(local_file, remote_file)] mock_cli.open_sftp().put.assert_has_calls(calls, any_order=True) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_consume_stdout(self): # Test utf-8 decoding of ``stdout`` still works fine when reading CHUNK_SIZE splits a # multi-byte utf-8 character in the middle. We should wait to collect all bytes # and finally decode. - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu'} + conn_params = {"hostname": "dummy.host.org", "username": "ubuntu"} mock = ParamikoSSHClient(**conn_params) mock.CHUNK_SIZE = 1 chan = Mock() chan.recv_ready.side_effect = [True, True, True, True, False] - chan.recv.side_effect = [b'\xF0', b'\x90', b'\x8D', b'\x88'] + chan.recv.side_effect = [b"\xF0", b"\x90", b"\x8D", b"\x88"] try: - b'\xF0'.decode('utf-8') - self.fail('Test fixture is not right.') + b"\xF0".decode("utf-8") + self.fail("Test fixture is not right.") except UnicodeDecodeError: pass stdout = mock._consume_stdout(chan) - self.assertEqual(u'\U00010348', stdout.getvalue()) + self.assertEqual("\U00010348", stdout.getvalue()) - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_consume_stderr(self): # Test utf-8 decoding of ``stderr`` still works fine when reading CHUNK_SIZE splits a # multi-byte utf-8 character in the middle. We should wait to collect all bytes # and finally decode. - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu'} + conn_params = {"hostname": "dummy.host.org", "username": "ubuntu"} mock = ParamikoSSHClient(**conn_params) mock.CHUNK_SIZE = 1 chan = Mock() chan.recv_stderr_ready.side_effect = [True, True, True, True, False] - chan.recv_stderr.side_effect = [b'\xF0', b'\x90', b'\x8D', b'\x88'] + chan.recv_stderr.side_effect = [b"\xF0", b"\x90", b"\x8D", b"\x88"] try: - b'\xF0'.decode('utf-8') - self.fail('Test fixture is not right.') + b"\xF0".decode("utf-8") + self.fail("Test fixture is not right.") except UnicodeDecodeError: pass stderr = mock._consume_stderr(chan) - self.assertEqual(u'\U00010348', stderr.getvalue()) - - @patch('paramiko.SSHClient', Mock) - @patch.object(ParamikoSSHClient, '_consume_stdout', - MagicMock(return_value=StringIO(''))) - @patch.object(ParamikoSSHClient, '_consume_stderr', - MagicMock(return_value=StringIO(''))) - @patch.object(os.path, 'exists', MagicMock(return_value=True)) - @patch.object(os, 'stat', MagicMock(return_value=None)) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + self.assertEqual("\U00010348", stderr.getvalue()) + + @patch("paramiko.SSHClient", Mock) + @patch.object( + ParamikoSSHClient, "_consume_stdout", MagicMock(return_value=StringIO("")) + ) + @patch.object( + ParamikoSSHClient, "_consume_stderr", MagicMock(return_value=StringIO("")) + ) + @patch.object(os.path, "exists", MagicMock(return_value=True)) + @patch.object(os, "stat", MagicMock(return_value=None)) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_sftp_connection_is_only_established_if_required(self): # Verify that SFTP connection is lazily established only if and when needed. - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', 'password': 'ubuntu'} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "password": "ubuntu", + } # Verify sftp connection and client hasn't been established yet client = ParamikoSSHClient(**conn_params) @@ -577,7 +703,7 @@ def test_sftp_connection_is_only_established_if_required(self): # run method doesn't require sftp access so it shouldn't establish connection client = ParamikoSSHClient(**conn_params) client.connect() - client.run(cmd='whoami') + client.run(cmd="whoami") self.assertIsNone(client.sftp_client) @@ -585,7 +711,7 @@ def test_sftp_connection_is_only_established_if_required(self): # put client = ParamikoSSHClient(**conn_params) client.connect() - path = '/root/random_script.sh' + path = "/root/random_script.sh" client.put(path, path, mirror_local_mode=False) self.assertIsNotNone(client.sftp_client) @@ -593,14 +719,14 @@ def test_sftp_connection_is_only_established_if_required(self): # exists client = ParamikoSSHClient(**conn_params) client.connect() - client.exists('/root/somepath.txt') + client.exists("/root/somepath.txt") self.assertIsNotNone(client.sftp_client) # mkdir client = ParamikoSSHClient(**conn_params) client.connect() - client.mkdir('/root/somedirfoo') + client.mkdir("/root/somedirfoo") self.assertIsNotNone(client.sftp_client) @@ -614,26 +740,26 @@ def test_sftp_connection_is_only_established_if_required(self): # Verify SFTP connection is closed if it's opened client = ParamikoSSHClient(**conn_params) client.connect() - client.mkdir('/root/somedirfoo') + client.mkdir("/root/somedirfoo") self.assertIsNotNone(client.sftp_client) client.close() self.assertEqual(client.sftp_client.close.call_count, 1) - @patch('paramiko.SSHClient', Mock) - @patch.object(os.path, 'exists', MagicMock(return_value=True)) - @patch.object(os, 'stat', MagicMock(return_value=None)) + @patch("paramiko.SSHClient", Mock) + @patch.object(os.path, "exists", MagicMock(return_value=True)) + @patch.object(os, "stat", MagicMock(return_value=None)) def test_handle_stdout_and_stderr_line_funcs(self): mock_handle_stdout_line_func = mock.Mock() mock_handle_stderr_line_func = mock.Mock() conn_params = { - 'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'password': 'ubuntu', - 'handle_stdout_line_func': mock_handle_stdout_line_func, - 'handle_stderr_line_func': mock_handle_stderr_line_func + "hostname": "dummy.host.org", + "username": "ubuntu", + "password": "ubuntu", + "handle_stdout_line_func": mock_handle_stdout_line_func, + "handle_stderr_line_func": mock_handle_stderr_line_func, } client = ParamikoSSHClient(**conn_params) client.connect() @@ -654,6 +780,7 @@ def mock_recv_ready(): return True return False + return mock_recv_ready def mock_recv_stderr_ready_factory(chan): @@ -665,12 +792,13 @@ def mock_recv_stderr_ready(): return True return False + return mock_recv_stderr_ready mock_chan.recv_ready = mock_recv_ready_factory(mock_chan) mock_chan.recv_stderr_ready = mock_recv_stderr_ready_factory(mock_chan) - mock_chan.recv.return_value = 'stdout 1\nstdout 2\nstdout 3' - mock_chan.recv_stderr.return_value = 'stderr 1\nstderr 2\nstderr 3' + mock_chan.recv.return_value = "stdout 1\nstdout 2\nstdout 3" + mock_chan.recv_stderr.return_value = "stderr 1\nstderr 2\nstderr 3" # call_line_handler_func is False so handler functions shouldn't be called client.run(cmd='echo "test"', call_line_handler_func=False) @@ -686,132 +814,176 @@ def mock_recv_stderr_ready(): client.run(cmd='echo "test"', call_line_handler_func=True) self.assertEqual(mock_handle_stdout_line_func.call_count, 3) - self.assertEqual(mock_handle_stdout_line_func.call_args_list[0][1]['line'], 'stdout 1\n') - self.assertEqual(mock_handle_stdout_line_func.call_args_list[1][1]['line'], 'stdout 2\n') - self.assertEqual(mock_handle_stdout_line_func.call_args_list[2][1]['line'], 'stdout 3\n') + self.assertEqual( + mock_handle_stdout_line_func.call_args_list[0][1]["line"], "stdout 1\n" + ) + self.assertEqual( + mock_handle_stdout_line_func.call_args_list[1][1]["line"], "stdout 2\n" + ) + self.assertEqual( + mock_handle_stdout_line_func.call_args_list[2][1]["line"], "stdout 3\n" + ) self.assertEqual(mock_handle_stderr_line_func.call_count, 3) - self.assertEqual(mock_handle_stdout_line_func.call_args_list[0][1]['line'], 'stdout 1\n') - self.assertEqual(mock_handle_stdout_line_func.call_args_list[1][1]['line'], 'stdout 2\n') - self.assertEqual(mock_handle_stdout_line_func.call_args_list[2][1]['line'], 'stdout 3\n') - - @patch('paramiko.SSHClient') + self.assertEqual( + mock_handle_stdout_line_func.call_args_list[0][1]["line"], "stdout 1\n" + ) + self.assertEqual( + mock_handle_stdout_line_func.call_args_list[1][1]["line"], "stdout 2\n" + ) + self.assertEqual( + mock_handle_stdout_line_func.call_args_list[2][1]["line"], "stdout 3\n" + ) + + @patch("paramiko.SSHClient") def test_use_ssh_config_port_value_provided_in_the_config(self, mock_sshclient): - cfg.CONF.set_override(name='use_ssh_config', override=True, group='ssh_runner') + cfg.CONF.set_override(name="use_ssh_config", override=True, group="ssh_runner") - ssh_config_file_path = os.path.join(get_resources_base_path(), 'ssh', 'empty_config') - cfg.CONF.set_override(name='ssh_config_file_path', override=ssh_config_file_path, - group='ssh_runner') + ssh_config_file_path = os.path.join( + get_resources_base_path(), "ssh", "empty_config" + ) + cfg.CONF.set_override( + name="ssh_config_file_path", + override=ssh_config_file_path, + group="ssh_runner", + ) # 1. Default port is used (not explicitly provided) mock_client = mock.Mock() mock_sshclient.return_value = mock_client - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'password': 'pass', - 'timeout': '600'} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "password": "pass", + "timeout": "600", + } ssh_client = ParamikoSSHClient(**conn_params) ssh_client.connect() call_kwargs = mock_client.connect.call_args[1] - self.assertEqual(call_kwargs['port'], 22) + self.assertEqual(call_kwargs["port"], 22) mock_client = mock.Mock() mock_sshclient.return_value = mock_client - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'password': 'pass', - 'port': None, - 'timeout': '600'} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "password": "pass", + "port": None, + "timeout": "600", + } ssh_client = ParamikoSSHClient(**conn_params) ssh_client.connect() call_kwargs = mock_client.connect.call_args[1] - self.assertEqual(call_kwargs['port'], 22) + self.assertEqual(call_kwargs["port"], 22) # 2. Default port is used (explicitly provided) mock_client = mock.Mock() mock_sshclient.return_value = mock_client - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'password': 'pass', - 'port': DEFAULT_SSH_PORT, - 'timeout': '600'} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "password": "pass", + "port": DEFAULT_SSH_PORT, + "timeout": "600", + } ssh_client = ParamikoSSHClient(**conn_params) ssh_client.connect() call_kwargs = mock_client.connect.call_args[1] - self.assertEqual(call_kwargs['port'], DEFAULT_SSH_PORT) - self.assertEqual(call_kwargs['port'], 22) + self.assertEqual(call_kwargs["port"], DEFAULT_SSH_PORT) + self.assertEqual(call_kwargs["port"], 22) # 3. Custom port is used (explicitly provided) mock_client = mock.Mock() mock_sshclient.return_value = mock_client - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'password': 'pass', - 'port': 5555, - 'timeout': '600'} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "password": "pass", + "port": 5555, + "timeout": "600", + } ssh_client = ParamikoSSHClient(**conn_params) ssh_client.connect() call_kwargs = mock_client.connect.call_args[1] - self.assertEqual(call_kwargs['port'], 5555) + self.assertEqual(call_kwargs["port"], 5555) # 4. Custom port is specified in the ssh config (it has precedence over default port) - ssh_config_file_path = os.path.join(get_resources_base_path(), 'ssh', - 'ssh_config_custom_port') - cfg.CONF.set_override(name='ssh_config_file_path', override=ssh_config_file_path, - group='ssh_runner') + ssh_config_file_path = os.path.join( + get_resources_base_path(), "ssh", "ssh_config_custom_port" + ) + cfg.CONF.set_override( + name="ssh_config_file_path", + override=ssh_config_file_path, + group="ssh_runner", + ) mock_client = mock.Mock() mock_sshclient.return_value = mock_client - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'password': 'pass'} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "password": "pass", + } ssh_client = ParamikoSSHClient(**conn_params) ssh_client.connect() call_kwargs = mock_client.connect.call_args[1] - self.assertEqual(call_kwargs['port'], 6677) + self.assertEqual(call_kwargs["port"], 6677) mock_client = mock.Mock() mock_sshclient.return_value = mock_client - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'password': 'pass', - 'port': DEFAULT_SSH_PORT} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "password": "pass", + "port": DEFAULT_SSH_PORT, + } ssh_client = ParamikoSSHClient(**conn_params) ssh_client.connect() call_kwargs = mock_client.connect.call_args[1] - self.assertEqual(call_kwargs['port'], 6677) + self.assertEqual(call_kwargs["port"], 6677) # 5. Custom port is specified in ssh config, but one is also provided via runner parameter # (runner parameter one has precedence) - ssh_config_file_path = os.path.join(get_resources_base_path(), 'ssh', - 'ssh_config_custom_port') - cfg.CONF.set_override(name='ssh_config_file_path', override=ssh_config_file_path, - group='ssh_runner') + ssh_config_file_path = os.path.join( + get_resources_base_path(), "ssh", "ssh_config_custom_port" + ) + cfg.CONF.set_override( + name="ssh_config_file_path", + override=ssh_config_file_path, + group="ssh_runner", + ) mock_client = mock.Mock() mock_sshclient.return_value = mock_client - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'password': 'pass', - 'port': 9999} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "password": "pass", + "port": 9999, + } ssh_client = ParamikoSSHClient(**conn_params) ssh_client.connect() call_kwargs = mock_client.connect.call_args[1] - self.assertEqual(call_kwargs['port'], 9999) + self.assertEqual(call_kwargs["port"], 9999) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_socket_closed(self): - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'password': 'pass', - 'timeout': '600'} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "password": "pass", + "timeout": "600", + } ssh_client = ParamikoSSHClient(**conn_params) # Make sure .close() doesn't actually call anything real @@ -840,13 +1012,18 @@ def test_socket_closed(self): self.assertEqual(ssh_client.bastion_socket.close.call_count, 1) self.assertEqual(ssh_client.bastion_client.close.call_count, 1) - @patch.object(ParamikoSSHClient, '_is_key_file_needs_passphrase', - MagicMock(return_value=False)) + @patch.object( + ParamikoSSHClient, + "_is_key_file_needs_passphrase", + MagicMock(return_value=False), + ) def test_socket_not_closed_if_none(self): - conn_params = {'hostname': 'dummy.host.org', - 'username': 'ubuntu', - 'password': 'pass', - 'timeout': '600'} + conn_params = { + "hostname": "dummy.host.org", + "username": "ubuntu", + "password": "pass", + "timeout": "600", + } ssh_client = ParamikoSSHClient(**conn_params) # Make sure .close() doesn't actually call anything real diff --git a/st2actions/tests/unit/test_paramiko_ssh_runner.py b/st2actions/tests/unit/test_paramiko_ssh_runner.py index 8467264c5b7..f42746d6026 100644 --- a/st2actions/tests/unit/test_paramiko_ssh_runner.py +++ b/st2actions/tests/unit/test_paramiko_ssh_runner.py @@ -29,6 +29,7 @@ import st2tests.config as tests_config from st2tests.fixturesloader import get_resources_base_path + tests_config.parse_args() @@ -38,195 +39,192 @@ def run(self): class ParamikoSSHRunnerTestCase(unittest2.TestCase): - @mock.patch('st2common.runners.paramiko_ssh_runner.ParallelSSHClient') + @mock.patch("st2common.runners.paramiko_ssh_runner.ParallelSSHClient") def test_pre_run(self, mock_client): # Test case which verifies that ParamikoSSHClient is instantiated with the correct arguments - private_key_path = os.path.join(get_resources_base_path(), 'ssh', 'dummy_rsa') + private_key_path = os.path.join(get_resources_base_path(), "ssh", "dummy_rsa") - with open(private_key_path, 'r') as fp: + with open(private_key_path, "r") as fp: private_key = fp.read() # Username and password provided - runner = Runner('id') + runner = Runner("id") runner.context = {} runner_parameters = { - RUNNER_HOSTS: 'localhost', - RUNNER_USERNAME: 'someuser1', - RUNNER_PASSWORD: 'somepassword' + RUNNER_HOSTS: "localhost", + RUNNER_USERNAME: "someuser1", + RUNNER_PASSWORD: "somepassword", } runner.runner_parameters = runner_parameters runner.pre_run() expected_kwargs = { - 'hosts': ['localhost'], - 'user': 'someuser1', - 'password': 'somepassword', - 'port': None, - 'concurrency': 1, - 'bastion_host': None, - 'raise_on_any_error': False, - 'connect': True, - 'handle_stdout_line_func': mock.ANY, - 'handle_stderr_line_func': mock.ANY + "hosts": ["localhost"], + "user": "someuser1", + "password": "somepassword", + "port": None, + "concurrency": 1, + "bastion_host": None, + "raise_on_any_error": False, + "connect": True, + "handle_stdout_line_func": mock.ANY, + "handle_stderr_line_func": mock.ANY, } mock_client.assert_called_with(**expected_kwargs) # Private key provided as raw key material - runner = Runner('id') + runner = Runner("id") runner.context = {} runner_parameters = { - RUNNER_HOSTS: 'localhost', - RUNNER_USERNAME: 'someuser2', + RUNNER_HOSTS: "localhost", + RUNNER_USERNAME: "someuser2", RUNNER_PRIVATE_KEY: private_key, - RUNNER_SSH_PORT: 22 + RUNNER_SSH_PORT: 22, } runner.runner_parameters = runner_parameters runner.pre_run() expected_kwargs = { - 'hosts': ['localhost'], - 'user': 'someuser2', - 'pkey_material': private_key, - 'port': 22, - 'concurrency': 1, - 'bastion_host': None, - 'raise_on_any_error': False, - 'connect': True, - 'handle_stdout_line_func': mock.ANY, - 'handle_stderr_line_func': mock.ANY + "hosts": ["localhost"], + "user": "someuser2", + "pkey_material": private_key, + "port": 22, + "concurrency": 1, + "bastion_host": None, + "raise_on_any_error": False, + "connect": True, + "handle_stdout_line_func": mock.ANY, + "handle_stderr_line_func": mock.ANY, } mock_client.assert_called_with(**expected_kwargs) # Private key provided as raw key material + passphrase - runner = Runner('id') + runner = Runner("id") runner.context = {} runner_parameters = { - RUNNER_HOSTS: 'localhost21', - RUNNER_USERNAME: 'someuser21', + RUNNER_HOSTS: "localhost21", + RUNNER_USERNAME: "someuser21", RUNNER_PRIVATE_KEY: private_key, - RUNNER_PASSPHRASE: 'passphrase21', - RUNNER_SSH_PORT: 22 + RUNNER_PASSPHRASE: "passphrase21", + RUNNER_SSH_PORT: 22, } runner.runner_parameters = runner_parameters runner.pre_run() expected_kwargs = { - 'hosts': ['localhost21'], - 'user': 'someuser21', - 'pkey_material': private_key, - 'passphrase': 'passphrase21', - 'port': 22, - 'concurrency': 1, - 'bastion_host': None, - 'raise_on_any_error': False, - 'connect': True, - 'handle_stdout_line_func': mock.ANY, - 'handle_stderr_line_func': mock.ANY + "hosts": ["localhost21"], + "user": "someuser21", + "pkey_material": private_key, + "passphrase": "passphrase21", + "port": 22, + "concurrency": 1, + "bastion_host": None, + "raise_on_any_error": False, + "connect": True, + "handle_stdout_line_func": mock.ANY, + "handle_stderr_line_func": mock.ANY, } mock_client.assert_called_with(**expected_kwargs) # Private key provided as path to the private key file - runner = Runner('id') + runner = Runner("id") runner.context = {} runner_parameters = { - RUNNER_HOSTS: 'localhost', - RUNNER_USERNAME: 'someuser3', + RUNNER_HOSTS: "localhost", + RUNNER_USERNAME: "someuser3", RUNNER_PRIVATE_KEY: private_key_path, - RUNNER_SSH_PORT: 22 + RUNNER_SSH_PORT: 22, } runner.runner_parameters = runner_parameters runner.pre_run() expected_kwargs = { - 'hosts': ['localhost'], - 'user': 'someuser3', - 'pkey_file': private_key_path, - 'port': 22, - 'concurrency': 1, - 'bastion_host': None, - 'raise_on_any_error': False, - 'connect': True, - 'handle_stdout_line_func': mock.ANY, - 'handle_stderr_line_func': mock.ANY + "hosts": ["localhost"], + "user": "someuser3", + "pkey_file": private_key_path, + "port": 22, + "concurrency": 1, + "bastion_host": None, + "raise_on_any_error": False, + "connect": True, + "handle_stdout_line_func": mock.ANY, + "handle_stderr_line_func": mock.ANY, } mock_client.assert_called_with(**expected_kwargs) # Private key provided as path to the private key file + passphrase - runner = Runner('id') + runner = Runner("id") runner.context = {} runner_parameters = { - RUNNER_HOSTS: 'localhost31', - RUNNER_USERNAME: 'someuser31', + RUNNER_HOSTS: "localhost31", + RUNNER_USERNAME: "someuser31", RUNNER_PRIVATE_KEY: private_key_path, - RUNNER_PASSPHRASE: 'passphrase31', - RUNNER_SSH_PORT: 22 + RUNNER_PASSPHRASE: "passphrase31", + RUNNER_SSH_PORT: 22, } runner.runner_parameters = runner_parameters runner.pre_run() expected_kwargs = { - 'hosts': ['localhost31'], - 'user': 'someuser31', - 'pkey_file': private_key_path, - 'passphrase': 'passphrase31', - 'port': 22, - 'concurrency': 1, - 'bastion_host': None, - 'raise_on_any_error': False, - 'connect': True, - 'handle_stdout_line_func': mock.ANY, - 'handle_stderr_line_func': mock.ANY + "hosts": ["localhost31"], + "user": "someuser31", + "pkey_file": private_key_path, + "passphrase": "passphrase31", + "port": 22, + "concurrency": 1, + "bastion_host": None, + "raise_on_any_error": False, + "connect": True, + "handle_stdout_line_func": mock.ANY, + "handle_stderr_line_func": mock.ANY, } mock_client.assert_called_with(**expected_kwargs) # No password or private key provided, should default to system user private key - runner = Runner('id') + runner = Runner("id") runner.context = {} - runner_parameters = { - RUNNER_HOSTS: 'localhost4', - RUNNER_SSH_PORT: 22 - } + runner_parameters = {RUNNER_HOSTS: "localhost4", RUNNER_SSH_PORT: 22} runner.runner_parameters = runner_parameters runner.pre_run() expected_kwargs = { - 'hosts': ['localhost4'], - 'user': None, - 'pkey_file': None, - 'port': 22, - 'concurrency': 1, - 'bastion_host': None, - 'raise_on_any_error': False, - 'connect': True, - 'handle_stdout_line_func': mock.ANY, - 'handle_stderr_line_func': mock.ANY + "hosts": ["localhost4"], + "user": None, + "pkey_file": None, + "port": 22, + "concurrency": 1, + "bastion_host": None, + "raise_on_any_error": False, + "connect": True, + "handle_stdout_line_func": mock.ANY, + "handle_stderr_line_func": mock.ANY, } mock_client.assert_called_with(**expected_kwargs) - @mock.patch('st2common.runners.paramiko_ssh_runner.ParallelSSHClient') + @mock.patch("st2common.runners.paramiko_ssh_runner.ParallelSSHClient") def test_post_run(self, mock_client): # Verify that the SSH connections are closed on post_run - runner = Runner('id') + runner = Runner("id") runner.context = {} runner_parameters = { - RUNNER_HOSTS: 'localhost', - RUNNER_USERNAME: 'someuser1', - RUNNER_PASSWORD: 'somepassword' + RUNNER_HOSTS: "localhost", + RUNNER_USERNAME: "someuser1", + RUNNER_PASSWORD: "somepassword", } runner.runner_parameters = runner_parameters runner.pre_run() expected_kwargs = { - 'hosts': ['localhost'], - 'user': 'someuser1', - 'password': 'somepassword', - 'port': None, - 'concurrency': 1, - 'bastion_host': None, - 'raise_on_any_error': False, - 'connect': True, - 'handle_stdout_line_func': mock.ANY, - 'handle_stderr_line_func': mock.ANY + "hosts": ["localhost"], + "user": "someuser1", + "password": "somepassword", + "port": None, + "concurrency": 1, + "bastion_host": None, + "raise_on_any_error": False, + "connect": True, + "handle_stdout_line_func": mock.ANY, + "handle_stderr_line_func": mock.ANY, } mock_client.assert_called_with(**expected_kwargs) self.assertEqual(runner._parallel_ssh_client.close.call_count, 0) diff --git a/st2actions/tests/unit/test_policies.py b/st2actions/tests/unit/test_policies.py index 4be7af59e5e..f16ffbcb0b5 100644 --- a/st2actions/tests/unit/test_policies.py +++ b/st2actions/tests/unit/test_policies.py @@ -37,37 +37,34 @@ TEST_FIXTURES = { - 'actions': [ - 'action1.yaml' - ], - 'policytypes': [ - 'fake_policy_type_1.yaml', - 'fake_policy_type_2.yaml' - ], - 'policies': [ - 'policy_1.yaml', - 'policy_2.yaml' - ] + "actions": ["action1.yaml"], + "policytypes": ["fake_policy_type_1.yaml", "fake_policy_type_2.yaml"], + "policies": ["policy_1.yaml", "policy_2.yaml"], } -PACK = 'generic' +PACK = "generic" LOADER = FixturesLoader() FIXTURES = LOADER.load_fixtures(fixtures_pack=PACK, fixtures_dict=TEST_FIXTURES) @mock.patch.object( - CUDPublisher, 'publish_update', - mock.MagicMock(side_effect=MockExecutionPublisher.publish_update)) + CUDPublisher, + "publish_update", + mock.MagicMock(side_effect=MockExecutionPublisher.publish_update), +) +@mock.patch.object(CUDPublisher, "publish_create", mock.MagicMock(return_value=None)) @mock.patch.object( - CUDPublisher, 'publish_create', - mock.MagicMock(return_value=None)) -@mock.patch.object( - LiveActionPublisher, 'publish_state', - mock.MagicMock(side_effect=MockLiveActionPublisher.publish_state)) -@mock.patch('st2common.runners.base.get_runner', mock.Mock(return_value=runner.get_runner())) -@mock.patch('st2actions.container.base.get_runner', mock.Mock(return_value=runner.get_runner())) + LiveActionPublisher, + "publish_state", + mock.MagicMock(side_effect=MockLiveActionPublisher.publish_state), +) +@mock.patch( + "st2common.runners.base.get_runner", mock.Mock(return_value=runner.get_runner()) +) +@mock.patch( + "st2actions.container.base.get_runner", mock.Mock(return_value=runner.get_runner()) +) class SchedulingPolicyTest(ExecutionDbTestCase): - @classmethod def setUpClass(cls): super(SchedulingPolicyTest, cls).setUpClass() @@ -75,15 +72,15 @@ def setUpClass(cls): # Register runners runners_registrar.register_runners() - for _, fixture in six.iteritems(FIXTURES['actions']): + for _, fixture in six.iteritems(FIXTURES["actions"]): instance = ActionAPI(**fixture) Action.add_or_update(ActionAPI.to_model(instance)) - for _, fixture in six.iteritems(FIXTURES['policytypes']): + for _, fixture in six.iteritems(FIXTURES["policytypes"]): instance = PolicyTypeAPI(**fixture) PolicyType.add_or_update(PolicyTypeAPI.to_model(instance)) - for _, fixture in six.iteritems(FIXTURES['policies']): + for _, fixture in six.iteritems(FIXTURES["policies"]): instance = PolicyAPI(**fixture) Policy.add_or_update(PolicyAPI.to_model(instance)) @@ -91,35 +88,54 @@ def tearDown(self): # Ensure all liveactions are canceled at end of each test. for liveaction in LiveAction.get_all(): action_service.update_status( - liveaction, action_constants.LIVEACTION_STATUS_CANCELED) + liveaction, action_constants.LIVEACTION_STATUS_CANCELED + ) @mock.patch.object( - FakeConcurrencyApplicator, 'apply_before', + FakeConcurrencyApplicator, + "apply_before", mock.MagicMock( - side_effect=FakeConcurrencyApplicator(None, None, threshold=3).apply_before)) + side_effect=FakeConcurrencyApplicator(None, None, threshold=3).apply_before + ), + ) @mock.patch.object( - RaiseExceptionApplicator, 'apply_before', - mock.MagicMock( - side_effect=RaiseExceptionApplicator(None, None).apply_before)) + RaiseExceptionApplicator, + "apply_before", + mock.MagicMock(side_effect=RaiseExceptionApplicator(None, None).apply_before), + ) @mock.patch.object( - FakeConcurrencyApplicator, 'apply_after', + FakeConcurrencyApplicator, + "apply_after", mock.MagicMock( - side_effect=FakeConcurrencyApplicator(None, None, threshold=3).apply_after)) + side_effect=FakeConcurrencyApplicator(None, None, threshold=3).apply_after + ), + ) @mock.patch.object( - RaiseExceptionApplicator, 'apply_after', - mock.MagicMock( - side_effect=RaiseExceptionApplicator(None, None).apply_after)) + RaiseExceptionApplicator, + "apply_after", + mock.MagicMock(side_effect=RaiseExceptionApplicator(None, None).apply_after), + ) def test_apply(self): - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'}) + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "foo"} + ) liveaction, _ = action_service.request(liveaction) - liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED) + liveaction = self._wait_on_status( + liveaction, action_constants.LIVEACTION_STATUS_SUCCEEDED + ) FakeConcurrencyApplicator.apply_before.assert_called_once_with(liveaction) RaiseExceptionApplicator.apply_before.assert_called_once_with(liveaction) FakeConcurrencyApplicator.apply_after.assert_called_once_with(liveaction) RaiseExceptionApplicator.apply_after.assert_called_once_with(liveaction) - @mock.patch.object(FakeConcurrencyApplicator, 'get_threshold', mock.MagicMock(return_value=0)) + @mock.patch.object( + FakeConcurrencyApplicator, "get_threshold", mock.MagicMock(return_value=0) + ) def test_enforce(self): - liveaction = LiveActionDB(action='wolfpack.action-1', parameters={'actionstr': 'foo'}) + liveaction = LiveActionDB( + action="wolfpack.action-1", parameters={"actionstr": "foo"} + ) liveaction, _ = action_service.request(liveaction) - liveaction = self._wait_on_status(liveaction, action_constants.LIVEACTION_STATUS_CANCELED) + liveaction = self._wait_on_status( + liveaction, action_constants.LIVEACTION_STATUS_CANCELED + ) diff --git a/st2actions/tests/unit/test_polling_async_runner.py b/st2actions/tests/unit/test_polling_async_runner.py index 435f7eb9b68..c48bb9aa675 100644 --- a/st2actions/tests/unit/test_polling_async_runner.py +++ b/st2actions/tests/unit/test_polling_async_runner.py @@ -14,15 +14,16 @@ # limitations under the License. from __future__ import absolute_import + try: import simplejson as json except: import json from st2common.runners.base import PollingAsyncActionRunner -from st2common.constants.action import (LIVEACTION_STATUS_RUNNING) +from st2common.constants.action import LIVEACTION_STATUS_RUNNING -RAISE_PROPERTY = 'raise' +RAISE_PROPERTY = "raise" def get_runner(): @@ -31,7 +32,7 @@ def get_runner(): class PollingAsyncTestRunner(PollingAsyncActionRunner): def __init__(self): - super(PollingAsyncTestRunner, self).__init__(runner_id='1') + super(PollingAsyncTestRunner, self).__init__(runner_id="1") self.pre_run_called = False self.run_called = False self.post_run_called = False @@ -43,14 +44,11 @@ def run(self, action_params): self.run_called = True result = {} if self.runner_parameters.get(RAISE_PROPERTY, False): - raise Exception('Raise required.') + raise Exception("Raise required.") else: - result = { - 'ran': True, - 'action_params': action_params - } + result = {"ran": True, "action_params": action_params} - return (LIVEACTION_STATUS_RUNNING, json.dumps(result), {'id': 'foo'}) + return (LIVEACTION_STATUS_RUNNING, json.dumps(result), {"id": "foo"}) def post_run(self, status, result): self.post_run_called = True diff --git a/st2actions/tests/unit/test_queue_consumers.py b/st2actions/tests/unit/test_queue_consumers.py index 1550a82e224..80d3a09c260 100644 --- a/st2actions/tests/unit/test_queue_consumers.py +++ b/st2actions/tests/unit/test_queue_consumers.py @@ -18,6 +18,7 @@ import st2tests import st2tests.config as tests_config + tests_config.parse_args() import mock @@ -39,16 +40,13 @@ from st2tests.base import ExecutionDbTestCase -PACKS = [ - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' -] +PACKS = [st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core"] -@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) -@mock.patch.object(executions, 'update_execution', mock.MagicMock()) -@mock.patch.object(Message, 'ack', mock.MagicMock()) +@mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) +@mock.patch.object(executions, "update_execution", mock.MagicMock()) +@mock.patch.object(Message, "ack", mock.MagicMock()) class QueueConsumerTest(ExecutionDbTestCase): - @classmethod def setUpClass(cls): super(QueueConsumerTest, cls).setUpClass() @@ -58,8 +56,7 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: @@ -71,14 +68,16 @@ def __init__(self, *args, **kwargs): self.scheduling_queue = scheduling_queue.get_handler() self.dispatcher = worker.get_worker() - def _create_liveaction_db(self, status=action_constants.LIVEACTION_STATUS_REQUESTED): - action_db = action_utils.get_action_by_ref('core.noop') + def _create_liveaction_db( + self, status=action_constants.LIVEACTION_STATUS_REQUESTED + ): + action_db = action_utils.get_action_by_ref("core.noop") liveaction_db = LiveActionDB( action=action_db.ref, parameters=None, start_timestamp=date_utils.get_datetime_utc_now(), - status=status + status=status, ) liveaction_db = action.LiveAction.add_or_update(liveaction_db, publish=False) @@ -91,15 +90,16 @@ def _process_request(self, liveaction_db): queued_request = self.scheduling_queue._get_next_execution() self.scheduling_queue._handle_execution(queued_request) - @mock.patch.object(RunnerContainer, 'dispatch', mock.MagicMock(return_value={'key': 'value'})) + @mock.patch.object( + RunnerContainer, "dispatch", mock.MagicMock(return_value={"key": "value"}) + ) def test_execute(self): liveaction_db = self._create_liveaction_db() self._process_request(liveaction_db) scheduled_liveaction_db = action_utils.get_liveaction_by_id(liveaction_db.id) scheduled_liveaction_db = self._wait_on_status( - scheduled_liveaction_db, - action_constants.LIVEACTION_STATUS_SCHEDULED + scheduled_liveaction_db, action_constants.LIVEACTION_STATUS_SCHEDULED ) self.assertDictEqual(scheduled_liveaction_db.runner_info, {}) @@ -107,54 +107,56 @@ def test_execute(self): dispatched_liveaction_db = action_utils.get_liveaction_by_id(liveaction_db.id) self.assertGreater(len(list(dispatched_liveaction_db.runner_info.keys())), 0) self.assertEqual( - dispatched_liveaction_db.status, - action_constants.LIVEACTION_STATUS_RUNNING + dispatched_liveaction_db.status, action_constants.LIVEACTION_STATUS_RUNNING ) - @mock.patch.object(RunnerContainer, 'dispatch', mock.MagicMock(side_effect=Exception('Boom!'))) + @mock.patch.object( + RunnerContainer, "dispatch", mock.MagicMock(side_effect=Exception("Boom!")) + ) def test_execute_failure(self): liveaction_db = self._create_liveaction_db() self._process_request(liveaction_db) scheduled_liveaction_db = action_utils.get_liveaction_by_id(liveaction_db.id) scheduled_liveaction_db = self._wait_on_status( - scheduled_liveaction_db, - action_constants.LIVEACTION_STATUS_SCHEDULED + scheduled_liveaction_db, action_constants.LIVEACTION_STATUS_SCHEDULED ) self.dispatcher._queue_consumer._process_message(scheduled_liveaction_db) dispatched_liveaction_db = action_utils.get_liveaction_by_id(liveaction_db.id) - self.assertEqual(dispatched_liveaction_db.status, action_constants.LIVEACTION_STATUS_FAILED) + self.assertEqual( + dispatched_liveaction_db.status, action_constants.LIVEACTION_STATUS_FAILED + ) - @mock.patch.object(RunnerContainer, 'dispatch', mock.MagicMock(return_value=None)) + @mock.patch.object(RunnerContainer, "dispatch", mock.MagicMock(return_value=None)) def test_execute_no_result(self): liveaction_db = self._create_liveaction_db() self._process_request(liveaction_db) scheduled_liveaction_db = action_utils.get_liveaction_by_id(liveaction_db.id) scheduled_liveaction_db = self._wait_on_status( - scheduled_liveaction_db, - action_constants.LIVEACTION_STATUS_SCHEDULED + scheduled_liveaction_db, action_constants.LIVEACTION_STATUS_SCHEDULED ) self.dispatcher._queue_consumer._process_message(scheduled_liveaction_db) dispatched_liveaction_db = action_utils.get_liveaction_by_id(liveaction_db.id) - self.assertEqual(dispatched_liveaction_db.status, action_constants.LIVEACTION_STATUS_FAILED) + self.assertEqual( + dispatched_liveaction_db.status, action_constants.LIVEACTION_STATUS_FAILED + ) - @mock.patch.object(RunnerContainer, 'dispatch', mock.MagicMock(return_value=None)) + @mock.patch.object(RunnerContainer, "dispatch", mock.MagicMock(return_value=None)) def test_execute_cancelation(self): liveaction_db = self._create_liveaction_db() self._process_request(liveaction_db) scheduled_liveaction_db = action_utils.get_liveaction_by_id(liveaction_db.id) scheduled_liveaction_db = self._wait_on_status( - scheduled_liveaction_db, - action_constants.LIVEACTION_STATUS_SCHEDULED + scheduled_liveaction_db, action_constants.LIVEACTION_STATUS_SCHEDULED ) action_utils.update_liveaction_status( status=action_constants.LIVEACTION_STATUS_CANCELED, - liveaction_id=liveaction_db.id + liveaction_id=liveaction_db.id, ) canceled_liveaction_db = action_utils.get_liveaction_by_id(liveaction_db.id) @@ -162,11 +164,10 @@ def test_execute_cancelation(self): dispatched_liveaction_db = action_utils.get_liveaction_by_id(liveaction_db.id) self.assertEqual( - dispatched_liveaction_db.status, - action_constants.LIVEACTION_STATUS_CANCELED + dispatched_liveaction_db.status, action_constants.LIVEACTION_STATUS_CANCELED ) self.assertDictEqual( dispatched_liveaction_db.result, - {'message': 'Action execution canceled by user.'} + {"message": "Action execution canceled by user."}, ) diff --git a/st2actions/tests/unit/test_remote_runners.py b/st2actions/tests/unit/test_remote_runners.py index 26d75cb5dc5..7f84165dbb4 100644 --- a/st2actions/tests/unit/test_remote_runners.py +++ b/st2actions/tests/unit/test_remote_runners.py @@ -16,6 +16,7 @@ # XXX: FabricRunner import depends on config being setup. from __future__ import absolute_import import st2tests.config as tests_config + tests_config.parse_args() from unittest2 import TestCase @@ -26,12 +27,20 @@ class RemoteScriptActionTestCase(TestCase): def test_parameter_formatting(self): # Only named args - named_args = {'--foo1': 'bar1', '--foo2': 'bar2', '--foo3': True, - '--foo4': False} + named_args = { + "--foo1": "bar1", + "--foo2": "bar2", + "--foo3": True, + "--foo4": False, + } - action = RemoteScriptAction(name='foo', action_exec_id='dummy', - script_local_path_abs='test.py', - script_local_libs_path_abs='/', - remote_dir='/tmp', - named_args=named_args, positional_args=None) - self.assertEqual(action.command, '/tmp/test.py --foo1=bar1 --foo2=bar2 --foo3') + action = RemoteScriptAction( + name="foo", + action_exec_id="dummy", + script_local_path_abs="test.py", + script_local_libs_path_abs="/", + remote_dir="/tmp", + named_args=named_args, + positional_args=None, + ) + self.assertEqual(action.command, "/tmp/test.py --foo1=bar1 --foo2=bar2 --foo3") diff --git a/st2actions/tests/unit/test_runner_container.py b/st2actions/tests/unit/test_runner_container.py index 3ccfb7a4ea1..f17eeceb713 100644 --- a/st2actions/tests/unit/test_runner_container.py +++ b/st2actions/tests/unit/test_runner_container.py @@ -21,7 +21,10 @@ from st2common.constants import action as action_constants from st2common.runners.base import get_runner -from st2common.exceptions.actionrunner import ActionRunnerCreateError, ActionRunnerDispatchError +from st2common.exceptions.actionrunner import ( + ActionRunnerCreateError, + ActionRunnerDispatchError, +) from st2common.models.system.common import ResourceReference from st2common.models.db.liveaction import LiveActionDB from st2common.models.db.runner import RunnerTypeDB @@ -34,6 +37,7 @@ from st2tests.base import DbTestCase import st2tests.config as tests_config + tests_config.parse_args() from st2tests.fixturesloader import FixturesLoader @@ -44,39 +48,43 @@ from st2actions.container.base import get_runner_container TEST_FIXTURES = { - 'runners': [ - 'run-local.yaml', - 'testrunner1.yaml', - 'testfailingrunner1.yaml', - 'testasyncrunner1.yaml', - 'testasyncrunner2.yaml' + "runners": [ + "run-local.yaml", + "testrunner1.yaml", + "testfailingrunner1.yaml", + "testasyncrunner1.yaml", + "testasyncrunner2.yaml", + ], + "actions": [ + "local.yaml", + "action1.yaml", + "async_action1.yaml", + "async_action2.yaml", + "action-invalid-runner.yaml", ], - 'actions': [ - 'local.yaml', - 'action1.yaml', - 'async_action1.yaml', - 'async_action2.yaml', - 'action-invalid-runner.yaml' - ] } -FIXTURES_PACK = 'generic' +FIXTURES_PACK = "generic" NON_UTF8_RESULT = { - 'stderr': '', - 'stdout': '\x82\n', - 'succeeded': True, - 'failed': False, - 'return_code': 0 + "stderr": "", + "stdout": "\x82\n", + "succeeded": True, + "failed": False, + "return_code": 0, } from st2tests.mocks.runners import runner from st2tests.mocks.runners import polling_async_runner -@mock.patch('st2common.runners.base.get_runner', mock.Mock(return_value=runner.get_runner())) -@mock.patch('st2actions.container.base.get_runner', mock.Mock(return_value=runner.get_runner())) -@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) +@mock.patch( + "st2common.runners.base.get_runner", mock.Mock(return_value=runner.get_runner()) +) +@mock.patch( + "st2actions.container.base.get_runner", mock.Mock(return_value=runner.get_runner()) +) +@mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) class RunnerContainerTest(DbTestCase): action_db = None async_action_db = None @@ -88,30 +96,38 @@ class RunnerContainerTest(DbTestCase): def setUpClass(cls): super(RunnerContainerTest, cls).setUpClass() - cfg.CONF.set_override(name='validate_output_schema', override=False, group='system') + cfg.CONF.set_override( + name="validate_output_schema", override=False, group="system" + ) models = RunnerContainerTest.fixtures_loader.save_fixtures_to_db( - fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES) - RunnerContainerTest.runnertype_db = models['runners']['testrunner1.yaml'] - RunnerContainerTest.action_db = models['actions']['action1.yaml'] - RunnerContainerTest.local_action_db = models['actions']['local.yaml'] - RunnerContainerTest.async_action_db = models['actions']['async_action1.yaml'] - RunnerContainerTest.polling_async_action_db = models['actions']['async_action2.yaml'] - RunnerContainerTest.failingaction_db = models['actions']['action-invalid-runner.yaml'] + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES + ) + RunnerContainerTest.runnertype_db = models["runners"]["testrunner1.yaml"] + RunnerContainerTest.action_db = models["actions"]["action1.yaml"] + RunnerContainerTest.local_action_db = models["actions"]["local.yaml"] + RunnerContainerTest.async_action_db = models["actions"]["async_action1.yaml"] + RunnerContainerTest.polling_async_action_db = models["actions"][ + "async_action2.yaml" + ] + RunnerContainerTest.failingaction_db = models["actions"][ + "action-invalid-runner.yaml" + ] @classmethod def tearDownClass(cls): RunnerContainerTest.fixtures_loader.delete_fixtures_from_db( - fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES) + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES + ) super(RunnerContainerTest, cls).tearDownClass() def test_get_runner_module(self): - runner = get_runner(name='local-shell-script') - self.assertIsNotNone(runner, 'TestRunner must be valid.') + runner = get_runner(name="local-shell-script") + self.assertIsNotNone(runner, "TestRunner must be valid.") def test_pre_run_runner_is_disabled(self): runnertype_db = RunnerContainerTest.runnertype_db - runner = get_runner(name='local-shell-cmd') + runner = get_runner(name="local-shell-cmd") runner.runner_type = runnertype_db runner.runner_type.enabled = False @@ -119,10 +135,12 @@ def test_pre_run_runner_is_disabled(self): expected_msg = 'Runner "test-runner-1" has been disabled by the administrator' self.assertRaisesRegexp(ValueError, expected_msg, runner.pre_run) - def test_created_temporary_auth_token_is_correctly_scoped_to_user_who_ran_the_action(self): + def test_created_temporary_auth_token_is_correctly_scoped_to_user_who_ran_the_action( + self, + ): params = { - 'actionstr': 'bar', - 'mock_status': action_constants.LIVEACTION_STATUS_SUCCEEDED + "actionstr": "bar", + "mock_status": action_constants.LIVEACTION_STATUS_SUCCEEDED, } global global_runner @@ -141,15 +159,17 @@ def mock_get_runner(*args, **kwargs): liveaction_db = self._get_failingaction_exec_db_model(params) liveaction_db = LiveAction.add_or_update(liveaction_db) - liveaction_db.context = {'user': 'user_joe_1'} + liveaction_db.context = {"user": "user_joe_1"} executions.create_execution_object(liveaction_db) runner_container._get_runner = mock_get_runner - self.assertEqual(getattr(global_runner, 'auth_token', None), None) + self.assertEqual(getattr(global_runner, "auth_token", None), None) runner_container.dispatch(liveaction_db) - self.assertEqual(global_runner.auth_token.user, 'user_joe_1') - self.assertEqual(global_runner.auth_token.metadata['service'], 'actions_container') + self.assertEqual(global_runner.auth_token.user, "user_joe_1") + self.assertEqual( + global_runner.auth_token.metadata["service"], "actions_container" + ) runner_container._get_runner = original_get_runner @@ -160,23 +180,25 @@ def mock_get_runner(*args, **kwargs): liveaction_db = self._get_failingaction_exec_db_model(params) liveaction_db = LiveAction.add_or_update(liveaction_db) - liveaction_db.context = {'user': 'user_mark_2'} + liveaction_db.context = {"user": "user_mark_2"} executions.create_execution_object(liveaction_db) original_get_runner = runner_container._get_runner runner_container._get_runner = mock_get_runner - self.assertEqual(getattr(global_runner, 'auth_token', None), None) + self.assertEqual(getattr(global_runner, "auth_token", None), None) runner_container.dispatch(liveaction_db) - self.assertEqual(global_runner.auth_token.user, 'user_mark_2') - self.assertEqual(global_runner.auth_token.metadata['service'], 'actions_container') + self.assertEqual(global_runner.auth_token.user, "user_mark_2") + self.assertEqual( + global_runner.auth_token.metadata["service"], "actions_container" + ) def test_post_run_is_always_called_after_run(self): # 1. post_run should be called on success, failure, etc. runner_container = get_runner_container() params = { - 'actionstr': 'bar', - 'mock_status': action_constants.LIVEACTION_STATUS_SUCCEEDED + "actionstr": "bar", + "mock_status": action_constants.LIVEACTION_STATUS_SUCCEEDED, } liveaction_db = self._get_failingaction_exec_db_model(params) liveaction_db = LiveAction.add_or_update(liveaction_db) @@ -191,6 +213,7 @@ def mock_get_runner(*args, **kwargs): runner = original_get_runner(*args, **kwargs) global_runner = runner return runner + runner_container._get_runner = mock_get_runner # Note: We can't assert here that post_run hasn't been called yet because runner instance @@ -200,10 +223,7 @@ def mock_get_runner(*args, **kwargs): # 2. Verify post_run is called if run() throws runner_container = get_runner_container() - params = { - 'actionstr': 'bar', - 'raise': True - } + params = {"actionstr": "bar", "raise": True} liveaction_db = self._get_failingaction_exec_db_model(params) liveaction_db = LiveAction.add_or_update(liveaction_db) executions.create_execution_object(liveaction_db) @@ -216,6 +236,7 @@ def mock_get_runner(*args, **kwargs): runner = original_get_runner(*args, **kwargs) global_runner = runner return runner + runner_container._get_runner = mock_get_runner # Note: We can't assert here that post_run hasn't been called yet because runner instance @@ -225,10 +246,10 @@ def mock_get_runner(*args, **kwargs): # 2. Verify post_run is also called if _delete_auth_token throws runner_container = get_runner_container() - runner_container._delete_auth_token = mock.Mock(side_effect=ValueError('throw')) + runner_container._delete_auth_token = mock.Mock(side_effect=ValueError("throw")) params = { - 'actionstr': 'bar', - 'mock_status': action_constants.LIVEACTION_STATUS_SUCCEEDED + "actionstr": "bar", + "mock_status": action_constants.LIVEACTION_STATUS_SUCCEEDED, } liveaction_db = self._get_failingaction_exec_db_model(params) liveaction_db = LiveAction.add_or_update(liveaction_db) @@ -242,6 +263,7 @@ def mock_get_runner(*args, **kwargs): runner = original_get_runner(*args, **kwargs) global_runner = runner return runner + runner_container._get_runner = mock_get_runner # Note: We can't assert here that post_run hasn't been called yet because runner instance @@ -250,43 +272,42 @@ def mock_get_runner(*args, **kwargs): self.assertTrue(global_runner.post_run_called) def test_get_runner_module_fail(self): - runnertype_db = RunnerTypeDB(name='dummy', runner_module='absent.module') + runnertype_db = RunnerTypeDB(name="dummy", runner_module="absent.module") runner = None try: - runner = get_runner(runnertype_db.runner_module, runnertype_db.runner_module) + runner = get_runner( + runnertype_db.runner_module, runnertype_db.runner_module + ) except ActionRunnerCreateError: pass - self.assertFalse(runner, 'TestRunner must be valid.') + self.assertFalse(runner, "TestRunner must be valid.") def test_dispatch(self): runner_container = get_runner_container() - params = { - 'actionstr': 'bar' - } - liveaction_db = self._get_liveaction_model(RunnerContainerTest.action_db, params) + params = {"actionstr": "bar"} + liveaction_db = self._get_liveaction_model( + RunnerContainerTest.action_db, params + ) liveaction_db = LiveAction.add_or_update(liveaction_db) executions.create_execution_object(liveaction_db) # Assert that execution ran successfully. runner_container.dispatch(liveaction_db) liveaction_db = LiveAction.get_by_id(liveaction_db.id) result = liveaction_db.result - self.assertTrue(result.get('action_params').get('actionint') == 10) - self.assertTrue(result.get('action_params').get('actionstr') == 'bar') + self.assertTrue(result.get("action_params").get("actionint") == 10) + self.assertTrue(result.get("action_params").get("actionstr") == "bar") # Assert that context is written correctly. - context = { - 'user': 'stanley', - 'third_party_system': { - 'ref_id': '1234' - } - } + context = {"user": "stanley", "third_party_system": {"ref_id": "1234"}} self.assertDictEqual(liveaction_db.context, context) def test_dispatch_unsupported_status(self): runner_container = get_runner_container() - params = {'actionstr': 'bar'} - liveaction_db = self._get_liveaction_model(RunnerContainerTest.action_db, params) + params = {"actionstr": "bar"} + liveaction_db = self._get_liveaction_model( + RunnerContainerTest.action_db, params + ) liveaction_db = LiveAction.add_or_update(liveaction_db) executions.create_execution_object(liveaction_db) @@ -295,86 +316,74 @@ def test_dispatch_unsupported_status(self): # Assert exception is raised on dispatch. self.assertRaises( - ActionRunnerDispatchError, - runner_container.dispatch, - liveaction_db + ActionRunnerDispatchError, runner_container.dispatch, liveaction_db ) def test_dispatch_runner_failure(self): runner_container = get_runner_container() - params = { - 'actionstr': 'bar' - } + params = {"actionstr": "bar"} liveaction_db = self._get_failingaction_exec_db_model(params) liveaction_db = LiveAction.add_or_update(liveaction_db) executions.create_execution_object(liveaction_db) runner_container.dispatch(liveaction_db) # pickup updated liveaction_db liveaction_db = LiveAction.get_by_id(liveaction_db.id) - self.assertIn('error', liveaction_db.result) - self.assertIn('traceback', liveaction_db.result) + self.assertIn("error", liveaction_db.result) + self.assertIn("traceback", liveaction_db.result) def test_dispatch_override_default_action_params(self): runner_container = get_runner_container() - params = { - 'actionstr': 'foo', - 'actionint': 20 - } - liveaction_db = self._get_liveaction_model(RunnerContainerTest.action_db, params) + params = {"actionstr": "foo", "actionint": 20} + liveaction_db = self._get_liveaction_model( + RunnerContainerTest.action_db, params + ) liveaction_db = LiveAction.add_or_update(liveaction_db) executions.create_execution_object(liveaction_db) # Assert that execution ran successfully. runner_container.dispatch(liveaction_db) liveaction_db = LiveAction.get_by_id(liveaction_db.id) result = liveaction_db.result - self.assertTrue(result.get('action_params').get('actionint') == 20) - self.assertTrue(result.get('action_params').get('actionstr') == 'foo') + self.assertTrue(result.get("action_params").get("actionint") == 20) + self.assertTrue(result.get("action_params").get("actionstr") == "foo") def test_state_db_created_for_polling_async_actions(self): runner_container = get_runner_container() - params = { - 'actionstr': 'foo', - 'actionint': 20, - 'async_test': True - } + params = {"actionstr": "foo", "actionint": 20, "async_test": True} liveaction_db = self._get_liveaction_model( - RunnerContainerTest.polling_async_action_db, - params + RunnerContainerTest.polling_async_action_db, params ) liveaction_db = LiveAction.add_or_update(liveaction_db) executions.create_execution_object(liveaction_db) # Assert that execution ran without exceptions. - with mock.patch('st2actions.container.base.get_runner', - mock.Mock(return_value=polling_async_runner.get_runner())): + with mock.patch( + "st2actions.container.base.get_runner", + mock.Mock(return_value=polling_async_runner.get_runner()), + ): runner_container.dispatch(liveaction_db) states = ActionExecutionState.get_all() found = [state for state in states if state.execution_id == liveaction_db.id] - self.assertTrue(len(found) > 0, 'There should be a state db object.') - self.assertTrue(len(found) == 1, 'There should only be one state db object.') + self.assertTrue(len(found) > 0, "There should be a state db object.") + self.assertTrue(len(found) == 1, "There should only be one state db object.") self.assertIsNotNone(found[0].query_context) self.assertIsNotNone(found[0].query_module) @mock.patch.object( PollingAsyncActionRunner, - 'is_polling_enabled', - mock.MagicMock(return_value=False)) + "is_polling_enabled", + mock.MagicMock(return_value=False), + ) def test_state_db_not_created_for_disabled_polling_async_actions(self): runner_container = get_runner_container() - params = { - 'actionstr': 'foo', - 'actionint': 20, - 'async_test': True - } + params = {"actionstr": "foo", "actionint": 20, "async_test": True} liveaction_db = self._get_liveaction_model( - RunnerContainerTest.polling_async_action_db, - params + RunnerContainerTest.polling_async_action_db, params ) liveaction_db = LiveAction.add_or_update(liveaction_db) @@ -385,20 +394,15 @@ def test_state_db_not_created_for_disabled_polling_async_actions(self): states = ActionExecutionState.get_all() found = [state for state in states if state.execution_id == liveaction_db.id] - self.assertTrue(len(found) == 0, 'There should not be a state db object.') + self.assertTrue(len(found) == 0, "There should not be a state db object.") def test_state_db_not_created_for_async_actions(self): runner_container = get_runner_container() - params = { - 'actionstr': 'foo', - 'actionint': 20, - 'async_test': True - } + params = {"actionstr": "foo", "actionint": 20, "async_test": True} liveaction_db = self._get_liveaction_model( - RunnerContainerTest.async_action_db, - params + RunnerContainerTest.async_action_db, params ) liveaction_db = LiveAction.add_or_update(liveaction_db) @@ -409,17 +413,21 @@ def test_state_db_not_created_for_async_actions(self): states = ActionExecutionState.get_all() found = [state for state in states if state.execution_id == liveaction_db.id] - self.assertTrue(len(found) == 0, 'There should not be a state db object.') + self.assertTrue(len(found) == 0, "There should not be a state db object.") def _get_liveaction_model(self, action_db, params): status = action_constants.LIVEACTION_STATUS_REQUESTED start_timestamp = date_utils.get_datetime_utc_now() action_ref = ResourceReference(name=action_db.name, pack=action_db.pack).ref parameters = params - context = {'user': cfg.CONF.system_user.user} - liveaction_db = LiveActionDB(status=status, start_timestamp=start_timestamp, - action=action_ref, parameters=parameters, - context=context) + context = {"user": cfg.CONF.system_user.user} + liveaction_db = LiveActionDB( + status=status, + start_timestamp=start_timestamp, + action=action_ref, + parameters=parameters, + context=context, + ) return liveaction_db def _get_failingaction_exec_db_model(self, params): @@ -427,12 +435,17 @@ def _get_failingaction_exec_db_model(self, params): start_timestamp = date_utils.get_datetime_utc_now() action_ref = ResourceReference( name=RunnerContainerTest.failingaction_db.name, - pack=RunnerContainerTest.failingaction_db.pack).ref + pack=RunnerContainerTest.failingaction_db.pack, + ).ref parameters = params - context = {'user': cfg.CONF.system_user.user} - liveaction_db = LiveActionDB(status=status, start_timestamp=start_timestamp, - action=action_ref, parameters=parameters, - context=context) + context = {"user": cfg.CONF.system_user.user} + liveaction_db = LiveActionDB( + status=status, + start_timestamp=start_timestamp, + action=action_ref, + parameters=parameters, + context=context, + ) return liveaction_db def _get_output_schema_exec_db_model(self, params): @@ -440,10 +453,15 @@ def _get_output_schema_exec_db_model(self, params): start_timestamp = date_utils.get_datetime_utc_now() action_ref = ResourceReference( name=RunnerContainerTest.schema_output_action_db.name, - pack=RunnerContainerTest.schema_output_action_db.pack).ref + pack=RunnerContainerTest.schema_output_action_db.pack, + ).ref parameters = params - context = {'user': cfg.CONF.system_user.user} - liveaction_db = LiveActionDB(status=status, start_timestamp=start_timestamp, - action=action_ref, parameters=parameters, - context=context) + context = {"user": cfg.CONF.system_user.user} + liveaction_db = LiveActionDB( + status=status, + start_timestamp=start_timestamp, + action=action_ref, + parameters=parameters, + context=context, + ) return liveaction_db diff --git a/st2actions/tests/unit/test_scheduler.py b/st2actions/tests/unit/test_scheduler.py index c23568eea16..1a7d4b9bebb 100644 --- a/st2actions/tests/unit/test_scheduler.py +++ b/st2actions/tests/unit/test_scheduler.py @@ -20,6 +20,7 @@ import eventlet from st2tests import config as test_config + test_config.parse_args() import st2common @@ -45,31 +46,28 @@ LIVE_ACTION = { - 'parameters': { - 'cmd': 'echo ":dat_face:"', + "parameters": { + "cmd": 'echo ":dat_face:"', }, - 'action': 'core.local', - 'status': 'requested' + "action": "core.local", + "status": "requested", } -PACK = 'generic' +PACK = "generic" TEST_FIXTURES = { - 'actions': [ - 'action1.yaml', - 'action2.yaml' - ], - 'policies': [ - 'policy_3.yaml', - 'policy_7.yaml' - ] + "actions": ["action1.yaml", "action2.yaml"], + "policies": ["policy_3.yaml", "policy_7.yaml"], } @mock.patch.object( - LiveActionPublisher, 'publish_state', - mock.MagicMock(side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state)) + LiveActionPublisher, + "publish_state", + mock.MagicMock( + side_effect=MockLiveActionPublisherSchedulingQueueOnly.publish_state + ), +) class ActionExecutionSchedulingQueueItemDBTest(ExecutionDbTestCase): - @classmethod def setUpClass(cls): ExecutionDbTestCase.setUpClass() @@ -81,18 +79,21 @@ def setUpClass(cls): register_policy_types(st2common) loader = FixturesLoader() - loader.save_fixtures_to_db(fixtures_pack=PACK, - fixtures_dict=TEST_FIXTURES) + loader.save_fixtures_to_db(fixtures_pack=PACK, fixtures_dict=TEST_FIXTURES) def setUp(self): super(ActionExecutionSchedulingQueueItemDBTest, self).setUp() self.scheduler = scheduling.get_scheduler_entrypoint() self.scheduling_queue = scheduling_queue.get_handler() - def _create_liveaction_db(self, status=action_constants.LIVEACTION_STATUS_REQUESTED): - action_ref = 'wolfpack.action-1' - parameters = {'actionstr': 'fu'} - liveaction_db = LiveActionDB(action=action_ref, parameters=parameters, status=status) + def _create_liveaction_db( + self, status=action_constants.LIVEACTION_STATUS_REQUESTED + ): + action_ref = "wolfpack.action-1" + parameters = {"actionstr": "fu"} + liveaction_db = LiveActionDB( + action=action_ref, parameters=parameters, status=status + ) liveaction_db = LiveAction.add_or_update(liveaction_db) execution_service.create_execution_object(liveaction_db, publish=False) @@ -108,7 +109,9 @@ def test_create_from_liveaction(self): delay, ) - delay_date = date.append_milliseconds_to_time(liveaction_db.start_timestamp, delay) + delay_date = date.append_milliseconds_to_time( + liveaction_db.start_timestamp, delay + ) self.assertIsInstance(schedule_q_db, ActionExecutionSchedulingQueueItemDB) self.assertEqual(schedule_q_db.scheduled_start_timestamp, delay_date) @@ -125,12 +128,14 @@ def test_next_execution(self): for delay in delays: liveaction_db = self._create_liveaction_db() - delayed_start = date.append_milliseconds_to_time(liveaction_db.start_timestamp, delay) + delayed_start = date.append_milliseconds_to_time( + liveaction_db.start_timestamp, delay + ) test_case = { - 'liveaction': liveaction_db, - 'delay': delay, - 'delayed_start': delayed_start + "liveaction": liveaction_db, + "delay": delay, + "delayed_start": delayed_start, } test_cases.append(test_case) @@ -139,8 +144,8 @@ def test_next_execution(self): schedule_q_dbs.append( ActionExecutionSchedulingQueue.add_or_update( self.scheduler._create_execution_queue_item_db_from_liveaction( - test_case['liveaction'], - test_case['delay'], + test_case["liveaction"], + test_case["delay"], ) ) ) @@ -152,22 +157,24 @@ def test_next_execution(self): test_case = test_cases[index] date_mock = mock.MagicMock() - date_mock.get_datetime_utc_now.return_value = test_case['delayed_start'] + date_mock.get_datetime_utc_now.return_value = test_case["delayed_start"] date_mock.append_milliseconds_to_time = date.append_milliseconds_to_time - with mock.patch('st2actions.scheduler.handler.date', date_mock): + with mock.patch("st2actions.scheduler.handler.date", date_mock): schedule_q_db = self.scheduling_queue._get_next_execution() ActionExecutionSchedulingQueue.delete(schedule_q_db) self.assertIsInstance(schedule_q_db, ActionExecutionSchedulingQueueItemDB) - self.assertEqual(schedule_q_db.delay, test_case['delay']) - self.assertEqual(schedule_q_db.liveaction_id, str(test_case['liveaction'].id)) + self.assertEqual(schedule_q_db.delay, test_case["delay"]) + self.assertEqual( + schedule_q_db.liveaction_id, str(test_case["liveaction"].id) + ) # NOTE: We can't directly assert on the timestamp due to the delays on the code and # timing variance scheduled_start_timestamp = schedule_q_db.scheduled_start_timestamp - test_case_start_timestamp = test_case['delayed_start'] - start_timestamp_diff = (scheduled_start_timestamp - test_case_start_timestamp) + test_case_start_timestamp = test_case["delayed_start"] + start_timestamp_diff = scheduled_start_timestamp - test_case_start_timestamp self.assertTrue(start_timestamp_diff <= datetime.timedelta(seconds=1)) def test_next_executions_empty(self): @@ -227,9 +234,11 @@ def test_garbage_collection(self): schedule_q_db = self.scheduling_queue._get_next_execution() self.assertIsNotNone(schedule_q_db) - @mock.patch('st2actions.scheduler.handler.action_service') - @mock.patch('st2actions.scheduler.handler.ActionExecutionSchedulingQueue.delete') - def test_processing_when_task_completed(self, mock_execution_queue_delete, mock_action_service): + @mock.patch("st2actions.scheduler.handler.action_service") + @mock.patch("st2actions.scheduler.handler.ActionExecutionSchedulingQueue.delete") + def test_processing_when_task_completed( + self, mock_execution_queue_delete, mock_action_service + ): self.reset() liveaction_db = self._create_liveaction_db() @@ -245,7 +254,7 @@ def test_processing_when_task_completed(self, mock_execution_queue_delete, mock_ mock_execution_queue_delete.assert_called_once() ActionExecutionSchedulingQueue.delete(schedule_q_db) - @mock.patch('st2actions.scheduler.handler.LOG') + @mock.patch("st2actions.scheduler.handler.LOG") def test_failed_next_item(self, mocked_logger): self.reset() @@ -258,15 +267,17 @@ def test_failed_next_item(self, mocked_logger): schedule_q_db = ActionExecutionSchedulingQueue.add_or_update(schedule_q_db) with mock.patch( - 'st2actions.scheduler.handler.ActionExecutionSchedulingQueue.add_or_update', - side_effect=db_exc.StackStormDBObjectWriteConflictError(schedule_q_db) + "st2actions.scheduler.handler.ActionExecutionSchedulingQueue.add_or_update", + side_effect=db_exc.StackStormDBObjectWriteConflictError(schedule_q_db), ): schedule_q_db = self.scheduling_queue._get_next_execution() self.assertIsNone(schedule_q_db) self.assertEqual(mocked_logger.info.call_count, 2) call_args = mocked_logger.info.call_args_list[1][0] - self.assertEqual(r'[%s] Item "%s" is already handled by another scheduler.', call_args[0]) + self.assertEqual( + r'[%s] Item "%s" is already handled by another scheduler.', call_args[0] + ) schedule_q_db = self.scheduling_queue._get_next_execution() self.assertIsNotNone(schedule_q_db) @@ -288,33 +299,39 @@ def test_cleanup_policy_delayed(self): # Manually update the liveaction to policy-delayed status. # Using action_service.update_status will throw an exception on the # deprecated action_constants.LIVEACTION_STATUS_POLICY_DELAYED. - liveaction_db.status = 'policy-delayed' + liveaction_db.status = "policy-delayed" liveaction_db = LiveAction.add_or_update(liveaction_db) execution_db = execution_service.update_execution(liveaction_db) # Check that the execution status is set to policy-delayed. liveaction_db = LiveAction.get_by_id(str(liveaction_db.id)) - self.assertEqual(liveaction_db.status, 'policy-delayed') + self.assertEqual(liveaction_db.status, "policy-delayed") execution_db = ActionExecution.get_by_id(str(execution_db.id)) - self.assertEqual(execution_db.status, 'policy-delayed') + self.assertEqual(execution_db.status, "policy-delayed") # Run the clean up logic. self.scheduling_queue._cleanup_policy_delayed() # Check that the execution status is reset to requested. liveaction_db = LiveAction.get_by_id(str(liveaction_db.id)) - self.assertEqual(liveaction_db.status, action_constants.LIVEACTION_STATUS_REQUESTED) + self.assertEqual( + liveaction_db.status, action_constants.LIVEACTION_STATUS_REQUESTED + ) execution_db = ActionExecution.get_by_id(str(execution_db.id)) - self.assertEqual(execution_db.status, action_constants.LIVEACTION_STATUS_REQUESTED) + self.assertEqual( + execution_db.status, action_constants.LIVEACTION_STATUS_REQUESTED + ) # The old entry should have been deleted. Since the execution is # reset to requested, there should be a new scheduling entry. new_schedule_q_db = self.scheduling_queue._get_next_execution() self.assertIsNotNone(new_schedule_q_db) self.assertNotEqual(str(schedule_q_db.id), str(new_schedule_q_db.id)) - self.assertEqual(schedule_q_db.action_execution_id, new_schedule_q_db.action_execution_id) + self.assertEqual( + schedule_q_db.action_execution_id, new_schedule_q_db.action_execution_id + ) self.assertEqual(schedule_q_db.liveaction_id, new_schedule_q_db.liveaction_id) # TODO: Remove this test case for populating action_execution_id in v3.2. diff --git a/st2actions/tests/unit/test_scheduler_entrypoint.py b/st2actions/tests/unit/test_scheduler_entrypoint.py index ddcba287e71..2bc535d99df 100644 --- a/st2actions/tests/unit/test_scheduler_entrypoint.py +++ b/st2actions/tests/unit/test_scheduler_entrypoint.py @@ -17,6 +17,7 @@ import mock from st2tests import config as test_config + test_config.parse_args() from st2actions.cmd.scheduler import _run_scheduler @@ -25,32 +26,30 @@ from st2tests.base import CleanDbTestCase -__all__ = [ - 'SchedulerServiceEntryPointTestCase' -] +__all__ = ["SchedulerServiceEntryPointTestCase"] def mock_handler_run(self): # NOTE: We use eventlet.sleep to emulate async nature of this process eventlet.sleep(0.2) - raise Exception('handler run exception') + raise Exception("handler run exception") def mock_handler_cleanup(self): # NOTE: We use eventlet.sleep to emulate async nature of this process eventlet.sleep(0.2) - raise Exception('handler clean exception') + raise Exception("handler clean exception") def mock_entrypoint_start(self): # NOTE: We use eventlet.sleep to emulate async nature of this process eventlet.sleep(0.2) - raise Exception('entrypoint start exception') + raise Exception("entrypoint start exception") class SchedulerServiceEntryPointTestCase(CleanDbTestCase): - @mock.patch.object(ActionExecutionSchedulingQueueHandler, 'run', mock_handler_run) - @mock.patch('st2actions.cmd.scheduler.LOG') + @mock.patch.object(ActionExecutionSchedulingQueueHandler, "run", mock_handler_run) + @mock.patch("st2actions.cmd.scheduler.LOG") def test_service_exits_correctly_on_fatal_exception_in_handler_run(self, mock_log): run_thread = eventlet.spawn(_run_scheduler) result = run_thread.wait() @@ -58,26 +57,32 @@ def test_service_exits_correctly_on_fatal_exception_in_handler_run(self, mock_lo self.assertEqual(result, 1) mock_log_exception_call = mock_log.exception.call_args_list[0][0][0] - self.assertIn('Scheduler unexpectedly stopped', mock_log_exception_call) - - @mock.patch.object(ActionExecutionSchedulingQueueHandler, 'cleanup', mock_handler_cleanup) - @mock.patch('st2actions.cmd.scheduler.LOG') - def test_service_exits_correctly_on_fatal_exception_in_handler_cleanup(self, mock_log): + self.assertIn("Scheduler unexpectedly stopped", mock_log_exception_call) + + @mock.patch.object( + ActionExecutionSchedulingQueueHandler, "cleanup", mock_handler_cleanup + ) + @mock.patch("st2actions.cmd.scheduler.LOG") + def test_service_exits_correctly_on_fatal_exception_in_handler_cleanup( + self, mock_log + ): run_thread = eventlet.spawn(_run_scheduler) result = run_thread.wait() self.assertEqual(result, 1) mock_log_exception_call = mock_log.exception.call_args_list[0][0][0] - self.assertIn('Scheduler unexpectedly stopped', mock_log_exception_call) + self.assertIn("Scheduler unexpectedly stopped", mock_log_exception_call) - @mock.patch.object(SchedulerEntrypoint, 'start', mock_entrypoint_start) - @mock.patch('st2actions.cmd.scheduler.LOG') - def test_service_exits_correctly_on_fatal_exception_in_entrypoint_start(self, mock_log): + @mock.patch.object(SchedulerEntrypoint, "start", mock_entrypoint_start) + @mock.patch("st2actions.cmd.scheduler.LOG") + def test_service_exits_correctly_on_fatal_exception_in_entrypoint_start( + self, mock_log + ): run_thread = eventlet.spawn(_run_scheduler) result = run_thread.wait() self.assertEqual(result, 1) mock_log_exception_call = mock_log.exception.call_args_list[0][0][0] - self.assertIn('Scheduler unexpectedly stopped', mock_log_exception_call) + self.assertIn("Scheduler unexpectedly stopped", mock_log_exception_call) diff --git a/st2actions/tests/unit/test_scheduler_retry.py b/st2actions/tests/unit/test_scheduler_retry.py index e47a2ad3eb8..ad1f221df1c 100644 --- a/st2actions/tests/unit/test_scheduler_retry.py +++ b/st2actions/tests/unit/test_scheduler_retry.py @@ -19,6 +19,7 @@ import uuid from st2tests import config as test_config + test_config.parse_args() from st2actions.scheduler import handler @@ -27,22 +28,23 @@ from st2tests.base import CleanDbTestCase -__all__ = [ - 'SchedulerHandlerRetryTestCase' -] +__all__ = ["SchedulerHandlerRetryTestCase"] -MOCK_QUEUE_ITEM = ex_q_db.ActionExecutionSchedulingQueueItemDB(liveaction_id=uuid.uuid4().hex) +MOCK_QUEUE_ITEM = ex_q_db.ActionExecutionSchedulingQueueItemDB( + liveaction_id=uuid.uuid4().hex +) class SchedulerHandlerRetryTestCase(CleanDbTestCase): - - @mock.patch.object( - handler.ActionExecutionSchedulingQueueHandler, '_get_next_execution', - mock.MagicMock(side_effect=[pymongo.errors.ConnectionFailure(), MOCK_QUEUE_ITEM])) @mock.patch.object( - eventlet.GreenPool, 'spawn', - mock.MagicMock(return_value=None)) + handler.ActionExecutionSchedulingQueueHandler, + "_get_next_execution", + mock.MagicMock( + side_effect=[pymongo.errors.ConnectionFailure(), MOCK_QUEUE_ITEM] + ), + ) + @mock.patch.object(eventlet.GreenPool, "spawn", mock.MagicMock(return_value=None)) def test_handler_retry_connection_error(self): scheduling_queue_handler = handler.ActionExecutionSchedulingQueueHandler() scheduling_queue_handler.process() @@ -52,69 +54,88 @@ def test_handler_retry_connection_error(self): eventlet.GreenPool.spawn.assert_has_calls(calls) @mock.patch.object( - handler.ActionExecutionSchedulingQueueHandler, '_get_next_execution', - mock.MagicMock(side_effect=[pymongo.errors.ConnectionFailure()] * 3)) - @mock.patch.object( - eventlet.GreenPool, 'spawn', - mock.MagicMock(return_value=None)) + handler.ActionExecutionSchedulingQueueHandler, + "_get_next_execution", + mock.MagicMock(side_effect=[pymongo.errors.ConnectionFailure()] * 3), + ) + @mock.patch.object(eventlet.GreenPool, "spawn", mock.MagicMock(return_value=None)) def test_handler_retries_exhausted(self): scheduling_queue_handler = handler.ActionExecutionSchedulingQueueHandler() - self.assertRaises(pymongo.errors.ConnectionFailure, scheduling_queue_handler.process) + self.assertRaises( + pymongo.errors.ConnectionFailure, scheduling_queue_handler.process + ) self.assertEqual(eventlet.GreenPool.spawn.call_count, 0) @mock.patch.object( - handler.ActionExecutionSchedulingQueueHandler, '_get_next_execution', - mock.MagicMock(side_effect=KeyError())) - @mock.patch.object( - eventlet.GreenPool, 'spawn', - mock.MagicMock(return_value=None)) + handler.ActionExecutionSchedulingQueueHandler, + "_get_next_execution", + mock.MagicMock(side_effect=KeyError()), + ) + @mock.patch.object(eventlet.GreenPool, "spawn", mock.MagicMock(return_value=None)) def test_handler_retry_unexpected_error(self): scheduling_queue_handler = handler.ActionExecutionSchedulingQueueHandler() self.assertRaises(KeyError, scheduling_queue_handler.process) self.assertEqual(eventlet.GreenPool.spawn.call_count, 0) @mock.patch.object( - ex_q_db_access.ActionExecutionSchedulingQueue, 'query', - mock.MagicMock(side_effect=[pymongo.errors.ConnectionFailure(), [MOCK_QUEUE_ITEM]])) + ex_q_db_access.ActionExecutionSchedulingQueue, + "query", + mock.MagicMock( + side_effect=[pymongo.errors.ConnectionFailure(), [MOCK_QUEUE_ITEM]] + ), + ) @mock.patch.object( - ex_q_db_access.ActionExecutionSchedulingQueue, 'add_or_update', - mock.MagicMock(return_value=None)) + ex_q_db_access.ActionExecutionSchedulingQueue, + "add_or_update", + mock.MagicMock(return_value=None), + ) def test_handler_gc_retry_connection_error(self): scheduling_queue_handler = handler.ActionExecutionSchedulingQueueHandler() scheduling_queue_handler._handle_garbage_collection() # Make sure retry occurs and that _handle_execution in process is called. calls = [mock.call(MOCK_QUEUE_ITEM, publish=False)] - ex_q_db_access.ActionExecutionSchedulingQueue.add_or_update.assert_has_calls(calls) + ex_q_db_access.ActionExecutionSchedulingQueue.add_or_update.assert_has_calls( + calls + ) @mock.patch.object( - ex_q_db_access.ActionExecutionSchedulingQueue, 'query', - mock.MagicMock(side_effect=[pymongo.errors.ConnectionFailure()] * 3)) + ex_q_db_access.ActionExecutionSchedulingQueue, + "query", + mock.MagicMock(side_effect=[pymongo.errors.ConnectionFailure()] * 3), + ) @mock.patch.object( - ex_q_db_access.ActionExecutionSchedulingQueue, 'add_or_update', - mock.MagicMock(return_value=None)) + ex_q_db_access.ActionExecutionSchedulingQueue, + "add_or_update", + mock.MagicMock(return_value=None), + ) def test_handler_gc_retries_exhausted(self): scheduling_queue_handler = handler.ActionExecutionSchedulingQueueHandler() self.assertRaises( pymongo.errors.ConnectionFailure, - scheduling_queue_handler._handle_garbage_collection + scheduling_queue_handler._handle_garbage_collection, ) - self.assertEqual(ex_q_db_access.ActionExecutionSchedulingQueue.add_or_update.call_count, 0) + self.assertEqual( + ex_q_db_access.ActionExecutionSchedulingQueue.add_or_update.call_count, 0 + ) @mock.patch.object( - ex_q_db_access.ActionExecutionSchedulingQueue, 'query', - mock.MagicMock(side_effect=KeyError())) + ex_q_db_access.ActionExecutionSchedulingQueue, + "query", + mock.MagicMock(side_effect=KeyError()), + ) @mock.patch.object( - ex_q_db_access.ActionExecutionSchedulingQueue, 'add_or_update', - mock.MagicMock(return_value=None)) + ex_q_db_access.ActionExecutionSchedulingQueue, + "add_or_update", + mock.MagicMock(return_value=None), + ) def test_handler_gc_unexpected_error(self): scheduling_queue_handler = handler.ActionExecutionSchedulingQueueHandler() - self.assertRaises( - KeyError, - scheduling_queue_handler._handle_garbage_collection - ) + self.assertRaises(KeyError, scheduling_queue_handler._handle_garbage_collection) - self.assertEqual(ex_q_db_access.ActionExecutionSchedulingQueue.add_or_update.call_count, 0) + self.assertEqual( + ex_q_db_access.ActionExecutionSchedulingQueue.add_or_update.call_count, 0 + ) diff --git a/st2actions/tests/unit/test_worker.py b/st2actions/tests/unit/test_worker.py index 19ffd695537..d8637b9ac74 100644 --- a/st2actions/tests/unit/test_worker.py +++ b/st2actions/tests/unit/test_worker.py @@ -36,16 +36,20 @@ from st2tests.fixturesloader import FixturesLoader import st2tests.config as tests_config from six.moves import range + tests_config.parse_args() -TEST_FIXTURES = { - 'actions': ['local.yaml'] -} +TEST_FIXTURES = {"actions": ["local.yaml"]} -FIXTURES_PACK = 'generic' +FIXTURES_PACK = "generic" -NON_UTF8_RESULT = {'stderr': '', 'stdout': '\x82\n', 'succeeded': True, 'failed': False, - 'return_code': 0} +NON_UTF8_RESULT = { + "stderr": "", + "stdout": "\x82\n", + "succeeded": True, + "failed": False, + "return_code": 0, +} class WorkerTestCase(DbTestCase): @@ -58,28 +62,42 @@ def setUpClass(cls): runners_registrar.register_runners() models = WorkerTestCase.fixtures_loader.save_fixtures_to_db( - fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES) - WorkerTestCase.local_action_db = models['actions']['local.yaml'] + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES + ) + WorkerTestCase.local_action_db = models["actions"]["local.yaml"] def _get_liveaction_model(self, action_db, params): status = action_constants.LIVEACTION_STATUS_REQUESTED start_timestamp = date_utils.get_datetime_utc_now() action_ref = ResourceReference(name=action_db.name, pack=action_db.pack).ref parameters = params - context = {'user': cfg.CONF.system_user.user} - liveaction_db = LiveActionDB(status=status, start_timestamp=start_timestamp, - action=action_ref, parameters=parameters, - context=context) + context = {"user": cfg.CONF.system_user.user} + liveaction_db = LiveActionDB( + status=status, + start_timestamp=start_timestamp, + action=action_ref, + parameters=parameters, + context=context, + ) return liveaction_db - @mock.patch.object(LocalShellCommandRunner, 'run', mock.MagicMock( - return_value=(action_constants.LIVEACTION_STATUS_SUCCEEDED, NON_UTF8_RESULT, None))) + @mock.patch.object( + LocalShellCommandRunner, + "run", + mock.MagicMock( + return_value=( + action_constants.LIVEACTION_STATUS_SUCCEEDED, + NON_UTF8_RESULT, + None, + ) + ), + ) def test_non_utf8_action_result_string(self): action_worker = actions_worker.get_worker() - params = { - 'cmd': "python -c 'print \"\\x82\"'" - } - liveaction_db = self._get_liveaction_model(WorkerTestCase.local_action_db, params) + params = {"cmd": "python -c 'print \"\\x82\"'"} + liveaction_db = self._get_liveaction_model( + WorkerTestCase.local_action_db, params + ) liveaction_db = LiveAction.add_or_update(liveaction_db) execution_db = executions.create_execution_object(liveaction_db) @@ -87,11 +105,15 @@ def test_non_utf8_action_result_string(self): action_worker._run_action(liveaction_db) except InvalidStringData: liveaction_db = LiveAction.get_by_id(liveaction_db.id) - self.assertEqual(liveaction_db.status, action_constants.LIVEACTION_STATUS_FAILED) - self.assertIn('error', liveaction_db.result) - self.assertIn('traceback', liveaction_db.result) + self.assertEqual( + liveaction_db.status, action_constants.LIVEACTION_STATUS_FAILED + ) + self.assertIn("error", liveaction_db.result) + self.assertIn("traceback", liveaction_db.result) execution_db = ActionExecution.get_by_id(execution_db.id) - self.assertEqual(liveaction_db.status, action_constants.LIVEACTION_STATUS_FAILED) + self.assertEqual( + liveaction_db.status, action_constants.LIVEACTION_STATUS_FAILED + ) def test_worker_shutdown(self): action_worker = actions_worker.get_worker() @@ -107,8 +129,10 @@ def test_worker_shutdown(self): self.assertTrue(os.path.isfile(temp_file)) # Launch the action execution in a separate thread. - params = {'cmd': 'while [ -e \'%s\' ]; do sleep 0.1; done' % temp_file} - liveaction_db = self._get_liveaction_model(WorkerTestCase.local_action_db, params) + params = {"cmd": "while [ -e '%s' ]; do sleep 0.1; done" % temp_file} + liveaction_db = self._get_liveaction_model( + WorkerTestCase.local_action_db, params + ) liveaction_db = LiveAction.add_or_update(liveaction_db) executions.create_execution_object(liveaction_db) runner_thread = eventlet.spawn(action_worker._run_action, liveaction_db) @@ -127,8 +151,11 @@ def test_worker_shutdown(self): # Verify that _running_liveactions is empty and the liveaction is abandoned. self.assertEqual(len(action_worker._running_liveactions), 0) - self.assertEqual(liveaction_db.status, action_constants.LIVEACTION_STATUS_ABANDONED, - str(liveaction_db)) + self.assertEqual( + liveaction_db.status, + action_constants.LIVEACTION_STATUS_ABANDONED, + str(liveaction_db), + ) # Make sure the temporary file has been deleted. self.assertFalse(os.path.isfile(temp_file)) diff --git a/st2actions/tests/unit/test_workflow_engine.py b/st2actions/tests/unit/test_workflow_engine.py index 916682d5695..b8e4fae83f6 100644 --- a/st2actions/tests/unit/test_workflow_engine.py +++ b/st2actions/tests/unit/test_workflow_engine.py @@ -26,6 +26,7 @@ # XXX: actionsensor import depends on config being setup. import st2tests.config as tests_config + tests_config.parse_args() from st2actions.workflows import workflows @@ -46,37 +47,45 @@ from st2tests.mocks import workflow as mock_wf_ex_xport -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) PACKS = [ TEST_PACK_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] @mock.patch.object( - publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) + publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None) +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create)) + "publish_create", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create + ), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state)) + "publish_state", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state + ), +) class WorkflowExecutionHandlerTest(st2tests.WorkflowTestCase): - @classmethod def setUpClass(cls): super(WorkflowExecutionHandlerTest, cls).setUpClass() @@ -86,50 +95,57 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: actions_registrar.register_from_pack(pack) def test_process(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) # Process task1. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0] + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + )[0] workflows.get_engine().process(t1_ac_ex_db) t1_ex_db = wf_db_access.TaskExecution.get_by_id(t1_ex_db.id) self.assertEqual(t1_ex_db.status, wf_statuses.SUCCEEDED) # Process task2. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task2'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task2"} t2_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t2_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t2_ex_db.id))[0] + t2_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t2_ex_db.id) + )[0] workflows.get_engine().process(t2_ac_ex_db) t2_ex_db = wf_db_access.TaskExecution.get_by_id(t2_ex_db.id) self.assertEqual(t2_ex_db.status, wf_statuses.SUCCEEDED) # Process task3. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task3'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task3"} t3_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t3_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t3_ex_db.id))[0] + t3_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t3_ex_db.id) + )[0] workflows.get_engine().process(t3_ac_ex_db) t3_ex_db = wf_db_access.TaskExecution.get_by_id(t3_ex_db.id) self.assertEqual(t3_ex_db.status, wf_statuses.SUCCEEDED) # Assert the workflow has completed successfully with expected output. - expected_output = {'msg': 'Stanley, All your base are belong to us!'} + expected_output = {"msg": "Stanley, All your base are belong to us!"} wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) self.assertEqual(wf_ex_db.status, wf_statuses.SUCCEEDED) self.assertDictEqual(wf_ex_db.output, expected_output) @@ -137,37 +153,43 @@ def test_process(self): self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_SUCCEEDED) @mock.patch.object( - coordination_service.NoOpDriver, 'get_lock', - mock.MagicMock(side_effect=coordination.ToozConnectionError('foobar'))) + coordination_service.NoOpDriver, + "get_lock", + mock.MagicMock(side_effect=coordination.ToozConnectionError("foobar")), + ) def test_process_error_handling(self): expected_errors = [ { - 'message': 'Execution failed. See result for details.', - 'type': 'error', - 'task_id': 'task1' + "message": "Execution failed. See result for details.", + "type": "error", + "task_id": "task1", }, { - 'type': 'error', - 'message': 'ToozConnectionError: foobar', - 'task_id': 'task1', - 'route': 0 - } + "type": "error", + "message": "ToozConnectionError: foobar", + "task_id": "task1", + "route": 0, + }, ] - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) # Process task1. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0] + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + )[0] workflows.get_engine().process(t1_ac_ex_db) # Assert the task is marked as failed. @@ -182,36 +204,42 @@ def test_process_error_handling(self): self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_FAILED) @mock.patch.object( - coordination_service.NoOpDriver, 'get_lock', - mock.MagicMock(side_effect=coordination.ToozConnectionError('foobar'))) + coordination_service.NoOpDriver, + "get_lock", + mock.MagicMock(side_effect=coordination.ToozConnectionError("foobar")), + ) @mock.patch.object( workflows.WorkflowExecutionHandler, - 'fail_workflow_execution', - mock.MagicMock(side_effect=Exception('Unexpected error.'))) + "fail_workflow_execution", + mock.MagicMock(side_effect=Exception("Unexpected error.")), + ) def test_process_error_handling_has_error(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.request(lv_ac_db) # Assert action execution is running. lv_ac_db = lv_db_access.LiveAction.get_by_id(str(lv_ac_db.id)) self.assertEqual(lv_ac_db.status, action_constants.LIVEACTION_STATUS_RUNNING) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] self.assertEqual(wf_ex_db.status, action_constants.LIVEACTION_STATUS_RUNNING) # Process task1. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} t1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - t1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(t1_ex_db.id))[0] + t1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(t1_ex_db.id) + )[0] self.assertRaisesRegexp( - Exception, - 'Unexpected error.', - workflows.get_engine().process, - t1_ac_ex_db + Exception, "Unexpected error.", workflows.get_engine().process, t1_ac_ex_db ) - self.assertTrue(workflows.WorkflowExecutionHandler.fail_workflow_execution.called) + self.assertTrue( + workflows.WorkflowExecutionHandler.fail_workflow_execution.called + ) # Since error handling failed, the workflow will have status of running. wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_db.id) diff --git a/st2api/dist_utils.py b/st2api/dist_utils.py index a6f62c8cc2a..2f2043cf29f 100644 --- a/st2api/dist_utils.py +++ b/st2api/dist_utils.py @@ -43,17 +43,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" __all__ = [ - 'check_pip_is_installed', - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_is_installed", + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] @@ -64,15 +64,15 @@ def check_pip_is_installed(): try: import pip # NOQA except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) return True -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ @@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'): import pip if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) return True @@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path): reqs = [] def _get_link(line): - vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+'] + vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"] for vcs_prefix in vcs_prefixes: - if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)): - req_name = re.findall('.*#egg=(.+)([&|@]).*$', line) + if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)): + req_name = re.findall(".*#egg=(.+)([&|@]).*$", line) if not req_name: - req_name = re.findall('.*#egg=(.+?)$', line) + req_name = re.findall(".*#egg=(.+?)$", line) else: req_name = req_name[0] if not req_name: - raise ValueError('Line "%s" is missing "#egg="' % (line)) + raise ValueError( + 'Line "%s" is missing "#egg="' % (line) + ) - link = line.replace('-e ', '').strip() + link = line.replace("-e ", "").strip() return link, req_name[0] return None, None - with open(requirements_file_path, 'r') as fp: + with open(requirements_file_path, "r") as fp: for line in fp.readlines(): line = line.strip() - if line.startswith('#') or not line: + if line.startswith("#") or not line: continue link, req_name = _get_link(line=line) @@ -131,8 +135,8 @@ def _get_link(line): else: req_name = line - if ';' in req_name: - req_name = req_name.split(';')[0].strip() + if ";" in req_name: + req_name = req_name.split(";")[0].strip() reqs.append(req_name) @@ -146,7 +150,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -155,14 +159,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/st2api/setup.py b/st2api/setup.py index 932f2e90f42..b0cfa240679 100644 --- a/st2api/setup.py +++ b/st2api/setup.py @@ -22,9 +22,9 @@ from dist_utils import apply_vagrant_workaround from st2api import __version__ -ST2_COMPONENT = 'st2api' +ST2_COMPONENT = "st2api" BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt') +REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt") install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE) @@ -32,18 +32,18 @@ setup( name=ST2_COMPONENT, version=__version__, - description='{} StackStorm event-driven automation platform component'.format(ST2_COMPONENT), - author='StackStorm', - author_email='info@stackstorm.com', - license='Apache License (2.0)', - url='https://stackstorm.com/', + description="{} StackStorm event-driven automation platform component".format( + ST2_COMPONENT + ), + author="StackStorm", + author_email="info@stackstorm.com", + license="Apache License (2.0)", + url="https://stackstorm.com/", install_requires=install_reqs, dependency_links=dep_links, test_suite=ST2_COMPONENT, zip_safe=False, include_package_data=True, - packages=find_packages(exclude=['setuptools', 'tests']), - scripts=[ - 'bin/st2api' - ] + packages=find_packages(exclude=["setuptools", "tests"]), + scripts=["bin/st2api"], ) diff --git a/st2api/st2api/__init__.py b/st2api/st2api/__init__.py index bbe290db9a7..e6d3f15e0bd 100644 --- a/st2api/st2api/__init__.py +++ b/st2api/st2api/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = "3.4dev" diff --git a/st2api/st2api/app.py b/st2api/st2api/app.py index 5b10e58c3fd..2483d0ef9ea 100644 --- a/st2api/st2api/app.py +++ b/st2api/st2api/app.py @@ -36,55 +36,60 @@ def setup_app(config=None): config = config or {} - LOG.info('Creating st2api: %s as OpenAPI app.', VERSION_STRING) + LOG.info("Creating st2api: %s as OpenAPI app.", VERSION_STRING) - is_gunicorn = config.get('is_gunicorn', False) + is_gunicorn = config.get("is_gunicorn", False) if is_gunicorn: # NOTE: We only want to perform this logic in the WSGI worker st2api_config.register_opts() capabilities = { - 'name': 'api', - 'listen_host': cfg.CONF.api.host, - 'listen_port': cfg.CONF.api.port, - 'type': 'active' + "name": "api", + "listen_host": cfg.CONF.api.host, + "listen_port": cfg.CONF.api.port, + "type": "active", } # This should be called in gunicorn case because we only want # workers to connect to db, rabbbitmq etc. In standalone HTTP # server case, this setup would have already occurred. - common_setup(service='api', config=st2api_config, setup_db=True, - register_mq_exchanges=True, - register_signal_handlers=True, - register_internal_trigger_types=True, - run_migrations=True, - service_registry=True, - capabilities=capabilities, - config_args=config.get('config_args', None)) + common_setup( + service="api", + config=st2api_config, + setup_db=True, + register_mq_exchanges=True, + register_signal_handlers=True, + register_internal_trigger_types=True, + run_migrations=True, + service_registry=True, + capabilities=capabilities, + config_args=config.get("config_args", None), + ) # Additional pre-run time checks validate_rbac_is_correctly_configured() - router = Router(debug=cfg.CONF.api.debug, auth=cfg.CONF.auth.enable, - is_gunicorn=is_gunicorn) + router = Router( + debug=cfg.CONF.api.debug, auth=cfg.CONF.auth.enable, is_gunicorn=is_gunicorn + ) - spec = spec_loader.load_spec('st2common', 'openapi.yaml.j2') + spec = spec_loader.load_spec("st2common", "openapi.yaml.j2") transforms = { - '^/api/v1/$': ['/v1'], - '^/api/v1/': ['/', '/v1/'], - '^/api/v1/executions': ['/actionexecutions', '/v1/actionexecutions'], - '^/api/exp/': ['/exp/'] + "^/api/v1/$": ["/v1"], + "^/api/v1/": ["/", "/v1/"], + "^/api/v1/executions": ["/actionexecutions", "/v1/actionexecutions"], + "^/api/exp/": ["/exp/"], } router.add_spec(spec, transforms=transforms) app = router.as_wsgi # Order is important. Check middleware for detailed explanation. - app = StreamingMiddleware(app, path_whitelist=['/v1/executions/*/output*']) + app = StreamingMiddleware(app, path_whitelist=["/v1/executions/*/output*"]) app = ErrorHandlingMiddleware(app) app = CorsMiddleware(app) app = LoggingMiddleware(app, router) - app = ResponseInstrumentationMiddleware(app, router, service_name='api') + app = ResponseInstrumentationMiddleware(app, router, service_name="api") app = RequestIDMiddleware(app) - app = RequestInstrumentationMiddleware(app, router, service_name='api') + app = RequestInstrumentationMiddleware(app, router, service_name="api") return app diff --git a/st2api/st2api/cmd/__init__.py b/st2api/st2api/cmd/__init__.py index 4e28bca4330..0b9307922ae 100644 --- a/st2api/st2api/cmd/__init__.py +++ b/st2api/st2api/cmd/__init__.py @@ -15,4 +15,4 @@ from st2api.cmd import api -__all__ = ['api'] +__all__ = ["api"] diff --git a/st2api/st2api/cmd/api.py b/st2api/st2api/cmd/api.py index 73d35204446..1cf01d0544b 100644 --- a/st2api/st2api/cmd/api.py +++ b/st2api/st2api/cmd/api.py @@ -21,6 +21,7 @@ # See https://github.com/StackStorm/st2/issues/4832 and https://github.com/gevent/gevent/issues/1016 # for details. from st2common.util.monkey_patch import monkey_patch + monkey_patch() import eventlet @@ -31,14 +32,13 @@ from st2common.service_setup import setup as common_setup from st2common.service_setup import teardown as common_teardown from st2api import config + config.register_opts() from st2api import app from st2api.validation import validate_rbac_is_correctly_configured -__all__ = [ - 'main' -] +__all__ = ["main"] LOG = logging.getLogger(__name__) @@ -48,15 +48,22 @@ def _setup(): capabilities = { - 'name': 'api', - 'listen_host': cfg.CONF.api.host, - 'listen_port': cfg.CONF.api.port, - 'type': 'active' + "name": "api", + "listen_host": cfg.CONF.api.host, + "listen_port": cfg.CONF.api.port, + "type": "active", } - common_setup(service='api', config=config, setup_db=True, register_mq_exchanges=True, - register_signal_handlers=True, register_internal_trigger_types=True, - service_registry=True, capabilities=capabilities) + common_setup( + service="api", + config=config, + setup_db=True, + register_mq_exchanges=True, + register_signal_handlers=True, + register_internal_trigger_types=True, + service_registry=True, + capabilities=capabilities, + ) # Additional pre-run time checks validate_rbac_is_correctly_configured() @@ -66,13 +73,15 @@ def _run_server(): host = cfg.CONF.api.host port = cfg.CONF.api.port - LOG.info('(PID=%s) ST2 API is serving on http://%s:%s.', os.getpid(), host, port) + LOG.info("(PID=%s) ST2 API is serving on http://%s:%s.", os.getpid(), host, port) max_pool_size = eventlet.wsgi.DEFAULT_MAX_SIMULTANEOUS_REQUESTS worker_pool = eventlet.GreenPool(max_pool_size) sock = eventlet.listen((host, port)) - wsgi.server(sock, app.setup_app(), custom_pool=worker_pool, log=LOG, log_output=False) + wsgi.server( + sock, app.setup_app(), custom_pool=worker_pool, log=LOG, log_output=False + ) return 0 @@ -87,7 +96,7 @@ def main(): except SystemExit as exit_code: sys.exit(exit_code) except Exception: - LOG.exception('(PID=%s) ST2 API quit due to exception.', os.getpid()) + LOG.exception("(PID=%s) ST2 API quit due to exception.", os.getpid()) return 1 finally: _teardown() diff --git a/st2api/st2api/config.py b/st2api/st2api/config.py index 71378da9ad6..35a21d87d5b 100644 --- a/st2api/st2api/config.py +++ b/st2api/st2api/config.py @@ -32,8 +32,11 @@ def parse_args(args=None): - cfg.CONF(args=args, version=VERSION_STRING, - default_config_files=[DEFAULT_CONFIG_FILE_PATH]) + cfg.CONF( + args=args, + version=VERSION_STRING, + default_config_files=[DEFAULT_CONFIG_FILE_PATH], + ) def register_opts(): @@ -52,32 +55,38 @@ def get_logging_config_path(): def _register_app_opts(): # Note "host", "port", "allow_origin", "mask_secrets" options are registered as part of # st2common config since they are also used outside st2api - static_root = os.path.join(cfg.CONF.system.base_path, 'static') - template_path = os.path.join(BASE_DIR, 'templates/') + static_root = os.path.join(cfg.CONF.system.base_path, "static") + template_path = os.path.join(BASE_DIR, "templates/") pecan_opts = [ cfg.StrOpt( - 'root', default='st2api.controllers.root.RootController', - help='Action root controller'), - cfg.StrOpt('static_root', default=static_root), - cfg.StrOpt('template_path', default=template_path), - cfg.ListOpt('modules', default=['st2api']), - cfg.BoolOpt('debug', default=False), - cfg.BoolOpt('auth_enable', default=True), - cfg.DictOpt('errors', default={'__force_dict__': True}) + "root", + default="st2api.controllers.root.RootController", + help="Action root controller", + ), + cfg.StrOpt("static_root", default=static_root), + cfg.StrOpt("template_path", default=template_path), + cfg.ListOpt("modules", default=["st2api"]), + cfg.BoolOpt("debug", default=False), + cfg.BoolOpt("auth_enable", default=True), + cfg.DictOpt("errors", default={"__force_dict__": True}), ] - CONF.register_opts(pecan_opts, group='api_pecan') + CONF.register_opts(pecan_opts, group="api_pecan") logging_opts = [ - cfg.BoolOpt('debug', default=False), + cfg.BoolOpt("debug", default=False), cfg.StrOpt( - 'logging', default='/etc/st2/logging.api.conf', - help='location of the logging.conf file'), + "logging", + default="/etc/st2/logging.api.conf", + help="location of the logging.conf file", + ), cfg.IntOpt( - 'max_page_size', default=100, - help='Maximum limit (page size) argument which can be ' - 'specified by the user in a query string.') + "max_page_size", + default=100, + help="Maximum limit (page size) argument which can be " + "specified by the user in a query string.", + ), ] - CONF.register_opts(logging_opts, group='api') + CONF.register_opts(logging_opts, group="api") diff --git a/st2api/st2api/controllers/base.py b/st2api/st2api/controllers/base.py index a3f24e2f0fe..e4f13d8f1e9 100644 --- a/st2api/st2api/controllers/base.py +++ b/st2api/st2api/controllers/base.py @@ -20,9 +20,7 @@ from st2api.controllers.controller_transforms import transform_to_bool from st2common.rbac.backends import get_rbac_backend -__all__ = [ - 'BaseRestControllerMixin' -] +__all__ = ["BaseRestControllerMixin"] class BaseRestControllerMixin(object): @@ -41,7 +39,9 @@ def _parse_query_params(self, request): return query_params - def _get_query_param_value(self, request, param_name, param_type, default_value=None): + def _get_query_param_value( + self, request, param_name, param_type, default_value=None + ): """ Return a value for the provided query param and optionally cast it for boolean types. @@ -61,7 +61,7 @@ def _get_query_param_value(self, request, param_name, param_type, default_value= query_params = self._parse_query_params(request=request) value = query_params.get(param_name, default_value) - if param_type == 'bool' and isinstance(value, six.string_types): + if param_type == "bool" and isinstance(value, six.string_types): value = transform_to_bool(value) return value diff --git a/st2api/st2api/controllers/controller_transforms.py b/st2api/st2api/controllers/controller_transforms.py index 8afff88da6b..0ca51a0a753 100644 --- a/st2api/st2api/controllers/controller_transforms.py +++ b/st2api/st2api/controllers/controller_transforms.py @@ -14,9 +14,7 @@ # limitations under the License. -__all__ = [ - 'transform_to_bool' -] +__all__ = ["transform_to_bool"] def transform_to_bool(value): @@ -27,8 +25,8 @@ def transform_to_bool(value): Any other representation will be rejected. """ - if value in ['1', 'true', 'True', True]: + if value in ["1", "true", "True", True]: return True - elif value in ['0', 'false', 'False', False]: + elif value in ["0", "false", "False", False]: return False raise ValueError('Invalid bool representation "%s" provided.' % value) diff --git a/st2api/st2api/controllers/resource.py b/st2api/st2api/controllers/resource.py index 72611a90dc8..a2391ff9aab 100644 --- a/st2api/st2api/controllers/resource.py +++ b/st2api/st2api/controllers/resource.py @@ -35,21 +35,19 @@ LOG = logging.getLogger(__name__) -RESERVED_QUERY_PARAMS = { - 'id': 'id', - 'name': 'name', - 'sort': 'order_by' -} +RESERVED_QUERY_PARAMS = {"id": "id", "name": "name", "sort": "order_by"} def split_id_value(value): if not value or isinstance(value, (list, tuple)): return value - split = value.split(',') + split = value.split(",") if len(split) > 100: - raise ValueError('Maximum of 100 items can be provided for a query parameter value') + raise ValueError( + "Maximum of 100 items can be provided for a query parameter value" + ) return split @@ -57,7 +55,7 @@ def split_id_value(value): DEFAULT_FILTER_TRANSFORM_FUNCTIONS = { # Support for filtering on multiple ids when a commona delimited string is provided # (e.g. ?id=1,2,3) - 'id': split_id_value + "id": split_id_value } @@ -65,14 +63,14 @@ def parameter_validation(validator, properties, instance, schema): parameter_specific_schema = { "description": "Input parameters for the action.", "type": "object", - "patternProperties": { - r"^\w+$": util_schema.get_action_parameters_schema() - }, - 'additionalProperties': False, - "default": {} + "patternProperties": {r"^\w+$": util_schema.get_action_parameters_schema()}, + "additionalProperties": False, + "default": {}, } - parameter_specific_validator = util_schema.CustomValidator(parameter_specific_schema) + parameter_specific_validator = util_schema.CustomValidator( + parameter_specific_schema + ) for error in parameter_specific_validator.iter_errors(instance=instance): yield error @@ -91,18 +89,16 @@ class ResourceController(object): # ?include_attributes filter. Those attributes need to be included because a lot of code # depends on compound references and primary keys. In addition to that, it's needed for secrets # masking to work, etc. - mandatory_include_fields_retrieve = ['id'] + mandatory_include_fields_retrieve = ["id"] # A list of fields which are always included in the response when ?include_attributes filter is # used. Those are things such as primary keys and similar. - mandatory_include_fields_response = ['id'] + mandatory_include_fields_response = ["id"] # Default number of items returned per page if no limit is explicitly provided default_limit = 100 - query_options = { - 'sort': [] - } + query_options = {"sort": []} # A list of optional transformation functions for user provided filter values filter_transform_functions = {} @@ -120,7 +116,9 @@ def __init__(self): self.supported_filters = copy.deepcopy(self.__class__.supported_filters) self.supported_filters.update(RESERVED_QUERY_PARAMS) - self.filter_transform_functions = copy.deepcopy(self.__class__.filter_transform_functions) + self.filter_transform_functions = copy.deepcopy( + self.__class__.filter_transform_functions + ) self.filter_transform_functions.update(DEFAULT_FILTER_TRANSFORM_FUNCTIONS) self.get_one_db_method = self._get_by_name_or_id @@ -130,9 +128,19 @@ def __init__(self): def max_limit(self): return cfg.CONF.api.max_page_size - def _get_all(self, exclude_fields=None, include_fields=None, advanced_filters=None, - sort=None, offset=0, limit=None, query_options=None, - from_model_kwargs=None, raw_filters=None, requester_user=None): + def _get_all( + self, + exclude_fields=None, + include_fields=None, + advanced_filters=None, + sort=None, + offset=0, + limit=None, + query_options=None, + from_model_kwargs=None, + raw_filters=None, + requester_user=None, + ): """ :param exclude_fields: A list of object fields to exclude. :type exclude_fields: ``list`` @@ -144,8 +152,10 @@ def _get_all(self, exclude_fields=None, include_fields=None, advanced_filters=No query_options = query_options if query_options else self.query_options if exclude_fields and include_fields: - msg = ('exclude_fields and include_fields arguments are mutually exclusive. ' - 'You need to provide either one or another, but not both.') + msg = ( + "exclude_fields and include_fields arguments are mutually exclusive. " + "You need to provide either one or another, but not both." + ) raise ValueError(msg) exclude_fields = self._validate_exclude_fields(exclude_fields=exclude_fields) @@ -153,18 +163,18 @@ def _get_all(self, exclude_fields=None, include_fields=None, advanced_filters=No # TODO: Why do we use comma delimited string, user can just specify # multiple values using ?sort=foo&sort=bar and we get a list back - sort = sort.split(',') if sort else [] + sort = sort.split(",") if sort else [] db_sort_values = [] for sort_key in sort: - if sort_key.startswith('-'): - direction = '-' + if sort_key.startswith("-"): + direction = "-" sort_key = sort_key[1:] - elif sort_key.startswith('+'): - direction = '+' + elif sort_key.startswith("+"): + direction = "+" sort_key = sort_key[1:] else: - direction = '' + direction = "" if sort_key not in self.supported_filters: # Skip unsupported sort key @@ -173,12 +183,12 @@ def _get_all(self, exclude_fields=None, include_fields=None, advanced_filters=No sort_value = direction + self.supported_filters[sort_key] db_sort_values.append(sort_value) - default_sort_values = copy.copy(query_options.get('sort')) - raw_filters['sort'] = db_sort_values if db_sort_values else default_sort_values + default_sort_values = copy.copy(query_options.get("sort")) + raw_filters["sort"] = db_sort_values if db_sort_values else default_sort_values # TODO: To protect us from DoS, we need to make max_limit mandatory offset = int(offset) - if offset >= 2**31: + if offset >= 2 ** 31: raise ValueError('Offset "%s" specified is more than 32-bit int' % (offset)) limit = validate_limit_query_param(limit=limit, requester_user=requester_user) @@ -195,32 +205,35 @@ def _get_all(self, exclude_fields=None, include_fields=None, advanced_filters=No value_transform_function = value_transform_function or (lambda value: value) filter_value = value_transform_function(value=filter_value) - if k in ['id', 'name'] and isinstance(filter_value, list): - filters[k + '__in'] = filter_value + if k in ["id", "name"] and isinstance(filter_value, list): + filters[k + "__in"] = filter_value else: - field_name_split = v.split('.') + field_name_split = v.split(".") # Make sure filter value is a list when using "in" filter - if field_name_split[-1] == 'in' and not isinstance(filter_value, (list, tuple)): + if field_name_split[-1] == "in" and not isinstance( + filter_value, (list, tuple) + ): filter_value = [filter_value] - filters['__'.join(field_name_split)] = filter_value + filters["__".join(field_name_split)] = filter_value if advanced_filters: - for token in advanced_filters.split(' '): + for token in advanced_filters.split(" "): try: - [k, v] = token.split(':', 1) + [k, v] = token.split(":", 1) except ValueError: raise ValueError('invalid format for filter "%s"' % token) - path = k.split('.') + path = k.split(".") try: self.model.model._lookup_field(path) - filters['__'.join(path)] = v + filters["__".join(path)] = v except LookUpError as e: raise ValueError(six.text_type(e)) - instances = self.access.query(exclude_fields=exclude_fields, only_fields=include_fields, - **filters) + instances = self.access.query( + exclude_fields=exclude_fields, only_fields=include_fields, **filters + ) if limit == 1: # Perform the filtering on the DB side instances = instances.limit(limit) @@ -228,44 +241,65 @@ def _get_all(self, exclude_fields=None, include_fields=None, advanced_filters=No from_model_kwargs = from_model_kwargs or {} from_model_kwargs.update(self.from_model_kwargs) - result = self.resources_model_filter(model=self.model, - instances=instances, - offset=offset, - eop=eop, - requester_user=requester_user, - **from_model_kwargs) + result = self.resources_model_filter( + model=self.model, + instances=instances, + offset=offset, + eop=eop, + requester_user=requester_user, + **from_model_kwargs, + ) resp = Response(json=result) - resp.headers['X-Total-Count'] = str(instances.count()) + resp.headers["X-Total-Count"] = str(instances.count()) if limit: - resp.headers['X-Limit'] = str(limit) + resp.headers["X-Limit"] = str(limit) return resp - def resources_model_filter(self, model, instances, requester_user=None, offset=0, eop=0, - **from_model_kwargs): + def resources_model_filter( + self, + model, + instances, + requester_user=None, + offset=0, + eop=0, + **from_model_kwargs, + ): """ Method which converts DB objects to API objects and performs any additional filtering. """ result = [] for instance in instances[offset:eop]: - item = self.resource_model_filter(model=model, instance=instance, - requester_user=requester_user, - **from_model_kwargs) + item = self.resource_model_filter( + model=model, + instance=instance, + requester_user=requester_user, + **from_model_kwargs, + ) result.append(item) return result - def resource_model_filter(self, model, instance, requester_user=None, **from_model_kwargs): + def resource_model_filter( + self, model, instance, requester_user=None, **from_model_kwargs + ): """ Method which converts DB object to API object and performs any additional filtering. """ item = model.from_model(instance, **from_model_kwargs) return item - def _get_one_by_id(self, id, requester_user, permission_type, exclude_fields=None, - include_fields=None, from_model_kwargs=None): + def _get_one_by_id( + self, + id, + requester_user, + permission_type, + exclude_fields=None, + include_fields=None, + from_model_kwargs=None, + ): """ :param exclude_fields: A list of object fields to exclude. :type exclude_fields: ``list`` @@ -273,14 +307,17 @@ def _get_one_by_id(self, id, requester_user, permission_type, exclude_fields=Non :type include_fields: ``list`` """ - instance = self._get_by_id(resource_id=id, exclude_fields=exclude_fields, - include_fields=include_fields) + instance = self._get_by_id( + resource_id=id, exclude_fields=exclude_fields, include_fields=include_fields + ) if permission_type: rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=instance, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=instance, + permission_type=permission_type, + ) if not instance: msg = 'Unable to identify resource with id "%s".' % id @@ -289,21 +326,35 @@ def _get_one_by_id(self, id, requester_user, permission_type, exclude_fields=Non from_model_kwargs = from_model_kwargs or {} from_model_kwargs.update(self.from_model_kwargs) - result = self.resource_model_filter(model=self.model, instance=instance, - requester_user=requester_user, - **from_model_kwargs) + result = self.resource_model_filter( + model=self.model, + instance=instance, + requester_user=requester_user, + **from_model_kwargs, + ) if not result: - LOG.debug('Not returning the result because RBAC resource isolation is enabled and ' - 'current user doesn\'t match the resource user') - raise ResourceAccessDeniedPermissionIsolationError(user_db=requester_user, - resource_api_or_db=instance, - permission_type=permission_type) + LOG.debug( + "Not returning the result because RBAC resource isolation is enabled and " + "current user doesn't match the resource user" + ) + raise ResourceAccessDeniedPermissionIsolationError( + user_db=requester_user, + resource_api_or_db=instance, + permission_type=permission_type, + ) return result - def _get_one_by_name_or_id(self, name_or_id, requester_user, permission_type, - exclude_fields=None, include_fields=None, from_model_kwargs=None): + def _get_one_by_name_or_id( + self, + name_or_id, + requester_user, + permission_type, + exclude_fields=None, + include_fields=None, + from_model_kwargs=None, + ): """ :param exclude_fields: A list of object fields to exclude. :type exclude_fields: ``list`` @@ -311,14 +362,19 @@ def _get_one_by_name_or_id(self, name_or_id, requester_user, permission_type, :type include_fields: ``list`` """ - instance = self._get_by_name_or_id(name_or_id=name_or_id, exclude_fields=exclude_fields, - include_fields=include_fields) + instance = self._get_by_name_or_id( + name_or_id=name_or_id, + exclude_fields=exclude_fields, + include_fields=include_fields, + ) if permission_type: rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=instance, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=instance, + permission_type=permission_type, + ) if not instance: msg = 'Unable to identify resource with name_or_id "%s".' % (name_or_id) @@ -330,10 +386,14 @@ def _get_one_by_name_or_id(self, name_or_id, requester_user, permission_type, return result - def _get_one_by_pack_ref(self, pack_ref, exclude_fields=None, include_fields=None, - from_model_kwargs=None): - instance = self._get_by_pack_ref(pack_ref=pack_ref, exclude_fields=exclude_fields, - include_fields=include_fields) + def _get_one_by_pack_ref( + self, pack_ref, exclude_fields=None, include_fields=None, from_model_kwargs=None + ): + instance = self._get_by_pack_ref( + pack_ref=pack_ref, + exclude_fields=exclude_fields, + include_fields=include_fields, + ) if not instance: msg = 'Unable to identify resource with pack_ref "%s".' % (pack_ref) @@ -347,8 +407,11 @@ def _get_one_by_pack_ref(self, pack_ref, exclude_fields=None, include_fields=Non def _get_by_id(self, resource_id, exclude_fields=None, include_fields=None): try: - resource_db = self.access.get(id=resource_id, exclude_fields=exclude_fields, - only_fields=include_fields) + resource_db = self.access.get( + id=resource_id, + exclude_fields=exclude_fields, + only_fields=include_fields, + ) except ValidationError: resource_db = None @@ -356,8 +419,11 @@ def _get_by_id(self, resource_id, exclude_fields=None, include_fields=None): def _get_by_name(self, resource_name, exclude_fields=None, include_fields=None): try: - resource_db = self.access.get(name=resource_name, exclude_fields=exclude_fields, - only_fields=include_fields) + resource_db = self.access.get( + name=resource_name, + exclude_fields=exclude_fields, + only_fields=include_fields, + ) except Exception: resource_db = None @@ -365,8 +431,9 @@ def _get_by_name(self, resource_name, exclude_fields=None, include_fields=None): def _get_by_pack_ref(self, pack_ref, exclude_fields=None, include_fields=None): try: - resource_db = self.access.get(pack=pack_ref, exclude_fields=exclude_fields, - only_fields=include_fields) + resource_db = self.access.get( + pack=pack_ref, exclude_fields=exclude_fields, only_fields=include_fields + ) except Exception: resource_db = None @@ -376,13 +443,17 @@ def _get_by_name_or_id(self, name_or_id, exclude_fields=None, include_fields=Non """ Retrieve resource object by an id of a name. """ - resource_db = self._get_by_id(resource_id=name_or_id, exclude_fields=exclude_fields, - include_fields=include_fields) + resource_db = self._get_by_id( + resource_id=name_or_id, + exclude_fields=exclude_fields, + include_fields=include_fields, + ) if not resource_db: # Try name - resource_db = self._get_by_name(resource_name=name_or_id, - exclude_fields=exclude_fields) + resource_db = self._get_by_name( + resource_name=name_or_id, exclude_fields=exclude_fields + ) if not resource_db: msg = 'Resource with a name or id "%s" not found' % (name_or_id) @@ -402,11 +473,16 @@ def _get_one_by_scope_and_name(self, scope, name, from_model_kwargs=None): """ instance = self.access.get_by_scope_and_name(scope=scope, name=name) if not instance: - msg = 'KeyValuePair with name: %s and scope: %s not found in db.' % (name, scope) + msg = "KeyValuePair with name: %s and scope: %s not found in db." % ( + name, + scope, + ) raise StackStormDBObjectNotFoundError(msg) from_model_kwargs = from_model_kwargs or {} result = self.model.from_model(instance, **from_model_kwargs) - LOG.debug('GET with scope=%s and name=%s, client_result=%s', scope, name, result) + LOG.debug( + "GET with scope=%s and name=%s, client_result=%s", scope, name, result + ) return result @@ -422,7 +498,7 @@ def _validate_exclude_fields(self, exclude_fields): for field in exclude_fields: if field not in self.valid_exclude_attributes: - msg = ('Invalid or unsupported exclude attribute specified: %s' % (field)) + msg = "Invalid or unsupported exclude attribute specified: %s" % (field) raise ValueError(msg) return exclude_fields @@ -438,7 +514,7 @@ def _validate_include_fields(self, include_fields): for field in self.mandatory_include_fields_retrieve: # Don't add mandatory field if user already requested the whole dict object (e.g. user # requests action and action.parameters is a mandatory field) - partial_field = field.split('.')[0] + partial_field = field.split(".")[0] if partial_field in include_fields: continue @@ -456,20 +532,38 @@ class BaseResourceIsolationControllerMixin(object): users). """ - def resources_model_filter(self, model, instances, requester_user=None, offset=0, eop=0, - **from_model_kwargs): + def resources_model_filter( + self, + model, + instances, + requester_user=None, + offset=0, + eop=0, + **from_model_kwargs, + ): # RBAC or permission isolation is disabled, bail out if not (cfg.CONF.rbac.enable and cfg.CONF.rbac.permission_isolation): - result = super(BaseResourceIsolationControllerMixin, self).resources_model_filter( - model=model, instances=instances, requester_user=requester_user, - offset=offset, eop=eop, **from_model_kwargs) + result = super( + BaseResourceIsolationControllerMixin, self + ).resources_model_filter( + model=model, + instances=instances, + requester_user=requester_user, + offset=offset, + eop=eop, + **from_model_kwargs, + ) return result result = [] for instance in instances[offset:eop]: - item = self.resource_model_filter(model=model, instance=instance, - requester_user=requester_user, **from_model_kwargs) + item = self.resource_model_filter( + model=model, + instance=instance, + requester_user=requester_user, + **from_model_kwargs, + ) if not item: continue @@ -478,18 +572,25 @@ def resources_model_filter(self, model, instances, requester_user=None, offset=0 return result - def resource_model_filter(self, model, instance, requester_user=None, **from_model_kwargs): + def resource_model_filter( + self, model, instance, requester_user=None, **from_model_kwargs + ): # RBAC or permission isolation is disabled, bail out if not (cfg.CONF.rbac.enable and cfg.CONF.rbac.permission_isolation): - result = super(BaseResourceIsolationControllerMixin, self).resource_model_filter( - model=model, instance=instance, requester_user=requester_user, - **from_model_kwargs) + result = super( + BaseResourceIsolationControllerMixin, self + ).resource_model_filter( + model=model, + instance=instance, + requester_user=requester_user, + **from_model_kwargs, + ) return result rbac_utils = get_rbac_backend().get_utils_class() user_is_admin = rbac_utils.user_is_admin(user_db=requester_user) - user_is_system_user = (requester_user.name == cfg.CONF.system_user.user) + user_is_system_user = requester_user.name == cfg.CONF.system_user.user item = model.from_model(instance, **from_model_kwargs) @@ -497,7 +598,7 @@ def resource_model_filter(self, model, instance, requester_user=None, **from_mod if user_is_admin or user_is_system_user: return item - user = item.context.get('user', None) + user = item.context.get("user", None) if user and (user == requester_user.name): return item @@ -506,21 +607,31 @@ def resource_model_filter(self, model, instance, requester_user=None, **from_mod class ContentPackResourceController(ResourceController): # name and pack are mandatory because they compromise primary key - reference (.) - mandatory_include_fields_retrieve = ['pack', 'name'] + mandatory_include_fields_retrieve = ["pack", "name"] # A list of fields which are always included in the response. Those are things such as primary # keys and similar - mandatory_include_fields_response = ['id', 'ref'] + mandatory_include_fields_response = ["id", "ref"] def __init__(self): super(ContentPackResourceController, self).__init__() self.get_one_db_method = self._get_by_ref_or_id - def _get_one(self, ref_or_id, requester_user, permission_type, exclude_fields=None, - include_fields=None, from_model_kwargs=None): + def _get_one( + self, + ref_or_id, + requester_user, + permission_type, + exclude_fields=None, + include_fields=None, + from_model_kwargs=None, + ): try: - instance = self._get_by_ref_or_id(ref_or_id=ref_or_id, exclude_fields=exclude_fields, - include_fields=include_fields) + instance = self._get_by_ref_or_id( + ref_or_id=ref_or_id, + exclude_fields=exclude_fields, + include_fields=include_fields, + ) except Exception as e: LOG.exception(six.text_type(e)) abort(http_client.NOT_FOUND, six.text_type(e)) @@ -528,40 +639,59 @@ def _get_one(self, ref_or_id, requester_user, permission_type, exclude_fields=No if permission_type: rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=instance, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=instance, + permission_type=permission_type, + ) # Perform resource isolation check (if supported) from_model_kwargs = from_model_kwargs or {} from_model_kwargs.update(self.from_model_kwargs) - result = self.resource_model_filter(model=self.model, instance=instance, - requester_user=requester_user, - **from_model_kwargs) + result = self.resource_model_filter( + model=self.model, + instance=instance, + requester_user=requester_user, + **from_model_kwargs, + ) if not result: - LOG.debug('Not returning the result because RBAC resource isolation is enabled and ' - 'current user doesn\'t match the resource user') - raise ResourceAccessDeniedPermissionIsolationError(user_db=requester_user, - resource_api_or_db=instance, - permission_type=permission_type) + LOG.debug( + "Not returning the result because RBAC resource isolation is enabled and " + "current user doesn't match the resource user" + ) + raise ResourceAccessDeniedPermissionIsolationError( + user_db=requester_user, + resource_api_or_db=instance, + permission_type=permission_type, + ) return Response(json=result) - def _get_all(self, exclude_fields=None, include_fields=None, - sort=None, offset=0, limit=None, query_options=None, - from_model_kwargs=None, raw_filters=None, requester_user=None): - resp = super(ContentPackResourceController, - self)._get_all(exclude_fields=exclude_fields, - include_fields=include_fields, - sort=sort, - offset=offset, - limit=limit, - query_options=query_options, - from_model_kwargs=from_model_kwargs, - raw_filters=raw_filters, - requester_user=requester_user) + def _get_all( + self, + exclude_fields=None, + include_fields=None, + sort=None, + offset=0, + limit=None, + query_options=None, + from_model_kwargs=None, + raw_filters=None, + requester_user=None, + ): + resp = super(ContentPackResourceController, self)._get_all( + exclude_fields=exclude_fields, + include_fields=include_fields, + sort=sort, + offset=offset, + limit=limit, + query_options=query_options, + from_model_kwargs=from_model_kwargs, + raw_filters=raw_filters, + requester_user=requester_user, + ) return resp @@ -574,8 +704,10 @@ def _get_by_ref_or_id(self, ref_or_id, exclude_fields=None, include_fields=None) """ if exclude_fields and include_fields: - msg = ('exclude_fields and include_fields arguments are mutually exclusive. ' - 'You need to provide either one or another, but not both.') + msg = ( + "exclude_fields and include_fields arguments are mutually exclusive. " + "You need to provide either one or another, but not both." + ) raise ValueError(msg) if ResourceReference.is_resource_reference(ref_or_id): @@ -585,11 +717,17 @@ def _get_by_ref_or_id(self, ref_or_id, exclude_fields=None, include_fields=None) is_reference = False if is_reference: - resource_db = self._get_by_ref(resource_ref=ref_or_id, exclude_fields=exclude_fields, - include_fields=include_fields) + resource_db = self._get_by_ref( + resource_ref=ref_or_id, + exclude_fields=exclude_fields, + include_fields=include_fields, + ) else: - resource_db = self._get_by_id(resource_id=ref_or_id, exclude_fields=exclude_fields, - include_fields=include_fields) + resource_db = self._get_by_id( + resource_id=ref_or_id, + exclude_fields=exclude_fields, + include_fields=include_fields, + ) if not resource_db: msg = 'Resource with a reference or id "%s" not found' % (ref_or_id) @@ -599,8 +737,10 @@ def _get_by_ref_or_id(self, ref_or_id, exclude_fields=None, include_fields=None) def _get_by_ref(self, resource_ref, exclude_fields=None, include_fields=None): if exclude_fields and include_fields: - msg = ('exclude_fields and include_fields arguments are mutually exclusive. ' - 'You need to provide either one or another, but not both.') + msg = ( + "exclude_fields and include_fields arguments are mutually exclusive. " + "You need to provide either one or another, but not both." + ) raise ValueError(msg) try: @@ -608,9 +748,12 @@ def _get_by_ref(self, resource_ref, exclude_fields=None, include_fields=None): except Exception: return None - resource_db = self.access.query(name=ref.name, pack=ref.pack, - exclude_fields=exclude_fields, - only_fields=include_fields).first() + resource_db = self.access.query( + name=ref.name, + pack=ref.pack, + exclude_fields=exclude_fields, + only_fields=include_fields, + ).first() return resource_db @@ -629,25 +772,29 @@ def validate_limit_query_param(limit, requester_user=None): if int(limit) == -1: if not user_is_admin: # Only admins can specify limit -1 - message = ('Administrator access required to be able to specify limit=-1 and ' - 'retrieve all the records') - raise AccessDeniedError(message=message, - user_db=requester_user) + message = ( + "Administrator access required to be able to specify limit=-1 and " + "retrieve all the records" + ) + raise AccessDeniedError(message=message, user_db=requester_user) return 0 elif int(limit) <= -2: msg = 'Limit, "%s" specified, must be a positive number.' % (limit) raise ValueError(msg) elif int(limit) > cfg.CONF.api.max_page_size and not user_is_admin: - msg = ('Limit "%s" specified, maximum value is "%s"' % (limit, - cfg.CONF.api.max_page_size)) + msg = 'Limit "%s" specified, maximum value is "%s"' % ( + limit, + cfg.CONF.api.max_page_size, + ) - raise AccessDeniedError(message=msg, - user_db=requester_user) + raise AccessDeniedError(message=msg, user_db=requester_user) # Disable n = 0 elif limit == 0: - msg = ('Limit, "%s" specified, must be a positive number or -1 for full result set.' % - (limit)) + msg = ( + 'Limit, "%s" specified, must be a positive number or -1 for full result set.' + % (limit) + ) raise ValueError(msg) return limit diff --git a/st2api/st2api/controllers/root.py b/st2api/st2api/controllers/root.py index c2db487b029..2d5d953afab 100644 --- a/st2api/st2api/controllers/root.py +++ b/st2api/st2api/controllers/root.py @@ -15,23 +15,21 @@ from st2common import __version__ -__all__ = [ - 'RootController' -] +__all__ = ["RootController"] class RootController(object): def index(self): data = {} - if 'dev' in __version__: - docs_url = 'http://docs.stackstorm.com/latest' + if "dev" in __version__: + docs_url = "http://docs.stackstorm.com/latest" else: - docs_version = '.'.join(__version__.split('.')[:2]) - docs_url = 'http://docs.stackstorm.com/%s' % docs_version + docs_version = ".".join(__version__.split(".")[:2]) + docs_url = "http://docs.stackstorm.com/%s" % docs_version - data['version'] = __version__ - data['docs_url'] = docs_url + data["version"] = __version__ + data["docs_url"] = docs_url return data diff --git a/st2api/st2api/controllers/v1/action_views.py b/st2api/st2api/controllers/v1/action_views.py index d1701ebfbf8..2e528b5b13d 100644 --- a/st2api/st2api/controllers/v1/action_views.py +++ b/st2api/st2api/controllers/v1/action_views.py @@ -33,11 +33,7 @@ from st2common.router import abort from st2common.router import Response -__all__ = [ - 'OverviewController', - 'ParametersViewController', - 'EntryPointController' -] +__all__ = ["OverviewController", "ParametersViewController", "EntryPointController"] http_client = six.moves.http_client @@ -45,7 +41,6 @@ class LookupUtils(object): - @staticmethod def _get_action_by_id(id): try: @@ -75,31 +70,33 @@ def _get_runner_by_name(name): class ParametersViewController(object): - def get_one(self, action_id, requester_user): return self._get_one(action_id, requester_user=requester_user) @staticmethod def _get_one(action_id, requester_user): """ - List merged action & runner parameters by action id. + List merged action & runner parameters by action id. - Handle: - GET /actions/views/parameters/1 + Handle: + GET /actions/views/parameters/1 """ action_db = LookupUtils._get_action_by_id(action_id) permission_type = PermissionType.ACTION_VIEW rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=action_db, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=action_db, + permission_type=permission_type, + ) - runner_db = LookupUtils._get_runner_by_name(action_db.runner_type['name']) + runner_db = LookupUtils._get_runner_by_name(action_db.runner_type["name"]) all_params = action_param_utils.get_params_view( - action_db=action_db, runner_db=runner_db, merged_only=True) + action_db=action_db, runner_db=runner_db, merged_only=True + ) - return {'parameters': all_params} + return {"parameters": all_params} class OverviewController(resource.ContentPackResourceController): @@ -107,47 +104,54 @@ class OverviewController(resource.ContentPackResourceController): access = Action supported_filters = {} - query_options = { - 'sort': ['pack', 'name'] - } + query_options = {"sort": ["pack", "name"]} - mandatory_include_fields_retrieve = [ - 'pack', - 'name', - 'parameters', - 'runner_type' - ] + mandatory_include_fields_retrieve = ["pack", "name", "parameters", "runner_type"] def get_one(self, ref_or_id, requester_user): """ - List action by id. + List action by id. - Handle: - GET /actions/views/overview/1 + Handle: + GET /actions/views/overview/1 """ - resp = super(OverviewController, self)._get_one(ref_or_id, - requester_user=requester_user, - permission_type=PermissionType.ACTION_VIEW) + resp = super(OverviewController, self)._get_one( + ref_or_id, + requester_user=requester_user, + permission_type=PermissionType.ACTION_VIEW, + ) action_api = ActionAPI(**resp.json) - result = self._transform_action_api(action_api=action_api, requester_user=requester_user) + result = self._transform_action_api( + action_api=action_api, requester_user=requester_user + ) resp.json = result return resp - def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0, - limit=None, requester_user=None, **raw_filters): + def get_all( + self, + exclude_attributes=None, + include_attributes=None, + sort=None, + offset=0, + limit=None, + requester_user=None, + **raw_filters, + ): """ - List all actions. + List all actions. - Handles requests: - GET /actions/views/overview + Handles requests: + GET /actions/views/overview """ - resp = super(OverviewController, self)._get_all(exclude_fields=exclude_attributes, - include_fields=include_attributes, - sort=sort, - offset=offset, - limit=limit, - raw_filters=raw_filters, - requester_user=requester_user) + resp = super(OverviewController, self)._get_all( + exclude_fields=exclude_attributes, + include_fields=include_attributes, + sort=sort, + offset=offset, + limit=limit, + raw_filters=raw_filters, + requester_user=requester_user, + ) runner_type_names = set([]) action_ids = [] @@ -164,9 +168,12 @@ def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, o # N * 2 additional queries # 1. Retrieve all the respective runner objects - we only need parameters - runner_type_dbs = RunnerType.query(name__in=runner_type_names, - only_fields=['name', 'runner_parameters']) - runner_type_dbs = dict([(runner_db.name, runner_db) for runner_db in runner_type_dbs]) + runner_type_dbs = RunnerType.query( + name__in=runner_type_names, only_fields=["name", "runner_parameters"] + ) + runner_type_dbs = dict( + [(runner_db.name, runner_db) for runner_db in runner_type_dbs] + ) # 2. Retrieve all the respective action objects - we only need parameters action_dbs = dict([(action_db.id, action_db) for action_db in result]) @@ -174,9 +181,9 @@ def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, o for action_api in result: action_db = action_dbs.get(action_api.id, None) runner_db = runner_type_dbs.get(action_api.runner_type, None) - all_params = action_param_utils.get_params_view(action_db=action_db, - runner_db=runner_db, - merged_only=True) + all_params = action_param_utils.get_params_view( + action_db=action_db, runner_db=runner_db, merged_only=True + ) action_api.parameters = all_params resp.json = result @@ -185,9 +192,10 @@ def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, o @staticmethod def _transform_action_api(action_api, requester_user): action_id = action_api.id - result = ParametersViewController._get_one(action_id=action_id, - requester_user=requester_user) - action_api.parameters = result.get('parameters', {}) + result = ParametersViewController._get_one( + action_id=action_id, requester_user=requester_user + ) + action_api.parameters = result.get("parameters", {}) return action_api @@ -202,35 +210,38 @@ def get_all(self): def get_one(self, ref_or_id, requester_user): """ - Outputs the file associated with action entry_point + Outputs the file associated with action entry_point - Handles requests: - GET /actions/views/entry_point/1 + Handles requests: + GET /actions/views/entry_point/1 """ - LOG.info('GET /actions/views/entry_point with ref_or_id=%s', ref_or_id) + LOG.info("GET /actions/views/entry_point with ref_or_id=%s", ref_or_id) action_db = self._get_by_ref_or_id(ref_or_id=ref_or_id) permission_type = PermissionType.ACTION_VIEW rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=action_db, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=action_db, + permission_type=permission_type, + ) - pack = getattr(action_db, 'pack', None) - entry_point = getattr(action_db, 'entry_point', None) + pack = getattr(action_db, "pack", None) + entry_point = getattr(action_db, "entry_point", None) abs_path = utils.get_entry_point_abs_path(pack, entry_point) if not abs_path: - raise StackStormDBObjectNotFoundError('Action ref_or_id=%s has no entry_point to output' - % ref_or_id) + raise StackStormDBObjectNotFoundError( + "Action ref_or_id=%s has no entry_point to output" % ref_or_id + ) - with codecs.open(abs_path, 'r') as fp: + with codecs.open(abs_path, "r") as fp: content = fp.read() # Ensure content is utf-8 if isinstance(content, six.binary_type): - content = content.decode('utf-8') + content = content.decode("utf-8") try: content_type = mimetypes.guess_type(abs_path)[0] @@ -240,15 +251,15 @@ def get_one(self, ref_or_id, requester_user): # Special case if /etc/mime.types doesn't contain entry for yaml, py if not content_type: _, extension = os.path.splitext(abs_path) - if extension in ['.yaml', '.yml']: - content_type = 'application/x-yaml' - elif extension in ['.py']: - content_type = 'application/x-python' + if extension in [".yaml", ".yml"]: + content_type = "application/x-yaml" + elif extension in [".py"]: + content_type = "application/x-python" else: - content_type = 'text/plain' + content_type = "text/plain" response = Response() - response.headers['Content-Type'] = content_type + response.headers["Content-Type"] = content_type response.text = content return response diff --git a/st2api/st2api/controllers/v1/actionalias.py b/st2api/st2api/controllers/v1/actionalias.py index 00e58675f95..5488300d6ed 100644 --- a/st2api/st2api/controllers/v1/actionalias.py +++ b/st2api/st2api/controllers/v1/actionalias.py @@ -37,175 +37,219 @@ class ActionAliasController(resource.ContentPackResourceController): """ - Implements the RESTful interface for ActionAliases. + Implements the RESTful interface for ActionAliases. """ + model = ActionAliasAPI access = ActionAlias - supported_filters = { - 'name': 'name', - 'pack': 'pack' - } - - query_options = { - 'sort': ['pack', 'name'] - } - - _custom_actions = { - 'match': ['POST'], - 'help': ['POST'] - } - - def get_all(self, exclude_attributes=None, include_attributes=None, - sort=None, offset=0, limit=None, requester_user=None, **raw_filters): - return super(ActionAliasController, self)._get_all(exclude_fields=exclude_attributes, - include_fields=include_attributes, - sort=sort, - offset=offset, - limit=limit, - raw_filters=raw_filters, - requester_user=requester_user) + supported_filters = {"name": "name", "pack": "pack"} + + query_options = {"sort": ["pack", "name"]} + + _custom_actions = {"match": ["POST"], "help": ["POST"]} + + def get_all( + self, + exclude_attributes=None, + include_attributes=None, + sort=None, + offset=0, + limit=None, + requester_user=None, + **raw_filters, + ): + return super(ActionAliasController, self)._get_all( + exclude_fields=exclude_attributes, + include_fields=include_attributes, + sort=sort, + offset=offset, + limit=limit, + raw_filters=raw_filters, + requester_user=requester_user, + ) def get_one(self, ref_or_id, requester_user): permission_type = PermissionType.ACTION_ALIAS_VIEW - return super(ActionAliasController, self)._get_one(ref_or_id, - requester_user=requester_user, - permission_type=permission_type) + return super(ActionAliasController, self)._get_one( + ref_or_id, requester_user=requester_user, permission_type=permission_type + ) def match(self, action_alias_match_api): """ - Find a matching action alias. + Find a matching action alias. - Handles requests: - POST /actionalias/match + Handles requests: + POST /actionalias/match """ command = action_alias_match_api.command try: format_ = get_matching_alias(command=command) except ActionAliasAmbiguityException as e: - LOG.exception('Command "%s" matched (%s) patterns.', e.command, len(e.matches)) + LOG.exception( + 'Command "%s" matched (%s) patterns.', e.command, len(e.matches) + ) return abort(http_client.BAD_REQUEST, six.text_type(e)) # Convert ActionAliasDB to API - action_alias_api = ActionAliasAPI.from_model(format_['alias']) + action_alias_api = ActionAliasAPI.from_model(format_["alias"]) return { - 'actionalias': action_alias_api, - 'display': format_['display'], - 'representation': format_['representation'], + "actionalias": action_alias_api, + "display": format_["display"], + "representation": format_["representation"], } def help(self, filter, pack, limit, offset, **kwargs): """ - Get available help strings for action aliases. + Get available help strings for action aliases. - Handles requests: - GET /actionalias/help + Handles requests: + GET /actionalias/help """ try: aliases_resp = super(ActionAliasController, self)._get_all(**kwargs) aliases = [ActionAliasAPI(**alias) for alias in aliases_resp.json] - return generate_helpstring_result(aliases, filter, pack, int(limit), int(offset)) + return generate_helpstring_result( + aliases, filter, pack, int(limit), int(offset) + ) except (TypeError) as e: - LOG.exception('Helpstring request contains an invalid data type: %s.', six.text_type(e)) + LOG.exception( + "Helpstring request contains an invalid data type: %s.", + six.text_type(e), + ) return abort(http_client.BAD_REQUEST, six.text_type(e)) def post(self, action_alias, requester_user): """ - Create a new ActionAlias. + Create a new ActionAlias. - Handles requests: - POST /actionalias/ + Handles requests: + POST /actionalias/ """ permission_type = PermissionType.ACTION_ALIAS_CREATE rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_api_permission(user_db=requester_user, - resource_api=action_alias, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_api_permission( + user_db=requester_user, + resource_api=action_alias, + permission_type=permission_type, + ) try: action_alias_db = ActionAliasAPI.to_model(action_alias) - LOG.debug('/actionalias/ POST verified ActionAliasAPI and formulated ActionAliasDB=%s', - action_alias_db) + LOG.debug( + "/actionalias/ POST verified ActionAliasAPI and formulated ActionAliasDB=%s", + action_alias_db, + ) action_alias_db = ActionAlias.add_or_update(action_alias_db) except (ValidationError, ValueError, ValueValidationException) as e: - LOG.exception('Validation failed for action alias data=%s.', action_alias) + LOG.exception("Validation failed for action alias data=%s.", action_alias) abort(http_client.BAD_REQUEST, six.text_type(e)) return - extra = {'action_alias_db': action_alias_db} - LOG.audit('Action alias created. ActionAlias.id=%s' % (action_alias_db.id), extra=extra) + extra = {"action_alias_db": action_alias_db} + LOG.audit( + "Action alias created. ActionAlias.id=%s" % (action_alias_db.id), + extra=extra, + ) action_alias_api = ActionAliasAPI.from_model(action_alias_db) return Response(json=action_alias_api, status=http_client.CREATED) def put(self, action_alias, ref_or_id, requester_user): """ - Update an action alias. + Update an action alias. - Handles requests: - PUT /actionalias/1 + Handles requests: + PUT /actionalias/1 """ action_alias_db = self._get_by_ref_or_id(ref_or_id=ref_or_id) - LOG.debug('PUT /actionalias/ lookup with id=%s found object: %s', ref_or_id, - action_alias_db) + LOG.debug( + "PUT /actionalias/ lookup with id=%s found object: %s", + ref_or_id, + action_alias_db, + ) permission_type = PermissionType.ACTION_ALIAS_MODIFY rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=action_alias_db, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=action_alias_db, + permission_type=permission_type, + ) - if not hasattr(action_alias, 'id'): + if not hasattr(action_alias, "id"): action_alias.id = None try: - if action_alias.id is not None and action_alias.id != '' and \ - action_alias.id != ref_or_id: - LOG.warning('Discarding mismatched id=%s found in payload and using uri_id=%s.', - action_alias.id, ref_or_id) + if ( + action_alias.id is not None + and action_alias.id != "" + and action_alias.id != ref_or_id + ): + LOG.warning( + "Discarding mismatched id=%s found in payload and using uri_id=%s.", + action_alias.id, + ref_or_id, + ) old_action_alias_db = action_alias_db action_alias_db = ActionAliasAPI.to_model(action_alias) action_alias_db.id = ref_or_id action_alias_db = ActionAlias.add_or_update(action_alias_db) except (ValidationError, ValueError) as e: - LOG.exception('Validation failed for action alias data=%s', action_alias) + LOG.exception("Validation failed for action alias data=%s", action_alias) abort(http_client.BAD_REQUEST, six.text_type(e)) return - extra = {'old_action_alias_db': old_action_alias_db, 'new_action_alias_db': action_alias_db} - LOG.audit('Action alias updated. ActionAlias.id=%s.' % (action_alias_db.id), extra=extra) + extra = { + "old_action_alias_db": old_action_alias_db, + "new_action_alias_db": action_alias_db, + } + LOG.audit( + "Action alias updated. ActionAlias.id=%s." % (action_alias_db.id), + extra=extra, + ) action_alias_api = ActionAliasAPI.from_model(action_alias_db) return action_alias_api def delete(self, ref_or_id, requester_user): """ - Delete an action alias. + Delete an action alias. - Handles requests: - DELETE /actionalias/1 + Handles requests: + DELETE /actionalias/1 """ action_alias_db = self._get_by_ref_or_id(ref_or_id=ref_or_id) - LOG.debug('DELETE /actionalias/ lookup with id=%s found object: %s', ref_or_id, - action_alias_db) + LOG.debug( + "DELETE /actionalias/ lookup with id=%s found object: %s", + ref_or_id, + action_alias_db, + ) permission_type = PermissionType.ACTION_ALIAS_DELETE rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=action_alias_db, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=action_alias_db, + permission_type=permission_type, + ) try: ActionAlias.delete(action_alias_db) except Exception as e: - LOG.exception('Database delete encountered exception during delete of id="%s".', - ref_or_id) + LOG.exception( + 'Database delete encountered exception during delete of id="%s".', + ref_or_id, + ) abort(http_client.INTERNAL_SERVER_ERROR, six.text_type(e)) return - extra = {'action_alias_db': action_alias_db} - LOG.audit('Action alias deleted. ActionAlias.id=%s.' % (action_alias_db.id), extra=extra) + extra = {"action_alias_db": action_alias_db} + LOG.audit( + "Action alias deleted. ActionAlias.id=%s." % (action_alias_db.id), + extra=extra, + ) return Response(status=http_client.NO_CONTENT) diff --git a/st2api/st2api/controllers/v1/actionexecutions.py b/st2api/st2api/controllers/v1/actionexecutions.py index b0aa4e9e1d4..3cc7741b2d3 100644 --- a/st2api/st2api/controllers/v1/actionexecutions.py +++ b/st2api/st2api/controllers/v1/actionexecutions.py @@ -54,18 +54,15 @@ from st2common.rbac.types import PermissionType from st2common.rbac.backends import get_rbac_backend -__all__ = [ - 'ActionExecutionsController' -] +__all__ = ["ActionExecutionsController"] LOG = logging.getLogger(__name__) # Note: We initialize filters here and not in the constructor SUPPORTED_EXECUTIONS_FILTERS = copy.deepcopy(SUPPORTED_FILTERS) -SUPPORTED_EXECUTIONS_FILTERS.update({ - 'timestamp_gt': 'start_timestamp.gt', - 'timestamp_lt': 'start_timestamp.lt' -}) +SUPPORTED_EXECUTIONS_FILTERS.update( + {"timestamp_gt": "start_timestamp.gt", "timestamp_lt": "start_timestamp.lt"} +) MONITOR_THREAD_EMPTY_Q_SLEEP_TIME = 5 MONITOR_THREAD_NO_WORKERS_SLEEP_TIME = 1 @@ -82,29 +79,24 @@ class ActionExecutionsControllerMixin(BaseRestControllerMixin): # Those two attributes are mandatory so we can correctly determine and mask secret execution # parameters mandatory_include_fields_retrieve = [ - 'action.parameters', - 'runner.runner_parameters', - 'parameters', - + "action.parameters", + "runner.runner_parameters", + "parameters", # Attributes below are mandatory for RBAC installations - 'action.pack', - 'action.uid', - + "action.pack", + "action.uid", # Required when rbac.permission_isolation is enabled - 'context' + "context", ] # A list of attributes which can be specified using ?exclude_attributes filter # NOTE: Allowing user to exclude attribute such as action and runner would break secrets # masking - valid_exclude_attributes = [ - 'result', - 'trigger_instance', - 'status' - ] + valid_exclude_attributes = ["result", "trigger_instance", "status"] - def _handle_schedule_execution(self, liveaction_api, requester_user, context_string=None, - show_secrets=False): + def _handle_schedule_execution( + self, liveaction_api, requester_user, context_string=None, show_secrets=False + ): """ :param liveaction: LiveActionAPI object. :type liveaction: :class:`LiveActionAPI` @@ -124,101 +116,129 @@ def _handle_schedule_execution(self, liveaction_api, requester_user, context_str # Assert the permissions permission_type = PermissionType.ACTION_EXECUTE rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=action_db, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=action_db, + permission_type=permission_type, + ) # Validate that the authenticated user is admin if user query param is provided user = liveaction_api.user or requester_user.name - rbac_utils.assert_user_is_admin_if_user_query_param_is_provided(user_db=requester_user, - user=user) + rbac_utils.assert_user_is_admin_if_user_query_param_is_provided( + user_db=requester_user, user=user + ) try: - return self._schedule_execution(liveaction=liveaction_api, - requester_user=requester_user, - user=user, - context_string=context_string, - show_secrets=show_secrets, - action_db=action_db) + return self._schedule_execution( + liveaction=liveaction_api, + requester_user=requester_user, + user=user, + context_string=context_string, + show_secrets=show_secrets, + action_db=action_db, + ) except ValueError as e: - LOG.exception('Unable to execute action.') + LOG.exception("Unable to execute action.") abort(http_client.BAD_REQUEST, six.text_type(e)) except jsonschema.ValidationError as e: - LOG.exception('Unable to execute action. Parameter validation failed.') - abort(http_client.BAD_REQUEST, re.sub("u'([^']*)'", r"'\1'", - getattr(e, 'message', six.text_type(e)))) + LOG.exception("Unable to execute action. Parameter validation failed.") + abort( + http_client.BAD_REQUEST, + re.sub("u'([^']*)'", r"'\1'", getattr(e, "message", six.text_type(e))), + ) except trace_exc.TraceNotFoundException as e: abort(http_client.BAD_REQUEST, six.text_type(e)) except validation_exc.ValueValidationException as e: raise e except Exception as e: - LOG.exception('Unable to execute action. Unexpected error encountered.') + LOG.exception("Unable to execute action. Unexpected error encountered.") abort(http_client.INTERNAL_SERVER_ERROR, six.text_type(e)) - def _schedule_execution(self, liveaction, requester_user, action_db, user=None, - context_string=None, show_secrets=False): + def _schedule_execution( + self, + liveaction, + requester_user, + action_db, + user=None, + context_string=None, + show_secrets=False, + ): # Initialize execution context if it does not exist. - if not hasattr(liveaction, 'context'): + if not hasattr(liveaction, "context"): liveaction.context = dict() - liveaction.context['user'] = user - liveaction.context['pack'] = action_db.pack + liveaction.context["user"] = user + liveaction.context["pack"] = action_db.pack - LOG.debug('User is: %s' % liveaction.context['user']) + LOG.debug("User is: %s" % liveaction.context["user"]) # Retrieve other st2 context from request header. if context_string: context = try_loads(context_string) if not isinstance(context, dict): - raise ValueError('Unable to convert st2-context from the headers into JSON.') + raise ValueError( + "Unable to convert st2-context from the headers into JSON." + ) liveaction.context.update(context) # Include RBAC context (if RBAC is available and enabled) if cfg.CONF.rbac.enable: user_db = UserDB(name=user) rbac_service = get_rbac_backend().get_service_class() - role_dbs = rbac_service.get_roles_for_user(user_db=user_db, include_remote=True) + role_dbs = rbac_service.get_roles_for_user( + user_db=user_db, include_remote=True + ) roles = [role_db.name for role_db in role_dbs] - liveaction.context['rbac'] = { - 'user': user, - 'roles': roles - } + liveaction.context["rbac"] = {"user": user, "roles": roles} # Schedule the action execution. liveaction_db = LiveActionAPI.to_model(liveaction) - runnertype_db = action_utils.get_runnertype_by_name(action_db.runner_type['name']) + runnertype_db = action_utils.get_runnertype_by_name( + action_db.runner_type["name"] + ) try: liveaction_db.parameters = param_utils.render_live_params( - runnertype_db.runner_parameters, action_db.parameters, liveaction_db.parameters, - liveaction_db.context) + runnertype_db.runner_parameters, + action_db.parameters, + liveaction_db.parameters, + liveaction_db.context, + ) except param_exc.ParamException: # We still need to create a request, so liveaction_db is assigned an ID liveaction_db, actionexecution_db = action_service.create_request( liveaction=liveaction_db, action_db=action_db, - runnertype_db=runnertype_db) + runnertype_db=runnertype_db, + ) # By this point the execution is already in the DB therefore need to mark it failed. _, e, tb = sys.exc_info() action_service.update_status( liveaction=liveaction_db, new_status=action_constants.LIVEACTION_STATUS_FAILED, - result={'error': six.text_type(e), - 'traceback': ''.join(traceback.format_tb(tb, 20))}) + result={ + "error": six.text_type(e), + "traceback": "".join(traceback.format_tb(tb, 20)), + }, + ) # Might be a good idea to return the actual ActionExecution rather than bubble up # the exception. raise validation_exc.ValueValidationException(six.text_type(e)) # The request should be created after the above call to render_live_params # so any templates in live parameters have a chance to render. - liveaction_db, actionexecution_db = action_service.create_request(liveaction=liveaction_db, - action_db=action_db, - runnertype_db=runnertype_db) + liveaction_db, actionexecution_db = action_service.create_request( + liveaction=liveaction_db, action_db=action_db, runnertype_db=runnertype_db + ) - _, actionexecution_db = action_service.publish_request(liveaction_db, actionexecution_db) + _, actionexecution_db = action_service.publish_request( + liveaction_db, actionexecution_db + ) mask_secrets = self._get_mask_secrets(requester_user, show_secrets=show_secrets) - execution_api = ActionExecutionAPI.from_model(actionexecution_db, mask_secrets=mask_secrets) + execution_api = ActionExecutionAPI.from_model( + actionexecution_db, mask_secrets=mask_secrets + ) return Response(json=execution_api, status=http_client.CREATED) @@ -231,25 +251,33 @@ def _get_result_object(self, id): :rtype: ``dict`` """ - fields = ['result'] - action_exec_db = self.access.impl.model.objects.filter(id=id).only(*fields).get() + fields = ["result"] + action_exec_db = ( + self.access.impl.model.objects.filter(id=id).only(*fields).get() + ) return action_exec_db.result - def _get_children(self, id_, requester_user, depth=-1, result_fmt=None, show_secrets=False): + def _get_children( + self, id_, requester_user, depth=-1, result_fmt=None, show_secrets=False + ): # make sure depth is int. Url encoding will make it a string and needs to # be converted back in that case. depth = int(depth) - LOG.debug('retrieving children for id: %s with depth: %s', id_, depth) - descendants = execution_service.get_descendants(actionexecution_id=id_, - descendant_depth=depth, - result_fmt=result_fmt) + LOG.debug("retrieving children for id: %s with depth: %s", id_, depth) + descendants = execution_service.get_descendants( + actionexecution_id=id_, descendant_depth=depth, result_fmt=result_fmt + ) mask_secrets = self._get_mask_secrets(requester_user, show_secrets=show_secrets) - return [self.model.from_model(descendant, mask_secrets=mask_secrets) for - descendant in descendants] + return [ + self.model.from_model(descendant, mask_secrets=mask_secrets) + for descendant in descendants + ] -class BaseActionExecutionNestedController(ActionExecutionsControllerMixin, ResourceController): +class BaseActionExecutionNestedController( + ActionExecutionsControllerMixin, ResourceController +): # Note: We need to override "get_one" and "get_all" to return 404 since nested controller # don't implement thos methods @@ -265,24 +293,36 @@ def get_one(self, id): class ActionExecutionChildrenController(BaseActionExecutionNestedController): - def get_one(self, id, requester_user, depth=-1, result_fmt=None, show_secrets=False): + def get_one( + self, id, requester_user, depth=-1, result_fmt=None, show_secrets=False + ): """ Retrieve children for the provided action execution. :rtype: ``list`` """ - execution_db = self._get_one_by_id(id=id, requester_user=requester_user, - permission_type=PermissionType.EXECUTION_VIEW) + execution_db = self._get_one_by_id( + id=id, + requester_user=requester_user, + permission_type=PermissionType.EXECUTION_VIEW, + ) id = str(execution_db.id) - return self._get_children(id_=id, depth=depth, result_fmt=result_fmt, - requester_user=requester_user, show_secrets=show_secrets) + return self._get_children( + id_=id, + depth=depth, + result_fmt=result_fmt, + requester_user=requester_user, + show_secrets=show_secrets, + ) class ActionExecutionAttributeController(BaseActionExecutionNestedController): - valid_exclude_attributes = ['action__pack', 'action__uid'] + \ - ActionExecutionsControllerMixin.valid_exclude_attributes + valid_exclude_attributes = [ + "action__pack", + "action__uid", + ] + ActionExecutionsControllerMixin.valid_exclude_attributes def get(self, id, attribute, requester_user): """ @@ -294,76 +334,94 @@ def get(self, id, attribute, requester_user): :rtype: ``dict`` """ - fields = [attribute, 'action__pack', 'action__uid'] + fields = [attribute, "action__pack", "action__uid"] try: fields = self._validate_exclude_fields(fields) except ValueError: - valid_attributes = ', '.join(ActionExecutionsControllerMixin.valid_exclude_attributes) - msg = ('Invalid attribute "%s" specified. Valid attributes are: %s' % - (attribute, valid_attributes)) + valid_attributes = ", ".join( + ActionExecutionsControllerMixin.valid_exclude_attributes + ) + msg = 'Invalid attribute "%s" specified. Valid attributes are: %s' % ( + attribute, + valid_attributes, + ) raise ValueError(msg) - action_exec_db = self.access.impl.model.objects.filter(id=id).only(*fields).get() + action_exec_db = ( + self.access.impl.model.objects.filter(id=id).only(*fields).get() + ) permission_type = PermissionType.EXECUTION_VIEW rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=action_exec_db, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=action_exec_db, + permission_type=permission_type, + ) result = getattr(action_exec_db, attribute, None) return Response(json=result, status=http_client.OK) -class ActionExecutionOutputController(ActionExecutionsControllerMixin, ResourceController): - supported_filters = { - 'output_type': 'output_type' - } +class ActionExecutionOutputController( + ActionExecutionsControllerMixin, ResourceController +): + supported_filters = {"output_type": "output_type"} exclude_fields = [] - def get_one(self, id, output_type='all', output_format='raw', existing_only=False, - requester_user=None): + def get_one( + self, + id, + output_type="all", + output_format="raw", + existing_only=False, + requester_user=None, + ): # Special case for id == "last" - if id == 'last': - execution_db = ActionExecution.query().order_by('-id').limit(1).first() + if id == "last": + execution_db = ActionExecution.query().order_by("-id").limit(1).first() if not execution_db: - raise ValueError('No executions found in the database') + raise ValueError("No executions found in the database") id = str(execution_db.id) - execution_db = self._get_one_by_id(id=id, requester_user=requester_user, - permission_type=PermissionType.EXECUTION_VIEW) + execution_db = self._get_one_by_id( + id=id, + requester_user=requester_user, + permission_type=PermissionType.EXECUTION_VIEW, + ) execution_id = str(execution_db.id) query_filters = {} - if output_type and output_type != 'all': - query_filters['output_type'] = output_type + if output_type and output_type != "all": + query_filters["output_type"] = output_type def existing_output_iter(): # Consume and return all of the existing lines # pylint: disable=no-member - output_dbs = ActionExecutionOutput.query(execution_id=execution_id, **query_filters) + output_dbs = ActionExecutionOutput.query( + execution_id=execution_id, **query_filters + ) - output = ''.join([output_db.data for output_db in output_dbs]) - yield six.binary_type(output.encode('utf-8')) + output = "".join([output_db.data for output_db in output_dbs]) + yield six.binary_type(output.encode("utf-8")) def make_response(): app_iter = existing_output_iter() - res = Response(content_type='text/plain', app_iter=app_iter) + res = Response(content_type="text/plain", app_iter=app_iter) return res res = make_response() return res -class ActionExecutionReRunController(ActionExecutionsControllerMixin, ResourceController): +class ActionExecutionReRunController( + ActionExecutionsControllerMixin, ResourceController +): supported_filters = {} - exclude_fields = [ - 'result', - 'trigger_instance' - ] + exclude_fields = ["result", "trigger_instance"] class ExecutionSpecificationAPI(object): def __init__(self, parameters=None, tasks=None, reset=None, user=None): @@ -374,8 +432,10 @@ def __init__(self, parameters=None, tasks=None, reset=None, user=None): def validate(self): if (self.tasks or self.reset) and self.parameters: - raise ValueError('Parameters override is not supported when ' - 're-running task(s) for a workflow.') + raise ValueError( + "Parameters override is not supported when " + "re-running task(s) for a workflow." + ) if self.parameters: assert isinstance(self.parameters, dict) @@ -387,7 +447,9 @@ def validate(self): assert isinstance(self.reset, list) if list(set(self.reset) - set(self.tasks)): - raise ValueError('List of tasks to reset does not match the tasks to rerun.') + raise ValueError( + "List of tasks to reset does not match the tasks to rerun." + ) return self @@ -401,8 +463,10 @@ def post(self, spec_api, id, requester_user, no_merge=False, show_secrets=False) """ if (spec_api.tasks or spec_api.reset) and spec_api.parameters: - raise ValueError('Parameters override is not supported when ' - 're-running task(s) for a workflow.') + raise ValueError( + "Parameters override is not supported when " + "re-running task(s) for a workflow." + ) if spec_api.parameters: assert isinstance(spec_api.parameters, dict) @@ -414,7 +478,9 @@ def post(self, spec_api, id, requester_user, no_merge=False, show_secrets=False) assert isinstance(spec_api.reset, list) if list(set(spec_api.reset) - set(spec_api.tasks)): - raise ValueError('List of tasks to reset does not match the tasks to rerun.') + raise ValueError( + "List of tasks to reset does not match the tasks to rerun." + ) delay = None @@ -422,59 +488,69 @@ def post(self, spec_api, id, requester_user, no_merge=False, show_secrets=False) delay = spec_api.delay no_merge = cast_argument_value(value_type=bool, value=no_merge) - existing_execution = self._get_one_by_id(id=id, exclude_fields=self.exclude_fields, - requester_user=requester_user, - permission_type=PermissionType.EXECUTION_VIEW) + existing_execution = self._get_one_by_id( + id=id, + exclude_fields=self.exclude_fields, + requester_user=requester_user, + permission_type=PermissionType.EXECUTION_VIEW, + ) - if spec_api.tasks and \ - existing_execution.runner['name'] != 'orquesta': - raise ValueError('Task option is only supported for Orquesta workflows.') + if spec_api.tasks and existing_execution.runner["name"] != "orquesta": + raise ValueError("Task option is only supported for Orquesta workflows.") # Merge in any parameters provided by the user new_parameters = {} if not no_merge: - new_parameters.update(getattr(existing_execution, 'parameters', {})) + new_parameters.update(getattr(existing_execution, "parameters", {})) new_parameters.update(spec_api.parameters) # Create object for the new execution - action_ref = existing_execution.action['ref'] + action_ref = existing_execution.action["ref"] # Include additional option(s) for the execution context = { - 're-run': { - 'ref': id, + "re-run": { + "ref": id, } } if spec_api.tasks: - context['re-run']['tasks'] = spec_api.tasks + context["re-run"]["tasks"] = spec_api.tasks if spec_api.reset: - context['re-run']['reset'] = spec_api.reset + context["re-run"]["reset"] = spec_api.reset # Add trace to the new execution trace = trace_service.get_trace_db_by_action_execution( - action_execution_id=existing_execution.id) + action_execution_id=existing_execution.id + ) if trace: - context['trace_context'] = {'id_': str(trace.id)} - - new_liveaction_api = LiveActionCreateAPI(action=action_ref, - context=context, - parameters=new_parameters, - user=spec_api.user, - delay=delay) - - return self._handle_schedule_execution(liveaction_api=new_liveaction_api, - requester_user=requester_user, - show_secrets=show_secrets) - - -class ActionExecutionsController(BaseResourceIsolationControllerMixin, - ActionExecutionsControllerMixin, ResourceController): + context["trace_context"] = {"id_": str(trace.id)} + + new_liveaction_api = LiveActionCreateAPI( + action=action_ref, + context=context, + parameters=new_parameters, + user=spec_api.user, + delay=delay, + ) + + return self._handle_schedule_execution( + liveaction_api=new_liveaction_api, + requester_user=requester_user, + show_secrets=show_secrets, + ) + + +class ActionExecutionsController( + BaseResourceIsolationControllerMixin, + ActionExecutionsControllerMixin, + ResourceController, +): """ - Implements the RESTful web endpoint that handles - the lifecycle of ActionExecutions in the system. + Implements the RESTful web endpoint that handles + the lifecycle of ActionExecutions in the system. """ # Nested controllers @@ -485,17 +561,25 @@ class ActionExecutionsController(BaseResourceIsolationControllerMixin, re_run = ActionExecutionReRunController() # ResourceController attributes - query_options = { - 'sort': ['-start_timestamp', 'action.ref'] - } + query_options = {"sort": ["-start_timestamp", "action.ref"]} supported_filters = SUPPORTED_EXECUTIONS_FILTERS filter_transform_functions = { - 'timestamp_gt': lambda value: isotime.parse(value=value), - 'timestamp_lt': lambda value: isotime.parse(value=value) + "timestamp_gt": lambda value: isotime.parse(value=value), + "timestamp_lt": lambda value: isotime.parse(value=value), } - def get_all(self, requester_user, exclude_attributes=None, sort=None, offset=0, limit=None, - show_secrets=False, include_attributes=None, advanced_filters=None, **raw_filters): + def get_all( + self, + requester_user, + exclude_attributes=None, + sort=None, + offset=0, + limit=None, + show_secrets=False, + include_attributes=None, + advanced_filters=None, + **raw_filters, + ): """ List all executions. @@ -508,27 +592,37 @@ def get_all(self, requester_user, exclude_attributes=None, sort=None, offset=0, # Use a custom sort order when filtering on a timestamp so we return a correct result as # expected by the user query_options = None - if raw_filters.get('timestamp_lt', None) or raw_filters.get('sort_desc', None): - query_options = {'sort': ['-start_timestamp', 'action.ref']} - elif raw_filters.get('timestamp_gt', None) or raw_filters.get('sort_asc', None): - query_options = {'sort': ['+start_timestamp', 'action.ref']} + if raw_filters.get("timestamp_lt", None) or raw_filters.get("sort_desc", None): + query_options = {"sort": ["-start_timestamp", "action.ref"]} + elif raw_filters.get("timestamp_gt", None) or raw_filters.get("sort_asc", None): + query_options = {"sort": ["+start_timestamp", "action.ref"]} from_model_kwargs = { - 'mask_secrets': self._get_mask_secrets(requester_user, show_secrets=show_secrets) + "mask_secrets": self._get_mask_secrets( + requester_user, show_secrets=show_secrets + ) } - return self._get_action_executions(exclude_fields=exclude_attributes, - include_fields=include_attributes, - from_model_kwargs=from_model_kwargs, - sort=sort, - offset=offset, - limit=limit, - query_options=query_options, - raw_filters=raw_filters, - advanced_filters=advanced_filters, - requester_user=requester_user) - - def get_one(self, id, requester_user, exclude_attributes=None, include_attributes=None, - show_secrets=False): + return self._get_action_executions( + exclude_fields=exclude_attributes, + include_fields=include_attributes, + from_model_kwargs=from_model_kwargs, + sort=sort, + offset=offset, + limit=limit, + query_options=query_options, + raw_filters=raw_filters, + advanced_filters=advanced_filters, + requester_user=requester_user, + ) + + def get_one( + self, + id, + requester_user, + exclude_attributes=None, + include_attributes=None, + show_secrets=False, + ): """ Retrieve a single execution. @@ -538,33 +632,48 @@ def get_one(self, id, requester_user, exclude_attributes=None, include_attribute :param exclude_attributes: List of attributes to exclude from the object. :type exclude_attributes: ``list`` """ - exclude_fields = self._validate_exclude_fields(exclude_fields=exclude_attributes) - include_fields = self._validate_include_fields(include_fields=include_attributes) + exclude_fields = self._validate_exclude_fields( + exclude_fields=exclude_attributes + ) + include_fields = self._validate_include_fields( + include_fields=include_attributes + ) from_model_kwargs = { - 'mask_secrets': self._get_mask_secrets(requester_user, show_secrets=show_secrets) + "mask_secrets": self._get_mask_secrets( + requester_user, show_secrets=show_secrets + ) } # Special case for id == "last" - if id == 'last': - execution_db = ActionExecution.query().order_by('-id').limit(1).only('id').first() + if id == "last": + execution_db = ( + ActionExecution.query().order_by("-id").limit(1).only("id").first() + ) if not execution_db: - raise ValueError('No executions found in the database') + raise ValueError("No executions found in the database") id = str(execution_db.id) - return self._get_one_by_id(id=id, exclude_fields=exclude_fields, - include_fields=include_fields, - requester_user=requester_user, - from_model_kwargs=from_model_kwargs, - permission_type=PermissionType.EXECUTION_VIEW) - - def post(self, liveaction_api, requester_user, context_string=None, show_secrets=False): - return self._handle_schedule_execution(liveaction_api=liveaction_api, - requester_user=requester_user, - context_string=context_string, - show_secrets=show_secrets) + return self._get_one_by_id( + id=id, + exclude_fields=exclude_fields, + include_fields=include_fields, + requester_user=requester_user, + from_model_kwargs=from_model_kwargs, + permission_type=PermissionType.EXECUTION_VIEW, + ) + + def post( + self, liveaction_api, requester_user, context_string=None, show_secrets=False + ): + return self._handle_schedule_execution( + liveaction_api=liveaction_api, + requester_user=requester_user, + context_string=context_string, + show_secrets=show_secrets, + ) def put(self, id, liveaction_api, requester_user, show_secrets=False): """ @@ -578,76 +687,118 @@ def put(self, id, liveaction_api, requester_user, show_secrets=False): requester_user = UserDB(cfg.CONF.system_user.user) from_model_kwargs = { - 'mask_secrets': self._get_mask_secrets(requester_user, show_secrets=show_secrets) + "mask_secrets": self._get_mask_secrets( + requester_user, show_secrets=show_secrets + ) } - execution_api = self._get_one_by_id(id=id, requester_user=requester_user, - from_model_kwargs=from_model_kwargs, - permission_type=PermissionType.EXECUTION_STOP) + execution_api = self._get_one_by_id( + id=id, + requester_user=requester_user, + from_model_kwargs=from_model_kwargs, + permission_type=PermissionType.EXECUTION_STOP, + ) if not execution_api: - abort(http_client.NOT_FOUND, 'Execution with id %s not found.' % id) + abort(http_client.NOT_FOUND, "Execution with id %s not found." % id) - liveaction_id = execution_api.liveaction['id'] + liveaction_id = execution_api.liveaction["id"] if not liveaction_id: - abort(http_client.INTERNAL_SERVER_ERROR, - 'Execution object missing link to liveaction %s.' % liveaction_id) + abort( + http_client.INTERNAL_SERVER_ERROR, + "Execution object missing link to liveaction %s." % liveaction_id, + ) try: liveaction_db = LiveAction.get_by_id(liveaction_id) except: - abort(http_client.INTERNAL_SERVER_ERROR, - 'Execution object missing link to liveaction %s.' % liveaction_id) + abort( + http_client.INTERNAL_SERVER_ERROR, + "Execution object missing link to liveaction %s." % liveaction_id, + ) if liveaction_db.status in action_constants.LIVEACTION_COMPLETED_STATES: - abort(http_client.BAD_REQUEST, 'Execution is already in completed state.') + abort(http_client.BAD_REQUEST, "Execution is already in completed state.") def update_status(liveaction_api, liveaction_db): status = liveaction_api.status - result = getattr(liveaction_api, 'result', None) + result = getattr(liveaction_api, "result", None) liveaction_db = action_service.update_status(liveaction_db, status, result) - actionexecution_db = ActionExecution.get(liveaction__id=str(liveaction_db.id)) + actionexecution_db = ActionExecution.get( + liveaction__id=str(liveaction_db.id) + ) return (liveaction_db, actionexecution_db) try: - if (liveaction_db.status == action_constants.LIVEACTION_STATUS_CANCELING and - liveaction_api.status == action_constants.LIVEACTION_STATUS_CANCELED): + if ( + liveaction_db.status == action_constants.LIVEACTION_STATUS_CANCELING + and liveaction_api.status == action_constants.LIVEACTION_STATUS_CANCELED + ): if action_service.is_children_active(liveaction_id): liveaction_api.status = action_constants.LIVEACTION_STATUS_CANCELING - liveaction_db, actionexecution_db = update_status(liveaction_api, liveaction_db) - elif (liveaction_api.status == action_constants.LIVEACTION_STATUS_CANCELING or - liveaction_api.status == action_constants.LIVEACTION_STATUS_CANCELED): + liveaction_db, actionexecution_db = update_status( + liveaction_api, liveaction_db + ) + elif ( + liveaction_api.status == action_constants.LIVEACTION_STATUS_CANCELING + or liveaction_api.status == action_constants.LIVEACTION_STATUS_CANCELED + ): liveaction_db, actionexecution_db = action_service.request_cancellation( - liveaction_db, requester_user.name or cfg.CONF.system_user.user) - elif (liveaction_db.status == action_constants.LIVEACTION_STATUS_PAUSING and - liveaction_api.status == action_constants.LIVEACTION_STATUS_PAUSED): + liveaction_db, requester_user.name or cfg.CONF.system_user.user + ) + elif ( + liveaction_db.status == action_constants.LIVEACTION_STATUS_PAUSING + and liveaction_api.status == action_constants.LIVEACTION_STATUS_PAUSED + ): if action_service.is_children_active(liveaction_id): liveaction_api.status = action_constants.LIVEACTION_STATUS_PAUSING - liveaction_db, actionexecution_db = update_status(liveaction_api, liveaction_db) - elif (liveaction_api.status == action_constants.LIVEACTION_STATUS_PAUSING or - liveaction_api.status == action_constants.LIVEACTION_STATUS_PAUSED): + liveaction_db, actionexecution_db = update_status( + liveaction_api, liveaction_db + ) + elif ( + liveaction_api.status == action_constants.LIVEACTION_STATUS_PAUSING + or liveaction_api.status == action_constants.LIVEACTION_STATUS_PAUSED + ): liveaction_db, actionexecution_db = action_service.request_pause( - liveaction_db, requester_user.name or cfg.CONF.system_user.user) + liveaction_db, requester_user.name or cfg.CONF.system_user.user + ) elif liveaction_api.status == action_constants.LIVEACTION_STATUS_RESUMING: liveaction_db, actionexecution_db = action_service.request_resume( - liveaction_db, requester_user.name or cfg.CONF.system_user.user) + liveaction_db, requester_user.name or cfg.CONF.system_user.user + ) else: - liveaction_db, actionexecution_db = update_status(liveaction_api, liveaction_db) + liveaction_db, actionexecution_db = update_status( + liveaction_api, liveaction_db + ) except runner_exc.InvalidActionRunnerOperationError as e: - LOG.exception('Failed updating liveaction %s. %s', liveaction_db.id, six.text_type(e)) - abort(http_client.BAD_REQUEST, 'Failed updating execution. %s' % six.text_type(e)) + LOG.exception( + "Failed updating liveaction %s. %s", liveaction_db.id, six.text_type(e) + ) + abort( + http_client.BAD_REQUEST, + "Failed updating execution. %s" % six.text_type(e), + ) except runner_exc.UnexpectedActionExecutionStatusError as e: - LOG.exception('Failed updating liveaction %s. %s', liveaction_db.id, six.text_type(e)) - abort(http_client.BAD_REQUEST, 'Failed updating execution. %s' % six.text_type(e)) + LOG.exception( + "Failed updating liveaction %s. %s", liveaction_db.id, six.text_type(e) + ) + abort( + http_client.BAD_REQUEST, + "Failed updating execution. %s" % six.text_type(e), + ) except Exception as e: - LOG.exception('Failed updating liveaction %s. %s', liveaction_db.id, six.text_type(e)) + LOG.exception( + "Failed updating liveaction %s. %s", liveaction_db.id, six.text_type(e) + ) abort( http_client.INTERNAL_SERVER_ERROR, - 'Failed updating execution due to unexpected error.' + "Failed updating execution due to unexpected error.", ) mask_secrets = self._get_mask_secrets(requester_user, show_secrets=show_secrets) - execution_api = ActionExecutionAPI.from_model(actionexecution_db, mask_secrets=mask_secrets) + execution_api = ActionExecutionAPI.from_model( + actionexecution_db, mask_secrets=mask_secrets + ) return execution_api @@ -663,50 +814,76 @@ def delete(self, id, requester_user, show_secrets=False): requester_user = UserDB(cfg.CONF.system_user.user) from_model_kwargs = { - 'mask_secrets': self._get_mask_secrets(requester_user, show_secrets=show_secrets) + "mask_secrets": self._get_mask_secrets( + requester_user, show_secrets=show_secrets + ) } - execution_api = self._get_one_by_id(id=id, requester_user=requester_user, - from_model_kwargs=from_model_kwargs, - permission_type=PermissionType.EXECUTION_STOP) + execution_api = self._get_one_by_id( + id=id, + requester_user=requester_user, + from_model_kwargs=from_model_kwargs, + permission_type=PermissionType.EXECUTION_STOP, + ) if not execution_api: - abort(http_client.NOT_FOUND, 'Execution with id %s not found.' % id) + abort(http_client.NOT_FOUND, "Execution with id %s not found." % id) - liveaction_id = execution_api.liveaction['id'] + liveaction_id = execution_api.liveaction["id"] if not liveaction_id: - abort(http_client.INTERNAL_SERVER_ERROR, - 'Execution object missing link to liveaction %s.' % liveaction_id) + abort( + http_client.INTERNAL_SERVER_ERROR, + "Execution object missing link to liveaction %s." % liveaction_id, + ) try: liveaction_db = LiveAction.get_by_id(liveaction_id) except: - abort(http_client.INTERNAL_SERVER_ERROR, - 'Execution object missing link to liveaction %s.' % liveaction_id) + abort( + http_client.INTERNAL_SERVER_ERROR, + "Execution object missing link to liveaction %s." % liveaction_id, + ) if liveaction_db.status == action_constants.LIVEACTION_STATUS_CANCELED: LOG.info( 'Action %s already in "canceled" state; \ - returning execution object.' % liveaction_db.id + returning execution object.' + % liveaction_db.id ) return execution_api if liveaction_db.status not in action_constants.LIVEACTION_CANCELABLE_STATES: - abort(http_client.OK, 'Action cannot be canceled. State = %s.' % liveaction_db.status) + abort( + http_client.OK, + "Action cannot be canceled. State = %s." % liveaction_db.status, + ) try: (liveaction_db, execution_db) = action_service.request_cancellation( - liveaction_db, requester_user.name or cfg.CONF.system_user.user) + liveaction_db, requester_user.name or cfg.CONF.system_user.user + ) except: - LOG.exception('Failed requesting cancellation for liveaction %s.', liveaction_db.id) - abort(http_client.INTERNAL_SERVER_ERROR, 'Failed canceling execution.') - - return ActionExecutionAPI.from_model(execution_db, - mask_secrets=from_model_kwargs['mask_secrets']) - - def _get_action_executions(self, exclude_fields=None, include_fields=None, - sort=None, offset=0, limit=None, advanced_filters=None, - query_options=None, raw_filters=None, from_model_kwargs=None, - requester_user=None): + LOG.exception( + "Failed requesting cancellation for liveaction %s.", liveaction_db.id + ) + abort(http_client.INTERNAL_SERVER_ERROR, "Failed canceling execution.") + + return ActionExecutionAPI.from_model( + execution_db, mask_secrets=from_model_kwargs["mask_secrets"] + ) + + def _get_action_executions( + self, + exclude_fields=None, + include_fields=None, + sort=None, + offset=0, + limit=None, + advanced_filters=None, + query_options=None, + raw_filters=None, + from_model_kwargs=None, + requester_user=None, + ): """ :param exclude_fields: A list of object fields to exclude. :type exclude_fields: ``list`` @@ -717,18 +894,25 @@ def _get_action_executions(self, exclude_fields=None, include_fields=None, limit = int(limit) - LOG.debug('Retrieving all action executions with filters=%s,exclude_fields=%s,' - 'include_fields=%s', raw_filters, exclude_fields, include_fields) - return super(ActionExecutionsController, self)._get_all(exclude_fields=exclude_fields, - include_fields=include_fields, - from_model_kwargs=from_model_kwargs, - sort=sort, - offset=offset, - limit=limit, - query_options=query_options, - raw_filters=raw_filters, - advanced_filters=advanced_filters, - requester_user=requester_user) + LOG.debug( + "Retrieving all action executions with filters=%s,exclude_fields=%s," + "include_fields=%s", + raw_filters, + exclude_fields, + include_fields, + ) + return super(ActionExecutionsController, self)._get_all( + exclude_fields=exclude_fields, + include_fields=include_fields, + from_model_kwargs=from_model_kwargs, + sort=sort, + offset=offset, + limit=limit, + query_options=query_options, + raw_filters=raw_filters, + advanced_filters=advanced_filters, + requester_user=requester_user, + ) action_executions_controller = ActionExecutionsController() diff --git a/st2api/st2api/controllers/v1/actions.py b/st2api/st2api/controllers/v1/actions.py index 1746e84b838..c78667076d2 100644 --- a/st2api/st2api/controllers/v1/actions.py +++ b/st2api/st2api/controllers/v1/actions.py @@ -53,91 +53,102 @@ class ActionsController(resource.ContentPackResourceController): """ - Implements the RESTful web endpoint that handles - the lifecycle of Actions in the system. + Implements the RESTful web endpoint that handles + the lifecycle of Actions in the system. """ + views = ActionViewsController() model = ActionAPI access = Action - supported_filters = { - 'name': 'name', - 'pack': 'pack', - 'tags': 'tags.name' - } + supported_filters = {"name": "name", "pack": "pack", "tags": "tags.name"} - query_options = { - 'sort': ['pack', 'name'] - } + query_options = {"sort": ["pack", "name"]} - valid_exclude_attributes = [ - 'parameters', - 'notify' - ] + valid_exclude_attributes = ["parameters", "notify"] def __init__(self, *args, **kwargs): super(ActionsController, self).__init__(*args, **kwargs) self._trigger_dispatcher = TriggerDispatcher(LOG) - def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0, - limit=None, requester_user=None, **raw_filters): - return super(ActionsController, self)._get_all(exclude_fields=exclude_attributes, - include_fields=include_attributes, - sort=sort, - offset=offset, - limit=limit, - raw_filters=raw_filters, - requester_user=requester_user) + def get_all( + self, + exclude_attributes=None, + include_attributes=None, + sort=None, + offset=0, + limit=None, + requester_user=None, + **raw_filters, + ): + return super(ActionsController, self)._get_all( + exclude_fields=exclude_attributes, + include_fields=include_attributes, + sort=sort, + offset=offset, + limit=limit, + raw_filters=raw_filters, + requester_user=requester_user, + ) def get_one(self, ref_or_id, requester_user): - return super(ActionsController, self)._get_one(ref_or_id, requester_user=requester_user, - permission_type=PermissionType.ACTION_VIEW) + return super(ActionsController, self)._get_one( + ref_or_id, + requester_user=requester_user, + permission_type=PermissionType.ACTION_VIEW, + ) def post(self, action, requester_user): """ - Create a new action. + Create a new action. - Handles requests: - POST /actions/ + Handles requests: + POST /actions/ """ permission_type = PermissionType.ACTION_CREATE rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_api_permission(user_db=requester_user, - resource_api=action, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_api_permission( + user_db=requester_user, resource_api=action, permission_type=permission_type + ) try: # Perform validation validate_not_part_of_system_pack(action) action_validator.validate_action(action) - except (ValidationError, ValueError, - ValueValidationException, InvalidActionParameterException) as e: - LOG.exception('Unable to create action data=%s', action) + except ( + ValidationError, + ValueError, + ValueValidationException, + InvalidActionParameterException, + ) as e: + LOG.exception("Unable to create action data=%s", action) abort(http_client.BAD_REQUEST, six.text_type(e)) return # Write pack data files to disk (if any are provided) - data_files = getattr(action, 'data_files', []) + data_files = getattr(action, "data_files", []) written_data_files = [] if data_files: - written_data_files = self._handle_data_files(pack_ref=action.pack, - data_files=data_files) + written_data_files = self._handle_data_files( + pack_ref=action.pack, data_files=data_files + ) action_model = ActionAPI.to_model(action) - LOG.debug('/actions/ POST verified ActionAPI object=%s', action) + LOG.debug("/actions/ POST verified ActionAPI object=%s", action) action_db = Action.add_or_update(action_model) - LOG.debug('/actions/ POST saved ActionDB object=%s', action_db) + LOG.debug("/actions/ POST saved ActionDB object=%s", action_db) # Dispatch an internal trigger for each written data file. This way user # automate comitting this files to git using StackStorm rule if written_data_files: - self._dispatch_trigger_for_written_data_files(action_db=action_db, - written_data_files=written_data_files) + self._dispatch_trigger_for_written_data_files( + action_db=action_db, written_data_files=written_data_files + ) - extra = {'acion_db': action_db} - LOG.audit('Action created. Action.id=%s' % (action_db.id), extra=extra) + extra = {"acion_db": action_db} + LOG.audit("Action created. Action.id=%s" % (action_db.id), extra=extra) action_api = ActionAPI.from_model(action_db) return Response(json=action_api, status=http_client.CREATED) @@ -148,13 +159,15 @@ def put(self, action, ref_or_id, requester_user): # Assert permissions permission_type = PermissionType.ACTION_MODIFY rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=action_db, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=action_db, + permission_type=permission_type, + ) action_id = action_db.id - if not getattr(action, 'pack', None): + if not getattr(action, "pack", None): action.pack = action_db.pack # Perform validation @@ -162,70 +175,81 @@ def put(self, action, ref_or_id, requester_user): action_validator.validate_action(action) # Write pack data files to disk (if any are provided) - data_files = getattr(action, 'data_files', []) + data_files = getattr(action, "data_files", []) written_data_files = [] if data_files: - written_data_files = self._handle_data_files(pack_ref=action.pack, - data_files=data_files) + written_data_files = self._handle_data_files( + pack_ref=action.pack, data_files=data_files + ) try: action_db = ActionAPI.to_model(action) - LOG.debug('/actions/ PUT incoming action: %s', action_db) + LOG.debug("/actions/ PUT incoming action: %s", action_db) action_db.id = action_id action_db = Action.add_or_update(action_db) - LOG.debug('/actions/ PUT after add_or_update: %s', action_db) + LOG.debug("/actions/ PUT after add_or_update: %s", action_db) except (ValidationError, ValueError) as e: - LOG.exception('Unable to update action data=%s', action) + LOG.exception("Unable to update action data=%s", action) abort(http_client.BAD_REQUEST, six.text_type(e)) return # Dispatch an internal trigger for each written data file. This way user # automate committing this files to git using StackStorm rule if written_data_files: - self._dispatch_trigger_for_written_data_files(action_db=action_db, - written_data_files=written_data_files) + self._dispatch_trigger_for_written_data_files( + action_db=action_db, written_data_files=written_data_files + ) action_api = ActionAPI.from_model(action_db) - LOG.debug('PUT /actions/ client_result=%s', action_api) + LOG.debug("PUT /actions/ client_result=%s", action_api) return action_api def delete(self, ref_or_id, requester_user): """ - Delete an action. + Delete an action. - Handles requests: - POST /actions/1?_method=delete - DELETE /actions/1 - DELETE /actions/mypack.myaction + Handles requests: + POST /actions/1?_method=delete + DELETE /actions/1 + DELETE /actions/mypack.myaction """ action_db = self._get_by_ref_or_id(ref_or_id=ref_or_id) action_id = action_db.id permission_type = PermissionType.ACTION_DELETE rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=action_db, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=action_db, + permission_type=permission_type, + ) try: validate_not_part_of_system_pack(action_db) except ValueValidationException as e: abort(http_client.BAD_REQUEST, six.text_type(e)) - LOG.debug('DELETE /actions/ lookup with ref_or_id=%s found object: %s', - ref_or_id, action_db) + LOG.debug( + "DELETE /actions/ lookup with ref_or_id=%s found object: %s", + ref_or_id, + action_db, + ) try: Action.delete(action_db) except Exception as e: - LOG.error('Database delete encountered exception during delete of id="%s". ' - 'Exception was %s', action_id, e) + LOG.error( + 'Database delete encountered exception during delete of id="%s". ' + "Exception was %s", + action_id, + e, + ) abort(http_client.INTERNAL_SERVER_ERROR, six.text_type(e)) return - extra = {'action_db': action_db} - LOG.audit('Action deleted. Action.id=%s' % (action_db.id), extra=extra) + extra = {"action_db": action_db} + LOG.audit("Action deleted. Action.id=%s" % (action_db.id), extra=extra) return Response(status=http_client.NO_CONTENT) def _handle_data_files(self, pack_ref, data_files): @@ -238,13 +262,17 @@ def _handle_data_files(self, pack_ref, data_files): 2. Updates affected PackDB model """ # Write files to disk - written_file_paths = self._write_data_files_to_disk(pack_ref=pack_ref, - data_files=data_files) + written_file_paths = self._write_data_files_to_disk( + pack_ref=pack_ref, data_files=data_files + ) # Update affected PackDB model (update a list of files) # Update PackDB - self._update_pack_model(pack_ref=pack_ref, data_files=data_files, - written_file_paths=written_file_paths) + self._update_pack_model( + pack_ref=pack_ref, + data_files=data_files, + written_file_paths=written_file_paths, + ) return written_file_paths @@ -255,23 +283,27 @@ def _write_data_files_to_disk(self, pack_ref, data_files): written_file_paths = [] for data_file in data_files: - file_path = data_file['file_path'] - content = data_file['content'] + file_path = data_file["file_path"] + content = data_file["content"] - file_path = get_pack_resource_file_abs_path(pack_ref=pack_ref, - resource_type='action', - file_path=file_path) + file_path = get_pack_resource_file_abs_path( + pack_ref=pack_ref, resource_type="action", file_path=file_path + ) LOG.debug('Writing data file "%s" to "%s"' % (str(data_file), file_path)) try: - self._write_data_file(pack_ref=pack_ref, file_path=file_path, content=content) + self._write_data_file( + pack_ref=pack_ref, file_path=file_path, content=content + ) except (OSError, IOError) as e: # Throw a more user-friendly exception on Permission denied error if e.errno == errno.EACCES: - msg = ('Unable to write data to "%s" (permission denied). Make sure ' - 'permissions for that pack directory are configured correctly so ' - 'st2api can write to it.' % (file_path)) + msg = ( + 'Unable to write data to "%s" (permission denied). Make sure ' + "permissions for that pack directory are configured correctly so " + "st2api can write to it." % (file_path) + ) raise ValueError(msg) raise e @@ -285,7 +317,9 @@ def _update_pack_model(self, pack_ref, data_files, written_file_paths): """ file_paths = [] # A list of paths relative to the pack directory for new files for file_path in written_file_paths: - file_path = get_relative_path_to_pack_file(pack_ref=pack_ref, file_path=file_path) + file_path = get_relative_path_to_pack_file( + pack_ref=pack_ref, file_path=file_path + ) file_paths.append(file_path) pack_db = Pack.get_by_ref(pack_ref) @@ -314,18 +348,18 @@ def _write_data_file(self, pack_ref, file_path, content): mode = stat.S_IRWXU | stat.S_IRWXG | stat.S_IROTH | stat.S_IXOTH os.makedirs(directory, mode) - with open(file_path, 'w') as fp: + with open(file_path, "w") as fp: fp.write(content) def _dispatch_trigger_for_written_data_files(self, action_db, written_data_files): - trigger = ACTION_FILE_WRITTEN_TRIGGER['name'] + trigger = ACTION_FILE_WRITTEN_TRIGGER["name"] host_info = get_host_info() for file_path in written_data_files: payload = { - 'ref': action_db.ref, - 'file_path': file_path, - 'host_info': host_info + "ref": action_db.ref, + "file_path": file_path, + "host_info": host_info, } self._trigger_dispatcher.dispatch(trigger=trigger, payload=payload) diff --git a/st2api/st2api/controllers/v1/aliasexecution.py b/st2api/st2api/controllers/v1/aliasexecution.py index 7ecc14d62e2..ecbea2028e1 100644 --- a/st2api/st2api/controllers/v1/aliasexecution.py +++ b/st2api/st2api/controllers/v1/aliasexecution.py @@ -30,7 +30,9 @@ from st2common.models.db.liveaction import LiveActionDB from st2common.models.db.notification import NotificationSchema, NotificationSubSchema from st2common.models.utils import action_param_utils -from st2common.models.utils.action_alias_utils import extract_parameters_for_action_alias_db +from st2common.models.utils.action_alias_utils import ( + extract_parameters_for_action_alias_db, +) from st2common.models.utils.action_alias_utils import inject_immutable_parameters from st2common.persistence.actionalias import ActionAlias from st2common.services import action as action_service @@ -53,57 +55,60 @@ def cast_array(value): # Already a list, no casting needed nor wanted. return value - return [v.strip() for v in value.split(',')] + return [v.strip() for v in value.split(",")] CAST_OVERRIDES = { - 'array': cast_array, + "array": cast_array, } class ActionAliasExecutionController(BaseRestControllerMixin): def match_and_execute(self, input_api, requester_user, show_secrets=False): """ - Try to find a matching alias and if one is found, schedule a new - execution by parsing parameters from the provided command against - the matched alias. + Try to find a matching alias and if one is found, schedule a new + execution by parsing parameters from the provided command against + the matched alias. - Handles requests: - POST /aliasexecution/match_and_execute + Handles requests: + POST /aliasexecution/match_and_execute """ command = input_api.command try: format_ = get_matching_alias(command=command) except ActionAliasAmbiguityException as e: - LOG.exception('Command "%s" matched (%s) patterns.', e.command, len(e.matches)) + LOG.exception( + 'Command "%s" matched (%s) patterns.', e.command, len(e.matches) + ) return abort(http_client.BAD_REQUEST, six.text_type(e)) - action_alias_db = format_['alias'] - representation = format_['representation'] + action_alias_db = format_["alias"] + representation = format_["representation"] params = { - 'name': action_alias_db.name, - 'format': representation, - 'command': command, - 'user': input_api.user, - 'source_channel': input_api.source_channel, + "name": action_alias_db.name, + "format": representation, + "command": command, + "user": input_api.user, + "source_channel": input_api.source_channel, } # Add in any additional parameters provided by the user if input_api.notification_channel: - params['notification_channel'] = input_api.notification_channel + params["notification_channel"] = input_api.notification_channel if input_api.notification_route: - params['notification_route'] = input_api.notification_route + params["notification_route"] = input_api.notification_route alias_execution_api = AliasMatchAndExecuteInputAPI(**params) results = self._post( payload=alias_execution_api, requester_user=requester_user, show_secrets=show_secrets, - match_multiple=format_['match_multiple']) - return Response(json={'results': results}, status=http_client.CREATED) + match_multiple=format_["match_multiple"], + ) + return Response(json={"results": results}, status=http_client.CREATED) def _post(self, payload, requester_user, show_secrets=False, match_multiple=False): action_alias_name = payload.name if payload else None @@ -115,8 +120,8 @@ def _post(self, payload, requester_user, show_secrets=False, match_multiple=Fals if not requester_user: requester_user = UserDB(cfg.CONF.system_user.user) - format_str = payload.format or '' - command = payload.command or '' + format_str = payload.format or "" + command = payload.command or "" try: action_alias_db = ActionAlias.get_by_name(action_alias_name) @@ -124,7 +129,9 @@ def _post(self, payload, requester_user, show_secrets=False, match_multiple=Fals action_alias_db = None if not action_alias_db: - msg = 'Unable to identify action alias with name "%s".' % (action_alias_name) + msg = 'Unable to identify action alias with name "%s".' % ( + action_alias_name + ) abort(http_client.NOT_FOUND, msg) return @@ -138,132 +145,163 @@ def _post(self, payload, requester_user, show_secrets=False, match_multiple=Fals action_alias_db=action_alias_db, format_str=format_str, param_stream=command, - match_multiple=match_multiple) + match_multiple=match_multiple, + ) else: multiple_execution_parameters = [ extract_parameters_for_action_alias_db( action_alias_db=action_alias_db, format_str=format_str, param_stream=command, - match_multiple=match_multiple) + match_multiple=match_multiple, + ) ] notify = self._get_notify_field(payload) context = { - 'action_alias_ref': reference.get_ref_from_model(action_alias_db), - 'api_user': payload.user, - 'user': requester_user.name, - 'source_channel': payload.source_channel, + "action_alias_ref": reference.get_ref_from_model(action_alias_db), + "api_user": payload.user, + "user": requester_user.name, + "source_channel": payload.source_channel, } inject_immutable_parameters( action_alias_db=action_alias_db, multiple_execution_parameters=multiple_execution_parameters, - action_context=context) + action_context=context, + ) results = [] for execution_parameters in multiple_execution_parameters: - execution = self._schedule_execution(action_alias_db=action_alias_db, - params=execution_parameters, - notify=notify, - context=context, - show_secrets=show_secrets, - requester_user=requester_user) + execution = self._schedule_execution( + action_alias_db=action_alias_db, + params=execution_parameters, + notify=notify, + context=context, + show_secrets=show_secrets, + requester_user=requester_user, + ) result = { - 'execution': execution, - 'actionalias': ActionAliasAPI.from_model(action_alias_db) + "execution": execution, + "actionalias": ActionAliasAPI.from_model(action_alias_db), } if action_alias_db.ack: try: - if 'format' in action_alias_db.ack: - message = render({'alias': action_alias_db.ack['format']}, result)['alias'] + if "format" in action_alias_db.ack: + message = render( + {"alias": action_alias_db.ack["format"]}, result + )["alias"] - result.update({ - 'message': message - }) + result.update({"message": message}) except UndefinedError as e: - result.update({ - 'message': ('Cannot render "format" in field "ack" for alias. ' + - six.text_type(e)) - }) + result.update( + { + "message": ( + 'Cannot render "format" in field "ack" for alias. ' + + six.text_type(e) + ) + } + ) try: - if 'extra' in action_alias_db.ack: - result.update({ - 'extra': render(action_alias_db.ack['extra'], result) - }) + if "extra" in action_alias_db.ack: + result.update( + {"extra": render(action_alias_db.ack["extra"], result)} + ) except UndefinedError as e: - result.update({ - 'extra': ('Cannot render "extra" in field "ack" for alias. ' + - six.text_type(e)) - }) + result.update( + { + "extra": ( + 'Cannot render "extra" in field "ack" for alias. ' + + six.text_type(e) + ) + } + ) results.append(result) return results def post(self, payload, requester_user, show_secrets=False): - results = self._post(payload, requester_user, show_secrets, match_multiple=False) + results = self._post( + payload, requester_user, show_secrets, match_multiple=False + ) return Response(json=results[0], status=http_client.CREATED) def _tokenize_alias_execution(self, alias_execution): - tokens = alias_execution.strip().split(' ', 1) + tokens = alias_execution.strip().split(" ", 1) return (tokens[0], tokens[1] if len(tokens) > 1 else None) def _get_notify_field(self, payload): on_complete = NotificationSubSchema() - route = (getattr(payload, 'notification_route', None) or - getattr(payload, 'notification_channel', None)) + route = getattr(payload, "notification_route", None) or getattr( + payload, "notification_channel", None + ) on_complete.routes = [route] on_complete.data = { - 'user': payload.user, - 'source_channel': payload.source_channel, - 'source_context': getattr(payload, 'source_context', None), + "user": payload.user, + "source_channel": payload.source_channel, + "source_context": getattr(payload, "source_context", None), } notify = NotificationSchema() notify.on_complete = on_complete return notify - def _schedule_execution(self, action_alias_db, params, notify, context, requester_user, - show_secrets): + def _schedule_execution( + self, action_alias_db, params, notify, context, requester_user, show_secrets + ): action_ref = action_alias_db.action_ref action_db = action_utils.get_action_by_ref(action_ref) if not action_db: - raise StackStormDBObjectNotFoundError('Action with ref "%s" not found ' % (action_ref)) + raise StackStormDBObjectNotFoundError( + 'Action with ref "%s" not found ' % (action_ref) + ) rbac_utils = get_rbac_backend().get_utils_class() permission_type = PermissionType.ACTION_EXECUTE - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=action_db, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=action_db, + permission_type=permission_type, + ) try: # prior to shipping off the params cast them to the right type. - params = action_param_utils.cast_params(action_ref=action_alias_db.action_ref, - params=params, - cast_overrides=CAST_OVERRIDES) + params = action_param_utils.cast_params( + action_ref=action_alias_db.action_ref, + params=params, + cast_overrides=CAST_OVERRIDES, + ) if not context: context = { - 'action_alias_ref': reference.get_ref_from_model(action_alias_db), - 'user': get_system_username() + "action_alias_ref": reference.get_ref_from_model(action_alias_db), + "user": get_system_username(), } - liveaction = LiveActionDB(action=action_alias_db.action_ref, context=context, - parameters=params, notify=notify) + liveaction = LiveActionDB( + action=action_alias_db.action_ref, + context=context, + parameters=params, + notify=notify, + ) _, action_execution_db = action_service.request(liveaction) - mask_secrets = self._get_mask_secrets(requester_user, show_secrets=show_secrets) - return ActionExecutionAPI.from_model(action_execution_db, mask_secrets=mask_secrets) + mask_secrets = self._get_mask_secrets( + requester_user, show_secrets=show_secrets + ) + return ActionExecutionAPI.from_model( + action_execution_db, mask_secrets=mask_secrets + ) except ValueError as e: - LOG.exception('Unable to execute action.') + LOG.exception("Unable to execute action.") abort(http_client.BAD_REQUEST, six.text_type(e)) except jsonschema.ValidationError as e: - LOG.exception('Unable to execute action. Parameter validation failed.') + LOG.exception("Unable to execute action. Parameter validation failed.") abort(http_client.BAD_REQUEST, six.text_type(e)) except Exception as e: - LOG.exception('Unable to execute action. Unexpected error encountered.') + LOG.exception("Unable to execute action. Unexpected error encountered.") abort(http_client.INTERNAL_SERVER_ERROR, six.text_type(e)) diff --git a/st2api/st2api/controllers/v1/auth.py b/st2api/st2api/controllers/v1/auth.py index 909c8ff4fea..d4741c4bf1e 100644 --- a/st2api/st2api/controllers/v1/auth.py +++ b/st2api/st2api/controllers/v1/auth.py @@ -37,9 +37,7 @@ LOG = logging.getLogger(__name__) -__all__ = [ - 'ApiKeyController' -] +__all__ = ["ApiKeyController"] # See st2common.rbac.resolvers.ApiKeyPermissionResolver#user_has_resource_db_permission for resaon @@ -49,13 +47,9 @@ class ApiKeyController(BaseRestControllerMixin): Implements the REST endpoint for managing the key value store. """ - supported_filters = { - 'user': 'user' - } + supported_filters = {"user": "user"} - query_options = { - 'sort': ['user'] - } + query_options = {"sort": ["user"]} def __init__(self): super(ApiKeyController, self).__init__() @@ -63,31 +57,36 @@ def __init__(self): def get_one(self, api_key_id_or_key, requester_user, show_secrets=None): """ - List api keys. + List api keys. - Handle: - GET /apikeys/1 + Handle: + GET /apikeys/1 """ api_key_db = None try: api_key_db = ApiKey.get_by_key_or_id(api_key_id_or_key) except ApiKeyNotFoundError: - msg = ('ApiKey matching %s for reference and id not found.' % (api_key_id_or_key)) + msg = "ApiKey matching %s for reference and id not found." % ( + api_key_id_or_key + ) LOG.exception(msg) abort(http_client.NOT_FOUND, msg) permission_type = PermissionType.API_KEY_VIEW rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=api_key_db, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=api_key_db, + permission_type=permission_type, + ) try: - mask_secrets = self._get_mask_secrets(show_secrets=show_secrets, - requester_user=requester_user) + mask_secrets = self._get_mask_secrets( + show_secrets=show_secrets, requester_user=requester_user + ) return ApiKeyAPI.from_model(api_key_db, mask_secrets=mask_secrets) except (ValidationError, ValueError) as e: - LOG.exception('Failed to serialize API key.') + LOG.exception("Failed to serialize API key.") abort(http_client.INTERNAL_SERVER_ERROR, six.text_type(e)) @property @@ -96,29 +95,34 @@ def max_limit(self): def get_all(self, requester_user, show_secrets=None, limit=None, offset=0): """ - List all keys. + List all keys. - Handles requests: - GET /apikeys/ + Handles requests: + GET /apikeys/ """ - mask_secrets = self._get_mask_secrets(show_secrets=show_secrets, - requester_user=requester_user) + mask_secrets = self._get_mask_secrets( + show_secrets=show_secrets, requester_user=requester_user + ) - limit = resource.validate_limit_query_param(limit, requester_user=requester_user) + limit = resource.validate_limit_query_param( + limit, requester_user=requester_user + ) try: api_key_dbs = ApiKey.get_all(limit=limit, offset=offset) - api_keys = [ApiKeyAPI.from_model(api_key_db, mask_secrets=mask_secrets) - for api_key_db in api_key_dbs] + api_keys = [ + ApiKeyAPI.from_model(api_key_db, mask_secrets=mask_secrets) + for api_key_db in api_key_dbs + ] except OverflowError: msg = 'Offset "%s" specified is more than 32 bit int' % (offset) raise ValueError(msg) resp = Response(json=api_keys) - resp.headers['X-Total-Count'] = str(api_key_dbs.count()) + resp.headers["X-Total-Count"] = str(api_key_dbs.count()) if limit: - resp.headers['X-Limit'] = str(limit) + resp.headers["X-Limit"] = str(limit) return resp @@ -129,14 +133,16 @@ def post(self, api_key_api, requester_user): permission_type = PermissionType.API_KEY_CREATE rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_api_permission(user_db=requester_user, - resource_api=api_key_api, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_api_permission( + user_db=requester_user, + resource_api=api_key_api, + permission_type=permission_type, + ) api_key_db = None api_key = None try: - if not getattr(api_key_api, 'user', None): + if not getattr(api_key_api, "user", None): if requester_user: api_key_api.user = requester_user.name else: @@ -148,22 +154,22 @@ def post(self, api_key_api, requester_user): user_db = UserDB(name=api_key_api.user) User.add_or_update(user_db) - extra = {'username': api_key_api.user, 'user': user_db} + extra = {"username": api_key_api.user, "user": user_db} LOG.audit('Registered new user "%s".' % (api_key_api.user), extra=extra) # If key_hash is provided use that and do not create a new key. The assumption # is user already has the original api-key - if not getattr(api_key_api, 'key_hash', None): + if not getattr(api_key_api, "key_hash", None): api_key, api_key_hash = auth_util.generate_api_key_and_hash() # store key_hash in DB api_key_api.key_hash = api_key_hash api_key_db = ApiKey.add_or_update(ApiKeyAPI.to_model(api_key_api)) except (ValidationError, ValueError) as e: - LOG.exception('Validation failed for api_key data=%s.', api_key_api) + LOG.exception("Validation failed for api_key data=%s.", api_key_api) abort(http_client.BAD_REQUEST, six.text_type(e)) - extra = {'api_key_db': api_key_db} - LOG.audit('ApiKey created. ApiKey.id=%s' % (api_key_db.id), extra=extra) + extra = {"api_key_db": api_key_db} + LOG.audit("ApiKey created. ApiKey.id=%s" % (api_key_db.id), extra=extra) api_key_create_response_api = ApiKeyCreateResponseAPI.from_model(api_key_db) # Return real api_key back to user. A one-way hash of the api_key is stored in the DB @@ -178,9 +184,11 @@ def put(self, api_key_api, api_key_id_or_key, requester_user): permission_type = PermissionType.API_KEY_MODIFY rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=api_key_db, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=api_key_db, + permission_type=permission_type, + ) old_api_key_db = api_key_db api_key_db = ApiKeyAPI.to_model(api_key_api) @@ -191,7 +199,7 @@ def put(self, api_key_api, api_key_id_or_key, requester_user): user_db = UserDB(name=api_key_api.user) User.add_or_update(user_db) - extra = {'username': api_key_api.user, 'user': user_db} + extra = {"username": api_key_api.user, "user": user_db} LOG.audit('Registered new user "%s".' % (api_key_api.user), extra=extra) # Passing in key_hash as MASKED_ATTRIBUTE_VALUE is expected since we do not @@ -203,36 +211,38 @@ def put(self, api_key_api, api_key_id_or_key, requester_user): # Rather than silently ignore any update to key_hash it is better to explicitly # disallow and notify user. if old_api_key_db.key_hash != api_key_db.key_hash: - raise ValueError('Update of key_hash is not allowed.') + raise ValueError("Update of key_hash is not allowed.") api_key_db.id = old_api_key_db.id api_key_db = ApiKey.add_or_update(api_key_db) - extra = {'old_api_key_db': old_api_key_db, 'new_api_key_db': api_key_db} - LOG.audit('API Key updated. ApiKey.id=%s.' % (api_key_db.id), extra=extra) + extra = {"old_api_key_db": old_api_key_db, "new_api_key_db": api_key_db} + LOG.audit("API Key updated. ApiKey.id=%s." % (api_key_db.id), extra=extra) api_key_api = ApiKeyAPI.from_model(api_key_db) return api_key_api def delete(self, api_key_id_or_key, requester_user): """ - Delete the key value pair. + Delete the key value pair. - Handles requests: - DELETE /apikeys/1 + Handles requests: + DELETE /apikeys/1 """ api_key_db = ApiKey.get_by_key_or_id(api_key_id_or_key) permission_type = PermissionType.API_KEY_DELETE rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=api_key_db, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=api_key_db, + permission_type=permission_type, + ) ApiKey.delete(api_key_db) - extra = {'api_key_db': api_key_db} - LOG.audit('ApiKey deleted. ApiKey.id=%s' % (api_key_db.id), extra=extra) + extra = {"api_key_db": api_key_db} + LOG.audit("ApiKey deleted. ApiKey.id=%s" % (api_key_db.id), extra=extra) return Response(status=http_client.NO_CONTENT) diff --git a/st2api/st2api/controllers/v1/execution_views.py b/st2api/st2api/controllers/v1/execution_views.py index 9a61bdc321d..f4240b94abd 100644 --- a/st2api/st2api/controllers/v1/execution_views.py +++ b/st2api/st2api/controllers/v1/execution_views.py @@ -29,51 +29,51 @@ # response. Failure to do so will eventually result in Chrome hanging out while opening History # tab of st2web. SUPPORTED_FILTERS = { - 'action': 'action.ref', - 'status': 'status', - 'liveaction': 'liveaction.id', - 'parent': 'parent', - 'rule': 'rule.name', - 'runner': 'runner.name', - 'timestamp': 'start_timestamp', - 'trigger': 'trigger.name', - 'trigger_type': 'trigger_type.name', - 'trigger_instance': 'trigger_instance.id', - 'user': 'context.user' + "action": "action.ref", + "status": "status", + "liveaction": "liveaction.id", + "parent": "parent", + "rule": "rule.name", + "runner": "runner.name", + "timestamp": "start_timestamp", + "trigger": "trigger.name", + "trigger_type": "trigger_type.name", + "trigger_instance": "trigger_instance.id", + "user": "context.user", } # A list of fields for which null (None) is a valid value which we include in the list of valid # filters. FILTERS_WITH_VALID_NULL_VALUES = [ - 'parent', - 'rule', - 'trigger', - 'trigger_type', - 'trigger_instance' + "parent", + "rule", + "trigger", + "trigger_type", + "trigger_instance", ] # List of filters that are too broad to distinct by them and are very likely to represent 1 to 1 # relation between filter and particular history record. -IGNORE_FILTERS = ['parent', 'timestamp', 'liveaction', 'trigger_instance'] +IGNORE_FILTERS = ["parent", "timestamp", "liveaction", "trigger_instance"] class FiltersController(object): def get_all(self, types=None): """ - List all distinct filters. + List all distinct filters. - Handles requests: - GET /executions/views/filters[?types=action,rule] + Handles requests: + GET /executions/views/filters[?types=action,rule] - :param types: Comma delimited string of filter types to output. - :type types: ``str`` + :param types: Comma delimited string of filter types to output. + :type types: ``str`` """ filters = {} for name, field in six.iteritems(SUPPORTED_FILTERS): if name not in IGNORE_FILTERS and (not types or name in types): if name not in FILTERS_WITH_VALID_NULL_VALUES: - query = {field.replace('.', '__'): {'$ne': None}} + query = {field.replace(".", "__"): {"$ne": None}} else: query = {} diff --git a/st2api/st2api/controllers/v1/inquiries.py b/st2api/st2api/controllers/v1/inquiries.py index a8920769174..fb3bf2e3f09 100644 --- a/st2api/st2api/controllers/v1/inquiries.py +++ b/st2api/st2api/controllers/v1/inquiries.py @@ -34,13 +34,11 @@ from st2common.services import inquiry as inquiry_service -__all__ = [ - 'InquiriesController' -] +__all__ = ["InquiriesController"] LOG = logging.getLogger(__name__) -INQUIRY_RUNNER = 'inquirer' +INQUIRY_RUNNER = "inquirer" class InquiriesController(ResourceController): @@ -55,12 +53,18 @@ class InquiriesController(ResourceController): model = inqy_api_models.InquiryAPI access = ex_db_access.ActionExecution - def get_all(self, exclude_attributes=None, include_attributes=None, requester_user=None, - limit=None, **raw_filters): + def get_all( + self, + exclude_attributes=None, + include_attributes=None, + requester_user=None, + limit=None, + **raw_filters, + ): """Retrieve multiple Inquiries - Handles requests: - GET /inquiries/ + Handles requests: + GET /inquiries/ """ # NOTE: This controller retrieves execution objects and returns a new model composed of @@ -70,13 +74,13 @@ def get_all(self, exclude_attributes=None, include_attributes=None, requester_us # filtering before returning the response. raw_inquiries = super(InquiriesController, self)._get_all( exclude_fields=[], - include_fields=['id', 'result'], + include_fields=["id", "result"], limit=limit, raw_filters={ - 'status': action_constants.LIVEACTION_STATUS_PENDING, - 'runner': INQUIRY_RUNNER + "status": action_constants.LIVEACTION_STATUS_PENDING, + "runner": INQUIRY_RUNNER, }, - requester_user=requester_user + requester_user=requester_user, ) # Since "model" is set to InquiryAPI (for good reasons), _get_all returns a list of @@ -90,18 +94,18 @@ def get_all(self, exclude_attributes=None, include_attributes=None, requester_us # Repackage into Response with correct headers resp = api_router.Response(json=inquiries) - resp.headers['X-Total-Count'] = raw_inquiries.headers['X-Total-Count'] + resp.headers["X-Total-Count"] = raw_inquiries.headers["X-Total-Count"] if limit: - resp.headers['X-Limit'] = str(limit) + resp.headers["X-Limit"] = str(limit) return resp def get_one(self, inquiry_id, requester_user=None): """Retrieve a single Inquiry - Handles requests: - GET /inquiries/ + Handles requests: + GET /inquiries/ """ # Retrieve the inquiry by id. @@ -110,7 +114,7 @@ def get_one(self, inquiry_id, requester_user=None): inquiry = self._get_one_by_id( id=inquiry_id, requester_user=requester_user, - permission_type=rbac_types.PermissionType.INQUIRY_VIEW + permission_type=rbac_types.PermissionType.INQUIRY_VIEW, ) except db_exceptions.StackStormDBObjectNotFoundError as e: LOG.exception('Unable to identify inquiry with id "%s".' % inquiry_id) @@ -132,15 +136,18 @@ def get_one(self, inquiry_id, requester_user=None): def put(self, inquiry_id, response_data, requester_user): """Provide response data to an Inquiry - In general, provided the response data validates against the provided - schema, and the user has the appropriate permissions to respond, - this will set the Inquiry execution to a successful status, and resume - the parent workflow. + In general, provided the response data validates against the provided + schema, and the user has the appropriate permissions to respond, + this will set the Inquiry execution to a successful status, and resume + the parent workflow. - Handles requests: - PUT /inquiries/ + Handles requests: + PUT /inquiries/ """ - LOG.debug("Inquiry %s received response payload: %s" % (inquiry_id, response_data.response)) + LOG.debug( + "Inquiry %s received response payload: %s" + % (inquiry_id, response_data.response) + ) # Set requester to system user if not provided. if not requester_user: @@ -151,7 +158,7 @@ def put(self, inquiry_id, response_data, requester_user): inquiry = self._get_one_by_id( id=inquiry_id, requester_user=requester_user, - permission_type=rbac_types.PermissionType.INQUIRY_RESPOND + permission_type=rbac_types.PermissionType.INQUIRY_RESPOND, ) except db_exceptions.StackStormDBObjectNotFoundError as e: LOG.exception('Unable to identify inquiry with id "%s".' % inquiry_id) @@ -186,18 +193,23 @@ def put(self, inquiry_id, response_data, requester_user): # Respond to inquiry and update if there is a partial response. try: - inquiry_service.respond(inquiry, response_data.response, requester=requester_user) + inquiry_service.respond( + inquiry, response_data.response, requester=requester_user + ) except Exception as e: LOG.exception('Fail to update response for inquiry "%s".' % inquiry_id) api_router.abort(http_client.INTERNAL_SERVER_ERROR, six.text_type(e)) - return { - 'id': inquiry_id, - 'response': response_data.response - } + return {"id": inquiry_id, "response": response_data.response} - def _get_one_by_id(self, id, requester_user, permission_type, - exclude_fields=None, from_model_kwargs=None): + def _get_one_by_id( + self, + id, + requester_user, + permission_type, + exclude_fields=None, + from_model_kwargs=None, + ): """Override ResourceController._get_one_by_id to contain scope of Inquiries UID hack :param exclude_fields: A list of object fields to exclude. :type exclude_fields: ``list`` @@ -215,8 +227,11 @@ def _get_one_by_id(self, id, requester_user, permission_type, # "inquiry:". # # TODO (mierdin): All of this should be removed once Inquiries get their own DB model. - if (execution_db and getattr(execution_db, 'runner', None) and - execution_db.runner.get('runner_module') == INQUIRY_RUNNER): + if ( + execution_db + and getattr(execution_db, "runner", None) + and execution_db.runner.get("runner_module") == INQUIRY_RUNNER + ): execution_db.get_uid = get_uid LOG.debug('Checking permission on inquiry "%s".' % id) @@ -226,7 +241,7 @@ def _get_one_by_id(self, id, requester_user, permission_type, rbac_utils.assert_user_has_resource_db_permission( user_db=requester_user, resource_db=execution_db, - permission_type=permission_type + permission_type=permission_type, ) from_model_kwargs = from_model_kwargs or {} @@ -237,9 +252,8 @@ def _get_one_by_id(self, id, requester_user, permission_type, def get_uid(): - """Inquiry UID hack for RBAC - """ - return 'inquiry' + """Inquiry UID hack for RBAC""" + return "inquiry" inquiries_controller = InquiriesController() diff --git a/st2api/st2api/controllers/v1/keyvalue.py b/st2api/st2api/controllers/v1/keyvalue.py index eab8cb025a4..2bd8449e248 100644 --- a/st2api/st2api/controllers/v1/keyvalue.py +++ b/st2api/st2api/controllers/v1/keyvalue.py @@ -24,7 +24,10 @@ from st2common.constants.keyvalue import ALL_SCOPE, FULL_SYSTEM_SCOPE, SYSTEM_SCOPE from st2common.constants.keyvalue import FULL_USER_SCOPE, USER_SCOPE, ALLOWED_SCOPES from st2common.exceptions.db import StackStormDBObjectNotFoundError -from st2common.exceptions.keyvalue import CryptoKeyNotSetupException, InvalidScopeException +from st2common.exceptions.keyvalue import ( + CryptoKeyNotSetupException, + InvalidScopeException, +) from st2common.models.api.keyvalue import KeyValuePairAPI from st2common.models.db.auth import UserDB from st2common.persistence.keyvalue import KeyValuePair @@ -40,9 +43,7 @@ LOG = logging.getLogger(__name__) -__all__ = [ - 'KeyValuePairController' -] +__all__ = ["KeyValuePairController"] class KeyValuePairController(ResourceController): @@ -52,22 +53,21 @@ class KeyValuePairController(ResourceController): model = KeyValuePairAPI access = KeyValuePair - supported_filters = { - 'prefix': 'name__startswith', - 'scope': 'scope' - } + supported_filters = {"prefix": "name__startswith", "scope": "scope"} def __init__(self): super(KeyValuePairController, self).__init__() self._coordinator = coordination.get_coordinator() self.get_one_db_method = self._get_by_name - def get_one(self, name, requester_user, scope=FULL_SYSTEM_SCOPE, user=None, decrypt=False): + def get_one( + self, name, requester_user, scope=FULL_SYSTEM_SCOPE, user=None, decrypt=False + ): """ - List key by name. + List key by name. - Handle: - GET /keys/key1 + Handle: + GET /keys/key1 """ if not scope: # Default to system scope @@ -84,8 +84,9 @@ def get_one(self, name, requester_user, scope=FULL_SYSTEM_SCOPE, user=None, decr self._validate_scope(scope=scope) # User needs to be either admin or requesting item for itself - self._validate_decrypt_query_parameter(decrypt=decrypt, scope=scope, - requester_user=requester_user) + self._validate_decrypt_query_parameter( + decrypt=decrypt, scope=scope, requester_user=requester_user + ) user_query_param_filter = bool(user) @@ -95,45 +96,56 @@ def get_one(self, name, requester_user, scope=FULL_SYSTEM_SCOPE, user=None, decr rbac_utils = get_rbac_backend().get_utils_class() # Validate that the authenticated user is admin if user query param is provided - rbac_utils.assert_user_is_admin_if_user_query_param_is_provided(user_db=requester_user, - user=user, - require_rbac=True) + rbac_utils.assert_user_is_admin_if_user_query_param_is_provided( + user_db=requester_user, user=user, require_rbac=True + ) # Additional guard to ensure there is no information leakage across users is_admin = rbac_utils.user_is_admin(user_db=requester_user) if is_admin and user_query_param_filter: # Retrieve values scoped to the provided user - user_scope_prefix = get_key_reference(name=name, scope=USER_SCOPE, user=user) + user_scope_prefix = get_key_reference( + name=name, scope=USER_SCOPE, user=user + ) else: # RBAC not enabled or user is not an admin, retrieve user scoped values for the # current user - user_scope_prefix = get_key_reference(name=name, scope=USER_SCOPE, - user=current_user) + user_scope_prefix = get_key_reference( + name=name, scope=USER_SCOPE, user=current_user + ) if scope == FULL_USER_SCOPE: key_ref = user_scope_prefix elif scope == FULL_SYSTEM_SCOPE: key_ref = get_key_reference(scope=FULL_SYSTEM_SCOPE, name=name, user=user) else: - raise ValueError('Invalid scope: %s' % (scope)) + raise ValueError("Invalid scope: %s" % (scope)) - from_model_kwargs = {'mask_secrets': not decrypt} + from_model_kwargs = {"mask_secrets": not decrypt} kvp_api = self._get_one_by_scope_and_name( - name=key_ref, - scope=scope, - from_model_kwargs=from_model_kwargs + name=key_ref, scope=scope, from_model_kwargs=from_model_kwargs ) return kvp_api - def get_all(self, requester_user, prefix=None, scope=FULL_SYSTEM_SCOPE, user=None, - decrypt=False, sort=None, offset=0, limit=None, **raw_filters): + def get_all( + self, + requester_user, + prefix=None, + scope=FULL_SYSTEM_SCOPE, + user=None, + decrypt=False, + sort=None, + offset=0, + limit=None, + **raw_filters, + ): """ - List all keys. + List all keys. - Handles requests: - GET /keys/ + Handles requests: + GET /keys/ """ if not scope: # Default to system scope @@ -152,8 +164,9 @@ def get_all(self, requester_user, prefix=None, scope=FULL_SYSTEM_SCOPE, user=Non self._validate_all_scope(scope=scope, requester_user=requester_user) # User needs to be either admin or requesting items for themselves - self._validate_decrypt_query_parameter(decrypt=decrypt, scope=scope, - requester_user=requester_user) + self._validate_decrypt_query_parameter( + decrypt=decrypt, scope=scope, requester_user=requester_user + ) user_query_param_filter = bool(user) @@ -163,15 +176,15 @@ def get_all(self, requester_user, prefix=None, scope=FULL_SYSTEM_SCOPE, user=Non rbac_utils = get_rbac_backend().get_utils_class() # Validate that the authenticated user is admin if user query param is provided - rbac_utils.assert_user_is_admin_if_user_query_param_is_provided(user_db=requester_user, - user=user, - require_rbac=True) + rbac_utils.assert_user_is_admin_if_user_query_param_is_provided( + user_db=requester_user, user=user, require_rbac=True + ) - from_model_kwargs = {'mask_secrets': not decrypt} + from_model_kwargs = {"mask_secrets": not decrypt} if scope and scope not in ALL_SCOPE: self._validate_scope(scope=scope) - raw_filters['scope'] = scope + raw_filters["scope"] = scope # Set prefix which will be used for user-scoped items. # NOTE: It's very important raw_filters['prefix'] is set when requesting user scoped items @@ -180,47 +193,52 @@ def get_all(self, requester_user, prefix=None, scope=FULL_SYSTEM_SCOPE, user=Non if is_admin and user_query_param_filter: # Retrieve values scoped to the provided user - user_scope_prefix = get_key_reference(name=prefix or '', scope=USER_SCOPE, user=user) + user_scope_prefix = get_key_reference( + name=prefix or "", scope=USER_SCOPE, user=user + ) else: # RBAC not enabled or user is not an admin, retrieve user scoped values for the # current user - user_scope_prefix = get_key_reference(name=prefix or '', scope=USER_SCOPE, - user=current_user) + user_scope_prefix = get_key_reference( + name=prefix or "", scope=USER_SCOPE, user=current_user + ) if scope == ALL_SCOPE: # Special case for ALL_SCOPE # 1. Retrieve system scoped values - raw_filters['scope'] = FULL_SYSTEM_SCOPE - raw_filters['prefix'] = prefix + raw_filters["scope"] = FULL_SYSTEM_SCOPE + raw_filters["prefix"] = prefix - assert 'scope' in raw_filters + assert "scope" in raw_filters kvp_apis_system = super(KeyValuePairController, self)._get_all( from_model_kwargs=from_model_kwargs, sort=sort, offset=offset, limit=limit, raw_filters=raw_filters, - requester_user=requester_user) + requester_user=requester_user, + ) # 2. Retrieve user scoped items for current user or for all the users (depending if the # authenticated user is admin and if ?user is provided) - raw_filters['scope'] = FULL_USER_SCOPE + raw_filters["scope"] = FULL_USER_SCOPE if cfg.CONF.rbac.enable and is_admin and not user_query_param_filter: # Admin user retrieving user-scoped items for all the users - raw_filters['prefix'] = prefix or '' + raw_filters["prefix"] = prefix or "" else: - raw_filters['prefix'] = user_scope_prefix + raw_filters["prefix"] = user_scope_prefix - assert 'scope' in raw_filters - assert 'prefix' in raw_filters + assert "scope" in raw_filters + assert "prefix" in raw_filters kvp_apis_user = super(KeyValuePairController, self)._get_all( from_model_kwargs=from_model_kwargs, sort=sort, offset=offset, limit=limit, raw_filters=raw_filters, - requester_user=requester_user) + requester_user=requester_user, + ) # Combine the result kvp_apis = [] @@ -228,31 +246,33 @@ def get_all(self, requester_user, prefix=None, scope=FULL_SYSTEM_SCOPE, user=Non kvp_apis.extend(kvp_apis_user.json or []) elif scope in [USER_SCOPE, FULL_USER_SCOPE]: # Make sure we only returned values scoped to current user - prefix = get_key_reference(name=prefix or '', scope=scope, user=user) - raw_filters['prefix'] = user_scope_prefix + prefix = get_key_reference(name=prefix or "", scope=scope, user=user) + raw_filters["prefix"] = user_scope_prefix - assert 'scope' in raw_filters - assert 'prefix' in raw_filters + assert "scope" in raw_filters + assert "prefix" in raw_filters kvp_apis = super(KeyValuePairController, self)._get_all( from_model_kwargs=from_model_kwargs, sort=sort, offset=offset, limit=limit, raw_filters=raw_filters, - requester_user=requester_user) + requester_user=requester_user, + ) elif scope in [SYSTEM_SCOPE, FULL_SYSTEM_SCOPE]: - raw_filters['prefix'] = prefix + raw_filters["prefix"] = prefix - assert 'scope' in raw_filters + assert "scope" in raw_filters kvp_apis = super(KeyValuePairController, self)._get_all( from_model_kwargs=from_model_kwargs, sort=sort, offset=offset, limit=limit, raw_filters=raw_filters, - requester_user=requester_user) + requester_user=requester_user, + ) else: - raise ValueError('Invalid scope: %s' % (scope)) + raise ValueError("Invalid scope: %s" % (scope)) return kvp_apis @@ -266,42 +286,42 @@ def put(self, kvp, name, requester_user, scope=FULL_SYSTEM_SCOPE): if not requester_user: requester_user = UserDB(cfg.CONF.system_user.user) - scope = getattr(kvp, 'scope', scope) + scope = getattr(kvp, "scope", scope) scope = get_datastore_full_scope(scope) self._validate_scope(scope=scope) - user = getattr(kvp, 'user', requester_user.name) or requester_user.name + user = getattr(kvp, "user", requester_user.name) or requester_user.name # Validate that the authenticated user is admin if user query param is provided rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_is_admin_if_user_query_param_is_provided(user_db=requester_user, - user=user, - require_rbac=True) + rbac_utils.assert_user_is_admin_if_user_query_param_is_provided( + user_db=requester_user, user=user, require_rbac=True + ) # Validate that encrypted option can only be used by admins - encrypted = getattr(kvp, 'encrypted', False) - self._validate_encrypted_query_parameter(encrypted=encrypted, scope=scope, - requester_user=requester_user) + encrypted = getattr(kvp, "encrypted", False) + self._validate_encrypted_query_parameter( + encrypted=encrypted, scope=scope, requester_user=requester_user + ) key_ref = get_key_reference(scope=scope, name=name, user=user) lock_name = self._get_lock_name_for_key(name=key_ref, scope=scope) - LOG.debug('PUT scope: %s, name: %s', scope, name) + LOG.debug("PUT scope: %s, name: %s", scope, name) # TODO: Custom permission check since the key doesn't need to exist here # Note: We use lock to avoid a race with self._coordinator.get_lock(lock_name): try: existing_kvp_api = self._get_one_by_scope_and_name( - scope=scope, - name=key_ref + scope=scope, name=key_ref ) except StackStormDBObjectNotFoundError: existing_kvp_api = None # st2client sends invalid id when initially setting a key so we ignore those - id_ = kvp.__dict__.get('id', None) + id_ = kvp.__dict__.get("id", None) if not existing_kvp_api and id_ and not bson.ObjectId.is_valid(id_): - del kvp.__dict__['id'] + del kvp.__dict__["id"] kvp.name = key_ref kvp.scope = scope @@ -314,7 +334,7 @@ def put(self, kvp, name, requester_user, scope=FULL_SYSTEM_SCOPE): kvp_db = KeyValuePair.add_or_update(kvp_db) except (ValidationError, ValueError) as e: - LOG.exception('Validation failed for key value data=%s', kvp) + LOG.exception("Validation failed for key value data=%s", kvp) abort(http_client.BAD_REQUEST, six.text_type(e)) return except CryptoKeyNotSetupException as e: @@ -325,18 +345,18 @@ def put(self, kvp, name, requester_user, scope=FULL_SYSTEM_SCOPE): LOG.exception(six.text_type(e)) abort(http_client.BAD_REQUEST, six.text_type(e)) return - extra = {'kvp_db': kvp_db} - LOG.audit('KeyValuePair updated. KeyValuePair.id=%s' % (kvp_db.id), extra=extra) + extra = {"kvp_db": kvp_db} + LOG.audit("KeyValuePair updated. KeyValuePair.id=%s" % (kvp_db.id), extra=extra) kvp_api = KeyValuePairAPI.from_model(kvp_db) return kvp_api def delete(self, name, requester_user, scope=FULL_SYSTEM_SCOPE, user=None): """ - Delete the key value pair. + Delete the key value pair. - Handles requests: - DELETE /keys/1 + Handles requests: + DELETE /keys/1 """ if not scope: scope = FULL_SYSTEM_SCOPE @@ -351,37 +371,42 @@ def delete(self, name, requester_user, scope=FULL_SYSTEM_SCOPE, user=None): # Validate that the authenticated user is admin if user query param is provided rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_is_admin_if_user_query_param_is_provided(user_db=requester_user, - user=user, - require_rbac=True) + rbac_utils.assert_user_is_admin_if_user_query_param_is_provided( + user_db=requester_user, user=user, require_rbac=True + ) key_ref = get_key_reference(scope=scope, name=name, user=user) lock_name = self._get_lock_name_for_key(name=key_ref, scope=scope) # Note: We use lock to avoid a race with self._coordinator.get_lock(lock_name): - from_model_kwargs = {'mask_secrets': True} + from_model_kwargs = {"mask_secrets": True} kvp_api = self._get_one_by_scope_and_name( - name=key_ref, - scope=scope, - from_model_kwargs=from_model_kwargs + name=key_ref, scope=scope, from_model_kwargs=from_model_kwargs ) kvp_db = KeyValuePairAPI.to_model(kvp_api) - LOG.debug('DELETE /keys/ lookup with scope=%s name=%s found object: %s', - scope, name, kvp_db) + LOG.debug( + "DELETE /keys/ lookup with scope=%s name=%s found object: %s", + scope, + name, + kvp_db, + ) try: KeyValuePair.delete(kvp_db) except Exception as e: - LOG.exception('Database delete encountered exception during ' - 'delete of name="%s". ', name) + LOG.exception( + "Database delete encountered exception during " + 'delete of name="%s". ', + name, + ) abort(http_client.INTERNAL_SERVER_ERROR, six.text_type(e)) return - extra = {'kvp_db': kvp_db} - LOG.audit('KeyValuePair deleted. KeyValuePair.id=%s' % (kvp_db.id), extra=extra) + extra = {"kvp_db": kvp_db} + LOG.audit("KeyValuePair deleted. KeyValuePair.id=%s" % (kvp_db.id), extra=extra) return Response(status=http_client.NO_CONTENT) @@ -392,7 +417,7 @@ def _get_lock_name_for_key(self, name, scope=FULL_SYSTEM_SCOPE): :param name: Datastore item name (PK). :type name: ``str`` """ - lock_name = six.b('kvp-crud-%s.%s' % (scope, name)) + lock_name = six.b("kvp-crud-%s.%s" % (scope, name)) return lock_name def _validate_all_scope(self, scope, requester_user): @@ -400,7 +425,7 @@ def _validate_all_scope(self, scope, requester_user): Validate that "all" scope can only be provided by admins on RBAC installations. """ scope = get_datastore_full_scope(scope) - is_all_scope = (scope == ALL_SCOPE) + is_all_scope = scope == ALL_SCOPE rbac_utils = get_rbac_backend().get_utils_class() is_admin = rbac_utils.user_is_admin(user_db=requester_user) @@ -415,22 +440,25 @@ def _validate_decrypt_query_parameter(self, decrypt, scope, requester_user): """ rbac_utils = get_rbac_backend().get_utils_class() is_admin = rbac_utils.user_is_admin(user_db=requester_user) - is_user_scope = (scope == USER_SCOPE or scope == FULL_USER_SCOPE) + is_user_scope = scope == USER_SCOPE or scope == FULL_USER_SCOPE if decrypt and (not is_user_scope and not is_admin): - msg = 'Decrypt option requires administrator access' + msg = "Decrypt option requires administrator access" raise AccessDeniedError(message=msg, user_db=requester_user) def _validate_encrypted_query_parameter(self, encrypted, scope, requester_user): rbac_utils = get_rbac_backend().get_utils_class() is_admin = rbac_utils.user_is_admin(user_db=requester_user) if encrypted and not is_admin: - msg = 'Pre-encrypted option requires administrator access' + msg = "Pre-encrypted option requires administrator access" raise AccessDeniedError(message=msg, user_db=requester_user) def _validate_scope(self, scope): if scope not in ALLOWED_SCOPES: - msg = 'Scope %s is not in allowed scopes list: %s.' % (scope, ALLOWED_SCOPES) + msg = "Scope %s is not in allowed scopes list: %s." % ( + scope, + ALLOWED_SCOPES, + ) raise ValueError(msg) diff --git a/st2api/st2api/controllers/v1/pack_config_schemas.py b/st2api/st2api/controllers/v1/pack_config_schemas.py index 551573e12eb..933a7ab5001 100644 --- a/st2api/st2api/controllers/v1/pack_config_schemas.py +++ b/st2api/st2api/controllers/v1/pack_config_schemas.py @@ -23,9 +23,7 @@ http_client = six.moves.http_client -__all__ = [ - 'PackConfigSchemasController' -] +__all__ = ["PackConfigSchemasController"] class PackConfigSchemasController(ResourceController): @@ -40,7 +38,9 @@ def __init__(self): # this case, RBAC is checked on the parent PackDB object self.get_one_db_method = packs_service.get_pack_by_ref - def get_all(self, sort=None, offset=0, limit=None, requester_user=None, **raw_filters): + def get_all( + self, sort=None, offset=0, limit=None, requester_user=None, **raw_filters + ): """ Retrieve config schema for all the packs. @@ -48,11 +48,13 @@ def get_all(self, sort=None, offset=0, limit=None, requester_user=None, **raw_fi GET /config_schema/ """ - return super(PackConfigSchemasController, self)._get_all(sort=sort, - offset=offset, - limit=limit, - raw_filters=raw_filters, - requester_user=requester_user) + return super(PackConfigSchemasController, self)._get_all( + sort=sort, + offset=offset, + limit=limit, + raw_filters=raw_filters, + requester_user=requester_user, + ) def get_one(self, pack_ref, requester_user): """ @@ -61,7 +63,9 @@ def get_one(self, pack_ref, requester_user): Handles requests: GET /config_schema/ """ - packs_controller._get_one_by_ref_or_id(ref_or_id=pack_ref, requester_user=requester_user) + packs_controller._get_one_by_ref_or_id( + ref_or_id=pack_ref, requester_user=requester_user + ) return self._get_one_by_pack_ref(pack_ref=pack_ref) diff --git a/st2api/st2api/controllers/v1/pack_configs.py b/st2api/st2api/controllers/v1/pack_configs.py index 6eb18c7a342..4123a3cb227 100644 --- a/st2api/st2api/controllers/v1/pack_configs.py +++ b/st2api/st2api/controllers/v1/pack_configs.py @@ -35,9 +35,7 @@ http_client = six.moves.http_client -__all__ = [ - 'PackConfigsController' -] +__all__ = ["PackConfigsController"] LOG = logging.getLogger(__name__) @@ -54,8 +52,15 @@ def __init__(self): # this case, RBAC is checked on the parent PackDB object self.get_one_db_method = packs_service.get_pack_by_ref - def get_all(self, requester_user, sort=None, offset=0, limit=None, show_secrets=False, - **raw_filters): + def get_all( + self, + requester_user, + sort=None, + offset=0, + limit=None, + show_secrets=False, + **raw_filters, + ): """ Retrieve configs for all the packs. @@ -63,14 +68,18 @@ def get_all(self, requester_user, sort=None, offset=0, limit=None, show_secrets= GET /configs/ """ from_model_kwargs = { - 'mask_secrets': self._get_mask_secrets(requester_user, show_secrets=show_secrets) + "mask_secrets": self._get_mask_secrets( + requester_user, show_secrets=show_secrets + ) } - return super(PackConfigsController, self)._get_all(sort=sort, - offset=offset, - limit=limit, - from_model_kwargs=from_model_kwargs, - raw_filters=raw_filters, - requester_user=requester_user) + return super(PackConfigsController, self)._get_all( + sort=sort, + offset=offset, + limit=limit, + from_model_kwargs=from_model_kwargs, + raw_filters=raw_filters, + requester_user=requester_user, + ) def get_one(self, pack_ref, requester_user, show_secrets=False): """ @@ -80,7 +89,9 @@ def get_one(self, pack_ref, requester_user, show_secrets=False): GET /configs/ """ from_model_kwargs = { - 'mask_secrets': self._get_mask_secrets(requester_user, show_secrets=show_secrets) + "mask_secrets": self._get_mask_secrets( + requester_user, show_secrets=show_secrets + ) } try: instance = packs_service.get_pack_by_ref(pack_ref=pack_ref) @@ -89,18 +100,22 @@ def get_one(self, pack_ref, requester_user, show_secrets=False): abort(http_client.NOT_FOUND, msg) rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=instance, - permission_type=PermissionType.PACK_VIEW) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=instance, + permission_type=PermissionType.PACK_VIEW, + ) - return self._get_one_by_pack_ref(pack_ref=pack_ref, from_model_kwargs=from_model_kwargs) + return self._get_one_by_pack_ref( + pack_ref=pack_ref, from_model_kwargs=from_model_kwargs + ) def put(self, pack_config_content, pack_ref, requester_user, show_secrets=False): """ - Create a new config for a pack. + Create a new config for a pack. - Handles requests: - POST /configs/ + Handles requests: + POST /configs/ """ try: @@ -121,9 +136,9 @@ def put(self, pack_config_content, pack_ref, requester_user, show_secrets=False) def _dump_config_to_disk(self, config_api): config_content = yaml.safe_dump(config_api.values, default_flow_style=False) - configs_path = os.path.join(cfg.CONF.system.base_path, 'configs/') - config_path = os.path.join(configs_path, '%s.yaml' % config_api.pack) - with open(config_path, 'w') as f: + configs_path = os.path.join(cfg.CONF.system.base_path, "configs/") + config_path = os.path.join(configs_path, "%s.yaml" % config_api.pack) + with open(config_path, "w") as f: f.write(config_content) diff --git a/st2api/st2api/controllers/v1/pack_views.py b/st2api/st2api/controllers/v1/pack_views.py index 5e8b310c332..4fd6f9dd3a4 100644 --- a/st2api/st2api/controllers/v1/pack_views.py +++ b/st2api/st2api/controllers/v1/pack_views.py @@ -33,10 +33,7 @@ http_client = six.moves.http_client -__all__ = [ - 'FilesController', - 'FileController' -] +__all__ = ["FilesController", "FileController"] http_client = six.moves.http_client @@ -46,12 +43,10 @@ # Maximum file size in bytes. If the file on disk is larger then this value, we don't include it # in the response. This prevents DDoS / exhaustion attacks. -MAX_FILE_SIZE = (500 * 1000) +MAX_FILE_SIZE = 500 * 1000 # File paths in the file controller for which RBAC checks are not performed -WHITELISTED_FILE_PATHS = [ - 'icon.png' -] +WHITELISTED_FILE_PATHS = ["icon.png"] class BaseFileController(BasePacksController): @@ -76,7 +71,7 @@ def _get_file_stats(self, file_path): return file_stats.st_size, file_stats.st_mtime def _get_file_content(self, file_path): - with codecs.open(file_path, 'rb') as fp: + with codecs.open(file_path, "rb") as fp: content = fp.read() return content @@ -105,17 +100,19 @@ def __init__(self): def get_one(self, ref_or_id, requester_user): """ - Outputs the content of all the files inside the pack. + Outputs the content of all the files inside the pack. - Handles requests: - GET /packs/views/files/ + Handles requests: + GET /packs/views/files/ """ pack_db = self._get_by_ref_or_id(ref_or_id=ref_or_id) rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=pack_db, - permission_type=PermissionType.PACK_VIEW) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=pack_db, + permission_type=PermissionType.PACK_VIEW, + ) if not pack_db: msg = 'Pack with ref_or_id "%s" does not exist' % (ref_or_id) @@ -126,15 +123,19 @@ def get_one(self, ref_or_id, requester_user): result = [] for file_path in pack_files: - normalized_file_path = get_pack_file_abs_path(pack_ref=pack_ref, file_path=file_path) + normalized_file_path = get_pack_file_abs_path( + pack_ref=pack_ref, file_path=file_path + ) if not normalized_file_path or not os.path.isfile(normalized_file_path): # Ignore references to files which don't exist on disk continue file_size = self._get_file_size(file_path=normalized_file_path) if file_size is not None and file_size > MAX_FILE_SIZE: - LOG.debug('Skipping file "%s" which size exceeds max file size (%s bytes)' % - (normalized_file_path, MAX_FILE_SIZE)) + LOG.debug( + 'Skipping file "%s" which size exceeds max file size (%s bytes)' + % (normalized_file_path, MAX_FILE_SIZE) + ) continue content = self._get_file_content(file_path=normalized_file_path) @@ -144,10 +145,7 @@ def get_one(self, ref_or_id, requester_user): LOG.debug('Skipping binary file "%s"' % (normalized_file_path)) continue - item = { - 'file_path': file_path, - 'content': content - } + item = {"file_path": file_path, "content": content} result.append(item) return result @@ -173,13 +171,19 @@ class FileController(BaseFileController): Controller which allows user to retrieve content of a specific file in a pack. """ - def get_one(self, ref_or_id, file_path, requester_user, if_none_match=None, - if_modified_since=None): + def get_one( + self, + ref_or_id, + file_path, + requester_user, + if_none_match=None, + if_modified_since=None, + ): """ - Outputs the content of a specific file in a pack. + Outputs the content of a specific file in a pack. - Handles requests: - GET /packs/views/file// + Handles requests: + GET /packs/views/file// """ pack_db = self._get_by_ref_or_id(ref_or_id=ref_or_id) @@ -188,7 +192,7 @@ def get_one(self, ref_or_id, file_path, requester_user, if_none_match=None, raise StackStormDBObjectNotFoundError(msg) if not file_path: - raise ValueError('Missing file path') + raise ValueError("Missing file path") pack_ref = pack_db.ref @@ -196,11 +200,15 @@ def get_one(self, ref_or_id, file_path, requester_user, if_none_match=None, permission_type = PermissionType.PACK_VIEW if file_path not in WHITELISTED_FILE_PATHS: rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=pack_db, - permission_type=permission_type) - - normalized_file_path = get_pack_file_abs_path(pack_ref=pack_ref, file_path=file_path) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=pack_db, + permission_type=permission_type, + ) + + normalized_file_path = get_pack_file_abs_path( + pack_ref=pack_ref, file_path=file_path + ) if not normalized_file_path or not os.path.isfile(normalized_file_path): # Ignore references to files which don't exist on disk raise StackStormDBObjectNotFoundError('File "%s" not found' % (file_path)) @@ -209,24 +217,28 @@ def get_one(self, ref_or_id, file_path, requester_user, if_none_match=None, response = Response() - if not self._is_file_changed(file_mtime, - if_none_match=if_none_match, - if_modified_since=if_modified_since): + if not self._is_file_changed( + file_mtime, if_none_match=if_none_match, if_modified_since=if_modified_since + ): response.status = http_client.NOT_MODIFIED else: if file_size is not None and file_size > MAX_FILE_SIZE: - msg = ('File %s exceeds maximum allowed file size (%s bytes)' % - (file_path, MAX_FILE_SIZE)) + msg = "File %s exceeds maximum allowed file size (%s bytes)" % ( + file_path, + MAX_FILE_SIZE, + ) raise ValueError(msg) - content_type = mimetypes.guess_type(normalized_file_path)[0] or \ - 'application/octet-stream' + content_type = ( + mimetypes.guess_type(normalized_file_path)[0] + or "application/octet-stream" + ) - response.headers['Content-Type'] = content_type + response.headers["Content-Type"] = content_type response.body = self._get_file_content(file_path=normalized_file_path) - response.headers['Last-Modified'] = format_date_time(file_mtime) - response.headers['ETag'] = repr(file_mtime) + response.headers["Last-Modified"] = format_date_time(file_mtime) + response.headers["ETag"] = repr(file_mtime) return response diff --git a/st2api/st2api/controllers/v1/packs.py b/st2api/st2api/controllers/v1/packs.py index 6193a3f01fb..75da16e5e5e 100644 --- a/st2api/st2api/controllers/v1/packs.py +++ b/st2api/st2api/controllers/v1/packs.py @@ -52,115 +52,119 @@ http_client = six.moves.http_client -__all__ = [ - 'PacksController', - 'BasePacksController', - 'ENTITIES' -] +__all__ = ["PacksController", "BasePacksController", "ENTITIES"] LOG = logging.getLogger(__name__) # Note: The order those are defined it's important so they are registered in # the same order as they are in st2-register-content. # We also need to use list of tuples to preserve the order. -ENTITIES = OrderedDict([ - ('trigger', (TriggersRegistrar, 'triggers')), - ('sensor', (SensorsRegistrar, 'sensors')), - ('action', (ActionsRegistrar, 'actions')), - ('rule', (RulesRegistrar, 'rules')), - ('alias', (AliasesRegistrar, 'aliases')), - ('policy', (PolicyRegistrar, 'policies')), - ('config', (ConfigsRegistrar, 'configs')) -]) +ENTITIES = OrderedDict( + [ + ("trigger", (TriggersRegistrar, "triggers")), + ("sensor", (SensorsRegistrar, "sensors")), + ("action", (ActionsRegistrar, "actions")), + ("rule", (RulesRegistrar, "rules")), + ("alias", (AliasesRegistrar, "aliases")), + ("policy", (PolicyRegistrar, "policies")), + ("config", (ConfigsRegistrar, "configs")), + ] +) def _get_proxy_config(): - LOG.debug('Loading proxy configuration from env variables %s.', os.environ) - http_proxy = os.environ.get('http_proxy', None) - https_proxy = os.environ.get('https_proxy', None) - no_proxy = os.environ.get('no_proxy', None) - proxy_ca_bundle_path = os.environ.get('proxy_ca_bundle_path', None) + LOG.debug("Loading proxy configuration from env variables %s.", os.environ) + http_proxy = os.environ.get("http_proxy", None) + https_proxy = os.environ.get("https_proxy", None) + no_proxy = os.environ.get("no_proxy", None) + proxy_ca_bundle_path = os.environ.get("proxy_ca_bundle_path", None) proxy_config = { - 'http_proxy': http_proxy, - 'https_proxy': https_proxy, - 'proxy_ca_bundle_path': proxy_ca_bundle_path, - 'no_proxy': no_proxy + "http_proxy": http_proxy, + "https_proxy": https_proxy, + "proxy_ca_bundle_path": proxy_ca_bundle_path, + "no_proxy": no_proxy, } - LOG.debug('Proxy configuration: %s', proxy_config) + LOG.debug("Proxy configuration: %s", proxy_config) return proxy_config class PackInstallController(ActionExecutionsControllerMixin): - def post(self, pack_install_request, requester_user=None): parameters = { - 'packs': pack_install_request.packs, + "packs": pack_install_request.packs, } if pack_install_request.force: - parameters['force'] = True + parameters["force"] = True if pack_install_request.skip_dependencies: - parameters['skip_dependencies'] = True + parameters["skip_dependencies"] = True if not requester_user: requester_user = UserDB(cfg.CONF.system_user.user) - new_liveaction_api = LiveActionCreateAPI(action='packs.install', - parameters=parameters, - user=requester_user.name) + new_liveaction_api = LiveActionCreateAPI( + action="packs.install", parameters=parameters, user=requester_user.name + ) - execution_resp = self._handle_schedule_execution(liveaction_api=new_liveaction_api, - requester_user=requester_user) + execution_resp = self._handle_schedule_execution( + liveaction_api=new_liveaction_api, requester_user=requester_user + ) - exec_id = PackAsyncAPI(execution_id=execution_resp.json['id']) + exec_id = PackAsyncAPI(execution_id=execution_resp.json["id"]) return Response(json=exec_id, status=http_client.ACCEPTED) class PackUninstallController(ActionExecutionsControllerMixin): - def post(self, pack_uninstall_request, ref_or_id=None, requester_user=None): if ref_or_id: - parameters = { - 'packs': [ref_or_id] - } + parameters = {"packs": [ref_or_id]} else: - parameters = { - 'packs': pack_uninstall_request.packs - } + parameters = {"packs": pack_uninstall_request.packs} if not requester_user: requester_user = UserDB(cfg.CONF.system_user.user) - new_liveaction_api = LiveActionCreateAPI(action='packs.uninstall', - parameters=parameters, - user=requester_user.name) + new_liveaction_api = LiveActionCreateAPI( + action="packs.uninstall", parameters=parameters, user=requester_user.name + ) - execution_resp = self._handle_schedule_execution(liveaction_api=new_liveaction_api, - requester_user=requester_user) + execution_resp = self._handle_schedule_execution( + liveaction_api=new_liveaction_api, requester_user=requester_user + ) - exec_id = PackAsyncAPI(execution_id=execution_resp.json['id']) + exec_id = PackAsyncAPI(execution_id=execution_resp.json["id"]) return Response(json=exec_id, status=http_client.ACCEPTED) class PackRegisterController(object): - CONTENT_TYPES = ['runner', 'action', 'trigger', 'sensor', 'rule', - 'rule_type', 'alias', 'policy_type', 'policy', 'config'] + CONTENT_TYPES = [ + "runner", + "action", + "trigger", + "sensor", + "rule", + "rule_type", + "alias", + "policy_type", + "policy", + "config", + ] def post(self, pack_register_request): - if pack_register_request and hasattr(pack_register_request, 'types'): + if pack_register_request and hasattr(pack_register_request, "types"): types = pack_register_request.types - if 'all' in types: + if "all" in types: types = PackRegisterController.CONTENT_TYPES else: types = PackRegisterController.CONTENT_TYPES - if pack_register_request and hasattr(pack_register_request, 'packs'): + if pack_register_request and hasattr(pack_register_request, "packs"): packs = list(set(pack_register_request.packs)) else: packs = None @@ -168,64 +172,80 @@ def post(self, pack_register_request): result = defaultdict(int) # Register depended resources (actions depend on runners, rules depend on rule types, etc) - if ('runner' in types or 'runners' in types) or ('action' in types or 'actions' in types): - result['runners'] = runners_registrar.register_runners(experimental=True) - if ('rule_type' in types or 'rule_types' in types) or \ - ('rule' in types or 'rules' in types): - result['rule_types'] = rule_types_registrar.register_rule_types() - if ('policy_type' in types or 'policy_types' in types) or \ - ('policy' in types or 'policies' in types): - result['policy_types'] = policies_registrar.register_policy_types(st2common) + if ("runner" in types or "runners" in types) or ( + "action" in types or "actions" in types + ): + result["runners"] = runners_registrar.register_runners(experimental=True) + if ("rule_type" in types or "rule_types" in types) or ( + "rule" in types or "rules" in types + ): + result["rule_types"] = rule_types_registrar.register_rule_types() + if ("policy_type" in types or "policy_types" in types) or ( + "policy" in types or "policies" in types + ): + result["policy_types"] = policies_registrar.register_policy_types(st2common) use_pack_cache = False - fail_on_failure = getattr(pack_register_request, 'fail_on_failure', True) + fail_on_failure = getattr(pack_register_request, "fail_on_failure", True) for type, (Registrar, name) in six.iteritems(ENTITIES): if type in types or name in types: - registrar = Registrar(use_pack_cache=use_pack_cache, - use_runners_cache=True, - fail_on_failure=fail_on_failure) + registrar = Registrar( + use_pack_cache=use_pack_cache, + use_runners_cache=True, + fail_on_failure=fail_on_failure, + ) if packs: for pack in packs: pack_path = content_utils.get_pack_base_path(pack) try: - registered_count = registrar.register_from_pack(pack_dir=pack_path) + registered_count = registrar.register_from_pack( + pack_dir=pack_path + ) result[name] += registered_count except ValueError as e: # Throw more user-friendly exception if requsted pack doesn't exist - if re.match('Directory ".*?" doesn\'t exist', six.text_type(e)): - msg = 'Pack "%s" not found on disk: %s' % (pack, six.text_type(e)) + if re.match( + 'Directory ".*?" doesn\'t exist', six.text_type(e) + ): + msg = 'Pack "%s" not found on disk: %s' % ( + pack, + six.text_type(e), + ) raise ValueError(msg) raise e else: packs_base_paths = content_utils.get_packs_base_paths() - registered_count = registrar.register_from_packs(base_dirs=packs_base_paths) + registered_count = registrar.register_from_packs( + base_dirs=packs_base_paths + ) result[name] += registered_count return result class PackSearchController(object): - def post(self, pack_search_request): proxy_config = _get_proxy_config() - if hasattr(pack_search_request, 'query'): - packs = packs_service.search_pack_index(pack_search_request.query, - case_sensitive=False, - proxy_config=proxy_config) + if hasattr(pack_search_request, "query"): + packs = packs_service.search_pack_index( + pack_search_request.query, + case_sensitive=False, + proxy_config=proxy_config, + ) return [PackAPI(**pack) for pack in packs] else: - pack = packs_service.get_pack_from_index(pack_search_request.pack, - proxy_config=proxy_config) + pack = packs_service.get_pack_from_index( + pack_search_request.pack, proxy_config=proxy_config + ) return PackAPI(**pack) if pack else [] class IndexHealthController(object): - def get(self): """ Check if all listed indexes are healthy: they should be reachable, @@ -233,7 +253,9 @@ def get(self): """ proxy_config = _get_proxy_config() - _, status = packs_service.fetch_pack_index(allow_empty=True, proxy_config=proxy_config) + _, status = packs_service.fetch_pack_index( + allow_empty=True, proxy_config=proxy_config + ) health = { "indexes": { @@ -249,13 +271,13 @@ def get(self): } for index in status: - if index['error']: - error_count = health['indexes']['errors'].get(index['error'], 0) + 1 - health['indexes']['invalid'] += 1 - health['indexes']['errors'][index['error']] = error_count + if index["error"]: + error_count = health["indexes"]["errors"].get(index["error"], 0) + 1 + health["indexes"]["invalid"] += 1 + health["indexes"]["errors"][index["error"]] = error_count else: - health['indexes']['valid'] += 1 - health['packs']['count'] += index['packs'] + health["indexes"]["valid"] += 1 + health["packs"]["count"] += index["packs"] return health @@ -265,12 +287,16 @@ class BasePacksController(ResourceController): access = Pack def _get_one_by_ref_or_id(self, ref_or_id, requester_user, exclude_fields=None): - instance = self._get_by_ref_or_id(ref_or_id=ref_or_id, exclude_fields=exclude_fields) + instance = self._get_by_ref_or_id( + ref_or_id=ref_or_id, exclude_fields=exclude_fields + ) rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=instance, - permission_type=PermissionType.PACK_VIEW) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=instance, + permission_type=PermissionType.PACK_VIEW, + ) if not instance: msg = 'Unable to identify resource with ref_or_id "%s".' % (ref_or_id) @@ -282,7 +308,9 @@ def _get_one_by_ref_or_id(self, ref_or_id, requester_user, exclude_fields=None): return result def _get_by_ref_or_id(self, ref_or_id, exclude_fields=None): - resource_db = self._get_by_id(resource_id=ref_or_id, exclude_fields=exclude_fields) + resource_db = self._get_by_id( + resource_id=ref_or_id, exclude_fields=exclude_fields + ) if not resource_db: # Try ref @@ -302,7 +330,7 @@ def _get_by_ref(self, ref, exclude_fields=None): return resource_db -class PacksIndexController(): +class PacksIndexController: search = PackSearchController() health = IndexHealthController() @@ -311,10 +339,7 @@ def get_all(self): index, status = packs_service.fetch_pack_index(proxy_config=proxy_config) - return { - 'status': status, - 'index': index - } + return {"status": status, "index": index} class PacksController(BasePacksController): @@ -322,14 +347,9 @@ class PacksController(BasePacksController): model = PackAPI access = Pack - supported_filters = { - 'name': 'name', - 'ref': 'ref' - } + supported_filters = {"name": "name", "ref": "ref"} - query_options = { - 'sort': ['ref'] - } + query_options = {"sort": ["ref"]} # Nested controllers install = PackInstallController() @@ -342,18 +362,30 @@ def __init__(self): super(PacksController, self).__init__() self.get_one_db_method = self._get_by_ref_or_id - def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0, - limit=None, requester_user=None, **raw_filters): - return super(PacksController, self)._get_all(exclude_fields=exclude_attributes, - include_fields=include_attributes, - sort=sort, - offset=offset, - limit=limit, - raw_filters=raw_filters, - requester_user=requester_user) + def get_all( + self, + exclude_attributes=None, + include_attributes=None, + sort=None, + offset=0, + limit=None, + requester_user=None, + **raw_filters, + ): + return super(PacksController, self)._get_all( + exclude_fields=exclude_attributes, + include_fields=include_attributes, + sort=sort, + offset=offset, + limit=limit, + raw_filters=raw_filters, + requester_user=requester_user, + ) def get_one(self, ref_or_id, requester_user): - return self._get_one_by_ref_or_id(ref_or_id=ref_or_id, requester_user=requester_user) + return self._get_one_by_ref_or_id( + ref_or_id=ref_or_id, requester_user=requester_user + ) packs_controller = PacksController() diff --git a/st2api/st2api/controllers/v1/policies.py b/st2api/st2api/controllers/v1/policies.py index 3fc488708bb..aa57b7cf3da 100644 --- a/st2api/st2api/controllers/v1/policies.py +++ b/st2api/st2api/controllers/v1/policies.py @@ -37,54 +37,73 @@ class PolicyTypeController(resource.ResourceController): model = PolicyTypeAPI access = PolicyType - mandatory_include_fields_retrieve = ['id', 'name', 'resource_type'] + mandatory_include_fields_retrieve = ["id", "name", "resource_type"] - supported_filters = { - 'resource_type': 'resource_type' - } + supported_filters = {"resource_type": "resource_type"} - query_options = { - 'sort': ['resource_type', 'name'] - } + query_options = {"sort": ["resource_type", "name"]} def get_one(self, ref_or_id, requester_user): return self._get_one(ref_or_id, requester_user=requester_user) - def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0, - limit=None, requester_user=None, **raw_filters): - return self._get_all(exclude_fields=exclude_attributes, - include_fields=include_attributes, - sort=sort, - offset=offset, - limit=limit, - raw_filters=raw_filters, - requester_user=requester_user) + def get_all( + self, + exclude_attributes=None, + include_attributes=None, + sort=None, + offset=0, + limit=None, + requester_user=None, + **raw_filters, + ): + return self._get_all( + exclude_fields=exclude_attributes, + include_fields=include_attributes, + sort=sort, + offset=offset, + limit=limit, + raw_filters=raw_filters, + requester_user=requester_user, + ) def _get_one(self, ref_or_id, requester_user): instance = self._get_by_ref_or_id(ref_or_id=ref_or_id) permission_type = PermissionType.POLICY_TYPE_VIEW rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=instance, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=instance, + permission_type=permission_type, + ) result = self.model.from_model(instance) return result - def _get_all(self, exclude_fields=None, include_fields=None, sort=None, offset=0, limit=None, - query_options=None, from_model_kwargs=None, raw_filters=None, - requester_user=None): - - resp = super(PolicyTypeController, self)._get_all(exclude_fields=exclude_fields, - include_fields=include_fields, - sort=sort, - offset=offset, - limit=limit, - query_options=query_options, - from_model_kwargs=from_model_kwargs, - raw_filters=raw_filters, - requester_user=requester_user) + def _get_all( + self, + exclude_fields=None, + include_fields=None, + sort=None, + offset=0, + limit=None, + query_options=None, + from_model_kwargs=None, + raw_filters=None, + requester_user=None, + ): + + resp = super(PolicyTypeController, self)._get_all( + exclude_fields=exclude_fields, + include_fields=include_fields, + sort=sort, + offset=offset, + limit=limit, + query_options=query_options, + from_model_kwargs=from_model_kwargs, + raw_filters=raw_filters, + requester_user=requester_user, + ) return resp @@ -114,7 +133,9 @@ def _get_by_ref(self, resource_ref): except Exception: return None - resource_db = self.access.query(name=ref.name, resource_type=ref.resource_type).first() + resource_db = self.access.query( + name=ref.name, resource_type=ref.resource_type + ).first() return resource_db @@ -123,77 +144,93 @@ class PolicyController(resource.ContentPackResourceController): access = Policy supported_filters = { - 'pack': 'pack', - 'resource_ref': 'resource_ref', - 'policy_type': 'policy_type' - } - - query_options = { - 'sort': ['pack', 'name'] + "pack": "pack", + "resource_ref": "resource_ref", + "policy_type": "policy_type", } - def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0, - limit=None, requester_user=None, **raw_filters): - return self._get_all(exclude_fields=exclude_attributes, - include_fields=include_attributes, - sort=sort, - offset=offset, - limit=limit, - raw_filters=raw_filters, - requester_user=requester_user) + query_options = {"sort": ["pack", "name"]} + + def get_all( + self, + exclude_attributes=None, + include_attributes=None, + sort=None, + offset=0, + limit=None, + requester_user=None, + **raw_filters, + ): + return self._get_all( + exclude_fields=exclude_attributes, + include_fields=include_attributes, + sort=sort, + offset=offset, + limit=limit, + raw_filters=raw_filters, + requester_user=requester_user, + ) def get_one(self, ref_or_id, requester_user): permission_type = PermissionType.POLICY_VIEW - return self._get_one(ref_or_id, permission_type=permission_type, - requester_user=requester_user) + return self._get_one( + ref_or_id, permission_type=permission_type, requester_user=requester_user + ) def post(self, instance, requester_user): """ - Create a new policy. - Handles requests: - POST /policies/ + Create a new policy. + Handles requests: + POST /policies/ """ permission_type = PermissionType.POLICY_CREATE rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_api_permission(user_db=requester_user, - resource_api=instance, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_api_permission( + user_db=requester_user, + resource_api=instance, + permission_type=permission_type, + ) - op = 'POST /policies/' + op = "POST /policies/" db_model = self.model.to_model(instance) - LOG.debug('%s verified object: %s', op, db_model) + LOG.debug("%s verified object: %s", op, db_model) db_model = self.access.add_or_update(db_model) - LOG.debug('%s created object: %s', op, db_model) - LOG.audit('Policy created. Policy.id=%s' % (db_model.id), extra={'policy_db': db_model}) + LOG.debug("%s created object: %s", op, db_model) + LOG.audit( + "Policy created. Policy.id=%s" % (db_model.id), + extra={"policy_db": db_model}, + ) exec_result = self.model.from_model(db_model) return Response(json=exec_result, status=http_client.CREATED) def put(self, instance, ref_or_id, requester_user): - op = 'PUT /policies/%s/' % ref_or_id + op = "PUT /policies/%s/" % ref_or_id db_model = self._get_by_ref_or_id(ref_or_id=ref_or_id) - LOG.debug('%s found object: %s', op, db_model) + LOG.debug("%s found object: %s", op, db_model) permission_type = PermissionType.POLICY_MODIFY rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=db_model, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=db_model, + permission_type=permission_type, + ) db_model_id = db_model.id try: validate_not_part_of_system_pack(db_model) except ValueValidationException as e: - LOG.exception('%s unable to update object from system pack.', op) + LOG.exception("%s unable to update object from system pack.", op) abort(http_client.BAD_REQUEST, six.text_type(e)) - if not getattr(instance, 'pack', None): + if not getattr(instance, "pack", None): instance.pack = db_model.pack try: @@ -201,12 +238,15 @@ def put(self, instance, ref_or_id, requester_user): db_model.id = db_model_id db_model = self.access.add_or_update(db_model) except (ValidationError, ValueError) as e: - LOG.exception('%s unable to update object: %s', op, db_model) + LOG.exception("%s unable to update object: %s", op, db_model) abort(http_client.BAD_REQUEST, six.text_type(e)) return - LOG.debug('%s updated object: %s', op, db_model) - LOG.audit('Policy updated. Policy.id=%s' % (db_model.id), extra={'policy_db': db_model}) + LOG.debug("%s updated object: %s", op, db_model) + LOG.audit( + "Policy updated. Policy.id=%s" % (db_model.id), + extra={"policy_db": db_model}, + ) exec_result = self.model.from_model(db_model) @@ -214,38 +254,43 @@ def put(self, instance, ref_or_id, requester_user): def delete(self, ref_or_id, requester_user): """ - Delete a policy. - Handles requests: - POST /policies/1?_method=delete - DELETE /policies/1 - DELETE /policies/mypack.mypolicy + Delete a policy. + Handles requests: + POST /policies/1?_method=delete + DELETE /policies/1 + DELETE /policies/mypack.mypolicy """ - op = 'DELETE /policies/%s/' % ref_or_id + op = "DELETE /policies/%s/" % ref_or_id db_model = self._get_by_ref_or_id(ref_or_id=ref_or_id) - LOG.debug('%s found object: %s', op, db_model) + LOG.debug("%s found object: %s", op, db_model) permission_type = PermissionType.POLICY_DELETE rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=db_model, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=db_model, + permission_type=permission_type, + ) try: validate_not_part_of_system_pack(db_model) except ValueValidationException as e: - LOG.exception('%s unable to delete object from system pack.', op) + LOG.exception("%s unable to delete object from system pack.", op) abort(http_client.BAD_REQUEST, six.text_type(e)) try: self.access.delete(db_model) except Exception as e: - LOG.exception('%s unable to delete object: %s', op, db_model) + LOG.exception("%s unable to delete object: %s", op, db_model) abort(http_client.INTERNAL_SERVER_ERROR, six.text_type(e)) return - LOG.debug('%s deleted object: %s', op, db_model) - LOG.audit('Policy deleted. Policy.id=%s' % (db_model.id), extra={'policy_db': db_model}) + LOG.debug("%s deleted object: %s", op, db_model) + LOG.audit( + "Policy deleted. Policy.id=%s" % (db_model.id), + extra={"policy_db": db_model}, + ) # return None return Response(status=http_client.NO_CONTENT) diff --git a/st2api/st2api/controllers/v1/rbac.py b/st2api/st2api/controllers/v1/rbac.py index 0e8c1d41794..49a552f7dc2 100644 --- a/st2api/st2api/controllers/v1/rbac.py +++ b/st2api/st2api/controllers/v1/rbac.py @@ -23,78 +23,76 @@ from st2common.rbac.backends import get_rbac_backend from st2common.router import exc -__all__ = [ - 'RolesController', - 'RoleAssignmentsController', - 'PermissionTypesController' -] +__all__ = ["RolesController", "RoleAssignmentsController", "PermissionTypesController"] class RolesController(ResourceController): model = RoleAPI access = Role - supported_filters = { - 'name': 'name', - 'system': 'system' - } + supported_filters = {"name": "name", "system": "system"} - query_options = { - 'sort': ['name'] - } + query_options = {"sort": ["name"]} def get_one(self, name_or_id, requester_user): rbac_utils = get_rbac_backend().get_utils_class() rbac_utils.assert_user_is_admin(user_db=requester_user) - return self._get_one_by_name_or_id(name_or_id=name_or_id, - permission_type=None, - requester_user=requester_user) + return self._get_one_by_name_or_id( + name_or_id=name_or_id, permission_type=None, requester_user=requester_user + ) def get_all(self, requester_user, sort=None, offset=0, limit=None, **raw_filters): rbac_utils = get_rbac_backend().get_utils_class() rbac_utils.assert_user_is_admin(user_db=requester_user) - return self._get_all(sort=sort, - offset=offset, - limit=limit, - raw_filters=raw_filters, - requester_user=requester_user) + return self._get_all( + sort=sort, + offset=offset, + limit=limit, + raw_filters=raw_filters, + requester_user=requester_user, + ) class RoleAssignmentsController(ResourceController): """ Meta controller for listing role assignments. """ + model = UserRoleAssignmentAPI access = UserRoleAssignment supported_filters = { - 'user': 'user', - 'role': 'role', - 'source': 'source', - 'remote': 'is_remote' + "user": "user", + "role": "role", + "source": "source", + "remote": "is_remote", } def get_all(self, requester_user, sort=None, offset=0, limit=None, **raw_filters): - user = raw_filters.get('user', None) + user = raw_filters.get("user", None) rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_is_admin_or_operating_on_own_resource(user_db=requester_user, - user=user) - - return self._get_all(sort=sort, - offset=offset, - limit=limit, - raw_filters=raw_filters, - requester_user=requester_user) + rbac_utils.assert_user_is_admin_or_operating_on_own_resource( + user_db=requester_user, user=user + ) + + return self._get_all( + sort=sort, + offset=offset, + limit=limit, + raw_filters=raw_filters, + requester_user=requester_user, + ) def get_one(self, id, requester_user): - result = self._get_one_by_id(id, - requester_user=requester_user, - permission_type=None) - user = getattr(result, 'user', None) + result = self._get_one_by_id( + id, requester_user=requester_user, permission_type=None + ) + user = getattr(result, "user", None) rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_is_admin_or_operating_on_own_resource(user_db=requester_user, - user=user) + rbac_utils.assert_user_is_admin_or_operating_on_own_resource( + user_db=requester_user, user=user + ) return result @@ -106,10 +104,10 @@ class PermissionTypesController(object): def get_all(self, requester_user): """ - List all the available permission types. + List all the available permission types. - Handles requests: - GET /rbac/permission_types + Handles requests: + GET /rbac/permission_types """ rbac_utils = get_rbac_backend().get_utils_class() rbac_utils.assert_user_is_admin(user_db=requester_user) @@ -119,10 +117,10 @@ def get_all(self, requester_user): def get_one(self, resource_type, requester_user): """ - List all the available permission types for a particular resource type. + List all the available permission types for a particular resource type. - Handles requests: - GET /rbac/permission_types/ + Handles requests: + GET /rbac/permission_types/ """ rbac_utils = get_rbac_backend().get_utils_class() rbac_utils.assert_user_is_admin(user_db=requester_user) @@ -131,7 +129,7 @@ def get_one(self, resource_type, requester_user): permission_types = all_permission_types.get(resource_type, None) if permission_types is None: - raise exc.HTTPNotFound('Invalid resource type: %s' % (resource_type)) + raise exc.HTTPNotFound("Invalid resource type: %s" % (resource_type)) return permission_types diff --git a/st2api/st2api/controllers/v1/rule_enforcement_views.py b/st2api/st2api/controllers/v1/rule_enforcement_views.py index 3d23d027a9e..75831a917bc 100644 --- a/st2api/st2api/controllers/v1/rule_enforcement_views.py +++ b/st2api/st2api/controllers/v1/rule_enforcement_views.py @@ -26,9 +26,7 @@ from st2api.controllers.resource import ResourceController -__all__ = [ - 'RuleEnforcementViewController' -] +__all__ = ["RuleEnforcementViewController"] class RuleEnforcementViewController(ResourceController): @@ -50,8 +48,16 @@ class RuleEnforcementViewController(ResourceController): supported_filters = SUPPORTED_FILTERS filter_transform_functions = FILTER_TRANSFORM_FUNCTIONS - def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0, - limit=None, requester_user=None, **raw_filters): + def get_all( + self, + exclude_attributes=None, + include_attributes=None, + sort=None, + offset=0, + limit=None, + requester_user=None, + **raw_filters, + ): rule_enforcement_apis = super(RuleEnforcementViewController, self)._get_all( exclude_fields=exclude_attributes, include_fields=include_attributes, @@ -59,16 +65,25 @@ def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, o offset=offset, limit=limit, raw_filters=raw_filters, - requester_user=requester_user) + requester_user=requester_user, + ) - rule_enforcement_apis.json = self._append_view_properties(rule_enforcement_apis.json) + rule_enforcement_apis.json = self._append_view_properties( + rule_enforcement_apis.json + ) return rule_enforcement_apis def get_one(self, id, requester_user): - rule_enforcement_api = super(RuleEnforcementViewController, - self)._get_one_by_id(id, requester_user=requester_user, - permission_type=PermissionType.RULE_ENFORCEMENT_VIEW) - rule_enforcement_api = self._append_view_properties([rule_enforcement_api.__json__()])[0] + rule_enforcement_api = super( + RuleEnforcementViewController, self + )._get_one_by_id( + id, + requester_user=requester_user, + permission_type=PermissionType.RULE_ENFORCEMENT_VIEW, + ) + rule_enforcement_api = self._append_view_properties( + [rule_enforcement_api.__json__()] + )[0] return rule_enforcement_api def _append_view_properties(self, rule_enforcement_apis): @@ -80,29 +95,29 @@ def _append_view_properties(self, rule_enforcement_apis): execution_ids = [] for rule_enforcement_api in rule_enforcement_apis: - if rule_enforcement_api.get('trigger_instance_id', None): - trigger_instance_ids.add(str(rule_enforcement_api['trigger_instance_id'])) + if rule_enforcement_api.get("trigger_instance_id", None): + trigger_instance_ids.add( + str(rule_enforcement_api["trigger_instance_id"]) + ) - if rule_enforcement_api.get('execution_id', None): - execution_ids.append(rule_enforcement_api['execution_id']) + if rule_enforcement_api.get("execution_id", None): + execution_ids.append(rule_enforcement_api["execution_id"]) # 1. Retrieve corresponding execution objects # NOTE: Executions contain a lot of field and could contain a lot of data so we only # retrieve fields we need only_fields = [ - 'id', - - 'action.ref', - 'action.parameters', - - 'runner.name', - 'runner.runner_parameters', - - 'parameters', - 'status' + "id", + "action.ref", + "action.parameters", + "runner.name", + "runner.runner_parameters", + "parameters", + "status", ] - execution_dbs = ActionExecution.query(id__in=execution_ids, - only_fields=only_fields) + execution_dbs = ActionExecution.query( + id__in=execution_ids, only_fields=only_fields + ) execution_dbs_by_id = {} for execution_db in execution_dbs: @@ -114,26 +129,32 @@ def _append_view_properties(self, rule_enforcement_apis): trigger_instance_dbs_by_id = {} for trigger_instance_db in trigger_instance_dbs: - trigger_instance_dbs_by_id[str(trigger_instance_db.id)] = trigger_instance_db + trigger_instance_dbs_by_id[ + str(trigger_instance_db.id) + ] = trigger_instance_db # Ammend rule enforcement objects with additional data for rule_enforcement_api in rule_enforcement_apis: - rule_enforcement_api['trigger_instance'] = {} - rule_enforcement_api['execution'] = {} + rule_enforcement_api["trigger_instance"] = {} + rule_enforcement_api["execution"] = {} - trigger_instance_id = rule_enforcement_api.get('trigger_instance_id', None) - execution_id = rule_enforcement_api.get('execution_id', None) + trigger_instance_id = rule_enforcement_api.get("trigger_instance_id", None) + execution_id = rule_enforcement_api.get("execution_id", None) - trigger_instance_db = trigger_instance_dbs_by_id.get(trigger_instance_id, None) + trigger_instance_db = trigger_instance_dbs_by_id.get( + trigger_instance_id, None + ) execution_db = execution_dbs_by_id.get(execution_id, None) if trigger_instance_db: - trigger_instance_api = TriggerInstanceAPI.from_model(trigger_instance_db) - rule_enforcement_api['trigger_instance'] = trigger_instance_api + trigger_instance_api = TriggerInstanceAPI.from_model( + trigger_instance_db + ) + rule_enforcement_api["trigger_instance"] = trigger_instance_api if execution_db: execution_api = ActionExecutionAPI.from_model(execution_db) - rule_enforcement_api['execution'] = execution_api + rule_enforcement_api["execution"] = execution_api return rule_enforcement_apis diff --git a/st2api/st2api/controllers/v1/rule_enforcements.py b/st2api/st2api/controllers/v1/rule_enforcements.py index 1c117558ca4..f1c1f4c5b7b 100644 --- a/st2api/st2api/controllers/v1/rule_enforcements.py +++ b/st2api/st2api/controllers/v1/rule_enforcements.py @@ -24,11 +24,10 @@ from st2api.controllers.resource import ResourceController __all__ = [ - 'RuleEnforcementController', - - 'SUPPORTED_FILTERS', - 'QUERY_OPTIONS', - 'FILTER_TRANSFORM_FUNCTIONS' + "RuleEnforcementController", + "SUPPORTED_FILTERS", + "QUERY_OPTIONS", + "FILTER_TRANSFORM_FUNCTIONS", ] @@ -38,23 +37,21 @@ SUPPORTED_FILTERS = { - 'rule_ref': 'rule.ref', - 'rule_id': 'rule.id', - 'execution': 'execution_id', - 'trigger_instance': 'trigger_instance_id', - 'enforced_at': 'enforced_at', - 'enforced_at_gt': 'enforced_at.gt', - 'enforced_at_lt': 'enforced_at.lt' + "rule_ref": "rule.ref", + "rule_id": "rule.id", + "execution": "execution_id", + "trigger_instance": "trigger_instance_id", + "enforced_at": "enforced_at", + "enforced_at_gt": "enforced_at.gt", + "enforced_at_lt": "enforced_at.lt", } -QUERY_OPTIONS = { - 'sort': ['-enforced_at', 'rule.ref'] -} +QUERY_OPTIONS = {"sort": ["-enforced_at", "rule.ref"]} FILTER_TRANSFORM_FUNCTIONS = { - 'enforced_at': lambda value: isotime.parse(value=value), - 'enforced_at_gt': lambda value: isotime.parse(value=value), - 'enforced_at_lt': lambda value: isotime.parse(value=value) + "enforced_at": lambda value: isotime.parse(value=value), + "enforced_at_gt": lambda value: isotime.parse(value=value), + "enforced_at_lt": lambda value: isotime.parse(value=value), } @@ -69,20 +66,32 @@ class RuleEnforcementController(ResourceController): supported_filters = SUPPORTED_FILTERS filter_transform_functions = FILTER_TRANSFORM_FUNCTIONS - def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0, - limit=None, requester_user=None, **raw_filters): - return super(RuleEnforcementController, self)._get_all(exclude_fields=exclude_attributes, - include_fields=include_attributes, - sort=sort, - offset=offset, - limit=limit, - raw_filters=raw_filters, - requester_user=requester_user) + def get_all( + self, + exclude_attributes=None, + include_attributes=None, + sort=None, + offset=0, + limit=None, + requester_user=None, + **raw_filters, + ): + return super(RuleEnforcementController, self)._get_all( + exclude_fields=exclude_attributes, + include_fields=include_attributes, + sort=sort, + offset=offset, + limit=limit, + raw_filters=raw_filters, + requester_user=requester_user, + ) def get_one(self, id, requester_user): - return super(RuleEnforcementController, - self)._get_one_by_id(id, requester_user=requester_user, - permission_type=PermissionType.RULE_ENFORCEMENT_VIEW) + return super(RuleEnforcementController, self)._get_one_by_id( + id, + requester_user=requester_user, + permission_type=PermissionType.RULE_ENFORCEMENT_VIEW, + ) rule_enforcements_controller = RuleEnforcementController() diff --git a/st2api/st2api/controllers/v1/rule_views.py b/st2api/st2api/controllers/v1/rule_views.py index 70555149b7b..39b4682c526 100644 --- a/st2api/st2api/controllers/v1/rule_views.py +++ b/st2api/st2api/controllers/v1/rule_views.py @@ -32,10 +32,12 @@ LOG = logging.getLogger(__name__) -__all__ = ['RuleViewController'] +__all__ = ["RuleViewController"] -class RuleViewController(BaseResourceIsolationControllerMixin, ContentPackResourceController): +class RuleViewController( + BaseResourceIsolationControllerMixin, ContentPackResourceController +): """ Add some extras to a Rule object to make it easier for UI to render a rule. The additions do not necessarily belong in the Rule itself but are still valuable augmentations. @@ -74,64 +76,78 @@ class RuleViewController(BaseResourceIsolationControllerMixin, ContentPackResour model = RuleViewAPI access = Rule - supported_filters = { - 'name': 'name', - 'pack': 'pack', - 'user': 'context.user' - } - - query_options = { - 'sort': ['pack', 'name'] - } - - mandatory_include_fields_retrieve = ['pack', 'name', 'trigger'] - - def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0, - limit=None, requester_user=None, **raw_filters): - rules = super(RuleViewController, self)._get_all(exclude_fields=exclude_attributes, - include_fields=include_attributes, - sort=sort, - offset=offset, - limit=limit, - raw_filters=raw_filters, - requester_user=requester_user) + supported_filters = {"name": "name", "pack": "pack", "user": "context.user"} + + query_options = {"sort": ["pack", "name"]} + + mandatory_include_fields_retrieve = ["pack", "name", "trigger"] + + def get_all( + self, + exclude_attributes=None, + include_attributes=None, + sort=None, + offset=0, + limit=None, + requester_user=None, + **raw_filters, + ): + rules = super(RuleViewController, self)._get_all( + exclude_fields=exclude_attributes, + include_fields=include_attributes, + sort=sort, + offset=offset, + limit=limit, + raw_filters=raw_filters, + requester_user=requester_user, + ) result = self._append_view_properties(rules.json) rules.json = result return rules def get_one(self, ref_or_id, requester_user): - from_model_kwargs = {'mask_secrets': True} - rule = self._get_one(ref_or_id, permission_type=PermissionType.RULE_VIEW, - requester_user=requester_user, from_model_kwargs=from_model_kwargs) + from_model_kwargs = {"mask_secrets": True} + rule = self._get_one( + ref_or_id, + permission_type=PermissionType.RULE_VIEW, + requester_user=requester_user, + from_model_kwargs=from_model_kwargs, + ) result = self._append_view_properties([rule.json])[0] rule.json = result return rule def _append_view_properties(self, rules): - action_by_refs, trigger_by_refs, trigger_type_by_refs = self._get_referenced_models(rules) + ( + action_by_refs, + trigger_by_refs, + trigger_type_by_refs, + ) = self._get_referenced_models(rules) for rule in rules: - action_ref = rule.get('action', {}).get('ref', None) - trigger_ref = rule.get('trigger', {}).get('ref', None) - trigger_type_ref = rule.get('trigger', {}).get('type', None) + action_ref = rule.get("action", {}).get("ref", None) + trigger_ref = rule.get("trigger", {}).get("ref", None) + trigger_type_ref = rule.get("trigger", {}).get("type", None) action_db = action_by_refs.get(action_ref, None) - if 'action' in rule: - rule['action']['description'] = action_db.description if action_db else '' + if "action" in rule: + rule["action"]["description"] = ( + action_db.description if action_db else "" + ) - if 'trigger' in rule: - rule['trigger']['description'] = '' + if "trigger" in rule: + rule["trigger"]["description"] = "" trigger_db = trigger_by_refs.get(trigger_ref, None) if trigger_db: - rule['trigger']['description'] = trigger_db.description + rule["trigger"]["description"] = trigger_db.description # If description is not found in trigger get description from TriggerType - if 'trigger' in rule and not rule['trigger']['description']: + if "trigger" in rule and not rule["trigger"]["description"]: trigger_type_db = trigger_type_by_refs.get(trigger_type_ref, None) if trigger_type_db: - rule['trigger']['description'] = trigger_type_db.description + rule["trigger"]["description"] = trigger_type_db.description return rules @@ -145,9 +161,9 @@ def _get_referenced_models(self, rules): trigger_type_refs = set() for rule in rules: - action_ref = rule.get('action', {}).get('ref', None) - trigger_ref = rule.get('trigger', {}).get('ref', None) - trigger_type_ref = rule.get('trigger', {}).get('type', None) + action_ref = rule.get("action", {}).get("ref", None) + trigger_ref = rule.get("trigger", {}).get("ref", None) + trigger_type_ref = rule.get("trigger", {}).get("type", None) if action_ref: action_refs.add(action_ref) @@ -164,27 +180,31 @@ def _get_referenced_models(self, rules): # The functions that will return args that can used to query. def ref_query_args(ref): - return {'ref': ref} + return {"ref": ref} def name_pack_query_args(ref): resource_ref = ResourceReference.from_string_reference(ref=ref) - return {'name': resource_ref.name, 'pack': resource_ref.pack} + return {"name": resource_ref.name, "pack": resource_ref.pack} - action_dbs = self._get_entities(model_persistence=Action, - refs=action_refs, - query_args=ref_query_args) + action_dbs = self._get_entities( + model_persistence=Action, refs=action_refs, query_args=ref_query_args + ) for action_db in action_dbs: action_by_refs[action_db.ref] = action_db - trigger_dbs = self._get_entities(model_persistence=Trigger, - refs=trigger_refs, - query_args=name_pack_query_args) + trigger_dbs = self._get_entities( + model_persistence=Trigger, + refs=trigger_refs, + query_args=name_pack_query_args, + ) for trigger_db in trigger_dbs: trigger_by_refs[trigger_db.get_reference().ref] = trigger_db - trigger_type_dbs = self._get_entities(model_persistence=TriggerType, - refs=trigger_type_refs, - query_args=name_pack_query_args) + trigger_type_dbs = self._get_entities( + model_persistence=TriggerType, + refs=trigger_type_refs, + query_args=name_pack_query_args, + ) for trigger_type_db in trigger_type_dbs: trigger_type_by_refs[trigger_type_db.get_reference().ref] = trigger_type_db diff --git a/st2api/st2api/controllers/v1/rules.py b/st2api/st2api/controllers/v1/rules.py index 5904f9140ee..89f9e63531a 100644 --- a/st2api/st2api/controllers/v1/rules.py +++ b/st2api/st2api/controllers/v1/rules.py @@ -34,124 +34,149 @@ from st2common.router import exc from st2common.router import abort from st2common.router import Response -from st2common.services.triggers import cleanup_trigger_db_for_rule, increment_trigger_ref_count +from st2common.services.triggers import ( + cleanup_trigger_db_for_rule, + increment_trigger_ref_count, +) http_client = six.moves.http_client LOG = logging.getLogger(__name__) -class RuleController(BaseRestControllerMixin, BaseResourceIsolationControllerMixin, - ContentPackResourceController): +class RuleController( + BaseRestControllerMixin, + BaseResourceIsolationControllerMixin, + ContentPackResourceController, +): """ - Implements the RESTful web endpoint that handles - the lifecycle of Rules in the system. + Implements the RESTful web endpoint that handles + the lifecycle of Rules in the system. """ + views = RuleViewController() model = RuleAPI access = Rule supported_filters = { - 'name': 'name', - 'pack': 'pack', - 'action': 'action.ref', - 'trigger': 'trigger', - 'enabled': 'enabled', - 'user': 'context.user' + "name": "name", + "pack": "pack", + "action": "action.ref", + "trigger": "trigger", + "enabled": "enabled", + "user": "context.user", } - filter_transform_functions = { - 'enabled': transform_to_bool - } + filter_transform_functions = {"enabled": transform_to_bool} - query_options = { - 'sort': ['pack', 'name'] - } + query_options = {"sort": ["pack", "name"]} - mandatory_include_fields_retrieve = ['pack', 'name', 'trigger'] + mandatory_include_fields_retrieve = ["pack", "name", "trigger"] - def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0, - limit=None, show_secrets=False, requester_user=None, **raw_filters): + def get_all( + self, + exclude_attributes=None, + include_attributes=None, + sort=None, + offset=0, + limit=None, + show_secrets=False, + requester_user=None, + **raw_filters, + ): from_model_kwargs = { - 'ignore_missing_trigger': True, - 'mask_secrets': self._get_mask_secrets(requester_user, show_secrets=show_secrets) + "ignore_missing_trigger": True, + "mask_secrets": self._get_mask_secrets( + requester_user, show_secrets=show_secrets + ), } - return super(RuleController, self)._get_all(exclude_fields=exclude_attributes, - include_fields=include_attributes, - from_model_kwargs=from_model_kwargs, - sort=sort, - offset=offset, - limit=limit, - raw_filters=raw_filters, - requester_user=requester_user) + return super(RuleController, self)._get_all( + exclude_fields=exclude_attributes, + include_fields=include_attributes, + from_model_kwargs=from_model_kwargs, + sort=sort, + offset=offset, + limit=limit, + raw_filters=raw_filters, + requester_user=requester_user, + ) def get_one(self, ref_or_id, requester_user, show_secrets=False): from_model_kwargs = { - 'ignore_missing_trigger': True, - 'mask_secrets': self._get_mask_secrets(requester_user, show_secrets=show_secrets) + "ignore_missing_trigger": True, + "mask_secrets": self._get_mask_secrets( + requester_user, show_secrets=show_secrets + ), } - return super(RuleController, self)._get_one(ref_or_id, from_model_kwargs=from_model_kwargs, - requester_user=requester_user, - permission_type=PermissionType.RULE_VIEW) + return super(RuleController, self)._get_one( + ref_or_id, + from_model_kwargs=from_model_kwargs, + requester_user=requester_user, + permission_type=PermissionType.RULE_VIEW, + ) def post(self, rule, requester_user): """ - Create a new rule. + Create a new rule. - Handles requests: - POST /rules/ + Handles requests: + POST /rules/ """ rbac_utils = get_rbac_backend().get_utils_class() permission_type = PermissionType.RULE_CREATE - rbac_utils.assert_user_has_resource_api_permission(user_db=requester_user, - resource_api=rule, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_api_permission( + user_db=requester_user, resource_api=rule, permission_type=permission_type + ) if not requester_user: requester_user = UserDB(cfg.CONF.system_user.user) # Validate that the authenticated user is admin if user query param is provided user = requester_user.name - rbac_utils.assert_user_is_admin_if_user_query_param_is_provided(user_db=requester_user, - user=user) + rbac_utils.assert_user_is_admin_if_user_query_param_is_provided( + user_db=requester_user, user=user + ) - if not hasattr(rule, 'context'): + if not hasattr(rule, "context"): rule.context = dict() - rule.context['user'] = user + rule.context["user"] = user try: rule_db = RuleAPI.to_model(rule) - LOG.debug('/rules/ POST verified RuleAPI and formulated RuleDB=%s', rule_db) + LOG.debug("/rules/ POST verified RuleAPI and formulated RuleDB=%s", rule_db) # Check referenced trigger and action permissions # Note: This needs to happen after "to_model" call since to_model performs some # validation (trigger exists, etc.) - rbac_utils.assert_user_has_rule_trigger_and_action_permission(user_db=requester_user, - rule_api=rule) + rbac_utils.assert_user_has_rule_trigger_and_action_permission( + user_db=requester_user, rule_api=rule + ) rule_db = Rule.add_or_update(rule_db) # After the rule has been added modify the ref_count. This way a failure to add # the rule due to violated constraints will have no impact on ref_count. increment_trigger_ref_count(rule_api=rule) except (ValidationError, ValueError) as e: - LOG.exception('Validation failed for rule data=%s.', rule) + LOG.exception("Validation failed for rule data=%s.", rule) abort(http_client.BAD_REQUEST, six.text_type(e)) return except (ValueValidationException, jsonschema.ValidationError) as e: - LOG.exception('Validation failed for rule data=%s.', rule) + LOG.exception("Validation failed for rule data=%s.", rule) abort(http_client.BAD_REQUEST, six.text_type(e)) return except TriggerDoesNotExistException: - msg = ('Trigger "%s" defined in the rule does not exist in system or it\'s missing ' - 'required "parameters" attribute' % (rule.trigger['type'])) + msg = ( + 'Trigger "%s" defined in the rule does not exist in system or it\'s missing ' + 'required "parameters" attribute' % (rule.trigger["type"]) + ) LOG.exception(msg) abort(http_client.BAD_REQUEST, msg) return - extra = {'rule_db': rule_db} - LOG.audit('Rule created. Rule.id=%s' % (rule_db.id), extra=extra) + extra = {"rule_db": rule_db} + LOG.audit("Rule created. Rule.id=%s" % (rule_db.id), extra=extra) rule_api = RuleAPI.from_model(rule_db) return Response(json=rule_api, status=exc.HTTPCreated.code) @@ -161,27 +186,33 @@ def put(self, rule, rule_ref_or_id, requester_user): rbac_utils = get_rbac_backend().get_utils_class() permission_type = PermissionType.RULE_MODIFY - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=rule, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, resource_db=rule, permission_type=permission_type + ) - LOG.debug('PUT /rules/ lookup with id=%s found object: %s', rule_ref_or_id, rule_db) + LOG.debug( + "PUT /rules/ lookup with id=%s found object: %s", rule_ref_or_id, rule_db + ) if not requester_user: requester_user = UserDB(cfg.CONF.system_user.user) # Validate that the authenticated user is admin if user query param is provided user = requester_user.name - rbac_utils.assert_user_is_admin_if_user_query_param_is_provided(user_db=requester_user, - user=user) + rbac_utils.assert_user_is_admin_if_user_query_param_is_provided( + user_db=requester_user, user=user + ) - if not hasattr(rule, 'context'): + if not hasattr(rule, "context"): rule.context = dict() - rule.context['user'] = user + rule.context["user"] = user try: - if rule.id is not None and rule.id != '' and rule.id != rule_ref_or_id: - LOG.warning('Discarding mismatched id=%s found in payload and using uri_id=%s.', - rule.id, rule_ref_or_id) + if rule.id is not None and rule.id != "" and rule.id != rule_ref_or_id: + LOG.warning( + "Discarding mismatched id=%s found in payload and using uri_id=%s.", + rule.id, + rule_ref_or_id, + ) old_rule_db = rule_db try: @@ -193,8 +224,9 @@ def put(self, rule, rule_ref_or_id, requester_user): # Check referenced trigger and action permissions # Note: This needs to happen after "to_model" call since to_model performs some # validation (trigger exists, etc.) - rbac_utils.assert_user_has_rule_trigger_and_action_permission(user_db=requester_user, - rule_api=rule) + rbac_utils.assert_user_has_rule_trigger_and_action_permission( + user_db=requester_user, rule_api=rule + ) rule_db.id = rule_ref_or_id rule_db = Rule.add_or_update(rule_db) @@ -202,48 +234,52 @@ def put(self, rule, rule_ref_or_id, requester_user): # the rule due to violated constraints will have no impact on ref_count. increment_trigger_ref_count(rule_api=rule) except (ValueValidationException, jsonschema.ValidationError, ValueError) as e: - LOG.exception('Validation failed for rule data=%s', rule) + LOG.exception("Validation failed for rule data=%s", rule) abort(http_client.BAD_REQUEST, six.text_type(e)) return # use old_rule_db for cleanup. cleanup_trigger_db_for_rule(old_rule_db) - extra = {'old_rule_db': old_rule_db, 'new_rule_db': rule_db} - LOG.audit('Rule updated. Rule.id=%s.' % (rule_db.id), extra=extra) + extra = {"old_rule_db": old_rule_db, "new_rule_db": rule_db} + LOG.audit("Rule updated. Rule.id=%s." % (rule_db.id), extra=extra) rule_api = RuleAPI.from_model(rule_db) return rule_api def delete(self, rule_ref_or_id, requester_user): """ - Delete a rule. + Delete a rule. - Handles requests: - DELETE /rules/1 + Handles requests: + DELETE /rules/1 """ rule_db = self._get_by_ref_or_id(ref_or_id=rule_ref_or_id) permission_type = PermissionType.RULE_DELETE rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=rule_db, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, resource_db=rule_db, permission_type=permission_type + ) - LOG.debug('DELETE /rules/ lookup with id=%s found object: %s', rule_ref_or_id, rule_db) + LOG.debug( + "DELETE /rules/ lookup with id=%s found object: %s", rule_ref_or_id, rule_db + ) try: Rule.delete(rule_db) except Exception as e: - LOG.exception('Database delete encountered exception during delete of id="%s".', - rule_ref_or_id) + LOG.exception( + 'Database delete encountered exception during delete of id="%s".', + rule_ref_or_id, + ) abort(http_client.INTERNAL_SERVER_ERROR, six.text_type(e)) return # use old_rule_db for cleanup. cleanup_trigger_db_for_rule(rule_db) - extra = {'rule_db': rule_db} - LOG.audit('Rule deleted. Rule.id=%s.' % (rule_db.id), extra=extra) + extra = {"rule_db": rule_db} + LOG.audit("Rule deleted. Rule.id=%s." % (rule_db.id), extra=extra) return Response(status=http_client.NO_CONTENT) diff --git a/st2api/st2api/controllers/v1/ruletypes.py b/st2api/st2api/controllers/v1/ruletypes.py index dcf62069ef3..267c1925346 100644 --- a/st2api/st2api/controllers/v1/ruletypes.py +++ b/st2api/st2api/controllers/v1/ruletypes.py @@ -28,8 +28,8 @@ class RuleTypesController(object): """ - Implements the RESTful web endpoint that handles - the lifecycle of a RuleType in the system. + Implements the RESTful web endpoint that handles + the lifecycle of a RuleType in the system. """ @staticmethod @@ -46,15 +46,17 @@ def __get_by_name(name): try: return [RuleType.get_by_name(name)] except ValueError as e: - LOG.debug('Database lookup for name="%s" resulted in exception : %s.', name, e) + LOG.debug( + 'Database lookup for name="%s" resulted in exception : %s.', name, e + ) return [] def get_one(self, id): """ - List RuleType objects by id. + List RuleType objects by id. - Handle: - GET /ruletypes/1 + Handle: + GET /ruletypes/1 """ ruletype_db = RuleTypesController.__get_by_id(id) ruletype_api = RuleTypeAPI.from_model(ruletype_db) @@ -62,14 +64,15 @@ def get_one(self, id): def get_all(self): """ - List all RuleType objects. + List all RuleType objects. - Handles requests: - GET /ruletypes/ + Handles requests: + GET /ruletypes/ """ ruletype_dbs = RuleType.get_all() - ruletype_apis = [RuleTypeAPI.from_model(runnertype_db) - for runnertype_db in ruletype_dbs] + ruletype_apis = [ + RuleTypeAPI.from_model(runnertype_db) for runnertype_db in ruletype_dbs + ] return ruletype_apis diff --git a/st2api/st2api/controllers/v1/runnertypes.py b/st2api/st2api/controllers/v1/runnertypes.py index b947babd941..1c84b4425cc 100644 --- a/st2api/st2api/controllers/v1/runnertypes.py +++ b/st2api/st2api/controllers/v1/runnertypes.py @@ -31,34 +31,42 @@ class RunnerTypesController(ResourceController): """ - Implements the RESTful web endpoint that handles - the lifecycle of an RunnerType in the system. + Implements the RESTful web endpoint that handles + the lifecycle of an RunnerType in the system. """ model = RunnerTypeAPI access = RunnerType - supported_filters = { - 'name': 'name' - } - - query_options = { - 'sort': ['name'] - } - - def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0, - limit=None, requester_user=None, **raw_filters): - return super(RunnerTypesController, self)._get_all(exclude_fields=exclude_attributes, - include_fields=include_attributes, - sort=sort, - offset=offset, - limit=limit, - raw_filters=raw_filters, - requester_user=requester_user) + supported_filters = {"name": "name"} + + query_options = {"sort": ["name"]} + + def get_all( + self, + exclude_attributes=None, + include_attributes=None, + sort=None, + offset=0, + limit=None, + requester_user=None, + **raw_filters, + ): + return super(RunnerTypesController, self)._get_all( + exclude_fields=exclude_attributes, + include_fields=include_attributes, + sort=sort, + offset=offset, + limit=limit, + raw_filters=raw_filters, + requester_user=requester_user, + ) def get_one(self, name_or_id, requester_user): - return self._get_one_by_name_or_id(name_or_id, - requester_user=requester_user, - permission_type=PermissionType.RUNNER_VIEW) + return self._get_one_by_name_or_id( + name_or_id, + requester_user=requester_user, + permission_type=PermissionType.RUNNER_VIEW, + ) def put(self, runner_type_api, name_or_id, requester_user): # Note: We only allow "enabled" attribute of the runner to be changed @@ -66,28 +74,41 @@ def put(self, runner_type_api, name_or_id, requester_user): permission_type = PermissionType.RUNNER_MODIFY rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=runner_type_db, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=runner_type_db, + permission_type=permission_type, + ) old_runner_type_db = runner_type_db - LOG.debug('PUT /runnertypes/ lookup with id=%s found object: %s', name_or_id, - runner_type_db) + LOG.debug( + "PUT /runnertypes/ lookup with id=%s found object: %s", + name_or_id, + runner_type_db, + ) try: if runner_type_api.id and runner_type_api.id != name_or_id: - LOG.warning('Discarding mismatched id=%s found in payload and using uri_id=%s.', - runner_type_api.id, name_or_id) + LOG.warning( + "Discarding mismatched id=%s found in payload and using uri_id=%s.", + runner_type_api.id, + name_or_id, + ) runner_type_db.enabled = runner_type_api.enabled runner_type_db = RunnerType.add_or_update(runner_type_db) except (ValidationError, ValueError) as e: - LOG.exception('Validation failed for runner type data=%s', runner_type_api) + LOG.exception("Validation failed for runner type data=%s", runner_type_api) abort(http_client.BAD_REQUEST, six.text_type(e)) return - extra = {'old_runner_type_db': old_runner_type_db, 'new_runner_type_db': runner_type_db} - LOG.audit('Runner Type updated. RunnerType.id=%s.' % (runner_type_db.id), extra=extra) + extra = { + "old_runner_type_db": old_runner_type_db, + "new_runner_type_db": runner_type_db, + } + LOG.audit( + "Runner Type updated. RunnerType.id=%s." % (runner_type_db.id), extra=extra + ) runner_type_api = RunnerTypeAPI.from_model(runner_type_db) return runner_type_api diff --git a/st2api/st2api/controllers/v1/sensors.py b/st2api/st2api/controllers/v1/sensors.py index a3a71853d8b..b62b56c92d9 100644 --- a/st2api/st2api/controllers/v1/sensors.py +++ b/st2api/st2api/controllers/v1/sensors.py @@ -36,35 +36,41 @@ class SensorTypeController(resource.ContentPackResourceController): model = SensorTypeAPI access = SensorType supported_filters = { - 'name': 'name', - 'pack': 'pack', - 'enabled': 'enabled', - 'trigger': 'trigger_types' + "name": "name", + "pack": "pack", + "enabled": "enabled", + "trigger": "trigger_types", } - filter_transform_functions = { - 'enabled': transform_to_bool - } - - options = { - 'sort': ['pack', 'name'] - } - - def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0, - limit=None, requester_user=None, **raw_filters): - return super(SensorTypeController, self)._get_all(exclude_fields=exclude_attributes, - include_fields=include_attributes, - sort=sort, - offset=offset, - limit=limit, - raw_filters=raw_filters, - requester_user=requester_user) + filter_transform_functions = {"enabled": transform_to_bool} + + options = {"sort": ["pack", "name"]} + + def get_all( + self, + exclude_attributes=None, + include_attributes=None, + sort=None, + offset=0, + limit=None, + requester_user=None, + **raw_filters, + ): + return super(SensorTypeController, self)._get_all( + exclude_fields=exclude_attributes, + include_fields=include_attributes, + sort=sort, + offset=offset, + limit=limit, + raw_filters=raw_filters, + requester_user=requester_user, + ) def get_one(self, ref_or_id, requester_user): permission_type = PermissionType.SENSOR_VIEW - return super(SensorTypeController, self)._get_one(ref_or_id, - requester_user=requester_user, - permission_type=permission_type) + return super(SensorTypeController, self)._get_one( + ref_or_id, requester_user=requester_user, permission_type=permission_type + ) def put(self, sensor_type, ref_or_id, requester_user): # Note: Right now this function only supports updating of "enabled" @@ -76,9 +82,11 @@ def put(self, sensor_type, ref_or_id, requester_user): permission_type = PermissionType.SENSOR_MODIFY rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=sensor_type_db, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=sensor_type_db, + permission_type=permission_type, + ) sensor_type_id = sensor_type_db.id @@ -88,23 +96,23 @@ def put(self, sensor_type, ref_or_id, requester_user): abort(http_client.BAD_REQUEST, six.text_type(e)) return - if not getattr(sensor_type, 'pack', None): + if not getattr(sensor_type, "pack", None): sensor_type.pack = sensor_type_db.pack try: old_sensor_type_db = sensor_type_db sensor_type_db.id = sensor_type_id - sensor_type_db.enabled = getattr(sensor_type, 'enabled', False) + sensor_type_db.enabled = getattr(sensor_type, "enabled", False) sensor_type_db = SensorType.add_or_update(sensor_type_db) except (ValidationError, ValueError) as e: - LOG.exception('Unable to update sensor_type data=%s', sensor_type) + LOG.exception("Unable to update sensor_type data=%s", sensor_type) abort(http_client.BAD_REQUEST, six.text_type(e)) return extra = { - 'old_sensor_type_db': old_sensor_type_db, - 'new_sensor_type_db': sensor_type_db + "old_sensor_type_db": old_sensor_type_db, + "new_sensor_type_db": sensor_type_db, } - LOG.audit('Sensor updated. Sensor.id=%s.' % (sensor_type_db.id), extra=extra) + LOG.audit("Sensor updated. Sensor.id=%s." % (sensor_type_db.id), extra=extra) sensor_type_api = SensorTypeAPI.from_model(sensor_type_db) return sensor_type_api diff --git a/st2api/st2api/controllers/v1/service_registry.py b/st2api/st2api/controllers/v1/service_registry.py index d9ee9d542ba..3a54563b252 100644 --- a/st2api/st2api/controllers/v1/service_registry.py +++ b/st2api/st2api/controllers/v1/service_registry.py @@ -22,8 +22,8 @@ from st2common.rbac.backends import get_rbac_backend __all__ = [ - 'ServiceRegistryGroupsController', - 'ServiceRegistryGroupMembersController', + "ServiceRegistryGroupsController", + "ServiceRegistryGroupMembersController", ] @@ -35,11 +35,9 @@ def get_all(self, requester_user): coordinator = coordination.get_coordinator() group_ids = list(coordinator.get_groups().get()) - group_ids = [item.decode('utf-8') for item in group_ids] + group_ids = [item.decode("utf-8") for item in group_ids] - result = { - 'groups': group_ids - } + result = {"groups": group_ids} return result @@ -51,26 +49,26 @@ def get_one(self, group_id, requester_user): coordinator = coordination.get_coordinator() if not isinstance(group_id, six.binary_type): - group_id = group_id.encode('utf-8') + group_id = group_id.encode("utf-8") try: member_ids = list(coordinator.get_members(group_id).get()) except GroupNotCreated: - msg = ('Group with ID "%s" not found.' % (group_id.decode('utf-8'))) + msg = 'Group with ID "%s" not found.' % (group_id.decode("utf-8")) raise StackStormDBObjectNotFoundError(msg) - result = { - 'members': [] - } + result = {"members": []} for member_id in member_ids: - capabilities = coordinator.get_member_capabilities(group_id, member_id).get() + capabilities = coordinator.get_member_capabilities( + group_id, member_id + ).get() item = { - 'group_id': group_id.decode('utf-8'), - 'member_id': member_id.decode('utf-8'), - 'capabilities': capabilities + "group_id": group_id.decode("utf-8"), + "member_id": member_id.decode("utf-8"), + "capabilities": capabilities, } - result['members'].append(item) + result["members"].append(item) return result diff --git a/st2api/st2api/controllers/v1/timers.py b/st2api/st2api/controllers/v1/timers.py index c91b80fec19..541957a0999 100644 --- a/st2api/st2api/controllers/v1/timers.py +++ b/st2api/st2api/controllers/v1/timers.py @@ -30,17 +30,13 @@ from st2common.services.triggerwatcher import TriggerWatcher from st2common.router import abort -__all__ = [ - 'TimersController', - 'TimersHolder' -] +__all__ = ["TimersController", "TimersHolder"] LOG = logging.getLogger(__name__) class TimersHolder(object): - def __init__(self): self._timers = {} @@ -54,7 +50,7 @@ def get_all(self, timer_type=None): timer_triggers = [] for _, timer in iteritems(self._timers): - if not timer_type or timer['type'] == timer_type: + if not timer_type or timer["type"] == timer_type: timer_triggers.append(timer) return timer_triggers @@ -65,35 +61,37 @@ class TimersController(resource.ContentPackResourceController): access = Trigger supported_filters = { - 'type': 'type', + "type": "type", } - query_options = { - 'sort': ['type'] - } + query_options = {"sort": ["type"]} def __init__(self): self._timers = TimersHolder() self._trigger_types = TIMER_TRIGGER_TYPES.keys() queue_suffix = self.__class__.__name__ - self._trigger_watcher = TriggerWatcher(create_handler=self._handle_create_trigger, - update_handler=self._handle_update_trigger, - delete_handler=self._handle_delete_trigger, - trigger_types=self._trigger_types, - queue_suffix=queue_suffix, - exclusive=True) + self._trigger_watcher = TriggerWatcher( + create_handler=self._handle_create_trigger, + update_handler=self._handle_update_trigger, + delete_handler=self._handle_delete_trigger, + trigger_types=self._trigger_types, + queue_suffix=queue_suffix, + exclusive=True, + ) self._trigger_watcher.start() self._register_timer_trigger_types() self._allowed_timer_types = TIMER_TRIGGER_TYPES.keys() def get_all(self, timer_type=None): if timer_type and timer_type not in self._allowed_timer_types: - msg = 'Timer type %s not in supported types - %s.' % (timer_type, - self._allowed_timer_types) + msg = "Timer type %s not in supported types - %s." % ( + timer_type, + self._allowed_timer_types, + ) abort(http_client.BAD_REQUEST, msg) t_all = self._timers.get_all(timer_type=timer_type) - LOG.debug('Got timers: %s', t_all) + LOG.debug("Got timers: %s", t_all) return t_all def get_one(self, ref_or_id, requester_user): @@ -108,9 +106,11 @@ def get_one(self, ref_or_id, requester_user): resource_db = TimerDB(pack=trigger_db.pack, name=trigger_db.name) rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=resource_db, - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=resource_db, + permission_type=permission_type, + ) result = self.model.from_model(trigger_db) return result @@ -119,7 +119,7 @@ def add_trigger(self, trigger): # Note: Permission checking for creating and deleting a timer is done during rule # creation ref = self._get_timer_ref(trigger) - LOG.info('Started timer %s with parameters %s', ref, trigger['parameters']) + LOG.info("Started timer %s with parameters %s", ref, trigger["parameters"]) self._timers.add_trigger(ref, trigger) def update_trigger(self, trigger): @@ -130,14 +130,16 @@ def remove_trigger(self, trigger): # creation ref = self._get_timer_ref(trigger) self._timers.remove_trigger(ref, trigger) - LOG.info('Stopped timer %s with parameters %s.', ref, trigger['parameters']) + LOG.info("Stopped timer %s with parameters %s.", ref, trigger["parameters"]) def _register_timer_trigger_types(self): for trigger_type in TIMER_TRIGGER_TYPES.values(): trigger_service.create_trigger_type_db(trigger_type) def _get_timer_ref(self, trigger): - return ResourceReference.to_string_reference(pack=trigger['pack'], name=trigger['name']) + return ResourceReference.to_string_reference( + pack=trigger["pack"], name=trigger["name"] + ) ############################################## # Event handler methods for the trigger events diff --git a/st2api/st2api/controllers/v1/traces.py b/st2api/st2api/controllers/v1/traces.py index 91c6e95e4fe..4ab1d02aa57 100644 --- a/st2api/st2api/controllers/v1/traces.py +++ b/st2api/st2api/controllers/v1/traces.py @@ -18,47 +18,53 @@ from st2common.persistence.trace import Trace from st2common.rbac.types import PermissionType -__all__ = [ - 'TracesController' -] +__all__ = ["TracesController"] class TracesController(ResourceController): model = TraceAPI access = Trace supported_filters = { - 'trace_tag': 'trace_tag', - 'execution': 'action_executions.object_id', - 'rule': 'rules.object_id', - 'trigger_instance': 'trigger_instances.object_id', + "trace_tag": "trace_tag", + "execution": "action_executions.object_id", + "rule": "rules.object_id", + "trigger_instance": "trigger_instances.object_id", } - query_options = { - 'sort': ['-start_timestamp', 'trace_tag'] - } + query_options = {"sort": ["-start_timestamp", "trace_tag"]} - def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0, - limit=None, requester_user=None, **raw_filters): + def get_all( + self, + exclude_attributes=None, + include_attributes=None, + sort=None, + offset=0, + limit=None, + requester_user=None, + **raw_filters, + ): # Use a custom sort order when filtering on a timestamp so we return a correct result as # expected by the user query_options = None - if 'sort_desc' in raw_filters and raw_filters['sort_desc'] == 'True': - query_options = {'sort': ['-start_timestamp', 'trace_tag']} - elif 'sort_asc' in raw_filters and raw_filters['sort_asc'] == 'True': - query_options = {'sort': ['+start_timestamp', 'trace_tag']} - return self._get_all(exclude_fields=exclude_attributes, - include_fields=include_attributes, - sort=sort, - offset=offset, - limit=limit, - query_options=query_options, - raw_filters=raw_filters, - requester_user=requester_user) + if "sort_desc" in raw_filters and raw_filters["sort_desc"] == "True": + query_options = {"sort": ["-start_timestamp", "trace_tag"]} + elif "sort_asc" in raw_filters and raw_filters["sort_asc"] == "True": + query_options = {"sort": ["+start_timestamp", "trace_tag"]} + return self._get_all( + exclude_fields=exclude_attributes, + include_fields=include_attributes, + sort=sort, + offset=offset, + limit=limit, + query_options=query_options, + raw_filters=raw_filters, + requester_user=requester_user, + ) def get_one(self, id, requester_user): - return self._get_one_by_id(id, - requester_user=requester_user, - permission_type=PermissionType.TRACE_VIEW) + return self._get_one_by_id( + id, requester_user=requester_user, permission_type=PermissionType.TRACE_VIEW + ) traces_controller = TracesController() diff --git a/st2api/st2api/controllers/v1/triggers.py b/st2api/st2api/controllers/v1/triggers.py index 12c3f133ec1..cbdc5ca66b5 100644 --- a/st2api/st2api/controllers/v1/triggers.py +++ b/st2api/st2api/controllers/v1/triggers.py @@ -39,55 +39,64 @@ class TriggerTypeController(resource.ContentPackResourceController): """ - Implements the RESTful web endpoint that handles - the lifecycle of TriggerTypes in the system. + Implements the RESTful web endpoint that handles + the lifecycle of TriggerTypes in the system. """ + model = TriggerTypeAPI access = TriggerType - supported_filters = { - 'name': 'name', - 'pack': 'pack' - } - - options = { - 'sort': ['pack', 'name'] - } - - query_options = { - 'sort': ['ref'] - } - - def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0, - limit=None, requester_user=None, **raw_filters): - return self._get_all(exclude_fields=exclude_attributes, - include_fields=include_attributes, - sort=sort, - offset=offset, - limit=limit, - raw_filters=raw_filters, - requester_user=requester_user) + supported_filters = {"name": "name", "pack": "pack"} + + options = {"sort": ["pack", "name"]} + + query_options = {"sort": ["ref"]} + + def get_all( + self, + exclude_attributes=None, + include_attributes=None, + sort=None, + offset=0, + limit=None, + requester_user=None, + **raw_filters, + ): + return self._get_all( + exclude_fields=exclude_attributes, + include_fields=include_attributes, + sort=sort, + offset=offset, + limit=limit, + raw_filters=raw_filters, + requester_user=requester_user, + ) def get_one(self, triggertype_ref_or_id): - return self._get_one(triggertype_ref_or_id, permission_type=None, requester_user=None) + return self._get_one( + triggertype_ref_or_id, permission_type=None, requester_user=None + ) def post(self, triggertype): """ - Create a new triggertype. + Create a new triggertype. - Handles requests: - POST /triggertypes/ + Handles requests: + POST /triggertypes/ """ try: triggertype_db = TriggerTypeAPI.to_model(triggertype) triggertype_db = TriggerType.add_or_update(triggertype_db) except (ValidationError, ValueError) as e: - LOG.exception('Validation failed for triggertype data=%s.', triggertype) + LOG.exception("Validation failed for triggertype data=%s.", triggertype) abort(http_client.BAD_REQUEST, six.text_type(e)) return else: - extra = {'triggertype_db': triggertype_db} - LOG.audit('TriggerType created. TriggerType.id=%s' % (triggertype_db.id), extra=extra) + extra = {"triggertype_db": triggertype_db} + LOG.audit( + "TriggerType created. TriggerType.id=%s" % (triggertype_db.id), + extra=extra, + ) if not triggertype_db.parameters_schema: TriggerTypeController._create_shadow_trigger(triggertype_db) @@ -106,34 +115,44 @@ def put(self, triggertype, triggertype_ref_or_id): try: triggertype_db = TriggerTypeAPI.to_model(triggertype) - if triggertype.id is not None and len(triggertype.id) > 0 and \ - triggertype.id != triggertype_id: - LOG.warning('Discarding mismatched id=%s found in payload and using uri_id=%s.', - triggertype.id, triggertype_id) + if ( + triggertype.id is not None + and len(triggertype.id) > 0 + and triggertype.id != triggertype_id + ): + LOG.warning( + "Discarding mismatched id=%s found in payload and using uri_id=%s.", + triggertype.id, + triggertype_id, + ) triggertype_db.id = triggertype_id old_triggertype_db = triggertype_db triggertype_db = TriggerType.add_or_update(triggertype_db) except (ValidationError, ValueError) as e: - LOG.exception('Validation failed for triggertype data=%s', triggertype) + LOG.exception("Validation failed for triggertype data=%s", triggertype) abort(http_client.BAD_REQUEST, six.text_type(e)) return - extra = {'old_triggertype_db': old_triggertype_db, 'new_triggertype_db': triggertype_db} - LOG.audit('TriggerType updated. TriggerType.id=%s' % (triggertype_db.id), extra=extra) + extra = { + "old_triggertype_db": old_triggertype_db, + "new_triggertype_db": triggertype_db, + } + LOG.audit( + "TriggerType updated. TriggerType.id=%s" % (triggertype_db.id), extra=extra + ) triggertype_api = TriggerTypeAPI.from_model(triggertype_db) return triggertype_api def delete(self, triggertype_ref_or_id): """ - Delete a triggertype. + Delete a triggertype. - Handles requests: - DELETE /triggertypes/1 - DELETE /triggertypes/pack.name + Handles requests: + DELETE /triggertypes/1 + DELETE /triggertypes/pack.name """ - LOG.info('DELETE /triggertypes/ with ref_or_id=%s', - triggertype_ref_or_id) + LOG.info("DELETE /triggertypes/ with ref_or_id=%s", triggertype_ref_or_id) triggertype_db = self._get_by_ref_or_id(ref_or_id=triggertype_ref_or_id) triggertype_id = triggertype_db.id @@ -146,13 +165,18 @@ def delete(self, triggertype_ref_or_id): try: TriggerType.delete(triggertype_db) except Exception as e: - LOG.exception('Database delete encountered exception during delete of id="%s". ', - triggertype_id) + LOG.exception( + 'Database delete encountered exception during delete of id="%s". ', + triggertype_id, + ) abort(http_client.INTERNAL_SERVER_ERROR, six.text_type(e)) return else: - extra = {'triggertype': triggertype_db} - LOG.audit('TriggerType deleted. TriggerType.id=%s' % (triggertype_db.id), extra=extra) + extra = {"triggertype": triggertype_db} + LOG.audit( + "TriggerType deleted. TriggerType.id=%s" % (triggertype_db.id), + extra=extra, + ) if not triggertype_db.parameters_schema: TriggerTypeController._delete_shadow_trigger(triggertype_db) @@ -162,55 +186,70 @@ def delete(self, triggertype_ref_or_id): def _create_shadow_trigger(triggertype_db): try: trigger_type_ref = triggertype_db.get_reference().ref - trigger = {'name': triggertype_db.name, - 'pack': triggertype_db.pack, - 'type': trigger_type_ref, - 'parameters': {}} + trigger = { + "name": triggertype_db.name, + "pack": triggertype_db.pack, + "type": trigger_type_ref, + "parameters": {}, + } trigger_db = TriggerService.create_or_update_trigger_db(trigger) - extra = {'trigger_db': trigger_db} - LOG.audit('Trigger created for parameter-less TriggerType. Trigger.id=%s' % - (trigger_db.id), extra=extra) + extra = {"trigger_db": trigger_db} + LOG.audit( + "Trigger created for parameter-less TriggerType. Trigger.id=%s" + % (trigger_db.id), + extra=extra, + ) except (ValidationError, ValueError): - LOG.exception('Validation failed for trigger data=%s.', trigger) + LOG.exception("Validation failed for trigger data=%s.", trigger) # Not aborting as this is convenience. return except StackStormDBObjectConflictError as e: - LOG.warn('Trigger creation of "%s" failed with uniqueness conflict. Exception: %s', - trigger, six.text_type(e)) + LOG.warn( + 'Trigger creation of "%s" failed with uniqueness conflict. Exception: %s', + trigger, + six.text_type(e), + ) # Not aborting as this is convenience. return @staticmethod def _delete_shadow_trigger(triggertype_db): # shadow Trigger's have the same name as the shadowed TriggerType. - triggertype_ref = ResourceReference(name=triggertype_db.name, pack=triggertype_db.pack) + triggertype_ref = ResourceReference( + name=triggertype_db.name, pack=triggertype_db.pack + ) trigger_db = TriggerService.get_trigger_db_by_ref(triggertype_ref.ref) if not trigger_db: - LOG.warn('No shadow trigger found for %s. Will skip delete.', triggertype_db) + LOG.warn( + "No shadow trigger found for %s. Will skip delete.", triggertype_db + ) return try: Trigger.delete(trigger_db) except Exception: - LOG.exception('Database delete encountered exception during delete of id="%s". ', - trigger_db.id) + LOG.exception( + 'Database delete encountered exception during delete of id="%s". ', + trigger_db.id, + ) - extra = {'trigger_db': trigger_db} - LOG.audit('Trigger deleted. Trigger.id=%s' % (trigger_db.id), extra=extra) + extra = {"trigger_db": trigger_db} + LOG.audit("Trigger deleted. Trigger.id=%s" % (trigger_db.id), extra=extra) class TriggerController(object): """ - Implements the RESTful web endpoint that handles - the lifecycle of Triggers in the system. + Implements the RESTful web endpoint that handles + the lifecycle of Triggers in the system. """ + def get_one(self, trigger_id): """ - List trigger by id. + List trigger by id. - Handle: - GET /triggers/1 + Handle: + GET /triggers/1 """ trigger_db = TriggerController.__get_by_id(trigger_id) trigger_api = TriggerAPI.from_model(trigger_db) @@ -218,10 +257,10 @@ def get_one(self, trigger_id): def get_all(self, requester_user=None): """ - List all triggers. + List all triggers. - Handles requests: - GET /triggers/ + Handles requests: + GET /triggers/ """ trigger_dbs = Trigger.get_all() trigger_apis = [TriggerAPI.from_model(trigger_db) for trigger_db in trigger_dbs] @@ -229,20 +268,20 @@ def get_all(self, requester_user=None): def post(self, trigger): """ - Create a new trigger. + Create a new trigger. - Handles requests: - POST /triggers/ + Handles requests: + POST /triggers/ """ try: trigger_db = TriggerService.create_trigger_db(trigger) except (ValidationError, ValueError) as e: - LOG.exception('Validation failed for trigger data=%s.', trigger) + LOG.exception("Validation failed for trigger data=%s.", trigger) abort(http_client.BAD_REQUEST, six.text_type(e)) return - extra = {'trigger': trigger_db} - LOG.audit('Trigger created. Trigger.id=%s' % (trigger_db.id), extra=extra) + extra = {"trigger": trigger_db} + LOG.audit("Trigger created. Trigger.id=%s" % (trigger_db.id), extra=extra) trigger_api = TriggerAPI.from_model(trigger_db) return Response(json=trigger_api, status=http_client.CREATED) @@ -250,42 +289,47 @@ def post(self, trigger): def put(self, trigger, trigger_id): trigger_db = TriggerController.__get_by_id(trigger_id) try: - if trigger.id is not None and trigger.id != '' and trigger.id != trigger_id: - LOG.warning('Discarding mismatched id=%s found in payload and using uri_id=%s.', - trigger.id, trigger_id) + if trigger.id is not None and trigger.id != "" and trigger.id != trigger_id: + LOG.warning( + "Discarding mismatched id=%s found in payload and using uri_id=%s.", + trigger.id, + trigger_id, + ) trigger_db = TriggerAPI.to_model(trigger) trigger_db.id = trigger_id trigger_db = Trigger.add_or_update(trigger_db) except (ValidationError, ValueError) as e: - LOG.exception('Validation failed for trigger data=%s', trigger) + LOG.exception("Validation failed for trigger data=%s", trigger) abort(http_client.BAD_REQUEST, six.text_type(e)) return - extra = {'old_trigger_db': trigger, 'new_trigger_db': trigger_db} - LOG.audit('Trigger updated. Trigger.id=%s' % (trigger.id), extra=extra) + extra = {"old_trigger_db": trigger, "new_trigger_db": trigger_db} + LOG.audit("Trigger updated. Trigger.id=%s" % (trigger.id), extra=extra) trigger_api = TriggerAPI.from_model(trigger_db) return trigger_api def delete(self, trigger_id): """ - Delete a trigger. + Delete a trigger. - Handles requests: - DELETE /triggers/1 + Handles requests: + DELETE /triggers/1 """ - LOG.info('DELETE /triggers/ with id=%s', trigger_id) + LOG.info("DELETE /triggers/ with id=%s", trigger_id) trigger_db = TriggerController.__get_by_id(trigger_id) try: Trigger.delete(trigger_db) except Exception as e: - LOG.exception('Database delete encountered exception during delete of id="%s". ', - trigger_id) + LOG.exception( + 'Database delete encountered exception during delete of id="%s". ', + trigger_id, + ) abort(http_client.INTERNAL_SERVER_ERROR, six.text_type(e)) return - extra = {'trigger_db': trigger_db} - LOG.audit('Trigger deleted. Trigger.id=%s' % (trigger_db.id), extra=extra) + extra = {"trigger_db": trigger_db} + LOG.audit("Trigger deleted. Trigger.id=%s" % (trigger_db.id), extra=extra) return Response(status=http_client.NO_CONTENT) @@ -294,7 +338,9 @@ def __get_by_id(trigger_id): try: return Trigger.get_by_id(trigger_id) except (ValueError, ValidationError): - LOG.exception('Database lookup for id="%s" resulted in exception.', trigger_id) + LOG.exception( + 'Database lookup for id="%s" resulted in exception.', trigger_id + ) abort(http_client.NOT_FOUND) @staticmethod @@ -302,7 +348,11 @@ def __get_by_name(trigger_name): try: return [Trigger.get_by_name(trigger_name)] except ValueError as e: - LOG.debug('Database lookup for name="%s" resulted in exception : %s.', trigger_name, e) + LOG.debug( + 'Database lookup for name="%s" resulted in exception : %s.', + trigger_name, + e, + ) return [] @@ -311,7 +361,9 @@ class TriggerInstanceControllerMixin(object): access = TriggerInstance -class TriggerInstanceResendController(TriggerInstanceControllerMixin, resource.ResourceController): +class TriggerInstanceResendController( + TriggerInstanceControllerMixin, resource.ResourceController +): supported_filters = {} def __init__(self, *args, **kwargs): @@ -338,106 +390,130 @@ def post(self, trigger_instance_id): POST /triggerinstance//re_send """ # Note: We only really need parameters here - existing_trigger_instance = self._get_one_by_id(id=trigger_instance_id, - permission_type=None, - requester_user=None) + existing_trigger_instance = self._get_one_by_id( + id=trigger_instance_id, permission_type=None, requester_user=None + ) new_payload = copy.deepcopy(existing_trigger_instance.payload) - new_payload['__context'] = { - 'original_id': trigger_instance_id - } + new_payload["__context"] = {"original_id": trigger_instance_id} try: - self.trigger_dispatcher.dispatch(existing_trigger_instance.trigger, - new_payload) + self.trigger_dispatcher.dispatch( + existing_trigger_instance.trigger, new_payload + ) return { - 'message': 'Trigger instance %s succesfully re-sent.' % trigger_instance_id, - 'payload': new_payload + "message": "Trigger instance %s succesfully re-sent." + % trigger_instance_id, + "payload": new_payload, } except Exception as e: abort(http_client.INTERNAL_SERVER_ERROR, six.text_type(e)) -class TriggerInstanceController(TriggerInstanceControllerMixin, resource.ResourceController): +class TriggerInstanceController( + TriggerInstanceControllerMixin, resource.ResourceController +): """ - Implements the RESTful web endpoint that handles - the lifecycle of TriggerInstances in the system. + Implements the RESTful web endpoint that handles + the lifecycle of TriggerInstances in the system. """ + supported_filters = { - 'timestamp_gt': 'occurrence_time.gt', - 'timestamp_lt': 'occurrence_time.lt', - 'status': 'status', - 'trigger': 'trigger.in' + "timestamp_gt": "occurrence_time.gt", + "timestamp_lt": "occurrence_time.lt", + "status": "status", + "trigger": "trigger.in", } filter_transform_functions = { - 'timestamp_gt': lambda value: isotime.parse(value=value), - 'timestamp_lt': lambda value: isotime.parse(value=value) + "timestamp_gt": lambda value: isotime.parse(value=value), + "timestamp_lt": lambda value: isotime.parse(value=value), } - query_options = { - 'sort': ['-occurrence_time', 'trigger'] - } + query_options = {"sort": ["-occurrence_time", "trigger"]} def __init__(self): super(TriggerInstanceController, self).__init__() def get_one(self, instance_id): """ - List triggerinstance by instance_id. + List triggerinstance by instance_id. - Handle: - GET /triggerinstances/1 + Handle: + GET /triggerinstances/1 """ - return self._get_one_by_id(instance_id, permission_type=None, requester_user=None) - - def get_all(self, exclude_attributes=None, include_attributes=None, sort=None, offset=0, - limit=None, requester_user=None, **raw_filters): + return self._get_one_by_id( + instance_id, permission_type=None, requester_user=None + ) + + def get_all( + self, + exclude_attributes=None, + include_attributes=None, + sort=None, + offset=0, + limit=None, + requester_user=None, + **raw_filters, + ): """ - List all triggerinstances. + List all triggerinstances. - Handles requests: - GET /triggerinstances/ + Handles requests: + GET /triggerinstances/ """ # If trigger_type filter is provided, filter based on the TriggerType via Trigger object - trigger_type_ref = raw_filters.get('trigger_type', None) + trigger_type_ref = raw_filters.get("trigger_type", None) if trigger_type_ref: # 1. Retrieve TriggerType object id which match this trigger_type ref - trigger_dbs = Trigger.query(type=trigger_type_ref, - only_fields=['ref', 'name', 'pack', 'type']) + trigger_dbs = Trigger.query( + type=trigger_type_ref, only_fields=["ref", "name", "pack", "type"] + ) trigger_refs = [trigger_db.ref for trigger_db in trigger_dbs] - raw_filters['trigger'] = trigger_refs + raw_filters["trigger"] = trigger_refs - if trigger_type_ref and len(raw_filters.get('trigger', [])) == 0: + if trigger_type_ref and len(raw_filters.get("trigger", [])) == 0: # Empty list means trigger_type_ref filter was provided, but we matched no Triggers so # we should return back empty result return [] - trigger_instances = self._get_trigger_instances(exclude_fields=exclude_attributes, - include_fields=include_attributes, - sort=sort, - offset=offset, - limit=limit, - raw_filters=raw_filters, - requester_user=requester_user) + trigger_instances = self._get_trigger_instances( + exclude_fields=exclude_attributes, + include_fields=include_attributes, + sort=sort, + offset=offset, + limit=limit, + raw_filters=raw_filters, + requester_user=requester_user, + ) return trigger_instances - def _get_trigger_instances(self, exclude_fields=None, include_fields=None, sort=None, offset=0, - limit=None, raw_filters=None, requester_user=None): + def _get_trigger_instances( + self, + exclude_fields=None, + include_fields=None, + sort=None, + offset=0, + limit=None, + raw_filters=None, + requester_user=None, + ): if limit is None: limit = self.default_limit limit = int(limit) - LOG.debug('Retrieving all trigger instances with filters=%s', raw_filters) - return super(TriggerInstanceController, self)._get_all(exclude_fields=exclude_fields, - include_fields=include_fields, - sort=sort, - offset=offset, - limit=limit, - raw_filters=raw_filters, - requester_user=requester_user) + LOG.debug("Retrieving all trigger instances with filters=%s", raw_filters) + return super(TriggerInstanceController, self)._get_all( + exclude_fields=exclude_fields, + include_fields=include_fields, + sort=sort, + offset=offset, + limit=limit, + raw_filters=raw_filters, + requester_user=requester_user, + ) triggertype_controller = TriggerTypeController() diff --git a/st2api/st2api/controllers/v1/user.py b/st2api/st2api/controllers/v1/user.py index e3de60b9784..0593a133845 100644 --- a/st2api/st2api/controllers/v1/user.py +++ b/st2api/st2api/controllers/v1/user.py @@ -17,9 +17,7 @@ from st2common.rbac.backends import get_rbac_backend -__all__ = [ - 'UserController' -] +__all__ = ["UserController"] class UserController(object): @@ -43,21 +41,21 @@ def get(self, requester_user, auth_info): roles = [] data = { - 'username': requester_user.name, - 'authentication': { - 'method': auth_info['method'], - 'location': auth_info['location'] + "username": requester_user.name, + "authentication": { + "method": auth_info["method"], + "location": auth_info["location"], + }, + "rbac": { + "enabled": cfg.CONF.rbac.enable, + "roles": roles, + "is_admin": rbac_utils.user_is_admin(user_db=requester_user), }, - 'rbac': { - 'enabled': cfg.CONF.rbac.enable, - 'roles': roles, - 'is_admin': rbac_utils.user_is_admin(user_db=requester_user) - } } - if auth_info.get('token_expire', None): - token_expire = auth_info['token_expire'].strftime('%Y-%m-%dT%H:%M:%SZ') - data['authentication']['token_expire'] = token_expire + if auth_info.get("token_expire", None): + token_expire = auth_info["token_expire"].strftime("%Y-%m-%dT%H:%M:%SZ") + data["authentication"]["token_expire"] = token_expire return data diff --git a/st2api/st2api/controllers/v1/webhooks.py b/st2api/st2api/controllers/v1/webhooks.py index 35af0c8337b..1985bb4dad2 100644 --- a/st2api/st2api/controllers/v1/webhooks.py +++ b/st2api/st2api/controllers/v1/webhooks.py @@ -19,7 +19,10 @@ from six.moves import http_client from st2common import log as logging -from st2common.constants.auth import HEADER_API_KEY_ATTRIBUTE_NAME, HEADER_ATTRIBUTE_NAME +from st2common.constants.auth import ( + HEADER_API_KEY_ATTRIBUTE_NAME, + HEADER_ATTRIBUTE_NAME, +) from st2common.constants.triggers import WEBHOOK_TRIGGER_TYPES from st2common.models.api.trace import TraceContext from st2common.models.api.trigger import TriggerAPI @@ -35,13 +38,14 @@ LOG = logging.getLogger(__name__) -TRACE_TAG_HEADER = 'St2-Trace-Tag' +TRACE_TAG_HEADER = "St2-Trace-Tag" class HooksHolder(object): """ Maintains a hook to TriggerDB mapping. """ + def __init__(self): self._triggers_by_hook = {} @@ -58,7 +62,7 @@ def remove_hook(self, hook, trigger): return False remove_index = -1 for idx, item in enumerate(self._triggers_by_hook[hook]): - if item['id'] == trigger['id']: + if item["id"] == trigger["id"]: remove_index = idx break if remove_index < 0: @@ -81,17 +85,19 @@ def get_all(self): class WebhooksController(object): def __init__(self, *args, **kwargs): self._hooks = HooksHolder() - self._base_url = '/webhooks/' + self._base_url = "/webhooks/" self._trigger_types = list(WEBHOOK_TRIGGER_TYPES.keys()) self._trigger_dispatcher_service = TriggerDispatcherService(LOG) queue_suffix = self.__class__.__name__ - self._trigger_watcher = TriggerWatcher(create_handler=self._handle_create_trigger, - update_handler=self._handle_update_trigger, - delete_handler=self._handle_delete_trigger, - trigger_types=self._trigger_types, - queue_suffix=queue_suffix, - exclusive=True) + self._trigger_watcher = TriggerWatcher( + create_handler=self._handle_create_trigger, + update_handler=self._handle_update_trigger, + delete_handler=self._handle_delete_trigger, + trigger_types=self._trigger_types, + queue_suffix=queue_suffix, + exclusive=True, + ) self._trigger_watcher.start() self._register_webhook_trigger_types() @@ -108,9 +114,11 @@ def get_one(self, url, requester_user): permission_type = PermissionType.WEBHOOK_VIEW rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=WebhookDB(name=url), - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=WebhookDB(name=url), + permission_type=permission_type, + ) # For demonstration purpose return 1st return triggers[0] @@ -120,55 +128,62 @@ def post(self, hook, webhook_body_api, headers, requester_user): permission_type = PermissionType.WEBHOOK_SEND rbac_utils = get_rbac_backend().get_utils_class() - rbac_utils.assert_user_has_resource_db_permission(user_db=requester_user, - resource_db=WebhookDB(name=hook), - permission_type=permission_type) + rbac_utils.assert_user_has_resource_db_permission( + user_db=requester_user, + resource_db=WebhookDB(name=hook), + permission_type=permission_type, + ) headers = self._get_headers_as_dict(headers) headers = self._filter_authentication_headers(headers) # If webhook contains a trace-tag use that else create create a unique trace-tag. - trace_context = self._create_trace_context(trace_tag=headers.pop(TRACE_TAG_HEADER, None), - hook=hook) + trace_context = self._create_trace_context( + trace_tag=headers.pop(TRACE_TAG_HEADER, None), hook=hook + ) - if hook == 'st2' or hook == 'st2/': + if hook == "st2" or hook == "st2/": # When using st2 or system webhook, body needs to always be a dict if not isinstance(body, dict): type_string = get_json_type_for_python_value(body) - msg = ('Webhook body needs to be an object, got: %s' % (type_string)) + msg = "Webhook body needs to be an object, got: %s" % (type_string) raise ValueError(msg) - trigger = body.get('trigger', None) - payload = body.get('payload', None) + trigger = body.get("trigger", None) + payload = body.get("payload", None) if not trigger: - msg = 'Trigger not specified.' + msg = "Trigger not specified." return abort(http_client.BAD_REQUEST, msg) - self._trigger_dispatcher_service.dispatch_with_context(trigger=trigger, - payload=payload, - trace_context=trace_context, - throw_on_validation_error=True) + self._trigger_dispatcher_service.dispatch_with_context( + trigger=trigger, + payload=payload, + trace_context=trace_context, + throw_on_validation_error=True, + ) else: if not self._is_valid_hook(hook): - self._log_request('Invalid hook.', headers, body) - msg = 'Webhook %s not registered with st2' % hook + self._log_request("Invalid hook.", headers, body) + msg = "Webhook %s not registered with st2" % hook return abort(http_client.NOT_FOUND, msg) triggers = self._hooks.get_triggers_for_hook(hook) payload = {} - payload['headers'] = headers - payload['body'] = body + payload["headers"] = headers + payload["body"] = body # Dispatch trigger instance for each of the trigger found for trigger_dict in triggers: # TODO: Instead of dispatching the whole dict we should just # dispatch TriggerDB.ref or similar - self._trigger_dispatcher_service.dispatch_with_context(trigger=trigger_dict, - payload=payload, - trace_context=trace_context, - throw_on_validation_error=True) + self._trigger_dispatcher_service.dispatch_with_context( + trigger=trigger_dict, + payload=payload, + trace_context=trace_context, + throw_on_validation_error=True, + ) return Response(json=body, status=http_client.ACCEPTED) @@ -183,7 +198,7 @@ def _register_webhook_trigger_types(self): def _create_trace_context(self, trace_tag, hook): # if no trace_tag then create a unique one if not trace_tag: - trace_tag = 'webhook-%s-%s' % (hook, uuid.uuid4().hex) + trace_tag = "webhook-%s-%s" % (hook, uuid.uuid4().hex) return TraceContext(trace_tag=trace_tag) def add_trigger(self, trigger): @@ -191,7 +206,7 @@ def add_trigger(self, trigger): # Note: Permission checking for creating and deleting a webhook is done during rule # creation url = self._get_normalized_url(trigger) - LOG.info('Listening to endpoint: %s', urlparse.urljoin(self._base_url, url)) + LOG.info("Listening to endpoint: %s", urlparse.urljoin(self._base_url, url)) self._hooks.add_hook(url, trigger) def update_trigger(self, trigger): @@ -204,14 +219,16 @@ def remove_trigger(self, trigger): removed = self._hooks.remove_hook(url, trigger) if removed: - LOG.info('Stop listening to endpoint: %s', urlparse.urljoin(self._base_url, url)) + LOG.info( + "Stop listening to endpoint: %s", urlparse.urljoin(self._base_url, url) + ) def _get_normalized_url(self, trigger): """ remove the trailing and leading / so that the hook url and those coming from trigger parameters end up being the same. """ - return trigger['parameters']['url'].strip('/') + return trigger["parameters"]["url"].strip("/") def _get_headers_as_dict(self, headers): headers_dict = {} @@ -220,13 +237,13 @@ def _get_headers_as_dict(self, headers): return headers_dict def _filter_authentication_headers(self, headers): - auth_headers = [HEADER_API_KEY_ATTRIBUTE_NAME, HEADER_ATTRIBUTE_NAME, 'Cookie'] + auth_headers = [HEADER_API_KEY_ATTRIBUTE_NAME, HEADER_ATTRIBUTE_NAME, "Cookie"] return {key: value for key, value in headers.items() if key not in auth_headers} def _log_request(self, msg, headers, body, log_method=LOG.debug): headers = self._get_headers_as_dict(headers) body = str(body) - log_method('%s\n\trequest.header: %s.\n\trequest.body: %s.', msg, headers, body) + log_method("%s\n\trequest.header: %s.\n\trequest.body: %s.", msg, headers, body) ############################################## # Event handler methods for the trigger events diff --git a/st2api/st2api/controllers/v1/workflow_inspection.py b/st2api/st2api/controllers/v1/workflow_inspection.py index 1e5ee53d852..04d60dd2b1f 100644 --- a/st2api/st2api/controllers/v1/workflow_inspection.py +++ b/st2api/st2api/controllers/v1/workflow_inspection.py @@ -30,13 +30,12 @@ class WorkflowInspectionController(object): - def mock_st2_ctx(self): st2_ctx = { - 'st2': { - 'api_url': api_utils.get_full_public_api_url(), - 'action_execution_id': uuid.uuid4().hex, - 'user': cfg.CONF.system_user.user + "st2": { + "api_url": api_utils.get_full_public_api_url(), + "action_execution_id": uuid.uuid4().hex, + "user": cfg.CONF.system_user.user, } } @@ -44,7 +43,7 @@ def mock_st2_ctx(self): def post(self, wf_def): # Load workflow definition into workflow spec model. - spec_module = specs_loader.get_spec_module('native') + spec_module = specs_loader.get_spec_module("native") wf_spec = spec_module.instantiate(wf_def) # Mock the st2 context that is typically passed to the workflow engine. diff --git a/st2api/st2api/validation.py b/st2api/st2api/validation.py index ae92d1d9cb3..42120c57bfb 100644 --- a/st2api/st2api/validation.py +++ b/st2api/st2api/validation.py @@ -15,9 +15,7 @@ from oslo_config import cfg -__all__ = [ - 'validate_rbac_is_correctly_configured' -] +__all__ = ["validate_rbac_is_correctly_configured"] def validate_rbac_is_correctly_configured(): @@ -28,24 +26,29 @@ def validate_rbac_is_correctly_configured(): return True from st2common.rbac.backends import get_available_backends + available_rbac_backends = get_available_backends() # 1. Verify auth is enabled if not cfg.CONF.auth.enable: - msg = ('Authentication is not enabled. RBAC only works when authentication is enabled. ' - 'You can either enable authentication or disable RBAC.') + msg = ( + "Authentication is not enabled. RBAC only works when authentication is enabled. " + "You can either enable authentication or disable RBAC." + ) raise ValueError(msg) # 2. Verify default backend is set - if cfg.CONF.rbac.backend != 'default': - msg = ('You have enabled RBAC, but RBAC backend is not set to "default". ' - 'For RBAC to work, you need to set ' - '"rbac.backend" config option to "default" and restart st2api service.') + if cfg.CONF.rbac.backend != "default": + msg = ( + 'You have enabled RBAC, but RBAC backend is not set to "default". ' + "For RBAC to work, you need to set " + '"rbac.backend" config option to "default" and restart st2api service.' + ) raise ValueError(msg) # 3. Verify default RBAC backend is available - if 'default' not in available_rbac_backends: - msg = ('"default" RBAC backend is not available.') + if "default" not in available_rbac_backends: + msg = '"default" RBAC backend is not available.' raise ValueError(msg) return True diff --git a/st2api/st2api/wsgi.py b/st2api/st2api/wsgi.py index b9c92b7bf4a..79baf0f110f 100644 --- a/st2api/st2api/wsgi.py +++ b/st2api/st2api/wsgi.py @@ -20,6 +20,7 @@ import os from st2common.util.monkey_patch import monkey_patch + # Note: We need to perform monkey patching in the worker. If we do it in # the master process (gunicorn_config.py), it breaks tons of things # including shutdown @@ -32,8 +33,11 @@ from st2api import app config = { - 'is_gunicorn': True, - 'config_args': ['--config-file', os.environ.get('ST2_CONFIG_PATH', '/etc/st2/st2.conf')] + "is_gunicorn": True, + "config_args": [ + "--config-file", + os.environ.get("ST2_CONFIG_PATH", "/etc/st2/st2.conf"), + ], } application = app.setup_app(config) diff --git a/st2api/tests/integration/test_gunicorn_configs.py b/st2api/tests/integration/test_gunicorn_configs.py index 65950bfa7cf..9375cf3b858 100644 --- a/st2api/tests/integration/test_gunicorn_configs.py +++ b/st2api/tests/integration/test_gunicorn_configs.py @@ -28,38 +28,44 @@ BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -ST2_CONFIG_PATH = os.path.join(BASE_DIR, '../../../conf/st2.tests.conf') +ST2_CONFIG_PATH = os.path.join(BASE_DIR, "../../../conf/st2.tests.conf") class GunicornWSGIEntryPointTestCase(IntegrationTestCase): - @unittest2.skipIf(profiling.is_enabled(), 'Profiling is enabled') + @unittest2.skipIf(profiling.is_enabled(), "Profiling is enabled") def test_st2api_wsgi_entry_point(self): port = random.randint(10000, 30000) - cmd = ('gunicorn st2api.wsgi:application -k eventlet -b "127.0.0.1:%s" --workers 1' % port) + cmd = ( + 'gunicorn st2api.wsgi:application -k eventlet -b "127.0.0.1:%s" --workers 1' + % port + ) env = os.environ.copy() - env['ST2_CONFIG_PATH'] = ST2_CONFIG_PATH + env["ST2_CONFIG_PATH"] = ST2_CONFIG_PATH process = subprocess.Popen(cmd, env=env, shell=True, preexec_fn=os.setsid) try: self.add_process(process=process) eventlet.sleep(8) self.assertProcessIsRunning(process=process) - response = requests.get('http://127.0.0.1:%s/v1/actions' % (port)) + response = requests.get("http://127.0.0.1:%s/v1/actions" % (port)) self.assertEqual(response.status_code, http_client.OK) finally: kill_process(process) - @unittest2.skipIf(profiling.is_enabled(), 'Profiling is enabled') + @unittest2.skipIf(profiling.is_enabled(), "Profiling is enabled") def test_st2auth(self): port = random.randint(10000, 30000) - cmd = ('gunicorn st2auth.wsgi:application -k eventlet -b "127.0.0.1:%s" --workers 1' % port) + cmd = ( + 'gunicorn st2auth.wsgi:application -k eventlet -b "127.0.0.1:%s" --workers 1' + % port + ) env = os.environ.copy() - env['ST2_CONFIG_PATH'] = ST2_CONFIG_PATH + env["ST2_CONFIG_PATH"] = ST2_CONFIG_PATH process = subprocess.Popen(cmd, env=env, shell=True, preexec_fn=os.setsid) try: self.add_process(process=process) eventlet.sleep(8) self.assertProcessIsRunning(process=process) - response = requests.post('http://127.0.0.1:%s/tokens' % (port)) + response = requests.post("http://127.0.0.1:%s/tokens" % (port)) self.assertEqual(response.status_code, http_client.UNAUTHORIZED) finally: kill_process(process) diff --git a/st2api/tests/unit/controllers/test_root.py b/st2api/tests/unit/controllers/test_root.py index d4172ce1556..db4ea017136 100644 --- a/st2api/tests/unit/controllers/test_root.py +++ b/st2api/tests/unit/controllers/test_root.py @@ -15,15 +15,13 @@ from st2tests.api import FunctionalTest -__all__ = [ - 'RootControllerTestCase' -] +__all__ = ["RootControllerTestCase"] class RootControllerTestCase(FunctionalTest): def test_get_index(self): - paths = ['/', '/v1/', '/v1'] + paths = ["/", "/v1/", "/v1"] for path in paths: resp = self.app.get(path) - self.assertIn('version', resp.json) - self.assertIn('docs_url', resp.json) + self.assertIn("version", resp.json) + self.assertIn("docs_url", resp.json) diff --git a/st2api/tests/unit/controllers/v1/test_action_alias.py b/st2api/tests/unit/controllers/v1/test_action_alias.py index 299ce530e33..208ed082be2 100644 --- a/st2api/tests/unit/controllers/v1/test_action_alias.py +++ b/st2api/tests/unit/controllers/v1/test_action_alias.py @@ -21,31 +21,33 @@ from st2tests.api import FunctionalTest from st2tests.api import APIControllerWithIncludeAndExcludeFilterTestCase -FIXTURES_PACK = 'aliases' +FIXTURES_PACK = "aliases" TEST_MODELS = { - 'aliases': ['alias1.yaml', 'alias2.yaml', 'alias_with_undefined_jinja_in_ack_format.yaml'], - 'actions': ['action3.yaml', 'action4.yaml'] + "aliases": [ + "alias1.yaml", + "alias2.yaml", + "alias_with_undefined_jinja_in_ack_format.yaml", + ], + "actions": ["action3.yaml", "action4.yaml"], } TEST_LOAD_MODELS = { - 'aliases': ['alias3.yaml'], + "aliases": ["alias3.yaml"], } -GENERIC_FIXTURES_PACK = 'generic' +GENERIC_FIXTURES_PACK = "generic" -TEST_LOAD_MODELS_GENERIC = { - 'aliases': ['alias3.yaml'], - 'runners': ['testrunner1.yaml'] -} +TEST_LOAD_MODELS_GENERIC = {"aliases": ["alias3.yaml"], "runners": ["testrunner1.yaml"]} -class ActionAliasControllerTestCase(FunctionalTest, - APIControllerWithIncludeAndExcludeFilterTestCase): - get_all_path = '/v1/actionalias' +class ActionAliasControllerTestCase( + FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase +): + get_all_path = "/v1/actionalias" controller_cls = ActionAliasController - include_attribute_field_name = 'formats' - exclude_attribute_field_name = 'result' + include_attribute_field_name = "formats" + exclude_attribute_field_name = "result" models = None alias1 = None @@ -56,153 +58,186 @@ class ActionAliasControllerTestCase(FunctionalTest, @classmethod def setUpClass(cls): super(ActionAliasControllerTestCase, cls).setUpClass() - cls.models = FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_MODELS) - cls.alias1 = cls.models['aliases']['alias1.yaml'] - cls.alias2 = cls.models['aliases']['alias2.yaml'] - - loaded_models = FixturesLoader().load_models(fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_LOAD_MODELS) - cls.alias3 = loaded_models['aliases']['alias3.yaml'] - - FixturesLoader().save_fixtures_to_db(fixtures_pack=GENERIC_FIXTURES_PACK, - fixtures_dict={'aliases': ['alias7.yaml']}) - - loaded_models = FixturesLoader().load_models(fixtures_pack=GENERIC_FIXTURES_PACK, - fixtures_dict=TEST_LOAD_MODELS_GENERIC) - cls.alias3_generic = loaded_models['aliases']['alias3.yaml'] + cls.models = FixturesLoader().save_fixtures_to_db( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS + ) + cls.alias1 = cls.models["aliases"]["alias1.yaml"] + cls.alias2 = cls.models["aliases"]["alias2.yaml"] + + loaded_models = FixturesLoader().load_models( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_LOAD_MODELS + ) + cls.alias3 = loaded_models["aliases"]["alias3.yaml"] + + FixturesLoader().save_fixtures_to_db( + fixtures_pack=GENERIC_FIXTURES_PACK, + fixtures_dict={"aliases": ["alias7.yaml"]}, + ) + + loaded_models = FixturesLoader().load_models( + fixtures_pack=GENERIC_FIXTURES_PACK, fixtures_dict=TEST_LOAD_MODELS_GENERIC + ) + cls.alias3_generic = loaded_models["aliases"]["alias3.yaml"] def test_get_all(self): - resp = self.app.get('/v1/actionalias') + resp = self.app.get("/v1/actionalias") self.assertEqual(resp.status_int, 200) - self.assertEqual(len(resp.json), 4, '/v1/actionalias did not return all aliases.') - - retrieved_names = [alias['name'] for alias in resp.json] - - self.assertEqual(retrieved_names, [self.alias1.name, self.alias2.name, - 'alias_with_undefined_jinja_in_ack_format', - 'alias7'], - 'Incorrect aliases retrieved.') + self.assertEqual( + len(resp.json), 4, "/v1/actionalias did not return all aliases." + ) + + retrieved_names = [alias["name"] for alias in resp.json] + + self.assertEqual( + retrieved_names, + [ + self.alias1.name, + self.alias2.name, + "alias_with_undefined_jinja_in_ack_format", + "alias7", + ], + "Incorrect aliases retrieved.", + ) def test_get_all_query_param_filters(self): - resp = self.app.get('/v1/actionalias?pack=doesntexist') + resp = self.app.get("/v1/actionalias?pack=doesntexist") self.assertEqual(resp.status_int, 200) self.assertEqual(len(resp.json), 0) - resp = self.app.get('/v1/actionalias?pack=aliases') + resp = self.app.get("/v1/actionalias?pack=aliases") self.assertEqual(resp.status_int, 200) self.assertEqual(len(resp.json), 3) for alias_api in resp.json: - self.assertEqual(alias_api['pack'], 'aliases') + self.assertEqual(alias_api["pack"], "aliases") - resp = self.app.get('/v1/actionalias?pack=generic') + resp = self.app.get("/v1/actionalias?pack=generic") self.assertEqual(resp.status_int, 200) self.assertEqual(len(resp.json), 1) for alias_api in resp.json: - self.assertEqual(alias_api['pack'], 'generic') + self.assertEqual(alias_api["pack"], "generic") - resp = self.app.get('/v1/actionalias?name=doesntexist') + resp = self.app.get("/v1/actionalias?name=doesntexist") self.assertEqual(resp.status_int, 200) self.assertEqual(len(resp.json), 0) - resp = self.app.get('/v1/actionalias?name=alias2') + resp = self.app.get("/v1/actionalias?name=alias2") self.assertEqual(resp.status_int, 200) self.assertEqual(len(resp.json), 1) - self.assertEqual(resp.json[0]['name'], 'alias2') + self.assertEqual(resp.json[0]["name"], "alias2") def test_get_one(self): - resp = self.app.get('/v1/actionalias/%s' % self.alias1.id) + resp = self.app.get("/v1/actionalias/%s" % self.alias1.id) self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json['name'], self.alias1.name, - 'Incorrect aliases retrieved.') + self.assertEqual( + resp.json["name"], self.alias1.name, "Incorrect aliases retrieved." + ) def test_post_delete(self): post_resp = self._do_post(vars(ActionAliasAPI.from_model(self.alias3))) self.assertEqual(post_resp.status_int, 201) - get_resp = self.app.get('/v1/actionalias/%s' % post_resp.json['id']) + get_resp = self.app.get("/v1/actionalias/%s" % post_resp.json["id"]) self.assertEqual(get_resp.status_int, 200) - self.assertEqual(get_resp.json['name'], self.alias3.name, - 'Incorrect aliases retrieved.') + self.assertEqual( + get_resp.json["name"], self.alias3.name, "Incorrect aliases retrieved." + ) - del_resp = self.__do_delete(post_resp.json['id']) + del_resp = self.__do_delete(post_resp.json["id"]) self.assertEqual(del_resp.status_int, 204) - get_resp = self.app.get('/v1/actionalias/%s' % post_resp.json['id'], expect_errors=True) + get_resp = self.app.get( + "/v1/actionalias/%s" % post_resp.json["id"], expect_errors=True + ) self.assertEqual(get_resp.status_int, 404) def test_update_existing_alias(self): post_resp = self._do_post(vars(ActionAliasAPI.from_model(self.alias3))) self.assertEqual(post_resp.status_int, 201) - self.assertEqual(post_resp.json['name'], self.alias3['name']) + self.assertEqual(post_resp.json["name"], self.alias3["name"]) data = vars(ActionAliasAPI.from_model(self.alias3)) - data['name'] = 'updated-alias-name' + data["name"] = "updated-alias-name" - put_resp = self.app.put_json('/v1/actionalias/%s' % post_resp.json['id'], data) - self.assertEqual(put_resp.json['name'], data['name']) + put_resp = self.app.put_json("/v1/actionalias/%s" % post_resp.json["id"], data) + self.assertEqual(put_resp.json["name"], data["name"]) - get_resp = self.app.get('/v1/actionalias/%s' % post_resp.json['id']) - self.assertEqual(get_resp.json['name'], data['name']) + get_resp = self.app.get("/v1/actionalias/%s" % post_resp.json["id"]) + self.assertEqual(get_resp.json["name"], data["name"]) - del_resp = self.__do_delete(post_resp.json['id']) + del_resp = self.__do_delete(post_resp.json["id"]) self.assertEqual(del_resp.status_int, 204) def test_post_dup_name(self): post_resp = self._do_post(vars(ActionAliasAPI.from_model(self.alias3))) self.assertEqual(post_resp.status_int, 201) - post_resp_dup_name = self._do_post(vars(ActionAliasAPI.from_model(self.alias3_generic))) + post_resp_dup_name = self._do_post( + vars(ActionAliasAPI.from_model(self.alias3_generic)) + ) self.assertEqual(post_resp_dup_name.status_int, 201) - self.__do_delete(post_resp.json['id']) - self.__do_delete(post_resp_dup_name.json['id']) + self.__do_delete(post_resp.json["id"]) + self.__do_delete(post_resp_dup_name.json["id"]) def test_match(self): # No matching patterns - data = {'command': 'hello donny'} + data = {"command": "hello donny"} resp = self.app.post_json("/v1/actionalias/match", data, expect_errors=True) self.assertEqual(resp.status_int, 400) - self.assertEqual(str(resp.json['faultstring']), - "Command 'hello donny' matched no patterns") + self.assertEqual( + str(resp.json["faultstring"]), "Command 'hello donny' matched no patterns" + ) # More than one matching pattern - data = {'command': 'Lorem ipsum banana dolor sit pineapple amet.'} + data = {"command": "Lorem ipsum banana dolor sit pineapple amet."} resp = self.app.post_json("/v1/actionalias/match", data, expect_errors=True) self.assertEqual(resp.status_int, 400) - self.assertEqual(str(resp.json['faultstring']), - "Command 'Lorem ipsum banana dolor sit pineapple amet.' " - "matched more than 1 pattern") + self.assertEqual( + str(resp.json["faultstring"]), + "Command 'Lorem ipsum banana dolor sit pineapple amet.' " + "matched more than 1 pattern", + ) # Single matching pattern - success - data = {'command': 'run whoami on localhost1'} + data = {"command": "run whoami on localhost1"} resp = self.app.post_json("/v1/actionalias/match", data) self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json['actionalias']['name'], - 'alias_with_undefined_jinja_in_ack_format') + self.assertEqual( + resp.json["actionalias"]["name"], "alias_with_undefined_jinja_in_ack_format" + ) def test_help(self): resp = self.app.get("/v1/actionalias/help") self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json.get('available'), 5) + self.assertEqual(resp.json.get("available"), 5) def test_help_args(self): - resp = self.app.get("/v1/actionalias/help?filter=.*&pack=aliases&limit=1&offset=0") + resp = self.app.get( + "/v1/actionalias/help?filter=.*&pack=aliases&limit=1&offset=0" + ) self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json.get('available'), 3) - self.assertEqual(len(resp.json.get('helpstrings')), 1) + self.assertEqual(resp.json.get("available"), 3) + self.assertEqual(len(resp.json.get("helpstrings")), 1) def _insert_mock_models(self): - alias_ids = [self.alias1['id'], self.alias2['id'], self.alias3['id'], - self.alias3_generic['id']] + alias_ids = [ + self.alias1["id"], + self.alias2["id"], + self.alias3["id"], + self.alias3_generic["id"], + ] return alias_ids def _delete_mock_models(self, object_ids): return None def _do_post(self, actionalias, expect_errors=False): - return self.app.post_json('/v1/actionalias', actionalias, expect_errors=expect_errors) + return self.app.post_json( + "/v1/actionalias", actionalias, expect_errors=expect_errors + ) def __do_delete(self, actionalias_id, expect_errors=False): - return self.app.delete('/v1/actionalias/%s' % actionalias_id, expect_errors=expect_errors) + return self.app.delete( + "/v1/actionalias/%s" % actionalias_id, expect_errors=expect_errors + ) diff --git a/st2api/tests/unit/controllers/v1/test_action_views.py b/st2api/tests/unit/controllers/v1/test_action_views.py index dbb9346662a..a28219c04d5 100644 --- a/st2api/tests/unit/controllers/v1/test_action_views.py +++ b/st2api/tests/unit/controllers/v1/test_action_views.py @@ -25,42 +25,44 @@ # ACTION_1: Good action definition. ACTION_1 = { - 'name': 'st2.dummy.action1', - 'description': 'test description', - 'enabled': True, - 'pack': 'wolfpack', - 'entry_point': 'test/action1.sh', - 'runner_type': 'local-shell-script', - 'parameters': { - 'a': {'type': 'string', 'default': 'A1'}, - 'b': {'type': 'string', 'default': 'B1'} - } + "name": "st2.dummy.action1", + "description": "test description", + "enabled": True, + "pack": "wolfpack", + "entry_point": "test/action1.sh", + "runner_type": "local-shell-script", + "parameters": { + "a": {"type": "string", "default": "A1"}, + "b": {"type": "string", "default": "B1"}, + }, } # ACTION_2: Good action definition. No content pack. ACTION_2 = { - 'name': 'st2.dummy.action2', - 'description': 'test description', - 'enabled': True, - 'pack': 'wolfpack', - 'entry_point': 'test/action2.py', - 'runner_type': 'local-shell-script', - 'parameters': { - 'c': {'type': 'string', 'default': 'C1', 'position': 0}, - 'd': {'type': 'string', 'default': 'D1', 'immutable': True} - } + "name": "st2.dummy.action2", + "description": "test description", + "enabled": True, + "pack": "wolfpack", + "entry_point": "test/action2.py", + "runner_type": "local-shell-script", + "parameters": { + "c": {"type": "string", "default": "C1", "position": 0}, + "d": {"type": "string", "default": "D1", "immutable": True}, + }, } -class ActionViewsOverviewControllerTestCase(FunctionalTest, - APIControllerWithIncludeAndExcludeFilterTestCase): - get_all_path = '/v1/actions/views/overview' +class ActionViewsOverviewControllerTestCase( + FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase +): + get_all_path = "/v1/actions/views/overview" controller_cls = OverviewController - include_attribute_field_name = 'entry_point' - exclude_attribute_field_name = 'parameters' + include_attribute_field_name = "entry_point" + exclude_attribute_field_name = "parameters" - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_get_one(self): post_resp = self._do_post(ACTION_1) action_id = self._get_action_id(post_resp) @@ -71,8 +73,9 @@ def test_get_one(self): finally: self._do_delete(action_id) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_get_one_ref(self): post_resp = self._do_post(ACTION_1) action_id = self._get_action_id(post_resp) @@ -80,66 +83,85 @@ def test_get_one_ref(self): try: get_resp = self._do_get_one(action_ref) self.assertEqual(get_resp.status_int, 200) - self.assertEqual(get_resp.json['ref'], action_ref) + self.assertEqual(get_resp.json["ref"], action_ref) finally: self._do_delete(action_id) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_get_all_and_limit_minus_one(self): action_1_id = self._get_action_id(self._do_post(ACTION_1)) action_2_id = self._get_action_id(self._do_post(ACTION_2)) try: - resp = self.app.get('/v1/actions/views/overview') + resp = self.app.get("/v1/actions/views/overview") self.assertEqual(resp.status_int, 200) - self.assertEqual(len(resp.json), 2, - '/v1/actions/views/overview did not return all actions.') - resp = self.app.get('/v1/actions/views/overview/?limit=-1') + self.assertEqual( + len(resp.json), + 2, + "/v1/actions/views/overview did not return all actions.", + ) + resp = self.app.get("/v1/actions/views/overview/?limit=-1") self.assertEqual(resp.status_int, 200) - self.assertEqual(len(resp.json), 2, - '/v1/actions/views/overview did not return all actions.') + self.assertEqual( + len(resp.json), + 2, + "/v1/actions/views/overview did not return all actions.", + ) finally: self._do_delete(action_1_id) self._do_delete(action_2_id) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_get_all_negative_limit(self): action_1_id = self._get_action_id(self._do_post(ACTION_1)) action_2_id = self._get_action_id(self._do_post(ACTION_2)) try: - resp = self.app.get('/v1/actions/views/overview/?limit=-22', expect_errors=True) + resp = self.app.get( + "/v1/actions/views/overview/?limit=-22", expect_errors=True + ) self.assertEqual(resp.status_int, 400) - self.assertEqual(resp.json['faultstring'], - u'Limit, "-22" specified, must be a positive number.') + self.assertEqual( + resp.json["faultstring"], + 'Limit, "-22" specified, must be a positive number.', + ) finally: self._do_delete(action_1_id) self._do_delete(action_2_id) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_get_all_filter_by_name(self): action_1_id = self._get_action_id(self._do_post(ACTION_1)) action_2_id = self._get_action_id(self._do_post(ACTION_2)) try: - resp = self.app.get('/v1/actions/views/overview?name=%s' % str('st2.dummy.action2')) + resp = self.app.get( + "/v1/actions/views/overview?name=%s" % str("st2.dummy.action2") + ) self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json[0]['id'], action_2_id, 'Filtering failed') + self.assertEqual(resp.json[0]["id"], action_2_id, "Filtering failed") finally: self._do_delete(action_1_id) self._do_delete(action_2_id) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_get_all_include_attributes_filter(self): - return super(ActionViewsOverviewControllerTestCase, self) \ - .test_get_all_include_attributes_filter() + return super( + ActionViewsOverviewControllerTestCase, self + ).test_get_all_include_attributes_filter() - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_get_all_exclude_attributes_filter(self): - return super(ActionViewsOverviewControllerTestCase, self) \ - .test_get_all_include_attributes_filter() + return super( + ActionViewsOverviewControllerTestCase, self + ).test_get_all_include_attributes_filter() def _insert_mock_models(self): action_1_id = self._get_action_id(self._do_post(ACTION_1)) @@ -149,115 +171,141 @@ def _insert_mock_models(self): @staticmethod def _get_action_id(resp): - return resp.json['id'] + return resp.json["id"] @staticmethod def _get_action_ref(resp): - return '.'.join((resp.json['pack'], resp.json['name'])) + return ".".join((resp.json["pack"], resp.json["name"])) @staticmethod def _get_action_name(resp): - return resp.json['name'] + return resp.json["name"] def _do_get_one(self, action_id, expect_errors=False): - return self.app.get('/v1/actions/views/overview/%s' % action_id, - expect_errors=expect_errors) + return self.app.get( + "/v1/actions/views/overview/%s" % action_id, expect_errors=expect_errors + ) def _do_post(self, action, expect_errors=False): - return self.app.post_json('/v1/actions', action, expect_errors=expect_errors) + return self.app.post_json("/v1/actions", action, expect_errors=expect_errors) def _do_delete(self, action_id, expect_errors=False): - return self.app.delete('/v1/actions/%s' % action_id, expect_errors=expect_errors) + return self.app.delete( + "/v1/actions/%s" % action_id, expect_errors=expect_errors + ) class ActionViewsParametersControllerTestCase(FunctionalTest): - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_get_one(self): - post_resp = self.app.post_json('/v1/actions', ACTION_1) - action_id = post_resp.json['id'] + post_resp = self.app.post_json("/v1/actions", ACTION_1) + action_id = post_resp.json["id"] try: - get_resp = self.app.get('/v1/actions/views/parameters/%s' % action_id) + get_resp = self.app.get("/v1/actions/views/parameters/%s" % action_id) self.assertEqual(get_resp.status_int, 200) finally: - self.app.delete('/v1/actions/%s' % action_id) + self.app.delete("/v1/actions/%s" % action_id) class ActionEntryPointViewControllerTestCase(FunctionalTest): - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) - @mock.patch.object(content_utils, 'get_entry_point_abs_path', mock.MagicMock( - return_value='/path/to/file')) - @mock.patch(mock_open_name, mock.mock_open(read_data='file content'), create=True) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + content_utils, + "get_entry_point_abs_path", + mock.MagicMock(return_value="/path/to/file"), + ) + @mock.patch(mock_open_name, mock.mock_open(read_data="file content"), create=True) def test_get_one(self): - post_resp = self.app.post_json('/v1/actions', ACTION_1) - action_id = post_resp.json['id'] + post_resp = self.app.post_json("/v1/actions", ACTION_1) + action_id = post_resp.json["id"] try: - get_resp = self.app.get('/v1/actions/views/entry_point/%s' % action_id) + get_resp = self.app.get("/v1/actions/views/entry_point/%s" % action_id) self.assertEqual(get_resp.status_int, 200) finally: - self.app.delete('/v1/actions/%s' % action_id) + self.app.delete("/v1/actions/%s" % action_id) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) - @mock.patch.object(content_utils, 'get_entry_point_abs_path', mock.MagicMock( - return_value='/path/to/file')) - @mock.patch(mock_open_name, mock.mock_open(read_data='file content'), create=True) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + content_utils, + "get_entry_point_abs_path", + mock.MagicMock(return_value="/path/to/file"), + ) + @mock.patch(mock_open_name, mock.mock_open(read_data="file content"), create=True) def test_get_one_ref(self): - post_resp = self.app.post_json('/v1/actions', ACTION_1) - action_id = post_resp.json['id'] - action_ref = '.'.join((post_resp.json['pack'], post_resp.json['name'])) + post_resp = self.app.post_json("/v1/actions", ACTION_1) + action_id = post_resp.json["id"] + action_ref = ".".join((post_resp.json["pack"], post_resp.json["name"])) try: - get_resp = self.app.get('/v1/actions/views/entry_point/%s' % action_ref) + get_resp = self.app.get("/v1/actions/views/entry_point/%s" % action_ref) self.assertEqual(get_resp.status_int, 200) finally: - self.app.delete('/v1/actions/%s' % action_id) + self.app.delete("/v1/actions/%s" % action_id) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) - @mock.patch.object(content_utils, 'get_entry_point_abs_path', mock.MagicMock( - return_value='/path/to/file.yaml')) - @mock.patch(mock_open_name, mock.mock_open(read_data='file content'), create=True) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + content_utils, + "get_entry_point_abs_path", + mock.MagicMock(return_value="/path/to/file.yaml"), + ) + @mock.patch(mock_open_name, mock.mock_open(read_data="file content"), create=True) def test_get_one_ref_yaml_content_type(self): - post_resp = self.app.post_json('/v1/actions', ACTION_1) - action_id = post_resp.json['id'] - action_ref = '.'.join((post_resp.json['pack'], post_resp.json['name'])) + post_resp = self.app.post_json("/v1/actions", ACTION_1) + action_id = post_resp.json["id"] + action_ref = ".".join((post_resp.json["pack"], post_resp.json["name"])) try: - get_resp = self.app.get('/v1/actions/views/entry_point/%s' % action_ref) + get_resp = self.app.get("/v1/actions/views/entry_point/%s" % action_ref) self.assertEqual(get_resp.status_int, 200) - self.assertEqual(get_resp.headers['Content-Type'], 'application/x-yaml') + self.assertEqual(get_resp.headers["Content-Type"], "application/x-yaml") finally: - self.app.delete('/v1/actions/%s' % action_id) + self.app.delete("/v1/actions/%s" % action_id) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) - @mock.patch.object(content_utils, 'get_entry_point_abs_path', mock.MagicMock( - return_value=__file__.replace('.pyc', '.py'))) - @mock.patch(mock_open_name, mock.mock_open(read_data='file content'), create=True) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + content_utils, + "get_entry_point_abs_path", + mock.MagicMock(return_value=__file__.replace(".pyc", ".py")), + ) + @mock.patch(mock_open_name, mock.mock_open(read_data="file content"), create=True) def test_get_one_ref_python_content_type(self): - post_resp = self.app.post_json('/v1/actions', ACTION_1) - action_id = post_resp.json['id'] - action_ref = '.'.join((post_resp.json['pack'], post_resp.json['name'])) + post_resp = self.app.post_json("/v1/actions", ACTION_1) + action_id = post_resp.json["id"] + action_ref = ".".join((post_resp.json["pack"], post_resp.json["name"])) try: - get_resp = self.app.get('/v1/actions/views/entry_point/%s' % action_ref) + get_resp = self.app.get("/v1/actions/views/entry_point/%s" % action_ref) self.assertEqual(get_resp.status_int, 200) - self.assertIn(get_resp.headers['Content-Type'], ['application/x-python', - 'text/x-python']) + self.assertIn( + get_resp.headers["Content-Type"], + ["application/x-python", "text/x-python"], + ) finally: - self.app.delete('/v1/actions/%s' % action_id) + self.app.delete("/v1/actions/%s" % action_id) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) - @mock.patch.object(content_utils, 'get_entry_point_abs_path', mock.MagicMock( - return_value='/file/does/not/exist')) - @mock.patch(mock_open_name, mock.mock_open(read_data='file content'), create=True) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + content_utils, + "get_entry_point_abs_path", + mock.MagicMock(return_value="/file/does/not/exist"), + ) + @mock.patch(mock_open_name, mock.mock_open(read_data="file content"), create=True) def test_get_one_ref_text_plain_content_type(self): - post_resp = self.app.post_json('/v1/actions', ACTION_1) - action_id = post_resp.json['id'] - action_ref = '.'.join((post_resp.json['pack'], post_resp.json['name'])) + post_resp = self.app.post_json("/v1/actions", ACTION_1) + action_id = post_resp.json["id"] + action_ref = ".".join((post_resp.json["pack"], post_resp.json["name"])) try: - get_resp = self.app.get('/v1/actions/views/entry_point/%s' % action_ref) + get_resp = self.app.get("/v1/actions/views/entry_point/%s" % action_ref) self.assertEqual(get_resp.status_int, 200) - self.assertEqual(get_resp.headers['Content-Type'], 'text/plain') + self.assertEqual(get_resp.headers["Content-Type"], "text/plain") finally: - self.app.delete('/v1/actions/%s' % action_id) + self.app.delete("/v1/actions/%s" % action_id) diff --git a/st2api/tests/unit/controllers/v1/test_actions.py b/st2api/tests/unit/controllers/v1/test_actions.py index c189803d2c6..40e973c0a14 100644 --- a/st2api/tests/unit/controllers/v1/test_actions.py +++ b/st2api/tests/unit/controllers/v1/test_actions.py @@ -41,257 +41,259 @@ # ACTION_1: Good action definition. ACTION_1 = { - 'name': 'st2.dummy.action1', - 'description': 'test description', - 'enabled': True, - 'pack': 'wolfpack', - 'entry_point': '/tmp/test/action1.sh', - 'runner_type': 'local-shell-script', - 'parameters': { - 'a': {'type': 'string', 'default': 'A1'}, - 'b': {'type': 'string', 'default': 'B1'} + "name": "st2.dummy.action1", + "description": "test description", + "enabled": True, + "pack": "wolfpack", + "entry_point": "/tmp/test/action1.sh", + "runner_type": "local-shell-script", + "parameters": { + "a": {"type": "string", "default": "A1"}, + "b": {"type": "string", "default": "B1"}, }, - 'tags': [ - {'name': 'tag1', 'value': 'dont-care'}, - {'name': 'tag2', 'value': 'dont-care'} - ] + "tags": [ + {"name": "tag1", "value": "dont-care"}, + {"name": "tag2", "value": "dont-care"}, + ], } # ACTION_2: Good action definition. No content pack. ACTION_2 = { - 'name': 'st2.dummy.action2', - 'description': 'test description', - 'enabled': True, - 'entry_point': '/tmp/test/action2.py', - 'runner_type': 'local-shell-script', - 'parameters': { - 'c': {'type': 'string', 'default': 'C1', 'position': 0}, - 'd': {'type': 'string', 'default': 'D1', 'immutable': True} - } + "name": "st2.dummy.action2", + "description": "test description", + "enabled": True, + "entry_point": "/tmp/test/action2.py", + "runner_type": "local-shell-script", + "parameters": { + "c": {"type": "string", "default": "C1", "position": 0}, + "d": {"type": "string", "default": "D1", "immutable": True}, + }, } # ACTION_3: No enabled field ACTION_3 = { - 'name': 'st2.dummy.action3', - 'description': 'test description', - 'pack': 'wolfpack', - 'entry_point': '/tmp/test/action1.sh', - 'runner_type': 'local-shell-script', - 'parameters': { - 'a': {'type': 'string', 'default': 'A1'}, - 'b': {'type': 'string', 'default': 'B1'} - } + "name": "st2.dummy.action3", + "description": "test description", + "pack": "wolfpack", + "entry_point": "/tmp/test/action1.sh", + "runner_type": "local-shell-script", + "parameters": { + "a": {"type": "string", "default": "A1"}, + "b": {"type": "string", "default": "B1"}, + }, } # ACTION_4: Enabled field is False ACTION_4 = { - 'name': 'st2.dummy.action4', - 'description': 'test description', - 'enabled': False, - 'pack': 'wolfpack', - 'entry_point': '/tmp/test/action1.sh', - 'runner_type': 'local-shell-script', - 'parameters': { - 'a': {'type': 'string', 'default': 'A1'}, - 'b': {'type': 'string', 'default': 'B1'} - } + "name": "st2.dummy.action4", + "description": "test description", + "enabled": False, + "pack": "wolfpack", + "entry_point": "/tmp/test/action1.sh", + "runner_type": "local-shell-script", + "parameters": { + "a": {"type": "string", "default": "A1"}, + "b": {"type": "string", "default": "B1"}, + }, } # ACTION_5: Invalid runner_type ACTION_5 = { - 'name': 'st2.dummy.action5', - 'description': 'test description', - 'enabled': False, - 'pack': 'wolfpack', - 'entry_point': '/tmp/test/action1.sh', - 'runner_type': 'xyzxyz', - 'parameters': { - 'a': {'type': 'string', 'default': 'A1'}, - 'b': {'type': 'string', 'default': 'B1'} - } + "name": "st2.dummy.action5", + "description": "test description", + "enabled": False, + "pack": "wolfpack", + "entry_point": "/tmp/test/action1.sh", + "runner_type": "xyzxyz", + "parameters": { + "a": {"type": "string", "default": "A1"}, + "b": {"type": "string", "default": "B1"}, + }, } # ACTION_6: No description field. ACTION_6 = { - 'name': 'st2.dummy.action6', - 'enabled': False, - 'pack': 'wolfpack', - 'entry_point': '/tmp/test/action1.sh', - 'runner_type': 'local-shell-script', - 'parameters': { - 'a': {'type': 'string', 'default': 'A1'}, - 'b': {'type': 'string', 'default': 'B1'} - } + "name": "st2.dummy.action6", + "enabled": False, + "pack": "wolfpack", + "entry_point": "/tmp/test/action1.sh", + "runner_type": "local-shell-script", + "parameters": { + "a": {"type": "string", "default": "A1"}, + "b": {"type": "string", "default": "B1"}, + }, } # ACTION_7: id field provided ACTION_7 = { - 'id': 'foobar', - 'name': 'st2.dummy.action7', - 'description': 'test description', - 'enabled': False, - 'pack': 'wolfpack', - 'entry_point': '/tmp/test/action1.sh', - 'runner_type': 'local-shell-script', - 'parameters': { - 'a': {'type': 'string', 'default': 'A1'}, - 'b': {'type': 'string', 'default': 'B1'} - } + "id": "foobar", + "name": "st2.dummy.action7", + "description": "test description", + "enabled": False, + "pack": "wolfpack", + "entry_point": "/tmp/test/action1.sh", + "runner_type": "local-shell-script", + "parameters": { + "a": {"type": "string", "default": "A1"}, + "b": {"type": "string", "default": "B1"}, + }, } # ACTION_8: id field provided ACTION_8 = { - 'name': 'st2.dummy.action8', - 'description': 'test description', - 'enabled': True, - 'pack': 'wolfpack', - 'entry_point': '/tmp/test/action1.sh', - 'runner_type': 'local-shell-script', - 'parameters': { - 'cmd': {'type': 'string', 'default': 'A1'}, - 'b': {'type': 'string', 'default': 'B1'} - } + "name": "st2.dummy.action8", + "description": "test description", + "enabled": True, + "pack": "wolfpack", + "entry_point": "/tmp/test/action1.sh", + "runner_type": "local-shell-script", + "parameters": { + "cmd": {"type": "string", "default": "A1"}, + "b": {"type": "string", "default": "B1"}, + }, } # ACTION_9: Parameter dict has fields not part of JSONSchema spec. ACTION_9 = { - 'name': 'st2.dummy.action9', - 'description': 'test description', - 'enabled': True, - 'pack': 'wolfpack', - 'entry_point': '/tmp/test/action1.sh', - 'runner_type': 'local-shell-script', - 'parameters': { - 'a': {'type': 'string', 'default': 'A1', 'dummyfield': True}, # dummyfield is invalid. - 'b': {'type': 'string', 'default': 'B1'} - } + "name": "st2.dummy.action9", + "description": "test description", + "enabled": True, + "pack": "wolfpack", + "entry_point": "/tmp/test/action1.sh", + "runner_type": "local-shell-script", + "parameters": { + "a": { + "type": "string", + "default": "A1", + "dummyfield": True, + }, # dummyfield is invalid. + "b": {"type": "string", "default": "B1"}, + }, } # Same name as ACTION_1. Different pack though. # Ensure that this remains the only action with pack == wolfpack1, # otherwise take care of the test test_get_one_using_pack_parameter ACTION_10 = { - 'name': 'st2.dummy.action1', - 'description': 'test description', - 'enabled': True, - 'pack': 'wolfpack1', - 'entry_point': '/tmp/test/action1.sh', - 'runner_type': 'local-shell-script', - 'parameters': { - 'a': {'type': 'string', 'default': 'A1'}, - 'b': {'type': 'string', 'default': 'B1'} - } + "name": "st2.dummy.action1", + "description": "test description", + "enabled": True, + "pack": "wolfpack1", + "entry_point": "/tmp/test/action1.sh", + "runner_type": "local-shell-script", + "parameters": { + "a": {"type": "string", "default": "A1"}, + "b": {"type": "string", "default": "B1"}, + }, } # Good action with a system pack ACTION_11 = { - 'name': 'st2.dummy.action11', - 'pack': SYSTEM_PACK_NAME, - 'description': 'test description', - 'enabled': True, - 'entry_point': '/tmp/test/action2.py', - 'runner_type': 'local-shell-script', - 'parameters': { - 'c': {'type': 'string', 'default': 'C1', 'position': 0}, - 'd': {'type': 'string', 'default': 'D1', 'immutable': True} - } + "name": "st2.dummy.action11", + "pack": SYSTEM_PACK_NAME, + "description": "test description", + "enabled": True, + "entry_point": "/tmp/test/action2.py", + "runner_type": "local-shell-script", + "parameters": { + "c": {"type": "string", "default": "C1", "position": 0}, + "d": {"type": "string", "default": "D1", "immutable": True}, + }, } # Good action inside dummy pack ACTION_12 = { - 'name': 'st2.dummy.action1', - 'description': 'test description', - 'enabled': True, - 'pack': 'dummy_pack_1', - 'entry_point': '/tmp/test/action1.sh', - 'runner_type': 'local-shell-script', - 'parameters': { - 'a': {'type': 'string', 'default': 'A1'}, - 'b': {'type': 'string', 'default': 'B1'} + "name": "st2.dummy.action1", + "description": "test description", + "enabled": True, + "pack": "dummy_pack_1", + "entry_point": "/tmp/test/action1.sh", + "runner_type": "local-shell-script", + "parameters": { + "a": {"type": "string", "default": "A1"}, + "b": {"type": "string", "default": "B1"}, }, - 'tags': [ - {'name': 'tag1', 'value': 'dont-care'}, - {'name': 'tag2', 'value': 'dont-care'} - ] + "tags": [ + {"name": "tag1", "value": "dont-care"}, + {"name": "tag2", "value": "dont-care"}, + ], } # Action with invalid parameter type attribute ACTION_13 = { - 'name': 'st2.dummy.action2', - 'description': 'test description', - 'enabled': True, - 'pack': 'dummy_pack_1', - 'entry_point': '/tmp/test/action1.sh', - 'runner_type': 'local-shell-script', - 'parameters': { - 'a': {'type': ['string', 'object'], 'default': 'A1'}, - 'b': {'type': 'string', 'default': 'B1'} - } + "name": "st2.dummy.action2", + "description": "test description", + "enabled": True, + "pack": "dummy_pack_1", + "entry_point": "/tmp/test/action1.sh", + "runner_type": "local-shell-script", + "parameters": { + "a": {"type": ["string", "object"], "default": "A1"}, + "b": {"type": "string", "default": "B1"}, + }, } ACTION_14 = { - 'name': 'st2.dummy.action14', - 'description': 'test description', - 'enabled': True, - 'pack': 'dummy_pack_1', - 'entry_point': '/tmp/test/action1.sh', - 'runner_type': 'local-shell-script', - 'parameters': { - 'a': {'type': 'string', 'default': 'A1'}, - 'b': {'type': 'string', 'default': 'B1'}, - 'sudo': {'type': 'string'} - } + "name": "st2.dummy.action14", + "description": "test description", + "enabled": True, + "pack": "dummy_pack_1", + "entry_point": "/tmp/test/action1.sh", + "runner_type": "local-shell-script", + "parameters": { + "a": {"type": "string", "default": "A1"}, + "b": {"type": "string", "default": "B1"}, + "sudo": {"type": "string"}, + }, } ACTION_15 = { - 'name': 'st2.dummy.action15', - 'description': 'test description', - 'enabled': True, - 'pack': 'dummy_pack_1', - 'entry_point': '/tmp/test/action1.sh', - 'runner_type': 'local-shell-script', - 'parameters': { - 'a': {'type': 'string', 'default': 'A1'}, - 'b': {'type': 'string', 'default': 'B1'}, - 'sudo': {'default': True, 'immutable': True} - } + "name": "st2.dummy.action15", + "description": "test description", + "enabled": True, + "pack": "dummy_pack_1", + "entry_point": "/tmp/test/action1.sh", + "runner_type": "local-shell-script", + "parameters": { + "a": {"type": "string", "default": "A1"}, + "b": {"type": "string", "default": "B1"}, + "sudo": {"default": True, "immutable": True}, + }, } ACTION_WITH_NOTIFY = { - 'name': 'st2.dummy.action_notify_test', - 'description': 'test description', - 'enabled': True, - 'pack': 'dummy_pack_1', - 'entry_point': '/tmp/test/action1.sh', - 'runner_type': 'local-shell-script', - 'parameters': { - 'a': {'type': 'string', 'default': 'A1'}, - 'b': {'type': 'string', 'default': 'B1'}, - 'sudo': {'default': True, 'immutable': True} + "name": "st2.dummy.action_notify_test", + "description": "test description", + "enabled": True, + "pack": "dummy_pack_1", + "entry_point": "/tmp/test/action1.sh", + "runner_type": "local-shell-script", + "parameters": { + "a": {"type": "string", "default": "A1"}, + "b": {"type": "string", "default": "B1"}, + "sudo": {"default": True, "immutable": True}, }, - 'notify': { - 'on-complete': { - 'message': 'Woohoo! I completed!!!' - } - } + "notify": {"on-complete": {"message": "Woohoo! I completed!!!"}}, } -class ActionsControllerTestCase(FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase, - CleanFilesTestCase): - get_all_path = '/v1/actions' +class ActionsControllerTestCase( + FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase, CleanFilesTestCase +): + get_all_path = "/v1/actions" controller_cls = ActionsController - include_attribute_field_name = 'entry_point' - exclude_attribute_field_name = 'parameters' + include_attribute_field_name = "entry_point" + exclude_attribute_field_name = "parameters" register_packs = True to_delete_files = [ - os.path.join(get_fixtures_packs_base_path(), 'dummy_pack_1/actions/filea.txt') + os.path.join(get_fixtures_packs_base_path(), "dummy_pack_1/actions/filea.txt") ] - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_get_one_using_id(self): post_resp = self.__do_post(ACTION_1) action_id = self.__get_action_id(post_resp) @@ -300,146 +302,169 @@ def test_get_one_using_id(self): self.assertEqual(self.__get_action_id(get_resp), action_id) self.__do_delete(action_id) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_get_one_using_ref(self): - ref = '.'.join([ACTION_1['pack'], ACTION_1['name']]) + ref = ".".join([ACTION_1["pack"], ACTION_1["name"]]) action_id = self.__get_action_id(self.__do_post(ACTION_1)) get_resp = self.__do_get_one(ref) self.assertEqual(get_resp.status_int, 200) self.assertEqual(self.__get_action_id(get_resp), action_id) - self.assertEqual(get_resp.json['ref'], ref) + self.assertEqual(get_resp.json["ref"], ref) self.__do_delete(action_id) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_get_one_validate_params(self): post_resp = self.__do_post(ACTION_1) action_id = self.__get_action_id(post_resp) get_resp = self.__do_get_one(action_id) self.assertEqual(get_resp.status_int, 200) self.assertEqual(self.__get_action_id(get_resp), action_id) - expected_args = ACTION_1['parameters'] - self.assertEqual(get_resp.json['parameters'], expected_args) + expected_args = ACTION_1["parameters"] + self.assertEqual(get_resp.json["parameters"], expected_args) self.__do_delete(action_id) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_get_all_and_with_minus_one(self): - action_1_ref = '.'.join([ACTION_1['pack'], ACTION_1['name']]) + action_1_ref = ".".join([ACTION_1["pack"], ACTION_1["name"]]) action_1_id = self.__get_action_id(self.__do_post(ACTION_1)) action_2_id = self.__get_action_id(self.__do_post(ACTION_2)) - resp = self.app.get('/v1/actions') + resp = self.app.get("/v1/actions") self.assertEqual(resp.status_int, 200) - self.assertEqual(len(resp.json), 2, '/v1/actions did not return all actions.') + self.assertEqual(len(resp.json), 2, "/v1/actions did not return all actions.") - item = [i for i in resp.json if i['id'] == action_1_id][0] - self.assertEqual(item['ref'], action_1_ref) + item = [i for i in resp.json if i["id"] == action_1_id][0] + self.assertEqual(item["ref"], action_1_ref) - resp = self.app.get('/v1/actions?limit=-1') + resp = self.app.get("/v1/actions?limit=-1") self.assertEqual(resp.status_int, 200) - self.assertEqual(len(resp.json), 2, '/v1/actions did not return all actions.') + self.assertEqual(len(resp.json), 2, "/v1/actions did not return all actions.") - item = [i for i in resp.json if i['id'] == action_1_id][0] - self.assertEqual(item['ref'], action_1_ref) + item = [i for i in resp.json if i["id"] == action_1_id][0] + self.assertEqual(item["ref"], action_1_ref) self.__do_delete(action_1_id) self.__do_delete(action_2_id) - @mock.patch('st2common.rbac.backends.noop.NoOpRBACUtils.user_is_admin', - mock.Mock(return_value=False)) + @mock.patch( + "st2common.rbac.backends.noop.NoOpRBACUtils.user_is_admin", + mock.Mock(return_value=False), + ) def test_get_all_invalid_limit_too_large_none_admin(self): # limit > max_page_size, but user is not admin - resp = self.app.get('/v1/actions?limit=1000', expect_errors=True) + resp = self.app.get("/v1/actions?limit=1000", expect_errors=True) self.assertEqual(resp.status_int, http_client.FORBIDDEN) - self.assertEqual(resp.json['faultstring'], 'Limit "1000" specified, maximum value is' - ' "100"') + self.assertEqual( + resp.json["faultstring"], + 'Limit "1000" specified, maximum value is' ' "100"', + ) def test_get_all_limit_negative_number(self): - resp = self.app.get('/v1/actions?limit=-22', expect_errors=True) + resp = self.app.get("/v1/actions?limit=-22", expect_errors=True) self.assertEqual(resp.status_int, 400) - self.assertEqual(resp.json['faultstring'], - u'Limit, "-22" specified, must be a positive number.') - - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + self.assertEqual( + resp.json["faultstring"], + 'Limit, "-22" specified, must be a positive number.', + ) + + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_get_all_include_attributes_filter(self): - return super(ActionsControllerTestCase, self).test_get_all_include_attributes_filter() + return super( + ActionsControllerTestCase, self + ).test_get_all_include_attributes_filter() - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_get_all_exclude_attributes_filter(self): - return super(ActionsControllerTestCase, self).test_get_all_include_attributes_filter() + return super( + ActionsControllerTestCase, self + ).test_get_all_include_attributes_filter() - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_query(self): action_1_id = self.__get_action_id(self.__do_post(ACTION_1)) action_2_id = self.__get_action_id(self.__do_post(ACTION_2)) - resp = self.app.get('/v1/actions?name=%s' % ACTION_1['name']) + resp = self.app.get("/v1/actions?name=%s" % ACTION_1["name"]) self.assertEqual(resp.status_int, 200) - self.assertEqual(len(resp.json), 1, '/v1/actions did not return all actions.') + self.assertEqual(len(resp.json), 1, "/v1/actions did not return all actions.") self.__do_delete(action_1_id) self.__do_delete(action_2_id) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_get_one_fail(self): - resp = self.app.get('/v1/actions/1', expect_errors=True) + resp = self.app.get("/v1/actions/1", expect_errors=True) self.assertEqual(resp.status_int, 404) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_post_delete(self): post_resp = self.__do_post(ACTION_1) self.assertEqual(post_resp.status_int, 201) self.__do_delete(self.__get_action_id(post_resp)) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_post_action_with_bad_params(self): post_resp = self.__do_post(ACTION_9, expect_errors=True) self.assertEqual(post_resp.status_int, 400) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_post_no_description_field(self): post_resp = self.__do_post(ACTION_6) self.assertEqual(post_resp.status_int, 201) self.__do_delete(self.__get_action_id(post_resp)) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_post_no_enable_field(self): post_resp = self.__do_post(ACTION_3) self.assertEqual(post_resp.status_int, 201) - self.assertIn(b'enabled', post_resp.body) + self.assertIn(b"enabled", post_resp.body) # If enabled field is not provided it should default to True data = json.loads(post_resp.body) - self.assertDictContainsSubset({'enabled': True}, data) + self.assertDictContainsSubset({"enabled": True}, data) self.__do_delete(self.__get_action_id(post_resp)) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_post_false_enable_field(self): post_resp = self.__do_post(ACTION_4) self.assertEqual(post_resp.status_int, 201) data = json.loads(post_resp.body) - self.assertDictContainsSubset({'enabled': False}, data) + self.assertDictContainsSubset({"enabled": False}, data) self.__do_delete(self.__get_action_id(post_resp)) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_post_name_unicode_action_already_exists(self): # Verify that exception messages containing unicode characters don't result in internal # server errors action = copy.deepcopy(ACTION_1) # NOTE: We explicitly don't prefix this string value with u"" - action['name'] = 'žactionćšžži💩' + action["name"] = "žactionćšžži💩" # 1. Initial creation post_resp = self.__do_post(action, expect_errors=True) @@ -448,54 +473,64 @@ def test_post_name_unicode_action_already_exists(self): # 2. Action already exists post_resp = self.__do_post(action, expect_errors=True) self.assertEqual(post_resp.status_int, 409) - self.assertIn('Tried to save duplicate unique keys', post_resp.json['faultstring']) + self.assertIn( + "Tried to save duplicate unique keys", post_resp.json["faultstring"] + ) # 3. Action already exists (this time with unicode type) - action['name'] = u'žactionćšžži💩' + action["name"] = "žactionćšžži💩" post_resp = self.__do_post(action, expect_errors=True) self.assertEqual(post_resp.status_int, 409) - self.assertIn('Tried to save duplicate unique keys', post_resp.json['faultstring']) + self.assertIn( + "Tried to save duplicate unique keys", post_resp.json["faultstring"] + ) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_post_parameter_type_is_array_and_invalid(self): post_resp = self.__do_post(ACTION_13, expect_errors=True) self.assertEqual(post_resp.status_int, 400) if six.PY3: - expected_error = b'[\'string\', \'object\'] is not valid under any of the given schemas' + expected_error = ( + b"['string', 'object'] is not valid under any of the given schemas" + ) else: - expected_error = \ - b'[u\'string\', u\'object\'] is not valid under any of the given schemas' + expected_error = ( + b"[u'string', u'object'] is not valid under any of the given schemas" + ) self.assertIn(expected_error, post_resp.body) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_post_discard_id_field(self): post_resp = self.__do_post(ACTION_7) self.assertEqual(post_resp.status_int, 201) - self.assertIn(b'id', post_resp.body) + self.assertIn(b"id", post_resp.body) data = json.loads(post_resp.body) # Verify that user-provided id is discarded. - self.assertNotEquals(data['id'], ACTION_7['id']) + self.assertNotEquals(data["id"], ACTION_7["id"]) self.__do_delete(self.__get_action_id(post_resp)) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_post_duplicate(self): action_ids = [] post_resp = self.__do_post(ACTION_1) self.assertEqual(post_resp.status_int, 201) - action_in_db = Action.get_by_name(ACTION_1.get('name')) - self.assertIsNotNone(action_in_db, 'Action must be in db.') + action_in_db = Action.get_by_name(ACTION_1.get("name")) + self.assertIsNotNone(action_in_db, "Action must be in db.") action_ids.append(self.__get_action_id(post_resp)) post_resp = self.__do_post(ACTION_1, expect_errors=True) # Verify name conflict self.assertEqual(post_resp.status_int, 409) - self.assertEqual(post_resp.json['conflict-id'], action_ids[0]) + self.assertEqual(post_resp.json["conflict-id"], action_ids[0]) post_resp = self.__do_post(ACTION_10) action_ids.append(self.__get_action_id(post_resp)) @@ -505,20 +540,16 @@ def test_post_duplicate(self): for i in action_ids: self.__do_delete(i) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_post_include_files(self): # Verify initial state - pack_db = Pack.get_by_ref(ACTION_12['pack']) - self.assertNotIn('actions/filea.txt', pack_db.files) + pack_db = Pack.get_by_ref(ACTION_12["pack"]) + self.assertNotIn("actions/filea.txt", pack_db.files) action = copy.deepcopy(ACTION_12) - action['data_files'] = [ - { - 'file_path': 'filea.txt', - 'content': 'test content' - } - ] + action["data_files"] = [{"file_path": "filea.txt", "content": "test content"}] post_resp = self.__do_post(action) # Verify file has been written on disk @@ -526,29 +557,30 @@ def test_post_include_files(self): self.assertTrue(os.path.exists(file_path)) # Verify PackDB.files has been updated - pack_db = Pack.get_by_ref(ACTION_12['pack']) - self.assertIn('actions/filea.txt', pack_db.files) + pack_db = Pack.get_by_ref(ACTION_12["pack"]) + self.assertIn("actions/filea.txt", pack_db.files) self.__do_delete(self.__get_action_id(post_resp)) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_post_put_delete(self): action = copy.copy(ACTION_1) post_resp = self.__do_post(action) self.assertEqual(post_resp.status_int, 201) - self.assertIn(b'id', post_resp.body) + self.assertIn(b"id", post_resp.body) body = json.loads(post_resp.body) - action['id'] = body['id'] - action['description'] = 'some other test description' - pack = action['pack'] - del action['pack'] - self.assertNotIn('pack', action) - put_resp = self.__do_put(action['id'], action) + action["id"] = body["id"] + action["description"] = "some other test description" + pack = action["pack"] + del action["pack"] + self.assertNotIn("pack", action) + put_resp = self.__do_put(action["id"], action) self.assertEqual(put_resp.status_int, 200) - self.assertIn(b'description', put_resp.body) + self.assertIn(b"description", put_resp.body) body = json.loads(put_resp.body) - self.assertEqual(body['description'], action['description']) - self.assertEqual(body['pack'], pack) + self.assertEqual(body["description"], action["description"]) + self.assertEqual(body["pack"], pack) delete_resp = self.__do_delete(self.__get_action_id(post_resp)) self.assertEqual(delete_resp.status_int, 204) @@ -559,94 +591,107 @@ def test_post_invalid_runner_type(self): def test_post_override_runner_param_not_allowed(self): post_resp = self.__do_post(ACTION_14, expect_errors=True) self.assertEqual(post_resp.status_int, 400) - expected = ('The attribute "type" for the runner parameter "sudo" ' - 'in action "dummy_pack_1.st2.dummy.action14" cannot be overridden.') - self.assertEqual(post_resp.json.get('faultstring'), expected) + expected = ( + 'The attribute "type" for the runner parameter "sudo" ' + 'in action "dummy_pack_1.st2.dummy.action14" cannot be overridden.' + ) + self.assertEqual(post_resp.json.get("faultstring"), expected) def test_post_override_runner_param_allowed(self): post_resp = self.__do_post(ACTION_15) self.assertEqual(post_resp.status_int, 201) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_delete(self): post_resp = self.__do_post(ACTION_1) del_resp = self.__do_delete(self.__get_action_id(post_resp)) self.assertEqual(del_resp.status_int, 204) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_action_with_tags(self): post_resp = self.__do_post(ACTION_1) action_id = self.__get_action_id(post_resp) get_resp = self.__do_get_one(action_id) self.assertEqual(get_resp.status_int, 200) self.assertEqual(self.__get_action_id(get_resp), action_id) - self.assertEqual(get_resp.json['tags'], ACTION_1['tags']) + self.assertEqual(get_resp.json["tags"], ACTION_1["tags"]) self.__do_delete(action_id) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_action_with_notify_update(self): post_resp = self.__do_post(ACTION_WITH_NOTIFY) action_id = self.__get_action_id(post_resp) get_resp = self.__do_get_one(action_id) self.assertEqual(get_resp.status_int, 200) self.assertEqual(self.__get_action_id(get_resp), action_id) - self.assertIsNotNone(get_resp.json['notify']['on-complete']) + self.assertIsNotNone(get_resp.json["notify"]["on-complete"]) # Now post the same action with no notify ACTION_WITHOUT_NOTIFY = copy.copy(ACTION_WITH_NOTIFY) - del ACTION_WITHOUT_NOTIFY['notify'] + del ACTION_WITHOUT_NOTIFY["notify"] self.__do_put(action_id, ACTION_WITHOUT_NOTIFY) # Validate that notify section has vanished get_resp = self.__do_get_one(action_id) - self.assertEqual(get_resp.json['notify'], {}) + self.assertEqual(get_resp.json["notify"], {}) self.__do_delete(action_id) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_get_one_using_name_parameter(self): action_id, action_name = self.__get_action_id_and_additional_attribute( - self.__do_post(ACTION_1), 'name') - get_resp = self.__do_get_actions_by_url_parameter('name', action_name) + self.__do_post(ACTION_1), "name" + ) + get_resp = self.__do_get_actions_by_url_parameter("name", action_name) self.assertEqual(get_resp.status_int, 200) - self.assertEqual(get_resp.json[0]['id'], action_id) - self.assertEqual(get_resp.json[0]['name'], action_name) + self.assertEqual(get_resp.json[0]["id"], action_id) + self.assertEqual(get_resp.json[0]["name"], action_name) self.__do_delete(action_id) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_get_one_using_pack_parameter(self): action_id, action_pack = self.__get_action_id_and_additional_attribute( - self.__do_post(ACTION_10), 'pack') - get_resp = self.__do_get_actions_by_url_parameter('pack', action_pack) + self.__do_post(ACTION_10), "pack" + ) + get_resp = self.__do_get_actions_by_url_parameter("pack", action_pack) self.assertEqual(get_resp.status_int, 200) - self.assertEqual(get_resp.json[0]['id'], action_id) - self.assertEqual(get_resp.json[0]['pack'], action_pack) + self.assertEqual(get_resp.json[0]["id"], action_id) + self.assertEqual(get_resp.json[0]["pack"], action_pack) self.__do_delete(action_id) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def test_get_one_using_tag_parameter(self): action_id, action_tags = self.__get_action_id_and_additional_attribute( - self.__do_post(ACTION_1), 'tags') - get_resp = self.__do_get_actions_by_url_parameter('tags', action_tags[0]['name']) + self.__do_post(ACTION_1), "tags" + ) + get_resp = self.__do_get_actions_by_url_parameter( + "tags", action_tags[0]["name"] + ) self.assertEqual(get_resp.status_int, 200) - self.assertEqual(get_resp.json[0]['id'], action_id) - self.assertEqual(get_resp.json[0]['tags'], action_tags) + self.assertEqual(get_resp.json[0]["id"], action_id) + self.assertEqual(get_resp.json[0]["tags"], action_tags) self.__do_delete(action_id) # TODO: Re-enable those tests after we ensure DB is flushed in setUp # and each test starts in a clean state - @unittest2.skip('Skip because of test polution') + @unittest2.skip("Skip because of test polution") def test_update_action_belonging_to_system_pack(self): post_resp = self.__do_post(ACTION_11) action_id = self.__get_action_id(post_resp) put_resp = self.__do_put(action_id, ACTION_11, expect_errors=True) self.assertEqual(put_resp.status_int, 400) - @unittest2.skip('Skip because of test polution') + @unittest2.skip("Skip because of test polution") def test_delete_action_belonging_to_system_pack(self): post_resp = self.__do_post(ACTION_11) action_id = self.__get_action_id(post_resp) @@ -664,31 +709,37 @@ def _do_delete(self, action_id, expect_errors=False): @staticmethod def __get_action_id(resp): - return resp.json['id'] + return resp.json["id"] @staticmethod def __get_action_name(resp): - return resp.json['name'] + return resp.json["name"] @staticmethod def __get_action_tags(resp): - return resp.json['tags'] + return resp.json["tags"] @staticmethod def __get_action_id_and_additional_attribute(resp, attribute): - return resp.json['id'], resp.json[attribute] + return resp.json["id"], resp.json[attribute] def __do_get_one(self, action_id, expect_errors=False): - return self.app.get('/v1/actions/%s' % action_id, expect_errors=expect_errors) + return self.app.get("/v1/actions/%s" % action_id, expect_errors=expect_errors) def __do_get_actions_by_url_parameter(self, filter, value, expect_errors=False): - return self.app.get('/v1/actions?%s=%s' % (filter, value), expect_errors=expect_errors) + return self.app.get( + "/v1/actions?%s=%s" % (filter, value), expect_errors=expect_errors + ) def __do_post(self, action, expect_errors=False): - return self.app.post_json('/v1/actions', action, expect_errors=expect_errors) + return self.app.post_json("/v1/actions", action, expect_errors=expect_errors) def __do_put(self, action_id, action, expect_errors=False): - return self.app.put_json('/v1/actions/%s' % action_id, action, expect_errors=expect_errors) + return self.app.put_json( + "/v1/actions/%s" % action_id, action, expect_errors=expect_errors + ) def __do_delete(self, action_id, expect_errors=False): - return self.app.delete('/v1/actions/%s' % action_id, expect_errors=expect_errors) + return self.app.delete( + "/v1/actions/%s" % action_id, expect_errors=expect_errors + ) diff --git a/st2api/tests/unit/controllers/v1/test_alias_execution.py b/st2api/tests/unit/controllers/v1/test_alias_execution.py index e7b1827f319..9806a468642 100644 --- a/st2api/tests/unit/controllers/v1/test_alias_execution.py +++ b/st2api/tests/unit/controllers/v1/test_alias_execution.py @@ -24,29 +24,32 @@ from st2tests.fixturesloader import FixturesLoader from st2tests.api import FunctionalTest -FIXTURES_PACK = 'aliases' +FIXTURES_PACK = "aliases" TEST_MODELS = { - 'aliases': ['alias1.yaml', 'alias2.yaml', 'alias_with_undefined_jinja_in_ack_format.yaml', - 'alias_with_immutable_list_param.yaml', - 'alias_with_immutable_list_param_str_cast.yaml', - 'alias4.yaml', 'alias5.yaml', 'alias_fixes1.yaml', 'alias_fixes2.yaml', - 'alias_match_multiple.yaml'], - 'actions': ['action1.yaml', 'action2.yaml', 'action3.yaml', 'action4.yaml'], - 'runners': ['runner1.yaml'] + "aliases": [ + "alias1.yaml", + "alias2.yaml", + "alias_with_undefined_jinja_in_ack_format.yaml", + "alias_with_immutable_list_param.yaml", + "alias_with_immutable_list_param_str_cast.yaml", + "alias4.yaml", + "alias5.yaml", + "alias_fixes1.yaml", + "alias_fixes2.yaml", + "alias_match_multiple.yaml", + ], + "actions": ["action1.yaml", "action2.yaml", "action3.yaml", "action4.yaml"], + "runners": ["runner1.yaml"], } -TEST_LOAD_MODELS = { - 'aliases': ['alias3.yaml'] -} +TEST_LOAD_MODELS = {"aliases": ["alias3.yaml"]} -EXECUTION = ActionExecutionDB(id='54e657d60640fd16887d6855', - status=LIVEACTION_STATUS_SUCCEEDED, - result='') +EXECUTION = ActionExecutionDB( + id="54e657d60640fd16887d6855", status=LIVEACTION_STATUS_SUCCEEDED, result="" +) -__all__ = [ - 'AliasExecutionTestCase' -] +__all__ = ["AliasExecutionTestCase"] class AliasExecutionTestCase(FunctionalTest): @@ -59,193 +62,217 @@ class AliasExecutionTestCase(FunctionalTest): @classmethod def setUpClass(cls): super(AliasExecutionTestCase, cls).setUpClass() - cls.models = FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_MODELS) - - cls.runner1 = cls.models['runners']['runner1.yaml'] - cls.action1 = cls.models['actions']['action1.yaml'] - cls.alias1 = cls.models['aliases']['alias1.yaml'] - cls.alias2 = cls.models['aliases']['alias2.yaml'] - cls.alias4 = cls.models['aliases']['alias4.yaml'] - cls.alias5 = cls.models['aliases']['alias5.yaml'] - cls.alias_with_undefined_jinja_in_ack_format = \ - cls.models['aliases']['alias_with_undefined_jinja_in_ack_format.yaml'] - - @mock.patch.object(action_service, 'request', - return_value=(None, EXECUTION)) + cls.models = FixturesLoader().save_fixtures_to_db( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS + ) + + cls.runner1 = cls.models["runners"]["runner1.yaml"] + cls.action1 = cls.models["actions"]["action1.yaml"] + cls.alias1 = cls.models["aliases"]["alias1.yaml"] + cls.alias2 = cls.models["aliases"]["alias2.yaml"] + cls.alias4 = cls.models["aliases"]["alias4.yaml"] + cls.alias5 = cls.models["aliases"]["alias5.yaml"] + cls.alias_with_undefined_jinja_in_ack_format = cls.models["aliases"][ + "alias_with_undefined_jinja_in_ack_format.yaml" + ] + + @mock.patch.object(action_service, "request", return_value=(None, EXECUTION)) def test_basic_execution(self, request): command = 'Lorem ipsum value1 dolor sit "value2 value3" amet.' post_resp = self._do_post(alias_execution=self.alias1, command=command) self.assertEqual(post_resp.status_int, 201) - expected_parameters = {'param1': 'value1', 'param2': 'value2 value3'} + expected_parameters = {"param1": "value1", "param2": "value2 value3"} self.assertEqual(request.call_args[0][0].parameters, expected_parameters) - @mock.patch.object(action_service, 'request', - return_value=(None, EXECUTION)) + @mock.patch.object(action_service, "request", return_value=(None, EXECUTION)) def test_basic_execution_with_immutable_parameters(self, request): - command = 'lorem ipsum' + command = "lorem ipsum" post_resp = self._do_post(alias_execution=self.alias5, command=command) self.assertEqual(post_resp.status_int, 201) - expected_parameters = {'param1': 'value1', 'param2': 'value2'} + expected_parameters = {"param1": "value1", "param2": "value2"} self.assertEqual(request.call_args[0][0].parameters, expected_parameters) - @mock.patch.object(action_service, 'request', - return_value=(None, EXECUTION)) + @mock.patch.object(action_service, "request", return_value=(None, EXECUTION)) def test_invalid_format_string_referenced_in_request(self, request): command = 'Lorem ipsum value1 dolor sit "value2 value3" amet.' - format_str = 'some invalid not supported string' - post_resp = self._do_post(alias_execution=self.alias1, command=command, - format_str=format_str, expect_errors=True) + format_str = "some invalid not supported string" + post_resp = self._do_post( + alias_execution=self.alias1, + command=command, + format_str=format_str, + expect_errors=True, + ) self.assertEqual(post_resp.status_int, 400) - expected_msg = ('Format string "some invalid not supported string" is ' - 'not available on the alias "alias1"') - self.assertIn(expected_msg, post_resp.json['faultstring']) + expected_msg = ( + 'Format string "some invalid not supported string" is ' + 'not available on the alias "alias1"' + ) + self.assertIn(expected_msg, post_resp.json["faultstring"]) - @mock.patch.object(action_service, 'request', - return_value=(None, EXECUTION)) + @mock.patch.object(action_service, "request", return_value=(None, EXECUTION)) def test_execution_with_array_type_single_value(self, request): - command = 'Lorem ipsum value1 dolor sit value2 amet.' + command = "Lorem ipsum value1 dolor sit value2 amet." self._do_post(alias_execution=self.alias2, command=command) - expected_parameters = {'param1': 'value1', 'param3': ['value2']} + expected_parameters = {"param1": "value1", "param3": ["value2"]} self.assertEqual(request.call_args[0][0].parameters, expected_parameters) - @mock.patch.object(action_service, 'request', - return_value=(None, EXECUTION)) + @mock.patch.object(action_service, "request", return_value=(None, EXECUTION)) def test_execution_with_array_type_multi_value(self, request): command = 'Lorem ipsum value1 dolor sit "value2, value3" amet.' post_resp = self._do_post(alias_execution=self.alias2, command=command) self.assertEqual(post_resp.status_int, 201) - expected_parameters = {'param1': 'value1', 'param3': ['value2', 'value3']} + expected_parameters = {"param1": "value1", "param3": ["value2", "value3"]} self.assertEqual(request.call_args[0][0].parameters, expected_parameters) - @mock.patch.object(action_service, 'request', - return_value=(None, EXECUTION)) + @mock.patch.object(action_service, "request", return_value=(None, EXECUTION)) def test_invalid_jinja_var_in_ack_format(self, request): - command = 'run date on localhost' + command = "run date on localhost" # print(self.alias_with_undefined_jinja_in_ack_format) post_resp = self._do_post( alias_execution=self.alias_with_undefined_jinja_in_ack_format, command=command, - expect_errors=False + expect_errors=False, ) self.assertEqual(post_resp.status_int, 201) - expected_parameters = {'cmd': 'date', 'hosts': 'localhost'} + expected_parameters = {"cmd": "date", "hosts": "localhost"} self.assertEqual(request.call_args[0][0].parameters, expected_parameters) self.assertEqual( - post_resp.json['message'], - 'Cannot render "format" in field "ack" for alias. \'cmd\' is undefined' + post_resp.json["message"], + 'Cannot render "format" in field "ack" for alias. \'cmd\' is undefined', ) - @mock.patch.object(action_service, 'request') + @mock.patch.object(action_service, "request") def test_execution_secret_parameter(self, request): - execution = ActionExecutionDB(id='54e657d60640fd16887d6855', - status=LIVEACTION_STATUS_SUCCEEDED, - action={'parameters': self.action1.parameters}, - runner={'runner_parameters': self.runner1.runner_parameters}, - parameters={ - 'param4': SUPER_SECRET_PARAMETER - }, - result='') + execution = ActionExecutionDB( + id="54e657d60640fd16887d6855", + status=LIVEACTION_STATUS_SUCCEEDED, + action={"parameters": self.action1.parameters}, + runner={"runner_parameters": self.runner1.runner_parameters}, + parameters={"param4": SUPER_SECRET_PARAMETER}, + result="", + ) request.return_value = (None, execution) - command = 'Lorem ipsum value1 dolor sit ' + SUPER_SECRET_PARAMETER + ' amet.' + command = "Lorem ipsum value1 dolor sit " + SUPER_SECRET_PARAMETER + " amet." post_resp = self._do_post(alias_execution=self.alias4, command=command) self.assertEqual(post_resp.status_int, 201) - expected_parameters = {'param1': 'value1', 'param4': SUPER_SECRET_PARAMETER} + expected_parameters = {"param1": "value1", "param4": SUPER_SECRET_PARAMETER} self.assertEqual(request.call_args[0][0].parameters, expected_parameters) - post_resp = self._do_post(alias_execution=self.alias4, command=command, show_secrets=True, - expect_errors=True) + post_resp = self._do_post( + alias_execution=self.alias4, + command=command, + show_secrets=True, + expect_errors=True, + ) self.assertEqual(post_resp.status_int, 201) - self.assertEqual(post_resp.json['execution']['parameters']['param4'], - SUPER_SECRET_PARAMETER) + self.assertEqual( + post_resp.json["execution"]["parameters"]["param4"], SUPER_SECRET_PARAMETER + ) - @mock.patch.object(action_service, 'request', - return_value=(None, EXECUTION)) + @mock.patch.object(action_service, "request", return_value=(None, EXECUTION)) def test_match_and_execute_doesnt_match(self, mock_request): base_data = { - 'source_channel': 'chat', - 'notification_route': 'hubot', - 'user': 'chat-user' + "source_channel": "chat", + "notification_route": "hubot", + "user": "chat-user", } # Command doesnt match any patterns data = copy.deepcopy(base_data) - data['command'] = 'hello donny' - resp = self.app.post_json('/v1/aliasexecution/match_and_execute', data, expect_errors=True) + data["command"] = "hello donny" + resp = self.app.post_json( + "/v1/aliasexecution/match_and_execute", data, expect_errors=True + ) self.assertEqual(resp.status_int, 400) - self.assertEqual(str(resp.json['faultstring']), - "Command 'hello donny' matched no patterns") + self.assertEqual( + str(resp.json["faultstring"]), "Command 'hello donny' matched no patterns" + ) - @mock.patch.object(action_service, 'request', - return_value=(None, EXECUTION)) + @mock.patch.object(action_service, "request", return_value=(None, EXECUTION)) def test_match_and_execute_matches_many(self, mock_request): base_data = { - 'source_channel': 'chat', - 'notification_route': 'hubot', - 'user': 'chat-user' + "source_channel": "chat", + "notification_route": "hubot", + "user": "chat-user", } # Command matches more than one pattern data = copy.deepcopy(base_data) - data['command'] = 'Lorem ipsum banana dolor sit pineapple amet.' - resp = self.app.post_json('/v1/aliasexecution/match_and_execute', data, expect_errors=True) + data["command"] = "Lorem ipsum banana dolor sit pineapple amet." + resp = self.app.post_json( + "/v1/aliasexecution/match_and_execute", data, expect_errors=True + ) self.assertEqual(resp.status_int, 400) - self.assertEqual(str(resp.json['faultstring']), - "Command 'Lorem ipsum banana dolor sit pineapple amet.' " - "matched more than 1 pattern") + self.assertEqual( + str(resp.json["faultstring"]), + "Command 'Lorem ipsum banana dolor sit pineapple amet.' " + "matched more than 1 pattern", + ) - @mock.patch.object(action_service, 'request', - return_value=(None, EXECUTION)) + @mock.patch.object(action_service, "request", return_value=(None, EXECUTION)) def test_match_and_execute_matches_one(self, mock_request): base_data = { - 'source_channel': 'chat-channel', - 'notification_route': 'hubot', - 'user': 'chat-user', + "source_channel": "chat-channel", + "notification_route": "hubot", + "user": "chat-user", } # Command matches - should result in action execution data = copy.deepcopy(base_data) - data['command'] = 'run date on localhost' - resp = self.app.post_json('/v1/aliasexecution/match_and_execute', data) + data["command"] = "run date on localhost" + resp = self.app.post_json("/v1/aliasexecution/match_and_execute", data) self.assertEqual(resp.status_int, 201) - self.assertEqual(len(resp.json['results']), 1) - self.assertEqual(resp.json['results'][0]['execution']['id'], str(EXECUTION['id'])) - self.assertEqual(resp.json['results'][0]['execution']['status'], EXECUTION['status']) + self.assertEqual(len(resp.json["results"]), 1) + self.assertEqual( + resp.json["results"][0]["execution"]["id"], str(EXECUTION["id"]) + ) + self.assertEqual( + resp.json["results"][0]["execution"]["status"], EXECUTION["status"] + ) - expected_parameters = {'cmd': 'date', 'hosts': 'localhost'} + expected_parameters = {"cmd": "date", "hosts": "localhost"} self.assertEqual(mock_request.call_args[0][0].parameters, expected_parameters) # Also check for source_channel - see # https://github.com/StackStorm/st2/issues/4650 actual_context = mock_request.call_args[0][0].context - self.assertIn('source_channel', mock_request.call_args[0][0].context.keys()) - self.assertEqual(actual_context['source_channel'], 'chat-channel') - self.assertEqual(actual_context['api_user'], 'chat-user') - self.assertEqual(actual_context['user'], 'stanley') + self.assertIn("source_channel", mock_request.call_args[0][0].context.keys()) + self.assertEqual(actual_context["source_channel"], "chat-channel") + self.assertEqual(actual_context["api_user"], "chat-user") + self.assertEqual(actual_context["user"], "stanley") - @mock.patch.object(action_service, 'request', - return_value=(None, EXECUTION)) + @mock.patch.object(action_service, "request", return_value=(None, EXECUTION)) def test_match_and_execute_matches_one_multiple_match(self, mock_request): base_data = { - 'source_channel': 'chat', - 'notification_route': 'hubot', - 'user': 'chat-user' + "source_channel": "chat", + "notification_route": "hubot", + "user": "chat-user", } # Command matches multiple times - should result in multiple action execution data = copy.deepcopy(base_data) - data['command'] = ('JKROWLING-4 is a duplicate of JRRTOLKIEN-24 which ' - 'is a duplicate of DRSEUSS-12') - resp = self.app.post_json('/v1/aliasexecution/match_and_execute', data) + data["command"] = ( + "JKROWLING-4 is a duplicate of JRRTOLKIEN-24 which " + "is a duplicate of DRSEUSS-12" + ) + resp = self.app.post_json("/v1/aliasexecution/match_and_execute", data) self.assertEqual(resp.status_int, 201) - self.assertEqual(len(resp.json['results']), 2) - self.assertEqual(resp.json['results'][0]['execution']['id'], str(EXECUTION['id'])) - self.assertEqual(resp.json['results'][0]['execution']['status'], EXECUTION['status']) - self.assertEqual(resp.json['results'][1]['execution']['id'], str(EXECUTION['id'])) - self.assertEqual(resp.json['results'][1]['execution']['status'], EXECUTION['status']) + self.assertEqual(len(resp.json["results"]), 2) + self.assertEqual( + resp.json["results"][0]["execution"]["id"], str(EXECUTION["id"]) + ) + self.assertEqual( + resp.json["results"][0]["execution"]["status"], EXECUTION["status"] + ) + self.assertEqual( + resp.json["results"][1]["execution"]["id"], str(EXECUTION["id"]) + ) + self.assertEqual( + resp.json["results"][1]["execution"]["status"], EXECUTION["status"] + ) # The mock object only stores the parameters of the _last_ time it was called, so that's # what we assert on. Luckily re.finditer() processes groups in order, so if this was the @@ -255,34 +282,39 @@ def test_match_and_execute_matches_one_multiple_match(self, mock_request): # # We've also already checked the results array # - expected_parameters = {'issue_key': 'DRSEUSS-12'} + expected_parameters = {"issue_key": "DRSEUSS-12"} self.assertEqual(mock_request.call_args[0][0].parameters, expected_parameters) - @mock.patch.object(action_service, 'request', - return_value=(None, EXECUTION)) + @mock.patch.object(action_service, "request", return_value=(None, EXECUTION)) def test_match_and_execute_matches_many_multiple_match(self, mock_request): base_data = { - 'source_channel': 'chat', - 'notification_route': 'hubot', - 'user': 'chat-user' + "source_channel": "chat", + "notification_route": "hubot", + "user": "chat-user", } # Command matches multiple times - should result in multiple action execution data = copy.deepcopy(base_data) - data['command'] = 'JKROWLING-4 fixes JRRTOLKIEN-24 which fixes DRSEUSS-12' - resp = self.app.post_json('/v1/aliasexecution/match_and_execute', data, expect_errors=True) + data["command"] = "JKROWLING-4 fixes JRRTOLKIEN-24 which fixes DRSEUSS-12" + resp = self.app.post_json( + "/v1/aliasexecution/match_and_execute", data, expect_errors=True + ) self.assertEqual(resp.status_int, 400) - self.assertEqual(str(resp.json['faultstring']), - "Command '{command}' " - "matched more than 1 (multi) pattern".format(command=data['command'])) + self.assertEqual( + str(resp.json["faultstring"]), + "Command '{command}' " + "matched more than 1 (multi) pattern".format(command=data["command"]), + ) def test_match_and_execute_list_action_param_str_cast_to_list(self): data = { - 'command': 'test alias list param str cast', - 'source_channel': 'hubot', - 'user': 'foo', + "command": "test alias list param str cast", + "source_channel": "hubot", + "user": "foo", } - resp = self.app.post_json("/v1/aliasexecution/match_and_execute", data, expect_errors=True) + resp = self.app.post_json( + "/v1/aliasexecution/match_and_execute", data, expect_errors=True + ) # Param is a comma delimited string - our custom cast function should cast it to a list. # I assume that was done to make specifying complex params in chat easier. @@ -300,15 +332,19 @@ def test_match_and_execute_list_action_param_str_cast_to_list(self): self.assertEqual(live_action["parameters"]["array_param"][1], "two") self.assertEqual(live_action["parameters"]["array_param"][2], "three") self.assertEqual(live_action["parameters"]["array_param"][3], "four") - self.assertTrue(isinstance(action_alias["immutable_parameters"]["array_param"], str)) + self.assertTrue( + isinstance(action_alias["immutable_parameters"]["array_param"], str) + ) def test_match_and_execute_list_action_param_already_a_list(self): data = { - 'command': 'test alias foo', - 'source_channel': 'hubot', - 'user': 'foo', + "command": "test alias foo", + "source_channel": "hubot", + "user": "foo", } - resp = self.app.post_json("/v1/aliasexecution/match_and_execute", data, expect_errors=True) + resp = self.app.post_json( + "/v1/aliasexecution/match_and_execute", data, expect_errors=True + ) # immutable_param is already a list - verify no casting is performed self.assertEqual(resp.status_int, 201) @@ -323,37 +359,53 @@ def test_match_and_execute_list_action_param_already_a_list(self): self.assertEqual(live_action["parameters"]["array_param"][0]["key2"], "two") self.assertEqual(live_action["parameters"]["array_param"][1]["key3"], "three") self.assertEqual(live_action["parameters"]["array_param"][1]["key4"], "four") - self.assertTrue(isinstance(action_alias["immutable_parameters"]["array_param"], list)) + self.assertTrue( + isinstance(action_alias["immutable_parameters"]["array_param"], list) + ) def test_match_and_execute_success(self): data = { - 'command': 'run whoami on localhost1', - 'source_channel': 'hubot', - 'user': "user", + "command": "run whoami on localhost1", + "source_channel": "hubot", + "user": "user", } resp = self.app.post_json("/v1/aliasexecution/match_and_execute", data) self.assertEqual(resp.status_int, 201) self.assertEqual(len(resp.json["results"]), 1) - self.assertTrue(resp.json["results"][0]["actionalias"]["ref"], - "aliases.alias_with_undefined_jinja_in_ack_format") - - def _do_post(self, alias_execution, command, format_str=None, expect_errors=False, - show_secrets=False): - if (isinstance(alias_execution.formats[0], dict) and - alias_execution.formats[0].get('representation')): - representation = alias_execution.formats[0].get('representation')[0] + self.assertTrue( + resp.json["results"][0]["actionalias"]["ref"], + "aliases.alias_with_undefined_jinja_in_ack_format", + ) + + def _do_post( + self, + alias_execution, + command, + format_str=None, + expect_errors=False, + show_secrets=False, + ): + if isinstance(alias_execution.formats[0], dict) and alias_execution.formats[ + 0 + ].get("representation"): + representation = alias_execution.formats[0].get("representation")[0] else: representation = alias_execution.formats[0] if not format_str: format_str = representation - execution = {'name': alias_execution.name, - 'format': format_str, - 'command': command, - 'user': 'stanley', - 'source_channel': 'test', - 'notification_route': 'test'} - url = show_secrets and '/v1/aliasexecution?show_secrets=true' or '/v1/aliasexecution' - return self.app.post_json(url, execution, - expect_errors=expect_errors) + execution = { + "name": alias_execution.name, + "format": format_str, + "command": command, + "user": "stanley", + "source_channel": "test", + "notification_route": "test", + } + url = ( + show_secrets + and "/v1/aliasexecution?show_secrets=true" + or "/v1/aliasexecution" + ) + return self.app.post_json(url, execution, expect_errors=expect_errors) diff --git a/st2api/tests/unit/controllers/v1/test_auth.py b/st2api/tests/unit/controllers/v1/test_auth.py index fb5a203929c..d6f3602c3c5 100644 --- a/st2api/tests/unit/controllers/v1/test_auth.py +++ b/st2api/tests/unit/controllers/v1/test_auth.py @@ -27,7 +27,7 @@ from st2tests.fixturesloader import FixturesLoader OBJ_ID = bson.ObjectId() -USER = 'stanley' +USER = "stanley" USER_DB = UserDB(name=USER) TOKEN = uuid.uuid4().hex NOW = date_utils.get_datetime_utc_now() @@ -40,67 +40,84 @@ class TestTokenBasedAuth(FunctionalTest): enable_auth = True @mock.patch.object( - Token, 'get', - mock.Mock(return_value=TokenDB(id=OBJ_ID, user=USER, token=TOKEN, expiry=FUTURE))) - @mock.patch.object(User, 'get_by_name', mock.Mock(return_value=USER_DB)) + Token, + "get", + mock.Mock( + return_value=TokenDB(id=OBJ_ID, user=USER, token=TOKEN, expiry=FUTURE) + ), + ) + @mock.patch.object(User, "get_by_name", mock.Mock(return_value=USER_DB)) def test_token_validation_token_in_headers(self): - response = self.app.get('/v1/actions', headers={'X-Auth-Token': TOKEN}, - expect_errors=False) - self.assertIn('application/json', response.headers['content-type']) + response = self.app.get( + "/v1/actions", headers={"X-Auth-Token": TOKEN}, expect_errors=False + ) + self.assertIn("application/json", response.headers["content-type"]) self.assertEqual(response.status_int, 200) @mock.patch.object( - Token, 'get', - mock.Mock(return_value=TokenDB(id=OBJ_ID, user=USER, token=TOKEN, expiry=FUTURE))) - @mock.patch.object(User, 'get_by_name', mock.Mock(return_value=USER_DB)) + Token, + "get", + mock.Mock( + return_value=TokenDB(id=OBJ_ID, user=USER, token=TOKEN, expiry=FUTURE) + ), + ) + @mock.patch.object(User, "get_by_name", mock.Mock(return_value=USER_DB)) def test_token_validation_token_in_query_params(self): - response = self.app.get('/v1/actions?x-auth-token=%s' % (TOKEN), expect_errors=False) - self.assertIn('application/json', response.headers['content-type']) + response = self.app.get( + "/v1/actions?x-auth-token=%s" % (TOKEN), expect_errors=False + ) + self.assertIn("application/json", response.headers["content-type"]) self.assertEqual(response.status_int, 200) @mock.patch.object( - Token, 'get', - mock.Mock(return_value=TokenDB(id=OBJ_ID, user=USER, token=TOKEN, expiry=FUTURE))) - @mock.patch.object(User, 'get_by_name', mock.Mock(return_value=USER_DB)) + Token, + "get", + mock.Mock( + return_value=TokenDB(id=OBJ_ID, user=USER, token=TOKEN, expiry=FUTURE) + ), + ) + @mock.patch.object(User, "get_by_name", mock.Mock(return_value=USER_DB)) def test_token_validation_token_in_cookies(self): - response = self.app.get('/v1/actions', headers={'X-Auth-Token': TOKEN}, - expect_errors=False) - self.assertIn('application/json', response.headers['content-type']) + response = self.app.get( + "/v1/actions", headers={"X-Auth-Token": TOKEN}, expect_errors=False + ) + self.assertIn("application/json", response.headers["content-type"]) self.assertEqual(response.status_int, 200) - with mock.patch.object(self.app.cookiejar, 'clear', return_value=None): - response = self.app.get('/v1/actions', expect_errors=False) - self.assertIn('application/json', response.headers['content-type']) + with mock.patch.object(self.app.cookiejar, "clear", return_value=None): + response = self.app.get("/v1/actions", expect_errors=False) + self.assertIn("application/json", response.headers["content-type"]) self.assertEqual(response.status_int, 200) @mock.patch.object( - Token, 'get', - mock.Mock(return_value=TokenDB(id=OBJ_ID, user=USER, token=TOKEN, expiry=PAST))) + Token, + "get", + mock.Mock(return_value=TokenDB(id=OBJ_ID, user=USER, token=TOKEN, expiry=PAST)), + ) def test_token_expired(self): - response = self.app.get('/v1/actions', headers={'X-Auth-Token': TOKEN}, - expect_errors=True) - self.assertIn('application/json', response.headers['content-type']) + response = self.app.get( + "/v1/actions", headers={"X-Auth-Token": TOKEN}, expect_errors=True + ) + self.assertIn("application/json", response.headers["content-type"]) self.assertEqual(response.status_int, 401) - @mock.patch.object( - Token, 'get', mock.MagicMock(side_effect=TokenNotFoundError())) + @mock.patch.object(Token, "get", mock.MagicMock(side_effect=TokenNotFoundError())) def test_token_not_found(self): - response = self.app.get('/v1/actions', headers={'X-Auth-Token': TOKEN}, - expect_errors=True) - self.assertIn('application/json', response.headers['content-type']) + response = self.app.get( + "/v1/actions", headers={"X-Auth-Token": TOKEN}, expect_errors=True + ) + self.assertIn("application/json", response.headers["content-type"]) self.assertEqual(response.status_int, 401) def test_token_not_provided(self): - response = self.app.get('/v1/actions', expect_errors=True) - self.assertIn('application/json', response.headers['content-type']) + response = self.app.get("/v1/actions", expect_errors=True) + self.assertIn("application/json", response.headers["content-type"]) self.assertEqual(response.status_int, 401) -FIXTURES_PACK = 'generic' +FIXTURES_PACK = "generic" -TEST_MODELS = { - 'apikeys': ['apikey1.yaml', 'apikey_disabled.yaml'] -} +TEST_MODELS = {"apikeys": ["apikey1.yaml", "apikey_disabled.yaml"]} # Hardcoded keys matching the fixtures. Lazy way to workound one-way hash and still use fixtures. KEY1_KEY = "1234" @@ -117,62 +134,83 @@ class TestApiKeyBasedAuth(FunctionalTest): @classmethod def setUpClass(cls): super(TestApiKeyBasedAuth, cls).setUpClass() - models = FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_MODELS) - cls.apikey1 = models['apikeys']['apikey1.yaml'] - cls.apikey_disabled = models['apikeys']['apikey_disabled.yaml'] + models = FixturesLoader().save_fixtures_to_db( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS + ) + cls.apikey1 = models["apikeys"]["apikey1.yaml"] + cls.apikey_disabled = models["apikeys"]["apikey_disabled.yaml"] - @mock.patch.object(User, 'get_by_name', mock.Mock(return_value=UserDB(name='bill'))) + @mock.patch.object(User, "get_by_name", mock.Mock(return_value=UserDB(name="bill"))) def test_apikey_validation_apikey_in_headers(self): - response = self.app.get('/v1/actions', headers={'St2-Api-key': KEY1_KEY}, - expect_errors=False) - self.assertIn('application/json', response.headers['content-type']) + response = self.app.get( + "/v1/actions", headers={"St2-Api-key": KEY1_KEY}, expect_errors=False + ) + self.assertIn("application/json", response.headers["content-type"]) self.assertEqual(response.status_int, 200) - @mock.patch.object(User, 'get_by_name', mock.Mock(return_value=UserDB(name='bill'))) + @mock.patch.object(User, "get_by_name", mock.Mock(return_value=UserDB(name="bill"))) def test_apikey_validation_apikey_in_query_params(self): - response = self.app.get('/v1/actions?st2-api-key=%s' % (KEY1_KEY), expect_errors=False) - self.assertIn('application/json', response.headers['content-type']) + response = self.app.get( + "/v1/actions?st2-api-key=%s" % (KEY1_KEY), expect_errors=False + ) + self.assertIn("application/json", response.headers["content-type"]) self.assertEqual(response.status_int, 200) - @mock.patch.object(User, 'get_by_name', mock.Mock(return_value=UserDB(name='bill'))) + @mock.patch.object(User, "get_by_name", mock.Mock(return_value=UserDB(name="bill"))) def test_apikey_validation_apikey_in_cookies(self): - response = self.app.get('/v1/actions', headers={'St2-Api-key': KEY1_KEY}, - expect_errors=False) - self.assertIn('application/json', response.headers['content-type']) + response = self.app.get( + "/v1/actions", headers={"St2-Api-key": KEY1_KEY}, expect_errors=False + ) + self.assertIn("application/json", response.headers["content-type"]) self.assertEqual(response.status_int, 200) - with mock.patch.object(self.app.cookiejar, 'clear', return_value=None): - response = self.app.get('/v1/actions', expect_errors=True) + with mock.patch.object(self.app.cookiejar, "clear", return_value=None): + response = self.app.get("/v1/actions", expect_errors=True) self.assertEqual(response.status_int, 401) - self.assertEqual(response.json_body['faultstring'], - 'Unauthorized - One of Token or API key required.') + self.assertEqual( + response.json_body["faultstring"], + "Unauthorized - One of Token or API key required.", + ) def test_apikey_disabled(self): - response = self.app.get('/v1/actions', headers={'St2-Api-key': DISABLED_KEY}, - expect_errors=True) - self.assertIn('application/json', response.headers['content-type']) + response = self.app.get( + "/v1/actions", headers={"St2-Api-key": DISABLED_KEY}, expect_errors=True + ) + self.assertIn("application/json", response.headers["content-type"]) self.assertEqual(response.status_int, 401) - self.assertEqual(response.json_body['faultstring'], 'Unauthorized - API key is disabled.') + self.assertEqual( + response.json_body["faultstring"], "Unauthorized - API key is disabled." + ) def test_apikey_not_found(self): - response = self.app.get('/v1/actions', headers={'St2-Api-key': 'UNKNOWN'}, - expect_errors=True) - self.assertIn('application/json', response.headers['content-type']) + response = self.app.get( + "/v1/actions", headers={"St2-Api-key": "UNKNOWN"}, expect_errors=True + ) + self.assertIn("application/json", response.headers["content-type"]) self.assertEqual(response.status_int, 401) - self.assertRegexpMatches(response.json_body['faultstring'], - '^Unauthorized - ApiKey with key_hash=([a-zA-Z0-9]+) not found.$') + self.assertRegexpMatches( + response.json_body["faultstring"], + "^Unauthorized - ApiKey with key_hash=([a-zA-Z0-9]+) not found.$", + ) @mock.patch.object( - Token, 'get', - mock.Mock(return_value=TokenDB(id=OBJ_ID, user=USER, token=TOKEN, expiry=FUTURE))) + Token, + "get", + mock.Mock( + return_value=TokenDB(id=OBJ_ID, user=USER, token=TOKEN, expiry=FUTURE) + ), + ) @mock.patch.object( - ApiKey, 'get', - mock.Mock(return_value=ApiKeyDB(user=USER, key_hash=KEY1_KEY, enabled=True))) - @mock.patch.object(User, 'get_by_name', mock.Mock(return_value=USER_DB)) + ApiKey, + "get", + mock.Mock(return_value=ApiKeyDB(user=USER, key_hash=KEY1_KEY, enabled=True)), + ) + @mock.patch.object(User, "get_by_name", mock.Mock(return_value=USER_DB)) def test_multiple_auth_sources(self): - response = self.app.get('/v1/actions', - headers={'X-Auth-Token': TOKEN, 'St2-Api-key': KEY1_KEY}, - expect_errors=True) - self.assertIn('application/json', response.headers['content-type']) + response = self.app.get( + "/v1/actions", + headers={"X-Auth-Token": TOKEN, "St2-Api-key": KEY1_KEY}, + expect_errors=True, + ) + self.assertIn("application/json", response.headers["content-type"]) self.assertEqual(response.status_int, 200) diff --git a/st2api/tests/unit/controllers/v1/test_auth_api_keys.py b/st2api/tests/unit/controllers/v1/test_auth_api_keys.py index c172b224453..bf76d41276c 100644 --- a/st2api/tests/unit/controllers/v1/test_auth_api_keys.py +++ b/st2api/tests/unit/controllers/v1/test_auth_api_keys.py @@ -22,11 +22,16 @@ from st2tests.fixturesloader import FixturesLoader from st2tests.api import FunctionalTest -FIXTURES_PACK = 'generic' +FIXTURES_PACK = "generic" TEST_MODELS = { - 'apikeys': ['apikey1.yaml', 'apikey2.yaml', 'apikey3.yaml', 'apikey_disabled.yaml', - 'apikey_malformed.yaml'] + "apikeys": [ + "apikey1.yaml", + "apikey2.yaml", + "apikey3.yaml", + "apikey_disabled.yaml", + "apikey_malformed.yaml", + ] } # Hardcoded keys matching the fixtures. Lazy way to workound one-way hash and still use fixtures. @@ -45,205 +50,239 @@ class TestApiKeyController(FunctionalTest): def setUpClass(cls): super(TestApiKeyController, cls).setUpClass() - cfg.CONF.set_override(name='mask_secrets', override=True, group='api') - cfg.CONF.set_override(name='mask_secrets', override=True, group='log') + cfg.CONF.set_override(name="mask_secrets", override=True, group="api") + cfg.CONF.set_override(name="mask_secrets", override=True, group="log") - models = FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_MODELS) - cls.apikey1 = models['apikeys']['apikey1.yaml'] - cls.apikey2 = models['apikeys']['apikey2.yaml'] - cls.apikey3 = models['apikeys']['apikey3.yaml'] - cls.apikey4 = models['apikeys']['apikey_disabled.yaml'] - cls.apikey5 = models['apikeys']['apikey_malformed.yaml'] + models = FixturesLoader().save_fixtures_to_db( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS + ) + cls.apikey1 = models["apikeys"]["apikey1.yaml"] + cls.apikey2 = models["apikeys"]["apikey2.yaml"] + cls.apikey3 = models["apikeys"]["apikey3.yaml"] + cls.apikey4 = models["apikeys"]["apikey_disabled.yaml"] + cls.apikey5 = models["apikeys"]["apikey_malformed.yaml"] def test_get_all_and_minus_one(self): - resp = self.app.get('/v1/apikeys') + resp = self.app.get("/v1/apikeys") self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.headers['X-Total-Count'], "5") - self.assertEqual(resp.headers['X-Limit'], "50") - self.assertEqual(len(resp.json), 5, '/v1/apikeys did not return all apikeys.') - - retrieved_ids = [apikey['id'] for apikey in resp.json] - self.assertEqual(retrieved_ids, - [str(self.apikey1.id), str(self.apikey2.id), str(self.apikey3.id), - str(self.apikey4.id), str(self.apikey5.id)], - 'Incorrect api keys retrieved.') - - resp = self.app.get('/v1/apikeys/?limit=-1') + self.assertEqual(resp.headers["X-Total-Count"], "5") + self.assertEqual(resp.headers["X-Limit"], "50") + self.assertEqual(len(resp.json), 5, "/v1/apikeys did not return all apikeys.") + + retrieved_ids = [apikey["id"] for apikey in resp.json] + self.assertEqual( + retrieved_ids, + [ + str(self.apikey1.id), + str(self.apikey2.id), + str(self.apikey3.id), + str(self.apikey4.id), + str(self.apikey5.id), + ], + "Incorrect api keys retrieved.", + ) + + resp = self.app.get("/v1/apikeys/?limit=-1") self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.headers['X-Total-Count'], "5") - self.assertEqual(len(resp.json), 5, '/v1/apikeys did not return all apikeys.') + self.assertEqual(resp.headers["X-Total-Count"], "5") + self.assertEqual(len(resp.json), 5, "/v1/apikeys did not return all apikeys.") def test_get_all_with_pagnination_with_offset_and_limit(self): - resp = self.app.get('/v1/apikeys?offset=2&limit=1') + resp = self.app.get("/v1/apikeys?offset=2&limit=1") self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.headers['X-Total-Count'], "5") - self.assertEqual(resp.headers['X-Limit'], "1") + self.assertEqual(resp.headers["X-Total-Count"], "5") + self.assertEqual(resp.headers["X-Limit"], "1") self.assertEqual(len(resp.json), 1) - retrieved_ids = [apikey['id'] for apikey in resp.json] + retrieved_ids = [apikey["id"] for apikey in resp.json] self.assertEqual(retrieved_ids, [str(self.apikey3.id)]) def test_get_all_with_pagnination_with_only_offset(self): - resp = self.app.get('/v1/apikeys?offset=3') + resp = self.app.get("/v1/apikeys?offset=3") self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.headers['X-Total-Count'], "5") - self.assertEqual(resp.headers['X-Limit'], "50") + self.assertEqual(resp.headers["X-Total-Count"], "5") + self.assertEqual(resp.headers["X-Limit"], "50") self.assertEqual(len(resp.json), 2) - retrieved_ids = [apikey['id'] for apikey in resp.json] + retrieved_ids = [apikey["id"] for apikey in resp.json] self.assertEqual(retrieved_ids, [str(self.apikey4.id), str(self.apikey5.id)]) def test_get_all_with_pagnination_with_only_limit(self): - resp = self.app.get('/v1/apikeys?limit=2') + resp = self.app.get("/v1/apikeys?limit=2") self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.headers['X-Total-Count'], "5") - self.assertEqual(resp.headers['X-Limit'], "2") + self.assertEqual(resp.headers["X-Total-Count"], "5") + self.assertEqual(resp.headers["X-Limit"], "2") self.assertEqual(len(resp.json), 2) - retrieved_ids = [apikey['id'] for apikey in resp.json] + retrieved_ids = [apikey["id"] for apikey in resp.json] self.assertEqual(retrieved_ids, [str(self.apikey1.id), str(self.apikey2.id)]) - @mock.patch('st2common.rbac.backends.noop.NoOpRBACUtils.user_is_admin', - mock.Mock(return_value=False)) + @mock.patch( + "st2common.rbac.backends.noop.NoOpRBACUtils.user_is_admin", + mock.Mock(return_value=False), + ) def test_get_all_invalid_limit_too_large_none_admin(self): # limit > max_page_size, but user is not admin - resp = self.app.get('/v1/apikeys?offset=2&limit=1000', expect_errors=True) + resp = self.app.get("/v1/apikeys?offset=2&limit=1000", expect_errors=True) self.assertEqual(resp.status_int, http_client.FORBIDDEN) - self.assertEqual(resp.json['faultstring'], - 'Limit "1000" specified, maximum value is "100"') + self.assertEqual( + resp.json["faultstring"], 'Limit "1000" specified, maximum value is "100"' + ) def test_get_all_invalid_limit_negative_integer(self): - resp = self.app.get('/v1/apikeys?offset=2&limit=-22', expect_errors=True) + resp = self.app.get("/v1/apikeys?offset=2&limit=-22", expect_errors=True) self.assertEqual(resp.status_int, 400) - self.assertEqual(resp.json['faultstring'], - 'Limit, "-22" specified, must be a positive number.') + self.assertEqual( + resp.json["faultstring"], + 'Limit, "-22" specified, must be a positive number.', + ) def test_get_all_invalid_offset_too_large(self): - offset = '2141564789454123457895412237483648' - resp = self.app.get('/v1/apikeys?offset=%s&limit=1' % (offset), expect_errors=True) + offset = "2141564789454123457895412237483648" + resp = self.app.get( + "/v1/apikeys?offset=%s&limit=1" % (offset), expect_errors=True + ) self.assertEqual(resp.status_int, 400) - self.assertEqual(resp.json['faultstring'], - 'Offset "%s" specified is more than 32 bit int' % (offset)) + self.assertEqual( + resp.json["faultstring"], + 'Offset "%s" specified is more than 32 bit int' % (offset), + ) def test_get_one_by_id(self): - resp = self.app.get('/v1/apikeys/%s' % self.apikey1.id) + resp = self.app.get("/v1/apikeys/%s" % self.apikey1.id) self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json['id'], str(self.apikey1.id), - 'Incorrect api key retrieved.') - self.assertEqual(resp.json['key_hash'], MASKED_ATTRIBUTE_VALUE, - 'Key should be masked.') + self.assertEqual( + resp.json["id"], str(self.apikey1.id), "Incorrect api key retrieved." + ) + self.assertEqual( + resp.json["key_hash"], MASKED_ATTRIBUTE_VALUE, "Key should be masked." + ) def test_get_one_by_key(self): # key1 - resp = self.app.get('/v1/apikeys/%s' % KEY1_KEY) + resp = self.app.get("/v1/apikeys/%s" % KEY1_KEY) self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json['id'], str(self.apikey1.id), - 'Incorrect api key retrieved.') - self.assertEqual(resp.json['key_hash'], MASKED_ATTRIBUTE_VALUE, - 'Key should be masked.') + self.assertEqual( + resp.json["id"], str(self.apikey1.id), "Incorrect api key retrieved." + ) + self.assertEqual( + resp.json["key_hash"], MASKED_ATTRIBUTE_VALUE, "Key should be masked." + ) # key2 - resp = self.app.get('/v1/apikeys/%s' % KEY2_KEY) + resp = self.app.get("/v1/apikeys/%s" % KEY2_KEY) self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json['id'], str(self.apikey2.id), - 'Incorrect api key retrieved.') - self.assertEqual(resp.json['key_hash'], MASKED_ATTRIBUTE_VALUE, - 'Key should be masked.') + self.assertEqual( + resp.json["id"], str(self.apikey2.id), "Incorrect api key retrieved." + ) + self.assertEqual( + resp.json["key_hash"], MASKED_ATTRIBUTE_VALUE, "Key should be masked." + ) # key3 - resp = self.app.get('/v1/apikeys/%s' % KEY3_KEY) + resp = self.app.get("/v1/apikeys/%s" % KEY3_KEY) self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json['id'], str(self.apikey3.id), - 'Incorrect api key retrieved.') - self.assertEqual(resp.json['key_hash'], MASKED_ATTRIBUTE_VALUE, - 'Key should be masked.') + self.assertEqual( + resp.json["id"], str(self.apikey3.id), "Incorrect api key retrieved." + ) + self.assertEqual( + resp.json["key_hash"], MASKED_ATTRIBUTE_VALUE, "Key should be masked." + ) def test_get_show_secrets(self): - resp = self.app.get('/v1/apikeys?show_secrets=True', expect_errors=True) + resp = self.app.get("/v1/apikeys?show_secrets=True", expect_errors=True) self.assertEqual(resp.status_int, 200) for key in resp.json: - self.assertNotEqual(key['key_hash'], MASKED_ATTRIBUTE_VALUE) - self.assertNotEqual(key['uid'], MASKED_ATTRIBUTE_VALUE) + self.assertNotEqual(key["key_hash"], MASKED_ATTRIBUTE_VALUE) + self.assertNotEqual(key["uid"], MASKED_ATTRIBUTE_VALUE) def test_post_delete_key(self): - api_key = { - 'user': 'herge' - } - resp1 = self.app.post_json('/v1/apikeys', api_key) + api_key = {"user": "herge"} + resp1 = self.app.post_json("/v1/apikeys", api_key) self.assertEqual(resp1.status_int, 201) - self.assertTrue(resp1.json['key'], 'Key should be non-None.') - self.assertNotEqual(resp1.json['key'], MASKED_ATTRIBUTE_VALUE, - 'Key should not be masked.') + self.assertTrue(resp1.json["key"], "Key should be non-None.") + self.assertNotEqual( + resp1.json["key"], MASKED_ATTRIBUTE_VALUE, "Key should not be masked." + ) # should lead to creation of another key - resp2 = self.app.post_json('/v1/apikeys', api_key) + resp2 = self.app.post_json("/v1/apikeys", api_key) self.assertEqual(resp2.status_int, 201) - self.assertTrue(resp2.json['key'], 'Key should be non-None.') - self.assertNotEqual(resp2.json['key'], MASKED_ATTRIBUTE_VALUE, 'Key should not be masked.') - self.assertNotEqual(resp1.json['key'], resp2.json['key'], 'Should be different') + self.assertTrue(resp2.json["key"], "Key should be non-None.") + self.assertNotEqual( + resp2.json["key"], MASKED_ATTRIBUTE_VALUE, "Key should not be masked." + ) + self.assertNotEqual(resp1.json["key"], resp2.json["key"], "Should be different") - resp = self.app.delete('/v1/apikeys/%s' % resp1.json['id']) + resp = self.app.delete("/v1/apikeys/%s" % resp1.json["id"]) self.assertEqual(resp.status_int, 204) - resp = self.app.delete('/v1/apikeys/%s' % resp2.json['key']) + resp = self.app.delete("/v1/apikeys/%s" % resp2.json["key"]) self.assertEqual(resp.status_int, 204) # With auth disabled, use system_user - resp3 = self.app.post_json('/v1/apikeys', {}) + resp3 = self.app.post_json("/v1/apikeys", {}) self.assertEqual(resp3.status_int, 201) - self.assertTrue(resp3.json['key'], 'Key should be non-None.') - self.assertNotEqual(resp3.json['key'], MASKED_ATTRIBUTE_VALUE, - 'Key should not be masked.') - self.assertTrue(resp3.json['user'], cfg.CONF.system_user.user) + self.assertTrue(resp3.json["key"], "Key should be non-None.") + self.assertNotEqual( + resp3.json["key"], MASKED_ATTRIBUTE_VALUE, "Key should not be masked." + ) + self.assertTrue(resp3.json["user"], cfg.CONF.system_user.user) def test_post_delete_same_key_hash(self): api_key = { - 'id': '5c5dbb576cb8de06a2d79a4d', - 'user': 'herge', - 'key_hash': 'ABCDE' + "id": "5c5dbb576cb8de06a2d79a4d", + "user": "herge", + "key_hash": "ABCDE", } - resp1 = self.app.post_json('/v1/apikeys', api_key) + resp1 = self.app.post_json("/v1/apikeys", api_key) self.assertEqual(resp1.status_int, 201) - self.assertEqual(resp1.json['key'], None, 'Key should be None.') + self.assertEqual(resp1.json["key"], None, "Key should be None.") # drop into the DB since API will be masking this value. - api_key_db = ApiKey.get_by_id(resp1.json['id']) + api_key_db = ApiKey.get_by_id(resp1.json["id"]) - self.assertEqual(resp1.json['id'], api_key['id'], 'PK ID of created API should match.') - self.assertEqual(api_key_db.key_hash, api_key['key_hash'], 'Key_hash should match.') - self.assertEqual(api_key_db.user, api_key['user'], 'User should match.') + self.assertEqual( + resp1.json["id"], api_key["id"], "PK ID of created API should match." + ) + self.assertEqual( + api_key_db.key_hash, api_key["key_hash"], "Key_hash should match." + ) + self.assertEqual(api_key_db.user, api_key["user"], "User should match.") - resp = self.app.delete('/v1/apikeys/%s' % resp1.json['id']) + resp = self.app.delete("/v1/apikeys/%s" % resp1.json["id"]) self.assertEqual(resp.status_int, 204) def test_put_api_key(self): - resp = self.app.get('/v1/apikeys/%s' % self.apikey1.id) + resp = self.app.get("/v1/apikeys/%s" % self.apikey1.id) self.assertEqual(resp.status_int, 200) update_input = resp.json - update_input['enabled'] = not update_input['enabled'] - put_resp = self.app.put_json('/v1/apikeys/%s' % self.apikey1.id, update_input, - expect_errors=True) + update_input["enabled"] = not update_input["enabled"] + put_resp = self.app.put_json( + "/v1/apikeys/%s" % self.apikey1.id, update_input, expect_errors=True + ) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['enabled'], not resp.json['enabled']) + self.assertEqual(put_resp.json["enabled"], not resp.json["enabled"]) update_input = put_resp.json - update_input['enabled'] = not update_input['enabled'] - put_resp = self.app.put_json('/v1/apikeys/%s' % self.apikey1.id, update_input, - expect_errors=True) + update_input["enabled"] = not update_input["enabled"] + put_resp = self.app.put_json( + "/v1/apikeys/%s" % self.apikey1.id, update_input, expect_errors=True + ) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['enabled'], resp.json['enabled']) + self.assertEqual(put_resp.json["enabled"], resp.json["enabled"]) def test_put_api_key_fail(self): - resp = self.app.get('/v1/apikeys/%s' % self.apikey1.id) + resp = self.app.get("/v1/apikeys/%s" % self.apikey1.id) self.assertEqual(resp.status_int, 200) update_input = resp.json - update_input['key_hash'] = '1' - put_resp = self.app.put_json('/v1/apikeys/%s' % self.apikey1.id, update_input, - expect_errors=True) + update_input["key_hash"] = "1" + put_resp = self.app.put_json( + "/v1/apikeys/%s" % self.apikey1.id, update_input, expect_errors=True + ) self.assertEqual(put_resp.status_int, 400) - self.assertTrue(put_resp.json['faultstring']) + self.assertTrue(put_resp.json["faultstring"]) def test_post_no_user_fail(self): - self.app.post_json('/v1/apikeys', {}, expect_errors=True) + self.app.post_json("/v1/apikeys", {}, expect_errors=True) diff --git a/st2api/tests/unit/controllers/v1/test_base.py b/st2api/tests/unit/controllers/v1/test_base.py index fa8b4f1c92d..cbfe3e54c2d 100644 --- a/st2api/tests/unit/controllers/v1/test_base.py +++ b/st2api/tests/unit/controllers/v1/test_base.py @@ -19,77 +19,79 @@ class TestBase(FunctionalTest): def test_defaults(self): - response = self.app.get('/') + response = self.app.get("/") self.assertEqual(response.status_int, 200) - self.assertEqual(response.headers['Access-Control-Allow-Origin'], - 'http://127.0.0.1:3000') - self.assertEqual(response.headers['Access-Control-Allow-Methods'], - 'GET,POST,PUT,DELETE,OPTIONS') - self.assertEqual(response.headers['Access-Control-Allow-Headers'], - 'Content-Type,Authorization,X-Auth-Token,St2-Api-Key,X-Request-ID') - self.assertEqual(response.headers['Access-Control-Expose-Headers'], - 'Content-Type,X-Limit,X-Total-Count,X-Request-ID') + self.assertEqual( + response.headers["Access-Control-Allow-Origin"], "http://127.0.0.1:3000" + ) + self.assertEqual( + response.headers["Access-Control-Allow-Methods"], + "GET,POST,PUT,DELETE,OPTIONS", + ) + self.assertEqual( + response.headers["Access-Control-Allow-Headers"], + "Content-Type,Authorization,X-Auth-Token,St2-Api-Key,X-Request-ID", + ) + self.assertEqual( + response.headers["Access-Control-Expose-Headers"], + "Content-Type,X-Limit,X-Total-Count,X-Request-ID", + ) def test_origin(self): - response = self.app.get('/', headers={ - 'origin': 'http://127.0.0.1:3000' - }) + response = self.app.get("/", headers={"origin": "http://127.0.0.1:3000"}) self.assertEqual(response.status_int, 200) - self.assertEqual(response.headers['Access-Control-Allow-Origin'], - 'http://127.0.0.1:3000') + self.assertEqual( + response.headers["Access-Control-Allow-Origin"], "http://127.0.0.1:3000" + ) def test_additional_origin(self): - response = self.app.get('/', headers={ - 'origin': 'http://dev' - }) + response = self.app.get("/", headers={"origin": "http://dev"}) self.assertEqual(response.status_int, 200) - self.assertEqual(response.headers['Access-Control-Allow-Origin'], - 'http://dev') + self.assertEqual(response.headers["Access-Control-Allow-Origin"], "http://dev") def test_wrong_origin(self): # Invalid origin (not specified in the config), we return first allowed origin specified # in the config - response = self.app.get('/', headers={ - 'origin': 'http://xss' - }) + response = self.app.get("/", headers={"origin": "http://xss"}) self.assertEqual(response.status_int, 200) - self.assertEqual(response.headers.get('Access-Control-Allow-Origin'), - 'http://127.0.0.1:3000') + self.assertEqual( + response.headers.get("Access-Control-Allow-Origin"), "http://127.0.0.1:3000" + ) invalid_origins = [ - 'http://', - 'https://', - 'https://www.example.com', - 'null', - '*' + "http://", + "https://", + "https://www.example.com", + "null", + "*", ] for origin in invalid_origins: - response = self.app.get('/', headers={ - 'origin': origin - }) + response = self.app.get("/", headers={"origin": origin}) self.assertEqual(response.status_int, 200) - self.assertEqual(response.headers.get('Access-Control-Allow-Origin'), - 'http://127.0.0.1:3000') + self.assertEqual( + response.headers.get("Access-Control-Allow-Origin"), + "http://127.0.0.1:3000", + ) def test_wildcard_origin(self): try: - cfg.CONF.set_override('allow_origin', ['*'], 'api') - response = self.app.get('/', headers={ - 'origin': 'http://xss' - }) + cfg.CONF.set_override("allow_origin", ["*"], "api") + response = self.app.get("/", headers={"origin": "http://xss"}) finally: - cfg.CONF.clear_override('allow_origin', 'api') + cfg.CONF.clear_override("allow_origin", "api") self.assertEqual(response.status_int, 200) - self.assertEqual(response.headers['Access-Control-Allow-Origin'], - 'http://xss') + self.assertEqual(response.headers["Access-Control-Allow-Origin"], "http://xss") def test_valid_status_code_is_returned_on_invalid_path(self): # TypeError: get_all() takes exactly 1 argument (2 given) - resp = self.app.get('/v1/executions/577f775b0640fd1451f2030b/re_run', expect_errors=True) + resp = self.app.get( + "/v1/executions/577f775b0640fd1451f2030b/re_run", expect_errors=True + ) self.assertEqual(resp.status_int, 404) # get_one() takes exactly 2 arguments (4 given) - resp = self.app.get('/v1/executions/577f775b0640fd1451f2030b/re_run/a/b', - expect_errors=True) + resp = self.app.get( + "/v1/executions/577f775b0640fd1451f2030b/re_run/a/b", expect_errors=True + ) self.assertEqual(resp.status_int, 404) diff --git a/st2api/tests/unit/controllers/v1/test_executions.py b/st2api/tests/unit/controllers/v1/test_executions.py index 57dad1f9f30..5a59f6aab57 100644 --- a/st2api/tests/unit/controllers/v1/test_executions.py +++ b/st2api/tests/unit/controllers/v1/test_executions.py @@ -55,324 +55,286 @@ from st2tests.api import APIControllerWithIncludeAndExcludeFilterTestCase __all__ = [ - 'ActionExecutionControllerTestCase', - 'ActionExecutionOutputControllerTestCase' + "ActionExecutionControllerTestCase", + "ActionExecutionOutputControllerTestCase", ] ACTION_1 = { - 'name': 'st2.dummy.action1', - 'description': 'test description', - 'enabled': True, - 'entry_point': '/tmp/test/action1.sh', - 'pack': 'sixpack', - 'runner_type': 'remote-shell-cmd', - 'parameters': { - 'a': { - 'type': 'string', - 'default': 'abc' - }, - 'b': { - 'type': 'number', - 'default': 123 - }, - 'c': { - 'type': 'number', - 'default': 123, - 'immutable': True - }, - 'd': { - 'type': 'string', - 'secret': True - } - } + "name": "st2.dummy.action1", + "description": "test description", + "enabled": True, + "entry_point": "/tmp/test/action1.sh", + "pack": "sixpack", + "runner_type": "remote-shell-cmd", + "parameters": { + "a": {"type": "string", "default": "abc"}, + "b": {"type": "number", "default": 123}, + "c": {"type": "number", "default": 123, "immutable": True}, + "d": {"type": "string", "secret": True}, + }, } ACTION_2 = { - 'name': 'st2.dummy.action2', - 'description': 'another test description', - 'enabled': True, - 'entry_point': '/tmp/test/action2.sh', - 'pack': 'familypack', - 'runner_type': 'remote-shell-cmd', - 'parameters': { - 'c': { - 'type': 'object', - 'properties': { - 'c1': { - 'type': 'string' - } - } - }, - 'd': { - 'type': 'boolean', - 'default': False - } - } + "name": "st2.dummy.action2", + "description": "another test description", + "enabled": True, + "entry_point": "/tmp/test/action2.sh", + "pack": "familypack", + "runner_type": "remote-shell-cmd", + "parameters": { + "c": {"type": "object", "properties": {"c1": {"type": "string"}}}, + "d": {"type": "boolean", "default": False}, + }, } ACTION_3 = { - 'name': 'st2.dummy.action3', - 'description': 'another test description', - 'enabled': True, - 'entry_point': '/tmp/test/action3.sh', - 'pack': 'wolfpack', - 'runner_type': 'remote-shell-cmd', - 'parameters': { - 'e': {}, - 'f': {} - } + "name": "st2.dummy.action3", + "description": "another test description", + "enabled": True, + "entry_point": "/tmp/test/action3.sh", + "pack": "wolfpack", + "runner_type": "remote-shell-cmd", + "parameters": {"e": {}, "f": {}}, } ACTION_4 = { - 'name': 'st2.dummy.action4', - 'description': 'another test description', - 'enabled': True, - 'entry_point': '/tmp/test/workflows/action4.yaml', - 'pack': 'starterpack', - 'runner_type': 'orquesta', - 'parameters': { - 'a': { - 'type': 'string', - 'default': 'abc' - }, - 'b': { - 'type': 'number', - 'default': 123 - } - } + "name": "st2.dummy.action4", + "description": "another test description", + "enabled": True, + "entry_point": "/tmp/test/workflows/action4.yaml", + "pack": "starterpack", + "runner_type": "orquesta", + "parameters": { + "a": {"type": "string", "default": "abc"}, + "b": {"type": "number", "default": 123}, + }, } ACTION_INQUIRY = { - 'name': 'st2.dummy.ask', - 'description': 'another test description', - 'enabled': True, - 'pack': 'wolfpack', - 'runner_type': 'inquirer', + "name": "st2.dummy.ask", + "description": "another test description", + "enabled": True, + "pack": "wolfpack", + "runner_type": "inquirer", } ACTION_DEFAULT_TEMPLATE = { - 'name': 'st2.dummy.default_template', - 'description': 'An action that uses a jinja template as a default value for a parameter', - 'enabled': True, - 'pack': 'starterpack', - 'runner_type': 'local-shell-cmd', - 'parameters': { - 'intparam': { - 'type': 'integer', - 'default': '{{ st2kv.system.test_int | int }}' - } - } + "name": "st2.dummy.default_template", + "description": "An action that uses a jinja template as a default value for a parameter", + "enabled": True, + "pack": "starterpack", + "runner_type": "local-shell-cmd", + "parameters": { + "intparam": {"type": "integer", "default": "{{ st2kv.system.test_int | int }}"} + }, } ACTION_DEFAULT_ENCRYPT = { - 'name': 'st2.dummy.default_encrypted_value', - 'description': 'An action that uses a jinja template with decrypt_kv filter ' - 'in default parameter', - 'enabled': True, - 'pack': 'starterpack', - 'runner_type': 'local-shell-cmd', - 'parameters': { - 'encrypted_param': { - 'type': 'string', - 'default': '{{ st2kv.system.secret | decrypt_kv }}' + "name": "st2.dummy.default_encrypted_value", + "description": "An action that uses a jinja template with decrypt_kv filter " + "in default parameter", + "enabled": True, + "pack": "starterpack", + "runner_type": "local-shell-cmd", + "parameters": { + "encrypted_param": { + "type": "string", + "default": "{{ st2kv.system.secret | decrypt_kv }}", }, - 'encrypted_user_param': { - 'type': 'string', - 'default': '{{ st2kv.user.secret | decrypt_kv }}' - } - } + "encrypted_user_param": { + "type": "string", + "default": "{{ st2kv.user.secret | decrypt_kv }}", + }, + }, } ACTION_DEFAULT_ENCRYPT_SECRET_PARAMS = { - 'name': 'st2.dummy.default_encrypted_value_secret_param', - 'description': 'An action that uses a jinja template with decrypt_kv filter ' - 'in default parameter', - 'enabled': True, - 'pack': 'starterpack', - 'runner_type': 'local-shell-cmd', - 'parameters': { - 'encrypted_param': { - 'type': 'string', - 'default': '{{ st2kv.system.secret | decrypt_kv }}', - 'secret': True + "name": "st2.dummy.default_encrypted_value_secret_param", + "description": "An action that uses a jinja template with decrypt_kv filter " + "in default parameter", + "enabled": True, + "pack": "starterpack", + "runner_type": "local-shell-cmd", + "parameters": { + "encrypted_param": { + "type": "string", + "default": "{{ st2kv.system.secret | decrypt_kv }}", + "secret": True, }, - 'encrypted_user_param': { - 'type': 'string', - 'default': '{{ st2kv.user.secret | decrypt_kv }}', - 'secret': True - } - } + "encrypted_user_param": { + "type": "string", + "default": "{{ st2kv.user.secret | decrypt_kv }}", + "secret": True, + }, + }, } LIVE_ACTION_1 = { - 'action': 'sixpack.st2.dummy.action1', - 'parameters': { - 'hosts': 'localhost', - 'cmd': 'uname -a', - 'd': SUPER_SECRET_PARAMETER - } + "action": "sixpack.st2.dummy.action1", + "parameters": { + "hosts": "localhost", + "cmd": "uname -a", + "d": SUPER_SECRET_PARAMETER, + }, } LIVE_ACTION_2 = { - 'action': 'familypack.st2.dummy.action2', - 'parameters': { - 'hosts': 'localhost', - 'cmd': 'ls -l' - } + "action": "familypack.st2.dummy.action2", + "parameters": {"hosts": "localhost", "cmd": "ls -l"}, } LIVE_ACTION_3 = { - 'action': 'wolfpack.st2.dummy.action3', - 'parameters': { - 'hosts': 'localhost', - 'cmd': 'ls -l', - 'e': 'abcde', - 'f': 12345 - } + "action": "wolfpack.st2.dummy.action3", + "parameters": {"hosts": "localhost", "cmd": "ls -l", "e": "abcde", "f": 12345}, } LIVE_ACTION_4 = { - 'action': 'starterpack.st2.dummy.action4', + "action": "starterpack.st2.dummy.action4", } LIVE_ACTION_DELAY = { - 'action': 'sixpack.st2.dummy.action1', - 'parameters': { - 'hosts': 'localhost', - 'cmd': 'uname -a', - 'd': SUPER_SECRET_PARAMETER + "action": "sixpack.st2.dummy.action1", + "parameters": { + "hosts": "localhost", + "cmd": "uname -a", + "d": SUPER_SECRET_PARAMETER, }, - 'delay': 100 + "delay": 100, } LIVE_ACTION_INQUIRY = { - 'parameters': { - 'route': 'developers', - 'schema': { - 'type': 'object', - 'properties': { - 'secondfactor': { - 'secret': True, - 'required': True, - 'type': u'string', - 'description': 'Please enter second factor for authenticating to "foo" service' + "parameters": { + "route": "developers", + "schema": { + "type": "object", + "properties": { + "secondfactor": { + "secret": True, + "required": True, + "type": "string", + "description": 'Please enter second factor for authenticating to "foo" service', } - } - } - }, - 'action': 'wolfpack.st2.dummy.ask', - 'result': { - 'users': [], - 'roles': [], - 'route': 'developers', - 'ttl': 1440, - 'response': { - 'secondfactor': 'supersecretvalue' + }, }, - 'schema': { - 'type': 'object', - 'properties': { - 'secondfactor': { - 'secret': True, - 'required': True, - 'type': 'string', - 'description': 'Please enter second factor for authenticating to "foo" service' + }, + "action": "wolfpack.st2.dummy.ask", + "result": { + "users": [], + "roles": [], + "route": "developers", + "ttl": 1440, + "response": {"secondfactor": "supersecretvalue"}, + "schema": { + "type": "object", + "properties": { + "secondfactor": { + "secret": True, + "required": True, + "type": "string", + "description": 'Please enter second factor for authenticating to "foo" service', } - } - } - } + }, + }, + }, } LIVE_ACTION_WITH_SECRET_PARAM = { - 'parameters': { + "parameters": { # action params - 'a': 'param a', - 'd': 'secretpassword1', - + "a": "param a", + "d": "secretpassword1", # runner params - 'password': 'secretpassword2', - 'hosts': 'localhost' + "password": "secretpassword2", + "hosts": "localhost", }, - 'action': 'sixpack.st2.dummy.action1' + "action": "sixpack.st2.dummy.action1", } # Do not add parameters to this. There are tests that will test first without params, # then make a copy with params. LIVE_ACTION_DEFAULT_TEMPLATE = { - 'action': 'starterpack.st2.dummy.default_template', + "action": "starterpack.st2.dummy.default_template", } LIVE_ACTION_DEFAULT_ENCRYPT = { - 'action': 'starterpack.st2.dummy.default_encrypted_value', + "action": "starterpack.st2.dummy.default_encrypted_value", } LIVE_ACTION_DEFAULT_ENCRYPT_SECRET_PARAM = { - 'action': 'starterpack.st2.dummy.default_encrypted_value_secret_param', + "action": "starterpack.st2.dummy.default_encrypted_value_secret_param", } -FIXTURES_PACK = 'generic' +FIXTURES_PACK = "generic" TEST_FIXTURES = { - 'runners': ['testrunner1.yaml'], - 'actions': ['action1.yaml', 'local.yaml'] + "runners": ["testrunner1.yaml"], + "actions": ["action1.yaml", "local.yaml"], } -@mock.patch.object(content_utils, 'get_pack_base_path', mock.MagicMock(return_value='/tmp/test')) -@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) -class ActionExecutionControllerTestCase(BaseActionExecutionControllerTestCase, FunctionalTest, - APIControllerWithIncludeAndExcludeFilterTestCase): - get_all_path = '/v1/executions' +@mock.patch.object( + content_utils, "get_pack_base_path", mock.MagicMock(return_value="/tmp/test") +) +@mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) +class ActionExecutionControllerTestCase( + BaseActionExecutionControllerTestCase, + FunctionalTest, + APIControllerWithIncludeAndExcludeFilterTestCase, +): + get_all_path = "/v1/executions" controller_cls = ActionExecutionsController - include_attribute_field_name = 'status' - exclude_attribute_field_name = 'status' + include_attribute_field_name = "status" + exclude_attribute_field_name = "status" test_exact_object_count = False @classmethod - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def setUpClass(cls): super(BaseActionExecutionControllerTestCase, cls).setUpClass() cls.action1 = copy.deepcopy(ACTION_1) - post_resp = cls.app.post_json('/v1/actions', cls.action1) - cls.action1['id'] = post_resp.json['id'] + post_resp = cls.app.post_json("/v1/actions", cls.action1) + cls.action1["id"] = post_resp.json["id"] cls.action2 = copy.deepcopy(ACTION_2) - post_resp = cls.app.post_json('/v1/actions', cls.action2) - cls.action2['id'] = post_resp.json['id'] + post_resp = cls.app.post_json("/v1/actions", cls.action2) + cls.action2["id"] = post_resp.json["id"] cls.action3 = copy.deepcopy(ACTION_3) - post_resp = cls.app.post_json('/v1/actions', cls.action3) - cls.action3['id'] = post_resp.json['id'] + post_resp = cls.app.post_json("/v1/actions", cls.action3) + cls.action3["id"] = post_resp.json["id"] cls.action4 = copy.deepcopy(ACTION_4) - post_resp = cls.app.post_json('/v1/actions', cls.action4) - cls.action4['id'] = post_resp.json['id'] + post_resp = cls.app.post_json("/v1/actions", cls.action4) + cls.action4["id"] = post_resp.json["id"] cls.action_inquiry = copy.deepcopy(ACTION_INQUIRY) - post_resp = cls.app.post_json('/v1/actions', cls.action_inquiry) - cls.action_inquiry['id'] = post_resp.json['id'] + post_resp = cls.app.post_json("/v1/actions", cls.action_inquiry) + cls.action_inquiry["id"] = post_resp.json["id"] cls.action_template = copy.deepcopy(ACTION_DEFAULT_TEMPLATE) - post_resp = cls.app.post_json('/v1/actions', cls.action_template) - cls.action_template['id'] = post_resp.json['id'] + post_resp = cls.app.post_json("/v1/actions", cls.action_template) + cls.action_template["id"] = post_resp.json["id"] cls.action_decrypt = copy.deepcopy(ACTION_DEFAULT_ENCRYPT) - post_resp = cls.app.post_json('/v1/actions', cls.action_decrypt) - cls.action_decrypt['id'] = post_resp.json['id'] + post_resp = cls.app.post_json("/v1/actions", cls.action_decrypt) + cls.action_decrypt["id"] = post_resp.json["id"] - cls.action_decrypt_secret_param = copy.deepcopy(ACTION_DEFAULT_ENCRYPT_SECRET_PARAMS) - post_resp = cls.app.post_json('/v1/actions', cls.action_decrypt_secret_param) - cls.action_decrypt_secret_param['id'] = post_resp.json['id'] + cls.action_decrypt_secret_param = copy.deepcopy( + ACTION_DEFAULT_ENCRYPT_SECRET_PARAMS + ) + post_resp = cls.app.post_json("/v1/actions", cls.action_decrypt_secret_param) + cls.action_decrypt_secret_param["id"] = post_resp.json["id"] @classmethod def tearDownClass(cls): - cls.app.delete('/v1/actions/%s' % cls.action1['id']) - cls.app.delete('/v1/actions/%s' % cls.action2['id']) - cls.app.delete('/v1/actions/%s' % cls.action3['id']) - cls.app.delete('/v1/actions/%s' % cls.action4['id']) - cls.app.delete('/v1/actions/%s' % cls.action_inquiry['id']) - cls.app.delete('/v1/actions/%s' % cls.action_template['id']) - cls.app.delete('/v1/actions/%s' % cls.action_decrypt['id']) + cls.app.delete("/v1/actions/%s" % cls.action1["id"]) + cls.app.delete("/v1/actions/%s" % cls.action2["id"]) + cls.app.delete("/v1/actions/%s" % cls.action3["id"]) + cls.app.delete("/v1/actions/%s" % cls.action4["id"]) + cls.app.delete("/v1/actions/%s" % cls.action_inquiry["id"]) + cls.app.delete("/v1/actions/%s" % cls.action_template["id"]) + cls.app.delete("/v1/actions/%s" % cls.action_decrypt["id"]) super(BaseActionExecutionControllerTestCase, cls).tearDownClass() def test_get_one(self): @@ -381,11 +343,11 @@ def test_get_one(self): get_resp = self._do_get_one(actionexecution_id) self.assertEqual(get_resp.status_int, 200) self.assertEqual(self._get_actionexecution_id(get_resp), actionexecution_id) - self.assertIn('web_url', get_resp) - if 'end_timestamp' in get_resp: - self.assertIn('elapsed_seconds', get_resp) + self.assertIn("web_url", get_resp) + if "end_timestamp" in get_resp: + self.assertIn("elapsed_seconds", get_resp) - get_resp = self._do_get_one('last') + get_resp = self._do_get_one("last") self.assertEqual(get_resp.status_int, 200) self.assertEqual(self._get_actionexecution_id(get_resp), actionexecution_id) @@ -396,13 +358,15 @@ def test_get_all_id_query_param_filtering_success(self): self.assertEqual(get_resp.status_int, 200) self.assertEqual(self._get_actionexecution_id(get_resp), actionexecution_id) - resp = self.app.get('/v1/executions?id=%s' % (actionexecution_id), expect_errors=False) + resp = self.app.get( + "/v1/executions?id=%s" % (actionexecution_id), expect_errors=False + ) self.assertEqual(resp.status_int, 200) def test_get_all_id_query_param_filtering_invalid_id(self): - resp = self.app.get('/v1/executions?id=invalidid', expect_errors=True) + resp = self.app.get("/v1/executions?id=invalidid", expect_errors=True) self.assertEqual(resp.status_int, 400) - self.assertIn('not a valid ObjectId', resp.json['faultstring']) + self.assertIn("not a valid ObjectId", resp.json["faultstring"]) def test_get_all_id_query_param_filtering_multiple_ids_provided(self): post_resp = self._do_post(LIVE_ACTION_1) @@ -413,94 +377,118 @@ def test_get_all_id_query_param_filtering_multiple_ids_provided(self): self.assertEqual(post_resp.status_int, 201) id_2 = self._get_actionexecution_id(post_resp) - resp = self.app.get('/v1/executions?id=%s,%s' % (id_1, id_2), expect_errors=False) + resp = self.app.get( + "/v1/executions?id=%s,%s" % (id_1, id_2), expect_errors=False + ) self.assertEqual(resp.status_int, 200) self.assertEqual(len(resp.json), 2) def test_get_all(self): self._get_actionexecution_id(self._do_post(LIVE_ACTION_1)) self._get_actionexecution_id(self._do_post(LIVE_ACTION_2)) - resp = self.app.get('/v1/executions') + resp = self.app.get("/v1/executions") body = resp.json self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.headers['X-Total-Count'], "2") - self.assertEqual(len(resp.json), 2, - '/v1/executions did not return all ' - 'actionexecutions.') + self.assertEqual(resp.headers["X-Total-Count"], "2") + self.assertEqual( + len(resp.json), 2, "/v1/executions did not return all " "actionexecutions." + ) # Assert liveactions are sorted by timestamp. for i in range(len(body) - 1): - self.assertTrue(isotime.parse(body[i]['start_timestamp']) >= - isotime.parse(body[i + 1]['start_timestamp'])) - self.assertIn('web_url', body[i]) - if 'end_timestamp' in body[i]: - self.assertIn('elapsed_seconds', body[i]) + self.assertTrue( + isotime.parse(body[i]["start_timestamp"]) + >= isotime.parse(body[i + 1]["start_timestamp"]) + ) + self.assertIn("web_url", body[i]) + if "end_timestamp" in body[i]: + self.assertIn("elapsed_seconds", body[i]) def test_get_all_invalid_offset_too_large(self): - offset = '2141564789454123457895412237483648' - resp = self.app.get('/v1/executions?offset=%s&limit=1' % (offset), expect_errors=True) + offset = "2141564789454123457895412237483648" + resp = self.app.get( + "/v1/executions?offset=%s&limit=1" % (offset), expect_errors=True + ) self.assertEqual(resp.status_int, 400) - self.assertEqual(resp.json['faultstring'], - u'Offset "%s" specified is more than 32-bit int' % (offset)) + self.assertEqual( + resp.json["faultstring"], + 'Offset "%s" specified is more than 32-bit int' % (offset), + ) def test_get_query(self): - actionexecution_1_id = self._get_actionexecution_id(self._do_post(LIVE_ACTION_1)) + actionexecution_1_id = self._get_actionexecution_id( + self._do_post(LIVE_ACTION_1) + ) - resp = self.app.get('/v1/executions?action=%s' % LIVE_ACTION_1['action']) + resp = self.app.get("/v1/executions?action=%s" % LIVE_ACTION_1["action"]) self.assertEqual(resp.status_int, 200) - matching_execution = filter(lambda ae: ae['id'] == actionexecution_1_id, resp.json) - self.assertEqual(len(list(matching_execution)), 1, - '/v1/executions did not return correct liveaction.') + matching_execution = filter( + lambda ae: ae["id"] == actionexecution_1_id, resp.json + ) + self.assertEqual( + len(list(matching_execution)), + 1, + "/v1/executions did not return correct liveaction.", + ) def test_get_query_with_limit_and_offset(self): self._get_actionexecution_id(self._do_post(LIVE_ACTION_1)) self._get_actionexecution_id(self._do_post(LIVE_ACTION_1)) - resp = self.app.get('/v1/executions') + resp = self.app.get("/v1/executions") self.assertEqual(resp.status_int, 200) self.assertTrue(len(resp.json) > 0) - resp = self.app.get('/v1/executions?limit=1') + resp = self.app.get("/v1/executions?limit=1") self.assertEqual(resp.status_int, 200) self.assertEqual(len(resp.json), 1) - resp = self.app.get('/v1/executions?limit=0', expect_errors=True) + resp = self.app.get("/v1/executions?limit=0", expect_errors=True) self.assertEqual(resp.status_int, 400) - self.assertTrue(resp.json['faultstring'], - u'Limit, "0" specified, must be a positive number or -1 for full \ - result set.') + self.assertTrue( + resp.json["faultstring"], + 'Limit, "0" specified, must be a positive number or -1 for full \ + result set.', + ) - resp = self.app.get('/v1/executions?limit=-1') + resp = self.app.get("/v1/executions?limit=-1") self.assertEqual(resp.status_int, 200) self.assertTrue(len(resp.json) > 1) - resp = self.app.get('/v1/executions?limit=-22', expect_errors=True) + resp = self.app.get("/v1/executions?limit=-22", expect_errors=True) self.assertEqual(resp.status_int, 400) - self.assertEqual(resp.json['faultstring'], - u'Limit, "-22" specified, must be a positive number.') + self.assertEqual( + resp.json["faultstring"], + 'Limit, "-22" specified, must be a positive number.', + ) - resp = self.app.get('/v1/executions?action=%s' % LIVE_ACTION_1['action']) + resp = self.app.get("/v1/executions?action=%s" % LIVE_ACTION_1["action"]) self.assertEqual(resp.status_int, 200) self.assertTrue(len(resp.json) > 1) - resp = self.app.get('/v1/executions?action=%s&limit=0' % - LIVE_ACTION_1['action'], expect_errors=True) + resp = self.app.get( + "/v1/executions?action=%s&limit=0" % LIVE_ACTION_1["action"], + expect_errors=True, + ) self.assertEqual(resp.status_int, 400) - self.assertTrue(resp.json['faultstring'], - u'Limit, "0" specified, must be a positive number or -1 for full \ - result set.') - - resp = self.app.get('/v1/executions?action=%s&limit=1' % - LIVE_ACTION_1['action']) + self.assertTrue( + resp.json["faultstring"], + 'Limit, "0" specified, must be a positive number or -1 for full \ + result set.', + ) + + resp = self.app.get( + "/v1/executions?action=%s&limit=1" % LIVE_ACTION_1["action"] + ) self.assertEqual(resp.status_int, 200) self.assertEqual(len(resp.json), 1) - total_count = resp.headers['X-Total-Count'] + total_count = resp.headers["X-Total-Count"] - resp = self.app.get('/v1/executions?offset=%s&limit=1' % total_count) + resp = self.app.get("/v1/executions?offset=%s&limit=1" % total_count) self.assertEqual(resp.status_int, 200) self.assertTrue(len(resp.json), 0) def test_get_one_fail(self): - resp = self.app.get('/v1/executions/100', expect_errors=True) + resp = self.app.get("/v1/executions/100", expect_errors=True) self.assertEqual(resp.status_int, 404) def test_post_delete(self): @@ -508,13 +496,13 @@ def test_post_delete(self): self.assertEqual(post_resp.status_int, 201) delete_resp = self._do_delete(self._get_actionexecution_id(post_resp)) self.assertEqual(delete_resp.status_int, 200) - self.assertEqual(delete_resp.json['status'], 'canceled') - expected_result = {'message': 'Action canceled by user.', 'user': 'stanley'} - self.assertDictEqual(delete_resp.json['result'], expected_result) + self.assertEqual(delete_resp.json["status"], "canceled") + expected_result = {"message": "Action canceled by user.", "user": "stanley"} + self.assertDictEqual(delete_resp.json["result"], expected_result) def test_post_delete_duplicate(self): """Cancels an execution twice, to ensure that a full execution object - is returned instead of an error message + is returned instead of an error message """ post_resp = self._do_post(LIVE_ACTION_1) @@ -524,59 +512,65 @@ def test_post_delete_duplicate(self): for i in range(2): delete_resp = self._do_delete(self._get_actionexecution_id(post_resp)) self.assertEqual(delete_resp.status_int, 200) - self.assertEqual(delete_resp.json['status'], 'canceled') - expected_result = {'message': 'Action canceled by user.', 'user': 'stanley'} - self.assertDictEqual(delete_resp.json['result'], expected_result) + self.assertEqual(delete_resp.json["status"], "canceled") + expected_result = {"message": "Action canceled by user.", "user": "stanley"} + self.assertDictEqual(delete_resp.json["result"], expected_result) def test_post_delete_trace(self): LIVE_ACTION_TRACE = copy.copy(LIVE_ACTION_1) - LIVE_ACTION_TRACE['context'] = {'trace_context': {'trace_tag': 'balleilaka'}} + LIVE_ACTION_TRACE["context"] = {"trace_context": {"trace_tag": "balleilaka"}} post_resp = self._do_post(LIVE_ACTION_TRACE) self.assertEqual(post_resp.status_int, 201) delete_resp = self._do_delete(self._get_actionexecution_id(post_resp)) self.assertEqual(delete_resp.status_int, 200) - self.assertEqual(delete_resp.json['status'], 'canceled') + self.assertEqual(delete_resp.json["status"], "canceled") trace_id = str(Trace.get_all()[0].id) - LIVE_ACTION_TRACE['context'] = {'trace_context': {'id_': trace_id}} + LIVE_ACTION_TRACE["context"] = {"trace_context": {"id_": trace_id}} post_resp = self._do_post(LIVE_ACTION_TRACE) self.assertEqual(post_resp.status_int, 201) delete_resp = self._do_delete(self._get_actionexecution_id(post_resp)) self.assertEqual(delete_resp.status_int, 200) - self.assertEqual(delete_resp.json['status'], 'canceled') + self.assertEqual(delete_resp.json["status"], "canceled") def test_post_nonexistent_action(self): live_action = copy.deepcopy(LIVE_ACTION_1) - live_action['action'] = 'mock.foobar' + live_action["action"] = "mock.foobar" post_resp = self._do_post(live_action, expect_errors=True) self.assertEqual(post_resp.status_int, 400) - expected_error = 'Action "%s" cannot be found.' % live_action['action'] - self.assertEqual(expected_error, post_resp.json['faultstring']) + expected_error = 'Action "%s" cannot be found.' % live_action["action"] + self.assertEqual(expected_error, post_resp.json["faultstring"]) def test_post_parameter_validation_failed(self): execution = copy.deepcopy(LIVE_ACTION_1) # Runner type does not expects additional properties. - execution['parameters']['foo'] = 'bar' + execution["parameters"]["foo"] = "bar" post_resp = self._do_post(execution, expect_errors=True) self.assertEqual(post_resp.status_int, 400) - self.assertEqual(post_resp.json['faultstring'], - "Additional properties are not allowed ('foo' was unexpected)") + self.assertEqual( + post_resp.json["faultstring"], + "Additional properties are not allowed ('foo' was unexpected)", + ) # Runner type expects parameter "hosts". - execution['parameters'] = {} + execution["parameters"] = {} post_resp = self._do_post(execution, expect_errors=True) self.assertEqual(post_resp.status_int, 400) - self.assertEqual(post_resp.json['faultstring'], "'hosts' is a required property") + self.assertEqual( + post_resp.json["faultstring"], "'hosts' is a required property" + ) # Runner type expects parameters "cmd" to be str. - execution['parameters'] = {"hosts": "127.0.0.1", "cmd": 1000} + execution["parameters"] = {"hosts": "127.0.0.1", "cmd": 1000} post_resp = self._do_post(execution, expect_errors=True) self.assertEqual(post_resp.status_int, 400) - self.assertIn('Value "1000" must either be a string or None. Got "int"', - post_resp.json['faultstring']) + self.assertIn( + 'Value "1000" must either be a string or None. Got "int"', + post_resp.json["faultstring"], + ) # Runner type expects parameters "cmd" to be str. - execution['parameters'] = {"hosts": "127.0.0.1", "cmd": "1000", "c": 1} + execution["parameters"] = {"hosts": "127.0.0.1", "cmd": "1000", "c": 1} post_resp = self._do_post(execution, expect_errors=True) self.assertEqual(post_resp.status_int, 400) @@ -589,53 +583,55 @@ def test_post_parameter_render_failed(self): execution = copy.deepcopy(LIVE_ACTION_1) # Runner type does not expects additional properties. - execution['parameters']['hosts'] = '{{ABSENT}}' + execution["parameters"]["hosts"] = "{{ABSENT}}" post_resp = self._do_post(execution, expect_errors=True) self.assertEqual(post_resp.status_int, 400) - self.assertEqual(post_resp.json['faultstring'], - 'Dependency unsatisfied in variable "ABSENT"') + self.assertEqual( + post_resp.json["faultstring"], 'Dependency unsatisfied in variable "ABSENT"' + ) def test_post_parameter_validation_explicit_none(self): execution = copy.deepcopy(LIVE_ACTION_1) - execution['parameters']['a'] = None + execution["parameters"]["a"] = None post_resp = self._do_post(execution) self.assertEqual(post_resp.status_int, 201) def test_post_with_st2_context_in_headers(self): resp = self._do_post(copy.deepcopy(LIVE_ACTION_1)) self.assertEqual(resp.status_int, 201) - parent_user = resp.json['context']['user'] - parent_exec_id = str(resp.json['id']) + parent_user = resp.json["context"]["user"] + parent_exec_id = str(resp.json["id"]) context = { - 'parent': { - 'execution_id': parent_exec_id, - 'user': parent_user - }, - 'user': None, - 'other': {'k1': 'v1'} + "parent": {"execution_id": parent_exec_id, "user": parent_user}, + "user": None, + "other": {"k1": "v1"}, + } + headers = { + "content-type": "application/json", + "st2-context": json.dumps(context), } - headers = {'content-type': 'application/json', 'st2-context': json.dumps(context)} resp = self._do_post(copy.deepcopy(LIVE_ACTION_1), headers=headers) self.assertEqual(resp.status_int, 201) - self.assertEqual(resp.json['context']['user'], parent_user, 'Should use parent\'s user.') + self.assertEqual( + resp.json["context"]["user"], parent_user, "Should use parent's user." + ) expected = { - 'parent': { - 'execution_id': parent_exec_id, - 'user': parent_user - }, - 'user': parent_user, - 'pack': 'sixpack', - 'other': {'k1': 'v1'} + "parent": {"execution_id": parent_exec_id, "user": parent_user}, + "user": parent_user, + "pack": "sixpack", + "other": {"k1": "v1"}, } - self.assertDictEqual(resp.json['context'], expected) + self.assertDictEqual(resp.json["context"], expected) def test_post_with_st2_context_in_headers_failed(self): resp = self._do_post(copy.deepcopy(LIVE_ACTION_1)) self.assertEqual(resp.status_int, 201) - headers = {'content-type': 'application/json', 'st2-context': 'foobar'} - resp = self._do_post(copy.deepcopy(LIVE_ACTION_1), headers=headers, expect_errors=True) + headers = {"content-type": "application/json", "st2-context": "foobar"} + resp = self._do_post( + copy.deepcopy(LIVE_ACTION_1), headers=headers, expect_errors=True + ) self.assertEqual(resp.status_int, 400) - self.assertIn('Unable to convert st2-context', resp.json['faultstring']) + self.assertIn("Unable to convert st2-context", resp.json["faultstring"]) def test_re_run_success(self): # Create a new execution @@ -645,12 +641,16 @@ def test_re_run_success(self): # Re-run created execution (no parameters overrides) data = {} - re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id), data) + re_run_resp = self.app.post_json( + "/v1/executions/%s/re_run" % (execution_id), data + ) self.assertEqual(re_run_resp.status_int, 201) # Re-run created execution (with parameters overrides) - data = {'parameters': {'a': 'val1'}} - re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id), data) + data = {"parameters": {"a": "val1"}} + re_run_resp = self.app.post_json( + "/v1/executions/%s/re_run" % (execution_id), data + ) self.assertEqual(re_run_resp.status_int, 201) def test_re_run_with_delay(self): @@ -659,21 +659,24 @@ def test_re_run_with_delay(self): execution_id = self._get_actionexecution_id(post_resp) delay_time = 100 - data = {'delay': delay_time} - re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id), data) + data = {"delay": delay_time} + re_run_resp = self.app.post_json( + "/v1/executions/%s/re_run" % (execution_id), data + ) self.assertEqual(re_run_resp.status_int, 201) resp = json.loads(re_run_resp.body) - self.assertEqual(resp['delay'], delay_time) + self.assertEqual(resp["delay"], delay_time) def test_re_run_with_incorrect_delay(self): post_resp = self._do_post(LIVE_ACTION_1) self.assertEqual(post_resp.status_int, 201) execution_id = self._get_actionexecution_id(post_resp) - delay_time = 'sudo apt -y upgrade winson' - data = {'delay': delay_time} - re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id), - data, expect_errors=True) + delay_time = "sudo apt -y upgrade winson" + data = {"delay": delay_time} + re_run_resp = self.app.post_json( + "/v1/executions/%s/re_run" % (execution_id), data, expect_errors=True + ) self.assertEqual(re_run_resp.status_int, 400) def test_re_run_with_very_large_delay(self): @@ -682,8 +685,10 @@ def test_re_run_with_very_large_delay(self): execution_id = self._get_actionexecution_id(post_resp) delay_time = 10 ** 10 - data = {'delay': delay_time} - re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id), data) + data = {"delay": delay_time} + re_run_resp = self.app.post_json( + "/v1/executions/%s/re_run" % (execution_id), data + ) self.assertEqual(re_run_resp.status_int, 201) def test_re_run_delayed_aciton_with_no_delay(self): @@ -692,11 +697,13 @@ def test_re_run_delayed_aciton_with_no_delay(self): execution_id = self._get_actionexecution_id(post_resp) delay_time = 0 - data = {'delay': delay_time} - re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id), data) + data = {"delay": delay_time} + re_run_resp = self.app.post_json( + "/v1/executions/%s/re_run" % (execution_id), data + ) self.assertEqual(re_run_resp.status_int, 201) resp = json.loads(re_run_resp.body) - self.assertNotIn('delay', resp.keys()) + self.assertNotIn("delay", resp.keys()) def test_re_run_failure_execution_doesnt_exist(self): # Create a new execution @@ -705,8 +712,9 @@ def test_re_run_failure_execution_doesnt_exist(self): # Re-run created execution (override parameter with an invalid value) data = {} - re_run_resp = self.app.post_json('/v1/executions/doesntexist/re_run', - data, expect_errors=True) + re_run_resp = self.app.post_json( + "/v1/executions/doesntexist/re_run", data, expect_errors=True + ) self.assertEqual(re_run_resp.status_int, 404) def test_re_run_failure_parameter_override_invalid_type(self): @@ -716,12 +724,15 @@ def test_re_run_failure_parameter_override_invalid_type(self): execution_id = self._get_actionexecution_id(post_resp) # Re-run created execution (override parameter and task together) - data = {'parameters': {'a': 1000}} - re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id), - data, expect_errors=True) + data = {"parameters": {"a": 1000}} + re_run_resp = self.app.post_json( + "/v1/executions/%s/re_run" % (execution_id), data, expect_errors=True + ) self.assertEqual(re_run_resp.status_int, 400) - self.assertIn('Value "1000" must either be a string or None. Got "int"', - re_run_resp.json['faultstring']) + self.assertIn( + 'Value "1000" must either be a string or None. Got "int"', + re_run_resp.json["faultstring"], + ) def test_template_param(self): @@ -731,31 +742,46 @@ def test_template_param(self): # Assert that the template in the parameter default value # was rendered and st2kv was used - self.assertEqual(post_resp.json['parameters']['intparam'], 0) + self.assertEqual(post_resp.json["parameters"]["intparam"], 0) # Test with live param live_int_param = 3 livaction_with_params = copy.deepcopy(LIVE_ACTION_DEFAULT_TEMPLATE) - livaction_with_params['parameters'] = { - "intparam": live_int_param - } + livaction_with_params["parameters"] = {"intparam": live_int_param} post_resp = self._do_post(livaction_with_params) self.assertEqual(post_resp.status_int, 201) # Assert that the template in the parameter default value # was not rendered, and the provided parameter was used - self.assertEqual(post_resp.json['parameters']['intparam'], live_int_param) + self.assertEqual(post_resp.json["parameters"]["intparam"], live_int_param) def test_template_encrypted_params(self): # register datastore values which are used in this test case KeyValuePairAPI._setup_crypto() register_items = [ - {'name': 'secret', 'secret': True, - 'value': crypto_utils.symmetric_encrypt(KeyValuePairAPI.crypto_key, 'foo')}, - {'name': 'stanley:secret', 'secret': True, 'scope': FULL_USER_SCOPE, - 'value': crypto_utils.symmetric_encrypt(KeyValuePairAPI.crypto_key, 'bar')}, - {'name': 'user1:secret', 'secret': True, 'scope': FULL_USER_SCOPE, - 'value': crypto_utils.symmetric_encrypt(KeyValuePairAPI.crypto_key, 'baz')}, + { + "name": "secret", + "secret": True, + "value": crypto_utils.symmetric_encrypt( + KeyValuePairAPI.crypto_key, "foo" + ), + }, + { + "name": "stanley:secret", + "secret": True, + "scope": FULL_USER_SCOPE, + "value": crypto_utils.symmetric_encrypt( + KeyValuePairAPI.crypto_key, "bar" + ), + }, + { + "name": "user1:secret", + "secret": True, + "scope": FULL_USER_SCOPE, + "value": crypto_utils.symmetric_encrypt( + KeyValuePairAPI.crypto_key, "baz" + ), + }, ] kvps = [KeyValuePair.add_or_update(KeyValuePairDB(**x)) for x in register_items] @@ -763,43 +789,53 @@ def test_template_encrypted_params(self): # 1. parameters are not marked as secret resp = self._do_post(LIVE_ACTION_DEFAULT_ENCRYPT) self.assertEqual(resp.status_int, 201) - self.assertEqual(resp.json['context']['user'], 'stanley') - self.assertEqual(resp.json['parameters']['encrypted_param'], 'foo') - self.assertEqual(resp.json['parameters']['encrypted_user_param'], 'bar') + self.assertEqual(resp.json["context"]["user"], "stanley") + self.assertEqual(resp.json["parameters"]["encrypted_param"], "foo") + self.assertEqual(resp.json["parameters"]["encrypted_user_param"], "bar") # 2. parameters are marked as secret resp = self._do_post(LIVE_ACTION_DEFAULT_ENCRYPT_SECRET_PARAM) self.assertEqual(resp.status_int, 201) - self.assertEqual(resp.json['context']['user'], 'stanley') - self.assertEqual(resp.json['parameters']['encrypted_param'], MASKED_ATTRIBUTE_VALUE) - self.assertEqual(resp.json['parameters']['encrypted_user_param'], MASKED_ATTRIBUTE_VALUE) + self.assertEqual(resp.json["context"]["user"], "stanley") + self.assertEqual( + resp.json["parameters"]["encrypted_param"], MASKED_ATTRIBUTE_VALUE + ) + self.assertEqual( + resp.json["parameters"]["encrypted_user_param"], MASKED_ATTRIBUTE_VALUE + ) # After switching to the 'user1', that value will be read from switched user's scope - self.use_user(UserDB(name='user1')) + self.use_user(UserDB(name="user1")) # 1. parameters are not marked as secret resp = self._do_post(LIVE_ACTION_DEFAULT_ENCRYPT) self.assertEqual(resp.status_int, 201) - self.assertEqual(resp.json['context']['user'], 'user1') - self.assertEqual(resp.json['parameters']['encrypted_param'], 'foo') - self.assertEqual(resp.json['parameters']['encrypted_user_param'], 'baz') + self.assertEqual(resp.json["context"]["user"], "user1") + self.assertEqual(resp.json["parameters"]["encrypted_param"], "foo") + self.assertEqual(resp.json["parameters"]["encrypted_user_param"], "baz") # 2. parameters are marked as secret resp = self._do_post(LIVE_ACTION_DEFAULT_ENCRYPT_SECRET_PARAM) self.assertEqual(resp.status_int, 201) - self.assertEqual(resp.json['context']['user'], 'user1') - self.assertEqual(resp.json['parameters']['encrypted_param'], MASKED_ATTRIBUTE_VALUE) - self.assertEqual(resp.json['parameters']['encrypted_user_param'], MASKED_ATTRIBUTE_VALUE) + self.assertEqual(resp.json["context"]["user"], "user1") + self.assertEqual( + resp.json["parameters"]["encrypted_param"], MASKED_ATTRIBUTE_VALUE + ) + self.assertEqual( + resp.json["parameters"]["encrypted_user_param"], MASKED_ATTRIBUTE_VALUE + ) # This switches to the 'user2', there is no value in that user's scope. When a request # that tries to evaluate Jinja expression to decrypt empty value is sent, a HTTP response # which has 4xx status code will be returned. - self.use_user(UserDB(name='user2')) + self.use_user(UserDB(name="user2")) resp = self._do_post(LIVE_ACTION_DEFAULT_ENCRYPT, expect_errors=True) self.assertEqual(resp.status_int, 400) - self.assertEqual(resp.json['faultstring'], - 'Failed to render parameter "encrypted_user_param": Referenced datastore ' - 'item "st2kv.user.secret" doesn\'t exist or it contains an empty string') + self.assertEqual( + resp.json["faultstring"], + 'Failed to render parameter "encrypted_user_param": Referenced datastore ' + 'item "st2kv.user.secret" doesn\'t exist or it contains an empty string', + ) # clean-up values that are registered at first for kvp in kvps: @@ -808,7 +844,9 @@ def test_template_encrypted_params(self): def test_template_encrypted_params_without_registering(self): resp = self._do_post(LIVE_ACTION_DEFAULT_ENCRYPT, expect_errors=True) self.assertEqual(resp.status_int, 400) - self.assertEqual(resp.json['faultstring'].index('Failed to render parameter'), 0) + self.assertEqual( + resp.json["faultstring"].index("Failed to render parameter"), 0 + ) def test_re_run_workflow_success(self): # Create a new execution @@ -818,26 +856,25 @@ def test_re_run_workflow_success(self): # Re-run created execution (tasks option for non workflow) data = {} - re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id), - data, expect_errors=True) + re_run_resp = self.app.post_json( + "/v1/executions/%s/re_run" % (execution_id), data, expect_errors=True + ) self.assertEqual(re_run_resp.status_int, 201) # Get the trace - trace = trace_service.get_trace_db_by_action_execution(action_execution_id=execution_id) + trace = trace_service.get_trace_db_by_action_execution( + action_execution_id=execution_id + ) expected_context = { - 'user': 'stanley', - 'pack': 'starterpack', - 're-run': { - 'ref': execution_id - }, - 'trace_context': { - 'id_': str(trace.id) - } + "user": "stanley", + "pack": "starterpack", + "re-run": {"ref": execution_id}, + "trace_context": {"id_": str(trace.id)}, } - self.assertDictEqual(re_run_resp.json['context'], expected_context) + self.assertDictEqual(re_run_resp.json["context"], expected_context) def test_re_run_workflow_task_success(self): # Create a new execution @@ -846,28 +883,26 @@ def test_re_run_workflow_task_success(self): execution_id = self._get_actionexecution_id(post_resp) # Re-run created execution (tasks option for non workflow) - data = {'tasks': ['x']} - re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id), - data, expect_errors=True) + data = {"tasks": ["x"]} + re_run_resp = self.app.post_json( + "/v1/executions/%s/re_run" % (execution_id), data, expect_errors=True + ) self.assertEqual(re_run_resp.status_int, 201) # Get the trace - trace = trace_service.get_trace_db_by_action_execution(action_execution_id=execution_id) + trace = trace_service.get_trace_db_by_action_execution( + action_execution_id=execution_id + ) expected_context = { - 'pack': 'starterpack', - 'user': 'stanley', - 're-run': { - 'ref': execution_id, - 'tasks': data['tasks'] - }, - 'trace_context': { - 'id_': str(trace.id) - } + "pack": "starterpack", + "user": "stanley", + "re-run": {"ref": execution_id, "tasks": data["tasks"]}, + "trace_context": {"id_": str(trace.id)}, } - self.assertDictEqual(re_run_resp.json['context'], expected_context) + self.assertDictEqual(re_run_resp.json["context"], expected_context) def test_re_run_workflow_tasks_success(self): # Create a new execution @@ -876,28 +911,26 @@ def test_re_run_workflow_tasks_success(self): execution_id = self._get_actionexecution_id(post_resp) # Re-run created execution (tasks option for non workflow) - data = {'tasks': ['x', 'y']} - re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id), - data, expect_errors=True) + data = {"tasks": ["x", "y"]} + re_run_resp = self.app.post_json( + "/v1/executions/%s/re_run" % (execution_id), data, expect_errors=True + ) self.assertEqual(re_run_resp.status_int, 201) # Get the trace - trace = trace_service.get_trace_db_by_action_execution(action_execution_id=execution_id) + trace = trace_service.get_trace_db_by_action_execution( + action_execution_id=execution_id + ) expected_context = { - 'pack': 'starterpack', - 'user': 'stanley', - 're-run': { - 'ref': execution_id, - 'tasks': data['tasks'] - }, - 'trace_context': { - 'id_': str(trace.id) - } + "pack": "starterpack", + "user": "stanley", + "re-run": {"ref": execution_id, "tasks": data["tasks"]}, + "trace_context": {"id_": str(trace.id)}, } - self.assertDictEqual(re_run_resp.json['context'], expected_context) + self.assertDictEqual(re_run_resp.json["context"], expected_context) def test_re_run_workflow_tasks_reset_success(self): # Create a new execution @@ -906,29 +939,30 @@ def test_re_run_workflow_tasks_reset_success(self): execution_id = self._get_actionexecution_id(post_resp) # Re-run created execution (tasks option for non workflow) - data = {'tasks': ['x', 'y'], 'reset': ['y']} - re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id), - data, expect_errors=True) + data = {"tasks": ["x", "y"], "reset": ["y"]} + re_run_resp = self.app.post_json( + "/v1/executions/%s/re_run" % (execution_id), data, expect_errors=True + ) self.assertEqual(re_run_resp.status_int, 201) # Get the trace - trace = trace_service.get_trace_db_by_action_execution(action_execution_id=execution_id) + trace = trace_service.get_trace_db_by_action_execution( + action_execution_id=execution_id + ) expected_context = { - 'pack': 'starterpack', - 'user': 'stanley', - 're-run': { - 'ref': execution_id, - 'tasks': data['tasks'], - 'reset': data['reset'] + "pack": "starterpack", + "user": "stanley", + "re-run": { + "ref": execution_id, + "tasks": data["tasks"], + "reset": data["reset"], }, - 'trace_context': { - 'id_': str(trace.id) - } + "trace_context": {"id_": str(trace.id)}, } - self.assertDictEqual(re_run_resp.json['context'], expected_context) + self.assertDictEqual(re_run_resp.json["context"], expected_context) def test_re_run_failure_tasks_option_for_non_workflow(self): # Create a new execution @@ -937,14 +971,15 @@ def test_re_run_failure_tasks_option_for_non_workflow(self): execution_id = self._get_actionexecution_id(post_resp) # Re-run created execution (tasks option for non workflow) - data = {'tasks': ['x']} - re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id), - data, expect_errors=True) + data = {"tasks": ["x"]} + re_run_resp = self.app.post_json( + "/v1/executions/%s/re_run" % (execution_id), data, expect_errors=True + ) self.assertEqual(re_run_resp.status_int, 400) - expected_substring = 'only supported for Orquesta workflows' - self.assertIn(expected_substring, re_run_resp.json['faultstring']) + expected_substring = "only supported for Orquesta workflows" + self.assertIn(expected_substring, re_run_resp.json["faultstring"]) def test_re_run_workflow_failure_given_both_params_and_tasks(self): # Create a new execution @@ -953,13 +988,16 @@ def test_re_run_workflow_failure_given_both_params_and_tasks(self): execution_id = self._get_actionexecution_id(post_resp) # Re-run created execution (override parameter with an invalid value) - data = {'parameters': {'a': 'xyz'}, 'tasks': ['x']} - re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id), - data, expect_errors=True) + data = {"parameters": {"a": "xyz"}, "tasks": ["x"]} + re_run_resp = self.app.post_json( + "/v1/executions/%s/re_run" % (execution_id), data, expect_errors=True + ) self.assertEqual(re_run_resp.status_int, 400) - self.assertIn('not supported when re-running task(s) for a workflow', - re_run_resp.json['faultstring']) + self.assertIn( + "not supported when re-running task(s) for a workflow", + re_run_resp.json["faultstring"], + ) def test_re_run_workflow_failure_given_both_params_and_reset_tasks(self): # Create a new execution @@ -968,13 +1006,16 @@ def test_re_run_workflow_failure_given_both_params_and_reset_tasks(self): execution_id = self._get_actionexecution_id(post_resp) # Re-run created execution (override parameter with an invalid value) - data = {'parameters': {'a': 'xyz'}, 'reset': ['x']} - re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id), - data, expect_errors=True) + data = {"parameters": {"a": "xyz"}, "reset": ["x"]} + re_run_resp = self.app.post_json( + "/v1/executions/%s/re_run" % (execution_id), data, expect_errors=True + ) self.assertEqual(re_run_resp.status_int, 400) - self.assertIn('not supported when re-running task(s) for a workflow', - re_run_resp.json['faultstring']) + self.assertIn( + "not supported when re-running task(s) for a workflow", + re_run_resp.json["faultstring"], + ) def test_re_run_workflow_failure_invalid_reset_tasks(self): # Create a new execution @@ -983,13 +1024,16 @@ def test_re_run_workflow_failure_invalid_reset_tasks(self): execution_id = self._get_actionexecution_id(post_resp) # Re-run created execution (override parameter with an invalid value) - data = {'tasks': ['x'], 'reset': ['y']} - re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id), - data, expect_errors=True) + data = {"tasks": ["x"], "reset": ["y"]} + re_run_resp = self.app.post_json( + "/v1/executions/%s/re_run" % (execution_id), data, expect_errors=True + ) self.assertEqual(re_run_resp.status_int, 400) - self.assertIn('tasks to reset does not match the tasks to rerun', - re_run_resp.json['faultstring']) + self.assertIn( + "tasks to reset does not match the tasks to rerun", + re_run_resp.json["faultstring"], + ) def test_re_run_secret_parameter(self): # Create a new execution @@ -999,96 +1043,100 @@ def test_re_run_secret_parameter(self): # Re-run created execution (no parameters overrides) data = {} - re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id), data) + re_run_resp = self.app.post_json( + "/v1/executions/%s/re_run" % (execution_id), data + ) self.assertEqual(re_run_resp.status_int, 201) execution_id = self._get_actionexecution_id(re_run_resp) - re_run_result = self._do_get_one(execution_id, - params={'show_secrets': True}, - expect_errors=True) - self.assertEqual(re_run_result.json['parameters'], LIVE_ACTION_1['parameters']) + re_run_result = self._do_get_one( + execution_id, params={"show_secrets": True}, expect_errors=True + ) + self.assertEqual(re_run_result.json["parameters"], LIVE_ACTION_1["parameters"]) # Re-run created execution (with parameters overrides) - data = {'parameters': {'a': 'val1', 'd': ANOTHER_SUPER_SECRET_PARAMETER}} - re_run_resp = self.app.post_json('/v1/executions/%s/re_run' % (execution_id), data) + data = {"parameters": {"a": "val1", "d": ANOTHER_SUPER_SECRET_PARAMETER}} + re_run_resp = self.app.post_json( + "/v1/executions/%s/re_run" % (execution_id), data + ) self.assertEqual(re_run_resp.status_int, 201) execution_id = self._get_actionexecution_id(re_run_resp) - re_run_result = self._do_get_one(execution_id, - params={'show_secrets': True}, - expect_errors=True) - self.assertEqual(re_run_result.json['parameters']['d'], data['parameters']['d']) + re_run_result = self._do_get_one( + execution_id, params={"show_secrets": True}, expect_errors=True + ) + self.assertEqual(re_run_result.json["parameters"]["d"], data["parameters"]["d"]) def test_put_status_and_result(self): post_resp = self._do_post(LIVE_ACTION_1) self.assertEqual(post_resp.status_int, 201) execution_id = self._get_actionexecution_id(post_resp) - updates = {'status': 'succeeded', 'result': {'stdout': 'foobar'}} + updates = {"status": "succeeded", "result": {"stdout": "foobar"}} put_resp = self._do_put(execution_id, updates) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['status'], 'succeeded') - self.assertDictEqual(put_resp.json['result'], {'stdout': 'foobar'}) + self.assertEqual(put_resp.json["status"], "succeeded") + self.assertDictEqual(put_resp.json["result"], {"stdout": "foobar"}) get_resp = self._do_get_one(execution_id) self.assertEqual(get_resp.status_int, 200) - self.assertEqual(get_resp.json['status'], 'succeeded') - self.assertDictEqual(get_resp.json['result'], {'stdout': 'foobar'}) + self.assertEqual(get_resp.json["status"], "succeeded") + self.assertDictEqual(get_resp.json["result"], {"stdout": "foobar"}) def test_put_bad_state(self): post_resp = self._do_post(LIVE_ACTION_1) self.assertEqual(post_resp.status_int, 201) execution_id = self._get_actionexecution_id(post_resp) - updates = {'status': 'married'} + updates = {"status": "married"} put_resp = self._do_put(execution_id, updates, expect_errors=True) self.assertEqual(put_resp.status_int, 400) - self.assertIn('\'married\' is not one of', put_resp.json['faultstring']) + self.assertIn("'married' is not one of", put_resp.json["faultstring"]) def test_put_bad_result(self): post_resp = self._do_post(LIVE_ACTION_1) self.assertEqual(post_resp.status_int, 201) execution_id = self._get_actionexecution_id(post_resp) - updates = {'result': 'foobar'} + updates = {"result": "foobar"} put_resp = self._do_put(execution_id, updates, expect_errors=True) self.assertEqual(put_resp.status_int, 400) - self.assertIn('is not of type \'object\'', put_resp.json['faultstring']) + self.assertIn("is not of type 'object'", put_resp.json["faultstring"]) def test_put_bad_property(self): post_resp = self._do_post(LIVE_ACTION_1) self.assertEqual(post_resp.status_int, 201) execution_id = self._get_actionexecution_id(post_resp) - updates = {'status': 'abandoned', 'foo': 'bar'} + updates = {"status": "abandoned", "foo": "bar"} put_resp = self._do_put(execution_id, updates, expect_errors=True) self.assertEqual(put_resp.status_int, 400) - self.assertIn('Additional properties are not allowed', put_resp.json['faultstring']) + self.assertIn( + "Additional properties are not allowed", put_resp.json["faultstring"] + ) def test_put_status_to_completed_execution(self): post_resp = self._do_post(LIVE_ACTION_1) self.assertEqual(post_resp.status_int, 201) execution_id = self._get_actionexecution_id(post_resp) - updates = {'status': 'succeeded', 'result': {'stdout': 'foobar'}} + updates = {"status": "succeeded", "result": {"stdout": "foobar"}} put_resp = self._do_put(execution_id, updates) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['status'], 'succeeded') - self.assertDictEqual(put_resp.json['result'], {'stdout': 'foobar'}) + self.assertEqual(put_resp.json["status"], "succeeded") + self.assertDictEqual(put_resp.json["result"], {"stdout": "foobar"}) - updates = {'status': 'abandoned'} + updates = {"status": "abandoned"} put_resp = self._do_put(execution_id, updates, expect_errors=True) self.assertEqual(put_resp.status_int, 400) - @mock.patch.object( - LiveAction, 'get_by_id', - mock.MagicMock(return_value=None)) + @mock.patch.object(LiveAction, "get_by_id", mock.MagicMock(return_value=None)) def test_put_execution_missing_liveaction(self): post_resp = self._do_post(LIVE_ACTION_1) self.assertEqual(post_resp.status_int, 201) execution_id = self._get_actionexecution_id(post_resp) - updates = {'status': 'succeeded', 'result': {'stdout': 'foobar'}} + updates = {"status": "succeeded", "result": {"stdout": "foobar"}} put_resp = self._do_put(execution_id, updates, expect_errors=True) self.assertEqual(put_resp.status_int, 500) @@ -1098,19 +1146,19 @@ def test_put_pause_unsupported(self): execution_id = self._get_actionexecution_id(post_resp) - updates = {'status': 'pausing'} + updates = {"status": "pausing"} put_resp = self._do_put(execution_id, updates, expect_errors=True) self.assertEqual(put_resp.status_int, 400) - self.assertIn('it is not supported', put_resp.json['faultstring']) + self.assertIn("it is not supported", put_resp.json["faultstring"]) - updates = {'status': 'paused'} + updates = {"status": "paused"} put_resp = self._do_put(execution_id, updates, expect_errors=True) self.assertEqual(put_resp.status_int, 400) - self.assertIn('it is not supported', put_resp.json['faultstring']) + self.assertIn("it is not supported", put_resp.json["faultstring"]) def test_put_pause(self): # Add the runner type to the list of runners that support pause and resume. - action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION_1['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION_1["runner_type"]) try: post_resp = self._do_post(LIVE_ACTION_1) @@ -1118,50 +1166,50 @@ def test_put_pause(self): execution_id = self._get_actionexecution_id(post_resp) - updates = {'status': 'running'} + updates = {"status": "running"} put_resp = self._do_put(execution_id, updates) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['status'], 'running') + self.assertEqual(put_resp.json["status"], "running") - updates = {'status': 'pausing'} + updates = {"status": "pausing"} put_resp = self._do_put(execution_id, updates) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['status'], 'pausing') - self.assertIsNone(put_resp.json.get('result')) + self.assertEqual(put_resp.json["status"], "pausing") + self.assertIsNone(put_resp.json.get("result")) get_resp = self._do_get_one(execution_id) self.assertEqual(get_resp.status_int, 200) - self.assertEqual(get_resp.json['status'], 'pausing') - self.assertIsNone(get_resp.json.get('result')) + self.assertEqual(get_resp.json["status"], "pausing") + self.assertIsNone(get_resp.json.get("result")) finally: - action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION_1['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION_1["runner_type"]) def test_put_pause_not_running(self): # Add the runner type to the list of runners that support pause and resume. - action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION_1['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION_1["runner_type"]) try: post_resp = self._do_post(LIVE_ACTION_1) self.assertEqual(post_resp.status_int, 201) - self.assertEqual(post_resp.json['status'], 'requested') + self.assertEqual(post_resp.json["status"], "requested") execution_id = self._get_actionexecution_id(post_resp) - updates = {'status': 'pausing'} + updates = {"status": "pausing"} put_resp = self._do_put(execution_id, updates, expect_errors=True) self.assertEqual(put_resp.status_int, 400) - self.assertIn('is not in a running state', put_resp.json['faultstring']) + self.assertIn("is not in a running state", put_resp.json["faultstring"]) get_resp = self._do_get_one(execution_id) self.assertEqual(get_resp.status_int, 200) - self.assertEqual(get_resp.json['status'], 'requested') - self.assertIsNone(get_resp.json.get('result')) + self.assertEqual(get_resp.json["status"], "requested") + self.assertIsNone(get_resp.json.get("result")) finally: - action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION_1['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION_1["runner_type"]) def test_put_pause_already_pausing(self): # Add the runner type to the list of runners that support pause and resume. - action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION_1['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION_1["runner_type"]) try: post_resp = self._do_post(LIVE_ACTION_1) @@ -1169,44 +1217,46 @@ def test_put_pause_already_pausing(self): execution_id = self._get_actionexecution_id(post_resp) - updates = {'status': 'running'} + updates = {"status": "running"} put_resp = self._do_put(execution_id, updates) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['status'], 'running') + self.assertEqual(put_resp.json["status"], "running") - updates = {'status': 'pausing'} + updates = {"status": "pausing"} put_resp = self._do_put(execution_id, updates) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['status'], 'pausing') - self.assertIsNone(put_resp.json.get('result')) + self.assertEqual(put_resp.json["status"], "pausing") + self.assertIsNone(put_resp.json.get("result")) - with mock.patch.object(action_service, 'update_status', return_value=None) as mocked: - updates = {'status': 'pausing'} + with mock.patch.object( + action_service, "update_status", return_value=None + ) as mocked: + updates = {"status": "pausing"} put_resp = self._do_put(execution_id, updates) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['status'], 'pausing') + self.assertEqual(put_resp.json["status"], "pausing") mocked.assert_not_called() get_resp = self._do_get_one(execution_id) self.assertEqual(get_resp.status_int, 200) - self.assertEqual(get_resp.json['status'], 'pausing') - self.assertIsNone(get_resp.json.get('result')) + self.assertEqual(get_resp.json["status"], "pausing") + self.assertIsNone(get_resp.json.get("result")) finally: - action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION_1['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION_1["runner_type"]) def test_put_resume_unsupported(self): post_resp = self._do_post(LIVE_ACTION_1) self.assertEqual(post_resp.status_int, 201) execution_id = self._get_actionexecution_id(post_resp) - updates = {'status': 'resuming'} + updates = {"status": "resuming"} put_resp = self._do_put(execution_id, updates, expect_errors=True) self.assertEqual(put_resp.status_int, 400) - self.assertIn('it is not supported', put_resp.json['faultstring']) + self.assertIn("it is not supported", put_resp.json["faultstring"]) def test_put_resume(self): # Add the runner type to the list of runners that support pause and resume. - action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION_1['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION_1["runner_type"]) try: post_resp = self._do_post(LIVE_ACTION_1) @@ -1214,44 +1264,46 @@ def test_put_resume(self): execution_id = self._get_actionexecution_id(post_resp) - updates = {'status': 'running'} + updates = {"status": "running"} put_resp = self._do_put(execution_id, updates) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['status'], 'running') + self.assertEqual(put_resp.json["status"], "running") - updates = {'status': 'pausing'} + updates = {"status": "pausing"} put_resp = self._do_put(execution_id, updates) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['status'], 'pausing') - self.assertIsNone(put_resp.json.get('result')) + self.assertEqual(put_resp.json["status"], "pausing") + self.assertIsNone(put_resp.json.get("result")) # Manually change the status to paused because only the runner pause method should # set the paused status directly to the liveaction and execution database objects. liveaction_id = self._get_liveaction_id(post_resp) liveaction = action_db_util.get_liveaction_by_id(liveaction_id) - action_service.update_status(liveaction, action_constants.LIVEACTION_STATUS_PAUSED) + action_service.update_status( + liveaction, action_constants.LIVEACTION_STATUS_PAUSED + ) get_resp = self._do_get_one(execution_id) self.assertEqual(get_resp.status_int, 200) - self.assertEqual(get_resp.json['status'], 'paused') - self.assertIsNone(get_resp.json.get('result')) + self.assertEqual(get_resp.json["status"], "paused") + self.assertIsNone(get_resp.json.get("result")) - updates = {'status': 'resuming'} + updates = {"status": "resuming"} put_resp = self._do_put(execution_id, updates) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['status'], 'resuming') - self.assertIsNone(put_resp.json.get('result')) + self.assertEqual(put_resp.json["status"], "resuming") + self.assertIsNone(put_resp.json.get("result")) get_resp = self._do_get_one(execution_id) self.assertEqual(get_resp.status_int, 200) - self.assertEqual(get_resp.json['status'], 'resuming') - self.assertIsNone(get_resp.json.get('result')) + self.assertEqual(get_resp.json["status"], "resuming") + self.assertIsNone(get_resp.json.get("result")) finally: - action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION_1['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION_1["runner_type"]) def test_put_resume_not_paused(self): # Add the runner type to the list of runners that support pause and resume. - action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION_1['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION_1["runner_type"]) try: post_resp = self._do_post(LIVE_ACTION_1) @@ -1259,33 +1311,35 @@ def test_put_resume_not_paused(self): execution_id = self._get_actionexecution_id(post_resp) - updates = {'status': 'running'} + updates = {"status": "running"} put_resp = self._do_put(execution_id, updates) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['status'], 'running') + self.assertEqual(put_resp.json["status"], "running") - updates = {'status': 'pausing'} + updates = {"status": "pausing"} put_resp = self._do_put(execution_id, updates) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['status'], 'pausing') - self.assertIsNone(put_resp.json.get('result')) + self.assertEqual(put_resp.json["status"], "pausing") + self.assertIsNone(put_resp.json.get("result")) - updates = {'status': 'resuming'} + updates = {"status": "resuming"} put_resp = self._do_put(execution_id, updates, expect_errors=True) self.assertEqual(put_resp.status_int, 400) - expected_error_message = 'it is in "pausing" state and not in "paused" state' - self.assertIn(expected_error_message, put_resp.json['faultstring']) + expected_error_message = ( + 'it is in "pausing" state and not in "paused" state' + ) + self.assertIn(expected_error_message, put_resp.json["faultstring"]) get_resp = self._do_get_one(execution_id) self.assertEqual(get_resp.status_int, 200) - self.assertEqual(get_resp.json['status'], 'pausing') - self.assertIsNone(get_resp.json.get('result')) + self.assertEqual(get_resp.json["status"], "pausing") + self.assertIsNone(get_resp.json.get("result")) finally: - action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION_1['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION_1["runner_type"]) def test_put_resume_already_running(self): # Add the runner type to the list of runners that support pause and resume. - action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION_1['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION_1["runner_type"]) try: post_resp = self._do_post(LIVE_ACTION_1) @@ -1293,24 +1347,26 @@ def test_put_resume_already_running(self): execution_id = self._get_actionexecution_id(post_resp) - updates = {'status': 'running'} + updates = {"status": "running"} put_resp = self._do_put(execution_id, updates) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['status'], 'running') + self.assertEqual(put_resp.json["status"], "running") - with mock.patch.object(action_service, 'update_status', return_value=None) as mocked: - updates = {'status': 'resuming'} + with mock.patch.object( + action_service, "update_status", return_value=None + ) as mocked: + updates = {"status": "resuming"} put_resp = self._do_put(execution_id, updates) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['status'], 'running') + self.assertEqual(put_resp.json["status"], "running") mocked.assert_not_called() get_resp = self._do_get_one(execution_id) self.assertEqual(get_resp.status_int, 200) - self.assertEqual(get_resp.json['status'], 'running') - self.assertIsNone(get_resp.json.get('result')) + self.assertEqual(get_resp.json["status"], "running") + self.assertIsNone(get_resp.json.get("result")) finally: - action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION_1['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION_1["runner_type"]) def test_get_inquiry_mask(self): """Ensure Inquiry responses are masked when retrieved via ActionExecution GET @@ -1327,194 +1383,213 @@ def test_get_inquiry_mask(self): self.assertEqual(get_resp.status_int, 200) resp = json.loads(get_resp.body) - self.assertEqual(resp['result']['response']['secondfactor'], MASKED_ATTRIBUTE_VALUE) + self.assertEqual( + resp["result"]["response"]["secondfactor"], MASKED_ATTRIBUTE_VALUE + ) post_resp = self._do_post(LIVE_ACTION_INQUIRY) actionexecution_id = self._get_actionexecution_id(post_resp) - get_resp = self._do_get_one(actionexecution_id, params={'show_secrets': True}) + get_resp = self._do_get_one(actionexecution_id, params={"show_secrets": True}) self.assertEqual(get_resp.status_int, 200) resp = json.loads(get_resp.body) - self.assertEqual(resp['result']['response']['secondfactor'], "supersecretvalue") + self.assertEqual(resp["result"]["response"]["secondfactor"], "supersecretvalue") def test_get_include_attributes_and_secret_parameters(self): # Verify that secret parameters are correctly masked when using ?include_attributes filter self._do_post(LIVE_ACTION_WITH_SECRET_PARAM) urls = [ - '/v1/actionexecutions?include_attributes=parameters', - '/v1/actionexecutions?include_attributes=parameters,action', - '/v1/actionexecutions?include_attributes=parameters,runner', - '/v1/actionexecutions?include_attributes=parameters,action,runner' + "/v1/actionexecutions?include_attributes=parameters", + "/v1/actionexecutions?include_attributes=parameters,action", + "/v1/actionexecutions?include_attributes=parameters,runner", + "/v1/actionexecutions?include_attributes=parameters,action,runner", ] for url in urls: - resp = self.app.get(url + '&limit=1') + resp = self.app.get(url + "&limit=1") - self.assertIn('parameters', resp.json[0]) - self.assertEqual(resp.json[0]['parameters']['a'], 'param a') - self.assertEqual(resp.json[0]['parameters']['d'], MASKED_ATTRIBUTE_VALUE) - self.assertEqual(resp.json[0]['parameters']['password'], MASKED_ATTRIBUTE_VALUE) - self.assertEqual(resp.json[0]['parameters']['hosts'], 'localhost') + self.assertIn("parameters", resp.json[0]) + self.assertEqual(resp.json[0]["parameters"]["a"], "param a") + self.assertEqual(resp.json[0]["parameters"]["d"], MASKED_ATTRIBUTE_VALUE) + self.assertEqual( + resp.json[0]["parameters"]["password"], MASKED_ATTRIBUTE_VALUE + ) + self.assertEqual(resp.json[0]["parameters"]["hosts"], "localhost") # With ?show_secrets=True urls = [ - ('/v1/actionexecutions?&include_attributes=parameters'), - ('/v1/actionexecutions?include_attributes=parameters,action'), - ('/v1/actionexecutions?include_attributes=parameters,runner'), - ('/v1/actionexecutions?include_attributes=parameters,action,runner') + ("/v1/actionexecutions?&include_attributes=parameters"), + ("/v1/actionexecutions?include_attributes=parameters,action"), + ("/v1/actionexecutions?include_attributes=parameters,runner"), + ("/v1/actionexecutions?include_attributes=parameters,action,runner"), ] for url in urls: - resp = self.app.get(url + '&limit=1&show_secrets=True') + resp = self.app.get(url + "&limit=1&show_secrets=True") - self.assertIn('parameters', resp.json[0]) - self.assertEqual(resp.json[0]['parameters']['a'], 'param a') - self.assertEqual(resp.json[0]['parameters']['d'], 'secretpassword1') - self.assertEqual(resp.json[0]['parameters']['password'], 'secretpassword2') - self.assertEqual(resp.json[0]['parameters']['hosts'], 'localhost') + self.assertIn("parameters", resp.json[0]) + self.assertEqual(resp.json[0]["parameters"]["a"], "param a") + self.assertEqual(resp.json[0]["parameters"]["d"], "secretpassword1") + self.assertEqual(resp.json[0]["parameters"]["password"], "secretpassword2") + self.assertEqual(resp.json[0]["parameters"]["hosts"], "localhost") # NOTE: We don't allow exclusion of attributes such as "action" and "runner" because # that would break secrets masking urls = [ - '/v1/actionexecutions?limit=1&exclude_attributes=action', - '/v1/actionexecutions?limit=1&exclude_attributes=runner', - '/v1/actionexecutions?limit=1&exclude_attributes=action,runner', + "/v1/actionexecutions?limit=1&exclude_attributes=action", + "/v1/actionexecutions?limit=1&exclude_attributes=runner", + "/v1/actionexecutions?limit=1&exclude_attributes=action,runner", ] for url in urls: - resp = self.app.get(url + '&limit=1', expect_errors=True) + resp = self.app.get(url + "&limit=1", expect_errors=True) self.assertEqual(resp.status_int, 400) - self.assertTrue('Invalid or unsupported exclude attribute specified:' in - resp.json['faultstring']) + self.assertTrue( + "Invalid or unsupported exclude attribute specified:" + in resp.json["faultstring"] + ) def test_get_single_attribute_success(self): - exec_id = self.app.get('/v1/actionexecutions?limit=1').json[0]['id'] + exec_id = self.app.get("/v1/actionexecutions?limit=1").json[0]["id"] - resp = self.app.get('/v1/executions/%s/attribute/status' % (exec_id)) + resp = self.app.get("/v1/executions/%s/attribute/status" % (exec_id)) self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json, 'requested') + self.assertEqual(resp.json, "requested") - resp = self.app.get('/v1/executions/%s/attribute/result' % (exec_id)) + resp = self.app.get("/v1/executions/%s/attribute/result" % (exec_id)) self.assertEqual(resp.status_int, 200) self.assertEqual(resp.json, None) - resp = self.app.get('/v1/executions/%s/attribute/trigger_instance' % (exec_id)) + resp = self.app.get("/v1/executions/%s/attribute/trigger_instance" % (exec_id)) self.assertEqual(resp.status_int, 200) self.assertEqual(resp.json, None) data = {} - data['status'] = action_constants.LIVEACTION_STATUS_SUCCEEDED - data['result'] = {'foo': 'bar'} + data["status"] = action_constants.LIVEACTION_STATUS_SUCCEEDED + data["result"] = {"foo": "bar"} - resp = self.app.put_json('/v1/executions/%s' % (exec_id), data) + resp = self.app.put_json("/v1/executions/%s" % (exec_id), data) self.assertEqual(resp.status_int, 200) - resp = self.app.get('/v1/executions/%s/attribute/result' % (exec_id)) + resp = self.app.get("/v1/executions/%s/attribute/result" % (exec_id)) self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json, data['result']) + self.assertEqual(resp.json, data["result"]) def test_get_single_attribute_failure_invalid_attribute(self): - exec_id = self.app.get('/v1/actionexecutions?limit=1').json[0]['id'] + exec_id = self.app.get("/v1/actionexecutions?limit=1").json[0]["id"] - resp = self.app.get('/v1/executions/%s/attribute/start_timestamp' % (exec_id), - expect_errors=True) + resp = self.app.get( + "/v1/executions/%s/attribute/start_timestamp" % (exec_id), + expect_errors=True, + ) self.assertEqual(resp.status_int, 400) - self.assertTrue('Invalid attribute "start_timestamp" specified.' in - resp.json['faultstring']) + self.assertTrue( + 'Invalid attribute "start_timestamp" specified.' in resp.json["faultstring"] + ) def test_get_single_include_attributes_and_secret_parameters(self): # Verify that secret parameters are correctly masked when using ?include_attributes filter self._do_post(LIVE_ACTION_WITH_SECRET_PARAM) - exec_id = self.app.get('/v1/actionexecutions?limit=1').json[0]['id'] + exec_id = self.app.get("/v1/actionexecutions?limit=1").json[0]["id"] # FYI, the response always contains the 'id' parameter urls = [ { - 'url': '/v1/executions/%s?include_attributes=parameters' % (exec_id), - 'expected_parameters': ['id', 'parameters'], + "url": "/v1/executions/%s?include_attributes=parameters" % (exec_id), + "expected_parameters": ["id", "parameters"], }, { - 'url': '/v1/executions/%s?include_attributes=parameters,action' % (exec_id), - 'expected_parameters': ['id', 'parameters', 'action'], + "url": "/v1/executions/%s?include_attributes=parameters,action" + % (exec_id), + "expected_parameters": ["id", "parameters", "action"], }, { - 'url': '/v1/executions/%s?include_attributes=parameters,runner' % (exec_id), - 'expected_parameters': ['id', 'parameters', 'runner'], + "url": "/v1/executions/%s?include_attributes=parameters,runner" + % (exec_id), + "expected_parameters": ["id", "parameters", "runner"], }, { - 'url': '/v1/executions/%s?include_attributes=parameters,action,runner' % (exec_id), - 'expected_parameters': ['id', 'parameters', 'action', 'runner'], - } + "url": "/v1/executions/%s?include_attributes=parameters,action,runner" + % (exec_id), + "expected_parameters": ["id", "parameters", "action", "runner"], + }, ] for item in urls: - url = item['url'] + url = item["url"] resp = self.app.get(url) - self.assertIn('parameters', resp.json) - self.assertEqual(resp.json['parameters']['a'], 'param a') - self.assertEqual(resp.json['parameters']['d'], MASKED_ATTRIBUTE_VALUE) - self.assertEqual(resp.json['parameters']['password'], MASKED_ATTRIBUTE_VALUE) - self.assertEqual(resp.json['parameters']['hosts'], 'localhost') + self.assertIn("parameters", resp.json) + self.assertEqual(resp.json["parameters"]["a"], "param a") + self.assertEqual(resp.json["parameters"]["d"], MASKED_ATTRIBUTE_VALUE) + self.assertEqual( + resp.json["parameters"]["password"], MASKED_ATTRIBUTE_VALUE + ) + self.assertEqual(resp.json["parameters"]["hosts"], "localhost") # ensure that the response has only the keys we epect, no more, no less resp_keys = set(resp.json.keys()) - expected_params = set(item['expected_parameters']) + expected_params = set(item["expected_parameters"]) diff = resp_keys.symmetric_difference(expected_params) self.assertEqual(diff, set()) # With ?show_secrets=True urls = [ { - 'url': '/v1/executions/%s?&include_attributes=parameters' % (exec_id), - 'expected_parameters': ['id', 'parameters'], + "url": "/v1/executions/%s?&include_attributes=parameters" % (exec_id), + "expected_parameters": ["id", "parameters"], }, { - 'url': '/v1/executions/%s?include_attributes=parameters,action' % (exec_id), - 'expected_parameters': ['id', 'parameters', 'action'], + "url": "/v1/executions/%s?include_attributes=parameters,action" + % (exec_id), + "expected_parameters": ["id", "parameters", "action"], }, { - 'url': '/v1/executions/%s?include_attributes=parameters,runner' % (exec_id), - 'expected_parameters': ['id', 'parameters', 'runner'], + "url": "/v1/executions/%s?include_attributes=parameters,runner" + % (exec_id), + "expected_parameters": ["id", "parameters", "runner"], }, { - 'url': '/v1/executions/%s?include_attributes=parameters,action,runner' % (exec_id), - 'expected_parameters': ['id', 'parameters', 'action', 'runner'], + "url": "/v1/executions/%s?include_attributes=parameters,action,runner" + % (exec_id), + "expected_parameters": ["id", "parameters", "action", "runner"], }, ] for item in urls: - url = item['url'] - resp = self.app.get(url + '&show_secrets=True') + url = item["url"] + resp = self.app.get(url + "&show_secrets=True") - self.assertIn('parameters', resp.json) - self.assertEqual(resp.json['parameters']['a'], 'param a') - self.assertEqual(resp.json['parameters']['d'], 'secretpassword1') - self.assertEqual(resp.json['parameters']['password'], 'secretpassword2') - self.assertEqual(resp.json['parameters']['hosts'], 'localhost') + self.assertIn("parameters", resp.json) + self.assertEqual(resp.json["parameters"]["a"], "param a") + self.assertEqual(resp.json["parameters"]["d"], "secretpassword1") + self.assertEqual(resp.json["parameters"]["password"], "secretpassword2") + self.assertEqual(resp.json["parameters"]["hosts"], "localhost") # ensure that the response has only the keys we epect, no more, no less resp_keys = set(resp.json.keys()) - expected_params = set(item['expected_parameters']) + expected_params = set(item["expected_parameters"]) diff = resp_keys.symmetric_difference(expected_params) self.assertEqual(diff, set()) # NOTE: We don't allow exclusion of attributes such as "action" and "runner" because # that would break secrets masking urls = [ - '/v1/executions/%s?limit=1&exclude_attributes=action', - '/v1/executions/%s?limit=1&exclude_attributes=runner', - '/v1/executions/%s?limit=1&exclude_attributes=action,runner', + "/v1/executions/%s?limit=1&exclude_attributes=action", + "/v1/executions/%s?limit=1&exclude_attributes=runner", + "/v1/executions/%s?limit=1&exclude_attributes=action,runner", ] for url in urls: resp = self.app.get(url, expect_errors=True) self.assertEqual(resp.status_int, 400) - self.assertTrue('Invalid or unsupported exclude attribute specified:' in - resp.json['faultstring']) + self.assertTrue( + "Invalid or unsupported exclude attribute specified:" + in resp.json["faultstring"] + ) def _insert_mock_models(self): execution_1_id = self._get_actionexecution_id(self._do_post(LIVE_ACTION_1)) @@ -1522,37 +1597,44 @@ def _insert_mock_models(self): return [execution_1_id, execution_2_id] -class ActionExecutionOutputControllerTestCase(BaseActionExecutionControllerTestCase, - FunctionalTest): +class ActionExecutionOutputControllerTestCase( + BaseActionExecutionControllerTestCase, FunctionalTest +): def test_get_output_id_last_no_executions_in_the_database(self): ActionExecution.query().delete() - resp = self.app.get('/v1/executions/last/output', expect_errors=True) + resp = self.app.get("/v1/executions/last/output", expect_errors=True) self.assertEqual(resp.status_int, http_client.BAD_REQUEST) - self.assertEqual(resp.json['faultstring'], 'No executions found in the database') + self.assertEqual( + resp.json["faultstring"], "No executions found in the database" + ) def test_get_output_running_execution(self): # Only the output produced so far should be returned # Test the execution output API endpoint for execution which is running (blocking) status = action_constants.LIVEACTION_STATUS_RUNNING timestamp = date_utils.get_datetime_utc_now() - action_execution_db = ActionExecutionDB(start_timestamp=timestamp, - end_timestamp=timestamp, - status=status, - action={'ref': 'core.local'}, - runner={'name': 'local-shell-cmd'}, - liveaction={'ref': 'foo'}) + action_execution_db = ActionExecutionDB( + start_timestamp=timestamp, + end_timestamp=timestamp, + status=status, + action={"ref": "core.local"}, + runner={"name": "local-shell-cmd"}, + liveaction={"ref": "foo"}, + ) action_execution_db = ActionExecution.add_or_update(action_execution_db) - output_params = dict(execution_id=str(action_execution_db.id), - action_ref='core.local', - runner_ref='dummy', - timestamp=timestamp, - output_type='stdout', - data='stdout before start\n') + output_params = dict( + execution_id=str(action_execution_db.id), + action_ref="core.local", + runner_ref="dummy", + timestamp=timestamp, + output_type="stdout", + data="stdout before start\n", + ) def insert_mock_data(data): - output_params['data'] = data + output_params["data"] = data output_db = ActionExecutionOutputDB(**output_params) ActionExecutionOutput.add_or_update(output_db) @@ -1561,45 +1643,51 @@ def insert_mock_data(data): ActionExecutionOutput.add_or_update(output_db, publish=False) # Retrieve data while execution is running - data produced so far should be retrieved - resp = self.app.get('/v1/executions/%s/output' % (str(action_execution_db.id)), - expect_errors=False) + resp = self.app.get( + "/v1/executions/%s/output" % (str(action_execution_db.id)), + expect_errors=False, + ) self.assertEqual(resp.status_int, 200) - lines = resp.text.strip().split('\n') + lines = resp.text.strip().split("\n") lines = [line for line in lines if line.strip()] self.assertEqual(len(lines), 1) - self.assertEqual(lines[0], 'stdout before start') + self.assertEqual(lines[0], "stdout before start") # Insert more data - insert_mock_data('stdout mid 1\n') + insert_mock_data("stdout mid 1\n") # Retrieve data while execution is running - data produced so far should be retrieved - resp = self.app.get('/v1/executions/%s/output' % (str(action_execution_db.id)), - expect_errors=False) + resp = self.app.get( + "/v1/executions/%s/output" % (str(action_execution_db.id)), + expect_errors=False, + ) self.assertEqual(resp.status_int, 200) - lines = resp.text.strip().split('\n') + lines = resp.text.strip().split("\n") lines = [line for line in lines if line.strip()] self.assertEqual(len(lines), 2) - self.assertEqual(lines[0], 'stdout before start') - self.assertEqual(lines[1], 'stdout mid 1') + self.assertEqual(lines[0], "stdout before start") + self.assertEqual(lines[1], "stdout mid 1") # Insert more data - insert_mock_data('stdout pre finish 1\n') + insert_mock_data("stdout pre finish 1\n") # Transition execution to completed state action_execution_db.status = action_constants.LIVEACTION_STATUS_SUCCEEDED action_execution_db = ActionExecution.add_or_update(action_execution_db) # Execution has finished - resp = self.app.get('/v1/executions/%s/output' % (str(action_execution_db.id)), - expect_errors=False) + resp = self.app.get( + "/v1/executions/%s/output" % (str(action_execution_db.id)), + expect_errors=False, + ) self.assertEqual(resp.status_int, 200) - lines = resp.text.strip().split('\n') + lines = resp.text.strip().split("\n") lines = [line for line in lines if line.strip()] self.assertEqual(len(lines), 3) - self.assertEqual(lines[0], 'stdout before start') - self.assertEqual(lines[1], 'stdout mid 1') - self.assertEqual(lines[2], 'stdout pre finish 1') + self.assertEqual(lines[0], "stdout before start") + self.assertEqual(lines[1], "stdout mid 1") + self.assertEqual(lines[2], "stdout pre finish 1") def test_get_output_finished_execution(self): # Test the execution output API endpoint for execution which has finished @@ -1607,42 +1695,50 @@ def test_get_output_finished_execution(self): # Insert mock execution and output objects status = action_constants.LIVEACTION_STATUS_SUCCEEDED timestamp = date_utils.get_datetime_utc_now() - action_execution_db = ActionExecutionDB(start_timestamp=timestamp, - end_timestamp=timestamp, - status=status, - action={'ref': 'core.local'}, - runner={'name': 'local-shell-cmd'}, - liveaction={'ref': 'foo'}) + action_execution_db = ActionExecutionDB( + start_timestamp=timestamp, + end_timestamp=timestamp, + status=status, + action={"ref": "core.local"}, + runner={"name": "local-shell-cmd"}, + liveaction={"ref": "foo"}, + ) action_execution_db = ActionExecution.add_or_update(action_execution_db) for i in range(1, 6): - stdout_db = ActionExecutionOutputDB(execution_id=str(action_execution_db.id), - action_ref='core.local', - runner_ref='dummy', - timestamp=timestamp, - output_type='stdout', - data='stdout %s\n' % (i)) + stdout_db = ActionExecutionOutputDB( + execution_id=str(action_execution_db.id), + action_ref="core.local", + runner_ref="dummy", + timestamp=timestamp, + output_type="stdout", + data="stdout %s\n" % (i), + ) ActionExecutionOutput.add_or_update(stdout_db) for i in range(10, 15): - stderr_db = ActionExecutionOutputDB(execution_id=str(action_execution_db.id), - action_ref='core.local', - runner_ref='dummy', - timestamp=timestamp, - output_type='stderr', - data='stderr %s\n' % (i)) + stderr_db = ActionExecutionOutputDB( + execution_id=str(action_execution_db.id), + action_ref="core.local", + runner_ref="dummy", + timestamp=timestamp, + output_type="stderr", + data="stderr %s\n" % (i), + ) ActionExecutionOutput.add_or_update(stderr_db) - resp = self.app.get('/v1/executions/%s/output' % (str(action_execution_db.id)), - expect_errors=False) + resp = self.app.get( + "/v1/executions/%s/output" % (str(action_execution_db.id)), + expect_errors=False, + ) self.assertEqual(resp.status_int, 200) - lines = resp.text.strip().split('\n') + lines = resp.text.strip().split("\n") self.assertEqual(len(lines), 10) - self.assertEqual(lines[0], 'stdout 1') - self.assertEqual(lines[9], 'stderr 14') + self.assertEqual(lines[0], "stdout 1") + self.assertEqual(lines[9], "stderr 14") # Verify "last" short-hand id works - resp = self.app.get('/v1/executions/last/output', expect_errors=False) + resp = self.app.get("/v1/executions/last/output", expect_errors=False) self.assertEqual(resp.status_int, 200) - lines = resp.text.strip().split('\n') + lines = resp.text.strip().split("\n") self.assertEqual(len(lines), 10) diff --git a/st2api/tests/unit/controllers/v1/test_executions_auth.py b/st2api/tests/unit/controllers/v1/test_executions_auth.py index e408d053dca..f1045a7d548 100644 --- a/st2api/tests/unit/controllers/v1/test_executions_auth.py +++ b/st2api/tests/unit/controllers/v1/test_executions_auth.py @@ -44,61 +44,48 @@ ACTION_1 = { - 'name': 'st2.dummy.action1', - 'description': 'test description', - 'enabled': True, - 'entry_point': '/tmp/test/action1.sh', - 'pack': 'sixpack', - 'runner_type': 'remote-shell-cmd', - 'parameters': { - 'a': { - 'type': 'string', - 'default': 'abc' - }, - 'b': { - 'type': 'number', - 'default': 123 - }, - 'c': { - 'type': 'number', - 'default': 123, - 'immutable': True - }, - 'd': { - 'type': 'string', - 'secret': True - } - } + "name": "st2.dummy.action1", + "description": "test description", + "enabled": True, + "entry_point": "/tmp/test/action1.sh", + "pack": "sixpack", + "runner_type": "remote-shell-cmd", + "parameters": { + "a": {"type": "string", "default": "abc"}, + "b": {"type": "number", "default": 123}, + "c": {"type": "number", "default": 123, "immutable": True}, + "d": {"type": "string", "secret": True}, + }, } ACTION_DEFAULT_ENCRYPT = { - 'name': 'st2.dummy.default_encrypted_value', - 'description': 'An action that uses a jinja template with decrypt_kv filter ' - 'in default parameter', - 'enabled': True, - 'pack': 'starterpack', - 'runner_type': 'local-shell-cmd', - 'parameters': { - 'encrypted_param': { - 'type': 'string', - 'default': '{{ st2kv.system.secret | decrypt_kv }}' + "name": "st2.dummy.default_encrypted_value", + "description": "An action that uses a jinja template with decrypt_kv filter " + "in default parameter", + "enabled": True, + "pack": "starterpack", + "runner_type": "local-shell-cmd", + "parameters": { + "encrypted_param": { + "type": "string", + "default": "{{ st2kv.system.secret | decrypt_kv }}", }, - 'encrypted_user_param': { - 'type': 'string', - 'default': '{{ st2kv.user.secret | decrypt_kv }}' - } - } + "encrypted_user_param": { + "type": "string", + "default": "{{ st2kv.user.secret | decrypt_kv }}", + }, + }, } LIVE_ACTION_1 = { - 'action': 'sixpack.st2.dummy.action1', - 'parameters': { - 'hosts': 'localhost', - 'cmd': 'uname -a', - 'd': SUPER_SECRET_PARAMETER - } + "action": "sixpack.st2.dummy.action1", + "parameters": { + "hosts": "localhost", + "cmd": "uname -a", + "d": SUPER_SECRET_PARAMETER, + }, } LIVE_ACTION_DEFAULT_ENCRYPT = { - 'action': 'starterpack.st2.dummy.default_encrypted_value', + "action": "starterpack.st2.dummy.default_encrypted_value", } # NOTE: We use a longer expiry time because this variable is initialized on module import (aka @@ -107,19 +94,23 @@ # by that time and the tests would fail. NOW = date_utils.get_datetime_utc_now() EXPIRY = NOW + datetime.timedelta(seconds=1000) -SYS_TOKEN = TokenDB(id=bson.ObjectId(), user='system', token=uuid.uuid4().hex, expiry=EXPIRY) -USR_TOKEN = TokenDB(id=bson.ObjectId(), user='tokenuser', token=uuid.uuid4().hex, expiry=EXPIRY) +SYS_TOKEN = TokenDB( + id=bson.ObjectId(), user="system", token=uuid.uuid4().hex, expiry=EXPIRY +) +USR_TOKEN = TokenDB( + id=bson.ObjectId(), user="tokenuser", token=uuid.uuid4().hex, expiry=EXPIRY +) -FIXTURES_PACK = 'generic' -FIXTURES = { - 'users': ['system_user.yaml', 'token_user.yaml'] -} +FIXTURES_PACK = "generic" +FIXTURES = {"users": ["system_user.yaml", "token_user.yaml"]} # These parameters are used for the tests of getting value from datastore and decrypting it at # Jinja expression in a action metadata definition. -TEST_USER = UserDB(name='user1') -TEST_TOKEN = TokenDB(id=bson.ObjectId(), user=TEST_USER, token=uuid.uuid4().hex, expiry=EXPIRY) -TEST_APIKEY = ApiKeyDB(user=TEST_USER, key_hash='secret_key', enabled=True) +TEST_USER = UserDB(name="user1") +TEST_TOKEN = TokenDB( + id=bson.ObjectId(), user=TEST_USER, token=uuid.uuid4().hex, expiry=EXPIRY +) +TEST_APIKEY = ApiKeyDB(user=TEST_USER, key_hash="secret_key", enabled=True) def mock_get_token(*args, **kwargs): @@ -128,50 +119,69 @@ def mock_get_token(*args, **kwargs): return USR_TOKEN -@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) +@mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) class ActionExecutionControllerTestCaseAuthEnabled(FunctionalTest): enable_auth = True @classmethod + @mock.patch.object(Token, "get", mock.MagicMock(side_effect=mock_get_token)) + @mock.patch.object(User, "get_by_name", mock.MagicMock(side_effect=UserDB)) @mock.patch.object( - Token, 'get', - mock.MagicMock(side_effect=mock_get_token)) - @mock.patch.object(User, 'get_by_name', mock.MagicMock(side_effect=UserDB)) - @mock.patch.object(action_validator, 'validate_action', mock.MagicMock( - return_value=True)) + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def setUpClass(cls): super(ActionExecutionControllerTestCaseAuthEnabled, cls).setUpClass() cls.action = copy.deepcopy(ACTION_1) - headers = {'content-type': 'application/json', 'X-Auth-Token': str(SYS_TOKEN.token)} - post_resp = cls.app.post_json('/v1/actions', cls.action, headers=headers) - cls.action['id'] = post_resp.json['id'] + headers = { + "content-type": "application/json", + "X-Auth-Token": str(SYS_TOKEN.token), + } + post_resp = cls.app.post_json("/v1/actions", cls.action, headers=headers) + cls.action["id"] = post_resp.json["id"] cls.action_encrypt = copy.deepcopy(ACTION_DEFAULT_ENCRYPT) - post_resp = cls.app.post_json('/v1/actions', cls.action_encrypt, headers=headers) - cls.action_encrypt['id'] = post_resp.json['id'] + post_resp = cls.app.post_json( + "/v1/actions", cls.action_encrypt, headers=headers + ) + cls.action_encrypt["id"] = post_resp.json["id"] - FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK, - fixtures_dict=FIXTURES) + FixturesLoader().save_fixtures_to_db( + fixtures_pack=FIXTURES_PACK, fixtures_dict=FIXTURES + ) # register datastore values which are used in this tests KeyValuePairAPI._setup_crypto() register_items = [ - {'name': 'secret', 'secret': True, - 'value': crypto_utils.symmetric_encrypt(KeyValuePairAPI.crypto_key, 'foo')}, - {'name': 'user1:secret', 'secret': True, 'scope': FULL_USER_SCOPE, - 'value': crypto_utils.symmetric_encrypt(KeyValuePairAPI.crypto_key, 'bar')}, + { + "name": "secret", + "secret": True, + "value": crypto_utils.symmetric_encrypt( + KeyValuePairAPI.crypto_key, "foo" + ), + }, + { + "name": "user1:secret", + "secret": True, + "scope": FULL_USER_SCOPE, + "value": crypto_utils.symmetric_encrypt( + KeyValuePairAPI.crypto_key, "bar" + ), + }, + ] + cls.kvps = [ + KeyValuePair.add_or_update(KeyValuePairDB(**x)) for x in register_items ] - cls.kvps = [KeyValuePair.add_or_update(KeyValuePairDB(**x)) for x in register_items] @classmethod - @mock.patch.object( - Token, 'get', - mock.MagicMock(side_effect=mock_get_token)) + @mock.patch.object(Token, "get", mock.MagicMock(side_effect=mock_get_token)) def tearDownClass(cls): - headers = {'content-type': 'application/json', 'X-Auth-Token': str(SYS_TOKEN.token)} - cls.app.delete('/v1/actions/%s' % cls.action['id'], headers=headers) - cls.app.delete('/v1/actions/%s' % cls.action_encrypt['id'], headers=headers) + headers = { + "content-type": "application/json", + "X-Auth-Token": str(SYS_TOKEN.token), + } + cls.app.delete("/v1/actions/%s" % cls.action["id"], headers=headers) + cls.app.delete("/v1/actions/%s" % cls.action_encrypt["id"], headers=headers) # unregister key-value pairs for tests [KeyValuePair.delete(x) for x in cls.kvps] @@ -179,49 +189,53 @@ def tearDownClass(cls): super(ActionExecutionControllerTestCaseAuthEnabled, cls).tearDownClass() def _do_post(self, liveaction, *args, **kwargs): - return self.app.post_json('/v1/executions', liveaction, *args, **kwargs) + return self.app.post_json("/v1/executions", liveaction, *args, **kwargs) - @mock.patch.object( - Token, 'get', - mock.MagicMock(side_effect=mock_get_token)) + @mock.patch.object(Token, "get", mock.MagicMock(side_effect=mock_get_token)) def test_post_with_st2_context_in_headers(self): - headers = {'content-type': 'application/json', 'X-Auth-Token': str(USR_TOKEN.token)} + headers = { + "content-type": "application/json", + "X-Auth-Token": str(USR_TOKEN.token), + } resp = self._do_post(copy.deepcopy(LIVE_ACTION_1), headers=headers) self.assertEqual(resp.status_int, 201) - token_user = resp.json['context']['user'] - self.assertEqual(token_user, 'tokenuser') - context = {'parent': {'execution_id': str(resp.json['id']), 'user': token_user}} - headers = {'content-type': 'application/json', - 'X-Auth-Token': str(SYS_TOKEN.token), - 'st2-context': json.dumps(context)} + token_user = resp.json["context"]["user"] + self.assertEqual(token_user, "tokenuser") + context = {"parent": {"execution_id": str(resp.json["id"]), "user": token_user}} + headers = { + "content-type": "application/json", + "X-Auth-Token": str(SYS_TOKEN.token), + "st2-context": json.dumps(context), + } resp = self._do_post(copy.deepcopy(LIVE_ACTION_1), headers=headers) self.assertEqual(resp.status_int, 201) - self.assertEqual(resp.json['context']['user'], 'tokenuser') - self.assertEqual(resp.json['context']['parent'], context['parent']) + self.assertEqual(resp.json["context"]["user"], "tokenuser") + self.assertEqual(resp.json["context"]["parent"], context["parent"]) - @mock.patch.object(ApiKey, 'get', mock.Mock(return_value=TEST_APIKEY)) - @mock.patch.object(User, 'get_by_name', mock.Mock(return_value=TEST_USER)) + @mock.patch.object(ApiKey, "get", mock.Mock(return_value=TEST_APIKEY)) + @mock.patch.object(User, "get_by_name", mock.Mock(return_value=TEST_USER)) def test_template_encrypted_params_with_apikey(self): - resp = self._do_post(LIVE_ACTION_DEFAULT_ENCRYPT, headers={ - 'St2-Api-key': 'secret_key' - }) + resp = self._do_post( + LIVE_ACTION_DEFAULT_ENCRYPT, headers={"St2-Api-key": "secret_key"} + ) self.assertEqual(resp.status_int, 201) - self.assertEqual(resp.json['parameters']['encrypted_param'], 'foo') - self.assertEqual(resp.json['parameters']['encrypted_user_param'], 'bar') + self.assertEqual(resp.json["parameters"]["encrypted_param"], "foo") + self.assertEqual(resp.json["parameters"]["encrypted_user_param"], "bar") - @mock.patch.object(Token, 'get', mock.Mock(return_value=TEST_TOKEN)) - @mock.patch.object(User, 'get_by_name', mock.Mock(return_value=TEST_USER)) + @mock.patch.object(Token, "get", mock.Mock(return_value=TEST_TOKEN)) + @mock.patch.object(User, "get_by_name", mock.Mock(return_value=TEST_USER)) def test_template_encrypted_params_with_access_token(self): - resp = self._do_post(LIVE_ACTION_DEFAULT_ENCRYPT, headers={ - 'X-Auth-Token': str(TEST_TOKEN.token) - }) + resp = self._do_post( + LIVE_ACTION_DEFAULT_ENCRYPT, headers={"X-Auth-Token": str(TEST_TOKEN.token)} + ) self.assertEqual(resp.status_int, 201) - self.assertEqual(resp.json['parameters']['encrypted_param'], 'foo') - self.assertEqual(resp.json['parameters']['encrypted_user_param'], 'bar') + self.assertEqual(resp.json["parameters"]["encrypted_param"], "foo") + self.assertEqual(resp.json["parameters"]["encrypted_user_param"], "bar") def test_template_encrypted_params_without_auth(self): resp = self._do_post(LIVE_ACTION_DEFAULT_ENCRYPT, expect_errors=True) self.assertEqual(resp.status_int, 401) - self.assertEqual(resp.json['faultstring'], - 'Unauthorized - One of Token or API key required.') + self.assertEqual( + resp.json["faultstring"], "Unauthorized - One of Token or API key required." + ) diff --git a/st2api/tests/unit/controllers/v1/test_executions_descendants.py b/st2api/tests/unit/controllers/v1/test_executions_descendants.py index 1afbcdde2f1..945e03feeb9 100644 --- a/st2api/tests/unit/controllers/v1/test_executions_descendants.py +++ b/st2api/tests/unit/controllers/v1/test_executions_descendants.py @@ -19,64 +19,85 @@ from st2tests.api import FunctionalTest -DESCENDANTS_PACK = 'descendants' +DESCENDANTS_PACK = "descendants" DESCENDANTS_FIXTURES = { - 'executions': ['root_execution.yaml', 'child1_level1.yaml', 'child2_level1.yaml', - 'child1_level2.yaml', 'child2_level2.yaml', 'child3_level2.yaml', - 'child1_level3.yaml', 'child2_level3.yaml', 'child3_level3.yaml'] + "executions": [ + "root_execution.yaml", + "child1_level1.yaml", + "child2_level1.yaml", + "child1_level2.yaml", + "child2_level2.yaml", + "child3_level2.yaml", + "child1_level3.yaml", + "child2_level3.yaml", + "child3_level3.yaml", + ] } class ActionExecutionControllerTestCaseDescendantsTest(FunctionalTest): - @classmethod def setUpClass(cls): super(ActionExecutionControllerTestCaseDescendantsTest, cls).setUpClass() - cls.MODELS = FixturesLoader().save_fixtures_to_db(fixtures_pack=DESCENDANTS_PACK, - fixtures_dict=DESCENDANTS_FIXTURES) + cls.MODELS = FixturesLoader().save_fixtures_to_db( + fixtures_pack=DESCENDANTS_PACK, fixtures_dict=DESCENDANTS_FIXTURES + ) def test_get_all_descendants(self): - root_execution = self.MODELS['executions']['root_execution.yaml'] - resp = self.app.get('/v1/executions/%s/children' % str(root_execution.id)) + root_execution = self.MODELS["executions"]["root_execution.yaml"] + resp = self.app.get("/v1/executions/%s/children" % str(root_execution.id)) self.assertEqual(resp.status_int, 200) - all_descendants_ids = [descendant['id'] for descendant in resp.json] + all_descendants_ids = [descendant["id"] for descendant in resp.json] all_descendants_ids.sort() # everything except the root_execution - expected_ids = [str(v.id) for _, v in six.iteritems(self.MODELS['executions']) - if v.id != root_execution.id] + expected_ids = [ + str(v.id) + for _, v in six.iteritems(self.MODELS["executions"]) + if v.id != root_execution.id + ] expected_ids.sort() self.assertListEqual(all_descendants_ids, expected_ids) def test_get_all_descendants_depth_neg_1(self): - root_execution = self.MODELS['executions']['root_execution.yaml'] - resp = self.app.get('/v1/executions/%s/children?depth=-1' % str(root_execution.id)) + root_execution = self.MODELS["executions"]["root_execution.yaml"] + resp = self.app.get( + "/v1/executions/%s/children?depth=-1" % str(root_execution.id) + ) self.assertEqual(resp.status_int, 200) - all_descendants_ids = [descendant['id'] for descendant in resp.json] + all_descendants_ids = [descendant["id"] for descendant in resp.json] all_descendants_ids.sort() # everything except the root_execution - expected_ids = [str(v.id) for _, v in six.iteritems(self.MODELS['executions']) - if v.id != root_execution.id] + expected_ids = [ + str(v.id) + for _, v in six.iteritems(self.MODELS["executions"]) + if v.id != root_execution.id + ] expected_ids.sort() self.assertListEqual(all_descendants_ids, expected_ids) def test_get_1_level_descendants(self): - root_execution = self.MODELS['executions']['root_execution.yaml'] - resp = self.app.get('/v1/executions/%s/children?depth=1' % str(root_execution.id)) + root_execution = self.MODELS["executions"]["root_execution.yaml"] + resp = self.app.get( + "/v1/executions/%s/children?depth=1" % str(root_execution.id) + ) self.assertEqual(resp.status_int, 200) - all_descendants_ids = [descendant['id'] for descendant in resp.json] + all_descendants_ids = [descendant["id"] for descendant in resp.json] all_descendants_ids.sort() # All children of root_execution - expected_ids = [str(v.id) for _, v in six.iteritems(self.MODELS['executions']) - if v.parent == str(root_execution.id)] + expected_ids = [ + str(v.id) + for _, v in six.iteritems(self.MODELS["executions"]) + if v.parent == str(root_execution.id) + ] expected_ids.sort() self.assertListEqual(all_descendants_ids, expected_ids) diff --git a/st2api/tests/unit/controllers/v1/test_executions_filters.py b/st2api/tests/unit/controllers/v1/test_executions_filters.py index e33e8bf87d5..af451ca5192 100644 --- a/st2api/tests/unit/controllers/v1/test_executions_filters.py +++ b/st2api/tests/unit/controllers/v1/test_executions_filters.py @@ -22,6 +22,7 @@ from six.moves import http_client import st2tests.config as tests_config + tests_config.parse_args() from st2tests.api import FunctionalTest @@ -36,7 +37,6 @@ class TestActionExecutionFilters(FunctionalTest): - @classmethod def testDownClass(cls): pass @@ -52,29 +52,33 @@ def setUpClass(cls): cls.start_timestamps = [] cls.fake_types = [ { - 'trigger': copy.deepcopy(fixture.ARTIFACTS['trigger']), - 'trigger_type': copy.deepcopy(fixture.ARTIFACTS['trigger_type']), - 'trigger_instance': copy.deepcopy(fixture.ARTIFACTS['trigger_instance']), - 'rule': copy.deepcopy(fixture.ARTIFACTS['rule']), - 'action': copy.deepcopy(fixture.ARTIFACTS['actions']['chain']), - 'runner': copy.deepcopy(fixture.ARTIFACTS['runners']['action-chain']), - 'liveaction': copy.deepcopy(fixture.ARTIFACTS['liveactions']['workflow']), - 'context': copy.deepcopy(fixture.ARTIFACTS['context']), - 'children': [] + "trigger": copy.deepcopy(fixture.ARTIFACTS["trigger"]), + "trigger_type": copy.deepcopy(fixture.ARTIFACTS["trigger_type"]), + "trigger_instance": copy.deepcopy( + fixture.ARTIFACTS["trigger_instance"] + ), + "rule": copy.deepcopy(fixture.ARTIFACTS["rule"]), + "action": copy.deepcopy(fixture.ARTIFACTS["actions"]["chain"]), + "runner": copy.deepcopy(fixture.ARTIFACTS["runners"]["action-chain"]), + "liveaction": copy.deepcopy( + fixture.ARTIFACTS["liveactions"]["workflow"] + ), + "context": copy.deepcopy(fixture.ARTIFACTS["context"]), + "children": [], }, { - 'action': copy.deepcopy(fixture.ARTIFACTS['actions']['local']), - 'runner': copy.deepcopy(fixture.ARTIFACTS['runners']['run-local']), - 'liveaction': copy.deepcopy(fixture.ARTIFACTS['liveactions']['task1']) - } + "action": copy.deepcopy(fixture.ARTIFACTS["actions"]["local"]), + "runner": copy.deepcopy(fixture.ARTIFACTS["runners"]["run-local"]), + "liveaction": copy.deepcopy(fixture.ARTIFACTS["liveactions"]["task1"]), + }, ] def assign_parent(child): - candidates = [v for k, v in cls.refs.items() if v.action['name'] == 'chain'] + candidates = [v for k, v in cls.refs.items() if v.action["name"] == "chain"] if candidates: parent = random.choice(candidates) - child['parent'] = str(parent.id) - parent.children.append(child['id']) + child["parent"] = str(parent.id) + parent.children.append(child["id"]) cls.refs[str(parent.id)] = ActionExecution.add_or_update(parent) for i in range(cls.num_records): @@ -82,12 +86,12 @@ def assign_parent(child): timestamp = cls.dt_base + datetime.timedelta(seconds=i) fake_type = random.choice(cls.fake_types) data = copy.deepcopy(fake_type) - data['id'] = obj_id - data['start_timestamp'] = isotime.format(timestamp, offset=False) - data['end_timestamp'] = isotime.format(timestamp, offset=False) - data['status'] = data['liveaction']['status'] - data['result'] = data['liveaction']['result'] - if fake_type['action']['name'] == 'local' and random.choice([True, False]): + data["id"] = obj_id + data["start_timestamp"] = isotime.format(timestamp, offset=False) + data["end_timestamp"] = isotime.format(timestamp, offset=False) + data["status"] = data["liveaction"]["status"] + data["result"] = data["liveaction"]["result"] + if fake_type["action"]["name"] == "local" and random.choice([True, False]): assign_parent(data) wb_obj = ActionExecutionAPI(**data) db_obj = ActionExecutionAPI.to_model(wb_obj) @@ -97,154 +101,185 @@ def assign_parent(child): cls.start_timestamps = sorted(cls.start_timestamps) def test_get_all(self): - response = self.app.get('/v1/executions') + response = self.app.get("/v1/executions") self.assertEqual(response.status_int, 200) self.assertIsInstance(response.json, list) self.assertEqual(len(response.json), self.num_records) - self.assertEqual(response.headers['X-Total-Count'], str(self.num_records)) - ids = [item['id'] for item in response.json] + self.assertEqual(response.headers["X-Total-Count"], str(self.num_records)) + ids = [item["id"] for item in response.json] self.assertListEqual(sorted(ids), sorted(self.refs.keys())) def test_get_all_exclude_attributes(self): # No attributes excluded - response = self.app.get('/v1/executions?action=executions.local&limit=1') + response = self.app.get("/v1/executions?action=executions.local&limit=1") self.assertEqual(response.status_int, 200) - self.assertIn('result', response.json[0]) + self.assertIn("result", response.json[0]) # Exclude "result" attribute - path = '/v1/executions?action=executions.local&limit=1&exclude_attributes=result' + path = ( + "/v1/executions?action=executions.local&limit=1&exclude_attributes=result" + ) response = self.app.get(path) self.assertEqual(response.status_int, 200) - self.assertNotIn('result', response.json[0]) + self.assertNotIn("result", response.json[0]) def test_get_one(self): obj_id = random.choice(list(self.refs.keys())) - response = self.app.get('/v1/executions/%s' % obj_id) + response = self.app.get("/v1/executions/%s" % obj_id) self.assertEqual(response.status_int, 200) self.assertIsInstance(response.json, dict) record = response.json fake_record = ActionExecutionAPI.from_model(self.refs[obj_id]) - self.assertEqual(record['id'], obj_id) - self.assertDictEqual(record['action'], fake_record.action) - self.assertDictEqual(record['runner'], fake_record.runner) - self.assertDictEqual(record['liveaction'], fake_record.liveaction) + self.assertEqual(record["id"], obj_id) + self.assertDictEqual(record["action"], fake_record.action) + self.assertDictEqual(record["runner"], fake_record.runner) + self.assertDictEqual(record["liveaction"], fake_record.liveaction) def test_get_one_failed(self): - response = self.app.get('/v1/executions/%s' % str(bson.ObjectId()), - expect_errors=True) + response = self.app.get( + "/v1/executions/%s" % str(bson.ObjectId()), expect_errors=True + ) self.assertEqual(response.status_int, http_client.NOT_FOUND) def test_limit(self): limit = 10 - refs = [k for k, v in six.iteritems(self.refs) if v.action['name'] == 'chain'] - response = self.app.get('/v1/executions?action=executions.chain&limit=%s' % - limit) + refs = [k for k, v in six.iteritems(self.refs) if v.action["name"] == "chain"] + response = self.app.get( + "/v1/executions?action=executions.chain&limit=%s" % limit + ) self.assertEqual(response.status_int, 200) self.assertIsInstance(response.json, list) self.assertEqual(len(response.json), limit) - self.assertEqual(response.headers['X-Limit'], str(limit)) - self.assertEqual(response.headers['X-Total-Count'], str(len(refs)), response.json) - ids = [item['id'] for item in response.json] + self.assertEqual(response.headers["X-Limit"], str(limit)) + self.assertEqual( + response.headers["X-Total-Count"], str(len(refs)), response.json + ) + ids = [item["id"] for item in response.json] self.assertListEqual(list(set(ids) - set(refs)), []) def test_limit_minus_one(self): limit = -1 - refs = [k for k, v in six.iteritems(self.refs) if v.action['name'] == 'chain'] - response = self.app.get('/v1/executions?action=executions.chain&limit=%s' % limit) + refs = [k for k, v in six.iteritems(self.refs) if v.action["name"] == "chain"] + response = self.app.get( + "/v1/executions?action=executions.chain&limit=%s" % limit + ) self.assertEqual(response.status_int, 200) self.assertIsInstance(response.json, list) self.assertEqual(len(response.json), len(refs)) - self.assertEqual(response.headers['X-Total-Count'], str(len(refs)), response.json) - ids = [item['id'] for item in response.json] + self.assertEqual( + response.headers["X-Total-Count"], str(len(refs)), response.json + ) + ids = [item["id"] for item in response.json] self.assertListEqual(list(set(ids) - set(refs)), []) def test_limit_negative(self): limit = -22 - response = self.app.get('/v1/executions?action=executions.chain&limit=%s' % limit, - expect_errors=True) + response = self.app.get( + "/v1/executions?action=executions.chain&limit=%s" % limit, + expect_errors=True, + ) self.assertEqual(response.status_int, 400) - self.assertEqual(response.json['faultstring'], - u'Limit, "-22" specified, must be a positive number.') + self.assertEqual( + response.json["faultstring"], + 'Limit, "-22" specified, must be a positive number.', + ) def test_query(self): - refs = [k for k, v in six.iteritems(self.refs) if v.action['name'] == 'chain'] - response = self.app.get('/v1/executions?action=executions.chain') + refs = [k for k, v in six.iteritems(self.refs) if v.action["name"] == "chain"] + response = self.app.get("/v1/executions?action=executions.chain") self.assertEqual(response.status_int, 200) self.assertIsInstance(response.json, list) self.assertEqual(len(response.json), len(refs)) - self.assertEqual(response.headers['X-Total-Count'], str(len(refs))) - ids = [item['id'] for item in response.json] + self.assertEqual(response.headers["X-Total-Count"], str(len(refs))) + ids = [item["id"] for item in response.json] self.assertListEqual(sorted(ids), sorted(refs)) def test_filters(self): - excludes = ['parent', 'timestamp', 'action', 'liveaction', 'timestamp_gt', - 'timestamp_lt', 'status'] + excludes = [ + "parent", + "timestamp", + "action", + "liveaction", + "timestamp_gt", + "timestamp_lt", + "status", + ] for param, field in six.iteritems(ActionExecutionsController.supported_filters): if param in excludes: continue value = self.fake_types[0] - for item in field.split('.'): + for item in field.split("."): value = value[item] - response = self.app.get('/v1/executions?%s=%s' % (param, value)) + response = self.app.get("/v1/executions?%s=%s" % (param, value)) self.assertEqual(response.status_int, 200) self.assertIsInstance(response.json, list) self.assertGreater(len(response.json), 0) - self.assertGreater(int(response.headers['X-Total-Count']), 0) + self.assertGreater(int(response.headers["X-Total-Count"]), 0) def test_advanced_filters(self): - excludes = ['parent', 'timestamp', 'action', 'liveaction', 'timestamp_gt', - 'timestamp_lt', 'status'] + excludes = [ + "parent", + "timestamp", + "action", + "liveaction", + "timestamp_gt", + "timestamp_lt", + "status", + ] for param, field in six.iteritems(ActionExecutionsController.supported_filters): if param in excludes: continue value = self.fake_types[0] - for item in field.split('.'): + for item in field.split("."): value = value[item] - response = self.app.get('/v1/executions?filter=%s:%s' % (field, value)) + response = self.app.get("/v1/executions?filter=%s:%s" % (field, value)) self.assertEqual(response.status_int, 200) self.assertIsInstance(response.json, list) self.assertGreater(len(response.json), 0) - self.assertGreater(int(response.headers['X-Total-Count']), 0) + self.assertGreater(int(response.headers["X-Total-Count"]), 0) def test_advanced_filters_malformed(self): - response = self.app.get('/v1/executions?filter=a:b,c:d', expect_errors=True) + response = self.app.get("/v1/executions?filter=a:b,c:d", expect_errors=True) self.assertEqual(response.status_int, 400) - self.assertEqual(response.json, { - "faultstring": "Cannot resolve field \"a\"" - }) - response = self.app.get('/v1/executions?filter=action.ref', expect_errors=True) + self.assertEqual(response.json, {"faultstring": 'Cannot resolve field "a"'}) + response = self.app.get("/v1/executions?filter=action.ref", expect_errors=True) self.assertEqual(response.status_int, 400) - self.assertEqual(response.json, { - "faultstring": "invalid format for filter \"action.ref\"" - }) + self.assertEqual( + response.json, {"faultstring": 'invalid format for filter "action.ref"'} + ) def test_parent(self): - refs = [v for k, v in six.iteritems(self.refs) - if v.action['name'] == 'chain' and v.children] + refs = [ + v + for k, v in six.iteritems(self.refs) + if v.action["name"] == "chain" and v.children + ] self.assertTrue(refs) ref = random.choice(refs) - response = self.app.get('/v1/executions?parent=%s' % str(ref.id)) + response = self.app.get("/v1/executions?parent=%s" % str(ref.id)) self.assertEqual(response.status_int, 200) self.assertIsInstance(response.json, list) self.assertEqual(len(response.json), len(ref.children)) - self.assertEqual(response.headers['X-Total-Count'], str(len(ref.children))) - ids = [item['id'] for item in response.json] + self.assertEqual(response.headers["X-Total-Count"], str(len(ref.children))) + ids = [item["id"] for item in response.json] self.assertListEqual(sorted(ids), sorted(ref.children)) def test_parentless(self): - refs = {k: v for k, v in six.iteritems(self.refs) if not getattr(v, 'parent', None)} + refs = { + k: v for k, v in six.iteritems(self.refs) if not getattr(v, "parent", None) + } self.assertTrue(refs) self.assertNotEqual(len(refs), self.num_records) - response = self.app.get('/v1/executions?parent=null') + response = self.app.get("/v1/executions?parent=null") self.assertEqual(response.status_int, 200) self.assertIsInstance(response.json, list) self.assertEqual(len(response.json), len(refs)) - self.assertEqual(response.headers['X-Total-Count'], str(len(refs))) - ids = [item['id'] for item in response.json] + self.assertEqual(response.headers["X-Total-Count"], str(len(refs))) + ids = [item["id"] for item in response.json] self.assertListEqual(sorted(ids), sorted(refs.keys())) def test_pagination(self): @@ -253,14 +288,15 @@ def test_pagination(self): page_count = int(self.num_records / page_size) for i in range(page_count): offset = i * page_size - response = self.app.get('/v1/executions?offset=%s&limit=%s' % ( - offset, page_size)) + response = self.app.get( + "/v1/executions?offset=%s&limit=%s" % (offset, page_size) + ) self.assertEqual(response.status_int, 200) self.assertIsInstance(response.json, list) self.assertEqual(len(response.json), page_size) - self.assertEqual(response.headers['X-Limit'], str(page_size)) - self.assertEqual(response.headers['X-Total-Count'], str(self.num_records)) - ids = [item['id'] for item in response.json] + self.assertEqual(response.headers["X-Limit"], str(page_size)) + self.assertEqual(response.headers["X-Total-Count"], str(self.num_records)) + ids = [item["id"] for item in response.json] self.assertListEqual(list(set(ids) - set(self.refs.keys())), []) self.assertListEqual(sorted(list(set(ids) - set(retrieved))), sorted(ids)) retrieved += ids @@ -270,60 +306,62 @@ def test_ui_history_query(self): # In this test we only care about making sure this exact query works. This query is used # by the webui for the history page so it is special and breaking this is bad. limit = 50 - history_query = '/v1/executions?limit={}&parent=null&exclude_attributes=' \ - 'result%2Ctrigger_instance&status=&action=&trigger_type=&rule=&' \ - 'offset=0'.format(limit) + history_query = ( + "/v1/executions?limit={}&parent=null&exclude_attributes=" + "result%2Ctrigger_instance&status=&action=&trigger_type=&rule=&" + "offset=0".format(limit) + ) response = self.app.get(history_query) self.assertEqual(response.status_int, 200) self.assertIsInstance(response.json, list) self.assertEqual(len(response.json), limit) - self.assertTrue(int(response.headers['X-Total-Count']) > limit) + self.assertTrue(int(response.headers["X-Total-Count"]) > limit) def test_datetime_range(self): - dt_range = '2014-12-25T00:00:10Z..2014-12-25T00:00:19Z' - response = self.app.get('/v1/executions?timestamp=%s' % dt_range) + dt_range = "2014-12-25T00:00:10Z..2014-12-25T00:00:19Z" + response = self.app.get("/v1/executions?timestamp=%s" % dt_range) self.assertEqual(response.status_int, 200) self.assertIsInstance(response.json, list) self.assertEqual(len(response.json), 10) - self.assertEqual(response.headers['X-Total-Count'], '10') + self.assertEqual(response.headers["X-Total-Count"], "10") - dt1 = response.json[0]['start_timestamp'] - dt2 = response.json[9]['start_timestamp'] + dt1 = response.json[0]["start_timestamp"] + dt2 = response.json[9]["start_timestamp"] self.assertLess(isotime.parse(dt1), isotime.parse(dt2)) - dt_range = '2014-12-25T00:00:19Z..2014-12-25T00:00:10Z' - response = self.app.get('/v1/executions?timestamp=%s' % dt_range) + dt_range = "2014-12-25T00:00:19Z..2014-12-25T00:00:10Z" + response = self.app.get("/v1/executions?timestamp=%s" % dt_range) self.assertEqual(response.status_int, 200) self.assertIsInstance(response.json, list) self.assertEqual(len(response.json), 10) - self.assertEqual(response.headers['X-Total-Count'], '10') - dt1 = response.json[0]['start_timestamp'] - dt2 = response.json[9]['start_timestamp'] + self.assertEqual(response.headers["X-Total-Count"], "10") + dt1 = response.json[0]["start_timestamp"] + dt2 = response.json[9]["start_timestamp"] self.assertLess(isotime.parse(dt2), isotime.parse(dt1)) def test_default_sort(self): - response = self.app.get('/v1/executions') + response = self.app.get("/v1/executions") self.assertEqual(response.status_int, 200) self.assertIsInstance(response.json, list) - dt1 = response.json[0]['start_timestamp'] - dt2 = response.json[len(response.json) - 1]['start_timestamp'] + dt1 = response.json[0]["start_timestamp"] + dt2 = response.json[len(response.json) - 1]["start_timestamp"] self.assertLess(isotime.parse(dt2), isotime.parse(dt1)) def test_ascending_sort(self): - response = self.app.get('/v1/executions?sort_asc=True') + response = self.app.get("/v1/executions?sort_asc=True") self.assertEqual(response.status_int, 200) self.assertIsInstance(response.json, list) - dt1 = response.json[0]['start_timestamp'] - dt2 = response.json[len(response.json) - 1]['start_timestamp'] + dt1 = response.json[0]["start_timestamp"] + dt2 = response.json[len(response.json) - 1]["start_timestamp"] self.assertLess(isotime.parse(dt1), isotime.parse(dt2)) def test_descending_sort(self): - response = self.app.get('/v1/executions?sort_desc=True') + response = self.app.get("/v1/executions?sort_desc=True") self.assertEqual(response.status_int, 200) self.assertIsInstance(response.json, list) - dt1 = response.json[0]['start_timestamp'] - dt2 = response.json[len(response.json) - 1]['start_timestamp'] + dt1 = response.json[0]["start_timestamp"] + dt2 = response.json[len(response.json) - 1]["start_timestamp"] self.assertLess(isotime.parse(dt2), isotime.parse(dt1)) def test_timestamp_lt_and_gt_filter(self): @@ -335,57 +373,81 @@ def isoformat(timestamp): # Last (largest) timestamp, there are no executions with a greater timestamp timestamp = self.start_timestamps[-1] - response = self.app.get('/v1/executions?timestamp_gt=%s' % (isoformat(timestamp))) + response = self.app.get( + "/v1/executions?timestamp_gt=%s" % (isoformat(timestamp)) + ) self.assertEqual(len(response.json), 0) # First (smallest) timestamp, there are no executions with a smaller timestamp timestamp = self.start_timestamps[0] - response = self.app.get('/v1/executions?timestamp_lt=%s' % (isoformat(timestamp))) + response = self.app.get( + "/v1/executions?timestamp_lt=%s" % (isoformat(timestamp)) + ) self.assertEqual(len(response.json), 0) # Second last, there should be one timestamp greater than it timestamp = self.start_timestamps[-2] - response = self.app.get('/v1/executions?timestamp_gt=%s' % (isoformat(timestamp))) + response = self.app.get( + "/v1/executions?timestamp_gt=%s" % (isoformat(timestamp)) + ) self.assertEqual(len(response.json), 1) - self.assertTrue(isotime.parse(response.json[0]['start_timestamp']) > timestamp) + self.assertTrue(isotime.parse(response.json[0]["start_timestamp"]) > timestamp) # Second one, there should be one timestamp smaller than it timestamp = self.start_timestamps[1] - response = self.app.get('/v1/executions?timestamp_lt=%s' % (isoformat(timestamp))) + response = self.app.get( + "/v1/executions?timestamp_lt=%s" % (isoformat(timestamp)) + ) self.assertEqual(len(response.json), 1) - self.assertTrue(isotime.parse(response.json[0]['start_timestamp']) < timestamp) + self.assertTrue(isotime.parse(response.json[0]["start_timestamp"]) < timestamp) # Half of the timestamps should be smaller index = (len(self.start_timestamps) - 1) // 2 timestamp = self.start_timestamps[index] - response = self.app.get('/v1/executions?timestamp_lt=%s' % (isoformat(timestamp))) + response = self.app.get( + "/v1/executions?timestamp_lt=%s" % (isoformat(timestamp)) + ) self.assertEqual(len(response.json), index) - self.assertTrue(isotime.parse(response.json[0]['start_timestamp']) < timestamp) + self.assertTrue(isotime.parse(response.json[0]["start_timestamp"]) < timestamp) # Half of the timestamps should be greater index = (len(self.start_timestamps) - 1) // 2 timestamp = self.start_timestamps[-index] - response = self.app.get('/v1/executions?timestamp_gt=%s' % (isoformat(timestamp))) + response = self.app.get( + "/v1/executions?timestamp_gt=%s" % (isoformat(timestamp)) + ) self.assertEqual(len(response.json), (index - 1)) - self.assertTrue(isotime.parse(response.json[0]['start_timestamp']) > timestamp) + self.assertTrue(isotime.parse(response.json[0]["start_timestamp"]) > timestamp) # Both, lt and gt filters, should return exactly two results timestamp_gt = self.start_timestamps[10] timestamp_lt = self.start_timestamps[13] - response = self.app.get('/v1/executions?timestamp_gt=%s×tamp_lt=%s' % - (isoformat(timestamp_gt), isoformat(timestamp_lt))) + response = self.app.get( + "/v1/executions?timestamp_gt=%s×tamp_lt=%s" + % (isoformat(timestamp_gt), isoformat(timestamp_lt)) + ) self.assertEqual(len(response.json), 2) - self.assertTrue(isotime.parse(response.json[0]['start_timestamp']) > timestamp_gt) - self.assertTrue(isotime.parse(response.json[1]['start_timestamp']) > timestamp_gt) - self.assertTrue(isotime.parse(response.json[0]['start_timestamp']) < timestamp_lt) - self.assertTrue(isotime.parse(response.json[1]['start_timestamp']) < timestamp_lt) + self.assertTrue( + isotime.parse(response.json[0]["start_timestamp"]) > timestamp_gt + ) + self.assertTrue( + isotime.parse(response.json[1]["start_timestamp"]) > timestamp_gt + ) + self.assertTrue( + isotime.parse(response.json[0]["start_timestamp"]) < timestamp_lt + ) + self.assertTrue( + isotime.parse(response.json[1]["start_timestamp"]) < timestamp_lt + ) def test_filters_view(self): - response = self.app.get('/v1/executions/views/filters') + response = self.app.get("/v1/executions/views/filters") self.assertEqual(response.status_int, 200) self.assertIsInstance(response.json, dict) - self.assertEqual(len(response.json), len(history_views.ARTIFACTS['filters']['default'])) - for key, value in six.iteritems(history_views.ARTIFACTS['filters']['default']): + self.assertEqual( + len(response.json), len(history_views.ARTIFACTS["filters"]["default"]) + ) + for key, value in six.iteritems(history_views.ARTIFACTS["filters"]["default"]): filter_values = response.json[key] # Verify empty (None / null) filters are excluded @@ -399,9 +461,13 @@ def test_filters_view(self): self.assertEqual(set(filter_values), set(value)) def test_filters_view_specific_types(self): - response = self.app.get('/v1/executions/views/filters?types=action,user,nonexistent') + response = self.app.get( + "/v1/executions/views/filters?types=action,user,nonexistent" + ) self.assertEqual(response.status_int, 200) self.assertIsInstance(response.json, dict) - self.assertEqual(len(response.json), len(history_views.ARTIFACTS['filters']['specific'])) - for key, value in six.iteritems(history_views.ARTIFACTS['filters']['specific']): + self.assertEqual( + len(response.json), len(history_views.ARTIFACTS["filters"]["specific"]) + ) + for key, value in six.iteritems(history_views.ARTIFACTS["filters"]["specific"]): self.assertEqual(set(response.json[key]), set(value)) diff --git a/st2api/tests/unit/controllers/v1/test_inquiries.py b/st2api/tests/unit/controllers/v1/test_inquiries.py index 469b1c88169..173acbe4058 100644 --- a/st2api/tests/unit/controllers/v1/test_inquiries.py +++ b/st2api/tests/unit/controllers/v1/test_inquiries.py @@ -36,58 +36,50 @@ ACTION_1 = { - 'name': 'st2.dummy.action1', - 'description': 'test description', - 'enabled': True, - 'pack': 'testpack', - 'runner_type': 'local-shell-cmd', + "name": "st2.dummy.action1", + "description": "test description", + "enabled": True, + "pack": "testpack", + "runner_type": "local-shell-cmd", } LIVE_ACTION_1 = { - 'action': 'testpack.st2.dummy.action1', - 'parameters': { - 'cmd': 'uname -a' - } + "action": "testpack.st2.dummy.action1", + "parameters": {"cmd": "uname -a"}, } INQUIRY_ACTION = { - 'name': 'st2.dummy.ask', - 'description': 'test description', - 'enabled': True, - 'pack': 'testpack', - 'runner_type': 'inquirer', + "name": "st2.dummy.ask", + "description": "test description", + "enabled": True, + "pack": "testpack", + "runner_type": "inquirer", } INQUIRY_1 = { - 'action': 'testpack.st2.dummy.ask', - 'status': 'pending', - 'parameters': {}, - 'context': { - 'parent': { - 'user': 'testu', - 'execution_id': '59b845e132ed350d396a798f', - 'pack': 'examples' + "action": "testpack.st2.dummy.ask", + "status": "pending", + "parameters": {}, + "context": { + "parent": { + "user": "testu", + "execution_id": "59b845e132ed350d396a798f", + "pack": "examples", }, - 'trace_context': {'trace_tag': 'balleilaka'} - } + "trace_context": {"trace_tag": "balleilaka"}, + }, } INQUIRY_2 = { - 'action': 'testpack.st2.dummy.ask', - 'status': 'pending', - 'parameters': { - 'route': 'superlative', - 'users': ['foo', 'bar'] - } + "action": "testpack.st2.dummy.ask", + "status": "pending", + "parameters": {"route": "superlative", "users": ["foo", "bar"]}, } INQUIRY_TIMEOUT = { - 'action': 'testpack.st2.dummy.ask', - 'status': 'timeout', - 'parameters': { - 'route': 'superlative', - 'users': ['foo', 'bar'] - } + "action": "testpack.st2.dummy.ask", + "status": "timeout", + "parameters": {"route": "superlative", "users": ["foo", "bar"]}, } SCHEMA_DEFAULT = { @@ -97,7 +89,7 @@ "continue": { "type": "boolean", "description": "Would you like to continue the workflow?", - "required": True + "required": True, } }, } @@ -109,18 +101,18 @@ "name": { "type": "string", "description": "What is your name?", - "required": True + "required": True, }, "pin": { "type": "integer", "description": "What is your PIN?", - "required": True + "required": True, }, "paradox": { "type": "boolean", "description": "This statement is False.", - "required": True - } + "required": True, + }, }, } @@ -132,7 +124,7 @@ "roles": [], "users": [], "route": "", - "ttl": 1440 + "ttl": 1440, } RESULT_2 = { @@ -140,7 +132,7 @@ "roles": [], "users": ["foo", "bar"], "route": "superlative", - "ttl": 1440 + "ttl": 1440, } RESULT_MULTIPLE = { @@ -148,58 +140,51 @@ "roles": [], "users": [], "route": "", - "ttl": 1440 + "ttl": 1440, } -RESPONSE_MULTIPLE = { - "name": "matt", - "pin": 1234, - "paradox": True -} +RESPONSE_MULTIPLE = {"name": "matt", "pin": 1234, "paradox": True} ROOT_LIVEACTION_DB = lv_db_models.LiveActionDB( - id=uuid.uuid4().hex, - status=action_constants.LIVEACTION_STATUS_PAUSED + id=uuid.uuid4().hex, status=action_constants.LIVEACTION_STATUS_PAUSED ) -@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) -class InquiryControllerTestCase(BaseInquiryControllerTestCase, - APIControllerWithIncludeAndExcludeFilterTestCase): - get_all_path = '/v1/inquiries' +@mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) +class InquiryControllerTestCase( + BaseInquiryControllerTestCase, APIControllerWithIncludeAndExcludeFilterTestCase +): + get_all_path = "/v1/inquiries" controller_cls = InquiriesController - include_attribute_field_name = 'ttl' - exclude_attribute_field_name = 'ttl' + include_attribute_field_name = "ttl" + exclude_attribute_field_name = "ttl" @mock.patch.object( - action_validator, - 'validate_action', - mock.MagicMock(return_value=True)) + action_validator, "validate_action", mock.MagicMock(return_value=True) + ) def setUp(cls): super(BaseInquiryControllerTestCase, cls).setUpClass() cls.inquiry1 = copy.deepcopy(INQUIRY_ACTION) - post_resp = cls.app.post_json('/v1/actions', cls.inquiry1) - cls.inquiry1['id'] = post_resp.json['id'] + post_resp = cls.app.post_json("/v1/actions", cls.inquiry1) + cls.inquiry1["id"] = post_resp.json["id"] cls.action1 = copy.deepcopy(ACTION_1) - post_resp = cls.app.post_json('/v1/actions', cls.action1) - cls.action1['id'] = post_resp.json['id'] + post_resp = cls.app.post_json("/v1/actions", cls.action1) + cls.action1["id"] = post_resp.json["id"] def test_get_all(self): - """Test retrieval of a list of Inquiries - """ + """Test retrieval of a list of Inquiries""" inquiry_count = 5 for i in range(inquiry_count): self._do_create_inquiry(INQUIRY_1, RESULT_DEFAULT) get_all_resp = self._do_get_all() inquiries = get_all_resp.json - self.assertEqual(get_all_resp.headers['X-Total-Count'], str(len(inquiries))) + self.assertEqual(get_all_resp.headers["X-Total-Count"], str(len(inquiries))) self.assertIsInstance(inquiries, list) self.assertEqual(len(inquiries), inquiry_count) def test_get_all_empty(self): - """Test retrieval of a list of Inquiries when there are none - """ + """Test retrieval of a list of Inquiries when there are none""" inquiry_count = 0 get_all_resp = self._do_get_all() inquiries = get_all_resp.json @@ -207,8 +192,7 @@ def test_get_all_empty(self): self.assertEqual(len(inquiries), inquiry_count) def test_get_all_decrease_after_respond(self): - """Test that the inquiry list decreases when we respond to one of them - """ + """Test that the inquiry list decreases when we respond to one of them""" # Create inquiries inquiry_count = 5 @@ -221,7 +205,7 @@ def test_get_all_decrease_after_respond(self): # Respond to one of them response = {"continue": True} - self._do_respond(inquiries[0].get('id'), response) + self._do_respond(inquiries[0].get("id"), response) # Ensure the list is one smaller get_all_resp = self._do_get_all() @@ -230,8 +214,7 @@ def test_get_all_decrease_after_respond(self): self.assertEqual(len(inquiries), inquiry_count - 1) def test_get_all_limit(self): - """Test that the limit parameter works correctly - """ + """Test that the limit parameter works correctly""" # Create inquiries inquiry_count = 5 @@ -241,12 +224,11 @@ def test_get_all_limit(self): get_all_resp = self._do_get_all(limit=limit) inquiries = get_all_resp.json self.assertIsInstance(inquiries, list) - self.assertEqual(inquiry_count, int(get_all_resp.headers['X-Total-Count'])) + self.assertEqual(inquiry_count, int(get_all_resp.headers["X-Total-Count"])) self.assertEqual(len(inquiries), limit) def test_get_one(self): - """Test retrieval of a single Inquiry - """ + """Test retrieval of a single Inquiry""" post_resp = self._do_create_inquiry(INQUIRY_1, RESULT_DEFAULT) inquiry_id = self._get_inquiry_id(post_resp) get_resp = self._do_get_one(inquiry_id) @@ -254,24 +236,21 @@ def test_get_one(self): self.assertEqual(self._get_inquiry_id(get_resp), inquiry_id) def test_get_one_failed(self): - """Test failed retrieval of an Inquiry - """ - inquiry_id = 'asdfeoijasdf' + """Test failed retrieval of an Inquiry""" + inquiry_id = "asdfeoijasdf" get_resp = self._do_get_one(inquiry_id, expect_errors=True) self.assertEqual(get_resp.status_int, http_client.NOT_FOUND) - self.assertIn('resource could not be found', get_resp.json['faultstring']) + self.assertIn("resource could not be found", get_resp.json["faultstring"]) def test_get_one_not_an_inquiry(self): - """Test that an attempt to retrieve a valid execution that isn't an Inquiry fails - """ - test_exec = json.loads(self.app.post_json('/v1/executions', LIVE_ACTION_1).body) - get_resp = self._do_get_one(test_exec.get('id'), expect_errors=True) + """Test that an attempt to retrieve a valid execution that isn't an Inquiry fails""" + test_exec = json.loads(self.app.post_json("/v1/executions", LIVE_ACTION_1).body) + get_resp = self._do_get_one(test_exec.get("id"), expect_errors=True) self.assertEqual(get_resp.status_int, http_client.BAD_REQUEST) - self.assertIn('is not an inquiry', get_resp.json['faultstring']) + self.assertIn("is not an inquiry", get_resp.json["faultstring"]) def test_get_one_nondefault_params(self): - """Ensure an Inquiry with custom parameters contains those in result - """ + """Ensure an Inquiry with custom parameters contains those in result""" post_resp = self._do_create_inquiry(INQUIRY_2, RESULT_2) inquiry_id = self._get_inquiry_id(post_resp) get_resp = self._do_get_one(inquiry_id) @@ -282,14 +261,15 @@ def test_get_one_nondefault_params(self): self.assertEqual(get_resp.json.get(param), RESULT_2.get(param)) @mock.patch.object( - action_service, 'get_root_liveaction', - mock.MagicMock(return_value=ROOT_LIVEACTION_DB)) + action_service, + "get_root_liveaction", + mock.MagicMock(return_value=ROOT_LIVEACTION_DB), + ) @mock.patch.object( - action_service, 'request_resume', - mock.MagicMock(return_value=None)) + action_service, "request_resume", mock.MagicMock(return_value=None) + ) def test_respond(self): - """Test that a correct response is successful - """ + """Test that a correct response is successful""" post_resp = self._do_create_inquiry(INQUIRY_1, RESULT_DEFAULT) inquiry_id = self._get_inquiry_id(post_resp) @@ -300,21 +280,22 @@ def test_respond(self): # The inquiry no longer exists, since the status should not be "pending" # Get the execution and confirm this. inquiry_execution = self._do_get_execution(inquiry_id) - self.assertEqual(inquiry_execution.json.get('status'), 'succeeded') + self.assertEqual(inquiry_execution.json.get("status"), "succeeded") # This Inquiry is in a workflow, so has a parent. Assert that the resume # was requested for this parent. action_service.request_resume.assert_called_once() @mock.patch.object( - action_service, 'get_root_liveaction', - mock.MagicMock(return_value=ROOT_LIVEACTION_DB)) + action_service, + "get_root_liveaction", + mock.MagicMock(return_value=ROOT_LIVEACTION_DB), + ) @mock.patch.object( - action_service, 'request_resume', - mock.MagicMock(return_value=None)) + action_service, "request_resume", mock.MagicMock(return_value=None) + ) def test_respond_multiple(self): - """Test that a more complicated response is successful - """ + """Test that a more complicated response is successful""" post_resp = self._do_create_inquiry(INQUIRY_1, RESULT_MULTIPLE) inquiry_id = self._get_inquiry_id(post_resp) @@ -324,38 +305,35 @@ def test_respond_multiple(self): # The inquiry no longer exists, since the status should not be "pending" # Get the execution and confirm this. inquiry_execution = self._do_get_execution(inquiry_id) - self.assertEqual(inquiry_execution.json.get('status'), 'succeeded') + self.assertEqual(inquiry_execution.json.get("status"), "succeeded") # This Inquiry is in a workflow, so has a parent. Assert that the resume # was requested for this parent. action_service.request_resume.assert_called_once() def test_respond_fail(self): - """Test that an incorrect response is unsuccessful - """ + """Test that an incorrect response is unsuccessful""" post_resp = self._do_create_inquiry(INQUIRY_2, RESULT_DEFAULT) inquiry_id = self._get_inquiry_id(post_resp) response = {"continue": 123} put_resp = self._do_respond(inquiry_id, response, expect_errors=True) self.assertEqual(put_resp.status_int, http_client.BAD_REQUEST) - self.assertIn('did not pass schema validation', put_resp.json['faultstring']) + self.assertIn("did not pass schema validation", put_resp.json["faultstring"]) def test_respond_not_an_inquiry(self): - """Test that attempts to respond to an execution ID that isn't an Inquiry fails - """ - test_exec = json.loads(self.app.post_json('/v1/executions', LIVE_ACTION_1).body) + """Test that attempts to respond to an execution ID that isn't an Inquiry fails""" + test_exec = json.loads(self.app.post_json("/v1/executions", LIVE_ACTION_1).body) response = {"continue": 123} - put_resp = self._do_respond(test_exec.get('id'), response, expect_errors=True) + put_resp = self._do_respond(test_exec.get("id"), response, expect_errors=True) self.assertEqual(put_resp.status_int, http_client.BAD_REQUEST) - self.assertIn('is not an inquiry', put_resp.json['faultstring']) + self.assertIn("is not an inquiry", put_resp.json["faultstring"]) @mock.patch.object( - action_service, 'request_resume', - mock.MagicMock(return_value=None)) + action_service, "request_resume", mock.MagicMock(return_value=None) + ) def test_respond_no_parent(self): - """Test that a resume was not requested for an Inquiry without a parent - """ + """Test that a resume was not requested for an Inquiry without a parent""" post_resp = self._do_create_inquiry(INQUIRY_2, RESULT_DEFAULT) inquiry_id = self._get_inquiry_id(post_resp) @@ -365,8 +343,7 @@ def test_respond_no_parent(self): action_service.request_resume.assert_not_called() def test_respond_duplicate_rejected(self): - """Test that responding to an already-responded Inquiry fails - """ + """Test that responding to an already-responded Inquiry fails""" post_resp = self._do_create_inquiry(INQUIRY_2, RESULT_DEFAULT) inquiry_id = self._get_inquiry_id(post_resp) @@ -377,28 +354,30 @@ def test_respond_duplicate_rejected(self): # The inquiry no longer exists, since the status should not be "pending" # Get the execution and confirm this. inquiry_execution = self._do_get_execution(inquiry_id) - self.assertEqual(inquiry_execution.json.get('status'), 'succeeded') + self.assertEqual(inquiry_execution.json.get("status"), "succeeded") # A second, equivalent response attempt should not succeed, since the Inquiry # has already been successfully responded to put_resp = self._do_respond(inquiry_id, response, expect_errors=True) self.assertEqual(put_resp.status_int, http_client.BAD_REQUEST) - self.assertIn('has already been responded to', put_resp.json['faultstring']) + self.assertIn("has already been responded to", put_resp.json["faultstring"]) def test_respond_timeout_rejected(self): - """Test that responding to a timed-out Inquiry fails - """ + """Test that responding to a timed-out Inquiry fails""" - post_resp = self._do_create_inquiry(INQUIRY_TIMEOUT, RESULT_DEFAULT, status='timeout') + post_resp = self._do_create_inquiry( + INQUIRY_TIMEOUT, RESULT_DEFAULT, status="timeout" + ) inquiry_id = self._get_inquiry_id(post_resp) response = {"continue": True} put_resp = self._do_respond(inquiry_id, response, expect_errors=True) self.assertEqual(put_resp.status_int, http_client.BAD_REQUEST) - self.assertIn('timed out and cannot be responded to', put_resp.json['faultstring']) + self.assertIn( + "timed out and cannot be responded to", put_resp.json["faultstring"] + ) def test_respond_restrict_users(self): - """Test that Inquiries can reject responses from users not in a list - """ + """Test that Inquiries can reject responses from users not in a list""" # Default user for tests is "stanley", which is not in the 'users' list # Should be rejected @@ -407,7 +386,9 @@ def test_respond_restrict_users(self): response = {"continue": True} put_resp = self._do_respond(inquiry_id, response, expect_errors=True) self.assertEqual(put_resp.status_int, http_client.FORBIDDEN) - self.assertIn('does not have permission to respond', put_resp.json['faultstring']) + self.assertIn( + "does not have permission to respond", put_resp.json["faultstring"] + ) # Responding as a use in the list should be accepted old_user = cfg.CONF.system_user.user @@ -425,8 +406,8 @@ def test_get_all_invalid_exclude_and_include_parameter(self): pass def _insert_mock_models(self): - id_1 = self._do_create_inquiry(INQUIRY_1, RESULT_DEFAULT).json['id'] - id_2 = self._do_create_inquiry(INQUIRY_1, RESULT_DEFAULT).json['id'] + id_1 = self._do_create_inquiry(INQUIRY_1, RESULT_DEFAULT).json["id"] + id_2 = self._do_create_inquiry(INQUIRY_1, RESULT_DEFAULT).json["id"] return [id_1, id_2] diff --git a/st2api/tests/unit/controllers/v1/test_kvps.py b/st2api/tests/unit/controllers/v1/test_kvps.py index 06103134bdb..61a903a3ad4 100644 --- a/st2api/tests/unit/controllers/v1/test_kvps.py +++ b/st2api/tests/unit/controllers/v1/test_kvps.py @@ -21,83 +21,66 @@ from six.moves import http_client -__all__ = [ - 'KeyValuePairControllerTestCase' -] +__all__ = ["KeyValuePairControllerTestCase"] -KVP = { - 'name': 'keystone_endpoint', - 'value': 'http://127.0.0.1:5000/v3' -} +KVP = {"name": "keystone_endpoint", "value": "http://127.0.0.1:5000/v3"} -KVP_2 = { - 'name': 'keystone_version', - 'value': 'v3' -} +KVP_2 = {"name": "keystone_version", "value": "v3"} -KVP_2_USER = { - 'name': 'keystone_version', - 'value': 'user_v3', - 'scope': 'st2kv.user' -} +KVP_2_USER = {"name": "keystone_version", "value": "user_v3", "scope": "st2kv.user"} -KVP_2_USER_LEGACY = { - 'name': 'keystone_version', - 'value': 'user_v3', - 'scope': 'user' -} +KVP_2_USER_LEGACY = {"name": "keystone_version", "value": "user_v3", "scope": "user"} KVP_3_USER = { - 'name': 'keystone_endpoint', - 'value': 'http://127.0.1.1:5000/v3', - 'scope': 'st2kv.user' + "name": "keystone_endpoint", + "value": "http://127.0.1.1:5000/v3", + "scope": "st2kv.user", } KVP_4_USER = { - 'name': 'customer_ssn', - 'value': '123-456-7890', - 'secret': True, - 'scope': 'st2kv.user' + "name": "customer_ssn", + "value": "123-456-7890", + "secret": True, + "scope": "st2kv.user", } KVP_WITH_TTL = { - 'name': 'keystone_endpoint', - 'value': 'http://127.0.0.1:5000/v3', - 'ttl': 10 + "name": "keystone_endpoint", + "value": "http://127.0.0.1:5000/v3", + "ttl": 10, } -SECRET_KVP = { - 'name': 'secret_key1', - 'value': 'secret_value1', - 'secret': True -} +SECRET_KVP = {"name": "secret_key1", "value": "secret_value1", "secret": True} # value = S3cret!Value # encrypted with st2tests/conf/st2_kvstore_tests.crypto.key.json ENCRYPTED_KVP = { - 'name': 'secret_key1', - 'value': ('3030303030298D848B45A24EDCD1A82FAB4E831E3FCE6E60956817A48A180E4C040801E' - 'B30170DACF79498F30520236A629912C3584847098D'), - 'encrypted': True + "name": "secret_key1", + "value": ( + "3030303030298D848B45A24EDCD1A82FAB4E831E3FCE6E60956817A48A180E4C040801E" + "B30170DACF79498F30520236A629912C3584847098D" + ), + "encrypted": True, } ENCRYPTED_KVP_SECRET_FALSE = { - 'name': 'secret_key2', - 'value': ('3030303030298D848B45A24EDCD1A82FAB4E831E3FCE6E60956817A48A180E4C040801E' - 'B30170DACF79498F30520236A629912C3584847098D'), - 'secret': True, - 'encrypted': True + "name": "secret_key2", + "value": ( + "3030303030298D848B45A24EDCD1A82FAB4E831E3FCE6E60956817A48A180E4C040801E" + "B30170DACF79498F30520236A629912C3584847098D" + ), + "secret": True, + "encrypted": True, } class KeyValuePairControllerTestCase(FunctionalTest): - def test_get_all(self): - resp = self.app.get('/v1/keys') + resp = self.app.get("/v1/keys") self.assertEqual(resp.status_int, 200) def test_get_one(self): - put_resp = self.__do_put('key1', KVP) + put_resp = self.__do_put("key1", KVP) kvp_id = self.__get_kvp_id(put_resp) get_resp = self.__do_get_one(kvp_id) self.assertEqual(get_resp.status_int, 200) @@ -107,484 +90,534 @@ def test_get_one(self): def test_get_all_all_scope(self): # Test which cases various scenarios which ensure non-admin users can't read / view keys # from other users - user_db_1 = UserDB(name='user1') - user_db_2 = UserDB(name='user2') - user_db_3 = UserDB(name='user3') + user_db_1 = UserDB(name="user1") + user_db_2 = UserDB(name="user2") + user_db_3 = UserDB(name="user3") # Insert some mock data # System scoped keys - put_resp = self.__do_put('system1', {'name': 'system1', 'value': 'val1', - 'scope': 'st2kv.system'}) + put_resp = self.__do_put( + "system1", {"name": "system1", "value": "val1", "scope": "st2kv.system"} + ) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['name'], 'system1') - self.assertEqual(put_resp.json['scope'], 'st2kv.system') + self.assertEqual(put_resp.json["name"], "system1") + self.assertEqual(put_resp.json["scope"], "st2kv.system") - put_resp = self.__do_put('system2', {'name': 'system2', 'value': 'val2', - 'scope': 'st2kv.system'}) + put_resp = self.__do_put( + "system2", {"name": "system2", "value": "val2", "scope": "st2kv.system"} + ) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['name'], 'system2') - self.assertEqual(put_resp.json['scope'], 'st2kv.system') + self.assertEqual(put_resp.json["name"], "system2") + self.assertEqual(put_resp.json["scope"], "st2kv.system") # user1 scoped keys self.use_user(user_db_1) - put_resp = self.__do_put('user1', {'name': 'user1', 'value': 'user1', - 'scope': 'st2kv.user'}) + put_resp = self.__do_put( + "user1", {"name": "user1", "value": "user1", "scope": "st2kv.user"} + ) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['name'], 'user1') - self.assertEqual(put_resp.json['scope'], 'st2kv.user') - self.assertEqual(put_resp.json['value'], 'user1') + self.assertEqual(put_resp.json["name"], "user1") + self.assertEqual(put_resp.json["scope"], "st2kv.user") + self.assertEqual(put_resp.json["value"], "user1") - put_resp = self.__do_put('userkey', {'name': 'userkey', 'value': 'user1', - 'scope': 'st2kv.user'}) + put_resp = self.__do_put( + "userkey", {"name": "userkey", "value": "user1", "scope": "st2kv.user"} + ) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['name'], 'userkey') - self.assertEqual(put_resp.json['scope'], 'st2kv.user') - self.assertEqual(put_resp.json['value'], 'user1') + self.assertEqual(put_resp.json["name"], "userkey") + self.assertEqual(put_resp.json["scope"], "st2kv.user") + self.assertEqual(put_resp.json["value"], "user1") # user2 scoped keys self.use_user(user_db_2) - put_resp = self.__do_put('user2', {'name': 'user2', 'value': 'user2', - 'scope': 'st2kv.user'}) + put_resp = self.__do_put( + "user2", {"name": "user2", "value": "user2", "scope": "st2kv.user"} + ) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['name'], 'user2') - self.assertEqual(put_resp.json['scope'], 'st2kv.user') - self.assertEqual(put_resp.json['value'], 'user2') + self.assertEqual(put_resp.json["name"], "user2") + self.assertEqual(put_resp.json["scope"], "st2kv.user") + self.assertEqual(put_resp.json["value"], "user2") - put_resp = self.__do_put('userkey', {'name': 'userkey', 'value': 'user2', - 'scope': 'st2kv.user'}) + put_resp = self.__do_put( + "userkey", {"name": "userkey", "value": "user2", "scope": "st2kv.user"} + ) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['name'], 'userkey') - self.assertEqual(put_resp.json['scope'], 'st2kv.user') - self.assertEqual(put_resp.json['value'], 'user2') + self.assertEqual(put_resp.json["name"], "userkey") + self.assertEqual(put_resp.json["scope"], "st2kv.user") + self.assertEqual(put_resp.json["value"], "user2") # user3 scoped keys self.use_user(user_db_3) - put_resp = self.__do_put('user3', {'name': 'user3', 'value': 'user3', - 'scope': 'st2kv.user'}) + put_resp = self.__do_put( + "user3", {"name": "user3", "value": "user3", "scope": "st2kv.user"} + ) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['name'], 'user3') - self.assertEqual(put_resp.json['scope'], 'st2kv.user') - self.assertEqual(put_resp.json['value'], 'user3') + self.assertEqual(put_resp.json["name"], "user3") + self.assertEqual(put_resp.json["scope"], "st2kv.user") + self.assertEqual(put_resp.json["value"], "user3") - put_resp = self.__do_put('userkey', {'name': 'userkey', 'value': 'user3', - 'scope': 'st2kv.user'}) + put_resp = self.__do_put( + "userkey", {"name": "userkey", "value": "user3", "scope": "st2kv.user"} + ) self.assertEqual(put_resp.status_int, 200) - self.assertEqual(put_resp.json['name'], 'userkey') - self.assertEqual(put_resp.json['scope'], 'st2kv.user') - self.assertEqual(put_resp.json['value'], 'user3') + self.assertEqual(put_resp.json["name"], "userkey") + self.assertEqual(put_resp.json["scope"], "st2kv.user") + self.assertEqual(put_resp.json["value"], "user3") # 1. "all" scope as user1 - should only be able to view system + current user items self.use_user(user_db_1) - resp = self.app.get('/v1/keys?scope=all') + resp = self.app.get("/v1/keys?scope=all") self.assertEqual(len(resp.json), 2 + 2) # 2 system, 2 user - self.assertEqual(resp.json[0]['name'], 'system1') - self.assertEqual(resp.json[0]['scope'], 'st2kv.system') + self.assertEqual(resp.json[0]["name"], "system1") + self.assertEqual(resp.json[0]["scope"], "st2kv.system") - self.assertEqual(resp.json[1]['name'], 'system2') - self.assertEqual(resp.json[1]['scope'], 'st2kv.system') + self.assertEqual(resp.json[1]["name"], "system2") + self.assertEqual(resp.json[1]["scope"], "st2kv.system") - self.assertEqual(resp.json[2]['name'], 'user1') - self.assertEqual(resp.json[2]['scope'], 'st2kv.user') - self.assertEqual(resp.json[2]['user'], 'user1') + self.assertEqual(resp.json[2]["name"], "user1") + self.assertEqual(resp.json[2]["scope"], "st2kv.user") + self.assertEqual(resp.json[2]["user"], "user1") - self.assertEqual(resp.json[3]['name'], 'userkey') - self.assertEqual(resp.json[3]['scope'], 'st2kv.user') - self.assertEqual(resp.json[3]['user'], 'user1') + self.assertEqual(resp.json[3]["name"], "userkey") + self.assertEqual(resp.json[3]["scope"], "st2kv.user") + self.assertEqual(resp.json[3]["user"], "user1") # Verify user can't retrieve values for other users by manipulating "prefix" - resp = self.app.get('/v1/keys?scope=all&prefix=user2:') + resp = self.app.get("/v1/keys?scope=all&prefix=user2:") self.assertEqual(resp.json, []) - resp = self.app.get('/v1/keys?scope=all&prefix=user') + resp = self.app.get("/v1/keys?scope=all&prefix=user") self.assertEqual(len(resp.json), 2) # 2 user - self.assertEqual(resp.json[0]['name'], 'user1') - self.assertEqual(resp.json[0]['scope'], 'st2kv.user') - self.assertEqual(resp.json[0]['user'], 'user1') + self.assertEqual(resp.json[0]["name"], "user1") + self.assertEqual(resp.json[0]["scope"], "st2kv.user") + self.assertEqual(resp.json[0]["user"], "user1") - self.assertEqual(resp.json[1]['name'], 'userkey') - self.assertEqual(resp.json[1]['scope'], 'st2kv.user') - self.assertEqual(resp.json[1]['user'], 'user1') + self.assertEqual(resp.json[1]["name"], "userkey") + self.assertEqual(resp.json[1]["scope"], "st2kv.user") + self.assertEqual(resp.json[1]["user"], "user1") # 2. "all" scope user user2 - should only be able to view system + current user items self.use_user(user_db_2) - resp = self.app.get('/v1/keys?scope=all') + resp = self.app.get("/v1/keys?scope=all") self.assertEqual(len(resp.json), 2 + 2) # 2 system, 2 user - self.assertEqual(resp.json[0]['name'], 'system1') - self.assertEqual(resp.json[0]['scope'], 'st2kv.system') + self.assertEqual(resp.json[0]["name"], "system1") + self.assertEqual(resp.json[0]["scope"], "st2kv.system") - self.assertEqual(resp.json[1]['name'], 'system2') - self.assertEqual(resp.json[1]['scope'], 'st2kv.system') + self.assertEqual(resp.json[1]["name"], "system2") + self.assertEqual(resp.json[1]["scope"], "st2kv.system") - self.assertEqual(resp.json[2]['name'], 'user2') - self.assertEqual(resp.json[2]['scope'], 'st2kv.user') - self.assertEqual(resp.json[2]['user'], 'user2') + self.assertEqual(resp.json[2]["name"], "user2") + self.assertEqual(resp.json[2]["scope"], "st2kv.user") + self.assertEqual(resp.json[2]["user"], "user2") - self.assertEqual(resp.json[3]['name'], 'userkey') - self.assertEqual(resp.json[3]['scope'], 'st2kv.user') - self.assertEqual(resp.json[3]['user'], 'user2') + self.assertEqual(resp.json[3]["name"], "userkey") + self.assertEqual(resp.json[3]["scope"], "st2kv.user") + self.assertEqual(resp.json[3]["user"], "user2") # Verify user can't retrieve values for other users by manipulating "prefix" - resp = self.app.get('/v1/keys?scope=all&prefix=user1:') + resp = self.app.get("/v1/keys?scope=all&prefix=user1:") self.assertEqual(resp.json, []) - resp = self.app.get('/v1/keys?scope=all&prefix=user') + resp = self.app.get("/v1/keys?scope=all&prefix=user") self.assertEqual(len(resp.json), 2) # 2 user - self.assertEqual(resp.json[0]['name'], 'user2') - self.assertEqual(resp.json[0]['scope'], 'st2kv.user') - self.assertEqual(resp.json[0]['user'], 'user2') + self.assertEqual(resp.json[0]["name"], "user2") + self.assertEqual(resp.json[0]["scope"], "st2kv.user") + self.assertEqual(resp.json[0]["user"], "user2") - self.assertEqual(resp.json[1]['name'], 'userkey') - self.assertEqual(resp.json[1]['scope'], 'st2kv.user') - self.assertEqual(resp.json[1]['user'], 'user2') + self.assertEqual(resp.json[1]["name"], "userkey") + self.assertEqual(resp.json[1]["scope"], "st2kv.user") + self.assertEqual(resp.json[1]["user"], "user2") # Verify non-admon user can't retrieve key for an arbitrary users - resp = self.app.get('/v1/keys?scope=user&user=user1', expect_errors=True) - expected_error = '"user" attribute can only be provided by admins when RBAC is enabled' + resp = self.app.get("/v1/keys?scope=user&user=user1", expect_errors=True) + expected_error = ( + '"user" attribute can only be provided by admins when RBAC is enabled' + ) self.assertEqual(resp.status_int, http_client.FORBIDDEN) - self.assertEqual(resp.json['faultstring'], expected_error) + self.assertEqual(resp.json["faultstring"], expected_error) # 3. "all" scope user user3 - should only be able to view system + current user items self.use_user(user_db_3) - resp = self.app.get('/v1/keys?scope=all') + resp = self.app.get("/v1/keys?scope=all") self.assertEqual(len(resp.json), 2 + 2) # 2 system, 2 user - self.assertEqual(resp.json[0]['name'], 'system1') - self.assertEqual(resp.json[0]['scope'], 'st2kv.system') + self.assertEqual(resp.json[0]["name"], "system1") + self.assertEqual(resp.json[0]["scope"], "st2kv.system") - self.assertEqual(resp.json[1]['name'], 'system2') - self.assertEqual(resp.json[1]['scope'], 'st2kv.system') + self.assertEqual(resp.json[1]["name"], "system2") + self.assertEqual(resp.json[1]["scope"], "st2kv.system") - self.assertEqual(resp.json[2]['name'], 'user3') - self.assertEqual(resp.json[2]['scope'], 'st2kv.user') - self.assertEqual(resp.json[2]['user'], 'user3') + self.assertEqual(resp.json[2]["name"], "user3") + self.assertEqual(resp.json[2]["scope"], "st2kv.user") + self.assertEqual(resp.json[2]["user"], "user3") - self.assertEqual(resp.json[3]['name'], 'userkey') - self.assertEqual(resp.json[3]['scope'], 'st2kv.user') - self.assertEqual(resp.json[3]['user'], 'user3') + self.assertEqual(resp.json[3]["name"], "userkey") + self.assertEqual(resp.json[3]["scope"], "st2kv.user") + self.assertEqual(resp.json[3]["user"], "user3") # Verify user can't retrieve values for other users by manipulating "prefix" - resp = self.app.get('/v1/keys?scope=all&prefix=user1:') + resp = self.app.get("/v1/keys?scope=all&prefix=user1:") self.assertEqual(resp.json, []) - resp = self.app.get('/v1/keys?scope=all&prefix=user') + resp = self.app.get("/v1/keys?scope=all&prefix=user") self.assertEqual(len(resp.json), 2) # 2 user - self.assertEqual(resp.json[0]['name'], 'user3') - self.assertEqual(resp.json[0]['scope'], 'st2kv.user') - self.assertEqual(resp.json[0]['user'], 'user3') + self.assertEqual(resp.json[0]["name"], "user3") + self.assertEqual(resp.json[0]["scope"], "st2kv.user") + self.assertEqual(resp.json[0]["user"], "user3") - self.assertEqual(resp.json[1]['name'], 'userkey') - self.assertEqual(resp.json[1]['scope'], 'st2kv.user') - self.assertEqual(resp.json[1]['user'], 'user3') + self.assertEqual(resp.json[1]["name"], "userkey") + self.assertEqual(resp.json[1]["scope"], "st2kv.user") + self.assertEqual(resp.json[1]["user"], "user3") # Clean up - self.__do_delete('system1') - self.__do_delete('system2') + self.__do_delete("system1") + self.__do_delete("system2") self.use_user(user_db_1) - self.__do_delete('user1?scope=user') - self.__do_delete('userkey?scope=user') + self.__do_delete("user1?scope=user") + self.__do_delete("userkey?scope=user") self.use_user(user_db_2) - self.__do_delete('user2?scope=user') - self.__do_delete('userkey?scope=user') + self.__do_delete("user2?scope=user") + self.__do_delete("userkey?scope=user") self.use_user(user_db_3) - self.__do_delete('user3?scope=user') - self.__do_delete('userkey?scope=user') + self.__do_delete("user3?scope=user") + self.__do_delete("userkey?scope=user") def test_get_all_user_query_param_can_only_be_used_with_rbac(self): - resp = self.app.get('/v1/keys?user=foousera', expect_errors=True) + resp = self.app.get("/v1/keys?user=foousera", expect_errors=True) - expected_error = '"user" attribute can only be provided by admins when RBAC is enabled' + expected_error = ( + '"user" attribute can only be provided by admins when RBAC is enabled' + ) self.assertEqual(resp.status_int, http_client.FORBIDDEN) - self.assertEqual(resp.json['faultstring'], expected_error) + self.assertEqual(resp.json["faultstring"], expected_error) def test_get_one_user_query_param_can_only_be_used_with_rbac(self): - resp = self.app.get('/v1/keys/keystone_endpoint?user=foousera', expect_errors=True) + resp = self.app.get( + "/v1/keys/keystone_endpoint?user=foousera", expect_errors=True + ) - expected_error = '"user" attribute can only be provided by admins when RBAC is enabled' + expected_error = ( + '"user" attribute can only be provided by admins when RBAC is enabled' + ) self.assertEqual(resp.status_int, http_client.FORBIDDEN) - self.assertEqual(resp.json['faultstring'], expected_error) + self.assertEqual(resp.json["faultstring"], expected_error) def test_get_all_prefix_filtering(self): - put_resp1 = self.__do_put(KVP['name'], KVP) - put_resp2 = self.__do_put(KVP_2['name'], KVP_2) + put_resp1 = self.__do_put(KVP["name"], KVP) + put_resp2 = self.__do_put(KVP_2["name"], KVP_2) self.assertEqual(put_resp1.status_int, 200) self.assertEqual(put_resp2.status_int, 200) # No keys with that prefix - resp = self.app.get('/v1/keys?prefix=something') + resp = self.app.get("/v1/keys?prefix=something") self.assertEqual(resp.json, []) # Two keys with the provided prefix - resp = self.app.get('/v1/keys?prefix=keystone') + resp = self.app.get("/v1/keys?prefix=keystone") self.assertEqual(len(resp.json), 2) # One key with the provided prefix - resp = self.app.get('/v1/keys?prefix=keystone_endpoint') + resp = self.app.get("/v1/keys?prefix=keystone_endpoint") self.assertEqual(len(resp.json), 1) self.__do_delete(self.__get_kvp_id(put_resp1)) self.__do_delete(self.__get_kvp_id(put_resp2)) def test_get_one_fail(self): - resp = self.app.get('/v1/keys/1', expect_errors=True) + resp = self.app.get("/v1/keys/1", expect_errors=True) self.assertEqual(resp.status_int, 404) def test_put(self): - put_resp = self.__do_put('key1', KVP) + put_resp = self.__do_put("key1", KVP) update_input = put_resp.json - update_input['value'] = 'http://127.0.0.1:35357/v3' + update_input["value"] = "http://127.0.0.1:35357/v3" put_resp = self.__do_put(self.__get_kvp_id(put_resp), update_input) self.assertEqual(put_resp.status_int, 200) self.__do_delete(self.__get_kvp_id(put_resp)) def test_put_with_scope(self): - self.app.put_json('/v1/keys/%s' % 'keystone_endpoint', KVP, - expect_errors=False) - self.app.put_json('/v1/keys/%s?scope=st2kv.system' % 'keystone_version', KVP_2, - expect_errors=False) - - get_resp_1 = self.app.get('/v1/keys/keystone_endpoint') + self.app.put_json("/v1/keys/%s" % "keystone_endpoint", KVP, expect_errors=False) + self.app.put_json( + "/v1/keys/%s?scope=st2kv.system" % "keystone_version", + KVP_2, + expect_errors=False, + ) + + get_resp_1 = self.app.get("/v1/keys/keystone_endpoint") self.assertTrue(get_resp_1.status_int, 200) - self.assertEqual(self.__get_kvp_id(get_resp_1), 'keystone_endpoint') - get_resp_2 = self.app.get('/v1/keys/keystone_version?scope=st2kv.system') + self.assertEqual(self.__get_kvp_id(get_resp_1), "keystone_endpoint") + get_resp_2 = self.app.get("/v1/keys/keystone_version?scope=st2kv.system") self.assertTrue(get_resp_2.status_int, 200) - self.assertEqual(self.__get_kvp_id(get_resp_2), 'keystone_version') - get_resp_3 = self.app.get('/v1/keys/keystone_version') + self.assertEqual(self.__get_kvp_id(get_resp_2), "keystone_version") + get_resp_3 = self.app.get("/v1/keys/keystone_version") self.assertTrue(get_resp_3.status_int, 200) - self.assertEqual(self.__get_kvp_id(get_resp_3), 'keystone_version') - self.app.delete('/v1/keys/keystone_endpoint?scope=st2kv.system') - self.app.delete('/v1/keys/keystone_version?scope=st2kv.system') + self.assertEqual(self.__get_kvp_id(get_resp_3), "keystone_version") + self.app.delete("/v1/keys/keystone_endpoint?scope=st2kv.system") + self.app.delete("/v1/keys/keystone_version?scope=st2kv.system") def test_put_user_scope_and_system_scope_dont_overlap(self): - self.app.put_json('/v1/keys/%s?scope=st2kv.system' % 'keystone_version', KVP_2, - expect_errors=False) - self.app.put_json('/v1/keys/%s?scope=st2kv.user' % 'keystone_version', KVP_2_USER, - expect_errors=False) - get_resp = self.app.get('/v1/keys/keystone_version?scope=st2kv.system') - self.assertEqual(get_resp.json['value'], KVP_2['value']) - - get_resp = self.app.get('/v1/keys/keystone_version?scope=st2kv.user') - self.assertEqual(get_resp.json['value'], KVP_2_USER['value']) - self.app.delete('/v1/keys/keystone_version?scope=st2kv.system') - self.app.delete('/v1/keys/keystone_version?scope=st2kv.user') + self.app.put_json( + "/v1/keys/%s?scope=st2kv.system" % "keystone_version", + KVP_2, + expect_errors=False, + ) + self.app.put_json( + "/v1/keys/%s?scope=st2kv.user" % "keystone_version", + KVP_2_USER, + expect_errors=False, + ) + get_resp = self.app.get("/v1/keys/keystone_version?scope=st2kv.system") + self.assertEqual(get_resp.json["value"], KVP_2["value"]) + + get_resp = self.app.get("/v1/keys/keystone_version?scope=st2kv.user") + self.assertEqual(get_resp.json["value"], KVP_2_USER["value"]) + self.app.delete("/v1/keys/keystone_version?scope=st2kv.system") + self.app.delete("/v1/keys/keystone_version?scope=st2kv.user") def test_put_invalid_scope(self): - put_resp = self.app.put_json('/v1/keys/keystone_version?scope=st2', KVP_2, - expect_errors=True) + put_resp = self.app.put_json( + "/v1/keys/keystone_version?scope=st2", KVP_2, expect_errors=True + ) self.assertTrue(put_resp.status_int, 400) def test_get_all_with_scope(self): - self.app.put_json('/v1/keys/%s?scope=st2kv.system' % 'keystone_version', KVP_2, - expect_errors=False) - self.app.put_json('/v1/keys/%s?scope=st2kv.user' % 'keystone_version', KVP_2_USER, - expect_errors=False) + self.app.put_json( + "/v1/keys/%s?scope=st2kv.system" % "keystone_version", + KVP_2, + expect_errors=False, + ) + self.app.put_json( + "/v1/keys/%s?scope=st2kv.user" % "keystone_version", + KVP_2_USER, + expect_errors=False, + ) # Note that the following two calls overwrite st2sytem and st2kv.user scoped variables with # same name. - self.app.put_json('/v1/keys/%s?scope=system' % 'keystone_version', KVP_2, - expect_errors=False) - self.app.put_json('/v1/keys/%s?scope=user' % 'keystone_version', KVP_2_USER_LEGACY, - expect_errors=False) - - get_resp_all = self.app.get('/v1/keys?scope=all') + self.app.put_json( + "/v1/keys/%s?scope=system" % "keystone_version", KVP_2, expect_errors=False + ) + self.app.put_json( + "/v1/keys/%s?scope=user" % "keystone_version", + KVP_2_USER_LEGACY, + expect_errors=False, + ) + + get_resp_all = self.app.get("/v1/keys?scope=all") self.assertTrue(len(get_resp_all.json), 2) - get_resp_sys = self.app.get('/v1/keys?scope=st2kv.system') + get_resp_sys = self.app.get("/v1/keys?scope=st2kv.system") self.assertTrue(len(get_resp_sys.json), 1) - self.assertEqual(get_resp_sys.json[0]['value'], KVP_2['value']) + self.assertEqual(get_resp_sys.json[0]["value"], KVP_2["value"]) - get_resp_sys = self.app.get('/v1/keys?scope=system') + get_resp_sys = self.app.get("/v1/keys?scope=system") self.assertTrue(len(get_resp_sys.json), 1) - self.assertEqual(get_resp_sys.json[0]['value'], KVP_2['value']) + self.assertEqual(get_resp_sys.json[0]["value"], KVP_2["value"]) - get_resp_sys = self.app.get('/v1/keys?scope=st2kv.user') + get_resp_sys = self.app.get("/v1/keys?scope=st2kv.user") self.assertTrue(len(get_resp_sys.json), 1) - self.assertEqual(get_resp_sys.json[0]['value'], KVP_2_USER['value']) + self.assertEqual(get_resp_sys.json[0]["value"], KVP_2_USER["value"]) - get_resp_sys = self.app.get('/v1/keys?scope=user') + get_resp_sys = self.app.get("/v1/keys?scope=user") self.assertTrue(len(get_resp_sys.json), 1) - self.assertEqual(get_resp_sys.json[0]['value'], KVP_2_USER['value']) + self.assertEqual(get_resp_sys.json[0]["value"], KVP_2_USER["value"]) - self.app.delete('/v1/keys/keystone_version?scope=st2kv.system') - self.app.delete('/v1/keys/keystone_version?scope=st2kv.user') + self.app.delete("/v1/keys/keystone_version?scope=st2kv.system") + self.app.delete("/v1/keys/keystone_version?scope=st2kv.user") def test_get_all_with_scope_and_prefix_filtering(self): - self.app.put_json('/v1/keys/%s?scope=st2kv.user' % 'keystone_version', KVP_2_USER, - expect_errors=False) - self.app.put_json('/v1/keys/%s?scope=st2kv.user' % 'keystone_endpoint', KVP_3_USER, - expect_errors=False) - self.app.put_json('/v1/keys/%s?scope=st2kv.user' % 'customer_ssn', KVP_4_USER, - expect_errors=False) - get_prefix = self.app.get('/v1/keys?scope=st2kv.user&prefix=keystone') + self.app.put_json( + "/v1/keys/%s?scope=st2kv.user" % "keystone_version", + KVP_2_USER, + expect_errors=False, + ) + self.app.put_json( + "/v1/keys/%s?scope=st2kv.user" % "keystone_endpoint", + KVP_3_USER, + expect_errors=False, + ) + self.app.put_json( + "/v1/keys/%s?scope=st2kv.user" % "customer_ssn", + KVP_4_USER, + expect_errors=False, + ) + get_prefix = self.app.get("/v1/keys?scope=st2kv.user&prefix=keystone") self.assertEqual(len(get_prefix.json), 2) - self.app.delete('/v1/keys/keystone_version?scope=st2kv.user') - self.app.delete('/v1/keys/keystone_endpoint?scope=st2kv.user') - self.app.delete('/v1/keys/customer_ssn?scope=st2kv.user') + self.app.delete("/v1/keys/keystone_version?scope=st2kv.user") + self.app.delete("/v1/keys/keystone_endpoint?scope=st2kv.user") + self.app.delete("/v1/keys/customer_ssn?scope=st2kv.user") def test_put_with_ttl(self): - put_resp = self.__do_put('key_with_ttl', KVP_WITH_TTL) + put_resp = self.__do_put("key_with_ttl", KVP_WITH_TTL) self.assertEqual(put_resp.status_int, 200) - get_resp = self.app.get('/v1/keys') - self.assertTrue(get_resp.json[0]['expire_timestamp']) + get_resp = self.app.get("/v1/keys") + self.assertTrue(get_resp.json[0]["expire_timestamp"]) self.__do_delete(self.__get_kvp_id(put_resp)) def test_put_secret(self): - put_resp = self.__do_put('secret_key1', SECRET_KVP) + put_resp = self.__do_put("secret_key1", SECRET_KVP) kvp_id = self.__get_kvp_id(put_resp) get_resp = self.__do_get_one(kvp_id) - self.assertTrue(get_resp.json['encrypted']) - crypto_val = get_resp.json['value'] - self.assertNotEqual(SECRET_KVP['value'], crypto_val) + self.assertTrue(get_resp.json["encrypted"]) + crypto_val = get_resp.json["value"] + self.assertNotEqual(SECRET_KVP["value"], crypto_val) self.__do_delete(self.__get_kvp_id(put_resp)) def test_get_one_secret_no_decrypt(self): - put_resp = self.__do_put('secret_key1', SECRET_KVP) + put_resp = self.__do_put("secret_key1", SECRET_KVP) kvp_id = self.__get_kvp_id(put_resp) - get_resp = self.app.get('/v1/keys/secret_key1') + get_resp = self.app.get("/v1/keys/secret_key1") self.assertEqual(get_resp.status_int, 200) self.assertEqual(self.__get_kvp_id(get_resp), kvp_id) - self.assertTrue(get_resp.json['secret']) - self.assertTrue(get_resp.json['encrypted']) + self.assertTrue(get_resp.json["secret"]) + self.assertTrue(get_resp.json["encrypted"]) self.__do_delete(kvp_id) def test_get_one_secret_decrypt(self): - put_resp = self.__do_put('secret_key1', SECRET_KVP) + put_resp = self.__do_put("secret_key1", SECRET_KVP) kvp_id = self.__get_kvp_id(put_resp) - get_resp = self.app.get('/v1/keys/secret_key1?decrypt=true') + get_resp = self.app.get("/v1/keys/secret_key1?decrypt=true") self.assertEqual(get_resp.status_int, 200) self.assertEqual(self.__get_kvp_id(get_resp), kvp_id) - self.assertTrue(get_resp.json['secret']) - self.assertFalse(get_resp.json['encrypted']) - self.assertEqual(get_resp.json['value'], SECRET_KVP['value']) + self.assertTrue(get_resp.json["secret"]) + self.assertFalse(get_resp.json["encrypted"]) + self.assertEqual(get_resp.json["value"], SECRET_KVP["value"]) self.__do_delete(kvp_id) def test_get_all_decrypt(self): - put_resp = self.__do_put('secret_key1', SECRET_KVP) + put_resp = self.__do_put("secret_key1", SECRET_KVP) kvp_id_1 = self.__get_kvp_id(put_resp) - put_resp = self.__do_put('key1', KVP) + put_resp = self.__do_put("key1", KVP) kvp_id_2 = self.__get_kvp_id(put_resp) - kvps = {'key1': KVP, 'secret_key1': SECRET_KVP} - stored_kvps = self.app.get('/v1/keys?decrypt=true').json + kvps = {"key1": KVP, "secret_key1": SECRET_KVP} + stored_kvps = self.app.get("/v1/keys?decrypt=true").json self.assertTrue(len(stored_kvps), 2) for stored_kvp in stored_kvps: - self.assertFalse(stored_kvp['encrypted']) - exp_kvp = kvps.get(stored_kvp['name']) + self.assertFalse(stored_kvp["encrypted"]) + exp_kvp = kvps.get(stored_kvp["name"]) self.assertIsNotNone(exp_kvp) - self.assertEqual(exp_kvp['value'], stored_kvp['value']) + self.assertEqual(exp_kvp["value"], stored_kvp["value"]) self.__do_delete(kvp_id_1) self.__do_delete(kvp_id_2) def test_put_encrypted_value(self): # 1. encrypted=True, secret=True - put_resp = self.__do_put('secret_key1', ENCRYPTED_KVP) + put_resp = self.__do_put("secret_key1", ENCRYPTED_KVP) kvp_id = self.__get_kvp_id(put_resp) # Verify there is no secrets leakage self.assertEqual(put_resp.status_code, 200) - self.assertEqual(put_resp.json['name'], 'secret_key1') - self.assertEqual(put_resp.json['scope'], 'st2kv.system') - self.assertEqual(put_resp.json['encrypted'], True) - self.assertEqual(put_resp.json['secret'], True) - self.assertEqual(put_resp.json['value'], ENCRYPTED_KVP['value']) - self.assertTrue(put_resp.json['value'] != 'S3cret!Value') - self.assertTrue(len(put_resp.json['value']) > len('S3cret!Value') * 2) - - get_resp = self.__do_get_one(kvp_id + '?decrypt=True') - self.assertEqual(put_resp.json['name'], 'secret_key1') - self.assertEqual(put_resp.json['scope'], 'st2kv.system') - self.assertEqual(put_resp.json['encrypted'], True) - self.assertEqual(put_resp.json['secret'], True) - self.assertEqual(put_resp.json['value'], ENCRYPTED_KVP['value']) + self.assertEqual(put_resp.json["name"], "secret_key1") + self.assertEqual(put_resp.json["scope"], "st2kv.system") + self.assertEqual(put_resp.json["encrypted"], True) + self.assertEqual(put_resp.json["secret"], True) + self.assertEqual(put_resp.json["value"], ENCRYPTED_KVP["value"]) + self.assertTrue(put_resp.json["value"] != "S3cret!Value") + self.assertTrue(len(put_resp.json["value"]) > len("S3cret!Value") * 2) + + get_resp = self.__do_get_one(kvp_id + "?decrypt=True") + self.assertEqual(put_resp.json["name"], "secret_key1") + self.assertEqual(put_resp.json["scope"], "st2kv.system") + self.assertEqual(put_resp.json["encrypted"], True) + self.assertEqual(put_resp.json["secret"], True) + self.assertEqual(put_resp.json["value"], ENCRYPTED_KVP["value"]) # Verify data integrity post decryption - get_resp = self.__do_get_one(kvp_id + '?decrypt=True') - self.assertFalse(get_resp.json['encrypted']) - self.assertEqual(get_resp.json['value'], 'S3cret!Value') + get_resp = self.__do_get_one(kvp_id + "?decrypt=True") + self.assertFalse(get_resp.json["encrypted"]) + self.assertEqual(get_resp.json["value"], "S3cret!Value") self.__do_delete(self.__get_kvp_id(put_resp)) # 2. encrypted=True, secret=False # encrypted should always imply secret=True - put_resp = self.__do_put('secret_key2', ENCRYPTED_KVP_SECRET_FALSE) + put_resp = self.__do_put("secret_key2", ENCRYPTED_KVP_SECRET_FALSE) kvp_id = self.__get_kvp_id(put_resp) # Verify there is no secrets leakage self.assertEqual(put_resp.status_code, 200) - self.assertEqual(put_resp.json['name'], 'secret_key2') - self.assertEqual(put_resp.json['scope'], 'st2kv.system') - self.assertEqual(put_resp.json['encrypted'], True) - self.assertEqual(put_resp.json['secret'], True) - self.assertEqual(put_resp.json['value'], ENCRYPTED_KVP['value']) - self.assertTrue(put_resp.json['value'] != 'S3cret!Value') - self.assertTrue(len(put_resp.json['value']) > len('S3cret!Value') * 2) - - get_resp = self.__do_get_one(kvp_id + '?decrypt=True') - self.assertEqual(put_resp.json['name'], 'secret_key2') - self.assertEqual(put_resp.json['scope'], 'st2kv.system') - self.assertEqual(put_resp.json['encrypted'], True) - self.assertEqual(put_resp.json['secret'], True) - self.assertEqual(put_resp.json['value'], ENCRYPTED_KVP['value']) + self.assertEqual(put_resp.json["name"], "secret_key2") + self.assertEqual(put_resp.json["scope"], "st2kv.system") + self.assertEqual(put_resp.json["encrypted"], True) + self.assertEqual(put_resp.json["secret"], True) + self.assertEqual(put_resp.json["value"], ENCRYPTED_KVP["value"]) + self.assertTrue(put_resp.json["value"] != "S3cret!Value") + self.assertTrue(len(put_resp.json["value"]) > len("S3cret!Value") * 2) + + get_resp = self.__do_get_one(kvp_id + "?decrypt=True") + self.assertEqual(put_resp.json["name"], "secret_key2") + self.assertEqual(put_resp.json["scope"], "st2kv.system") + self.assertEqual(put_resp.json["encrypted"], True) + self.assertEqual(put_resp.json["secret"], True) + self.assertEqual(put_resp.json["value"], ENCRYPTED_KVP["value"]) # Verify data integrity post decryption - get_resp = self.__do_get_one(kvp_id + '?decrypt=True') - self.assertFalse(get_resp.json['encrypted']) - self.assertEqual(get_resp.json['value'], 'S3cret!Value') + get_resp = self.__do_get_one(kvp_id + "?decrypt=True") + self.assertFalse(get_resp.json["encrypted"]) + self.assertEqual(get_resp.json["value"], "S3cret!Value") self.__do_delete(self.__get_kvp_id(put_resp)) def test_put_encrypted_value_integrity_check_failed(self): data = copy.deepcopy(ENCRYPTED_KVP) - data['value'] = 'corrupted' - put_resp = self.__do_put('secret_key1', data, expect_errors=True) + data["value"] = "corrupted" + put_resp = self.__do_put("secret_key1", data, expect_errors=True) self.assertEqual(put_resp.status_code, 400) - expected_error = ('Failed to verify the integrity of the provided value for key ' - '"secret_key1".') - self.assertIn(expected_error, put_resp.json['faultstring']) + expected_error = ( + "Failed to verify the integrity of the provided value for key " + '"secret_key1".' + ) + self.assertIn(expected_error, put_resp.json["faultstring"]) data = copy.deepcopy(ENCRYPTED_KVP) - data['value'] = str(data['value'][:-2]) - put_resp = self.__do_put('secret_key1', data, expect_errors=True) + data["value"] = str(data["value"][:-2]) + put_resp = self.__do_put("secret_key1", data, expect_errors=True) self.assertEqual(put_resp.status_code, 400) - expected_error = ('Failed to verify the integrity of the provided value for key ' - '"secret_key1".') - self.assertIn(expected_error, put_resp.json['faultstring']) + expected_error = ( + "Failed to verify the integrity of the provided value for key " + '"secret_key1".' + ) + self.assertIn(expected_error, put_resp.json["faultstring"]) def test_put_delete(self): - put_resp = self.__do_put('key1', KVP) + put_resp = self.__do_put("key1", KVP) self.assertEqual(put_resp.status_int, 200) self.__do_delete(self.__get_kvp_id(put_resp)) def test_delete(self): - put_resp = self.__do_put('key1', KVP) + put_resp = self.__do_put("key1", KVP) del_resp = self.__do_delete(self.__get_kvp_id(put_resp)) self.assertEqual(del_resp.status_int, 204) def test_delete_fail(self): - resp = self.__do_delete('inexistentkey', expect_errors=True) + resp = self.__do_delete("inexistentkey", expect_errors=True) self.assertEqual(resp.status_int, 404) @staticmethod def __get_kvp_id(resp): - return resp.json['name'] + return resp.json["name"] def __do_get_one(self, kvp_id, expect_errors=False): - return self.app.get('/v1/keys/%s' % kvp_id, expect_errors=expect_errors) + return self.app.get("/v1/keys/%s" % kvp_id, expect_errors=expect_errors) def __do_put(self, kvp_id, kvp, expect_errors=False): - return self.app.put_json('/v1/keys/%s' % kvp_id, kvp, expect_errors=expect_errors) + return self.app.put_json( + "/v1/keys/%s" % kvp_id, kvp, expect_errors=expect_errors + ) def __do_delete(self, kvp_id, expect_errors=False): - return self.app.delete('/v1/keys/%s' % kvp_id, expect_errors=expect_errors) + return self.app.delete("/v1/keys/%s" % kvp_id, expect_errors=expect_errors) diff --git a/st2api/tests/unit/controllers/v1/test_pack_config_schema.py b/st2api/tests/unit/controllers/v1/test_pack_config_schema.py index bff5935e386..a38c278f077 100644 --- a/st2api/tests/unit/controllers/v1/test_pack_config_schema.py +++ b/st2api/tests/unit/controllers/v1/test_pack_config_schema.py @@ -19,12 +19,10 @@ from st2tests.fixturesloader import get_fixtures_packs_base_path -__all__ = [ - 'PackConfigSchemasControllerTestCase' -] +__all__ = ["PackConfigSchemasControllerTestCase"] PACKS_PATH = get_fixtures_packs_base_path() -CONFIG_SCHEMA_COUNT = len(glob.glob('%s/*/config.schema.yaml' % (PACKS_PATH))) +CONFIG_SCHEMA_COUNT = len(glob.glob("%s/*/config.schema.yaml" % (PACKS_PATH))) assert CONFIG_SCHEMA_COUNT > 1 @@ -32,29 +30,34 @@ class PackConfigSchemasControllerTestCase(FunctionalTest): register_packs = True def test_get_all(self): - resp = self.app.get('/v1/config_schemas') + resp = self.app.get("/v1/config_schemas") self.assertEqual(resp.status_int, 200) - self.assertEqual(len(resp.json), CONFIG_SCHEMA_COUNT, - '/v1/config_schemas did not return all schemas.') + self.assertEqual( + len(resp.json), + CONFIG_SCHEMA_COUNT, + "/v1/config_schemas did not return all schemas.", + ) def test_get_one_success(self): - resp = self.app.get('/v1/config_schemas/dummy_pack_1') + resp = self.app.get("/v1/config_schemas/dummy_pack_1") self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json['pack'], 'dummy_pack_1') - self.assertIn('api_key', resp.json['attributes']) + self.assertEqual(resp.json["pack"], "dummy_pack_1") + self.assertIn("api_key", resp.json["attributes"]) def test_get_one_doesnt_exist(self): # Pack exists, schema doesnt - resp = self.app.get('/v1/config_schemas/dummy_pack_2', - expect_errors=True) + resp = self.app.get("/v1/config_schemas/dummy_pack_2", expect_errors=True) self.assertEqual(resp.status_int, 404) - self.assertIn('Unable to identify resource with pack_ref ', resp.json['faultstring']) + self.assertIn( + "Unable to identify resource with pack_ref ", resp.json["faultstring"] + ) # Pack doesn't exist - ref_or_id = 'pack_doesnt_exist' - resp = self.app.get('/v1/config_schemas/%s' % ref_or_id, - expect_errors=True) + ref_or_id = "pack_doesnt_exist" + resp = self.app.get("/v1/config_schemas/%s" % ref_or_id, expect_errors=True) self.assertEqual(resp.status_int, 404) # Changed from: 'Unable to find the PackDB instance' - self.assertTrue('Resource with a ref or id "%s" not found' % ref_or_id in - resp.json['faultstring']) + self.assertTrue( + 'Resource with a ref or id "%s" not found' % ref_or_id + in resp.json["faultstring"] + ) diff --git a/st2api/tests/unit/controllers/v1/test_pack_configs.py b/st2api/tests/unit/controllers/v1/test_pack_configs.py index 6e789c413ab..5a87719eaab 100644 --- a/st2api/tests/unit/controllers/v1/test_pack_configs.py +++ b/st2api/tests/unit/controllers/v1/test_pack_configs.py @@ -21,12 +21,10 @@ from st2api.controllers.v1.pack_configs import PackConfigsController from st2tests.fixturesloader import get_fixtures_packs_base_path -__all__ = [ - 'PackConfigsControllerTestCase' -] +__all__ = ["PackConfigsControllerTestCase"] PACKS_PATH = get_fixtures_packs_base_path() -CONFIGS_COUNT = len(glob.glob('%s/configs/*.yaml' % (PACKS_PATH))) +CONFIGS_COUNT = len(glob.glob("%s/configs/*.yaml" % (PACKS_PATH))) assert CONFIGS_COUNT > 1 @@ -35,60 +33,80 @@ class PackConfigsControllerTestCase(FunctionalTest): register_pack_configs = True def test_get_all(self): - resp = self.app.get('/v1/configs') + resp = self.app.get("/v1/configs") self.assertEqual(resp.status_int, 200) - self.assertEqual(len(resp.json), CONFIGS_COUNT, '/v1/configs did not return all configs.') + self.assertEqual( + len(resp.json), CONFIGS_COUNT, "/v1/configs did not return all configs." + ) def test_get_one_success(self): - resp = self.app.get('/v1/configs/dummy_pack_1', params={'show_secrets': True}, - expect_errors=True) + resp = self.app.get( + "/v1/configs/dummy_pack_1", + params={"show_secrets": True}, + expect_errors=True, + ) self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json['pack'], 'dummy_pack_1') - self.assertEqual(resp.json['values']['api_key'], '{{st2kv.user.api_key}}') - self.assertEqual(resp.json['values']['region'], 'us-west-1') + self.assertEqual(resp.json["pack"], "dummy_pack_1") + self.assertEqual(resp.json["values"]["api_key"], "{{st2kv.user.api_key}}") + self.assertEqual(resp.json["values"]["region"], "us-west-1") def test_get_one_mask_secret(self): - resp = self.app.get('/v1/configs/dummy_pack_1') + resp = self.app.get("/v1/configs/dummy_pack_1") self.assertEqual(resp.status_int, 200) - self.assertNotEqual(resp.json['values']['api_key'], '{{st2kv.user.api_key}}') + self.assertNotEqual(resp.json["values"]["api_key"], "{{st2kv.user.api_key}}") def test_get_one_pack_config_doesnt_exist(self): # Pack exists, config doesnt - resp = self.app.get('/v1/configs/dummy_pack_2', - expect_errors=True) + resp = self.app.get("/v1/configs/dummy_pack_2", expect_errors=True) self.assertEqual(resp.status_int, 404) - self.assertIn('Unable to identify resource with pack_ref ', resp.json['faultstring']) + self.assertIn( + "Unable to identify resource with pack_ref ", resp.json["faultstring"] + ) # Pack doesn't exist - resp = self.app.get('/v1/configs/pack_doesnt_exist', - expect_errors=True) + resp = self.app.get("/v1/configs/pack_doesnt_exist", expect_errors=True) self.assertEqual(resp.status_int, 404) # Changed from : 'Unable to find the PackDB instance.' - self.assertIn('Unable to identify resource with pack_ref', resp.json['faultstring']) + self.assertIn( + "Unable to identify resource with pack_ref", resp.json["faultstring"] + ) - @mock.patch.object(PackConfigsController, '_dump_config_to_disk', mock.MagicMock()) + @mock.patch.object(PackConfigsController, "_dump_config_to_disk", mock.MagicMock()) def test_put_pack_config(self): - get_resp = self.app.get('/v1/configs/dummy_pack_1', params={'show_secrets': True}, - expect_errors=True) - config = copy.copy(get_resp.json['values']) - config['region'] = 'us-west-2' + get_resp = self.app.get( + "/v1/configs/dummy_pack_1", + params={"show_secrets": True}, + expect_errors=True, + ) + config = copy.copy(get_resp.json["values"]) + config["region"] = "us-west-2" - put_resp = self.app.put_json('/v1/configs/dummy_pack_1', config) + put_resp = self.app.put_json("/v1/configs/dummy_pack_1", config) self.assertEqual(put_resp.status_int, 200) - put_resp_undo = self.app.put_json('/v1/configs/dummy_pack_1?show_secrets=true', - get_resp.json['values'], expect_errors=True) + put_resp_undo = self.app.put_json( + "/v1/configs/dummy_pack_1?show_secrets=true", + get_resp.json["values"], + expect_errors=True, + ) self.assertEqual(put_resp.status_int, 200) self.assertEqual(get_resp.json, put_resp_undo.json) - @mock.patch.object(PackConfigsController, '_dump_config_to_disk', mock.MagicMock()) + @mock.patch.object(PackConfigsController, "_dump_config_to_disk", mock.MagicMock()) def test_put_invalid_pack_config(self): - get_resp = self.app.get('/v1/configs/dummy_pack_11', params={'show_secrets': True}, - expect_errors=True) - config = copy.copy(get_resp.json['values']) - put_resp = self.app.put_json('/v1/configs/dummy_pack_11', config, expect_errors=True) + get_resp = self.app.get( + "/v1/configs/dummy_pack_11", + params={"show_secrets": True}, + expect_errors=True, + ) + config = copy.copy(get_resp.json["values"]) + put_resp = self.app.put_json( + "/v1/configs/dummy_pack_11", config, expect_errors=True + ) self.assertEqual(put_resp.status_int, 400) - expected_msg = ('Values specified as "secret: True" in config schema are automatically ' - 'decrypted by default. Use of "decrypt_kv" jinja filter is not allowed ' - 'for such values. Please check the specified values in the config or ' - 'the default values in the schema.') - self.assertIn(expected_msg, put_resp.json['faultstring']) + expected_msg = ( + 'Values specified as "secret: True" in config schema are automatically ' + 'decrypted by default. Use of "decrypt_kv" jinja filter is not allowed ' + "for such values. Please check the specified values in the config or " + "the default values in the schema." + ) + self.assertIn(expected_msg, put_resp.json["faultstring"]) diff --git a/st2api/tests/unit/controllers/v1/test_packs.py b/st2api/tests/unit/controllers/v1/test_packs.py index 9406a50af05..07cacd0be8d 100644 --- a/st2api/tests/unit/controllers/v1/test_packs.py +++ b/st2api/tests/unit/controllers/v1/test_packs.py @@ -33,9 +33,7 @@ from st2tests.fixturesloader import get_fixtures_base_path -__all__ = [ - 'PacksControllerTestCase' -] +__all__ = ["PacksControllerTestCase"] PACK_INDEX = { "test": { @@ -45,7 +43,7 @@ "author": "st2-dev", "keywords": ["some", "search", "another", "terms"], "email": "info@stackstorm.com", - "description": "st2 pack to test package management pipeline" + "description": "st2 pack to test package management pipeline", }, "test2": { "version": "0.5.0", @@ -54,13 +52,13 @@ "author": "stanley", "keywords": ["some", "special", "terms"], "email": "info@stackstorm.com", - "description": "another st2 pack to test package management pipeline" - } + "description": "another st2 pack to test package management pipeline", + }, } PACK_INDEXES = { - 'http://main.example.com': PACK_INDEX, - 'http://fallback.example.com': { + "http://main.example.com": PACK_INDEX, + "http://fallback.example.com": { "test": { "version": "0.1.0", "name": "test", @@ -68,10 +66,10 @@ "author": "st2-dev", "keywords": ["some", "search", "another", "terms"], "email": "info@stackstorm.com", - "description": "st2 pack to test package management pipeline" + "description": "st2 pack to test package management pipeline", } }, - 'http://override.example.com': { + "http://override.example.com": { "test2": { "version": "1.0.0", "name": "test2", @@ -79,10 +77,12 @@ "author": "stanley", "keywords": ["some", "special", "terms"], "email": "info@stackstorm.com", - "description": "another st2 pack to test package management pipeline" + "description": "another st2 pack to test package management pipeline", } }, - 'http://broken.example.com': requests.exceptions.RequestException('index is broken') + "http://broken.example.com": requests.exceptions.RequestException( + "index is broken" + ), } @@ -93,10 +93,7 @@ def mock_index_get(url, *args, **kwargs): raise index status = 200 - content = { - 'metadata': {}, - 'packs': index - } + content = {"metadata": {}, "packs": index} # Return mock response object @@ -104,311 +101,371 @@ def mock_index_get(url, *args, **kwargs): mock_resp.raise_for_status = mock.Mock() mock_resp.status_code = status mock_resp.content = content - mock_resp.json = mock.Mock( - return_value=content - ) + mock_resp.json = mock.Mock(return_value=content) return mock_resp -class PacksControllerTestCase(FunctionalTest, - APIControllerWithIncludeAndExcludeFilterTestCase): - get_all_path = '/v1/packs' +class PacksControllerTestCase( + FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase +): + get_all_path = "/v1/packs" controller_cls = PacksController - include_attribute_field_name = 'version' - exclude_attribute_field_name = 'author' + include_attribute_field_name = "version" + exclude_attribute_field_name = "author" @classmethod def setUpClass(cls): super(PacksControllerTestCase, cls).setUpClass() - cls.pack_db_1 = PackDB(name='pack1', description='foo', version='0.1.0', author='foo', - email='test@example.com', ref='pack1') - cls.pack_db_2 = PackDB(name='pack2', description='foo', version='0.1.0', author='foo', - email='test@example.com', ref='pack2') - cls.pack_db_3 = PackDB(name='pack3-name', ref='pack3-ref', description='foo', - version='0.1.0', author='foo', - email='test@example.com') + cls.pack_db_1 = PackDB( + name="pack1", + description="foo", + version="0.1.0", + author="foo", + email="test@example.com", + ref="pack1", + ) + cls.pack_db_2 = PackDB( + name="pack2", + description="foo", + version="0.1.0", + author="foo", + email="test@example.com", + ref="pack2", + ) + cls.pack_db_3 = PackDB( + name="pack3-name", + ref="pack3-ref", + description="foo", + version="0.1.0", + author="foo", + email="test@example.com", + ) Pack.add_or_update(cls.pack_db_1) Pack.add_or_update(cls.pack_db_2) Pack.add_or_update(cls.pack_db_3) def test_get_all(self): - resp = self.app.get('/v1/packs') + resp = self.app.get("/v1/packs") self.assertEqual(resp.status_int, 200) - self.assertEqual(len(resp.json), 3, '/v1/actionalias did not return all packs.') + self.assertEqual(len(resp.json), 3, "/v1/actionalias did not return all packs.") def test_get_one(self): # Get by id - resp = self.app.get('/v1/packs/%s' % (self.pack_db_1.id)) + resp = self.app.get("/v1/packs/%s" % (self.pack_db_1.id)) self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json['name'], self.pack_db_1.name) + self.assertEqual(resp.json["name"], self.pack_db_1.name) # Get by name - resp = self.app.get('/v1/packs/%s' % (self.pack_db_1.ref)) + resp = self.app.get("/v1/packs/%s" % (self.pack_db_1.ref)) self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json['ref'], self.pack_db_1.ref) - self.assertEqual(resp.json['name'], self.pack_db_1.name) + self.assertEqual(resp.json["ref"], self.pack_db_1.ref) + self.assertEqual(resp.json["name"], self.pack_db_1.name) # Get by ref (ref != name) - resp = self.app.get('/v1/packs/%s' % (self.pack_db_3.ref)) + resp = self.app.get("/v1/packs/%s" % (self.pack_db_3.ref)) self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json['ref'], self.pack_db_3.ref) + self.assertEqual(resp.json["ref"], self.pack_db_3.ref) def test_get_one_doesnt_exist(self): - resp = self.app.get('/v1/packs/doesntexistfoo', expect_errors=True) + resp = self.app.get("/v1/packs/doesntexistfoo", expect_errors=True) self.assertEqual(resp.status_int, 404) - @mock.patch.object(ActionExecutionsControllerMixin, '_handle_schedule_execution') + @mock.patch.object(ActionExecutionsControllerMixin, "_handle_schedule_execution") def test_install(self, _handle_schedule_execution): - _handle_schedule_execution.return_value = Response(json={'id': '123'}) - payload = {'packs': ['some']} + _handle_schedule_execution.return_value = Response(json={"id": "123"}) + payload = {"packs": ["some"]} - resp = self.app.post_json('/v1/packs/install', payload) + resp = self.app.post_json("/v1/packs/install", payload) self.assertEqual(resp.status_int, 202) - self.assertEqual(resp.json, {'execution_id': '123'}) + self.assertEqual(resp.json, {"execution_id": "123"}) - @mock.patch.object(ActionExecutionsControllerMixin, '_handle_schedule_execution') + @mock.patch.object(ActionExecutionsControllerMixin, "_handle_schedule_execution") def test_install_with_force_parameter(self, _handle_schedule_execution): - _handle_schedule_execution.return_value = Response(json={'id': '123'}) - payload = {'packs': ['some'], 'force': True} + _handle_schedule_execution.return_value = Response(json={"id": "123"}) + payload = {"packs": ["some"], "force": True} - resp = self.app.post_json('/v1/packs/install', payload) + resp = self.app.post_json("/v1/packs/install", payload) self.assertEqual(resp.status_int, 202) - self.assertEqual(resp.json, {'execution_id': '123'}) + self.assertEqual(resp.json, {"execution_id": "123"}) - @mock.patch.object(ActionExecutionsControllerMixin, '_handle_schedule_execution') + @mock.patch.object(ActionExecutionsControllerMixin, "_handle_schedule_execution") def test_install_with_skip_dependencies_parameter(self, _handle_schedule_execution): - _handle_schedule_execution.return_value = Response(json={'id': '123'}) - payload = {'packs': ['some'], 'skip_dependencies': True} + _handle_schedule_execution.return_value = Response(json={"id": "123"}) + payload = {"packs": ["some"], "skip_dependencies": True} - resp = self.app.post_json('/v1/packs/install', payload) + resp = self.app.post_json("/v1/packs/install", payload) self.assertEqual(resp.status_int, 202) - self.assertEqual(resp.json, {'execution_id': '123'}) + self.assertEqual(resp.json, {"execution_id": "123"}) - @mock.patch.object(ActionExecutionsControllerMixin, '_handle_schedule_execution') + @mock.patch.object(ActionExecutionsControllerMixin, "_handle_schedule_execution") def test_uninstall(self, _handle_schedule_execution): - _handle_schedule_execution.return_value = Response(json={'id': '123'}) - payload = {'packs': ['some']} + _handle_schedule_execution.return_value = Response(json={"id": "123"}) + payload = {"packs": ["some"]} - resp = self.app.post_json('/v1/packs/uninstall', payload) + resp = self.app.post_json("/v1/packs/uninstall", payload) self.assertEqual(resp.status_int, 202) - self.assertEqual(resp.json, {'execution_id': '123'}) + self.assertEqual(resp.json, {"execution_id": "123"}) - @mock.patch.object(pack_service, 'fetch_pack_index', - mock.MagicMock(return_value=(PACK_INDEX, {}))) + @mock.patch.object( + pack_service, "fetch_pack_index", mock.MagicMock(return_value=(PACK_INDEX, {})) + ) def test_search_with_query(self): test_scenarios = [ { - 'input': {'query': 'test'}, - 'expected_code': 200, - 'expected_result': [PACK_INDEX['test'], PACK_INDEX['test2']] + "input": {"query": "test"}, + "expected_code": 200, + "expected_result": [PACK_INDEX["test"], PACK_INDEX["test2"]], }, { - 'input': {'query': 'stanley'}, - 'expected_code': 200, - 'expected_result': [PACK_INDEX['test2']] + "input": {"query": "stanley"}, + "expected_code": 200, + "expected_result": [PACK_INDEX["test2"]], }, { - 'input': {'query': 'special'}, - 'expected_code': 200, - 'expected_result': [PACK_INDEX['test2']] + "input": {"query": "special"}, + "expected_code": 200, + "expected_result": [PACK_INDEX["test2"]], }, { - 'input': {'query': 'TEST'}, # Search should be case insensitive by default - 'expected_code': 200, - 'expected_result': [PACK_INDEX['test'], PACK_INDEX['test2']] + "input": { + "query": "TEST" + }, # Search should be case insensitive by default + "expected_code": 200, + "expected_result": [PACK_INDEX["test"], PACK_INDEX["test2"]], }, { - 'input': {'query': 'SPECIAL'}, - 'expected_code': 200, - 'expected_result': [PACK_INDEX['test2']] + "input": {"query": "SPECIAL"}, + "expected_code": 200, + "expected_result": [PACK_INDEX["test2"]], }, { - 'input': {'query': 'sPeCiAL'}, - 'expected_code': 200, - 'expected_result': [PACK_INDEX['test2']] + "input": {"query": "sPeCiAL"}, + "expected_code": 200, + "expected_result": [PACK_INDEX["test2"]], }, { - 'input': {'query': 'st2-dev'}, - 'expected_code': 200, - 'expected_result': [PACK_INDEX['test']] + "input": {"query": "st2-dev"}, + "expected_code": 200, + "expected_result": [PACK_INDEX["test"]], }, { - 'input': {'query': 'ST2-dev'}, - 'expected_code': 200, - 'expected_result': [PACK_INDEX['test']] + "input": {"query": "ST2-dev"}, + "expected_code": 200, + "expected_result": [PACK_INDEX["test"]], }, { - 'input': {'query': '-dev'}, - 'expected_code': 200, - 'expected_result': [PACK_INDEX['test']] + "input": {"query": "-dev"}, + "expected_code": 200, + "expected_result": [PACK_INDEX["test"]], }, - { - 'input': {'query': 'core'}, - 'expected_code': 200, - 'expected_result': [] - } + {"input": {"query": "core"}, "expected_code": 200, "expected_result": []}, ] for scenario in test_scenarios: - resp = self.app.post_json('/v1/packs/index/search', scenario['input']) - self.assertEqual(resp.status_int, scenario['expected_code']) - self.assertEqual(resp.json, scenario['expected_result']) - - @mock.patch.object(pack_service, 'get_pack_from_index', - mock.MagicMock(return_value=PACK_INDEX['test'])) + resp = self.app.post_json("/v1/packs/index/search", scenario["input"]) + self.assertEqual(resp.status_int, scenario["expected_code"]) + self.assertEqual(resp.json, scenario["expected_result"]) + + @mock.patch.object( + pack_service, + "get_pack_from_index", + mock.MagicMock(return_value=PACK_INDEX["test"]), + ) def test_search_with_pack_has_result(self): - resp = self.app.post_json('/v1/packs/index/search', {'pack': 'st2-dev'}) + resp = self.app.post_json("/v1/packs/index/search", {"pack": "st2-dev"}) self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json, PACK_INDEX['test']) + self.assertEqual(resp.json, PACK_INDEX["test"]) - @mock.patch.object(pack_service, 'get_pack_from_index', - mock.MagicMock(return_value=None)) + @mock.patch.object( + pack_service, "get_pack_from_index", mock.MagicMock(return_value=None) + ) def test_search_with_pack_no_result(self): - resp = self.app.post_json('/v1/packs/index/search', {'pack': 'not-found'}) + resp = self.app.post_json("/v1/packs/index/search", {"pack": "not-found"}) self.assertEqual(resp.status_int, 200) self.assertEqual(resp.json, []) - @mock.patch.object(pack_service, 'fetch_pack_index', - mock.MagicMock(return_value=(PACK_INDEX, {}))) + @mock.patch.object( + pack_service, "fetch_pack_index", mock.MagicMock(return_value=(PACK_INDEX, {})) + ) def test_show(self): - resp = self.app.post_json('/v1/packs/index/search', {'pack': 'test'}) + resp = self.app.post_json("/v1/packs/index/search", {"pack": "test"}) self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json, PACK_INDEX['test']) + self.assertEqual(resp.json, PACK_INDEX["test"]) - resp = self.app.post_json('/v1/packs/index/search', {'pack': 'test2'}) + resp = self.app.post_json("/v1/packs/index/search", {"pack": "test2"}) self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json, PACK_INDEX['test2']) + self.assertEqual(resp.json, PACK_INDEX["test2"]) - @mock.patch.object(pack_service, '_build_index_list', - mock.MagicMock(return_value=['http://main.example.com'])) - @mock.patch.object(requests, 'get', mock_index_get) + @mock.patch.object( + pack_service, + "_build_index_list", + mock.MagicMock(return_value=["http://main.example.com"]), + ) + @mock.patch.object(requests, "get", mock_index_get) def test_index_health(self): - resp = self.app.get('/v1/packs/index/health') + resp = self.app.get("/v1/packs/index/health") self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json, { - 'packs': { - 'count': 2 + self.assertEqual( + resp.json, + { + "packs": {"count": 2}, + "indexes": { + "count": 1, + "status": [ + { + "url": "http://main.example.com", + "message": "Success.", + "packs": 2, + "error": None, + } + ], + "valid": 1, + "errors": {}, + "invalid": 0, + }, }, - 'indexes': { - 'count': 1, - 'status': [{ - 'url': 'http://main.example.com', - 'message': 'Success.', - 'packs': 2, - 'error': None - }], - 'valid': 1, - 'errors': {}, - 'invalid': 0 - } - }) - - @mock.patch.object(pack_service, '_build_index_list', - mock.MagicMock(return_value=['http://main.example.com', - 'http://broken.example.com'])) - @mock.patch.object(requests, 'get', mock_index_get) + ) + + @mock.patch.object( + pack_service, + "_build_index_list", + mock.MagicMock( + return_value=["http://main.example.com", "http://broken.example.com"] + ), + ) + @mock.patch.object(requests, "get", mock_index_get) def test_index_health_broken(self): - resp = self.app.get('/v1/packs/index/health') + resp = self.app.get("/v1/packs/index/health") self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json, { - 'packs': { - 'count': 2 - }, - 'indexes': { - 'count': 2, - 'status': [{ - 'url': 'http://main.example.com', - 'message': 'Success.', - 'packs': 2, - 'error': None - }, { - 'url': 'http://broken.example.com', - 'message': "RequestException('index is broken',)", - 'packs': 0, - 'error': 'unresponsive' - }], - 'valid': 1, - 'errors': { - 'unresponsive': 1 + self.assertEqual( + resp.json, + { + "packs": {"count": 2}, + "indexes": { + "count": 2, + "status": [ + { + "url": "http://main.example.com", + "message": "Success.", + "packs": 2, + "error": None, + }, + { + "url": "http://broken.example.com", + "message": "RequestException('index is broken',)", + "packs": 0, + "error": "unresponsive", + }, + ], + "valid": 1, + "errors": {"unresponsive": 1}, + "invalid": 1, }, - 'invalid': 1 - } - }) + }, + ) - @mock.patch.object(pack_service, '_build_index_list', - mock.MagicMock(return_value=['http://main.example.com'])) - @mock.patch.object(requests, 'get', mock_index_get) + @mock.patch.object( + pack_service, + "_build_index_list", + mock.MagicMock(return_value=["http://main.example.com"]), + ) + @mock.patch.object(requests, "get", mock_index_get) def test_index(self): - resp = self.app.get('/v1/packs/index') + resp = self.app.get("/v1/packs/index") self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json, { - 'status': [{ - 'url': 'http://main.example.com', - 'message': 'Success.', - 'packs': 2, - 'error': None - }], - 'index': PACK_INDEX - }) - - @mock.patch.object(pack_service, '_build_index_list', - mock.MagicMock(return_value=['http://fallback.example.com', - 'http://main.example.com'])) - @mock.patch.object(requests, 'get', mock_index_get) + self.assertEqual( + resp.json, + { + "status": [ + { + "url": "http://main.example.com", + "message": "Success.", + "packs": 2, + "error": None, + } + ], + "index": PACK_INDEX, + }, + ) + + @mock.patch.object( + pack_service, + "_build_index_list", + mock.MagicMock( + return_value=["http://fallback.example.com", "http://main.example.com"] + ), + ) + @mock.patch.object(requests, "get", mock_index_get) def test_index_fallback(self): - resp = self.app.get('/v1/packs/index') + resp = self.app.get("/v1/packs/index") self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json, { - 'status': [{ - 'url': 'http://fallback.example.com', - 'message': 'Success.', - 'packs': 1, - 'error': None - }, { - 'url': 'http://main.example.com', - 'message': 'Success.', - 'packs': 2, - 'error': None - }], - 'index': PACK_INDEX - }) - - @mock.patch.object(pack_service, '_build_index_list', - mock.MagicMock(return_value=['http://main.example.com', - 'http://override.example.com'])) - @mock.patch.object(requests, 'get', mock_index_get) + self.assertEqual( + resp.json, + { + "status": [ + { + "url": "http://fallback.example.com", + "message": "Success.", + "packs": 1, + "error": None, + }, + { + "url": "http://main.example.com", + "message": "Success.", + "packs": 2, + "error": None, + }, + ], + "index": PACK_INDEX, + }, + ) + + @mock.patch.object( + pack_service, + "_build_index_list", + mock.MagicMock( + return_value=["http://main.example.com", "http://override.example.com"] + ), + ) + @mock.patch.object(requests, "get", mock_index_get) def test_index_override(self): - resp = self.app.get('/v1/packs/index') + resp = self.app.get("/v1/packs/index") self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json, { - 'status': [{ - 'url': 'http://main.example.com', - 'message': 'Success.', - 'packs': 2, - 'error': None - }, { - 'url': 'http://override.example.com', - 'message': 'Success.', - 'packs': 1, - 'error': None - }], - 'index': { - 'test': PACK_INDEX['test'], - 'test2': PACK_INDEXES['http://override.example.com']['test2'] - } - }) + self.assertEqual( + resp.json, + { + "status": [ + { + "url": "http://main.example.com", + "message": "Success.", + "packs": 2, + "error": None, + }, + { + "url": "http://override.example.com", + "message": "Success.", + "packs": 1, + "error": None, + }, + ], + "index": { + "test": PACK_INDEX["test"], + "test2": PACK_INDEXES["http://override.example.com"]["test2"], + }, + }, + ) def test_packs_register_endpoint_resource_register_order(self): # Verify that resources are registered in the same order as they are inside @@ -416,17 +473,17 @@ def test_packs_register_endpoint_resource_register_order(self): # Note: Sadly there is no easier / better way to test this resource_types = list(ENTITIES.keys()) expected_order = [ - 'trigger', - 'sensor', - 'action', - 'rule', - 'alias', - 'policy', - 'config' + "trigger", + "sensor", + "action", + "rule", + "alias", + "policy", + "config", ] self.assertEqual(resource_types, expected_order) - @mock.patch.object(ContentPackLoader, 'get_packs') + @mock.patch.object(ContentPackLoader, "get_packs") def test_packs_register_endpoint(self, mock_get_packs): # Register resources from all packs - make sure the count values are correctly added # together @@ -434,12 +491,12 @@ def test_packs_register_endpoint(self, mock_get_packs): # Note: We only register a couple of packs and not all on disk to speed # things up. Registering all the packs takes a long time. fixtures_base_path = get_fixtures_base_path() - packs_base_path = os.path.join(fixtures_base_path, 'packs') + packs_base_path = os.path.join(fixtures_base_path, "packs") pack_names = [ - 'dummy_pack_1', - 'dummy_pack_2', - 'dummy_pack_3', - 'dummy_pack_10', + "dummy_pack_1", + "dummy_pack_2", + "dummy_pack_3", + "dummy_pack_10", ] mock_return_value = {} for pack_name in pack_names: @@ -447,160 +504,180 @@ def test_packs_register_endpoint(self, mock_get_packs): mock_get_packs.return_value = mock_return_value - resp = self.app.post_json('/v1/packs/register', {'fail_on_failure': False}) + resp = self.app.post_json("/v1/packs/register", {"fail_on_failure": False}) self.assertEqual(resp.status_int, 200) - self.assertIn('runners', resp.json) - self.assertIn('actions', resp.json) - self.assertIn('triggers', resp.json) - self.assertIn('sensors', resp.json) - self.assertIn('rules', resp.json) - self.assertIn('rule_types', resp.json) - self.assertIn('aliases', resp.json) - self.assertIn('policy_types', resp.json) - self.assertIn('policies', resp.json) - self.assertIn('configs', resp.json) - - self.assertTrue(resp.json['actions'] >= 3) - self.assertTrue(resp.json['configs'] >= 1) + self.assertIn("runners", resp.json) + self.assertIn("actions", resp.json) + self.assertIn("triggers", resp.json) + self.assertIn("sensors", resp.json) + self.assertIn("rules", resp.json) + self.assertIn("rule_types", resp.json) + self.assertIn("aliases", resp.json) + self.assertIn("policy_types", resp.json) + self.assertIn("policies", resp.json) + self.assertIn("configs", resp.json) + + self.assertTrue(resp.json["actions"] >= 3) + self.assertTrue(resp.json["configs"] >= 1) # Register resources from a specific pack - resp = self.app.post_json('/v1/packs/register', {'packs': ['dummy_pack_1'], - 'fail_on_failure': False}) + resp = self.app.post_json( + "/v1/packs/register", {"packs": ["dummy_pack_1"], "fail_on_failure": False} + ) self.assertEqual(resp.status_int, 200) - self.assertTrue(resp.json['actions'] >= 1) - self.assertTrue(resp.json['sensors'] >= 1) - self.assertTrue(resp.json['configs'] >= 1) + self.assertTrue(resp.json["actions"] >= 1) + self.assertTrue(resp.json["sensors"] >= 1) + self.assertTrue(resp.json["configs"] >= 1) # Verify metadata_file attribute is set - action_dbs = Action.query(pack='dummy_pack_1') - self.assertEqual(action_dbs[0].metadata_file, 'actions/my_action.yaml') + action_dbs = Action.query(pack="dummy_pack_1") + self.assertEqual(action_dbs[0].metadata_file, "actions/my_action.yaml") # Register 'all' resource types should try include any possible content for the pack - resp = self.app.post_json('/v1/packs/register', {'packs': ['dummy_pack_1'], - 'fail_on_failure': False, - 'types': ['all']}) + resp = self.app.post_json( + "/v1/packs/register", + {"packs": ["dummy_pack_1"], "fail_on_failure": False, "types": ["all"]}, + ) self.assertEqual(resp.status_int, 200) - self.assertIn('runners', resp.json) - self.assertIn('actions', resp.json) - self.assertIn('triggers', resp.json) - self.assertIn('sensors', resp.json) - self.assertIn('rules', resp.json) - self.assertIn('rule_types', resp.json) - self.assertIn('aliases', resp.json) - self.assertIn('policy_types', resp.json) - self.assertIn('policies', resp.json) - self.assertIn('configs', resp.json) + self.assertIn("runners", resp.json) + self.assertIn("actions", resp.json) + self.assertIn("triggers", resp.json) + self.assertIn("sensors", resp.json) + self.assertIn("rules", resp.json) + self.assertIn("rule_types", resp.json) + self.assertIn("aliases", resp.json) + self.assertIn("policy_types", resp.json) + self.assertIn("policies", resp.json) + self.assertIn("configs", resp.json) # Registering single resource type should also cause dependent resources # to be registered # * actions -> runners # * rules -> rule types # * policies -> policy types - resp = self.app.post_json('/v1/packs/register', {'packs': ['dummy_pack_1'], - 'fail_on_failure': False, - 'types': ['actions']}) + resp = self.app.post_json( + "/v1/packs/register", + {"packs": ["dummy_pack_1"], "fail_on_failure": False, "types": ["actions"]}, + ) self.assertEqual(resp.status_int, 200) - self.assertTrue(resp.json['runners'] >= 1) - self.assertTrue(resp.json['actions'] >= 1) + self.assertTrue(resp.json["runners"] >= 1) + self.assertTrue(resp.json["actions"] >= 1) - resp = self.app.post_json('/v1/packs/register', {'packs': ['dummy_pack_1'], - 'fail_on_failure': False, - 'types': ['rules']}) + resp = self.app.post_json( + "/v1/packs/register", + {"packs": ["dummy_pack_1"], "fail_on_failure": False, "types": ["rules"]}, + ) self.assertEqual(resp.status_int, 200) - self.assertTrue(resp.json['rule_types'] >= 1) - self.assertTrue(resp.json['rules'] >= 1) + self.assertTrue(resp.json["rule_types"] >= 1) + self.assertTrue(resp.json["rules"] >= 1) - resp = self.app.post_json('/v1/packs/register', {'packs': ['dummy_pack_2'], - 'fail_on_failure': False, - 'types': ['policies']}) + resp = self.app.post_json( + "/v1/packs/register", + { + "packs": ["dummy_pack_2"], + "fail_on_failure": False, + "types": ["policies"], + }, + ) self.assertEqual(resp.status_int, 200) - self.assertTrue(resp.json['policy_types'] >= 1) - self.assertTrue(resp.json['policies'] >= 0) + self.assertTrue(resp.json["policy_types"] >= 1) + self.assertTrue(resp.json["policies"] >= 0) # Register specific type for all packs - resp = self.app.post_json('/v1/packs/register', {'types': ['sensor'], - 'fail_on_failure': False}) + resp = self.app.post_json( + "/v1/packs/register", {"types": ["sensor"], "fail_on_failure": False} + ) self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json, {'sensors': 3}) + self.assertEqual(resp.json, {"sensors": 3}) # Verify that plural name form also works - resp = self.app.post_json('/v1/packs/register', {'types': ['sensors'], - 'fail_on_failure': False}) + resp = self.app.post_json( + "/v1/packs/register", {"types": ["sensors"], "fail_on_failure": False} + ) self.assertEqual(resp.status_int, 200) # Register specific type for a single packs - resp = self.app.post_json('/v1/packs/register', - {'packs': ['dummy_pack_1'], 'types': ['action']}) + resp = self.app.post_json( + "/v1/packs/register", {"packs": ["dummy_pack_1"], "types": ["action"]} + ) self.assertEqual(resp.status_int, 200) # 13 real plus 1 mock runner - self.assertEqual(resp.json, {'actions': 1, 'runners': 14}) + self.assertEqual(resp.json, {"actions": 1, "runners": 14}) # Verify that plural name form also works - resp = self.app.post_json('/v1/packs/register', - {'packs': ['dummy_pack_1'], 'types': ['actions']}) + resp = self.app.post_json( + "/v1/packs/register", {"packs": ["dummy_pack_1"], "types": ["actions"]} + ) self.assertEqual(resp.status_int, 200) # 13 real plus 1 mock runner - self.assertEqual(resp.json, {'actions': 1, 'runners': 14}) + self.assertEqual(resp.json, {"actions": 1, "runners": 14}) # Register single resource from a single pack specified multiple times - verify that # resources from the same pack are only registered once - resp = self.app.post_json('/v1/packs/register', - {'packs': ['dummy_pack_1', 'dummy_pack_1', 'dummy_pack_1'], - 'types': ['actions'], - 'fail_on_failure': False}) + resp = self.app.post_json( + "/v1/packs/register", + { + "packs": ["dummy_pack_1", "dummy_pack_1", "dummy_pack_1"], + "types": ["actions"], + "fail_on_failure": False, + }, + ) self.assertEqual(resp.status_int, 200) # 13 real plus 1 mock runner - self.assertEqual(resp.json, {'actions': 1, 'runners': 14}) + self.assertEqual(resp.json, {"actions": 1, "runners": 14}) # Register resources from a single (non-existent pack) - resp = self.app.post_json('/v1/packs/register', {'packs': ['doesntexist']}, - expect_errors=True) + resp = self.app.post_json( + "/v1/packs/register", {"packs": ["doesntexist"]}, expect_errors=True + ) self.assertEqual(resp.status_int, 400) - self.assertIn('Pack "doesntexist" not found on disk:', resp.json['faultstring']) + self.assertIn('Pack "doesntexist" not found on disk:', resp.json["faultstring"]) # Fail on failure is enabled by default - resp = self.app.post_json('/v1/packs/register', expect_errors=True) + resp = self.app.post_json("/v1/packs/register", expect_errors=True) expected_msg = 'Failed to register pack "dummy_pack_10":' self.assertEqual(resp.status_int, 400) - self.assertIn(expected_msg, resp.json['faultstring']) + self.assertIn(expected_msg, resp.json["faultstring"]) # Fail on failure (broken pack metadata) - resp = self.app.post_json('/v1/packs/register', {'packs': ['dummy_pack_1']}, - expect_errors=True) + resp = self.app.post_json( + "/v1/packs/register", {"packs": ["dummy_pack_1"]}, expect_errors=True + ) expected_msg = 'Referenced policy_type "action.mock_policy_error" doesnt exist' self.assertEqual(resp.status_int, 400) - self.assertIn(expected_msg, resp.json['faultstring']) + self.assertIn(expected_msg, resp.json["faultstring"]) # Fail on failure (broken action metadata) - resp = self.app.post_json('/v1/packs/register', {'packs': ['dummy_pack_15']}, - expect_errors=True) + resp = self.app.post_json( + "/v1/packs/register", {"packs": ["dummy_pack_15"]}, expect_errors=True + ) - expected_msg = 'Failed to register action' + expected_msg = "Failed to register action" self.assertEqual(resp.status_int, 400) - self.assertIn(expected_msg, resp.json['faultstring']) + self.assertIn(expected_msg, resp.json["faultstring"]) - expected_msg = '\'stringa\' is not valid under any of the given schemas' + expected_msg = "'stringa' is not valid under any of the given schemas" self.assertEqual(resp.status_int, 400) - self.assertIn(expected_msg, resp.json['faultstring']) + self.assertIn(expected_msg, resp.json["faultstring"]) def test_get_all_invalid_exclude_and_include_parameter(self): pass def _insert_mock_models(self): - return [self.pack_db_1['id'], self.pack_db_2['id'], self.pack_db_3['id']] + return [self.pack_db_1["id"], self.pack_db_2["id"], self.pack_db_3["id"]] def _do_delete(self, object_ids): pass diff --git a/st2api/tests/unit/controllers/v1/test_packs_views.py b/st2api/tests/unit/controllers/v1/test_packs_views.py index 5535a6e22bc..a1b96a4aea1 100644 --- a/st2api/tests/unit/controllers/v1/test_packs_views.py +++ b/st2api/tests/unit/controllers/v1/test_packs_views.py @@ -21,7 +21,7 @@ from st2tests.api import FunctionalTest -@mock.patch('st2common.bootstrap.base.REGISTERED_PACKS_CACHE', {}) +@mock.patch("st2common.bootstrap.base.REGISTERED_PACKS_CACHE", {}) class PacksViewsControllerTestCase(FunctionalTest): @classmethod def setUpClass(cls): @@ -31,32 +31,34 @@ def setUpClass(cls): actions_registrar.register_actions(use_pack_cache=False) def test_get_pack_files_success(self): - resp = self.app.get('/v1/packs/views/files/dummy_pack_1') + resp = self.app.get("/v1/packs/views/files/dummy_pack_1") self.assertEqual(resp.status_int, http_client.OK) self.assertTrue(len(resp.json) > 1) - item = [_item for _item in resp.json if _item['file_path'] == 'pack.yaml'][0] - self.assertEqual(item['file_path'], 'pack.yaml') - item = [_item for _item in resp.json if _item['file_path'] == 'actions/my_action.py'][0] - self.assertEqual(item['file_path'], 'actions/my_action.py') + item = [_item for _item in resp.json if _item["file_path"] == "pack.yaml"][0] + self.assertEqual(item["file_path"], "pack.yaml") + item = [ + _item for _item in resp.json if _item["file_path"] == "actions/my_action.py" + ][0] + self.assertEqual(item["file_path"], "actions/my_action.py") def test_get_pack_files_pack_doesnt_exist(self): - resp = self.app.get('/v1/packs/views/files/doesntexist', expect_errors=True) + resp = self.app.get("/v1/packs/views/files/doesntexist", expect_errors=True) self.assertEqual(resp.status_int, http_client.NOT_FOUND) def test_get_pack_files_binary_files_are_excluded(self): binary_files = [ - 'icon.png', - 'etc/permissions.png', - 'etc/travisci.png', - 'etc/generate_new_token.png' + "icon.png", + "etc/permissions.png", + "etc/travisci.png", + "etc/generate_new_token.png", ] - pack_db = Pack.get_by_ref('dummy_pack_1') + pack_db = Pack.get_by_ref("dummy_pack_1") all_files_count = len(pack_db.files) non_binary_files_count = all_files_count - len(binary_files) - resp = self.app.get('/v1/packs/views/files/dummy_pack_1') + resp = self.app.get("/v1/packs/views/files/dummy_pack_1") self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), non_binary_files_count) @@ -65,63 +67,75 @@ def test_get_pack_files_binary_files_are_excluded(self): # But not in files controller response for file_path in binary_files: - item = [item for item in resp.json if item['file_path'] == file_path] + item = [item for item in resp.json if item["file_path"] == file_path] self.assertFalse(item) def test_get_pack_file_success(self): - resp = self.app.get('/v1/packs/views/file/dummy_pack_1/pack.yaml') + resp = self.app.get("/v1/packs/views/file/dummy_pack_1/pack.yaml") self.assertEqual(resp.status_int, http_client.OK) - self.assertIn(b'name : dummy_pack_1', resp.body) + self.assertIn(b"name : dummy_pack_1", resp.body) def test_get_pack_file_pack_doesnt_exist(self): - resp = self.app.get('/v1/packs/views/files/doesntexist/pack.yaml', expect_errors=True) + resp = self.app.get( + "/v1/packs/views/files/doesntexist/pack.yaml", expect_errors=True + ) self.assertEqual(resp.status_int, http_client.NOT_FOUND) - @mock.patch('st2api.controllers.v1.pack_views.MAX_FILE_SIZE', 1) + @mock.patch("st2api.controllers.v1.pack_views.MAX_FILE_SIZE", 1) def test_pack_file_file_larger_then_maximum_size(self): - resp = self.app.get('/v1/packs/views/file/dummy_pack_1/pack.yaml', expect_errors=True) + resp = self.app.get( + "/v1/packs/views/file/dummy_pack_1/pack.yaml", expect_errors=True + ) self.assertEqual(resp.status_int, http_client.BAD_REQUEST) - self.assertIn('File pack.yaml exceeds maximum allowed file size', resp) + self.assertIn("File pack.yaml exceeds maximum allowed file size", resp) def test_headers_get_pack_file(self): - resp = self.app.get('/v1/packs/views/file/dummy_pack_1/pack.yaml') + resp = self.app.get("/v1/packs/views/file/dummy_pack_1/pack.yaml") self.assertEqual(resp.status_int, http_client.OK) - self.assertIn(b'name : dummy_pack_1', resp.body) - self.assertIsNotNone(resp.headers['ETag']) - self.assertIsNotNone(resp.headers['Last-Modified']) + self.assertIn(b"name : dummy_pack_1", resp.body) + self.assertIsNotNone(resp.headers["ETag"]) + self.assertIsNotNone(resp.headers["Last-Modified"]) def test_no_change_get_pack_file(self): - resp = self.app.get('/v1/packs/views/file/dummy_pack_1/pack.yaml') + resp = self.app.get("/v1/packs/views/file/dummy_pack_1/pack.yaml") self.assertEqual(resp.status_int, http_client.OK) - self.assertIn(b'name : dummy_pack_1', resp.body) + self.assertIn(b"name : dummy_pack_1", resp.body) # Confirm NOT_MODIFIED - resp = self.app.get('/v1/packs/views/file/dummy_pack_1/pack.yaml', - headers={'If-None-Match': resp.headers['ETag']}) + resp = self.app.get( + "/v1/packs/views/file/dummy_pack_1/pack.yaml", + headers={"If-None-Match": resp.headers["ETag"]}, + ) self.assertEqual(resp.status_code, http_client.NOT_MODIFIED) - resp = self.app.get('/v1/packs/views/file/dummy_pack_1/pack.yaml', - headers={'If-Modified-Since': resp.headers['Last-Modified']}) + resp = self.app.get( + "/v1/packs/views/file/dummy_pack_1/pack.yaml", + headers={"If-Modified-Since": resp.headers["Last-Modified"]}, + ) self.assertEqual(resp.status_code, http_client.NOT_MODIFIED) # Confirm value is returned if header do not match - resp = self.app.get('/v1/packs/views/file/dummy_pack_1/pack.yaml', - headers={'If-None-Match': 'ETAG'}) + resp = self.app.get( + "/v1/packs/views/file/dummy_pack_1/pack.yaml", + headers={"If-None-Match": "ETAG"}, + ) self.assertEqual(resp.status_code, http_client.OK) - self.assertIn(b'name : dummy_pack_1', resp.body) + self.assertIn(b"name : dummy_pack_1", resp.body) - resp = self.app.get('/v1/packs/views/file/dummy_pack_1/pack.yaml', - headers={'If-Modified-Since': 'Last-Modified'}) + resp = self.app.get( + "/v1/packs/views/file/dummy_pack_1/pack.yaml", + headers={"If-Modified-Since": "Last-Modified"}, + ) self.assertEqual(resp.status_code, http_client.OK) - self.assertIn(b'name : dummy_pack_1', resp.body) + self.assertIn(b"name : dummy_pack_1", resp.body) def test_get_pack_files_and_pack_file_ref_doesnt_equal_pack_name(self): # Ref is not equal to the name, controller should still work - resp = self.app.get('/v1/packs/views/files/dummy_pack_16') + resp = self.app.get("/v1/packs/views/files/dummy_pack_16") self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 1) - self.assertEqual(resp.json[0]['file_path'], 'pack.yaml') + self.assertEqual(resp.json[0]["file_path"], "pack.yaml") - resp = self.app.get('/v1/packs/views/file/dummy_pack_16/pack.yaml') + resp = self.app.get("/v1/packs/views/file/dummy_pack_16/pack.yaml") self.assertEqual(resp.status_int, http_client.OK) - self.assertIn(b'ref: dummy_pack_16', resp.body) + self.assertIn(b"ref: dummy_pack_16", resp.body) diff --git a/st2api/tests/unit/controllers/v1/test_policies.py b/st2api/tests/unit/controllers/v1/test_policies.py index a26c3dea24b..3127b3aeb70 100644 --- a/st2api/tests/unit/controllers/v1/test_policies.py +++ b/st2api/tests/unit/controllers/v1/test_policies.py @@ -27,36 +27,28 @@ from st2tests.api import FunctionalTest from st2tests.api import APIControllerWithIncludeAndExcludeFilterTestCase -__all__ = [ - 'PolicyTypeControllerTestCase', - 'PolicyControllerTestCase' -] +__all__ = ["PolicyTypeControllerTestCase", "PolicyControllerTestCase"] TEST_FIXTURES = { - 'policytypes': [ - 'fake_policy_type_1.yaml', - 'fake_policy_type_2.yaml' - ], - 'policies': [ - 'policy_1.yaml', - 'policy_2.yaml' - ] + "policytypes": ["fake_policy_type_1.yaml", "fake_policy_type_2.yaml"], + "policies": ["policy_1.yaml", "policy_2.yaml"], } -PACK = 'generic' +PACK = "generic" LOADER = FixturesLoader() FIXTURES = LOADER.load_fixtures(fixtures_pack=PACK, fixtures_dict=TEST_FIXTURES) -class PolicyTypeControllerTestCase(FunctionalTest, - APIControllerWithIncludeAndExcludeFilterTestCase): - get_all_path = '/v1/policytypes' +class PolicyTypeControllerTestCase( + FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase +): + get_all_path = "/v1/policytypes" controller_cls = PolicyTypeController - include_attribute_field_name = 'module' - exclude_attribute_field_name = 'parameters' + include_attribute_field_name = "module" + exclude_attribute_field_name = "parameters" - base_url = '/v1/policytypes' + base_url = "/v1/policytypes" @classmethod def setUpClass(cls): @@ -64,7 +56,7 @@ def setUpClass(cls): cls.policy_type_dbs = [] - for _, fixture in six.iteritems(FIXTURES['policytypes']): + for _, fixture in six.iteritems(FIXTURES["policytypes"]): instance = PolicyTypeAPI(**fixture) policy_type_db = PolicyType.add_or_update(PolicyTypeAPI.to_model(instance)) cls.policy_type_dbs.append(policy_type_db) @@ -80,23 +72,25 @@ def test_policy_type_filter(self): self.assertGreater(len(resp.json), 0) selected = resp.json[0] - resp = self.__do_get_all(filter='resource_type=%s&name=%s' % - (selected['resource_type'], selected['name'])) + resp = self.__do_get_all( + filter="resource_type=%s&name=%s" + % (selected["resource_type"], selected["name"]) + ) self.assertEqual(resp.status_int, 200) self.assertEqual(len(resp.json), 1) - self.assertEqual(self.__get_obj_id(resp, idx=0), selected['id']) + self.assertEqual(self.__get_obj_id(resp, idx=0), selected["id"]) - resp = self.__do_get_all(filter='name=%s' % selected['name']) + resp = self.__do_get_all(filter="name=%s" % selected["name"]) self.assertEqual(resp.status_int, 200) self.assertEqual(len(resp.json), 1) - self.assertEqual(self.__get_obj_id(resp, idx=0), selected['id']) + self.assertEqual(self.__get_obj_id(resp, idx=0), selected["id"]) - resp = self.__do_get_all(filter='resource_type=%s' % selected['resource_type']) + resp = self.__do_get_all(filter="resource_type=%s" % selected["resource_type"]) self.assertEqual(resp.status_int, 200) self.assertGreater(len(resp.json), 1) def test_policy_type_filter_empty(self): - resp = self.__do_get_all(filter='resource_type=yo&name=whatever') + resp = self.__do_get_all(filter="resource_type=yo&name=whatever") self.assertEqual(resp.status_int, 200) self.assertEqual(len(resp.json), 0) @@ -106,16 +100,16 @@ def test_policy_type_get_one(self): self.assertGreater(len(resp.json), 0) selected = resp.json[0] - resp = self.__do_get_one(selected['id']) + resp = self.__do_get_one(selected["id"]) self.assertEqual(resp.status_int, 200) - self.assertEqual(self.__get_obj_id(resp), selected['id']) + self.assertEqual(self.__get_obj_id(resp), selected["id"]) - resp = self.__do_get_one(selected['ref']) + resp = self.__do_get_one(selected["ref"]) self.assertEqual(resp.status_int, 200) - self.assertEqual(self.__get_obj_id(resp), selected['id']) + self.assertEqual(self.__get_obj_id(resp), selected["id"]) def test_policy_type_get_one_fail(self): - resp = self.__do_get_one('1') + resp = self.__do_get_one("1") self.assertEqual(resp.status_int, 404) def _insert_mock_models(self): @@ -130,36 +124,37 @@ def _delete_mock_models(self, object_ids): @staticmethod def __get_obj_id(resp, idx=-1): - return resp.json['id'] if idx < 0 else resp.json[idx]['id'] + return resp.json["id"] if idx < 0 else resp.json[idx]["id"] def __do_get_all(self, filter=None): - url = '%s?%s' % (self.base_url, filter) if filter else self.base_url + url = "%s?%s" % (self.base_url, filter) if filter else self.base_url return self.app.get(url, expect_errors=True) def __do_get_one(self, id): - return self.app.get('%s/%s' % (self.base_url, id), expect_errors=True) + return self.app.get("%s/%s" % (self.base_url, id), expect_errors=True) -class PolicyControllerTestCase(FunctionalTest, - APIControllerWithIncludeAndExcludeFilterTestCase): - get_all_path = '/v1/policies' +class PolicyControllerTestCase( + FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase +): + get_all_path = "/v1/policies" controller_cls = PolicyController - include_attribute_field_name = 'policy_type' - exclude_attribute_field_name = 'parameters' + include_attribute_field_name = "policy_type" + exclude_attribute_field_name = "parameters" - base_url = '/v1/policies' + base_url = "/v1/policies" @classmethod def setUpClass(cls): super(PolicyControllerTestCase, cls).setUpClass() - for _, fixture in six.iteritems(FIXTURES['policytypes']): + for _, fixture in six.iteritems(FIXTURES["policytypes"]): instance = PolicyTypeAPI(**fixture) PolicyType.add_or_update(PolicyTypeAPI.to_model(instance)) cls.policy_dbs = [] - for _, fixture in six.iteritems(FIXTURES['policies']): + for _, fixture in six.iteritems(FIXTURES["policies"]): instance = PolicyAPI(**fixture) policy_db = Policy.add_or_update(PolicyAPI.to_model(instance)) cls.policy_dbs.append(policy_db) @@ -175,22 +170,24 @@ def test_filter(self): self.assertGreater(len(resp.json), 0) selected = resp.json[0] - resp = self.__do_get_all(filter='pack=%s&name=%s' % (selected['pack'], selected['name'])) + resp = self.__do_get_all( + filter="pack=%s&name=%s" % (selected["pack"], selected["name"]) + ) self.assertEqual(resp.status_int, 200) self.assertEqual(len(resp.json), 1) - self.assertEqual(self.__get_obj_id(resp, idx=0), selected['id']) + self.assertEqual(self.__get_obj_id(resp, idx=0), selected["id"]) - resp = self.__do_get_all(filter='name=%s' % selected['name']) + resp = self.__do_get_all(filter="name=%s" % selected["name"]) self.assertEqual(resp.status_int, 200) self.assertEqual(len(resp.json), 1) - self.assertEqual(self.__get_obj_id(resp, idx=0), selected['id']) + self.assertEqual(self.__get_obj_id(resp, idx=0), selected["id"]) - resp = self.__do_get_all(filter='pack=%s' % selected['pack']) + resp = self.__do_get_all(filter="pack=%s" % selected["pack"]) self.assertEqual(resp.status_int, 200) self.assertGreater(len(resp.json), 1) def test_filter_empty(self): - resp = self.__do_get_all(filter='pack=yo&name=whatever') + resp = self.__do_get_all(filter="pack=yo&name=whatever") self.assertEqual(resp.status_int, 200) self.assertEqual(len(resp.json), 0) @@ -200,16 +197,16 @@ def test_get_one(self): self.assertGreater(len(resp.json), 0) selected = resp.json[0] - resp = self.__do_get_one(selected['id']) + resp = self.__do_get_one(selected["id"]) self.assertEqual(resp.status_int, 200) - self.assertEqual(self.__get_obj_id(resp), selected['id']) + self.assertEqual(self.__get_obj_id(resp), selected["id"]) - resp = self.__do_get_one(selected['ref']) + resp = self.__do_get_one(selected["ref"]) self.assertEqual(resp.status_int, 200) - self.assertEqual(self.__get_obj_id(resp), selected['id']) + self.assertEqual(self.__get_obj_id(resp), selected["id"]) def test_get_one_fail(self): - resp = self.__do_get_one('1') + resp = self.__do_get_one("1") self.assertEqual(resp.status_int, 404) def test_crud(self): @@ -221,10 +218,10 @@ def test_crud(self): self.assertEqual(get_resp.status_int, http_client.OK) updated_input = get_resp.json - updated_input['enabled'] = not updated_input['enabled'] + updated_input["enabled"] = not updated_input["enabled"] put_resp = self.__do_put(self.__get_obj_id(post_resp), updated_input) self.assertEqual(put_resp.status_int, http_client.OK) - self.assertEqual(put_resp.json['enabled'], updated_input['enabled']) + self.assertEqual(put_resp.json["enabled"], updated_input["enabled"]) del_resp = self.__do_delete(self.__get_obj_id(post_resp)) self.assertEqual(del_resp.status_int, http_client.NO_CONTENT) @@ -243,41 +240,45 @@ def test_post_duplicate(self): def test_put_not_found(self): updated_input = self.__create_instance() - put_resp = self.__do_put('12345', updated_input) + put_resp = self.__do_put("12345", updated_input) self.assertEqual(put_resp.status_int, http_client.NOT_FOUND) def test_put_sys_pack(self): instance = self.__create_instance() - instance['pack'] = 'core' + instance["pack"] = "core" post_resp = self.__do_post(instance) self.assertEqual(post_resp.status_int, http_client.CREATED) updated_input = post_resp.json - updated_input['enabled'] = not updated_input['enabled'] + updated_input["enabled"] = not updated_input["enabled"] put_resp = self.__do_put(self.__get_obj_id(post_resp), updated_input) self.assertEqual(put_resp.status_int, http_client.BAD_REQUEST) - self.assertEqual(put_resp.json['faultstring'], - "Resources belonging to system level packs can't be manipulated") + self.assertEqual( + put_resp.json["faultstring"], + "Resources belonging to system level packs can't be manipulated", + ) # Clean up manually since API won't delete object in sys pack. Policy.delete(Policy.get_by_id(self.__get_obj_id(post_resp))) def test_delete_not_found(self): - del_resp = self.__do_delete('12345') + del_resp = self.__do_delete("12345") self.assertEqual(del_resp.status_int, http_client.NOT_FOUND) def test_delete_sys_pack(self): instance = self.__create_instance() - instance['pack'] = 'core' + instance["pack"] = "core" post_resp = self.__do_post(instance) self.assertEqual(post_resp.status_int, http_client.CREATED) del_resp = self.__do_delete(self.__get_obj_id(post_resp)) self.assertEqual(del_resp.status_int, http_client.BAD_REQUEST) - self.assertEqual(del_resp.json['faultstring'], - "Resources belonging to system level packs can't be manipulated") + self.assertEqual( + del_resp.json["faultstring"], + "Resources belonging to system level packs can't be manipulated", + ) # Clean up manually since API won't delete object in sys pack. Policy.delete(Policy.get_by_id(self.__get_obj_id(post_resp))) @@ -295,34 +296,34 @@ def _delete_mock_models(self, object_ids): @staticmethod def __create_instance(): return { - 'name': 'myaction.mypolicy', - 'pack': 'mypack', - 'resource_ref': 'mypack.myaction', - 'policy_type': 'action.mock_policy_error', - 'parameters': { - 'k1': 'v1' - } + "name": "myaction.mypolicy", + "pack": "mypack", + "resource_ref": "mypack.myaction", + "policy_type": "action.mock_policy_error", + "parameters": {"k1": "v1"}, } @staticmethod def __get_obj_id(resp, idx=-1): - return resp.json['id'] if idx < 0 else resp.json[idx]['id'] + return resp.json["id"] if idx < 0 else resp.json[idx]["id"] def __do_get_all(self, filter=None): - url = '%s?%s' % (self.base_url, filter) if filter else self.base_url + url = "%s?%s" % (self.base_url, filter) if filter else self.base_url return self.app.get(url, expect_errors=True) def __do_get_one(self, id): - return self.app.get('%s/%s' % (self.base_url, id), expect_errors=True) + return self.app.get("%s/%s" % (self.base_url, id), expect_errors=True) - @mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) + @mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) def __do_post(self, instance): return self.app.post_json(self.base_url, instance, expect_errors=True) - @mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) + @mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) def __do_put(self, id, instance): - return self.app.put_json('%s/%s' % (self.base_url, id), instance, expect_errors=True) + return self.app.put_json( + "%s/%s" % (self.base_url, id), instance, expect_errors=True + ) - @mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) + @mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) def __do_delete(self, id): - return self.app.delete('%s/%s' % (self.base_url, id), expect_errors=True) + return self.app.delete("%s/%s" % (self.base_url, id), expect_errors=True) diff --git a/st2api/tests/unit/controllers/v1/test_rule_enforcement_views.py b/st2api/tests/unit/controllers/v1/test_rule_enforcement_views.py index 84c2a66b4aa..0a7a104d350 100644 --- a/st2api/tests/unit/controllers/v1/test_rule_enforcement_views.py +++ b/st2api/tests/unit/controllers/v1/test_rule_enforcement_views.py @@ -21,87 +21,109 @@ from st2tests.api import FunctionalTest from st2tests.api import APIControllerWithIncludeAndExcludeFilterTestCase -__all__ = [ - 'RuleEnforcementViewsControllerTestCase' -] +__all__ = ["RuleEnforcementViewsControllerTestCase"] http_client = six.moves.http_client TEST_FIXTURES = { - 'enforcements': ['enforcement1.yaml', 'enforcement2.yaml', 'enforcement3.yaml'], - 'executions': ['execution1.yaml'], - 'triggerinstances': ['trigger_instance_1.yaml'] + "enforcements": ["enforcement1.yaml", "enforcement2.yaml", "enforcement3.yaml"], + "executions": ["execution1.yaml"], + "triggerinstances": ["trigger_instance_1.yaml"], } -FIXTURES_PACK = 'rule_enforcements' +FIXTURES_PACK = "rule_enforcements" -class RuleEnforcementViewsControllerTestCase(FunctionalTest, - APIControllerWithIncludeAndExcludeFilterTestCase): - get_all_path = '/v1/ruleenforcements/views' +class RuleEnforcementViewsControllerTestCase( + FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase +): + get_all_path = "/v1/ruleenforcements/views" controller_cls = RuleEnforcementViewController - include_attribute_field_name = 'enforced_at' - exclude_attribute_field_name = 'status' + include_attribute_field_name = "enforced_at" + exclude_attribute_field_name = "status" fixtures_loader = FixturesLoader() @classmethod def setUpClass(cls): super(RuleEnforcementViewsControllerTestCase, cls).setUpClass() - cls.models = RuleEnforcementViewsControllerTestCase.fixtures_loader.save_fixtures_to_db( - fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES, - use_object_ids=True) - cls.ENFORCEMENT_1 = cls.models['enforcements']['enforcement1.yaml'] + cls.models = ( + RuleEnforcementViewsControllerTestCase.fixtures_loader.save_fixtures_to_db( + fixtures_pack=FIXTURES_PACK, + fixtures_dict=TEST_FIXTURES, + use_object_ids=True, + ) + ) + cls.ENFORCEMENT_1 = cls.models["enforcements"]["enforcement1.yaml"] def test_get_all(self): - resp = self.app.get('/v1/ruleenforcements/views') + resp = self.app.get("/v1/ruleenforcements/views") self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 3) # Verify it includes corresponding execution and trigger instance object - self.assertEqual(resp.json[0]['trigger_instance']['id'], '565e15ce32ed350857dfa623') - self.assertEqual(resp.json[0]['trigger_instance']['payload'], {'foo': 'bar', 'name': 'Joe'}) - - self.assertEqual(resp.json[0]['execution']['action']['ref'], 'core.local') - self.assertEqual(resp.json[0]['execution']['action']['parameters'], - {'sudo': {'immutable': True}}) - self.assertEqual(resp.json[0]['execution']['runner']['name'], 'action-chain') - self.assertEqual(resp.json[0]['execution']['runner']['runner_parameters'], - {'foo': {'type': 'string'}}) - self.assertEqual(resp.json[0]['execution']['parameters'], {'cmd': 'echo bar'}) - self.assertEqual(resp.json[0]['execution']['status'], 'scheduled') - - self.assertEqual(resp.json[1]['trigger_instance'], {}) - self.assertEqual(resp.json[1]['execution'], {}) - - self.assertEqual(resp.json[2]['trigger_instance'], {}) - self.assertEqual(resp.json[2]['execution'], {}) + self.assertEqual( + resp.json[0]["trigger_instance"]["id"], "565e15ce32ed350857dfa623" + ) + self.assertEqual( + resp.json[0]["trigger_instance"]["payload"], {"foo": "bar", "name": "Joe"} + ) + + self.assertEqual(resp.json[0]["execution"]["action"]["ref"], "core.local") + self.assertEqual( + resp.json[0]["execution"]["action"]["parameters"], + {"sudo": {"immutable": True}}, + ) + self.assertEqual(resp.json[0]["execution"]["runner"]["name"], "action-chain") + self.assertEqual( + resp.json[0]["execution"]["runner"]["runner_parameters"], + {"foo": {"type": "string"}}, + ) + self.assertEqual(resp.json[0]["execution"]["parameters"], {"cmd": "echo bar"}) + self.assertEqual(resp.json[0]["execution"]["status"], "scheduled") + + self.assertEqual(resp.json[1]["trigger_instance"], {}) + self.assertEqual(resp.json[1]["execution"], {}) + + self.assertEqual(resp.json[2]["trigger_instance"], {}) + self.assertEqual(resp.json[2]["execution"], {}) def test_filter_by_rule_ref(self): - resp = self.app.get('/v1/ruleenforcements/views?rule_ref=wolfpack.golden_rule') + resp = self.app.get("/v1/ruleenforcements/views?rule_ref=wolfpack.golden_rule") self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 1) - self.assertEqual(resp.json[0]['rule']['ref'], 'wolfpack.golden_rule') + self.assertEqual(resp.json[0]["rule"]["ref"], "wolfpack.golden_rule") def test_get_one_success(self): - resp = self.app.get('/v1/ruleenforcements/views/%s' % (str(self.ENFORCEMENT_1.id))) - self.assertEqual(resp.json['id'], str(self.ENFORCEMENT_1.id)) - - self.assertEqual(resp.json['trigger_instance']['id'], '565e15ce32ed350857dfa623') - self.assertEqual(resp.json['trigger_instance']['payload'], {'foo': 'bar', 'name': 'Joe'}) - - self.assertEqual(resp.json['execution']['action']['ref'], 'core.local') - self.assertEqual(resp.json['execution']['action']['parameters'], - {'sudo': {'immutable': True}}) - self.assertEqual(resp.json['execution']['runner']['name'], 'action-chain') - self.assertEqual(resp.json['execution']['runner']['runner_parameters'], - {'foo': {'type': 'string'}}) - self.assertEqual(resp.json['execution']['parameters'], {'cmd': 'echo bar'}) - self.assertEqual(resp.json['execution']['status'], 'scheduled') + resp = self.app.get( + "/v1/ruleenforcements/views/%s" % (str(self.ENFORCEMENT_1.id)) + ) + self.assertEqual(resp.json["id"], str(self.ENFORCEMENT_1.id)) + + self.assertEqual( + resp.json["trigger_instance"]["id"], "565e15ce32ed350857dfa623" + ) + self.assertEqual( + resp.json["trigger_instance"]["payload"], {"foo": "bar", "name": "Joe"} + ) + + self.assertEqual(resp.json["execution"]["action"]["ref"], "core.local") + self.assertEqual( + resp.json["execution"]["action"]["parameters"], + {"sudo": {"immutable": True}}, + ) + self.assertEqual(resp.json["execution"]["runner"]["name"], "action-chain") + self.assertEqual( + resp.json["execution"]["runner"]["runner_parameters"], + {"foo": {"type": "string"}}, + ) + self.assertEqual(resp.json["execution"]["parameters"], {"cmd": "echo bar"}) + self.assertEqual(resp.json["execution"]["status"], "scheduled") def _insert_mock_models(self): - enfrocement_ids = [enforcement['id'] for enforcement in - self.models['enforcements'].values()] + enfrocement_ids = [ + enforcement["id"] for enforcement in self.models["enforcements"].values() + ] return enfrocement_ids def _delete_mock_models(self, object_ids): diff --git a/st2api/tests/unit/controllers/v1/test_rule_enforcements.py b/st2api/tests/unit/controllers/v1/test_rule_enforcements.py index 172b186098c..f2de1e2b2a5 100644 --- a/st2api/tests/unit/controllers/v1/test_rule_enforcements.py +++ b/st2api/tests/unit/controllers/v1/test_rule_enforcements.py @@ -24,92 +24,106 @@ http_client = six.moves.http_client TEST_FIXTURES = { - 'enforcements': ['enforcement1.yaml', 'enforcement2.yaml', 'enforcement3.yaml'] + "enforcements": ["enforcement1.yaml", "enforcement2.yaml", "enforcement3.yaml"] } -FIXTURES_PACK = 'rule_enforcements' +FIXTURES_PACK = "rule_enforcements" -class RuleEnforcementControllerTestCase(FunctionalTest, - APIControllerWithIncludeAndExcludeFilterTestCase): - get_all_path = '/v1/ruleenforcements' +class RuleEnforcementControllerTestCase( + FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase +): + get_all_path = "/v1/ruleenforcements" controller_cls = RuleEnforcementController - include_attribute_field_name = 'enforced_at' - exclude_attribute_field_name = 'status' + include_attribute_field_name = "enforced_at" + exclude_attribute_field_name = "status" fixtures_loader = FixturesLoader() @classmethod def setUpClass(cls): super(RuleEnforcementControllerTestCase, cls).setUpClass() - cls.models = RuleEnforcementControllerTestCase.fixtures_loader.save_fixtures_to_db( - fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES) - RuleEnforcementControllerTestCase.ENFORCEMENT_1 = \ - cls.models['enforcements']['enforcement1.yaml'] + cls.models = ( + RuleEnforcementControllerTestCase.fixtures_loader.save_fixtures_to_db( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES + ) + ) + RuleEnforcementControllerTestCase.ENFORCEMENT_1 = cls.models["enforcements"][ + "enforcement1.yaml" + ] def test_get_all(self): - resp = self.app.get('/v1/ruleenforcements') + resp = self.app.get("/v1/ruleenforcements") self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 3) def test_get_all_minus_one(self): - resp = self.app.get('/v1/ruleenforcements/?limit=-1') + resp = self.app.get("/v1/ruleenforcements/?limit=-1") self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 3) def test_get_all_limit(self): - resp = self.app.get('/v1/ruleenforcements/?limit=1') + resp = self.app.get("/v1/ruleenforcements/?limit=1") self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 1) def test_get_all_limit_negative_number(self): - resp = self.app.get('/v1/ruleenforcements?limit=-22', expect_errors=True) + resp = self.app.get("/v1/ruleenforcements?limit=-22", expect_errors=True) self.assertEqual(resp.status_int, 400) - self.assertEqual(resp.json['faultstring'], - u'Limit, "-22" specified, must be a positive number.') + self.assertEqual( + resp.json["faultstring"], + 'Limit, "-22" specified, must be a positive number.', + ) def test_get_one_by_id(self): e_id = str(RuleEnforcementControllerTestCase.ENFORCEMENT_1.id) - resp = self.app.get('/v1/ruleenforcements/%s' % e_id) + resp = self.app.get("/v1/ruleenforcements/%s" % e_id) self.assertEqual(resp.status_int, http_client.OK) - self.assertEqual(resp.json['id'], e_id) + self.assertEqual(resp.json["id"], e_id) def test_get_one_fail(self): - resp = self.app.get('/v1/ruleenforcements/1', expect_errors=True) + resp = self.app.get("/v1/ruleenforcements/1", expect_errors=True) self.assertEqual(resp.status_int, http_client.NOT_FOUND) def test_filter_by_rule_ref(self): - resp = self.app.get('/v1/ruleenforcements?rule_ref=wolfpack.golden_rule') + resp = self.app.get("/v1/ruleenforcements?rule_ref=wolfpack.golden_rule") self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 1) def test_filter_by_rule_id(self): - resp = self.app.get('/v1/ruleenforcements?rule_id=565e15c032ed35086c54f331') + resp = self.app.get("/v1/ruleenforcements?rule_id=565e15c032ed35086c54f331") self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 2) def test_filter_by_execution_id(self): - resp = self.app.get('/v1/ruleenforcements?execution=565e15cd32ed350857dfa620') + resp = self.app.get("/v1/ruleenforcements?execution=565e15cd32ed350857dfa620") self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 1) def test_filter_by_trigger_instance_id(self): - resp = self.app.get('/v1/ruleenforcements?trigger_instance=565e15ce32ed350857dfa623') + resp = self.app.get( + "/v1/ruleenforcements?trigger_instance=565e15ce32ed350857dfa623" + ) self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 1) def test_filter_by_enforced_at(self): - resp = self.app.get('/v1/ruleenforcements?enforced_at_gt=2015-12-01T21:49:01.000000Z') + resp = self.app.get( + "/v1/ruleenforcements?enforced_at_gt=2015-12-01T21:49:01.000000Z" + ) self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 2) - resp = self.app.get('/v1/ruleenforcements?enforced_at_lt=2015-12-01T21:49:01.000000Z') + resp = self.app.get( + "/v1/ruleenforcements?enforced_at_lt=2015-12-01T21:49:01.000000Z" + ) self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 1) def _insert_mock_models(self): - enfrocement_ids = [enforcement['id'] for enforcement in - self.models['enforcements'].values()] + enfrocement_ids = [ + enforcement["id"] for enforcement in self.models["enforcements"].values() + ] return enfrocement_ids def _delete_mock_models(self, object_ids): diff --git a/st2api/tests/unit/controllers/v1/test_rule_views.py b/st2api/tests/unit/controllers/v1/test_rule_views.py index f8a25e5d3d7..95839c3110f 100644 --- a/st2api/tests/unit/controllers/v1/test_rule_views.py +++ b/st2api/tests/unit/controllers/v1/test_rule_views.py @@ -25,25 +25,24 @@ http_client = six.moves.http_client TEST_FIXTURES = { - 'runners': ['testrunner1.yaml'], - 'actions': ['action1.yaml', 'action2.yaml'], - 'triggers': ['trigger1.yaml'], - 'triggertypes': ['triggertype1.yaml'] + "runners": ["testrunner1.yaml"], + "actions": ["action1.yaml", "action2.yaml"], + "triggers": ["trigger1.yaml"], + "triggertypes": ["triggertype1.yaml"], } -TEST_FIXTURES_RULES = { - 'rules': ['rule1.yaml', 'rule4.yaml', 'rule5.yaml'] -} +TEST_FIXTURES_RULES = {"rules": ["rule1.yaml", "rule4.yaml", "rule5.yaml"]} -FIXTURES_PACK = 'generic' +FIXTURES_PACK = "generic" -class RuleViewControllerTestCase(FunctionalTest, - APIControllerWithIncludeAndExcludeFilterTestCase): - get_all_path = '/v1/rules/views' +class RuleViewControllerTestCase( + FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase +): + get_all_path = "/v1/rules/views" controller_cls = RuleViewController - include_attribute_field_name = 'criteria' - exclude_attribute_field_name = 'enabled' + include_attribute_field_name = "criteria" + exclude_attribute_field_name = "enabled" fixtures_loader = FixturesLoader() @@ -51,17 +50,21 @@ class RuleViewControllerTestCase(FunctionalTest, def setUpClass(cls): super(RuleViewControllerTestCase, cls).setUpClass() models = RuleViewControllerTestCase.fixtures_loader.save_fixtures_to_db( - fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES) - RuleViewControllerTestCase.ACTION_1 = models['actions']['action1.yaml'] - RuleViewControllerTestCase.TRIGGER_TYPE_1 = models['triggertypes']['triggertype1.yaml'] - - file_name = 'rule1.yaml' + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES + ) + RuleViewControllerTestCase.ACTION_1 = models["actions"]["action1.yaml"] + RuleViewControllerTestCase.TRIGGER_TYPE_1 = models["triggertypes"][ + "triggertype1.yaml" + ] + + file_name = "rule1.yaml" cls.rules = RuleViewControllerTestCase.fixtures_loader.save_fixtures_to_db( - fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES_RULES)['rules'] + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES_RULES + )["rules"] RuleViewControllerTestCase.RULE_1 = cls.rules[file_name] def test_get_all(self): - resp = self.app.get('/v1/rules/views') + resp = self.app.get("/v1/rules/views") self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 3) @@ -70,25 +73,29 @@ def test_get_one_by_id(self): get_resp = self.__do_get_one(rule_id) self.assertEqual(get_resp.status_int, http_client.OK) self.assertEqual(self.__get_rule_id(get_resp), rule_id) - self.assertEqual(get_resp.json['action']['description'], - RuleViewControllerTestCase.ACTION_1.description) - self.assertEqual(get_resp.json['trigger']['description'], - RuleViewControllerTestCase.TRIGGER_TYPE_1.description) + self.assertEqual( + get_resp.json["action"]["description"], + RuleViewControllerTestCase.ACTION_1.description, + ) + self.assertEqual( + get_resp.json["trigger"]["description"], + RuleViewControllerTestCase.TRIGGER_TYPE_1.description, + ) def test_get_one_by_ref(self): rule_name = RuleViewControllerTestCase.RULE_1.name rule_pack = RuleViewControllerTestCase.RULE_1.pack ref = ResourceReference.to_string_reference(name=rule_name, pack=rule_pack) get_resp = self.__do_get_one(ref) - self.assertEqual(get_resp.json['name'], rule_name) + self.assertEqual(get_resp.json["name"], rule_name) self.assertEqual(get_resp.status_int, http_client.OK) def test_get_one_fail(self): - resp = self.app.get('/v1/rules/1', expect_errors=True) + resp = self.app.get("/v1/rules/1", expect_errors=True) self.assertEqual(resp.status_int, http_client.NOT_FOUND) def _insert_mock_models(self): - rule_ids = [rule['id'] for rule in self.rules.values()] + rule_ids = [rule["id"] for rule in self.rules.values()] return rule_ids def _delete_mock_models(self, object_ids): @@ -96,7 +103,7 @@ def _delete_mock_models(self, object_ids): @staticmethod def __get_rule_id(resp): - return resp.json['id'] + return resp.json["id"] def __do_get_one(self, rule_id): - return self.app.get('/v1/rules/views/%s' % rule_id, expect_errors=True) + return self.app.get("/v1/rules/views/%s" % rule_id, expect_errors=True) diff --git a/st2api/tests/unit/controllers/v1/test_rules.py b/st2api/tests/unit/controllers/v1/test_rules.py index f52b4294cab..daf6845bcb7 100644 --- a/st2api/tests/unit/controllers/v1/test_rules.py +++ b/st2api/tests/unit/controllers/v1/test_rules.py @@ -34,21 +34,23 @@ http_client = six.moves.http_client TEST_FIXTURES = { - 'runners': ['testrunner1.yaml'], - 'actions': ['action1.yaml'], - 'triggers': ['trigger1.yaml'], - 'triggertypes': ['triggertype1.yaml', 'triggertype_with_parameters_2.yaml'] + "runners": ["testrunner1.yaml"], + "actions": ["action1.yaml"], + "triggers": ["trigger1.yaml"], + "triggertypes": ["triggertype1.yaml", "triggertype_with_parameters_2.yaml"], } -FIXTURES_PACK = 'generic' +FIXTURES_PACK = "generic" -@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) -class RulesControllerTestCase(FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase): - get_all_path = '/v1/rules' +@mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) +class RulesControllerTestCase( + FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase +): + get_all_path = "/v1/rules" controller_cls = RuleController - include_attribute_field_name = 'criteria' - exclude_attribute_field_name = 'enabled' + include_attribute_field_name = "criteria" + exclude_attribute_field_name = "enabled" VALIDATE_TRIGGER_PAYLOAD = None @@ -64,71 +66,96 @@ def setUpClass(cls): cls.VALIDATE_TRIGGER_PAYLOAD = cfg.CONF.system.validate_trigger_parameters models = RulesControllerTestCase.fixtures_loader.save_fixtures_to_db( - fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES) - RulesControllerTestCase.RUNNER_TYPE = models['runners']['testrunner1.yaml'] - RulesControllerTestCase.ACTION = models['actions']['action1.yaml'] - RulesControllerTestCase.TRIGGER = models['triggers']['trigger1.yaml'] + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES + ) + RulesControllerTestCase.RUNNER_TYPE = models["runners"]["testrunner1.yaml"] + RulesControllerTestCase.ACTION = models["actions"]["action1.yaml"] + RulesControllerTestCase.TRIGGER = models["triggers"]["trigger1.yaml"] # Don't load rule into DB as that is what is being tested. - file_name = 'rule1.yaml' - RulesControllerTestCase.RULE_1 = RulesControllerTestCase.fixtures_loader.load_fixtures( - fixtures_pack=FIXTURES_PACK, - fixtures_dict={'rules': [file_name]})['rules'][file_name] - - file_name = 'cron_timer_rule_invalid_parameters.yaml' - RulesControllerTestCase.RULE_2 = RulesControllerTestCase.fixtures_loader.load_fixtures( - fixtures_pack=FIXTURES_PACK, - fixtures_dict={'rules': [file_name]})['rules'][file_name] - - file_name = 'rule_no_enabled_attribute.yaml' - RulesControllerTestCase.RULE_3 = RulesControllerTestCase.fixtures_loader.load_fixtures( - fixtures_pack=FIXTURES_PACK, - fixtures_dict={'rules': [file_name]})['rules'][file_name] - - file_name = 'backstop_rule.yaml' - RulesControllerTestCase.RULE_4 = RulesControllerTestCase.fixtures_loader.load_fixtures( - fixtures_pack=FIXTURES_PACK, - fixtures_dict={'rules': [file_name]})['rules'][file_name] - - file_name = 'date_timer_rule_invalid_parameters.yaml' - RulesControllerTestCase.RULE_5 = RulesControllerTestCase.fixtures_loader.load_fixtures( - fixtures_pack=FIXTURES_PACK, - fixtures_dict={'rules': [file_name]})['rules'][file_name] - - file_name = 'cron_timer_rule_invalid_parameters_1.yaml' - RulesControllerTestCase.RULE_6 = RulesControllerTestCase.fixtures_loader.load_fixtures( - fixtures_pack=FIXTURES_PACK, - fixtures_dict={'rules': [file_name]})['rules'][file_name] - - file_name = 'cron_timer_rule_invalid_parameters_2.yaml' - RulesControllerTestCase.RULE_7 = RulesControllerTestCase.fixtures_loader.load_fixtures( - fixtures_pack=FIXTURES_PACK, - fixtures_dict={'rules': [file_name]})['rules'][file_name] - - file_name = 'cron_timer_rule_invalid_parameters_3.yaml' - RulesControllerTestCase.RULE_8 = RulesControllerTestCase.fixtures_loader.load_fixtures( - fixtures_pack=FIXTURES_PACK, - fixtures_dict={'rules': [file_name]})['rules'][file_name] - - file_name = 'rule_invalid_trigger_parameter_type.yaml' - RulesControllerTestCase.RULE_9 = RulesControllerTestCase.fixtures_loader.load_fixtures( - fixtures_pack=FIXTURES_PACK, - fixtures_dict={'rules': [file_name]})['rules'][file_name] - - file_name = 'rule_trigger_with_no_parameters.yaml' - RulesControllerTestCase.RULE_10 = RulesControllerTestCase.fixtures_loader.load_fixtures( - fixtures_pack=FIXTURES_PACK, - fixtures_dict={'rules': [file_name]})['rules'][file_name] - - file_name = 'rule_invalid_trigger_parameter_type_default_cfg.yaml' - RulesControllerTestCase.RULE_11 = RulesControllerTestCase.fixtures_loader.load_fixtures( - fixtures_pack=FIXTURES_PACK, - fixtures_dict={'rules': [file_name]})['rules'][file_name] - - file_name = 'rule space.yaml' - RulesControllerTestCase.RULE_SPACE = RulesControllerTestCase.fixtures_loader.load_fixtures( - fixtures_pack=FIXTURES_PACK, - fixtures_dict={'rules': [file_name]})['rules'][file_name] + file_name = "rule1.yaml" + RulesControllerTestCase.RULE_1 = ( + RulesControllerTestCase.fixtures_loader.load_fixtures( + fixtures_pack=FIXTURES_PACK, fixtures_dict={"rules": [file_name]} + )["rules"][file_name] + ) + + file_name = "cron_timer_rule_invalid_parameters.yaml" + RulesControllerTestCase.RULE_2 = ( + RulesControllerTestCase.fixtures_loader.load_fixtures( + fixtures_pack=FIXTURES_PACK, fixtures_dict={"rules": [file_name]} + )["rules"][file_name] + ) + + file_name = "rule_no_enabled_attribute.yaml" + RulesControllerTestCase.RULE_3 = ( + RulesControllerTestCase.fixtures_loader.load_fixtures( + fixtures_pack=FIXTURES_PACK, fixtures_dict={"rules": [file_name]} + )["rules"][file_name] + ) + + file_name = "backstop_rule.yaml" + RulesControllerTestCase.RULE_4 = ( + RulesControllerTestCase.fixtures_loader.load_fixtures( + fixtures_pack=FIXTURES_PACK, fixtures_dict={"rules": [file_name]} + )["rules"][file_name] + ) + + file_name = "date_timer_rule_invalid_parameters.yaml" + RulesControllerTestCase.RULE_5 = ( + RulesControllerTestCase.fixtures_loader.load_fixtures( + fixtures_pack=FIXTURES_PACK, fixtures_dict={"rules": [file_name]} + )["rules"][file_name] + ) + + file_name = "cron_timer_rule_invalid_parameters_1.yaml" + RulesControllerTestCase.RULE_6 = ( + RulesControllerTestCase.fixtures_loader.load_fixtures( + fixtures_pack=FIXTURES_PACK, fixtures_dict={"rules": [file_name]} + )["rules"][file_name] + ) + + file_name = "cron_timer_rule_invalid_parameters_2.yaml" + RulesControllerTestCase.RULE_7 = ( + RulesControllerTestCase.fixtures_loader.load_fixtures( + fixtures_pack=FIXTURES_PACK, fixtures_dict={"rules": [file_name]} + )["rules"][file_name] + ) + + file_name = "cron_timer_rule_invalid_parameters_3.yaml" + RulesControllerTestCase.RULE_8 = ( + RulesControllerTestCase.fixtures_loader.load_fixtures( + fixtures_pack=FIXTURES_PACK, fixtures_dict={"rules": [file_name]} + )["rules"][file_name] + ) + + file_name = "rule_invalid_trigger_parameter_type.yaml" + RulesControllerTestCase.RULE_9 = ( + RulesControllerTestCase.fixtures_loader.load_fixtures( + fixtures_pack=FIXTURES_PACK, fixtures_dict={"rules": [file_name]} + )["rules"][file_name] + ) + + file_name = "rule_trigger_with_no_parameters.yaml" + RulesControllerTestCase.RULE_10 = ( + RulesControllerTestCase.fixtures_loader.load_fixtures( + fixtures_pack=FIXTURES_PACK, fixtures_dict={"rules": [file_name]} + )["rules"][file_name] + ) + + file_name = "rule_invalid_trigger_parameter_type_default_cfg.yaml" + RulesControllerTestCase.RULE_11 = ( + RulesControllerTestCase.fixtures_loader.load_fixtures( + fixtures_pack=FIXTURES_PACK, fixtures_dict={"rules": [file_name]} + )["rules"][file_name] + ) + + file_name = "rule space.yaml" + RulesControllerTestCase.RULE_SPACE = ( + RulesControllerTestCase.fixtures_loader.load_fixtures( + fixtures_pack=FIXTURES_PACK, fixtures_dict={"rules": [file_name]} + )["rules"][file_name] + ) @classmethod def tearDownClass(cls): @@ -136,18 +163,19 @@ def tearDownClass(cls): cfg.CONF.system.validate_trigger_payload = cls.VALIDATE_TRIGGER_PAYLOAD RulesControllerTestCase.fixtures_loader.delete_fixtures_from_db( - fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES) + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES + ) super(RulesControllerTestCase, cls).setUpClass() def test_get_all_and_minus_one(self): post_resp_rule_1 = self.__do_post(RulesControllerTestCase.RULE_1) post_resp_rule_3 = self.__do_post(RulesControllerTestCase.RULE_3) - resp = self.app.get('/v1/rules') + resp = self.app.get("/v1/rules") self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 2) - resp = self.app.get('/v1/rules/?limit=-1') + resp = self.app.get("/v1/rules/?limit=-1") self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 2) @@ -158,10 +186,12 @@ def test_get_all_limit_negative_number(self): post_resp_rule_1 = self.__do_post(RulesControllerTestCase.RULE_1) post_resp_rule_3 = self.__do_post(RulesControllerTestCase.RULE_3) - resp = self.app.get('/v1/rules?limit=-22', expect_errors=True) + resp = self.app.get("/v1/rules?limit=-22", expect_errors=True) self.assertEqual(resp.status_int, 400) - self.assertEqual(resp.json['faultstring'], - u'Limit, "-22" specified, must be a positive number.') + self.assertEqual( + resp.json["faultstring"], + 'Limit, "-22" specified, must be a positive number.', + ) self.__do_delete(self.__get_rule_id(post_resp_rule_1)) self.__do_delete(self.__get_rule_id(post_resp_rule_3)) @@ -171,18 +201,18 @@ def test_get_all_enabled(self): post_resp_rule_3 = self.__do_post(RulesControllerTestCase.RULE_3) # enabled=True - resp = self.app.get('/v1/rules?enabled=True') + resp = self.app.get("/v1/rules?enabled=True") self.assertEqual(resp.status_int, http_client.OK) rule = resp.json[0] - self.assertEqual(self.__get_rule_id(post_resp_rule_1), rule['id']) - self.assertEqual(rule['enabled'], True) + self.assertEqual(self.__get_rule_id(post_resp_rule_1), rule["id"]) + self.assertEqual(rule["enabled"], True) # enabled=False - resp = self.app.get('/v1/rules?enabled=False') + resp = self.app.get("/v1/rules?enabled=False") self.assertEqual(resp.status_int, http_client.OK) rule = resp.json[0] - self.assertEqual(self.__get_rule_id(post_resp_rule_3), rule['id']) - self.assertEqual(rule['enabled'], False) + self.assertEqual(self.__get_rule_id(post_resp_rule_3), rule["id"]) + self.assertEqual(rule["enabled"], False) self.__do_delete(self.__get_rule_id(post_resp_rule_1)) self.__do_delete(self.__get_rule_id(post_resp_rule_3)) @@ -191,37 +221,45 @@ def test_get_all_action_parameters_secrets_masking(self): post_resp_rule_1 = self.__do_post(RulesControllerTestCase.RULE_1) # Verify parameter is masked by default - resp = self.app.get('/v1/rules') - self.assertEqual('action' in resp.json[0], True) - self.assertEqual(resp.json[0]['action']['parameters']['action_secret'], - MASKED_ATTRIBUTE_VALUE) + resp = self.app.get("/v1/rules") + self.assertEqual("action" in resp.json[0], True) + self.assertEqual( + resp.json[0]["action"]["parameters"]["action_secret"], + MASKED_ATTRIBUTE_VALUE, + ) # Verify ?show_secrets=true works - resp = self.app.get('/v1/rules?include_attributes=action&show_secrets=true') - self.assertEqual('action' in resp.json[0], True) - self.assertEqual(resp.json[0]['action']['parameters']['action_secret'], 'secret') + resp = self.app.get("/v1/rules?include_attributes=action&show_secrets=true") + self.assertEqual("action" in resp.json[0], True) + self.assertEqual( + resp.json[0]["action"]["parameters"]["action_secret"], "secret" + ) self.__do_delete(self.__get_rule_id(post_resp_rule_1)) def test_get_all_parameters_mask_with_exclude_parameters(self): post_resp_rule_1 = self.__do_post(RulesControllerTestCase.RULE_1) - resp = self.app.get('/v1/rules?exclude_attributes=action') - self.assertEqual('action' in resp.json[0], False) + resp = self.app.get("/v1/rules?exclude_attributes=action") + self.assertEqual("action" in resp.json[0], False) self.__do_delete(self.__get_rule_id(post_resp_rule_1)) def test_get_all_parameters_mask_with_include_parameters(self): post_resp_rule_1 = self.__do_post(RulesControllerTestCase.RULE_1) # Verify parameter is masked by default - resp = self.app.get('/v1/rules?include_attributes=action') - self.assertEqual('action' in resp.json[0], True) - self.assertEqual(resp.json[0]['action']['parameters']['action_secret'], - MASKED_ATTRIBUTE_VALUE) + resp = self.app.get("/v1/rules?include_attributes=action") + self.assertEqual("action" in resp.json[0], True) + self.assertEqual( + resp.json[0]["action"]["parameters"]["action_secret"], + MASKED_ATTRIBUTE_VALUE, + ) # Verify ?show_secrets=true works - resp = self.app.get('/v1/rules?include_attributes=action&show_secrets=true') - self.assertEqual('action' in resp.json[0], True) - self.assertEqual(resp.json[0]['action']['parameters']['action_secret'], 'secret') + resp = self.app.get("/v1/rules?include_attributes=action&show_secrets=true") + self.assertEqual("action" in resp.json[0], True) + self.assertEqual( + resp.json[0]["action"]["parameters"]["action_secret"], "secret" + ) self.__do_delete(self.__get_rule_id(post_resp_rule_1)) @@ -229,13 +267,16 @@ def test_get_one_action_parameters_secrets_masking(self): post_resp_rule_1 = self.__do_post(RulesControllerTestCase.RULE_1) # Verify parameter is masked by default - resp = self.app.get('/v1/rules/%s' % (post_resp_rule_1.json['id'])) - self.assertEqual(resp.json['action']['parameters']['action_secret'], - MASKED_ATTRIBUTE_VALUE) + resp = self.app.get("/v1/rules/%s" % (post_resp_rule_1.json["id"])) + self.assertEqual( + resp.json["action"]["parameters"]["action_secret"], MASKED_ATTRIBUTE_VALUE + ) # Verify ?show_secrets=true works - resp = self.app.get('/v1/rules/%s?show_secrets=true' % (post_resp_rule_1.json['id'])) - self.assertEqual(resp.json['action']['parameters']['action_secret'], 'secret') + resp = self.app.get( + "/v1/rules/%s?show_secrets=true" % (post_resp_rule_1.json["id"]) + ) + self.assertEqual(resp.json["action"]["parameters"]["action_secret"], "secret") self.__do_delete(self.__get_rule_id(post_resp_rule_1)) @@ -249,27 +290,27 @@ def test_get_one_by_id(self): def test_get_one_by_ref(self): post_resp = self.__do_post(RulesControllerTestCase.RULE_1) - rule_name = post_resp.json['name'] - rule_pack = post_resp.json['pack'] + rule_name = post_resp.json["name"] + rule_pack = post_resp.json["pack"] ref = ResourceReference.to_string_reference(name=rule_name, pack=rule_pack) - rule_id = post_resp.json['id'] + rule_id = post_resp.json["id"] get_resp = self.__do_get_one(ref) - self.assertEqual(get_resp.json['name'], rule_name) + self.assertEqual(get_resp.json["name"], rule_name) self.assertEqual(get_resp.status_int, http_client.OK) self.__do_delete(rule_id) post_resp = self.__do_post(RulesControllerTestCase.RULE_SPACE) - rule_name = post_resp.json['name'] - rule_pack = post_resp.json['pack'] + rule_name = post_resp.json["name"] + rule_pack = post_resp.json["pack"] ref = ResourceReference.to_string_reference(name=rule_name, pack=rule_pack) - rule_id = post_resp.json['id'] + rule_id = post_resp.json["id"] get_resp = self.__do_get_one(ref) - self.assertEqual(get_resp.json['name'], rule_name) + self.assertEqual(get_resp.json["name"], rule_name) self.assertEqual(get_resp.status_int, http_client.OK) self.__do_delete(rule_id) def test_get_one_fail(self): - resp = self.app.get('/v1/rules/1', expect_errors=True) + resp = self.app.get("/v1/rules/1", expect_errors=True) self.assertEqual(resp.status_int, http_client.NOT_FOUND) def test_post(self): @@ -283,38 +324,44 @@ def test_post_duplicate(self): self.assertEqual(post_resp.status_int, http_client.CREATED) post_resp_2 = self.__do_post(RulesControllerTestCase.RULE_1) self.assertEqual(post_resp_2.status_int, http_client.CONFLICT) - self.assertEqual(post_resp_2.json['conflict-id'], org_id) + self.assertEqual(post_resp_2.json["conflict-id"], org_id) self.__do_delete(org_id) def test_post_invalid_rule_data(self): - post_resp = self.__do_post({'name': 'rule'}) + post_resp = self.__do_post({"name": "rule"}) self.assertEqual(post_resp.status_int, http_client.BAD_REQUEST) expected_msg = "'trigger' is a required property" - self.assertEqual(post_resp.json['faultstring'], expected_msg) + self.assertEqual(post_resp.json["faultstring"], expected_msg) def test_post_trigger_parameter_schema_validation_fails(self): post_resp = self.__do_post(RulesControllerTestCase.RULE_2) self.assertEqual(post_resp.status_int, http_client.BAD_REQUEST) if six.PY3: - expected_msg = b'Additional properties are not allowed (\'minutex\' was unexpected)' + expected_msg = ( + b"Additional properties are not allowed ('minutex' was unexpected)" + ) else: - expected_msg = b'Additional properties are not allowed (u\'minutex\' was unexpected)' + expected_msg = ( + b"Additional properties are not allowed (u'minutex' was unexpected)" + ) self.assertIn(expected_msg, post_resp.body) - def test_post_trigger_parameter_schema_validation_fails_missing_required_param(self): + def test_post_trigger_parameter_schema_validation_fails_missing_required_param( + self, + ): post_resp = self.__do_post(RulesControllerTestCase.RULE_5) self.assertEqual(post_resp.status_int, http_client.BAD_REQUEST) - expected_msg = b'\'date\' is a required property' + expected_msg = b"'date' is a required property" self.assertIn(expected_msg, post_resp.body) def test_post_invalid_crontimer_trigger_parameters(self): post_resp = self.__do_post(RulesControllerTestCase.RULE_6) self.assertEqual(post_resp.status_int, http_client.BAD_REQUEST) - expected_msg = b'1000 is greater than the maximum of 6' + expected_msg = b"1000 is greater than the maximum of 6" self.assertIn(expected_msg, post_resp.body) post_resp = self.__do_post(RulesControllerTestCase.RULE_7) @@ -329,7 +376,9 @@ def test_post_invalid_crontimer_trigger_parameters(self): expected_msg = b'Invalid weekday name \\"a\\"' self.assertIn(expected_msg, post_resp.body) - def test_post_invalid_custom_trigger_parameter_trigger_param_validation_enabled(self): + def test_post_invalid_custom_trigger_parameter_trigger_param_validation_enabled( + self, + ): # Invalid custom trigger parameter (invalid type) and non-system trigger parameter # validation is enabled - trigger creation should fail cfg.CONF.system.validate_trigger_parameters = True @@ -338,16 +387,22 @@ def test_post_invalid_custom_trigger_parameter_trigger_param_validation_enabled( self.assertEqual(post_resp.status_int, http_client.BAD_REQUEST) if six.PY3: - expected_msg_1 = "Failed validating 'type' in schema['properties']['param1']:" - expected_msg_2 = '12345 is not of type \'string\'' + expected_msg_1 = ( + "Failed validating 'type' in schema['properties']['param1']:" + ) + expected_msg_2 = "12345 is not of type 'string'" else: - expected_msg_1 = "Failed validating u'type' in schema[u'properties'][u'param1']:" - expected_msg_2 = '12345 is not of type u\'string\'' + expected_msg_1 = ( + "Failed validating u'type' in schema[u'properties'][u'param1']:" + ) + expected_msg_2 = "12345 is not of type u'string'" - self.assertIn(expected_msg_1, post_resp.json['faultstring']) - self.assertIn(expected_msg_2, post_resp.json['faultstring']) + self.assertIn(expected_msg_1, post_resp.json["faultstring"]) + self.assertIn(expected_msg_2, post_resp.json["faultstring"]) - def test_post_invalid_custom_trigger_parameter_trigger_param_validation_disabled(self): + def test_post_invalid_custom_trigger_parameter_trigger_param_validation_disabled( + self, + ): # Invalid custom trigger parameter (invalid type) and non-system trigger parameter # validation is disabled - trigger creation should succeed cfg.CONF.system.validate_trigger_parameters = False @@ -368,33 +423,33 @@ def test_post_invalid_custom_trigger_parameter_trigger_no_parameters_schema(self def test_post_no_enabled_attribute_disabled_by_default(self): post_resp = self.__do_post(RulesControllerTestCase.RULE_3) self.assertEqual(post_resp.status_int, http_client.CREATED) - self.assertFalse(post_resp.json['enabled']) + self.assertFalse(post_resp.json["enabled"]) self.__do_delete(self.__get_rule_id(post_resp)) def test_put(self): post_resp = self.__do_post(RulesControllerTestCase.RULE_1) update_input = post_resp.json - update_input['enabled'] = not update_input['enabled'] + update_input["enabled"] = not update_input["enabled"] put_resp = self.__do_put(self.__get_rule_id(post_resp), update_input) self.assertEqual(put_resp.status_int, http_client.OK) self.__do_delete(self.__get_rule_id(put_resp)) def test_post_no_pack_info(self): rule = copy.deepcopy(RulesControllerTestCase.RULE_1) - del rule['pack'] + del rule["pack"] post_resp = self.__do_post(rule) - self.assertEqual(post_resp.json['pack'], DEFAULT_PACK_NAME) + self.assertEqual(post_resp.json["pack"], DEFAULT_PACK_NAME) self.assertEqual(post_resp.status_int, http_client.CREATED) self.__do_delete(self.__get_rule_id(post_resp)) def test_put_no_pack_info(self): post_resp = self.__do_post(RulesControllerTestCase.RULE_1) test_rule = post_resp.json - if 'pack' in test_rule: - del test_rule['pack'] - self.assertNotIn('pack', test_rule) + if "pack" in test_rule: + del test_rule["pack"] + self.assertNotIn("pack", test_rule) put_resp = self.__do_put(self.__get_rule_id(post_resp), test_rule) - self.assertEqual(put_resp.json['pack'], DEFAULT_PACK_NAME) + self.assertEqual(put_resp.json["pack"], DEFAULT_PACK_NAME) self.assertEqual(put_resp.status_int, http_client.OK) self.__do_delete(self.__get_rule_id(put_resp)) @@ -417,7 +472,7 @@ def test_rule_with_tags(self): get_resp = self.__do_get_one(rule_id) self.assertEqual(get_resp.status_int, http_client.OK) self.assertEqual(self.__get_rule_id(get_resp), rule_id) - self.assertEqual(get_resp.json['tags'], RulesControllerTestCase.RULE_1['tags']) + self.assertEqual(get_resp.json["tags"], RulesControllerTestCase.RULE_1["tags"]) self.__do_delete(rule_id) def test_rule_without_type(self): @@ -426,10 +481,13 @@ def test_rule_without_type(self): get_resp = self.__do_get_one(rule_id) self.assertEqual(get_resp.status_int, http_client.OK) self.assertEqual(self.__get_rule_id(get_resp), rule_id) - assigned_rule_type = get_resp.json['type'] - self.assertTrue(assigned_rule_type, 'rule_type should be assigned') - self.assertEqual(assigned_rule_type['ref'], RULE_TYPE_STANDARD, - 'rule_type should be standard') + assigned_rule_type = get_resp.json["type"] + self.assertTrue(assigned_rule_type, "rule_type should be assigned") + self.assertEqual( + assigned_rule_type["ref"], + RULE_TYPE_STANDARD, + "rule_type should be standard", + ) self.__do_delete(rule_id) def test_rule_with_type(self): @@ -438,10 +496,13 @@ def test_rule_with_type(self): get_resp = self.__do_get_one(rule_id) self.assertEqual(get_resp.status_int, http_client.OK) self.assertEqual(self.__get_rule_id(get_resp), rule_id) - assigned_rule_type = get_resp.json['type'] - self.assertTrue(assigned_rule_type, 'rule_type should be assigned') - self.assertEqual(assigned_rule_type['ref'], RULE_TYPE_BACKSTOP, - 'rule_type should be backstop') + assigned_rule_type = get_resp.json["type"] + self.assertTrue(assigned_rule_type, "rule_type should be assigned") + self.assertEqual( + assigned_rule_type["ref"], + RULE_TYPE_BACKSTOP, + "rule_type should be backstop", + ) self.__do_delete(rule_id) def test_update_rule_no_data(self): @@ -451,7 +512,7 @@ def test_update_rule_no_data(self): put_resp = self.__do_put(rule_1_id, {}) expected_msg = "'name' is a required property" self.assertEqual(put_resp.status_code, http_client.BAD_REQUEST) - self.assertEqual(put_resp.json['faultstring'], expected_msg) + self.assertEqual(put_resp.json["faultstring"], expected_msg) self.__do_delete(rule_1_id) @@ -460,16 +521,16 @@ def test_update_rule_missing_id_in_body(self): rule_1_id = self.__get_rule_id(post_resp) rule_without_id = copy.deepcopy(self.RULE_1) - rule_without_id.pop('id', None) + rule_without_id.pop("id", None) put_resp = self.__do_put(rule_1_id, rule_without_id) self.assertEqual(put_resp.status_int, http_client.OK) - self.assertEqual(put_resp.json['id'], rule_1_id) + self.assertEqual(put_resp.json["id"], rule_1_id) self.__do_delete(rule_1_id) def _insert_mock_models(self): rule = copy.deepcopy(RulesControllerTestCase.RULE_1) - rule['name'] += '-253' + rule["name"] += "-253" post_resp = self.__do_post(rule) rule_1_id = self.__get_rule_id(post_resp) return [rule_1_id] @@ -479,32 +540,32 @@ def _do_delete(self, rule_id): @staticmethod def __get_rule_id(resp): - return resp.json['id'] + return resp.json["id"] def __do_get_one(self, rule_id): - return self.app.get('/v1/rules/%s' % rule_id, expect_errors=True) + return self.app.get("/v1/rules/%s" % rule_id, expect_errors=True) - @mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) + @mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) def __do_post(self, rule): - return self.app.post_json('/v1/rules', rule, expect_errors=True) + return self.app.post_json("/v1/rules", rule, expect_errors=True) - @mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) + @mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) def __do_put(self, rule_id, rule): - return self.app.put_json('/v1/rules/%s' % rule_id, rule, expect_errors=True) + return self.app.put_json("/v1/rules/%s" % rule_id, rule, expect_errors=True) - @mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) + @mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) def __do_delete(self, rule_id): - return self.app.delete('/v1/rules/%s' % rule_id) + return self.app.delete("/v1/rules/%s" % rule_id) TEST_FIXTURES_2 = { - 'runners': ['testrunner1.yaml'], - 'actions': ['action1.yaml'], - 'triggertypes': ['triggertype_with_parameter.yaml'] + "runners": ["testrunner1.yaml"], + "actions": ["action1.yaml"], + "triggertypes": ["triggertype_with_parameter.yaml"], } -@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) +@mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) class RulesControllerTestCaseTriggerCreator(FunctionalTest): fixtures_loader = FixturesLoader() @@ -513,32 +574,33 @@ class RulesControllerTestCaseTriggerCreator(FunctionalTest): def setUpClass(cls): super(RulesControllerTestCaseTriggerCreator, cls).setUpClass() cls.models = cls.fixtures_loader.save_fixtures_to_db( - fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES_2) + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES_2 + ) # Don't load rule into DB as that is what is being tested. - file_name = 'rule_trigger_params.yaml' + file_name = "rule_trigger_params.yaml" cls.RULE_1 = cls.fixtures_loader.load_fixtures( - fixtures_pack=FIXTURES_PACK, - fixtures_dict={'rules': [file_name]})['rules'][file_name] + fixtures_pack=FIXTURES_PACK, fixtures_dict={"rules": [file_name]} + )["rules"][file_name] def test_ref_count_trigger_increment(self): post_resp = self.__do_post(self.RULE_1) rule_1_id = self.__get_rule_id(post_resp) self.assertEqual(post_resp.status_int, http_client.CREATED) # ref_count is not served over API. Likely a choice that will prove unwise. - triggers = Trigger.get_all(**{'type': post_resp.json['trigger']['type']}) - self.assertEqual(len(triggers), 1, 'Exactly 1 should exist') - self.assertEqual(triggers[0].ref_count, 1, 'ref_count should be 1') + triggers = Trigger.get_all(**{"type": post_resp.json["trigger"]["type"]}) + self.assertEqual(len(triggers), 1, "Exactly 1 should exist") + self.assertEqual(triggers[0].ref_count, 1, "ref_count should be 1") # different rule same params rule_2 = copy.copy(self.RULE_1) - rule_2['name'] = rule_2['name'] + '-2' + rule_2["name"] = rule_2["name"] + "-2" post_resp = self.__do_post(rule_2) rule_2_id = self.__get_rule_id(post_resp) self.assertEqual(post_resp.status_int, http_client.CREATED) - triggers = Trigger.get_all(**{'type': post_resp.json['trigger']['type']}) - self.assertEqual(len(triggers), 1, 'Exactly 1 should exist') - self.assertEqual(triggers[0].ref_count, 2, 'ref_count should be 1') + triggers = Trigger.get_all(**{"type": post_resp.json["trigger"]["type"]}) + self.assertEqual(len(triggers), 1, "Exactly 1 should exist") + self.assertEqual(triggers[0].ref_count, 2, "ref_count should be 1") self.__do_delete(rule_1_id) self.__do_delete(rule_2_id) @@ -549,16 +611,16 @@ def test_ref_count_trigger_decrement(self): self.assertEqual(post_resp.status_int, http_client.CREATED) rule_2 = copy.copy(self.RULE_1) - rule_2['name'] = rule_2['name'] + '-2' + rule_2["name"] = rule_2["name"] + "-2" post_resp = self.__do_post(rule_2) rule_2_id = self.__get_rule_id(post_resp) self.assertEqual(post_resp.status_int, http_client.CREATED) # validate decrement self.__do_delete(rule_1_id) - triggers = Trigger.get_all(**{'type': post_resp.json['trigger']['type']}) - self.assertEqual(len(triggers), 1, 'Exactly 1 should exist') - self.assertEqual(triggers[0].ref_count, 1, 'ref_count should be 1') + triggers = Trigger.get_all(**{"type": post_resp.json["trigger"]["type"]}) + self.assertEqual(len(triggers), 1, "Exactly 1 should exist") + self.assertEqual(triggers[0].ref_count, 1, "ref_count should be 1") self.__do_delete(rule_2_id) def test_trigger_cleanup(self): @@ -567,34 +629,34 @@ def test_trigger_cleanup(self): self.assertEqual(post_resp.status_int, http_client.CREATED) rule_2 = copy.copy(self.RULE_1) - rule_2['name'] = rule_2['name'] + '-2' + rule_2["name"] = rule_2["name"] + "-2" post_resp = self.__do_post(rule_2) rule_2_id = self.__get_rule_id(post_resp) self.assertEqual(post_resp.status_int, http_client.CREATED) - triggers = Trigger.get_all(**{'type': post_resp.json['trigger']['type']}) - self.assertEqual(len(triggers), 1, 'Exactly 1 should exist') - self.assertEqual(triggers[0].ref_count, 2, 'ref_count should be 1') + triggers = Trigger.get_all(**{"type": post_resp.json["trigger"]["type"]}) + self.assertEqual(len(triggers), 1, "Exactly 1 should exist") + self.assertEqual(triggers[0].ref_count, 2, "ref_count should be 1") self.__do_delete(rule_1_id) self.__do_delete(rule_2_id) # validate cleanup - triggers = Trigger.get_all(**{'type': post_resp.json['trigger']['type']}) - self.assertEqual(len(triggers), 0, 'Exactly 1 should exist') + triggers = Trigger.get_all(**{"type": post_resp.json["trigger"]["type"]}) + self.assertEqual(len(triggers), 0, "Exactly 1 should exist") @staticmethod def __get_rule_id(resp): - return resp.json['id'] + return resp.json["id"] def __do_get_one(self, rule_id): - return self.app.get('/v1/rules/%s' % rule_id, expect_errors=True) + return self.app.get("/v1/rules/%s" % rule_id, expect_errors=True) def __do_post(self, rule): - return self.app.post_json('/v1/rules', rule, expect_errors=True) + return self.app.post_json("/v1/rules", rule, expect_errors=True) def __do_put(self, rule_id, rule): - return self.app.put_json('/v1/rules/%s' % rule_id, rule, expect_errors=True) + return self.app.put_json("/v1/rules/%s" % rule_id, rule, expect_errors=True) def __do_delete(self, rule_id): - return self.app.delete('/v1/rules/%s' % rule_id) + return self.app.delete("/v1/rules/%s" % rule_id) diff --git a/st2api/tests/unit/controllers/v1/test_ruletypes.py b/st2api/tests/unit/controllers/v1/test_ruletypes.py index 5cba9614098..87b1c4c584a 100644 --- a/st2api/tests/unit/controllers/v1/test_ruletypes.py +++ b/st2api/tests/unit/controllers/v1/test_ruletypes.py @@ -26,20 +26,26 @@ def setUpClass(cls): ruletypes_registrar.register_rule_types() def test_get_one(self): - list_resp = self.app.get('/v1/ruletypes') + list_resp = self.app.get("/v1/ruletypes") self.assertEqual(list_resp.status_int, 200) - self.assertTrue(len(list_resp.json) > 0, '/v1/ruletypes did not return correct ruletypes.') - ruletype_id = list_resp.json[0]['id'] - get_resp = self.app.get('/v1/ruletypes/%s' % ruletype_id) - retrieved_id = get_resp.json['id'] + self.assertTrue( + len(list_resp.json) > 0, "/v1/ruletypes did not return correct ruletypes." + ) + ruletype_id = list_resp.json[0]["id"] + get_resp = self.app.get("/v1/ruletypes/%s" % ruletype_id) + retrieved_id = get_resp.json["id"] self.assertEqual(get_resp.status_int, 200) - self.assertEqual(retrieved_id, ruletype_id, '/v1/ruletypes returned incorrect ruletype.') + self.assertEqual( + retrieved_id, ruletype_id, "/v1/ruletypes returned incorrect ruletype." + ) def test_get_all(self): - resp = self.app.get('/v1/ruletypes') + resp = self.app.get("/v1/ruletypes") self.assertEqual(resp.status_int, 200) - self.assertTrue(len(resp.json) > 0, '/v1/ruletypes did not return correct ruletypes.') + self.assertTrue( + len(resp.json) > 0, "/v1/ruletypes did not return correct ruletypes." + ) def test_get_one_fail_doesnt_exist(self): - resp = self.app.get('/v1/ruletypes/1', expect_errors=True) + resp = self.app.get("/v1/ruletypes/1", expect_errors=True) self.assertEqual(resp.status_int, 404) diff --git a/st2api/tests/unit/controllers/v1/test_runnertypes.py b/st2api/tests/unit/controllers/v1/test_runnertypes.py index edaacdf6dda..34c243c5452 100644 --- a/st2api/tests/unit/controllers/v1/test_runnertypes.py +++ b/st2api/tests/unit/controllers/v1/test_runnertypes.py @@ -18,67 +18,76 @@ from st2tests.api import FunctionalTest from st2tests.api import APIControllerWithIncludeAndExcludeFilterTestCase -__all__ = [ - 'RunnerTypesControllerTestCase' -] +__all__ = ["RunnerTypesControllerTestCase"] -class RunnerTypesControllerTestCase(FunctionalTest, - APIControllerWithIncludeAndExcludeFilterTestCase): - get_all_path = '/v1/runnertypes' +class RunnerTypesControllerTestCase( + FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase +): + get_all_path = "/v1/runnertypes" controller_cls = RunnerTypesController - include_attribute_field_name = 'runner_package' - exclude_attribute_field_name = 'runner_module' - test_exact_object_count = False # runners are registered dynamically in base test class + include_attribute_field_name = "runner_package" + exclude_attribute_field_name = "runner_module" + test_exact_object_count = ( + False # runners are registered dynamically in base test class + ) def test_get_one(self): - resp = self.app.get('/v1/runnertypes') + resp = self.app.get("/v1/runnertypes") self.assertEqual(resp.status_int, 200) - self.assertTrue(len(resp.json) > 0, '/v1/runnertypes did not return correct runnertypes.') + self.assertTrue( + len(resp.json) > 0, "/v1/runnertypes did not return correct runnertypes." + ) runnertype_id = RunnerTypesControllerTestCase.__get_runnertype_id(resp.json[0]) - resp = self.app.get('/v1/runnertypes/%s' % runnertype_id) + resp = self.app.get("/v1/runnertypes/%s" % runnertype_id) retrieved_id = RunnerTypesControllerTestCase.__get_runnertype_id(resp.json) self.assertEqual(resp.status_int, 200) - self.assertEqual(retrieved_id, runnertype_id, - '/v1/runnertypes returned incorrect runnertype.') + self.assertEqual( + retrieved_id, + runnertype_id, + "/v1/runnertypes returned incorrect runnertype.", + ) def test_get_all(self): - resp = self.app.get('/v1/runnertypes') + resp = self.app.get("/v1/runnertypes") self.assertEqual(resp.status_int, 200) - self.assertTrue(len(resp.json) > 0, '/v1/runnertypes did not return correct runnertypes.') + self.assertTrue( + len(resp.json) > 0, "/v1/runnertypes did not return correct runnertypes." + ) def test_get_one_fail_doesnt_exist(self): - resp = self.app.get('/v1/runnertypes/1', expect_errors=True) + resp = self.app.get("/v1/runnertypes/1", expect_errors=True) self.assertEqual(resp.status_int, 404) def test_put_disable_runner(self): - runnertype_id = 'action-chain' - resp = self.app.get('/v1/runnertypes/%s' % runnertype_id) - self.assertTrue(resp.json['enabled']) + runnertype_id = "action-chain" + resp = self.app.get("/v1/runnertypes/%s" % runnertype_id) + self.assertTrue(resp.json["enabled"]) # Disable the runner update_input = resp.json - update_input['enabled'] = False - update_input['name'] = 'foobar' + update_input["enabled"] = False + update_input["name"] = "foobar" put_resp = self.__do_put(runnertype_id, update_input) - self.assertFalse(put_resp.json['enabled']) + self.assertFalse(put_resp.json["enabled"]) # Verify that the name hasn't been updated - we only allow updating # enabled attribute on the runner - self.assertEqual(put_resp.json['name'], 'action-chain') + self.assertEqual(put_resp.json["name"], "action-chain") # Enable the runner update_input = resp.json - update_input['enabled'] = True + update_input["enabled"] = True put_resp = self.__do_put(runnertype_id, update_input) - self.assertTrue(put_resp.json['enabled']) + self.assertTrue(put_resp.json["enabled"]) def __do_put(self, runner_type_id, runner_type): - return self.app.put_json('/v1/runnertypes/%s' % runner_type_id, runner_type, - expect_errors=True) + return self.app.put_json( + "/v1/runnertypes/%s" % runner_type_id, runner_type, expect_errors=True + ) @staticmethod def __get_runnertype_id(resp_json): - return resp_json['id'] + return resp_json["id"] diff --git a/st2api/tests/unit/controllers/v1/test_sensortypes.py b/st2api/tests/unit/controllers/v1/test_sensortypes.py index 8e66cdfb40c..c59a1c28e2c 100644 --- a/st2api/tests/unit/controllers/v1/test_sensortypes.py +++ b/st2api/tests/unit/controllers/v1/test_sensortypes.py @@ -25,17 +25,16 @@ http_client = six.moves.http_client -__all__ = [ - 'SensorTypeControllerTestCase' -] +__all__ = ["SensorTypeControllerTestCase"] -class SensorTypeControllerTestCase(FunctionalTest, - APIControllerWithIncludeAndExcludeFilterTestCase): - get_all_path = '/v1/sensortypes' +class SensorTypeControllerTestCase( + FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase +): + get_all_path = "/v1/sensortypes" controller_cls = SensorTypeController - include_attribute_field_name = 'entry_point' - exclude_attribute_field_name = 'artifact_uri' + include_attribute_field_name = "entry_point" + exclude_attribute_field_name = "artifact_uri" test_exact_object_count = False @classmethod @@ -46,106 +45,108 @@ def setUpClass(cls): sensors_registrar.register_sensors(use_pack_cache=False) def test_get_all_and_minus_one(self): - resp = self.app.get('/v1/sensortypes') + resp = self.app.get("/v1/sensortypes") self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 3) - self.assertEqual(resp.json[0]['name'], 'SampleSensor') + self.assertEqual(resp.json[0]["name"], "SampleSensor") - resp = self.app.get('/v1/sensortypes/?limit=-1') + resp = self.app.get("/v1/sensortypes/?limit=-1") self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 3) - self.assertEqual(resp.json[0]['name'], 'SampleSensor') + self.assertEqual(resp.json[0]["name"], "SampleSensor") def test_get_all_negative_limit(self): - resp = self.app.get('/v1/sensortypes/?limit=-22', expect_errors=True) + resp = self.app.get("/v1/sensortypes/?limit=-22", expect_errors=True) self.assertEqual(resp.status_int, 400) - self.assertEqual(resp.json['faultstring'], - u'Limit, "-22" specified, must be a positive number.') + self.assertEqual( + resp.json["faultstring"], + 'Limit, "-22" specified, must be a positive number.', + ) def test_get_all_filters(self): - resp = self.app.get('/v1/sensortypes') + resp = self.app.get("/v1/sensortypes") self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 3) # ?name filter - resp = self.app.get('/v1/sensortypes?name=foobar') + resp = self.app.get("/v1/sensortypes?name=foobar") self.assertEqual(len(resp.json), 0) - resp = self.app.get('/v1/sensortypes?name=SampleSensor2') + resp = self.app.get("/v1/sensortypes?name=SampleSensor2") self.assertEqual(len(resp.json), 1) - self.assertEqual(resp.json[0]['name'], 'SampleSensor2') - self.assertEqual(resp.json[0]['ref'], 'dummy_pack_1.SampleSensor2') + self.assertEqual(resp.json[0]["name"], "SampleSensor2") + self.assertEqual(resp.json[0]["ref"], "dummy_pack_1.SampleSensor2") - resp = self.app.get('/v1/sensortypes?name=SampleSensor3') + resp = self.app.get("/v1/sensortypes?name=SampleSensor3") self.assertEqual(len(resp.json), 1) - self.assertEqual(resp.json[0]['name'], 'SampleSensor3') + self.assertEqual(resp.json[0]["name"], "SampleSensor3") # ?pack filter - resp = self.app.get('/v1/sensortypes?pack=foobar') + resp = self.app.get("/v1/sensortypes?pack=foobar") self.assertEqual(len(resp.json), 0) - resp = self.app.get('/v1/sensortypes?pack=dummy_pack_1') + resp = self.app.get("/v1/sensortypes?pack=dummy_pack_1") self.assertEqual(len(resp.json), 3) # ?enabled filter - resp = self.app.get('/v1/sensortypes?enabled=False') + resp = self.app.get("/v1/sensortypes?enabled=False") self.assertEqual(len(resp.json), 1) - self.assertEqual(resp.json[0]['enabled'], False) + self.assertEqual(resp.json[0]["enabled"], False) - resp = self.app.get('/v1/sensortypes?enabled=True') + resp = self.app.get("/v1/sensortypes?enabled=True") self.assertEqual(len(resp.json), 2) - self.assertEqual(resp.json[0]['enabled'], True) - self.assertEqual(resp.json[1]['enabled'], True) + self.assertEqual(resp.json[0]["enabled"], True) + self.assertEqual(resp.json[1]["enabled"], True) # ?trigger filter - resp = self.app.get('/v1/sensortypes?trigger=dummy_pack_1.event3') + resp = self.app.get("/v1/sensortypes?trigger=dummy_pack_1.event3") self.assertEqual(len(resp.json), 1) - self.assertEqual(resp.json[0]['trigger_types'], ['dummy_pack_1.event3']) + self.assertEqual(resp.json[0]["trigger_types"], ["dummy_pack_1.event3"]) - resp = self.app.get('/v1/sensortypes?trigger=dummy_pack_1.event') + resp = self.app.get("/v1/sensortypes?trigger=dummy_pack_1.event") self.assertEqual(len(resp.json), 2) - self.assertEqual(resp.json[0]['trigger_types'], ['dummy_pack_1.event']) - self.assertEqual(resp.json[1]['trigger_types'], ['dummy_pack_1.event']) + self.assertEqual(resp.json[0]["trigger_types"], ["dummy_pack_1.event"]) + self.assertEqual(resp.json[1]["trigger_types"], ["dummy_pack_1.event"]) def test_get_one_success(self): - resp = self.app.get('/v1/sensortypes/dummy_pack_1.SampleSensor') + resp = self.app.get("/v1/sensortypes/dummy_pack_1.SampleSensor") self.assertEqual(resp.status_int, http_client.OK) - self.assertEqual(resp.json['name'], 'SampleSensor') - self.assertEqual(resp.json['ref'], 'dummy_pack_1.SampleSensor') + self.assertEqual(resp.json["name"], "SampleSensor") + self.assertEqual(resp.json["ref"], "dummy_pack_1.SampleSensor") def test_get_one_doesnt_exist(self): - resp = self.app.get('/v1/sensortypes/1', expect_errors=True) + resp = self.app.get("/v1/sensortypes/1", expect_errors=True) self.assertEqual(resp.status_int, http_client.NOT_FOUND) def test_disable_and_enable_sensor(self): # Verify initial state - resp = self.app.get('/v1/sensortypes/dummy_pack_1.SampleSensor') + resp = self.app.get("/v1/sensortypes/dummy_pack_1.SampleSensor") self.assertEqual(resp.status_int, http_client.OK) - self.assertTrue(resp.json['enabled']) + self.assertTrue(resp.json["enabled"]) sensor_data = resp.json # Disable sensor data = copy.deepcopy(sensor_data) - data['enabled'] = False - put_resp = self.app.put_json('/v1/sensortypes/dummy_pack_1.SampleSensor', data) + data["enabled"] = False + put_resp = self.app.put_json("/v1/sensortypes/dummy_pack_1.SampleSensor", data) self.assertEqual(put_resp.status_int, http_client.OK) - self.assertEqual(put_resp.json['ref'], 'dummy_pack_1.SampleSensor') - self.assertFalse(put_resp.json['enabled']) + self.assertEqual(put_resp.json["ref"], "dummy_pack_1.SampleSensor") + self.assertFalse(put_resp.json["enabled"]) # Verify sensor has been disabled - resp = self.app.get('/v1/sensortypes/dummy_pack_1.SampleSensor') + resp = self.app.get("/v1/sensortypes/dummy_pack_1.SampleSensor") self.assertEqual(resp.status_int, http_client.OK) - self.assertFalse(resp.json['enabled']) + self.assertFalse(resp.json["enabled"]) # Enable sensor data = copy.deepcopy(sensor_data) - data['enabled'] = True - put_resp = self.app.put_json('/v1/sensortypes/dummy_pack_1.SampleSensor', data) + data["enabled"] = True + put_resp = self.app.put_json("/v1/sensortypes/dummy_pack_1.SampleSensor", data) self.assertEqual(put_resp.status_int, http_client.OK) - self.assertTrue(put_resp.json['enabled']) + self.assertTrue(put_resp.json["enabled"]) # Verify sensor has been enabled - resp = self.app.get('/v1/sensortypes/dummy_pack_1.SampleSensor') + resp = self.app.get("/v1/sensortypes/dummy_pack_1.SampleSensor") self.assertEqual(resp.status_int, http_client.OK) - self.assertTrue(resp.json['enabled']) + self.assertTrue(resp.json["enabled"]) diff --git a/st2api/tests/unit/controllers/v1/test_service_registry.py b/st2api/tests/unit/controllers/v1/test_service_registry.py index efeb7d432ad..d195c2361eb 100644 --- a/st2api/tests/unit/controllers/v1/test_service_registry.py +++ b/st2api/tests/unit/controllers/v1/test_service_registry.py @@ -22,9 +22,7 @@ from st2tests.api import FunctionalTest -__all__ = [ - 'ServiceyRegistryControllerTestCase' -] +__all__ = ["ServiceyRegistryControllerTestCase"] class ServiceyRegistryControllerTestCase(FunctionalTest): @@ -41,10 +39,11 @@ def setUpClass(cls): # NOTE: We mock call common_setup to emulate service being registered in the service # registry during bootstrap phase - register_service_in_service_registry(service='mock_service', - capabilities={'key1': 'value1', - 'name': 'mock_service'}, - start_heart=True) + register_service_in_service_registry( + service="mock_service", + capabilities={"key1": "value1", "name": "mock_service"}, + start_heart=True, + ) @classmethod def tearDownClass(cls): @@ -53,33 +52,40 @@ def tearDownClass(cls): coordination.coordinator_teardown(cls.coordinator) def test_get_groups(self): - list_resp = self.app.get('/v1/service_registry/groups') + list_resp = self.app.get("/v1/service_registry/groups") self.assertEqual(list_resp.status_int, 200) - self.assertEqual(list_resp.json, {'groups': ['mock_service']}) + self.assertEqual(list_resp.json, {"groups": ["mock_service"]}) def test_get_group_members(self): proc_info = system_info.get_process_info() member_id = get_member_id() # 1. Group doesn't exist - resp = self.app.get('/v1/service_registry/groups/doesnt-exist/members', expect_errors=True) + resp = self.app.get( + "/v1/service_registry/groups/doesnt-exist/members", expect_errors=True + ) self.assertEqual(resp.status_int, 404) - self.assertEqual(resp.json['faultstring'], 'Group with ID "doesnt-exist" not found.') + self.assertEqual( + resp.json["faultstring"], 'Group with ID "doesnt-exist" not found.' + ) # 2. Group exists and has a single member - resp = self.app.get('/v1/service_registry/groups/mock_service/members') + resp = self.app.get("/v1/service_registry/groups/mock_service/members") self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json, { - 'members': [ - { - 'group_id': 'mock_service', - 'member_id': member_id.decode('utf-8'), - 'capabilities': { - 'key1': 'value1', - 'name': 'mock_service', - 'hostname': proc_info['hostname'], - 'pid': proc_info['pid'] + self.assertEqual( + resp.json, + { + "members": [ + { + "group_id": "mock_service", + "member_id": member_id.decode("utf-8"), + "capabilities": { + "key1": "value1", + "name": "mock_service", + "hostname": proc_info["hostname"], + "pid": proc_info["pid"], + }, } - } - ] - }) + ] + }, + ) diff --git a/st2api/tests/unit/controllers/v1/test_timers.py b/st2api/tests/unit/controllers/v1/test_timers.py index 492c231b100..cb578445391 100644 --- a/st2api/tests/unit/controllers/v1/test_timers.py +++ b/st2api/tests/unit/controllers/v1/test_timers.py @@ -17,20 +17,29 @@ import st2common.services.triggers as trigger_service -with mock.patch.object(trigger_service, 'create_trigger_type_db', mock.MagicMock()): +with mock.patch.object(trigger_service, "create_trigger_type_db", mock.MagicMock()): from st2api.controllers.v1.timers import TimersHolder from st2common.models.system.common import ResourceReference from st2tests.base import DbTestCase from st2tests.fixturesloader import FixturesLoader -from st2common.constants.triggers import INTERVAL_TIMER_TRIGGER_REF, DATE_TIMER_TRIGGER_REF +from st2common.constants.triggers import ( + INTERVAL_TIMER_TRIGGER_REF, + DATE_TIMER_TRIGGER_REF, +) from st2common.constants.triggers import CRON_TIMER_TRIGGER_REF from st2tests.api import FunctionalTest -PACK = 'timers' +PACK = "timers" FIXTURES = { - 'triggers': ['cron1.yaml', 'date1.yaml', 'interval1.yaml', 'interval2.yaml', 'interval3.yaml'] + "triggers": [ + "cron1.yaml", + "date1.yaml", + "interval1.yaml", + "interval2.yaml", + "interval3.yaml", + ] } @@ -43,23 +52,28 @@ def setUpClass(cls): loader = FixturesLoader() TestTimersHolder.MODELS = loader.load_fixtures( - fixtures_pack=PACK, fixtures_dict=FIXTURES)['triggers'] + fixtures_pack=PACK, fixtures_dict=FIXTURES + )["triggers"] loader.save_fixtures_to_db(fixtures_pack=PACK, fixtures_dict=FIXTURES) def test_add_trigger(self): holder = TimersHolder() for _, model in TestTimersHolder.MODELS.items(): holder.add_trigger( - ref=ResourceReference.to_string_reference(pack=model['pack'], name=model['name']), - trigger=model + ref=ResourceReference.to_string_reference( + pack=model["pack"], name=model["name"] + ), + trigger=model, ) self.assertEqual(len(holder._timers), 5) def test_remove_trigger(self): holder = TimersHolder() - model = TestTimersHolder.MODELS.get('cron1.yaml', None) + model = TestTimersHolder.MODELS.get("cron1.yaml", None) self.assertIsNotNone(model) - ref = ResourceReference.to_string_reference(pack=model['pack'], name=model['name']) + ref = ResourceReference.to_string_reference( + pack=model["pack"], name=model["name"] + ) holder.add_trigger(ref, model) self.assertEqual(len(holder._timers), 1) holder.remove_trigger(ref, model) @@ -69,8 +83,10 @@ def test_get_all(self): holder = TimersHolder() for _, model in TestTimersHolder.MODELS.items(): holder.add_trigger( - ref=ResourceReference.to_string_reference(pack=model['pack'], name=model['name']), - trigger=model + ref=ResourceReference.to_string_reference( + pack=model["pack"], name=model["name"] + ), + trigger=model, ) self.assertEqual(len(holder.get_all()), 5) @@ -78,8 +94,10 @@ def test_get_all_filters_filter_by_type(self): holder = TimersHolder() for _, model in TestTimersHolder.MODELS.items(): holder.add_trigger( - ref=ResourceReference.to_string_reference(pack=model['pack'], name=model['name']), - trigger=model + ref=ResourceReference.to_string_reference( + pack=model["pack"], name=model["name"] + ), + trigger=model, ) self.assertEqual(len(holder.get_all(timer_type=INTERVAL_TIMER_TRIGGER_REF)), 3) self.assertEqual(len(holder.get_all(timer_type=DATE_TIMER_TRIGGER_REF)), 1) @@ -95,20 +113,23 @@ def setUpClass(cls): loader = FixturesLoader() TestTimersController.MODELS = loader.save_fixtures_to_db( - fixtures_pack=PACK, fixtures_dict=FIXTURES)['triggers'] + fixtures_pack=PACK, fixtures_dict=FIXTURES + )["triggers"] def test_timerscontroller_get_one_with_id(self): - model = TestTimersController.MODELS['interval1.yaml'] + model = TestTimersController.MODELS["interval1.yaml"] get_resp = self._do_get_one(model.id) self.assertEqual(get_resp.status_int, 200) - self.assertEqual(get_resp.json['parameters'], model['parameters']) + self.assertEqual(get_resp.json["parameters"], model["parameters"]) def test_timerscontroller_get_one_with_ref(self): - model = TestTimersController.MODELS['interval1.yaml'] - ref = ResourceReference.to_string_reference(pack=model['pack'], name=model['name']) + model = TestTimersController.MODELS["interval1.yaml"] + ref = ResourceReference.to_string_reference( + pack=model["pack"], name=model["name"] + ) get_resp = self._do_get_one(ref) self.assertEqual(get_resp.status_int, 200) - self.assertEqual(get_resp.json['parameters'], model['parameters']) + self.assertEqual(get_resp.json["parameters"], model["parameters"]) def _do_get_one(self, timer_id, expect_errors=False): - return self.app.get('/v1/timers/%s' % timer_id, expect_errors=expect_errors) + return self.app.get("/v1/timers/%s" % timer_id, expect_errors=expect_errors) diff --git a/st2api/tests/unit/controllers/v1/test_traces.py b/st2api/tests/unit/controllers/v1/test_traces.py index 0ce16a2a292..79bbdad6aef 100644 --- a/st2api/tests/unit/controllers/v1/test_traces.py +++ b/st2api/tests/unit/controllers/v1/test_traces.py @@ -19,23 +19,24 @@ from st2tests.api import FunctionalTest from st2tests.api import APIControllerWithIncludeAndExcludeFilterTestCase -FIXTURES_PACK = 'traces' +FIXTURES_PACK = "traces" TEST_MODELS = { - 'traces': [ - 'trace_empty.yaml', - 'trace_one_each.yaml', - 'trace_multiple_components.yaml' + "traces": [ + "trace_empty.yaml", + "trace_one_each.yaml", + "trace_multiple_components.yaml", ] } -class TracesControllerTestCase(FunctionalTest, - APIControllerWithIncludeAndExcludeFilterTestCase): - get_all_path = '/v1/traces' +class TracesControllerTestCase( + FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase +): + get_all_path = "/v1/traces" controller_cls = TracesController - include_attribute_field_name = 'trace_tag' - exclude_attribute_field_name = 'start_timestamp' + include_attribute_field_name = "trace_tag" + exclude_attribute_field_name = "start_timestamp" models = None trace1 = None @@ -45,112 +46,145 @@ class TracesControllerTestCase(FunctionalTest, @classmethod def setUpClass(cls): super(TracesControllerTestCase, cls).setUpClass() - cls.models = FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_MODELS) - cls.trace1 = cls.models['traces']['trace_empty.yaml'] - cls.trace2 = cls.models['traces']['trace_one_each.yaml'] - cls.trace3 = cls.models['traces']['trace_multiple_components.yaml'] + cls.models = FixturesLoader().save_fixtures_to_db( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS + ) + cls.trace1 = cls.models["traces"]["trace_empty.yaml"] + cls.trace2 = cls.models["traces"]["trace_one_each.yaml"] + cls.trace3 = cls.models["traces"]["trace_multiple_components.yaml"] def test_get_all_and_minus_one(self): - resp = self.app.get('/v1/traces') + resp = self.app.get("/v1/traces") self.assertEqual(resp.status_int, 200) - self.assertEqual(len(resp.json), 3, '/v1/traces did not return all traces.') + self.assertEqual(len(resp.json), 3, "/v1/traces did not return all traces.") # Note: traces are returned sorted by start_timestamp in descending order by default - retrieved_trace_tags = [trace['trace_tag'] for trace in resp.json] - self.assertEqual(retrieved_trace_tags, - [self.trace3.trace_tag, self.trace2.trace_tag, self.trace1.trace_tag], - 'Incorrect traces retrieved.') - - resp = self.app.get('/v1/traces/?limit=-1') + retrieved_trace_tags = [trace["trace_tag"] for trace in resp.json] + self.assertEqual( + retrieved_trace_tags, + [self.trace3.trace_tag, self.trace2.trace_tag, self.trace1.trace_tag], + "Incorrect traces retrieved.", + ) + + resp = self.app.get("/v1/traces/?limit=-1") self.assertEqual(resp.status_int, 200) - self.assertEqual(len(resp.json), 3, '/v1/traces did not return all traces.') + self.assertEqual(len(resp.json), 3, "/v1/traces did not return all traces.") # Note: traces are returned sorted by start_timestamp in descending order by default - retrieved_trace_tags = [trace['trace_tag'] for trace in resp.json] - self.assertEqual(retrieved_trace_tags, - [self.trace3.trace_tag, self.trace2.trace_tag, self.trace1.trace_tag], - 'Incorrect traces retrieved.') + retrieved_trace_tags = [trace["trace_tag"] for trace in resp.json] + self.assertEqual( + retrieved_trace_tags, + [self.trace3.trace_tag, self.trace2.trace_tag, self.trace1.trace_tag], + "Incorrect traces retrieved.", + ) def test_get_all_ascending_and_descending(self): - resp = self.app.get('/v1/traces?sort_asc=True') + resp = self.app.get("/v1/traces?sort_asc=True") self.assertEqual(resp.status_int, 200) - self.assertEqual(len(resp.json), 3, '/v1/traces did not return all traces.') + self.assertEqual(len(resp.json), 3, "/v1/traces did not return all traces.") - retrieved_trace_tags = [trace['trace_tag'] for trace in resp.json] - self.assertEqual(retrieved_trace_tags, - [self.trace1.trace_tag, self.trace2.trace_tag, self.trace3.trace_tag], - 'Incorrect traces retrieved.') + retrieved_trace_tags = [trace["trace_tag"] for trace in resp.json] + self.assertEqual( + retrieved_trace_tags, + [self.trace1.trace_tag, self.trace2.trace_tag, self.trace3.trace_tag], + "Incorrect traces retrieved.", + ) - resp = self.app.get('/v1/traces?sort_desc=True') + resp = self.app.get("/v1/traces?sort_desc=True") self.assertEqual(resp.status_int, 200) - self.assertEqual(len(resp.json), 3, '/v1/traces did not return all traces.') + self.assertEqual(len(resp.json), 3, "/v1/traces did not return all traces.") - retrieved_trace_tags = [trace['trace_tag'] for trace in resp.json] - self.assertEqual(retrieved_trace_tags, - [self.trace3.trace_tag, self.trace2.trace_tag, self.trace1.trace_tag], - 'Incorrect traces retrieved.') + retrieved_trace_tags = [trace["trace_tag"] for trace in resp.json] + self.assertEqual( + retrieved_trace_tags, + [self.trace3.trace_tag, self.trace2.trace_tag, self.trace1.trace_tag], + "Incorrect traces retrieved.", + ) def test_get_all_limit(self): - resp = self.app.get('/v1/traces?limit=1') + resp = self.app.get("/v1/traces?limit=1") self.assertEqual(resp.status_int, 200) - self.assertEqual(len(resp.json), 1, '/v1/traces did not return all traces.') + self.assertEqual(len(resp.json), 1, "/v1/traces did not return all traces.") - retrieved_trace_tags = [trace['trace_tag'] for trace in resp.json] - self.assertEqual(retrieved_trace_tags, - [self.trace3.trace_tag], 'Incorrect traces retrieved.') + retrieved_trace_tags = [trace["trace_tag"] for trace in resp.json] + self.assertEqual( + retrieved_trace_tags, [self.trace3.trace_tag], "Incorrect traces retrieved." + ) def test_get_all_limit_negative_number(self): - resp = self.app.get('/v1/traces?limit=-22', expect_errors=True) + resp = self.app.get("/v1/traces?limit=-22", expect_errors=True) self.assertEqual(resp.status_int, 400) - self.assertEqual(resp.json['faultstring'], - u'Limit, "-22" specified, must be a positive number.') + self.assertEqual( + resp.json["faultstring"], + 'Limit, "-22" specified, must be a positive number.', + ) def test_get_by_id(self): - resp = self.app.get('/v1/traces/%s' % self.trace1.id) + resp = self.app.get("/v1/traces/%s" % self.trace1.id) self.assertEqual(resp.status_int, 200) - self.assertEqual(resp.json['id'], str(self.trace1.id), - 'Incorrect trace retrieved.') + self.assertEqual( + resp.json["id"], str(self.trace1.id), "Incorrect trace retrieved." + ) def test_query_by_trace_tag(self): - resp = self.app.get('/v1/traces?trace_tag=test-trace-1') + resp = self.app.get("/v1/traces?trace_tag=test-trace-1") self.assertEqual(resp.status_int, 200) - self.assertEqual(len(resp.json), 1, '/v1/traces?trace_tag=x did not return correct trace.') + self.assertEqual( + len(resp.json), 1, "/v1/traces?trace_tag=x did not return correct trace." + ) - self.assertEqual(resp.json[0]['trace_tag'], self.trace1['trace_tag'], - 'Correct trace not returned.') + self.assertEqual( + resp.json[0]["trace_tag"], + self.trace1["trace_tag"], + "Correct trace not returned.", + ) def test_query_by_action_execution(self): - execution_id = self.trace3['action_executions'][0].object_id - resp = self.app.get('/v1/traces?execution=%s' % execution_id) + execution_id = self.trace3["action_executions"][0].object_id + resp = self.app.get("/v1/traces?execution=%s" % execution_id) self.assertEqual(resp.status_int, 200) - self.assertEqual(len(resp.json), 1, - '/v1/traces?execution=x did not return correct trace.') - self.assertEqual(resp.json[0]['trace_tag'], self.trace3['trace_tag'], - 'Correct trace not returned.') + self.assertEqual( + len(resp.json), 1, "/v1/traces?execution=x did not return correct trace." + ) + self.assertEqual( + resp.json[0]["trace_tag"], + self.trace3["trace_tag"], + "Correct trace not returned.", + ) def test_query_by_rule(self): - rule_id = self.trace3['rules'][0].object_id - resp = self.app.get('/v1/traces?rule=%s' % rule_id) + rule_id = self.trace3["rules"][0].object_id + resp = self.app.get("/v1/traces?rule=%s" % rule_id) self.assertEqual(resp.status_int, 200) - self.assertEqual(len(resp.json), 1, '/v1/traces?rule=x did not return correct trace.') - self.assertEqual(resp.json[0]['trace_tag'], self.trace3['trace_tag'], - 'Correct trace not returned.') + self.assertEqual( + len(resp.json), 1, "/v1/traces?rule=x did not return correct trace." + ) + self.assertEqual( + resp.json[0]["trace_tag"], + self.trace3["trace_tag"], + "Correct trace not returned.", + ) def test_query_by_trigger_instance(self): - trigger_instance_id = self.trace3['trigger_instances'][0].object_id - resp = self.app.get('/v1/traces?trigger_instance=%s' % trigger_instance_id) + trigger_instance_id = self.trace3["trigger_instances"][0].object_id + resp = self.app.get("/v1/traces?trigger_instance=%s" % trigger_instance_id) self.assertEqual(resp.status_int, 200) - self.assertEqual(len(resp.json), 1, - '/v1/traces?trigger_instance=x did not return correct trace.') - self.assertEqual(resp.json[0]['trace_tag'], self.trace3['trace_tag'], - 'Correct trace not returned.') + self.assertEqual( + len(resp.json), + 1, + "/v1/traces?trigger_instance=x did not return correct trace.", + ) + self.assertEqual( + resp.json[0]["trace_tag"], + self.trace3["trace_tag"], + "Correct trace not returned.", + ) def _insert_mock_models(self): - trace_ids = [trace['id'] for trace in self.models['traces'].values()] + trace_ids = [trace["id"] for trace in self.models["traces"].values()] return trace_ids def _delete_mock_models(self, object_ids): diff --git a/st2api/tests/unit/controllers/v1/test_triggerinstances.py b/st2api/tests/unit/controllers/v1/test_triggerinstances.py index 0d81de723da..2a4149707c8 100644 --- a/st2api/tests/unit/controllers/v1/test_triggerinstances.py +++ b/st2api/tests/unit/controllers/v1/test_triggerinstances.py @@ -31,13 +31,14 @@ http_client = six.moves.http_client -@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) -class TriggerInstanceTestCase(FunctionalTest, - APIControllerWithIncludeAndExcludeFilterTestCase): - get_all_path = '/v1/triggerinstances' +@mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) +class TriggerInstanceTestCase( + FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase +): + get_all_path = "/v1/triggerinstances" controller_cls = TriggerInstanceController - include_attribute_field_name = 'trigger' - exclude_attribute_field_name = 'payload' + include_attribute_field_name = "trigger" + exclude_attribute_field_name = "payload" @classmethod def setUpClass(cls): @@ -47,74 +48,84 @@ def setUpClass(cls): cls._setupTriggerInstance() def test_get_all(self): - resp = self.app.get('/v1/triggerinstances') + resp = self.app.get("/v1/triggerinstances") self.assertEqual(resp.status_int, http_client.OK) - self.assertEqual(len(resp.json), self.triggerinstance_count, 'Get all failure.') + self.assertEqual(len(resp.json), self.triggerinstance_count, "Get all failure.") def test_get_all_limit(self): limit = 1 - resp = self.app.get('/v1/triggerinstances?limit=%d' % limit) + resp = self.app.get("/v1/triggerinstances?limit=%d" % limit) self.assertEqual(resp.status_int, http_client.OK) - self.assertEqual(len(resp.json), limit, 'Get all failure. Length doesn\'t match limit.') + self.assertEqual( + len(resp.json), limit, "Get all failure. Length doesn't match limit." + ) def test_get_all_limit_negative_number(self): limit = -22 - resp = self.app.get('/v1/triggerinstances?limit=%d' % limit, expect_errors=True) + resp = self.app.get("/v1/triggerinstances?limit=%d" % limit, expect_errors=True) self.assertEqual(resp.status_int, 400) - self.assertEqual(resp.json['faultstring'], - u'Limit, "-22" specified, must be a positive number.') + self.assertEqual( + resp.json["faultstring"], + 'Limit, "-22" specified, must be a positive number.', + ) def test_get_all_filter_by_trigger(self): - trigger = 'dummy_pack_1.st2.test.trigger0' - resp = self.app.get('/v1/triggerinstances?trigger=%s' % trigger) + trigger = "dummy_pack_1.st2.test.trigger0" + resp = self.app.get("/v1/triggerinstances?trigger=%s" % trigger) self.assertEqual(resp.status_int, http_client.OK) - self.assertEqual(len(resp.json), 1, 'Get all failure. Must get only one such instance.') + self.assertEqual( + len(resp.json), 1, "Get all failure. Must get only one such instance." + ) def test_get_all_filter_by_timestamp(self): - resp = self.app.get('/v1/triggerinstances') + resp = self.app.get("/v1/triggerinstances") self.assertEqual(resp.status_int, http_client.OK) - timestamp_largest = resp.json[0]['occurrence_time'] - timestamp_middle = resp.json[1]['occurrence_time'] + timestamp_largest = resp.json[0]["occurrence_time"] + timestamp_middle = resp.json[1]["occurrence_time"] dt = isotime.parse(timestamp_largest) dt = dt + datetime.timedelta(seconds=1) timestamp_largest = isotime.format(dt, offset=False) - resp = self.app.get('/v1/triggerinstances?timestamp_gt=%s' % timestamp_largest) + resp = self.app.get("/v1/triggerinstances?timestamp_gt=%s" % timestamp_largest) # Since we sort trigger instances by time (latest first), the previous # get should return no trigger instances. self.assertEqual(len(resp.json), 0) - resp = self.app.get('/v1/triggerinstances?timestamp_lt=%s' % (timestamp_middle)) + resp = self.app.get("/v1/triggerinstances?timestamp_lt=%s" % (timestamp_middle)) self.assertEqual(len(resp.json), 1) def test_get_all_trigger_type_ref_filtering(self): # 1. Invalid / inexistent trigger type ref - resp = self.app.get('/v1/triggerinstances?trigger_type=foo.bar.invalid') + resp = self.app.get("/v1/triggerinstances?trigger_type=foo.bar.invalid") self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 0) # 2. Valid trigger type ref with corresponding trigger instances - resp = self.app.get('/v1/triggerinstances?trigger_type=dummy_pack_1.st2.test.triggertype0') + resp = self.app.get( + "/v1/triggerinstances?trigger_type=dummy_pack_1.st2.test.triggertype0" + ) self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 1) # 3. Valid trigger type ref with no corresponding trigger instances - resp = self.app.get('/v1/triggerinstances?trigger_type=dummy_pack_1.st2.test.triggertype3') + resp = self.app.get( + "/v1/triggerinstances?trigger_type=dummy_pack_1.st2.test.triggertype3" + ) self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 0) def test_reemit_trigger_instance(self): - resp = self.app.get('/v1/triggerinstances') + resp = self.app.get("/v1/triggerinstances") self.assertEqual(resp.status_int, http_client.OK) - instance_id = resp.json[0]['id'] - resp = self.app.post('/v1/triggerinstances/%s/re_emit' % instance_id) + instance_id = resp.json[0]["id"] + resp = self.app.post("/v1/triggerinstances/%s/re_emit" % instance_id) self.assertEqual(resp.status_int, http_client.OK) - resent_message = resp.json['message'] - resent_payload = resp.json['payload'] + resent_message = resp.json["message"] + resent_payload = resp.json["payload"] self.assertIn(instance_id, resent_message) - self.assertIn('__context', resent_payload) - self.assertEqual(resent_payload['__context']['original_id'], instance_id) + self.assertIn("__context", resent_payload) + self.assertEqual(resent_payload["__context"]["original_id"], instance_id) def test_get_one(self): triggerinstance_id = str(self.triggerinstance_1.id) @@ -133,79 +144,78 @@ def test_get_one(self): self.assertEqual(self._get_id(resp), triggerinstance_id) def test_get_one_fail(self): - resp = self._do_get_one('1') + resp = self._do_get_one("1") self.assertEqual(resp.status_int, http_client.NOT_FOUND) @classmethod def _setupTriggerTypes(cls): TRIGGERTYPE_0 = { - 'name': 'st2.test.triggertype0', - 'pack': 'dummy_pack_1', - 'description': 'test trigger', - 'payload_schema': {'tp1': None, 'tp2': None, 'tp3': None}, - 'parameters_schema': {} + "name": "st2.test.triggertype0", + "pack": "dummy_pack_1", + "description": "test trigger", + "payload_schema": {"tp1": None, "tp2": None, "tp3": None}, + "parameters_schema": {}, } TRIGGERTYPE_1 = { - 'name': 'st2.test.triggertype1', - 'pack': 'dummy_pack_1', - 'description': 'test trigger', - 'payload_schema': {'tp1': None, 'tp2': None, 'tp3': None}, + "name": "st2.test.triggertype1", + "pack": "dummy_pack_1", + "description": "test trigger", + "payload_schema": {"tp1": None, "tp2": None, "tp3": None}, } TRIGGERTYPE_2 = { - 'name': 'st2.test.triggertype2', - 'pack': 'dummy_pack_1', - 'description': 'test trigger', - 'payload_schema': {'tp1': None, 'tp2': None, 'tp3': None}, - 'parameters_schema': {'param1': {'type': 'object'}} + "name": "st2.test.triggertype2", + "pack": "dummy_pack_1", + "description": "test trigger", + "payload_schema": {"tp1": None, "tp2": None, "tp3": None}, + "parameters_schema": {"param1": {"type": "object"}}, } TRIGGERTYPE_3 = { - 'name': 'st2.test.triggertype3', - 'pack': 'dummy_pack_1', - 'description': 'test trigger', - 'payload_schema': {'tp1': None, 'tp2': None, 'tp3': None}, - 'parameters_schema': {'param1': {'type': 'object'}} + "name": "st2.test.triggertype3", + "pack": "dummy_pack_1", + "description": "test trigger", + "payload_schema": {"tp1": None, "tp2": None, "tp3": None}, + "parameters_schema": {"param1": {"type": "object"}}, } - cls.app.post_json('/v1/triggertypes', TRIGGERTYPE_0, expect_errors=False) - cls.app.post_json('/v1/triggertypes', TRIGGERTYPE_1, expect_errors=False) - cls.app.post_json('/v1/triggertypes', TRIGGERTYPE_2, expect_errors=False) - cls.app.post_json('/v1/triggertypes', TRIGGERTYPE_3, expect_errors=False) + cls.app.post_json("/v1/triggertypes", TRIGGERTYPE_0, expect_errors=False) + cls.app.post_json("/v1/triggertypes", TRIGGERTYPE_1, expect_errors=False) + cls.app.post_json("/v1/triggertypes", TRIGGERTYPE_2, expect_errors=False) + cls.app.post_json("/v1/triggertypes", TRIGGERTYPE_3, expect_errors=False) @classmethod def _setupTriggers(cls): TRIGGER_0 = { - 'name': 'st2.test.trigger0', - 'pack': 'dummy_pack_1', - 'description': 'test trigger', - 'type': 'dummy_pack_1.st2.test.triggertype0', - 'parameters': {} + "name": "st2.test.trigger0", + "pack": "dummy_pack_1", + "description": "test trigger", + "type": "dummy_pack_1.st2.test.triggertype0", + "parameters": {}, } TRIGGER_1 = { - 'name': 'st2.test.trigger1', - 'pack': 'dummy_pack_1', - 'description': 'test trigger', - 'type': 'dummy_pack_1.st2.test.triggertype1', - 'parameters': {} + "name": "st2.test.trigger1", + "pack": "dummy_pack_1", + "description": "test trigger", + "type": "dummy_pack_1.st2.test.triggertype1", + "parameters": {}, } TRIGGER_2 = { - 'name': 'st2.test.trigger2', - 'pack': 'dummy_pack_1', - 'description': 'test trigger', - 'type': 'dummy_pack_1.st2.test.triggertype2', - 'parameters': { - 'param1': { - 'foo': 'bar' - } - } + "name": "st2.test.trigger2", + "pack": "dummy_pack_1", + "description": "test trigger", + "type": "dummy_pack_1.st2.test.triggertype2", + "parameters": {"param1": {"foo": "bar"}}, } - cls.app.post_json('/v1/triggers', TRIGGER_0, expect_errors=False) - cls.app.post_json('/v1/triggers', TRIGGER_1, expect_errors=False) - cls.app.post_json('/v1/triggers', TRIGGER_2, expect_errors=False) + cls.app.post_json("/v1/triggers", TRIGGER_0, expect_errors=False) + cls.app.post_json("/v1/triggers", TRIGGER_1, expect_errors=False) + cls.app.post_json("/v1/triggers", TRIGGER_2, expect_errors=False) def _insert_mock_models(self): - return [self.triggerinstance_1['id'], self.triggerinstance_2['id'], - self.triggerinstance_3['id']] + return [ + self.triggerinstance_1["id"], + self.triggerinstance_2["id"], + self.triggerinstance_3["id"], + ] def _delete_mock_models(self, object_ids): return None @@ -214,17 +224,20 @@ def _delete_mock_models(self, object_ids): def _setupTriggerInstance(cls): cls.triggerinstance_count = 0 cls.triggerinstance_1 = cls._create_trigger_instance( - trigger_ref='dummy_pack_1.st2.test.trigger0', - payload={'tp1': 1, 'tp2': 2, 'tp3': 3}, - seconds=1) + trigger_ref="dummy_pack_1.st2.test.trigger0", + payload={"tp1": 1, "tp2": 2, "tp3": 3}, + seconds=1, + ) cls.triggerinstance_2 = cls._create_trigger_instance( - trigger_ref='dummy_pack_1.st2.test.trigger1', - payload={'tp1': 'a', 'tp2': 'b', 'tp3': 'c'}, - seconds=2) + trigger_ref="dummy_pack_1.st2.test.trigger1", + payload={"tp1": "a", "tp2": "b", "tp3": "c"}, + seconds=2, + ) cls.triggerinstance_3 = cls._create_trigger_instance( - trigger_ref='dummy_pack_1.st2.test.trigger2', - payload={'tp1': None, 'tp2': None, 'tp3': None}, - seconds=3) + trigger_ref="dummy_pack_1.st2.test.trigger2", + payload={"tp1": None, "tp2": None, "tp3": None}, + seconds=3, + ) @classmethod def _create_trigger_instance(cls, trigger_ref, payload, seconds): @@ -244,7 +257,9 @@ def _create_trigger_instance(cls, trigger_ref, payload, seconds): @staticmethod def _get_id(resp): - return resp.json['id'] + return resp.json["id"] def _do_get_one(self, triggerinstance_id): - return self.app.get('/v1/triggerinstances/%s' % triggerinstance_id, expect_errors=True) + return self.app.get( + "/v1/triggerinstances/%s" % triggerinstance_id, expect_errors=True + ) diff --git a/st2api/tests/unit/controllers/v1/test_triggers.py b/st2api/tests/unit/controllers/v1/test_triggers.py index d3526e624a1..5067c7674f6 100644 --- a/st2api/tests/unit/controllers/v1/test_triggers.py +++ b/st2api/tests/unit/controllers/v1/test_triggers.py @@ -22,57 +22,52 @@ http_client = six.moves.http_client TRIGGER_0 = { - 'name': 'st2.test.trigger0', - 'pack': 'dummy_pack_1', - 'description': 'test trigger', - 'type': 'dummy_pack_1.st2.test.triggertype0', - 'parameters': {} + "name": "st2.test.trigger0", + "pack": "dummy_pack_1", + "description": "test trigger", + "type": "dummy_pack_1.st2.test.triggertype0", + "parameters": {}, } TRIGGER_1 = { - 'name': 'st2.test.trigger1', - 'pack': 'dummy_pack_1', - 'description': 'test trigger', - 'type': 'dummy_pack_1.st2.test.triggertype1', - 'parameters': {} + "name": "st2.test.trigger1", + "pack": "dummy_pack_1", + "description": "test trigger", + "type": "dummy_pack_1.st2.test.triggertype1", + "parameters": {}, } TRIGGER_2 = { - 'name': 'st2.test.trigger2', - 'pack': 'dummy_pack_1', - 'description': 'test trigger', - 'type': 'dummy_pack_1.st2.test.triggertype2', - 'parameters': { - 'param1': { - 'foo': 'bar' - } - } + "name": "st2.test.trigger2", + "pack": "dummy_pack_1", + "description": "test trigger", + "type": "dummy_pack_1.st2.test.triggertype2", + "parameters": {"param1": {"foo": "bar"}}, } -@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) +@mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) class TestTriggerController(FunctionalTest): - @classmethod def setUpClass(cls): super(TestTriggerController, cls).setUpClass() cls._setupTriggerTypes() def test_get_all(self): - resp = self.app.get('/v1/triggers') + resp = self.app.get("/v1/triggers") self.assertEqual(resp.status_int, http_client.OK) # TriggerType without parameters will register a trigger # with same name. - self.assertEqual(len(resp.json), 2, 'Get all failure. %s' % resp.json) + self.assertEqual(len(resp.json), 2, "Get all failure. %s" % resp.json) post_resp = self._do_post(TRIGGER_0) trigger_id_0 = self._get_trigger_id(post_resp) post_resp = self._do_post(TRIGGER_1) trigger_id_1 = self._get_trigger_id(post_resp) - resp = self.app.get('/v1/triggers') + resp = self.app.get("/v1/triggers") self.assertEqual(resp.status_int, http_client.OK) # TriggerType without parameters will register a trigger # with same name. So here we see 4 instead of 2. - self.assertEqual(len(resp.json), 4, 'Get all failure.') + self.assertEqual(len(resp.json), 4, "Get all failure.") self._do_delete(trigger_id_0) self._do_delete(trigger_id_1) @@ -85,7 +80,7 @@ def test_get_one(self): self._do_delete(trigger_id) def test_get_one_fail(self): - resp = self._do_get_one('1') + resp = self._do_get_one("1") self.assertEqual(resp.status_int, http_client.NOT_FOUND) def test_post(self): @@ -106,13 +101,15 @@ def test_post_duplicate(self): # id is same in both cases. post_resp_2 = self._do_post(TRIGGER_1) self.assertEqual(post_resp_2.status_int, http_client.CREATED) - self.assertEqual(self._get_trigger_id(post_resp), self._get_trigger_id(post_resp_2)) + self.assertEqual( + self._get_trigger_id(post_resp), self._get_trigger_id(post_resp_2) + ) self._do_delete(self._get_trigger_id(post_resp)) def test_put(self): post_resp = self._do_post(TRIGGER_1) update_input = post_resp.json - update_input['description'] = 'updated description.' + update_input["description"] = "updated description." put_resp = self._do_put(self._get_trigger_id(post_resp), update_input) self.assertEqual(put_resp.status_int, http_client.OK) self._do_delete(self._get_trigger_id(put_resp)) @@ -133,41 +130,43 @@ def test_delete(self): @classmethod def _setupTriggerTypes(cls): TRIGGERTYPE_0 = { - 'name': 'st2.test.triggertype0', - 'pack': 'dummy_pack_1', - 'description': 'test trigger', - 'payload_schema': {'tp1': None, 'tp2': None, 'tp3': None}, - 'parameters_schema': {} + "name": "st2.test.triggertype0", + "pack": "dummy_pack_1", + "description": "test trigger", + "payload_schema": {"tp1": None, "tp2": None, "tp3": None}, + "parameters_schema": {}, } TRIGGERTYPE_1 = { - 'name': 'st2.test.triggertype1', - 'pack': 'dummy_pack_1', - 'description': 'test trigger', - 'payload_schema': {'tp1': None, 'tp2': None, 'tp3': None}, + "name": "st2.test.triggertype1", + "pack": "dummy_pack_1", + "description": "test trigger", + "payload_schema": {"tp1": None, "tp2": None, "tp3": None}, } TRIGGERTYPE_2 = { - 'name': 'st2.test.triggertype2', - 'pack': 'dummy_pack_1', - 'description': 'test trigger', - 'payload_schema': {'tp1': None, 'tp2': None, 'tp3': None}, - 'parameters_schema': {'param1': {'type': 'object'}} + "name": "st2.test.triggertype2", + "pack": "dummy_pack_1", + "description": "test trigger", + "payload_schema": {"tp1": None, "tp2": None, "tp3": None}, + "parameters_schema": {"param1": {"type": "object"}}, } - cls.app.post_json('/v1/triggertypes', TRIGGERTYPE_0, expect_errors=False) - cls.app.post_json('/v1/triggertypes', TRIGGERTYPE_1, expect_errors=False) - cls.app.post_json('/v1/triggertypes', TRIGGERTYPE_2, expect_errors=False) + cls.app.post_json("/v1/triggertypes", TRIGGERTYPE_0, expect_errors=False) + cls.app.post_json("/v1/triggertypes", TRIGGERTYPE_1, expect_errors=False) + cls.app.post_json("/v1/triggertypes", TRIGGERTYPE_2, expect_errors=False) @staticmethod def _get_trigger_id(resp): - return resp.json['id'] + return resp.json["id"] def _do_get_one(self, trigger_id): - return self.app.get('/v1/triggers/%s' % trigger_id, expect_errors=True) + return self.app.get("/v1/triggers/%s" % trigger_id, expect_errors=True) def _do_post(self, trigger): - return self.app.post_json('/v1/triggers', trigger, expect_errors=True) + return self.app.post_json("/v1/triggers", trigger, expect_errors=True) def _do_put(self, trigger_id, trigger): - return self.app.put_json('/v1/triggers/%s' % trigger_id, trigger, expect_errors=True) + return self.app.put_json( + "/v1/triggers/%s" % trigger_id, trigger, expect_errors=True + ) def _do_delete(self, trigger_id): - return self.app.delete('/v1/triggers/%s' % trigger_id) + return self.app.delete("/v1/triggers/%s" % trigger_id) diff --git a/st2api/tests/unit/controllers/v1/test_triggertypes.py b/st2api/tests/unit/controllers/v1/test_triggertypes.py index c7848f5c2dd..414fc343601 100644 --- a/st2api/tests/unit/controllers/v1/test_triggertypes.py +++ b/st2api/tests/unit/controllers/v1/test_triggertypes.py @@ -23,33 +23,34 @@ http_client = six.moves.http_client TRIGGER_0 = { - 'name': 'st2.test.triggertype0', - 'pack': 'dummy_pack_1', - 'description': 'test trigger', - 'payload_schema': {'tp1': None, 'tp2': None, 'tp3': None}, - 'parameters_schema': {} + "name": "st2.test.triggertype0", + "pack": "dummy_pack_1", + "description": "test trigger", + "payload_schema": {"tp1": None, "tp2": None, "tp3": None}, + "parameters_schema": {}, } TRIGGER_1 = { - 'name': 'st2.test.triggertype1', - 'pack': 'dummy_pack_2', - 'description': 'test trigger', - 'payload_schema': {'tp1': None, 'tp2': None, 'tp3': None}, + "name": "st2.test.triggertype1", + "pack": "dummy_pack_2", + "description": "test trigger", + "payload_schema": {"tp1": None, "tp2": None, "tp3": None}, } TRIGGER_2 = { - 'name': 'st2.test.triggertype3', - 'pack': 'dummy_pack_3', - 'description': 'test trigger', - 'payload_schema': {'tp1': None, 'tp2': None, 'tp3': None}, - 'parameters_schema': {'param1': {'type': 'object'}} + "name": "st2.test.triggertype3", + "pack": "dummy_pack_3", + "description": "test trigger", + "payload_schema": {"tp1": None, "tp2": None, "tp3": None}, + "parameters_schema": {"param1": {"type": "object"}}, } -class TriggerTypeControllerTestCase(FunctionalTest, - APIControllerWithIncludeAndExcludeFilterTestCase): - get_all_path = '/v1/triggertypes' +class TriggerTypeControllerTestCase( + FunctionalTest, APIControllerWithIncludeAndExcludeFilterTestCase +): + get_all_path = "/v1/triggertypes" controller_cls = TriggerTypeController - include_attribute_field_name = 'payload_schema' - exclude_attribute_field_name = 'parameters_schema' + include_attribute_field_name = "payload_schema" + exclude_attribute_field_name = "parameters_schema" @classmethod def setUpClass(cls): @@ -71,19 +72,19 @@ def test_get_all(self): trigger_id_0 = self.__get_trigger_id(post_resp) post_resp = self.__do_post(TRIGGER_1) trigger_id_1 = self.__get_trigger_id(post_resp) - resp = self.app.get('/v1/triggertypes') + resp = self.app.get("/v1/triggertypes") self.assertEqual(resp.status_int, http_client.OK) - self.assertEqual(len(resp.json), 2, 'Get all failure.') + self.assertEqual(len(resp.json), 2, "Get all failure.") # ?pack query filter - resp = self.app.get('/v1/triggertypes?pack=doesnt-exist-invalid') + resp = self.app.get("/v1/triggertypes?pack=doesnt-exist-invalid") self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 0) - resp = self.app.get('/v1/triggertypes?pack=%s' % (TRIGGER_0['pack'])) + resp = self.app.get("/v1/triggertypes?pack=%s" % (TRIGGER_0["pack"])) self.assertEqual(resp.status_int, http_client.OK) self.assertEqual(len(resp.json), 1) - self.assertEqual(resp.json[0]['pack'], TRIGGER_0['pack']) + self.assertEqual(resp.json[0]["pack"], TRIGGER_0["pack"]) self.__do_delete(trigger_id_0) self.__do_delete(trigger_id_1) @@ -97,7 +98,7 @@ def test_get_one(self): self.__do_delete(trigger_id) def test_get_one_fail(self): - resp = self.__do_get_one('1') + resp = self.__do_get_one("1") self.assertEqual(resp.status_int, http_client.NOT_FOUND) def test_post(self): @@ -116,13 +117,13 @@ def test_post_duplicate(self): self.assertEqual(post_resp.status_int, http_client.CREATED) post_resp_2 = self.__do_post(TRIGGER_1) self.assertEqual(post_resp_2.status_int, http_client.CONFLICT) - self.assertEqual(post_resp_2.json['conflict-id'], org_id) + self.assertEqual(post_resp_2.json["conflict-id"], org_id) self.__do_delete(org_id) def test_put(self): post_resp = self.__do_post(TRIGGER_1) update_input = post_resp.json - update_input['description'] = 'updated description.' + update_input["description"] = "updated description." put_resp = self.__do_put(self.__get_trigger_id(post_resp), update_input) self.assertEqual(put_resp.status_int, http_client.OK) self.__do_delete(self.__get_trigger_id(put_resp)) @@ -151,16 +152,18 @@ def _do_delete(self, trigger_id): @staticmethod def __get_trigger_id(resp): - return resp.json['id'] + return resp.json["id"] def __do_get_one(self, trigger_id): - return self.app.get('/v1/triggertypes/%s' % trigger_id, expect_errors=True) + return self.app.get("/v1/triggertypes/%s" % trigger_id, expect_errors=True) def __do_post(self, trigger): - return self.app.post_json('/v1/triggertypes', trigger, expect_errors=True) + return self.app.post_json("/v1/triggertypes", trigger, expect_errors=True) def __do_put(self, trigger_id, trigger): - return self.app.put_json('/v1/triggertypes/%s' % trigger_id, trigger, expect_errors=True) + return self.app.put_json( + "/v1/triggertypes/%s" % trigger_id, trigger, expect_errors=True + ) def __do_delete(self, trigger_id): - return self.app.delete('/v1/triggertypes/%s' % trigger_id) + return self.app.delete("/v1/triggertypes/%s" % trigger_id) diff --git a/st2api/tests/unit/controllers/v1/test_webhooks.py b/st2api/tests/unit/controllers/v1/test_webhooks.py index 487830a0922..e8fedc673ca 100644 --- a/st2api/tests/unit/controllers/v1/test_webhooks.py +++ b/st2api/tests/unit/controllers/v1/test_webhooks.py @@ -21,7 +21,7 @@ import st2common.services.triggers as trigger_service -with mock.patch.object(trigger_service, 'create_trigger_type_db', mock.MagicMock()): +with mock.patch.object(trigger_service, "create_trigger_type_db", mock.MagicMock()): from st2api.controllers.v1.webhooks import WebhooksController, HooksHolder from st2common.constants.triggers import WEBHOOK_TRIGGER_TYPES @@ -34,28 +34,20 @@ http_client = six.moves.http_client -WEBHOOK_1 = { - 'action': 'closed', - 'pull_request': { - 'merged': True - } -} +WEBHOOK_1 = {"action": "closed", "pull_request": {"merged": True}} ST2_WEBHOOK = { - 'trigger': 'git.pr-merged', - 'payload': { - 'value_str': 'string!', - 'value_int': 12345 - } + "trigger": "git.pr-merged", + "payload": {"value_str": "string!", "value_int": 12345}, } WEBHOOK_DATA = { - 'value_str': 'test string 1', - 'value_int': 987654, + "value_str": "test string 1", + "value_int": 987654, } # 1. Trigger which references a system webhook trigger type -DUMMY_TRIGGER_DB = TriggerDB(name='pr-merged', pack='git') +DUMMY_TRIGGER_DB = TriggerDB(name="pr-merged", pack="git") DUMMY_TRIGGER_DB.type = list(WEBHOOK_TRIGGER_TYPES.keys())[0] @@ -63,34 +55,24 @@ DUMMY_TRIGGER_DICT = vars(DUMMY_TRIGGER_API) # 2. Custom TriggerType object -DUMMY_TRIGGER_TYPE_DB = TriggerTypeDB(name='pr-merged', pack='git') +DUMMY_TRIGGER_TYPE_DB = TriggerTypeDB(name="pr-merged", pack="git") DUMMY_TRIGGER_TYPE_DB.payload_schema = { - 'type': 'object', - 'properties': { - 'body': { - 'properties': { - 'value_str': { - 'type': 'string', - 'required': True - }, - 'value_int': { - 'type': 'integer', - 'required': True - } + "type": "object", + "properties": { + "body": { + "properties": { + "value_str": {"type": "string", "required": True}, + "value_int": {"type": "integer", "required": True}, } } - } + }, } # 2. Custom TriggerType object -DUMMY_TRIGGER_TYPE_DB_2 = TriggerTypeDB(name='pr-merged', pack='git') +DUMMY_TRIGGER_TYPE_DB_2 = TriggerTypeDB(name="pr-merged", pack="git") DUMMY_TRIGGER_TYPE_DB_2.payload_schema = { - 'type': 'object', - 'properties': { - 'body': { - 'type': 'array' - } - } + "type": "object", + "properties": {"body": {"type": "array"}}, } @@ -100,190 +82,244 @@ def setUp(self): cfg.CONF.system.validate_trigger_payload = True - @mock.patch.object(TriggerInstancePublisher, 'publish_trigger', mock.MagicMock( - return_value=True)) - @mock.patch.object(WebhooksController, '_is_valid_hook', mock.MagicMock( - return_value=True)) - @mock.patch.object(HooksHolder, 'get_all', mock.MagicMock( - return_value=[DUMMY_TRIGGER_DICT])) + @mock.patch.object( + TriggerInstancePublisher, "publish_trigger", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + WebhooksController, "_is_valid_hook", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + HooksHolder, "get_all", mock.MagicMock(return_value=[DUMMY_TRIGGER_DICT]) + ) def test_get_all(self): - get_resp = self.app.get('/v1/webhooks', expect_errors=False) + get_resp = self.app.get("/v1/webhooks", expect_errors=False) self.assertEqual(get_resp.status_int, http_client.OK) self.assertEqual(get_resp.json, [DUMMY_TRIGGER_DICT]) - @mock.patch.object(TriggerInstancePublisher, 'publish_trigger', mock.MagicMock( - return_value=True)) - @mock.patch.object(WebhooksController, '_is_valid_hook', mock.MagicMock( - return_value=True)) - @mock.patch.object(HooksHolder, 'get_triggers_for_hook', mock.MagicMock( - return_value=[DUMMY_TRIGGER_DICT])) - @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch') + @mock.patch.object( + TriggerInstancePublisher, "publish_trigger", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + WebhooksController, "_is_valid_hook", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + HooksHolder, + "get_triggers_for_hook", + mock.MagicMock(return_value=[DUMMY_TRIGGER_DICT]), + ) + @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch") def test_post(self, dispatch_mock): - post_resp = self.__do_post('git', WEBHOOK_1, expect_errors=False) + post_resp = self.__do_post("git", WEBHOOK_1, expect_errors=False) self.assertEqual(post_resp.status_int, http_client.ACCEPTED) - self.assertTrue(dispatch_mock.call_args[1]['trace_context'].trace_tag) - - @mock.patch.object(TriggerInstancePublisher, 'publish_trigger', mock.MagicMock( - return_value=True)) - @mock.patch.object(WebhooksController, '_is_valid_hook', mock.MagicMock( - return_value=True)) - @mock.patch.object(HooksHolder, 'get_triggers_for_hook', mock.MagicMock( - return_value=[DUMMY_TRIGGER_DICT])) - @mock.patch('st2common.services.triggers.get_trigger_type_db', mock.MagicMock( - return_value=DUMMY_TRIGGER_TYPE_DB)) - @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch') + self.assertTrue(dispatch_mock.call_args[1]["trace_context"].trace_tag) + + @mock.patch.object( + TriggerInstancePublisher, "publish_trigger", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + WebhooksController, "_is_valid_hook", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + HooksHolder, + "get_triggers_for_hook", + mock.MagicMock(return_value=[DUMMY_TRIGGER_DICT]), + ) + @mock.patch( + "st2common.services.triggers.get_trigger_type_db", + mock.MagicMock(return_value=DUMMY_TRIGGER_TYPE_DB), + ) + @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch") def test_post_with_trace(self, dispatch_mock): - post_resp = self.__do_post('git', WEBHOOK_1, expect_errors=False, - headers={'St2-Trace-Tag': 'tag1'}) + post_resp = self.__do_post( + "git", WEBHOOK_1, expect_errors=False, headers={"St2-Trace-Tag": "tag1"} + ) self.assertEqual(post_resp.status_int, http_client.ACCEPTED) - self.assertEqual(dispatch_mock.call_args[1]['trace_context'].trace_tag, 'tag1') + self.assertEqual(dispatch_mock.call_args[1]["trace_context"].trace_tag, "tag1") - @mock.patch.object(TriggerInstancePublisher, 'publish_trigger', mock.MagicMock( - return_value=True)) + @mock.patch.object( + TriggerInstancePublisher, "publish_trigger", mock.MagicMock(return_value=True) + ) def test_post_hook_not_registered(self): - post_resp = self.__do_post('foo', WEBHOOK_1, expect_errors=True) + post_resp = self.__do_post("foo", WEBHOOK_1, expect_errors=True) self.assertEqual(post_resp.status_int, http_client.NOT_FOUND) - @mock.patch.object(TriggerInstancePublisher, 'publish_trigger', mock.MagicMock( - return_value=True)) - @mock.patch('st2common.services.triggers.get_trigger_type_db', mock.MagicMock( - return_value=DUMMY_TRIGGER_TYPE_DB)) - @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch') + @mock.patch.object( + TriggerInstancePublisher, "publish_trigger", mock.MagicMock(return_value=True) + ) + @mock.patch( + "st2common.services.triggers.get_trigger_type_db", + mock.MagicMock(return_value=DUMMY_TRIGGER_TYPE_DB), + ) + @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch") def test_st2_webhook_success(self, dispatch_mock): - post_resp = self.__do_post('st2', ST2_WEBHOOK) + post_resp = self.__do_post("st2", ST2_WEBHOOK) self.assertEqual(post_resp.status_int, http_client.ACCEPTED) - self.assertTrue(dispatch_mock.call_args[1]['trace_context'].trace_tag) + self.assertTrue(dispatch_mock.call_args[1]["trace_context"].trace_tag) - post_resp = self.__do_post('st2/', ST2_WEBHOOK) + post_resp = self.__do_post("st2/", ST2_WEBHOOK) self.assertEqual(post_resp.status_int, http_client.ACCEPTED) - @mock.patch.object(TriggerInstancePublisher, 'publish_trigger', mock.MagicMock( - return_value=True)) - @mock.patch('st2common.services.triggers.get_trigger_type_db', mock.MagicMock( - return_value=DUMMY_TRIGGER_TYPE_DB)) - @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch') + @mock.patch.object( + TriggerInstancePublisher, "publish_trigger", mock.MagicMock(return_value=True) + ) + @mock.patch( + "st2common.services.triggers.get_trigger_type_db", + mock.MagicMock(return_value=DUMMY_TRIGGER_TYPE_DB), + ) + @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch") def test_st2_webhook_failure_payload_validation_failed(self, dispatch_mock): - data = { - 'trigger': 'git.pr-merged', - 'payload': 'invalid' - } - post_resp = self.__do_post('st2', data, expect_errors=True) + data = {"trigger": "git.pr-merged", "payload": "invalid"} + post_resp = self.__do_post("st2", data, expect_errors=True) self.assertEqual(post_resp.status_int, http_client.BAD_REQUEST) - expected_msg = 'Trigger payload validation failed' - self.assertIn(expected_msg, post_resp.json['faultstring']) + expected_msg = "Trigger payload validation failed" + self.assertIn(expected_msg, post_resp.json["faultstring"]) expected_msg = "'invalid' is not of type 'object'" - self.assertIn(expected_msg, post_resp.json['faultstring']) - - @mock.patch.object(TriggerInstancePublisher, 'publish_trigger', mock.MagicMock( - return_value=True)) - @mock.patch('st2common.services.triggers.get_trigger_type_db', mock.MagicMock( - return_value=DUMMY_TRIGGER_TYPE_DB)) - @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch') + self.assertIn(expected_msg, post_resp.json["faultstring"]) + + @mock.patch.object( + TriggerInstancePublisher, "publish_trigger", mock.MagicMock(return_value=True) + ) + @mock.patch( + "st2common.services.triggers.get_trigger_type_db", + mock.MagicMock(return_value=DUMMY_TRIGGER_TYPE_DB), + ) + @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch") def test_st2_webhook_with_trace(self, dispatch_mock): - post_resp = self.__do_post('st2', ST2_WEBHOOK, headers={'St2-Trace-Tag': 'tag1'}) + post_resp = self.__do_post( + "st2", ST2_WEBHOOK, headers={"St2-Trace-Tag": "tag1"} + ) self.assertEqual(post_resp.status_int, http_client.ACCEPTED) - self.assertEqual(dispatch_mock.call_args[1]['trace_context'].trace_tag, 'tag1') + self.assertEqual(dispatch_mock.call_args[1]["trace_context"].trace_tag, "tag1") - @mock.patch.object(TriggerInstancePublisher, 'publish_trigger', mock.MagicMock( - return_value=True)) + @mock.patch.object( + TriggerInstancePublisher, "publish_trigger", mock.MagicMock(return_value=True) + ) def test_st2_webhook_body_missing_trigger(self): - post_resp = self.__do_post('st2', {'payload': {}}, expect_errors=True) - self.assertIn('Trigger not specified.', post_resp) + post_resp = self.__do_post("st2", {"payload": {}}, expect_errors=True) + self.assertIn("Trigger not specified.", post_resp) self.assertEqual(post_resp.status_int, http_client.BAD_REQUEST) - @mock.patch.object(TriggerInstancePublisher, 'publish_trigger', mock.MagicMock( - return_value=True)) - @mock.patch.object(WebhooksController, '_is_valid_hook', mock.MagicMock( - return_value=True)) - @mock.patch.object(HooksHolder, 'get_triggers_for_hook', mock.MagicMock( - return_value=[DUMMY_TRIGGER_DICT])) - @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch') + @mock.patch.object( + TriggerInstancePublisher, "publish_trigger", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + WebhooksController, "_is_valid_hook", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + HooksHolder, + "get_triggers_for_hook", + mock.MagicMock(return_value=[DUMMY_TRIGGER_DICT]), + ) + @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch") def test_json_request_body(self, dispatch_mock): # 1. Send JSON using application/json content type data = WEBHOOK_1 - post_resp = self.__do_post('git', data, - headers={'St2-Trace-Tag': 'tag1'}) + post_resp = self.__do_post("git", data, headers={"St2-Trace-Tag": "tag1"}) self.assertEqual(post_resp.status_int, http_client.ACCEPTED) - self.assertEqual(dispatch_mock.call_args[1]['payload']['headers']['Content-Type'], - 'application/json') - self.assertEqual(dispatch_mock.call_args[1]['payload']['body'], data) - self.assertEqual(dispatch_mock.call_args[1]['trace_context'].trace_tag, 'tag1') + self.assertEqual( + dispatch_mock.call_args[1]["payload"]["headers"]["Content-Type"], + "application/json", + ) + self.assertEqual(dispatch_mock.call_args[1]["payload"]["body"], data) + self.assertEqual(dispatch_mock.call_args[1]["trace_context"].trace_tag, "tag1") # 2. Send JSON using application/json + charset content type data = WEBHOOK_1 - headers = {'St2-Trace-Tag': 'tag1', 'Content-Type': 'application/json; charset=utf-8'} - post_resp = self.__do_post('git', data, - headers=headers) + headers = { + "St2-Trace-Tag": "tag1", + "Content-Type": "application/json; charset=utf-8", + } + post_resp = self.__do_post("git", data, headers=headers) self.assertEqual(post_resp.status_int, http_client.ACCEPTED) - self.assertEqual(dispatch_mock.call_args[1]['payload']['headers']['Content-Type'], - 'application/json; charset=utf-8') - self.assertEqual(dispatch_mock.call_args[1]['payload']['body'], data) - self.assertEqual(dispatch_mock.call_args[1]['trace_context'].trace_tag, 'tag1') + self.assertEqual( + dispatch_mock.call_args[1]["payload"]["headers"]["Content-Type"], + "application/json; charset=utf-8", + ) + self.assertEqual(dispatch_mock.call_args[1]["payload"]["body"], data) + self.assertEqual(dispatch_mock.call_args[1]["trace_context"].trace_tag, "tag1") # 3. JSON content type, invalid JSON body - data = 'invalid' - headers = {'St2-Trace-Tag': 'tag1', 'Content-Type': 'application/json'} - post_resp = self.app.post('/v1/webhooks/git', data, headers=headers, - expect_errors=True) + data = "invalid" + headers = {"St2-Trace-Tag": "tag1", "Content-Type": "application/json"} + post_resp = self.app.post( + "/v1/webhooks/git", data, headers=headers, expect_errors=True + ) self.assertEqual(post_resp.status_int, http_client.BAD_REQUEST) - self.assertIn('Failed to parse request body', post_resp) - - @mock.patch.object(TriggerInstancePublisher, 'publish_trigger', mock.MagicMock( - return_value=True)) - @mock.patch.object(WebhooksController, '_is_valid_hook', mock.MagicMock( - return_value=True)) - @mock.patch.object(HooksHolder, 'get_triggers_for_hook', mock.MagicMock( - return_value=[DUMMY_TRIGGER_DICT])) - @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch') + self.assertIn("Failed to parse request body", post_resp) + + @mock.patch.object( + TriggerInstancePublisher, "publish_trigger", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + WebhooksController, "_is_valid_hook", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + HooksHolder, + "get_triggers_for_hook", + mock.MagicMock(return_value=[DUMMY_TRIGGER_DICT]), + ) + @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch") def test_form_encoded_request_body(self, dispatch_mock): # Send request body as form urlencoded data if six.PY3: - data = {b'form': [b'test']} + data = {b"form": [b"test"]} else: - data = {'form': ['test']} + data = {"form": ["test"]} headers = { - 'Content-Type': 'application/x-www-form-urlencoded', - 'St2-Trace-Tag': 'tag1' + "Content-Type": "application/x-www-form-urlencoded", + "St2-Trace-Tag": "tag1", } - self.app.post('/v1/webhooks/git', data, headers=headers) - self.assertEqual(dispatch_mock.call_args[1]['payload']['headers']['Content-Type'], - 'application/x-www-form-urlencoded') - self.assertEqual(dispatch_mock.call_args[1]['payload']['body'], data) - self.assertEqual(dispatch_mock.call_args[1]['trace_context'].trace_tag, 'tag1') + self.app.post("/v1/webhooks/git", data, headers=headers) + self.assertEqual( + dispatch_mock.call_args[1]["payload"]["headers"]["Content-Type"], + "application/x-www-form-urlencoded", + ) + self.assertEqual(dispatch_mock.call_args[1]["payload"]["body"], data) + self.assertEqual(dispatch_mock.call_args[1]["trace_context"].trace_tag, "tag1") def test_unsupported_content_type(self): # Invalid / unsupported content type - should throw data = WEBHOOK_1 - headers = {'St2-Trace-Tag': 'tag1', 'Content-Type': 'foo/invalid'} - post_resp = self.app.post('/v1/webhooks/git', json.dumps(data), headers=headers, - expect_errors=True) + headers = {"St2-Trace-Tag": "tag1", "Content-Type": "foo/invalid"} + post_resp = self.app.post( + "/v1/webhooks/git", json.dumps(data), headers=headers, expect_errors=True + ) self.assertEqual(post_resp.status_int, http_client.BAD_REQUEST) - self.assertIn('Failed to parse request body', post_resp) - self.assertIn('Unsupported Content-Type', post_resp) - - @mock.patch.object(TriggerInstancePublisher, 'publish_trigger', mock.MagicMock( - return_value=True)) - @mock.patch.object(WebhooksController, '_is_valid_hook', mock.MagicMock( - return_value=True)) - @mock.patch.object(HooksHolder, 'get_triggers_for_hook', mock.MagicMock( - return_value=[DUMMY_TRIGGER_DICT])) - @mock.patch('st2common.services.triggers.get_trigger_type_db', mock.MagicMock( - return_value=DUMMY_TRIGGER_TYPE_DB_2)) - @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch') + self.assertIn("Failed to parse request body", post_resp) + self.assertIn("Unsupported Content-Type", post_resp) + + @mock.patch.object( + TriggerInstancePublisher, "publish_trigger", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + WebhooksController, "_is_valid_hook", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + HooksHolder, + "get_triggers_for_hook", + mock.MagicMock(return_value=[DUMMY_TRIGGER_DICT]), + ) + @mock.patch( + "st2common.services.triggers.get_trigger_type_db", + mock.MagicMock(return_value=DUMMY_TRIGGER_TYPE_DB_2), + ) + @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch") def test_custom_webhook_array_input_type(self, _): - post_resp = self.__do_post('sample', [{'foo': 'bar'}]) + post_resp = self.__do_post("sample", [{"foo": "bar"}]) self.assertEqual(post_resp.status_int, http_client.ACCEPTED) - self.assertEqual(post_resp.json, [{'foo': 'bar'}]) + self.assertEqual(post_resp.json, [{"foo": "bar"}]) def test_st2_webhook_array_webhook_array_input_type_not_valid(self): - post_resp = self.__do_post('st2', [{'foo': 'bar'}], expect_errors=True) + post_resp = self.__do_post("st2", [{"foo": "bar"}], expect_errors=True) self.assertEqual(post_resp.status_int, http_client.BAD_REQUEST) - self.assertEqual(post_resp.json['faultstring'], - 'Webhook body needs to be an object, got: array') + self.assertEqual( + post_resp.json["faultstring"], + "Webhook body needs to be an object, got: array", + ) def test_leading_trailing_slashes(self): # Ideally the test should setup fixtures in DB. However, the triggerwatcher @@ -296,52 +332,65 @@ def test_leading_trailing_slashes(self): # require hacking into the test app and force dependency on pecan internals. # TLDR; sorry for the ghetto test. Not sure how else to test this as a unit test. def get_webhook_trigger(name, url): - trigger = TriggerDB(name=name, pack='test') + trigger = TriggerDB(name=name, pack="test") trigger.type = list(WEBHOOK_TRIGGER_TYPES.keys())[0] - trigger.parameters = {'url': url} + trigger.parameters = {"url": url} return trigger test_triggers = [ - get_webhook_trigger('no_slash', 'no_slash'), - get_webhook_trigger('with_leading_slash', '/with_leading_slash'), - get_webhook_trigger('with_trailing_slash', '/with_trailing_slash/'), - get_webhook_trigger('with_leading_trailing_slash', '/with_leading_trailing_slash/'), - get_webhook_trigger('with_mixed_slash', '/with/mixed/slash/') + get_webhook_trigger("no_slash", "no_slash"), + get_webhook_trigger("with_leading_slash", "/with_leading_slash"), + get_webhook_trigger("with_trailing_slash", "/with_trailing_slash/"), + get_webhook_trigger( + "with_leading_trailing_slash", "/with_leading_trailing_slash/" + ), + get_webhook_trigger("with_mixed_slash", "/with/mixed/slash/"), ] controller = WebhooksController() for trigger in test_triggers: controller.add_trigger(trigger) - self.assertTrue(controller._is_valid_hook('no_slash')) - self.assertFalse(controller._is_valid_hook('/no_slash')) - self.assertTrue(controller._is_valid_hook('with_leading_slash')) - self.assertTrue(controller._is_valid_hook('with_trailing_slash')) - self.assertTrue(controller._is_valid_hook('with_leading_trailing_slash')) - self.assertTrue(controller._is_valid_hook('with/mixed/slash')) - - @mock.patch.object(TriggerInstancePublisher, 'publish_trigger', mock.MagicMock( - return_value=True)) - @mock.patch.object(WebhooksController, '_is_valid_hook', mock.MagicMock( - return_value=True)) - @mock.patch.object(HooksHolder, 'get_triggers_for_hook', mock.MagicMock( - return_value=[DUMMY_TRIGGER_DICT])) - @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch') + self.assertTrue(controller._is_valid_hook("no_slash")) + self.assertFalse(controller._is_valid_hook("/no_slash")) + self.assertTrue(controller._is_valid_hook("with_leading_slash")) + self.assertTrue(controller._is_valid_hook("with_trailing_slash")) + self.assertTrue(controller._is_valid_hook("with_leading_trailing_slash")) + self.assertTrue(controller._is_valid_hook("with/mixed/slash")) + + @mock.patch.object( + TriggerInstancePublisher, "publish_trigger", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + WebhooksController, "_is_valid_hook", mock.MagicMock(return_value=True) + ) + @mock.patch.object( + HooksHolder, + "get_triggers_for_hook", + mock.MagicMock(return_value=[DUMMY_TRIGGER_DICT]), + ) + @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch") def test_authentication_headers_should_be_removed(self, dispatch_mock): headers = { - 'Content-Type': 'application/x-www-form-urlencoded', - 'St2-Api-Key': 'foobar', - 'X-Auth-Token': 'deadbeaf', - 'Cookie': 'foo=bar' + "Content-Type": "application/x-www-form-urlencoded", + "St2-Api-Key": "foobar", + "X-Auth-Token": "deadbeaf", + "Cookie": "foo=bar", } - self.app.post('/v1/webhooks/git', WEBHOOK_1, headers=headers) - self.assertNotIn('St2-Api-Key', dispatch_mock.call_args[1]['payload']['headers']) - self.assertNotIn('X-Auth-Token', dispatch_mock.call_args[1]['payload']['headers']) - self.assertNotIn('Cookie', dispatch_mock.call_args[1]['payload']['headers']) + self.app.post("/v1/webhooks/git", WEBHOOK_1, headers=headers) + self.assertNotIn( + "St2-Api-Key", dispatch_mock.call_args[1]["payload"]["headers"] + ) + self.assertNotIn( + "X-Auth-Token", dispatch_mock.call_args[1]["payload"]["headers"] + ) + self.assertNotIn("Cookie", dispatch_mock.call_args[1]["payload"]["headers"]) def __do_post(self, hook, webhook, expect_errors=False, headers=None): - return self.app.post_json('/v1/webhooks/' + hook, - params=webhook, - expect_errors=expect_errors, - headers=headers) + return self.app.post_json( + "/v1/webhooks/" + hook, + params=webhook, + expect_errors=expect_errors, + headers=headers, + ) diff --git a/st2api/tests/unit/controllers/v1/test_workflow_inspection.py b/st2api/tests/unit/controllers/v1/test_workflow_inspection.py index 3b45421d797..91e251fe9dc 100644 --- a/st2api/tests/unit/controllers/v1/test_workflow_inspection.py +++ b/st2api/tests/unit/controllers/v1/test_workflow_inspection.py @@ -22,13 +22,17 @@ from st2tests.api import FunctionalTest -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK -PACKS = [TEST_PACK_PATH, st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core'] +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) +PACKS = [ + TEST_PACK_PATH, + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", +] class WorkflowInspectionControllerTest(FunctionalTest, st2tests.WorkflowTestCase): - @classmethod def setUpClass(cls): super(WorkflowInspectionControllerTest, cls).setUpClass() @@ -39,8 +43,7 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: @@ -48,14 +51,14 @@ def setUpClass(cls): def _do_post(self, wf_def, expect_errors=False): return self.app.post( - '/v1/workflows/inspect', + "/v1/workflows/inspect", wf_def, expect_errors=expect_errors, - content_type='text/plain' + content_type="text/plain", ) def test_inspection(self): - wf_file = 'sequential.yaml' + wf_file = "sequential.yaml" wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_file) wf_def = self.get_wf_def(TEST_PACK_PATH, wf_meta) @@ -65,48 +68,48 @@ def test_inspection(self): self.assertListEqual(response.json, expected_errors) def test_inspection_return_errors(self): - wf_file = 'fail-inspection.yaml' + wf_file = "fail-inspection.yaml" wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_file) wf_def = self.get_wf_def(TEST_PACK_PATH, wf_meta) expected_errors = [ { - 'type': 'content', - 'message': 'The action "std.noop" is not registered in the database.', - 'schema_path': r'properties.tasks.patternProperties.^\w+$.properties.action', - 'spec_path': 'tasks.task3.action' + "type": "content", + "message": 'The action "std.noop" is not registered in the database.', + "schema_path": r"properties.tasks.patternProperties.^\w+$.properties.action", + "spec_path": "tasks.task3.action", }, { - 'type': 'context', - 'language': 'yaql', - 'expression': '<% ctx().foobar %>', - 'message': 'Variable "foobar" is referenced before assignment.', - 'schema_path': r'properties.tasks.patternProperties.^\w+$.properties.input', - 'spec_path': 'tasks.task1.input', + "type": "context", + "language": "yaql", + "expression": "<% ctx().foobar %>", + "message": 'Variable "foobar" is referenced before assignment.', + "schema_path": r"properties.tasks.patternProperties.^\w+$.properties.input", + "spec_path": "tasks.task1.input", }, { - 'type': 'expression', - 'language': 'yaql', - 'expression': '<% <% succeeded() %>', - 'message': ( - 'Parse error: unexpected \'<\' at ' - 'position 0 of expression \'<% succeeded()\'' + "type": "expression", + "language": "yaql", + "expression": "<% <% succeeded() %>", + "message": ( + "Parse error: unexpected '<' at " + "position 0 of expression '<% succeeded()'" ), - 'schema_path': ( - r'properties.tasks.patternProperties.^\w+$.' - 'properties.next.items.properties.when' + "schema_path": ( + r"properties.tasks.patternProperties.^\w+$." + "properties.next.items.properties.when" ), - 'spec_path': 'tasks.task2.next[0].when' + "spec_path": "tasks.task2.next[0].when", }, { - 'type': 'syntax', - 'message': ( - '[{\'cmd\': \'echo <% ctx().macro %>\'}] is ' - 'not valid under any of the given schemas' + "type": "syntax", + "message": ( + "[{'cmd': 'echo <% ctx().macro %>'}] is " + "not valid under any of the given schemas" ), - 'schema_path': r'properties.tasks.patternProperties.^\w+$.properties.input.oneOf', - 'spec_path': 'tasks.task2.input' - } + "schema_path": r"properties.tasks.patternProperties.^\w+$.properties.input.oneOf", + "spec_path": "tasks.task2.input", + }, ] response = self._do_post(wf_def, expect_errors=False) diff --git a/st2api/tests/unit/test_validation_utils.py b/st2api/tests/unit/test_validation_utils.py index eaf1cd75a5e..bad17b22a53 100644 --- a/st2api/tests/unit/test_validation_utils.py +++ b/st2api/tests/unit/test_validation_utils.py @@ -19,9 +19,7 @@ from st2api.validation import validate_rbac_is_correctly_configured from st2tests import config as tests_config -__all__ = [ - 'ValidationUtilsTestCase' -] +__all__ = ["ValidationUtilsTestCase"] class ValidationUtilsTestCase(unittest2.TestCase): @@ -34,26 +32,34 @@ def test_validate_rbac_is_correctly_configured_succcess(self): self.assertTrue(result) def test_validate_rbac_is_correctly_configured_auth_not_enabled(self): - cfg.CONF.set_override(group='rbac', name='enable', override=True) - cfg.CONF.set_override(group='auth', name='enable', override=False) + cfg.CONF.set_override(group="rbac", name="enable", override=True) + cfg.CONF.set_override(group="auth", name="enable", override=False) - expected_msg = ('Authentication is not enabled. RBAC only works when authentication is ' - 'enabled. You can either enable authentication or disable RBAC.') - self.assertRaisesRegexp(ValueError, expected_msg, - validate_rbac_is_correctly_configured) + expected_msg = ( + "Authentication is not enabled. RBAC only works when authentication is " + "enabled. You can either enable authentication or disable RBAC." + ) + self.assertRaisesRegexp( + ValueError, expected_msg, validate_rbac_is_correctly_configured + ) def test_validate_rbac_is_correctly_configured_non_default_backend_set(self): - cfg.CONF.set_override(group='rbac', name='enable', override=True) - cfg.CONF.set_override(group='rbac', name='backend', override='invalid') - cfg.CONF.set_override(group='auth', name='enable', override=True) - - expected_msg = ('You have enabled RBAC, but RBAC backend is not set to "default".') - self.assertRaisesRegexp(ValueError, expected_msg, - validate_rbac_is_correctly_configured) - - def test_validate_rbac_is_correctly_configured_default_backend_available_success(self): - cfg.CONF.set_override(group='rbac', name='enable', override=True) - cfg.CONF.set_override(group='rbac', name='backend', override='default') - cfg.CONF.set_override(group='auth', name='enable', override=True) + cfg.CONF.set_override(group="rbac", name="enable", override=True) + cfg.CONF.set_override(group="rbac", name="backend", override="invalid") + cfg.CONF.set_override(group="auth", name="enable", override=True) + + expected_msg = ( + 'You have enabled RBAC, but RBAC backend is not set to "default".' + ) + self.assertRaisesRegexp( + ValueError, expected_msg, validate_rbac_is_correctly_configured + ) + + def test_validate_rbac_is_correctly_configured_default_backend_available_success( + self, + ): + cfg.CONF.set_override(group="rbac", name="enable", override=True) + cfg.CONF.set_override(group="rbac", name="backend", override="default") + cfg.CONF.set_override(group="auth", name="enable", override=True) result = validate_rbac_is_correctly_configured() self.assertTrue(result) diff --git a/st2auth/dist_utils.py b/st2auth/dist_utils.py index a6f62c8cc2a..2f2043cf29f 100644 --- a/st2auth/dist_utils.py +++ b/st2auth/dist_utils.py @@ -43,17 +43,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" __all__ = [ - 'check_pip_is_installed', - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_is_installed", + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] @@ -64,15 +64,15 @@ def check_pip_is_installed(): try: import pip # NOQA except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) return True -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ @@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'): import pip if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) return True @@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path): reqs = [] def _get_link(line): - vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+'] + vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"] for vcs_prefix in vcs_prefixes: - if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)): - req_name = re.findall('.*#egg=(.+)([&|@]).*$', line) + if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)): + req_name = re.findall(".*#egg=(.+)([&|@]).*$", line) if not req_name: - req_name = re.findall('.*#egg=(.+?)$', line) + req_name = re.findall(".*#egg=(.+?)$", line) else: req_name = req_name[0] if not req_name: - raise ValueError('Line "%s" is missing "#egg="' % (line)) + raise ValueError( + 'Line "%s" is missing "#egg="' % (line) + ) - link = line.replace('-e ', '').strip() + link = line.replace("-e ", "").strip() return link, req_name[0] return None, None - with open(requirements_file_path, 'r') as fp: + with open(requirements_file_path, "r") as fp: for line in fp.readlines(): line = line.strip() - if line.startswith('#') or not line: + if line.startswith("#") or not line: continue link, req_name = _get_link(line=line) @@ -131,8 +135,8 @@ def _get_link(line): else: req_name = line - if ';' in req_name: - req_name = req_name.split(';')[0].strip() + if ";" in req_name: + req_name = req_name.split(";")[0].strip() reqs.append(req_name) @@ -146,7 +150,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -155,14 +159,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/st2auth/setup.py b/st2auth/setup.py index f77ee72f03a..c6e266472b1 100644 --- a/st2auth/setup.py +++ b/st2auth/setup.py @@ -22,9 +22,9 @@ from dist_utils import apply_vagrant_workaround from st2auth import __version__ -ST2_COMPONENT = 'st2auth' +ST2_COMPONENT = "st2auth" BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt') +REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt") install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE) @@ -33,23 +33,21 @@ setup( name=ST2_COMPONENT, version=__version__, - description='{} StackStorm event-driven automation platform component'.format(ST2_COMPONENT), - author='StackStorm', - author_email='info@stackstorm.com', - license='Apache License (2.0)', - url='https://stackstorm.com/', + description="{} StackStorm event-driven automation platform component".format( + ST2_COMPONENT + ), + author="StackStorm", + author_email="info@stackstorm.com", + license="Apache License (2.0)", + url="https://stackstorm.com/", install_requires=install_reqs, dependency_links=dep_links, test_suite=ST2_COMPONENT, zip_safe=False, include_package_data=True, - packages=find_packages(exclude=['setuptools', 'tests']), - scripts=[ - 'bin/st2auth' - ], + packages=find_packages(exclude=["setuptools", "tests"]), + scripts=["bin/st2auth"], entry_points={ - 'st2auth.sso.backends': [ - 'noop = st2auth.sso.noop:NoOpSingleSignOnBackend' - ] - } + "st2auth.sso.backends": ["noop = st2auth.sso.noop:NoOpSingleSignOnBackend"] + }, ) diff --git a/st2auth/st2auth/__init__.py b/st2auth/st2auth/__init__.py index bbe290db9a7..e6d3f15e0bd 100644 --- a/st2auth/st2auth/__init__.py +++ b/st2auth/st2auth/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = "3.4dev" diff --git a/st2auth/st2auth/app.py b/st2auth/st2auth/app.py index 3398104c6c0..b9b8f7d595c 100644 --- a/st2auth/st2auth/app.py +++ b/st2auth/st2auth/app.py @@ -36,34 +36,38 @@ def setup_app(config=None): config = config or {} - LOG.info('Creating st2auth: %s as OpenAPI app.', VERSION_STRING) + LOG.info("Creating st2auth: %s as OpenAPI app.", VERSION_STRING) - is_gunicorn = config.get('is_gunicorn', False) + is_gunicorn = config.get("is_gunicorn", False) if is_gunicorn: # NOTE: We only want to perform this logic in the WSGI worker st2auth_config.register_opts() capabilities = { - 'name': 'auth', - 'listen_host': cfg.CONF.auth.host, - 'listen_port': cfg.CONF.auth.port, - 'listen_ssl': cfg.CONF.auth.use_ssl, - 'type': 'active' + "name": "auth", + "listen_host": cfg.CONF.auth.host, + "listen_port": cfg.CONF.auth.port, + "listen_ssl": cfg.CONF.auth.use_ssl, + "type": "active", } # This should be called in gunicorn case because we only want # workers to connect to db, rabbbitmq etc. In standalone HTTP # server case, this setup would have already occurred. - common_setup(service='auth', config=st2auth_config, setup_db=True, - register_mq_exchanges=False, - register_signal_handlers=True, - register_internal_trigger_types=False, - run_migrations=False, - service_registry=True, - capabilities=capabilities, - config_args=config.get('config_args', None)) + common_setup( + service="auth", + config=st2auth_config, + setup_db=True, + register_mq_exchanges=False, + register_signal_handlers=True, + register_internal_trigger_types=False, + run_migrations=False, + service_registry=True, + capabilities=capabilities, + config_args=config.get("config_args", None), + ) # pysaml2 uses subprocess communicate which calls communicate_with_poll - if cfg.CONF.auth.sso and cfg.CONF.auth.sso_backend == 'saml2': + if cfg.CONF.auth.sso and cfg.CONF.auth.sso_backend == "saml2": use_select_poll_workaround(nose_only=False) # Additional pre-run time checks @@ -71,10 +75,8 @@ def setup_app(config=None): router = Router(debug=cfg.CONF.auth.debug, is_gunicorn=is_gunicorn) - spec = spec_loader.load_spec('st2common', 'openapi.yaml.j2') - transforms = { - '^/auth/v1/': ['/', '/v1/'] - } + spec = spec_loader.load_spec("st2common", "openapi.yaml.j2") + transforms = {"^/auth/v1/": ["/", "/v1/"]} router.add_spec(spec, transforms=transforms) app = router.as_wsgi @@ -83,8 +85,8 @@ def setup_app(config=None): app = ErrorHandlingMiddleware(app) app = CorsMiddleware(app) app = LoggingMiddleware(app, router) - app = ResponseInstrumentationMiddleware(app, router, service_name='auth') + app = ResponseInstrumentationMiddleware(app, router, service_name="auth") app = RequestIDMiddleware(app) - app = RequestInstrumentationMiddleware(app, router, service_name='auth') + app = RequestInstrumentationMiddleware(app, router, service_name="auth") return app diff --git a/st2auth/st2auth/backends/__init__.py b/st2auth/st2auth/backends/__init__.py index 64d3275af53..a626f0d0822 100644 --- a/st2auth/st2auth/backends/__init__.py +++ b/st2auth/st2auth/backends/__init__.py @@ -22,14 +22,11 @@ from st2common import log as logging from st2common.util import driver_loader -__all__ = [ - 'get_available_backends', - 'get_backend_instance' -] +__all__ = ["get_available_backends", "get_backend_instance"] LOG = logging.getLogger(__name__) -BACKENDS_NAMESPACE = 'st2auth.backends.backend' +BACKENDS_NAMESPACE = "st2auth.backends.backend" def get_available_backends(): @@ -43,8 +40,10 @@ def get_backend_instance(name): try: kwargs = json.loads(backend_kwargs) except ValueError as e: - raise ValueError('Failed to JSON parse backend settings for backend "%s": %s' % - (name, six.text_type(e))) + raise ValueError( + 'Failed to JSON parse backend settings for backend "%s": %s' + % (name, six.text_type(e)) + ) else: kwargs = {} @@ -55,9 +54,11 @@ def get_backend_instance(name): except Exception as e: tb_msg = traceback.format_exc() class_name = cls.__name__ - msg = ('Failed to instantiate auth backend "%s" (class %s) with backend settings ' - '"%s": %s' % (name, class_name, str(kwargs), six.text_type(e))) - msg += '\n\n' + tb_msg + msg = ( + 'Failed to instantiate auth backend "%s" (class %s) with backend settings ' + '"%s": %s' % (name, class_name, str(kwargs), six.text_type(e)) + ) + msg += "\n\n" + tb_msg exc_cls = type(e) raise exc_cls(msg) diff --git a/st2auth/st2auth/backends/base.py b/st2auth/st2auth/backends/base.py index 0246729c1a2..4d32e51860a 100644 --- a/st2auth/st2auth/backends/base.py +++ b/st2auth/st2auth/backends/base.py @@ -19,9 +19,7 @@ from st2auth.backends.constants import AuthBackendCapability -__all__ = [ - 'BaseAuthenticationBackend' -] +__all__ = ["BaseAuthenticationBackend"] @six.add_metaclass(abc.ABCMeta) @@ -31,9 +29,7 @@ class BaseAuthenticationBackend(object): """ # Capabilities offered by the auth backend - CAPABILITIES = ( - AuthBackendCapability.CAN_AUTHENTICATE_USER - ) + CAPABILITIES = AuthBackendCapability.CAN_AUTHENTICATE_USER @abc.abstractmethod def authenticate(self, username, password): @@ -47,7 +43,7 @@ def get_user(self, username): :rtype: ``dict`` """ - raise NotImplementedError('get_user() not implemented for this backend') + raise NotImplementedError("get_user() not implemented for this backend") def get_user_groups(self, username): """ @@ -57,4 +53,4 @@ def get_user_groups(self, username): :rtype: ``list`` of ``str`` """ - raise NotImplementedError('get_groups() not implemented for this backend') + raise NotImplementedError("get_groups() not implemented for this backend") diff --git a/st2auth/st2auth/backends/constants.py b/st2auth/st2auth/backends/constants.py index 6cb990c64db..b50625e7459 100644 --- a/st2auth/st2auth/backends/constants.py +++ b/st2auth/st2auth/backends/constants.py @@ -19,17 +19,15 @@ from st2common.util.enum import Enum -__all__ = [ - 'AuthBackendCapability' -] +__all__ = ["AuthBackendCapability"] class AuthBackendCapability(Enum): # This auth backend can authenticate a user. - CAN_AUTHENTICATE_USER = 'can_authenticate_user' + CAN_AUTHENTICATE_USER = "can_authenticate_user" # Auth backend can provide additional information about a particular user. - HAS_USER_INFORMATION = 'has_user_info' + HAS_USER_INFORMATION = "has_user_info" # Auth backend can provide a group membership information for a particular user. - HAS_GROUP_INFORMATION = 'has_groups_info' + HAS_GROUP_INFORMATION = "has_groups_info" diff --git a/st2auth/st2auth/cmd/api.py b/st2auth/st2auth/cmd/api.py index d1fd7605bd0..4c52f2649c2 100644 --- a/st2auth/st2auth/cmd/api.py +++ b/st2auth/st2auth/cmd/api.py @@ -14,6 +14,7 @@ # limitations under the License. from st2common.util.monkey_patch import monkey_patch + monkey_patch() import eventlet @@ -27,15 +28,14 @@ from st2common.service_setup import setup as common_setup from st2common.service_setup import teardown as common_teardown from st2auth import config + config.register_opts() from st2auth import app from st2auth.validation import validate_auth_backend_is_correctly_configured -__all__ = [ - 'main' -] +__all__ = ["main"] LOG = logging.getLogger(__name__) @@ -43,15 +43,23 @@ def _setup(): capabilities = { - 'name': 'auth', - 'listen_host': cfg.CONF.auth.host, - 'listen_port': cfg.CONF.auth.port, - 'listen_ssl': cfg.CONF.auth.use_ssl, - 'type': 'active' + "name": "auth", + "listen_host": cfg.CONF.auth.host, + "listen_port": cfg.CONF.auth.port, + "listen_ssl": cfg.CONF.auth.use_ssl, + "type": "active", } - common_setup(service='auth', config=config, setup_db=True, register_mq_exchanges=False, - register_signal_handlers=True, register_internal_trigger_types=False, - run_migrations=False, service_registry=True, capabilities=capabilities) + common_setup( + service="auth", + config=config, + setup_db=True, + register_mq_exchanges=False, + register_signal_handlers=True, + register_internal_trigger_types=False, + run_migrations=False, + service_registry=True, + capabilities=capabilities, + ) # Additional pre-run time checks validate_auth_backend_is_correctly_configured() @@ -74,14 +82,18 @@ def _run_server(): socket = eventlet.listen((host, port)) if use_ssl: - socket = eventlet.wrap_ssl(socket, - certfile=cert_file_path, - keyfile=key_file_path, - server_side=True) + socket = eventlet.wrap_ssl( + socket, certfile=cert_file_path, keyfile=key_file_path, server_side=True + ) LOG.info('ST2 Auth API running in "%s" auth mode', cfg.CONF.auth.mode) - LOG.info('(PID=%s) ST2 Auth API is serving on %s://%s:%s.', os.getpid(), - 'https' if use_ssl else 'http', host, port) + LOG.info( + "(PID=%s) ST2 Auth API is serving on %s://%s:%s.", + os.getpid(), + "https" if use_ssl else "http", + host, + port, + ) wsgi.server(socket, app.setup_app(), log=LOG, log_output=False) return 0 @@ -98,7 +110,7 @@ def main(): except SystemExit as exit_code: sys.exit(exit_code) except Exception: - LOG.exception('(PID=%s) ST2 Auth API quit due to exception.', os.getpid()) + LOG.exception("(PID=%s) ST2 Auth API quit due to exception.", os.getpid()) return 1 finally: _teardown() diff --git a/st2auth/st2auth/config.py b/st2auth/st2auth/config.py index 00cfa2aca70..dee0d2d064c 100644 --- a/st2auth/st2auth/config.py +++ b/st2auth/st2auth/config.py @@ -28,8 +28,11 @@ def parse_args(args=None): - cfg.CONF(args=args, version=VERSION_STRING, - default_config_files=[DEFAULT_CONFIG_FILE_PATH]) + cfg.CONF( + args=args, + version=VERSION_STRING, + default_config_files=[DEFAULT_CONFIG_FILE_PATH], + ) def register_opts(): @@ -50,47 +53,61 @@ def _register_app_opts(): auth_opts = [ cfg.StrOpt( - 'host', default='127.0.0.1', - help='Host on which the service should listen on.'), + "host", + default="127.0.0.1", + help="Host on which the service should listen on.", + ), cfg.IntOpt( - 'port', default=9100, - help='Port on which the service should listen on.'), - cfg.BoolOpt( - 'use_ssl', default=False, - help='Specify to enable SSL / TLS mode'), + "port", default=9100, help="Port on which the service should listen on." + ), + cfg.BoolOpt("use_ssl", default=False, help="Specify to enable SSL / TLS mode"), cfg.StrOpt( - 'cert', default='/etc/apache2/ssl/mycert.crt', - help='Path to the SSL certificate file. Only used when "use_ssl" is specified.'), + "cert", + default="/etc/apache2/ssl/mycert.crt", + help='Path to the SSL certificate file. Only used when "use_ssl" is specified.', + ), cfg.StrOpt( - 'key', default='/etc/apache2/ssl/mycert.key', - help='Path to the SSL private key file. Only used when "use_ssl" is specified.'), + "key", + default="/etc/apache2/ssl/mycert.key", + help='Path to the SSL private key file. Only used when "use_ssl" is specified.', + ), cfg.StrOpt( - 'logging', default='/etc/st2/logging.auth.conf', - help='Path to the logging config.'), - cfg.BoolOpt( - 'debug', default=False, - help='Specify to enable debug mode.'), + "logging", + default="/etc/st2/logging.auth.conf", + help="Path to the logging config.", + ), + cfg.BoolOpt("debug", default=False, help="Specify to enable debug mode."), cfg.StrOpt( - 'mode', default=DEFAULT_MODE, - help='Authentication mode (%s)' % (','.join(VALID_MODES))), + "mode", + default=DEFAULT_MODE, + help="Authentication mode (%s)" % (",".join(VALID_MODES)), + ), cfg.StrOpt( - 'backend', default=DEFAULT_BACKEND, - help='Authentication backend to use in a standalone mode. Available ' - 'backends: %s.' % (', '.join(available_backends))), + "backend", + default=DEFAULT_BACKEND, + help="Authentication backend to use in a standalone mode. Available " + "backends: %s." % (", ".join(available_backends)), + ), cfg.StrOpt( - 'backend_kwargs', default=None, - help='JSON serialized arguments which are passed to the authentication ' - 'backend in a standalone mode.'), + "backend_kwargs", + default=None, + help="JSON serialized arguments which are passed to the authentication " + "backend in a standalone mode.", + ), cfg.BoolOpt( - 'sso', default=False, - help='Enable Single Sign On for GUI if true.'), + "sso", default=False, help="Enable Single Sign On for GUI if true." + ), cfg.StrOpt( - 'sso_backend', default=DEFAULT_SSO_BACKEND, - help='Single Sign On backend to use when SSO is enabled. Available ' - 'backends: noop, saml2.'), + "sso_backend", + default=DEFAULT_SSO_BACKEND, + help="Single Sign On backend to use when SSO is enabled. Available " + "backends: noop, saml2.", + ), cfg.StrOpt( - 'sso_backend_kwargs', default=None, - help='JSON serialized arguments which are passed to the SSO backend.') + "sso_backend_kwargs", + default=None, + help="JSON serialized arguments which are passed to the SSO backend.", + ), ] - cfg.CONF.register_cli_opts(auth_opts, group='auth') + cfg.CONF.register_cli_opts(auth_opts, group="auth") diff --git a/st2auth/st2auth/controllers/v1/auth.py b/st2auth/st2auth/controllers/v1/auth.py index f0042632e91..c77546141f4 100644 --- a/st2auth/st2auth/controllers/v1/auth.py +++ b/st2auth/st2auth/controllers/v1/auth.py @@ -29,8 +29,8 @@ HANDLER_MAPPINGS = { - 'proxy': handlers.ProxyAuthHandler, - 'standalone': handlers.StandaloneAuthHandler + "proxy": handlers.ProxyAuthHandler, + "standalone": handlers.StandaloneAuthHandler, } LOG = logging.getLogger(__name__) @@ -38,17 +38,17 @@ class TokenValidationController(object): def post(self, request): - token = getattr(request, 'token', None) + token = getattr(request, "token", None) if not token: - raise exc.HTTPBadRequest('Token is not provided.') + raise exc.HTTPBadRequest("Token is not provided.") try: - return {'valid': auth_utils.validate_token(token) is not None} + return {"valid": auth_utils.validate_token(token) is not None} except (TokenNotFoundError, TokenExpiredError): - return {'valid': False} + return {"valid": False} except Exception: - msg = 'Unexpected error occurred while verifying token.' + msg = "Unexpected error occurred while verifying token." LOG.exception(msg) raise exc.HTTPInternalServerError(msg) @@ -60,30 +60,32 @@ def __init__(self): try: self.handler = HANDLER_MAPPINGS[cfg.CONF.auth.mode]() except KeyError: - raise ParamException("%s is not a valid auth mode" % - cfg.CONF.auth.mode) + raise ParamException("%s is not a valid auth mode" % cfg.CONF.auth.mode) def post(self, request, **kwargs): headers = {} - if 'x-forwarded-for' in kwargs: - headers['x-forwarded-for'] = kwargs.pop('x-forwarded-for') + if "x-forwarded-for" in kwargs: + headers["x-forwarded-for"] = kwargs.pop("x-forwarded-for") - authorization = kwargs.pop('authorization', None) + authorization = kwargs.pop("authorization", None) if authorization: - authorization = tuple(authorization.split(' ')) - - token = self.handler.handle_auth(request=request, headers=headers, - remote_addr=kwargs.pop('remote_addr', None), - remote_user=kwargs.pop('remote_user', None), - authorization=authorization, - **kwargs) + authorization = tuple(authorization.split(" ")) + + token = self.handler.handle_auth( + request=request, + headers=headers, + remote_addr=kwargs.pop("remote_addr", None), + remote_user=kwargs.pop("remote_user", None), + authorization=authorization, + **kwargs, + ) return process_successful_response(token=token) def process_successful_response(token): resp = Response(json=token, status=http_client.CREATED) # NOTE: gunicon fails and throws an error if header value is not a string (e.g. if it's None) - resp.headers['X-API-URL'] = api_utils.get_base_public_api_url() + resp.headers["X-API-URL"] = api_utils.get_base_public_api_url() return resp diff --git a/st2auth/st2auth/controllers/v1/sso.py b/st2auth/st2auth/controllers/v1/sso.py index f25effe6812..ef1096462cf 100644 --- a/st2auth/st2auth/controllers/v1/sso.py +++ b/st2auth/st2auth/controllers/v1/sso.py @@ -32,7 +32,6 @@ class IdentityProviderCallbackController(object): - def __init__(self): self.st2_auth_handler = handlers.ProxyAuthHandler() @@ -40,16 +39,21 @@ def post(self, response, **kwargs): try: verified_user = SSO_BACKEND.verify_response(response) - st2_auth_token_create_request = {'user': verified_user['username'], 'ttl': None} + st2_auth_token_create_request = { + "user": verified_user["username"], + "ttl": None, + } st2_auth_token = self.st2_auth_handler.handle_auth( request=st2_auth_token_create_request, - remote_addr=verified_user['referer'], - remote_user=verified_user['username'], - headers={} + remote_addr=verified_user["referer"], + remote_user=verified_user["username"], + headers={}, ) - return process_successful_authn_response(verified_user['referer'], st2_auth_token) + return process_successful_authn_response( + verified_user["referer"], st2_auth_token + ) except NotImplementedError as e: return process_failure_response(http_client.INTERNAL_SERVER_ERROR, e) except auth_exc.SSOVerificationError as e: @@ -59,7 +63,6 @@ def post(self, response, **kwargs): class SingleSignOnRequestController(object): - def get(self, referer): try: response = router.Response(status=http_client.TEMPORARY_REDIRECT) @@ -76,15 +79,15 @@ class SingleSignOnController(object): callback = IdentityProviderCallbackController() def _get_sso_enabled_config(self): - return {'enabled': cfg.CONF.auth.sso} + return {"enabled": cfg.CONF.auth.sso} def get(self): try: result = self._get_sso_enabled_config() return process_successful_response(http_client.OK, result) except Exception: - LOG.exception('Error encountered while getting SSO configuration.') - result = {'enabled': False} + LOG.exception("Error encountered while getting SSO configuration.") + result = {"enabled": False} return process_successful_response(http_client.OK, result) @@ -107,23 +110,23 @@ def get(self): def process_successful_authn_response(referer, token): token_json = { - 'id': str(token.id), - 'user': token.user, - 'token': token.token, - 'expiry': str(token.expiry), - 'service': False, - 'metadata': {} + "id": str(token.id), + "user": token.user, + "token": token.token, + "expiry": str(token.expiry), + "service": False, + "metadata": {}, } body = CALLBACK_SUCCESS_RESPONSE_BODY % referer resp = router.Response(body=body) - resp.headers['Content-Type'] = 'text/html' + resp.headers["Content-Type"] = "text/html" resp.set_cookie( - 'st2-auth-token', + "st2-auth-token", value=urllib.parse.quote(json.dumps(token_json)), expires=datetime.timedelta(seconds=60), - overwrite=True + overwrite=True, ) return resp @@ -135,7 +138,7 @@ def process_successful_response(status_code, json_body): def process_failure_response(status_code, exception): LOG.error(str(exception)) - json_body = {'faultstring': str(exception)} + json_body = {"faultstring": str(exception)} return router.Response(status_code=status_code, json_body=json_body) diff --git a/st2auth/st2auth/handlers.py b/st2auth/st2auth/handlers.py index 59d74085cf5..f6540bcda73 100644 --- a/st2auth/st2auth/handlers.py +++ b/st2auth/st2auth/handlers.py @@ -35,13 +35,22 @@ LOG = logging.getLogger(__name__) -def abort_request(status_code=http_client.UNAUTHORIZED, message='Invalid or missing credentials'): +def abort_request( + status_code=http_client.UNAUTHORIZED, message="Invalid or missing credentials" +): return abort(status_code, message) class AuthHandlerBase(object): - def handle_auth(self, request, headers=None, remote_addr=None, - remote_user=None, authorization=None, **kwargs): + def handle_auth( + self, + request, + headers=None, + remote_addr=None, + remote_user=None, + authorization=None, + **kwargs, + ): raise NotImplementedError() def _create_token_for_user(self, username, ttl=None): @@ -49,80 +58,90 @@ def _create_token_for_user(self, username, ttl=None): return TokenAPI.from_model(tokendb) def _get_username_for_request(self, username, request): - impersonate_user = getattr(request, 'user', None) + impersonate_user = getattr(request, "user", None) if impersonate_user is not None: # check this is a service account try: if not User.get_by_name(username).is_service: - message = "Current user is not a service and cannot " \ - "request impersonated tokens" - abort_request(status_code=http_client.BAD_REQUEST, - message=message) + message = ( + "Current user is not a service and cannot " + "request impersonated tokens" + ) + abort_request(status_code=http_client.BAD_REQUEST, message=message) return username = impersonate_user except (UserNotFoundError, StackStormDBObjectNotFoundError): - message = "Could not locate user %s" % \ - (impersonate_user) - abort_request(status_code=http_client.BAD_REQUEST, - message=message) + message = "Could not locate user %s" % (impersonate_user) + abort_request(status_code=http_client.BAD_REQUEST, message=message) return else: - impersonate_user = getattr(request, 'impersonate_user', None) - nickname_origin = getattr(request, 'nickname_origin', None) + impersonate_user = getattr(request, "impersonate_user", None) + nickname_origin = getattr(request, "nickname_origin", None) if impersonate_user is not None: try: # check this is a service account if not User.get_by_name(username).is_service: raise NotServiceUserError() - username = User.get_by_nickname(impersonate_user, - nickname_origin).name + username = User.get_by_nickname( + impersonate_user, nickname_origin + ).name except NotServiceUserError: - message = "Current user is not a service and cannot " \ - "request impersonated tokens" - abort_request(status_code=http_client.BAD_REQUEST, - message=message) + message = ( + "Current user is not a service and cannot " + "request impersonated tokens" + ) + abort_request(status_code=http_client.BAD_REQUEST, message=message) return except (UserNotFoundError, StackStormDBObjectNotFoundError): - message = "Could not locate user %s@%s" % \ - (impersonate_user, nickname_origin) - abort_request(status_code=http_client.BAD_REQUEST, - message=message) + message = "Could not locate user %s@%s" % ( + impersonate_user, + nickname_origin, + ) + abort_request(status_code=http_client.BAD_REQUEST, message=message) return except NoNicknameOriginProvidedError: - message = "Nickname origin is not provided for nickname '%s'" % \ - impersonate_user - abort_request(status_code=http_client.BAD_REQUEST, - message=message) + message = ( + "Nickname origin is not provided for nickname '%s'" + % impersonate_user + ) + abort_request(status_code=http_client.BAD_REQUEST, message=message) return except AmbiguousUserError: - message = "%s@%s matched more than one username" % \ - (impersonate_user, nickname_origin) - abort_request(status_code=http_client.BAD_REQUEST, - message=message) + message = "%s@%s matched more than one username" % ( + impersonate_user, + nickname_origin, + ) + abort_request(status_code=http_client.BAD_REQUEST, message=message) return return username class ProxyAuthHandler(AuthHandlerBase): - def handle_auth(self, request, headers=None, remote_addr=None, - remote_user=None, authorization=None, **kwargs): - remote_addr = headers.get('x-forwarded-for', - remote_addr) - extra = {'remote_addr': remote_addr} + def handle_auth( + self, + request, + headers=None, + remote_addr=None, + remote_user=None, + authorization=None, + **kwargs, + ): + remote_addr = headers.get("x-forwarded-for", remote_addr) + extra = {"remote_addr": remote_addr} if remote_user: - ttl = getattr(request, 'ttl', None) + ttl = getattr(request, "ttl", None) username = self._get_username_for_request(remote_user, request) try: - token = self._create_token_for_user(username=username, - ttl=ttl) + token = self._create_token_for_user(username=username, ttl=ttl) except TTLTooLargeException as e: - abort_request(status_code=http_client.BAD_REQUEST, - message=six.text_type(e)) + abort_request( + status_code=http_client.BAD_REQUEST, message=six.text_type(e) + ) return token - LOG.audit('Access denied to anonymous user.', extra=extra) + LOG.audit("Access denied to anonymous user.", extra=extra) abort_request() @@ -131,77 +150,91 @@ def __init__(self, *args, **kwargs): self._auth_backend = get_auth_backend_instance(name=cfg.CONF.auth.backend) super(StandaloneAuthHandler, self).__init__(*args, **kwargs) - def handle_auth(self, request, headers=None, remote_addr=None, remote_user=None, - authorization=None, **kwargs): + def handle_auth( + self, + request, + headers=None, + remote_addr=None, + remote_user=None, + authorization=None, + **kwargs, + ): auth_backend = self._auth_backend.__class__.__name__ - extra = {'auth_backend': auth_backend, 'remote_addr': remote_addr} + extra = {"auth_backend": auth_backend, "remote_addr": remote_addr} if not authorization: - LOG.audit('Authorization header not provided', extra=extra) + LOG.audit("Authorization header not provided", extra=extra) abort_request() return auth_type, auth_value = authorization - if auth_type.lower() not in ['basic']: - extra['auth_type'] = auth_type - LOG.audit('Unsupported authorization type: %s' % (auth_type), extra=extra) + if auth_type.lower() not in ["basic"]: + extra["auth_type"] = auth_type + LOG.audit("Unsupported authorization type: %s" % (auth_type), extra=extra) abort_request() return try: auth_value = base64.b64decode(auth_value) except Exception: - LOG.audit('Invalid authorization header', extra=extra) + LOG.audit("Invalid authorization header", extra=extra) abort_request() return - split = auth_value.split(b':', 1) + split = auth_value.split(b":", 1) if len(split) != 2: - LOG.audit('Invalid authorization header', extra=extra) + LOG.audit("Invalid authorization header", extra=extra) abort_request() return username, password = split if six.PY3 and isinstance(username, six.binary_type): - username = username.decode('utf-8') + username = username.decode("utf-8") if six.PY3 and isinstance(password, six.binary_type): - password = password.decode('utf-8') + password = password.decode("utf-8") result = self._auth_backend.authenticate(username=username, password=password) if result is True: - ttl = getattr(request, 'ttl', None) + ttl = getattr(request, "ttl", None) username = self._get_username_for_request(username, request) try: token = self._create_token_for_user(username=username, ttl=ttl) except TTLTooLargeException as e: - abort_request(status_code=http_client.BAD_REQUEST, - message=six.text_type(e)) + abort_request( + status_code=http_client.BAD_REQUEST, message=six.text_type(e) + ) return # If remote group sync is enabled, sync the remote groups with local StackStorm roles - if cfg.CONF.rbac.sync_remote_groups and cfg.CONF.rbac.backend != 'noop': - LOG.debug('Retrieving auth backend groups for user "%s"' % (username), - extra=extra) + if cfg.CONF.rbac.sync_remote_groups and cfg.CONF.rbac.backend != "noop": + LOG.debug( + 'Retrieving auth backend groups for user "%s"' % (username), + extra=extra, + ) try: user_groups = self._auth_backend.get_user_groups(username=username) except (NotImplementedError, AttributeError): - LOG.debug('Configured auth backend doesn\'t expose user group membership ' - 'information, skipping sync...') + LOG.debug( + "Configured auth backend doesn't expose user group membership " + "information, skipping sync..." + ) return token if not user_groups: # No groups, return early return token - extra['username'] = username - extra['user_groups'] = user_groups + extra["username"] = username + extra["user_groups"] = user_groups - LOG.debug('Found "%s" groups for user "%s"' % (len(user_groups), username), - extra=extra) + LOG.debug( + 'Found "%s" groups for user "%s"' % (len(user_groups), username), + extra=extra, + ) user_db = UserDB(name=username) @@ -212,14 +245,19 @@ def handle_auth(self, request, headers=None, remote_addr=None, remote_user=None, syncer.sync(user_db=user_db, groups=user_groups) except Exception: # Note: Failed sync is not fatal - LOG.exception('Failed to synchronize remote groups for user "%s"' % (username), - extra=extra) + LOG.exception( + 'Failed to synchronize remote groups for user "%s"' + % (username), + extra=extra, + ) else: - LOG.debug('Successfully synchronized groups for user "%s"' % (username), - extra=extra) + LOG.debug( + 'Successfully synchronized groups for user "%s"' % (username), + extra=extra, + ) return token return token - LOG.audit('Invalid credentials provided', extra=extra) + LOG.audit("Invalid credentials provided", extra=extra) abort_request() diff --git a/st2auth/st2auth/sso/__init__.py b/st2auth/st2auth/sso/__init__.py index 5839059ed9d..b6d0df930a0 100644 --- a/st2auth/st2auth/sso/__init__.py +++ b/st2auth/st2auth/sso/__init__.py @@ -25,15 +25,11 @@ from st2common.util import driver_loader -__all__ = [ - 'get_available_backends', - 'get_backend_instance', - 'get_sso_backend' -] +__all__ = ["get_available_backends", "get_backend_instance", "get_sso_backend"] LOG = logging.getLogger(__name__) -BACKENDS_NAMESPACE = 'st2auth.sso.backends' +BACKENDS_NAMESPACE = "st2auth.sso.backends" def get_available_backends(): @@ -41,7 +37,9 @@ def get_available_backends(): def get_backend_instance(name): - sso_backend_cls = driver_loader.get_backend_driver(namespace=BACKENDS_NAMESPACE, name=name) + sso_backend_cls = driver_loader.get_backend_driver( + namespace=BACKENDS_NAMESPACE, name=name + ) kwargs = {} sso_backend_kwargs = cfg.CONF.auth.sso_backend_kwargs @@ -51,8 +49,8 @@ def get_backend_instance(name): kwargs = json.loads(sso_backend_kwargs) except ValueError as e: raise ValueError( - 'Failed to JSON parse backend settings for backend "%s": %s' % - (name, six.text_type(e)) + 'Failed to JSON parse backend settings for backend "%s": %s' + % (name, six.text_type(e)) ) try: @@ -60,9 +58,11 @@ def get_backend_instance(name): except Exception as e: tb_msg = traceback.format_exc() class_name = sso_backend_cls.__name__ - msg = ('Failed to instantiate SSO backend "%s" (class %s) with backend settings ' - '"%s": %s' % (name, class_name, str(kwargs), six.text_type(e))) - msg += '\n\n' + tb_msg + msg = ( + 'Failed to instantiate SSO backend "%s" (class %s) with backend settings ' + '"%s": %s' % (name, class_name, str(kwargs), six.text_type(e)) + ) + msg += "\n\n" + tb_msg exc_cls = type(e) raise exc_cls(msg) diff --git a/st2auth/st2auth/sso/base.py b/st2auth/st2auth/sso/base.py index c96782aba2d..5e111998181 100644 --- a/st2auth/st2auth/sso/base.py +++ b/st2auth/st2auth/sso/base.py @@ -16,9 +16,7 @@ import six -__all__ = [ - 'BaseSingleSignOnBackend' -] +__all__ = ["BaseSingleSignOnBackend"] @six.add_metaclass(abc.ABCMeta) @@ -32,5 +30,7 @@ def get_request_redirect_url(self, referer): raise NotImplementedError(msg) def verify_response(self, response): - msg = 'The function "verify_response" is not implemented in the base SSO backend.' + msg = ( + 'The function "verify_response" is not implemented in the base SSO backend.' + ) raise NotImplementedError(msg) diff --git a/st2auth/st2auth/sso/noop.py b/st2auth/st2auth/sso/noop.py index 6cacb5e7e9f..6699e084f3a 100644 --- a/st2auth/st2auth/sso/noop.py +++ b/st2auth/st2auth/sso/noop.py @@ -17,13 +17,11 @@ from st2auth.sso.base import BaseSingleSignOnBackend -__all__ = [ - 'NoOpSingleSignOnBackend' -] +__all__ = ["NoOpSingleSignOnBackend"] NOT_IMPLEMENTED_MESSAGE = ( 'The default "noop" SSO backend is not a proper implementation. ' - 'Please refer to the enterprise version for configuring SSO.' + "Please refer to the enterprise version for configuring SSO." ) diff --git a/st2auth/st2auth/validation.py b/st2auth/st2auth/validation.py index 924ad390f2d..ccea9060624 100644 --- a/st2auth/st2auth/validation.py +++ b/st2auth/st2auth/validation.py @@ -19,26 +19,28 @@ from st2auth.backends import get_backend_instance as get_auth_backend_instance from st2auth.backends.constants import AuthBackendCapability -__all__ = [ - 'validate_auth_backend_is_correctly_configured' -] +__all__ = ["validate_auth_backend_is_correctly_configured"] def validate_auth_backend_is_correctly_configured(): # 1. Verify correct mode is specified if cfg.CONF.auth.mode not in VALID_MODES: - msg = ('Invalid auth mode "%s" specified in the config. Valid modes are: %s' % - (cfg.CONF.auth.mode, ', '.join(VALID_MODES))) + msg = 'Invalid auth mode "%s" specified in the config. Valid modes are: %s' % ( + cfg.CONF.auth.mode, + ", ".join(VALID_MODES), + ) raise ValueError(msg) # 2. Verify that auth backend used by the user exposes group information if cfg.CONF.rbac.enable and cfg.CONF.rbac.sync_remote_groups: auth_backend = get_auth_backend_instance(name=cfg.CONF.auth.backend) - capabilies = getattr(auth_backend, 'CAPABILITIES', ()) + capabilies = getattr(auth_backend, "CAPABILITIES", ()) if AuthBackendCapability.HAS_GROUP_INFORMATION not in capabilies: - msg = ('Configured auth backend doesn\'t expose user group information. Disable ' - 'remote group synchronization or use a different backend which exposes ' - 'user group membership information.') + msg = ( + "Configured auth backend doesn't expose user group information. Disable " + "remote group synchronization or use a different backend which exposes " + "user group membership information." + ) raise ValueError(msg) return True diff --git a/st2auth/st2auth/wsgi.py b/st2auth/st2auth/wsgi.py index 2fb9bee07a4..16a44e64f35 100644 --- a/st2auth/st2auth/wsgi.py +++ b/st2auth/st2auth/wsgi.py @@ -16,6 +16,7 @@ import os from st2common.util.monkey_patch import monkey_patch + # Note: We need to perform monkey patching in the worker. If we do it in # the master process (gunicorn_config.py), it breaks tons of things # including shutdown @@ -28,8 +29,11 @@ from st2auth import app config = { - 'is_gunicorn': True, - 'config_args': ['--config-file', os.environ.get('ST2_CONFIG_PATH', '/etc/st2/st2.conf')] + "is_gunicorn": True, + "config_args": [ + "--config-file", + os.environ.get("ST2_CONFIG_PATH", "/etc/st2/st2.conf"), + ], } application = app.setup_app(config) diff --git a/st2auth/tests/base.py b/st2auth/tests/base.py index e3bc2e1a052..dc63c1094ec 100644 --- a/st2auth/tests/base.py +++ b/st2auth/tests/base.py @@ -20,7 +20,6 @@ class FunctionalTest(DbTestCase): - @classmethod def setUpClass(cls, **kwargs): super(FunctionalTest, cls).setUpClass() diff --git a/st2auth/tests/unit/controllers/v1/test_sso.py b/st2auth/tests/unit/controllers/v1/test_sso.py index 81d9dcea1d2..2b6edb1f839 100644 --- a/st2auth/tests/unit/controllers/v1/test_sso.py +++ b/st2auth/tests/unit/controllers/v1/test_sso.py @@ -13,6 +13,7 @@ # limitations under the License. import st2tests.config as tests_config + tests_config.parse_args() import json @@ -28,110 +29,125 @@ from tests.base import FunctionalTest -SSO_V1_PATH = '/v1/sso' -SSO_REQUEST_V1_PATH = SSO_V1_PATH + '/request' -SSO_CALLBACK_V1_PATH = SSO_V1_PATH + '/callback' -MOCK_REFERER = 'https://127.0.0.1' -MOCK_USER = 'stanley' +SSO_V1_PATH = "/v1/sso" +SSO_REQUEST_V1_PATH = SSO_V1_PATH + "/request" +SSO_CALLBACK_V1_PATH = SSO_V1_PATH + "/callback" +MOCK_REFERER = "https://127.0.0.1" +MOCK_USER = "stanley" class TestSingleSignOnController(FunctionalTest): - def test_sso_enabled(self): - cfg.CONF.set_override(group='auth', name='sso', override=True) + cfg.CONF.set_override(group="auth", name="sso", override=True) response = self.app.get(SSO_V1_PATH, expect_errors=False) self.assertTrue(response.status_code, http_client.OK) - self.assertDictEqual(response.json, {'enabled': True}) + self.assertDictEqual(response.json, {"enabled": True}) def test_sso_disabled(self): - cfg.CONF.set_override(group='auth', name='sso', override=False) + cfg.CONF.set_override(group="auth", name="sso", override=False) response = self.app.get(SSO_V1_PATH, expect_errors=False) self.assertTrue(response.status_code, http_client.OK) - self.assertDictEqual(response.json, {'enabled': False}) + self.assertDictEqual(response.json, {"enabled": False}) @mock.patch.object( sso_api_controller.SingleSignOnController, - '_get_sso_enabled_config', - mock.MagicMock(side_effect=KeyError('foobar'))) + "_get_sso_enabled_config", + mock.MagicMock(side_effect=KeyError("foobar")), + ) def test_unknown_exception(self): - cfg.CONF.set_override(group='auth', name='sso', override=True) + cfg.CONF.set_override(group="auth", name="sso", override=True) response = self.app.get(SSO_V1_PATH, expect_errors=False) self.assertTrue(response.status_code, http_client.OK) - self.assertDictEqual(response.json, {'enabled': False}) - self.assertTrue(sso_api_controller.SingleSignOnController._get_sso_enabled_config.called) + self.assertDictEqual(response.json, {"enabled": False}) + self.assertTrue( + sso_api_controller.SingleSignOnController._get_sso_enabled_config.called + ) class TestSingleSignOnRequestController(FunctionalTest): - @mock.patch.object( sso_api_controller.SSO_BACKEND, - 'get_request_redirect_url', - mock.MagicMock(side_effect=Exception('fooobar'))) + "get_request_redirect_url", + mock.MagicMock(side_effect=Exception("fooobar")), + ) def test_default_backend_unknown_exception(self): - expected_error = {'faultstring': 'Internal Server Error'} + expected_error = {"faultstring": "Internal Server Error"} response = self.app.get(SSO_REQUEST_V1_PATH, expect_errors=True) self.assertTrue(response.status_code, http_client.INTERNAL_SERVER_ERROR) self.assertDictEqual(response.json, expected_error) def test_default_backend_not_implemented(self): - expected_error = {'faultstring': noop.NOT_IMPLEMENTED_MESSAGE} + expected_error = {"faultstring": noop.NOT_IMPLEMENTED_MESSAGE} response = self.app.get(SSO_REQUEST_V1_PATH, expect_errors=True) self.assertTrue(response.status_code, http_client.INTERNAL_SERVER_ERROR) self.assertDictEqual(response.json, expected_error) @mock.patch.object( sso_api_controller.SSO_BACKEND, - 'get_request_redirect_url', - mock.MagicMock(return_value='https://127.0.0.1')) + "get_request_redirect_url", + mock.MagicMock(return_value="https://127.0.0.1"), + ) def test_idp_redirect(self): response = self.app.get(SSO_REQUEST_V1_PATH, expect_errors=False) self.assertTrue(response.status_code, http_client.TEMPORARY_REDIRECT) - self.assertEqual(response.location, 'https://127.0.0.1') + self.assertEqual(response.location, "https://127.0.0.1") class TestIdentityProviderCallbackController(FunctionalTest): - @mock.patch.object( sso_api_controller.SSO_BACKEND, - 'verify_response', - mock.MagicMock(side_effect=Exception('fooobar'))) + "verify_response", + mock.MagicMock(side_effect=Exception("fooobar")), + ) def test_default_backend_unknown_exception(self): - expected_error = {'faultstring': 'Internal Server Error'} - response = self.app.post_json(SSO_CALLBACK_V1_PATH, {'foo': 'bar'}, expect_errors=True) + expected_error = {"faultstring": "Internal Server Error"} + response = self.app.post_json( + SSO_CALLBACK_V1_PATH, {"foo": "bar"}, expect_errors=True + ) self.assertTrue(response.status_code, http_client.INTERNAL_SERVER_ERROR) self.assertDictEqual(response.json, expected_error) def test_default_backend_not_implemented(self): - expected_error = {'faultstring': noop.NOT_IMPLEMENTED_MESSAGE} - response = self.app.post_json(SSO_CALLBACK_V1_PATH, {'foo': 'bar'}, expect_errors=True) + expected_error = {"faultstring": noop.NOT_IMPLEMENTED_MESSAGE} + response = self.app.post_json( + SSO_CALLBACK_V1_PATH, {"foo": "bar"}, expect_errors=True + ) self.assertTrue(response.status_code, http_client.INTERNAL_SERVER_ERROR) self.assertDictEqual(response.json, expected_error) @mock.patch.object( sso_api_controller.SSO_BACKEND, - 'verify_response', - mock.MagicMock(return_value={'referer': MOCK_REFERER, 'username': MOCK_USER})) + "verify_response", + mock.MagicMock(return_value={"referer": MOCK_REFERER, "username": MOCK_USER}), + ) def test_idp_callback(self): expected_body = sso_api_controller.CALLBACK_SUCCESS_RESPONSE_BODY % MOCK_REFERER - response = self.app.post_json(SSO_CALLBACK_V1_PATH, {'foo': 'bar'}, expect_errors=False) + response = self.app.post_json( + SSO_CALLBACK_V1_PATH, {"foo": "bar"}, expect_errors=False + ) self.assertTrue(response.status_code, http_client.OK) - self.assertEqual(expected_body, response.body.decode('utf-8')) + self.assertEqual(expected_body, response.body.decode("utf-8")) - set_cookies_list = [h for h in response.headerlist if h[0] == 'Set-Cookie'] + set_cookies_list = [h for h in response.headerlist if h[0] == "Set-Cookie"] self.assertEqual(len(set_cookies_list), 1) - self.assertIn('st2-auth-token', set_cookies_list[0][1]) + self.assertIn("st2-auth-token", set_cookies_list[0][1]) - cookie = urllib.parse.unquote(set_cookies_list[0][1]).split('=') - st2_auth_token = json.loads(cookie[1].split(';')[0]) - self.assertIn('token', st2_auth_token) - self.assertEqual(st2_auth_token['user'], MOCK_USER) + cookie = urllib.parse.unquote(set_cookies_list[0][1]).split("=") + st2_auth_token = json.loads(cookie[1].split(";")[0]) + self.assertIn("token", st2_auth_token) + self.assertEqual(st2_auth_token["user"], MOCK_USER) @mock.patch.object( sso_api_controller.SSO_BACKEND, - 'verify_response', - mock.MagicMock(side_effect=auth_exc.SSOVerificationError('Verification Failed'))) + "verify_response", + mock.MagicMock( + side_effect=auth_exc.SSOVerificationError("Verification Failed") + ), + ) def test_idp_callback_verification_failed(self): - expected_error = {'faultstring': 'Verification Failed'} - response = self.app.post_json(SSO_CALLBACK_V1_PATH, {'foo': 'bar'}, expect_errors=True) + expected_error = {"faultstring": "Verification Failed"} + response = self.app.post_json( + SSO_CALLBACK_V1_PATH, {"foo": "bar"}, expect_errors=True + ) self.assertTrue(response.status_code, http_client.UNAUTHORIZED) self.assertDictEqual(response.json, expected_error) diff --git a/st2auth/tests/unit/controllers/v1/test_token.py b/st2auth/tests/unit/controllers/v1/test_token.py index ab5f12342b4..cd90a6cef14 100644 --- a/st2auth/tests/unit/controllers/v1/test_token.py +++ b/st2auth/tests/unit/controllers/v1/test_token.py @@ -29,25 +29,25 @@ from st2common.persistence.auth import User, Token, ApiKey -USERNAME = ''.join(random.choice(string.ascii_lowercase) for i in range(10)) -TOKEN_DEFAULT_PATH = '/tokens' -TOKEN_V1_PATH = '/v1/tokens' -TOKEN_VERIFY_PATH = '/v1/tokens/validate' +USERNAME = "".join(random.choice(string.ascii_lowercase) for i in range(10)) +TOKEN_DEFAULT_PATH = "/tokens" +TOKEN_V1_PATH = "/v1/tokens" +TOKEN_VERIFY_PATH = "/v1/tokens/validate" class TestTokenController(FunctionalTest): - @classmethod def setUpClass(cls, **kwargs): - kwargs['extra_environ'] = { - 'REMOTE_USER': USERNAME - } + kwargs["extra_environ"] = {"REMOTE_USER": USERNAME} super(TestTokenController, cls).setUpClass(**kwargs) def test_token_model(self): dt = date_utils.get_datetime_utc_now() - tk1 = TokenAPI(user='stanley', token=uuid.uuid4().hex, - expiry=isotime.format(dt, offset=False)) + tk1 = TokenAPI( + user="stanley", + token=uuid.uuid4().hex, + expiry=isotime.format(dt, offset=False), + ) tkdb1 = TokenAPI.to_model(tk1) self.assertIsNotNone(tkdb1) self.assertIsInstance(tkdb1, TokenDB) @@ -64,7 +64,7 @@ def test_token_model(self): def test_token_model_null_token(self): dt = date_utils.get_datetime_utc_now() - tk = TokenAPI(user='stanley', token=None, expiry=isotime.format(dt)) + tk = TokenAPI(user="stanley", token=None, expiry=isotime.format(dt)) self.assertRaises(ValueError, Token.add_or_update, TokenAPI.to_model(tk)) def test_token_model_null_user(self): @@ -73,191 +73,215 @@ def test_token_model_null_user(self): self.assertRaises(ValueError, Token.add_or_update, TokenAPI.to_model(tk)) def test_token_model_null_expiry(self): - tk = TokenAPI(user='stanley', token=uuid.uuid4().hex, expiry=None) + tk = TokenAPI(user="stanley", token=uuid.uuid4().hex, expiry=None) self.assertRaises(ValueError, Token.add_or_update, TokenAPI.to_model(tk)) def _test_token_post(self, path=TOKEN_V1_PATH): ttl = cfg.CONF.auth.token_ttl timestamp = date_utils.get_datetime_utc_now() response = self.app.post_json(path, {}, expect_errors=False) - expected_expiry = date_utils.get_datetime_utc_now() + datetime.timedelta(seconds=ttl) + expected_expiry = date_utils.get_datetime_utc_now() + datetime.timedelta( + seconds=ttl + ) expected_expiry = date_utils.add_utc_tz(expected_expiry) self.assertEqual(response.status_int, 201) - self.assertIsNotNone(response.json['token']) - self.assertEqual(response.json['user'], USERNAME) - actual_expiry = isotime.parse(response.json['expiry']) + self.assertIsNotNone(response.json["token"]) + self.assertEqual(response.json["user"], USERNAME) + actual_expiry = isotime.parse(response.json["expiry"]) self.assertLess(timestamp, actual_expiry) self.assertLess(actual_expiry, expected_expiry) return response def test_token_post_unauthorized(self): - response = self.app.post_json(TOKEN_V1_PATH, {}, expect_errors=True, extra_environ={ - 'REMOTE_USER': '' - }) + response = self.app.post_json( + TOKEN_V1_PATH, {}, expect_errors=True, extra_environ={"REMOTE_USER": ""} + ) self.assertEqual(response.status_int, 401) + @mock.patch.object(User, "get_by_name", mock.MagicMock(side_effect=Exception())) @mock.patch.object( - User, 'get_by_name', - mock.MagicMock(side_effect=Exception())) - @mock.patch.object( - User, 'add_or_update', - mock.Mock(return_value=UserDB(name=USERNAME))) + User, "add_or_update", mock.Mock(return_value=UserDB(name=USERNAME)) + ) def test_token_post_new_user(self): self._test_token_post() @mock.patch.object( - User, 'get_by_name', - mock.MagicMock(return_value=UserDB(name=USERNAME))) + User, "get_by_name", mock.MagicMock(return_value=UserDB(name=USERNAME)) + ) def test_token_post_existing_user(self): self._test_token_post() @mock.patch.object( - User, 'get_by_name', - mock.MagicMock(return_value=UserDB(name=USERNAME))) + User, "get_by_name", mock.MagicMock(return_value=UserDB(name=USERNAME)) + ) def test_token_post_success_x_api_url_header_value(self): # auth.api_url option is explicitly set - cfg.CONF.set_override('api_url', override='https://example.com', group='auth') + cfg.CONF.set_override("api_url", override="https://example.com", group="auth") resp = self._test_token_post() - self.assertEqual(resp.headers['X-API-URL'], 'https://example.com') + self.assertEqual(resp.headers["X-API-URL"], "https://example.com") # auth.api_url option is not set, url is inferred from listen host and port - cfg.CONF.set_override('api_url', override=None, group='auth') + cfg.CONF.set_override("api_url", override=None, group="auth") resp = self._test_token_post() - self.assertEqual(resp.headers['X-API-URL'], 'http://127.0.0.1:9101') + self.assertEqual(resp.headers["X-API-URL"], "http://127.0.0.1:9101") @mock.patch.object( - User, 'get_by_name', - mock.MagicMock(return_value=UserDB(name=USERNAME))) + User, "get_by_name", mock.MagicMock(return_value=UserDB(name=USERNAME)) + ) def test_token_post_default_url_path(self): self._test_token_post(path=TOKEN_DEFAULT_PATH) @mock.patch.object( - User, 'get_by_name', - mock.MagicMock(return_value=UserDB(name=USERNAME))) + User, "get_by_name", mock.MagicMock(return_value=UserDB(name=USERNAME)) + ) def test_token_post_set_ttl(self): timestamp = date_utils.add_utc_tz(date_utils.get_datetime_utc_now()) - response = self.app.post_json(TOKEN_V1_PATH, {'ttl': 60}, expect_errors=False) - expected_expiry = date_utils.get_datetime_utc_now() + datetime.timedelta(seconds=60) + response = self.app.post_json(TOKEN_V1_PATH, {"ttl": 60}, expect_errors=False) + expected_expiry = date_utils.get_datetime_utc_now() + datetime.timedelta( + seconds=60 + ) self.assertEqual(response.status_int, 201) - actual_expiry = isotime.parse(response.json['expiry']) + actual_expiry = isotime.parse(response.json["expiry"]) self.assertLess(timestamp, actual_expiry) self.assertLess(actual_expiry, expected_expiry) @mock.patch.object( - User, 'get_by_name', - mock.MagicMock(return_value=UserDB(name=USERNAME))) + User, "get_by_name", mock.MagicMock(return_value=UserDB(name=USERNAME)) + ) def test_token_post_no_data_in_body_text_plain_context_type_used(self): - response = self.app.post(TOKEN_V1_PATH, expect_errors=False, content_type='text/plain') + response = self.app.post( + TOKEN_V1_PATH, expect_errors=False, content_type="text/plain" + ) self.assertEqual(response.status_int, 201) @mock.patch.object( - User, 'get_by_name', - mock.MagicMock(return_value=UserDB(name=USERNAME))) + User, "get_by_name", mock.MagicMock(return_value=UserDB(name=USERNAME)) + ) def test_token_post_set_ttl_over_policy(self): ttl = cfg.CONF.auth.token_ttl - response = self.app.post_json(TOKEN_V1_PATH, {'ttl': ttl + 60}, expect_errors=True) - self.assertEqual(response.status_int, 400) - message = 'TTL specified %s is greater than max allowed %s.' % ( - ttl + 60, ttl + response = self.app.post_json( + TOKEN_V1_PATH, {"ttl": ttl + 60}, expect_errors=True ) - self.assertEqual(response.json['faultstring'], message) + self.assertEqual(response.status_int, 400) + message = "TTL specified %s is greater than max allowed %s." % (ttl + 60, ttl) + self.assertEqual(response.json["faultstring"], message) @mock.patch.object( - User, 'get_by_name', - mock.MagicMock(return_value=UserDB(name=USERNAME))) + User, "get_by_name", mock.MagicMock(return_value=UserDB(name=USERNAME)) + ) def test_token_post_set_bad_ttl(self): - response = self.app.post_json(TOKEN_V1_PATH, {'ttl': -1}, expect_errors=True) + response = self.app.post_json(TOKEN_V1_PATH, {"ttl": -1}, expect_errors=True) self.assertEqual(response.status_int, 400) - response = self.app.post_json(TOKEN_V1_PATH, {'ttl': 0}, expect_errors=True) + response = self.app.post_json(TOKEN_V1_PATH, {"ttl": 0}, expect_errors=True) self.assertEqual(response.status_int, 400) @mock.patch.object( - User, 'get_by_name', - mock.MagicMock(return_value=UserDB(name=USERNAME))) + User, "get_by_name", mock.MagicMock(return_value=UserDB(name=USERNAME)) + ) def test_token_get_unauthorized(self): # Create a new token. response = self.app.post_json(TOKEN_V1_PATH, expect_errors=False) # Verify the token. 401 is expected because an API key or token is not provided in header. - data = {'token': str(response.json['token'])} + data = {"token": str(response.json["token"])} response = self.app.post_json(TOKEN_VERIFY_PATH, data, expect_errors=True) self.assertEqual(response.status_int, 401) @mock.patch.object( - User, 'get_by_name', - mock.MagicMock(return_value=UserDB(name=USERNAME))) + User, "get_by_name", mock.MagicMock(return_value=UserDB(name=USERNAME)) + ) def test_token_get_unauthorized_bad_api_key(self): # Create a new token. response = self.app.post_json(TOKEN_V1_PATH, expect_errors=False) # Verify the token. 401 is expected because the API key is bad. - headers = {'St2-Api-Key': 'foobar'} - data = {'token': str(response.json['token'])} - response = self.app.post_json(TOKEN_VERIFY_PATH, data, headers=headers, expect_errors=True) + headers = {"St2-Api-Key": "foobar"} + data = {"token": str(response.json["token"])} + response = self.app.post_json( + TOKEN_VERIFY_PATH, data, headers=headers, expect_errors=True + ) self.assertEqual(response.status_int, 401) @mock.patch.object( - User, 'get_by_name', - mock.MagicMock(return_value=UserDB(name=USERNAME))) + User, "get_by_name", mock.MagicMock(return_value=UserDB(name=USERNAME)) + ) def test_token_get_unauthorized_bad_token(self): # Create a new token. response = self.app.post_json(TOKEN_V1_PATH, expect_errors=False) # Verify the token. 401 is expected because the token is bad. - headers = {'X-Auth-Token': 'foobar'} - data = {'token': str(response.json['token'])} - response = self.app.post_json(TOKEN_VERIFY_PATH, data, headers=headers, expect_errors=True) + headers = {"X-Auth-Token": "foobar"} + data = {"token": str(response.json["token"])} + response = self.app.post_json( + TOKEN_VERIFY_PATH, data, headers=headers, expect_errors=True + ) self.assertEqual(response.status_int, 401) @mock.patch.object( - User, 'get_by_name', - mock.MagicMock(return_value=UserDB(name=USERNAME))) + User, "get_by_name", mock.MagicMock(return_value=UserDB(name=USERNAME)) + ) @mock.patch.object( - ApiKey, 'get', - mock.MagicMock(return_value=ApiKeyDB(user=USERNAME, key_hash='foobar'))) + ApiKey, + "get", + mock.MagicMock(return_value=ApiKeyDB(user=USERNAME, key_hash="foobar")), + ) def test_token_get_auth_with_api_key(self): # Create a new token. response = self.app.post_json(TOKEN_V1_PATH, expect_errors=False) # Verify the token. Use an API key to authenticate with the st2 auth get token endpoint. - headers = {'St2-Api-Key': 'foobar'} - data = {'token': str(response.json['token'])} - response = self.app.post_json(TOKEN_VERIFY_PATH, data, headers=headers, expect_errors=True) + headers = {"St2-Api-Key": "foobar"} + data = {"token": str(response.json["token"])} + response = self.app.post_json( + TOKEN_VERIFY_PATH, data, headers=headers, expect_errors=True + ) self.assertEqual(response.status_int, 200) - self.assertTrue(response.json['valid']) + self.assertTrue(response.json["valid"]) @mock.patch.object( - User, 'get_by_name', - mock.MagicMock(return_value=UserDB(name=USERNAME))) + User, "get_by_name", mock.MagicMock(return_value=UserDB(name=USERNAME)) + ) def test_token_get_auth_with_token(self): # Create a new token. response = self.app.post_json(TOKEN_V1_PATH, {}, expect_errors=False) # Verify the token. Use a token to authenticate with the st2 auth get token endpoint. - headers = {'X-Auth-Token': str(response.json['token'])} - data = {'token': str(response.json['token'])} - response = self.app.post_json(TOKEN_VERIFY_PATH, data, headers=headers, expect_errors=True) + headers = {"X-Auth-Token": str(response.json["token"])} + data = {"token": str(response.json["token"])} + response = self.app.post_json( + TOKEN_VERIFY_PATH, data, headers=headers, expect_errors=True + ) self.assertEqual(response.status_int, 200) - self.assertTrue(response.json['valid']) + self.assertTrue(response.json["valid"]) @mock.patch.object( - User, 'get_by_name', - mock.MagicMock(return_value=UserDB(name=USERNAME))) + User, "get_by_name", mock.MagicMock(return_value=UserDB(name=USERNAME)) + ) @mock.patch.object( - ApiKey, 'get', - mock.MagicMock(return_value=ApiKeyDB(user=USERNAME, key_hash='foobar'))) + ApiKey, + "get", + mock.MagicMock(return_value=ApiKeyDB(user=USERNAME, key_hash="foobar")), + ) @mock.patch.object( - Token, 'get', + Token, + "get", mock.MagicMock( return_value=TokenDB( - user=USERNAME, token='12345', - expiry=date_utils.get_datetime_utc_now() - datetime.timedelta(minutes=1)))) + user=USERNAME, + token="12345", + expiry=date_utils.get_datetime_utc_now() + - datetime.timedelta(minutes=1), + ) + ), + ) def test_token_get_unauthorized_bad_ttl(self): # Verify the token. 400 is expected because the token has expired. - headers = {'St2-Api-Key': 'foobar'} - data = {'token': '12345'} - response = self.app.post_json(TOKEN_VERIFY_PATH, data, headers=headers, expect_errors=False) + headers = {"St2-Api-Key": "foobar"} + data = {"token": "12345"} + response = self.app.post_json( + TOKEN_VERIFY_PATH, data, headers=headers, expect_errors=False + ) self.assertEqual(response.status_int, 200) - self.assertFalse(response.json['valid']) + self.assertFalse(response.json["valid"]) diff --git a/st2auth/tests/unit/test_auth_backends.py b/st2auth/tests/unit/test_auth_backends.py index 96856e8a3e2..b367e328daf 100644 --- a/st2auth/tests/unit/test_auth_backends.py +++ b/st2auth/tests/unit/test_auth_backends.py @@ -25,4 +25,4 @@ class AuthenticationBackendsTestCase(unittest2.TestCase): def test_flat_file_backend_is_available_by_default(self): available_backends = get_available_backends() - self.assertIn('flat_file', available_backends) + self.assertIn("flat_file", available_backends) diff --git a/st2auth/tests/unit/test_handlers.py b/st2auth/tests/unit/test_handlers.py index a3627019d8a..cf00e642a69 100644 --- a/st2auth/tests/unit/test_handlers.py +++ b/st2auth/tests/unit/test_handlers.py @@ -30,25 +30,23 @@ from st2tests.mocks.auth import MockRequest from st2tests.mocks.auth import get_mock_backend -__all__ = [ - 'AuthHandlerTestCase' -] +__all__ = ["AuthHandlerTestCase"] -@mock.patch('st2auth.handlers.get_auth_backend_instance', get_mock_backend) +@mock.patch("st2auth.handlers.get_auth_backend_instance", get_mock_backend) class AuthHandlerTestCase(CleanDbTestCase): def setUp(self): super(AuthHandlerTestCase, self).setUp() - cfg.CONF.auth.backend = 'mock' + cfg.CONF.auth.backend = "mock" def test_proxy_handler(self): h = handlers.ProxyAuthHandler() request = {} token = h.handle_auth( - request, headers={}, remote_addr=None, - remote_user='test_proxy_handler') - self.assertEqual(token.user, 'test_proxy_handler') + request, headers={}, remote_addr=None, remote_user="test_proxy_handler" + ) + self.assertEqual(token.user, "test_proxy_handler") def test_standalone_bad_auth_type(self): h = handlers.StandaloneAuthHandler() @@ -56,8 +54,12 @@ def test_standalone_bad_auth_type(self): with self.assertRaises(exc.HTTPUnauthorized): h.handle_auth( - request, headers={}, remote_addr=None, - remote_user=None, authorization=('complex', DUMMY_CREDS)) + request, + headers={}, + remote_addr=None, + remote_user=None, + authorization=("complex", DUMMY_CREDS), + ) def test_standalone_no_auth(self): h = handlers.StandaloneAuthHandler() @@ -65,8 +67,12 @@ def test_standalone_no_auth(self): with self.assertRaises(exc.HTTPUnauthorized): h.handle_auth( - request, headers={}, remote_addr=None, - remote_user=None, authorization=None) + request, + headers={}, + remote_addr=None, + remote_user=None, + authorization=None, + ) def test_standalone_bad_auth_value(self): h = handlers.StandaloneAuthHandler() @@ -74,109 +80,159 @@ def test_standalone_bad_auth_value(self): with self.assertRaises(exc.HTTPUnauthorized): h.handle_auth( - request, headers={}, remote_addr=None, - remote_user=None, authorization=('basic', 'gobblegobble')) + request, + headers={}, + remote_addr=None, + remote_user=None, + authorization=("basic", "gobblegobble"), + ) def test_standalone_handler(self): h = handlers.StandaloneAuthHandler() request = {} token = h.handle_auth( - request, headers={}, remote_addr=None, - remote_user=None, authorization=('basic', DUMMY_CREDS)) - self.assertEqual(token.user, 'auser') + request, + headers={}, + remote_addr=None, + remote_user=None, + authorization=("basic", DUMMY_CREDS), + ) + self.assertEqual(token.user, "auser") def test_standalone_handler_ttl(self): h = handlers.StandaloneAuthHandler() token1 = h.handle_auth( - MockRequest(23), headers={}, remote_addr=None, - remote_user=None, authorization=('basic', DUMMY_CREDS)) + MockRequest(23), + headers={}, + remote_addr=None, + remote_user=None, + authorization=("basic", DUMMY_CREDS), + ) token2 = h.handle_auth( - MockRequest(2300), headers={}, remote_addr=None, - remote_user=None, authorization=('basic', DUMMY_CREDS)) - self.assertEqual(token1.user, 'auser') + MockRequest(2300), + headers={}, + remote_addr=None, + remote_user=None, + authorization=("basic", DUMMY_CREDS), + ) + self.assertEqual(token1.user, "auser") self.assertNotEqual(token1.expiry, token2.expiry) @mock.patch.object( - User, 'get_by_name', - mock.MagicMock(return_value=UserDB(name='auser'))) + User, "get_by_name", mock.MagicMock(return_value=UserDB(name="auser")) + ) def test_standalone_for_user_not_service(self): h = handlers.StandaloneAuthHandler() request = MockRequest(60) - request.user = 'anotheruser' + request.user = "anotheruser" with self.assertRaises(exc.HTTPBadRequest): h.handle_auth( - request, headers={}, remote_addr=None, - remote_user=None, authorization=('basic', DUMMY_CREDS)) + request, + headers={}, + remote_addr=None, + remote_user=None, + authorization=("basic", DUMMY_CREDS), + ) @mock.patch.object( - User, 'get_by_name', - mock.MagicMock(return_value=UserDB(name='auser', is_service=True))) + User, + "get_by_name", + mock.MagicMock(return_value=UserDB(name="auser", is_service=True)), + ) def test_standalone_for_user_service(self): h = handlers.StandaloneAuthHandler() request = MockRequest(60) - request.user = 'anotheruser' + request.user = "anotheruser" token = h.handle_auth( - request, headers={}, remote_addr=None, - remote_user=None, authorization=('basic', DUMMY_CREDS)) - self.assertEqual(token.user, 'anotheruser') + request, + headers={}, + remote_addr=None, + remote_user=None, + authorization=("basic", DUMMY_CREDS), + ) + self.assertEqual(token.user, "anotheruser") def test_standalone_for_user_not_found(self): h = handlers.StandaloneAuthHandler() request = MockRequest(60) - request.user = 'anotheruser' + request.user = "anotheruser" with self.assertRaises(exc.HTTPBadRequest): h.handle_auth( - request, headers={}, remote_addr=None, - remote_user=None, authorization=('basic', DUMMY_CREDS)) + request, + headers={}, + remote_addr=None, + remote_user=None, + authorization=("basic", DUMMY_CREDS), + ) def test_standalone_impersonate_user_not_found(self): h = handlers.StandaloneAuthHandler() request = MockRequest(60) - request.impersonate_user = 'anotheruser' + request.impersonate_user = "anotheruser" with self.assertRaises(exc.HTTPBadRequest): h.handle_auth( - request, headers={}, remote_addr=None, - remote_user=None, authorization=('basic', DUMMY_CREDS)) + request, + headers={}, + remote_addr=None, + remote_user=None, + authorization=("basic", DUMMY_CREDS), + ) @mock.patch.object( - User, 'get_by_name', - mock.MagicMock(return_value=UserDB(name='auser', is_service=True))) + User, + "get_by_name", + mock.MagicMock(return_value=UserDB(name="auser", is_service=True)), + ) @mock.patch.object( - User, 'get_by_nickname', - mock.MagicMock(return_value=UserDB(name='anotheruser', is_service=True))) + User, + "get_by_nickname", + mock.MagicMock(return_value=UserDB(name="anotheruser", is_service=True)), + ) def test_standalone_impersonate_user_with_nick_origin(self): h = handlers.StandaloneAuthHandler() request = MockRequest(60) - request.impersonate_user = 'anotheruser' - request.nickname_origin = 'slack' + request.impersonate_user = "anotheruser" + request.nickname_origin = "slack" token = h.handle_auth( - request, headers={}, remote_addr=None, - remote_user=None, authorization=('basic', DUMMY_CREDS)) - self.assertEqual(token.user, 'anotheruser') + request, + headers={}, + remote_addr=None, + remote_user=None, + authorization=("basic", DUMMY_CREDS), + ) + self.assertEqual(token.user, "anotheruser") def test_standalone_impersonate_user_no_origin(self): h = handlers.StandaloneAuthHandler() request = MockRequest(60) - request.impersonate_user = '@anotheruser' + request.impersonate_user = "@anotheruser" with self.assertRaises(exc.HTTPBadRequest): h.handle_auth( - request, headers={}, remote_addr=None, - remote_user=None, authorization=('basic', DUMMY_CREDS)) + request, + headers={}, + remote_addr=None, + remote_user=None, + authorization=("basic", DUMMY_CREDS), + ) def test_password_contains_colon(self): h = handlers.StandaloneAuthHandler() request = MockRequest(60) - authorization = ('Basic', base64.b64encode(b'username:password:password')) + authorization = ("Basic", base64.b64encode(b"username:password:password")) token = h.handle_auth( - request, headers={}, remote_addr=None, - remote_user=None, authorization=authorization) - self.assertEqual(token.user, 'username') + request, + headers={}, + remote_addr=None, + remote_user=None, + authorization=authorization, + ) + self.assertEqual(token.user, "username") diff --git a/st2auth/tests/unit/test_validation_utils.py b/st2auth/tests/unit/test_validation_utils.py index 21ab5e26b5e..213e106625f 100644 --- a/st2auth/tests/unit/test_validation_utils.py +++ b/st2auth/tests/unit/test_validation_utils.py @@ -19,9 +19,7 @@ from st2auth.validation import validate_auth_backend_is_correctly_configured from st2tests import config as tests_config -__all__ = [ - 'ValidationUtilsTestCase' -] +__all__ = ["ValidationUtilsTestCase"] class ValidationUtilsTestCase(unittest2.TestCase): @@ -34,22 +32,31 @@ def test_validate_auth_backend_is_correctly_configured_success(self): self.assertTrue(result) def test_validate_auth_backend_is_correctly_configured_invalid_backend(self): - cfg.CONF.set_override(group='auth', name='mode', override='invalid') - expected_msg = ('Invalid auth mode "invalid" specified in the config. ' - 'Valid modes are: proxy, standalone') - self.assertRaisesRegexp(ValueError, expected_msg, - validate_auth_backend_is_correctly_configured) - - def test_validate_auth_backend_is_correctly_configured_backend_doesnt_expose_groups(self): + cfg.CONF.set_override(group="auth", name="mode", override="invalid") + expected_msg = ( + 'Invalid auth mode "invalid" specified in the config. ' + "Valid modes are: proxy, standalone" + ) + self.assertRaisesRegexp( + ValueError, expected_msg, validate_auth_backend_is_correctly_configured + ) + + def test_validate_auth_backend_is_correctly_configured_backend_doesnt_expose_groups( + self, + ): # Flat file backend doesn't expose user group membership information aha provide # "has group info" capability - cfg.CONF.set_override(group='auth', name='backend', override='flat_file') - cfg.CONF.set_override(group='auth', name='backend_kwargs', - override='{"file_path": "dummy"}') - cfg.CONF.set_override(group='rbac', name='enable', override=True) - cfg.CONF.set_override(group='rbac', name='sync_remote_groups', override=True) - - expected_msg = ('Configured auth backend doesn\'t expose user group information. Disable ' - 'remote group synchronization or') - self.assertRaisesRegexp(ValueError, expected_msg, - validate_auth_backend_is_correctly_configured) + cfg.CONF.set_override(group="auth", name="backend", override="flat_file") + cfg.CONF.set_override( + group="auth", name="backend_kwargs", override='{"file_path": "dummy"}' + ) + cfg.CONF.set_override(group="rbac", name="enable", override=True) + cfg.CONF.set_override(group="rbac", name="sync_remote_groups", override=True) + + expected_msg = ( + "Configured auth backend doesn't expose user group information. Disable " + "remote group synchronization or" + ) + self.assertRaisesRegexp( + ValueError, expected_msg, validate_auth_backend_is_correctly_configured + ) diff --git a/st2client/dist_utils.py b/st2client/dist_utils.py index a6f62c8cc2a..2f2043cf29f 100644 --- a/st2client/dist_utils.py +++ b/st2client/dist_utils.py @@ -43,17 +43,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" __all__ = [ - 'check_pip_is_installed', - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_is_installed", + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] @@ -64,15 +64,15 @@ def check_pip_is_installed(): try: import pip # NOQA except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) return True -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ @@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'): import pip if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) return True @@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path): reqs = [] def _get_link(line): - vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+'] + vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"] for vcs_prefix in vcs_prefixes: - if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)): - req_name = re.findall('.*#egg=(.+)([&|@]).*$', line) + if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)): + req_name = re.findall(".*#egg=(.+)([&|@]).*$", line) if not req_name: - req_name = re.findall('.*#egg=(.+?)$', line) + req_name = re.findall(".*#egg=(.+?)$", line) else: req_name = req_name[0] if not req_name: - raise ValueError('Line "%s" is missing "#egg="' % (line)) + raise ValueError( + 'Line "%s" is missing "#egg="' % (line) + ) - link = line.replace('-e ', '').strip() + link = line.replace("-e ", "").strip() return link, req_name[0] return None, None - with open(requirements_file_path, 'r') as fp: + with open(requirements_file_path, "r") as fp: for line in fp.readlines(): line = line.strip() - if line.startswith('#') or not line: + if line.startswith("#") or not line: continue link, req_name = _get_link(line=line) @@ -131,8 +135,8 @@ def _get_link(line): else: req_name = line - if ';' in req_name: - req_name = req_name.split(';')[0].strip() + if ";" in req_name: + req_name = req_name.split(";")[0].strip() reqs.append(req_name) @@ -146,7 +150,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -155,14 +159,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/st2client/setup.py b/st2client/setup.py index 916b2823012..b318aed359f 100644 --- a/st2client/setup.py +++ b/st2client/setup.py @@ -26,10 +26,10 @@ check_pip_version() -ST2_COMPONENT = 'st2client' +ST2_COMPONENT = "st2client" BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt') -README_FILE = os.path.join(BASE_DIR, 'README.rst') +REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt") +README_FILE = os.path.join(BASE_DIR, "README.rst") install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE) apply_vagrant_workaround() @@ -40,43 +40,41 @@ setup( name=ST2_COMPONENT, version=__version__, - description=('Python client library and CLI for the StackStorm (st2) event-driven ' - 'automation platform.'), + description=( + "Python client library and CLI for the StackStorm (st2) event-driven " + "automation platform." + ), long_description=readme, - author='StackStorm', - author_email='info@stackstorm.com', - url='https://stackstorm.com/', + author="StackStorm", + author_email="info@stackstorm.com", + url="https://stackstorm.com/", classifiers=[ - 'Development Status :: 5 - Production/Stable', - 'Intended Audience :: Information Technology', - 'Intended Audience :: Developers', - 'Intended Audience :: System Administrators', - 'License :: OSI Approved :: Apache Software License', - 'Operating System :: POSIX :: Linux', - 'Programming Language :: Python', - 'Programming Language :: Python :: 2', - 'Programming Language :: Python :: 2.7' + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Information Technology", + "Intended Audience :: Developers", + "Intended Audience :: System Administrators", + "License :: OSI Approved :: Apache Software License", + "Operating System :: POSIX :: Linux", + "Programming Language :: Python", + "Programming Language :: Python :: 2", + "Programming Language :: Python :: 2.7", ], install_requires=install_reqs, dependency_links=dep_links, test_suite=ST2_COMPONENT, zip_safe=False, include_package_data=True, - packages=find_packages(exclude=['setuptools', 'tests']), - entry_points={ - 'console_scripts': [ - 'st2 = st2client.shell:main' - ] - }, + packages=find_packages(exclude=["setuptools", "tests"]), + entry_points={"console_scripts": ["st2 = st2client.shell:main"]}, project_urls={ - 'Pack Exchange': 'https://exchange.stackstorm.org', - 'Repository': 'https://github.com/StackStorm/st2', - 'Documentation': 'https://docs.stackstorm.com', - 'Community': 'https://stackstorm.com/community-signup', - 'Questions': 'https://forum.stackstorm.com/', - 'Donate': 'https://funding.communitybridge.org/projects/stackstorm', - 'News/Blog': 'https://stackstorm.com/blog', - 'Security': 'https://docs.stackstorm.com/latest/security.html', - 'Bug Reports': 'https://github.com/StackStorm/st2/issues', - } + "Pack Exchange": "https://exchange.stackstorm.org", + "Repository": "https://github.com/StackStorm/st2", + "Documentation": "https://docs.stackstorm.com", + "Community": "https://stackstorm.com/community-signup", + "Questions": "https://forum.stackstorm.com/", + "Donate": "https://funding.communitybridge.org/projects/stackstorm", + "News/Blog": "https://stackstorm.com/blog", + "Security": "https://docs.stackstorm.com/latest/security.html", + "Bug Reports": "https://github.com/StackStorm/st2/issues", + }, ) diff --git a/st2client/st2client/__init__.py b/st2client/st2client/__init__.py index bbe290db9a7..e6d3f15e0bd 100644 --- a/st2client/st2client/__init__.py +++ b/st2client/st2client/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = "3.4dev" diff --git a/st2client/st2client/base.py b/st2client/st2client/base.py index e4355407269..54b7a91b142 100644 --- a/st2client/st2client/base.py +++ b/st2client/st2client/base.py @@ -38,9 +38,7 @@ from st2client.utils.date import parse as parse_isotime from st2client.utils.misc import merge_dicts -__all__ = [ - 'BaseCLIApp' -] +__all__ = ["BaseCLIApp"] # Fix for "os.getlogin()) OSError: [Errno 2] No such file or directory" os.getlogin = lambda: pwd.getpwuid(os.getuid())[0] @@ -51,14 +49,14 @@ TOKEN_EXPIRATION_GRACE_PERIOD_SECONDS = 15 CONFIG_OPTION_TO_CLIENT_KWARGS_MAP = { - 'base_url': ['general', 'base_url'], - 'auth_url': ['auth', 'url'], - 'stream_url': ['stream', 'url'], - 'api_url': ['api', 'url'], - 'api_version': ['general', 'api_version'], - 'api_key': ['credentials', 'api_key'], - 'cacert': ['general', 'cacert'], - 'debug': ['cli', 'debug'] + "base_url": ["general", "base_url"], + "auth_url": ["auth", "url"], + "stream_url": ["stream", "url"], + "api_url": ["api", "url"], + "api_version": ["general", "api_version"], + "api_key": ["credentials", "api_key"], + "cacert": ["general", "cacert"], + "debug": ["cli", "debug"], } @@ -74,7 +72,7 @@ class BaseCLIApp(object): SKIP_AUTH_CLASSES = [] def get_client(self, args, debug=False): - ST2_CLI_SKIP_CONFIG = os.environ.get('ST2_CLI_SKIP_CONFIG', 0) + ST2_CLI_SKIP_CONFIG = os.environ.get("ST2_CLI_SKIP_CONFIG", 0) ST2_CLI_SKIP_CONFIG = int(ST2_CLI_SKIP_CONFIG) skip_config = args.skip_config @@ -82,12 +80,19 @@ def get_client(self, args, debug=False): # Note: Options provided as the CLI argument have the highest precedence # Precedence order: cli arguments > environment variables > rc file variables - cli_options = ['base_url', 'auth_url', 'api_url', 'stream_url', 'api_version', 'cacert'] + cli_options = [ + "base_url", + "auth_url", + "api_url", + "stream_url", + "api_version", + "cacert", + ] cli_options = {opt: getattr(args, opt, None) for opt in cli_options} if cli_options.get("cacert", None) is not None: - if cli_options["cacert"].lower() in ['true', '1', 't', 'y', 'yes']: + if cli_options["cacert"].lower() in ["true", "1", "t", "y", "yes"]: cli_options["cacert"] = True - elif cli_options["cacert"].lower() in ['false', '0', 'f', 'no']: + elif cli_options["cacert"].lower() in ["false", "0", "f", "no"]: cli_options["cacert"] = False config_file_options = self._get_config_file_options(args=args) @@ -98,20 +103,22 @@ def get_client(self, args, debug=False): kwargs = merge_dicts(kwargs, config_file_options) kwargs = merge_dicts(kwargs, cli_options) - kwargs['debug'] = debug + kwargs["debug"] = debug client = Client(**kwargs) if skip_config: # Config parsing is skipped - self.LOG.info('Skipping parsing CLI config') + self.LOG.info("Skipping parsing CLI config") return client # Ok to use config at this point rc_config = get_config() # Silence SSL warnings - silence_ssl_warnings = rc_config.get('general', {}).get('silence_ssl_warnings', False) + silence_ssl_warnings = rc_config.get("general", {}).get( + "silence_ssl_warnings", False + ) if silence_ssl_warnings: # pylint: disable=no-member requests.packages.urllib3.disable_warnings(InsecureRequestWarning) @@ -127,34 +134,45 @@ def get_client(self, args, debug=False): # We also skip automatic authentication if token is provided via the environment variable # or as a command line argument - env_var_token = os.environ.get('ST2_AUTH_TOKEN', None) - cli_argument_token = getattr(args, 'token', None) - env_var_api_key = os.environ.get('ST2_API_KEY', None) - cli_argument_api_key = getattr(args, 'api_key', None) - if env_var_token or cli_argument_token or env_var_api_key or cli_argument_api_key: + env_var_token = os.environ.get("ST2_AUTH_TOKEN", None) + cli_argument_token = getattr(args, "token", None) + env_var_api_key = os.environ.get("ST2_API_KEY", None) + cli_argument_api_key = getattr(args, "api_key", None) + if ( + env_var_token + or cli_argument_token + or env_var_api_key + or cli_argument_api_key + ): return client # If credentials are provided in the CLI config use them and try to authenticate - credentials = rc_config.get('credentials', {}) - username = credentials.get('username', None) - password = credentials.get('password', None) - cache_token = rc_config.get('cli', {}).get('cache_token', False) + credentials = rc_config.get("credentials", {}) + username = credentials.get("username", None) + password = credentials.get("password", None) + cache_token = rc_config.get("cli", {}).get("cache_token", False) if username: # Credentials are provided, try to authenticate agaist the API try: - token = self._get_auth_token(client=client, username=username, password=password, - cache_token=cache_token) + token = self._get_auth_token( + client=client, + username=username, + password=password, + cache_token=cache_token, + ) except requests.exceptions.ConnectionError as e: - self.LOG.warn('Auth API server is not available, skipping authentication.') + self.LOG.warn( + "Auth API server is not available, skipping authentication." + ) self.LOG.exception(e) return client except Exception as e: - print('Failed to authenticate with credentials provided in the config.') + print("Failed to authenticate with credentials provided in the config.") raise e client.token = token # TODO: Hack, refactor when splitting out the client - os.environ['ST2_AUTH_TOKEN'] = token + os.environ["ST2_AUTH_TOKEN"] = token return client @@ -166,9 +184,12 @@ def _get_config_file_options(self, args, validate_config_permissions=False): :rtype: ``dict`` """ rc_options = self._parse_config_file( - args=args, validate_config_permissions=validate_config_permissions) + args=args, validate_config_permissions=validate_config_permissions + ) result = {} - for kwarg_name, (section, option) in six.iteritems(CONFIG_OPTION_TO_CLIENT_KWARGS_MAP): + for kwarg_name, (section, option) in six.iteritems( + CONFIG_OPTION_TO_CLIENT_KWARGS_MAP + ): result[kwarg_name] = rc_options.get(section, {}).get(option, None) return result @@ -176,10 +197,12 @@ def _get_config_file_options(self, args, validate_config_permissions=False): def _parse_config_file(self, args, validate_config_permissions=False): config_file_path = self._get_config_file_path(args=args) - parser = CLIConfigParser(config_file_path=config_file_path, - validate_config_exists=False, - validate_config_permissions=validate_config_permissions, - log=self.LOG) + parser = CLIConfigParser( + config_file_path=config_file_path, + validate_config_exists=False, + validate_config_permissions=validate_config_permissions, + log=self.LOG, + ) result = parser.parse() return result @@ -189,7 +212,7 @@ def _get_config_file_path(self, args): :rtype: ``str`` """ - path = os.environ.get('ST2_CONFIG_FILE', ST2_CONFIG_PATH) + path = os.environ.get("ST2_CONFIG_FILE", ST2_CONFIG_PATH) if args.config_file: path = args.config_file @@ -212,15 +235,16 @@ def _get_auth_token(self, client, username, password, cache_token): :rtype: ``str`` """ if cache_token: - token = self._get_cached_auth_token(client=client, username=username, - password=password) + token = self._get_cached_auth_token( + client=client, username=username, password=password + ) else: token = None if not token: # Token is either expired or not available - token_obj = self._authenticate_and_retrieve_auth_token(client=client, - username=username, - password=password) + token_obj = self._authenticate_and_retrieve_auth_token( + client=client, username=username, password=password + ) self._cache_auth_token(token_obj=token_obj) token = token_obj.token @@ -243,10 +267,12 @@ def _get_cached_auth_token(self, client, username, password): if not os.access(ST2_CONFIG_DIRECTORY, os.R_OK): # We don't have read access to the file with a cached token - message = ('Unable to retrieve cached token from "%s" (user %s doesn\'t have read ' - 'access to the parent directory). Subsequent requests won\'t use a ' - 'cached token meaning they may be slower.' % (cached_token_path, - os.getlogin())) + message = ( + 'Unable to retrieve cached token from "%s" (user %s doesn\'t have read ' + "access to the parent directory). Subsequent requests won't use a " + "cached token meaning they may be slower." + % (cached_token_path, os.getlogin()) + ) self.LOG.warn(message) return None @@ -255,9 +281,11 @@ def _get_cached_auth_token(self, client, username, password): if not os.access(cached_token_path, os.R_OK): # We don't have read access to the file with a cached token - message = ('Unable to retrieve cached token from "%s" (user %s doesn\'t have read ' - 'access to this file). Subsequent requests won\'t use a cached token ' - 'meaning they may be slower.' % (cached_token_path, os.getlogin())) + message = ( + 'Unable to retrieve cached token from "%s" (user %s doesn\'t have read ' + "access to this file). Subsequent requests won't use a cached token " + "meaning they may be slower." % (cached_token_path, os.getlogin()) + ) self.LOG.warn(message) return None @@ -267,9 +295,11 @@ def _get_cached_auth_token(self, client, username, password): if others_st_mode >= 2: # Every user has access to this file which is dangerous - message = ('Permissions (%s) for cached token file "%s" are too permissive. Please ' - 'restrict the permissions and make sure only your own user can read ' - 'from or write to the file.' % (file_st_mode, cached_token_path)) + message = ( + 'Permissions (%s) for cached token file "%s" are too permissive. Please ' + "restrict the permissions and make sure only your own user can read " + "from or write to the file." % (file_st_mode, cached_token_path) + ) self.LOG.warn(message) with open(cached_token_path) as fp: @@ -278,16 +308,20 @@ def _get_cached_auth_token(self, client, username, password): try: data = json.loads(data) - token = data['token'] - expire_timestamp = data['expire_timestamp'] + token = data["token"] + expire_timestamp = data["expire_timestamp"] except Exception as e: - msg = ('File "%s" with cached token is corrupted or invalid (%s). Please delete ' - ' this file' % (cached_token_path, six.text_type(e))) + msg = ( + 'File "%s" with cached token is corrupted or invalid (%s). Please delete ' + " this file" % (cached_token_path, six.text_type(e)) + ) raise ValueError(msg) now = int(time.time()) if (expire_timestamp - TOKEN_EXPIRATION_GRACE_PERIOD_SECONDS) < now: - self.LOG.debug('Cached token from file "%s" has expired' % (cached_token_path)) + self.LOG.debug( + 'Cached token from file "%s" has expired' % (cached_token_path) + ) # Token has expired return None @@ -312,19 +346,25 @@ def _cache_auth_token(self, token_obj): if not os.access(ST2_CONFIG_DIRECTORY, os.W_OK): # We don't have write access to the file with a cached token - message = ('Unable to write token to "%s" (user %s doesn\'t have write ' - 'access to the parent directory). Subsequent requests won\'t use a ' - 'cached token meaning they may be slower.' % (cached_token_path, - os.getlogin())) + message = ( + 'Unable to write token to "%s" (user %s doesn\'t have write ' + "access to the parent directory). Subsequent requests won't use a " + "cached token meaning they may be slower." + % (cached_token_path, os.getlogin()) + ) self.LOG.warn(message) return None - if os.path.isfile(cached_token_path) and not os.access(cached_token_path, os.W_OK): + if os.path.isfile(cached_token_path) and not os.access( + cached_token_path, os.W_OK + ): # We don't have write access to the file with a cached token - message = ('Unable to write token to "%s" (user %s doesn\'t have write ' - 'access to this file). Subsequent requests won\'t use a ' - 'cached token meaning they may be slower.' % (cached_token_path, - os.getlogin())) + message = ( + 'Unable to write token to "%s" (user %s doesn\'t have write ' + "access to this file). Subsequent requests won't use a " + "cached token meaning they may be slower." + % (cached_token_path, os.getlogin()) + ) self.LOG.warn(message) return None @@ -333,8 +373,8 @@ def _cache_auth_token(self, token_obj): expire_timestamp = calendar.timegm(expire_timestamp.timetuple()) data = {} - data['token'] = token - data['expire_timestamp'] = expire_timestamp + data["token"] = token + data["expire_timestamp"] = expire_timestamp data = json.dumps(data) # Note: We explictly use fdopen instead of open + chmod to avoid a security issue. @@ -342,7 +382,7 @@ def _cache_auth_token(self, token_obj): # open and chmod) when file can potentially be read by other users if the default # permissions used during create allow that. fd = os.open(cached_token_path, os.O_WRONLY | os.O_CREAT, 0o660) - with os.fdopen(fd, 'w') as fp: + with os.fdopen(fd, "w") as fp: fp.write(data) os.chmod(cached_token_path, 0o660) @@ -350,8 +390,12 @@ def _cache_auth_token(self, token_obj): return True def _authenticate_and_retrieve_auth_token(self, client, username, password): - manager = models.ResourceManager(models.Token, client.endpoints['auth'], - cacert=client.cacert, debug=client.debug) + manager = models.ResourceManager( + models.Token, + client.endpoints["auth"], + cacert=client.cacert, + debug=client.debug, + ) instance = models.Token() instance = manager.create(instance, auth=(username, password)) return instance @@ -360,7 +404,7 @@ def _get_cached_token_path_for_user(self, username): """ Retrieve cached token path for the provided username. """ - file_name = 'token-%s' % (username) + file_name = "token-%s" % (username) result = os.path.abspath(os.path.join(ST2_CONFIG_DIRECTORY, file_name)) return result @@ -368,10 +412,10 @@ def _print_config(self, args): config = self._parse_config_file(args=args, validate_config_permissions=False) for section, options in six.iteritems(config): - print('[%s]' % (section)) + print("[%s]" % (section)) for name, value in six.iteritems(options): - print('%s = %s' % (name, value)) + print("%s = %s" % (name, value)) def _print_debug_info(self, args): # Print client settings @@ -388,19 +432,19 @@ def _print_client_settings(self, args): config_file_path = self._get_config_file_path(args=args) - print('CLI settings:') - print('----------------') - print('Config file path: %s' % (config_file_path)) - print('Client settings:') - print('----------------') - print('ST2_BASE_URL: %s' % (client.endpoints['base'])) - print('ST2_AUTH_URL: %s' % (client.endpoints['auth'])) - print('ST2_API_URL: %s' % (client.endpoints['api'])) - print('ST2_STREAM_URL: %s' % (client.endpoints['stream'])) - print('ST2_AUTH_TOKEN: %s' % (os.environ.get('ST2_AUTH_TOKEN'))) - print('') - print('Proxy settings:') - print('---------------') - print('HTTP_PROXY: %s' % (os.environ.get('HTTP_PROXY', ''))) - print('HTTPS_PROXY: %s' % (os.environ.get('HTTPS_PROXY', ''))) - print('') + print("CLI settings:") + print("----------------") + print("Config file path: %s" % (config_file_path)) + print("Client settings:") + print("----------------") + print("ST2_BASE_URL: %s" % (client.endpoints["base"])) + print("ST2_AUTH_URL: %s" % (client.endpoints["auth"])) + print("ST2_API_URL: %s" % (client.endpoints["api"])) + print("ST2_STREAM_URL: %s" % (client.endpoints["stream"])) + print("ST2_AUTH_TOKEN: %s" % (os.environ.get("ST2_AUTH_TOKEN"))) + print("") + print("Proxy settings:") + print("---------------") + print("HTTP_PROXY: %s" % (os.environ.get("HTTP_PROXY", ""))) + print("HTTPS_PROXY: %s" % (os.environ.get("HTTPS_PROXY", ""))) + print("") diff --git a/st2client/st2client/client.py b/st2client/st2client/client.py index 6bda37942b4..9772c825b75 100644 --- a/st2client/st2client/client.py +++ b/st2client/st2client/client.py @@ -47,144 +47,224 @@ DEFAULT_AUTH_PORT = 9100 DEFAULT_STREAM_PORT = 9102 -DEFAULT_BASE_URL = 'http://127.0.0.1' -DEFAULT_API_VERSION = 'v1' +DEFAULT_BASE_URL = "http://127.0.0.1" +DEFAULT_API_VERSION = "v1" class Client(object): - def __init__(self, base_url=None, auth_url=None, api_url=None, stream_url=None, - api_version=None, cacert=None, debug=False, token=None, api_key=None): + def __init__( + self, + base_url=None, + auth_url=None, + api_url=None, + stream_url=None, + api_version=None, + cacert=None, + debug=False, + token=None, + api_key=None, + ): # Get CLI options. If not given, then try to get it from the environment. self.endpoints = dict() # Populate the endpoints if base_url: - self.endpoints['base'] = base_url + self.endpoints["base"] = base_url else: - self.endpoints['base'] = os.environ.get('ST2_BASE_URL', DEFAULT_BASE_URL) + self.endpoints["base"] = os.environ.get("ST2_BASE_URL", DEFAULT_BASE_URL) - api_version = api_version or os.environ.get('ST2_API_VERSION', DEFAULT_API_VERSION) + api_version = api_version or os.environ.get( + "ST2_API_VERSION", DEFAULT_API_VERSION + ) - self.endpoints['exp'] = '%s:%s/%s' % (self.endpoints['base'], DEFAULT_API_PORT, 'exp') + self.endpoints["exp"] = "%s:%s/%s" % ( + self.endpoints["base"], + DEFAULT_API_PORT, + "exp", + ) if api_url: - self.endpoints['api'] = api_url + self.endpoints["api"] = api_url else: - self.endpoints['api'] = os.environ.get( - 'ST2_API_URL', '%s:%s/%s' % (self.endpoints['base'], DEFAULT_API_PORT, api_version)) + self.endpoints["api"] = os.environ.get( + "ST2_API_URL", + "%s:%s/%s" % (self.endpoints["base"], DEFAULT_API_PORT, api_version), + ) if auth_url: - self.endpoints['auth'] = auth_url + self.endpoints["auth"] = auth_url else: - self.endpoints['auth'] = os.environ.get( - 'ST2_AUTH_URL', '%s:%s' % (self.endpoints['base'], DEFAULT_AUTH_PORT)) + self.endpoints["auth"] = os.environ.get( + "ST2_AUTH_URL", "%s:%s" % (self.endpoints["base"], DEFAULT_AUTH_PORT) + ) if stream_url: - self.endpoints['stream'] = stream_url + self.endpoints["stream"] = stream_url else: - self.endpoints['stream'] = os.environ.get( - 'ST2_STREAM_URL', - '%s:%s/%s' % ( - self.endpoints['base'], - DEFAULT_STREAM_PORT, - api_version - ) + self.endpoints["stream"] = os.environ.get( + "ST2_STREAM_URL", + "%s:%s/%s" % (self.endpoints["base"], DEFAULT_STREAM_PORT, api_version), ) if cacert is not None: self.cacert = cacert else: - self.cacert = os.environ.get('ST2_CACERT', None) + self.cacert = os.environ.get("ST2_CACERT", None) # Note: boolean is also a valid value for "cacert" is_cacert_string = isinstance(self.cacert, six.string_types) - if (self.cacert and is_cacert_string and not os.path.isfile(self.cacert)): + if self.cacert and is_cacert_string and not os.path.isfile(self.cacert): raise ValueError('CA cert file "%s" does not exist.' % (self.cacert)) self.debug = debug # Note: This is a nasty hack for now, but we need to get rid of the decrator abuse if token: - os.environ['ST2_AUTH_TOKEN'] = token + os.environ["ST2_AUTH_TOKEN"] = token self.token = token if api_key: - os.environ['ST2_API_KEY'] = api_key + os.environ["ST2_API_KEY"] = api_key self.api_key = api_key # Instantiate resource managers and assign appropriate API endpoint. self.managers = dict() - self.managers['Token'] = ResourceManager( - models.Token, self.endpoints['auth'], cacert=self.cacert, debug=self.debug) - self.managers['RunnerType'] = ResourceManager( - models.RunnerType, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['Action'] = ActionResourceManager( - models.Action, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['ActionAlias'] = ActionAliasResourceManager( - models.ActionAlias, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['ActionAliasExecution'] = ActionAliasExecutionManager( - models.ActionAliasExecution, self.endpoints['api'], - cacert=self.cacert, debug=self.debug) - self.managers['ApiKey'] = ResourceManager( - models.ApiKey, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['Config'] = ConfigManager( - models.Config, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['ConfigSchema'] = ResourceManager( - models.ConfigSchema, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['Execution'] = ExecutionResourceManager( - models.Execution, self.endpoints['api'], cacert=self.cacert, debug=self.debug) + self.managers["Token"] = ResourceManager( + models.Token, self.endpoints["auth"], cacert=self.cacert, debug=self.debug + ) + self.managers["RunnerType"] = ResourceManager( + models.RunnerType, + self.endpoints["api"], + cacert=self.cacert, + debug=self.debug, + ) + self.managers["Action"] = ActionResourceManager( + models.Action, self.endpoints["api"], cacert=self.cacert, debug=self.debug + ) + self.managers["ActionAlias"] = ActionAliasResourceManager( + models.ActionAlias, + self.endpoints["api"], + cacert=self.cacert, + debug=self.debug, + ) + self.managers["ActionAliasExecution"] = ActionAliasExecutionManager( + models.ActionAliasExecution, + self.endpoints["api"], + cacert=self.cacert, + debug=self.debug, + ) + self.managers["ApiKey"] = ResourceManager( + models.ApiKey, self.endpoints["api"], cacert=self.cacert, debug=self.debug + ) + self.managers["Config"] = ConfigManager( + models.Config, self.endpoints["api"], cacert=self.cacert, debug=self.debug + ) + self.managers["ConfigSchema"] = ResourceManager( + models.ConfigSchema, + self.endpoints["api"], + cacert=self.cacert, + debug=self.debug, + ) + self.managers["Execution"] = ExecutionResourceManager( + models.Execution, + self.endpoints["api"], + cacert=self.cacert, + debug=self.debug, + ) # NOTE: LiveAction has been deprecated in favor of Execution. It will be left here for # backward compatibility reasons until v3.2.0 - self.managers['LiveAction'] = self.managers['Execution'] - self.managers['Inquiry'] = InquiryResourceManager( - models.Inquiry, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['Pack'] = PackResourceManager( - models.Pack, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['Policy'] = ResourceManager( - models.Policy, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['PolicyType'] = ResourceManager( - models.PolicyType, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['Rule'] = ResourceManager( - models.Rule, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['Sensor'] = ResourceManager( - models.Sensor, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['TriggerType'] = ResourceManager( - models.TriggerType, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['Trigger'] = ResourceManager( - models.Trigger, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['TriggerInstance'] = TriggerInstanceResourceManager( - models.TriggerInstance, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['KeyValuePair'] = ResourceManager( - models.KeyValuePair, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['Webhook'] = WebhookManager( - models.Webhook, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['Timer'] = ResourceManager( - models.Timer, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['Trace'] = ResourceManager( - models.Trace, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['RuleEnforcement'] = ResourceManager( - models.RuleEnforcement, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['Stream'] = StreamManager( - self.endpoints['stream'], cacert=self.cacert, debug=self.debug) - self.managers['Workflow'] = WorkflowManager( - self.endpoints['api'], cacert=self.cacert, debug=self.debug) + self.managers["LiveAction"] = self.managers["Execution"] + self.managers["Inquiry"] = InquiryResourceManager( + models.Inquiry, self.endpoints["api"], cacert=self.cacert, debug=self.debug + ) + self.managers["Pack"] = PackResourceManager( + models.Pack, self.endpoints["api"], cacert=self.cacert, debug=self.debug + ) + self.managers["Policy"] = ResourceManager( + models.Policy, self.endpoints["api"], cacert=self.cacert, debug=self.debug + ) + self.managers["PolicyType"] = ResourceManager( + models.PolicyType, + self.endpoints["api"], + cacert=self.cacert, + debug=self.debug, + ) + self.managers["Rule"] = ResourceManager( + models.Rule, self.endpoints["api"], cacert=self.cacert, debug=self.debug + ) + self.managers["Sensor"] = ResourceManager( + models.Sensor, self.endpoints["api"], cacert=self.cacert, debug=self.debug + ) + self.managers["TriggerType"] = ResourceManager( + models.TriggerType, + self.endpoints["api"], + cacert=self.cacert, + debug=self.debug, + ) + self.managers["Trigger"] = ResourceManager( + models.Trigger, self.endpoints["api"], cacert=self.cacert, debug=self.debug + ) + self.managers["TriggerInstance"] = TriggerInstanceResourceManager( + models.TriggerInstance, + self.endpoints["api"], + cacert=self.cacert, + debug=self.debug, + ) + self.managers["KeyValuePair"] = ResourceManager( + models.KeyValuePair, + self.endpoints["api"], + cacert=self.cacert, + debug=self.debug, + ) + self.managers["Webhook"] = WebhookManager( + models.Webhook, self.endpoints["api"], cacert=self.cacert, debug=self.debug + ) + self.managers["Timer"] = ResourceManager( + models.Timer, self.endpoints["api"], cacert=self.cacert, debug=self.debug + ) + self.managers["Trace"] = ResourceManager( + models.Trace, self.endpoints["api"], cacert=self.cacert, debug=self.debug + ) + self.managers["RuleEnforcement"] = ResourceManager( + models.RuleEnforcement, + self.endpoints["api"], + cacert=self.cacert, + debug=self.debug, + ) + self.managers["Stream"] = StreamManager( + self.endpoints["stream"], cacert=self.cacert, debug=self.debug + ) + self.managers["Workflow"] = WorkflowManager( + self.endpoints["api"], cacert=self.cacert, debug=self.debug + ) # Service Registry - self.managers['ServiceRegistryGroups'] = ServiceRegistryGroupsManager( - models.ServiceRegistryGroup, self.endpoints['api'], cacert=self.cacert, - debug=self.debug) - - self.managers['ServiceRegistryMembers'] = ServiceRegistryMembersManager( - models.ServiceRegistryMember, self.endpoints['api'], cacert=self.cacert, - debug=self.debug) + self.managers["ServiceRegistryGroups"] = ServiceRegistryGroupsManager( + models.ServiceRegistryGroup, + self.endpoints["api"], + cacert=self.cacert, + debug=self.debug, + ) + + self.managers["ServiceRegistryMembers"] = ServiceRegistryMembersManager( + models.ServiceRegistryMember, + self.endpoints["api"], + cacert=self.cacert, + debug=self.debug, + ) # RBAC - self.managers['Role'] = ResourceManager( - models.Role, self.endpoints['api'], cacert=self.cacert, debug=self.debug) - self.managers['UserRoleAssignment'] = ResourceManager( - models.UserRoleAssignment, self.endpoints['api'], cacert=self.cacert, debug=self.debug) + self.managers["Role"] = ResourceManager( + models.Role, self.endpoints["api"], cacert=self.cacert, debug=self.debug + ) + self.managers["UserRoleAssignment"] = ResourceManager( + models.UserRoleAssignment, + self.endpoints["api"], + cacert=self.cacert, + debug=self.debug, + ) @add_auth_token_to_kwargs_from_env def get_user_info(self, **kwargs): @@ -193,9 +273,10 @@ def get_user_info(self, **kwargs): :rtype: ``dict`` """ - url = '/user' - client = httpclient.HTTPClient(root=self.endpoints['api'], cacert=self.cacert, - debug=self.debug) + url = "/user" + client = httpclient.HTTPClient( + root=self.endpoints["api"], cacert=self.cacert, debug=self.debug + ) response = client.get(url=url, **kwargs) if response.status_code != 200: @@ -205,80 +286,85 @@ def get_user_info(self, **kwargs): @property def actions(self): - return self.managers['Action'] + return self.managers["Action"] @property def apikeys(self): - return self.managers['ApiKey'] + return self.managers["ApiKey"] @property def keys(self): - return self.managers['KeyValuePair'] + return self.managers["KeyValuePair"] @property def executions(self): - return self.managers['Execution'] + return self.managers["Execution"] # NOTE: LiveAction has been deprecated in favor of Execution. It will be left here for # backward compatibility reasons until v3.2.0 @property def liveactions(self): - warnings.warn(('st2client.liveactions has been renamed to st2client.executions, please ' - 'update your code'), DeprecationWarning) + warnings.warn( + ( + "st2client.liveactions has been renamed to st2client.executions, please " + "update your code" + ), + DeprecationWarning, + ) return self.executions @property def inquiries(self): - return self.managers['Inquiry'] + return self.managers["Inquiry"] @property def packs(self): - return self.managers['Pack'] + return self.managers["Pack"] @property def policies(self): - return self.managers['Policy'] + return self.managers["Policy"] @property def policytypes(self): - return self.managers['PolicyType'] + return self.managers["PolicyType"] @property def rules(self): - return self.managers['Rule'] + return self.managers["Rule"] @property def runners(self): - return self.managers['RunnerType'] + return self.managers["RunnerType"] @property def sensors(self): - return self.managers['Sensor'] + return self.managers["Sensor"] @property def tokens(self): - return self.managers['Token'] + return self.managers["Token"] @property def triggertypes(self): - return self.managers['TriggerType'] + return self.managers["TriggerType"] @property def triggerinstances(self): - return self.managers['TriggerInstance'] + return self.managers["TriggerInstance"] @property def trace(self): - return self.managers['Trace'] + return self.managers["Trace"] @property def ruleenforcements(self): - return self.managers['RuleEnforcement'] + return self.managers["RuleEnforcement"] @property def webhooks(self): - return self.managers['Webhook'] + return self.managers["Webhook"] @property def workflows(self): - return self.managers['Workflow'] + return self.managers["Workflow"] diff --git a/st2client/st2client/commands/__init__.py b/st2client/st2client/commands/__init__.py index a9b9cee86b8..995d3fd9d3a 100644 --- a/st2client/st2client/commands/__init__.py +++ b/st2client/st2client/commands/__init__.py @@ -35,9 +35,9 @@ def __init__(self, name, description, app, subparsers, parent_parser=None): self.description = description self.app = app self.parent_parser = parent_parser - self.parser = subparsers.add_parser(self.name, - description=self.description, - help=self.description) + self.parser = subparsers.add_parser( + self.name, description=self.description, help=self.description + ) self.commands = dict() @@ -45,16 +45,19 @@ def __init__(self, name, description, app, subparsers, parent_parser=None): class Command(object): """Represents a commandlet in the command tree.""" - def __init__(self, name, description, app, subparsers, - parent_parser=None, add_help=True): + def __init__( + self, name, description, app, subparsers, parent_parser=None, add_help=True + ): self.name = name self.description = description self.app = app self.parent_parser = parent_parser - self.parser = subparsers.add_parser(self.name, - description=self.description, - help=self.description, - add_help=add_help) + self.parser = subparsers.add_parser( + self.name, + description=self.description, + help=self.description, + add_help=add_help, + ) self.parser.set_defaults(func=self.run_and_print) @abc.abstractmethod @@ -74,8 +77,8 @@ def run_and_print(self, args, **kwargs): raise NotImplementedError def format_output(self, subject, formatter, *args, **kwargs): - json = kwargs.get('json', False) - yaml = kwargs.get('yaml', False) + json = kwargs.get("json", False) + yaml = kwargs.get("yaml", False) if json: func = doc.JsonFormatter.format @@ -90,4 +93,4 @@ def print_output(self, subject, formatter, *args, **kwargs): output = self.format_output(subject, formatter, *args, **kwargs) print(output) else: - print('No matching items found') + print("No matching items found") diff --git a/st2client/st2client/commands/action.py b/st2client/st2client/commands/action.py index 7a41d9e2eb6..dcf76c3a7d1 100644 --- a/st2client/st2client/commands/action.py +++ b/st2client/st2client/commands/action.py @@ -44,60 +44,54 @@ LOG = logging.getLogger(__name__) -LIVEACTION_STATUS_REQUESTED = 'requested' -LIVEACTION_STATUS_SCHEDULED = 'scheduled' -LIVEACTION_STATUS_DELAYED = 'delayed' -LIVEACTION_STATUS_RUNNING = 'running' -LIVEACTION_STATUS_SUCCEEDED = 'succeeded' -LIVEACTION_STATUS_FAILED = 'failed' -LIVEACTION_STATUS_TIMED_OUT = 'timeout' -LIVEACTION_STATUS_ABANDONED = 'abandoned' -LIVEACTION_STATUS_CANCELING = 'canceling' -LIVEACTION_STATUS_CANCELED = 'canceled' -LIVEACTION_STATUS_PAUSING = 'pausing' -LIVEACTION_STATUS_PAUSED = 'paused' -LIVEACTION_STATUS_RESUMING = 'resuming' +LIVEACTION_STATUS_REQUESTED = "requested" +LIVEACTION_STATUS_SCHEDULED = "scheduled" +LIVEACTION_STATUS_DELAYED = "delayed" +LIVEACTION_STATUS_RUNNING = "running" +LIVEACTION_STATUS_SUCCEEDED = "succeeded" +LIVEACTION_STATUS_FAILED = "failed" +LIVEACTION_STATUS_TIMED_OUT = "timeout" +LIVEACTION_STATUS_ABANDONED = "abandoned" +LIVEACTION_STATUS_CANCELING = "canceling" +LIVEACTION_STATUS_CANCELED = "canceled" +LIVEACTION_STATUS_PAUSING = "pausing" +LIVEACTION_STATUS_PAUSED = "paused" +LIVEACTION_STATUS_RESUMING = "resuming" LIVEACTION_COMPLETED_STATES = [ LIVEACTION_STATUS_SUCCEEDED, LIVEACTION_STATUS_FAILED, LIVEACTION_STATUS_TIMED_OUT, LIVEACTION_STATUS_CANCELED, - LIVEACTION_STATUS_ABANDONED + LIVEACTION_STATUS_ABANDONED, ] # Who parameters should be masked when displaying action execution output -PARAMETERS_TO_MASK = [ - 'password', - 'private_key' -] +PARAMETERS_TO_MASK = ["password", "private_key"] # A list of environment variables which are never inherited when using run # --inherit-env flag ENV_VARS_BLACKLIST = [ - 'pwd', - 'mail', - 'username', - 'user', - 'path', - 'home', - 'ps1', - 'shell', - 'pythonpath', - 'ssh_tty', - 'ssh_connection', - 'lang', - 'ls_colors', - 'logname', - 'oldpwd', - 'term', - 'xdg_session_id' + "pwd", + "mail", + "username", + "user", + "path", + "home", + "ps1", + "shell", + "pythonpath", + "ssh_tty", + "ssh_connection", + "lang", + "ls_colors", + "logname", + "oldpwd", + "term", + "xdg_session_id", ] -WORKFLOW_RUNNER_TYPES = [ - 'action-chain', - 'orquesta' -] +WORKFLOW_RUNNER_TYPES = ["action-chain", "orquesta"] def format_parameters(value): @@ -108,15 +102,15 @@ def format_parameters(value): for param_name, _ in value.items(): if param_name in PARAMETERS_TO_MASK: - value[param_name] = '********' + value[param_name] = "********" return value # String for indenting etc. -WF_PREFIX = '+ ' -NON_WF_PREFIX = ' ' -INDENT_CHAR = ' ' +WF_PREFIX = "+ " +NON_WF_PREFIX = " " +INDENT_CHAR = " " def format_wf_instances(instances): @@ -127,7 +121,7 @@ def format_wf_instances(instances): # only add extr chars if there are workflows. has_wf = False for instance in instances: - if not getattr(instance, 'children', None): + if not getattr(instance, "children", None): continue else: has_wf = True @@ -136,7 +130,7 @@ def format_wf_instances(instances): return instances # Prepend wf and non_wf prefixes. for instance in instances: - if getattr(instance, 'children', None): + if getattr(instance, "children", None): instance.id = WF_PREFIX + instance.id else: instance.id = NON_WF_PREFIX + instance.id @@ -158,59 +152,75 @@ def format_execution_status(instance): executions which are in running state and execution total run time for all the executions which have finished. """ - status = getattr(instance, 'status', None) - start_timestamp = getattr(instance, 'start_timestamp', None) - end_timestamp = getattr(instance, 'end_timestamp', None) + status = getattr(instance, "status", None) + start_timestamp = getattr(instance, "start_timestamp", None) + end_timestamp = getattr(instance, "end_timestamp", None) if status == LIVEACTION_STATUS_RUNNING and start_timestamp: start_timestamp = instance.start_timestamp start_timestamp = parse_isotime(start_timestamp) start_timestamp = calendar.timegm(start_timestamp.timetuple()) now = int(time.time()) - elapsed_seconds = (now - start_timestamp) - instance.status = '%s (%ss elapsed)' % (instance.status, elapsed_seconds) + elapsed_seconds = now - start_timestamp + instance.status = "%s (%ss elapsed)" % (instance.status, elapsed_seconds) elif status in LIVEACTION_COMPLETED_STATES and start_timestamp and end_timestamp: start_timestamp = parse_isotime(start_timestamp) start_timestamp = calendar.timegm(start_timestamp.timetuple()) end_timestamp = parse_isotime(end_timestamp) end_timestamp = calendar.timegm(end_timestamp.timetuple()) - elapsed_seconds = (end_timestamp - start_timestamp) - instance.status = '%s (%ss elapsed)' % (instance.status, elapsed_seconds) + elapsed_seconds = end_timestamp - start_timestamp + instance.status = "%s (%ss elapsed)" % (instance.status, elapsed_seconds) return instance class ActionBranch(resource.ResourceBranch): - def __init__(self, description, app, subparsers, parent_parser=None): super(ActionBranch, self).__init__( - models.Action, description, app, subparsers, + models.Action, + description, + app, + subparsers, parent_parser=parent_parser, commands={ - 'list': ActionListCommand, - 'get': ActionGetCommand, - 'update': ActionUpdateCommand, - 'delete': ActionDeleteCommand - }) + "list": ActionListCommand, + "get": ActionGetCommand, + "update": ActionUpdateCommand, + "delete": ActionDeleteCommand, + }, + ) # Registers extended commands - self.commands['enable'] = ActionEnableCommand(self.resource, self.app, self.subparsers) - self.commands['disable'] = ActionDisableCommand(self.resource, self.app, self.subparsers) - self.commands['execute'] = ActionRunCommand( - self.resource, self.app, self.subparsers, - add_help=False) + self.commands["enable"] = ActionEnableCommand( + self.resource, self.app, self.subparsers + ) + self.commands["disable"] = ActionDisableCommand( + self.resource, self.app, self.subparsers + ) + self.commands["execute"] = ActionRunCommand( + self.resource, self.app, self.subparsers, add_help=False + ) class ActionListCommand(resource.ContentPackResourceListCommand): - display_attributes = ['ref', 'pack', 'description'] + display_attributes = ["ref", "pack", "description"] class ActionGetCommand(resource.ContentPackResourceGetCommand): - display_attributes = ['all'] - attribute_display_order = ['id', 'uid', 'ref', 'pack', 'name', 'description', - 'enabled', 'entry_point', 'runner_type', - 'parameters'] + display_attributes = ["all"] + attribute_display_order = [ + "id", + "uid", + "ref", + "pack", + "name", + "description", + "enabled", + "entry_point", + "runner_type", + "parameters", + ] class ActionUpdateCommand(resource.ContentPackResourceUpdateCommand): @@ -218,17 +228,33 @@ class ActionUpdateCommand(resource.ContentPackResourceUpdateCommand): class ActionEnableCommand(resource.ContentPackResourceEnableCommand): - display_attributes = ['all'] - attribute_display_order = ['id', 'ref', 'pack', 'name', 'description', - 'enabled', 'entry_point', 'runner_type', - 'parameters'] + display_attributes = ["all"] + attribute_display_order = [ + "id", + "ref", + "pack", + "name", + "description", + "enabled", + "entry_point", + "runner_type", + "parameters", + ] class ActionDisableCommand(resource.ContentPackResourceDisableCommand): - display_attributes = ['all'] - attribute_display_order = ['id', 'ref', 'pack', 'name', 'description', - 'enabled', 'entry_point', 'runner_type', - 'parameters'] + display_attributes = ["all"] + attribute_display_order = [ + "id", + "ref", + "pack", + "name", + "description", + "enabled", + "entry_point", + "runner_type", + "parameters", + ] class ActionDeleteCommand(resource.ContentPackResourceDeleteCommand): @@ -239,15 +265,32 @@ class ActionRunCommandMixin(object): """ Mixin class which contains utility functions related to action execution. """ - display_attributes = ['id', 'action.ref', 'context.user', 'parameters', 'status', - 'start_timestamp', 'end_timestamp', 'result'] - attribute_display_order = ['id', 'action.ref', 'context.user', 'parameters', 'status', - 'start_timestamp', 'end_timestamp', 'result'] + + display_attributes = [ + "id", + "action.ref", + "context.user", + "parameters", + "status", + "start_timestamp", + "end_timestamp", + "result", + ] + attribute_display_order = [ + "id", + "action.ref", + "context.user", + "parameters", + "status", + "start_timestamp", + "end_timestamp", + "result", + ] attribute_transform_functions = { - 'start_timestamp': format_isodate_for_user_timezone, - 'end_timestamp': format_isodate_for_user_timezone, - 'parameters': format_parameters, - 'status': format_status + "start_timestamp": format_isodate_for_user_timezone, + "end_timestamp": format_isodate_for_user_timezone, + "parameters": format_parameters, + "status": format_status, } poll_interval = 2 # how often to poll for execution completion when using sync mode @@ -262,14 +305,19 @@ def run_and_print(self, args, **kwargs): execution = self.run(args, **kwargs) if args.action_async: - self.print_output('To get the results, execute:\n st2 execution get %s' % - (execution.id), six.text_type) - self.print_output('\nTo view output in real-time, execute:\n st2 execution ' - 'tail %s' % (execution.id), six.text_type) + self.print_output( + "To get the results, execute:\n st2 execution get %s" % (execution.id), + six.text_type, + ) + self.print_output( + "\nTo view output in real-time, execute:\n st2 execution " + "tail %s" % (execution.id), + six.text_type, + ) else: self._print_execution_details(execution=execution, args=args, **kwargs) - if execution.status == 'failed': + if execution.status == "failed": # Exit with non zero if the action has failed sys.exit(1) @@ -278,52 +326,99 @@ def _add_common_options(self): # Display options task_list_arg_grp = root_arg_grp.add_argument_group() - task_list_arg_grp.add_argument('--with-schema', - default=False, action='store_true', - help=('Show schema_ouput suggestion with action.')) - - task_list_arg_grp.add_argument('--raw', action='store_true', - help='Raw output, don\'t show sub-tasks for workflows.') - task_list_arg_grp.add_argument('--show-tasks', action='store_true', - help='Whether to show sub-tasks of an execution.') - task_list_arg_grp.add_argument('--depth', type=int, default=-1, - help='Depth to which to show sub-tasks. \ - By default all are shown.') - task_list_arg_grp.add_argument('-w', '--width', nargs='+', type=int, default=None, - help='Set the width of columns in output.') + task_list_arg_grp.add_argument( + "--with-schema", + default=False, + action="store_true", + help=("Show schema_ouput suggestion with action."), + ) + + task_list_arg_grp.add_argument( + "--raw", + action="store_true", + help="Raw output, don't show sub-tasks for workflows.", + ) + task_list_arg_grp.add_argument( + "--show-tasks", + action="store_true", + help="Whether to show sub-tasks of an execution.", + ) + task_list_arg_grp.add_argument( + "--depth", + type=int, + default=-1, + help="Depth to which to show sub-tasks. \ + By default all are shown.", + ) + task_list_arg_grp.add_argument( + "-w", + "--width", + nargs="+", + type=int, + default=None, + help="Set the width of columns in output.", + ) execution_details_arg_grp = root_arg_grp.add_mutually_exclusive_group() detail_arg_grp = execution_details_arg_grp.add_mutually_exclusive_group() - detail_arg_grp.add_argument('--attr', nargs='+', - default=self.display_attributes, - help=('List of attributes to include in the ' - 'output. "all" or unspecified will ' - 'return all attributes.')) - detail_arg_grp.add_argument('-d', '--detail', action='store_true', - help='Display full detail of the execution in table format.') + detail_arg_grp.add_argument( + "--attr", + nargs="+", + default=self.display_attributes, + help=( + "List of attributes to include in the " + 'output. "all" or unspecified will ' + "return all attributes." + ), + ) + detail_arg_grp.add_argument( + "-d", + "--detail", + action="store_true", + help="Display full detail of the execution in table format.", + ) result_arg_grp = execution_details_arg_grp.add_mutually_exclusive_group() - result_arg_grp.add_argument('-k', '--key', - help=('If result is type of JSON, then print specific ' - 'key-value pair; dot notation for nested JSON is ' - 'supported.')) - result_arg_grp.add_argument('--delay', type=int, default=None, - help=('How long (in milliseconds) to delay the ' - 'execution before scheduling.')) + result_arg_grp.add_argument( + "-k", + "--key", + help=( + "If result is type of JSON, then print specific " + "key-value pair; dot notation for nested JSON is " + "supported." + ), + ) + result_arg_grp.add_argument( + "--delay", + type=int, + default=None, + help=( + "How long (in milliseconds) to delay the " + "execution before scheduling." + ), + ) # Other options - detail_arg_grp.add_argument('--tail', action='store_true', - help='Automatically start tailing new execution.') + detail_arg_grp.add_argument( + "--tail", + action="store_true", + help="Automatically start tailing new execution.", + ) # Flag to opt-in to functionality introduced in PR #3670. More robust parsing # of complex datatypes is planned for 2.6, so this flag will be deprecated soon - detail_arg_grp.add_argument('--auto-dict', action='store_true', dest='auto_dict', - default=False, help='Automatically convert list items to ' - 'dictionaries when colons are detected. ' - '(NOTE - this parameter and its functionality will be ' - 'deprecated in the next release in favor of a more ' - 'robust conversion method)') + detail_arg_grp.add_argument( + "--auto-dict", + action="store_true", + dest="auto_dict", + default=False, + help="Automatically convert list items to " + "dictionaries when colons are detected. " + "(NOTE - this parameter and its functionality will be " + "deprecated in the next release in favor of a more " + "robust conversion method)", + ) return root_arg_grp @@ -334,20 +429,24 @@ def _print_execution_details(self, execution, args, **kwargs): This method takes into account if an executed action was workflow or not and formats the output accordingly. """ - runner_type = execution.action.get('runner_type', 'unknown') + runner_type = execution.action.get("runner_type", "unknown") is_workflow_action = runner_type in WORKFLOW_RUNNER_TYPES - show_tasks = getattr(args, 'show_tasks', False) - raw = getattr(args, 'raw', False) - detail = getattr(args, 'detail', False) - key = getattr(args, 'key', None) - attr = getattr(args, 'attr', []) + show_tasks = getattr(args, "show_tasks", False) + raw = getattr(args, "raw", False) + detail = getattr(args, "detail", False) + key = getattr(args, "key", None) + attr = getattr(args, "attr", []) if show_tasks and not is_workflow_action: - raise ValueError('--show-tasks option can only be used with workflow actions') + raise ValueError( + "--show-tasks option can only be used with workflow actions" + ) if not raw and not detail and (show_tasks or is_workflow_action): - self._run_and_print_child_task_list(execution=execution, args=args, **kwargs) + self._run_and_print_child_task_list( + execution=execution, args=args, **kwargs + ) else: instance = execution @@ -357,47 +456,61 @@ def _print_execution_details(self, execution, args, **kwargs): formatter = execution_formatter.ExecutionResult if detail: - options = {'attributes': copy.copy(self.display_attributes)} + options = {"attributes": copy.copy(self.display_attributes)} elif key: - options = {'attributes': ['result.%s' % (key)], 'key': key} + options = {"attributes": ["result.%s" % (key)], "key": key} else: - options = {'attributes': attr} - - options['json'] = args.json - options['yaml'] = args.yaml - options['with_schema'] = args.with_schema - options['attribute_transform_functions'] = self.attribute_transform_functions + options = {"attributes": attr} + + options["json"] = args.json + options["yaml"] = args.yaml + options["with_schema"] = args.with_schema + options[ + "attribute_transform_functions" + ] = self.attribute_transform_functions self.print_output(instance, formatter, **options) def _run_and_print_child_task_list(self, execution, args, **kwargs): - action_exec_mgr = self.app.client.managers['Execution'] + action_exec_mgr = self.app.client.managers["Execution"] instance = execution - options = {'attributes': ['id', 'action.ref', 'parameters', 'status', 'start_timestamp', - 'end_timestamp']} - options['json'] = args.json - options['attribute_transform_functions'] = self.attribute_transform_functions + options = { + "attributes": [ + "id", + "action.ref", + "parameters", + "status", + "start_timestamp", + "end_timestamp", + ] + } + options["json"] = args.json + options["attribute_transform_functions"] = self.attribute_transform_functions formatter = execution_formatter.ExecutionResult - kwargs['depth'] = args.depth - child_instances = action_exec_mgr.get_property(execution.id, 'children', **kwargs) + kwargs["depth"] = args.depth + child_instances = action_exec_mgr.get_property( + execution.id, "children", **kwargs + ) child_instances = self._format_child_instances(child_instances, execution.id) child_instances = format_execution_statuses(child_instances) if not child_instances: # No child error, there might be a global error, include result in the output - options['attributes'].append('result') + options["attributes"].append("result") - status_index = options['attributes'].index('status') + status_index = options["attributes"].index("status") - if hasattr(instance, 'result') and isinstance(instance.result, dict): - tasks = instance.result.get('tasks', []) + if hasattr(instance, "result") and isinstance(instance.result, dict): + tasks = instance.result.get("tasks", []) else: tasks = [] # On failure we also want to include error message and traceback at the top level - if instance.status == 'failed': - top_level_error, top_level_traceback = self._get_top_level_error(live_action=instance) + if instance.status == "failed": + top_level_error, top_level_traceback = self._get_top_level_error( + live_action=instance + ) if len(tasks) >= 1: task_error, task_traceback = self._get_task_error(task=tasks[-1]) @@ -408,18 +521,18 @@ def _run_and_print_child_task_list(self, execution, args, **kwargs): # Top-level error instance.error = top_level_error instance.traceback = top_level_traceback - instance.result = 'See error and traceback.' - options['attributes'].insert(status_index + 1, 'error') - options['attributes'].insert(status_index + 2, 'traceback') + instance.result = "See error and traceback." + options["attributes"].insert(status_index + 1, "error") + options["attributes"].insert(status_index + 2, "traceback") elif task_error: # Task error instance.error = task_error instance.traceback = task_traceback - instance.result = 'See error and traceback.' - instance.failed_on = tasks[-1].get('name', 'unknown') - options['attributes'].insert(status_index + 1, 'error') - options['attributes'].insert(status_index + 2, 'traceback') - options['attributes'].insert(status_index + 3, 'failed_on') + instance.result = "See error and traceback." + instance.failed_on = tasks[-1].get("name", "unknown") + options["attributes"].insert(status_index + 1, "error") + options["attributes"].insert(status_index + 2, "traceback") + options["attributes"].insert(status_index + 3, "failed_on") # Include result on the top-level object so user doesn't need to issue another command to # see the result @@ -427,57 +540,63 @@ def _run_and_print_child_task_list(self, execution, args, **kwargs): task_result = self._get_task_result(task=tasks[-1]) if task_result: - instance.result_task = tasks[-1].get('name', 'unknown') - options['attributes'].insert(status_index + 1, 'result_task') - options['attributes'].insert(status_index + 2, 'result') + instance.result_task = tasks[-1].get("name", "unknown") + options["attributes"].insert(status_index + 1, "result_task") + options["attributes"].insert(status_index + 2, "result") instance.result = task_result # Otherwise include the result of the workflow execution. else: - if 'result' not in options['attributes']: - options['attributes'].append('result') + if "result" not in options["attributes"]: + options["attributes"].append("result") # print root task self.print_output(instance, formatter, **options) # print child tasks if child_instances: - self.print_output(child_instances, table.MultiColumnTable, - attributes=['id', 'status', 'task', 'action', 'start_timestamp'], - widths=args.width, json=args.json, - yaml=args.yaml, - attribute_transform_functions=self.attribute_transform_functions) + self.print_output( + child_instances, + table.MultiColumnTable, + attributes=["id", "status", "task", "action", "start_timestamp"], + widths=args.width, + json=args.json, + yaml=args.yaml, + attribute_transform_functions=self.attribute_transform_functions, + ) def _get_execution_result(self, execution, action_exec_mgr, args, **kwargs): pending_statuses = [ LIVEACTION_STATUS_REQUESTED, LIVEACTION_STATUS_SCHEDULED, LIVEACTION_STATUS_RUNNING, - LIVEACTION_STATUS_CANCELING + LIVEACTION_STATUS_CANCELING, ] if args.tail: # Start tailing new execution print('Tailing execution "%s"' % (str(execution.id))) - execution_manager = self.app.client.managers['Execution'] - stream_manager = self.app.client.managers['Stream'] - ActionExecutionTailCommand.tail_execution(execution=execution, - execution_manager=execution_manager, - stream_manager=stream_manager, - **kwargs) + execution_manager = self.app.client.managers["Execution"] + stream_manager = self.app.client.managers["Stream"] + ActionExecutionTailCommand.tail_execution( + execution=execution, + execution_manager=execution_manager, + stream_manager=stream_manager, + **kwargs, + ) execution = action_exec_mgr.get_by_id(execution.id, **kwargs) - print('') + print("") return execution if not args.action_async: while execution.status in pending_statuses: time.sleep(self.poll_interval) if not args.json and not args.yaml: - sys.stdout.write('.') + sys.stdout.write(".") sys.stdout.flush() execution = action_exec_mgr.get_by_id(execution.id, **kwargs) - sys.stdout.write('\n') + sys.stdout.write("\n") if execution.status == LIVEACTION_STATUS_CANCELED: return execution @@ -491,8 +610,8 @@ def _get_top_level_error(self, live_action): :return: (error, traceback) """ if isinstance(live_action.result, dict): - error = live_action.result.get('error', None) - traceback = live_action.result.get('traceback', None) + error = live_action.result.get("error", None) + traceback = live_action.result.get("traceback", None) else: error = "See result" traceback = "See result" @@ -508,12 +627,12 @@ def _get_task_error(self, task): if not task: return None, None - result = task['result'] + result = task["result"] if isinstance(result, dict): - stderr = result.get('stderr', None) - error = result.get('error', None) - traceback = result.get('traceback', None) + stderr = result.get("stderr", None) + error = result.get("error", None) + traceback = result.get("traceback", None) error = error if error else stderr else: stderr = None @@ -526,7 +645,7 @@ def _get_task_result(self, task): if not task: return None - return task['result'] + return task["result"] def _get_action_parameters_from_args(self, action, runner, args): """ @@ -553,22 +672,22 @@ def read_file(file_path): if not os.path.isfile(file_path): raise ValueError('"%s" is not a file' % (file_path)) - with open(file_path, 'rb') as fp: + with open(file_path, "rb") as fp: content = fp.read() return content.decode("utf-8") def transform_object(value): # Also support simple key1=val1,key2=val2 syntax - if value.startswith('{'): + if value.startswith("{"): # Assume it's JSON result = value = json.loads(value) else: - pairs = value.split(',') + pairs = value.split(",") result = {} for pair in pairs: - split = pair.split('=', 1) + split = pair.split("=", 1) if len(split) != 2: continue @@ -605,18 +724,22 @@ def transform_array(value, action_params=None, auto_dict=False): try: result = json.loads(value) except ValueError: - result = [v.strip() for v in value.split(',')] + result = [v.strip() for v in value.split(",")] # When each values in this array represent dict type, this converts # the 'result' to the dict type value. - if all([isinstance(x, str) and ':' in x for x in result]) and auto_dict: + if all([isinstance(x, str) and ":" in x for x in result]) and auto_dict: result_dict = {} - for (k, v) in [x.split(':') for x in result]: + for (k, v) in [x.split(":") for x in result]: # To parse values using the 'transformer' according to the type which is # specified in the action metadata, calling 'normalize' method recursively. - if 'properties' in action_params and k in action_params['properties']: - result_dict[k] = normalize(k, v, action_params['properties'], - auto_dict=auto_dict) + if ( + "properties" in action_params + and k in action_params["properties"] + ): + result_dict[k] = normalize( + k, v, action_params["properties"], auto_dict=auto_dict + ) else: result_dict[k] = v return [result_dict] @@ -624,12 +747,12 @@ def transform_array(value, action_params=None, auto_dict=False): return result transformer = { - 'array': transform_array, - 'boolean': (lambda x: ast.literal_eval(x.capitalize())), - 'integer': int, - 'number': float, - 'object': transform_object, - 'string': str + "array": transform_array, + "boolean": (lambda x: ast.literal_eval(x.capitalize())), + "integer": int, + "number": float, + "object": transform_object, + "string": str, } def get_param_type(key, action_params=None): @@ -642,13 +765,13 @@ def get_param_type(key, action_params=None): param = action_params[key] if param: - return param['type'] + return param["type"] return None def normalize(name, value, action_params=None, auto_dict=False): - """ The desired type is contained in the action meta-data, so we can look that up - and call the desired "caster" function listed in the "transformer" dict + """The desired type is contained in the action meta-data, so we can look that up + and call the desired "caster" function listed in the "transformer" dict """ action_params = action_params or action.parameters @@ -663,8 +786,10 @@ def normalize(name, value, action_params=None, auto_dict=False): # (items: type: int for example) and this information is available here so we could # also leverage that to cast each array item to the correct type. param_type = get_param_type(name, action_params) - if param_type == 'array' and name in action_params: - return transformer[param_type](value, action_params[name], auto_dict=auto_dict) + if param_type == "array" and name in action_params: + return transformer[param_type]( + value, action_params[name], auto_dict=auto_dict + ) elif param_type: return transformer[param_type](value) @@ -677,11 +802,11 @@ def normalize(name, value, action_params=None, auto_dict=False): for idx in range(len(args.parameters)): arg = args.parameters[idx] - if '=' in arg: - k, v = arg.split('=', 1) + if "=" in arg: + k, v = arg.split("=", 1) # Attribute for files are prefixed with "@" - if k.startswith('@'): + if k.startswith("@"): k = k[1:] is_file = True else: @@ -695,15 +820,15 @@ def normalize(name, value, action_params=None, auto_dict=False): file_name = os.path.basename(file_path) content = read_file(file_path=file_path) - if action_ref_or_id == 'core.http': + if action_ref_or_id == "core.http": # Special case for http runner - result['_file_name'] = file_name - result['file_content'] = content + result["_file_name"] = file_name + result["file_content"] = content else: result[k] = content else: # This permits multiple declarations of argument only in the array type. - if get_param_type(k) == 'array' and k in result: + if get_param_type(k) == "array" and k in result: result[k] += normalize(k, v, auto_dict=args.auto_dict) else: result[k] = normalize(k, v, auto_dict=args.auto_dict) @@ -711,42 +836,44 @@ def normalize(name, value, action_params=None, auto_dict=False): except Exception as e: # TODO: Move transformers in a separate module and handle # exceptions there - if 'malformed string' in six.text_type(e): - message = ('Invalid value for boolean parameter. ' - 'Valid values are: true, false') + if "malformed string" in six.text_type(e): + message = ( + "Invalid value for boolean parameter. " + "Valid values are: true, false" + ) raise ValueError(message) else: raise e else: - result['cmd'] = ' '.join(args.parameters[idx:]) + result["cmd"] = " ".join(args.parameters[idx:]) break # Special case for http runner - if 'file_content' in result: - if 'method' not in result: + if "file_content" in result: + if "method" not in result: # Default to POST if a method is not provided - result['method'] = 'POST' + result["method"] = "POST" - if 'file_name' not in result: + if "file_name" not in result: # File name not provided, use default file name - result['file_name'] = result['_file_name'] + result["file_name"] = result["_file_name"] - del result['_file_name'] + del result["_file_name"] if args.inherit_env: - result['env'] = self._get_inherited_env_vars() + result["env"] = self._get_inherited_env_vars() return result @add_auth_token_to_kwargs_from_cli def _print_help(self, args, **kwargs): # Print appropriate help message if the help option is given. - action_mgr = self.app.client.managers['Action'] - action_exec_mgr = self.app.client.managers['Execution'] + action_mgr = self.app.client.managers["Action"] + action_exec_mgr = self.app.client.managers["Execution"] if args.help: - action_ref_or_id = getattr(args, 'ref_or_id', None) - action_exec_id = getattr(args, 'id', None) + action_ref_or_id = getattr(args, "ref_or_id", None) + action_exec_id = getattr(args, "id", None) if action_exec_id and not action_ref_or_id: action_exec = action_exec_mgr.get_by_id(action_exec_id, **kwargs) @@ -756,34 +883,47 @@ def _print_help(self, args, **kwargs): try: action = action_mgr.get_by_ref_or_id(args.ref_or_id, **kwargs) if not action: - raise resource.ResourceNotFoundError('Action %s not found' % args.ref_or_id) - runner_mgr = self.app.client.managers['RunnerType'] + raise resource.ResourceNotFoundError( + "Action %s not found" % args.ref_or_id + ) + runner_mgr = self.app.client.managers["RunnerType"] runner = runner_mgr.get_by_name(action.runner_type, **kwargs) - parameters, required, optional, _ = self._get_params_types(runner, - action) - print('') + parameters, required, optional, _ = self._get_params_types( + runner, action + ) + print("") print(textwrap.fill(action.description)) - print('') + print("") if required: - required = self._sort_parameters(parameters=parameters, - names=required) - - print('Required Parameters:') - [self._print_param(name, parameters.get(name)) - for name in required] + required = self._sort_parameters( + parameters=parameters, names=required + ) + + print("Required Parameters:") + [ + self._print_param(name, parameters.get(name)) + for name in required + ] if optional: - optional = self._sort_parameters(parameters=parameters, - names=optional) - - print('Optional Parameters:') - [self._print_param(name, parameters.get(name)) - for name in optional] + optional = self._sort_parameters( + parameters=parameters, names=optional + ) + + print("Optional Parameters:") + [ + self._print_param(name, parameters.get(name)) + for name in optional + ] except resource.ResourceNotFoundError: - print(('Action "%s" is not found. ' % args.ref_or_id) + - 'Use "st2 action list" to see the list of available actions.') + print( + ('Action "%s" is not found. ' % args.ref_or_id) + + 'Use "st2 action list" to see the list of available actions.' + ) except Exception as e: - print('ERROR: Unable to print help for action "%s". %s' % - (args.ref_or_id, e)) + print( + 'ERROR: Unable to print help for action "%s". %s' + % (args.ref_or_id, e) + ) else: self.parser.print_help() return True @@ -795,20 +935,20 @@ def _print_param(name, schema): raise ValueError('Missing schema for parameter "%s"' % (name)) wrapper = textwrap.TextWrapper(width=78) - wrapper.initial_indent = ' ' * 4 + wrapper.initial_indent = " " * 4 wrapper.subsequent_indent = wrapper.initial_indent print(wrapper.fill(name)) - wrapper.initial_indent = ' ' * 8 + wrapper.initial_indent = " " * 8 wrapper.subsequent_indent = wrapper.initial_indent - if 'description' in schema and schema['description']: - print(wrapper.fill(schema['description'])) - if 'type' in schema and schema['type']: - print(wrapper.fill('Type: %s' % schema['type'])) - if 'enum' in schema and schema['enum']: - print(wrapper.fill('Enum: %s' % ', '.join(schema['enum']))) - if 'default' in schema and schema['default'] is not None: - print(wrapper.fill('Default: %s' % schema['default'])) - print('') + if "description" in schema and schema["description"]: + print(wrapper.fill(schema["description"])) + if "type" in schema and schema["type"]: + print(wrapper.fill("Type: %s" % schema["type"])) + if "enum" in schema and schema["enum"]: + print(wrapper.fill("Enum: %s" % ", ".join(schema["enum"]))) + if "default" in schema and schema["default"] is not None: + print(wrapper.fill("Default: %s" % schema["default"])) + print("") @staticmethod def _get_params_types(runner, action): @@ -816,19 +956,18 @@ def _get_params_types(runner, action): action_params = action.parameters parameters = copy.copy(runner_params) parameters.update(copy.copy(action_params)) - required = set([k for k, v in six.iteritems(parameters) if v.get('required')]) + required = set([k for k, v in six.iteritems(parameters) if v.get("required")]) def is_immutable(runner_param_meta, action_param_meta): # If runner sets a param as immutable, action cannot override that. - if runner_param_meta.get('immutable', False): + if runner_param_meta.get("immutable", False): return True else: - return action_param_meta.get('immutable', False) + return action_param_meta.get("immutable", False) immutable = set() for param in parameters.keys(): - if is_immutable(runner_params.get(param, {}), - action_params.get(param, {})): + if is_immutable(runner_params.get(param, {}), action_params.get(param, {})): immutable.add(param) required = required - immutable @@ -837,12 +976,12 @@ def is_immutable(runner_param_meta, action_param_meta): return parameters, required, optional, immutable def _format_child_instances(self, children, parent_id): - ''' + """ The goal of this method is to add an indent at every level. This way the WF is represented as a tree structure while in a list. For the right visuals representation the list must be a DF traversal else the idents will end up looking strange. - ''' + """ # apply basic WF formating first. children = format_wf_instances(children) # setup a depth lookup table @@ -856,7 +995,9 @@ def _format_child_instances(self, children, parent_id): parent = None for instance in children: if WF_PREFIX in instance.id: - instance_id = instance.id[instance.id.index(WF_PREFIX) + len(WF_PREFIX):] + instance_id = instance.id[ + instance.id.index(WF_PREFIX) + len(WF_PREFIX) : + ] else: instance_id = instance.id if instance_id == child.parent: @@ -871,26 +1012,28 @@ def _format_child_instances(self, children, parent_id): return result def _format_for_common_representation(self, task): - ''' + """ Formats a task for common representation for action-chain. - ''' + """ # This really needs to be better handled on the back-end but that would be a bigger # change so handling in cli. - context = getattr(task, 'context', None) - if context and 'chain' in context: - task_name_key = 'context.chain.name' - elif context and 'orquesta' in context: - task_name_key = 'context.orquesta.task_name' + context = getattr(task, "context", None) + if context and "chain" in context: + task_name_key = "context.chain.name" + elif context and "orquesta" in context: + task_name_key = "context.orquesta.task_name" # Use Execution as the object so that the formatter lookup does not change. # AKA HACK! - return models.action.Execution(**{ - 'id': task.id, - 'status': task.status, - 'task': jsutil.get_value(vars(task), task_name_key), - 'action': task.action.get('ref', None), - 'start_timestamp': task.start_timestamp, - 'end_timestamp': getattr(task, 'end_timestamp', None) - }) + return models.action.Execution( + **{ + "id": task.id, + "status": task.status, + "task": jsutil.get_value(vars(task), task_name_key), + "action": task.action.get("ref", None), + "start_timestamp": task.start_timestamp, + "end_timestamp": getattr(task, "end_timestamp", None), + } + ) def _sort_parameters(self, parameters, names): """ @@ -899,10 +1042,12 @@ def _sort_parameters(self, parameters, names): :type parameters: ``list`` :type names: ``list`` or ``set`` """ - sorted_parameters = sorted(names, key=lambda name: - self._get_parameter_sort_value( - parameters=parameters, - name=name)) + sorted_parameters = sorted( + names, + key=lambda name: self._get_parameter_sort_value( + parameters=parameters, name=name + ), + ) return sorted_parameters @@ -919,7 +1064,7 @@ def _get_parameter_sort_value(self, parameters, name): if not parameter: return None - sort_value = parameter.get('position', name) + sort_value = parameter.get("position", name) return sort_value def _get_inherited_env_vars(self): @@ -938,44 +1083,76 @@ class ActionRunCommand(ActionRunCommandMixin, resource.ResourceCommand): def __init__(self, resource, *args, **kwargs): super(ActionRunCommand, self).__init__( - resource, kwargs.pop('name', 'execute'), - 'Invoke an action manually.', - *args, **kwargs) - - self.parser.add_argument('ref_or_id', nargs='?', - metavar='ref-or-id', - help='Action reference (pack.action_name) ' + - 'or ID of the action.') - self.parser.add_argument('parameters', nargs='*', - help='List of keyword args, positional args, ' - 'and optional args for the action.') - - self.parser.add_argument('-h', '--help', - action='store_true', dest='help', - help='Print usage for the given action.') + resource, + kwargs.pop("name", "execute"), + "Invoke an action manually.", + *args, + **kwargs, + ) + + self.parser.add_argument( + "ref_or_id", + nargs="?", + metavar="ref-or-id", + help="Action reference (pack.action_name) " + "or ID of the action.", + ) + self.parser.add_argument( + "parameters", + nargs="*", + help="List of keyword args, positional args, " + "and optional args for the action.", + ) + + self.parser.add_argument( + "-h", + "--help", + action="store_true", + dest="help", + help="Print usage for the given action.", + ) self._add_common_options() - if self.name in ['run', 'execute']: - self.parser.add_argument('--trace-tag', '--trace_tag', - help='A trace tag string to track execution later.', - dest='trace_tag', required=False) - self.parser.add_argument('--trace-id', - help='Existing trace id for this execution.', - dest='trace_id', required=False) - self.parser.add_argument('-a', '--async', - action='store_true', dest='action_async', - help='Do not wait for action to finish.') - self.parser.add_argument('-e', '--inherit-env', - action='store_true', dest='inherit_env', - help='Pass all the environment variables ' - 'which are accessible to the CLI as "env" ' - 'parameter to the action. Note: Only works ' - 'with python, local and remote runners.') - self.parser.add_argument('-u', '--user', type=str, default=None, - help='User under which to run the action (admins only).') - - if self.name == 'run': + if self.name in ["run", "execute"]: + self.parser.add_argument( + "--trace-tag", + "--trace_tag", + help="A trace tag string to track execution later.", + dest="trace_tag", + required=False, + ) + self.parser.add_argument( + "--trace-id", + help="Existing trace id for this execution.", + dest="trace_id", + required=False, + ) + self.parser.add_argument( + "-a", + "--async", + action="store_true", + dest="action_async", + help="Do not wait for action to finish.", + ) + self.parser.add_argument( + "-e", + "--inherit-env", + action="store_true", + dest="inherit_env", + help="Pass all the environment variables " + 'which are accessible to the CLI as "env" ' + "parameter to the action. Note: Only works " + "with python, local and remote runners.", + ) + self.parser.add_argument( + "-u", + "--user", + type=str, + default=None, + help="User under which to run the action (admins only).", + ) + + if self.name == "run": self.parser.set_defaults(action_async=False) else: self.parser.set_defaults(action_async=True) @@ -983,22 +1160,27 @@ def __init__(self, resource, *args, **kwargs): @add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): if not args.ref_or_id: - self.parser.error('Missing action reference or id') + self.parser.error("Missing action reference or id") action = self.get_resource(args.ref_or_id, **kwargs) if not action: - raise resource.ResourceNotFoundError('Action "%s" cannot be found.' - % (args.ref_or_id)) + raise resource.ResourceNotFoundError( + 'Action "%s" cannot be found.' % (args.ref_or_id) + ) - runner_mgr = self.app.client.managers['RunnerType'] + runner_mgr = self.app.client.managers["RunnerType"] runner = runner_mgr.get_by_name(action.runner_type, **kwargs) if not runner: - raise resource.ResourceNotFoundError('Runner type "%s" for action "%s" cannot be \ - found.' % (action.runner_type, action.name)) + raise resource.ResourceNotFoundError( + 'Runner type "%s" for action "%s" cannot be \ + found.' + % (action.runner_type, action.name) + ) - action_ref = '.'.join([action.pack, action.name]) - action_parameters = self._get_action_parameters_from_args(action=action, runner=runner, - args=args) + action_ref = ".".join([action.pack, action.name]) + action_parameters = self._get_action_parameters_from_args( + action=action, runner=runner, args=args + ) execution = models.Execution() execution.action = action_ref @@ -1009,56 +1191,79 @@ def run(self, args, **kwargs): execution.delay = args.delay if not args.trace_id and args.trace_tag: - execution.context = {'trace_context': {'trace_tag': args.trace_tag}} + execution.context = {"trace_context": {"trace_tag": args.trace_tag}} if args.trace_id: - execution.context = {'trace_context': {'id_': args.trace_id}} + execution.context = {"trace_context": {"id_": args.trace_id}} - action_exec_mgr = self.app.client.managers['Execution'] + action_exec_mgr = self.app.client.managers["Execution"] execution = action_exec_mgr.create(execution, **kwargs) - execution = self._get_execution_result(execution=execution, - action_exec_mgr=action_exec_mgr, - args=args, **kwargs) + execution = self._get_execution_result( + execution=execution, action_exec_mgr=action_exec_mgr, args=args, **kwargs + ) return execution class ActionExecutionBranch(resource.ResourceBranch): - def __init__(self, description, app, subparsers, parent_parser=None): super(ActionExecutionBranch, self).__init__( - models.Execution, description, app, subparsers, - parent_parser=parent_parser, read_only=True, - commands={'list': ActionExecutionListCommand, - 'get': ActionExecutionGetCommand}) + models.Execution, + description, + app, + subparsers, + parent_parser=parent_parser, + read_only=True, + commands={ + "list": ActionExecutionListCommand, + "get": ActionExecutionGetCommand, + }, + ) # Register extended commands - self.commands['re-run'] = ActionExecutionReRunCommand( - self.resource, self.app, self.subparsers, add_help=False) - self.commands['cancel'] = ActionExecutionCancelCommand( - self.resource, self.app, self.subparsers, add_help=True) - self.commands['pause'] = ActionExecutionPauseCommand( - self.resource, self.app, self.subparsers, add_help=True) - self.commands['resume'] = ActionExecutionResumeCommand( - self.resource, self.app, self.subparsers, add_help=True) - self.commands['tail'] = ActionExecutionTailCommand(self.resource, self.app, - self.subparsers, - add_help=True) - - -POSSIBLE_ACTION_STATUS_VALUES = ('succeeded', 'running', 'scheduled', 'paused', 'failed', - 'canceling', 'canceled') + self.commands["re-run"] = ActionExecutionReRunCommand( + self.resource, self.app, self.subparsers, add_help=False + ) + self.commands["cancel"] = ActionExecutionCancelCommand( + self.resource, self.app, self.subparsers, add_help=True + ) + self.commands["pause"] = ActionExecutionPauseCommand( + self.resource, self.app, self.subparsers, add_help=True + ) + self.commands["resume"] = ActionExecutionResumeCommand( + self.resource, self.app, self.subparsers, add_help=True + ) + self.commands["tail"] = ActionExecutionTailCommand( + self.resource, self.app, self.subparsers, add_help=True + ) + + +POSSIBLE_ACTION_STATUS_VALUES = ( + "succeeded", + "running", + "scheduled", + "paused", + "failed", + "canceling", + "canceled", +) class ActionExecutionListCommand(ResourceViewCommand): - display_attributes = ['id', 'action.ref', 'context.user', 'status', 'start_timestamp', - 'end_timestamp'] + display_attributes = [ + "id", + "action.ref", + "context.user", + "status", + "start_timestamp", + "end_timestamp", + ] attribute_transform_functions = { - 'start_timestamp': format_isodate_for_user_timezone, - 'end_timestamp': format_isodate_for_user_timezone, - 'parameters': format_parameters, - 'status': format_status + "start_timestamp": format_isodate_for_user_timezone, + "end_timestamp": format_isodate_for_user_timezone, + "parameters": format_parameters, + "status": format_status, } def __init__(self, resource, *args, **kwargs): @@ -1066,83 +1271,133 @@ def __init__(self, resource, *args, **kwargs): self.default_limit = 50 super(ActionExecutionListCommand, self).__init__( - resource, 'list', 'Get the list of the %s most recent %s.' % - (self.default_limit, resource.get_plural_display_name().lower()), - *args, **kwargs) + resource, + "list", + "Get the list of the %s most recent %s." + % (self.default_limit, resource.get_plural_display_name().lower()), + *args, + **kwargs, + ) self.resource_name = resource.get_plural_display_name().lower() self.group = self.parser.add_argument_group() - self.parser.add_argument('-n', '--last', type=int, dest='last', - default=self.default_limit, - help=('List N most recent %s. Use -n -1 to fetch the full result \ - set.' % self.resource_name)) - self.parser.add_argument('-s', '--sort', type=str, dest='sort_order', - default='descending', - help=('Sort %s by start timestamp, ' - 'asc|ascending (earliest first) ' - 'or desc|descending (latest first)' % self.resource_name)) + self.parser.add_argument( + "-n", + "--last", + type=int, + dest="last", + default=self.default_limit, + help=( + "List N most recent %s. Use -n -1 to fetch the full result \ + set." + % self.resource_name + ), + ) + self.parser.add_argument( + "-s", + "--sort", + type=str, + dest="sort_order", + default="descending", + help=( + "Sort %s by start timestamp, " + "asc|ascending (earliest first) " + "or desc|descending (latest first)" % self.resource_name + ), + ) # Filter options - self.group.add_argument('--action', help='Action reference to filter the list.') - self.group.add_argument('--status', help=('Only return executions with the provided \ - status. Possible values are \'%s\', \'%s\', \ - \'%s\', \'%s\', \'%s\', \'%s\' or \'%s\'' - '.' % POSSIBLE_ACTION_STATUS_VALUES)) - self.group.add_argument('--user', - help='Only return executions created by the provided user.') - self.group.add_argument('--trigger_instance', - help='Trigger instance id to filter the list.') - self.parser.add_argument('-tg', '--timestamp-gt', type=str, dest='timestamp_gt', - default=None, - help=('Only return executions with timestamp ' - 'greater than the one provided. ' - 'Use time in the format "2000-01-01T12:00:00.000Z".')) - self.parser.add_argument('-tl', '--timestamp-lt', type=str, dest='timestamp_lt', - default=None, - help=('Only return executions with timestamp ' - 'lower than the one provided. ' - 'Use time in the format "2000-01-01T12:00:00.000Z".')) - self.parser.add_argument('-l', '--showall', action='store_true', - help='') + self.group.add_argument("--action", help="Action reference to filter the list.") + self.group.add_argument( + "--status", + help=( + "Only return executions with the provided \ + status. Possible values are '%s', '%s', \ + '%s', '%s', '%s', '%s' or '%s'" + "." % POSSIBLE_ACTION_STATUS_VALUES + ), + ) + self.group.add_argument( + "--user", help="Only return executions created by the provided user." + ) + self.group.add_argument( + "--trigger_instance", help="Trigger instance id to filter the list." + ) + self.parser.add_argument( + "-tg", + "--timestamp-gt", + type=str, + dest="timestamp_gt", + default=None, + help=( + "Only return executions with timestamp " + "greater than the one provided. " + 'Use time in the format "2000-01-01T12:00:00.000Z".' + ), + ) + self.parser.add_argument( + "-tl", + "--timestamp-lt", + type=str, + dest="timestamp_lt", + default=None, + help=( + "Only return executions with timestamp " + "lower than the one provided. " + 'Use time in the format "2000-01-01T12:00:00.000Z".' + ), + ) + self.parser.add_argument("-l", "--showall", action="store_true", help="") # Display options - self.parser.add_argument('-a', '--attr', nargs='+', - default=self.display_attributes, - help=('List of attributes to include in the ' - 'output. "all" will return all ' - 'attributes.')) - self.parser.add_argument('-w', '--width', nargs='+', type=int, - default=None, - help=('Set the width of columns in output.')) + self.parser.add_argument( + "-a", + "--attr", + nargs="+", + default=self.display_attributes, + help=( + "List of attributes to include in the " + 'output. "all" will return all ' + "attributes." + ), + ) + self.parser.add_argument( + "-w", + "--width", + nargs="+", + type=int, + default=None, + help=("Set the width of columns in output."), + ) @add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): # Filtering options if args.action: - kwargs['action'] = args.action + kwargs["action"] = args.action if args.status: - kwargs['status'] = args.status + kwargs["status"] = args.status if args.user: - kwargs['user'] = args.user + kwargs["user"] = args.user if args.trigger_instance: - kwargs['trigger_instance'] = args.trigger_instance + kwargs["trigger_instance"] = args.trigger_instance if not args.showall: # null is the magic string that translates to does not exist. - kwargs['parent'] = 'null' + kwargs["parent"] = "null" if args.timestamp_gt: - kwargs['timestamp_gt'] = args.timestamp_gt + kwargs["timestamp_gt"] = args.timestamp_gt if args.timestamp_lt: - kwargs['timestamp_lt'] = args.timestamp_lt + kwargs["timestamp_lt"] = args.timestamp_lt if args.sort_order: - if args.sort_order in ['asc', 'ascending']: - kwargs['sort_asc'] = True - elif args.sort_order in ['desc', 'descending']: - kwargs['sort_desc'] = True + if args.sort_order in ["asc", "ascending"]: + kwargs["sort_asc"] = True + elif args.sort_order in ["desc", "descending"]: + kwargs["sort_desc"] = True # We only retrieve attributes which are needed to speed things up include_attributes = self._get_include_attributes(args=args) if include_attributes: - kwargs['include_attributes'] = ','.join(include_attributes) + kwargs["include_attributes"] = ",".join(include_attributes) return self.manager.query_with_count(limit=args.last, **kwargs) @@ -1152,49 +1407,73 @@ def run_and_print(self, args, **kwargs): instances = format_wf_instances(result) if args.json or args.yaml: - self.print_output(reversed(instances), table.MultiColumnTable, - attributes=args.attr, widths=args.width, - json=args.json, - yaml=args.yaml, - attribute_transform_functions=self.attribute_transform_functions) + self.print_output( + reversed(instances), + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + json=args.json, + yaml=args.yaml, + attribute_transform_functions=self.attribute_transform_functions, + ) else: # Include elapsed time for running executions instances = format_execution_statuses(instances) - self.print_output(reversed(instances), table.MultiColumnTable, - attributes=args.attr, widths=args.width, - attribute_transform_functions=self.attribute_transform_functions) + self.print_output( + reversed(instances), + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + attribute_transform_functions=self.attribute_transform_functions, + ) if args.last and count and count > args.last: table.SingleRowTable.note_box(self.resource_name, args.last) class ActionExecutionGetCommand(ActionRunCommandMixin, ResourceViewCommand): - display_attributes = ['id', 'action.ref', 'context.user', 'parameters', 'status', - 'start_timestamp', 'end_timestamp', 'result'] - include_attributes = ['action.ref', 'action.runner_type', 'start_timestamp', - 'end_timestamp'] + display_attributes = [ + "id", + "action.ref", + "context.user", + "parameters", + "status", + "start_timestamp", + "end_timestamp", + "result", + ] + include_attributes = [ + "action.ref", + "action.runner_type", + "start_timestamp", + "end_timestamp", + ] def __init__(self, resource, *args, **kwargs): super(ActionExecutionGetCommand, self).__init__( - resource, 'get', - 'Get individual %s.' % resource.get_display_name().lower(), - *args, **kwargs) + resource, + "get", + "Get individual %s." % resource.get_display_name().lower(), + *args, + **kwargs, + ) - self.parser.add_argument('id', - help=('ID of the %s.' % - resource.get_display_name().lower())) + self.parser.add_argument( + "id", help=("ID of the %s." % resource.get_display_name().lower()) + ) self._add_common_options() @add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): # We only retrieve attributes which are needed to speed things up - include_attributes = self._get_include_attributes(args=args, - extra_attributes=self.include_attributes) + include_attributes = self._get_include_attributes( + args=args, extra_attributes=self.include_attributes + ) if include_attributes: - include_attributes = ','.join(include_attributes) - kwargs['params'] = {'include_attributes': include_attributes} + include_attributes = ",".join(include_attributes) + kwargs["params"] = {"include_attributes": include_attributes} execution = self.get_resource_by_id(id=args.id, **kwargs) return execution @@ -1209,22 +1488,25 @@ def run_and_print(self, args, **kwargs): execution = format_execution_status(execution) except resource.ResourceNotFoundError: self.print_not_found(args.id) - raise ResourceNotFoundError('Execution with id %s not found.' % (args.id)) + raise ResourceNotFoundError("Execution with id %s not found." % (args.id)) return self._print_execution_details(execution=execution, args=args, **kwargs) class ActionExecutionCancelCommand(resource.ResourceCommand): - def __init__(self, resource, *args, **kwargs): super(ActionExecutionCancelCommand, self).__init__( - resource, 'cancel', 'Cancel %s.' % - resource.get_plural_display_name().lower(), - *args, **kwargs) - - self.parser.add_argument('ids', - nargs='+', - help=('IDs of the %ss to cancel.' % - resource.get_display_name().lower())) + resource, + "cancel", + "Cancel %s." % resource.get_plural_display_name().lower(), + *args, + **kwargs, + ) + + self.parser.add_argument( + "ids", + nargs="+", + help=("IDs of the %ss to cancel." % resource.get_display_name().lower()), + ) def run(self, args, **kwargs): responses = [] @@ -1242,16 +1524,23 @@ def run_and_print(self, args, **kwargs): self._print_result(execution_id=execution_id, response=response) def _print_result(self, execution_id, response): - if response and 'faultstring' in response: - message = response.get('faultstring', 'Cancellation requested for %s with id %s.' % - (self.resource.get_display_name().lower(), execution_id)) + if response and "faultstring" in response: + message = response.get( + "faultstring", + "Cancellation requested for %s with id %s." + % (self.resource.get_display_name().lower(), execution_id), + ) elif response: - message = '%s with id %s canceled.' % (self.resource.get_display_name().lower(), - execution_id) + message = "%s with id %s canceled." % ( + self.resource.get_display_name().lower(), + execution_id, + ) else: - message = 'Cannot cancel %s with id %s.' % (self.resource.get_display_name().lower(), - execution_id) + message = "Cannot cancel %s with id %s." % ( + self.resource.get_display_name().lower(), + execution_id, + ) print(message) @@ -1259,35 +1548,58 @@ class ActionExecutionReRunCommand(ActionRunCommandMixin, resource.ResourceComman def __init__(self, resource, *args, **kwargs): super(ActionExecutionReRunCommand, self).__init__( - resource, kwargs.pop('name', 're-run'), - 'Re-run a particular action.', - *args, **kwargs) - - self.parser.add_argument('id', nargs='?', - metavar='id', - help='ID of action execution to re-run ') - self.parser.add_argument('parameters', nargs='*', - help='List of keyword args, positional args, ' - 'and optional args for the action.') - self.parser.add_argument('--tasks', nargs='*', - help='Name of the workflow tasks to re-run.') - self.parser.add_argument('--no-reset', dest='no_reset', nargs='*', - help='Name of the with-items tasks to not reset. This only ' - 'applies to Orquesta workflows. By default, all iterations ' - 'for with-items tasks is rerun. If no reset, only failed ' - ' iterations are rerun.') - self.parser.add_argument('-a', '--async', - action='store_true', dest='action_async', - help='Do not wait for action to finish.') - self.parser.add_argument('-e', '--inherit-env', - action='store_true', dest='inherit_env', - help='Pass all the environment variables ' - 'which are accessible to the CLI as "env" ' - 'parameter to the action. Note: Only works ' - 'with python, local and remote runners.') - self.parser.add_argument('-h', '--help', - action='store_true', dest='help', - help='Print usage for the given action.') + resource, + kwargs.pop("name", "re-run"), + "Re-run a particular action.", + *args, + **kwargs, + ) + + self.parser.add_argument( + "id", nargs="?", metavar="id", help="ID of action execution to re-run " + ) + self.parser.add_argument( + "parameters", + nargs="*", + help="List of keyword args, positional args, " + "and optional args for the action.", + ) + self.parser.add_argument( + "--tasks", nargs="*", help="Name of the workflow tasks to re-run." + ) + self.parser.add_argument( + "--no-reset", + dest="no_reset", + nargs="*", + help="Name of the with-items tasks to not reset. This only " + "applies to Orquesta workflows. By default, all iterations " + "for with-items tasks is rerun. If no reset, only failed " + " iterations are rerun.", + ) + self.parser.add_argument( + "-a", + "--async", + action="store_true", + dest="action_async", + help="Do not wait for action to finish.", + ) + self.parser.add_argument( + "-e", + "--inherit-env", + action="store_true", + dest="inherit_env", + help="Pass all the environment variables " + 'which are accessible to the CLI as "env" ' + "parameter to the action. Note: Only works " + "with python, local and remote runners.", + ) + self.parser.add_argument( + "-h", + "--help", + action="store_true", + dest="help", + help="Print usage for the given action.", + ) self._add_common_options() @add_auth_token_to_kwargs_from_cli @@ -1295,47 +1607,63 @@ def run(self, args, **kwargs): existing_execution = self.manager.get_by_id(args.id, **kwargs) if not existing_execution: - raise resource.ResourceNotFoundError('Action execution with id "%s" cannot be found.' % - (args.id)) + raise resource.ResourceNotFoundError( + 'Action execution with id "%s" cannot be found.' % (args.id) + ) - action_mgr = self.app.client.managers['Action'] - runner_mgr = self.app.client.managers['RunnerType'] - action_exec_mgr = self.app.client.managers['Execution'] + action_mgr = self.app.client.managers["Action"] + runner_mgr = self.app.client.managers["RunnerType"] + action_exec_mgr = self.app.client.managers["Execution"] - action_ref = existing_execution.action['ref'] + action_ref = existing_execution.action["ref"] action = action_mgr.get_by_ref_or_id(action_ref) runner = runner_mgr.get_by_name(action.runner_type) - action_parameters = self._get_action_parameters_from_args(action=action, runner=runner, - args=args) + action_parameters = self._get_action_parameters_from_args( + action=action, runner=runner, args=args + ) - execution = action_exec_mgr.re_run(execution_id=args.id, - parameters=action_parameters, - tasks=args.tasks, - no_reset=args.no_reset, - delay=args.delay if args.delay else 0, - **kwargs) + execution = action_exec_mgr.re_run( + execution_id=args.id, + parameters=action_parameters, + tasks=args.tasks, + no_reset=args.no_reset, + delay=args.delay if args.delay else 0, + **kwargs, + ) - execution = self._get_execution_result(execution=execution, - action_exec_mgr=action_exec_mgr, - args=args, **kwargs) + execution = self._get_execution_result( + execution=execution, action_exec_mgr=action_exec_mgr, args=args, **kwargs + ) return execution class ActionExecutionPauseCommand(ActionRunCommandMixin, ResourceViewCommand): - display_attributes = ['id', 'action.ref', 'context.user', 'parameters', 'status', - 'start_timestamp', 'end_timestamp', 'result'] + display_attributes = [ + "id", + "action.ref", + "context.user", + "parameters", + "status", + "start_timestamp", + "end_timestamp", + "result", + ] def __init__(self, resource, *args, **kwargs): super(ActionExecutionPauseCommand, self).__init__( - resource, 'pause', 'Pause %s (workflow executions only).' % - resource.get_plural_display_name().lower(), - *args, **kwargs) - - self.parser.add_argument('ids', - nargs='+', - help='ID of action execution to pause.') + resource, + "pause", + "Pause %s (workflow executions only)." + % resource.get_plural_display_name().lower(), + *args, + **kwargs, + ) + + self.parser.add_argument( + "ids", nargs="+", help="ID of action execution to pause." + ) self._add_common_options() @@ -1348,7 +1676,9 @@ def run(self, args, **kwargs): responses.append([execution_id, response]) except resource.ResourceNotFoundError: self.print_not_found(args.ids) - raise ResourceNotFoundError('Execution with id %s not found.' % (execution_id)) + raise ResourceNotFoundError( + "Execution with id %s not found." % (execution_id) + ) return responses @@ -1367,18 +1697,30 @@ def _print_result(self, args, execution_id, execution, **kwargs): class ActionExecutionResumeCommand(ActionRunCommandMixin, ResourceViewCommand): - display_attributes = ['id', 'action.ref', 'context.user', 'parameters', 'status', - 'start_timestamp', 'end_timestamp', 'result'] + display_attributes = [ + "id", + "action.ref", + "context.user", + "parameters", + "status", + "start_timestamp", + "end_timestamp", + "result", + ] def __init__(self, resource, *args, **kwargs): super(ActionExecutionResumeCommand, self).__init__( - resource, 'resume', 'Resume %s (workflow executions only).' % - resource.get_plural_display_name().lower(), - *args, **kwargs) - - self.parser.add_argument('ids', - nargs='+', - help='ID of action execution to resume.') + resource, + "resume", + "Resume %s (workflow executions only)." + % resource.get_plural_display_name().lower(), + *args, + **kwargs, + ) + + self.parser.add_argument( + "ids", nargs="+", help="ID of action execution to resume." + ) self._add_common_options() @@ -1391,7 +1733,9 @@ def run(self, args, **kwargs): responses.append([execution_id, response]) except resource.ResourceNotFoundError: self.print_not_found(execution_id) - raise ResourceNotFoundError('Execution with id %s not found.' % (execution_id)) + raise ResourceNotFoundError( + "Execution with id %s not found." % (execution_id) + ) return responses @@ -1412,22 +1756,33 @@ def _print_result(self, args, execution, **kwargs): class ActionExecutionTailCommand(resource.ResourceCommand): def __init__(self, resource, *args, **kwargs): super(ActionExecutionTailCommand, self).__init__( - resource, kwargs.pop('name', 'tail'), - 'Tail output of a particular execution.', - *args, **kwargs) - - self.parser.add_argument('id', nargs='?', - metavar='id', - default='last', - help='ID of action execution to tail.') - self.parser.add_argument('--type', dest='output_type', action='store', - help=('Type of output to tail for. If not provided, ' - 'defaults to all.')) - self.parser.add_argument('--include-metadata', dest='include_metadata', - action='store_true', - default=False, - help=('Include metadata (timestamp, output type) with the ' - 'output.')) + resource, + kwargs.pop("name", "tail"), + "Tail output of a particular execution.", + *args, + **kwargs, + ) + + self.parser.add_argument( + "id", + nargs="?", + metavar="id", + default="last", + help="ID of action execution to tail.", + ) + self.parser.add_argument( + "--type", + dest="output_type", + action="store", + help=("Type of output to tail for. If not provided, " "defaults to all."), + ) + self.parser.add_argument( + "--include-metadata", + dest="include_metadata", + action="store_true", + default=False, + help=("Include metadata (timestamp, output type) with the " "output."), + ) def run(self, args, **kwargs): pass @@ -1435,45 +1790,55 @@ def run(self, args, **kwargs): @add_auth_token_to_kwargs_from_cli def run_and_print(self, args, **kwargs): execution_id = args.id - output_type = getattr(args, 'output_type', None) + output_type = getattr(args, "output_type", None) include_metadata = args.include_metadata # Special case for id "last" - if execution_id == 'last': + if execution_id == "last": executions = self.manager.query(limit=1) if executions: execution = executions[0] execution_id = execution.id else: - print('No executions found in db.') + print("No executions found in db.") return else: execution = self.manager.get_by_id(execution_id, **kwargs) if not execution: - raise ResourceNotFoundError('Execution with id %s not found.' % (args.id)) + raise ResourceNotFoundError("Execution with id %s not found." % (args.id)) execution_manager = self.manager - stream_manager = self.app.client.managers['Stream'] - ActionExecutionTailCommand.tail_execution(execution=execution, - execution_manager=execution_manager, - stream_manager=stream_manager, - output_type=output_type, - include_metadata=include_metadata, - **kwargs) + stream_manager = self.app.client.managers["Stream"] + ActionExecutionTailCommand.tail_execution( + execution=execution, + execution_manager=execution_manager, + stream_manager=stream_manager, + output_type=output_type, + include_metadata=include_metadata, + **kwargs, + ) @classmethod - def tail_execution(cls, execution_manager, stream_manager, execution, output_type=None, - include_metadata=False, **kwargs): + def tail_execution( + cls, + execution_manager, + stream_manager, + execution, + output_type=None, + include_metadata=False, + **kwargs, + ): execution_id = str(execution.id) # Indicates if the execution we are tailing is a child execution in a workflow context = cls.get_normalized_context_execution_task_event(execution.__dict__) - has_parent_attribute = bool(getattr(execution, 'parent', None)) - has_parent_execution_id = bool(context['parent_execution_id']) + has_parent_attribute = bool(getattr(execution, "parent", None)) + has_parent_execution_id = bool(context["parent_execution_id"]) - is_tailing_execution_child_execution = bool(has_parent_attribute or - has_parent_execution_id) + is_tailing_execution_child_execution = bool( + has_parent_attribute or has_parent_execution_id + ) # Note: For non-workflow actions child_execution_id always matches parent_execution_id so # we don't need to do any other checks to determine if executions represents a workflow @@ -1484,10 +1849,14 @@ def tail_execution(cls, execution_manager, stream_manager, execution, output_typ # NOTE: This doesn't recurse down into child executions if user is tailing a workflow # execution if execution.status in LIVEACTION_COMPLETED_STATES: - output = execution_manager.get_output(execution_id=execution_id, - output_type=output_type) + output = execution_manager.get_output( + execution_id=execution_id, output_type=output_type + ) print(output) - print('Execution %s has completed (status=%s).' % (execution_id, execution.status)) + print( + "Execution %s has completed (status=%s)." + % (execution_id, execution.status) + ) return # We keep track of all the workflow executions which could contain children. @@ -1497,29 +1866,27 @@ def tail_execution(cls, execution_manager, stream_manager, execution, output_typ # Retrieve parent execution object so we can keep track of any existing children # executions (only applies to already running executions). - filters = { - 'params': { - 'include_attributes': 'id,children' - } - } + filters = {"params": {"include_attributes": "id,children"}} execution = execution_manager.get_by_id(id=execution_id, **filters) - children_execution_ids = getattr(execution, 'children', []) + children_execution_ids = getattr(execution, "children", []) workflow_execution_ids.update(children_execution_ids) - events = ['st2.execution__update', 'st2.execution.output__create'] - for event in stream_manager.listen(events, - end_execution_id=execution_id, - end_event="st2.execution__update", - **kwargs): - status = event.get('status', None) + events = ["st2.execution__update", "st2.execution.output__create"] + for event in stream_manager.listen( + events, + end_execution_id=execution_id, + end_event="st2.execution__update", + **kwargs, + ): + status = event.get("status", None) is_execution_event = status is not None if is_execution_event: context = cls.get_normalized_context_execution_task_event(event) - task_execution_id = context['execution_id'] - task_name = context['task_name'] - task_parent_execution_id = context['parent_execution_id'] + task_execution_id = context["execution_id"] + task_name = context["task_name"] + task_parent_execution_id = context["parent_execution_id"] # An execution is considered a child execution if it has parent execution id is_child_execution = bool(task_parent_execution_id) @@ -1536,14 +1903,18 @@ def tail_execution(cls, execution_manager, stream_manager, execution, output_typ if is_child_execution: if status == LIVEACTION_STATUS_RUNNING: - print('Child execution (task=%s) %s has started.' % (task_name, - task_execution_id)) - print('') + print( + "Child execution (task=%s) %s has started." + % (task_name, task_execution_id) + ) + print("") continue elif status in LIVEACTION_COMPLETED_STATES: - print('') - print('Child execution (task=%s) %s has finished (status=%s).' % - (task_name, task_execution_id, status)) + print("") + print( + "Child execution (task=%s) %s has finished (status=%s)." + % (task_name, task_execution_id, status) + ) if is_tailing_execution_child_execution: # User is tailing a child execution inside a workflow, stop the command. @@ -1556,56 +1927,69 @@ def tail_execution(cls, execution_manager, stream_manager, execution, output_typ else: # NOTE: In some situations execution update event with "running" status is # dispatched twice so we ignore any duplicated events - if status == LIVEACTION_STATUS_RUNNING and not event.get('children', []): - print('Execution %s has started.' % (execution_id)) - print('') + if status == LIVEACTION_STATUS_RUNNING and not event.get( + "children", [] + ): + print("Execution %s has started." % (execution_id)) + print("") continue elif status in LIVEACTION_COMPLETED_STATES: # Bail out once parent execution has finished - print('') - print('Execution %s has completed (status=%s).' % (execution_id, status)) + print("") + print( + "Execution %s has completed (status=%s)." + % (execution_id, status) + ) break else: # We don't care about other execution events continue # Ignore events for executions which don't belong to the one we are tailing - event_execution_id = event['execution_id'] + event_execution_id = event["execution_id"] if event_execution_id not in workflow_execution_ids: continue # Filter on output_type if provided - event_output_type = event.get('output_type', None) - if output_type != 'all' and output_type and (event_output_type != output_type): + event_output_type = event.get("output_type", None) + if ( + output_type != "all" + and output_type + and (event_output_type != output_type) + ): continue if include_metadata: - sys.stdout.write('[%s][%s] %s' % (event['timestamp'], event['output_type'], - event['data'])) + sys.stdout.write( + "[%s][%s] %s" + % (event["timestamp"], event["output_type"], event["data"]) + ) else: - sys.stdout.write(event['data']) + sys.stdout.write(event["data"]) @classmethod def get_normalized_context_execution_task_event(cls, event): """ Return a dictionary with normalized context attributes for execution event or object. """ - context = event.get('context', {}) - - result = { - 'parent_execution_id': None, - 'execution_id': None, - 'task_name': None - } - - if 'orquesta' in context: - result['parent_execution_id'] = context.get('parent', {}).get('execution_id', None) - result['execution_id'] = event['id'] - result['task_name'] = context.get('orquesta', {}).get('task_name', 'unknown') + context = event.get("context", {}) + + result = {"parent_execution_id": None, "execution_id": None, "task_name": None} + + if "orquesta" in context: + result["parent_execution_id"] = context.get("parent", {}).get( + "execution_id", None + ) + result["execution_id"] = event["id"] + result["task_name"] = context.get("orquesta", {}).get( + "task_name", "unknown" + ) else: # Action chain workflow - result['parent_execution_id'] = context.get('parent', {}).get('execution_id', None) - result['execution_id'] = event['id'] - result['task_name'] = context.get('chain', {}).get('name', 'unknown') + result["parent_execution_id"] = context.get("parent", {}).get( + "execution_id", None + ) + result["execution_id"] = event["id"] + result["task_name"] = context.get("chain", {}).get("name", "unknown") return result diff --git a/st2client/st2client/commands/action_alias.py b/st2client/st2client/commands/action_alias.py index 32a65776cce..d6f5fbcfc1e 100644 --- a/st2client/st2client/commands/action_alias.py +++ b/st2client/st2client/commands/action_alias.py @@ -22,63 +22,87 @@ from st2client.formatters import table -__all__ = [ - 'ActionAliasBranch', - 'ActionAliasMatchCommand', - 'ActionAliasExecuteCommand' -] +__all__ = ["ActionAliasBranch", "ActionAliasMatchCommand", "ActionAliasExecuteCommand"] class ActionAliasBranch(resource.ResourceBranch): def __init__(self, description, app, subparsers, parent_parser=None): super(ActionAliasBranch, self).__init__( - ActionAlias, description, app, subparsers, - parent_parser=parent_parser, read_only=False, - commands={ - 'list': ActionAliasListCommand, - 'get': ActionAliasGetCommand - }) - - self.commands['match'] = ActionAliasMatchCommand( - self.resource, self.app, self.subparsers, - add_help=True) - self.commands['execute'] = ActionAliasExecuteCommand( - self.resource, self.app, self.subparsers, - add_help=True) + ActionAlias, + description, + app, + subparsers, + parent_parser=parent_parser, + read_only=False, + commands={"list": ActionAliasListCommand, "get": ActionAliasGetCommand}, + ) + + self.commands["match"] = ActionAliasMatchCommand( + self.resource, self.app, self.subparsers, add_help=True + ) + self.commands["execute"] = ActionAliasExecuteCommand( + self.resource, self.app, self.subparsers, add_help=True + ) class ActionAliasListCommand(resource.ContentPackResourceListCommand): - display_attributes = ['ref', 'pack', 'description', 'enabled'] + display_attributes = ["ref", "pack", "description", "enabled"] class ActionAliasGetCommand(resource.ContentPackResourceGetCommand): - display_attributes = ['all'] - attribute_display_order = ['id', 'ref', 'pack', 'name', 'description', - 'enabled', 'action_ref', 'formats'] + display_attributes = ["all"] + attribute_display_order = [ + "id", + "ref", + "pack", + "name", + "description", + "enabled", + "action_ref", + "formats", + ] class ActionAliasMatchCommand(resource.ResourceCommand): - display_attributes = ['name', 'description'] + display_attributes = ["name", "description"] def __init__(self, resource, *args, **kwargs): super(ActionAliasMatchCommand, self).__init__( - resource, 'match', - 'Get the %s that match the command text.' % - resource.get_display_name().lower(), - *args, **kwargs) - - self.parser.add_argument('match_text', - metavar='command', - help=('Get the %s that match the command text.' % - resource.get_display_name().lower())) - self.parser.add_argument('-a', '--attr', nargs='+', - default=self.display_attributes, - help=('List of attributes to include in the ' - 'output. "all" will return all ' - 'attributes.')) - self.parser.add_argument('-w', '--width', nargs='+', type=int, - default=None, - help=('Set the width of columns in output.')) + resource, + "match", + "Get the %s that match the command text." + % resource.get_display_name().lower(), + *args, + **kwargs, + ) + + self.parser.add_argument( + "match_text", + metavar="command", + help=( + "Get the %s that match the command text." + % resource.get_display_name().lower() + ), + ) + self.parser.add_argument( + "-a", + "--attr", + nargs="+", + default=self.display_attributes, + help=( + "List of attributes to include in the " + 'output. "all" will return all ' + "attributes." + ), + ) + self.parser.add_argument( + "-w", + "--width", + nargs="+", + type=int, + default=None, + help=("Set the width of columns in output."), + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): @@ -90,40 +114,62 @@ def run(self, args, **kwargs): def run_and_print(self, args, **kwargs): instances = self.run(args, **kwargs) - self.print_output(instances, table.MultiColumnTable, - attributes=args.attr, widths=args.width, - json=args.json, yaml=args.yaml) + self.print_output( + instances, + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + json=args.json, + yaml=args.yaml, + ) class ActionAliasExecuteCommand(resource.ResourceCommand): - display_attributes = ['name'] + display_attributes = ["name"] def __init__(self, resource, *args, **kwargs): super(ActionAliasExecuteCommand, self).__init__( - resource, 'execute', - ('Execute the command text by finding a matching %s.' % - resource.get_display_name().lower()), *args, **kwargs) - - self.parser.add_argument('command_text', - metavar='command', - help=('Execute the command text by finding a matching %s.' % - resource.get_display_name().lower())) - self.parser.add_argument('-u', '--user', type=str, default=None, - help='User under which to run the action (admins only).') + resource, + "execute", + ( + "Execute the command text by finding a matching %s." + % resource.get_display_name().lower() + ), + *args, + **kwargs, + ) + + self.parser.add_argument( + "command_text", + metavar="command", + help=( + "Execute the command text by finding a matching %s." + % resource.get_display_name().lower() + ), + ) + self.parser.add_argument( + "-u", + "--user", + type=str, + default=None, + help="User under which to run the action (admins only).", + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): payload = core.Resource() payload.command = args.command_text payload.user = args.user or "" - payload.source_channel = 'cli' + payload.source_channel = "cli" - alias_execution_mgr = self.app.client.managers['ActionAliasExecution'] + alias_execution_mgr = self.app.client.managers["ActionAliasExecution"] execution = alias_execution_mgr.match_and_execute(payload) return execution def run_and_print(self, args, **kwargs): execution = self.run(args, **kwargs) - print("Matching Action-alias: '%s'" % execution.actionalias['ref']) - print("To get the results, execute:\n st2 execution get %s" % - (execution.execution['id'])) + print("Matching Action-alias: '%s'" % execution.actionalias["ref"]) + print( + "To get the results, execute:\n st2 execution get %s" + % (execution.execution["id"]) + ) diff --git a/st2client/st2client/commands/auth.py b/st2client/st2client/commands/auth.py index 40066d1a5e9..5b0507f324d 100644 --- a/st2client/st2client/commands/auth.py +++ b/st2client/st2client/commands/auth.py @@ -39,36 +39,54 @@ class TokenCreateCommand(resource.ResourceCommand): - display_attributes = ['user', 'token', 'expiry'] + display_attributes = ["user", "token", "expiry"] def __init__(self, resource, *args, **kwargs): - kwargs['has_token_opt'] = False + kwargs["has_token_opt"] = False super(TokenCreateCommand, self).__init__( - resource, kwargs.pop('name', 'create'), - 'Authenticate user and acquire access token.', - *args, **kwargs) - - self.parser.add_argument('username', - help='Name of the user to authenticate.') - - self.parser.add_argument('-p', '--password', dest='password', - help='Password for the user. If password is not provided, ' - 'it will be prompted for.') - self.parser.add_argument('-l', '--ttl', type=int, dest='ttl', default=None, - help='The life span of the token in seconds. ' - 'Max TTL configured by the admin supersedes this.') - self.parser.add_argument('-t', '--only-token', action='store_true', dest='only_token', - default=False, - help='On successful authentication, print only token to the ' - 'console.') + resource, + kwargs.pop("name", "create"), + "Authenticate user and acquire access token.", + *args, + **kwargs, + ) + + self.parser.add_argument("username", help="Name of the user to authenticate.") + + self.parser.add_argument( + "-p", + "--password", + dest="password", + help="Password for the user. If password is not provided, " + "it will be prompted for.", + ) + self.parser.add_argument( + "-l", + "--ttl", + type=int, + dest="ttl", + default=None, + help="The life span of the token in seconds. " + "Max TTL configured by the admin supersedes this.", + ) + self.parser.add_argument( + "-t", + "--only-token", + action="store_true", + dest="only_token", + default=False, + help="On successful authentication, print only token to the " "console.", + ) def run(self, args, **kwargs): if not args.password: args.password = getpass.getpass() instance = self.resource(ttl=args.ttl) if args.ttl else self.resource() - return self.manager.create(instance, auth=(args.username, args.password), **kwargs) + return self.manager.create( + instance, auth=(args.username, args.password), **kwargs + ) def run_and_print(self, args, **kwargs): instance = self.run(args, **kwargs) @@ -76,35 +94,57 @@ def run_and_print(self, args, **kwargs): if args.only_token: print(instance.token) else: - self.print_output(instance, table.PropertyValueTable, - attributes=self.display_attributes, json=args.json, yaml=args.yaml) + self.print_output( + instance, + table.PropertyValueTable, + attributes=self.display_attributes, + json=args.json, + yaml=args.yaml, + ) class LoginCommand(resource.ResourceCommand): - display_attributes = ['user', 'token', 'expiry'] + display_attributes = ["user", "token", "expiry"] def __init__(self, resource, *args, **kwargs): - kwargs['has_token_opt'] = False + kwargs["has_token_opt"] = False super(LoginCommand, self).__init__( - resource, kwargs.pop('name', 'create'), - 'Authenticate user, acquire access token, and update CLI config directory', - *args, **kwargs) - - self.parser.add_argument('username', - help='Name of the user to authenticate.') - - self.parser.add_argument('-p', '--password', dest='password', - help='Password for the user. If password is not provided, ' - 'it will be prompted for.') - self.parser.add_argument('-l', '--ttl', type=int, dest='ttl', default=None, - help='The life span of the token in seconds. ' - 'Max TTL configured by the admin supersedes this.') - self.parser.add_argument('-w', '--write-password', action='store_true', default=False, - dest='write_password', - help='Write the password in plain text to the config file ' - '(default is to omit it)') + resource, + kwargs.pop("name", "create"), + "Authenticate user, acquire access token, and update CLI config directory", + *args, + **kwargs, + ) + + self.parser.add_argument("username", help="Name of the user to authenticate.") + + self.parser.add_argument( + "-p", + "--password", + dest="password", + help="Password for the user. If password is not provided, " + "it will be prompted for.", + ) + self.parser.add_argument( + "-l", + "--ttl", + type=int, + dest="ttl", + default=None, + help="The life span of the token in seconds. " + "Max TTL configured by the admin supersedes this.", + ) + self.parser.add_argument( + "-w", + "--write-password", + action="store_true", + default=False, + dest="write_password", + help="Write the password in plain text to the config file " + "(default is to omit it)", + ) def run(self, args, **kwargs): @@ -122,7 +162,9 @@ def run(self, args, **kwargs): config_file = config_parser.ST2_CONFIG_PATH # Retrieve token - manager = self.manager.create(instance, auth=(args.username, args.password), **kwargs) + manager = self.manager.create( + instance, auth=(args.username, args.password), **kwargs + ) cli._cache_auth_token(token_obj=manager) # Update existing configuration with new credentials @@ -130,18 +172,18 @@ def run(self, args, **kwargs): config.read(config_file) # Modify config (and optionally populate with password) - if not config.has_section('credentials'): - config.add_section('credentials') + if not config.has_section("credentials"): + config.add_section("credentials") - config.set('credentials', 'username', args.username) + config.set("credentials", "username", args.username) if args.write_password: - config.set('credentials', 'password', args.password) + config.set("credentials", "password", args.password) else: # Remove any existing password from config - config.remove_option('credentials', 'password') + config.remove_option("credentials", "password") config_existed = os.path.exists(config_file) - with open(config_file, 'w') as cfg_file_out: + with open(config_file, "w") as cfg_file_out: config.write(cfg_file_out) # If we created the config file, correct the permissions if not config_existed: @@ -156,35 +198,44 @@ def run_and_print(self, args, **kwargs): if self.app.client.debug: raise - raise Exception('Failed to log in as %s: %s' % (args.username, six.text_type(e))) + raise Exception( + "Failed to log in as %s: %s" % (args.username, six.text_type(e)) + ) - print('Logged in as %s' % (args.username)) + print("Logged in as %s" % (args.username)) if not args.write_password: # Note: Client can't depend and import from common so we need to hard-code this # default value token_expire_hours = 24 - print('') - print('Note: You didn\'t use --write-password option so the password hasn\'t been ' - 'stored in the client config and you will need to login again in %s hours when ' - 'the auth token expires.' % (token_expire_hours)) - print('As an alternative, you can run st2 login command with the "--write-password" ' - 'flag, but keep it mind this will cause it to store the password in plain-text ' - 'in the client config file (~/.st2/config).') + print("") + print( + "Note: You didn't use --write-password option so the password hasn't been " + "stored in the client config and you will need to login again in %s hours when " + "the auth token expires." % (token_expire_hours) + ) + print( + 'As an alternative, you can run st2 login command with the "--write-password" ' + "flag, but keep it mind this will cause it to store the password in plain-text " + "in the client config file (~/.st2/config)." + ) class WhoamiCommand(resource.ResourceCommand): - display_attributes = ['user', 'token', 'expiry'] + display_attributes = ["user", "token", "expiry"] def __init__(self, resource, *args, **kwargs): - kwargs['has_token_opt'] = False + kwargs["has_token_opt"] = False super(WhoamiCommand, self).__init__( - resource, kwargs.pop('name', 'create'), - 'Display the currently authenticated user', - *args, **kwargs) + resource, + kwargs.pop("name", "create"), + "Display the currently authenticated user", + *args, + **kwargs, + ) def run(self, args, **kwargs): user_info = self.app.client.get_user_info(**kwargs) @@ -194,119 +245,157 @@ def run_and_print(self, args, **kwargs): try: user_info = self.run(args, **kwargs) except Exception as e: - response = getattr(e, 'response', None) - status_code = getattr(response, 'status_code', None) - is_unathorized_error = (status_code == http_client.UNAUTHORIZED) + response = getattr(e, "response", None) + status_code = getattr(response, "status_code", None) + is_unathorized_error = status_code == http_client.UNAUTHORIZED if response and is_unathorized_error: - print('Not authenticated') + print("Not authenticated") else: - print('Unable to retrieve currently logged-in user') + print("Unable to retrieve currently logged-in user") if self.app.client.debug: raise return - print('Currently logged in as "%s".' % (user_info['username'])) - print('') - print('Authentication method: %s' % (user_info['authentication']['method'])) + print('Currently logged in as "%s".' % (user_info["username"])) + print("") + print("Authentication method: %s" % (user_info["authentication"]["method"])) - if user_info['authentication']['method'] == 'authentication token': - print('Authentication token expire time: %s' % - (user_info['authentication']['token_expire'])) + if user_info["authentication"]["method"] == "authentication token": + print( + "Authentication token expire time: %s" + % (user_info["authentication"]["token_expire"]) + ) - print('') - print('RBAC:') - print(' - Enabled: %s' % (user_info['rbac']['enabled'])) - print(' - Roles: %s' % (', '.join(user_info['rbac']['roles']))) + print("") + print("RBAC:") + print(" - Enabled: %s" % (user_info["rbac"]["enabled"])) + print(" - Roles: %s" % (", ".join(user_info["rbac"]["roles"]))) class ApiKeyBranch(resource.ResourceBranch): - def __init__(self, description, app, subparsers, parent_parser=None): super(ApiKeyBranch, self).__init__( - models.ApiKey, description, app, subparsers, + models.ApiKey, + description, + app, + subparsers, parent_parser=parent_parser, commands={ - 'list': ApiKeyListCommand, - 'get': ApiKeyGetCommand, - 'create': ApiKeyCreateCommand, - 'update': NoopCommand, - 'delete': ApiKeyDeleteCommand - }) - - self.commands['enable'] = ApiKeyEnableCommand(self.resource, self.app, self.subparsers) - self.commands['disable'] = ApiKeyDisableCommand(self.resource, self.app, self.subparsers) - self.commands['load'] = ApiKeyLoadCommand(self.resource, self.app, self.subparsers) + "list": ApiKeyListCommand, + "get": ApiKeyGetCommand, + "create": ApiKeyCreateCommand, + "update": NoopCommand, + "delete": ApiKeyDeleteCommand, + }, + ) + + self.commands["enable"] = ApiKeyEnableCommand( + self.resource, self.app, self.subparsers + ) + self.commands["disable"] = ApiKeyDisableCommand( + self.resource, self.app, self.subparsers + ) + self.commands["load"] = ApiKeyLoadCommand( + self.resource, self.app, self.subparsers + ) class ApiKeyListCommand(resource.ResourceListCommand): - detail_display_attributes = ['all'] - display_attributes = ['id', 'user', 'metadata'] + detail_display_attributes = ["all"] + display_attributes = ["id", "user", "metadata"] def __init__(self, resource, *args, **kwargs): super(ApiKeyListCommand, self).__init__(resource, *args, **kwargs) - self.parser.add_argument('-u', '--user', type=str, - help='Only return ApiKeys belonging to the provided user') - self.parser.add_argument('-d', '--detail', action='store_true', - help='Full list of attributes.') - self.parser.add_argument('--show-secrets', action='store_true', - help='Full list of attributes.') + self.parser.add_argument( + "-u", + "--user", + type=str, + help="Only return ApiKeys belonging to the provided user", + ) + self.parser.add_argument( + "-d", "--detail", action="store_true", help="Full list of attributes." + ) + self.parser.add_argument( + "--show-secrets", action="store_true", help="Full list of attributes." + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): filters = {} - filters['user'] = args.user + filters["user"] = args.user filters.update(**kwargs) # show_secrets is not a filter but a query param. There is some special # handling for filters in the get method which reuqires this odd hack. if args.show_secrets: - params = filters.get('params', {}) - params['show_secrets'] = True - filters['params'] = params + params = filters.get("params", {}) + params["show_secrets"] = True + filters["params"] = params return self.manager.get_all(**filters) def run_and_print(self, args, **kwargs): instances = self.run(args, **kwargs) attr = self.detail_display_attributes if args.detail else args.attr - self.print_output(instances, table.MultiColumnTable, - attributes=attr, widths=args.width, - json=args.json, yaml=args.yaml) + self.print_output( + instances, + table.MultiColumnTable, + attributes=attr, + widths=args.width, + json=args.json, + yaml=args.yaml, + ) class ApiKeyGetCommand(resource.ResourceGetCommand): - display_attributes = ['all'] - attribute_display_order = ['id', 'user', 'metadata'] + display_attributes = ["all"] + attribute_display_order = ["id", "user", "metadata"] - pk_argument_name = 'key_or_id' # name of the attribute which stores resource PK + pk_argument_name = "key_or_id" # name of the attribute which stores resource PK class ApiKeyCreateCommand(resource.ResourceCommand): - def __init__(self, resource, *args, **kwargs): super(ApiKeyCreateCommand, self).__init__( - resource, 'create', 'Create a new %s.' % resource.get_display_name().lower(), - *args, **kwargs) - - self.parser.add_argument('-u', '--user', type=str, - help='User for which to create API Keys.', - default='') - self.parser.add_argument('-m', '--metadata', type=json.loads, - help='Optional metadata to associate with the API Keys.', - default={}) - self.parser.add_argument('-k', '--only-key', action='store_true', dest='only_key', - default=False, - help='Only print API Key to the console on creation.') + resource, + "create", + "Create a new %s." % resource.get_display_name().lower(), + *args, + **kwargs, + ) + + self.parser.add_argument( + "-u", + "--user", + type=str, + help="User for which to create API Keys.", + default="", + ) + self.parser.add_argument( + "-m", + "--metadata", + type=json.loads, + help="Optional metadata to associate with the API Keys.", + default={}, + ) + self.parser.add_argument( + "-k", + "--only-key", + action="store_true", + dest="only_key", + default=False, + help="Only print API Key to the console on creation.", + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): data = {} if args.user: - data['user'] = args.user + data["user"] = args.user if args.metadata: - data['metadata'] = args.metadata + data["metadata"] = args.metadata instance = self.resource.deserialize(data) return self.manager.create(instance, **kwargs) @@ -314,39 +403,59 @@ def run_and_print(self, args, **kwargs): try: instance = self.run(args, **kwargs) if not instance: - raise Exception('Server did not create instance.') + raise Exception("Server did not create instance.") except Exception as e: message = six.text_type(e) - print('ERROR: %s' % (message)) + print("ERROR: %s" % (message)) raise OperationFailureException(message) if args.only_key: print(instance.key) else: - self.print_output(instance, table.PropertyValueTable, - attributes=['all'], json=args.json, yaml=args.yaml) + self.print_output( + instance, + table.PropertyValueTable, + attributes=["all"], + json=args.json, + yaml=args.yaml, + ) class ApiKeyLoadCommand(resource.ResourceCommand): - def __init__(self, resource, *args, **kwargs): super(ApiKeyLoadCommand, self).__init__( - resource, 'load', 'Load %s from a file.' % resource.get_display_name().lower(), - *args, **kwargs) - - self.parser.add_argument('file', - help=('JSON/YAML file containing the %s(s) to load.' - % resource.get_display_name().lower()), - default='') - - self.parser.add_argument('-w', '--width', nargs='+', type=int, - default=None, - help=('Set the width of columns in output.')) + resource, + "load", + "Load %s from a file." % resource.get_display_name().lower(), + *args, + **kwargs, + ) + + self.parser.add_argument( + "file", + help=( + "JSON/YAML file containing the %s(s) to load." + % resource.get_display_name().lower() + ), + default="", + ) + + self.parser.add_argument( + "-w", + "--width", + nargs="+", + type=int, + default=None, + help=("Set the width of columns in output."), + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): resources = resource.load_meta_file(args.file) if not resources: - print('No %s found in %s.' % (self.resource.get_display_name().lower(), args.file)) + print( + "No %s found in %s." + % (self.resource.get_display_name().lower(), args.file) + ) return None if not isinstance(resources, list): resources = [resources] @@ -354,14 +463,14 @@ def run(self, args, **kwargs): for res in resources: # pick only the meaningful properties. data = { - 'user': res['user'], # required - 'key_hash': res['key_hash'], # required - 'metadata': res.get('metadata', {}), - 'enabled': res.get('enabled', False) + "user": res["user"], # required + "key_hash": res["key_hash"], # required + "metadata": res.get("metadata", {}), + "enabled": res.get("enabled", False), } - if 'id' in res: - data['id'] = res['id'] + if "id" in res: + data["id"] = res["id"] instance = self.resource.deserialize(data) @@ -381,19 +490,23 @@ def run(self, args, **kwargs): def run_and_print(self, args, **kwargs): instances = self.run(args, **kwargs) if instances: - self.print_output(instances, table.MultiColumnTable, - attributes=ApiKeyListCommand.display_attributes, - widths=args.width, - json=args.json, yaml=args.yaml) + self.print_output( + instances, + table.MultiColumnTable, + attributes=ApiKeyListCommand.display_attributes, + widths=args.width, + json=args.json, + yaml=args.yaml, + ) class ApiKeyDeleteCommand(resource.ResourceDeleteCommand): - pk_argument_name = 'key_or_id' # name of the attribute which stores resource PK + pk_argument_name = "key_or_id" # name of the attribute which stores resource PK class ApiKeyEnableCommand(resource.ResourceEnableCommand): - pk_argument_name = 'key_or_id' # name of the attribute which stores resource PK + pk_argument_name = "key_or_id" # name of the attribute which stores resource PK class ApiKeyDisableCommand(resource.ResourceDisableCommand): - pk_argument_name = 'key_or_id' # name of the attribute which stores resource PK + pk_argument_name = "key_or_id" # name of the attribute which stores resource PK diff --git a/st2client/st2client/commands/inquiry.py b/st2client/st2client/commands/inquiry.py index 250c86b3d73..d9395a5c548 100644 --- a/st2client/st2client/commands/inquiry.py +++ b/st2client/st2client/commands/inquiry.py @@ -25,60 +25,81 @@ LOG = logging.getLogger(__name__) -DEFAULT_SCOPE = 'system' +DEFAULT_SCOPE = "system" class InquiryBranch(resource.ResourceBranch): - def __init__(self, description, app, subparsers, parent_parser=None): super(InquiryBranch, self).__init__( - Inquiry, description, app, subparsers, - parent_parser=parent_parser, read_only=True, - commands={'list': InquiryListCommand, - 'get': InquiryGetCommand}) + Inquiry, + description, + app, + subparsers, + parent_parser=parent_parser, + read_only=True, + commands={"list": InquiryListCommand, "get": InquiryGetCommand}, + ) # Register extended commands - self.commands['respond'] = InquiryRespondCommand( - self.resource, self.app, self.subparsers) + self.commands["respond"] = InquiryRespondCommand( + self.resource, self.app, self.subparsers + ) class InquiryListCommand(resource.ResourceCommand): # Omitting "schema" and "response", as it doesn't really show up in a table well. # The user can drill into a specific Inquiry to get this - display_attributes = [ - 'id', - 'roles', - 'users', - 'route', - 'ttl' - ] + display_attributes = ["id", "roles", "users", "route", "ttl"] def __init__(self, resource, *args, **kwargs): self.default_limit = 50 super(InquiryListCommand, self).__init__( - resource, 'list', 'Get the list of the %s most recent %s.' % - (self.default_limit, resource.get_plural_display_name().lower()), - *args, **kwargs) + resource, + "list", + "Get the list of the %s most recent %s." + % (self.default_limit, resource.get_plural_display_name().lower()), + *args, + **kwargs, + ) self.resource_name = resource.get_plural_display_name().lower() - self.parser.add_argument('-n', '--last', type=int, dest='last', - default=self.default_limit, - help=('List N most recent %s. Use -n -1 to fetch the full result \ - set.' % self.resource_name)) + self.parser.add_argument( + "-n", + "--last", + type=int, + dest="last", + default=self.default_limit, + help=( + "List N most recent %s. Use -n -1 to fetch the full result \ + set." + % self.resource_name + ), + ) # Display options - self.parser.add_argument('-a', '--attr', nargs='+', - default=self.display_attributes, - help=('List of attributes to include in the ' - 'output. "all" will return all ' - 'attributes.')) - self.parser.add_argument('-w', '--width', nargs='+', type=int, - default=None, - help=('Set the width of columns in output.')) + self.parser.add_argument( + "-a", + "--attr", + nargs="+", + default=self.display_attributes, + help=( + "List of attributes to include in the " + 'output. "all" will return all ' + "attributes." + ), + ) + self.parser.add_argument( + "-w", + "--width", + nargs="+", + type=int, + default=None, + help=("Set the width of columns in output."), + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): @@ -87,17 +108,21 @@ def run(self, args, **kwargs): def run_and_print(self, args, **kwargs): instances, count = self.run(args, **kwargs) - self.print_output(reversed(instances), table.MultiColumnTable, - attributes=args.attr, widths=args.width, - json=args.json, - yaml=args.yaml) + self.print_output( + reversed(instances), + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + json=args.json, + yaml=args.yaml, + ) if args.last and count and count > args.last: table.SingleRowTable.note_box(self.resource_name, args.last) class InquiryGetCommand(resource.ResourceGetCommand): - pk_argument_name = 'id' - display_attributes = ['id', 'roles', 'users', 'route', 'ttl', 'schema'] + pk_argument_name = "id" + display_attributes = ["id", "roles", "users", "route", "ttl", "schema"] def __init__(self, kv_resource, *args, **kwargs): super(InquiryGetCommand, self).__init__(kv_resource, *args, **kwargs) @@ -109,22 +134,28 @@ def run(self, args, **kwargs): class InquiryRespondCommand(resource.ResourceCommand): - display_attributes = ['id', 'response'] + display_attributes = ["id", "response"] def __init__(self, resource, *args, **kwargs): super(InquiryRespondCommand, self).__init__( - resource, 'respond', - 'Respond to an %s.' % resource.get_display_name().lower(), - *args, **kwargs + resource, + "respond", + "Respond to an %s." % resource.get_display_name().lower(), + *args, + **kwargs, ) - self.parser.add_argument('id', - metavar='id', - help='Inquiry ID') - self.parser.add_argument('-r', '--response', type=str, dest='response', - default=None, - help=('Entire response payload as JSON string ' - '(bypass interactive mode)')) + self.parser.add_argument("id", metavar="id", help="Inquiry ID") + self.parser.add_argument( + "-r", + "--response", + type=str, + dest="response", + default=None, + help=( + "Entire response payload as JSON string " "(bypass interactive mode)" + ), + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): @@ -135,12 +166,13 @@ def run(self, args, **kwargs): instance.response = json.loads(args.response) else: response = InteractiveForm( - inquiry.schema.get('properties')).initiate_dialog() + inquiry.schema.get("properties") + ).initiate_dialog() instance.response = response - return self.manager.respond(inquiry_id=instance.id, - inquiry_response=instance.response, - **kwargs) + return self.manager.respond( + inquiry_id=instance.id, inquiry_response=instance.response, **kwargs + ) def run_and_print(self, args, **kwargs): instance = self.run(args, **kwargs) diff --git a/st2client/st2client/commands/keyvalue.py b/st2client/st2client/commands/keyvalue.py index 9c0d06f806a..e87f6afa358 100644 --- a/st2client/st2client/commands/keyvalue.py +++ b/st2client/st2client/commands/keyvalue.py @@ -31,83 +31,125 @@ LOG = logging.getLogger(__name__) -DEFAULT_LIST_SCOPE = 'all' -DEFAULT_GET_SCOPE = 'system' -DEFAULT_CUD_SCOPE = 'system' +DEFAULT_LIST_SCOPE = "all" +DEFAULT_GET_SCOPE = "system" +DEFAULT_CUD_SCOPE = "system" class KeyValuePairBranch(resource.ResourceBranch): - def __init__(self, description, app, subparsers, parent_parser=None): super(KeyValuePairBranch, self).__init__( - KeyValuePair, description, app, subparsers, + KeyValuePair, + description, + app, + subparsers, parent_parser=parent_parser, commands={ - 'list': KeyValuePairListCommand, - 'get': KeyValuePairGetCommand, - 'delete': KeyValuePairDeleteCommand, - 'create': NoopCommand, - 'update': NoopCommand - }) + "list": KeyValuePairListCommand, + "get": KeyValuePairGetCommand, + "delete": KeyValuePairDeleteCommand, + "create": NoopCommand, + "update": NoopCommand, + }, + ) # Registers extended commands - self.commands['set'] = KeyValuePairSetCommand(self.resource, self.app, - self.subparsers) - self.commands['load'] = KeyValuePairLoadCommand( - self.resource, self.app, self.subparsers) - self.commands['delete_by_prefix'] = KeyValuePairDeleteByPrefixCommand( - self.resource, self.app, self.subparsers) + self.commands["set"] = KeyValuePairSetCommand( + self.resource, self.app, self.subparsers + ) + self.commands["load"] = KeyValuePairLoadCommand( + self.resource, self.app, self.subparsers + ) + self.commands["delete_by_prefix"] = KeyValuePairDeleteByPrefixCommand( + self.resource, self.app, self.subparsers + ) # Remove unsupported commands # TODO: Refactor parent class and make it nicer - del self.commands['create'] - del self.commands['update'] + del self.commands["create"] + del self.commands["update"] class KeyValuePairListCommand(resource.ResourceTableCommand): - display_attributes = ['name', 'value', 'secret', 'encrypted', 'scope', 'user', - 'expire_timestamp'] + display_attributes = [ + "name", + "value", + "secret", + "encrypted", + "scope", + "user", + "expire_timestamp", + ] attribute_transform_functions = { - 'expire_timestamp': format_isodate_for_user_timezone, + "expire_timestamp": format_isodate_for_user_timezone, } def __init__(self, resource, *args, **kwargs): self.default_limit = 50 - super(KeyValuePairListCommand, self).__init__(resource, 'list', - 'Get the list of the %s most recent %s.' % - (self.default_limit, - resource.get_plural_display_name().lower()), - *args, **kwargs) + super(KeyValuePairListCommand, self).__init__( + resource, + "list", + "Get the list of the %s most recent %s." + % (self.default_limit, resource.get_plural_display_name().lower()), + *args, + **kwargs, + ) self.resource_name = resource.get_plural_display_name().lower() # Filter options - self.parser.add_argument('--prefix', help=('Only return values with names starting with ' - 'the provided prefix.')) - self.parser.add_argument('-d', '--decrypt', action='store_true', - help='Decrypt secrets and displays plain text.') - self.parser.add_argument('-s', '--scope', default=DEFAULT_LIST_SCOPE, dest='scope', - help='Scope item is under. Example: "user".') - self.parser.add_argument('-u', '--user', dest='user', default=None, - help='User for user scoped items (admin only).') - self.parser.add_argument('-n', '--last', type=int, dest='last', - default=self.default_limit, - help=('List N most recent %s. Use -n -1 to fetch the full result \ - set.' % self.resource_name)) + self.parser.add_argument( + "--prefix", + help=( + "Only return values with names starting with " "the provided prefix." + ), + ) + self.parser.add_argument( + "-d", + "--decrypt", + action="store_true", + help="Decrypt secrets and displays plain text.", + ) + self.parser.add_argument( + "-s", + "--scope", + default=DEFAULT_LIST_SCOPE, + dest="scope", + help='Scope item is under. Example: "user".', + ) + self.parser.add_argument( + "-u", + "--user", + dest="user", + default=None, + help="User for user scoped items (admin only).", + ) + self.parser.add_argument( + "-n", + "--last", + type=int, + dest="last", + default=self.default_limit, + help=( + "List N most recent %s. Use -n -1 to fetch the full result \ + set." + % self.resource_name + ), + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): # Filtering options if args.prefix: - kwargs['prefix'] = args.prefix + kwargs["prefix"] = args.prefix - decrypt = getattr(args, 'decrypt', False) - kwargs['params'] = {'decrypt': str(decrypt).lower()} - scope = getattr(args, 'scope', DEFAULT_LIST_SCOPE) - kwargs['params']['scope'] = scope + decrypt = getattr(args, "decrypt", False) + kwargs["params"] = {"decrypt": str(decrypt).lower()} + scope = getattr(args, "scope", DEFAULT_LIST_SCOPE) + kwargs["params"]["scope"] = scope if args.user: - kwargs['params']['user'] = args.user - kwargs['params']['limit'] = args.last + kwargs["params"]["user"] = args.user + kwargs["params"]["limit"] = args.last return self.manager.query_with_count(**kwargs) @@ -115,73 +157,124 @@ def run(self, args, **kwargs): def run_and_print(self, args, **kwargs): instances, count = self.run(args, **kwargs) if args.json or args.yaml: - self.print_output(reversed(instances), table.MultiColumnTable, - attributes=args.attr, widths=args.width, - json=args.json, yaml=args.yaml, - attribute_transform_functions=self.attribute_transform_functions) + self.print_output( + reversed(instances), + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + json=args.json, + yaml=args.yaml, + attribute_transform_functions=self.attribute_transform_functions, + ) else: - self.print_output(instances, table.MultiColumnTable, - attributes=args.attr, widths=args.width, - attribute_transform_functions=self.attribute_transform_functions) + self.print_output( + instances, + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + attribute_transform_functions=self.attribute_transform_functions, + ) if args.last and count and count > args.last: table.SingleRowTable.note_box(self.resource_name, args.last) class KeyValuePairGetCommand(resource.ResourceGetCommand): - pk_argument_name = 'name' - display_attributes = ['name', 'value', 'secret', 'encrypted', 'scope', 'expire_timestamp'] + pk_argument_name = "name" + display_attributes = [ + "name", + "value", + "secret", + "encrypted", + "scope", + "expire_timestamp", + ] def __init__(self, kv_resource, *args, **kwargs): super(KeyValuePairGetCommand, self).__init__(kv_resource, *args, **kwargs) - self.parser.add_argument('-d', '--decrypt', action='store_true', - help='Decrypt secret if encrypted and show plain text.') - self.parser.add_argument('-s', '--scope', default=DEFAULT_GET_SCOPE, dest='scope', - help='Scope item is under. Example: "user".') + self.parser.add_argument( + "-d", + "--decrypt", + action="store_true", + help="Decrypt secret if encrypted and show plain text.", + ) + self.parser.add_argument( + "-s", + "--scope", + default=DEFAULT_GET_SCOPE, + dest="scope", + help='Scope item is under. Example: "user".', + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): resource_name = getattr(args, self.pk_argument_name, None) - decrypt = getattr(args, 'decrypt', False) - scope = getattr(args, 'scope', DEFAULT_GET_SCOPE) - kwargs['params'] = {'decrypt': str(decrypt).lower()} - kwargs['params']['scope'] = scope + decrypt = getattr(args, "decrypt", False) + scope = getattr(args, "scope", DEFAULT_GET_SCOPE) + kwargs["params"] = {"decrypt": str(decrypt).lower()} + kwargs["params"]["scope"] = scope return self.get_resource_by_id(id=resource_name, **kwargs) class KeyValuePairSetCommand(resource.ResourceCommand): - display_attributes = ['name', 'value', 'scope', 'expire_timestamp'] + display_attributes = ["name", "value", "scope", "expire_timestamp"] def __init__(self, resource, *args, **kwargs): super(KeyValuePairSetCommand, self).__init__( - resource, 'set', - 'Set an existing %s.' % resource.get_display_name().lower(), - *args, **kwargs + resource, + "set", + "Set an existing %s." % resource.get_display_name().lower(), + *args, + **kwargs, ) # --encrypt and --encrypted options are mutually exclusive. # --encrypt implies provided value is plain text and should be encrypted whereas # --encrypted implies value is already encrypted and should be treated as-is. encryption_group = self.parser.add_mutually_exclusive_group() - encryption_group.add_argument('-e', '--encrypt', dest='secret', - action='store_true', - help='Encrypt value before saving.') - encryption_group.add_argument('--encrypted', dest='encrypted', - action='store_true', - help=('Value provided is already encrypted with the ' - 'instance crypto key and should be stored as-is.')) - - self.parser.add_argument('name', - metavar='name', - help='Name of the key value pair.') - self.parser.add_argument('value', help='Value paired with the key.') - self.parser.add_argument('-l', '--ttl', dest='ttl', type=int, default=None, - help='TTL (in seconds) for this value.') - self.parser.add_argument('-s', '--scope', dest='scope', default=DEFAULT_CUD_SCOPE, - help='Specify the scope under which you want ' + - 'to place the item.') - self.parser.add_argument('-u', '--user', dest='user', default=None, - help='User for user scoped items (admin only).') + encryption_group.add_argument( + "-e", + "--encrypt", + dest="secret", + action="store_true", + help="Encrypt value before saving.", + ) + encryption_group.add_argument( + "--encrypted", + dest="encrypted", + action="store_true", + help=( + "Value provided is already encrypted with the " + "instance crypto key and should be stored as-is." + ), + ) + + self.parser.add_argument( + "name", metavar="name", help="Name of the key value pair." + ) + self.parser.add_argument("value", help="Value paired with the key.") + self.parser.add_argument( + "-l", + "--ttl", + dest="ttl", + type=int, + default=None, + help="TTL (in seconds) for this value.", + ) + self.parser.add_argument( + "-s", + "--scope", + dest="scope", + default=DEFAULT_CUD_SCOPE, + help="Specify the scope under which you want " + "to place the item.", + ) + self.parser.add_argument( + "-u", + "--user", + dest="user", + default=None, + help="User for user scoped items (admin only).", + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): @@ -205,35 +298,49 @@ def run(self, args, **kwargs): def run_and_print(self, args, **kwargs): instance = self.run(args, **kwargs) - self.print_output(instance, table.PropertyValueTable, - attributes=self.display_attributes, json=args.json, - yaml=args.yaml) + self.print_output( + instance, + table.PropertyValueTable, + attributes=self.display_attributes, + json=args.json, + yaml=args.yaml, + ) class KeyValuePairDeleteCommand(resource.ResourceDeleteCommand): - pk_argument_name = 'name' + pk_argument_name = "name" def __init__(self, resource, *args, **kwargs): super(KeyValuePairDeleteCommand, self).__init__(resource, *args, **kwargs) - self.parser.add_argument('-s', '--scope', dest='scope', default=DEFAULT_CUD_SCOPE, - help='Specify the scope under which you want ' + - 'to place the item.') - self.parser.add_argument('-u', '--user', dest='user', default=None, - help='User for user scoped items (admin only).') + self.parser.add_argument( + "-s", + "--scope", + dest="scope", + default=DEFAULT_CUD_SCOPE, + help="Specify the scope under which you want " + "to place the item.", + ) + self.parser.add_argument( + "-u", + "--user", + dest="user", + default=None, + help="User for user scoped items (admin only).", + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): resource_id = getattr(args, self.pk_argument_name, None) - scope = getattr(args, 'scope', DEFAULT_CUD_SCOPE) - kwargs['params'] = {} - kwargs['params']['scope'] = scope - kwargs['params']['user'] = args.user + scope = getattr(args, "scope", DEFAULT_CUD_SCOPE) + kwargs["params"] = {} + kwargs["params"]["scope"] = scope + kwargs["params"]["user"] = args.user instance = self.get_resource(resource_id, **kwargs) if not instance: - raise resource.ResourceNotFoundError('KeyValuePair with id "%s" not found' - % resource_id) + raise resource.ResourceNotFoundError( + 'KeyValuePair with id "%s" not found' % resource_id + ) instance.id = resource_id # TODO: refactor and get rid of id self.manager.delete(instance, **kwargs) @@ -244,14 +351,23 @@ class KeyValuePairDeleteByPrefixCommand(resource.ResourceCommand): Commands which delete all the key value pairs which match the provided prefix. """ + def __init__(self, resource, *args, **kwargs): - super(KeyValuePairDeleteByPrefixCommand, self).__init__(resource, 'delete_by_prefix', - 'Delete KeyValue pairs which \ - match the provided prefix', - *args, **kwargs) + super(KeyValuePairDeleteByPrefixCommand, self).__init__( + resource, + "delete_by_prefix", + "Delete KeyValue pairs which \ + match the provided prefix", + *args, + **kwargs, + ) - self.parser.add_argument('-p', '--prefix', required=True, - help='Name prefix (e.g. twitter.TwitterSensor:)') + self.parser.add_argument( + "-p", + "--prefix", + required=True, + help="Name prefix (e.g. twitter.TwitterSensor:)", + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): @@ -276,27 +392,39 @@ def run_and_print(self, args, **kwargs): deleted = self.run(args, **kwargs) key_ids = [key_pair.id for key_pair in deleted] - print('Deleted %s keys' % (len(deleted))) - print('Deleted key ids: %s' % (', '.join(key_ids))) + print("Deleted %s keys" % (len(deleted))) + print("Deleted key ids: %s" % (", ".join(key_ids))) class KeyValuePairLoadCommand(resource.ResourceCommand): - pk_argument_name = 'name' - display_attributes = ['name', 'value'] + pk_argument_name = "name" + display_attributes = ["name", "value"] def __init__(self, resource, *args, **kwargs): - help_text = ('Load a list of %s from file.' % - resource.get_plural_display_name().lower()) - super(KeyValuePairLoadCommand, self).__init__(resource, 'load', - help_text, *args, **kwargs) - - self.parser.add_argument('-c', '--convert', action='store_true', - help=('Convert non-string types (hash, array, boolean,' - ' int, float) to a JSON string before loading it' - ' into the datastore.')) + help_text = ( + "Load a list of %s from file." % resource.get_plural_display_name().lower() + ) + super(KeyValuePairLoadCommand, self).__init__( + resource, "load", help_text, *args, **kwargs + ) + + self.parser.add_argument( + "-c", + "--convert", + action="store_true", + help=( + "Convert non-string types (hash, array, boolean," + " int, float) to a JSON string before loading it" + " into the datastore." + ), + ) self.parser.add_argument( - 'file', help=('JSON/YAML file containing the %s(s) to load' - % resource.get_plural_display_name().lower())) + "file", + help=( + "JSON/YAML file containing the %s(s) to load" + % resource.get_plural_display_name().lower() + ), + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): @@ -318,15 +446,15 @@ def run(self, args, **kwargs): for item in kvps: # parse required KeyValuePair properties - name = item['name'] - value = item['value'] + name = item["name"] + value = item["value"] # parse optional KeyValuePair properties - scope = item.get('scope', DEFAULT_CUD_SCOPE) - user = item.get('user', None) - encrypted = item.get('encrypted', False) - secret = item.get('secret', False) - ttl = item.get('ttl', None) + scope = item.get("scope", DEFAULT_CUD_SCOPE) + user = item.get("user", None) + encrypted = item.get("encrypted", False) + secret = item.get("secret", False) + ttl = item.get("ttl", None) # if the value is not a string, convert it to JSON # all keys in the datastore must strings @@ -334,10 +462,15 @@ def run(self, args, **kwargs): if args.convert: value = json.dumps(value) else: - raise ValueError(("Item '%s' has a value that is not a string." - " Either pass in the -c/--convert option to convert" - " non-string types to JSON strings automatically, or" - " convert the data to a string in the file") % name) + raise ValueError( + ( + "Item '%s' has a value that is not a string." + " Either pass in the -c/--convert option to convert" + " non-string types to JSON strings automatically, or" + " convert the data to a string in the file" + ) + % name + ) # create the KeyValuePair instance instance = KeyValuePair() @@ -368,7 +501,10 @@ def run(self, args, **kwargs): def run_and_print(self, args, **kwargs): instances = self.run(args, **kwargs) - self.print_output(instances, table.MultiColumnTable, - attributes=['name', 'value', 'secret', 'scope', 'user', 'ttl'], - json=args.json, - yaml=args.yaml) + self.print_output( + instances, + table.MultiColumnTable, + attributes=["name", "value", "secret", "scope", "user", "ttl"], + json=args.json, + yaml=args.yaml, + ) diff --git a/st2client/st2client/commands/pack.py b/st2client/st2client/commands/pack.py index 827db663df7..8d05fc88dc6 100644 --- a/st2client/st2client/commands/pack.py +++ b/st2client/st2client/commands/pack.py @@ -34,43 +34,56 @@ from st2client.utils import interactive -LIVEACTION_STATUS_REQUESTED = 'requested' -LIVEACTION_STATUS_SCHEDULED = 'scheduled' -LIVEACTION_STATUS_DELAYED = 'delayed' -LIVEACTION_STATUS_RUNNING = 'running' -LIVEACTION_STATUS_SUCCEEDED = 'succeeded' -LIVEACTION_STATUS_FAILED = 'failed' -LIVEACTION_STATUS_TIMED_OUT = 'timeout' -LIVEACTION_STATUS_ABANDONED = 'abandoned' -LIVEACTION_STATUS_CANCELING = 'canceling' -LIVEACTION_STATUS_CANCELED = 'canceled' +LIVEACTION_STATUS_REQUESTED = "requested" +LIVEACTION_STATUS_SCHEDULED = "scheduled" +LIVEACTION_STATUS_DELAYED = "delayed" +LIVEACTION_STATUS_RUNNING = "running" +LIVEACTION_STATUS_SUCCEEDED = "succeeded" +LIVEACTION_STATUS_FAILED = "failed" +LIVEACTION_STATUS_TIMED_OUT = "timeout" +LIVEACTION_STATUS_ABANDONED = "abandoned" +LIVEACTION_STATUS_CANCELING = "canceling" +LIVEACTION_STATUS_CANCELED = "canceled" LIVEACTION_COMPLETED_STATES = [ LIVEACTION_STATUS_SUCCEEDED, LIVEACTION_STATUS_FAILED, LIVEACTION_STATUS_TIMED_OUT, LIVEACTION_STATUS_CANCELED, - LIVEACTION_STATUS_ABANDONED + LIVEACTION_STATUS_ABANDONED, ] class PackBranch(resource.ResourceBranch): def __init__(self, description, app, subparsers, parent_parser=None): super(PackBranch, self).__init__( - Pack, description, app, subparsers, + Pack, + description, + app, + subparsers, parent_parser=parent_parser, read_only=True, - commands={ - 'list': PackListCommand, - 'get': PackGetCommand - }) - - self.commands['show'] = PackShowCommand(self.resource, self.app, self.subparsers) - self.commands['search'] = PackSearchCommand(self.resource, self.app, self.subparsers) - self.commands['install'] = PackInstallCommand(self.resource, self.app, self.subparsers) - self.commands['remove'] = PackRemoveCommand(self.resource, self.app, self.subparsers) - self.commands['register'] = PackRegisterCommand(self.resource, self.app, self.subparsers) - self.commands['config'] = PackConfigCommand(self.resource, self.app, self.subparsers) + commands={"list": PackListCommand, "get": PackGetCommand}, + ) + + self.commands["show"] = PackShowCommand( + self.resource, self.app, self.subparsers + ) + self.commands["search"] = PackSearchCommand( + self.resource, self.app, self.subparsers + ) + self.commands["install"] = PackInstallCommand( + self.resource, self.app, self.subparsers + ) + self.commands["remove"] = PackRemoveCommand( + self.resource, self.app, self.subparsers + ) + self.commands["register"] = PackRegisterCommand( + self.resource, self.app, self.subparsers + ) + self.commands["config"] = PackConfigCommand( + self.resource, self.app, self.subparsers + ) class PackResourceCommand(resource.ResourceCommand): @@ -79,13 +92,18 @@ def run_and_print(self, args, **kwargs): instance = self.run(args, **kwargs) if not instance: raise resource.ResourceNotFoundError("No matching items found") - self.print_output(instance, table.PropertyValueTable, - attributes=['all'], json=args.json, yaml=args.yaml) + self.print_output( + instance, + table.PropertyValueTable, + attributes=["all"], + json=args.json, + yaml=args.yaml, + ) except resource.ResourceNotFoundError: print("No matching items found") except Exception as e: message = six.text_type(e) - print('ERROR: %s' % (message)) + print("ERROR: %s" % (message)) raise OperationFailureException(message) @@ -93,48 +111,72 @@ class PackAsyncCommand(ActionRunCommandMixin, resource.ResourceCommand): def __init__(self, *args, **kwargs): super(PackAsyncCommand, self).__init__(*args, **kwargs) - self.parser.add_argument('-w', '--width', nargs='+', type=int, default=None, - help='Set the width of columns in output.') + self.parser.add_argument( + "-w", + "--width", + nargs="+", + type=int, + default=None, + help="Set the width of columns in output.", + ) detail_arg_grp = self.parser.add_mutually_exclusive_group() - detail_arg_grp.add_argument('--attr', nargs='+', - default=['ref', 'name', 'description', 'version', 'author'], - help=('List of attributes to include in the ' - 'output. "all" or unspecified will ' - 'return all attributes.')) - detail_arg_grp.add_argument('-d', '--detail', action='store_true', - help='Display full detail of the execution in table format.') + detail_arg_grp.add_argument( + "--attr", + nargs="+", + default=["ref", "name", "description", "version", "author"], + help=( + "List of attributes to include in the " + 'output. "all" or unspecified will ' + "return all attributes." + ), + ) + detail_arg_grp.add_argument( + "-d", + "--detail", + action="store_true", + help="Display full detail of the execution in table format.", + ) @add_auth_token_to_kwargs_from_cli def run_and_print(self, args, **kwargs): instance = self.run(args, **kwargs) if not instance: - raise Exception('Server did not create instance.') + raise Exception("Server did not create instance.") parent_id = instance.execution_id - stream_mgr = self.app.client.managers['Stream'] + stream_mgr = self.app.client.managers["Stream"] execution = None with term.TaskIndicator() as indicator: - events = ['st2.execution__create', 'st2.execution__update'] - for event in stream_mgr.listen(events, end_execution_id=parent_id, - end_event="st2.execution__update", **kwargs): + events = ["st2.execution__create", "st2.execution__update"] + for event in stream_mgr.listen( + events, + end_execution_id=parent_id, + end_event="st2.execution__update", + **kwargs, + ): execution = Execution(**event) - if execution.id == parent_id \ - and execution.status in LIVEACTION_COMPLETED_STATES: + if ( + execution.id == parent_id + and execution.status in LIVEACTION_COMPLETED_STATES + ): break # Suppress intermediate output in case output formatter is requested if args.json or args.yaml: continue - if getattr(execution, 'parent', None) == parent_id: + if getattr(execution, "parent", None) == parent_id: status = execution.status - name = execution.context['orquesta']['task_name'] \ - if 'orquesta' in execution.context else execution.context['chain']['name'] + name = ( + execution.context["orquesta"]["task_name"] + if "orquesta" in execution.context + else execution.context["chain"]["name"] + ) if status == LIVEACTION_STATUS_SCHEDULED: indicator.add_stage(status, name) @@ -148,31 +190,48 @@ def run_and_print(self, args, **kwargs): self._print_execution_details(execution=execution, args=args, **kwargs) sys.exit(1) - return self.app.client.managers['Execution'].get_by_id(parent_id, **kwargs) + return self.app.client.managers["Execution"].get_by_id(parent_id, **kwargs) class PackListCommand(resource.ResourceListCommand): - display_attributes = ['ref', 'name', 'description', 'version', 'author'] - attribute_display_order = ['ref', 'name', 'description', 'version', 'author'] + display_attributes = ["ref", "name", "description", "version", "author"] + attribute_display_order = ["ref", "name", "description", "version", "author"] class PackGetCommand(resource.ResourceGetCommand): - pk_argument_name = 'ref' - display_attributes = ['name', 'version', 'author', 'email', 'keywords', 'description'] - attribute_display_order = ['name', 'version', 'author', 'email', 'keywords', 'description'] - help_string = 'Get information about an installed pack.' + pk_argument_name = "ref" + display_attributes = [ + "name", + "version", + "author", + "email", + "keywords", + "description", + ] + attribute_display_order = [ + "name", + "version", + "author", + "email", + "keywords", + "description", + ] + help_string = "Get information about an installed pack." class PackShowCommand(PackResourceCommand): def __init__(self, resource, *args, **kwargs): - help_string = ('Get information about an available %s from the index.' % - resource.get_display_name().lower()) - super(PackShowCommand, self).__init__(resource, 'show', help_string, - *args, **kwargs) - - self.parser.add_argument('pack', - help='Name of the %s to show.' % - resource.get_display_name().lower()) + help_string = ( + "Get information about an available %s from the index." + % resource.get_display_name().lower() + ) + super(PackShowCommand, self).__init__( + resource, "show", help_string, *args, **kwargs + ) + + self.parser.add_argument( + "pack", help="Name of the %s to show." % resource.get_display_name().lower() + ) @add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): @@ -181,27 +240,39 @@ def run(self, args, **kwargs): class PackInstallCommand(PackAsyncCommand): def __init__(self, resource, *args, **kwargs): - super(PackInstallCommand, self).__init__(resource, 'install', 'Install new %s.' - % resource.get_plural_display_name().lower(), - *args, **kwargs) - - self.parser.add_argument('packs', - nargs='+', - metavar='pack', - help='Name of the %s in Exchange, or a git repo URL.' % - resource.get_plural_display_name().lower()) - self.parser.add_argument('--python3', - action='store_true', - default=False, - help='Use Python 3 binary for pack virtual environment.') - self.parser.add_argument('--force', - action='store_true', - default=False, - help='Force pack installation.') - self.parser.add_argument('--skip-dependencies', - action='store_true', - default=False, - help='Skip pack dependency installation.') + super(PackInstallCommand, self).__init__( + resource, + "install", + "Install new %s." % resource.get_plural_display_name().lower(), + *args, + **kwargs, + ) + + self.parser.add_argument( + "packs", + nargs="+", + metavar="pack", + help="Name of the %s in Exchange, or a git repo URL." + % resource.get_plural_display_name().lower(), + ) + self.parser.add_argument( + "--python3", + action="store_true", + default=False, + help="Use Python 3 binary for pack virtual environment.", + ) + self.parser.add_argument( + "--force", + action="store_true", + default=False, + help="Force pack installation.", + ) + self.parser.add_argument( + "--skip-dependencies", + action="store_true", + default=False, + help="Skip pack dependency installation.", + ) def run(self, args, **kwargs): is_structured_output = args.json or args.yaml @@ -212,30 +283,42 @@ def run(self, args, **kwargs): self._get_content_counts_for_pack(args, **kwargs) if args.python3: - warnings.warn('DEPRECATION WARNING: --python3 flag is ignored and will be removed ' - 'in v3.5.0 as StackStorm now runs with python3 only') - - return self.manager.install(args.packs, force=args.force, - skip_dependencies=args.skip_dependencies, **kwargs) + warnings.warn( + "DEPRECATION WARNING: --python3 flag is ignored and will be removed " + "in v3.5.0 as StackStorm now runs with python3 only" + ) + + return self.manager.install( + args.packs, + force=args.force, + skip_dependencies=args.skip_dependencies, + **kwargs, + ) def _get_content_counts_for_pack(self, args, **kwargs): # Global content list, excluding "tests" # Note: We skip this step for local packs - pack_content = {'actions': 0, 'rules': 0, 'sensors': 0, 'aliases': 0, 'triggers': 0} + pack_content = { + "actions": 0, + "rules": 0, + "sensors": 0, + "aliases": 0, + "triggers": 0, + } if len(args.packs) == 1: args.pack = args.packs[0] - if args.pack.startswith('file://'): + if args.pack.startswith("file://"): return pack_info = self.manager.search(args, ignore_errors=True, **kwargs) - content = getattr(pack_info, 'content', {}) + content = getattr(pack_info, "content", {}) if content: for entity in content.keys(): if entity in pack_content: - pack_content[entity] += content[entity]['count'] + pack_content[entity] += content[entity]["count"] self._print_pack_content(args.packs, pack_content) else: @@ -246,122 +329,165 @@ def _get_content_counts_for_pack(self, args, **kwargs): # args.pack required for search args.pack = pack - if args.pack.startswith('file://'): + if args.pack.startswith("file://"): return pack_info = self.manager.search(args, ignore_errors=True, **kwargs) - content = getattr(pack_info, 'content', {}) + content = getattr(pack_info, "content", {}) if content: for entity in content.keys(): if entity in pack_content: - pack_content[entity] += content[entity]['count'] + pack_content[entity] += content[entity]["count"] if content: self._print_pack_content(args.packs, pack_content) @staticmethod def _print_pack_content(pack_name, pack_content): - print('\nFor the "%s" %s, the following content will be registered:\n' - % (', '.join(pack_name), 'pack' if len(pack_name) == 1 else 'packs')) + print( + '\nFor the "%s" %s, the following content will be registered:\n' + % (", ".join(pack_name), "pack" if len(pack_name) == 1 else "packs") + ) for item, count in pack_content.items(): - print('%-10s| %s' % (item, count)) - print('\nInstallation may take a while for packs with many items.') + print("%-10s| %s" % (item, count)) + print("\nInstallation may take a while for packs with many items.") @add_auth_token_to_kwargs_from_cli def run_and_print(self, args, **kwargs): instance = super(PackInstallCommand, self).run_and_print(args, **kwargs) # Hack to get a list of resolved references of installed packs - packs = instance.result['output']['packs_list'] + packs = instance.result["output"]["packs_list"] if len(packs) == 1: - pack_instance = self.app.client.managers['Pack'].get_by_ref_or_id(packs[0], **kwargs) - self.print_output(pack_instance, table.PropertyValueTable, - attributes=args.attr, json=args.json, yaml=args.yaml, - attribute_display_order=self.attribute_display_order) + pack_instance = self.app.client.managers["Pack"].get_by_ref_or_id( + packs[0], **kwargs + ) + self.print_output( + pack_instance, + table.PropertyValueTable, + attributes=args.attr, + json=args.json, + yaml=args.yaml, + attribute_display_order=self.attribute_display_order, + ) else: - all_pack_instances = self.app.client.managers['Pack'].get_all(**kwargs) + all_pack_instances = self.app.client.managers["Pack"].get_all(**kwargs) pack_instances = [] for pack in all_pack_instances: if pack.name in packs or pack.ref in packs: pack_instances.append(pack) - self.print_output(pack_instances, table.MultiColumnTable, - attributes=args.attr, widths=args.width, - json=args.json, yaml=args.yaml) + self.print_output( + pack_instances, + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + json=args.json, + yaml=args.yaml, + ) - warnings = instance.result['output']['warning_list'] + warnings = instance.result["output"]["warning_list"] for warning in warnings: print(warning) class PackRemoveCommand(PackAsyncCommand): def __init__(self, resource, *args, **kwargs): - super(PackRemoveCommand, self).__init__(resource, 'remove', 'Remove %s.' - % resource.get_plural_display_name().lower(), - *args, **kwargs) - - self.parser.add_argument('packs', - nargs='+', - metavar='pack', - help='Name of the %s to remove.' % - resource.get_plural_display_name().lower()) + super(PackRemoveCommand, self).__init__( + resource, + "remove", + "Remove %s." % resource.get_plural_display_name().lower(), + *args, + **kwargs, + ) + + self.parser.add_argument( + "packs", + nargs="+", + metavar="pack", + help="Name of the %s to remove." + % resource.get_plural_display_name().lower(), + ) def run(self, args, **kwargs): return self.manager.remove(args.packs, **kwargs) @add_auth_token_to_kwargs_from_cli def run_and_print(self, args, **kwargs): - all_pack_instances = self.app.client.managers['Pack'].get_all(**kwargs) + all_pack_instances = self.app.client.managers["Pack"].get_all(**kwargs) super(PackRemoveCommand, self).run_and_print(args, **kwargs) packs = args.packs if len(packs) == 1: - pack_instance = self.app.client.managers['Pack'].get_by_ref_or_id(packs[0], **kwargs) + pack_instance = self.app.client.managers["Pack"].get_by_ref_or_id( + packs[0], **kwargs + ) if pack_instance: - raise OperationFailureException('Pack %s has not been removed properly' % packs[0]) - - removed_pack_instance = next((pack for pack in all_pack_instances - if pack.name == packs[0]), None) - - self.print_output(removed_pack_instance, table.PropertyValueTable, - attributes=args.attr, json=args.json, yaml=args.yaml, - attribute_display_order=self.attribute_display_order) + raise OperationFailureException( + "Pack %s has not been removed properly" % packs[0] + ) + + removed_pack_instance = next( + (pack for pack in all_pack_instances if pack.name == packs[0]), None + ) + + self.print_output( + removed_pack_instance, + table.PropertyValueTable, + attributes=args.attr, + json=args.json, + yaml=args.yaml, + attribute_display_order=self.attribute_display_order, + ) else: - remaining_pack_instances = self.app.client.managers['Pack'].get_all(**kwargs) + remaining_pack_instances = self.app.client.managers["Pack"].get_all( + **kwargs + ) pack_instances = [] for pack in all_pack_instances: if pack.name in packs or pack.ref in packs: pack_instances.append(pack) if pack in remaining_pack_instances: - raise OperationFailureException('Pack %s has not been removed properly' - % pack.name) + raise OperationFailureException( + "Pack %s has not been removed properly" % pack.name + ) - self.print_output(pack_instances, table.MultiColumnTable, - attributes=args.attr, widths=args.width, - json=args.json, yaml=args.yaml) + self.print_output( + pack_instances, + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + json=args.json, + yaml=args.yaml, + ) class PackRegisterCommand(PackResourceCommand): def __init__(self, resource, *args, **kwargs): - super(PackRegisterCommand, self).__init__(resource, 'register', - 'Register %s(s): sync all file changes with DB.' - % resource.get_display_name().lower(), - *args, **kwargs) - - self.parser.add_argument('packs', - nargs='*', - metavar='pack', - help='Name of the %s(s) to register.' % - resource.get_display_name().lower()) - - self.parser.add_argument('--types', - nargs='+', - help='Types of content to register.') + super(PackRegisterCommand, self).__init__( + resource, + "register", + "Register %s(s): sync all file changes with DB." + % resource.get_display_name().lower(), + *args, + **kwargs, + ) + + self.parser.add_argument( + "packs", + nargs="*", + metavar="pack", + help="Name of the %s(s) to register." % resource.get_display_name().lower(), + ) + + self.parser.add_argument( + "--types", nargs="+", help="Types of content to register." + ) @add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): @@ -369,18 +495,21 @@ def run(self, args, **kwargs): class PackSearchCommand(resource.ResourceTableCommand): - display_attributes = ['name', 'description', 'version', 'author'] - attribute_display_order = ['name', 'description', 'version', 'author'] + display_attributes = ["name", "description", "version", "author"] + attribute_display_order = ["name", "description", "version", "author"] def __init__(self, resource, *args, **kwargs): - super(PackSearchCommand, self).__init__(resource, 'search', - 'Search the index for a %s with any attribute \ - matching the query.' - % resource.get_display_name().lower(), - *args, **kwargs) - - self.parser.add_argument('query', - help='Search query.') + super(PackSearchCommand, self).__init__( + resource, + "search", + "Search the index for a %s with any attribute \ + matching the query." + % resource.get_display_name().lower(), + *args, + **kwargs, + ) + + self.parser.add_argument("query", help="Search query.") @add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): @@ -389,31 +518,41 @@ def run(self, args, **kwargs): class PackConfigCommand(resource.ResourceCommand): def __init__(self, resource, *args, **kwargs): - super(PackConfigCommand, self).__init__(resource, 'config', - 'Configure a %s based on config schema.' - % resource.get_display_name().lower(), - *args, **kwargs) - - self.parser.add_argument('name', - help='Name of the %s(s) to configure.' % - resource.get_display_name().lower()) + super(PackConfigCommand, self).__init__( + resource, + "config", + "Configure a %s based on config schema." + % resource.get_display_name().lower(), + *args, + **kwargs, + ) + + self.parser.add_argument( + "name", + help="Name of the %s(s) to configure." + % resource.get_display_name().lower(), + ) @add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): - schema = self.app.client.managers['ConfigSchema'].get_by_ref_or_id(args.name, **kwargs) + schema = self.app.client.managers["ConfigSchema"].get_by_ref_or_id( + args.name, **kwargs + ) if not schema: - msg = '%s "%s" doesn\'t exist or doesn\'t have a config schema defined.' - raise resource.ResourceNotFoundError(msg % (self.resource.get_display_name(), - args.name)) + msg = "%s \"%s\" doesn't exist or doesn't have a config schema defined." + raise resource.ResourceNotFoundError( + msg % (self.resource.get_display_name(), args.name) + ) config = interactive.InteractiveForm(schema.attributes).initiate_dialog() - message = '---\nDo you want to preview the config in an editor before saving?' - description = 'Secrets will be shown in plain text.' - preview_dialog = interactive.Question(message, {'default': 'y', - 'description': description}) - if preview_dialog.read() == 'y': + message = "---\nDo you want to preview the config in an editor before saving?" + description = "Secrets will be shown in plain text." + preview_dialog = interactive.Question( + message, {"default": "y", "description": description} + ) + if preview_dialog.read() == "y": try: contents = yaml.safe_dump(config, indent=4, default_flow_style=False) modified = editor.edit(contents=contents) @@ -421,13 +560,13 @@ def run(self, args, **kwargs): except editor.EditorError as e: print(six.text_type(e)) - message = '---\nDo you want me to save it?' - save_dialog = interactive.Question(message, {'default': 'y'}) - if save_dialog.read() == 'n': - raise OperationFailureException('Interrupted') + message = "---\nDo you want me to save it?" + save_dialog = interactive.Question(message, {"default": "y"}) + if save_dialog.read() == "n": + raise OperationFailureException("Interrupted") config_item = Config(pack=args.name, values=config) - result = self.app.client.managers['Config'].update(config_item, **kwargs) + result = self.app.client.managers["Config"].update(config_item, **kwargs) return result @@ -436,14 +575,19 @@ def run_and_print(self, args, **kwargs): instance = self.run(args, **kwargs) if not instance: raise Exception("Configuration failed") - self.print_output(instance, table.PropertyValueTable, - attributes=['all'], json=args.json, yaml=args.yaml) + self.print_output( + instance, + table.PropertyValueTable, + attributes=["all"], + json=args.json, + yaml=args.yaml, + ) except (KeyboardInterrupt, SystemExit): - raise OperationFailureException('Interrupted') + raise OperationFailureException("Interrupted") except Exception as e: if self.app.client.debug: raise message = six.text_type(e) - print('ERROR: %s' % (message)) + print("ERROR: %s" % (message)) raise OperationFailureException(message) diff --git a/st2client/st2client/commands/policy.py b/st2client/st2client/commands/policy.py index de6c8ba997e..cd891bc3a8b 100644 --- a/st2client/st2client/commands/policy.py +++ b/st2client/st2client/commands/policy.py @@ -25,31 +25,36 @@ class PolicyTypeBranch(resource.ResourceBranch): - def __init__(self, description, app, subparsers, parent_parser=None): super(PolicyTypeBranch, self).__init__( - models.PolicyType, description, app, subparsers, + models.PolicyType, + description, + app, + subparsers, parent_parser=parent_parser, read_only=True, - commands={ - 'list': PolicyTypeListCommand, - 'get': PolicyTypeGetCommand - }) + commands={"list": PolicyTypeListCommand, "get": PolicyTypeGetCommand}, + ) class PolicyTypeListCommand(resource.ResourceListCommand): - display_attributes = ['id', 'resource_type', 'name', 'description'] + display_attributes = ["id", "resource_type", "name", "description"] def __init__(self, resource, *args, **kwargs): super(PolicyTypeListCommand, self).__init__(resource, *args, **kwargs) - self.parser.add_argument('-r', '--resource-type', type=str, dest='resource_type', - help='Return policy types for the resource type.') + self.parser.add_argument( + "-r", + "--resource-type", + type=str, + dest="resource_type", + help="Return policy types for the resource type.", + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): if args.resource_type: - filters = {'resource_type': args.resource_type} + filters = {"resource_type": args.resource_type} filters.update(**kwargs) instances = self.manager.query(**filters) return instances @@ -58,36 +63,49 @@ def run(self, args, **kwargs): class PolicyTypeGetCommand(resource.ResourceGetCommand): - pk_argument_name = 'ref_or_id' + pk_argument_name = "ref_or_id" def get_resource(self, ref_or_id, **kwargs): return self.get_resource_by_ref_or_id(ref_or_id=ref_or_id, **kwargs) class PolicyBranch(resource.ResourceBranch): - def __init__(self, description, app, subparsers, parent_parser=None): super(PolicyBranch, self).__init__( - models.Policy, description, app, subparsers, + models.Policy, + description, + app, + subparsers, parent_parser=parent_parser, commands={ - 'list': PolicyListCommand, - 'get': PolicyGetCommand, - 'update': PolicyUpdateCommand, - 'delete': PolicyDeleteCommand - }) + "list": PolicyListCommand, + "get": PolicyGetCommand, + "update": PolicyUpdateCommand, + "delete": PolicyDeleteCommand, + }, + ) class PolicyListCommand(resource.ContentPackResourceListCommand): - display_attributes = ['ref', 'resource_ref', 'policy_type', 'enabled'] + display_attributes = ["ref", "resource_ref", "policy_type", "enabled"] def __init__(self, resource, *args, **kwargs): super(PolicyListCommand, self).__init__(resource, *args, **kwargs) - self.parser.add_argument('-r', '--resource-ref', type=str, dest='resource_ref', - help='Return policies for the resource ref.') - self.parser.add_argument('-pt', '--policy-type', type=str, dest='policy_type', - help='Return policies of the policy type.') + self.parser.add_argument( + "-r", + "--resource-ref", + type=str, + dest="resource_ref", + help="Return policies for the resource ref.", + ) + self.parser.add_argument( + "-pt", + "--policy-type", + type=str, + dest="policy_type", + help="Return policies of the policy type.", + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): @@ -95,10 +113,10 @@ def run(self, args, **kwargs): filters = {} if args.resource_ref: - filters['resource_ref'] = args.resource_ref + filters["resource_ref"] = args.resource_ref if args.policy_type: - filters['policy_type'] = args.policy_type + filters["policy_type"] = args.policy_type filters.update(**kwargs) @@ -108,10 +126,18 @@ def run(self, args, **kwargs): class PolicyGetCommand(resource.ContentPackResourceGetCommand): - display_attributes = ['all'] - attribute_display_order = ['id', 'ref', 'pack', 'name', 'description', - 'enabled', 'resource_ref', 'policy_type', - 'parameters'] + display_attributes = ["all"] + attribute_display_order = [ + "id", + "ref", + "pack", + "name", + "description", + "enabled", + "resource_ref", + "policy_type", + "parameters", + ] class PolicyUpdateCommand(resource.ContentPackResourceUpdateCommand): diff --git a/st2client/st2client/commands/rbac.py b/st2client/st2client/commands/rbac.py index 0d7ea7f400e..9a9e8c274b3 100644 --- a/st2client/st2client/commands/rbac.py +++ b/st2client/st2client/commands/rbac.py @@ -20,58 +20,77 @@ from st2client.models.rbac import Role from st2client.models.rbac import UserRoleAssignment -__all__ = [ - 'RoleBranch', - 'RoleAssignmentBranch' +__all__ = ["RoleBranch", "RoleAssignmentBranch"] + +ROLE_ATTRIBUTE_DISPLAY_ORDER = ["id", "name", "system", "permission_grants"] +ROLE_ASSIGNMENT_ATTRIBUTE_DISPLAY_ORDER = [ + "id", + "role", + "user", + "is_remote", + "description", ] -ROLE_ATTRIBUTE_DISPLAY_ORDER = ['id', 'name', 'system', 'permission_grants'] -ROLE_ASSIGNMENT_ATTRIBUTE_DISPLAY_ORDER = ['id', 'role', 'user', 'is_remote', 'description'] - class RoleBranch(resource.ResourceBranch): def __init__(self, description, app, subparsers, parent_parser=None): super(RoleBranch, self).__init__( - Role, description, app, subparsers, + Role, + description, + app, + subparsers, parent_parser=parent_parser, read_only=True, - commands={ - 'list': RoleListCommand, - 'get': RoleGetCommand - }) + commands={"list": RoleListCommand, "get": RoleGetCommand}, + ) class RoleListCommand(resource.ResourceCommand): - display_attributes = ['id', 'name', 'system', 'description'] + display_attributes = ["id", "name", "system", "description"] attribute_display_order = ROLE_ATTRIBUTE_DISPLAY_ORDER def __init__(self, resource, *args, **kwargs): super(RoleListCommand, self).__init__( - resource, 'list', 'Get the list of the %s.' % - resource.get_plural_display_name().lower(), - *args, **kwargs) + resource, + "list", + "Get the list of the %s." % resource.get_plural_display_name().lower(), + *args, + **kwargs, + ) self.group = self.parser.add_mutually_exclusive_group() # Filter options - self.group.add_argument('-s', '--system', action='store_true', - help='Only display system roles.') + self.group.add_argument( + "-s", "--system", action="store_true", help="Only display system roles." + ) # Display options - self.parser.add_argument('-a', '--attr', nargs='+', - default=self.display_attributes, - help=('List of attributes to include in the ' - 'output. "all" will return all ' - 'attributes.')) - self.parser.add_argument('-w', '--width', nargs='+', type=int, - default=None, - help=('Set the width of columns in output.')) + self.parser.add_argument( + "-a", + "--attr", + nargs="+", + default=self.display_attributes, + help=( + "List of attributes to include in the " + 'output. "all" will return all ' + "attributes." + ), + ) + self.parser.add_argument( + "-w", + "--width", + nargs="+", + type=int, + default=None, + help=("Set the width of columns in output."), + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): # Filtering options if args.system: - kwargs['system'] = args.system + kwargs["system"] = args.system if args.system: result = self.manager.query(**kwargs) @@ -82,67 +101,93 @@ def run(self, args, **kwargs): def run_and_print(self, args, **kwargs): instances = self.run(args, **kwargs) - self.print_output(instances, table.MultiColumnTable, - attributes=args.attr, widths=args.width, - json=args.json, yaml=args.yaml) + self.print_output( + instances, + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + json=args.json, + yaml=args.yaml, + ) class RoleGetCommand(resource.ResourceGetCommand): - display_attributes = ['all'] + display_attributes = ["all"] attribute_display_order = ROLE_ATTRIBUTE_DISPLAY_ORDER - pk_argument_name = 'id' + pk_argument_name = "id" class RoleAssignmentBranch(resource.ResourceBranch): def __init__(self, description, app, subparsers, parent_parser=None): super(RoleAssignmentBranch, self).__init__( - UserRoleAssignment, description, app, subparsers, + UserRoleAssignment, + description, + app, + subparsers, parent_parser=parent_parser, read_only=True, commands={ - 'list': RoleAssignmentListCommand, - 'get': RoleAssignmentGetCommand - }) + "list": RoleAssignmentListCommand, + "get": RoleAssignmentGetCommand, + }, + ) class RoleAssignmentListCommand(resource.ResourceCommand): - display_attributes = ['id', 'role', 'user', 'is_remote', 'source', 'description'] + display_attributes = ["id", "role", "user", "is_remote", "source", "description"] attribute_display_order = ROLE_ASSIGNMENT_ATTRIBUTE_DISPLAY_ORDER def __init__(self, resource, *args, **kwargs): super(RoleAssignmentListCommand, self).__init__( - resource, 'list', 'Get the list of the %s.' % - resource.get_plural_display_name().lower(), - *args, **kwargs) + resource, + "list", + "Get the list of the %s." % resource.get_plural_display_name().lower(), + *args, + **kwargs, + ) # Filter options - self.parser.add_argument('-r', '--role', help='Role to filter on.') - self.parser.add_argument('-u', '--user', help='User to filter on.') - self.parser.add_argument('-s', '--source', help='Source to filter on.') - self.parser.add_argument('--remote', action='store_true', - help='Only display remote role assignments.') + self.parser.add_argument("-r", "--role", help="Role to filter on.") + self.parser.add_argument("-u", "--user", help="User to filter on.") + self.parser.add_argument("-s", "--source", help="Source to filter on.") + self.parser.add_argument( + "--remote", + action="store_true", + help="Only display remote role assignments.", + ) # Display options - self.parser.add_argument('-a', '--attr', nargs='+', - default=self.display_attributes, - help=('List of attributes to include in the ' - 'output. "all" will return all ' - 'attributes.')) - self.parser.add_argument('-w', '--width', nargs='+', type=int, - default=None, - help=('Set the width of columns in output.')) + self.parser.add_argument( + "-a", + "--attr", + nargs="+", + default=self.display_attributes, + help=( + "List of attributes to include in the " + 'output. "all" will return all ' + "attributes." + ), + ) + self.parser.add_argument( + "-w", + "--width", + nargs="+", + type=int, + default=None, + help=("Set the width of columns in output."), + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): # Filtering options if args.role: - kwargs['role'] = args.role + kwargs["role"] = args.role if args.user: - kwargs['user'] = args.user + kwargs["user"] = args.user if args.source: - kwargs['source'] = args.source + kwargs["source"] = args.source if args.remote: - kwargs['remote'] = args.remote + kwargs["remote"] = args.remote if args.role or args.user or args.remote or args.source: result = self.manager.query(**kwargs) @@ -153,12 +198,17 @@ def run(self, args, **kwargs): def run_and_print(self, args, **kwargs): instances = self.run(args, **kwargs) - self.print_output(instances, table.MultiColumnTable, - attributes=args.attr, widths=args.width, - json=args.json, yaml=args.yaml) + self.print_output( + instances, + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + json=args.json, + yaml=args.yaml, + ) class RoleAssignmentGetCommand(resource.ResourceGetCommand): - display_attributes = ['all'] + display_attributes = ["all"] attribute_display_order = ROLE_ASSIGNMENT_ATTRIBUTE_DISPLAY_ORDER - pk_argument_name = 'id' + pk_argument_name = "id" diff --git a/st2client/st2client/commands/resource.py b/st2client/st2client/commands/resource.py index 15ca68bb093..da0fbc85e36 100644 --- a/st2client/st2client/commands/resource.py +++ b/st2client/st2client/commands/resource.py @@ -32,8 +32,8 @@ from st2client.formatters import table from st2client.utils.types import OrderedSet -ALLOWED_EXTS = ['.json', '.yaml', '.yml'] -PARSER_FUNCS = {'.json': json.load, '.yml': yaml.safe_load, '.yaml': yaml.safe_load} +ALLOWED_EXTS = [".json", ".yaml", ".yml"] +PARSER_FUNCS = {".json": json.load, ".yml": yaml.safe_load, ".yaml": yaml.safe_load} LOG = logging.getLogger(__name__) @@ -41,11 +41,12 @@ def add_auth_token_to_kwargs_from_cli(func): @wraps(func) def decorate(*args, **kwargs): ns = args[1] - if getattr(ns, 'token', None): - kwargs['token'] = ns.token - if getattr(ns, 'api_key', None): - kwargs['api_key'] = ns.api_key + if getattr(ns, "token", None): + kwargs["token"] = ns.token + if getattr(ns, "api_key", None): + kwargs["api_key"] = ns.api_key return func(*args, **kwargs) + return decorate @@ -58,20 +59,34 @@ class ResourceNotFoundError(Exception): class ResourceBranch(commands.Branch): - - def __init__(self, resource, description, app, subparsers, - parent_parser=None, read_only=False, commands=None, - has_disable=False): + def __init__( + self, + resource, + description, + app, + subparsers, + parent_parser=None, + read_only=False, + commands=None, + has_disable=False, + ): self.resource = resource super(ResourceBranch, self).__init__( - self.resource.get_alias().lower(), description, - app, subparsers, parent_parser=parent_parser) + self.resource.get_alias().lower(), + description, + app, + subparsers, + parent_parser=parent_parser, + ) # Registers subcommands for managing the resource type. self.subparsers = self.parser.add_subparsers( - help=('List of commands for managing %s.' % - self.resource.get_plural_display_name().lower())) + help=( + "List of commands for managing %s." + % self.resource.get_plural_display_name().lower() + ) + ) # Resolves if commands need to be overridden. commands = commands or {} @@ -82,7 +97,7 @@ def __init__(self, resource, description, app, subparsers, "update": ResourceUpdateCommand, "delete": ResourceDeleteCommand, "enable": ResourceEnableCommand, - "disable": ResourceDisableCommand + "disable": ResourceDisableCommand, } for cmd, cmd_class in cmd_map.items(): if cmd not in commands: @@ -90,17 +105,17 @@ def __init__(self, resource, description, app, subparsers, # Instantiate commands. args = [self.resource, self.app, self.subparsers] - self.commands['list'] = commands['list'](*args) - self.commands['get'] = commands['get'](*args) + self.commands["list"] = commands["list"](*args) + self.commands["get"] = commands["get"](*args) if not read_only: - self.commands['create'] = commands['create'](*args) - self.commands['update'] = commands['update'](*args) - self.commands['delete'] = commands['delete'](*args) + self.commands["create"] = commands["create"](*args) + self.commands["update"] = commands["update"](*args) + self.commands["delete"] = commands["delete"](*args) if has_disable: - self.commands['enable'] = commands['enable'](*args) - self.commands['disable'] = commands['disable'](*args) + self.commands["enable"] = commands["enable"](*args) + self.commands["disable"] = commands["disable"](*args) @six.add_metaclass(abc.ABCMeta) @@ -109,29 +124,44 @@ class ResourceCommand(commands.Command): def __init__(self, resource, *args, **kwargs): - has_token_opt = kwargs.pop('has_token_opt', True) + has_token_opt = kwargs.pop("has_token_opt", True) super(ResourceCommand, self).__init__(*args, **kwargs) self.resource = resource if has_token_opt: - self.parser.add_argument('-t', '--token', dest='token', - help='Access token for user authentication. ' - 'Get ST2_AUTH_TOKEN from the environment ' - 'variables by default.') - self.parser.add_argument('--api-key', dest='api_key', - help='Api Key for user authentication. ' - 'Get ST2_API_KEY from the environment ' - 'variables by default.') + self.parser.add_argument( + "-t", + "--token", + dest="token", + help="Access token for user authentication. " + "Get ST2_AUTH_TOKEN from the environment " + "variables by default.", + ) + self.parser.add_argument( + "--api-key", + dest="api_key", + help="Api Key for user authentication. " + "Get ST2_API_KEY from the environment " + "variables by default.", + ) # Formatter flags - self.parser.add_argument('-j', '--json', - action='store_true', dest='json', - help='Print output in JSON format.') - self.parser.add_argument('-y', '--yaml', - action='store_true', dest='yaml', - help='Print output in YAML format.') + self.parser.add_argument( + "-j", + "--json", + action="store_true", + dest="json", + help="Print output in JSON format.", + ) + self.parser.add_argument( + "-y", + "--yaml", + action="store_true", + dest="yaml", + help="Print output in YAML format.", + ) @property def manager(self): @@ -140,18 +170,17 @@ def manager(self): @property def arg_name_for_resource_id(self): resource_name = self.resource.get_display_name().lower() - return '%s-id' % resource_name.replace(' ', '-') + return "%s-id" % resource_name.replace(" ", "-") def print_not_found(self, name): - print('%s "%s" is not found.\n' % - (self.resource.get_display_name(), name)) + print('%s "%s" is not found.\n' % (self.resource.get_display_name(), name)) def get_resource(self, name_or_id, **kwargs): pk_argument_name = self.pk_argument_name - if pk_argument_name == 'name_or_id': + if pk_argument_name == "name_or_id": instance = self.get_resource_by_name_or_id(name_or_id=name_or_id, **kwargs) - elif pk_argument_name == 'ref_or_id': + elif pk_argument_name == "ref_or_id": instance = self.get_resource_by_ref_or_id(ref_or_id=name_or_id, **kwargs) else: instance = self.get_resource_by_pk(pk=name_or_id, **kwargs) @@ -167,8 +196,8 @@ def get_resource_by_pk(self, pk, **kwargs): except Exception as e: traceback.print_exc() # Hack for "Unauthorized" exceptions, we do want to propagate those - response = getattr(e, 'response', None) - status_code = getattr(response, 'status_code', None) + response = getattr(e, "response", None) + status_code = getattr(response, "status_code", None) if status_code and status_code == http_client.UNAUTHORIZED: raise e @@ -180,7 +209,7 @@ def get_resource_by_id(self, id, **kwargs): instance = self.get_resource_by_pk(pk=id, **kwargs) if not instance: - message = ('Resource with id "%s" doesn\'t exist.' % (id)) + message = 'Resource with id "%s" doesn\'t exist.' % (id) raise ResourceNotFoundError(message) return instance @@ -197,8 +226,7 @@ def get_resource_by_name_or_id(self, name_or_id, **kwargs): instance = self.get_resource_by_pk(pk=name_or_id, **kwargs) if not instance: - message = ('Resource with id or name "%s" doesn\'t exist.' % - (name_or_id)) + message = 'Resource with id or name "%s" doesn\'t exist.' % (name_or_id) raise ResourceNotFoundError(message) return instance @@ -206,8 +234,7 @@ def get_resource_by_ref_or_id(self, ref_or_id, **kwargs): instance = self.manager.get_by_ref_or_id(ref_or_id=ref_or_id, **kwargs) if not instance: - message = ('Resource with id or reference "%s" doesn\'t exist.' % - (ref_or_id)) + message = 'Resource with id or reference "%s" doesn\'t exist.' % (ref_or_id) raise ResourceNotFoundError(message) return instance @@ -220,18 +247,18 @@ def run_and_print(self, args, **kwargs): raise NotImplementedError def _get_metavar_for_argument(self, argument): - return argument.replace('_', '-') + return argument.replace("_", "-") def _get_help_for_argument(self, resource, argument): argument_display_name = argument.title() resource_display_name = resource.get_display_name().lower() - if 'ref' in argument: - result = ('Reference or ID of the %s.' % (resource_display_name)) - elif 'name_or_id' in argument: - result = ('Name or ID of the %s.' % (resource_display_name)) + if "ref" in argument: + result = "Reference or ID of the %s." % (resource_display_name) + elif "name_or_id" in argument: + result = "Name or ID of the %s." % (resource_display_name) else: - result = ('%s of the %s.' % (argument_display_name, resource_display_name)) + result = "%s of the %s." % (argument_display_name, resource_display_name) return result @@ -263,7 +290,7 @@ def _get_include_attributes(cls, args, extra_attributes=None): # into account # Special case for "all" - if 'all' in args.attr: + if "all" in args.attr: return None for attr in args.attr: @@ -272,7 +299,7 @@ def _get_include_attributes(cls, args, extra_attributes=None): if include_attributes: return include_attributes - display_attributes = getattr(cls, 'display_attributes', []) + display_attributes = getattr(cls, "display_attributes", []) if display_attributes: include_attributes += display_attributes @@ -283,97 +310,129 @@ def _get_include_attributes(cls, args, extra_attributes=None): class ResourceTableCommand(ResourceViewCommand): - display_attributes = ['id', 'name', 'description'] + display_attributes = ["id", "name", "description"] def __init__(self, resource, name, description, *args, **kwargs): - super(ResourceTableCommand, self).__init__(resource, name, description, - *args, **kwargs) - - self.parser.add_argument('-a', '--attr', nargs='+', - default=self.display_attributes, - help=('List of attributes to include in the ' - 'output. "all" will return all ' - 'attributes.')) - self.parser.add_argument('-w', '--width', nargs='+', type=int, - default=None, - help=('Set the width of columns in output.')) + super(ResourceTableCommand, self).__init__( + resource, name, description, *args, **kwargs + ) + + self.parser.add_argument( + "-a", + "--attr", + nargs="+", + default=self.display_attributes, + help=( + "List of attributes to include in the " + 'output. "all" will return all ' + "attributes." + ), + ) + self.parser.add_argument( + "-w", + "--width", + nargs="+", + type=int, + default=None, + help=("Set the width of columns in output."), + ) @add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): include_attributes = self._get_include_attributes(args=args) if include_attributes: - include_attributes = ','.join(include_attributes) - kwargs['params'] = {'include_attributes': include_attributes} + include_attributes = ",".join(include_attributes) + kwargs["params"] = {"include_attributes": include_attributes} return self.manager.get_all(**kwargs) def run_and_print(self, args, **kwargs): instances = self.run(args, **kwargs) - self.print_output(instances, table.MultiColumnTable, - attributes=args.attr, widths=args.width, - json=args.json, yaml=args.yaml) + self.print_output( + instances, + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + json=args.json, + yaml=args.yaml, + ) class ResourceListCommand(ResourceTableCommand): def __init__(self, resource, *args, **kwargs): super(ResourceListCommand, self).__init__( - resource, 'list', 'Get the list of %s.' % resource.get_plural_display_name().lower(), - *args, **kwargs) + resource, + "list", + "Get the list of %s." % resource.get_plural_display_name().lower(), + *args, + **kwargs, + ) class ContentPackResourceListCommand(ResourceListCommand): """ Base command class for use with resources which belong to a content pack. """ + def __init__(self, resource, *args, **kwargs): - super(ContentPackResourceListCommand, self).__init__(resource, - *args, **kwargs) + super(ContentPackResourceListCommand, self).__init__(resource, *args, **kwargs) - self.parser.add_argument('-p', '--pack', type=str, - help=('Only return resources belonging to the' - ' provided pack')) + self.parser.add_argument( + "-p", + "--pack", + type=str, + help=("Only return resources belonging to the" " provided pack"), + ) @add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): - filters = {'pack': args.pack} + filters = {"pack": args.pack} filters.update(**kwargs) include_attributes = self._get_include_attributes(args=args) if include_attributes: - include_attributes = ','.join(include_attributes) - filters['params'] = {'include_attributes': include_attributes} + include_attributes = ",".join(include_attributes) + filters["params"] = {"include_attributes": include_attributes} return self.manager.get_all(**filters) class ResourceGetCommand(ResourceViewCommand): - display_attributes = ['all'] - attribute_display_order = ['id', 'name', 'description'] + display_attributes = ["all"] + attribute_display_order = ["id", "name", "description"] - pk_argument_name = 'name_or_id' # name of the attribute which stores resource PK + pk_argument_name = "name_or_id" # name of the attribute which stores resource PK help_string = None def __init__(self, resource, *args, **kwargs): super(ResourceGetCommand, self).__init__( - resource, 'get', - self.help_string or 'Get individual %s.' % resource.get_display_name().lower(), - *args, **kwargs + resource, + "get", + self.help_string + or "Get individual %s." % resource.get_display_name().lower(), + *args, + **kwargs, ) argument = self.pk_argument_name metavar = self._get_metavar_for_argument(argument=self.pk_argument_name) - help = self._get_help_for_argument(resource=resource, - argument=self.pk_argument_name) - - self.parser.add_argument(argument, - metavar=metavar, - help=help) - self.parser.add_argument('-a', '--attr', nargs='+', - default=self.display_attributes, - help=('List of attributes to include in the ' - 'output. "all" or unspecified will ' - 'return all attributes.')) + help = self._get_help_for_argument( + resource=resource, argument=self.pk_argument_name + ) + + self.parser.add_argument(argument, metavar=metavar, help=help) + self.parser.add_argument( + "-a", + "--attr", + nargs="+", + default=self.display_attributes, + help=( + "List of attributes to include in the " + 'output. "all" or unspecified will ' + "return all attributes." + ), + ) @add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): @@ -383,13 +442,18 @@ def run(self, args, **kwargs): def run_and_print(self, args, **kwargs): try: instance = self.run(args, **kwargs) - self.print_output(instance, table.PropertyValueTable, - attributes=args.attr, json=args.json, yaml=args.yaml, - attribute_display_order=self.attribute_display_order) + self.print_output( + instance, + table.PropertyValueTable, + attributes=args.attr, + json=args.json, + yaml=args.yaml, + attribute_display_order=self.attribute_display_order, + ) except ResourceNotFoundError: resource_id = getattr(args, self.pk_argument_name, None) self.print_not_found(resource_id) - raise OperationFailureException('Resource %s not found.' % resource_id) + raise OperationFailureException("Resource %s not found." % resource_id) class ContentPackResourceGetCommand(ResourceGetCommand): @@ -400,24 +464,31 @@ class ContentPackResourceGetCommand(ResourceGetCommand): retrieved by a reference or by an id. """ - attribute_display_order = ['id', 'pack', 'name', 'description'] + attribute_display_order = ["id", "pack", "name", "description"] - pk_argument_name = 'ref_or_id' + pk_argument_name = "ref_or_id" def get_resource(self, ref_or_id, **kwargs): return self.get_resource_by_ref_or_id(ref_or_id=ref_or_id, **kwargs) class ResourceCreateCommand(ResourceCommand): - def __init__(self, resource, *args, **kwargs): - super(ResourceCreateCommand, self).__init__(resource, 'create', - 'Create a new %s.' % resource.get_display_name().lower(), - *args, **kwargs) + super(ResourceCreateCommand, self).__init__( + resource, + "create", + "Create a new %s." % resource.get_display_name().lower(), + *args, + **kwargs, + ) - self.parser.add_argument('file', - help=('JSON/YAML file containing the %s to create.' - % resource.get_display_name().lower())) + self.parser.add_argument( + "file", + help=( + "JSON/YAML file containing the %s to create." + % resource.get_display_name().lower() + ), + ) @add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): @@ -429,34 +500,46 @@ def run_and_print(self, args, **kwargs): try: instance = self.run(args, **kwargs) if not instance: - raise Exception('Server did not create instance.') - self.print_output(instance, table.PropertyValueTable, - attributes=['all'], json=args.json, yaml=args.yaml) + raise Exception("Server did not create instance.") + self.print_output( + instance, + table.PropertyValueTable, + attributes=["all"], + json=args.json, + yaml=args.yaml, + ) except Exception as e: message = six.text_type(e) - print('ERROR: %s' % (message)) + print("ERROR: %s" % (message)) raise OperationFailureException(message) class ResourceUpdateCommand(ResourceCommand): - pk_argument_name = 'name_or_id' + pk_argument_name = "name_or_id" def __init__(self, resource, *args, **kwargs): - super(ResourceUpdateCommand, self).__init__(resource, 'update', - 'Updating an existing %s.' % resource.get_display_name().lower(), - *args, **kwargs) + super(ResourceUpdateCommand, self).__init__( + resource, + "update", + "Updating an existing %s." % resource.get_display_name().lower(), + *args, + **kwargs, + ) argument = self.pk_argument_name metavar = self._get_metavar_for_argument(argument=self.pk_argument_name) - help = self._get_help_for_argument(resource=resource, - argument=self.pk_argument_name) + help = self._get_help_for_argument( + resource=resource, argument=self.pk_argument_name + ) - self.parser.add_argument(argument, - metavar=metavar, - help=help) - self.parser.add_argument('file', - help=('JSON/YAML file containing the %s to update.' - % resource.get_display_name().lower())) + self.parser.add_argument(argument, metavar=metavar, help=help) + self.parser.add_argument( + "file", + help=( + "JSON/YAML file containing the %s to update." + % resource.get_display_name().lower() + ), + ) @add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): @@ -465,46 +548,55 @@ def run(self, args, **kwargs): data = load_meta_file(args.file) modified_instance = self.resource.deserialize(data) - if not getattr(modified_instance, 'id', None): + if not getattr(modified_instance, "id", None): modified_instance.id = instance.id else: if modified_instance.id != instance.id: - raise Exception('The value for the %s id in the JSON/YAML file ' - 'does not match the ID provided in the ' - 'command line arguments.' % - self.resource.get_display_name().lower()) + raise Exception( + "The value for the %s id in the JSON/YAML file " + "does not match the ID provided in the " + "command line arguments." % self.resource.get_display_name().lower() + ) return self.manager.update(modified_instance, **kwargs) def run_and_print(self, args, **kwargs): instance = self.run(args, **kwargs) try: - self.print_output(instance, table.PropertyValueTable, - attributes=['all'], json=args.json, yaml=args.yaml) + self.print_output( + instance, + table.PropertyValueTable, + attributes=["all"], + json=args.json, + yaml=args.yaml, + ) except Exception as e: - print('ERROR: %s' % (six.text_type(e))) + print("ERROR: %s" % (six.text_type(e))) raise OperationFailureException(six.text_type(e)) class ContentPackResourceUpdateCommand(ResourceUpdateCommand): - pk_argument_name = 'ref_or_id' + pk_argument_name = "ref_or_id" class ResourceEnableCommand(ResourceCommand): - pk_argument_name = 'name_or_id' + pk_argument_name = "name_or_id" def __init__(self, resource, *args, **kwargs): - super(ResourceEnableCommand, self).__init__(resource, 'enable', - 'Enable an existing %s.' % resource.get_display_name().lower(), - *args, **kwargs) + super(ResourceEnableCommand, self).__init__( + resource, + "enable", + "Enable an existing %s." % resource.get_display_name().lower(), + *args, + **kwargs, + ) argument = self.pk_argument_name metavar = self._get_metavar_for_argument(argument=self.pk_argument_name) - help = self._get_help_for_argument(resource=resource, - argument=self.pk_argument_name) + help = self._get_help_for_argument( + resource=resource, argument=self.pk_argument_name + ) - self.parser.add_argument(argument, - metavar=metavar, - help=help) + self.parser.add_argument(argument, metavar=metavar, help=help) @add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): @@ -513,40 +605,48 @@ def run(self, args, **kwargs): data = instance.serialize() - if 'ref' in data: - del data['ref'] + if "ref" in data: + del data["ref"] - data['enabled'] = True + data["enabled"] = True modified_instance = self.resource.deserialize(data) return self.manager.update(modified_instance, **kwargs) def run_and_print(self, args, **kwargs): instance = self.run(args, **kwargs) - self.print_output(instance, table.PropertyValueTable, - attributes=['all'], json=args.json, yaml=args.yaml) + self.print_output( + instance, + table.PropertyValueTable, + attributes=["all"], + json=args.json, + yaml=args.yaml, + ) class ContentPackResourceEnableCommand(ResourceEnableCommand): - pk_argument_name = 'ref_or_id' + pk_argument_name = "ref_or_id" class ResourceDisableCommand(ResourceCommand): - pk_argument_name = 'name_or_id' + pk_argument_name = "name_or_id" def __init__(self, resource, *args, **kwargs): - super(ResourceDisableCommand, self).__init__(resource, 'disable', - 'Disable an existing %s.' % resource.get_display_name().lower(), - *args, **kwargs) + super(ResourceDisableCommand, self).__init__( + resource, + "disable", + "Disable an existing %s." % resource.get_display_name().lower(), + *args, + **kwargs, + ) argument = self.pk_argument_name metavar = self._get_metavar_for_argument(argument=self.pk_argument_name) - help = self._get_help_for_argument(resource=resource, - argument=self.pk_argument_name) + help = self._get_help_for_argument( + resource=resource, argument=self.pk_argument_name + ) - self.parser.add_argument(argument, - metavar=metavar, - help=help) + self.parser.add_argument(argument, metavar=metavar, help=help) @add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): @@ -555,40 +655,48 @@ def run(self, args, **kwargs): data = instance.serialize() - if 'ref' in data: - del data['ref'] + if "ref" in data: + del data["ref"] - data['enabled'] = False + data["enabled"] = False modified_instance = self.resource.deserialize(data) return self.manager.update(modified_instance, **kwargs) def run_and_print(self, args, **kwargs): instance = self.run(args, **kwargs) - self.print_output(instance, table.PropertyValueTable, - attributes=['all'], json=args.json, yaml=args.yaml) + self.print_output( + instance, + table.PropertyValueTable, + attributes=["all"], + json=args.json, + yaml=args.yaml, + ) class ContentPackResourceDisableCommand(ResourceDisableCommand): - pk_argument_name = 'ref_or_id' + pk_argument_name = "ref_or_id" class ResourceDeleteCommand(ResourceCommand): - pk_argument_name = 'name_or_id' + pk_argument_name = "name_or_id" def __init__(self, resource, *args, **kwargs): - super(ResourceDeleteCommand, self).__init__(resource, 'delete', - 'Delete an existing %s.' % resource.get_display_name().lower(), - *args, **kwargs) + super(ResourceDeleteCommand, self).__init__( + resource, + "delete", + "Delete an existing %s." % resource.get_display_name().lower(), + *args, + **kwargs, + ) argument = self.pk_argument_name metavar = self._get_metavar_for_argument(argument=self.pk_argument_name) - help = self._get_help_for_argument(resource=resource, - argument=self.pk_argument_name) + help = self._get_help_for_argument( + resource=resource, argument=self.pk_argument_name + ) - self.parser.add_argument(argument, - metavar=metavar, - help=help) + self.parser.add_argument(argument, metavar=metavar, help=help) @add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): @@ -601,10 +709,12 @@ def run_and_print(self, args, **kwargs): try: self.run(args, **kwargs) - print('Resource with id "%s" has been successfully deleted.' % (resource_id)) + print( + 'Resource with id "%s" has been successfully deleted.' % (resource_id) + ) except ResourceNotFoundError: self.print_not_found(resource_id) - raise OperationFailureException('Resource %s not found.' % resource_id) + raise OperationFailureException("Resource %s not found." % resource_id) class ContentPackResourceDeleteCommand(ResourceDeleteCommand): @@ -612,7 +722,7 @@ class ContentPackResourceDeleteCommand(ResourceDeleteCommand): Base command class for deleting a resource which belongs to a content pack. """ - pk_argument_name = 'ref_or_id' + pk_argument_name = "ref_or_id" def load_meta_file(file_path): @@ -621,8 +731,10 @@ def load_meta_file(file_path): file_name, file_ext = os.path.splitext(file_path) if file_ext not in ALLOWED_EXTS: - raise Exception('Unsupported meta type %s, file %s. Allowed: %s' % - (file_ext, file_path, ALLOWED_EXTS)) + raise Exception( + "Unsupported meta type %s, file %s. Allowed: %s" + % (file_ext, file_path, ALLOWED_EXTS) + ) - with open(file_path, 'r') as f: + with open(file_path, "r") as f: return PARSER_FUNCS[file_ext](f) diff --git a/st2client/st2client/commands/rule.py b/st2client/st2client/commands/rule.py index 7f0f5e58dbe..cbab939e101 100644 --- a/st2client/st2client/commands/rule.py +++ b/st2client/st2client/commands/rule.py @@ -21,99 +21,143 @@ class RuleBranch(resource.ResourceBranch): - def __init__(self, description, app, subparsers, parent_parser=None): super(RuleBranch, self).__init__( - models.Rule, description, app, subparsers, + models.Rule, + description, + app, + subparsers, parent_parser=parent_parser, commands={ - 'list': RuleListCommand, - 'get': RuleGetCommand, - 'update': RuleUpdateCommand, - 'delete': RuleDeleteCommand - }) + "list": RuleListCommand, + "get": RuleGetCommand, + "update": RuleUpdateCommand, + "delete": RuleDeleteCommand, + }, + ) - self.commands['enable'] = RuleEnableCommand(self.resource, self.app, self.subparsers) - self.commands['disable'] = RuleDisableCommand(self.resource, self.app, self.subparsers) + self.commands["enable"] = RuleEnableCommand( + self.resource, self.app, self.subparsers + ) + self.commands["disable"] = RuleDisableCommand( + self.resource, self.app, self.subparsers + ) class RuleListCommand(resource.ResourceTableCommand): - display_attributes = ['ref', 'pack', 'description', 'enabled'] - display_attributes_iftt = ['ref', 'trigger.ref', 'action.ref', 'enabled'] + display_attributes = ["ref", "pack", "description", "enabled"] + display_attributes_iftt = ["ref", "trigger.ref", "action.ref", "enabled"] def __init__(self, resource, *args, **kwargs): self.default_limit = 50 - super(RuleListCommand, self).__init__(resource, 'list', - 'Get the list of the %s most recent %s.' % - (self.default_limit, - resource.get_plural_display_name().lower()), - *args, **kwargs) + super(RuleListCommand, self).__init__( + resource, + "list", + "Get the list of the %s most recent %s." + % (self.default_limit, resource.get_plural_display_name().lower()), + *args, + **kwargs, + ) self.resource_name = resource.get_plural_display_name().lower() self.group = self.parser.add_argument_group() - self.parser.add_argument('-n', '--last', type=int, dest='last', - default=self.default_limit, - help=('List N most recent %s. Use -n -1 to fetch the full result \ - set.' % self.resource_name)) - self.parser.add_argument('--iftt', action='store_true', - help='Show trigger and action in display list.') - self.parser.add_argument('-p', '--pack', type=str, - help=('Only return resources belonging to the' - ' provided pack')) - self.group.add_argument('-c', '--action', - help='Action reference to filter the list.') - self.group.add_argument('-g', '--trigger', - help='Trigger type reference to filter the list.') + self.parser.add_argument( + "-n", + "--last", + type=int, + dest="last", + default=self.default_limit, + help=( + "List N most recent %s. Use -n -1 to fetch the full result \ + set." + % self.resource_name + ), + ) + self.parser.add_argument( + "--iftt", + action="store_true", + help="Show trigger and action in display list.", + ) + self.parser.add_argument( + "-p", + "--pack", + type=str, + help=("Only return resources belonging to the" " provided pack"), + ) + self.group.add_argument( + "-c", "--action", help="Action reference to filter the list." + ) + self.group.add_argument( + "-g", "--trigger", help="Trigger type reference to filter the list." + ) self.enabled_filter_group = self.parser.add_mutually_exclusive_group() - self.enabled_filter_group.add_argument('--enabled', action='store_true', - help='Show rules that are enabled.') - self.enabled_filter_group.add_argument('--disabled', action='store_true', - help='Show rules that are disabled.') + self.enabled_filter_group.add_argument( + "--enabled", action="store_true", help="Show rules that are enabled." + ) + self.enabled_filter_group.add_argument( + "--disabled", action="store_true", help="Show rules that are disabled." + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): # Filtering options if args.pack: - kwargs['pack'] = args.pack + kwargs["pack"] = args.pack if args.action: - kwargs['action'] = args.action + kwargs["action"] = args.action if args.trigger: - kwargs['trigger'] = args.trigger + kwargs["trigger"] = args.trigger if args.enabled: - kwargs['enabled'] = True + kwargs["enabled"] = True if args.disabled: - kwargs['enabled'] = False + kwargs["enabled"] = False if args.iftt: # switch attr to display the trigger and action args.attr = self.display_attributes_iftt include_attributes = self._get_include_attributes(args=args) if include_attributes: - include_attributes = ','.join(include_attributes) - kwargs['params'] = {'include_attributes': include_attributes} + include_attributes = ",".join(include_attributes) + kwargs["params"] = {"include_attributes": include_attributes} return self.manager.query_with_count(limit=args.last, **kwargs) def run_and_print(self, args, **kwargs): instances, count = self.run(args, **kwargs) if args.json or args.yaml: - self.print_output(instances, table.MultiColumnTable, - attributes=args.attr, widths=args.width, - json=args.json, yaml=args.yaml) + self.print_output( + instances, + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + json=args.json, + yaml=args.yaml, + ) else: - self.print_output(instances, table.MultiColumnTable, - attributes=args.attr, widths=args.width) + self.print_output( + instances, + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + ) if args.last and count and count > args.last: table.SingleRowTable.note_box(self.resource_name, args.last) class RuleGetCommand(resource.ContentPackResourceGetCommand): - display_attributes = ['all'] - attribute_display_order = ['id', 'uid', 'ref', 'pack', 'name', 'description', - 'enabled'] + display_attributes = ["all"] + attribute_display_order = [ + "id", + "uid", + "ref", + "pack", + "name", + "description", + "enabled", + ] class RuleUpdateCommand(resource.ContentPackResourceUpdateCommand): @@ -121,15 +165,29 @@ class RuleUpdateCommand(resource.ContentPackResourceUpdateCommand): class RuleEnableCommand(resource.ContentPackResourceEnableCommand): - display_attributes = ['all'] - attribute_display_order = ['id', 'ref', 'pack', 'name', 'enabled', 'description', - 'enabled'] + display_attributes = ["all"] + attribute_display_order = [ + "id", + "ref", + "pack", + "name", + "enabled", + "description", + "enabled", + ] class RuleDisableCommand(resource.ContentPackResourceDisableCommand): - display_attributes = ['all'] - attribute_display_order = ['id', 'ref', 'pack', 'name', 'enabled', 'description', - 'enabled'] + display_attributes = ["all"] + attribute_display_order = [ + "id", + "ref", + "pack", + "name", + "enabled", + "description", + "enabled", + ] class RuleDeleteCommand(resource.ContentPackResourceDeleteCommand): diff --git a/st2client/st2client/commands/rule_enforcement.py b/st2client/st2client/commands/rule_enforcement.py index ecebba2b071..dd624d4a725 100644 --- a/st2client/st2client/commands/rule_enforcement.py +++ b/st2client/st2client/commands/rule_enforcement.py @@ -22,24 +22,39 @@ class RuleEnforcementBranch(resource.ResourceBranch): - def __init__(self, description, app, subparsers, parent_parser=None): super(RuleEnforcementBranch, self).__init__( - models.RuleEnforcement, description, app, subparsers, + models.RuleEnforcement, + description, + app, + subparsers, parent_parser=parent_parser, commands={ - 'list': RuleEnforcementListCommand, - 'get': RuleEnforcementGetCommand, - }) + "list": RuleEnforcementListCommand, + "get": RuleEnforcementGetCommand, + }, + ) class RuleEnforcementGetCommand(resource.ResourceGetCommand): - display_attributes = ['id', 'rule.ref', 'trigger_instance_id', - 'execution_id', 'failure_reason', 'enforced_at'] - attribute_display_order = ['id', 'rule.ref', 'trigger_instance_id', - 'execution_id', 'failure_reason', 'enforced_at'] - - pk_argument_name = 'id' + display_attributes = [ + "id", + "rule.ref", + "trigger_instance_id", + "execution_id", + "failure_reason", + "enforced_at", + ] + attribute_display_order = [ + "id", + "rule.ref", + "trigger_instance_id", + "execution_id", + "failure_reason", + "enforced_at", + ] + + pk_argument_name = "id" @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): @@ -48,84 +63,137 @@ def run(self, args, **kwargs): class RuleEnforcementListCommand(resource.ResourceCommand): - display_attributes = ['id', 'rule.ref', 'trigger_instance_id', - 'execution_id', 'enforced_at'] - attribute_display_order = ['id', 'rule.ref', 'trigger_instance_id', - 'execution_id', 'enforced_at'] - - attribute_transform_functions = { - 'enforced_at': format_isodate_for_user_timezone - } + display_attributes = [ + "id", + "rule.ref", + "trigger_instance_id", + "execution_id", + "enforced_at", + ] + attribute_display_order = [ + "id", + "rule.ref", + "trigger_instance_id", + "execution_id", + "enforced_at", + ] + + attribute_transform_functions = {"enforced_at": format_isodate_for_user_timezone} def __init__(self, resource, *args, **kwargs): self.default_limit = 50 super(RuleEnforcementListCommand, self).__init__( - resource, 'list', 'Get the list of the %s most recent %s.' % - (self.default_limit, resource.get_plural_display_name().lower()), - *args, **kwargs) + resource, + "list", + "Get the list of the %s most recent %s." + % (self.default_limit, resource.get_plural_display_name().lower()), + *args, + **kwargs, + ) self.resource_name = resource.get_plural_display_name().lower() self.group = self.parser.add_argument_group() - self.parser.add_argument('-n', '--last', type=int, dest='last', - default=self.default_limit, - help=('List N most recent %s. Use -n -1 to fetch the full result \ - set.' % self.resource_name)) + self.parser.add_argument( + "-n", + "--last", + type=int, + dest="last", + default=self.default_limit, + help=( + "List N most recent %s. Use -n -1 to fetch the full result \ + set." + % self.resource_name + ), + ) # Filter options - self.group.add_argument('--trigger-instance', - help='Trigger instance id to filter the list.') - - self.group.add_argument('--execution', - help='Execution id to filter the list.') - self.group.add_argument('--rule', - help='Rule ref to filter the list.') - - self.parser.add_argument('-tg', '--timestamp-gt', type=str, dest='timestamp_gt', - default=None, - help=('Only return enforcements with enforced_at ' - 'greater than the one provided. ' - 'Use time in the format 2000-01-01T12:00:00.000Z')) - self.parser.add_argument('-tl', '--timestamp-lt', type=str, dest='timestamp_lt', - default=None, - help=('Only return enforcements with enforced_at ' - 'lower than the one provided. ' - 'Use time in the format 2000-01-01T12:00:00.000Z')) + self.group.add_argument( + "--trigger-instance", help="Trigger instance id to filter the list." + ) + + self.group.add_argument("--execution", help="Execution id to filter the list.") + self.group.add_argument("--rule", help="Rule ref to filter the list.") + + self.parser.add_argument( + "-tg", + "--timestamp-gt", + type=str, + dest="timestamp_gt", + default=None, + help=( + "Only return enforcements with enforced_at " + "greater than the one provided. " + "Use time in the format 2000-01-01T12:00:00.000Z" + ), + ) + self.parser.add_argument( + "-tl", + "--timestamp-lt", + type=str, + dest="timestamp_lt", + default=None, + help=( + "Only return enforcements with enforced_at " + "lower than the one provided. " + "Use time in the format 2000-01-01T12:00:00.000Z" + ), + ) # Display options - self.parser.add_argument('-a', '--attr', nargs='+', - default=self.display_attributes, - help=('List of attributes to include in the ' - 'output. "all" will return all ' - 'attributes.')) - self.parser.add_argument('-w', '--width', nargs='+', type=int, - default=None, - help=('Set the width of columns in output.')) + self.parser.add_argument( + "-a", + "--attr", + nargs="+", + default=self.display_attributes, + help=( + "List of attributes to include in the " + 'output. "all" will return all ' + "attributes." + ), + ) + self.parser.add_argument( + "-w", + "--width", + nargs="+", + type=int, + default=None, + help=("Set the width of columns in output."), + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): # Filtering options if args.trigger_instance: - kwargs['trigger_instance'] = args.trigger_instance + kwargs["trigger_instance"] = args.trigger_instance if args.execution: - kwargs['execution'] = args.execution + kwargs["execution"] = args.execution if args.rule: - kwargs['rule_ref'] = args.rule + kwargs["rule_ref"] = args.rule if args.timestamp_gt: - kwargs['enforced_at_gt'] = args.timestamp_gt + kwargs["enforced_at_gt"] = args.timestamp_gt if args.timestamp_lt: - kwargs['enforced_at_lt'] = args.timestamp_lt + kwargs["enforced_at_lt"] = args.timestamp_lt return self.manager.query_with_count(limit=args.last, **kwargs) def run_and_print(self, args, **kwargs): instances, count = self.run(args, **kwargs) if args.json or args.yaml: - self.print_output(reversed(instances), table.MultiColumnTable, - attributes=args.attr, widths=args.width, - json=args.json, yaml=args.yaml, - attribute_transform_functions=self.attribute_transform_functions) + self.print_output( + reversed(instances), + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + json=args.json, + yaml=args.yaml, + attribute_transform_functions=self.attribute_transform_functions, + ) else: - self.print_output(instances, table.MultiColumnTable, - attributes=args.attr, widths=args.width) + self.print_output( + instances, + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + ) if args.last and count and count > args.last: table.SingleRowTable.note_box(self.resource_name, args.last) diff --git a/st2client/st2client/commands/sensor.py b/st2client/st2client/commands/sensor.py index 0d729c8c027..ca4cc335631 100644 --- a/st2client/st2client/commands/sensor.py +++ b/st2client/st2client/commands/sensor.py @@ -22,35 +22,67 @@ class SensorBranch(resource.ResourceBranch): def __init__(self, description, app, subparsers, parent_parser=None): super(SensorBranch, self).__init__( - Sensor, description, app, subparsers, + Sensor, + description, + app, + subparsers, parent_parser=parent_parser, read_only=True, - commands={ - 'list': SensorListCommand, - 'get': SensorGetCommand - }) + commands={"list": SensorListCommand, "get": SensorGetCommand}, + ) - self.commands['enable'] = SensorEnableCommand(self.resource, self.app, self.subparsers) - self.commands['disable'] = SensorDisableCommand(self.resource, self.app, self.subparsers) + self.commands["enable"] = SensorEnableCommand( + self.resource, self.app, self.subparsers + ) + self.commands["disable"] = SensorDisableCommand( + self.resource, self.app, self.subparsers + ) class SensorListCommand(resource.ContentPackResourceListCommand): - display_attributes = ['ref', 'pack', 'description', 'enabled'] + display_attributes = ["ref", "pack", "description", "enabled"] class SensorGetCommand(resource.ContentPackResourceGetCommand): - display_attributes = ['all'] - attribute_display_order = ['id', 'uid', 'ref', 'pack', 'name', 'enabled', 'entry_point', - 'artifact_uri', 'trigger_types'] + display_attributes = ["all"] + attribute_display_order = [ + "id", + "uid", + "ref", + "pack", + "name", + "enabled", + "entry_point", + "artifact_uri", + "trigger_types", + ] class SensorEnableCommand(resource.ContentPackResourceEnableCommand): - display_attributes = ['all'] - attribute_display_order = ['id', 'ref', 'pack', 'name', 'enabled', 'poll_interval', - 'entry_point', 'artifact_uri', 'trigger_types'] + display_attributes = ["all"] + attribute_display_order = [ + "id", + "ref", + "pack", + "name", + "enabled", + "poll_interval", + "entry_point", + "artifact_uri", + "trigger_types", + ] class SensorDisableCommand(resource.ContentPackResourceDisableCommand): - display_attributes = ['all'] - attribute_display_order = ['id', 'ref', 'pack', 'name', 'enabled', 'poll_interval', - 'entry_point', 'artifact_uri', 'trigger_types'] + display_attributes = ["all"] + attribute_display_order = [ + "id", + "ref", + "pack", + "name", + "enabled", + "poll_interval", + "entry_point", + "artifact_uri", + "trigger_types", + ] diff --git a/st2client/st2client/commands/service_registry.py b/st2client/st2client/commands/service_registry.py index b609e051a91..6b9bff60b99 100644 --- a/st2client/st2client/commands/service_registry.py +++ b/st2client/st2client/commands/service_registry.py @@ -25,76 +25,87 @@ class ServiceRegistryBranch(commands.Branch): def __init__(self, description, app, subparsers, parent_parser=None): super(ServiceRegistryBranch, self).__init__( - 'service-registry', description, - app, subparsers, parent_parser=parent_parser) + "service-registry", + description, + app, + subparsers, + parent_parser=parent_parser, + ) self.subparsers = self.parser.add_subparsers( - help=('List of commands for managing service registry.')) + help=("List of commands for managing service registry.") + ) # Instantiate commands - args_groups = ['Manage service registry groups', self.app, self.subparsers] - args_members = ['Manage service registry members', self.app, self.subparsers] + args_groups = ["Manage service registry groups", self.app, self.subparsers] + args_members = ["Manage service registry members", self.app, self.subparsers] - self.commands['groups'] = ServiceRegistryGroupsBranch(*args_groups) - self.commands['members'] = ServiceRegistryMembersBranch(*args_members) + self.commands["groups"] = ServiceRegistryGroupsBranch(*args_groups) + self.commands["members"] = ServiceRegistryMembersBranch(*args_members) class ServiceRegistryGroupsBranch(resource.ResourceBranch): def __init__(self, description, app, subparsers, parent_parser=None): super(ServiceRegistryGroupsBranch, self).__init__( - ServiceRegistryGroup, description, app, subparsers, + ServiceRegistryGroup, + description, + app, + subparsers, parent_parser=parent_parser, read_only=True, - commands={ - 'list': ServiceRegistryListGroupsCommand, - 'get': NoopCommand - }) + commands={"list": ServiceRegistryListGroupsCommand, "get": NoopCommand}, + ) - del self.commands['get'] + del self.commands["get"] class ServiceRegistryMembersBranch(resource.ResourceBranch): def __init__(self, description, app, subparsers, parent_parser=None): super(ServiceRegistryMembersBranch, self).__init__( - ServiceRegistryMember, description, app, subparsers, + ServiceRegistryMember, + description, + app, + subparsers, parent_parser=parent_parser, read_only=True, - commands={ - 'list': ServiceRegistryListMembersCommand, - 'get': NoopCommand - }) + commands={"list": ServiceRegistryListMembersCommand, "get": NoopCommand}, + ) - del self.commands['get'] + del self.commands["get"] class ServiceRegistryListGroupsCommand(resource.ResourceListCommand): - display_attributes = ['group_id'] - attribute_display_order = ['group_id'] + display_attributes = ["group_id"] + attribute_display_order = ["group_id"] @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): - manager = self.app.client.managers['ServiceRegistryGroups'] + manager = self.app.client.managers["ServiceRegistryGroups"] groups = manager.list() return groups class ServiceRegistryListMembersCommand(resource.ResourceListCommand): - display_attributes = ['group_id', 'member_id', 'capabilities'] - attribute_display_order = ['group_id', 'member_id', 'capabilities'] + display_attributes = ["group_id", "member_id", "capabilities"] + attribute_display_order = ["group_id", "member_id", "capabilities"] def __init__(self, resource, *args, **kwargs): super(ServiceRegistryListMembersCommand, self).__init__( resource, *args, **kwargs ) - self.parser.add_argument('--group-id', dest='group_id', default=None, - help='If provided only retrieve members for the specified group.') + self.parser.add_argument( + "--group-id", + dest="group_id", + default=None, + help="If provided only retrieve members for the specified group.", + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): - groups_manager = self.app.client.managers['ServiceRegistryGroups'] - members_manager = self.app.client.managers['ServiceRegistryMembers'] + groups_manager = self.app.client.managers["ServiceRegistryGroups"] + members_manager = self.app.client.managers["ServiceRegistryMembers"] # If group ID is provided only retrieve members for that group, otherwise retrieve members # for all groups diff --git a/st2client/st2client/commands/timer.py b/st2client/st2client/commands/timer.py index e3fc9e223fa..c1833672912 100644 --- a/st2client/st2client/commands/timer.py +++ b/st2client/st2client/commands/timer.py @@ -22,30 +22,39 @@ class TimerBranch(resource.ResourceBranch): def __init__(self, description, app, subparsers, parent_parser=None): super(TimerBranch, self).__init__( - Timer, description, app, subparsers, + Timer, + description, + app, + subparsers, parent_parser=parent_parser, read_only=True, - commands={ - 'list': TimerListCommand, - 'get': TimerGetCommand - }) + commands={"list": TimerListCommand, "get": TimerGetCommand}, + ) class TimerListCommand(resource.ResourceListCommand): - display_attributes = ['id', 'uid', 'pack', 'name', 'type', 'parameters'] + display_attributes = ["id", "uid", "pack", "name", "type", "parameters"] def __init__(self, resource, *args, **kwargs): super(TimerListCommand, self).__init__(resource, *args, **kwargs) - self.parser.add_argument('-ty', '--timer-type', type=str, dest='timer_type', - help=("List %s type, example: 'core.st2.IntervalTimer', \ - 'core.st2.DateTimer', 'core.st2.CronTimer'." % - resource.get_plural_display_name().lower()), required=False) + self.parser.add_argument( + "-ty", + "--timer-type", + type=str, + dest="timer_type", + help=( + "List %s type, example: 'core.st2.IntervalTimer', \ + 'core.st2.DateTimer', 'core.st2.CronTimer'." + % resource.get_plural_display_name().lower() + ), + required=False, + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): if args.timer_type: - kwargs['timer_type'] = args.timer_type + kwargs["timer_type"] = args.timer_type if kwargs: return self.manager.query(**kwargs) @@ -54,5 +63,5 @@ def run(self, args, **kwargs): class TimerGetCommand(resource.ResourceGetCommand): - display_attributes = ['all'] - attribute_display_order = ['type', 'pack', 'name', 'description', 'parameters'] + display_attributes = ["all"] + attribute_display_order = ["type", "pack", "name", "description", "parameters"] diff --git a/st2client/st2client/commands/trace.py b/st2client/st2client/commands/trace.py index b5e59c2cf13..ac8de676c26 100644 --- a/st2client/st2client/commands/trace.py +++ b/st2client/st2client/commands/trace.py @@ -23,53 +23,62 @@ from st2client.utils.date import format_isodate_for_user_timezone -TRACE_ATTRIBUTE_DISPLAY_ORDER = ['id', 'trace_tag', 'action_executions', 'rules', - 'trigger_instances', 'start_timestamp'] +TRACE_ATTRIBUTE_DISPLAY_ORDER = [ + "id", + "trace_tag", + "action_executions", + "rules", + "trigger_instances", + "start_timestamp", +] -TRACE_HEADER_DISPLAY_ORDER = ['id', 'trace_tag', 'start_timestamp'] +TRACE_HEADER_DISPLAY_ORDER = ["id", "trace_tag", "start_timestamp"] -TRACE_COMPONENT_DISPLAY_LABELS = ['id', 'type', 'ref', 'updated_at'] +TRACE_COMPONENT_DISPLAY_LABELS = ["id", "type", "ref", "updated_at"] -TRACE_DISPLAY_ATTRIBUTES = ['all'] +TRACE_DISPLAY_ATTRIBUTES = ["all"] TRIGGER_INSTANCE_DISPLAY_OPTIONS = [ - 'all', - 'trigger-instances', - 'trigger_instances', - 'triggerinstances', - 'triggers' + "all", + "trigger-instances", + "trigger_instances", + "triggerinstances", + "triggers", ] ACTION_EXECUTION_DISPLAY_OPTIONS = [ - 'all', - 'executions', - 'action-executions', - 'action_executions', - 'actionexecutions', - 'actions' + "all", + "executions", + "action-executions", + "action_executions", + "actionexecutions", + "actions", ] class TraceBranch(resource.ResourceBranch): def __init__(self, description, app, subparsers, parent_parser=None): super(TraceBranch, self).__init__( - Trace, description, app, subparsers, + Trace, + description, + app, + subparsers, parent_parser=parent_parser, read_only=True, - commands={ - 'list': TraceListCommand, - 'get': TraceGetCommand - }) + commands={"list": TraceListCommand, "get": TraceGetCommand}, + ) class SingleTraceDisplayMixin(object): - def print_trace_details(self, trace, args, **kwargs): - options = {'attributes': TRACE_ATTRIBUTE_DISPLAY_ORDER if args.json else - TRACE_HEADER_DISPLAY_ORDER} - options['json'] = args.json - options['yaml'] = args.yaml - options['attribute_transform_functions'] = self.attribute_transform_functions + options = { + "attributes": TRACE_ATTRIBUTE_DISPLAY_ORDER + if args.json + else TRACE_HEADER_DISPLAY_ORDER + } + options["json"] = args.json + options["yaml"] = args.yaml + options["attribute_transform_functions"] = self.attribute_transform_functions formatter = execution_formatter.ExecutionResult @@ -81,35 +90,63 @@ def print_trace_details(self, trace, args, **kwargs): components = [] if any(attr in args.attr for attr in TRIGGER_INSTANCE_DISPLAY_OPTIONS): - components.extend([Resource(**{'id': trigger_instance['object_id'], - 'type': TriggerInstance._alias.lower(), - 'ref': trigger_instance['ref'], - 'updated_at': trigger_instance['updated_at']}) - for trigger_instance in trace.trigger_instances]) - if any(attr in args.attr for attr in ['all', 'rules']): - components.extend([Resource(**{'id': rule['object_id'], - 'type': Rule._alias.lower(), - 'ref': rule['ref'], - 'updated_at': rule['updated_at']}) - for rule in trace.rules]) + components.extend( + [ + Resource( + **{ + "id": trigger_instance["object_id"], + "type": TriggerInstance._alias.lower(), + "ref": trigger_instance["ref"], + "updated_at": trigger_instance["updated_at"], + } + ) + for trigger_instance in trace.trigger_instances + ] + ) + if any(attr in args.attr for attr in ["all", "rules"]): + components.extend( + [ + Resource( + **{ + "id": rule["object_id"], + "type": Rule._alias.lower(), + "ref": rule["ref"], + "updated_at": rule["updated_at"], + } + ) + for rule in trace.rules + ] + ) if any(attr in args.attr for attr in ACTION_EXECUTION_DISPLAY_OPTIONS): - components.extend([Resource(**{'id': execution['object_id'], - 'type': Execution._alias.lower(), - 'ref': execution['ref'], - 'updated_at': execution['updated_at']}) - for execution in trace.action_executions]) + components.extend( + [ + Resource( + **{ + "id": execution["object_id"], + "type": Execution._alias.lower(), + "ref": execution["ref"], + "updated_at": execution["updated_at"], + } + ) + for execution in trace.action_executions + ] + ) if components: components.sort(key=lambda resource: resource.updated_at) - self.print_output(components, table.MultiColumnTable, - attributes=TRACE_COMPONENT_DISPLAY_LABELS, - json=args.json, yaml=args.yaml) + self.print_output( + components, + table.MultiColumnTable, + attributes=TRACE_COMPONENT_DISPLAY_LABELS, + json=args.json, + yaml=args.yaml, + ) class TraceListCommand(resource.ResourceCommand, SingleTraceDisplayMixin): - display_attributes = ['id', 'uid', 'trace_tag', 'start_timestamp'] + display_attributes = ["id", "uid", "trace_tag", "start_timestamp"] attribute_transform_functions = { - 'start_timestamp': format_isodate_for_user_timezone + "start_timestamp": format_isodate_for_user_timezone } attribute_display_order = TRACE_ATTRIBUTE_DISPLAY_ORDER @@ -119,55 +156,90 @@ def __init__(self, resource, *args, **kwargs): self.default_limit = 50 super(TraceListCommand, self).__init__( - resource, 'list', 'Get the list of the %s most recent %s.' % - (self.default_limit, resource.get_plural_display_name().lower()), - *args, **kwargs) + resource, + "list", + "Get the list of the %s most recent %s." + % (self.default_limit, resource.get_plural_display_name().lower()), + *args, + **kwargs, + ) self.resource_name = resource.get_plural_display_name().lower() self.group = self.parser.add_mutually_exclusive_group() - self.parser.add_argument('-n', '--last', type=int, dest='last', - default=self.default_limit, - help=('List N most recent %s. Use -n -1 to fetch the full result \ - set.' % self.resource_name)) - self.parser.add_argument('-s', '--sort', type=str, dest='sort_order', - default='descending', - help=('Sort %s by start timestamp, ' - 'asc|ascending (earliest first) ' - 'or desc|descending (latest first)' % self.resource_name)) + self.parser.add_argument( + "-n", + "--last", + type=int, + dest="last", + default=self.default_limit, + help=( + "List N most recent %s. Use -n -1 to fetch the full result \ + set." + % self.resource_name + ), + ) + self.parser.add_argument( + "-s", + "--sort", + type=str, + dest="sort_order", + default="descending", + help=( + "Sort %s by start timestamp, " + "asc|ascending (earliest first) " + "or desc|descending (latest first)" % self.resource_name + ), + ) # Filter options - self.group.add_argument('-c', '--trace-tag', help='Trace-tag to filter the list.') - self.group.add_argument('-e', '--execution', help='Execution to filter the list.') - self.group.add_argument('-r', '--rule', help='Rule to filter the list.') - self.group.add_argument('-g', '--trigger-instance', - help='TriggerInstance to filter the list.') + self.group.add_argument( + "-c", "--trace-tag", help="Trace-tag to filter the list." + ) + self.group.add_argument( + "-e", "--execution", help="Execution to filter the list." + ) + self.group.add_argument("-r", "--rule", help="Rule to filter the list.") + self.group.add_argument( + "-g", "--trigger-instance", help="TriggerInstance to filter the list." + ) # Display options - self.parser.add_argument('-a', '--attr', nargs='+', - default=self.display_attributes, - help=('List of attributes to include in the ' - 'output. "all" will return all ' - 'attributes.')) - self.parser.add_argument('-w', '--width', nargs='+', type=int, - default=None, - help=('Set the width of columns in output.')) + self.parser.add_argument( + "-a", + "--attr", + nargs="+", + default=self.display_attributes, + help=( + "List of attributes to include in the " + 'output. "all" will return all ' + "attributes." + ), + ) + self.parser.add_argument( + "-w", + "--width", + nargs="+", + type=int, + default=None, + help=("Set the width of columns in output."), + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): # Filtering options if args.trace_tag: - kwargs['trace_tag'] = args.trace_tag + kwargs["trace_tag"] = args.trace_tag if args.trigger_instance: - kwargs['trigger_instance'] = args.trigger_instance + kwargs["trigger_instance"] = args.trigger_instance if args.execution: - kwargs['execution'] = args.execution + kwargs["execution"] = args.execution if args.rule: - kwargs['rule'] = args.rule + kwargs["rule"] = args.rule if args.sort_order: - if args.sort_order in ['asc', 'ascending']: - kwargs['sort_asc'] = True - elif args.sort_order in ['desc', 'descending']: - kwargs['sort_desc'] = True + if args.sort_order in ["asc", "ascending"]: + kwargs["sort_asc"] = True + elif args.sort_order in ["desc", "descending"]: + kwargs["sort_desc"] = True return self.manager.query_with_count(limit=args.last, **kwargs) def run_and_print(self, args, **kwargs): @@ -177,7 +249,7 @@ def run_and_print(self, args, **kwargs): # For a single Trace we must include the components unless # user has overriden the attributes to display if args.attr == self.display_attributes: - args.attr = ['all'] + args.attr = ["all"] self.print_trace_details(trace=instances[0], args=args) if not args.json and not args.yaml: @@ -185,27 +257,36 @@ def run_and_print(self, args, **kwargs): table.SingleRowTable.note_box(self.resource_name, 1) else: if args.json or args.yaml: - self.print_output(instances, table.MultiColumnTable, - attributes=args.attr, widths=args.width, - json=args.json, yaml=args.yaml, - attribute_transform_functions=self.attribute_transform_functions) + self.print_output( + instances, + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + json=args.json, + yaml=args.yaml, + attribute_transform_functions=self.attribute_transform_functions, + ) else: - self.print_output(instances, table.MultiColumnTable, - attributes=args.attr, widths=args.width, - attribute_transform_functions=self.attribute_transform_functions) + self.print_output( + instances, + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + attribute_transform_functions=self.attribute_transform_functions, + ) if args.last and count and count > args.last: table.SingleRowTable.note_box(self.resource_name, args.last) class TraceGetCommand(resource.ResourceGetCommand, SingleTraceDisplayMixin): - display_attributes = ['all'] + display_attributes = ["all"] attribute_display_order = TRACE_ATTRIBUTE_DISPLAY_ORDER attribute_transform_functions = { - 'start_timestamp': format_isodate_for_user_timezone + "start_timestamp": format_isodate_for_user_timezone } - pk_argument_name = 'id' + pk_argument_name = "id" def __init__(self, resource, *args, **kwargs): super(TraceGetCommand, self).__init__(resource, *args, **kwargs) @@ -213,23 +294,36 @@ def __init__(self, resource, *args, **kwargs): # Causation chains self.causation_group = self.parser.add_mutually_exclusive_group() - self.causation_group.add_argument('-e', '--execution', - help='Execution to show causation chain.') - self.causation_group.add_argument('-r', '--rule', help='Rule to show causation chain.') - self.causation_group.add_argument('-g', '--trigger-instance', - help='TriggerInstance to show causation chain.') + self.causation_group.add_argument( + "-e", "--execution", help="Execution to show causation chain." + ) + self.causation_group.add_argument( + "-r", "--rule", help="Rule to show causation chain." + ) + self.causation_group.add_argument( + "-g", "--trigger-instance", help="TriggerInstance to show causation chain." + ) # display filter group self.display_filter_group = self.parser.add_argument_group() - self.display_filter_group.add_argument('--show-executions', action='store_true', - help='Only show executions.') - self.display_filter_group.add_argument('--show-rules', action='store_true', - help='Only show rules.') - self.display_filter_group.add_argument('--show-trigger-instances', action='store_true', - help='Only show trigger instances.') - self.display_filter_group.add_argument('-n', '--hide-noop-triggers', action='store_true', - help='Hide noop trigger instances.') + self.display_filter_group.add_argument( + "--show-executions", action="store_true", help="Only show executions." + ) + self.display_filter_group.add_argument( + "--show-rules", action="store_true", help="Only show rules." + ) + self.display_filter_group.add_argument( + "--show-trigger-instances", + action="store_true", + help="Only show trigger instances.", + ) + self.display_filter_group.add_argument( + "-n", + "--hide-noop-triggers", + action="store_true", + help="Hide noop trigger instances.", + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): @@ -243,7 +337,7 @@ def run_and_print(self, args, **kwargs): trace = self.run(args, **kwargs) except resource.ResourceNotFoundError: self.print_not_found(args.id) - raise OperationFailureException('Trace %s not found.' % (args.id)) + raise OperationFailureException("Trace %s not found." % (args.id)) # First filter for causation chains trace = self._filter_trace_components(trace=trace, args=args) # next filter for display purposes @@ -266,13 +360,13 @@ def _filter_trace_components(trace, args): # pick the right component type if args.execution: component_id = args.execution - component_type = 'action_execution' + component_type = "action_execution" elif args.rule: component_id = args.rule - component_type = 'rule' + component_type = "rule" elif args.trigger_instance: component_id = args.trigger_instance - component_type = 'trigger_instance' + component_type = "trigger_instance" # Initialize collection to use action_executions = [] @@ -284,13 +378,13 @@ def _filter_trace_components(trace, args): while search_target_found: components_list = [] - if component_type == 'action_execution': + if component_type == "action_execution": components_list = trace.action_executions to_update_list = action_executions - elif component_type == 'rule': + elif component_type == "rule": components_list = trace.rules to_update_list = rules - elif component_type == 'trigger_instance': + elif component_type == "trigger_instance": components_list = trace.trigger_instances to_update_list = trigger_instances # Look for search_target in the right collection and @@ -300,22 +394,25 @@ def _filter_trace_components(trace, args): # init to default value component_caused_by_id = None for component in components_list: - test_id = component['object_id'] + test_id = component["object_id"] if test_id == component_id: - caused_by = component.get('caused_by', {}) - component_id = caused_by.get('id', None) - component_type = caused_by.get('type', None) + caused_by = component.get("caused_by", {}) + component_id = caused_by.get("id", None) + component_type = caused_by.get("type", None) # If provided the component_caused_by_id must match as well. This is mostly # applicable for rules since the same rule may appear multiple times and can # only be distinguished by causing TriggerInstance. - if component_caused_by_id and component_caused_by_id != component_id: + if ( + component_caused_by_id + and component_caused_by_id != component_id + ): continue component_caused_by_id = None to_update_list.append(component) # In some cases the component_id and the causing component are combined to # provide the complete causation chain. Think rule + triggerinstance - if component_id and ':' in component_id: - component_id_split = component_id.split(':') + if component_id and ":" in component_id: + component_id_split = component_id.split(":") component_id = component_id_split[0] component_caused_by_id = component_id_split[1] search_target_found = True @@ -333,19 +430,21 @@ def _apply_display_filters(trace, args): should be displayed. """ # If all the filters are false nothing is to be filtered. - all_component_types = not(args.show_executions or - args.show_rules or - args.show_trigger_instances) + all_component_types = not ( + args.show_executions or args.show_rules or args.show_trigger_instances + ) # check if noop_triggers are to be hidden. This check applies whenever TriggerInstances # are to be shown. - if (all_component_types or args.show_trigger_instances) and args.hide_noop_triggers: + if ( + all_component_types or args.show_trigger_instances + ) and args.hide_noop_triggers: filtered_trigger_instances = [] for trigger_instance in trace.trigger_instances: is_noop_trigger_instance = True for rule in trace.rules: - caused_by_id = rule.get('caused_by', {}).get('id', None) - if caused_by_id == trigger_instance['object_id']: + caused_by_id = rule.get("caused_by", {}).get("id", None) + if caused_by_id == trigger_instance["object_id"]: is_noop_trigger_instance = False if not is_noop_trigger_instance: filtered_trigger_instances.append(trigger_instance) diff --git a/st2client/st2client/commands/trigger.py b/st2client/st2client/commands/trigger.py index 2fd966261ce..3a960fddc8c 100644 --- a/st2client/st2client/commands/trigger.py +++ b/st2client/st2client/commands/trigger.py @@ -23,29 +23,40 @@ class TriggerTypeBranch(resource.ResourceBranch): def __init__(self, description, app, subparsers, parent_parser=None): super(TriggerTypeBranch, self).__init__( - TriggerType, description, app, subparsers, + TriggerType, + description, + app, + subparsers, parent_parser=parent_parser, commands={ - 'list': TriggerTypeListCommand, - 'get': TriggerTypeGetCommand, - 'update': TriggerTypeUpdateCommand, - 'delete': TriggerTypeDeleteCommand - }) + "list": TriggerTypeListCommand, + "get": TriggerTypeGetCommand, + "update": TriggerTypeUpdateCommand, + "delete": TriggerTypeDeleteCommand, + }, + ) # Registers extended commands - self.commands['getspecs'] = TriggerTypeSubTriggerCommand( - self.resource, self.app, self.subparsers, - add_help=False) + self.commands["getspecs"] = TriggerTypeSubTriggerCommand( + self.resource, self.app, self.subparsers, add_help=False + ) class TriggerTypeListCommand(resource.ContentPackResourceListCommand): - display_attributes = ['ref', 'pack', 'description'] + display_attributes = ["ref", "pack", "description"] class TriggerTypeGetCommand(resource.ContentPackResourceGetCommand): - display_attributes = ['all'] - attribute_display_order = ['id', 'ref', 'pack', 'name', 'description', - 'parameters_schema', 'payload_schema'] + display_attributes = ["all"] + attribute_display_order = [ + "id", + "ref", + "pack", + "name", + "description", + "parameters_schema", + "payload_schema", + ] class TriggerTypeUpdateCommand(resource.ContentPackResourceUpdateCommand): @@ -57,29 +68,45 @@ class TriggerTypeDeleteCommand(resource.ContentPackResourceDeleteCommand): class TriggerTypeSubTriggerCommand(resource.ResourceCommand): - attribute_display_order = ['id', 'ref', 'context', 'parameters', 'status', - 'start_timestamp', 'result'] + attribute_display_order = [ + "id", + "ref", + "context", + "parameters", + "status", + "start_timestamp", + "result", + ] def __init__(self, resource, *args, **kwargs): super(TriggerTypeSubTriggerCommand, self).__init__( - resource, kwargs.pop('name', 'getspecs'), - 'Return Trigger Specifications of a Trigger.', - *args, **kwargs) - - self.parser.add_argument('ref', nargs='?', - metavar='ref', - help='Fully qualified name (pack.trigger_name) ' + - 'of the trigger.') - - self.parser.add_argument('-h', '--help', - action='store_true', dest='help', - help='Print usage for the given action.') + resource, + kwargs.pop("name", "getspecs"), + "Return Trigger Specifications of a Trigger.", + *args, + **kwargs, + ) + + self.parser.add_argument( + "ref", + nargs="?", + metavar="ref", + help="Fully qualified name (pack.trigger_name) " + "of the trigger.", + ) + + self.parser.add_argument( + "-h", + "--help", + action="store_true", + dest="help", + help="Print usage for the given action.", + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): - trigger_mgr = self.app.client.managers['Trigger'] - return trigger_mgr.query(**{'type': args.ref}) + trigger_mgr = self.app.client.managers["Trigger"] + return trigger_mgr.query(**{"type": args.ref}) @resource.add_auth_token_to_kwargs_from_cli def run_and_print(self, args, **kwargs): @@ -87,5 +114,6 @@ def run_and_print(self, args, **kwargs): self.parser.print_help() return instances = self.run(args, **kwargs) - self.print_output(instances, table.MultiColumnTable, - json=args.json, yaml=args.yaml) + self.print_output( + instances, table.MultiColumnTable, json=args.json, yaml=args.yaml + ) diff --git a/st2client/st2client/commands/triggerinstance.py b/st2client/st2client/commands/triggerinstance.py index 2ac4da73da0..12966fea924 100644 --- a/st2client/st2client/commands/triggerinstance.py +++ b/st2client/st2client/commands/triggerinstance.py @@ -25,17 +25,23 @@ class TriggerInstanceResendCommand(resource.ResourceCommand): def __init__(self, resource, *args, **kwargs): super(TriggerInstanceResendCommand, self).__init__( - resource, kwargs.pop('name', 're-emit'), - 'Re-emit a particular trigger instance.', - *args, **kwargs) + resource, + kwargs.pop("name", "re-emit"), + "Re-emit a particular trigger instance.", + *args, + **kwargs, + ) - self.parser.add_argument('id', nargs='?', - metavar='id', - help='ID of trigger instance to re-emit.') self.parser.add_argument( - '-h', '--help', - action='store_true', dest='help', - help='Print usage for the given command.') + "id", nargs="?", metavar="id", help="ID of trigger instance to re-emit." + ) + self.parser.add_argument( + "-h", + "--help", + action="store_true", + dest="help", + help="Print usage for the given command.", + ) def run(self, args, **kwargs): return self.manager.re_emit(args.id) @@ -43,29 +49,35 @@ def run(self, args, **kwargs): @resource.add_auth_token_to_kwargs_from_cli def run_and_print(self, args, **kwargs): ret = self.run(args, **kwargs) - if 'message' in ret: - print(ret['message']) + if "message" in ret: + print(ret["message"]) class TriggerInstanceBranch(resource.ResourceBranch): def __init__(self, description, app, subparsers, parent_parser=None): super(TriggerInstanceBranch, self).__init__( - TriggerInstance, description, app, subparsers, - parent_parser=parent_parser, read_only=True, + TriggerInstance, + description, + app, + subparsers, + parent_parser=parent_parser, + read_only=True, commands={ - 'list': TriggerInstanceListCommand, - 'get': TriggerInstanceGetCommand - }) + "list": TriggerInstanceListCommand, + "get": TriggerInstanceGetCommand, + }, + ) - self.commands['re-emit'] = TriggerInstanceResendCommand(self.resource, self.app, - self.subparsers, add_help=False) + self.commands["re-emit"] = TriggerInstanceResendCommand( + self.resource, self.app, self.subparsers, add_help=False + ) class TriggerInstanceListCommand(resource.ResourceViewCommand): - display_attributes = ['id', 'trigger', 'occurrence_time', 'status'] + display_attributes = ["id", "trigger", "occurrence_time", "status"] attribute_transform_functions = { - 'occurrence_time': format_isodate_for_user_timezone + "occurrence_time": format_isodate_for_user_timezone } def __init__(self, resource, *args, **kwargs): @@ -73,83 +85,133 @@ def __init__(self, resource, *args, **kwargs): self.default_limit = 50 super(TriggerInstanceListCommand, self).__init__( - resource, 'list', 'Get the list of the %s most recent %s.' % - (self.default_limit, resource.get_plural_display_name().lower()), - *args, **kwargs) + resource, + "list", + "Get the list of the %s most recent %s." + % (self.default_limit, resource.get_plural_display_name().lower()), + *args, + **kwargs, + ) self.resource_name = resource.get_plural_display_name().lower() self.group = self.parser.add_argument_group() - self.parser.add_argument('-n', '--last', type=int, dest='last', - default=self.default_limit, - help=('List N most recent %s. Use -n -1 to fetch the full result \ - set.' % self.resource_name)) + self.parser.add_argument( + "-n", + "--last", + type=int, + dest="last", + default=self.default_limit, + help=( + "List N most recent %s. Use -n -1 to fetch the full result \ + set." + % self.resource_name + ), + ) # Filter options - self.group.add_argument('--trigger', help='Trigger reference to filter the list.') - - self.parser.add_argument('-tg', '--timestamp-gt', type=str, dest='timestamp_gt', - default=None, - help=('Only return trigger instances with occurrence_time ' - 'greater than the one provided. ' - 'Use time in the format 2000-01-01T12:00:00.000Z')) - self.parser.add_argument('-tl', '--timestamp-lt', type=str, dest='timestamp_lt', - default=None, - help=('Only return trigger instances with timestamp ' - 'lower than the one provided. ' - 'Use time in the format 2000-01-01T12:00:00.000Z')) - - self.group.add_argument('--status', - help='Can be pending, processing, processed or processing_failed.') + self.group.add_argument( + "--trigger", help="Trigger reference to filter the list." + ) + + self.parser.add_argument( + "-tg", + "--timestamp-gt", + type=str, + dest="timestamp_gt", + default=None, + help=( + "Only return trigger instances with occurrence_time " + "greater than the one provided. " + "Use time in the format 2000-01-01T12:00:00.000Z" + ), + ) + self.parser.add_argument( + "-tl", + "--timestamp-lt", + type=str, + dest="timestamp_lt", + default=None, + help=( + "Only return trigger instances with timestamp " + "lower than the one provided. " + "Use time in the format 2000-01-01T12:00:00.000Z" + ), + ) + + self.group.add_argument( + "--status", + help="Can be pending, processing, processed or processing_failed.", + ) # Display options - self.parser.add_argument('-a', '--attr', nargs='+', - default=self.display_attributes, - help=('List of attributes to include in the ' - 'output. "all" will return all ' - 'attributes.')) - self.parser.add_argument('-w', '--width', nargs='+', type=int, - default=None, - help=('Set the width of columns in output.')) + self.parser.add_argument( + "-a", + "--attr", + nargs="+", + default=self.display_attributes, + help=( + "List of attributes to include in the " + 'output. "all" will return all ' + "attributes." + ), + ) + self.parser.add_argument( + "-w", + "--width", + nargs="+", + type=int, + default=None, + help=("Set the width of columns in output."), + ) @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): # Filtering options if args.trigger: - kwargs['trigger'] = args.trigger + kwargs["trigger"] = args.trigger if args.timestamp_gt: - kwargs['timestamp_gt'] = args.timestamp_gt + kwargs["timestamp_gt"] = args.timestamp_gt if args.timestamp_lt: - kwargs['timestamp_lt'] = args.timestamp_lt + kwargs["timestamp_lt"] = args.timestamp_lt if args.status: - kwargs['status'] = args.status + kwargs["status"] = args.status include_attributes = self._get_include_attributes(args=args) if include_attributes: - include_attributes = ','.join(include_attributes) - kwargs['params'] = {'include_attributes': include_attributes} + include_attributes = ",".join(include_attributes) + kwargs["params"] = {"include_attributes": include_attributes} return self.manager.query_with_count(limit=args.last, **kwargs) def run_and_print(self, args, **kwargs): instances, count = self.run(args, **kwargs) if args.json or args.yaml: - self.print_output(reversed(instances), table.MultiColumnTable, - attributes=args.attr, widths=args.width, - json=args.json, yaml=args.yaml, - attribute_transform_functions=self.attribute_transform_functions) + self.print_output( + reversed(instances), + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + json=args.json, + yaml=args.yaml, + attribute_transform_functions=self.attribute_transform_functions, + ) else: - self.print_output(reversed(instances), table.MultiColumnTable, - attributes=args.attr, widths=args.width, - attribute_transform_functions=self.attribute_transform_functions) + self.print_output( + reversed(instances), + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + attribute_transform_functions=self.attribute_transform_functions, + ) if args.last and count and count > args.last: table.SingleRowTable.note_box(self.resource_name, args.last) class TriggerInstanceGetCommand(resource.ResourceGetCommand): - display_attributes = ['all'] - attribute_display_order = ['id', 'trigger', 'occurrence_time', 'payload'] + display_attributes = ["all"] + attribute_display_order = ["id", "trigger", "occurrence_time", "payload"] - pk_argument_name = 'id' + pk_argument_name = "id" @resource.add_auth_token_to_kwargs_from_cli def run(self, args, **kwargs): diff --git a/st2client/st2client/commands/webhook.py b/st2client/st2client/commands/webhook.py index 3a483445006..4b555ac59f5 100644 --- a/st2client/st2client/commands/webhook.py +++ b/st2client/st2client/commands/webhook.py @@ -23,37 +23,47 @@ class WebhookBranch(resource.ResourceBranch): def __init__(self, description, app, subparsers, parent_parser=None): super(WebhookBranch, self).__init__( - Webhook, description, app, subparsers, + Webhook, + description, + app, + subparsers, parent_parser=parent_parser, read_only=True, - commands={ - 'list': WebhookListCommand, - 'get': WebhookGetCommand - }) + commands={"list": WebhookListCommand, "get": WebhookGetCommand}, + ) class WebhookListCommand(resource.ContentPackResourceListCommand): - display_attributes = ['url', 'type', 'description'] + display_attributes = ["url", "type", "description"] def run_and_print(self, args, **kwargs): instances = self.run(args, **kwargs) for instance in instances: - instance.url = instance.parameters['url'] + instance.url = instance.parameters["url"] instances = sorted(instances, key=lambda k: k.url) if args.json or args.yaml: - self.print_output(instances, table.MultiColumnTable, - attributes=args.attr, widths=args.width, - json=args.json, yaml=args.yaml) + self.print_output( + instances, + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + json=args.json, + yaml=args.yaml, + ) else: - self.print_output(instances, table.MultiColumnTable, - attributes=args.attr, widths=args.width) + self.print_output( + instances, + table.MultiColumnTable, + attributes=args.attr, + widths=args.width, + ) class WebhookGetCommand(resource.ResourceGetCommand): - display_attributes = ['all'] - attribute_display_order = ['type', 'description'] + display_attributes = ["all"] + attribute_display_order = ["type", "description"] - pk_argument_name = 'url' + pk_argument_name = "url" diff --git a/st2client/st2client/commands/workflow.py b/st2client/st2client/commands/workflow.py index 5348f767068..57f9f52a5f9 100644 --- a/st2client/st2client/commands/workflow.py +++ b/st2client/st2client/commands/workflow.py @@ -1,4 +1,3 @@ - # Copyright 2020 The StackStorm Authors. # Copyright 2019 Extreme Networks, Inc. # @@ -28,26 +27,25 @@ class WorkflowBranch(commands.Branch): - def __init__(self, description, app, subparsers, parent_parser=None): super(WorkflowBranch, self).__init__( - 'workflow', description, app, subparsers, - parent_parser=parent_parser + "workflow", description, app, subparsers, parent_parser=parent_parser ) # Add subparser to register subcommands for managing workflows. - help_message = 'List of commands for managing workflows.' + help_message = "List of commands for managing workflows." self.subparsers = self.parser.add_subparsers(help=help_message) # Register workflow commands. - self.commands['inspect'] = WorkflowInspectionCommand(self.app, self.subparsers) + self.commands["inspect"] = WorkflowInspectionCommand(self.app, self.subparsers) class WorkflowInspectionCommand(commands.Command): - def __init__(self, *args, **kwargs): - name = 'inspect' - description = 'Inspect workflow definition and return the list of errors if any.' + name = "inspect" + description = ( + "Inspect workflow definition and return the list of errors if any." + ) args = tuple([name, description] + list(args)) super(WorkflowInspectionCommand, self).__init__(*args, **kwargs) @@ -55,27 +53,25 @@ def __init__(self, *args, **kwargs): arg_group = self.parser.add_mutually_exclusive_group() arg_group.add_argument( - '--file', - dest='file', - help='Local file path to the workflow definition.' + "--file", dest="file", help="Local file path to the workflow definition." ) arg_group.add_argument( - '--action', - dest='action', - help='Reference name for the registered action. This option works only if the file ' - 'referenced by the entry point is installed locally under /opt/stackstorm/packs.' + "--action", + dest="action", + help="Reference name for the registered action. This option works only if the file " + "referenced by the entry point is installed locally under /opt/stackstorm/packs.", ) @property def manager(self): - return self.app.client.managers['Workflow'] + return self.app.client.managers["Workflow"] def get_file_content(self, file_path): if not os.path.isfile(file_path): raise Exception('File "%s" does not exist on local system.' % file_path) - with open(file_path, 'r') as f: + with open(file_path, "r") as f: content = f.read() return content @@ -88,13 +84,18 @@ def run(self, args, **kwargs): # is executed locally where the content is stored. if not wf_def_file: action_ref = args.action - action_manager = self.app.client.managers['Action'] + action_manager = self.app.client.managers["Action"] action = action_manager.get_by_ref_or_id(ref_or_id=action_ref) if not action: raise Exception('Unable to identify action "%s".' % action_ref) - wf_def_file = '/opt/stackstorm/packs/' + action.pack + '/actions/' + action.entry_point + wf_def_file = ( + "/opt/stackstorm/packs/" + + action.pack + + "/actions/" + + action.entry_point + ) wf_def = self.get_file_content(wf_def_file) @@ -105,10 +106,10 @@ def run_and_print(self, args, **kwargs): errors = self.run(args, **kwargs) if not isinstance(errors, list): - raise TypeError('The inspection result is not type of list: %s' % errors) + raise TypeError("The inspection result is not type of list: %s" % errors) if not errors: - print('No errors found in workflow definition.') + print("No errors found in workflow definition.") return print(yaml.safe_dump(errors, default_flow_style=False, allow_unicode=True)) diff --git a/st2client/st2client/config.py b/st2client/st2client/config.py index 5de500aec22..c002c7f4149 100644 --- a/st2client/st2client/config.py +++ b/st2client/st2client/config.py @@ -13,10 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = [ - 'get_config', - 'set_config' -] +__all__ = ["get_config", "set_config"] # Stores parsed config dictionary CONFIG = {} diff --git a/st2client/st2client/config_parser.py b/st2client/st2client/config_parser.py index e5095df3e2f..ca88209f87f 100644 --- a/st2client/st2client/config_parser.py +++ b/st2client/st2client/config_parser.py @@ -31,88 +31,38 @@ __all__ = [ - 'CLIConfigParser', - - 'ST2_CONFIG_DIRECTORY', - 'ST2_CONFIG_PATH', - - 'CONFIG_DEFAULT_VALUES' + "CLIConfigParser", + "ST2_CONFIG_DIRECTORY", + "ST2_CONFIG_PATH", + "CONFIG_DEFAULT_VALUES", ] -ST2_CONFIG_DIRECTORY = '~/.st2' +ST2_CONFIG_DIRECTORY = "~/.st2" ST2_CONFIG_DIRECTORY = os.path.abspath(os.path.expanduser(ST2_CONFIG_DIRECTORY)) -ST2_CONFIG_PATH = os.path.abspath(os.path.join(ST2_CONFIG_DIRECTORY, 'config')) +ST2_CONFIG_PATH = os.path.abspath(os.path.join(ST2_CONFIG_DIRECTORY, "config")) CONFIG_FILE_OPTIONS = { - 'general': { - 'base_url': { - 'type': 'string', - 'default': None - }, - 'api_version': { - 'type': 'string', - 'default': None - }, - 'cacert': { - 'type': 'string', - 'default': None - }, - 'silence_ssl_warnings': { - 'type': 'bool', - 'default': False - }, - 'silence_schema_output': { - 'type': 'bool', - 'default': True - } - }, - 'cli': { - 'debug': { - 'type': 'bool', - 'default': False - }, - 'cache_token': { - 'type': 'boolean', - 'default': True - }, - 'timezone': { - 'type': 'string', - 'default': 'UTC' - } - }, - 'credentials': { - 'username': { - 'type': 'string', - 'default': None - }, - 'password': { - 'type': 'string', - 'default': None - }, - 'api_key': { - 'type': 'string', - 'default': None - } + "general": { + "base_url": {"type": "string", "default": None}, + "api_version": {"type": "string", "default": None}, + "cacert": {"type": "string", "default": None}, + "silence_ssl_warnings": {"type": "bool", "default": False}, + "silence_schema_output": {"type": "bool", "default": True}, }, - 'api': { - 'url': { - 'type': 'string', - 'default': None - } + "cli": { + "debug": {"type": "bool", "default": False}, + "cache_token": {"type": "boolean", "default": True}, + "timezone": {"type": "string", "default": "UTC"}, }, - 'auth': { - 'url': { - 'type': 'string', - 'default': None - } + "credentials": { + "username": {"type": "string", "default": None}, + "password": {"type": "string", "default": None}, + "api_key": {"type": "string", "default": None}, }, - 'stream': { - 'url': { - 'type': 'string', - 'default': None - } - } + "api": {"url": {"type": "string", "default": None}}, + "auth": {"url": {"type": "string", "default": None}}, + "stream": {"url": {"type": "string", "default": None}}, } CONFIG_DEFAULT_VALUES = {} @@ -121,13 +71,18 @@ CONFIG_DEFAULT_VALUES[section] = {} for key, options in six.iteritems(keys): - default_value = options['default'] + default_value = options["default"] CONFIG_DEFAULT_VALUES[section][key] = default_value class CLIConfigParser(object): - def __init__(self, config_file_path, validate_config_exists=True, - validate_config_permissions=True, log=None): + def __init__( + self, + config_file_path, + validate_config_exists=True, + validate_config_permissions=True, + log=None, + ): if validate_config_exists and not os.path.isfile(config_file_path): raise ValueError('Config file "%s" doesn\'t exist') @@ -158,37 +113,40 @@ def parse(self): if bool(os.stat(config_dir_path).st_mode & 0o7): self.LOG.warn( "The StackStorm configuration directory permissions are " - "insecure (too permissive): others have access.") + "insecure (too permissive): others have access." + ) # Make sure the setgid bit is set on the directory if not bool(os.stat(config_dir_path).st_mode & 0o2000): self.LOG.info( "The SGID bit is not set on the StackStorm configuration " - "directory.") + "directory." + ) # Make sure the file permissions == 0o660 if bool(os.stat(self.config_file_path).st_mode & 0o7): self.LOG.warn( "The StackStorm configuration file permissions are " - "insecure: others have access.") + "insecure: others have access." + ) config = ConfigParser() - with io.open(self.config_file_path, 'r', encoding='utf8') as fp: + with io.open(self.config_file_path, "r", encoding="utf8") as fp: config.readfp(fp) for section, keys in six.iteritems(CONFIG_FILE_OPTIONS): for key, options in six.iteritems(keys): - key_type = options['type'] - key_default_value = options['default'] + key_type = options["type"] + key_default_value = options["default"] if config.has_option(section, key): - if key_type in ['str', 'string']: + if key_type in ["str", "string"]: get_func = config.get - elif key_type in ['int', 'integer']: + elif key_type in ["int", "integer"]: get_func = config.getint - elif key_type in ['float']: + elif key_type in ["float"]: get_func = config.getfloat - elif key_type in ['bool', 'boolean']: + elif key_type in ["bool", "boolean"]: get_func = config.getboolean else: msg = 'Invalid type "%s" for option "%s"' % (key_type, key) diff --git a/st2client/st2client/exceptions/base.py b/st2client/st2client/exceptions/base.py index f9cd3436658..97c9bb8a09f 100644 --- a/st2client/st2client/exceptions/base.py +++ b/st2client/st2client/exceptions/base.py @@ -16,7 +16,8 @@ class StackStormCLIBaseException(Exception): """ - The root of the exception class hierarchy for all - StackStorm CLI exceptions. + The root of the exception class hierarchy for all + StackStorm CLI exceptions. """ + pass diff --git a/st2client/st2client/formatters/__init__.py b/st2client/st2client/formatters/__init__.py index dcaaee3ee15..e0d8e5f718a 100644 --- a/st2client/st2client/formatters/__init__.py +++ b/st2client/st2client/formatters/__init__.py @@ -25,10 +25,8 @@ class Formatter(six.with_metaclass(abc.ABCMeta, object)): - @classmethod @abc.abstractmethod def format(cls, subject, *args, **kwargs): - """Override this method to customize output format for the subject. - """ + """Override this method to customize output format for the subject.""" raise NotImplementedError diff --git a/st2client/st2client/formatters/doc.py b/st2client/st2client/formatters/doc.py index ea2218dec4f..5f6ca96dce1 100644 --- a/st2client/st2client/formatters/doc.py +++ b/st2client/st2client/formatters/doc.py @@ -23,10 +23,7 @@ from st2client import formatters from st2client.utils import jsutil -__all__ = [ - 'JsonFormatter', - 'YAMLFormatter' -] +__all__ = ["JsonFormatter", "YAMLFormatter"] LOG = logging.getLogger(__name__) @@ -34,25 +31,34 @@ class BaseFormatter(formatters.Formatter): @classmethod def format(self, subject, *args, **kwargs): - attributes = kwargs.get('attributes', None) + attributes = kwargs.get("attributes", None) if type(subject) is str: subject = json.loads(subject) - elif not isinstance(subject, (list, tuple)) and not hasattr(subject, '__iter__'): + elif not isinstance(subject, (list, tuple)) and not hasattr( + subject, "__iter__" + ): doc = subject if isinstance(subject, dict) else subject.__dict__ - keys = list(doc.keys()) if not attributes or 'all' in attributes else attributes + keys = ( + list(doc.keys()) + if not attributes or "all" in attributes + else attributes + ) docs = jsutil.get_kvps(doc, keys) else: docs = [] for item in subject: doc = item if isinstance(item, dict) else item.__dict__ - keys = list(doc.keys()) if not attributes or 'all' in attributes else attributes + keys = ( + list(doc.keys()) + if not attributes or "all" in attributes + else attributes + ) docs.append(jsutil.get_kvps(doc, keys)) return docs class JsonFormatter(BaseFormatter): - @classmethod def format(self, subject, *args, **kwargs): docs = BaseFormatter.format(subject, *args, **kwargs) @@ -60,7 +66,6 @@ def format(self, subject, *args, **kwargs): class YAMLFormatter(BaseFormatter): - @classmethod def format(self, subject, *args, **kwargs): docs = BaseFormatter.format(subject, *args, **kwargs) diff --git a/st2client/st2client/formatters/execution.py b/st2client/st2client/formatters/execution.py index 69da8cdb416..b52527d4de2 100644 --- a/st2client/st2client/formatters/execution.py +++ b/st2client/st2client/formatters/execution.py @@ -32,32 +32,31 @@ LOG = logging.getLogger(__name__) -PLATFORM_MAXINT = 2 ** (struct.Struct('i').size * 8 - 1) - 1 +PLATFORM_MAXINT = 2 ** (struct.Struct("i").size * 8 - 1) - 1 def _print_bordered(text): - lines = text.split('\n') + lines = text.split("\n") width = max(len(s) for s in lines) + 2 - res = ['\n+' + '-' * width + '+'] + res = ["\n+" + "-" * width + "+"] for s in lines: - res.append('| ' + (s + ' ' * width)[:width - 2] + ' |') - res.append('+' + '-' * width + '+') - return '\n'.join(res) + res.append("| " + (s + " " * width)[: width - 2] + " |") + res.append("+" + "-" * width + "+") + return "\n".join(res) class ExecutionResult(formatters.Formatter): - @classmethod def format(cls, entry, *args, **kwargs): - attrs = kwargs.get('attributes', []) - attribute_transform_functions = kwargs.get('attribute_transform_functions', {}) - key = kwargs.get('key', None) + attrs = kwargs.get("attributes", []) + attribute_transform_functions = kwargs.get("attribute_transform_functions", {}) + key = kwargs.get("key", None) if key: output = jsutil.get_value(entry.result, key) else: # drop entry to the dict so that jsutil can operate entry = vars(entry) - output = '' + output = "" for attr in attrs: value = jsutil.get_value(entry, attr) value = strutil.strip_carriage_returns(strutil.unescape(value)) @@ -65,8 +64,12 @@ def format(cls, entry, *args, **kwargs): # if the leading character is objectish start and last character is objectish # end but the string isn't supposed to be a object. Try/Except will catch # this for now, but this should be improved. - if (isinstance(value, six.string_types) and len(value) > 0 and - value[0] in ['{', '['] and value[len(value) - 1] in ['}', ']']): + if ( + isinstance(value, six.string_types) + and len(value) > 0 + and value[0] in ["{", "["] + and value[len(value) - 1] in ["}", "]"] + ): try: new_value = ast.literal_eval(value) except: @@ -79,31 +82,40 @@ def format(cls, entry, *args, **kwargs): # 2. Drop the trailing newline # 3. Set width to maxint so pyyaml does not split text. Anything longer # and likely we will see other issues like storage :P. - formatted_value = yaml.safe_dump({attr: value}, - default_flow_style=False, - width=PLATFORM_MAXINT, - indent=2)[len(attr) + 2:-1] - value = ('\n' if isinstance(value, dict) else '') + formatted_value + formatted_value = yaml.safe_dump( + {attr: value}, + default_flow_style=False, + width=PLATFORM_MAXINT, + indent=2, + )[len(attr) + 2 : -1] + value = ("\n" if isinstance(value, dict) else "") + formatted_value value = strutil.dedupe_newlines(value) # transform the value of our attribute so things like 'status' # and 'timestamp' are formatted nicely - transform_function = attribute_transform_functions.get(attr, - lambda value: value) + transform_function = attribute_transform_functions.get( + attr, lambda value: value + ) value = transform_function(value=value) - output += ('\n' if output else '') + '%s: %s' % \ - (DisplayColors.colorize(attr, DisplayColors.BLUE), value) + output += ("\n" if output else "") + "%s: %s" % ( + DisplayColors.colorize(attr, DisplayColors.BLUE), + value, + ) - output_schema = entry.get('action', {}).get('output_schema') - schema_check = get_config()['general']['silence_schema_output'] - if not output_schema and kwargs.get('with_schema'): + output_schema = entry.get("action", {}).get("output_schema") + schema_check = get_config()["general"]["silence_schema_output"] + if not output_schema and kwargs.get("with_schema"): rendered_schema = { - 'output_schema': schema.render_output_schema_from_output(entry['result']) + "output_schema": schema.render_output_schema_from_output( + entry["result"] + ) } - rendered_schema = yaml.safe_dump(rendered_schema, default_flow_style=False) - output += '\n' + rendered_schema = yaml.safe_dump( + rendered_schema, default_flow_style=False + ) + output += "\n" output += _print_bordered( "Based on the action output the following inferred schema was built:" "\n\n" @@ -120,7 +132,11 @@ def format(cls, entry, *args, **kwargs): else: # Assume Python 2 try: - result = strutil.unescape(str(output)).decode('unicode_escape').encode('utf-8') + result = ( + strutil.unescape(str(output)) + .decode("unicode_escape") + .encode("utf-8") + ) except UnicodeDecodeError: # String contains a value which is not an unicode escape sequence, ignore the error result = strutil.unescape(str(output)) diff --git a/st2client/st2client/formatters/table.py b/st2client/st2client/formatters/table.py index 404469ce0ea..91cc59e009a 100644 --- a/st2client/st2client/formatters/table.py +++ b/st2client/st2client/formatters/table.py @@ -40,40 +40,38 @@ MIN_COL_WIDTH = 5 # Default attribute display order to use if one is not provided -DEFAULT_ATTRIBUTE_DISPLAY_ORDER = ['id', 'name', 'pack', 'description'] +DEFAULT_ATTRIBUTE_DISPLAY_ORDER = ["id", "name", "pack", "description"] # Attributes which contain bash escape sequences - we can't split those across multiple lines # since things would break COLORIZED_ATTRIBUTES = { - 'status': { - 'col_width': 24 # Note: len('succeed' + ' (XXXX elapsed)') <= 24 - } + "status": {"col_width": 24} # Note: len('succeed' + ' (XXXX elapsed)') <= 24 } class MultiColumnTable(formatters.Formatter): - def __init__(self): self._table_width = 0 @classmethod def format(cls, entries, *args, **kwargs): - attributes = kwargs.get('attributes', []) - attribute_transform_functions = kwargs.get('attribute_transform_functions', {}) - widths = kwargs.get('widths', []) + attributes = kwargs.get("attributes", []) + attribute_transform_functions = kwargs.get("attribute_transform_functions", {}) + widths = kwargs.get("widths", []) widths = widths or [] if not widths and attributes: # Dynamically calculate column size based on the terminal size cols = get_terminal_size_columns() - if attributes[0] == 'id': + if attributes[0] == "id": # consume iterator and save as entries so collection is accessible later. entries = [e for e in entries] # first column contains id, make sure it's not broken up - first_col_width = cls._get_required_column_width(values=[e.id for e in entries], - minimum_width=MIN_ID_COL_WIDTH) - cols = (cols - first_col_width) + first_col_width = cls._get_required_column_width( + values=[e.id for e in entries], minimum_width=MIN_ID_COL_WIDTH + ) + cols = cols - first_col_width col_width = int(math.floor((cols / len(attributes)))) else: col_width = int(math.floor((cols / len(attributes)))) @@ -88,14 +86,16 @@ def format(cls, entries, *args, **kwargs): continue if attribute_name in COLORIZED_ATTRIBUTES: - current_col_width = COLORIZED_ATTRIBUTES[attribute_name]['col_width'] - subtract += (current_col_width - col_width) + current_col_width = COLORIZED_ATTRIBUTES[attribute_name][ + "col_width" + ] + subtract += current_col_width - col_width else: # Make sure we subtract the added width from the last column so we account # for the fixed width columns and make sure table is not wider than the # terminal width. if index == (len(attributes) - 1) and subtract: - current_col_width = (col_width - subtract) + current_col_width = col_width - subtract if current_col_width <= MIN_COL_WIDTH: # Make sure column width is always grater than MIN_COL_WIDTH @@ -105,12 +105,14 @@ def format(cls, entries, *args, **kwargs): widths.append(current_col_width) - if not attributes or 'all' in attributes: + if not attributes or "all" in attributes: entries = list(entries) if entries else [] if len(entries) >= 1: attributes = list(entries[0].__dict__.keys()) - attributes = sorted([attr for attr in attributes if not attr.startswith('_')]) + attributes = sorted( + [attr for attr in attributes if not attr.startswith("_")] + ) else: # There are no entries so we can't infer available attributes attributes = [] @@ -123,8 +125,7 @@ def format(cls, entries, *args, **kwargs): # If only 1 width value is provided then # apply it to all columns else fix at 28. width = widths[0] if len(widths) == 1 else 28 - columns = zip(attributes, - [width for i in range(0, len(attributes))]) + columns = zip(attributes, [width for i in range(0, len(attributes))]) # Format result to table. table = PrettyTable() @@ -132,14 +133,14 @@ def format(cls, entries, *args, **kwargs): table.field_names.append(column[0]) table.max_width[column[0]] = column[1] table.padding_width = 1 - table.align = 'l' - table.valign = 't' + table.align = "l" + table.valign = "t" for entry in entries: # TODO: Improve getting values of nested dict. values = [] for field_name in table.field_names: - if '.' in field_name: - field_names = field_name.split('.') + if "." in field_name: + field_names = field_name.split(".") value = getattr(entry, field_names.pop(0), {}) for name in field_names: value = cls._get_field_value(value, name) @@ -149,8 +150,9 @@ def format(cls, entries, *args, **kwargs): values.append(value) else: value = cls._get_simple_field_value(entry, field_name) - transform_function = attribute_transform_functions.get(field_name, - lambda value: value) + transform_function = attribute_transform_functions.get( + field_name, lambda value: value + ) value = transform_function(value=value) value = strutil.strip_carriage_returns(strutil.unescape(value)) values.append(value) @@ -177,14 +179,14 @@ def _get_simple_field_value(entry, field_name): """ Format a value for a simple field. """ - value = getattr(entry, field_name, '') + value = getattr(entry, field_name, "") if isinstance(value, (list, tuple)): if len(value) == 0: - value = '' + value = "" elif isinstance(value[0], (str, six.text_type)): # List contains simple string values, format it as comma # separated string - value = ', '.join(value) + value = ", ".join(value) return value @@ -192,10 +194,10 @@ def _get_simple_field_value(entry, field_name): def _get_field_value(value, field_name): r_val = value.get(field_name, None) if r_val is None: - return '' + return "" if isinstance(r_val, list) or isinstance(r_val, dict): - return r_val if len(r_val) > 0 else '' + return r_val if len(r_val) > 0 else "" return r_val @staticmethod @@ -203,7 +205,7 @@ def _get_friendly_column_name(name): if not name: return None - friendly_name = name.replace('_', ' ').replace('.', ' ').capitalize() + friendly_name = name.replace("_", " ").replace(".", " ").capitalize() return friendly_name @staticmethod @@ -213,33 +215,34 @@ def _get_required_column_width(values, minimum_width=0): class PropertyValueTable(formatters.Formatter): - @classmethod def format(cls, subject, *args, **kwargs): - attributes = kwargs.get('attributes', None) - attribute_display_order = kwargs.get('attribute_display_order', - DEFAULT_ATTRIBUTE_DISPLAY_ORDER) - attribute_transform_functions = kwargs.get('attribute_transform_functions', {}) + attributes = kwargs.get("attributes", None) + attribute_display_order = kwargs.get( + "attribute_display_order", DEFAULT_ATTRIBUTE_DISPLAY_ORDER + ) + attribute_transform_functions = kwargs.get("attribute_transform_functions", {}) - if not attributes or 'all' in attributes: - attributes = sorted([attr for attr in subject.__dict__ - if not attr.startswith('_')]) + if not attributes or "all" in attributes: + attributes = sorted( + [attr for attr in subject.__dict__ if not attr.startswith("_")] + ) for attr in attribute_display_order[::-1]: if attr in attributes: attributes.remove(attr) attributes = [attr] + attributes table = PrettyTable() - table.field_names = ['Property', 'Value'] - table.max_width['Property'] = 20 - table.max_width['Value'] = 60 + table.field_names = ["Property", "Value"] + table.max_width["Property"] = 20 + table.max_width["Value"] = 60 table.padding_width = 1 - table.align = 'l' - table.valign = 't' + table.align = "l" + table.valign = "t" for attribute in attributes: - if '.' in attribute: - field_names = attribute.split('.') + if "." in attribute: + field_names = attribute.split(".") value = cls._get_attribute_value(subject, field_names.pop(0)) for name in field_names: value = cls._get_attribute_value(value, name) @@ -248,8 +251,9 @@ def format(cls, subject, *args, **kwargs): else: value = cls._get_attribute_value(subject, attribute) - transform_function = attribute_transform_functions.get(attribute, - lambda value: value) + transform_function = attribute_transform_functions.get( + attribute, lambda value: value + ) value = transform_function(value=value) if type(value) is dict or type(value) is list: @@ -266,9 +270,9 @@ def _get_attribute_value(subject, attribute): else: r_val = getattr(subject, attribute, None) if r_val is None: - return '' + return "" if isinstance(r_val, list) or isinstance(r_val, dict): - return r_val if len(r_val) > 0 else '' + return r_val if len(r_val) > 0 else "" return r_val @@ -284,19 +288,25 @@ def note_box(entity, limit): else: entity = entity[:-1] - message = "Note: Only one %s is displayed. Use -n/--last flag for more results." \ + message = ( + "Note: Only one %s is displayed. Use -n/--last flag for more results." % entity + ) else: - message = "Note: Only first %s %s are displayed. Use -n/--last flag for more results."\ + message = ( + "Note: Only first %s %s are displayed. Use -n/--last flag for more results." % (limit, entity) + ) # adding default padding message_length = len(message) + 3 m = MultiColumnTable() if m.table_width > message_length: - note = PrettyTable([""], right_padding_width=(m.table_width - message_length)) + note = PrettyTable( + [""], right_padding_width=(m.table_width - message_length) + ) else: note = PrettyTable([""]) note.header = False note.add_row([message]) - sys.stderr.write((str(note) + '\n')) + sys.stderr.write((str(note) + "\n")) return diff --git a/st2client/st2client/models/__init__.py b/st2client/st2client/models/__init__.py index 2862f59d283..8e27a77050e 100644 --- a/st2client/st2client/models/__init__.py +++ b/st2client/st2client/models/__init__.py @@ -15,19 +15,19 @@ from __future__ import absolute_import -from st2client.models.core import * # noqa -from st2client.models.auth import * # noqa -from st2client.models.action import * # noqa +from st2client.models.core import * # noqa +from st2client.models.auth import * # noqa +from st2client.models.action import * # noqa from st2client.models.action_alias import * # noqa from st2client.models.aliasexecution import * # noqa from st2client.models.config import * # noqa from st2client.models.inquiry import * # noqa -from st2client.models.keyvalue import * # noqa -from st2client.models.pack import * # noqa -from st2client.models.policy import * # noqa -from st2client.models.reactor import * # noqa -from st2client.models.trace import * # noqa -from st2client.models.webhook import * # noqa -from st2client.models.timer import * # noqa -from st2client.models.service_registry import * # noqa -from st2client.models.rbac import * # noqa +from st2client.models.keyvalue import * # noqa +from st2client.models.pack import * # noqa +from st2client.models.policy import * # noqa +from st2client.models.reactor import * # noqa +from st2client.models.trace import * # noqa +from st2client.models.webhook import * # noqa +from st2client.models.timer import * # noqa +from st2client.models.service_registry import * # noqa +from st2client.models.rbac import * # noqa diff --git a/st2client/st2client/models/action.py b/st2client/st2client/models/action.py index 10692d3dc47..d31b694f800 100644 --- a/st2client/st2client/models/action.py +++ b/st2client/st2client/models/action.py @@ -24,27 +24,33 @@ class RunnerType(core.Resource): - _alias = 'Runner' - _display_name = 'Runner' - _plural = 'RunnerTypes' - _plural_display_name = 'Runners' - _repr_attributes = ['name', 'enabled', 'description'] + _alias = "Runner" + _display_name = "Runner" + _plural = "RunnerTypes" + _plural_display_name = "Runners" + _repr_attributes = ["name", "enabled", "description"] class Action(core.Resource): - _plural = 'Actions' - _repr_attributes = ['name', 'pack', 'enabled', 'runner_type'] - _url_path = 'actions' + _plural = "Actions" + _repr_attributes = ["name", "pack", "enabled", "runner_type"] + _url_path = "actions" class Execution(core.Resource): - _alias = 'Execution' - _display_name = 'Action Execution' - _url_path = 'executions' - _plural = 'ActionExecutions' - _plural_display_name = 'Action executions' - _repr_attributes = ['status', 'action', 'start_timestamp', 'end_timestamp', 'parameters', - 'delay'] + _alias = "Execution" + _display_name = "Action Execution" + _url_path = "executions" + _plural = "ActionExecutions" + _plural_display_name = "Action executions" + _repr_attributes = [ + "status", + "action", + "start_timestamp", + "end_timestamp", + "parameters", + "delay", + ] # NOTE: LiveAction has been deprecated in favor of Execution. It will be left here for diff --git a/st2client/st2client/models/action_alias.py b/st2client/st2client/models/action_alias.py index 42162eae3b0..1c1a696cff2 100644 --- a/st2client/st2client/models/action_alias.py +++ b/st2client/st2client/models/action_alias.py @@ -17,25 +17,22 @@ from st2client.models import core -__all__ = [ - 'ActionAlias', - 'ActionAliasMatch' -] +__all__ = ["ActionAlias", "ActionAliasMatch"] class ActionAlias(core.Resource): - _alias = 'Action-Alias' - _display_name = 'Action Alias' - _plural = 'ActionAliases' - _plural_display_name = 'Action Aliases' - _url_path = 'actionalias' - _repr_attributes = ['name', 'pack', 'action_ref'] + _alias = "Action-Alias" + _display_name = "Action Alias" + _plural = "ActionAliases" + _plural_display_name = "Action Aliases" + _url_path = "actionalias" + _repr_attributes = ["name", "pack", "action_ref"] class ActionAliasMatch(core.Resource): - _alias = 'Action-Alias-Match' - _display_name = 'ActionAlias Match' - _plural = 'ActionAliasMatches' - _plural_display_name = 'Action Alias Matches' - _url_path = 'actionalias' - _repr_attributes = ['command'] + _alias = "Action-Alias-Match" + _display_name = "ActionAlias Match" + _plural = "ActionAliasMatches" + _plural_display_name = "Action Alias Matches" + _url_path = "actionalias" + _repr_attributes = ["command"] diff --git a/st2client/st2client/models/aliasexecution.py b/st2client/st2client/models/aliasexecution.py index 12cfc67cf5d..a2d7e62a57b 100644 --- a/st2client/st2client/models/aliasexecution.py +++ b/st2client/st2client/models/aliasexecution.py @@ -17,16 +17,21 @@ from st2client.models import core -__all__ = [ - 'ActionAliasExecution' -] +__all__ = ["ActionAliasExecution"] class ActionAliasExecution(core.Resource): - _alias = 'Action-Alias-Execution' - _display_name = 'ActionAlias Execution' - _plural = 'ActionAliasExecutions' - _plural_display_name = 'Runners' - _url_path = 'aliasexecution' - _repr_attributes = ['name', 'format', 'command', 'user', 'source_channel', - 'notification_channel', 'notification_route'] + _alias = "Action-Alias-Execution" + _display_name = "ActionAlias Execution" + _plural = "ActionAliasExecutions" + _plural_display_name = "Runners" + _url_path = "aliasexecution" + _repr_attributes = [ + "name", + "format", + "command", + "user", + "source_channel", + "notification_channel", + "notification_route", + ] diff --git a/st2client/st2client/models/auth.py b/st2client/st2client/models/auth.py index 9fa626a19a0..7c909ea1721 100644 --- a/st2client/st2client/models/auth.py +++ b/st2client/st2client/models/auth.py @@ -24,14 +24,14 @@ class Token(core.Resource): - _display_name = 'Access Token' - _plural = 'Tokens' - _plural_display_name = 'Access Tokens' - _repr_attributes = ['user', 'expiry', 'metadata'] + _display_name = "Access Token" + _plural = "Tokens" + _plural_display_name = "Access Tokens" + _repr_attributes = ["user", "expiry", "metadata"] class ApiKey(core.Resource): - _display_name = 'API Key' - _plural = 'ApiKeys' - _plural_display_name = 'API Keys' - _repr_attributes = ['id', 'user', 'metadata'] + _display_name = "API Key" + _plural = "ApiKeys" + _plural_display_name = "API Keys" + _repr_attributes = ["id", "user", "metadata"] diff --git a/st2client/st2client/models/config.py b/st2client/st2client/models/config.py index 247b4fcaf95..f9054751ed4 100644 --- a/st2client/st2client/models/config.py +++ b/st2client/st2client/models/config.py @@ -19,14 +19,14 @@ class Config(core.Resource): - _display_name = 'Config' - _plural = 'Configs' - _plural_display_name = 'Configs' + _display_name = "Config" + _plural = "Configs" + _plural_display_name = "Configs" class ConfigSchema(core.Resource): - _display_name = 'Config Schema' - _plural = 'ConfigSchema' - _plural_display_name = 'Config Schemas' - _url_path = 'config_schemas' - _repr_attributes = ['id', 'pack', 'attributes'] + _display_name = "Config Schema" + _plural = "ConfigSchema" + _plural_display_name = "Config Schemas" + _url_path = "config_schemas" + _repr_attributes = ["id", "pack", "attributes"] diff --git a/st2client/st2client/models/core.py b/st2client/st2client/models/core.py index 255c91534f7..d2a9b694f16 100644 --- a/st2client/st2client/models/core.py +++ b/st2client/st2client/models/core.py @@ -34,12 +34,13 @@ def add_auth_token_to_kwargs_from_env(func): @wraps(func) def decorate(*args, **kwargs): - if not kwargs.get('token') and os.environ.get('ST2_AUTH_TOKEN', None): - kwargs['token'] = os.environ.get('ST2_AUTH_TOKEN') - if not kwargs.get('api_key') and os.environ.get('ST2_API_KEY', None): - kwargs['api_key'] = os.environ.get('ST2_API_KEY') + if not kwargs.get("token") and os.environ.get("ST2_AUTH_TOKEN", None): + kwargs["token"] = os.environ.get("ST2_AUTH_TOKEN") + if not kwargs.get("api_key") and os.environ.get("ST2_API_KEY", None): + kwargs["api_key"] = os.environ.get("ST2_API_KEY") return func(*args, **kwargs) + return decorate @@ -81,8 +82,11 @@ def to_dict(self, exclude_attributes=None): exclude_attributes = exclude_attributes or [] attributes = list(self.__dict__.keys()) - attributes = [attr for attr in attributes if not attr.startswith('__') and - attr not in exclude_attributes] + attributes = [ + attr + for attr in attributes + if not attr.startswith("__") and attr not in exclude_attributes + ] result = {} for attribute in attributes: @@ -102,15 +106,15 @@ def get_display_name(cls): @classmethod def get_plural_name(cls): if not cls._plural: - raise Exception('The %s class is missing class attributes ' - 'in its definition.' % cls.__name__) + raise Exception( + "The %s class is missing class attributes " + "in its definition." % cls.__name__ + ) return cls._plural @classmethod def get_plural_display_name(cls): - return (cls._plural_display_name - if cls._plural_display_name - else cls._plural) + return cls._plural_display_name if cls._plural_display_name else cls._plural @classmethod def get_url_path_name(cls): @@ -120,9 +124,9 @@ def get_url_path_name(cls): return cls.get_plural_name().lower() def serialize(self): - return dict((k, v) - for k, v in six.iteritems(self.__dict__) - if not k.startswith('_')) + return dict( + (k, v) for k, v in six.iteritems(self.__dict__) if not k.startswith("_") + ) @classmethod def deserialize(cls, doc): @@ -140,16 +144,15 @@ def __repr__(self): attributes = [] for attribute in self._repr_attributes: value = getattr(self, attribute, None) - attributes.append('%s=%s' % (attribute, value)) + attributes.append("%s=%s" % (attribute, value)) - attributes = ','.join(attributes) + attributes = ",".join(attributes) class_name = self.__class__.__name__ - result = '<%s %s>' % (class_name, attributes) + result = "<%s %s>" % (class_name, attributes) return result class ResourceManager(object): - def __init__(self, resource, endpoint, cacert=None, debug=False): self.resource = resource self.debug = debug @@ -159,46 +162,47 @@ def __init__(self, resource, endpoint, cacert=None, debug=False): def handle_error(response): try: content = response.json() - fault = content.get('faultstring', '') if content else '' + fault = content.get("faultstring", "") if content else "" if fault: - response.reason += '\nMESSAGE: %s' % fault + response.reason += "\nMESSAGE: %s" % fault except Exception as e: - response.reason += ('\nUnable to retrieve detailed message ' - 'from the HTTP response. %s\n' % six.text_type(e)) + response.reason += ( + "\nUnable to retrieve detailed message " + "from the HTTP response. %s\n" % six.text_type(e) + ) response.raise_for_status() @add_auth_token_to_kwargs_from_env def get_all(self, **kwargs): # TODO: This is ugly, stop abusing kwargs - url = '/%s' % self.resource.get_url_path_name() - limit = kwargs.pop('limit', None) - pack = kwargs.pop('pack', None) - prefix = kwargs.pop('prefix', None) - user = kwargs.pop('user', None) + url = "/%s" % self.resource.get_url_path_name() + limit = kwargs.pop("limit", None) + pack = kwargs.pop("pack", None) + prefix = kwargs.pop("prefix", None) + user = kwargs.pop("user", None) - params = kwargs.pop('params', {}) + params = kwargs.pop("params", {}) if limit: - params['limit'] = limit + params["limit"] = limit if pack: - params['pack'] = pack + params["pack"] = pack if prefix: - params['prefix'] = prefix + params["prefix"] = prefix if user: - params['user'] = user + params["user"] = user response = self.client.get(url=url, params=params, **kwargs) if response.status_code != http_client.OK: self.handle_error(response) - return [self.resource.deserialize(item) - for item in response.json()] + return [self.resource.deserialize(item) for item in response.json()] @add_auth_token_to_kwargs_from_env def get_by_id(self, id, **kwargs): - url = '/%s/%s' % (self.resource.get_url_path_name(), id) + url = "/%s/%s" % (self.resource.get_url_path_name(), id) response = self.client.get(url, **kwargs) if response.status_code == http_client.NOT_FOUND: return None @@ -214,14 +218,18 @@ def get_property(self, id_, property_name, self_deserialize=True, **kwargs): property_name: Name of the property self_deserialize: #Implies use the deserialize method implemented by this resource. """ - token = kwargs.pop('token', None) - api_key = kwargs.pop('api_key', None) + token = kwargs.pop("token", None) + api_key = kwargs.pop("api_key", None) if kwargs: - url = '/%s/%s/%s/?%s' % (self.resource.get_url_path_name(), id_, property_name, - urllib.parse.urlencode(kwargs)) + url = "/%s/%s/%s/?%s" % ( + self.resource.get_url_path_name(), + id_, + property_name, + urllib.parse.urlencode(kwargs), + ) else: - url = '/%s/%s/%s/' % (self.resource.get_url_path_name(), id_, property_name) + url = "/%s/%s/%s/" % (self.resource.get_url_path_name(), id_, property_name) if token: response = self.client.get(url, token=token) @@ -246,19 +254,21 @@ def get_by_ref_or_id(self, ref_or_id, **kwargs): def _query_details(self, **kwargs): if not kwargs: - raise Exception('Query parameter is not provided.') + raise Exception("Query parameter is not provided.") - token = kwargs.get('token', None) - api_key = kwargs.get('api_key', None) - params = kwargs.get('params', {}) + token = kwargs.get("token", None) + api_key = kwargs.get("api_key", None) + params = kwargs.get("params", {}) for k, v in six.iteritems(kwargs): # Note: That's a special case to support api_key and token kwargs - if k not in ['token', 'api_key', 'params']: + if k not in ["token", "api_key", "params"]: params[k] = v - url = '/%s/?%s' % (self.resource.get_url_path_name(), - urllib.parse.urlencode(params)) + url = "/%s/?%s" % ( + self.resource.get_url_path_name(), + urllib.parse.urlencode(params), + ) if token: response = self.client.get(url, token=token) @@ -284,8 +294,8 @@ def query(self, **kwargs): @add_auth_token_to_kwargs_from_env def query_with_count(self, **kwargs): instances, response = self._query_details(**kwargs) - if response and 'X-Total-Count' in response.headers: - return (instances, int(response.headers['X-Total-Count'])) + if response and "X-Total-Count" in response.headers: + return (instances, int(response.headers["X-Total-Count"])) else: return (instances, None) @@ -296,13 +306,15 @@ def get_by_name(self, name, **kwargs): return None else: if len(instances) > 1: - raise Exception('More than one %s named "%s" are found.' % - (self.resource.__name__.lower(), name)) + raise Exception( + 'More than one %s named "%s" are found.' + % (self.resource.__name__.lower(), name) + ) return instances[0] @add_auth_token_to_kwargs_from_env def create(self, instance, **kwargs): - url = '/%s' % self.resource.get_url_path_name() + url = "/%s" % self.resource.get_url_path_name() response = self.client.post(url, instance.serialize(), **kwargs) if response.status_code != http_client.OK: self.handle_error(response) @@ -311,7 +323,7 @@ def create(self, instance, **kwargs): @add_auth_token_to_kwargs_from_env def update(self, instance, **kwargs): - url = '/%s/%s' % (self.resource.get_url_path_name(), instance.id) + url = "/%s/%s" % (self.resource.get_url_path_name(), instance.id) response = self.client.put(url, instance.serialize(), **kwargs) if response.status_code != http_client.OK: self.handle_error(response) @@ -320,12 +332,14 @@ def update(self, instance, **kwargs): @add_auth_token_to_kwargs_from_env def delete(self, instance, **kwargs): - url = '/%s/%s' % (self.resource.get_url_path_name(), instance.id) + url = "/%s/%s" % (self.resource.get_url_path_name(), instance.id) response = self.client.delete(url, **kwargs) - if response.status_code not in [http_client.OK, - http_client.NO_CONTENT, - http_client.NOT_FOUND]: + if response.status_code not in [ + http_client.OK, + http_client.NO_CONTENT, + http_client.NOT_FOUND, + ]: self.handle_error(response) return False @@ -333,11 +347,13 @@ def delete(self, instance, **kwargs): @add_auth_token_to_kwargs_from_env def delete_by_id(self, instance_id, **kwargs): - url = '/%s/%s' % (self.resource.get_url_path_name(), instance_id) + url = "/%s/%s" % (self.resource.get_url_path_name(), instance_id) response = self.client.delete(url, **kwargs) - if response.status_code not in [http_client.OK, - http_client.NO_CONTENT, - http_client.NOT_FOUND]: + if response.status_code not in [ + http_client.OK, + http_client.NO_CONTENT, + http_client.NOT_FOUND, + ]: self.handle_error(response) return False try: @@ -357,18 +373,21 @@ def __init__(self, resource, endpoint, cacert=None, debug=False): @add_auth_token_to_kwargs_from_env def match(self, instance, **kwargs): - url = '/%s/match' % self.resource.get_url_path_name() + url = "/%s/match" % self.resource.get_url_path_name() response = self.client.post(url, instance.serialize(), **kwargs) if response.status_code != http_client.OK: self.handle_error(response) match = response.json() - return (self.resource.deserialize(match['actionalias']), match['representation']) + return ( + self.resource.deserialize(match["actionalias"]), + match["representation"], + ) class ActionAliasExecutionManager(ResourceManager): @add_auth_token_to_kwargs_from_env def match_and_execute(self, instance, **kwargs): - url = '/%s/match_and_execute' % self.resource.get_url_path_name() + url = "/%s/match_and_execute" % self.resource.get_url_path_name() response = self.client.post(url, instance.serialize(), **kwargs) if response.status_code != http_client.OK: @@ -380,7 +399,10 @@ def match_and_execute(self, instance, **kwargs): class ActionResourceManager(ResourceManager): @add_auth_token_to_kwargs_from_env def get_entrypoint(self, ref_or_id, **kwargs): - url = '/%s/views/entry_point/%s' % (self.resource.get_url_path_name(), ref_or_id) + url = "/%s/views/entry_point/%s" % ( + self.resource.get_url_path_name(), + ref_or_id, + ) response = self.client.get(url, **kwargs) if response.status_code != http_client.OK: @@ -391,20 +413,30 @@ def get_entrypoint(self, ref_or_id, **kwargs): class ExecutionResourceManager(ResourceManager): @add_auth_token_to_kwargs_from_env - def re_run(self, execution_id, parameters=None, tasks=None, no_reset=None, delay=0, **kwargs): - url = '/%s/%s/re_run' % (self.resource.get_url_path_name(), execution_id) + def re_run( + self, + execution_id, + parameters=None, + tasks=None, + no_reset=None, + delay=0, + **kwargs, + ): + url = "/%s/%s/re_run" % (self.resource.get_url_path_name(), execution_id) tasks = tasks or [] no_reset = no_reset or [] if list(set(no_reset) - set(tasks)): - raise ValueError('List of tasks to reset does not match the tasks to rerun.') + raise ValueError( + "List of tasks to reset does not match the tasks to rerun." + ) data = { - 'parameters': parameters or {}, - 'tasks': tasks, - 'reset': list(set(tasks) - set(no_reset)), - 'delay': delay + "parameters": parameters or {}, + "tasks": tasks, + "reset": list(set(tasks) - set(no_reset)), + "delay": delay, } response = self.client.post(url, data, **kwargs) @@ -416,10 +448,10 @@ def re_run(self, execution_id, parameters=None, tasks=None, no_reset=None, delay @add_auth_token_to_kwargs_from_env def get_output(self, execution_id, output_type=None, **kwargs): - url = '/%s/%s/output' % (self.resource.get_url_path_name(), execution_id) + url = "/%s/%s/output" % (self.resource.get_url_path_name(), execution_id) if output_type: - url += '?' + urllib.parse.urlencode({'output_type': output_type}) + url += "?" + urllib.parse.urlencode({"output_type": output_type}) response = self.client.get(url, **kwargs) if response.status_code != http_client.OK: @@ -429,8 +461,8 @@ def get_output(self, execution_id, output_type=None, **kwargs): @add_auth_token_to_kwargs_from_env def pause(self, execution_id, **kwargs): - url = '/%s/%s' % (self.resource.get_url_path_name(), execution_id) - data = {'status': 'pausing'} + url = "/%s/%s" % (self.resource.get_url_path_name(), execution_id) + data = {"status": "pausing"} response = self.client.put(url, data, **kwargs) @@ -441,8 +473,8 @@ def pause(self, execution_id, **kwargs): @add_auth_token_to_kwargs_from_env def resume(self, execution_id, **kwargs): - url = '/%s/%s' % (self.resource.get_url_path_name(), execution_id) - data = {'status': 'resuming'} + url = "/%s/%s" % (self.resource.get_url_path_name(), execution_id) + data = {"status": "resuming"} response = self.client.put(url, data, **kwargs) @@ -453,14 +485,14 @@ def resume(self, execution_id, **kwargs): @add_auth_token_to_kwargs_from_env def get_children(self, execution_id, **kwargs): - url = '/%s/%s/children' % (self.resource.get_url_path_name(), execution_id) + url = "/%s/%s/children" % (self.resource.get_url_path_name(), execution_id) - depth = kwargs.pop('depth', -1) + depth = kwargs.pop("depth", -1) - params = kwargs.pop('params', {}) + params = kwargs.pop("params", {}) if depth: - params['depth'] = depth + params["depth"] = depth response = self.client.get(url=url, params=params, **kwargs) if response.status_code != http_client.OK: @@ -469,19 +501,15 @@ def get_children(self, execution_id, **kwargs): class InquiryResourceManager(ResourceManager): - @add_auth_token_to_kwargs_from_env def respond(self, inquiry_id, inquiry_response, **kwargs): """ Update st2.inquiry.respond action Update st2client respond command to use this? """ - url = '/%s/%s' % (self.resource.get_url_path_name(), inquiry_id) + url = "/%s/%s" % (self.resource.get_url_path_name(), inquiry_id) - payload = { - "id": inquiry_id, - "response": inquiry_response - } + payload = {"id": inquiry_id, "response": inquiry_response} response = self.client.put(url, payload, **kwargs) @@ -494,7 +522,10 @@ def respond(self, inquiry_id, inquiry_response, **kwargs): class TriggerInstanceResourceManager(ResourceManager): @add_auth_token_to_kwargs_from_env def re_emit(self, trigger_instance_id, **kwargs): - url = '/%s/%s/re_emit' % (self.resource.get_url_path_name(), trigger_instance_id) + url = "/%s/%s/re_emit" % ( + self.resource.get_url_path_name(), + trigger_instance_id, + ) response = self.client.post(url, None, **kwargs) if response.status_code != http_client.OK: self.handle_error(response) @@ -508,11 +539,11 @@ class AsyncRequest(Resource): class PackResourceManager(ResourceManager): @add_auth_token_to_kwargs_from_env def install(self, packs, force=False, skip_dependencies=False, **kwargs): - url = '/%s/install' % (self.resource.get_url_path_name()) + url = "/%s/install" % (self.resource.get_url_path_name()) payload = { - 'packs': packs, - 'force': force, - 'skip_dependencies': skip_dependencies + "packs": packs, + "force": force, + "skip_dependencies": skip_dependencies, } response = self.client.post(url, payload, **kwargs) if response.status_code != http_client.OK: @@ -522,8 +553,8 @@ def install(self, packs, force=False, skip_dependencies=False, **kwargs): @add_auth_token_to_kwargs_from_env def remove(self, packs, **kwargs): - url = '/%s/uninstall' % (self.resource.get_url_path_name()) - response = self.client.post(url, {'packs': packs}, **kwargs) + url = "/%s/uninstall" % (self.resource.get_url_path_name()) + response = self.client.post(url, {"packs": packs}, **kwargs) if response.status_code != http_client.OK: self.handle_error(response) instance = AsyncRequest.deserialize(response.json()) @@ -531,11 +562,11 @@ def remove(self, packs, **kwargs): @add_auth_token_to_kwargs_from_env def search(self, args, ignore_errors=False, **kwargs): - url = '/%s/index/search' % (self.resource.get_url_path_name()) - if 'query' in vars(args): - payload = {'query': args.query} + url = "/%s/index/search" % (self.resource.get_url_path_name()) + if "query" in vars(args): + payload = {"query": args.query} else: - payload = {'pack': args.pack} + payload = {"pack": args.pack} response = self.client.post(url, payload, **kwargs) @@ -552,12 +583,12 @@ def search(self, args, ignore_errors=False, **kwargs): @add_auth_token_to_kwargs_from_env def register(self, packs=None, types=None, **kwargs): - url = '/%s/register' % (self.resource.get_url_path_name()) + url = "/%s/register" % (self.resource.get_url_path_name()) payload = {} if types: - payload['types'] = types + payload["types"] = types if packs: - payload['packs'] = packs + payload["packs"] = packs response = self.client.post(url, payload, **kwargs) if response.status_code != http_client.OK: self.handle_error(response) @@ -568,7 +599,7 @@ def register(self, packs=None, types=None, **kwargs): class ConfigManager(ResourceManager): @add_auth_token_to_kwargs_from_env def update(self, instance, **kwargs): - url = '/%s/%s' % (self.resource.get_url_path_name(), instance.pack) + url = "/%s/%s" % (self.resource.get_url_path_name(), instance.pack) response = self.client.put(url, instance.values, **kwargs) if response.status_code != http_client.OK: self.handle_error(response) @@ -584,16 +615,13 @@ def __init__(self, resource, endpoint, cacert=None, debug=False): @add_auth_token_to_kwargs_from_env def post_generic_webhook(self, trigger, payload=None, trace_tag=None, **kwargs): - url = '/webhooks/st2' + url = "/webhooks/st2" headers = {} - data = { - 'trigger': trigger, - 'payload': payload or {} - } + data = {"trigger": trigger, "payload": payload or {}} if trace_tag: - headers['St2-Trace-Tag'] = trace_tag + headers["St2-Trace-Tag"] = trace_tag response = self.client.post(url, data=data, headers=headers, **kwargs) @@ -604,17 +632,20 @@ def post_generic_webhook(self, trigger, payload=None, trace_tag=None, **kwargs): @add_auth_token_to_kwargs_from_env def match(self, instance, **kwargs): - url = '/%s/match' % self.resource.get_url_path_name() + url = "/%s/match" % self.resource.get_url_path_name() response = self.client.post(url, instance.serialize(), **kwargs) if response.status_code != http_client.OK: self.handle_error(response) match = response.json() - return (self.resource.deserialize(match['actionalias']), match['representation']) + return ( + self.resource.deserialize(match["actionalias"]), + match["representation"], + ) class StreamManager(object): def __init__(self, endpoint, cacert=None, debug=False): - self._url = httpclient.get_url_without_trailing_slash(endpoint) + '/stream' + self._url = httpclient.get_url_without_trailing_slash(endpoint) + "/stream" self.debug = debug self.cacert = cacert @@ -631,25 +662,25 @@ def listen(self, events=None, **kwargs): if events and isinstance(events, six.string_types): events = [events] - if 'token' in kwargs: - query_params['x-auth-token'] = kwargs.get('token') + if "token" in kwargs: + query_params["x-auth-token"] = kwargs.get("token") - if 'api_key' in kwargs: - query_params['st2-api-key'] = kwargs.get('api_key') + if "api_key" in kwargs: + query_params["st2-api-key"] = kwargs.get("api_key") - if 'end_event' in kwargs: - query_params['end_event'] = kwargs.get('end_event') + if "end_event" in kwargs: + query_params["end_event"] = kwargs.get("end_event") - if 'end_execution_id' in kwargs: - query_params['end_execution_id'] = kwargs.get('end_execution_id') + if "end_execution_id" in kwargs: + query_params["end_execution_id"] = kwargs.get("end_execution_id") if events: - query_params['events'] = ','.join(events) + query_params["events"] = ",".join(events) if self.cacert is not None: - request_params['verify'] = self.cacert + request_params["verify"] = self.cacert - query_string = '?' + urllib.parse.urlencode(query_params) + query_string = "?" + urllib.parse.urlencode(query_params) url = url + query_string response = requests.get(url, stream=True, **request_params) @@ -667,36 +698,38 @@ class WorkflowManager(object): def __init__(self, endpoint, cacert, debug): self.debug = debug self.cacert = cacert - self.endpoint = endpoint + '/workflows' - self.client = httpclient.HTTPClient(root=self.endpoint, cacert=cacert, debug=debug) + self.endpoint = endpoint + "/workflows" + self.client = httpclient.HTTPClient( + root=self.endpoint, cacert=cacert, debug=debug + ) @staticmethod def handle_error(response): try: content = response.json() - fault = content.get('faultstring', '') if content else '' + fault = content.get("faultstring", "") if content else "" if fault: - response.reason += '\nMESSAGE: %s' % fault + response.reason += "\nMESSAGE: %s" % fault except Exception as e: response.reason += ( - '\nUnable to retrieve detailed message ' - 'from the HTTP response. %s\n' % six.text_type(e) + "\nUnable to retrieve detailed message " + "from the HTTP response. %s\n" % six.text_type(e) ) response.raise_for_status() @add_auth_token_to_kwargs_from_env def inspect(self, definition, **kwargs): - url = '/inspect' + url = "/inspect" if not isinstance(definition, six.string_types): - raise TypeError('Workflow definition is not type of string.') + raise TypeError("Workflow definition is not type of string.") - if 'headers' not in kwargs: - kwargs['headers'] = {} + if "headers" not in kwargs: + kwargs["headers"] = {} - kwargs['headers']['content-type'] = 'text/plain' + kwargs["headers"]["content-type"] = "text/plain" response = self.client.post_raw(url, definition, **kwargs) @@ -709,7 +742,7 @@ def inspect(self, definition, **kwargs): class ServiceRegistryGroupsManager(ResourceManager): @add_auth_token_to_kwargs_from_env def list(self, **kwargs): - url = '/service_registry/groups' + url = "/service_registry/groups" headers = {} response = self.client.get(url, headers=headers, **kwargs) @@ -717,21 +750,20 @@ def list(self, **kwargs): if response.status_code != http_client.OK: self.handle_error(response) - groups = response.json()['groups'] + groups = response.json()["groups"] result = [] for group in groups: - item = self.resource.deserialize({'group_id': group}) + item = self.resource.deserialize({"group_id": group}) result.append(item) return result class ServiceRegistryMembersManager(ResourceManager): - @add_auth_token_to_kwargs_from_env def list(self, group_id, **kwargs): - url = '/service_registry/groups/%s/members' % (group_id) + url = "/service_registry/groups/%s/members" % (group_id) headers = {} response = self.client.get(url, headers=headers, **kwargs) @@ -739,14 +771,14 @@ def list(self, group_id, **kwargs): if response.status_code != http_client.OK: self.handle_error(response) - members = response.json()['members'] + members = response.json()["members"] result = [] for member in members: data = { - 'group_id': group_id, - 'member_id': member['member_id'], - 'capabilities': member['capabilities'] + "group_id": group_id, + "member_id": member["member_id"], + "capabilities": member["capabilities"], } item = self.resource.deserialize(data) result.append(item) diff --git a/st2client/st2client/models/inquiry.py b/st2client/st2client/models/inquiry.py index 5d1a1076f5d..93161ee68fe 100644 --- a/st2client/st2client/models/inquiry.py +++ b/st2client/st2client/models/inquiry.py @@ -24,15 +24,8 @@ class Inquiry(core.Resource): - _display_name = 'Inquiry' - _plural = 'Inquiries' - _plural_display_name = 'Inquiries' - _url_path = 'inquiries' - _repr_attributes = [ - 'id', - 'schema', - 'roles', - 'users', - 'route', - 'ttl' - ] + _display_name = "Inquiry" + _plural = "Inquiries" + _plural_display_name = "Inquiries" + _url_path = "inquiries" + _repr_attributes = ["id", "schema", "roles", "users", "route", "ttl"] diff --git a/st2client/st2client/models/keyvalue.py b/st2client/st2client/models/keyvalue.py index f7095a4b8f1..5bcd1de8dec 100644 --- a/st2client/st2client/models/keyvalue.py +++ b/st2client/st2client/models/keyvalue.py @@ -24,11 +24,11 @@ class KeyValuePair(core.Resource): - _alias = 'Key' - _display_name = 'Key Value Pair' - _plural = 'Keys' - _plural_display_name = 'Key Value Pairs' - _repr_attributes = ['name', 'value'] + _alias = "Key" + _display_name = "Key Value Pair" + _plural = "Keys" + _plural_display_name = "Key Value Pairs" + _repr_attributes = ["name", "value"] # Note: This is a temporary hack until we refactor client and make it support non id PKs def get_id(self): diff --git a/st2client/st2client/models/pack.py b/st2client/st2client/models/pack.py index 5d681266ada..7333c1a28e0 100644 --- a/st2client/st2client/models/pack.py +++ b/st2client/st2client/models/pack.py @@ -19,8 +19,8 @@ class Pack(core.Resource): - _display_name = 'Pack' - _plural = 'Packs' - _plural_display_name = 'Packs' - _url_path = 'packs' - _repr_attributes = ['name', 'description', 'version', 'author'] + _display_name = "Pack" + _plural = "Packs" + _plural_display_name = "Packs" + _url_path = "packs" + _repr_attributes = ["name", "description", "version", "author"] diff --git a/st2client/st2client/models/policy.py b/st2client/st2client/models/policy.py index 851779d7fd2..4b8bb0c8139 100644 --- a/st2client/st2client/models/policy.py +++ b/st2client/st2client/models/policy.py @@ -24,13 +24,13 @@ class PolicyType(core.Resource): - _alias = 'Policy-Type' - _display_name = 'Policy type' - _plural = 'PolicyTypes' - _plural_display_name = 'Policy types' - _repr_attributes = ['ref', 'enabled', 'description'] + _alias = "Policy-Type" + _display_name = "Policy type" + _plural = "PolicyTypes" + _plural_display_name = "Policy types" + _repr_attributes = ["ref", "enabled", "description"] class Policy(core.Resource): - _plural = 'Policies' - _repr_attributes = ['name', 'pack', 'enabled', 'policy_type', 'resource_ref'] + _plural = "Policies" + _repr_attributes = ["name", "pack", "enabled", "policy_type", "resource_ref"] diff --git a/st2client/st2client/models/rbac.py b/st2client/st2client/models/rbac.py index 6df4aa4f941..94c765ddf38 100644 --- a/st2client/st2client/models/rbac.py +++ b/st2client/st2client/models/rbac.py @@ -17,25 +17,22 @@ from st2client.models import core -__all__ = [ - 'Role', - 'UserRoleAssignment' -] +__all__ = ["Role", "UserRoleAssignment"] class Role(core.Resource): - _alias = 'role' - _display_name = 'Role' - _plural = 'Roles' - _plural_display_name = 'Roles' - _repr_attributes = ['id', 'name', 'system'] - _url_path = 'rbac/roles' + _alias = "role" + _display_name = "Role" + _plural = "Roles" + _plural_display_name = "Roles" + _repr_attributes = ["id", "name", "system"] + _url_path = "rbac/roles" class UserRoleAssignment(core.Resource): - _alias = 'role-assignment' - _display_name = 'Role Assignment' - _plural = 'RoleAssignments' - _plural_display_name = 'Role Assignments' - _repr_attributes = ['id', 'role', 'user', 'is_remote'] - _url_path = 'rbac/role_assignments' + _alias = "role-assignment" + _display_name = "Role Assignment" + _plural = "RoleAssignments" + _plural_display_name = "Role Assignments" + _repr_attributes = ["id", "role", "user", "is_remote"] + _url_path = "rbac/role_assignments" diff --git a/st2client/st2client/models/reactor.py b/st2client/st2client/models/reactor.py index 140d1aaf50f..ef4c054f69e 100644 --- a/st2client/st2client/models/reactor.py +++ b/st2client/st2client/models/reactor.py @@ -24,43 +24,49 @@ class Sensor(core.Resource): - _plural = 'Sensortypes' - _repr_attributes = ['name', 'pack'] + _plural = "Sensortypes" + _repr_attributes = ["name", "pack"] class TriggerType(core.Resource): - _alias = 'Trigger' - _display_name = 'Trigger' - _plural = 'Triggertypes' - _plural_display_name = 'Triggers' - _repr_attributes = ['name', 'pack'] + _alias = "Trigger" + _display_name = "Trigger" + _plural = "Triggertypes" + _plural_display_name = "Triggers" + _repr_attributes = ["name", "pack"] class TriggerInstance(core.Resource): - _alias = 'Trigger-Instance' - _display_name = 'TriggerInstance' - _plural = 'Triggerinstances' - _plural_display_name = 'TriggerInstances' - _repr_attributes = ['id', 'trigger', 'occurrence_time', 'payload', 'status'] + _alias = "Trigger-Instance" + _display_name = "TriggerInstance" + _plural = "Triggerinstances" + _plural_display_name = "TriggerInstances" + _repr_attributes = ["id", "trigger", "occurrence_time", "payload", "status"] class Trigger(core.Resource): - _alias = 'TriggerSpecification' - _display_name = 'Trigger Specification' - _plural = 'Triggers' - _plural_display_name = 'Trigger Specifications' - _repr_attributes = ['name', 'pack'] + _alias = "TriggerSpecification" + _display_name = "Trigger Specification" + _plural = "Triggers" + _plural_display_name = "Trigger Specifications" + _repr_attributes = ["name", "pack"] class Rule(core.Resource): - _alias = 'Rule' - _plural = 'Rules' - _repr_attributes = ['name', 'pack', 'trigger', 'criteria', 'enabled'] + _alias = "Rule" + _plural = "Rules" + _repr_attributes = ["name", "pack", "trigger", "criteria", "enabled"] class RuleEnforcement(core.Resource): - _alias = 'Rule-Enforcement' - _plural = 'RuleEnforcements' - _display_name = 'Rule Enforcement' - _plural_display_name = 'Rule Enforcements' - _repr_attributes = ['id', 'trigger_instance_id', 'execution_id', 'rule.ref', 'enforced_at'] + _alias = "Rule-Enforcement" + _plural = "RuleEnforcements" + _display_name = "Rule Enforcement" + _plural_display_name = "Rule Enforcements" + _repr_attributes = [ + "id", + "trigger_instance_id", + "execution_id", + "rule.ref", + "enforced_at", + ] diff --git a/st2client/st2client/models/service_registry.py b/st2client/st2client/models/service_registry.py index 3b3057a3c36..ca95cd73cbe 100644 --- a/st2client/st2client/models/service_registry.py +++ b/st2client/st2client/models/service_registry.py @@ -17,32 +17,27 @@ from st2client.models import core -__all__ = [ - 'ServiceRegistry', - - 'ServiceRegistryGroup', - 'ServiceRegistryMember' -] +__all__ = ["ServiceRegistry", "ServiceRegistryGroup", "ServiceRegistryMember"] class ServiceRegistry(core.Resource): - _alias = 'service-registry' - _display_name = 'Service Registry' - _plural = 'Service Registry' - _plural_display_name = 'Service Registry' + _alias = "service-registry" + _display_name = "Service Registry" + _plural = "Service Registry" + _plural_display_name = "Service Registry" class ServiceRegistryGroup(core.Resource): - _alias = 'group' - _display_name = 'Group' - _plural = 'Groups' - _plural_display_name = 'Groups' - _repr_attributes = ['group_id'] + _alias = "group" + _display_name = "Group" + _plural = "Groups" + _plural_display_name = "Groups" + _repr_attributes = ["group_id"] class ServiceRegistryMember(core.Resource): - _alias = 'member' - _display_name = 'Group Member' - _plural = 'Group Members' - _plural_display_name = 'Group Members' - _repr_attributes = ['group_id', 'member_id'] + _alias = "member" + _display_name = "Group Member" + _plural = "Group Members" + _plural_display_name = "Group Members" + _repr_attributes = ["group_id", "member_id"] diff --git a/st2client/st2client/models/timer.py b/st2client/st2client/models/timer.py index 4ba58547f3f..fbfbd6cfcd5 100644 --- a/st2client/st2client/models/timer.py +++ b/st2client/st2client/models/timer.py @@ -24,7 +24,7 @@ class Timer(core.Resource): - _alias = 'Timer' - _display_name = 'Timer' - _plural = 'Timers' - _plural_display_name = 'Timers' + _alias = "Timer" + _display_name = "Timer" + _plural = "Timers" + _plural_display_name = "Timers" diff --git a/st2client/st2client/models/trace.py b/st2client/st2client/models/trace.py index a03b4a88125..3b7bfe44499 100644 --- a/st2client/st2client/models/trace.py +++ b/st2client/st2client/models/trace.py @@ -19,8 +19,8 @@ class Trace(core.Resource): - _alias = 'Trace' - _display_name = 'Trace' - _plural = 'Traces' - _plural_display_name = 'Traces' - _repr_attributes = ['id', 'trace_tag'] + _alias = "Trace" + _display_name = "Trace" + _plural = "Traces" + _plural_display_name = "Traces" + _repr_attributes = ["id", "trace_tag"] diff --git a/st2client/st2client/models/webhook.py b/st2client/st2client/models/webhook.py index 83d939f061e..161d1bdb4c6 100644 --- a/st2client/st2client/models/webhook.py +++ b/st2client/st2client/models/webhook.py @@ -24,8 +24,8 @@ class Webhook(core.Resource): - _alias = 'Webhook' - _display_name = 'Webhook' - _plural = 'Webhooks' - _plural_display_name = 'Webhooks' - _repr_attributes = ['parameters', 'type', 'pack', 'name'] + _alias = "Webhook" + _display_name = "Webhook" + _plural = "Webhooks" + _plural_display_name = "Webhooks" + _repr_attributes = ["parameters", "type", "pack", "name"] diff --git a/st2client/st2client/shell.py b/st2client/st2client/shell.py index ac6108d796f..7d3359c532c 100755 --- a/st2client/st2client/shell.py +++ b/st2client/st2client/shell.py @@ -25,6 +25,7 @@ # Ignore CryptographyDeprecationWarning warnings which appear on older versions of Python 2.7 import warnings from cryptography.utils import CryptographyDeprecationWarning + warnings.filterwarnings("ignore", category=CryptographyDeprecationWarning) import os @@ -66,13 +67,13 @@ from st2client.commands.auth import LoginCommand -__all__ = [ - 'Shell' -] +__all__ = ["Shell"] LOGGER = logging.getLogger(__name__) -CLI_DESCRIPTION = 'CLI for StackStorm event-driven automation platform. https://stackstorm.com' +CLI_DESCRIPTION = ( + "CLI for StackStorm event-driven automation platform. https://stackstorm.com" +) USAGE_STRING = """ Usage: %(prog)s [options] [options] @@ -83,15 +84,19 @@ %(prog)s --debug run core.local cmd=date """.strip() -NON_UTF8_LOCALE = """ +NON_UTF8_LOCALE = ( + """ Locale %s with encoding %s which is not UTF-8 is used. This means that some functionality which relies on outputting unicode characters won't work. You are encouraged to use UTF-8 locale by setting LC_ALL environment variable to en_US.UTF-8 or similar. -""".strip().replace('\n', ' ').replace(' ', ' ') +""".strip() + .replace("\n", " ") + .replace(" ", " ") +) -PACKAGE_METADATA_FILE_PATH = '/opt/stackstorm/st2/package.meta' +PACKAGE_METADATA_FILE_PATH = "/opt/stackstorm/st2/package.meta" def get_stackstorm_version(): @@ -101,7 +106,7 @@ def get_stackstorm_version(): :rtype: ``str`` """ - if 'dev' in __version__: + if "dev" in __version__: version = __version__ if not os.path.isfile(PACKAGE_METADATA_FILE_PATH): @@ -115,11 +120,11 @@ def get_stackstorm_version(): return version try: - git_revision = config.get('server', 'git_sha') + git_revision = config.get("server", "git_sha") except Exception: return version - version = '%s (%s)' % (version, git_revision) + version = "%s (%s)" % (version, git_revision) else: version = __version__ @@ -143,214 +148,237 @@ def __init__(self): # Set up general program options. self.parser.add_argument( - '--version', - action='version', - version='%(prog)s {version}, on Python {python_major}.{python_minor}.{python_patch}' - .format(version=get_stackstorm_version(), - python_major=sys.version_info.major, - python_minor=sys.version_info.minor, - python_patch=sys.version_info.micro)) + "--version", + action="version", + version="%(prog)s {version}, on Python {python_major}.{python_minor}.{python_patch}".format( + version=get_stackstorm_version(), + python_major=sys.version_info.major, + python_minor=sys.version_info.minor, + python_patch=sys.version_info.micro, + ), + ) self.parser.add_argument( - '--url', - action='store', - dest='base_url', + "--url", + action="store", + dest="base_url", default=None, - help='Base URL for the API servers. Assumes all servers use the ' - 'same base URL and default ports are used. Get ST2_BASE_URL ' - 'from the environment variables by default.' + help="Base URL for the API servers. Assumes all servers use the " + "same base URL and default ports are used. Get ST2_BASE_URL " + "from the environment variables by default.", ) self.parser.add_argument( - '--auth-url', - action='store', - dest='auth_url', + "--auth-url", + action="store", + dest="auth_url", default=None, - help='URL for the authentication service. Get ST2_AUTH_URL ' - 'from the environment variables by default.' + help="URL for the authentication service. Get ST2_AUTH_URL " + "from the environment variables by default.", ) self.parser.add_argument( - '--api-url', - action='store', - dest='api_url', + "--api-url", + action="store", + dest="api_url", default=None, - help='URL for the API server. Get ST2_API_URL ' - 'from the environment variables by default.' + help="URL for the API server. Get ST2_API_URL " + "from the environment variables by default.", ) self.parser.add_argument( - '--stream-url', - action='store', - dest='stream_url', + "--stream-url", + action="store", + dest="stream_url", default=None, - help='URL for the stream endpoint. Get ST2_STREAM_URL' - 'from the environment variables by default.' + help="URL for the stream endpoint. Get ST2_STREAM_URL" + "from the environment variables by default.", ) self.parser.add_argument( - '--api-version', - action='store', - dest='api_version', + "--api-version", + action="store", + dest="api_version", default=None, - help='API version to use. Get ST2_API_VERSION ' - 'from the environment variables by default.' + help="API version to use. Get ST2_API_VERSION " + "from the environment variables by default.", ) self.parser.add_argument( - '--cacert', - action='store', - dest='cacert', + "--cacert", + action="store", + dest="cacert", default=None, - help='Path to the CA cert bundle for the SSL endpoints. ' - 'Get ST2_CACERT from the environment variables by default. ' - 'If this is not provided, then SSL cert will not be verified.' + help="Path to the CA cert bundle for the SSL endpoints. " + "Get ST2_CACERT from the environment variables by default. " + "If this is not provided, then SSL cert will not be verified.", ) self.parser.add_argument( - '--config-file', - action='store', - dest='config_file', + "--config-file", + action="store", + dest="config_file", default=None, - help='Path to the CLI config file' + help="Path to the CLI config file", ) self.parser.add_argument( - '--print-config', - action='store_true', - dest='print_config', + "--print-config", + action="store_true", + dest="print_config", default=False, - help='Parse the config file and print the values' + help="Parse the config file and print the values", ) self.parser.add_argument( - '--skip-config', - action='store_true', - dest='skip_config', + "--skip-config", + action="store_true", + dest="skip_config", default=False, - help='Don\'t parse and use the CLI config file' + help="Don't parse and use the CLI config file", ) self.parser.add_argument( - '--debug', - action='store_true', - dest='debug', + "--debug", + action="store_true", + dest="debug", default=False, - help='Enable debug mode' + help="Enable debug mode", ) # Set up list of commands and subcommands. - self.subparsers = self.parser.add_subparsers(dest='parser') + self.subparsers = self.parser.add_subparsers(dest="parser") self.subparsers.required = True self.commands = {} - self.commands['run'] = action.ActionRunCommand( - models.Action, self, self.subparsers, name='run', add_help=False) + self.commands["run"] = action.ActionRunCommand( + models.Action, self, self.subparsers, name="run", add_help=False + ) - self.commands['action'] = action.ActionBranch( - 'An activity that happens as a response to the external event.', - self, self.subparsers) + self.commands["action"] = action.ActionBranch( + "An activity that happens as a response to the external event.", + self, + self.subparsers, + ) - self.commands['action-alias'] = action_alias.ActionAliasBranch( - 'Action aliases.', - self, self.subparsers) + self.commands["action-alias"] = action_alias.ActionAliasBranch( + "Action aliases.", self, self.subparsers + ) - self.commands['auth'] = auth.TokenCreateCommand( - models.Token, self, self.subparsers, name='auth') + self.commands["auth"] = auth.TokenCreateCommand( + models.Token, self, self.subparsers, name="auth" + ) - self.commands['login'] = auth.LoginCommand( - models.Token, self, self.subparsers, name='login') + self.commands["login"] = auth.LoginCommand( + models.Token, self, self.subparsers, name="login" + ) - self.commands['whoami'] = auth.WhoamiCommand( - models.Token, self, self.subparsers, name='whoami') + self.commands["whoami"] = auth.WhoamiCommand( + models.Token, self, self.subparsers, name="whoami" + ) - self.commands['api-key'] = auth.ApiKeyBranch( - 'API Keys.', - self, self.subparsers) + self.commands["api-key"] = auth.ApiKeyBranch("API Keys.", self, self.subparsers) - self.commands['execution'] = action.ActionExecutionBranch( - 'An invocation of an action.', - self, self.subparsers) + self.commands["execution"] = action.ActionExecutionBranch( + "An invocation of an action.", self, self.subparsers + ) - self.commands['inquiry'] = inquiry.InquiryBranch( - 'Inquiries provide an opportunity to ask a question ' - 'and wait for a response in a workflow.', - self, self.subparsers) + self.commands["inquiry"] = inquiry.InquiryBranch( + "Inquiries provide an opportunity to ask a question " + "and wait for a response in a workflow.", + self, + self.subparsers, + ) - self.commands['key'] = keyvalue.KeyValuePairBranch( - 'Key value pair is used to store commonly used configuration ' - 'for reuse in sensors, actions, and rules.', - self, self.subparsers) + self.commands["key"] = keyvalue.KeyValuePairBranch( + "Key value pair is used to store commonly used configuration " + "for reuse in sensors, actions, and rules.", + self, + self.subparsers, + ) - self.commands['pack'] = pack.PackBranch( - 'A group of related integration resources: ' - 'actions, rules, and sensors.', - self, self.subparsers) + self.commands["pack"] = pack.PackBranch( + "A group of related integration resources: " "actions, rules, and sensors.", + self, + self.subparsers, + ) - self.commands['policy'] = policy.PolicyBranch( - 'Policy that is enforced on a resource.', - self, self.subparsers) + self.commands["policy"] = policy.PolicyBranch( + "Policy that is enforced on a resource.", self, self.subparsers + ) - self.commands['policy-type'] = policy.PolicyTypeBranch( - 'Type of policy that can be applied to resources.', - self, self.subparsers) + self.commands["policy-type"] = policy.PolicyTypeBranch( + "Type of policy that can be applied to resources.", self, self.subparsers + ) - self.commands['rule'] = rule.RuleBranch( + self.commands["rule"] = rule.RuleBranch( 'A specification to invoke an "action" on a "trigger" selectively ' - 'based on some criteria.', - self, self.subparsers) + "based on some criteria.", + self, + self.subparsers, + ) - self.commands['webhook'] = webhook.WebhookBranch( - 'Webhooks.', - self, self.subparsers) + self.commands["webhook"] = webhook.WebhookBranch( + "Webhooks.", self, self.subparsers + ) - self.commands['timer'] = timer.TimerBranch( - 'Timers.', - self, self.subparsers) + self.commands["timer"] = timer.TimerBranch("Timers.", self, self.subparsers) - self.commands['runner'] = resource.ResourceBranch( + self.commands["runner"] = resource.ResourceBranch( models.RunnerType, - 'Runner is a type of handler for a specific class of actions.', - self, self.subparsers, read_only=True, has_disable=True) + "Runner is a type of handler for a specific class of actions.", + self, + self.subparsers, + read_only=True, + has_disable=True, + ) - self.commands['sensor'] = sensor.SensorBranch( - 'An adapter which allows you to integrate StackStorm with external system.', - self, self.subparsers) + self.commands["sensor"] = sensor.SensorBranch( + "An adapter which allows you to integrate StackStorm with external system.", + self, + self.subparsers, + ) - self.commands['trace'] = trace.TraceBranch( - 'A group of executions, rules and triggerinstances that are related.', - self, self.subparsers) + self.commands["trace"] = trace.TraceBranch( + "A group of executions, rules and triggerinstances that are related.", + self, + self.subparsers, + ) - self.commands['trigger'] = trigger.TriggerTypeBranch( - 'An external event that is mapped to a st2 input. It is the ' - 'st2 invocation point.', - self, self.subparsers) + self.commands["trigger"] = trigger.TriggerTypeBranch( + "An external event that is mapped to a st2 input. It is the " + "st2 invocation point.", + self, + self.subparsers, + ) - self.commands['trigger-instance'] = triggerinstance.TriggerInstanceBranch( - 'Actual instances of triggers received by st2.', - self, self.subparsers) + self.commands["trigger-instance"] = triggerinstance.TriggerInstanceBranch( + "Actual instances of triggers received by st2.", self, self.subparsers + ) - self.commands['rule-enforcement'] = rule_enforcement.RuleEnforcementBranch( - 'Models that represent enforcement of rules.', - self, self.subparsers) + self.commands["rule-enforcement"] = rule_enforcement.RuleEnforcementBranch( + "Models that represent enforcement of rules.", self, self.subparsers + ) - self.commands['workflow'] = workflow.WorkflowBranch( - 'Commands for workflow authoring related operations. ' - 'Only orquesta workflows are supported.', - self, self.subparsers) + self.commands["workflow"] = workflow.WorkflowBranch( + "Commands for workflow authoring related operations. " + "Only orquesta workflows are supported.", + self, + self.subparsers, + ) # Service Registry - self.commands['service-registry'] = service_registry.ServiceRegistryBranch( - 'Service registry group and membership related commands.', - self, self.subparsers) + self.commands["service-registry"] = service_registry.ServiceRegistryBranch( + "Service registry group and membership related commands.", + self, + self.subparsers, + ) # RBAC - self.commands['role'] = rbac.RoleBranch( - 'RBAC roles.', - self, self.subparsers) - self.commands['role-assignment'] = rbac.RoleAssignmentBranch( - 'RBAC role assignments.', - self, self.subparsers) + self.commands["role"] = rbac.RoleBranch("RBAC roles.", self, self.subparsers) + self.commands["role-assignment"] = rbac.RoleAssignmentBranch( + "RBAC role assignments.", self, self.subparsers + ) def run(self, argv): debug = False @@ -369,9 +397,9 @@ def run(self, argv): # Provide autocomplete for shell argcomplete.autocomplete(self.parser) - if '--print-config' in argv: + if "--print-config" in argv: # Hack because --print-config requires no command to be specified - argv = argv + ['action', 'list'] + argv = argv + ["action", "list"] # Parse command line arguments. args = self.parser.parse_args(args=argv) @@ -389,7 +417,7 @@ def run(self, argv): # Setup client and run the command try: - debug = getattr(args, 'debug', False) + debug = getattr(args, "debug", False) if debug: set_log_level_for_all_loggers(level=logging.DEBUG) @@ -399,7 +427,7 @@ def run(self, argv): # TODO: This is not so nice work-around for Python 3 because of a breaking change in # Python 3 - https://bugs.python.org/issue16308 try: - func = getattr(args, 'func') + func = getattr(args, "func") except AttributeError: parser.print_help() sys.exit(2) @@ -414,9 +442,9 @@ def run(self, argv): return 2 except Exception as e: # We allow exception to define custom exit codes - exit_code = getattr(e, 'exit_code', 1) + exit_code = getattr(e, "exit_code", 1) - print('ERROR: %s\n' % e) + print("ERROR: %s\n" % e) if debug: self._print_debug_info(args=args) @@ -426,10 +454,10 @@ def _print_config(self, args): config = self._parse_config_file(args=args) for section, options in six.iteritems(config): - print('[%s]' % (section)) + print("[%s]" % (section)) for name, value in six.iteritems(options): - print('%s = %s' % (name, value)) + print("%s = %s" % (name, value)) def _check_locale_and_print_warning(self): """ @@ -440,23 +468,23 @@ def _check_locale_and_print_warning(self): preferred_encoding = locale.getpreferredencoding() except ValueError: # Ignore unknown locale errors for now - default_locale = 'unknown' - preferred_encoding = 'unknown' + default_locale = "unknown" + preferred_encoding = "unknown" - if preferred_encoding and preferred_encoding.lower() != 'utf-8': - msg = NON_UTF8_LOCALE % (default_locale or 'unknown', preferred_encoding) + if preferred_encoding and preferred_encoding.lower() != "utf-8": + msg = NON_UTF8_LOCALE % (default_locale or "unknown", preferred_encoding) LOGGER.warn(msg) def setup_logging(argv): - debug = '--debug' in argv + debug = "--debug" in argv root = LOGGER root.setLevel(logging.WARNING) handler = logging.StreamHandler(sys.stderr) handler.setLevel(logging.WARNING) - formatter = logging.Formatter('%(asctime)s %(levelname)s - %(message)s') + formatter = logging.Formatter("%(asctime)s %(levelname)s - %(message)s") handler.setFormatter(formatter) if not debug: @@ -470,5 +498,5 @@ def main(argv=sys.argv[1:]): return Shell().run(argv) -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main(sys.argv[1:])) diff --git a/st2client/st2client/utils/color.py b/st2client/st2client/utils/color.py index 8b184021369..f1106851e2a 100644 --- a/st2client/st2client/utils/color.py +++ b/st2client/st2client/utils/color.py @@ -16,40 +16,36 @@ from __future__ import absolute_import import os -__all__ = [ - 'DisplayColors', - - 'format_status' -] +__all__ = ["DisplayColors", "format_status"] TERMINAL_SUPPORTS_ANSI_CODES = [ - 'xterm', - 'xterm-color', - 'screen', - 'vt100', - 'vt100-color', - 'xterm-256color' + "xterm", + "xterm-color", + "screen", + "vt100", + "vt100-color", + "xterm-256color", ] -DISABLED = os.environ.get('ST2_COLORIZE', '') +DISABLED = os.environ.get("ST2_COLORIZE", "") class DisplayColors(object): - RED = '\033[91m' - PURPLE = '\033[35m' - GREEN = '\033[92m' - YELLOW = '\033[93m' - BLUE = '\033[94m' - BROWN = '\033[33m' - ENDC = '\033[0m' - BOLD = '\033[1m' - UNDERLINE = '\033[4m' + RED = "\033[91m" + PURPLE = "\033[35m" + GREEN = "\033[92m" + YELLOW = "\033[93m" + BLUE = "\033[94m" + BROWN = "\033[33m" + ENDC = "\033[0m" + BOLD = "\033[1m" + UNDERLINE = "\033[4m" @staticmethod - def colorize(value, color=''): + def colorize(value, color=""): # TODO: use list of supported terminals - term = os.environ.get('TERM', None) + term = os.environ.get("TERM", None) if term is None or term.lower() not in TERMINAL_SUPPORTS_ANSI_CODES: # Terminal doesn't support colors @@ -58,33 +54,33 @@ def colorize(value, color=''): if DISABLED or not color: return value - return '%s%s%s' % (color, value, DisplayColors.ENDC) + return "%s%s%s" % (color, value, DisplayColors.ENDC) # Lookup table STATUS_LOOKUP = { - 'succeeded': DisplayColors.GREEN, - 'delayed': DisplayColors.BLUE, - 'failed': DisplayColors.RED, - 'timeout': DisplayColors.BROWN, - 'running': DisplayColors.YELLOW + "succeeded": DisplayColors.GREEN, + "delayed": DisplayColors.BLUE, + "failed": DisplayColors.RED, + "timeout": DisplayColors.BROWN, + "running": DisplayColors.YELLOW, } def format_status(value): # Support status values with elapsed info - split = value.split('(', 1) + split = value.split("(", 1) if len(split) == 2: status = split[0].strip() - remainder = '(' + split[1] + remainder = "(" + split[1] else: status = value - remainder = '' + remainder = "" color = STATUS_LOOKUP.get(status, DisplayColors.YELLOW) result = DisplayColors.colorize(status, color) if remainder: - result = result + ' ' + remainder + result = result + " " + remainder return result diff --git a/st2client/st2client/utils/date.py b/st2client/st2client/utils/date.py index b19e27f3ece..3a76a44c81a 100644 --- a/st2client/st2client/utils/date.py +++ b/st2client/st2client/utils/date.py @@ -20,10 +20,7 @@ from st2client.config import get_config -__all__ = [ - 'parse', - 'format_isodate' -] +__all__ = ["parse", "format_isodate"] def add_utc_tz(dt): @@ -39,7 +36,7 @@ def format_dt(dt): """ Format datetime object for human friendly representation. """ - value = dt.strftime('%a, %d %b %Y %H:%M:%S %Z') + value = dt.strftime("%a, %d %b %Y %H:%M:%S %Z") return value @@ -52,7 +49,7 @@ def format_isodate(value, timezone=None): :rtype: ``str`` """ if not value: - return '' + return "" # For some reason pylint thinks it returns a tuple but it returns a datetime object dt = dateutil.parser.parse(str(value)) @@ -70,6 +67,6 @@ def format_isodate_for_user_timezone(value): specific in the config. """ config = get_config() - timezone = config.get('cli', {}).get('timezone', 'UTC') + timezone = config.get("cli", {}).get("timezone", "UTC") result = format_isodate(value=value, timezone=timezone) return result diff --git a/st2client/st2client/utils/httpclient.py b/st2client/st2client/utils/httpclient.py index 089f6b88d6b..6af6595ec50 100644 --- a/st2client/st2client/utils/httpclient.py +++ b/st2client/st2client/utils/httpclient.py @@ -27,38 +27,41 @@ def add_ssl_verify_to_kwargs(func): def decorate(*args, **kwargs): - if isinstance(args[0], HTTPClient) and 'https' in getattr(args[0], 'root', ''): - cacert = getattr(args[0], 'cacert', None) - kwargs['verify'] = cacert if cacert is not None else False + if isinstance(args[0], HTTPClient) and "https" in getattr(args[0], "root", ""): + cacert = getattr(args[0], "cacert", None) + kwargs["verify"] = cacert if cacert is not None else False return func(*args, **kwargs) + return decorate def add_auth_token_to_headers(func): def decorate(*args, **kwargs): - headers = kwargs.get('headers', dict()) + headers = kwargs.get("headers", dict()) - token = kwargs.pop('token', None) + token = kwargs.pop("token", None) if token: - headers['X-Auth-Token'] = str(token) - kwargs['headers'] = headers + headers["X-Auth-Token"] = str(token) + kwargs["headers"] = headers - api_key = kwargs.pop('api_key', None) + api_key = kwargs.pop("api_key", None) if api_key: - headers['St2-Api-Key'] = str(api_key) - kwargs['headers'] = headers + headers["St2-Api-Key"] = str(api_key) + kwargs["headers"] = headers return func(*args, **kwargs) + return decorate def add_json_content_type_to_headers(func): def decorate(*args, **kwargs): - headers = kwargs.get('headers', dict()) - content_type = headers.get('content-type', 'application/json') - headers['content-type'] = content_type - kwargs['headers'] = headers + headers = kwargs.get("headers", dict()) + content_type = headers.get("content-type", "application/json") + headers["content-type"] = content_type + kwargs["headers"] = headers return func(*args, **kwargs) + return decorate @@ -71,12 +74,11 @@ def get_url_without_trailing_slash(value): :rtype: ``str`` """ - result = value[:-1] if value.endswith('/') else value + result = value[:-1] if value.endswith("/") else value return result class HTTPClient(object): - def __init__(self, root, cacert=None, debug=False): self.root = get_url_without_trailing_slash(root) self.cacert = cacert @@ -136,30 +138,30 @@ def _response_hook(self, response): print("# -------- begin %d response ----------" % (id(self))) print(response.text) print("# -------- end %d response ------------" % (id(self))) - print('') + print("") return response def _get_curl_line_for_request(self, request): - parts = ['curl'] + parts = ["curl"] # method method = request.method.upper() - if method in ['HEAD']: - parts.extend(['--head']) + if method in ["HEAD"]: + parts.extend(["--head"]) else: - parts.extend(['-X', pquote(method)]) + parts.extend(["-X", pquote(method)]) # headers for key, value in request.headers.items(): - parts.extend(['-H ', pquote('%s: %s' % (key, value))]) + parts.extend(["-H ", pquote("%s: %s" % (key, value))]) # body if request.body: - parts.extend(['--data-binary', pquote(request.body)]) + parts.extend(["--data-binary", pquote(request.body)]) # URL parts.extend([pquote(request.url)]) - curl_line = ' '.join(parts) + curl_line = " ".join(parts) return curl_line diff --git a/st2client/st2client/utils/interactive.py b/st2client/st2client/utils/interactive.py index 35065e5d948..7e6f81b29be 100644 --- a/st2client/st2client/utils/interactive.py +++ b/st2client/st2client/utils/interactive.py @@ -28,8 +28,8 @@ from six.moves import range -POSITIVE_BOOLEAN = {'1', 'y', 'yes', 'true'} -NEGATIVE_BOOLEAN = {'0', 'n', 'no', 'nope', 'nah', 'false'} +POSITIVE_BOOLEAN = {"1", "y", "yes", "true"} +NEGATIVE_BOOLEAN = {"0", "n", "no", "nope", "nah", "false"} class ReaderNotImplemented(OperationFailureException): @@ -58,10 +58,8 @@ class StringReader(object): def __init__(self, name, spec, prefix=None, secret=False, **kw): self.name = name self.spec = spec - self.prefix = prefix or '' - self.options = { - 'is_password': secret - } + self.prefix = prefix or "" + self.options = {"is_password": secret} self._construct_description() self._construct_template() @@ -84,7 +82,7 @@ def read(self): message = self.template.format(self.prefix + self.name, **self.spec) response = prompt(message, **self.options) - result = self.spec.get('default', None) + result = self.spec.get("default", None) if response: result = self._transform_response(response) @@ -92,20 +90,21 @@ def read(self): return result def _construct_description(self): - if 'description' in self.spec: + if "description" in self.spec: + def get_bottom_toolbar_tokens(cli): - return [(token.Token.Toolbar, self.spec['description'])] + return [(token.Token.Toolbar, self.spec["description"])] - self.options['get_bottom_toolbar_tokens'] = get_bottom_toolbar_tokens + self.options["get_bottom_toolbar_tokens"] = get_bottom_toolbar_tokens def _construct_template(self): - self.template = u'{0}: ' + self.template = "{0}: " - if 'default' in self.spec: - self.template = u'{0} [{default}]: ' + if "default" in self.spec: + self.template = "{0} [{default}]: " def _construct_validators(self): - self.options['validator'] = MuxValidator([self.validate], self.spec) + self.options["validator"] = MuxValidator([self.validate], self.spec) def _transform_response(self, response): return response @@ -114,25 +113,27 @@ def _transform_response(self, response): class BooleanReader(StringReader): @staticmethod def condition(spec): - return spec.get('type', None) == 'boolean' + return spec.get("type", None) == "boolean" @staticmethod def validate(input, spec): - if not input and (not spec.get('required', None) or spec.get('default', None)): + if not input and (not spec.get("required", None) or spec.get("default", None)): return if input.lower() not in POSITIVE_BOOLEAN | NEGATIVE_BOOLEAN: - raise validation.ValidationError(len(input), - 'Does not look like boolean. Pick from [%s]' - % ', '.join(POSITIVE_BOOLEAN | NEGATIVE_BOOLEAN)) + raise validation.ValidationError( + len(input), + "Does not look like boolean. Pick from [%s]" + % ", ".join(POSITIVE_BOOLEAN | NEGATIVE_BOOLEAN), + ) def _construct_template(self): - self.template = u'{0} (boolean)' + self.template = "{0} (boolean)" - if 'default' in self.spec: - self.template += u' [{}]: '.format(self.spec.get('default') and 'y' or 'n') + if "default" in self.spec: + self.template += " [{}]: ".format(self.spec.get("default") and "y" or "n") else: - self.template += u': ' + self.template += ": " def _transform_response(self, response): if response.lower() in POSITIVE_BOOLEAN: @@ -141,14 +142,16 @@ def _transform_response(self, response): return False # Hopefully, it will never happen - raise OperationFailureException('Response neither positive no negative. ' - 'Value have not been properly validated.') + raise OperationFailureException( + "Response neither positive no negative. " + "Value have not been properly validated." + ) class NumberReader(StringReader): @staticmethod def condition(spec): - return spec.get('type', None) == 'number' + return spec.get("type", None) == "number" @staticmethod def validate(input, spec): @@ -161,12 +164,12 @@ def validate(input, spec): super(NumberReader, NumberReader).validate(input, spec) def _construct_template(self): - self.template = u'{0} (float)' + self.template = "{0} (float)" - if 'default' in self.spec: - self.template += u' [{default}]: '.format(default=self.spec.get('default')) + if "default" in self.spec: + self.template += " [{default}]: ".format(default=self.spec.get("default")) else: - self.template += u': ' + self.template += ": " def _transform_response(self, response): return float(response) @@ -175,7 +178,7 @@ def _transform_response(self, response): class IntegerReader(StringReader): @staticmethod def condition(spec): - return spec.get('type', None) == 'integer' + return spec.get("type", None) == "integer" @staticmethod def validate(input, spec): @@ -188,12 +191,12 @@ def validate(input, spec): super(IntegerReader, IntegerReader).validate(input, spec) def _construct_template(self): - self.template = u'{0} (integer)' + self.template = "{0} (integer)" - if 'default' in self.spec: - self.template += u' [{default}]: '.format(default=self.spec.get('default')) + if "default" in self.spec: + self.template += " [{default}]: ".format(default=self.spec.get("default")) else: - self.template += u': ' + self.template += ": " def _transform_response(self, response): return int(response) @@ -205,71 +208,71 @@ def __init__(self, *args, **kwargs): @staticmethod def condition(spec): - return spec.get('secret', None) + return spec.get("secret", None) def _construct_template(self): - self.template = u'{0} (secret)' + self.template = "{0} (secret)" - if 'default' in self.spec: - self.template += u' [{default}]: '.format(default=self.spec.get('default')) + if "default" in self.spec: + self.template += " [{default}]: ".format(default=self.spec.get("default")) else: - self.template += u': ' + self.template += ": " class EnumReader(StringReader): @staticmethod def condition(spec): - return spec.get('enum', None) + return spec.get("enum", None) @staticmethod def validate(input, spec): - if not input and (not spec.get('required', None) or spec.get('default', None)): + if not input and (not spec.get("required", None) or spec.get("default", None)): return if not input.isdigit(): - raise validation.ValidationError(len(input), 'Not a number') + raise validation.ValidationError(len(input), "Not a number") - enum = spec.get('enum') + enum = spec.get("enum") try: enum[int(input)] except IndexError: - raise validation.ValidationError(len(input), 'Out of bounds') + raise validation.ValidationError(len(input), "Out of bounds") def _construct_template(self): - self.template = u'{0}: ' + self.template = "{0}: " - enum = self.spec.get('enum') + enum = self.spec.get("enum") for index, value in enumerate(enum): - self.template += u'\n {} - {}'.format(index, value) + self.template += "\n {} - {}".format(index, value) num_options = len(enum) - more = '' + more = "" if num_options > 3: num_options = 3 - more = '...' + more = "..." options = [str(i) for i in range(0, num_options)] - self.template += u'\nChoose from {}{}'.format(', '.join(options), more) + self.template += "\nChoose from {}{}".format(", ".join(options), more) - if 'default' in self.spec: - self.template += u' [{}]: '.format(enum.index(self.spec.get('default'))) + if "default" in self.spec: + self.template += " [{}]: ".format(enum.index(self.spec.get("default"))) else: - self.template += u': ' + self.template += ": " def _transform_response(self, response): - return self.spec.get('enum')[int(response)] + return self.spec.get("enum")[int(response)] class ObjectReader(StringReader): - @staticmethod def condition(spec): - return spec.get('type', None) == 'object' + return spec.get("type", None) == "object" def read(self): - prefix = u'{}.'.format(self.name) + prefix = "{}.".format(self.name) - result = InteractiveForm(self.spec.get('properties', {}), - prefix=prefix, reraise=True).initiate_dialog() + result = InteractiveForm( + self.spec.get("properties", {}), prefix=prefix, reraise=True + ).initiate_dialog() return result @@ -277,25 +280,27 @@ def read(self): class ArrayReader(StringReader): @staticmethod def condition(spec): - return spec.get('type', None) == 'array' + return spec.get("type", None) == "array" @staticmethod def validate(input, spec): - if not input and (not spec.get('required', None) or spec.get('default', None)): + if not input and (not spec.get("required", None) or spec.get("default", None)): return - for m in re.finditer(r'[^, ]+', input): + for m in re.finditer(r"[^, ]+", input): index, item = m.start(), m.group() try: - StringReader.validate(item, spec.get('items', {})) + StringReader.validate(item, spec.get("items", {})) except validation.ValidationError as e: raise validation.ValidationError(index, six.text_type(e)) def read(self): - item_type = self.spec.get('items', {}).get('type', 'string') + item_type = self.spec.get("items", {}).get("type", "string") - if item_type not in ['string', 'integer', 'number', 'boolean']: - message = 'Interactive mode does not support arrays of %s type yet' % item_type + if item_type not in ["string", "integer", "number", "boolean"]: + message = ( + "Interactive mode does not support arrays of %s type yet" % item_type + ) raise ReaderNotImplemented(message) result = super(ArrayReader, self).read() @@ -303,37 +308,46 @@ def read(self): return result def _construct_template(self): - self.template = u'{0} (comma-separated list)' + self.template = "{0} (comma-separated list)" - if 'default' in self.spec: - self.template += u' [{default}]: '.format(default=','.join(self.spec.get('default'))) + if "default" in self.spec: + self.template += " [{default}]: ".format( + default=",".join(self.spec.get("default")) + ) else: - self.template += u': ' + self.template += ": " def _transform_response(self, response): - return [item.strip() for item in response.split(',')] + return [item.strip() for item in response.split(",")] class ArrayObjectReader(StringReader): @staticmethod def condition(spec): - return spec.get('type', None) == 'array' and spec.get('items', {}).get('type') == 'object' + return ( + spec.get("type", None) == "array" + and spec.get("items", {}).get("type") == "object" + ) def read(self): results = [] - properties = self.spec.get('items', {}).get('properties', {}) - message = '~~~ Would you like to add another item to "%s" array / list?' % self.name + properties = self.spec.get("items", {}).get("properties", {}) + message = ( + '~~~ Would you like to add another item to "%s" array / list?' % self.name + ) is_continue = True index = 0 while is_continue: - prefix = u'{name}[{index}].'.format(name=self.name, index=index) - results.append(InteractiveForm(properties, - prefix=prefix, - reraise=True).initiate_dialog()) + prefix = "{name}[{index}].".format(name=self.name, index=index) + results.append( + InteractiveForm( + properties, prefix=prefix, reraise=True + ).initiate_dialog() + ) index += 1 - if Question(message, {'default': 'y'}).read() != 'y': + if Question(message, {"default": "y"}).read() != "y": is_continue = False return results @@ -341,53 +355,55 @@ def read(self): class ArrayEnumReader(EnumReader): def __init__(self, name, spec, prefix=None): - self.items = spec.get('items', {}) + self.items = spec.get("items", {}) super(ArrayEnumReader, self).__init__(name, spec, prefix) @staticmethod def condition(spec): - return spec.get('type', None) == 'array' and 'enum' in spec.get('items', {}) + return spec.get("type", None) == "array" and "enum" in spec.get("items", {}) @staticmethod def validate(input, spec): - if not input and (not spec.get('required', None) or spec.get('default', None)): + if not input and (not spec.get("required", None) or spec.get("default", None)): return - for m in re.finditer(r'[^, ]+', input): + for m in re.finditer(r"[^, ]+", input): index, item = m.start(), m.group() try: - EnumReader.validate(item, spec.get('items', {})) + EnumReader.validate(item, spec.get("items", {})) except validation.ValidationError as e: raise validation.ValidationError(index, six.text_type(e)) def _construct_template(self): - self.template = u'{0}: ' + self.template = "{0}: " - enum = self.items.get('enum') + enum = self.items.get("enum") for index, value in enumerate(enum): - self.template += u'\n {} - {}'.format(index, value) + self.template += "\n {} - {}".format(index, value) num_options = len(enum) - more = '' + more = "" if num_options > 3: num_options = 3 - more = '...' + more = "..." options = [str(i) for i in range(0, num_options)] - self.template += u'\nChoose from {}{}'.format(', '.join(options), more) + self.template += "\nChoose from {}{}".format(", ".join(options), more) - if 'default' in self.spec: - default_choises = [str(enum.index(item)) for item in self.spec.get('default')] - self.template += u' [{}]: '.format(', '.join(default_choises)) + if "default" in self.spec: + default_choises = [ + str(enum.index(item)) for item in self.spec.get("default") + ] + self.template += " [{}]: ".format(", ".join(default_choises)) else: - self.template += u': ' + self.template += ": " def _transform_response(self, response): result = [] - for i in (item.strip() for item in response.split(',')): + for i in (item.strip() for item in response.split(",")): if i: - result.append(self.items.get('enum')[int(i)]) + result.append(self.items.get("enum")[int(i)]) return result @@ -403,7 +419,7 @@ class InteractiveForm(object): ArrayObjectReader, ArrayReader, SecretStringReader, - StringReader + StringReader, ] def __init__(self, schema, prefix=None, reraise=False): @@ -419,11 +435,11 @@ def initiate_dialog(self): try: result[field] = self._read_field(field) except ReaderNotImplemented as e: - print('%s. Skipping...' % six.text_type(e)) + print("%s. Skipping..." % six.text_type(e)) except DialogInterrupted: if self.reraise: raise - print('Dialog interrupted.') + print("Dialog interrupted.") return result @@ -438,7 +454,7 @@ def _read_field(self, field): break if not reader: - raise ReaderNotImplemented('No reader for the field spec') + raise ReaderNotImplemented("No reader for the field spec") try: return reader.read() diff --git a/st2client/st2client/utils/jsutil.py b/st2client/st2client/utils/jsutil.py index 7aaf20dfe06..1d98ab8f467 100644 --- a/st2client/st2client/utils/jsutil.py +++ b/st2client/st2client/utils/jsutil.py @@ -48,7 +48,7 @@ def _get_value_simple(doc, key): Returns the extracted value from the key specified (if found) Returns None if the key can not be found """ - split_key = key.split('.') + split_key = key.split(".") if not split_key: return None @@ -82,8 +82,9 @@ def get_value(doc, key): raise ValueError("key is None or empty: '{}'".format(key)) if not isinstance(doc, dict): - raise ValueError("doc is not an instance of dict: type={} value='{}'".format(type(doc), - doc)) + raise ValueError( + "doc is not an instance of dict: type={} value='{}'".format(type(doc), doc) + ) # jsonpath_rw can be very slow when processing expressions. # In the case of a simple expression we've created a "fast path" that avoids # the complexity introduced by running jsonpath_rw code. @@ -113,12 +114,12 @@ def get_kvps(doc, keys): value = get_value(doc, key) if value is not None: nested = new_doc - while '.' in key: - attr = key[:key.index('.')] + while "." in key: + attr = key[: key.index(".")] if attr not in nested: nested[attr] = {} nested = nested[attr] - key = key[key.index('.') + 1:] + key = key[key.index(".") + 1 :] nested[key] = value return new_doc diff --git a/st2client/st2client/utils/logging.py b/st2client/st2client/utils/logging.py index dd8b8b9e440..8328a5c55eb 100644 --- a/st2client/st2client/utils/logging.py +++ b/st2client/st2client/utils/logging.py @@ -18,9 +18,9 @@ import logging __all__ = [ - 'LogLevelFilter', - 'set_log_level_for_all_handlers', - 'set_log_level_for_all_loggers' + "LogLevelFilter", + "set_log_level_for_all_handlers", + "set_log_level_for_all_loggers", ] diff --git a/st2client/st2client/utils/misc.py b/st2client/st2client/utils/misc.py index e8623b3070b..62c7b1a61f0 100644 --- a/st2client/st2client/utils/misc.py +++ b/st2client/st2client/utils/misc.py @@ -18,9 +18,7 @@ import six -__all__ = [ - 'merge_dicts' -] +__all__ = ["merge_dicts"] def merge_dicts(d1, d2): diff --git a/st2client/st2client/utils/schema.py b/st2client/st2client/utils/schema.py index 33142daa71e..2cf7d5b2314 100644 --- a/st2client/st2client/utils/schema.py +++ b/st2client/st2client/utils/schema.py @@ -17,36 +17,30 @@ TYPE_TABLE = { - dict: 'object', - list: 'array', - int: 'integer', - str: 'string', - float: 'number', - bool: 'boolean', - type(None): 'null', + dict: "object", + list: "array", + int: "integer", + str: "string", + float: "number", + bool: "boolean", + type(None): "null", } if sys.version_info[0] < 3: - TYPE_TABLE[unicode] = 'string' # noqa # pylint: disable=E0602 + TYPE_TABLE[unicode] = "string" # noqa # pylint: disable=E0602 def _dict_to_schema(item): schema = {} for key, value in item.iteritems(): if isinstance(value, dict): - schema[key] = { - 'type': 'object', - 'parameters': _dict_to_schema(value) - } + schema[key] = {"type": "object", "parameters": _dict_to_schema(value)} else: - schema[key] = { - 'type': TYPE_TABLE[type(value)] - } + schema[key] = {"type": TYPE_TABLE[type(value)]} return schema def render_output_schema_from_output(output): - """Given an action output produce a reasonable schema to match. - """ + """Given an action output produce a reasonable schema to match.""" return _dict_to_schema(output) diff --git a/st2client/st2client/utils/strutil.py b/st2client/st2client/utils/strutil.py index d6bc23d9cc1..0bb970ff3eb 100644 --- a/st2client/st2client/utils/strutil.py +++ b/st2client/st2client/utils/strutil.py @@ -24,9 +24,9 @@ def unescape(s): This function unescapes those chars. """ if isinstance(s, six.string_types): - s = s.replace('\\n', '\n') - s = s.replace('\\r', '\r') - s = s.replace('\\"', '\"') + s = s.replace("\\n", "\n") + s = s.replace("\\r", "\r") + s = s.replace('\\"', '"') return s @@ -39,14 +39,14 @@ def dedupe_newlines(s): """ if isinstance(s, six.string_types): - s = s.replace('\n\n', '\n') + s = s.replace("\n\n", "\n") return s def strip_carriage_returns(s): if isinstance(s, six.string_types): - s = s.replace('\\r', '') - s = s.replace('\r', '') + s = s.replace("\\r", "") + s = s.replace("\r", "") return s diff --git a/st2client/st2client/utils/terminal.py b/st2client/st2client/utils/terminal.py index 555753fc95f..6ce28a4d741 100644 --- a/st2client/st2client/utils/terminal.py +++ b/st2client/st2client/utils/terminal.py @@ -24,11 +24,7 @@ DEFAULT_TERMINAL_SIZE_COLUMNS = 150 -__all__ = [ - 'DEFAULT_TERMINAL_SIZE_COLUMNS', - - 'get_terminal_size_columns' -] +__all__ = ["DEFAULT_TERMINAL_SIZE_COLUMNS", "get_terminal_size_columns"] def get_terminal_size_columns(default=DEFAULT_TERMINAL_SIZE_COLUMNS): @@ -48,7 +44,7 @@ def get_terminal_size_columns(default=DEFAULT_TERMINAL_SIZE_COLUMNS): # This way it's consistent with upstream implementation. In the past, our implementation # checked those variables at the end as a fall back. try: - columns = os.environ['COLUMNS'] + columns = os.environ["COLUMNS"] return int(columns) except (KeyError, ValueError): pass @@ -56,8 +52,9 @@ def get_terminal_size_columns(default=DEFAULT_TERMINAL_SIZE_COLUMNS): def ioctl_GWINSZ(fd): import fcntl import termios + # Return a tuple (lines, columns) - return struct.unpack('hh', fcntl.ioctl(fd, termios.TIOCGWINSZ, '1234')) + return struct.unpack("hh", fcntl.ioctl(fd, termios.TIOCGWINSZ, "1234")) # 2. try stdin, stdout, stderr for fd in (0, 1, 2): @@ -78,10 +75,12 @@ def ioctl_GWINSZ(fd): # 4. try `stty size` try: - process = subprocess.Popen(['stty', 'size'], - shell=False, - stdout=subprocess.PIPE, - stderr=open(os.devnull, 'w')) + process = subprocess.Popen( + ["stty", "size"], + shell=False, + stdout=subprocess.PIPE, + stderr=open(os.devnull, "w"), + ) result = process.communicate() if process.returncode == 0: return tuple(int(x) for x in result[0].split())[1] @@ -101,23 +100,23 @@ def __exit__(self, type, value, traceback): return self.close() def add_stage(self, status, name): - self._write('\t[{:^20}] {}'.format(format_status(status), name)) + self._write("\t[{:^20}] {}".format(format_status(status), name)) def update_stage(self, status, name): - self._write('\t[{:^20}] {}'.format(format_status(status), name), override=True) + self._write("\t[{:^20}] {}".format(format_status(status), name), override=True) def finish_stage(self, status, name): - self._write('\t[{:^20}] {}'.format(format_status(status), name), override=True) + self._write("\t[{:^20}] {}".format(format_status(status), name), override=True) def close(self): if self.dirty: - self._write('\n') + self._write("\n") def _write(self, string, override=False): if override: - sys.stdout.write('\r') + sys.stdout.write("\r") else: - sys.stdout.write('\n') + sys.stdout.write("\n") sys.stdout.write(string) sys.stdout.flush() diff --git a/st2client/st2client/utils/types.py b/st2client/st2client/utils/types.py index 5c25990a6ed..ad70f078b9b 100644 --- a/st2client/st2client/utils/types.py +++ b/st2client/st2client/utils/types.py @@ -20,17 +20,14 @@ from __future__ import absolute_import import collections -__all__ = [ - 'OrderedSet' -] +__all__ = ["OrderedSet"] class OrderedSet(collections.MutableSet): - def __init__(self, iterable=None): self.end = end = [] - end += [None, end, end] # sentinel node for doubly linked list - self.map = {} # key --> [key, prev, next] + end += [None, end, end] # sentinel node for doubly linked list + self.map = {} # key --> [key, prev, next] if iterable is not None: self |= iterable @@ -68,15 +65,15 @@ def __reversed__(self): def pop(self, last=True): if not self: - raise KeyError('set is empty') + raise KeyError("set is empty") key = self.end[1][0] if last else self.end[2][0] self.discard(key) return key def __repr__(self): if not self: - return '%s()' % (self.__class__.__name__,) - return '%s(%r)' % (self.__class__.__name__, list(self)) + return "%s()" % (self.__class__.__name__,) + return "%s(%r)" % (self.__class__.__name__, list(self)) def __eq__(self, other): if isinstance(other, OrderedSet): diff --git a/st2client/tests/base.py b/st2client/tests/base.py index 307f00f74b2..80c14efef68 100644 --- a/st2client/tests/base.py +++ b/st2client/tests/base.py @@ -27,26 +27,22 @@ LOG = logging.getLogger(__name__) -FAKE_ENDPOINT = 'http://127.0.0.1:8268' +FAKE_ENDPOINT = "http://127.0.0.1:8268" RESOURCES = [ { "id": "123", "name": "abc", }, - { - "id": "456", - "name": "def" - } + {"id": "456", "name": "def"}, ] class FakeResource(models.Resource): - _plural = 'FakeResources' + _plural = "FakeResources" class FakeResponse(object): - def __init__(self, text, status_code, reason, *args): self.text = text self.status_code = status_code @@ -64,8 +60,7 @@ def raise_for_status(self): class FakeClient(object): def __init__(self): self.managers = { - 'FakeResource': models.ResourceManager(FakeResource, - FAKE_ENDPOINT) + "FakeResource": models.ResourceManager(FakeResource, FAKE_ENDPOINT) } @@ -75,23 +70,32 @@ def __init__(self): class BaseCLITestCase(unittest2.TestCase): - capture_output = True # if True, stdout and stderr are saved to self.stdout and self.stderr + capture_output = ( + True # if True, stdout and stderr are saved to self.stdout and self.stderr + ) stdout = six.moves.StringIO() stderr = six.moves.StringIO() - DEFAULT_SKIP_CONFIG = '1' + DEFAULT_SKIP_CONFIG = "1" def setUp(self): super(BaseCLITestCase, self).setUp() # Setup environment - for var in ['ST2_BASE_URL', 'ST2_AUTH_URL', 'ST2_API_URL', 'ST2_STREAM_URL', - 'ST2_AUTH_TOKEN', 'ST2_CONFIG_FILE', 'ST2_API_KEY']: + for var in [ + "ST2_BASE_URL", + "ST2_AUTH_URL", + "ST2_API_URL", + "ST2_STREAM_URL", + "ST2_AUTH_TOKEN", + "ST2_CONFIG_FILE", + "ST2_API_KEY", + ]: if var in os.environ: del os.environ[var] - os.environ['ST2_CLI_SKIP_CONFIG'] = self.DEFAULT_SKIP_CONFIG + os.environ["ST2_CLI_SKIP_CONFIG"] = self.DEFAULT_SKIP_CONFIG if self.capture_output: # Make sure we reset it for each test class instance @@ -134,5 +138,5 @@ def _reset_output_streams(self): self.stderr.truncate() # Verify it has been reset correctly - self.assertEqual(self.stdout.getvalue(), '') - self.assertEqual(self.stderr.getvalue(), '') + self.assertEqual(self.stdout.getvalue(), "") + self.assertEqual(self.stderr.getvalue(), "") diff --git a/st2client/tests/fixtures/loader.py b/st2client/tests/fixtures/loader.py index a471d8e7103..049a82b7a6f 100644 --- a/st2client/tests/fixtures/loader.py +++ b/st2client/tests/fixtures/loader.py @@ -14,6 +14,7 @@ # limitations under the License. from __future__ import absolute_import + try: import simplejson as json except ImportError: @@ -24,8 +25,8 @@ import yaml -ALLOWED_EXTS = ['.json', '.yaml', '.yml', '.txt'] -PARSER_FUNCS = {'.json': json.load, '.yml': yaml.safe_load, '.yaml': yaml.safe_load} +ALLOWED_EXTS = [".json", ".yaml", ".yml", ".txt"] +PARSER_FUNCS = {".json": json.load, ".yml": yaml.safe_load, ".yaml": yaml.safe_load} def get_fixtures_base_path(): @@ -44,12 +45,14 @@ def load_content(file_path): file_name, file_ext = os.path.splitext(file_path) if file_ext not in ALLOWED_EXTS: - raise Exception('Unsupported meta type %s, file %s. Allowed: %s' % - (file_ext, file_path, ALLOWED_EXTS)) + raise Exception( + "Unsupported meta type %s, file %s. Allowed: %s" + % (file_ext, file_path, ALLOWED_EXTS) + ) parser_func = PARSER_FUNCS.get(file_ext, None) - with open(file_path, 'r') as fd: + with open(file_path, "r") as fd: return parser_func(fd) if parser_func else fd.read() @@ -75,7 +78,7 @@ def load_fixtures(fixtures_dict=None): for fixture_type, fixtures in six.iteritems(fixtures_dict): loaded_fixtures = {} for fixture in fixtures: - fixture_path = fixtures_base_path + '/' + fixture + fixture_path = fixtures_base_path + "/" + fixture fixture_dict = load_content(fixture_path) loaded_fixtures[fixture] = fixture_dict all_fixtures[fixture_type] = loaded_fixtures diff --git a/st2client/tests/unit/test_action.py b/st2client/tests/unit/test_action.py index e02c1ea1ca4..1bb8be3810d 100644 --- a/st2client/tests/unit/test_action.py +++ b/st2client/tests/unit/test_action.py @@ -34,7 +34,7 @@ "float": {"type": "number"}, "json": {"type": "object"}, "list": {"type": "array"}, - "str": {"type": "string"} + "str": {"type": "string"}, }, "name": "mock-runner1", } @@ -46,7 +46,7 @@ "parameters": {}, "enabled": True, "entry_point": "", - "pack": "mockety" + "pack": "mockety", } RUNNER2 = { @@ -65,475 +65,583 @@ "float": {"type": "number"}, "json": {"type": "object"}, "list": {"type": "array"}, - "str": {"type": "string"} + "str": {"type": "string"}, }, "enabled": True, "entry_point": "", - "pack": "mockety" + "pack": "mockety", } LIVE_ACTION = { - 'action': 'mockety.mock', - 'status': 'complete', - 'result': {'stdout': 'non-empty'} + "action": "mockety.mock", + "status": "complete", + "result": {"stdout": "non-empty"}, } def get_by_name(name, **kwargs): - if name == 'mock-runner1': + if name == "mock-runner1": return models.RunnerType(**RUNNER1) - if name == 'mock-runner2': + if name == "mock-runner2": return models.RunnerType(**RUNNER2) def get_by_ref(**kwargs): - ref = kwargs.get('ref_or_id', None) + ref = kwargs.get("ref_or_id", None) if not ref: raise Exception('Actions must be referred to by "ref".') - if ref == 'mockety.mock1': + if ref == "mockety.mock1": return models.Action(**ACTION1) - if ref == 'mockety.mock2': + if ref == "mockety.mock2": return models.Action(**ACTION2) class ActionCommandTestCase(base.BaseCLITestCase): - def __init__(self, *args, **kwargs): super(ActionCommandTestCase, self).__init__(*args, **kwargs) self.shell = shell.Shell() @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_runner_param_bool_conversion(self): - self.shell.run(['run', 'mockety.mock1', 'bool=false']) - expected = {'action': 'mockety.mock1', 'user': None, 'parameters': {'bool': False}} - httpclient.HTTPClient.post.assert_called_with('/executions', expected) + self.shell.run(["run", "mockety.mock1", "bool=false"]) + expected = { + "action": "mockety.mock1", + "user": None, + "parameters": {"bool": False}, + } + httpclient.HTTPClient.post.assert_called_with("/executions", expected) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_runner_param_integer_conversion(self): - self.shell.run(['run', 'mockety.mock1', 'int=30']) - expected = {'action': 'mockety.mock1', 'user': None, 'parameters': {'int': 30}} - httpclient.HTTPClient.post.assert_called_with('/executions', expected) + self.shell.run(["run", "mockety.mock1", "int=30"]) + expected = {"action": "mockety.mock1", "user": None, "parameters": {"int": 30}} + httpclient.HTTPClient.post.assert_called_with("/executions", expected) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_runner_param_float_conversion(self): - self.shell.run(['run', 'mockety.mock1', 'float=3.01']) - expected = {'action': 'mockety.mock1', 'user': None, 'parameters': {'float': 3.01}} - httpclient.HTTPClient.post.assert_called_with('/executions', expected) + self.shell.run(["run", "mockety.mock1", "float=3.01"]) + expected = { + "action": "mockety.mock1", + "user": None, + "parameters": {"float": 3.01}, + } + httpclient.HTTPClient.post.assert_called_with("/executions", expected) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_runner_param_json_conversion(self): - self.shell.run(['run', 'mockety.mock1', 'json={"a":1}']) - expected = {'action': 'mockety.mock1', 'user': None, 'parameters': {'json': {'a': 1}}} - httpclient.HTTPClient.post.assert_called_with('/executions', expected) + self.shell.run(["run", "mockety.mock1", 'json={"a":1}']) + expected = { + "action": "mockety.mock1", + "user": None, + "parameters": {"json": {"a": 1}}, + } + httpclient.HTTPClient.post.assert_called_with("/executions", expected) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_runner_param_array_conversion(self): - self.shell.run(['run', 'mockety.mock1', 'list=one,two,three']) + self.shell.run(["run", "mockety.mock1", "list=one,two,three"]) expected = { - 'action': 'mockety.mock1', - 'user': None, - 'parameters': { - 'list': [ - 'one', - 'two', - 'three' - ] - } + "action": "mockety.mock1", + "user": None, + "parameters": {"list": ["one", "two", "three"]}, } - httpclient.HTTPClient.post.assert_called_with('/executions', expected) + httpclient.HTTPClient.post.assert_called_with("/executions", expected) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_runner_param_array_object_conversion(self): self.shell.run( [ - 'run', - 'mockety.mock1', - 'list=[{"foo":1, "ponies":"rainbows"},{"pluto":false, "earth":true}]' + "run", + "mockety.mock1", + 'list=[{"foo":1, "ponies":"rainbows"},{"pluto":false, "earth":true}]', ] ) expected = { - 'action': 'mockety.mock1', - 'user': None, - 'parameters': { - 'list': [ - { - 'foo': 1, - 'ponies': 'rainbows' - }, - { - 'pluto': False, - 'earth': True - } + "action": "mockety.mock1", + "user": None, + "parameters": { + "list": [ + {"foo": 1, "ponies": "rainbows"}, + {"pluto": False, "earth": True}, ] - } + }, } - httpclient.HTTPClient.post.assert_called_with('/executions', expected) + httpclient.HTTPClient.post.assert_called_with("/executions", expected) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_param_bool_conversion(self): - self.shell.run(['run', 'mockety.mock2', 'bool=false']) - expected = {'action': 'mockety.mock2', 'user': None, 'parameters': {'bool': False}} - httpclient.HTTPClient.post.assert_called_with('/executions', expected) + self.shell.run(["run", "mockety.mock2", "bool=false"]) + expected = { + "action": "mockety.mock2", + "user": None, + "parameters": {"bool": False}, + } + httpclient.HTTPClient.post.assert_called_with("/executions", expected) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_param_integer_conversion(self): - self.shell.run(['run', 'mockety.mock2', 'int=30']) - expected = {'action': 'mockety.mock2', 'user': None, 'parameters': {'int': 30}} - httpclient.HTTPClient.post.assert_called_with('/executions', expected) + self.shell.run(["run", "mockety.mock2", "int=30"]) + expected = {"action": "mockety.mock2", "user": None, "parameters": {"int": 30}} + httpclient.HTTPClient.post.assert_called_with("/executions", expected) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_param_float_conversion(self): - self.shell.run(['run', 'mockety.mock2', 'float=3.01']) - expected = {'action': 'mockety.mock2', 'user': None, 'parameters': {'float': 3.01}} - httpclient.HTTPClient.post.assert_called_with('/executions', expected) + self.shell.run(["run", "mockety.mock2", "float=3.01"]) + expected = { + "action": "mockety.mock2", + "user": None, + "parameters": {"float": 3.01}, + } + httpclient.HTTPClient.post.assert_called_with("/executions", expected) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_param_json_conversion(self): - self.shell.run(['run', 'mockety.mock2', 'json={"a":1}']) - expected = {'action': 'mockety.mock2', 'user': None, 'parameters': {'json': {'a': 1}}} - httpclient.HTTPClient.post.assert_called_with('/executions', expected) + self.shell.run(["run", "mockety.mock2", 'json={"a":1}']) + expected = { + "action": "mockety.mock2", + "user": None, + "parameters": {"json": {"a": 1}}, + } + httpclient.HTTPClient.post.assert_called_with("/executions", expected) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_param_array_conversion(self): - self.shell.run(['run', 'mockety.mock2', 'list=one,two,three']) + self.shell.run(["run", "mockety.mock2", "list=one,two,three"]) expected = { - 'action': 'mockety.mock2', - 'user': None, - 'parameters': { - 'list': [ - 'one', - 'two', - 'three' - ] - } + "action": "mockety.mock2", + "user": None, + "parameters": {"list": ["one", "two", "three"]}, } - httpclient.HTTPClient.post.assert_called_with('/executions', expected) + httpclient.HTTPClient.post.assert_called_with("/executions", expected) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_param_array_conversion_single_element_str(self): - self.shell.run(['run', 'mockety.mock2', 'list=one']) + self.shell.run(["run", "mockety.mock2", "list=one"]) expected = { - 'action': 'mockety.mock2', - 'user': None, - 'parameters': { - 'list': [ - 'one' - ] - } + "action": "mockety.mock2", + "user": None, + "parameters": {"list": ["one"]}, } - httpclient.HTTPClient.post.assert_called_with('/executions', expected) + httpclient.HTTPClient.post.assert_called_with("/executions", expected) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_param_array_conversion_single_element_int(self): - self.shell.run(['run', 'mockety.mock2', 'list=1']) + self.shell.run(["run", "mockety.mock2", "list=1"]) expected = { - 'action': 'mockety.mock2', - 'user': None, - 'parameters': { - 'list': [ - 1 - ] - } + "action": "mockety.mock2", + "user": None, + "parameters": {"list": [1]}, } - httpclient.HTTPClient.post.assert_called_with('/executions', expected) + httpclient.HTTPClient.post.assert_called_with("/executions", expected) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_param_array_object_conversion(self): self.shell.run( [ - 'run', - 'mockety.mock2', - 'list=[{"foo":1, "ponies":"rainbows"},{"pluto":false, "earth":true}]' + "run", + "mockety.mock2", + 'list=[{"foo":1, "ponies":"rainbows"},{"pluto":false, "earth":true}]', ] ) expected = { - 'action': 'mockety.mock2', - 'user': None, - 'parameters': { - 'list': [ - { - 'foo': 1, - 'ponies': 'rainbows' - }, - { - 'pluto': False, - 'earth': True - } + "action": "mockety.mock2", + "user": None, + "parameters": { + "list": [ + {"foo": 1, "ponies": "rainbows"}, + {"pluto": False, "earth": True}, ] - } + }, } - httpclient.HTTPClient.post.assert_called_with('/executions', expected) + httpclient.HTTPClient.post.assert_called_with("/executions", expected) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_param_dict_conversion_flag(self): - """Ensure that the automatic conversion to dict based on colons only occurs with the flag - """ + """Ensure that the automatic conversion to dict based on colons only occurs with the flag""" self.shell.run( - [ - 'run', - 'mockety.mock2', - 'list=key1:value1,key2:value2', - '--auto-dict' - ] + ["run", "mockety.mock2", "list=key1:value1,key2:value2", "--auto-dict"] ) expected = { - 'action': 'mockety.mock2', - 'user': None, - 'parameters': { - 'list': [ - { - 'key1': 'value1', - 'key2': 'value2' - } - ] - } + "action": "mockety.mock2", + "user": None, + "parameters": {"list": [{"key1": "value1", "key2": "value2"}]}, } - httpclient.HTTPClient.post.assert_called_with('/executions', expected) + httpclient.HTTPClient.post.assert_called_with("/executions", expected) - self.shell.run( - [ - 'run', - 'mockety.mock2', - 'list=key1:value1,key2:value2' - ] - ) + self.shell.run(["run", "mockety.mock2", "list=key1:value1,key2:value2"]) expected = { - 'action': 'mockety.mock2', - 'user': None, - 'parameters': { - 'list': [ - 'key1:value1', - 'key2:value2' - ] - } + "action": "mockety.mock2", + "user": None, + "parameters": {"list": ["key1:value1", "key2:value2"]}, } - httpclient.HTTPClient.post.assert_called_with('/executions', expected) + httpclient.HTTPClient.post.assert_called_with("/executions", expected) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_param_value_with_equal_sign(self): - self.shell.run(['run', 'mockety.mock2', 'key=foo=bar&ponies=unicorns']) - expected = {'action': 'mockety.mock2', 'user': None, - 'parameters': {'key': 'foo=bar&ponies=unicorns'}} - httpclient.HTTPClient.post.assert_called_with('/executions', expected) + self.shell.run(["run", "mockety.mock2", "key=foo=bar&ponies=unicorns"]) + expected = { + "action": "mockety.mock2", + "user": None, + "parameters": {"key": "foo=bar&ponies=unicorns"}, + } + httpclient.HTTPClient.post.assert_called_with("/executions", expected) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'delete', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "delete", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_cancel_single_execution(self): - self.shell.run(['execution', 'cancel', '123']) - httpclient.HTTPClient.delete.assert_called_with('/executions/123') + self.shell.run(["execution", "cancel", "123"]) + httpclient.HTTPClient.delete.assert_called_with("/executions/123") @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'delete', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "delete", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_cancel_multiple_executions(self): - self.shell.run(['execution', 'cancel', '123', '456', '789']) - calls = [mock.call('/executions/123'), - mock.call('/executions/456'), - mock.call('/executions/789')] + self.shell.run(["execution", "cancel", "123", "456", "789"]) + calls = [ + mock.call("/executions/123"), + mock.call("/executions/456"), + mock.call("/executions/789"), + ] httpclient.HTTPClient.delete.assert_has_calls(calls) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'put', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "put", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_pause_single_execution(self): - self.shell.run(['execution', 'pause', '123']) - expected = {'status': 'pausing'} - httpclient.HTTPClient.put.assert_called_with('/executions/123', expected) + self.shell.run(["execution", "pause", "123"]) + expected = {"status": "pausing"} + httpclient.HTTPClient.put.assert_called_with("/executions/123", expected) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'put', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "put", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_pause_multiple_executions(self): - self.shell.run(['execution', 'pause', '123', '456', '789']) - expected = {'status': 'pausing'} - calls = [mock.call('/executions/123', expected), - mock.call('/executions/456', expected), - mock.call('/executions/789', expected)] + self.shell.run(["execution", "pause", "123", "456", "789"]) + expected = {"status": "pausing"} + calls = [ + mock.call("/executions/123", expected), + mock.call("/executions/456", expected), + mock.call("/executions/789", expected), + ] httpclient.HTTPClient.put.assert_has_calls(calls) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'put', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "put", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_resume_single_execution(self): - self.shell.run(['execution', 'resume', '123']) - expected = {'status': 'resuming'} - httpclient.HTTPClient.put.assert_called_with('/executions/123', expected) + self.shell.run(["execution", "resume", "123"]) + expected = {"status": "resuming"} + httpclient.HTTPClient.put.assert_called_with("/executions/123", expected) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(side_effect=get_by_name)) + models.ResourceManager, "get_by_name", mock.MagicMock(side_effect=get_by_name) + ) @mock.patch.object( - httpclient.HTTPClient, 'put', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, 'OK'))) + httpclient.HTTPClient, + "put", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(LIVE_ACTION), 200, "OK") + ), + ) def test_resume_multiple_executions(self): - self.shell.run(['execution', 'resume', '123', '456', '789']) - expected = {'status': 'resuming'} - calls = [mock.call('/executions/123', expected), - mock.call('/executions/456', expected), - mock.call('/executions/789', expected)] + self.shell.run(["execution", "resume", "123", "456", "789"]) + expected = {"status": "resuming"} + calls = [ + mock.call("/executions/123", expected), + mock.call("/executions/456", expected), + mock.call("/executions/789", expected), + ] httpclient.HTTPClient.put.assert_has_calls(calls) diff --git a/st2client/tests/unit/test_action_alias.py b/st2client/tests/unit/test_action_alias.py index a360fd5139e..753b4e71a86 100644 --- a/st2client/tests/unit/test_action_alias.py +++ b/st2client/tests/unit/test_action_alias.py @@ -29,9 +29,7 @@ "execution": { "id": "mock-id", }, - "actionalias": { - "ref": "mock-ref" - } + "actionalias": {"ref": "mock-ref"}, } ] } @@ -43,20 +41,26 @@ def __init__(self, *args, **kwargs): self.shell = shell.Shell() @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(MOCK_MATCH_AND_EXECUTE_RESULT), - 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps(MOCK_MATCH_AND_EXECUTE_RESULT), 200, "OK" + ) + ), + ) def test_match_and_execute(self): - ret = self.shell.run(['action-alias', 'execute', "run whoami on localhost"]) + ret = self.shell.run(["action-alias", "execute", "run whoami on localhost"]) self.assertEqual(ret, 0) expected_args = { - 'command': 'run whoami on localhost', - 'user': '', - 'source_channel': 'cli' + "command": "run whoami on localhost", + "user": "", + "source_channel": "cli", } - httpclient.HTTPClient.post.assert_called_with('/aliasexecution/match_and_execute', - expected_args) + httpclient.HTTPClient.post.assert_called_with( + "/aliasexecution/match_and_execute", expected_args + ) mock_stdout = self.stdout.getvalue() diff --git a/st2client/tests/unit/test_app.py b/st2client/tests/unit/test_app.py index eb1a67242eb..217d3875ad7 100644 --- a/st2client/tests/unit/test_app.py +++ b/st2client/tests/unit/test_app.py @@ -26,33 +26,33 @@ class BaseCLIAppTestCase(unittest2.TestCase): - @mock.patch('os.path.isfile', mock.Mock()) + @mock.patch("os.path.isfile", mock.Mock()) def test_cli_config_file_path(self): app = BaseCLIApp() args = mock.Mock() # 1. Absolute path - args.config_file = '/tmp/full/abs/path/config.ini' + args.config_file = "/tmp/full/abs/path/config.ini" result = app._get_config_file_path(args=args) self.assertEqual(result, args.config_file) - args.config_file = '/home/user/st2/config.ini' + args.config_file = "/home/user/st2/config.ini" result = app._get_config_file_path(args=args) self.assertEqual(result, args.config_file) # 2. Path relative to user home directory, should get expanded - args.config_file = '~/.st2/config.ini' + args.config_file = "~/.st2/config.ini" result = app._get_config_file_path(args=args) - expected = os.path.join(os.path.expanduser('~' + USER), '.st2/config.ini') + expected = os.path.join(os.path.expanduser("~" + USER), ".st2/config.ini") self.assertEqual(result, expected) # 3. Relative path (should get converted to absolute one) - args.config_file = 'config.ini' + args.config_file = "config.ini" result = app._get_config_file_path(args=args) - expected = os.path.join(os.getcwd(), 'config.ini') + expected = os.path.join(os.getcwd(), "config.ini") self.assertEqual(result, expected) - args.config_file = '.st2/config.ini' + args.config_file = ".st2/config.ini" result = app._get_config_file_path(args=args) - expected = os.path.join(os.getcwd(), '.st2/config.ini') + expected = os.path.join(os.getcwd(), ".st2/config.ini") self.assertEqual(result, expected) diff --git a/st2client/tests/unit/test_auth.py b/st2client/tests/unit/test_auth.py index cd838712cb0..e59b31dfaf7 100644 --- a/st2client/tests/unit/test_auth.py +++ b/st2client/tests/unit/test_auth.py @@ -29,24 +29,27 @@ from st2client import shell from st2client.models.core import add_auth_token_to_kwargs_from_env from st2client.commands.resource import add_auth_token_to_kwargs_from_cli -from st2client.utils.httpclient import add_auth_token_to_headers, add_json_content_type_to_headers +from st2client.utils.httpclient import ( + add_auth_token_to_headers, + add_json_content_type_to_headers, +) LOG = logging.getLogger(__name__) if six.PY3: RULE = { - 'name': 'drule', - 'description': 'i am THE rule.', - 'pack': 'cli', - 'id': uuid.uuid4().hex + "name": "drule", + "description": "i am THE rule.", + "pack": "cli", + "id": uuid.uuid4().hex, } else: RULE = { - 'id': uuid.uuid4().hex, - 'description': 'i am THE rule.', - 'name': 'drule', - 'pack': 'cli', + "id": uuid.uuid4().hex, + "description": "i am THE rule.", + "name": "drule", + "pack": "cli", } @@ -59,9 +62,9 @@ class TestLoginBase(base.BaseCLITestCase): on duplicate code in each test class """ - DOTST2_PATH = os.path.expanduser('~/.st2/') - CONFIG_FILE_NAME = 'st2.conf' - PARENT_DIR = 'testconfig' + DOTST2_PATH = os.path.expanduser("~/.st2/") + CONFIG_FILE_NAME = "st2.conf" + PARENT_DIR = "testconfig" TMP_DIR = tempfile.mkdtemp() CONFIG_CONTENTS = """ [credentials] @@ -73,11 +76,11 @@ def __init__(self, *args, **kwargs): super(TestLoginBase, self).__init__(*args, **kwargs) # We're overriding the default behavior for CLI test cases here - self.DEFAULT_SKIP_CONFIG = '0' + self.DEFAULT_SKIP_CONFIG = "0" self.parser = argparse.ArgumentParser() - self.parser.add_argument('-t', '--token', dest='token') - self.parser.add_argument('--api-key', dest='api_key') + self.parser.add_argument("-t", "--token", dest="token") + self.parser.add_argument("--api-key", dest="api_key") self.shell = shell.Shell() self.CONFIG_DIR = os.path.join(self.TMP_DIR, self.PARENT_DIR) @@ -94,9 +97,9 @@ def setUp(self): if os.path.isfile(self.CONFIG_FILE): os.remove(self.CONFIG_FILE) - with open(self.CONFIG_FILE, 'w') as cfg: - for line in self.CONFIG_CONTENTS.split('\n'): - cfg.write('%s\n' % line.strip()) + with open(self.CONFIG_FILE, "w") as cfg: + for line in self.CONFIG_CONTENTS.split("\n"): + cfg.write("%s\n" % line.strip()) os.chmod(self.CONFIG_FILE, 0o660) @@ -107,7 +110,7 @@ def tearDown(self): os.remove(self.CONFIG_FILE) # Clean up tokens - for file in [f for f in os.listdir(self.DOTST2_PATH) if 'token-' in f]: + for file in [f for f in os.listdir(self.DOTST2_PATH) if "token-" in f]: os.remove(self.DOTST2_PATH + file) # Clean up config directory @@ -116,181 +119,208 @@ def tearDown(self): class TestLoginPasswordAndConfig(TestLoginBase): - CONFIG_FILE_NAME = 'logintest.cfg' + CONFIG_FILE_NAME = "logintest.cfg" TOKEN = { - 'user': 'st2admin', - 'token': '44583f15945b4095afbf57058535ca64', - 'expiry': '2017-02-12T00:53:09.632783Z', - 'id': '589e607532ed3535707f10eb', - 'metadata': {} + "user": "st2admin", + "token": "44583f15945b4095afbf57058535ca64", + "expiry": "2017-02-12T00:53:09.632783Z", + "id": "589e607532ed3535707f10eb", + "metadata": {}, } @mock.patch.object( - requests, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(TOKEN), 200, 'OK'))) + requests, + "post", + mock.MagicMock(return_value=base.FakeResponse(json.dumps(TOKEN), 200, "OK")), + ) def runTest(self): - '''Test 'st2 login' functionality by specifying a password and a configuration file - ''' - - expected_username = self.TOKEN['user'] - args = ['--config', self.CONFIG_FILE, 'login', expected_username, '--password', - 'Password1!'] + """Test 'st2 login' functionality by specifying a password and a configuration file""" + + expected_username = self.TOKEN["user"] + args = [ + "--config", + self.CONFIG_FILE, + "login", + expected_username, + "--password", + "Password1!", + ] self.shell.run(args) - with open(self.CONFIG_FILE, 'r') as config_file: + with open(self.CONFIG_FILE, "r") as config_file: for line in config_file.readlines(): # Make sure certain values are not present - self.assertNotIn('password', line) - self.assertNotIn('olduser', line) + self.assertNotIn("password", line) + self.assertNotIn("olduser", line) # Make sure configured username is what we expect - if 'username' in line: - self.assertEqual(line.split(' ')[2][:-1], expected_username) + if "username" in line: + self.assertEqual(line.split(" ")[2][:-1], expected_username) # validate token was created - self.assertTrue(os.path.isfile('%stoken-%s' % (self.DOTST2_PATH, expected_username))) + self.assertTrue( + os.path.isfile("%stoken-%s" % (self.DOTST2_PATH, expected_username)) + ) class TestLoginIntPwdAndConfig(TestLoginBase): - CONFIG_FILE_NAME = 'logintest.cfg' + CONFIG_FILE_NAME = "logintest.cfg" TOKEN = { - 'user': 'st2admin', - 'token': '44583f15945b4095afbf57058535ca64', - 'expiry': '2017-02-12T00:53:09.632783Z', - 'id': '589e607532ed3535707f10eb', - 'metadata': {} + "user": "st2admin", + "token": "44583f15945b4095afbf57058535ca64", + "expiry": "2017-02-12T00:53:09.632783Z", + "id": "589e607532ed3535707f10eb", + "metadata": {}, } @mock.patch.object( - requests, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(TOKEN), 200, 'OK'))) + requests, + "post", + mock.MagicMock(return_value=base.FakeResponse(json.dumps(TOKEN), 200, "OK")), + ) @mock.patch.object( - requests, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, 'OK'))) - @mock.patch('st2client.commands.auth.getpass') + requests, + "get", + mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, "OK")), + ) + @mock.patch("st2client.commands.auth.getpass") def runTest(self, mock_gp): - '''Test 'st2 login' functionality with interactive password entry - ''' + """Test 'st2 login' functionality with interactive password entry""" - expected_username = self.TOKEN['user'] - args = ['--config', self.CONFIG_FILE, 'login', expected_username] + expected_username = self.TOKEN["user"] + args = ["--config", self.CONFIG_FILE, "login", expected_username] - mock_gp.getpass.return_value = 'Password1!' + mock_gp.getpass.return_value = "Password1!" self.shell.run(args) expected_kwargs = { - 'headers': {'content-type': 'application/json'}, - 'auth': ('st2admin', 'Password1!') + "headers": {"content-type": "application/json"}, + "auth": ("st2admin", "Password1!"), } - requests.post.assert_called_with('http://127.0.0.1:9100/tokens', '{}', **expected_kwargs) + requests.post.assert_called_with( + "http://127.0.0.1:9100/tokens", "{}", **expected_kwargs + ) # Check file permissions self.assertEqual(os.stat(self.CONFIG_FILE).st_mode & 0o777, 0o660) - with open(self.CONFIG_FILE, 'r') as config_file: + with open(self.CONFIG_FILE, "r") as config_file: for line in config_file.readlines(): # Make sure certain values are not present - self.assertNotIn('password', line) - self.assertNotIn('olduser', line) + self.assertNotIn("password", line) + self.assertNotIn("olduser", line) # Make sure configured username is what we expect - if 'username' in line: - self.assertEqual(line.split(' ')[2][:-1], expected_username) + if "username" in line: + self.assertEqual(line.split(" ")[2][:-1], expected_username) # validate token was created - self.assertTrue(os.path.isfile('%stoken-%s' % (self.DOTST2_PATH, expected_username))) + self.assertTrue( + os.path.isfile("%stoken-%s" % (self.DOTST2_PATH, expected_username)) + ) # Validate token is sent on subsequent requests to st2 API - args = ['--config', self.CONFIG_FILE, 'pack', 'list'] + args = ["--config", self.CONFIG_FILE, "pack", "list"] self.shell.run(args) expected_kwargs = { - 'headers': { - 'X-Auth-Token': self.TOKEN['token'] - }, - 'params': { - 'include_attributes': 'ref,name,description,version,author' - } + "headers": {"X-Auth-Token": self.TOKEN["token"]}, + "params": {"include_attributes": "ref,name,description,version,author"}, } - requests.get.assert_called_with('http://127.0.0.1:9101/v1/packs', **expected_kwargs) + requests.get.assert_called_with( + "http://127.0.0.1:9101/v1/packs", **expected_kwargs + ) class TestLoginWritePwdOkay(TestLoginBase): - CONFIG_FILE_NAME = 'logintest.cfg' + CONFIG_FILE_NAME = "logintest.cfg" TOKEN = { - 'user': 'st2admin', - 'token': '44583f15945b4095afbf57058535ca64', - 'expiry': '2017-02-12T00:53:09.632783Z', - 'id': '589e607532ed3535707f10eb', - 'metadata': {} + "user": "st2admin", + "token": "44583f15945b4095afbf57058535ca64", + "expiry": "2017-02-12T00:53:09.632783Z", + "id": "589e607532ed3535707f10eb", + "metadata": {}, } @mock.patch.object( - requests, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(TOKEN), 200, 'OK'))) - @mock.patch('st2client.commands.auth.getpass') + requests, + "post", + mock.MagicMock(return_value=base.FakeResponse(json.dumps(TOKEN), 200, "OK")), + ) + @mock.patch("st2client.commands.auth.getpass") def runTest(self, mock_gp): - '''Test 'st2 login' functionality with --write-password flag set - ''' - - expected_username = self.TOKEN['user'] - args = ['--config', self.CONFIG_FILE, 'login', expected_username, '--password', - 'Password1!', '--write-password'] + """Test 'st2 login' functionality with --write-password flag set""" + + expected_username = self.TOKEN["user"] + args = [ + "--config", + self.CONFIG_FILE, + "login", + expected_username, + "--password", + "Password1!", + "--write-password", + ] self.shell.run(args) - with open(self.CONFIG_FILE, 'r') as config_file: + with open(self.CONFIG_FILE, "r") as config_file: for line in config_file.readlines(): # Make sure certain values are not present - self.assertNotIn('olduser', line) + self.assertNotIn("olduser", line) # Make sure configured username is what we expect - if 'username' in line: - self.assertEqual(line.split(' ')[2][:-1], expected_username) + if "username" in line: + self.assertEqual(line.split(" ")[2][:-1], expected_username) # validate token was created - self.assertTrue(os.path.isfile('%stoken-%s' % (self.DOTST2_PATH, expected_username))) + self.assertTrue( + os.path.isfile("%stoken-%s" % (self.DOTST2_PATH, expected_username)) + ) class TestLoginUncaughtException(TestLoginBase): - CONFIG_FILE_NAME = 'logintest.cfg' + CONFIG_FILE_NAME = "logintest.cfg" TOKEN = { - 'user': 'st2admin', - 'token': '44583f15945b4095afbf57058535ca64', - 'expiry': '2017-02-12T00:53:09.632783Z', - 'id': '589e607532ed3535707f10eb', - 'metadata': {} + "user": "st2admin", + "token": "44583f15945b4095afbf57058535ca64", + "expiry": "2017-02-12T00:53:09.632783Z", + "id": "589e607532ed3535707f10eb", + "metadata": {}, } @mock.patch.object( - requests, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(TOKEN), 200, 'OK'))) - @mock.patch('st2client.commands.auth.getpass') + requests, + "post", + mock.MagicMock(return_value=base.FakeResponse(json.dumps(TOKEN), 200, "OK")), + ) + @mock.patch("st2client.commands.auth.getpass") def runTest(self, mock_gp): - '''Test 'st2 login' ability to detect unhandled exceptions - ''' + """Test 'st2 login' ability to detect unhandled exceptions""" - expected_username = self.TOKEN['user'] - args = ['--config', self.CONFIG_FILE, 'login', expected_username] + expected_username = self.TOKEN["user"] + args = ["--config", self.CONFIG_FILE, "login", expected_username] mock_gp.getpass = mock.MagicMock(side_effect=Exception) self.shell.run(args) retcode = self.shell.run(args) - self.assertIn('Failed to log in as %s' % expected_username, self.stdout.getvalue()) - self.assertNotIn('Logged in as', self.stdout.getvalue()) + self.assertIn( + "Failed to log in as %s" % expected_username, self.stdout.getvalue() + ) + self.assertNotIn("Logged in as", self.stdout.getvalue()) self.assertEqual(retcode, 1) @@ -301,26 +331,26 @@ class TestAuthToken(base.BaseCLITestCase): def __init__(self, *args, **kwargs): super(TestAuthToken, self).__init__(*args, **kwargs) self.parser = argparse.ArgumentParser() - self.parser.add_argument('-t', '--token', dest='token') - self.parser.add_argument('--api-key', dest='api_key') + self.parser.add_argument("-t", "--token", dest="token") + self.parser.add_argument("--api-key", dest="api_key") self.shell = shell.Shell() def setUp(self): super(TestAuthToken, self).setUp() # Setup environment. - os.environ['ST2_BASE_URL'] = 'http://127.0.0.1' + os.environ["ST2_BASE_URL"] = "http://127.0.0.1" def tearDown(self): super(TestAuthToken, self).tearDown() # Clean up environment. - if 'ST2_AUTH_TOKEN' in os.environ: - del os.environ['ST2_AUTH_TOKEN'] - if 'ST2_API_KEY' in os.environ: - del os.environ['ST2_API_KEY'] - if 'ST2_BASE_URL' in os.environ: - del os.environ['ST2_BASE_URL'] + if "ST2_AUTH_TOKEN" in os.environ: + del os.environ["ST2_AUTH_TOKEN"] + if "ST2_API_KEY" in os.environ: + del os.environ["ST2_API_KEY"] + if "ST2_BASE_URL" in os.environ: + del os.environ["ST2_BASE_URL"] @add_auth_token_to_kwargs_from_cli @add_auth_token_to_kwargs_from_env @@ -329,27 +359,27 @@ def _mock_run(self, args, **kwargs): def test_decorate_auth_token_by_cli(self): token = uuid.uuid4().hex - args = self.parser.parse_args(args=['-t', token]) - self.assertDictEqual(self._mock_run(args), {'token': token}) - args = self.parser.parse_args(args=['--token', token]) - self.assertDictEqual(self._mock_run(args), {'token': token}) + args = self.parser.parse_args(args=["-t", token]) + self.assertDictEqual(self._mock_run(args), {"token": token}) + args = self.parser.parse_args(args=["--token", token]) + self.assertDictEqual(self._mock_run(args), {"token": token}) def test_decorate_api_key_by_cli(self): token = uuid.uuid4().hex - args = self.parser.parse_args(args=['--api-key', token]) - self.assertDictEqual(self._mock_run(args), {'api_key': token}) + args = self.parser.parse_args(args=["--api-key", token]) + self.assertDictEqual(self._mock_run(args), {"api_key": token}) def test_decorate_auth_token_by_env(self): token = uuid.uuid4().hex - os.environ['ST2_AUTH_TOKEN'] = token + os.environ["ST2_AUTH_TOKEN"] = token args = self.parser.parse_args(args=[]) - self.assertDictEqual(self._mock_run(args), {'token': token}) + self.assertDictEqual(self._mock_run(args), {"token": token}) def test_decorate_api_key_by_env(self): token = uuid.uuid4().hex - os.environ['ST2_API_KEY'] = token + os.environ["ST2_API_KEY"] = token args = self.parser.parse_args(args=[]) - self.assertDictEqual(self._mock_run(args), {'api_key': token}) + self.assertDictEqual(self._mock_run(args), {"api_key": token}) def test_decorate_without_auth_token(self): args = self.parser.parse_args(args=[]) @@ -362,187 +392,215 @@ def _mock_http(self, url, **kwargs): def test_decorate_auth_token_to_http_headers(self): token = uuid.uuid4().hex - kwargs = self._mock_http('/', token=token) - expected = {'content-type': 'application/json', 'X-Auth-Token': token} - self.assertIn('headers', kwargs) - self.assertDictEqual(kwargs['headers'], expected) + kwargs = self._mock_http("/", token=token) + expected = {"content-type": "application/json", "X-Auth-Token": token} + self.assertIn("headers", kwargs) + self.assertDictEqual(kwargs["headers"], expected) def test_decorate_api_key_to_http_headers(self): token = uuid.uuid4().hex - kwargs = self._mock_http('/', api_key=token) - expected = {'content-type': 'application/json', 'St2-Api-Key': token} - self.assertIn('headers', kwargs) - self.assertDictEqual(kwargs['headers'], expected) + kwargs = self._mock_http("/", api_key=token) + expected = {"content-type": "application/json", "St2-Api-Key": token} + self.assertIn("headers", kwargs) + self.assertDictEqual(kwargs["headers"], expected) def test_decorate_without_auth_token_to_http_headers(self): - kwargs = self._mock_http('/', auth=('stanley', 'stanley')) - expected = {'content-type': 'application/json'} - self.assertIn('auth', kwargs) - self.assertEqual(kwargs['auth'], ('stanley', 'stanley')) - self.assertIn('headers', kwargs) - self.assertDictEqual(kwargs['headers'], expected) + kwargs = self._mock_http("/", auth=("stanley", "stanley")) + expected = {"content-type": "application/json"} + self.assertIn("auth", kwargs) + self.assertEqual(kwargs["auth"], ("stanley", "stanley")) + self.assertIn("headers", kwargs) + self.assertDictEqual(kwargs["headers"], expected) @mock.patch.object( - requests, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, 'OK'))) + requests, + "get", + mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, "OK")), + ) def test_decorate_resource_list(self): - url = ('http://127.0.0.1:9101/v1/rules/' - '?include_attributes=ref,pack,description,enabled&limit=50') - url = url.replace(',', '%2C') + url = ( + "http://127.0.0.1:9101/v1/rules/" + "?include_attributes=ref,pack,description,enabled&limit=50" + ) + url = url.replace(",", "%2C") # Test without token. - self.shell.run(['rule', 'list']) + self.shell.run(["rule", "list"]) kwargs = {} requests.get.assert_called_with(url, **kwargs) # Test with token from cli. token = uuid.uuid4().hex - self.shell.run(['rule', 'list', '-t', token]) - kwargs = {'headers': {'X-Auth-Token': token}} + self.shell.run(["rule", "list", "-t", token]) + kwargs = {"headers": {"X-Auth-Token": token}} requests.get.assert_called_with(url, **kwargs) # Test with token from env. token = uuid.uuid4().hex - os.environ['ST2_AUTH_TOKEN'] = token - self.shell.run(['rule', 'list']) - kwargs = {'headers': {'X-Auth-Token': token}} + os.environ["ST2_AUTH_TOKEN"] = token + self.shell.run(["rule", "list"]) + kwargs = {"headers": {"X-Auth-Token": token}} requests.get.assert_called_with(url, **kwargs) @mock.patch.object( - requests, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(RULE), 200, 'OK'))) + requests, + "get", + mock.MagicMock(return_value=base.FakeResponse(json.dumps(RULE), 200, "OK")), + ) def test_decorate_resource_get(self): - rule_ref = '%s.%s' % (RULE['pack'], RULE['name']) - url = 'http://127.0.0.1:9101/v1/rules/%s' % rule_ref + rule_ref = "%s.%s" % (RULE["pack"], RULE["name"]) + url = "http://127.0.0.1:9101/v1/rules/%s" % rule_ref # Test without token. - self.shell.run(['rule', 'get', rule_ref]) + self.shell.run(["rule", "get", rule_ref]) kwargs = {} requests.get.assert_called_with(url, **kwargs) # Test with token from cli. token = uuid.uuid4().hex - self.shell.run(['rule', 'get', rule_ref, '-t', token]) - kwargs = {'headers': {'X-Auth-Token': token}} + self.shell.run(["rule", "get", rule_ref, "-t", token]) + kwargs = {"headers": {"X-Auth-Token": token}} requests.get.assert_called_with(url, **kwargs) # Test with token from env. token = uuid.uuid4().hex - os.environ['ST2_AUTH_TOKEN'] = token - self.shell.run(['rule', 'get', rule_ref]) - kwargs = {'headers': {'X-Auth-Token': token}} + os.environ["ST2_AUTH_TOKEN"] = token + self.shell.run(["rule", "get", rule_ref]) + kwargs = {"headers": {"X-Auth-Token": token}} requests.get.assert_called_with(url, **kwargs) @mock.patch.object( - requests, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(RULE), 200, 'OK'))) + requests, + "post", + mock.MagicMock(return_value=base.FakeResponse(json.dumps(RULE), 200, "OK")), + ) def test_decorate_resource_post(self): - url = 'http://127.0.0.1:9101/v1/rules' - data = {'name': RULE['name'], 'description': RULE['description']} + url = "http://127.0.0.1:9101/v1/rules" + data = {"name": RULE["name"], "description": RULE["description"]} - fd, path = tempfile.mkstemp(suffix='.json') + fd, path = tempfile.mkstemp(suffix=".json") try: - with open(path, 'a') as f: + with open(path, "a") as f: f.write(json.dumps(data, indent=4)) # Test without token. - self.shell.run(['rule', 'create', path]) - kwargs = {'headers': {'content-type': 'application/json'}} + self.shell.run(["rule", "create", path]) + kwargs = {"headers": {"content-type": "application/json"}} requests.post.assert_called_with(url, json.dumps(data), **kwargs) # Test with token from cli. token = uuid.uuid4().hex - self.shell.run(['rule', 'create', path, '-t', token]) - kwargs = {'headers': {'content-type': 'application/json', 'X-Auth-Token': token}} + self.shell.run(["rule", "create", path, "-t", token]) + kwargs = { + "headers": {"content-type": "application/json", "X-Auth-Token": token} + } requests.post.assert_called_with(url, json.dumps(data), **kwargs) # Test with token from env. token = uuid.uuid4().hex - os.environ['ST2_AUTH_TOKEN'] = token - self.shell.run(['rule', 'create', path]) - kwargs = {'headers': {'content-type': 'application/json', 'X-Auth-Token': token}} + os.environ["ST2_AUTH_TOKEN"] = token + self.shell.run(["rule", "create", path]) + kwargs = { + "headers": {"content-type": "application/json", "X-Auth-Token": token} + } requests.post.assert_called_with(url, json.dumps(data), **kwargs) finally: os.close(fd) os.unlink(path) @mock.patch.object( - requests, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(RULE), 200, 'OK'))) + requests, + "get", + mock.MagicMock(return_value=base.FakeResponse(json.dumps(RULE), 200, "OK")), + ) @mock.patch.object( - requests, 'put', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(RULE), 200, 'OK'))) + requests, + "put", + mock.MagicMock(return_value=base.FakeResponse(json.dumps(RULE), 200, "OK")), + ) def test_decorate_resource_put(self): - rule_ref = '%s.%s' % (RULE['pack'], RULE['name']) - - get_url = 'http://127.0.0.1:9101/v1/rules/%s' % rule_ref - put_url = 'http://127.0.0.1:9101/v1/rules/%s' % RULE['id'] - data = {'name': RULE['name'], 'description': RULE['description'], 'pack': RULE['pack']} + rule_ref = "%s.%s" % (RULE["pack"], RULE["name"]) + + get_url = "http://127.0.0.1:9101/v1/rules/%s" % rule_ref + put_url = "http://127.0.0.1:9101/v1/rules/%s" % RULE["id"] + data = { + "name": RULE["name"], + "description": RULE["description"], + "pack": RULE["pack"], + } - fd, path = tempfile.mkstemp(suffix='.json') + fd, path = tempfile.mkstemp(suffix=".json") try: - with open(path, 'a') as f: + with open(path, "a") as f: f.write(json.dumps(data, indent=4)) # Test without token. - self.shell.run(['rule', 'update', rule_ref, path]) + self.shell.run(["rule", "update", rule_ref, path]) kwargs = {} requests.get.assert_called_with(get_url, **kwargs) - kwargs = {'headers': {'content-type': 'application/json'}} + kwargs = {"headers": {"content-type": "application/json"}} requests.put.assert_called_with(put_url, json.dumps(RULE), **kwargs) # Test with token from cli. token = uuid.uuid4().hex - self.shell.run(['rule', 'update', rule_ref, path, '-t', token]) - kwargs = {'headers': {'X-Auth-Token': token}} + self.shell.run(["rule", "update", rule_ref, path, "-t", token]) + kwargs = {"headers": {"X-Auth-Token": token}} requests.get.assert_called_with(get_url, **kwargs) - kwargs = {'headers': {'content-type': 'application/json', 'X-Auth-Token': token}} + kwargs = { + "headers": {"content-type": "application/json", "X-Auth-Token": token} + } requests.put.assert_called_with(put_url, json.dumps(RULE), **kwargs) # Test with token from env. token = uuid.uuid4().hex - os.environ['ST2_AUTH_TOKEN'] = token - self.shell.run(['rule', 'update', rule_ref, path]) - kwargs = {'headers': {'X-Auth-Token': token}} + os.environ["ST2_AUTH_TOKEN"] = token + self.shell.run(["rule", "update", rule_ref, path]) + kwargs = {"headers": {"X-Auth-Token": token}} requests.get.assert_called_with(get_url, **kwargs) # Note: We parse the payload because data might not be in the same # order as the fixture - kwargs = {'headers': {'content-type': 'application/json', 'X-Auth-Token': token}} + kwargs = { + "headers": {"content-type": "application/json", "X-Auth-Token": token} + } requests.put.assert_called_with(put_url, json.dumps(RULE), **kwargs) finally: os.close(fd) os.unlink(path) @mock.patch.object( - requests, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(RULE), 200, 'OK'))) + requests, + "get", + mock.MagicMock(return_value=base.FakeResponse(json.dumps(RULE), 200, "OK")), + ) @mock.patch.object( - requests, 'delete', - mock.MagicMock(return_value=base.FakeResponse('', 204, 'OK'))) + requests, + "delete", + mock.MagicMock(return_value=base.FakeResponse("", 204, "OK")), + ) def test_decorate_resource_delete(self): - rule_ref = '%s.%s' % (RULE['pack'], RULE['name']) - get_url = 'http://127.0.0.1:9101/v1/rules/%s' % rule_ref - del_url = 'http://127.0.0.1:9101/v1/rules/%s' % RULE['id'] + rule_ref = "%s.%s" % (RULE["pack"], RULE["name"]) + get_url = "http://127.0.0.1:9101/v1/rules/%s" % rule_ref + del_url = "http://127.0.0.1:9101/v1/rules/%s" % RULE["id"] # Test without token. - self.shell.run(['rule', 'delete', rule_ref]) + self.shell.run(["rule", "delete", rule_ref]) kwargs = {} requests.get.assert_called_with(get_url, **kwargs) requests.delete.assert_called_with(del_url, **kwargs) # Test with token from cli. token = uuid.uuid4().hex - self.shell.run(['rule', 'delete', rule_ref, '-t', token]) - kwargs = {'headers': {'X-Auth-Token': token}} + self.shell.run(["rule", "delete", rule_ref, "-t", token]) + kwargs = {"headers": {"X-Auth-Token": token}} requests.get.assert_called_with(get_url, **kwargs) requests.delete.assert_called_with(del_url, **kwargs) # Test with token from env. token = uuid.uuid4().hex - os.environ['ST2_AUTH_TOKEN'] = token - self.shell.run(['rule', 'delete', rule_ref]) - kwargs = {'headers': {'X-Auth-Token': token}} + os.environ["ST2_AUTH_TOKEN"] = token + self.shell.run(["rule", "delete", rule_ref]) + kwargs = {"headers": {"X-Auth-Token": token}} requests.get.assert_called_with(get_url, **kwargs) requests.delete.assert_called_with(del_url, **kwargs) diff --git a/st2client/tests/unit/test_client.py b/st2client/tests/unit/test_client.py index 2e9fd950950..2d0a380ab55 100644 --- a/st2client/tests/unit/test_client.py +++ b/st2client/tests/unit/test_client.py @@ -25,25 +25,25 @@ LOG = logging.getLogger(__name__) -NONRESOURCES = ['workflows'] +NONRESOURCES = ["workflows"] class TestClientEndpoints(unittest2.TestCase): - def tearDown(self): for var in [ - 'ST2_BASE_URL', - 'ST2_API_URL', - 'ST2_STREAM_URL', - 'ST2_DATASTORE_URL', - 'ST2_AUTH_TOKEN' + "ST2_BASE_URL", + "ST2_API_URL", + "ST2_STREAM_URL", + "ST2_DATASTORE_URL", + "ST2_AUTH_TOKEN", ]: if var in os.environ: del os.environ[var] def test_managers(self): - property_names = [k for k, v in six.iteritems(Client.__dict__) - if isinstance(v, property)] + property_names = [ + k for k, v in six.iteritems(Client.__dict__) if isinstance(v, property) + ] client = Client() @@ -55,96 +55,109 @@ def test_managers(self): self.assertIsInstance(manager, models.ResourceManager) def test_default(self): - base_url = 'http://127.0.0.1' - api_url = 'http://127.0.0.1:9101/v1' - stream_url = 'http://127.0.0.1:9102/v1' + base_url = "http://127.0.0.1" + api_url = "http://127.0.0.1:9101/v1" + stream_url = "http://127.0.0.1:9102/v1" client = Client() endpoints = client.endpoints - self.assertEqual(endpoints['base'], base_url) - self.assertEqual(endpoints['api'], api_url) - self.assertEqual(endpoints['stream'], stream_url) + self.assertEqual(endpoints["base"], base_url) + self.assertEqual(endpoints["api"], api_url) + self.assertEqual(endpoints["stream"], stream_url) def test_env(self): - base_url = 'http://www.stackstorm.com' - api_url = 'http://www.st2.com:9101/v1' - stream_url = 'http://www.st2.com:9102/v1' + base_url = "http://www.stackstorm.com" + api_url = "http://www.st2.com:9101/v1" + stream_url = "http://www.st2.com:9102/v1" - os.environ['ST2_BASE_URL'] = base_url - os.environ['ST2_API_URL'] = api_url - os.environ['ST2_STREAM_URL'] = stream_url - self.assertEqual(os.environ.get('ST2_BASE_URL'), base_url) - self.assertEqual(os.environ.get('ST2_API_URL'), api_url) - self.assertEqual(os.environ.get('ST2_STREAM_URL'), stream_url) + os.environ["ST2_BASE_URL"] = base_url + os.environ["ST2_API_URL"] = api_url + os.environ["ST2_STREAM_URL"] = stream_url + self.assertEqual(os.environ.get("ST2_BASE_URL"), base_url) + self.assertEqual(os.environ.get("ST2_API_URL"), api_url) + self.assertEqual(os.environ.get("ST2_STREAM_URL"), stream_url) client = Client() endpoints = client.endpoints - self.assertEqual(endpoints['base'], base_url) - self.assertEqual(endpoints['api'], api_url) - self.assertEqual(endpoints['stream'], stream_url) + self.assertEqual(endpoints["base"], base_url) + self.assertEqual(endpoints["api"], api_url) + self.assertEqual(endpoints["stream"], stream_url) def test_env_base_only(self): - base_url = 'http://www.stackstorm.com' - api_url = 'http://www.stackstorm.com:9101/v1' - stream_url = 'http://www.stackstorm.com:9102/v1' + base_url = "http://www.stackstorm.com" + api_url = "http://www.stackstorm.com:9101/v1" + stream_url = "http://www.stackstorm.com:9102/v1" - os.environ['ST2_BASE_URL'] = base_url - self.assertEqual(os.environ.get('ST2_BASE_URL'), base_url) - self.assertEqual(os.environ.get('ST2_API_URL'), None) - self.assertEqual(os.environ.get('ST2_STREAM_URL'), None) + os.environ["ST2_BASE_URL"] = base_url + self.assertEqual(os.environ.get("ST2_BASE_URL"), base_url) + self.assertEqual(os.environ.get("ST2_API_URL"), None) + self.assertEqual(os.environ.get("ST2_STREAM_URL"), None) client = Client() endpoints = client.endpoints - self.assertEqual(endpoints['base'], base_url) - self.assertEqual(endpoints['api'], api_url) - self.assertEqual(endpoints['stream'], stream_url) + self.assertEqual(endpoints["base"], base_url) + self.assertEqual(endpoints["api"], api_url) + self.assertEqual(endpoints["stream"], stream_url) def test_args(self): - base_url = 'http://www.stackstorm.com' - api_url = 'http://www.st2.com:9101/v1' - stream_url = 'http://www.st2.com:9102/v1' + base_url = "http://www.stackstorm.com" + api_url = "http://www.st2.com:9101/v1" + stream_url = "http://www.st2.com:9102/v1" client = Client(base_url=base_url, api_url=api_url, stream_url=stream_url) endpoints = client.endpoints - self.assertEqual(endpoints['base'], base_url) - self.assertEqual(endpoints['api'], api_url) - self.assertEqual(endpoints['stream'], stream_url) + self.assertEqual(endpoints["base"], base_url) + self.assertEqual(endpoints["api"], api_url) + self.assertEqual(endpoints["stream"], stream_url) def test_cacert_arg(self): # Valid value, boolean True - base_url = 'http://www.stackstorm.com' - api_url = 'http://www.st2.com:9101/v1' - stream_url = 'http://www.st2.com:9102/v1' + base_url = "http://www.stackstorm.com" + api_url = "http://www.st2.com:9101/v1" + stream_url = "http://www.st2.com:9102/v1" - client = Client(base_url=base_url, api_url=api_url, stream_url=stream_url, cacert=True) + client = Client( + base_url=base_url, api_url=api_url, stream_url=stream_url, cacert=True + ) self.assertEqual(client.cacert, True) # Valid value, boolean False - base_url = 'http://www.stackstorm.com' - api_url = 'http://www.st2.com:9101/v1' - stream_url = 'http://www.st2.com:9102/v1' + base_url = "http://www.stackstorm.com" + api_url = "http://www.st2.com:9101/v1" + stream_url = "http://www.st2.com:9102/v1" - client = Client(base_url=base_url, api_url=api_url, stream_url=stream_url, cacert=False) + client = Client( + base_url=base_url, api_url=api_url, stream_url=stream_url, cacert=False + ) self.assertEqual(client.cacert, False) # Valid value, existing path to a CA bundle cacert = os.path.abspath(__file__) - client = Client(base_url=base_url, api_url=api_url, stream_url=stream_url, cacert=cacert) + client = Client( + base_url=base_url, api_url=api_url, stream_url=stream_url, cacert=cacert + ) self.assertEqual(client.cacert, cacert) # Invalid value, path to the bundle doesn't exist cacert = os.path.abspath(__file__) expected_msg = 'CA cert file "doesntexist" does not exist' - self.assertRaisesRegexp(ValueError, expected_msg, Client, base_url=base_url, - api_url=api_url, stream_url=stream_url, cacert='doesntexist') + self.assertRaisesRegexp( + ValueError, + expected_msg, + Client, + base_url=base_url, + api_url=api_url, + stream_url=stream_url, + cacert="doesntexist", + ) def test_args_base_only(self): - base_url = 'http://www.stackstorm.com' - api_url = 'http://www.stackstorm.com:9101/v1' - stream_url = 'http://www.stackstorm.com:9102/v1' + base_url = "http://www.stackstorm.com" + api_url = "http://www.stackstorm.com:9101/v1" + stream_url = "http://www.stackstorm.com:9102/v1" client = Client(base_url=base_url) endpoints = client.endpoints - self.assertEqual(endpoints['base'], base_url) - self.assertEqual(endpoints['api'], api_url) - self.assertEqual(endpoints['stream'], stream_url) + self.assertEqual(endpoints["base"], base_url) + self.assertEqual(endpoints["api"], api_url) + self.assertEqual(endpoints["stream"], stream_url) diff --git a/st2client/tests/unit/test_client_actions.py b/st2client/tests/unit/test_client_actions.py index 82b12b788d2..141e7c8ece5 100644 --- a/st2client/tests/unit/test_client_actions.py +++ b/st2client/tests/unit/test_client_actions.py @@ -31,22 +31,17 @@ EXECUTION = { "id": 12345, - "action": { - "ref": "mock.foobar" - }, + "action": {"ref": "mock.foobar"}, "status": "failed", - "result": "non-empty" + "result": "non-empty", } ENTRYPOINT = ( "version: 1.0" - "description: A basic workflow that runs an arbitrary linux command." - "input:" " - cmd" " - timeout" - "tasks:" " task1:" " action: core.local cmd=<% ctx(cmd) %> timeout=<% ctx(timeout) %>" @@ -55,51 +50,63 @@ " publish:" " - stdout: <% result().stdout %>" " - stderr: <% result().stderr %>" - "output:" " - stdout: <% ctx(stdout) %>" ) class TestActionResourceManager(unittest2.TestCase): - @classmethod def setUpClass(cls): super(TestActionResourceManager, cls).setUpClass() cls.client = client.Client() @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(ENTRYPOINT), 200, 'OK'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(ENTRYPOINT), 200, "OK") + ), + ) def test_get_action_entry_point_by_ref(self): - actual_entrypoint = self.client.actions.get_entrypoint(EXECUTION['action']['ref']) + actual_entrypoint = self.client.actions.get_entrypoint( + EXECUTION["action"]["ref"] + ) actual_entrypoint = json.loads(actual_entrypoint) - endpoint = '/actions/views/entry_point/%s' % EXECUTION['action']['ref'] + endpoint = "/actions/views/entry_point/%s" % EXECUTION["action"]["ref"] httpclient.HTTPClient.get.assert_called_with(endpoint) self.assertEqual(ENTRYPOINT, actual_entrypoint) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(ENTRYPOINT), 200, 'OK'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(ENTRYPOINT), 200, "OK") + ), + ) def test_get_action_entry_point_by_id(self): - actual_entrypoint = self.client.actions.get_entrypoint(EXECUTION['id']) + actual_entrypoint = self.client.actions.get_entrypoint(EXECUTION["id"]) actual_entrypoint = json.loads(actual_entrypoint) - endpoint = '/actions/views/entry_point/%s' % EXECUTION['id'] + endpoint = "/actions/views/entry_point/%s" % EXECUTION["id"] httpclient.HTTPClient.get.assert_called_with(endpoint) self.assertEqual(ENTRYPOINT, actual_entrypoint) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse( - json.dumps({}), 404, '404 Client Error: Not Found' - ))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps({}), 404, "404 Client Error: Not Found" + ) + ), + ) def test_get_non_existent_action_entry_point(self): - with self.assertRaisesRegexp(Exception, '404 Client Error: Not Found'): - self.client.actions.get_entrypoint('nonexistentpack.nonexistentaction') + with self.assertRaisesRegexp(Exception, "404 Client Error: Not Found"): + self.client.actions.get_entrypoint("nonexistentpack.nonexistentaction") - endpoint = '/actions/views/entry_point/%s' % 'nonexistentpack.nonexistentaction' + endpoint = "/actions/views/entry_point/%s" % "nonexistentpack.nonexistentaction" httpclient.HTTPClient.get.assert_called_with(endpoint) diff --git a/st2client/tests/unit/test_client_executions.py b/st2client/tests/unit/test_client_executions.py index 0470347ee21..a9dc19e2c3e 100644 --- a/st2client/tests/unit/test_client_executions.py +++ b/st2client/tests/unit/test_client_executions.py @@ -34,9 +34,7 @@ RUNNER = { "enabled": True, "name": "marathon", - "runner_parameters": { - "var1": {"type": "string"} - } + "runner_parameters": {"var1": {"type": "string"}}, } ACTION = { @@ -46,185 +44,227 @@ "parameters": {}, "enabled": True, "entry_point": "", - "pack": "mocke" + "pack": "mocke", } EXECUTION = { "id": 12345, - "action": { - "ref": "mock.foobar" - }, + "action": {"ref": "mock.foobar"}, "status": "failed", - "result": "non-empty" + "result": "non-empty", } class TestExecutionResourceManager(unittest2.TestCase): - @classmethod def setUpClass(cls): super(TestExecutionResourceManager, cls).setUpClass() cls.client = client.Client() @mock.patch.object( - models.ResourceManager, 'get_by_id', - mock.MagicMock(return_value=models.Execution(**EXECUTION))) + models.ResourceManager, + "get_by_id", + mock.MagicMock(return_value=models.Execution(**EXECUTION)), + ) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(return_value=models.Action(**ACTION))) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(return_value=models.Action(**ACTION)), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(return_value=models.RunnerType(**RUNNER))) + models.ResourceManager, + "get_by_name", + mock.MagicMock(return_value=models.RunnerType(**RUNNER)), + ) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(EXECUTION), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(EXECUTION), 200, "OK") + ), + ) def test_rerun_with_no_params(self): - self.client.executions.re_run(EXECUTION['id'], tasks=['foobar']) + self.client.executions.re_run(EXECUTION["id"], tasks=["foobar"]) - endpoint = '/executions/%s/re_run' % EXECUTION['id'] + endpoint = "/executions/%s/re_run" % EXECUTION["id"] - data = { - 'tasks': ['foobar'], - 'reset': ['foobar'], - 'parameters': {}, - 'delay': 0 - } + data = {"tasks": ["foobar"], "reset": ["foobar"], "parameters": {}, "delay": 0} httpclient.HTTPClient.post.assert_called_with(endpoint, data) @mock.patch.object( - models.ResourceManager, 'get_by_id', - mock.MagicMock(return_value=models.Execution(**EXECUTION))) + models.ResourceManager, + "get_by_id", + mock.MagicMock(return_value=models.Execution(**EXECUTION)), + ) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(return_value=models.Action(**ACTION))) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(return_value=models.Action(**ACTION)), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(return_value=models.RunnerType(**RUNNER))) + models.ResourceManager, + "get_by_name", + mock.MagicMock(return_value=models.RunnerType(**RUNNER)), + ) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(EXECUTION), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(EXECUTION), 200, "OK") + ), + ) def test_rerun_with_params(self): - params = { - 'var1': 'testing...' - } + params = {"var1": "testing..."} self.client.executions.re_run( - EXECUTION['id'], - tasks=['foobar'], - parameters=params + EXECUTION["id"], tasks=["foobar"], parameters=params ) - endpoint = '/executions/%s/re_run' % EXECUTION['id'] + endpoint = "/executions/%s/re_run" % EXECUTION["id"] data = { - 'tasks': ['foobar'], - 'reset': ['foobar'], - 'parameters': params, - 'delay': 0 + "tasks": ["foobar"], + "reset": ["foobar"], + "parameters": params, + "delay": 0, } httpclient.HTTPClient.post.assert_called_with(endpoint, data) @mock.patch.object( - models.ResourceManager, 'get_by_id', - mock.MagicMock(return_value=models.Execution(**EXECUTION))) + models.ResourceManager, + "get_by_id", + mock.MagicMock(return_value=models.Execution(**EXECUTION)), + ) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(return_value=models.Action(**ACTION))) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(return_value=models.Action(**ACTION)), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(return_value=models.RunnerType(**RUNNER))) + models.ResourceManager, + "get_by_name", + mock.MagicMock(return_value=models.RunnerType(**RUNNER)), + ) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(EXECUTION), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(EXECUTION), 200, "OK") + ), + ) def test_rerun_with_delay(self): - self.client.executions.re_run(EXECUTION['id'], tasks=['foobar'], delay=100) + self.client.executions.re_run(EXECUTION["id"], tasks=["foobar"], delay=100) - endpoint = '/executions/%s/re_run' % EXECUTION['id'] + endpoint = "/executions/%s/re_run" % EXECUTION["id"] data = { - 'tasks': ['foobar'], - 'reset': ['foobar'], - 'parameters': {}, - 'delay': 100 + "tasks": ["foobar"], + "reset": ["foobar"], + "parameters": {}, + "delay": 100, } httpclient.HTTPClient.post.assert_called_with(endpoint, data) @mock.patch.object( - models.ResourceManager, 'get_by_id', - mock.MagicMock(return_value=models.Execution(**EXECUTION))) + models.ResourceManager, + "get_by_id", + mock.MagicMock(return_value=models.Execution(**EXECUTION)), + ) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(return_value=models.Action(**ACTION))) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(return_value=models.Action(**ACTION)), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(return_value=models.RunnerType(**RUNNER))) + models.ResourceManager, + "get_by_name", + mock.MagicMock(return_value=models.RunnerType(**RUNNER)), + ) @mock.patch.object( - httpclient.HTTPClient, 'put', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(EXECUTION), 200, 'OK'))) + httpclient.HTTPClient, + "put", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(EXECUTION), 200, "OK") + ), + ) def test_pause(self): - self.client.executions.pause(EXECUTION['id']) + self.client.executions.pause(EXECUTION["id"]) - endpoint = '/executions/%s' % EXECUTION['id'] + endpoint = "/executions/%s" % EXECUTION["id"] - data = { - 'status': 'pausing' - } + data = {"status": "pausing"} httpclient.HTTPClient.put.assert_called_with(endpoint, data) @mock.patch.object( - models.ResourceManager, 'get_by_id', - mock.MagicMock(return_value=models.Execution(**EXECUTION))) + models.ResourceManager, + "get_by_id", + mock.MagicMock(return_value=models.Execution(**EXECUTION)), + ) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(return_value=models.Action(**ACTION))) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(return_value=models.Action(**ACTION)), + ) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(return_value=models.RunnerType(**RUNNER))) + models.ResourceManager, + "get_by_name", + mock.MagicMock(return_value=models.RunnerType(**RUNNER)), + ) @mock.patch.object( - httpclient.HTTPClient, 'put', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(EXECUTION), 200, 'OK'))) + httpclient.HTTPClient, + "put", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(EXECUTION), 200, "OK") + ), + ) def test_resume(self): - self.client.executions.resume(EXECUTION['id']) + self.client.executions.resume(EXECUTION["id"]) - endpoint = '/executions/%s' % EXECUTION['id'] + endpoint = "/executions/%s" % EXECUTION["id"] - data = { - 'status': 'resuming' - } + data = {"status": "resuming"} httpclient.HTTPClient.put.assert_called_with(endpoint, data) @mock.patch.object( - models.core.Resource, 'get_url_path_name', - mock.MagicMock(return_value='executions')) + models.core.Resource, + "get_url_path_name", + mock.MagicMock(return_value="executions"), + ) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps([EXECUTION]), 200, 'OK'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps([EXECUTION]), 200, "OK") + ), + ) def test_get_children(self): - self.client.executions.get_children(EXECUTION['id']) + self.client.executions.get_children(EXECUTION["id"]) - endpoint = '/executions/%s/children' % EXECUTION['id'] + endpoint = "/executions/%s/children" % EXECUTION["id"] - data = { - 'depth': -1 - } + data = {"depth": -1} httpclient.HTTPClient.get.assert_called_with(url=endpoint, params=data) @mock.patch.object( - models.ResourceManager, 'get_all', - mock.MagicMock(return_value=[models.Execution(**EXECUTION)])) - @mock.patch.object(warnings, 'warn') - def test_st2client_liveactions_has_been_deprecated_and_emits_warning(self, mock_warn): + models.ResourceManager, + "get_all", + mock.MagicMock(return_value=[models.Execution(**EXECUTION)]), + ) + @mock.patch.object(warnings, "warn") + def test_st2client_liveactions_has_been_deprecated_and_emits_warning( + self, mock_warn + ): self.assertEqual(mock_warn.call_args, None) self.client.liveactions.get_all() - expected_msg = 'st2client.liveactions has been renamed' + expected_msg = "st2client.liveactions has been renamed" self.assertTrue(len(mock_warn.call_args_list) >= 1) self.assertIn(expected_msg, mock_warn.call_args_list[0][0][0]) self.assertEqual(mock_warn.call_args_list[0][0][1], DeprecationWarning) diff --git a/st2client/tests/unit/test_command_actionrun.py b/st2client/tests/unit/test_command_actionrun.py index 763ac649a68..1e312e0786b 100644 --- a/st2client/tests/unit/test_command_actionrun.py +++ b/st2client/tests/unit/test_command_actionrun.py @@ -21,73 +21,79 @@ import mock from st2client.commands.action import ActionRunCommand -from st2client.models.action import (Action, RunnerType) +from st2client.models.action import Action, RunnerType class ActionRunCommandTest(unittest2.TestCase): - def test_get_params_types(self): runner = RunnerType() runner_params = { - 'foo': {'immutable': True, 'required': True}, - 'bar': {'description': 'Some param.', 'type': 'string'} + "foo": {"immutable": True, "required": True}, + "bar": {"description": "Some param.", "type": "string"}, } runner.runner_parameters = runner_params orig_runner_params = copy.deepcopy(runner.runner_parameters) action = Action() action.parameters = { - 'foo': {'immutable': False}, # Should not be allowed by API. - 'stuff': {'description': 'Some param.', 'type': 'string', 'required': True} + "foo": {"immutable": False}, # Should not be allowed by API. + "stuff": {"description": "Some param.", "type": "string", "required": True}, } orig_action_params = copy.deepcopy(action.parameters) params, rqd, opt, imm = ActionRunCommand._get_params_types(runner, action) self.assertEqual(len(list(params.keys())), 3) - self.assertIn('foo', imm, '"foo" param should be in immutable set.') - self.assertNotIn('foo', rqd, '"foo" param should not be in required set.') - self.assertNotIn('foo', opt, '"foo" param should not be in optional set.') + self.assertIn("foo", imm, '"foo" param should be in immutable set.') + self.assertNotIn("foo", rqd, '"foo" param should not be in required set.') + self.assertNotIn("foo", opt, '"foo" param should not be in optional set.') - self.assertIn('bar', opt, '"bar" param should be in optional set.') - self.assertNotIn('bar', rqd, '"bar" param should not be in required set.') - self.assertNotIn('bar', imm, '"bar" param should not be in immutable set.') + self.assertIn("bar", opt, '"bar" param should be in optional set.') + self.assertNotIn("bar", rqd, '"bar" param should not be in required set.') + self.assertNotIn("bar", imm, '"bar" param should not be in immutable set.') - self.assertIn('stuff', rqd, '"stuff" param should be in required set.') - self.assertNotIn('stuff', opt, '"stuff" param should not be in optional set.') - self.assertNotIn('stuff', imm, '"stuff" param should not be in immutable set.') - self.assertEqual(runner.runner_parameters, orig_runner_params, 'Runner params modified.') - self.assertEqual(action.parameters, orig_action_params, 'Action params modified.') + self.assertIn("stuff", rqd, '"stuff" param should be in required set.') + self.assertNotIn("stuff", opt, '"stuff" param should not be in optional set.') + self.assertNotIn("stuff", imm, '"stuff" param should not be in immutable set.') + self.assertEqual( + runner.runner_parameters, orig_runner_params, "Runner params modified." + ) + self.assertEqual( + action.parameters, orig_action_params, "Action params modified." + ) def test_opt_in_dict_auto_convert(self): - """Test ability for user to opt-in to dict convert functionality - """ + """Test ability for user to opt-in to dict convert functionality""" runner = RunnerType() runner.runner_parameters = {} action = Action() - action.ref = 'test.action' + action.ref = "test.action" action.parameters = { - 'param_array': {'type': 'array'}, + "param_array": {"type": "array"}, } subparser = mock.Mock() - command = ActionRunCommand(action, self, subparser, name='test') + command = ActionRunCommand(action, self, subparser, name="test") mockarg = mock.Mock() mockarg.inherit_env = False mockarg.parameters = [ - 'param_array=foo:bar,foo2:bar2', + "param_array=foo:bar,foo2:bar2", ] mockarg.auto_dict = False - param = command._get_action_parameters_from_args(action=action, runner=runner, args=mockarg) - self.assertEqual(param['param_array'], ['foo:bar', 'foo2:bar2']) + param = command._get_action_parameters_from_args( + action=action, runner=runner, args=mockarg + ) + self.assertEqual(param["param_array"], ["foo:bar", "foo2:bar2"]) mockarg.auto_dict = True - param = command._get_action_parameters_from_args(action=action, runner=runner, args=mockarg) - self.assertEqual(param['param_array'], [{'foo': 'bar', 'foo2': 'bar2'}]) + param = command._get_action_parameters_from_args( + action=action, runner=runner, args=mockarg + ) + self.assertEqual(param["param_array"], [{"foo": "bar", "foo2": "bar2"}]) # set auto_dict back to default mockarg.auto_dict = False @@ -104,60 +110,65 @@ def test_get_params_from_args(self): runner.runner_parameters = {} action = Action() - action.ref = 'test.action' + action.ref = "test.action" action.parameters = { - 'param_string': {'type': 'string'}, - 'param_integer': {'type': 'integer'}, - 'param_number': {'type': 'number'}, - 'param_object': {'type': 'object'}, - 'param_boolean': {'type': 'boolean'}, - 'param_array': {'type': 'array'}, - 'param_array_of_dicts': {'type': 'array', 'properties': { - 'foo': {'type': 'string'}, - 'bar': {'type': 'integer'}, - 'baz': {'type': 'number'}, - 'qux': {'type': 'object'}, - 'quux': {'type': 'boolean'}} + "param_string": {"type": "string"}, + "param_integer": {"type": "integer"}, + "param_number": {"type": "number"}, + "param_object": {"type": "object"}, + "param_boolean": {"type": "boolean"}, + "param_array": {"type": "array"}, + "param_array_of_dicts": { + "type": "array", + "properties": { + "foo": {"type": "string"}, + "bar": {"type": "integer"}, + "baz": {"type": "number"}, + "qux": {"type": "object"}, + "quux": {"type": "boolean"}, + }, }, } subparser = mock.Mock() - command = ActionRunCommand(action, self, subparser, name='test') + command = ActionRunCommand(action, self, subparser, name="test") mockarg = mock.Mock() mockarg.inherit_env = False mockarg.auto_dict = True mockarg.parameters = [ - 'param_string=hoge', - 'param_integer=123', - 'param_number=1.23', - 'param_object=hoge=1,fuga=2', - 'param_boolean=False', - 'param_array=foo,bar,baz', - 'param_array_of_dicts=foo:HOGE,bar:1,baz:1.23,qux:foo=bar,quux:True', - 'param_array_of_dicts=foo:FUGA,bar:2,baz:2.34,qux:bar=baz,quux:False' + "param_string=hoge", + "param_integer=123", + "param_number=1.23", + "param_object=hoge=1,fuga=2", + "param_boolean=False", + "param_array=foo,bar,baz", + "param_array_of_dicts=foo:HOGE,bar:1,baz:1.23,qux:foo=bar,quux:True", + "param_array_of_dicts=foo:FUGA,bar:2,baz:2.34,qux:bar=baz,quux:False", ] - param = command._get_action_parameters_from_args(action=action, runner=runner, args=mockarg) + param = command._get_action_parameters_from_args( + action=action, runner=runner, args=mockarg + ) self.assertIsInstance(param, dict) - self.assertEqual(param['param_string'], 'hoge') - self.assertEqual(param['param_integer'], 123) - self.assertEqual(param['param_number'], 1.23) - self.assertEqual(param['param_object'], {'hoge': '1', 'fuga': '2'}) - self.assertFalse(param['param_boolean']) - self.assertEqual(param['param_array'], ['foo', 'bar', 'baz']) + self.assertEqual(param["param_string"], "hoge") + self.assertEqual(param["param_integer"], 123) + self.assertEqual(param["param_number"], 1.23) + self.assertEqual(param["param_object"], {"hoge": "1", "fuga": "2"}) + self.assertFalse(param["param_boolean"]) + self.assertEqual(param["param_array"], ["foo", "bar", "baz"]) # checking the result of parsing for array of objects - self.assertIsInstance(param['param_array_of_dicts'], list) - self.assertEqual(len(param['param_array_of_dicts']), 2) - for param in param['param_array_of_dicts']: + self.assertIsInstance(param["param_array_of_dicts"], list) + self.assertEqual(len(param["param_array_of_dicts"]), 2) + for param in param["param_array_of_dicts"]: self.assertIsInstance(param, dict) - self.assertIsInstance(param['foo'], str) - self.assertIsInstance(param['bar'], int) - self.assertIsInstance(param['baz'], float) - self.assertIsInstance(param['qux'], dict) - self.assertIsInstance(param['quux'], bool) + self.assertIsInstance(param["foo"], str) + self.assertIsInstance(param["bar"], int) + self.assertIsInstance(param["baz"], float) + self.assertIsInstance(param["qux"], dict) + self.assertIsInstance(param["quux"], bool) # set auto_dict back to default mockarg.auto_dict = False @@ -167,36 +178,38 @@ def test_get_params_from_args_read_content_from_file(self): runner.runner_parameters = {} action = Action() - action.ref = 'test.action' + action.ref = "test.action" action.parameters = { - 'param_object': {'type': 'object'}, + "param_object": {"type": "object"}, } subparser = mock.Mock() - command = ActionRunCommand(action, self, subparser, name='test') + command = ActionRunCommand(action, self, subparser, name="test") # 1. File doesn't exist mockarg = mock.Mock() mockarg.inherit_env = False mockarg.auto_dict = True - mockarg.parameters = [ - '@param_object=doesnt-exist.json' - ] + mockarg.parameters = ["@param_object=doesnt-exist.json"] - self.assertRaisesRegex(ValueError, "doesn't exist", - command._get_action_parameters_from_args, action=action, - runner=runner, args=mockarg) + self.assertRaisesRegex( + ValueError, + "doesn't exist", + command._get_action_parameters_from_args, + action=action, + runner=runner, + args=mockarg, + ) # 2. Valid file path (we simply read this file) mockarg = mock.Mock() mockarg.inherit_env = False mockarg.auto_dict = True - mockarg.parameters = [ - '@param_string=%s' % (__file__) - ] + mockarg.parameters = ["@param_string=%s" % (__file__)] - params = command._get_action_parameters_from_args(action=action, - runner=runner, args=mockarg) + params = command._get_action_parameters_from_args( + action=action, runner=runner, args=mockarg + ) self.assertTrue(isinstance(params["param_string"], six.text_type)) self.assertTrue(params["param_string"].startswith("# Copyright")) @@ -212,37 +225,39 @@ def test_get_params_from_args_with_multiple_declarations(self): runner.runner_parameters = {} action = Action() - action.ref = 'test.action' + action.ref = "test.action" action.parameters = { - 'param_string': {'type': 'string'}, - 'param_array': {'type': 'array'}, - 'param_array_of_dicts': {'type': 'array'}, + "param_string": {"type": "string"}, + "param_array": {"type": "array"}, + "param_array_of_dicts": {"type": "array"}, } subparser = mock.Mock() - command = ActionRunCommand(action, self, subparser, name='test') + command = ActionRunCommand(action, self, subparser, name="test") mockarg = mock.Mock() mockarg.inherit_env = False mockarg.auto_dict = True mockarg.parameters = [ - 'param_string=hoge', # This value will be overwritten with the next declaration. - 'param_string=fuga', - 'param_array=foo', - 'param_array=bar', - 'param_array_of_dicts=foo:1,bar:2', - 'param_array_of_dicts=hoge:A,fuga:B' + "param_string=hoge", # This value will be overwritten with the next declaration. + "param_string=fuga", + "param_array=foo", + "param_array=bar", + "param_array_of_dicts=foo:1,bar:2", + "param_array_of_dicts=hoge:A,fuga:B", ] - param = command._get_action_parameters_from_args(action=action, runner=runner, args=mockarg) + param = command._get_action_parameters_from_args( + action=action, runner=runner, args=mockarg + ) # checks to accept multiple declaration only if the array type - self.assertEqual(param['param_string'], 'fuga') - self.assertEqual(param['param_array'], ['foo', 'bar']) - self.assertEqual(param['param_array_of_dicts'], [ - {'foo': '1', 'bar': '2'}, - {'hoge': 'A', 'fuga': 'B'} - ]) + self.assertEqual(param["param_string"], "fuga") + self.assertEqual(param["param_array"], ["foo", "bar"]) + self.assertEqual( + param["param_array_of_dicts"], + [{"foo": "1", "bar": "2"}, {"hoge": "A", "fuga": "B"}], + ) # set auto_dict back to default mockarg.auto_dict = False diff --git a/st2client/tests/unit/test_commands.py b/st2client/tests/unit/test_commands.py index 0748a4aeecc..de84f7883f0 100644 --- a/st2client/tests/unit/test_commands.py +++ b/st2client/tests/unit/test_commands.py @@ -32,97 +32,117 @@ from st2client.commands import resource from st2client.commands.resource import ResourceViewCommand -__all__ = [ - 'TestResourceCommand', - 'ResourceViewCommandTestCase' -] +__all__ = ["TestResourceCommand", "ResourceViewCommandTestCase"] LOG = logging.getLogger(__name__) class TestResourceCommand(unittest2.TestCase): - def __init__(self, *args, **kwargs): super(TestResourceCommand, self).__init__(*args, **kwargs) self.parser = argparse.ArgumentParser() self.subparsers = self.parser.add_subparsers() self.branch = resource.ResourceBranch( - base.FakeResource, 'Test Command', base.FakeApp(), self.subparsers) + base.FakeResource, "Test Command", base.FakeApp(), self.subparsers + ) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, 'OK'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, "OK") + ), + ) def test_command_list(self): - args = self.parser.parse_args(['fakeresource', 'list']) - self.assertEqual(args.func, self.branch.commands['list'].run_and_print) - instances = self.branch.commands['list'].run(args) + args = self.parser.parse_args(["fakeresource", "list"]) + self.assertEqual(args.func, self.branch.commands["list"].run_and_print) + instances = self.branch.commands["list"].run(args) actual = [instance.serialize() for instance in instances] expected = json.loads(json.dumps(base.RESOURCES)) self.assertListEqual(actual, expected) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse('', 500, 'INTERNAL SERVER ERROR'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse("", 500, "INTERNAL SERVER ERROR") + ), + ) def test_command_list_failed(self): - args = self.parser.parse_args(['fakeresource', 'list']) - self.assertRaises(Exception, self.branch.commands['list'].run, args) + args = self.parser.parse_args(["fakeresource", "list"]) + self.assertRaises(Exception, self.branch.commands["list"].run, args) @mock.patch.object( - models.ResourceManager, 'get_by_name', - mock.MagicMock(return_value=None)) + models.ResourceManager, "get_by_name", mock.MagicMock(return_value=None) + ) @mock.patch.object( - models.ResourceManager, 'get_by_id', - mock.MagicMock(return_value=base.FakeResource(**base.RESOURCES[0]))) + models.ResourceManager, + "get_by_id", + mock.MagicMock(return_value=base.FakeResource(**base.RESOURCES[0])), + ) def test_command_get_by_id(self): - args = self.parser.parse_args(['fakeresource', 'get', '123']) - self.assertEqual(args.func, self.branch.commands['get'].run_and_print) - instance = self.branch.commands['get'].run(args) + args = self.parser.parse_args(["fakeresource", "get", "123"]) + self.assertEqual(args.func, self.branch.commands["get"].run_and_print) + instance = self.branch.commands["get"].run(args) actual = instance.serialize() expected = json.loads(json.dumps(base.RESOURCES[0])) self.assertEqual(actual, expected) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(base.RESOURCES[0]), 200, 'OK'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(base.RESOURCES[0]), 200, "OK") + ), + ) def test_command_get(self): - args = self.parser.parse_args(['fakeresource', 'get', 'abc']) - self.assertEqual(args.func, self.branch.commands['get'].run_and_print) - instance = self.branch.commands['get'].run(args) + args = self.parser.parse_args(["fakeresource", "get", "abc"]) + self.assertEqual(args.func, self.branch.commands["get"].run_and_print) + instance = self.branch.commands["get"].run(args) actual = instance.serialize() expected = json.loads(json.dumps(base.RESOURCES[0])) self.assertEqual(actual, expected) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse('', 404, 'NOT FOUND'))) + httpclient.HTTPClient, + "get", + mock.MagicMock(return_value=base.FakeResponse("", 404, "NOT FOUND")), + ) def test_command_get_404(self): - args = self.parser.parse_args(['fakeresource', 'get', 'cba']) - self.assertEqual(args.func, self.branch.commands['get'].run_and_print) - self.assertRaises(resource.ResourceNotFoundError, - self.branch.commands['get'].run, - args) + args = self.parser.parse_args(["fakeresource", "get", "cba"]) + self.assertEqual(args.func, self.branch.commands["get"].run_and_print) + self.assertRaises( + resource.ResourceNotFoundError, self.branch.commands["get"].run, args + ) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse('', 500, 'INTERNAL SERVER ERROR'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse("", 500, "INTERNAL SERVER ERROR") + ), + ) def test_command_get_failed(self): - args = self.parser.parse_args(['fakeresource', 'get', 'cba']) - self.assertRaises(Exception, self.branch.commands['get'].run, args) + args = self.parser.parse_args(["fakeresource", "get", "cba"]) + self.assertRaises(Exception, self.branch.commands["get"].run, args) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(base.RESOURCES[0]), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(base.RESOURCES[0]), 200, "OK") + ), + ) def test_command_create(self): - instance = base.FakeResource(name='abc') - fd, path = tempfile.mkstemp(suffix='.json') + instance = base.FakeResource(name="abc") + fd, path = tempfile.mkstemp(suffix=".json") try: - with open(path, 'a') as f: + with open(path, "a") as f: f.write(json.dumps(instance.serialize(), indent=4)) - args = self.parser.parse_args(['fakeresource', 'create', path]) - self.assertEqual(args.func, - self.branch.commands['create'].run_and_print) - instance = self.branch.commands['create'].run(args) + args = self.parser.parse_args(["fakeresource", "create", path]) + self.assertEqual(args.func, self.branch.commands["create"].run_and_print) + instance = self.branch.commands["create"].run(args) actual = instance.serialize() expected = json.loads(json.dumps(base.RESOURCES[0])) self.assertEqual(actual, expected) @@ -131,40 +151,49 @@ def test_command_create(self): os.unlink(path) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse('', 500, 'INTERNAL SERVER ERROR'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse("", 500, "INTERNAL SERVER ERROR") + ), + ) def test_command_create_failed(self): - instance = base.FakeResource(name='abc') - fd, path = tempfile.mkstemp(suffix='.json') + instance = base.FakeResource(name="abc") + fd, path = tempfile.mkstemp(suffix=".json") try: - with open(path, 'a') as f: + with open(path, "a") as f: f.write(json.dumps(instance.serialize(), indent=4)) - args = self.parser.parse_args(['fakeresource', 'create', path]) - self.assertRaises(Exception, - self.branch.commands['create'].run, - args) + args = self.parser.parse_args(["fakeresource", "create", path]) + self.assertRaises(Exception, self.branch.commands["create"].run, args) finally: os.close(fd) os.unlink(path) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, 'OK', - {}))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps([base.RESOURCES[0]]), 200, "OK", {} + ) + ), + ) @mock.patch.object( - httpclient.HTTPClient, 'put', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(base.RESOURCES[0]), 200, 'OK'))) + httpclient.HTTPClient, + "put", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(base.RESOURCES[0]), 200, "OK") + ), + ) def test_command_update(self): - instance = base.FakeResource(id='123', name='abc') - fd, path = tempfile.mkstemp(suffix='.json') + instance = base.FakeResource(id="123", name="abc") + fd, path = tempfile.mkstemp(suffix=".json") try: - with open(path, 'a') as f: + with open(path, "a") as f: f.write(json.dumps(instance.serialize(), indent=4)) - args = self.parser.parse_args( - ['fakeresource', 'update', '123', path]) - self.assertEqual(args.func, - self.branch.commands['update'].run_and_print) - instance = self.branch.commands['update'].run(args) + args = self.parser.parse_args(["fakeresource", "update", "123", path]) + self.assertEqual(args.func, self.branch.commands["update"].run_and_print) + instance = self.branch.commands["update"].run(args) actual = instance.serialize() expected = json.loads(json.dumps(base.RESOURCES[0])) self.assertEqual(actual, expected) @@ -173,122 +202,142 @@ def test_command_update(self): os.unlink(path) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, 'OK'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, "OK") + ), + ) @mock.patch.object( - httpclient.HTTPClient, 'put', - mock.MagicMock(return_value=base.FakeResponse('', 500, 'INTERNAL SERVER ERROR'))) + httpclient.HTTPClient, + "put", + mock.MagicMock( + return_value=base.FakeResponse("", 500, "INTERNAL SERVER ERROR") + ), + ) def test_command_update_failed(self): - instance = base.FakeResource(id='123', name='abc') - fd, path = tempfile.mkstemp(suffix='.json') + instance = base.FakeResource(id="123", name="abc") + fd, path = tempfile.mkstemp(suffix=".json") try: - with open(path, 'a') as f: + with open(path, "a") as f: f.write(json.dumps(instance.serialize(), indent=4)) - args = self.parser.parse_args( - ['fakeresource', 'update', '123', path]) - self.assertRaises(Exception, - self.branch.commands['update'].run, - args) + args = self.parser.parse_args(["fakeresource", "update", "123", path]) + self.assertRaises(Exception, self.branch.commands["update"].run, args) finally: os.close(fd) os.unlink(path) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, 'OK'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, "OK") + ), + ) def test_command_update_id_mismatch(self): - instance = base.FakeResource(id='789', name='abc') - fd, path = tempfile.mkstemp(suffix='.json') + instance = base.FakeResource(id="789", name="abc") + fd, path = tempfile.mkstemp(suffix=".json") try: - with open(path, 'a') as f: + with open(path, "a") as f: f.write(json.dumps(instance.serialize(), indent=4)) - args = self.parser.parse_args( - ['fakeresource', 'update', '123', path]) - self.assertRaises(Exception, - self.branch.commands['update'].run, - args) + args = self.parser.parse_args(["fakeresource", "update", "123", path]) + self.assertRaises(Exception, self.branch.commands["update"].run, args) finally: os.close(fd) os.unlink(path) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, 'OK', - {}))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps([base.RESOURCES[0]]), 200, "OK", {} + ) + ), + ) @mock.patch.object( - httpclient.HTTPClient, 'delete', - mock.MagicMock(return_value=base.FakeResponse('', 204, 'NO CONTENT'))) + httpclient.HTTPClient, + "delete", + mock.MagicMock(return_value=base.FakeResponse("", 204, "NO CONTENT")), + ) def test_command_delete(self): - args = self.parser.parse_args(['fakeresource', 'delete', 'abc']) - self.assertEqual(args.func, - self.branch.commands['delete'].run_and_print) - self.branch.commands['delete'].run(args) + args = self.parser.parse_args(["fakeresource", "delete", "abc"]) + self.assertEqual(args.func, self.branch.commands["delete"].run_and_print) + self.branch.commands["delete"].run(args) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse('', 404, 'NOT FOUND'))) + httpclient.HTTPClient, + "get", + mock.MagicMock(return_value=base.FakeResponse("", 404, "NOT FOUND")), + ) def test_command_delete_404(self): - args = self.parser.parse_args(['fakeresource', 'delete', 'cba']) - self.assertEqual(args.func, - self.branch.commands['delete'].run_and_print) - self.assertRaises(resource.ResourceNotFoundError, - self.branch.commands['delete'].run, - args) + args = self.parser.parse_args(["fakeresource", "delete", "cba"]) + self.assertEqual(args.func, self.branch.commands["delete"].run_and_print) + self.assertRaises( + resource.ResourceNotFoundError, self.branch.commands["delete"].run, args + ) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, 'OK'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, "OK") + ), + ) @mock.patch.object( - httpclient.HTTPClient, 'delete', - mock.MagicMock(return_value=base.FakeResponse('', 500, 'INTERNAL SERVER ERROR'))) + httpclient.HTTPClient, + "delete", + mock.MagicMock( + return_value=base.FakeResponse("", 500, "INTERNAL SERVER ERROR") + ), + ) def test_command_delete_failed(self): - args = self.parser.parse_args(['fakeresource', 'delete', 'cba']) - self.assertRaises(Exception, self.branch.commands['delete'].run, args) + args = self.parser.parse_args(["fakeresource", "delete", "cba"]) + self.assertRaises(Exception, self.branch.commands["delete"].run, args) class ResourceViewCommandTestCase(unittest2.TestCase): - def setUp(self): ResourceViewCommand.display_attributes = [] def test_get_include_attributes(self): - cls = namedtuple('Args', 'attr') + cls = namedtuple("Args", "attr") args = cls(attr=[]) result = ResourceViewCommand._get_include_attributes(args=args) self.assertEqual(result, []) - args = cls(attr=['result']) + args = cls(attr=["result"]) result = ResourceViewCommand._get_include_attributes(args=args) - self.assertEqual(result, ['result']) + self.assertEqual(result, ["result"]) - args = cls(attr=['result', 'trigger_instance']) + args = cls(attr=["result", "trigger_instance"]) result = ResourceViewCommand._get_include_attributes(args=args) - self.assertEqual(result, ['result', 'trigger_instance']) + self.assertEqual(result, ["result", "trigger_instance"]) - args = cls(attr=['result.stdout']) + args = cls(attr=["result.stdout"]) result = ResourceViewCommand._get_include_attributes(args=args) - self.assertEqual(result, ['result.stdout']) + self.assertEqual(result, ["result.stdout"]) - args = cls(attr=['result.stdout', 'result.stderr']) + args = cls(attr=["result.stdout", "result.stderr"]) result = ResourceViewCommand._get_include_attributes(args=args) - self.assertEqual(result, ['result.stdout', 'result.stderr']) + self.assertEqual(result, ["result.stdout", "result.stderr"]) - args = cls(attr=['result.stdout', 'trigger_instance.id']) + args = cls(attr=["result.stdout", "trigger_instance.id"]) result = ResourceViewCommand._get_include_attributes(args=args) - self.assertEqual(result, ['result.stdout', 'trigger_instance.id']) + self.assertEqual(result, ["result.stdout", "trigger_instance.id"]) - ResourceViewCommand.display_attributes = ['id', 'status'] + ResourceViewCommand.display_attributes = ["id", "status"] args = cls(attr=[]) result = ResourceViewCommand._get_include_attributes(args=args) - self.assertEqual(set(result), set(['id', 'status'])) + self.assertEqual(set(result), set(["id", "status"])) - args = cls(attr=['trigger_instance']) + args = cls(attr=["trigger_instance"]) result = ResourceViewCommand._get_include_attributes(args=args) - self.assertEqual(set(result), set(['trigger_instance'])) + self.assertEqual(set(result), set(["trigger_instance"])) - args = cls(attr=['all']) + args = cls(attr=["all"]) result = ResourceViewCommand._get_include_attributes(args=args) self.assertEqual(result, None) @@ -303,20 +352,19 @@ class CommandsHelpStringTestCase(BaseCLITestCase): # TODO: Automatically iterate all the available commands COMMANDS = [ # action - ['action', 'list'], - ['action', 'get'], - ['action', 'create'], - ['action', 'update'], - ['action', 'delete'], - ['action', 'enable'], - ['action', 'disable'], - ['action', 'execute'], - + ["action", "list"], + ["action", "get"], + ["action", "create"], + ["action", "update"], + ["action", "delete"], + ["action", "enable"], + ["action", "disable"], + ["action", "execute"], # execution - ['execution', 'cancel'], - ['execution', 'pause'], - ['execution', 'resume'], - ['execution', 'tail'] + ["execution", "cancel"], + ["execution", "pause"], + ["execution", "resume"], + ["execution", "tail"], ] def test_help_command_line_arg_works_for_supported_commands(self): @@ -324,7 +372,7 @@ def test_help_command_line_arg_works_for_supported_commands(self): for command in self.COMMANDS: # First test longhang notation - argv = command + ['--help'] + argv = command + ["--help"] try: result = shell.run(argv) @@ -335,16 +383,16 @@ def test_help_command_line_arg_works_for_supported_commands(self): stdout = self.stdout.getvalue() - self.assertIn('usage:', stdout) - self.assertIn(' '.join(command), stdout) + self.assertIn("usage:", stdout) + self.assertIn(" ".join(command), stdout) # self.assertIn('positional arguments:', stdout) - self.assertIn('optional arguments:', stdout) + self.assertIn("optional arguments:", stdout) # Reset stdout and stderr after each iteration self._reset_output_streams() # Then shorthand notation - argv = command + ['-h'] + argv = command + ["-h"] try: result = shell.run(argv) @@ -355,14 +403,14 @@ def test_help_command_line_arg_works_for_supported_commands(self): stdout = self.stdout.getvalue() - self.assertIn('usage:', stdout) - self.assertIn(' '.join(command), stdout) + self.assertIn("usage:", stdout) + self.assertIn(" ".join(command), stdout) # self.assertIn('positional arguments:', stdout) - self.assertIn('optional arguments:', stdout) + self.assertIn("optional arguments:", stdout) # Verify that the actual help usage string was triggered and not the invalid # "too few arguments" which would indicate command doesn't actually correctly handle # --help flag - self.assertNotIn('too few arguments', stdout) + self.assertNotIn("too few arguments", stdout) self._reset_output_streams() diff --git a/st2client/tests/unit/test_config_parser.py b/st2client/tests/unit/test_config_parser.py index 35a125ebebb..9cea63ee5aa 100644 --- a/st2client/tests/unit/test_config_parser.py +++ b/st2client/tests/unit/test_config_parser.py @@ -26,80 +26,77 @@ from st2client.config_parser import CONFIG_DEFAULT_VALUES BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -CONFIG_FILE_PATH_FULL = os.path.join(BASE_DIR, '../fixtures/st2rc.full.ini') -CONFIG_FILE_PATH_PARTIAL = os.path.join(BASE_DIR, '../fixtures/st2rc.partial.ini') -CONFIG_FILE_PATH_UNICODE = os.path.join(BASE_DIR, '../fixtures/test_unicode.ini') +CONFIG_FILE_PATH_FULL = os.path.join(BASE_DIR, "../fixtures/st2rc.full.ini") +CONFIG_FILE_PATH_PARTIAL = os.path.join(BASE_DIR, "../fixtures/st2rc.partial.ini") +CONFIG_FILE_PATH_UNICODE = os.path.join(BASE_DIR, "../fixtures/test_unicode.ini") class CLIConfigParserTestCase(unittest2.TestCase): def test_constructor(self): - parser = CLIConfigParser(config_file_path='doesnotexist', validate_config_exists=False) + parser = CLIConfigParser( + config_file_path="doesnotexist", validate_config_exists=False + ) self.assertTrue(parser) - self.assertRaises(ValueError, CLIConfigParser, config_file_path='doestnotexist', - validate_config_exists=True) + self.assertRaises( + ValueError, + CLIConfigParser, + config_file_path="doestnotexist", + validate_config_exists=True, + ) def test_parse(self): # File doesn't exist - parser = CLIConfigParser(config_file_path='doesnotexist', validate_config_exists=False) + parser = CLIConfigParser( + config_file_path="doesnotexist", validate_config_exists=False + ) result = parser.parse() self.assertEqual(CONFIG_DEFAULT_VALUES, result) # File exists - all the options specified expected = { - 'general': { - 'base_url': 'http://127.0.0.1', - 'api_version': 'v1', - 'cacert': 'cacartpath', - 'silence_ssl_warnings': False, - 'silence_schema_output': True + "general": { + "base_url": "http://127.0.0.1", + "api_version": "v1", + "cacert": "cacartpath", + "silence_ssl_warnings": False, + "silence_schema_output": True, }, - 'cli': { - 'debug': True, - 'cache_token': False, - 'timezone': 'UTC' - }, - 'credentials': { - 'username': 'test1', - 'password': 'test1', - 'api_key': None - }, - 'api': { - 'url': 'http://127.0.0.1:9101/v1' - }, - 'auth': { - 'url': 'http://127.0.0.1:9100/' - }, - 'stream': { - 'url': 'http://127.0.0.1:9102/v1/stream' - } + "cli": {"debug": True, "cache_token": False, "timezone": "UTC"}, + "credentials": {"username": "test1", "password": "test1", "api_key": None}, + "api": {"url": "http://127.0.0.1:9101/v1"}, + "auth": {"url": "http://127.0.0.1:9100/"}, + "stream": {"url": "http://127.0.0.1:9102/v1/stream"}, } - parser = CLIConfigParser(config_file_path=CONFIG_FILE_PATH_FULL, - validate_config_exists=False) + parser = CLIConfigParser( + config_file_path=CONFIG_FILE_PATH_FULL, validate_config_exists=False + ) result = parser.parse() self.assertEqual(expected, result) # File exists - missing options, test defaults - parser = CLIConfigParser(config_file_path=CONFIG_FILE_PATH_PARTIAL, - validate_config_exists=False) + parser = CLIConfigParser( + config_file_path=CONFIG_FILE_PATH_PARTIAL, validate_config_exists=False + ) result = parser.parse() - self.assertTrue(result['cli']['cache_token'], True) + self.assertTrue(result["cli"]["cache_token"], True) def test_get_config_for_unicode_char(self): - parser = CLIConfigParser(config_file_path=CONFIG_FILE_PATH_UNICODE, - validate_config_exists=False) + parser = CLIConfigParser( + config_file_path=CONFIG_FILE_PATH_UNICODE, validate_config_exists=False + ) config = parser.parse() if six.PY3: - self.assertEqual(config['credentials']['password'], '密码') + self.assertEqual(config["credentials"]["password"], "密码") else: - self.assertEqual(config['credentials']['password'], u'\u5bc6\u7801') + self.assertEqual(config["credentials"]["password"], "\u5bc6\u7801") class CLIConfigPermissionsTestCase(unittest2.TestCase): def setUp(self): - self.TEMP_FILE_PATH = os.path.join('st2config', '.st2', 'config') + self.TEMP_FILE_PATH = os.path.join("st2config", ".st2", "config") self.TEMP_CONFIG_DIR = os.path.dirname(self.TEMP_FILE_PATH) if os.path.exists(self.TEMP_FILE_PATH): @@ -135,7 +132,9 @@ def test_correct_permissions_emit_no_warnings(self): self.assertEqual(os.stat(self.TEMP_FILE_PATH).st_mode & 0o777, 0o660) - parser = CLIConfigParser(config_file_path=self.TEMP_FILE_PATH, validate_config_exists=True) + parser = CLIConfigParser( + config_file_path=self.TEMP_FILE_PATH, validate_config_exists=True + ) parser.LOG = mock.Mock() result = parser.parse() # noqa F841 @@ -159,7 +158,9 @@ def test_weird_but_correct_permissions_emit_no_warnings(self): self.assertEqual(os.stat(self.TEMP_FILE_PATH).st_mode & 0o777, 0o640) - parser = CLIConfigParser(config_file_path=self.TEMP_FILE_PATH, validate_config_exists=True) + parser = CLIConfigParser( + config_file_path=self.TEMP_FILE_PATH, validate_config_exists=True + ) parser.LOG = mock.Mock() result = parser.parse() # noqa F841 @@ -175,7 +176,9 @@ def test_weird_but_correct_permissions_emit_no_warnings(self): self.assertEqual(os.stat(self.TEMP_FILE_PATH).st_mode & 0o777, 0o600) - parser = CLIConfigParser(config_file_path=self.TEMP_FILE_PATH, validate_config_exists=True) + parser = CLIConfigParser( + config_file_path=self.TEMP_FILE_PATH, validate_config_exists=True + ) parser.LOG = mock.Mock() result = parser.parse() # noqa F841 @@ -200,7 +203,9 @@ def test_warn_on_bad_config_permissions(self): self.assertNotEqual(os.stat(self.TEMP_FILE_PATH).st_mode & 0o777, 0o770) - parser = CLIConfigParser(config_file_path=self.TEMP_FILE_PATH, validate_config_exists=True) + parser = CLIConfigParser( + config_file_path=self.TEMP_FILE_PATH, validate_config_exists=True + ) parser.LOG = mock.Mock() result = parser.parse() # noqa F841 @@ -209,17 +214,20 @@ def test_warn_on_bad_config_permissions(self): self.assertEqual( "The SGID bit is not set on the StackStorm configuration directory.", - parser.LOG.info.call_args_list[0][0][0]) + parser.LOG.info.call_args_list[0][0][0], + ) self.assertEqual(parser.LOG.warn.call_count, 2) self.assertEqual( "The StackStorm configuration directory permissions are insecure " "(too permissive): others have access.", - parser.LOG.warn.call_args_list[0][0][0]) + parser.LOG.warn.call_args_list[0][0][0], + ) self.assertEqual( "The StackStorm configuration file permissions are insecure: others have access.", - parser.LOG.warn.call_args_list[1][0][0]) + parser.LOG.warn.call_args_list[1][0][0], + ) # Make sure we left the file alone self.assertTrue(os.path.exists(self.TEMP_FILE_PATH)) @@ -239,9 +247,11 @@ def test_disable_permissions_warnings(self): self.assertNotEqual(os.stat(self.TEMP_FILE_PATH).st_mode & 0o777, 0o770) - parser = CLIConfigParser(config_file_path=self.TEMP_FILE_PATH, - validate_config_exists=True, - validate_config_permissions=False) + parser = CLIConfigParser( + config_file_path=self.TEMP_FILE_PATH, + validate_config_exists=True, + validate_config_permissions=False, + ) parser.LOG = mock.Mock() result = parser.parse() # noqa F841 diff --git a/st2client/tests/unit/test_execution_tail_command.py b/st2client/tests/unit/test_execution_tail_command.py index 15500767f2b..08957ddbf16 100644 --- a/st2client/tests/unit/test_execution_tail_command.py +++ b/st2client/tests/unit/test_execution_tail_command.py @@ -27,247 +27,180 @@ from st2client.commands.action import LIVEACTION_STATUS_TIMED_OUT from st2client.shell import Shell -__all__ = [ - 'ActionExecutionTailCommandTestCase' -] +__all__ = ["ActionExecutionTailCommandTestCase"] # Mock objects -MOCK_LIVEACTION_1_RUNNING = { - 'id': 'idfoo1', - 'status': LIVEACTION_STATUS_RUNNING -} +MOCK_LIVEACTION_1_RUNNING = {"id": "idfoo1", "status": LIVEACTION_STATUS_RUNNING} -MOCK_LIVEACTION_1_SUCCEEDED = { - 'id': 'idfoo1', - 'status': LIVEACTION_STATUS_SUCCEEDED -} +MOCK_LIVEACTION_1_SUCCEEDED = {"id": "idfoo1", "status": LIVEACTION_STATUS_SUCCEEDED} -MOCK_LIVEACTION_2_FAILED = { - 'id': 'idfoo2', - 'status': LIVEACTION_STATUS_FAILED -} +MOCK_LIVEACTION_2_FAILED = {"id": "idfoo2", "status": LIVEACTION_STATUS_FAILED} # Mock liveaction objects for ActionChain workflow -MOCK_LIVEACTION_3_RUNNING = { - 'id': 'idfoo3', - 'status': LIVEACTION_STATUS_RUNNING -} +MOCK_LIVEACTION_3_RUNNING = {"id": "idfoo3", "status": LIVEACTION_STATUS_RUNNING} MOCK_LIVEACTION_3_CHILD_1_RUNNING = { - 'id': 'idchild1', - 'context': { - 'parent': { - 'execution_id': 'idfoo3' - }, - 'chain': { - 'name': 'task_1' - } - }, - 'status': LIVEACTION_STATUS_RUNNING + "id": "idchild1", + "context": {"parent": {"execution_id": "idfoo3"}, "chain": {"name": "task_1"}}, + "status": LIVEACTION_STATUS_RUNNING, } MOCK_LIVEACTION_3_CHILD_1_SUCCEEDED = { - 'id': 'idchild1', - 'context': { - 'parent': { - 'execution_id': 'idfoo3' - }, - 'chain': { - 'name': 'task_1' - } - }, - 'status': LIVEACTION_STATUS_SUCCEEDED + "id": "idchild1", + "context": {"parent": {"execution_id": "idfoo3"}, "chain": {"name": "task_1"}}, + "status": LIVEACTION_STATUS_SUCCEEDED, } MOCK_LIVEACTION_3_CHILD_1_OUTPUT_1 = { - 'execution_id': 'idchild1', - 'timestamp': '1505732598', - 'output_type': 'stdout', - 'data': 'line ac 4\n' + "execution_id": "idchild1", + "timestamp": "1505732598", + "output_type": "stdout", + "data": "line ac 4\n", } MOCK_LIVEACTION_3_CHILD_1_OUTPUT_2 = { - 'execution_id': 'idchild1', - 'timestamp': '1505732598', - 'output_type': 'stderr', - 'data': 'line ac 5\n' + "execution_id": "idchild1", + "timestamp": "1505732598", + "output_type": "stderr", + "data": "line ac 5\n", } MOCK_LIVEACTION_3_CHILD_2_RUNNING = { - 'id': 'idchild2', - 'context': { - 'parent': { - 'execution_id': 'idfoo3' - }, - 'chain': { - 'name': 'task_2' - } - }, - 'status': LIVEACTION_STATUS_RUNNING + "id": "idchild2", + "context": {"parent": {"execution_id": "idfoo3"}, "chain": {"name": "task_2"}}, + "status": LIVEACTION_STATUS_RUNNING, } MOCK_LIVEACTION_3_CHILD_2_FAILED = { - 'id': 'idchild2', - 'context': { - 'parent': { - 'execution_id': 'idfoo3' - }, - 'chain': { - 'name': 'task_2' - } - }, - 'status': LIVEACTION_STATUS_FAILED + "id": "idchild2", + "context": {"parent": {"execution_id": "idfoo3"}, "chain": {"name": "task_2"}}, + "status": LIVEACTION_STATUS_FAILED, } MOCK_LIVEACTION_3_CHILD_2_OUTPUT_1 = { - 'execution_id': 'idchild2', - 'timestamp': '1505732598', - 'output_type': 'stdout', - 'data': 'line ac 100\n' + "execution_id": "idchild2", + "timestamp": "1505732598", + "output_type": "stdout", + "data": "line ac 100\n", } -MOCK_LIVEACTION_3_SUCCEDED = { - 'id': 'idfoo3', - 'status': LIVEACTION_STATUS_SUCCEEDED -} +MOCK_LIVEACTION_3_SUCCEDED = {"id": "idfoo3", "status": LIVEACTION_STATUS_SUCCEEDED} # Mock objects for Orquesta workflow execution -MOCK_LIVEACTION_4_RUNNING = { - 'id': 'idfoo4', - 'status': LIVEACTION_STATUS_RUNNING -} +MOCK_LIVEACTION_4_RUNNING = {"id": "idfoo4", "status": LIVEACTION_STATUS_RUNNING} MOCK_LIVEACTION_4_CHILD_1_RUNNING = { - 'id': 'idorquestachild1', - 'context': { - 'orquesta': { - 'task_name': 'task_1' - }, - 'parent': { - 'execution_id': 'idfoo4' - } + "id": "idorquestachild1", + "context": { + "orquesta": {"task_name": "task_1"}, + "parent": {"execution_id": "idfoo4"}, }, - 'status': LIVEACTION_STATUS_RUNNING + "status": LIVEACTION_STATUS_RUNNING, } MOCK_LIVEACTION_4_CHILD_1_1_RUNNING = { - 'id': 'idorquestachild1_1', - 'context': { - 'orquesta': { - 'task_name': 'task_1' - }, - 'parent': { - 'execution_id': 'idorquestachild1' - } + "id": "idorquestachild1_1", + "context": { + "orquesta": {"task_name": "task_1"}, + "parent": {"execution_id": "idorquestachild1"}, }, - 'status': LIVEACTION_STATUS_RUNNING + "status": LIVEACTION_STATUS_RUNNING, } MOCK_LIVEACTION_4_CHILD_1_SUCCEEDED = { - 'id': 'idorquestachild1', - 'context': { - 'orquesta': { - 'task_name': 'task_1', + "id": "idorquestachild1", + "context": { + "orquesta": { + "task_name": "task_1", }, - 'parent': { - 'execution_id': 'idfoo4' - } + "parent": {"execution_id": "idfoo4"}, }, - 'status': LIVEACTION_STATUS_SUCCEEDED + "status": LIVEACTION_STATUS_SUCCEEDED, } MOCK_LIVEACTION_4_CHILD_1_1_SUCCEEDED = { - 'id': 'idorquestachild1_1', - 'context': { - 'orquesta': { - 'task_name': 'task_1', + "id": "idorquestachild1_1", + "context": { + "orquesta": { + "task_name": "task_1", }, - 'parent': { - 'execution_id': 'idorquestachild1' - } + "parent": {"execution_id": "idorquestachild1"}, }, - 'status': LIVEACTION_STATUS_SUCCEEDED + "status": LIVEACTION_STATUS_SUCCEEDED, } MOCK_LIVEACTION_4_CHILD_1_OUTPUT_1 = { - 'execution_id': 'idorquestachild1', - 'timestamp': '1505732598', - 'output_type': 'stdout', - 'data': 'line orquesta 4\n' + "execution_id": "idorquestachild1", + "timestamp": "1505732598", + "output_type": "stdout", + "data": "line orquesta 4\n", } MOCK_LIVEACTION_4_CHILD_1_OUTPUT_2 = { - 'execution_id': 'idorquestachild1', - 'timestamp': '1505732598', - 'output_type': 'stderr', - 'data': 'line orquesta 5\n' + "execution_id": "idorquestachild1", + "timestamp": "1505732598", + "output_type": "stderr", + "data": "line orquesta 5\n", } MOCK_LIVEACTION_4_CHILD_1_1_OUTPUT_1 = { - 'execution_id': 'idorquestachild1_1', - 'timestamp': '1505732598', - 'output_type': 'stdout', - 'data': 'line orquesta 4\n' + "execution_id": "idorquestachild1_1", + "timestamp": "1505732598", + "output_type": "stdout", + "data": "line orquesta 4\n", } MOCK_LIVEACTION_4_CHILD_1_1_OUTPUT_2 = { - 'execution_id': 'idorquestachild1_1', - 'timestamp': '1505732598', - 'output_type': 'stderr', - 'data': 'line orquesta 5\n' + "execution_id": "idorquestachild1_1", + "timestamp": "1505732598", + "output_type": "stderr", + "data": "line orquesta 5\n", } MOCK_LIVEACTION_4_CHILD_2_RUNNING = { - 'id': 'idorquestachild2', - 'context': { - 'orquesta': { - 'task_name': 'task_2', + "id": "idorquestachild2", + "context": { + "orquesta": { + "task_name": "task_2", }, - 'parent': { - 'execution_id': 'idfoo4' - } + "parent": {"execution_id": "idfoo4"}, }, - 'status': LIVEACTION_STATUS_RUNNING + "status": LIVEACTION_STATUS_RUNNING, } MOCK_LIVEACTION_4_CHILD_2_TIMED_OUT = { - 'id': 'idorquestachild2', - 'context': { - 'orquesta': { - 'task_name': 'task_2', + "id": "idorquestachild2", + "context": { + "orquesta": { + "task_name": "task_2", }, - 'parent': { - 'execution_id': 'idfoo4' - } + "parent": {"execution_id": "idfoo4"}, }, - 'status': LIVEACTION_STATUS_TIMED_OUT + "status": LIVEACTION_STATUS_TIMED_OUT, } MOCK_LIVEACTION_4_CHILD_2_OUTPUT_1 = { - 'execution_id': 'idorquestachild2', - 'timestamp': '1505732598', - 'output_type': 'stdout', - 'data': 'line orquesta 100\n' + "execution_id": "idorquestachild2", + "timestamp": "1505732598", + "output_type": "stdout", + "data": "line orquesta 100\n", } -MOCK_LIVEACTION_4_SUCCEDED = { - 'id': 'idfoo4', - 'status': LIVEACTION_STATUS_SUCCEEDED -} +MOCK_LIVEACTION_4_SUCCEDED = {"id": "idfoo4", "status": LIVEACTION_STATUS_SUCCEEDED} # Mock objects for simple actions MOCK_OUTPUT_1 = { - 'execution_id': 'idfoo3', - 'timestamp': '1505732598', - 'output_type': 'stdout', - 'data': 'line 1\n' + "execution_id": "idfoo3", + "timestamp": "1505732598", + "output_type": "stdout", + "data": "line 1\n", } MOCK_OUTPUT_2 = { - 'execution_id': 'idfoo3', - 'timestamp': '1505732598', - 'output_type': 'stderr', - 'data': 'line 2\n' + "execution_id": "idfoo3", + "timestamp": "1505732598", + "output_type": "stderr", + "data": "line 2\n", } @@ -279,42 +212,55 @@ def __init__(self, *args, **kwargs): self.shell = Shell() @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(MOCK_LIVEACTION_1_SUCCEEDED), - 200, 'OK'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps(MOCK_LIVEACTION_1_SUCCEEDED), 200, "OK" + ) + ), + ) def test_tail_simple_execution_already_finished_succeeded(self): - argv = ['execution', 'tail', 'idfoo1'] + argv = ["execution", "tail", "idfoo1"] self.assertEqual(self.shell.run(argv), 0) stdout = self.stdout.getvalue() stderr = self.stderr.getvalue() - self.assertIn('Execution idfoo1 has completed (status=succeeded)', stdout) - self.assertEqual(stderr, '') + self.assertIn("Execution idfoo1 has completed (status=succeeded)", stdout) + self.assertEqual(stderr, "") @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(MOCK_LIVEACTION_2_FAILED), - 200, 'OK'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps(MOCK_LIVEACTION_2_FAILED), 200, "OK" + ) + ), + ) def test_tail_simple_execution_already_finished_failed(self): - argv = ['execution', 'tail', 'idfoo2'] + argv = ["execution", "tail", "idfoo2"] self.assertEqual(self.shell.run(argv), 0) stdout = self.stdout.getvalue() stderr = self.stderr.getvalue() - self.assertIn('Execution idfoo2 has completed (status=failed)', stdout) - self.assertEqual(stderr, '') + self.assertIn("Execution idfoo2 has completed (status=failed)", stdout) + self.assertEqual(stderr, "") @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(MOCK_LIVEACTION_1_RUNNING), - 200, 'OK'))) - @mock.patch('st2client.client.StreamManager', autospec=True) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps(MOCK_LIVEACTION_1_RUNNING), 200, "OK" + ) + ), + ) + @mock.patch("st2client.client.StreamManager", autospec=True) def test_tail_simple_execution_running_no_data_produced(self, mock_stream_manager): - argv = ['execution', 'tail', 'idfoo1'] + argv = ["execution", "tail", "idfoo1"] - MOCK_EVENTS = [ - MOCK_LIVEACTION_1_SUCCEEDED - ] + MOCK_EVENTS = [MOCK_LIVEACTION_1_SUCCEEDED] mock_cls = mock.Mock() mock_cls.listen = mock.Mock() @@ -333,21 +279,26 @@ def test_tail_simple_execution_running_no_data_produced(self, mock_stream_manage Execution idfoo1 has completed (status=succeeded). """ self.assertEqual(stdout, expected_result) - self.assertEqual(stderr, '') + self.assertEqual(stderr, "") @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(MOCK_LIVEACTION_3_RUNNING), - 200, 'OK'))) - @mock.patch('st2client.client.StreamManager', autospec=True) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps(MOCK_LIVEACTION_3_RUNNING), 200, "OK" + ) + ), + ) + @mock.patch("st2client.client.StreamManager", autospec=True) def test_tail_simple_execution_running_with_data(self, mock_stream_manager): - argv = ['execution', 'tail', 'idfoo3'] + argv = ["execution", "tail", "idfoo3"] MOCK_EVENTS = [ MOCK_LIVEACTION_3_RUNNING, MOCK_OUTPUT_1, MOCK_OUTPUT_2, - MOCK_LIVEACTION_3_SUCCEDED + MOCK_LIVEACTION_3_SUCCEDED, ] mock_cls = mock.Mock() @@ -372,41 +323,39 @@ def test_tail_simple_execution_running_with_data(self, mock_stream_manager): Execution idfoo3 has completed (status=succeeded). """.lstrip() self.assertEqual(stdout, expected_result) - self.assertEqual(stderr, '') + self.assertEqual(stderr, "") @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(MOCK_LIVEACTION_3_RUNNING), - 200, 'OK'))) - @mock.patch('st2client.client.StreamManager', autospec=True) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps(MOCK_LIVEACTION_3_RUNNING), 200, "OK" + ) + ), + ) + @mock.patch("st2client.client.StreamManager", autospec=True) def test_tail_action_chain_workflow_execution(self, mock_stream_manager): - argv = ['execution', 'tail', 'idfoo3'] + argv = ["execution", "tail", "idfoo3"] MOCK_EVENTS = [ # Workflow started running MOCK_LIVEACTION_3_RUNNING, - # Child task 1 started running MOCK_LIVEACTION_3_CHILD_1_RUNNING, - # Output produced by the child task MOCK_LIVEACTION_3_CHILD_1_OUTPUT_1, MOCK_LIVEACTION_3_CHILD_1_OUTPUT_2, - # Child task 1 finished MOCK_LIVEACTION_3_CHILD_1_SUCCEEDED, - # Child task 2 started running MOCK_LIVEACTION_3_CHILD_2_RUNNING, - # Output produced by child task MOCK_LIVEACTION_3_CHILD_2_OUTPUT_1, - # Child task 2 finished MOCK_LIVEACTION_3_CHILD_2_FAILED, - # Parent workflow task finished - MOCK_LIVEACTION_3_SUCCEDED + MOCK_LIVEACTION_3_SUCCEDED, ] mock_cls = mock.Mock() @@ -440,41 +389,39 @@ def test_tail_action_chain_workflow_execution(self, mock_stream_manager): Execution idfoo3 has completed (status=succeeded). """.lstrip() self.assertEqual(stdout, expected_result) - self.assertEqual(stderr, '') + self.assertEqual(stderr, "") @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(MOCK_LIVEACTION_4_RUNNING), - 200, 'OK'))) - @mock.patch('st2client.client.StreamManager', autospec=True) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps(MOCK_LIVEACTION_4_RUNNING), 200, "OK" + ) + ), + ) + @mock.patch("st2client.client.StreamManager", autospec=True) def test_tail_orquesta_workflow_execution(self, mock_stream_manager): - argv = ['execution', 'tail', 'idfoo4'] + argv = ["execution", "tail", "idfoo4"] MOCK_EVENTS = [ # Workflow started running MOCK_LIVEACTION_4_RUNNING, - # Child task 1 started running MOCK_LIVEACTION_4_CHILD_1_RUNNING, - # Output produced by the child task MOCK_LIVEACTION_4_CHILD_1_OUTPUT_1, MOCK_LIVEACTION_4_CHILD_1_OUTPUT_2, - # Child task 1 finished MOCK_LIVEACTION_4_CHILD_1_SUCCEEDED, - # Child task 2 started running MOCK_LIVEACTION_4_CHILD_2_RUNNING, - # Output produced by child task MOCK_LIVEACTION_4_CHILD_2_OUTPUT_1, - # Child task 2 finished MOCK_LIVEACTION_4_CHILD_2_TIMED_OUT, - # Parent workflow task finished - MOCK_LIVEACTION_4_SUCCEDED + MOCK_LIVEACTION_4_SUCCEDED, ] mock_cls = mock.Mock() @@ -508,64 +455,55 @@ def test_tail_orquesta_workflow_execution(self, mock_stream_manager): Execution idfoo4 has completed (status=succeeded). """.lstrip() self.assertEqual(stdout, expected_result) - self.assertEqual(stderr, '') + self.assertEqual(stderr, "") @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(MOCK_LIVEACTION_4_RUNNING), - 200, 'OK'))) - @mock.patch('st2client.client.StreamManager', autospec=True) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps(MOCK_LIVEACTION_4_RUNNING), 200, "OK" + ) + ), + ) + @mock.patch("st2client.client.StreamManager", autospec=True) def test_tail_double_nested_orquesta_workflow_execution(self, mock_stream_manager): - argv = ['execution', 'tail', 'idfoo4'] + argv = ["execution", "tail", "idfoo4"] MOCK_EVENTS = [ # Workflow started running MOCK_LIVEACTION_4_RUNNING, - # Child task 1 started running (sub workflow) MOCK_LIVEACTION_4_CHILD_1_RUNNING, - # Child task 1 started running MOCK_LIVEACTION_4_CHILD_1_1_RUNNING, - # Output produced by the child task MOCK_LIVEACTION_4_CHILD_1_1_OUTPUT_1, MOCK_LIVEACTION_4_CHILD_1_1_OUTPUT_2, - # Another execution has started, this output should not be included MOCK_LIVEACTION_3_RUNNING, - # Child task 1 started running MOCK_LIVEACTION_3_CHILD_1_RUNNING, - # Output produced by the child task MOCK_LIVEACTION_3_CHILD_1_OUTPUT_1, MOCK_LIVEACTION_3_CHILD_1_OUTPUT_2, - # Child task 1 finished MOCK_LIVEACTION_3_CHILD_1_SUCCEEDED, - # Parent workflow task finished MOCK_LIVEACTION_3_SUCCEDED, # End another execution - # Child task 1 has finished MOCK_LIVEACTION_4_CHILD_1_1_SUCCEEDED, - # Child task 1 finished (sub workflow) MOCK_LIVEACTION_4_CHILD_1_SUCCEEDED, - # Child task 2 started running MOCK_LIVEACTION_4_CHILD_2_RUNNING, - # Output produced by child task MOCK_LIVEACTION_4_CHILD_2_OUTPUT_1, - # Child task 2 finished MOCK_LIVEACTION_4_CHILD_2_TIMED_OUT, - # Parent workflow task finished - MOCK_LIVEACTION_4_SUCCEDED + MOCK_LIVEACTION_4_SUCCEDED, ] mock_cls = mock.Mock() @@ -604,32 +542,33 @@ def test_tail_double_nested_orquesta_workflow_execution(self, mock_stream_manage """.lstrip() self.assertEqual(stdout, expected_result) - self.assertEqual(stderr, '') + self.assertEqual(stderr, "") @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(MOCK_LIVEACTION_4_CHILD_2_RUNNING), - 200, 'OK'))) - @mock.patch('st2client.client.StreamManager', autospec=True) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps(MOCK_LIVEACTION_4_CHILD_2_RUNNING), 200, "OK" + ) + ), + ) + @mock.patch("st2client.client.StreamManager", autospec=True) def test_tail_child_execution_directly(self, mock_stream_manager): - argv = ['execution', 'tail', 'idfoo4'] + argv = ["execution", "tail", "idfoo4"] MOCK_EVENTS = [ # Child task 2 started running MOCK_LIVEACTION_4_CHILD_2_RUNNING, - # Output produced by child task MOCK_LIVEACTION_4_CHILD_2_OUTPUT_1, - # Other executions should not interfere # Child task 1 started running MOCK_LIVEACTION_3_CHILD_1_RUNNING, - # Child task 1 finished (sub workflow) MOCK_LIVEACTION_4_CHILD_1_SUCCEEDED, - # Child task 2 finished - MOCK_LIVEACTION_4_CHILD_2_TIMED_OUT + MOCK_LIVEACTION_4_CHILD_2_TIMED_OUT, ] mock_cls = mock.Mock() @@ -654,4 +593,4 @@ def test_tail_child_execution_directly(self, mock_stream_manager): """.lstrip() self.assertEqual(stdout, expected_result) - self.assertEqual(stderr, '') + self.assertEqual(stderr, "") diff --git a/st2client/tests/unit/test_formatters.py b/st2client/tests/unit/test_formatters.py index b3733faba52..fe0370aea12 100644 --- a/st2client/tests/unit/test_formatters.py +++ b/st2client/tests/unit/test_formatters.py @@ -39,38 +39,43 @@ LOG = logging.getLogger(__name__) FIXTURES_MANIFEST = { - 'executions': ['execution.json', - 'execution_result_has_carriage_return.json', - 'execution_unicode.json', - 'execution_double_backslash.json', - 'execution_with_stack_trace.json', - 'execution_with_schema.json'], - 'results': ['execution_get_default.txt', - 'execution_get_detail.txt', - 'execution_get_result_by_key.txt', - 'execution_result_has_carriage_return.txt', - 'execution_result_has_carriage_return_py3.txt', - 'execution_get_attributes.txt', - 'execution_list_attr_start_timestamp.txt', - 'execution_list_empty_response_start_timestamp_attr.txt', - 'execution_unescape_newline.txt', - 'execution_unicode.txt', - 'execution_double_backslash.txt', - 'execution_unicode_py3.txt', - 'execution_get_has_schema.txt'] + "executions": [ + "execution.json", + "execution_result_has_carriage_return.json", + "execution_unicode.json", + "execution_double_backslash.json", + "execution_with_stack_trace.json", + "execution_with_schema.json", + ], + "results": [ + "execution_get_default.txt", + "execution_get_detail.txt", + "execution_get_result_by_key.txt", + "execution_result_has_carriage_return.txt", + "execution_result_has_carriage_return_py3.txt", + "execution_get_attributes.txt", + "execution_list_attr_start_timestamp.txt", + "execution_list_empty_response_start_timestamp_attr.txt", + "execution_unescape_newline.txt", + "execution_unicode.txt", + "execution_double_backslash.txt", + "execution_unicode_py3.txt", + "execution_get_has_schema.txt", + ], } FIXTURES = loader.load_fixtures(fixtures_dict=FIXTURES_MANIFEST) -EXECUTION = FIXTURES['executions']['execution.json'] -UNICODE = FIXTURES['executions']['execution_unicode.json'] -DOUBLE_BACKSLASH = FIXTURES['executions']['execution_double_backslash.json'] -OUTPUT_SCHEMA = FIXTURES['executions']['execution_with_schema.json'] -NEWLINE = FIXTURES['executions']['execution_with_stack_trace.json'] -HAS_CARRIAGE_RETURN = FIXTURES['executions']['execution_result_has_carriage_return.json'] +EXECUTION = FIXTURES["executions"]["execution.json"] +UNICODE = FIXTURES["executions"]["execution_unicode.json"] +DOUBLE_BACKSLASH = FIXTURES["executions"]["execution_double_backslash.json"] +OUTPUT_SCHEMA = FIXTURES["executions"]["execution_with_schema.json"] +NEWLINE = FIXTURES["executions"]["execution_with_stack_trace.json"] +HAS_CARRIAGE_RETURN = FIXTURES["executions"][ + "execution_result_has_carriage_return.json" +] class TestExecutionResultFormatter(unittest2.TestCase): - def __init__(self, *args, **kwargs): super(TestExecutionResultFormatter, self).__init__(*args, **kwargs) self.shell = shell.Shell() @@ -88,212 +93,278 @@ def tearDown(self): os.unlink(self.path) def _redirect_console(self, path): - sys.stdout = open(path, 'w') - sys.stderr = open(path, 'w') + sys.stdout = open(path, "w") + sys.stderr = open(path, "w") def _undo_console_redirect(self): sys.stdout = sys.__stdout__ sys.stderr = sys.__stderr__ def test_console_redirect(self): - message = 'Hello, World!' + message = "Hello, World!" print(message) self._undo_console_redirect() - with open(self.path, 'r') as fd: - content = fd.read().replace('\n', '') + with open(self.path, "r") as fd: + content = fd.read().replace("\n", "") self.assertEqual(content, message) def test_execution_get_default(self): - argv = ['execution', 'get', EXECUTION['id']] + argv = ["execution", "get", EXECUTION["id"]] content = self._get_execution(argv) - self.assertEqual(content, FIXTURES['results']['execution_get_default.txt']) + self.assertEqual(content, FIXTURES["results"]["execution_get_default.txt"]) def test_execution_get_attributes(self): - argv = ['execution', 'get', EXECUTION['id'], '--attr', 'status', 'end_timestamp'] + argv = [ + "execution", + "get", + EXECUTION["id"], + "--attr", + "status", + "end_timestamp", + ] content = self._get_execution(argv) - self.assertEqual(content, FIXTURES['results']['execution_get_attributes.txt']) + self.assertEqual(content, FIXTURES["results"]["execution_get_attributes.txt"]) def test_execution_get_default_in_json(self): - argv = ['execution', 'get', EXECUTION['id'], '-j'] + argv = ["execution", "get", EXECUTION["id"], "-j"] content = self._get_execution(argv) - self.assertEqual(json.loads(content), - jsutil.get_kvps(EXECUTION, ['id', 'action.ref', 'context.user', - 'start_timestamp', 'end_timestamp', 'status', - 'parameters', 'result'])) + self.assertEqual( + json.loads(content), + jsutil.get_kvps( + EXECUTION, + [ + "id", + "action.ref", + "context.user", + "start_timestamp", + "end_timestamp", + "status", + "parameters", + "result", + ], + ), + ) def test_execution_get_detail(self): - argv = ['execution', 'get', EXECUTION['id'], '-d'] + argv = ["execution", "get", EXECUTION["id"], "-d"] content = self._get_execution(argv) - self.assertEqual(content, FIXTURES['results']['execution_get_detail.txt']) + self.assertEqual(content, FIXTURES["results"]["execution_get_detail.txt"]) def test_execution_with_schema(self): - argv = ['execution', 'get', OUTPUT_SCHEMA['id']] + argv = ["execution", "get", OUTPUT_SCHEMA["id"]] content = self._get_schema_execution(argv) - self.assertEqual(content, FIXTURES['results']['execution_get_has_schema.txt']) + self.assertEqual(content, FIXTURES["results"]["execution_get_has_schema.txt"]) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(NEWLINE), 200, 'OK', {}))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(NEWLINE), 200, "OK", {}) + ), + ) def test_execution_unescape_newline(self): - """Ensure client renders newline characters - """ + """Ensure client renders newline characters""" - argv = ['execution', 'get', NEWLINE['id']] + argv = ["execution", "get", NEWLINE["id"]] self.assertEqual(self.shell.run(argv), 0) self._undo_console_redirect() - with open(self.path, 'r') as fd: + with open(self.path, "r") as fd: content = fd.read() - self.assertEqual(content, FIXTURES['results']['execution_unescape_newline.txt']) + self.assertEqual(content, FIXTURES["results"]["execution_unescape_newline.txt"]) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(UNICODE), 200, 'OK', {}))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(UNICODE), 200, "OK", {}) + ), + ) def test_execution_unicode(self): - """Ensure client renders unicode escape sequences - """ + """Ensure client renders unicode escape sequences""" - argv = ['execution', 'get', UNICODE['id']] + argv = ["execution", "get", UNICODE["id"]] self.assertEqual(self.shell.run(argv), 0) self._undo_console_redirect() - with open(self.path, 'r') as fd: + with open(self.path, "r") as fd: content = fd.read() if six.PY2: - self.assertEqual(content, FIXTURES['results']['execution_unicode.txt']) + self.assertEqual(content, FIXTURES["results"]["execution_unicode.txt"]) else: - content = content.replace(r'\xE2\x80\xA1', r'\u2021') - self.assertEqual(content, FIXTURES['results']['execution_unicode_py3.txt']) + content = content.replace(r"\xE2\x80\xA1", r"\u2021") + self.assertEqual(content, FIXTURES["results"]["execution_unicode_py3.txt"]) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(DOUBLE_BACKSLASH), 200, 'OK', {}))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(DOUBLE_BACKSLASH), 200, "OK", {}) + ), + ) def test_execution_double_backslash_not_unicode_escape_sequence(self): - argv = ['execution', 'get', DOUBLE_BACKSLASH['id']] + argv = ["execution", "get", DOUBLE_BACKSLASH["id"]] self.assertEqual(self.shell.run(argv), 0) self._undo_console_redirect() - with open(self.path, 'r') as fd: + with open(self.path, "r") as fd: content = fd.read() - self.assertEqual(content, FIXTURES['results']['execution_double_backslash.txt']) + self.assertEqual(content, FIXTURES["results"]["execution_double_backslash.txt"]) def test_execution_get_detail_in_json(self): - argv = ['execution', 'get', EXECUTION['id'], '-d', '-j'] + argv = ["execution", "get", EXECUTION["id"], "-d", "-j"] content = self._get_execution(argv) content_dict = json.loads(content) # Sufficient to check if output contains all expected keys. The entire result will not # match as content will contain characters which improve rendering. for k in six.iterkeys(EXECUTION): - if k in ['liveaction', 'callback']: + if k in ["liveaction", "callback"]: continue if k in content: continue - self.assertTrue(False, 'Missing key %s. %s != %s' % (k, EXECUTION, content_dict)) + self.assertTrue( + False, "Missing key %s. %s != %s" % (k, EXECUTION, content_dict) + ) def test_execution_get_result_by_key(self): - argv = ['execution', 'get', EXECUTION['id'], '-k', 'localhost.stdout'] + argv = ["execution", "get", EXECUTION["id"], "-k", "localhost.stdout"] content = self._get_execution(argv) - self.assertEqual(content, FIXTURES['results']['execution_get_result_by_key.txt']) + self.assertEqual( + content, FIXTURES["results"]["execution_get_result_by_key.txt"] + ) def test_execution_get_result_by_key_in_json(self): - argv = ['execution', 'get', EXECUTION['id'], '-k', 'localhost.stdout', '-j'] + argv = ["execution", "get", EXECUTION["id"], "-k", "localhost.stdout", "-j"] content = self._get_execution(argv) - self.assertDictEqual(json.loads(content), - jsutil.get_kvps(EXECUTION, ['result.localhost.stdout'])) + self.assertDictEqual( + json.loads(content), jsutil.get_kvps(EXECUTION, ["result.localhost.stdout"]) + ) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(HAS_CARRIAGE_RETURN), 200, 'OK', - {}))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps(HAS_CARRIAGE_RETURN), 200, "OK", {} + ) + ), + ) def test_execution_get_detail_with_carriage_return(self): - argv = ['execution', 'get', HAS_CARRIAGE_RETURN['id'], '-d'] + argv = ["execution", "get", HAS_CARRIAGE_RETURN["id"], "-d"] self.assertEqual(self.shell.run(argv), 0) self._undo_console_redirect() - with open(self.path, 'r') as fd: + with open(self.path, "r") as fd: content = fd.read() if six.PY2: self.assertEqual( - content, FIXTURES['results']['execution_result_has_carriage_return.txt']) + content, FIXTURES["results"]["execution_result_has_carriage_return.txt"] + ) else: self.assertEqual( content, - FIXTURES['results']['execution_result_has_carriage_return_py3.txt']) + FIXTURES["results"]["execution_result_has_carriage_return_py3.txt"], + ) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps([EXECUTION]), 200, 'OK', {}))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps([EXECUTION]), 200, "OK", {}) + ), + ) def test_execution_list_attribute_provided(self): # Client shouldn't throw if "-a" flag is provided when listing executions - argv = ['execution', 'list', '-a', 'start_timestamp'] + argv = ["execution", "list", "-a", "start_timestamp"] self.assertEqual(self.shell.run(argv), 0) self._undo_console_redirect() - with open(self.path, 'r') as fd: + with open(self.path, "r") as fd: content = fd.read() self.assertEqual( - content, FIXTURES['results']['execution_list_attr_start_timestamp.txt']) + content, FIXTURES["results"]["execution_list_attr_start_timestamp.txt"] + ) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps([]), 200, 'OK', {}))) + httpclient.HTTPClient, + "get", + mock.MagicMock(return_value=base.FakeResponse(json.dumps([]), 200, "OK", {})), + ) def test_execution_list_attribute_provided_empty_response(self): # Client shouldn't throw if "-a" flag is provided, but there are no executions - argv = ['execution', 'list', '-a', 'start_timestamp'] + argv = ["execution", "list", "-a", "start_timestamp"] self.assertEqual(self.shell.run(argv), 0) self._undo_console_redirect() - with open(self.path, 'r') as fd: + with open(self.path, "r") as fd: content = fd.read() self.assertEqual( - content, FIXTURES['results']['execution_list_empty_response_start_timestamp_attr.txt']) + content, + FIXTURES["results"][ + "execution_list_empty_response_start_timestamp_attr.txt" + ], + ) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(EXECUTION), 200, 'OK', {}))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(EXECUTION), 200, "OK", {}) + ), + ) def _get_execution(self, argv): self.assertEqual(self.shell.run(argv), 0) self._undo_console_redirect() - with open(self.path, 'r') as fd: + with open(self.path, "r") as fd: content = fd.read() return content @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(OUTPUT_SCHEMA), 200, 'OK', {}))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(OUTPUT_SCHEMA), 200, "OK", {}) + ), + ) def _get_schema_execution(self, argv): self.assertEqual(self.shell.run(argv), 0) self._undo_console_redirect() - with open(self.path, 'r') as fd: + with open(self.path, "r") as fd: content = fd.read() return content def test_SinlgeRowTable_notebox_one(self): - with mock.patch('sys.stderr', new=StringIO()) as fackety_fake: - expected = "Note: Only one action execution is displayed. Use -n/--last flag for " \ + with mock.patch("sys.stderr", new=StringIO()) as fackety_fake: + expected = ( + "Note: Only one action execution is displayed. Use -n/--last flag for " "more results." + ) print(self.table.note_box("action executions", 1)) - content = (fackety_fake.getvalue().split("|")[1].strip()) + content = fackety_fake.getvalue().split("|")[1].strip() self.assertEqual(content, expected) def test_SinlgeRowTable_notebox_zero(self): - with mock.patch('sys.stderr', new=BytesIO()) as fackety_fake: - contents = (fackety_fake.getvalue()) - self.assertEqual(contents, b'') + with mock.patch("sys.stderr", new=BytesIO()) as fackety_fake: + contents = fackety_fake.getvalue() + self.assertEqual(contents, b"") def test_SinlgeRowTable_notebox_default(self): - with mock.patch('sys.stderr', new=StringIO()) as fackety_fake: - expected = "Note: Only first 50 action executions are displayed. Use -n/--last flag " \ + with mock.patch("sys.stderr", new=StringIO()) as fackety_fake: + expected = ( + "Note: Only first 50 action executions are displayed. Use -n/--last flag " "for more results." + ) print(self.table.note_box("action executions", 50)) - content = (fackety_fake.getvalue().split("|")[1].strip()) + content = fackety_fake.getvalue().split("|")[1].strip() self.assertEqual(content, expected) - with mock.patch('sys.stderr', new=StringIO()) as fackety_fake: - expected = "Note: Only first 15 action executions are displayed. Use -n/--last flag " \ + with mock.patch("sys.stderr", new=StringIO()) as fackety_fake: + expected = ( + "Note: Only first 15 action executions are displayed. Use -n/--last flag " "for more results." + ) print(self.table.note_box("action executions", 15)) - content = (fackety_fake.getvalue().split("|")[1].strip()) + content = fackety_fake.getvalue().split("|")[1].strip() self.assertEqual(content, expected) diff --git a/st2client/tests/unit/test_inquiry.py b/st2client/tests/unit/test_inquiry.py index 138f1da8991..4132fda0d12 100644 --- a/st2client/tests/unit/test_inquiry.py +++ b/st2client/tests/unit/test_inquiry.py @@ -31,12 +31,12 @@ def _randomize_inquiry_id(inquiry): newinquiry = copy.deepcopy(inquiry) - newinquiry['id'] = str(uuid.uuid4()) + newinquiry["id"] = str(uuid.uuid4()) # ID can't have '1440' in it, otherwise our `count()` fails # when inspecting the inquiry list output for test: # test_list_inquiries_limit() - while '1440' in newinquiry['id']: - newinquiry['id'] = str(uuid.uuid4()) + while "1440" in newinquiry["id"]: + newinquiry["id"] = str(uuid.uuid4()) return newinquiry @@ -45,8 +45,7 @@ def _generate_inquiries(count): class TestInquiryBase(base.BaseCLITestCase): - """Base class for "inquiry" CLI tests - """ + """Base class for "inquiry" CLI tests""" capture_output = True @@ -54,8 +53,8 @@ def __init__(self, *args, **kwargs): super(TestInquiryBase, self).__init__(*args, **kwargs) self.parser = argparse.ArgumentParser() - self.parser.add_argument('-t', '--token', dest='token') - self.parser.add_argument('--api-key', dest='api_key') + self.parser.add_argument("-t", "--token", dest="token") + self.parser.add_argument("--api-key", dest="api_key") self.shell = shell.Shell() def setUp(self): @@ -72,14 +71,12 @@ def tearDown(self): "continue": { "type": "boolean", "description": "Would you like to continue the workflow?", - "required": True + "required": True, } }, } -RESPONSE_DEFAULT = { - "continue": True -} +RESPONSE_DEFAULT = {"continue": True} SCHEMA_MULTIPLE = { "title": "response_data", @@ -88,30 +85,24 @@ def tearDown(self): "name": { "type": "string", "description": "What is your name?", - "required": True + "required": True, }, "pin": { "type": "integer", "description": "What is your PIN?", - "required": True + "required": True, }, "paradox": { "type": "boolean", "description": "This statement is False.", - "required": True - } + "required": True, + }, }, } -RESPONSE_MULTIPLE = { - "name": "matt", - "pin": 1234, - "paradox": True -} +RESPONSE_MULTIPLE = {"name": "matt", "pin": 1234, "paradox": True} -RESPONSE_BAD = { - "foo": "bar" -} +RESPONSE_BAD = {"foo": "bar"} INQUIRY_1 = { "id": "abcdef", @@ -119,7 +110,7 @@ def tearDown(self): "roles": [], "users": [], "route": "", - "ttl": 1440 + "ttl": 1440, } INQUIRY_MULTIPLE = { @@ -128,145 +119,200 @@ def tearDown(self): "roles": [], "users": [], "route": "", - "ttl": 1440 + "ttl": 1440, } class TestInquirySubcommands(TestInquiryBase): - @mock.patch.object( - requests, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(INQUIRY_1), 200, 'OK'))) + requests, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(INQUIRY_1), 200, "OK") + ), + ) def test_get_inquiry(self): - """Test retrieval of a single inquiry - """ - inquiry_id = 'abcdef' - args = ['inquiry', 'get', inquiry_id] + """Test retrieval of a single inquiry""" + inquiry_id = "abcdef" + args = ["inquiry", "get", inquiry_id] retcode = self.shell.run(args) self.assertEqual(retcode, 0) @mock.patch.object( - requests, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 404, 'NOT FOUND'))) + requests, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps({}), 404, "NOT FOUND") + ), + ) def test_get_inquiry_not_found(self): - """Test retrieval of a inquiry that doesn't exist - """ - inquiry_id = 'asdbv' - args = ['inquiry', 'get', inquiry_id] + """Test retrieval of a inquiry that doesn't exist""" + inquiry_id = "asdbv" + args = ["inquiry", "get", inquiry_id] retcode = self.shell.run(args) - self.assertEqual('Inquiry "%s" is not found.\n\n' % inquiry_id, self.stdout.getvalue()) + self.assertEqual( + 'Inquiry "%s" is not found.\n\n' % inquiry_id, self.stdout.getvalue() + ) self.assertEqual(retcode, 2) @mock.patch.object( - requests, 'get', - mock.MagicMock(return_value=(base.FakeResponse( - json.dumps([INQUIRY_1]), 200, 'OK', {'X-Total-Count': '1'} - )))) + requests, + "get", + mock.MagicMock( + return_value=( + base.FakeResponse( + json.dumps([INQUIRY_1]), 200, "OK", {"X-Total-Count": "1"} + ) + ) + ), + ) def test_list_inquiries(self): - """Test retrieval of a list of Inquiries - """ - args = ['inquiry', 'list'] + """Test retrieval of a list of Inquiries""" + args = ["inquiry", "list"] retcode = self.shell.run(args) self.assertEqual(retcode, 0) - self.assertEqual(self.stdout.getvalue().count('1440'), 1) + self.assertEqual(self.stdout.getvalue().count("1440"), 1) @mock.patch.object( - requests, 'get', - mock.MagicMock(return_value=(base.FakeResponse( - json.dumps(_generate_inquiries(50)), 200, 'OK', {'X-Total-Count': '55'} - )))) + requests, + "get", + mock.MagicMock( + return_value=( + base.FakeResponse( + json.dumps(_generate_inquiries(50)), + 200, + "OK", + {"X-Total-Count": "55"}, + ) + ) + ), + ) def test_list_inquiries_limit(self): - """Test retrieval of a list of Inquiries while using the "limit" option - """ - args = ['inquiry', 'list', '-n', '50'] + """Test retrieval of a list of Inquiries while using the "limit" option""" + args = ["inquiry", "list", "-n", "50"] retcode = self.shell.run(args) self.assertEqual(retcode, 0) - self.assertEqual(self.stdout.getvalue().count('1440'), 50) - self.assertIn('Note: Only first 50 inquiries are displayed.', self.stderr.getvalue()) + self.assertEqual(self.stdout.getvalue().count("1440"), 50) + self.assertIn( + "Note: Only first 50 inquiries are displayed.", self.stderr.getvalue() + ) @mock.patch.object( - requests, 'get', - mock.MagicMock(return_value=(base.FakeResponse( - json.dumps([]), 200, 'OK', {'X-Total-Count': '0'} - )))) + requests, + "get", + mock.MagicMock( + return_value=( + base.FakeResponse(json.dumps([]), 200, "OK", {"X-Total-Count": "0"}) + ) + ), + ) def test_list_empty_inquiries(self): - """Test empty list of Inquiries - """ - args = ['inquiry', 'list'] + """Test empty list of Inquiries""" + args = ["inquiry", "list"] retcode = self.shell.run(args) self.assertEqual(retcode, 0) @mock.patch.object( - requests, 'get', - mock.MagicMock(return_value=(base.FakeResponse( - json.dumps(INQUIRY_1), 200, 'OK' - )))) + requests, + "get", + mock.MagicMock( + return_value=(base.FakeResponse(json.dumps(INQUIRY_1), 200, "OK")) + ), + ) @mock.patch.object( - requests, 'put', - mock.MagicMock(return_value=(base.FakeResponse( - json.dumps({"id": "abcdef", "response": RESPONSE_DEFAULT}), 200, 'OK' - )))) - @mock.patch('st2client.commands.inquiry.InteractiveForm') + requests, + "put", + mock.MagicMock( + return_value=( + base.FakeResponse( + json.dumps({"id": "abcdef", "response": RESPONSE_DEFAULT}), + 200, + "OK", + ) + ) + ), + ) + @mock.patch("st2client.commands.inquiry.InteractiveForm") def test_respond(self, mock_form): - """Test interactive response - """ + """Test interactive response""" form_instance = mock_form.return_value form_instance.initiate_dialog.return_value = RESPONSE_DEFAULT - args = ['inquiry', 'respond', 'abcdef'] + args = ["inquiry", "respond", "abcdef"] retcode = self.shell.run(args) self.assertEqual(retcode, 0) @mock.patch.object( - requests, 'get', - mock.MagicMock(return_value=(base.FakeResponse( - json.dumps(INQUIRY_1), 200, 'OK' - )))) + requests, + "get", + mock.MagicMock( + return_value=(base.FakeResponse(json.dumps(INQUIRY_1), 200, "OK")) + ), + ) @mock.patch.object( - requests, 'put', - mock.MagicMock(return_value=(base.FakeResponse( - json.dumps({"id": "abcdef", "response": RESPONSE_DEFAULT}), 200, 'OK' - )))) + requests, + "put", + mock.MagicMock( + return_value=( + base.FakeResponse( + json.dumps({"id": "abcdef", "response": RESPONSE_DEFAULT}), + 200, + "OK", + ) + ) + ), + ) def test_respond_response_flag(self): - """Test response without interactive mode - """ - args = ['inquiry', 'respond', '-r', '"%s"' % RESPONSE_DEFAULT, 'abcdef'] + """Test response without interactive mode""" + args = ["inquiry", "respond", "-r", '"%s"' % RESPONSE_DEFAULT, "abcdef"] retcode = self.shell.run(args) self.assertEqual(retcode, 0) @mock.patch.object( - requests, 'get', - mock.MagicMock(return_value=(base.FakeResponse( - json.dumps(INQUIRY_1), 200, 'OK' - )))) + requests, + "get", + mock.MagicMock( + return_value=(base.FakeResponse(json.dumps(INQUIRY_1), 200, "OK")) + ), + ) @mock.patch.object( - requests, 'put', - mock.MagicMock(return_value=(base.FakeResponse( - json.dumps({}), 400, '400 Client Error: Bad Request' - )))) + requests, + "put", + mock.MagicMock( + return_value=( + base.FakeResponse(json.dumps({}), 400, "400 Client Error: Bad Request") + ) + ), + ) def test_respond_invalid(self): - """Test invalid response - """ - args = ['inquiry', 'respond', '-r', '"%s"' % RESPONSE_BAD, 'abcdef'] + """Test invalid response""" + args = ["inquiry", "respond", "-r", '"%s"' % RESPONSE_BAD, "abcdef"] retcode = self.shell.run(args) self.assertEqual(retcode, 1) - self.assertEqual('ERROR: 400 Client Error: Bad Request', self.stdout.getvalue().strip()) + self.assertEqual( + "ERROR: 400 Client Error: Bad Request", self.stdout.getvalue().strip() + ) def test_respond_nonexistent_inquiry(self): - """Test responding to an inquiry that doesn't exist - """ - inquiry_id = '134234' - args = ['inquiry', 'respond', '-r', '"%s"' % RESPONSE_DEFAULT, inquiry_id] + """Test responding to an inquiry that doesn't exist""" + inquiry_id = "134234" + args = ["inquiry", "respond", "-r", '"%s"' % RESPONSE_DEFAULT, inquiry_id] retcode = self.shell.run(args) self.assertEqual(retcode, 1) - self.assertEqual('ERROR: Resource with id "%s" doesn\'t exist.' % inquiry_id, - self.stdout.getvalue().strip()) + self.assertEqual( + 'ERROR: Resource with id "%s" doesn\'t exist.' % inquiry_id, + self.stdout.getvalue().strip(), + ) @mock.patch.object( - requests, 'get', - mock.MagicMock(return_value=(base.FakeResponse( - json.dumps({}), 404, '404 Client Error: Not Found' - )))) - @mock.patch('st2client.commands.inquiry.InteractiveForm') + requests, + "get", + mock.MagicMock( + return_value=( + base.FakeResponse(json.dumps({}), 404, "404 Client Error: Not Found") + ) + ), + ) + @mock.patch("st2client.commands.inquiry.InteractiveForm") def test_respond_nonexistent_inquiry_interactive(self, mock_form): """Test interactively responding to an inquiry that doesn't exist @@ -274,11 +320,13 @@ def test_respond_nonexistent_inquiry_interactive(self, mock_form): responding with PUT, in order to retrieve the desired schema for this inquiry. So, we want to test that interaction separately. """ - inquiry_id = '253432' + inquiry_id = "253432" form_instance = mock_form.return_value form_instance.initiate_dialog.return_value = RESPONSE_DEFAULT - args = ['inquiry', 'respond', inquiry_id] + args = ["inquiry", "respond", inquiry_id] retcode = self.shell.run(args) self.assertEqual(retcode, 1) - self.assertEqual('ERROR: Resource with id "%s" doesn\'t exist.' % inquiry_id, - self.stdout.getvalue().strip()) + self.assertEqual( + 'ERROR: Resource with id "%s" doesn\'t exist.' % inquiry_id, + self.stdout.getvalue().strip(), + ) diff --git a/st2client/tests/unit/test_interactive.py b/st2client/tests/unit/test_interactive.py index 24f0080232e..dce4c6748dd 100644 --- a/st2client/tests/unit/test_interactive.py +++ b/st2client/tests/unit/test_interactive.py @@ -31,37 +31,32 @@ class TestInteractive(unittest2.TestCase): - def assertPromptMessage(self, prompt_mock, message, msg=None): self.assertEqual(prompt_mock.call_args[0], (message,), msg) def assertPromptDescription(self, prompt_mock, message, msg=None): - toolbar_factory = prompt_mock.call_args[1]['get_bottom_toolbar_tokens'] + toolbar_factory = prompt_mock.call_args[1]["get_bottom_toolbar_tokens"] self.assertEqual(toolbar_factory(None)[0][1], message, msg) def assertPromptValidate(self, prompt_mock, value): - validator = prompt_mock.call_args[1]['validator'] + validator = prompt_mock.call_args[1]["validator"] validator.validate(Document(text=six.text_type(value))) def assertPromptPassword(self, prompt_mock, value, msg=None): - self.assertEqual(prompt_mock.call_args[1]['is_password'], value, msg) + self.assertEqual(prompt_mock.call_args[1]["is_password"], value, msg) def test_interactive_form(self): reader = mock.MagicMock() Reader = mock.MagicMock(return_value=reader) Reader.condition = mock.MagicMock(return_value=True) - schema = { - 'string': { - 'type': 'string' - } - } + schema = {"string": {"type": "string"}} - with mock.patch.object(interactive.InteractiveForm, 'readers', [Reader]): + with mock.patch.object(interactive.InteractiveForm, "readers", [Reader]): interactive.InteractiveForm(schema).initiate_dialog() - Reader.condition.assert_called_once_with(schema['string']) + Reader.condition.assert_called_once_with(schema["string"]) reader.read.assert_called_once_with() def test_interactive_form_no_match(self): @@ -69,35 +64,27 @@ def test_interactive_form_no_match(self): Reader = mock.MagicMock(return_value=reader) Reader.condition = mock.MagicMock(return_value=False) - schema = { - 'string': { - 'type': 'string' - } - } + schema = {"string": {"type": "string"}} - with mock.patch.object(interactive.InteractiveForm, 'readers', [Reader]): + with mock.patch.object(interactive.InteractiveForm, "readers", [Reader]): interactive.InteractiveForm(schema).initiate_dialog() - Reader.condition.assert_called_once_with(schema['string']) + Reader.condition.assert_called_once_with(schema["string"]) reader.read.assert_not_called() - @mock.patch('sys.stdout', new_callable=StringIO) + @mock.patch("sys.stdout", new_callable=StringIO) def test_interactive_form_interrupted(self, stdout_mock): reader = mock.MagicMock() Reader = mock.MagicMock(return_value=reader) Reader.condition = mock.MagicMock(return_value=True) reader.read = mock.MagicMock(side_effect=KeyboardInterrupt) - schema = { - 'string': { - 'type': 'string' - } - } + schema = {"string": {"type": "string"}} - with mock.patch.object(interactive.InteractiveForm, 'readers', [Reader]): + with mock.patch.object(interactive.InteractiveForm, "readers", [Reader]): interactive.InteractiveForm(schema).initiate_dialog() - self.assertEqual(stdout_mock.getvalue(), 'Dialog interrupted.\n') + self.assertEqual(stdout_mock.getvalue(), "Dialog interrupted.\n") def test_interactive_form_interrupted_reraised(self): reader = mock.MagicMock() @@ -105,285 +92,278 @@ def test_interactive_form_interrupted_reraised(self): Reader.condition = mock.MagicMock(return_value=True) reader.read = mock.MagicMock(side_effect=KeyboardInterrupt) - schema = { - 'string': { - 'type': 'string' - } - } + schema = {"string": {"type": "string"}} - with mock.patch.object(interactive.InteractiveForm, 'readers', [Reader]): - self.assertRaises(interactive.DialogInterrupted, - interactive.InteractiveForm(schema, reraise=True).initiate_dialog) + with mock.patch.object(interactive.InteractiveForm, "readers", [Reader]): + self.assertRaises( + interactive.DialogInterrupted, + interactive.InteractiveForm(schema, reraise=True).initiate_dialog, + ) - @mock.patch.object(interactive, 'prompt') + @mock.patch.object(interactive, "prompt") def test_stringreader(self, prompt_mock): - spec = { - 'description': 'some description', - 'default': 'hey' - } - Reader = interactive.StringReader('some', spec) + spec = {"description": "some description", "default": "hey"} + Reader = interactive.StringReader("some", spec) - prompt_mock.return_value = 'stuff' + prompt_mock.return_value = "stuff" result = Reader.read() - self.assertEqual(result, 'stuff') - self.assertPromptMessage(prompt_mock, 'some [hey]: ') - self.assertPromptDescription(prompt_mock, 'some description') - self.assertPromptValidate(prompt_mock, 'stuff') + self.assertEqual(result, "stuff") + self.assertPromptMessage(prompt_mock, "some [hey]: ") + self.assertPromptDescription(prompt_mock, "some description") + self.assertPromptValidate(prompt_mock, "stuff") - prompt_mock.return_value = '' + prompt_mock.return_value = "" result = Reader.read() - self.assertEqual(result, 'hey') - self.assertPromptValidate(prompt_mock, '') + self.assertEqual(result, "hey") + self.assertPromptValidate(prompt_mock, "") - @mock.patch.object(interactive, 'prompt') + @mock.patch.object(interactive, "prompt") def test_booleanreader(self, prompt_mock): - spec = { - 'description': 'some description', - 'default': False - } - Reader = interactive.BooleanReader('some', spec) + spec = {"description": "some description", "default": False} + Reader = interactive.BooleanReader("some", spec) - prompt_mock.return_value = 'y' + prompt_mock.return_value = "y" result = Reader.read() self.assertEqual(result, True) - self.assertPromptMessage(prompt_mock, 'some (boolean) [n]: ') - self.assertPromptDescription(prompt_mock, 'some description') - self.assertPromptValidate(prompt_mock, 'y') - self.assertRaises(prompt_toolkit.validation.ValidationError, - self.assertPromptValidate, prompt_mock, 'some') - - prompt_mock.return_value = '' + self.assertPromptMessage(prompt_mock, "some (boolean) [n]: ") + self.assertPromptDescription(prompt_mock, "some description") + self.assertPromptValidate(prompt_mock, "y") + self.assertRaises( + prompt_toolkit.validation.ValidationError, + self.assertPromptValidate, + prompt_mock, + "some", + ) + + prompt_mock.return_value = "" result = Reader.read() self.assertEqual(result, False) - self.assertPromptValidate(prompt_mock, '') + self.assertPromptValidate(prompt_mock, "") - @mock.patch.object(interactive, 'prompt') + @mock.patch.object(interactive, "prompt") def test_numberreader(self, prompt_mock): - spec = { - 'description': 'some description', - 'default': 3.2 - } - Reader = interactive.NumberReader('some', spec) + spec = {"description": "some description", "default": 3.2} + Reader = interactive.NumberReader("some", spec) - prompt_mock.return_value = '5.3' + prompt_mock.return_value = "5.3" result = Reader.read() self.assertEqual(result, 5.3) - self.assertPromptMessage(prompt_mock, 'some (float) [3.2]: ') - self.assertPromptDescription(prompt_mock, 'some description') - self.assertPromptValidate(prompt_mock, '5.3') - self.assertRaises(prompt_toolkit.validation.ValidationError, - self.assertPromptValidate, prompt_mock, 'some') - - prompt_mock.return_value = '' + self.assertPromptMessage(prompt_mock, "some (float) [3.2]: ") + self.assertPromptDescription(prompt_mock, "some description") + self.assertPromptValidate(prompt_mock, "5.3") + self.assertRaises( + prompt_toolkit.validation.ValidationError, + self.assertPromptValidate, + prompt_mock, + "some", + ) + + prompt_mock.return_value = "" result = Reader.read() self.assertEqual(result, 3.2) - self.assertPromptValidate(prompt_mock, '') + self.assertPromptValidate(prompt_mock, "") - @mock.patch.object(interactive, 'prompt') + @mock.patch.object(interactive, "prompt") def test_integerreader(self, prompt_mock): - spec = { - 'description': 'some description', - 'default': 3 - } - Reader = interactive.IntegerReader('some', spec) + spec = {"description": "some description", "default": 3} + Reader = interactive.IntegerReader("some", spec) - prompt_mock.return_value = '5' + prompt_mock.return_value = "5" result = Reader.read() self.assertEqual(result, 5) - self.assertPromptMessage(prompt_mock, 'some (integer) [3]: ') - self.assertPromptDescription(prompt_mock, 'some description') - self.assertPromptValidate(prompt_mock, '5') - self.assertRaises(prompt_toolkit.validation.ValidationError, - self.assertPromptValidate, prompt_mock, '5.3') - - prompt_mock.return_value = '' + self.assertPromptMessage(prompt_mock, "some (integer) [3]: ") + self.assertPromptDescription(prompt_mock, "some description") + self.assertPromptValidate(prompt_mock, "5") + self.assertRaises( + prompt_toolkit.validation.ValidationError, + self.assertPromptValidate, + prompt_mock, + "5.3", + ) + + prompt_mock.return_value = "" result = Reader.read() self.assertEqual(result, 3) - self.assertPromptValidate(prompt_mock, '') + self.assertPromptValidate(prompt_mock, "") - @mock.patch.object(interactive, 'prompt') + @mock.patch.object(interactive, "prompt") def test_secretstringreader(self, prompt_mock): - spec = { - 'description': 'some description', - 'default': 'hey' - } - Reader = interactive.SecretStringReader('some', spec) + spec = {"description": "some description", "default": "hey"} + Reader = interactive.SecretStringReader("some", spec) - prompt_mock.return_value = 'stuff' + prompt_mock.return_value = "stuff" result = Reader.read() - self.assertEqual(result, 'stuff') - self.assertPromptMessage(prompt_mock, 'some (secret) [hey]: ') - self.assertPromptDescription(prompt_mock, 'some description') - self.assertPromptValidate(prompt_mock, 'stuff') + self.assertEqual(result, "stuff") + self.assertPromptMessage(prompt_mock, "some (secret) [hey]: ") + self.assertPromptDescription(prompt_mock, "some description") + self.assertPromptValidate(prompt_mock, "stuff") self.assertPromptPassword(prompt_mock, True) - prompt_mock.return_value = '' + prompt_mock.return_value = "" result = Reader.read() - self.assertEqual(result, 'hey') - self.assertPromptValidate(prompt_mock, '') + self.assertEqual(result, "hey") + self.assertPromptValidate(prompt_mock, "") - @mock.patch.object(interactive, 'prompt') + @mock.patch.object(interactive, "prompt") def test_enumreader(self, prompt_mock): spec = { - 'enum': ['some', 'thing', 'else'], - 'description': 'some description', - 'default': 'thing' + "enum": ["some", "thing", "else"], + "description": "some description", + "default": "thing", } - Reader = interactive.EnumReader('some', spec) + Reader = interactive.EnumReader("some", spec) - prompt_mock.return_value = '2' + prompt_mock.return_value = "2" result = Reader.read() - self.assertEqual(result, 'else') - message = 'some: \n 0 - some\n 1 - thing\n 2 - else\nChoose from 0, 1, 2 [1]: ' + self.assertEqual(result, "else") + message = "some: \n 0 - some\n 1 - thing\n 2 - else\nChoose from 0, 1, 2 [1]: " self.assertPromptMessage(prompt_mock, message) - self.assertPromptDescription(prompt_mock, 'some description') - self.assertPromptValidate(prompt_mock, '0') - self.assertRaises(prompt_toolkit.validation.ValidationError, - self.assertPromptValidate, prompt_mock, 'some') - self.assertRaises(prompt_toolkit.validation.ValidationError, - self.assertPromptValidate, prompt_mock, '5') - - prompt_mock.return_value = '' + self.assertPromptDescription(prompt_mock, "some description") + self.assertPromptValidate(prompt_mock, "0") + self.assertRaises( + prompt_toolkit.validation.ValidationError, + self.assertPromptValidate, + prompt_mock, + "some", + ) + self.assertRaises( + prompt_toolkit.validation.ValidationError, + self.assertPromptValidate, + prompt_mock, + "5", + ) + + prompt_mock.return_value = "" result = Reader.read() - self.assertEqual(result, 'thing') - self.assertPromptValidate(prompt_mock, '') + self.assertEqual(result, "thing") + self.assertPromptValidate(prompt_mock, "") - @mock.patch.object(interactive, 'prompt') + @mock.patch.object(interactive, "prompt") def test_arrayreader(self, prompt_mock): - spec = { - 'description': 'some description', - 'default': ['a', 'b'] - } - Reader = interactive.ArrayReader('some', spec) + spec = {"description": "some description", "default": ["a", "b"]} + Reader = interactive.ArrayReader("some", spec) - prompt_mock.return_value = 'some,thing,else' + prompt_mock.return_value = "some,thing,else" result = Reader.read() - self.assertEqual(result, ['some', 'thing', 'else']) - self.assertPromptMessage(prompt_mock, 'some (comma-separated list) [a,b]: ') - self.assertPromptDescription(prompt_mock, 'some description') - self.assertPromptValidate(prompt_mock, 'some,thing,else') + self.assertEqual(result, ["some", "thing", "else"]) + self.assertPromptMessage(prompt_mock, "some (comma-separated list) [a,b]: ") + self.assertPromptDescription(prompt_mock, "some description") + self.assertPromptValidate(prompt_mock, "some,thing,else") - prompt_mock.return_value = '' + prompt_mock.return_value = "" result = Reader.read() - self.assertEqual(result, ['a', 'b']) - self.assertPromptValidate(prompt_mock, '') + self.assertEqual(result, ["a", "b"]) + self.assertPromptValidate(prompt_mock, "") - @mock.patch.object(interactive, 'prompt') + @mock.patch.object(interactive, "prompt") def test_arrayreader_ends_with_comma(self, prompt_mock): - spec = { - 'description': 'some description', - 'default': ['a', 'b'] - } - Reader = interactive.ArrayReader('some', spec) + spec = {"description": "some description", "default": ["a", "b"]} + Reader = interactive.ArrayReader("some", spec) - prompt_mock.return_value = 'some,thing,else,' + prompt_mock.return_value = "some,thing,else," result = Reader.read() - self.assertEqual(result, ['some', 'thing', 'else', '']) - self.assertPromptMessage(prompt_mock, 'some (comma-separated list) [a,b]: ') - self.assertPromptDescription(prompt_mock, 'some description') - self.assertPromptValidate(prompt_mock, 'some,thing,else,') + self.assertEqual(result, ["some", "thing", "else", ""]) + self.assertPromptMessage(prompt_mock, "some (comma-separated list) [a,b]: ") + self.assertPromptDescription(prompt_mock, "some description") + self.assertPromptValidate(prompt_mock, "some,thing,else,") - @mock.patch.object(interactive, 'prompt') + @mock.patch.object(interactive, "prompt") def test_arrayenumreader(self, prompt_mock): spec = { - 'items': { - 'enum': ['a', 'b', 'c', 'd', 'e'] - }, - 'description': 'some description', - 'default': ['a', 'b'] + "items": {"enum": ["a", "b", "c", "d", "e"]}, + "description": "some description", + "default": ["a", "b"], } - Reader = interactive.ArrayEnumReader('some', spec) + Reader = interactive.ArrayEnumReader("some", spec) - prompt_mock.return_value = '0,2,4' + prompt_mock.return_value = "0,2,4" result = Reader.read() - self.assertEqual(result, ['a', 'c', 'e']) - message = 'some: \n 0 - a\n 1 - b\n 2 - c\n 3 - d\n 4 - e\nChoose from 0, 1, 2... [0, 1]: ' + self.assertEqual(result, ["a", "c", "e"]) + message = "some: \n 0 - a\n 1 - b\n 2 - c\n 3 - d\n 4 - e\nChoose from 0, 1, 2... [0, 1]: " self.assertPromptMessage(prompt_mock, message) - self.assertPromptDescription(prompt_mock, 'some description') - self.assertPromptValidate(prompt_mock, '0,2,4') + self.assertPromptDescription(prompt_mock, "some description") + self.assertPromptValidate(prompt_mock, "0,2,4") - prompt_mock.return_value = '' + prompt_mock.return_value = "" result = Reader.read() - self.assertEqual(result, ['a', 'b']) - self.assertPromptValidate(prompt_mock, '') + self.assertEqual(result, ["a", "b"]) + self.assertPromptValidate(prompt_mock, "") - @mock.patch.object(interactive, 'prompt') + @mock.patch.object(interactive, "prompt") def test_arrayenumreader_ends_with_comma(self, prompt_mock): spec = { - 'items': { - 'enum': ['a', 'b', 'c', 'd', 'e'] - }, - 'description': 'some description', - 'default': ['a', 'b'] + "items": {"enum": ["a", "b", "c", "d", "e"]}, + "description": "some description", + "default": ["a", "b"], } - Reader = interactive.ArrayEnumReader('some', spec) + Reader = interactive.ArrayEnumReader("some", spec) - prompt_mock.return_value = '0,2,4,' + prompt_mock.return_value = "0,2,4," result = Reader.read() - self.assertEqual(result, ['a', 'c', 'e']) - message = 'some: \n 0 - a\n 1 - b\n 2 - c\n 3 - d\n 4 - e\nChoose from 0, 1, 2... [0, 1]: ' + self.assertEqual(result, ["a", "c", "e"]) + message = "some: \n 0 - a\n 1 - b\n 2 - c\n 3 - d\n 4 - e\nChoose from 0, 1, 2... [0, 1]: " self.assertPromptMessage(prompt_mock, message) - self.assertPromptDescription(prompt_mock, 'some description') - self.assertPromptValidate(prompt_mock, '0,2,4,') + self.assertPromptDescription(prompt_mock, "some description") + self.assertPromptValidate(prompt_mock, "0,2,4,") - @mock.patch.object(interactive, 'prompt') + @mock.patch.object(interactive, "prompt") def test_arrayobjectreader(self, prompt_mock): spec = { - 'items': { - 'type': 'object', - 'properties': { - 'foo': { - 'type': 'string', - 'description': 'some description', + "items": { + "type": "object", + "properties": { + "foo": { + "type": "string", + "description": "some description", + }, + "bar": { + "type": "string", + "description": "some description", }, - 'bar': { - 'type': 'string', - 'description': 'some description', - } - } + }, }, - 'description': 'some description', + "description": "some description", } - Reader = interactive.ArrayObjectReader('some', spec) + Reader = interactive.ArrayObjectReader("some", spec) # To emulate continuing setting, this flag variable is needed self.is_continued = False def side_effect(msg, **kwargs): - if re.match(r'^~~~ Would you like to add another item to.*', msg): + if re.match(r"^~~~ Would you like to add another item to.*", msg): # prompt requires the input to judge continuing setting, or not if not self.is_continued: # continuing the configuration only once self.is_continued = True - return '' + return "" else: # finishing to configuration - return 'n' + return "n" else: # prompt requires the input of property value in the object - return 'value' + return "value" prompt_mock.side_effect = side_effect results = Reader.read() self.assertEqual(len(results), 2) self.assertTrue(all([len(list(x.keys())) == 2 for x in results])) - self.assertTrue(all(['foo' in x and 'bar' in x for x in results])) + self.assertTrue(all(["foo" in x and "bar" in x for x in results])) diff --git a/st2client/tests/unit/test_keyvalue.py b/st2client/tests/unit/test_keyvalue.py index bb5bf09d600..52c240a052c 100644 --- a/st2client/tests/unit/test_keyvalue.py +++ b/st2client/tests/unit/test_keyvalue.py @@ -29,77 +29,70 @@ LOG = logging.getLogger(__name__) KEYVALUE = { - 'id': 'kv_name', - 'name': 'kv_name.', - 'value': 'super cool value', - 'scope': 'system' + "id": "kv_name", + "name": "kv_name.", + "value": "super cool value", + "scope": "system", } KEYVALUE_USER = { - 'id': 'kv_name', - 'name': 'kv_name.', - 'value': 'super cool value', - 'scope': 'system', - 'user': 'stanley' + "id": "kv_name", + "name": "kv_name.", + "value": "super cool value", + "scope": "system", + "user": "stanley", } KEYVALUE_SECRET = { - 'id': 'kv_name', - 'name': 'kv_name.', - 'value': 'super cool value', - 'scope': 'system', - 'secret': True + "id": "kv_name", + "name": "kv_name.", + "value": "super cool value", + "scope": "system", + "secret": True, } KEYVALUE_PRE_ENCRYPTED = { - 'id': 'kv_name', - 'name': 'kv_name.', - 'value': 'AAABBBCCC1234', - 'scope': 'system', - 'encrypted': True, - 'secret': True + "id": "kv_name", + "name": "kv_name.", + "value": "AAABBBCCC1234", + "scope": "system", + "encrypted": True, + "secret": True, } KEYVALUE_TTL = { - 'id': 'kv_name', - 'name': 'kv_name.', - 'value': 'super cool value', - 'scope': 'system', - 'ttl': 100 + "id": "kv_name", + "name": "kv_name.", + "value": "super cool value", + "scope": "system", + "ttl": 100, } KEYVALUE_OBJECT = { - 'id': 'kv_name', - 'name': 'kv_name.', - 'value': {'obj': [1, True, 23.4, 'abc']}, - 'scope': 'system', + "id": "kv_name", + "name": "kv_name.", + "value": {"obj": [1, True, 23.4, "abc"]}, + "scope": "system", } KEYVALUE_ALL = { - 'id': 'kv_name', - 'name': 'kv_name.', - 'value': 'AAAAABBBBBCCCCCCDDDDD11122345', - 'scope': 'system', - 'user': 'stanley', - 'secret': True, - 'encrypted': True, - 'ttl': 100 + "id": "kv_name", + "name": "kv_name.", + "value": "AAAAABBBBBCCCCCCDDDDD11122345", + "scope": "system", + "user": "stanley", + "secret": True, + "encrypted": True, + "ttl": 100, } -KEYVALUE_MISSING_NAME = { - 'id': 'kv_name', - 'value': 'super cool value' -} +KEYVALUE_MISSING_NAME = {"id": "kv_name", "value": "super cool value"} -KEYVALUE_MISSING_VALUE = { - 'id': 'kv_name', - 'name': 'kv_name.' -} +KEYVALUE_MISSING_VALUE = {"id": "kv_name", "name": "kv_name."} class TestKeyValueBase(base.BaseCLITestCase): - """Base class for "key" CLI tests - """ + """Base class for "key" CLI tests""" capture_output = True @@ -107,8 +100,8 @@ def __init__(self, *args, **kwargs): super(TestKeyValueBase, self).__init__(*args, **kwargs) self.parser = argparse.ArgumentParser() - self.parser.add_argument('-t', '--token', dest='token') - self.parser.add_argument('--api-key', dest='api_key') + self.parser.add_argument("-t", "--token", dest="token") + self.parser.add_argument("--api-key", dest="api_key") self.shell = shell.Shell() def setUp(self): @@ -119,44 +112,49 @@ def tearDown(self): class TestKeyValueSet(TestKeyValueBase): - @mock.patch.object( - requests, 'put', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(KEYVALUE_PRE_ENCRYPTED), 200, - 'OK'))) + requests, + "put", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps(KEYVALUE_PRE_ENCRYPTED), 200, "OK" + ) + ), + ) def test_set_keyvalue(self): - """Test setting key/value pair with optional pre_encrypted field - """ - args = ['key', 'set', '--encrypted', 'kv_name', 'AAABBBCCC1234'] + """Test setting key/value pair with optional pre_encrypted field""" + args = ["key", "set", "--encrypted", "kv_name", "AAABBBCCC1234"] retcode = self.shell.run(args) self.assertEqual(retcode, 0) def test_encrypt_and_encrypted_flags_are_mutually_exclusive(self): - args = ['key', 'set', '--encrypt', '--encrypted', 'kv_name', 'AAABBBCCC1234'] + args = ["key", "set", "--encrypt", "--encrypted", "kv_name", "AAABBBCCC1234"] - self.assertRaisesRegexp(SystemExit, '2', self.shell.run, args) + self.assertRaisesRegexp(SystemExit, "2", self.shell.run, args) self.stderr.seek(0) stderr = self.stderr.read() - expected_msg = ('error: argument --encrypted: not allowed with argument -e/--encrypt') + expected_msg = ( + "error: argument --encrypted: not allowed with argument -e/--encrypt" + ) self.assertIn(expected_msg, stderr) class TestKeyValueLoad(TestKeyValueBase): - @mock.patch.object( - requests, 'put', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(KEYVALUE), 200, 'OK'))) + requests, + "put", + mock.MagicMock(return_value=base.FakeResponse(json.dumps(KEYVALUE), 200, "OK")), + ) def test_load_keyvalue_json(self): - """Test loading of key/value pair in JSON format - """ - fd, path = tempfile.mkstemp(suffix='.json') + """Test loading of key/value pair in JSON format""" + fd, path = tempfile.mkstemp(suffix=".json") try: - with open(path, 'a') as f: + with open(path, "a") as f: f.write(json.dumps(KEYVALUE, indent=4)) - args = ['key', 'load', path] + args = ["key", "load", path] retcode = self.shell.run(args) self.assertEqual(retcode, 0) finally: @@ -164,17 +162,18 @@ def test_load_keyvalue_json(self): os.unlink(path) @mock.patch.object( - requests, 'put', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(KEYVALUE), 200, 'OK'))) + requests, + "put", + mock.MagicMock(return_value=base.FakeResponse(json.dumps(KEYVALUE), 200, "OK")), + ) def test_load_keyvalue_yaml(self): - """Test loading of key/value pair in YAML format - """ - fd, path = tempfile.mkstemp(suffix='.yaml') + """Test loading of key/value pair in YAML format""" + fd, path = tempfile.mkstemp(suffix=".yaml") try: - with open(path, 'a') as f: + with open(path, "a") as f: f.write(yaml.safe_dump(KEYVALUE)) - args = ['key', 'load', path] + args = ["key", "load", path] retcode = self.shell.run(args) self.assertEqual(retcode, 0) finally: @@ -182,17 +181,20 @@ def test_load_keyvalue_yaml(self): os.unlink(path) @mock.patch.object( - requests, 'put', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(KEYVALUE_USER), 200, 'OK'))) + requests, + "put", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(KEYVALUE_USER), 200, "OK") + ), + ) def test_load_keyvalue_user(self): - """Test loading of key/value pair with the optional user field - """ - fd, path = tempfile.mkstemp(suffix='.json') + """Test loading of key/value pair with the optional user field""" + fd, path = tempfile.mkstemp(suffix=".json") try: - with open(path, 'a') as f: + with open(path, "a") as f: f.write(json.dumps(KEYVALUE_USER, indent=4)) - args = ['key', 'load', path] + args = ["key", "load", path] retcode = self.shell.run(args) self.assertEqual(retcode, 0) finally: @@ -200,17 +202,20 @@ def test_load_keyvalue_user(self): os.unlink(path) @mock.patch.object( - requests, 'put', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(KEYVALUE_SECRET), 200, 'OK'))) + requests, + "put", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(KEYVALUE_SECRET), 200, "OK") + ), + ) def test_load_keyvalue_secret(self): - """Test loading of key/value pair with the optional secret field - """ - fd, path = tempfile.mkstemp(suffix='.json') + """Test loading of key/value pair with the optional secret field""" + fd, path = tempfile.mkstemp(suffix=".json") try: - with open(path, 'a') as f: + with open(path, "a") as f: f.write(json.dumps(KEYVALUE_SECRET, indent=4)) - args = ['key', 'load', path] + args = ["key", "load", path] retcode = self.shell.run(args) self.assertEqual(retcode, 0) finally: @@ -218,18 +223,22 @@ def test_load_keyvalue_secret(self): os.unlink(path) @mock.patch.object( - requests, 'put', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(KEYVALUE_PRE_ENCRYPTED), 200, - 'OK'))) + requests, + "put", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps(KEYVALUE_PRE_ENCRYPTED), 200, "OK" + ) + ), + ) def test_load_keyvalue_already_encrypted(self): - """Test loading of key/value pair with the pre-encrypted value - """ - fd, path = tempfile.mkstemp(suffix='.json') + """Test loading of key/value pair with the pre-encrypted value""" + fd, path = tempfile.mkstemp(suffix=".json") try: - with open(path, 'a') as f: + with open(path, "a") as f: f.write(json.dumps(KEYVALUE_PRE_ENCRYPTED, indent=4)) - args = ['key', 'load', path] + args = ["key", "load", path] retcode = self.shell.run(args) self.assertEqual(retcode, 0) finally: @@ -237,17 +246,20 @@ def test_load_keyvalue_already_encrypted(self): os.unlink(path) @mock.patch.object( - requests, 'put', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(KEYVALUE_TTL), 200, 'OK'))) + requests, + "put", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(KEYVALUE_TTL), 200, "OK") + ), + ) def test_load_keyvalue_ttl(self): - """Test loading of key/value pair with the optional ttl field - """ - fd, path = tempfile.mkstemp(suffix='.json') + """Test loading of key/value pair with the optional ttl field""" + fd, path = tempfile.mkstemp(suffix=".json") try: - with open(path, 'a') as f: + with open(path, "a") as f: f.write(json.dumps(KEYVALUE_TTL, indent=4)) - args = ['key', 'load', path] + args = ["key", "load", path] retcode = self.shell.run(args) self.assertEqual(retcode, 0) finally: @@ -255,23 +267,26 @@ def test_load_keyvalue_ttl(self): os.unlink(path) @mock.patch.object( - requests, 'put', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(KEYVALUE_OBJECT), 200, 'OK'))) + requests, + "put", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(KEYVALUE_OBJECT), 200, "OK") + ), + ) def test_load_keyvalue_object(self): - """Test loading of key/value pair where the value is an object - """ - fd, path = tempfile.mkstemp(suffix='.json') + """Test loading of key/value pair where the value is an object""" + fd, path = tempfile.mkstemp(suffix=".json") try: - with open(path, 'a') as f: + with open(path, "a") as f: f.write(json.dumps(KEYVALUE_OBJECT, indent=4)) # test converting with short option - args = ['key', 'load', '-c', path] + args = ["key", "load", "-c", path] retcode = self.shell.run(args) self.assertEqual(retcode, 0) # test converting with long option - args = ['key', 'load', '--convert', path] + args = ["key", "load", "--convert", path] retcode = self.shell.run(args) self.assertEqual(retcode, 0) finally: @@ -279,19 +294,23 @@ def test_load_keyvalue_object(self): os.unlink(path) @mock.patch.object( - requests, 'put', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(KEYVALUE_OBJECT), 200, 'OK'))) + requests, + "put", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(KEYVALUE_OBJECT), 200, "OK") + ), + ) def test_load_keyvalue_object_fail(self): """Test failure to load key/value pair where the value is an object - and the -c/--convert option is not passed + and the -c/--convert option is not passed """ - fd, path = tempfile.mkstemp(suffix='.json') + fd, path = tempfile.mkstemp(suffix=".json") try: - with open(path, 'a') as f: + with open(path, "a") as f: f.write(json.dumps(KEYVALUE_OBJECT, indent=4)) # test converting with short option - args = ['key', 'load', path] + args = ["key", "load", path] retcode = self.shell.run(args) self.assertNotEqual(retcode, 0) finally: @@ -299,17 +318,20 @@ def test_load_keyvalue_object_fail(self): os.unlink(path) @mock.patch.object( - requests, 'put', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(KEYVALUE_ALL), 200, 'OK'))) + requests, + "put", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(KEYVALUE_ALL), 200, "OK") + ), + ) def test_load_keyvalue_all(self): - """Test loading of key/value pair with all optional fields - """ - fd, path = tempfile.mkstemp(suffix='.json') + """Test loading of key/value pair with all optional fields""" + fd, path = tempfile.mkstemp(suffix=".json") try: - with open(path, 'a') as f: + with open(path, "a") as f: f.write(json.dumps(KEYVALUE_ALL, indent=4)) - args = ['key', 'load', path] + args = ["key", "load", path] retcode = self.shell.run(args) self.assertEqual(retcode, 0) finally: @@ -317,21 +339,23 @@ def test_load_keyvalue_all(self): os.unlink(path) @mock.patch.object( - requests, 'put', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(KEYVALUE_ALL), - 200, 'OK'))) + requests, + "put", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(KEYVALUE_ALL), 200, "OK") + ), + ) def test_load_keyvalue_array(self): - """Test loading an array of key/value pairs - """ - fd, path = tempfile.mkstemp(suffix='.json') + """Test loading an array of key/value pairs""" + fd, path = tempfile.mkstemp(suffix=".json") try: array = [KEYVALUE, KEYVALUE_ALL] json_str = json.dumps(array, indent=4) LOG.info(json_str) - with open(path, 'a') as f: + with open(path, "a") as f: f.write(json_str) - args = ['key', 'load', path] + args = ["key", "load", path] retcode = self.shell.run(args) self.assertEqual(retcode, 0) finally: @@ -339,14 +363,13 @@ def test_load_keyvalue_array(self): os.unlink(path) def test_load_keyvalue_missing_name(self): - """Test loading of a key/value pair with the required field 'name' missing - """ - fd, path = tempfile.mkstemp(suffix='.json') + """Test loading of a key/value pair with the required field 'name' missing""" + fd, path = tempfile.mkstemp(suffix=".json") try: - with open(path, 'a') as f: + with open(path, "a") as f: f.write(json.dumps(KEYVALUE_MISSING_NAME, indent=4)) - args = ['key', 'load', path] + args = ["key", "load", path] retcode = self.shell.run(args) self.assertEqual(retcode, 1) finally: @@ -354,14 +377,13 @@ def test_load_keyvalue_missing_name(self): os.unlink(path) def test_load_keyvalue_missing_value(self): - """Test loading of a key/value pair with the required field 'value' missing - """ - fd, path = tempfile.mkstemp(suffix='.json') + """Test loading of a key/value pair with the required field 'value' missing""" + fd, path = tempfile.mkstemp(suffix=".json") try: - with open(path, 'a') as f: + with open(path, "a") as f: f.write(json.dumps(KEYVALUE_MISSING_VALUE, indent=4)) - args = ['key', 'load', path] + args = ["key", "load", path] retcode = self.shell.run(args) self.assertEqual(retcode, 1) finally: @@ -369,19 +391,17 @@ def test_load_keyvalue_missing_value(self): os.unlink(path) def test_load_keyvalue_missing_file(self): - """Test loading of a key/value pair with a missing file - """ - path = '/some/file/that/doesnt/exist.json' - args = ['key', 'load', path] + """Test loading of a key/value pair with a missing file""" + path = "/some/file/that/doesnt/exist.json" + args = ["key", "load", path] retcode = self.shell.run(args) self.assertEqual(retcode, 1) def test_load_keyvalue_bad_file_extension(self): - """Test loading of a key/value pair with a bad file extension - """ - fd, path = tempfile.mkstemp(suffix='.badext') + """Test loading of a key/value pair with a bad file extension""" + fd, path = tempfile.mkstemp(suffix=".badext") try: - args = ['key', 'load', path] + args = ["key", "load", path] retcode = self.shell.run(args) self.assertEqual(retcode, 1) finally: @@ -392,11 +412,11 @@ def test_load_keyvalue_empty_file(self): """ Loading K/V from an empty file shouldn't throw an error """ - fd, path = tempfile.mkstemp(suffix='.yaml') + fd, path = tempfile.mkstemp(suffix=".yaml") try: - args = ['key', 'load', path] + args = ["key", "load", path] retcode = self.shell.run(args) - self.assertIn('No matching items found', self.stdout.getvalue()) + self.assertIn("No matching items found", self.stdout.getvalue()) self.assertEqual(retcode, 0) finally: os.close(fd) diff --git a/st2client/tests/unit/test_models.py b/st2client/tests/unit/test_models.py index dd7f35d6b8d..8a137afa139 100644 --- a/st2client/tests/unit/test_models.py +++ b/st2client/tests/unit/test_models.py @@ -29,22 +29,24 @@ class TestSerialization(unittest2.TestCase): - def test_resource_serialize(self): - instance = base.FakeResource(id='123', name='abc') + instance = base.FakeResource(id="123", name="abc") self.assertDictEqual(instance.serialize(), base.RESOURCES[0]) def test_resource_deserialize(self): instance = base.FakeResource.deserialize(base.RESOURCES[0]) - self.assertEqual(instance.id, '123') - self.assertEqual(instance.name, 'abc') + self.assertEqual(instance.id, "123") + self.assertEqual(instance.name, "abc") class TestResourceManager(unittest2.TestCase): - @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, 'OK'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, "OK") + ), + ) def test_resource_get_all(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) resources = mgr.get_all() @@ -53,8 +55,12 @@ def test_resource_get_all(self): self.assertListEqual(actual, expected) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, 'OK'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, "OK") + ), + ) def test_resource_get_all_with_limit(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) resources = mgr.get_all(limit=50) @@ -63,135 +69,197 @@ def test_resource_get_all_with_limit(self): self.assertListEqual(actual, expected) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse('', 500, 'INTERNAL SERVER ERROR'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse("", 500, "INTERNAL SERVER ERROR") + ), + ) def test_resource_get_all_failed(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) self.assertRaises(Exception, mgr.get_all) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(base.RESOURCES[0]), 200, 'OK'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(base.RESOURCES[0]), 200, "OK") + ), + ) def test_resource_get_by_id(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) - resource = mgr.get_by_id('123') + resource = mgr.get_by_id("123") actual = resource.serialize() expected = json.loads(json.dumps(base.RESOURCES[0])) self.assertEqual(actual, expected) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse('', 404, 'NOT FOUND'))) + httpclient.HTTPClient, + "get", + mock.MagicMock(return_value=base.FakeResponse("", 404, "NOT FOUND")), + ) def test_resource_get_by_id_404(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) - resource = mgr.get_by_id('123') + resource = mgr.get_by_id("123") self.assertIsNone(resource) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse('', 500, 'INTERNAL SERVER ERROR'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse("", 500, "INTERNAL SERVER ERROR") + ), + ) def test_resource_get_by_id_failed(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) self.assertRaises(Exception, mgr.get_by_id) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, 'OK', - {}))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps([base.RESOURCES[0]]), 200, "OK", {} + ) + ), + ) def test_resource_query(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) - resources = mgr.query(name='abc') + resources = mgr.query(name="abc") actual = [resource.serialize() for resource in resources] expected = json.loads(json.dumps([base.RESOURCES[0]])) self.assertEqual(actual, expected) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, 'OK', - {'X-Total-Count': '50'}))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps([base.RESOURCES[0]]), 200, "OK", {"X-Total-Count": "50"} + ) + ), + ) def test_resource_query_with_count(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) - resources, count = mgr.query_with_count(name='abc') + resources, count = mgr.query_with_count(name="abc") actual = [resource.serialize() for resource in resources] expected = json.loads(json.dumps([base.RESOURCES[0]])) self.assertEqual(actual, expected) self.assertEqual(count, 50) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, 'OK', - {}))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps([base.RESOURCES[0]]), 200, "OK", {} + ) + ), + ) def test_resource_query_with_limit(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) - resources = mgr.query(name='abc', limit=50) + resources = mgr.query(name="abc", limit=50) actual = [resource.serialize() for resource in resources] expected = json.loads(json.dumps([base.RESOURCES[0]])) self.assertEqual(actual, expected) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse('', 404, 'NOT FOUND', - {'X-Total-Count': '30'}))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse( + "", 404, "NOT FOUND", {"X-Total-Count": "30"} + ) + ), + ) def test_resource_query_404(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) # No X-Total-Count - resources = mgr.query(name='abc') + resources = mgr.query(name="abc") self.assertListEqual(resources, []) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse('', 404, 'NOT FOUND', - {'X-Total-Count': '30'}))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse( + "", 404, "NOT FOUND", {"X-Total-Count": "30"} + ) + ), + ) def test_resource_query_with_count_404(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) - resources, count = mgr.query_with_count(name='abc') + resources, count = mgr.query_with_count(name="abc") self.assertListEqual(resources, []) self.assertIsNone(count) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse('', 500, 'INTERNAL SERVER ERROR'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse("", 500, "INTERNAL SERVER ERROR") + ), + ) def test_resource_query_failed(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) - self.assertRaises(Exception, mgr.query, name='abc') + self.assertRaises(Exception, mgr.query, name="abc") @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, 'OK', - {}))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps([base.RESOURCES[0]]), 200, "OK", {} + ) + ), + ) def test_resource_get_by_name(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) # No X-Total-Count - resource = mgr.get_by_name('abc') + resource = mgr.get_by_name("abc") actual = resource.serialize() expected = json.loads(json.dumps(base.RESOURCES[0])) self.assertEqual(actual, expected) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse('', 404, 'NOT FOUND'))) + httpclient.HTTPClient, + "get", + mock.MagicMock(return_value=base.FakeResponse("", 404, "NOT FOUND")), + ) def test_resource_get_by_name_404(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) - resource = mgr.get_by_name('abc') + resource = mgr.get_by_name("abc") self.assertIsNone(resource) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, 'OK'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, "OK") + ), + ) def test_resource_get_by_name_ambiguous(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) - self.assertRaises(Exception, mgr.get_by_name, 'abc') + self.assertRaises(Exception, mgr.get_by_name, "abc") @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse('', 500, 'INTERNAL SERVER ERROR'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse("", 500, "INTERNAL SERVER ERROR") + ), + ) def test_resource_get_by_name_failed(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) self.assertRaises(Exception, mgr.get_by_name) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(base.RESOURCES[0]), 200, 'OK'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(base.RESOURCES[0]), 200, "OK") + ), + ) def test_resource_create(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) instance = base.FakeResource.deserialize('{"name": "abc"}') @@ -199,16 +267,24 @@ def test_resource_create(self): self.assertIsNotNone(resource) @mock.patch.object( - httpclient.HTTPClient, 'post', - mock.MagicMock(return_value=base.FakeResponse('', 500, 'INTERNAL SERVER ERROR'))) + httpclient.HTTPClient, + "post", + mock.MagicMock( + return_value=base.FakeResponse("", 500, "INTERNAL SERVER ERROR") + ), + ) def test_resource_create_failed(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) instance = base.FakeResource.deserialize('{"name": "abc"}') self.assertRaises(Exception, mgr.create, instance) @mock.patch.object( - httpclient.HTTPClient, 'put', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(base.RESOURCES[0]), 200, 'OK'))) + httpclient.HTTPClient, + "put", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(base.RESOURCES[0]), 200, "OK") + ), + ) def test_resource_update(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) text = '{"id": "123", "name": "cba"}' @@ -217,8 +293,12 @@ def test_resource_update(self): self.assertIsNotNone(resource) @mock.patch.object( - httpclient.HTTPClient, 'put', - mock.MagicMock(return_value=base.FakeResponse('', 500, 'INTERNAL SERVER ERROR'))) + httpclient.HTTPClient, + "put", + mock.MagicMock( + return_value=base.FakeResponse("", 500, "INTERNAL SERVER ERROR") + ), + ) def test_resource_update_failed(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) text = '{"id": "123", "name": "cba"}' @@ -226,39 +306,57 @@ def test_resource_update_failed(self): self.assertRaises(Exception, mgr.update, instance) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, 'OK', - {}))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps([base.RESOURCES[0]]), 200, "OK", {} + ) + ), + ) @mock.patch.object( - httpclient.HTTPClient, 'delete', - mock.MagicMock(return_value=base.FakeResponse('', 204, 'NO CONTENT'))) + httpclient.HTTPClient, + "delete", + mock.MagicMock(return_value=base.FakeResponse("", 204, "NO CONTENT")), + ) def test_resource_delete(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) - instance = mgr.get_by_name('abc') + instance = mgr.get_by_name("abc") mgr.delete(instance) @mock.patch.object( - httpclient.HTTPClient, 'delete', - mock.MagicMock(return_value=base.FakeResponse('', 404, 'NOT FOUND'))) + httpclient.HTTPClient, + "delete", + mock.MagicMock(return_value=base.FakeResponse("", 404, "NOT FOUND")), + ) def test_resource_delete_404(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) instance = base.FakeResource.deserialize(base.RESOURCES[0]) mgr.delete(instance) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps([base.RESOURCES[0]]), 200, 'OK', - {}))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse( + json.dumps([base.RESOURCES[0]]), 200, "OK", {} + ) + ), + ) @mock.patch.object( - httpclient.HTTPClient, 'delete', - mock.MagicMock(return_value=base.FakeResponse('', 500, 'INTERNAL SERVER ERROR'))) + httpclient.HTTPClient, + "delete", + mock.MagicMock( + return_value=base.FakeResponse("", 500, "INTERNAL SERVER ERROR") + ), + ) def test_resource_delete_failed(self): mgr = models.ResourceManager(base.FakeResource, base.FAKE_ENDPOINT) - instance = mgr.get_by_name('abc') + instance = mgr.get_by_name("abc") self.assertRaises(Exception, mgr.delete, instance) - @mock.patch('requests.get') - @mock.patch('sseclient.SSEClient') + @mock.patch("requests.get") + @mock.patch("sseclient.SSEClient") def test_stream_resource_listen(self, mock_sseclient, mock_requests): mock_msg = mock.Mock() mock_msg.data = json.dumps(base.RESOURCES) @@ -267,14 +365,16 @@ def test_stream_resource_listen(self, mock_sseclient, mock_requests): def side_effect_checking_verify_parameter_is(): return [mock_msg] - mock_sseclient.return_value.events.side_effect = side_effect_checking_verify_parameter_is - mgr = models.StreamManager('https://example.com', cacert='/path/ca.crt') + mock_sseclient.return_value.events.side_effect = ( + side_effect_checking_verify_parameter_is + ) + mgr = models.StreamManager("https://example.com", cacert="/path/ca.crt") - resp = mgr.listen(events=['foo', 'bar']) + resp = mgr.listen(events=["foo", "bar"]) self.assertEqual(list(resp), [base.RESOURCES]) - call_args = tuple(['https://example.com/stream?events=foo%2Cbar']) - call_kwargs = {'stream': True, 'verify': '/path/ca.crt'} + call_args = tuple(["https://example.com/stream?events=foo%2Cbar"]) + call_kwargs = {"stream": True, "verify": "/path/ca.crt"} self.assertEqual(mock_requests.call_args_list[0][0], call_args) self.assertEqual(mock_requests.call_args_list[0][1], call_kwargs) @@ -283,15 +383,16 @@ def side_effect_checking_verify_parameter_is(): def side_effect_checking_verify_parameter_is_not(): return [mock_msg] - mock_sseclient.return_value.events.side_effect = \ + mock_sseclient.return_value.events.side_effect = ( side_effect_checking_verify_parameter_is_not - mgr = models.StreamManager('https://example.com') + ) + mgr = models.StreamManager("https://example.com") resp = mgr.listen() self.assertEqual(list(resp), [base.RESOURCES]) - call_args = tuple(['https://example.com/stream?']) - call_kwargs = {'stream': True} + call_args = tuple(["https://example.com/stream?"]) + call_kwargs = {"stream": True} self.assertEqual(mock_requests.call_args_list[1][0], call_args) self.assertEqual(mock_requests.call_args_list[1][1], call_kwargs) diff --git a/st2client/tests/unit/test_shell.py b/st2client/tests/unit/test_shell.py index 83835266154..bce176b4adc 100644 --- a/st2client/tests/unit/test_shell.py +++ b/st2client/tests/unit/test_shell.py @@ -38,8 +38,8 @@ LOG = logging.getLogger(__name__) BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -CONFIG_FILE_PATH_FULL = os.path.join(BASE_DIR, '../fixtures/st2rc.full.ini') -CONFIG_FILE_PATH_PARTIAL = os.path.join(BASE_DIR, '../fixtures/st2rc.partial.ini') +CONFIG_FILE_PATH_FULL = os.path.join(BASE_DIR, "../fixtures/st2rc.full.ini") +CONFIG_FILE_PATH_PARTIAL = os.path.join(BASE_DIR, "../fixtures/st2rc.partial.ini") MOCK_CONFIG = """ [credentials] @@ -77,352 +77,383 @@ def test_commands_usage_and_help_strings(self): self.stderr.seek(0) stderr = self.stderr.read() - self.assertIn('Usage: ', stderr) - self.assertIn('For example:', stderr) - self.assertIn('CLI for StackStorm', stderr) - self.assertIn('positional arguments:', stderr) + self.assertIn("Usage: ", stderr) + self.assertIn("For example:", stderr) + self.assertIn("CLI for StackStorm", stderr) + self.assertIn("positional arguments:", stderr) self.stdout.truncate() self.stderr.truncate() # --help should result in the same output try: - self.assertEqual(self.shell.run(['--help']), 0) + self.assertEqual(self.shell.run(["--help"]), 0) except SystemExit as e: self.assertEqual(e.code, 0) self.stdout.seek(0) stdout = self.stdout.read() - self.assertIn('Usage: ', stdout) - self.assertIn('For example:', stdout) - self.assertIn('CLI for StackStorm', stdout) - self.assertIn('positional arguments:', stdout) + self.assertIn("Usage: ", stdout) + self.assertIn("For example:", stdout) + self.assertIn("CLI for StackStorm", stdout) + self.assertIn("positional arguments:", stdout) self.stdout.truncate() self.stderr.truncate() # Sub command with no args try: - self.assertEqual(self.shell.run(['action']), 2) + self.assertEqual(self.shell.run(["action"]), 2) except SystemExit as e: self.assertEqual(e.code, 2) self.stderr.seek(0) stderr = self.stderr.read() - self.assertIn('usage', stderr) + self.assertIn("usage", stderr) if six.PY2: - self.assertIn('{list,get,create,update', stderr) - self.assertIn('error: too few arguments', stderr) + self.assertIn("{list,get,create,update", stderr) + self.assertIn("error: too few arguments", stderr) def test_endpoints_default(self): - base_url = 'http://127.0.0.1' - auth_url = 'http://127.0.0.1:9100' - api_url = 'http://127.0.0.1:9101/v1' - stream_url = 'http://127.0.0.1:9102/v1' - args = ['trigger', 'list'] + base_url = "http://127.0.0.1" + auth_url = "http://127.0.0.1:9100" + api_url = "http://127.0.0.1:9101/v1" + stream_url = "http://127.0.0.1:9102/v1" + args = ["trigger", "list"] parsed_args = self.shell.parser.parse_args(args) client = self.shell.get_client(parsed_args) - self.assertEqual(client.endpoints['base'], base_url) - self.assertEqual(client.endpoints['auth'], auth_url) - self.assertEqual(client.endpoints['api'], api_url) - self.assertEqual(client.endpoints['stream'], stream_url) + self.assertEqual(client.endpoints["base"], base_url) + self.assertEqual(client.endpoints["auth"], auth_url) + self.assertEqual(client.endpoints["api"], api_url) + self.assertEqual(client.endpoints["stream"], stream_url) def test_endpoints_base_url_from_cli(self): - base_url = 'http://www.st2.com' - auth_url = 'http://www.st2.com:9100' - api_url = 'http://www.st2.com:9101/v1' - stream_url = 'http://www.st2.com:9102/v1' - args = ['--url', base_url, 'trigger', 'list'] + base_url = "http://www.st2.com" + auth_url = "http://www.st2.com:9100" + api_url = "http://www.st2.com:9101/v1" + stream_url = "http://www.st2.com:9102/v1" + args = ["--url", base_url, "trigger", "list"] parsed_args = self.shell.parser.parse_args(args) client = self.shell.get_client(parsed_args) - self.assertEqual(client.endpoints['base'], base_url) - self.assertEqual(client.endpoints['auth'], auth_url) - self.assertEqual(client.endpoints['api'], api_url) - self.assertEqual(client.endpoints['stream'], stream_url) + self.assertEqual(client.endpoints["base"], base_url) + self.assertEqual(client.endpoints["auth"], auth_url) + self.assertEqual(client.endpoints["api"], api_url) + self.assertEqual(client.endpoints["stream"], stream_url) def test_endpoints_base_url_from_env(self): - base_url = 'http://www.st2.com' - auth_url = 'http://www.st2.com:9100' - api_url = 'http://www.st2.com:9101/v1' - stream_url = 'http://www.st2.com:9102/v1' - os.environ['ST2_BASE_URL'] = base_url - args = ['trigger', 'list'] + base_url = "http://www.st2.com" + auth_url = "http://www.st2.com:9100" + api_url = "http://www.st2.com:9101/v1" + stream_url = "http://www.st2.com:9102/v1" + os.environ["ST2_BASE_URL"] = base_url + args = ["trigger", "list"] parsed_args = self.shell.parser.parse_args(args) client = self.shell.get_client(parsed_args) - self.assertEqual(client.endpoints['base'], base_url) - self.assertEqual(client.endpoints['auth'], auth_url) - self.assertEqual(client.endpoints['api'], api_url) - self.assertEqual(client.endpoints['stream'], stream_url) + self.assertEqual(client.endpoints["base"], base_url) + self.assertEqual(client.endpoints["auth"], auth_url) + self.assertEqual(client.endpoints["api"], api_url) + self.assertEqual(client.endpoints["stream"], stream_url) def test_endpoints_override_from_cli(self): - base_url = 'http://www.st2.com' - auth_url = 'http://www.st2.com:8888' - api_url = 'http://www.stackstorm1.com:9101/v1' - stream_url = 'http://www.stackstorm1.com:9102/v1' - args = ['--url', base_url, - '--auth-url', auth_url, - '--api-url', api_url, - '--stream-url', stream_url, - 'trigger', 'list'] + base_url = "http://www.st2.com" + auth_url = "http://www.st2.com:8888" + api_url = "http://www.stackstorm1.com:9101/v1" + stream_url = "http://www.stackstorm1.com:9102/v1" + args = [ + "--url", + base_url, + "--auth-url", + auth_url, + "--api-url", + api_url, + "--stream-url", + stream_url, + "trigger", + "list", + ] parsed_args = self.shell.parser.parse_args(args) client = self.shell.get_client(parsed_args) - self.assertEqual(client.endpoints['base'], base_url) - self.assertEqual(client.endpoints['auth'], auth_url) - self.assertEqual(client.endpoints['api'], api_url) - self.assertEqual(client.endpoints['stream'], stream_url) + self.assertEqual(client.endpoints["base"], base_url) + self.assertEqual(client.endpoints["auth"], auth_url) + self.assertEqual(client.endpoints["api"], api_url) + self.assertEqual(client.endpoints["stream"], stream_url) def test_endpoints_override_from_env(self): - base_url = 'http://www.st2.com' - auth_url = 'http://www.st2.com:8888' - api_url = 'http://www.stackstorm1.com:9101/v1' - stream_url = 'http://www.stackstorm1.com:9102/v1' - os.environ['ST2_BASE_URL'] = base_url - os.environ['ST2_AUTH_URL'] = auth_url - os.environ['ST2_API_URL'] = api_url - os.environ['ST2_STREAM_URL'] = stream_url - args = ['trigger', 'list'] + base_url = "http://www.st2.com" + auth_url = "http://www.st2.com:8888" + api_url = "http://www.stackstorm1.com:9101/v1" + stream_url = "http://www.stackstorm1.com:9102/v1" + os.environ["ST2_BASE_URL"] = base_url + os.environ["ST2_AUTH_URL"] = auth_url + os.environ["ST2_API_URL"] = api_url + os.environ["ST2_STREAM_URL"] = stream_url + args = ["trigger", "list"] parsed_args = self.shell.parser.parse_args(args) client = self.shell.get_client(parsed_args) - self.assertEqual(client.endpoints['base'], base_url) - self.assertEqual(client.endpoints['auth'], auth_url) - self.assertEqual(client.endpoints['api'], api_url) - self.assertEqual(client.endpoints['stream'], stream_url) + self.assertEqual(client.endpoints["base"], base_url) + self.assertEqual(client.endpoints["auth"], auth_url) + self.assertEqual(client.endpoints["api"], api_url) + self.assertEqual(client.endpoints["stream"], stream_url) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, 'OK'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, "OK") + ), + ) def test_exit_code_on_success(self): - argv = ['trigger', 'list'] + argv = ["trigger", "list"] self.assertEqual(self.shell.run(argv), 0) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(None, 500, 'INTERNAL SERVER ERROR'))) + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(None, 500, "INTERNAL SERVER ERROR") + ), + ) def test_exit_code_on_error(self): - argv = ['trigger', 'list'] + argv = ["trigger", "list"] self.assertEqual(self.shell.run(argv), 1) def _validate_parser(self, args_list, is_subcommand=True): for args in args_list: ns = self.shell.parser.parse_args(args) - func = (self.shell.commands[args[0]].run_and_print - if not is_subcommand - else self.shell.commands[args[0]].commands[args[1]].run_and_print) + func = ( + self.shell.commands[args[0]].run_and_print + if not is_subcommand + else self.shell.commands[args[0]].commands[args[1]].run_and_print + ) self.assertEqual(ns.func, func) def test_action(self): args_list = [ - ['action', 'list'], - ['action', 'get', 'abc'], - ['action', 'create', '/tmp/action.json'], - ['action', 'update', '123', '/tmp/action.json'], - ['action', 'delete', 'abc'], - ['action', 'execute', '-h'], - ['action', 'execute', 'remote', '-h'], - ['action', 'execute', 'remote', 'hosts=192.168.1.1', 'user=st2', 'cmd="ls -l"'], - ['action', 'execute', 'remote-fib', 'hosts=192.168.1.1', '3', '8'] + ["action", "list"], + ["action", "get", "abc"], + ["action", "create", "/tmp/action.json"], + ["action", "update", "123", "/tmp/action.json"], + ["action", "delete", "abc"], + ["action", "execute", "-h"], + ["action", "execute", "remote", "-h"], + [ + "action", + "execute", + "remote", + "hosts=192.168.1.1", + "user=st2", + 'cmd="ls -l"', + ], + ["action", "execute", "remote-fib", "hosts=192.168.1.1", "3", "8"], ] self._validate_parser(args_list) def test_action_execution(self): args_list = [ - ['execution', 'list'], - ['execution', 'list', '-a', 'all'], - ['execution', 'list', '--attr=all'], - ['execution', 'get', '123'], - ['execution', 'get', '123', '-d'], - ['execution', 'get', '123', '-k', 'localhost.stdout'], - ['execution', 're-run', '123'], - ['execution', 're-run', '123', '--tasks', 'x', 'y', 'z'], - ['execution', 're-run', '123', '--tasks', 'x', 'y', 'z', '--no-reset', 'x'], - ['execution', 're-run', '123', 'a=1', 'b=x', 'c=True'], - ['execution', 'cancel', '123'], - ['execution', 'cancel', '123', '456'], - ['execution', 'pause', '123'], - ['execution', 'pause', '123', '456'], - ['execution', 'resume', '123'], - ['execution', 'resume', '123', '456'] + ["execution", "list"], + ["execution", "list", "-a", "all"], + ["execution", "list", "--attr=all"], + ["execution", "get", "123"], + ["execution", "get", "123", "-d"], + ["execution", "get", "123", "-k", "localhost.stdout"], + ["execution", "re-run", "123"], + ["execution", "re-run", "123", "--tasks", "x", "y", "z"], + ["execution", "re-run", "123", "--tasks", "x", "y", "z", "--no-reset", "x"], + ["execution", "re-run", "123", "a=1", "b=x", "c=True"], + ["execution", "cancel", "123"], + ["execution", "cancel", "123", "456"], + ["execution", "pause", "123"], + ["execution", "pause", "123", "456"], + ["execution", "resume", "123"], + ["execution", "resume", "123", "456"], ] self._validate_parser(args_list) # Test mutually exclusive argument groups - self.assertRaises(SystemExit, self._validate_parser, - [['execution', 'get', '123', '-d', '-k', 'localhost.stdout']]) + self.assertRaises( + SystemExit, + self._validate_parser, + [["execution", "get", "123", "-d", "-k", "localhost.stdout"]], + ) def test_key(self): args_list = [ - ['key', 'list'], - ['key', 'list', '-n', '2'], - ['key', 'get', 'abc'], - ['key', 'set', 'abc', '123'], - ['key', 'delete', 'abc'], - ['key', 'load', '/tmp/keys.json'] + ["key", "list"], + ["key", "list", "-n", "2"], + ["key", "get", "abc"], + ["key", "set", "abc", "123"], + ["key", "delete", "abc"], + ["key", "load", "/tmp/keys.json"], ] self._validate_parser(args_list) def test_policy(self): args_list = [ - ['policy', 'list'], - ['policy', 'list', '-p', 'core'], - ['policy', 'list', '--pack', 'core'], - ['policy', 'list', '-r', 'core.local'], - ['policy', 'list', '--resource-ref', 'core.local'], - ['policy', 'list', '-pt', 'action.type1'], - ['policy', 'list', '--policy-type', 'action.type1'], - ['policy', 'list', '-r', 'core.local', '-pt', 'action.type1'], - ['policy', 'list', '--resource-ref', 'core.local', '--policy-type', 'action.type1'], - ['policy', 'get', 'abc'], - ['policy', 'create', '/tmp/policy.json'], - ['policy', 'update', '123', '/tmp/policy.json'], - ['policy', 'delete', 'abc'] + ["policy", "list"], + ["policy", "list", "-p", "core"], + ["policy", "list", "--pack", "core"], + ["policy", "list", "-r", "core.local"], + ["policy", "list", "--resource-ref", "core.local"], + ["policy", "list", "-pt", "action.type1"], + ["policy", "list", "--policy-type", "action.type1"], + ["policy", "list", "-r", "core.local", "-pt", "action.type1"], + [ + "policy", + "list", + "--resource-ref", + "core.local", + "--policy-type", + "action.type1", + ], + ["policy", "get", "abc"], + ["policy", "create", "/tmp/policy.json"], + ["policy", "update", "123", "/tmp/policy.json"], + ["policy", "delete", "abc"], ] self._validate_parser(args_list) def test_policy_type(self): args_list = [ - ['policy-type', 'list'], - ['policy-type', 'list', '-r', 'action'], - ['policy-type', 'list', '--resource-type', 'action'], - ['policy-type', 'get', 'abc'] + ["policy-type", "list"], + ["policy-type", "list", "-r", "action"], + ["policy-type", "list", "--resource-type", "action"], + ["policy-type", "get", "abc"], ] self._validate_parser(args_list) def test_pack(self): args_list = [ - ['pack', 'list'], - ['pack', 'get', 'abc'], - ['pack', 'search', 'abc'], - ['pack', 'show', 'abc'], - ['pack', 'remove', 'abc'], - ['pack', 'remove', 'abc', '--detail'], - ['pack', 'install', 'abc'], - ['pack', 'install', 'abc', '--force'], - ['pack', 'install', 'abc', '--detail'], - ['pack', 'config', 'abc'] + ["pack", "list"], + ["pack", "get", "abc"], + ["pack", "search", "abc"], + ["pack", "show", "abc"], + ["pack", "remove", "abc"], + ["pack", "remove", "abc", "--detail"], + ["pack", "install", "abc"], + ["pack", "install", "abc", "--force"], + ["pack", "install", "abc", "--detail"], + ["pack", "config", "abc"], ] self._validate_parser(args_list) - @mock.patch('st2client.base.ST2_CONFIG_PATH', '/home/does/not/exist') + @mock.patch("st2client.base.ST2_CONFIG_PATH", "/home/does/not/exist") def test_print_config_default_config_no_config(self): - os.environ['ST2_CONFIG_FILE'] = '/home/does/not/exist' - argv = ['--print-config'] + os.environ["ST2_CONFIG_FILE"] = "/home/does/not/exist" + argv = ["--print-config"] self.assertEqual(self.shell.run(argv), 3) self.stdout.seek(0) stdout = self.stdout.read() - self.assertIn('username = None', stdout) - self.assertIn('cache_token = True', stdout) + self.assertIn("username = None", stdout) + self.assertIn("cache_token = True", stdout) def test_print_config_custom_config_as_env_variable(self): - os.environ['ST2_CONFIG_FILE'] = CONFIG_FILE_PATH_FULL - argv = ['--print-config'] + os.environ["ST2_CONFIG_FILE"] = CONFIG_FILE_PATH_FULL + argv = ["--print-config"] self.assertEqual(self.shell.run(argv), 3) self.stdout.seek(0) stdout = self.stdout.read() - self.assertIn('username = test1', stdout) - self.assertIn('cache_token = False', stdout) + self.assertIn("username = test1", stdout) + self.assertIn("cache_token = False", stdout) def test_print_config_custom_config_as_command_line_argument(self): - argv = ['--print-config', '--config-file=%s' % (CONFIG_FILE_PATH_FULL)] + argv = ["--print-config", "--config-file=%s" % (CONFIG_FILE_PATH_FULL)] self.assertEqual(self.shell.run(argv), 3) self.stdout.seek(0) stdout = self.stdout.read() - self.assertIn('username = test1', stdout) - self.assertIn('cache_token = False', stdout) + self.assertIn("username = test1", stdout) + self.assertIn("cache_token = False", stdout) def test_run(self): args_list = [ - ['run', '-h'], - ['run', 'abc', '-h'], - ['run', 'remote', 'hosts=192.168.1.1', 'user=st2', 'cmd="ls -l"'], - ['run', 'remote-fib', 'hosts=192.168.1.1', '3', '8'] + ["run", "-h"], + ["run", "abc", "-h"], + ["run", "remote", "hosts=192.168.1.1", "user=st2", 'cmd="ls -l"'], + ["run", "remote-fib", "hosts=192.168.1.1", "3", "8"], ] self._validate_parser(args_list, is_subcommand=False) def test_runner(self): - args_list = [ - ['runner', 'list'], - ['runner', 'get', 'abc'] - ] + args_list = [["runner", "list"], ["runner", "get", "abc"]] self._validate_parser(args_list) def test_rule(self): args_list = [ - ['rule', 'list'], - ['rule', 'list', '-n', '1'], - ['rule', 'get', 'abc'], - ['rule', 'create', '/tmp/rule.json'], - ['rule', 'update', '123', '/tmp/rule.json'], - ['rule', 'delete', 'abc'] + ["rule", "list"], + ["rule", "list", "-n", "1"], + ["rule", "get", "abc"], + ["rule", "create", "/tmp/rule.json"], + ["rule", "update", "123", "/tmp/rule.json"], + ["rule", "delete", "abc"], ] self._validate_parser(args_list) def test_trigger(self): args_list = [ - ['trigger', 'list'], - ['trigger', 'get', 'abc'], - ['trigger', 'create', '/tmp/trigger.json'], - ['trigger', 'update', '123', '/tmp/trigger.json'], - ['trigger', 'delete', 'abc'] + ["trigger", "list"], + ["trigger", "get", "abc"], + ["trigger", "create", "/tmp/trigger.json"], + ["trigger", "update", "123", "/tmp/trigger.json"], + ["trigger", "delete", "abc"], ] self._validate_parser(args_list) def test_workflow(self): args_list = [ - ['workflow', 'inspect', '--file', '/path/to/workflow/definition'], - ['workflow', 'inspect', '--action', 'mock.foobar'] + ["workflow", "inspect", "--file", "/path/to/workflow/definition"], + ["workflow", "inspect", "--action", "mock.foobar"], ] self._validate_parser(args_list) - @mock.patch('sys.exit', mock.Mock()) - @mock.patch('st2client.shell.__version__', 'v2.8.0') + @mock.patch("sys.exit", mock.Mock()) + @mock.patch("st2client.shell.__version__", "v2.8.0") def test_get_version_no_package_metadata_file_stable_version(self): # stable version, package metadata file doesn't exist on disk - no git revision should be # included shell = Shell() - shell.parser.parse_args(args=['--version']) + shell.parser.parse_args(args=["--version"]) self.version_output.seek(0) stderr = self.version_output.read() - self.assertIn('v2.8.0, on Python', stderr) + self.assertIn("v2.8.0, on Python", stderr) - @mock.patch('sys.exit', mock.Mock()) - @mock.patch('st2client.shell.__version__', 'v2.8.0') + @mock.patch("sys.exit", mock.Mock()) + @mock.patch("st2client.shell.__version__", "v2.8.0") def test_get_version_package_metadata_file_exists_stable_version(self): # stable version, package metadata file exists on disk - no git revision should be included package_metadata_path = self._write_mock_package_metadata_file() st2client.shell.PACKAGE_METADATA_FILE_PATH = package_metadata_path shell = Shell() - shell.run(argv=['--version']) + shell.run(argv=["--version"]) self.version_output.seek(0) stderr = self.version_output.read() - self.assertIn('v2.8.0, on Python', stderr) + self.assertIn("v2.8.0, on Python", stderr) - @mock.patch('sys.exit', mock.Mock()) - @mock.patch('st2client.shell.__version__', 'v2.9dev') - @mock.patch('st2client.shell.PACKAGE_METADATA_FILE_PATH', '/tmp/doesnt/exist.1') + @mock.patch("sys.exit", mock.Mock()) + @mock.patch("st2client.shell.__version__", "v2.9dev") + @mock.patch("st2client.shell.PACKAGE_METADATA_FILE_PATH", "/tmp/doesnt/exist.1") def test_get_version_no_package_metadata_file_dev_version(self): # dev version, package metadata file doesn't exist on disk - no git revision should be # included since package metadata file doesn't exist on disk shell = Shell() - shell.parser.parse_args(args=['--version']) + shell.parser.parse_args(args=["--version"]) self.version_output.seek(0) stderr = self.version_output.read() - self.assertIn('v2.9dev, on Python', stderr) + self.assertIn("v2.9dev, on Python", stderr) - @mock.patch('sys.exit', mock.Mock()) - @mock.patch('st2client.shell.__version__', 'v2.9dev') + @mock.patch("sys.exit", mock.Mock()) + @mock.patch("st2client.shell.__version__", "v2.9dev") def test_get_version_package_metadata_file_exists_dev_version(self): # dev version, package metadata file exists on disk - git revision should be included # since package metadata file exists on disk and contains server.git_sha attribute @@ -430,55 +461,67 @@ def test_get_version_package_metadata_file_exists_dev_version(self): st2client.shell.PACKAGE_METADATA_FILE_PATH = package_metadata_path shell = Shell() - shell.parser.parse_args(args=['--version']) + shell.parser.parse_args(args=["--version"]) self.version_output.seek(0) stderr = self.version_output.read() - self.assertIn('v2.9dev (abcdefg), on Python', stderr) + self.assertIn("v2.9dev (abcdefg), on Python", stderr) - @mock.patch('locale.getdefaultlocale', mock.Mock(return_value=['en_US'])) - @mock.patch('locale.getpreferredencoding', mock.Mock(return_value='iso')) + @mock.patch("locale.getdefaultlocale", mock.Mock(return_value=["en_US"])) + @mock.patch("locale.getpreferredencoding", mock.Mock(return_value="iso")) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, 'OK'))) - @mock.patch('st2client.shell.LOGGER') + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, "OK") + ), + ) + @mock.patch("st2client.shell.LOGGER") def test_non_unicode_encoding_locale_warning_is_printed(self, mock_logger): shell = Shell() - shell.run(argv=['trigger', 'list']) + shell.run(argv=["trigger", "list"]) call_args = mock_logger.warn.call_args[0][0] - self.assertIn('Locale en_US with encoding iso which is not UTF-8 is used.', call_args) + self.assertIn( + "Locale en_US with encoding iso which is not UTF-8 is used.", call_args + ) - @mock.patch('locale.getdefaultlocale', mock.Mock(side_effect=ValueError('bar'))) - @mock.patch('locale.getpreferredencoding', mock.Mock(side_effect=ValueError('bar'))) + @mock.patch("locale.getdefaultlocale", mock.Mock(side_effect=ValueError("bar"))) + @mock.patch("locale.getpreferredencoding", mock.Mock(side_effect=ValueError("bar"))) @mock.patch.object( - httpclient.HTTPClient, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, 'OK'))) - @mock.patch('st2client.shell.LOGGER') + httpclient.HTTPClient, + "get", + mock.MagicMock( + return_value=base.FakeResponse(json.dumps(base.RESOURCES), 200, "OK") + ), + ) + @mock.patch("st2client.shell.LOGGER") def test_failed_to_get_locale_encoding_warning_is_printed(self, mock_logger): shell = Shell() - shell.run(argv=['trigger', 'list']) + shell.run(argv=["trigger", "list"]) call_args = mock_logger.warn.call_args[0][0] - self.assertTrue('Locale unknown with encoding unknown which is not UTF-8 is used.' in - call_args) + self.assertTrue( + "Locale unknown with encoding unknown which is not UTF-8 is used." + in call_args + ) def _write_mock_package_metadata_file(self): _, package_metadata_path = tempfile.mkstemp() - with open(package_metadata_path, 'w') as fp: + with open(package_metadata_path, "w") as fp: fp.write(MOCK_PACKAGE_METADATA) return package_metadata_path - @unittest2.skipIf(True, 'skipping until checks are re-enabled') + @unittest2.skipIf(True, "skipping until checks are re-enabled") @mock.patch.object( - requests, 'get', - mock.MagicMock(return_value=base.FakeResponse("{}", 200, 'OK'))) + requests, "get", mock.MagicMock(return_value=base.FakeResponse("{}", 200, "OK")) + ) def test_dont_warn_multiple_times(self): mock_temp_dir_path = tempfile.mkdtemp() - mock_config_dir_path = os.path.join(mock_temp_dir_path, 'testconfig') - mock_config_path = os.path.join(mock_config_dir_path, 'config') + mock_config_dir_path = os.path.join(mock_temp_dir_path, "testconfig") + mock_config_path = os.path.join(mock_config_dir_path, "config") # Make the temporary config directory os.makedirs(mock_config_dir_path) @@ -495,38 +538,46 @@ def test_dont_warn_multiple_times(self): shell.LOG = mock.Mock() # Test without token. - shell.run(['--config-file', mock_config_path, 'action', 'list']) + shell.run(["--config-file", mock_config_path, "action", "list"]) self.assertEqual(shell.LOG.warn.call_count, 2) self.assertEqual( shell.LOG.warn.call_args_list[0][0][0][:63], - 'The StackStorm configuration directory permissions are insecure') + "The StackStorm configuration directory permissions are insecure", + ) self.assertEqual( shell.LOG.warn.call_args_list[1][0][0][:58], - 'The StackStorm configuration file permissions are insecure') + "The StackStorm configuration file permissions are insecure", + ) self.assertEqual(shell.LOG.info.call_count, 2) self.assertEqual( - shell.LOG.info.call_args_list[0][0][0], "The SGID bit is not " - "set on the StackStorm configuration directory.") + shell.LOG.info.call_args_list[0][0][0], + "The SGID bit is not " "set on the StackStorm configuration directory.", + ) self.assertEqual( - shell.LOG.info.call_args_list[1][0][0], 'Skipping parsing CLI config') + shell.LOG.info.call_args_list[1][0][0], "Skipping parsing CLI config" + ) class CLITokenCachingTestCase(unittest2.TestCase): def setUp(self): super(CLITokenCachingTestCase, self).setUp() self._mock_temp_dir_path = tempfile.mkdtemp() - self._mock_config_directory_path = os.path.join(self._mock_temp_dir_path, 'testconfig') - self._mock_config_path = os.path.join(self._mock_config_directory_path, 'config') + self._mock_config_directory_path = os.path.join( + self._mock_temp_dir_path, "testconfig" + ) + self._mock_config_path = os.path.join( + self._mock_config_directory_path, "config" + ) os.makedirs(self._mock_config_directory_path) - self._p1 = mock.patch('st2client.base.ST2_CONFIG_DIRECTORY', - self._mock_config_directory_path) - self._p2 = mock.patch('st2client.base.ST2_CONFIG_PATH', - self._mock_config_path) + self._p1 = mock.patch( + "st2client.base.ST2_CONFIG_DIRECTORY", self._mock_config_directory_path + ) + self._p2 = mock.patch("st2client.base.ST2_CONFIG_PATH", self._mock_config_path) self._p1.start() self._p2.start() @@ -536,46 +587,46 @@ def tearDown(self): self._p2.stop() for var in [ - 'ST2_BASE_URL', - 'ST2_API_URL', - 'ST2_STREAM_URL', - 'ST2_DATASTORE_URL', - 'ST2_AUTH_TOKEN' + "ST2_BASE_URL", + "ST2_API_URL", + "ST2_STREAM_URL", + "ST2_DATASTORE_URL", + "ST2_AUTH_TOKEN", ]: if var in os.environ: del os.environ[var] def _write_mock_config(self): - with open(self._mock_config_path, 'w') as fp: + with open(self._mock_config_path, "w") as fp: fp.write(MOCK_CONFIG) def test_get_cached_auth_token_invalid_permissions(self): shell = Shell() client = Client() - username = 'testu' - password = 'testp' + username = "testu" + password = "testp" cached_token_path = shell._get_cached_token_path_for_user(username=username) - data = { - 'token': 'yayvalid', - 'expire_timestamp': (int(time.time()) + 20) - } - with open(cached_token_path, 'w') as fp: + data = {"token": "yayvalid", "expire_timestamp": (int(time.time()) + 20)} + with open(cached_token_path, "w") as fp: fp.write(json.dumps(data)) # 1. Current user doesn't have read access to the config directory os.chmod(self._mock_config_directory_path, 0o000) shell.LOG = mock.Mock() - result = shell._get_cached_auth_token(client=client, username=username, - password=password) + result = shell._get_cached_auth_token( + client=client, username=username, password=password + ) self.assertEqual(result, None) self.assertEqual(shell.LOG.warn.call_count, 1) log_message = shell.LOG.warn.call_args[0][0] - expected_msg = ('Unable to retrieve cached token from .*? read access to the parent ' - 'directory') + expected_msg = ( + "Unable to retrieve cached token from .*? read access to the parent " + "directory" + ) self.assertRegexpMatches(log_message, expected_msg) # 2. Read access on the directory, but not on the cached token file @@ -583,14 +634,17 @@ def test_get_cached_auth_token_invalid_permissions(self): os.chmod(cached_token_path, 0o000) shell.LOG = mock.Mock() - result = shell._get_cached_auth_token(client=client, username=username, - password=password) + result = shell._get_cached_auth_token( + client=client, username=username, password=password + ) self.assertEqual(result, None) self.assertEqual(shell.LOG.warn.call_count, 1) log_message = shell.LOG.warn.call_args[0][0] - expected_msg = ('Unable to retrieve cached token from .*? read access to this file') + expected_msg = ( + "Unable to retrieve cached token from .*? read access to this file" + ) self.assertRegexpMatches(log_message, expected_msg) # 3. Other users also have read access to the file @@ -598,31 +652,29 @@ def test_get_cached_auth_token_invalid_permissions(self): os.chmod(cached_token_path, 0o444) shell.LOG = mock.Mock() - result = shell._get_cached_auth_token(client=client, username=username, - password=password) - self.assertEqual(result, 'yayvalid') + result = shell._get_cached_auth_token( + client=client, username=username, password=password + ) + self.assertEqual(result, "yayvalid") self.assertEqual(shell.LOG.warn.call_count, 1) log_message = shell.LOG.warn.call_args[0][0] - expected_msg = ('Permissions .*? for cached token file .*? are too permissive.*') + expected_msg = "Permissions .*? for cached token file .*? are too permissive.*" self.assertRegexpMatches(log_message, expected_msg) def test_cache_auth_token_invalid_permissions(self): shell = Shell() - username = 'testu' + username = "testu" cached_token_path = shell._get_cached_token_path_for_user(username=username) expiry = datetime.datetime.utcnow() + datetime.timedelta(seconds=30) - token_db = TokenDB(user=username, token='fyeah', expiry=expiry) + token_db = TokenDB(user=username, token="fyeah", expiry=expiry) cached_token_path = shell._get_cached_token_path_for_user(username=username) - data = { - 'token': 'yayvalid', - 'expire_timestamp': (int(time.time()) + 20) - } - with open(cached_token_path, 'w') as fp: + data = {"token": "yayvalid", "expire_timestamp": (int(time.time()) + 20)} + with open(cached_token_path, "w") as fp: fp.write(json.dumps(data)) # 1. Current user has no write access to the parent directory @@ -634,8 +686,10 @@ def test_cache_auth_token_invalid_permissions(self): self.assertEqual(shell.LOG.warn.call_count, 1) log_message = shell.LOG.warn.call_args[0][0] - expected_msg = ('Unable to write token to .*? doesn\'t have write access to the parent ' - 'directory') + expected_msg = ( + "Unable to write token to .*? doesn't have write access to the parent " + "directory" + ) self.assertRegexpMatches(log_message, expected_msg) # 2. Current user has no write access to the cached token file @@ -648,86 +702,93 @@ def test_cache_auth_token_invalid_permissions(self): self.assertEqual(shell.LOG.warn.call_count, 1) log_message = shell.LOG.warn.call_args[0][0] - expected_msg = ('Unable to write token to .*? doesn\'t have write access to this file') + expected_msg = ( + "Unable to write token to .*? doesn't have write access to this file" + ) self.assertRegexpMatches(log_message, expected_msg) def test_get_cached_auth_token_no_token_cache_file(self): client = Client() shell = Shell() - username = 'testu' - password = 'testp' + username = "testu" + password = "testp" - result = shell._get_cached_auth_token(client=client, username=username, - password=password) + result = shell._get_cached_auth_token( + client=client, username=username, password=password + ) self.assertEqual(result, None) def test_get_cached_auth_token_corrupted_token_cache_file(self): client = Client() shell = Shell() - username = 'testu' - password = 'testp' + username = "testu" + password = "testp" cached_token_path = shell._get_cached_token_path_for_user(username=username) - with open(cached_token_path, 'w') as fp: - fp.write('CORRRRRUPTED!') - - expected_msg = 'File (.+) with cached token is corrupted or invalid' - self.assertRaisesRegexp(ValueError, expected_msg, shell._get_cached_auth_token, - client=client, username=username, password=password) + with open(cached_token_path, "w") as fp: + fp.write("CORRRRRUPTED!") + + expected_msg = "File (.+) with cached token is corrupted or invalid" + self.assertRaisesRegexp( + ValueError, + expected_msg, + shell._get_cached_auth_token, + client=client, + username=username, + password=password, + ) def test_get_cached_auth_token_expired_token_in_cache_file(self): client = Client() shell = Shell() - username = 'testu' - password = 'testp' + username = "testu" + password = "testp" cached_token_path = shell._get_cached_token_path_for_user(username=username) - data = { - 'token': 'expired', - 'expire_timestamp': (int(time.time()) - 10) - } - with open(cached_token_path, 'w') as fp: + data = {"token": "expired", "expire_timestamp": (int(time.time()) - 10)} + with open(cached_token_path, "w") as fp: fp.write(json.dumps(data)) - result = shell._get_cached_auth_token(client=client, username=username, - password=password) + result = shell._get_cached_auth_token( + client=client, username=username, password=password + ) self.assertEqual(result, None) def test_get_cached_auth_token_valid_token_in_cache_file(self): client = Client() shell = Shell() - username = 'testu' - password = 'testp' + username = "testu" + password = "testp" cached_token_path = shell._get_cached_token_path_for_user(username=username) - data = { - 'token': 'yayvalid', - 'expire_timestamp': (int(time.time()) + 20) - } - with open(cached_token_path, 'w') as fp: + data = {"token": "yayvalid", "expire_timestamp": (int(time.time()) + 20)} + with open(cached_token_path, "w") as fp: fp.write(json.dumps(data)) - result = shell._get_cached_auth_token(client=client, username=username, - password=password) - self.assertEqual(result, 'yayvalid') + result = shell._get_cached_auth_token( + client=client, username=username, password=password + ) + self.assertEqual(result, "yayvalid") def test_cache_auth_token_success(self): client = Client() shell = Shell() - username = 'testu' - password = 'testp' + username = "testu" + password = "testp" expiry = datetime.datetime.utcnow() + datetime.timedelta(seconds=30) - result = shell._get_cached_auth_token(client=client, username=username, - password=password) + result = shell._get_cached_auth_token( + client=client, username=username, password=password + ) self.assertEqual(result, None) - token_db = TokenDB(user=username, token='fyeah', expiry=expiry) + token_db = TokenDB(user=username, token="fyeah", expiry=expiry) shell._cache_auth_token(token_obj=token_db) - result = shell._get_cached_auth_token(client=client, username=username, - password=password) - self.assertEqual(result, 'fyeah') + result = shell._get_cached_auth_token( + client=client, username=username, password=password + ) + self.assertEqual(result, "fyeah") def test_automatic_auth_skipped_on_auth_command(self): self._write_mock_config() @@ -735,7 +796,7 @@ def test_automatic_auth_skipped_on_auth_command(self): shell = Shell() shell._get_auth_token = mock.Mock() - argv = ['auth', 'testu', '-p', 'testp'] + argv = ["auth", "testu", "-p", "testp"] args = shell.parser.parse_args(args=argv) shell.get_client(args=args) self.assertEqual(shell._get_auth_token.call_count, 0) @@ -746,8 +807,8 @@ def test_automatic_auth_skipped_if_token_provided_as_env_variable(self): shell = Shell() shell._get_auth_token = mock.Mock() - os.environ['ST2_AUTH_TOKEN'] = 'fooo' - argv = ['action', 'list'] + os.environ["ST2_AUTH_TOKEN"] = "fooo" + argv = ["action", "list"] args = shell.parser.parse_args(args=argv) shell.get_client(args=args) self.assertEqual(shell._get_auth_token.call_count, 0) @@ -758,12 +819,12 @@ def test_automatic_auth_skipped_if_token_provided_as_cli_argument(self): shell = Shell() shell._get_auth_token = mock.Mock() - argv = ['action', 'list', '--token=bar'] + argv = ["action", "list", "--token=bar"] args = shell.parser.parse_args(args=argv) shell.get_client(args=args) self.assertEqual(shell._get_auth_token.call_count, 0) - argv = ['action', 'list', '-t', 'bar'] + argv = ["action", "list", "-t", "bar"] args = shell.parser.parse_args(args=argv) shell.get_client(args=args) self.assertEqual(shell._get_auth_token.call_count, 0) diff --git a/st2client/tests/unit/test_ssl.py b/st2client/tests/unit/test_ssl.py index 5ed8bfbf289..5db836482ba 100644 --- a/st2client/tests/unit/test_ssl.py +++ b/st2client/tests/unit/test_ssl.py @@ -27,17 +27,18 @@ LOG = logging.getLogger(__name__) -USERNAME = 'stanley' -PASSWORD = 'ShhhDontTell' -HEADERS = {'content-type': 'application/json'} -AUTH_URL = 'https://127.0.0.1:9100/tokens' -GET_RULES_URL = ('http://127.0.0.1:9101/v1/rules/' - '?include_attributes=ref,pack,description,enabled&limit=50') -GET_RULES_URL = GET_RULES_URL.replace(',', '%2C') +USERNAME = "stanley" +PASSWORD = "ShhhDontTell" +HEADERS = {"content-type": "application/json"} +AUTH_URL = "https://127.0.0.1:9100/tokens" +GET_RULES_URL = ( + "http://127.0.0.1:9101/v1/rules/" + "?include_attributes=ref,pack,description,enabled&limit=50" +) +GET_RULES_URL = GET_RULES_URL.replace(",", "%2C") class TestHttps(base.BaseCLITestCase): - def __init__(self, *args, **kwargs): super(TestHttps, self).__init__(*args, **kwargs) self.shell = shell.Shell() @@ -46,11 +47,11 @@ def setUp(self): super(TestHttps, self).setUp() # Setup environment. - os.environ['ST2_BASE_URL'] = 'http://127.0.0.1' - os.environ['ST2_AUTH_URL'] = 'https://127.0.0.1:9100' + os.environ["ST2_BASE_URL"] = "http://127.0.0.1" + os.environ["ST2_AUTH_URL"] = "https://127.0.0.1:9100" - if 'ST2_CACERT' in os.environ: - del os.environ['ST2_CACERT'] + if "ST2_CACERT" in os.environ: + del os.environ["ST2_CACERT"] # Create a temp file to mock a cert file. self.cacert_fd, self.cacert_path = tempfile.mkstemp() @@ -59,58 +60,78 @@ def tearDown(self): super(TestHttps, self).tearDown() # Clean up environment. - if 'ST2_CACERT' in os.environ: - del os.environ['ST2_CACERT'] - if 'ST2_BASE_URL' in os.environ: - del os.environ['ST2_BASE_URL'] + if "ST2_CACERT" in os.environ: + del os.environ["ST2_CACERT"] + if "ST2_BASE_URL" in os.environ: + del os.environ["ST2_BASE_URL"] # Clean up temp files. os.close(self.cacert_fd) os.unlink(self.cacert_path) @mock.patch.object( - requests, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, 'OK'))) + requests, + "post", + mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, "OK")), + ) def test_decorate_https_without_cacert(self): - self.shell.run(['auth', USERNAME, '-p', PASSWORD]) - kwargs = {'verify': False, 'headers': HEADERS, 'auth': (USERNAME, PASSWORD)} + self.shell.run(["auth", USERNAME, "-p", PASSWORD]) + kwargs = {"verify": False, "headers": HEADERS, "auth": (USERNAME, PASSWORD)} requests.post.assert_called_with(AUTH_URL, json.dumps({}), **kwargs) @mock.patch.object( - requests, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, 'OK'))) + requests, + "post", + mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, "OK")), + ) def test_decorate_https_with_cacert_from_cli(self): - self.shell.run(['--cacert', self.cacert_path, 'auth', USERNAME, '-p', PASSWORD]) - kwargs = {'verify': self.cacert_path, 'headers': HEADERS, 'auth': (USERNAME, PASSWORD)} + self.shell.run(["--cacert", self.cacert_path, "auth", USERNAME, "-p", PASSWORD]) + kwargs = { + "verify": self.cacert_path, + "headers": HEADERS, + "auth": (USERNAME, PASSWORD), + } requests.post.assert_called_with(AUTH_URL, json.dumps({}), **kwargs) @mock.patch.object( - requests, 'post', - mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, 'OK'))) + requests, + "post", + mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, "OK")), + ) def test_decorate_https_with_cacert_from_env(self): - os.environ['ST2_CACERT'] = self.cacert_path - self.shell.run(['auth', USERNAME, '-p', PASSWORD]) - kwargs = {'verify': self.cacert_path, 'headers': HEADERS, 'auth': (USERNAME, PASSWORD)} + os.environ["ST2_CACERT"] = self.cacert_path + self.shell.run(["auth", USERNAME, "-p", PASSWORD]) + kwargs = { + "verify": self.cacert_path, + "headers": HEADERS, + "auth": (USERNAME, PASSWORD), + } requests.post.assert_called_with(AUTH_URL, json.dumps({}), **kwargs) @mock.patch.object( - requests, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps([]), 200, 'OK'))) + requests, + "get", + mock.MagicMock(return_value=base.FakeResponse(json.dumps([]), 200, "OK")), + ) def test_decorate_http_without_cacert(self): - self.shell.run(['rule', 'list']) + self.shell.run(["rule", "list"]) requests.get.assert_called_with(GET_RULES_URL) @mock.patch.object( - requests, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, 'OK'))) + requests, + "get", + mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, "OK")), + ) def test_decorate_http_with_cacert_from_cli(self): - self.shell.run(['--cacert', self.cacert_path, 'rule', 'list']) + self.shell.run(["--cacert", self.cacert_path, "rule", "list"]) requests.get.assert_called_with(GET_RULES_URL) @mock.patch.object( - requests, 'get', - mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, 'OK'))) + requests, + "get", + mock.MagicMock(return_value=base.FakeResponse(json.dumps({}), 200, "OK")), + ) def test_decorate_http_with_cacert_from_env(self): - os.environ['ST2_CACERT'] = self.cacert_path - self.shell.run(['rule', 'list']) + os.environ["ST2_CACERT"] = self.cacert_path + self.shell.run(["rule", "list"]) requests.get.assert_called_with(GET_RULES_URL) diff --git a/st2client/tests/unit/test_trace_commands.py b/st2client/tests/unit/test_trace_commands.py index 99d60598a40..ea3b552d47c 100644 --- a/st2client/tests/unit/test_trace_commands.py +++ b/st2client/tests/unit/test_trace_commands.py @@ -23,23 +23,38 @@ class TraceCommandTestCase(base.BaseCLITestCase): - def test_trace_get_filter_trace_components_executions(self): trace = trace_models.Trace() - setattr(trace, 'action_executions', - [{'object_id': 'e1', 'caused_by': {'id': 'r1:t1', 'type': 'rule'}}]) - setattr(trace, 'rules', - [{'object_id': 'r1', 'caused_by': {'id': 't1', 'type': 'trigger_instance'}}]) - setattr(trace, 'trigger_instances', - [{'object_id': 't1', 'caused_by': {}}, - {'object_id': 't2', 'caused_by': {'id': 'e1', 'type': 'execution'}}]) + setattr( + trace, + "action_executions", + [{"object_id": "e1", "caused_by": {"id": "r1:t1", "type": "rule"}}], + ) + setattr( + trace, + "rules", + [ + { + "object_id": "r1", + "caused_by": {"id": "t1", "type": "trigger_instance"}, + } + ], + ) + setattr( + trace, + "trigger_instances", + [ + {"object_id": "t1", "caused_by": {}}, + {"object_id": "t2", "caused_by": {"id": "e1", "type": "execution"}}, + ], + ) args = argparse.Namespace() - setattr(args, 'execution', 'e1') - setattr(args, 'show_executions', False) - setattr(args, 'show_rules', False) - setattr(args, 'show_trigger_instances', False) - setattr(args, 'hide_noop_triggers', False) + setattr(args, "execution", "e1") + setattr(args, "show_executions", False) + setattr(args, "show_rules", False) + setattr(args, "show_trigger_instances", False) + setattr(args, "hide_noop_triggers", False) trace = trace_commands.TraceGetCommand._filter_trace_components(trace, args) self.assertEqual(len(trace.action_executions), 1) @@ -48,22 +63,38 @@ def test_trace_get_filter_trace_components_executions(self): def test_trace_get_filter_trace_components_rules(self): trace = trace_models.Trace() - setattr(trace, 'action_executions', - [{'object_id': 'e1', 'caused_by': {'id': 'r1:t1', 'type': 'rule'}}]) - setattr(trace, 'rules', - [{'object_id': 'r1', 'caused_by': {'id': 't1', 'type': 'trigger_instance'}}]) - setattr(trace, 'trigger_instances', - [{'object_id': 't1', 'caused_by': {}}, - {'object_id': 't2', 'caused_by': {'id': 'e1', 'type': 'execution'}}]) + setattr( + trace, + "action_executions", + [{"object_id": "e1", "caused_by": {"id": "r1:t1", "type": "rule"}}], + ) + setattr( + trace, + "rules", + [ + { + "object_id": "r1", + "caused_by": {"id": "t1", "type": "trigger_instance"}, + } + ], + ) + setattr( + trace, + "trigger_instances", + [ + {"object_id": "t1", "caused_by": {}}, + {"object_id": "t2", "caused_by": {"id": "e1", "type": "execution"}}, + ], + ) args = argparse.Namespace() - setattr(args, 'execution', None) - setattr(args, 'rule', 'r1') - setattr(args, 'trigger_instance', None) - setattr(args, 'show_executions', False) - setattr(args, 'show_rules', False) - setattr(args, 'show_trigger_instances', False) - setattr(args, 'hide_noop_triggers', False) + setattr(args, "execution", None) + setattr(args, "rule", "r1") + setattr(args, "trigger_instance", None) + setattr(args, "show_executions", False) + setattr(args, "show_rules", False) + setattr(args, "show_trigger_instances", False) + setattr(args, "hide_noop_triggers", False) trace = trace_commands.TraceGetCommand._filter_trace_components(trace, args) self.assertEqual(len(trace.action_executions), 0) @@ -72,22 +103,38 @@ def test_trace_get_filter_trace_components_rules(self): def test_trace_get_filter_trace_components_trigger_instances(self): trace = trace_models.Trace() - setattr(trace, 'action_executions', - [{'object_id': 'e1', 'caused_by': {'id': 'r1:t1', 'type': 'rule'}}]) - setattr(trace, 'rules', - [{'object_id': 'r1', 'caused_by': {'id': 't1', 'type': 'trigger_instance'}}]) - setattr(trace, 'trigger_instances', - [{'object_id': 't1', 'caused_by': {}}, - {'object_id': 't2', 'caused_by': {'id': 'e1', 'type': 'execution'}}]) + setattr( + trace, + "action_executions", + [{"object_id": "e1", "caused_by": {"id": "r1:t1", "type": "rule"}}], + ) + setattr( + trace, + "rules", + [ + { + "object_id": "r1", + "caused_by": {"id": "t1", "type": "trigger_instance"}, + } + ], + ) + setattr( + trace, + "trigger_instances", + [ + {"object_id": "t1", "caused_by": {}}, + {"object_id": "t2", "caused_by": {"id": "e1", "type": "execution"}}, + ], + ) args = argparse.Namespace() - setattr(args, 'execution', None) - setattr(args, 'rule', None) - setattr(args, 'trigger_instance', 't1') - setattr(args, 'show_executions', False) - setattr(args, 'show_rules', False) - setattr(args, 'show_trigger_instances', False) - setattr(args, 'hide_noop_triggers', False) + setattr(args, "execution", None) + setattr(args, "rule", None) + setattr(args, "trigger_instance", "t1") + setattr(args, "show_executions", False) + setattr(args, "show_rules", False) + setattr(args, "show_trigger_instances", False) + setattr(args, "hide_noop_triggers", False) trace = trace_commands.TraceGetCommand._filter_trace_components(trace, args) self.assertEqual(len(trace.action_executions), 0) @@ -96,15 +143,15 @@ def test_trace_get_filter_trace_components_trigger_instances(self): def test_trace_get_apply_display_filters_show_executions(self): trace = trace_models.Trace() - setattr(trace, 'action_executions', ['1']) - setattr(trace, 'rules', ['1']) - setattr(trace, 'trigger_instances', ['1']) + setattr(trace, "action_executions", ["1"]) + setattr(trace, "rules", ["1"]) + setattr(trace, "trigger_instances", ["1"]) args = argparse.Namespace() - setattr(args, 'show_executions', True) - setattr(args, 'show_rules', False) - setattr(args, 'show_trigger_instances', False) - setattr(args, 'hide_noop_triggers', False) + setattr(args, "show_executions", True) + setattr(args, "show_rules", False) + setattr(args, "show_trigger_instances", False) + setattr(args, "hide_noop_triggers", False) trace = trace_commands.TraceGetCommand._apply_display_filters(trace, args) self.assertTrue(trace.action_executions) @@ -113,15 +160,15 @@ def test_trace_get_apply_display_filters_show_executions(self): def test_trace_get_apply_display_filters_show_rules(self): trace = trace_models.Trace() - setattr(trace, 'action_executions', ['1']) - setattr(trace, 'rules', ['1']) - setattr(trace, 'trigger_instances', ['1']) + setattr(trace, "action_executions", ["1"]) + setattr(trace, "rules", ["1"]) + setattr(trace, "trigger_instances", ["1"]) args = argparse.Namespace() - setattr(args, 'show_executions', False) - setattr(args, 'show_rules', True) - setattr(args, 'show_trigger_instances', False) - setattr(args, 'hide_noop_triggers', False) + setattr(args, "show_executions", False) + setattr(args, "show_rules", True) + setattr(args, "show_trigger_instances", False) + setattr(args, "hide_noop_triggers", False) trace = trace_commands.TraceGetCommand._apply_display_filters(trace, args) self.assertFalse(trace.action_executions) @@ -130,15 +177,15 @@ def test_trace_get_apply_display_filters_show_rules(self): def test_trace_get_apply_display_filters_show_trigger_instances(self): trace = trace_models.Trace() - setattr(trace, 'action_executions', ['1']) - setattr(trace, 'rules', ['1']) - setattr(trace, 'trigger_instances', ['1']) + setattr(trace, "action_executions", ["1"]) + setattr(trace, "rules", ["1"]) + setattr(trace, "trigger_instances", ["1"]) args = argparse.Namespace() - setattr(args, 'show_executions', False) - setattr(args, 'show_rules', False) - setattr(args, 'show_trigger_instances', True) - setattr(args, 'hide_noop_triggers', False) + setattr(args, "show_executions", False) + setattr(args, "show_rules", False) + setattr(args, "show_trigger_instances", True) + setattr(args, "hide_noop_triggers", False) trace = trace_commands.TraceGetCommand._apply_display_filters(trace, args) self.assertFalse(trace.action_executions) @@ -147,15 +194,15 @@ def test_trace_get_apply_display_filters_show_trigger_instances(self): def test_trace_get_apply_display_filters_show_multiple(self): trace = trace_models.Trace() - setattr(trace, 'action_executions', ['1']) - setattr(trace, 'rules', ['1']) - setattr(trace, 'trigger_instances', ['1']) + setattr(trace, "action_executions", ["1"]) + setattr(trace, "rules", ["1"]) + setattr(trace, "trigger_instances", ["1"]) args = argparse.Namespace() - setattr(args, 'show_executions', True) - setattr(args, 'show_rules', True) - setattr(args, 'show_trigger_instances', False) - setattr(args, 'hide_noop_triggers', False) + setattr(args, "show_executions", True) + setattr(args, "show_rules", True) + setattr(args, "show_trigger_instances", False) + setattr(args, "hide_noop_triggers", False) trace = trace_commands.TraceGetCommand._apply_display_filters(trace, args) self.assertTrue(trace.action_executions) @@ -164,15 +211,15 @@ def test_trace_get_apply_display_filters_show_multiple(self): def test_trace_get_apply_display_filters_show_all(self): trace = trace_models.Trace() - setattr(trace, 'action_executions', ['1']) - setattr(trace, 'rules', ['1']) - setattr(trace, 'trigger_instances', ['1']) + setattr(trace, "action_executions", ["1"]) + setattr(trace, "rules", ["1"]) + setattr(trace, "trigger_instances", ["1"]) args = argparse.Namespace() - setattr(args, 'show_executions', False) - setattr(args, 'show_rules', False) - setattr(args, 'show_trigger_instances', False) - setattr(args, 'hide_noop_triggers', False) + setattr(args, "show_executions", False) + setattr(args, "show_rules", False) + setattr(args, "show_trigger_instances", False) + setattr(args, "hide_noop_triggers", False) trace = trace_commands.TraceGetCommand._apply_display_filters(trace, args) self.assertEqual(len(trace.action_executions), 1) @@ -181,19 +228,35 @@ def test_trace_get_apply_display_filters_show_all(self): def test_trace_get_apply_display_filters_hide_noop(self): trace = trace_models.Trace() - setattr(trace, 'action_executions', - [{'object_id': 'e1', 'caused_by': {'id': 'r1:t1', 'type': 'rule'}}]) - setattr(trace, 'rules', - [{'object_id': 'r1', 'caused_by': {'id': 't1', 'type': 'trigger_instance'}}]) - setattr(trace, 'trigger_instances', - [{'object_id': 't1', 'caused_by': {}}, - {'object_id': 't2', 'caused_by': {'id': 'e1', 'type': 'execution'}}]) + setattr( + trace, + "action_executions", + [{"object_id": "e1", "caused_by": {"id": "r1:t1", "type": "rule"}}], + ) + setattr( + trace, + "rules", + [ + { + "object_id": "r1", + "caused_by": {"id": "t1", "type": "trigger_instance"}, + } + ], + ) + setattr( + trace, + "trigger_instances", + [ + {"object_id": "t1", "caused_by": {}}, + {"object_id": "t2", "caused_by": {"id": "e1", "type": "execution"}}, + ], + ) args = argparse.Namespace() - setattr(args, 'show_executions', False) - setattr(args, 'show_rules', False) - setattr(args, 'show_trigger_instances', False) - setattr(args, 'hide_noop_triggers', True) + setattr(args, "show_executions", False) + setattr(args, "show_rules", False) + setattr(args, "show_trigger_instances", False) + setattr(args, "hide_noop_triggers", True) trace = trace_commands.TraceGetCommand._apply_display_filters(trace, args) self.assertEqual(len(trace.action_executions), 1) diff --git a/st2client/tests/unit/test_util_date.py b/st2client/tests/unit/test_util_date.py index e29b840ed71..2cdeab95fc1 100644 --- a/st2client/tests/unit/test_util_date.py +++ b/st2client/tests/unit/test_util_date.py @@ -30,31 +30,31 @@ def test_format_dt(self): dt = datetime.datetime(2015, 10, 20, 8, 0, 0) dt = add_utc_tz(dt) result = format_dt(dt) - self.assertEqual(result, 'Tue, 20 Oct 2015 08:00:00 UTC') + self.assertEqual(result, "Tue, 20 Oct 2015 08:00:00 UTC") def test_format_isodate(self): # No timezone, defaults to UTC - value = 'Tue, 20 Oct 2015 08:00:00 UTC' + value = "Tue, 20 Oct 2015 08:00:00 UTC" result = format_isodate(value=value) - self.assertEqual(result, 'Tue, 20 Oct 2015 08:00:00 UTC') + self.assertEqual(result, "Tue, 20 Oct 2015 08:00:00 UTC") # Timezone provided - value = 'Tue, 20 Oct 2015 08:00:00 UTC' - result = format_isodate(value=value, timezone='Europe/Ljubljana') - self.assertEqual(result, 'Tue, 20 Oct 2015 10:00:00 CEST') + value = "Tue, 20 Oct 2015 08:00:00 UTC" + result = format_isodate(value=value, timezone="Europe/Ljubljana") + self.assertEqual(result, "Tue, 20 Oct 2015 10:00:00 CEST") - @mock.patch('st2client.utils.date.get_config') + @mock.patch("st2client.utils.date.get_config") def test_format_isodate_for_user_timezone(self, mock_get_config): # No timezone, defaults to UTC mock_get_config.return_value = {} - value = 'Tue, 20 Oct 2015 08:00:00 UTC' + value = "Tue, 20 Oct 2015 08:00:00 UTC" result = format_isodate_for_user_timezone(value=value) - self.assertEqual(result, 'Tue, 20 Oct 2015 08:00:00 UTC') + self.assertEqual(result, "Tue, 20 Oct 2015 08:00:00 UTC") # Timezone provided - mock_get_config.return_value = {'cli': {'timezone': 'Europe/Ljubljana'}} + mock_get_config.return_value = {"cli": {"timezone": "Europe/Ljubljana"}} - value = 'Tue, 20 Oct 2015 08:00:00 UTC' + value = "Tue, 20 Oct 2015 08:00:00 UTC" result = format_isodate_for_user_timezone(value=value) - self.assertEqual(result, 'Tue, 20 Oct 2015 10:00:00 CEST') + self.assertEqual(result, "Tue, 20 Oct 2015 10:00:00 CEST") diff --git a/st2client/tests/unit/test_util_json.py b/st2client/tests/unit/test_util_json.py index f44a4b9bf93..2333128c2e5 100644 --- a/st2client/tests/unit/test_util_json.py +++ b/st2client/tests/unit/test_util_json.py @@ -25,76 +25,67 @@ LOG = logging.getLogger(__name__) DOC = { - 'a01': 1, - 'b01': 2, - 'c01': { - 'c11': 3, - 'd12': 4, - 'c13': { - 'c21': 5, - 'c22': 6 - }, - 'c14': [7, 8, 9] - } + "a01": 1, + "b01": 2, + "c01": {"c11": 3, "d12": 4, "c13": {"c21": 5, "c22": 6}, "c14": [7, 8, 9]}, } DOC_IP_ADDRESS = { - 'ips': { - "192.168.1.1": { - "hostname": "router.domain.tld" - }, - "192.168.1.10": { - "hostname": "server.domain.tld" - } + "ips": { + "192.168.1.1": {"hostname": "router.domain.tld"}, + "192.168.1.10": {"hostname": "server.domain.tld"}, } } class TestGetValue(unittest2.TestCase): - def test_dot_notation(self): - self.assertEqual(jsutil.get_value(DOC, 'a01'), 1) - self.assertEqual(jsutil.get_value(DOC, 'c01.c11'), 3) - self.assertEqual(jsutil.get_value(DOC, 'c01.c13.c22'), 6) - self.assertEqual(jsutil.get_value(DOC, 'c01.c13'), {'c21': 5, 'c22': 6}) - self.assertListEqual(jsutil.get_value(DOC, 'c01.c14'), [7, 8, 9]) + self.assertEqual(jsutil.get_value(DOC, "a01"), 1) + self.assertEqual(jsutil.get_value(DOC, "c01.c11"), 3) + self.assertEqual(jsutil.get_value(DOC, "c01.c13.c22"), 6) + self.assertEqual(jsutil.get_value(DOC, "c01.c13"), {"c21": 5, "c22": 6}) + self.assertListEqual(jsutil.get_value(DOC, "c01.c14"), [7, 8, 9]) def test_dot_notation_with_val_error(self): self.assertRaises(ValueError, jsutil.get_value, DOC, None) - self.assertRaises(ValueError, jsutil.get_value, DOC, '') - self.assertRaises(ValueError, jsutil.get_value, json.dumps(DOC), 'a01') + self.assertRaises(ValueError, jsutil.get_value, DOC, "") + self.assertRaises(ValueError, jsutil.get_value, json.dumps(DOC), "a01") def test_dot_notation_with_key_error(self): - self.assertIsNone(jsutil.get_value(DOC, 'd01')) - self.assertIsNone(jsutil.get_value(DOC, 'a01.a11')) - self.assertIsNone(jsutil.get_value(DOC, 'c01.c11.c21.c31')) - self.assertIsNone(jsutil.get_value(DOC, 'c01.c14.c31')) + self.assertIsNone(jsutil.get_value(DOC, "d01")) + self.assertIsNone(jsutil.get_value(DOC, "a01.a11")) + self.assertIsNone(jsutil.get_value(DOC, "c01.c11.c21.c31")) + self.assertIsNone(jsutil.get_value(DOC, "c01.c14.c31")) def test_ip_address(self): - self.assertEqual(jsutil.get_value(DOC_IP_ADDRESS, 'ips."192.168.1.1"'), - {"hostname": "router.domain.tld"}) + self.assertEqual( + jsutil.get_value(DOC_IP_ADDRESS, 'ips."192.168.1.1"'), + {"hostname": "router.domain.tld"}, + ) def test_chars_nums_dashes_underscores_calls_simple(self): - for char in 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_': + for char in "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-_": with mock.patch("st2client.utils.jsutil._get_value_simple") as mock_simple: jsutil.get_value(DOC, char) mock_simple.assert_called_with(DOC, char) def test_symbols_calls_complex(self): - for char in '`~!@#$%^&&*()=+{}[]|\\;:\'"<>,./?': - with mock.patch("st2client.utils.jsutil._get_value_complex") as mock_complex: + for char in "`~!@#$%^&&*()=+{}[]|\\;:'\"<>,./?": + with mock.patch( + "st2client.utils.jsutil._get_value_complex" + ) as mock_complex: jsutil.get_value(DOC, char) mock_complex.assert_called_with(DOC, char) @mock.patch("st2client.utils.jsutil._get_value_simple") def test_single_key_calls_simple(self, mock__get_value_simple): - jsutil.get_value(DOC, 'a01') - mock__get_value_simple.assert_called_with(DOC, 'a01') + jsutil.get_value(DOC, "a01") + mock__get_value_simple.assert_called_with(DOC, "a01") @mock.patch("st2client.utils.jsutil._get_value_simple") def test_dot_notation_calls_simple(self, mock__get_value_simple): - jsutil.get_value(DOC, 'c01.c11') - mock__get_value_simple.assert_called_with(DOC, 'c01.c11') + jsutil.get_value(DOC, "c01.c11") + mock__get_value_simple.assert_called_with(DOC, "c01.c11") @mock.patch("st2client.utils.jsutil._get_value_complex") def test_ip_address_calls_complex(self, mock__get_value_complex): @@ -103,54 +94,64 @@ def test_ip_address_calls_complex(self, mock__get_value_complex): @mock.patch("st2client.utils.jsutil._get_value_complex") def test_beginning_dot_calls_complex(self, mock__get_value_complex): - jsutil.get_value(DOC, '.c01.c11') - mock__get_value_complex.assert_called_with(DOC, '.c01.c11') + jsutil.get_value(DOC, ".c01.c11") + mock__get_value_complex.assert_called_with(DOC, ".c01.c11") @mock.patch("st2client.utils.jsutil._get_value_complex") def test_ending_dot_calls_complex(self, mock__get_value_complex): - jsutil.get_value(DOC, 'c01.c11.') - mock__get_value_complex.assert_called_with(DOC, 'c01.c11.') + jsutil.get_value(DOC, "c01.c11.") + mock__get_value_complex.assert_called_with(DOC, "c01.c11.") @mock.patch("st2client.utils.jsutil._get_value_complex") def test_double_dot_calls_complex(self, mock__get_value_complex): - jsutil.get_value(DOC, 'c01..c11') - mock__get_value_complex.assert_called_with(DOC, 'c01..c11') + jsutil.get_value(DOC, "c01..c11") + mock__get_value_complex.assert_called_with(DOC, "c01..c11") class TestGetKeyValuePairs(unittest2.TestCase): - def test_select_kvps(self): - self.assertEqual(jsutil.get_kvps(DOC, ['a01']), - {'a01': 1}) - self.assertEqual(jsutil.get_kvps(DOC, ['c01.c11']), - {'c01': {'c11': 3}}) - self.assertEqual(jsutil.get_kvps(DOC, ['c01.c13.c22']), - {'c01': {'c13': {'c22': 6}}}) - self.assertEqual(jsutil.get_kvps(DOC, ['c01.c13']), - {'c01': {'c13': {'c21': 5, 'c22': 6}}}) - self.assertEqual(jsutil.get_kvps(DOC, ['c01.c14']), - {'c01': {'c14': [7, 8, 9]}}) - self.assertEqual(jsutil.get_kvps(DOC, ['a01', 'c01.c11', 'c01.c13.c21']), - {'a01': 1, 'c01': {'c11': 3, 'c13': {'c21': 5}}}) - self.assertEqual(jsutil.get_kvps(DOC_IP_ADDRESS, - ['ips."192.168.1.1"', - 'ips."192.168.1.10".hostname']), - {'ips': - {'"192': - {'168': - {'1': - {'1"': {'hostname': 'router.domain.tld'}, - '10"': {'hostname': 'server.domain.tld'}}}}}}) + self.assertEqual(jsutil.get_kvps(DOC, ["a01"]), {"a01": 1}) + self.assertEqual(jsutil.get_kvps(DOC, ["c01.c11"]), {"c01": {"c11": 3}}) + self.assertEqual( + jsutil.get_kvps(DOC, ["c01.c13.c22"]), {"c01": {"c13": {"c22": 6}}} + ) + self.assertEqual( + jsutil.get_kvps(DOC, ["c01.c13"]), {"c01": {"c13": {"c21": 5, "c22": 6}}} + ) + self.assertEqual(jsutil.get_kvps(DOC, ["c01.c14"]), {"c01": {"c14": [7, 8, 9]}}) + self.assertEqual( + jsutil.get_kvps(DOC, ["a01", "c01.c11", "c01.c13.c21"]), + {"a01": 1, "c01": {"c11": 3, "c13": {"c21": 5}}}, + ) + self.assertEqual( + jsutil.get_kvps( + DOC_IP_ADDRESS, ['ips."192.168.1.1"', 'ips."192.168.1.10".hostname'] + ), + { + "ips": { + '"192': { + "168": { + "1": { + '1"': {"hostname": "router.domain.tld"}, + '10"': {"hostname": "server.domain.tld"}, + } + } + } + } + }, + ) def test_select_kvps_with_val_error(self): self.assertRaises(ValueError, jsutil.get_kvps, DOC, [None]) - self.assertRaises(ValueError, jsutil.get_kvps, DOC, ['']) - self.assertRaises(ValueError, jsutil.get_kvps, json.dumps(DOC), ['a01']) + self.assertRaises(ValueError, jsutil.get_kvps, DOC, [""]) + self.assertRaises(ValueError, jsutil.get_kvps, json.dumps(DOC), ["a01"]) def test_select_kvps_with_key_error(self): - self.assertEqual(jsutil.get_kvps(DOC, ['d01']), {}) - self.assertEqual(jsutil.get_kvps(DOC, ['a01.a11']), {}) - self.assertEqual(jsutil.get_kvps(DOC, ['c01.c11.c21.c31']), {}) - self.assertEqual(jsutil.get_kvps(DOC, ['c01.c14.c31']), {}) - self.assertEqual(jsutil.get_kvps(DOC, ['a01', 'c01.c11', 'c01.c13.c23']), - {'a01': 1, 'c01': {'c11': 3}}) + self.assertEqual(jsutil.get_kvps(DOC, ["d01"]), {}) + self.assertEqual(jsutil.get_kvps(DOC, ["a01.a11"]), {}) + self.assertEqual(jsutil.get_kvps(DOC, ["c01.c11.c21.c31"]), {}) + self.assertEqual(jsutil.get_kvps(DOC, ["c01.c14.c31"]), {}) + self.assertEqual( + jsutil.get_kvps(DOC, ["a01", "c01.c11", "c01.c13.c23"]), + {"a01": 1, "c01": {"c11": 3}}, + ) diff --git a/st2client/tests/unit/test_util_misc.py b/st2client/tests/unit/test_util_misc.py index 6a2cf3a8fc3..2e33156adc4 100644 --- a/st2client/tests/unit/test_util_misc.py +++ b/st2client/tests/unit/test_util_misc.py @@ -21,37 +21,37 @@ class MiscUtilTestCase(unittest2.TestCase): def test_merge_dicts(self): - d1 = {'a': 1} - d2 = {'a': 2} - expected = {'a': 2} + d1 = {"a": 1} + d2 = {"a": 2} + expected = {"a": 2} result = merge_dicts(d1, d2) self.assertEqual(result, expected) - d1 = {'a': 1} - d2 = {'b': 1} - expected = {'a': 1, 'b': 1} + d1 = {"a": 1} + d2 = {"b": 1} + expected = {"a": 1, "b": 1} result = merge_dicts(d1, d2) self.assertEqual(result, expected) - d1 = {'a': 1} - d2 = {'a': 3, 'b': 1} - expected = {'a': 3, 'b': 1} + d1 = {"a": 1} + d2 = {"a": 3, "b": 1} + expected = {"a": 3, "b": 1} result = merge_dicts(d1, d2) self.assertEqual(result, expected) - d1 = {'a': 1, 'm': None} - d2 = {'a': None, 'b': 1, 'c': None} - expected = {'a': 1, 'b': 1, 'c': None, 'm': None} + d1 = {"a": 1, "m": None} + d2 = {"a": None, "b": 1, "c": None} + expected = {"a": 1, "b": 1, "c": None, "m": None} result = merge_dicts(d1, d2) self.assertEqual(result, expected) - d1 = {'a': 1, 'b': {'a': 1, 'b': 2, 'c': 3}} - d2 = {'b': {'b': 100}} - expected = {'a': 1, 'b': {'a': 1, 'b': 100, 'c': 3}} + d1 = {"a": 1, "b": {"a": 1, "b": 2, "c": 3}} + d2 = {"b": {"b": 100}} + expected = {"a": 1, "b": {"a": 1, "b": 100, "c": 3}} result = merge_dicts(d1, d2) self.assertEqual(result, expected) diff --git a/st2client/tests/unit/test_util_strutil.py b/st2client/tests/unit/test_util_strutil.py index 2d442013de3..585e88c389b 100644 --- a/st2client/tests/unit/test_util_strutil.py +++ b/st2client/tests/unit/test_util_strutil.py @@ -26,17 +26,17 @@ class StrUtilTestCase(unittest2.TestCase): def test_unescape(self): in_str = 'Action execution result double escape \\"stuffs\\".\\r\\n' - expected = 'Action execution result double escape \"stuffs\".\r\n' + expected = 'Action execution result double escape "stuffs".\r\n' out_str = strutil.unescape(in_str) self.assertEqual(out_str, expected) def test_unicode_string(self): - in_str = '\u8c03\u7528CMS\u63a5\u53e3\u5220\u9664\u865a\u62df\u76ee\u5f55' + in_str = "\u8c03\u7528CMS\u63a5\u53e3\u5220\u9664\u865a\u62df\u76ee\u5f55" out_str = strutil.unescape(in_str) self.assertEqual(out_str, in_str) def test_strip_carriage_returns(self): - in_str = 'Windows editors introduce\r\nlike a noob in 2017.' + in_str = "Windows editors introduce\r\nlike a noob in 2017." out_str = strutil.strip_carriage_returns(in_str) - exp_str = 'Windows editors introduce\nlike a noob in 2017.' + exp_str = "Windows editors introduce\nlike a noob in 2017." self.assertEqual(out_str, exp_str) diff --git a/st2client/tests/unit/test_util_terminal.py b/st2client/tests/unit/test_util_terminal.py index c9b6d82b27c..29a8386b0b7 100644 --- a/st2client/tests/unit/test_util_terminal.py +++ b/st2client/tests/unit/test_util_terminal.py @@ -23,20 +23,20 @@ from st2client.utils.terminal import DEFAULT_TERMINAL_SIZE_COLUMNS from st2client.utils.terminal import get_terminal_size_columns -__all__ = [ - 'TerminalUtilsTestCase' -] +__all__ = ["TerminalUtilsTestCase"] class TerminalUtilsTestCase(unittest2.TestCase): def setUp(self): super(TerminalUtilsTestCase, self).setUp() - if 'COLUMNS' in os.environ: - del os.environ['COLUMNS'] + if "COLUMNS" in os.environ: + del os.environ["COLUMNS"] - @mock.patch.dict(os.environ, {'LINES': '111', 'COLUMNS': '222'}) - def test_get_terminal_size_columns_columns_environment_variable_has_precedence(self): + @mock.patch.dict(os.environ, {"LINES": "111", "COLUMNS": "222"}) + def test_get_terminal_size_columns_columns_environment_variable_has_precedence( + self, + ): # Verify that COLUMNS environment variables has precedence over other approaches columns = get_terminal_size_columns() @@ -44,16 +44,16 @@ def test_get_terminal_size_columns_columns_environment_variable_has_precedence(s # make sure that os.environ['COLUMNS'] isn't set so it can't override/screw-up this test @mock.patch.dict(os.environ, {}) - @mock.patch('fcntl.ioctl', mock.Mock(return_value='dummy')) - @mock.patch('struct.unpack', mock.Mock(return_value=(333, 444))) + @mock.patch("fcntl.ioctl", mock.Mock(return_value="dummy")) + @mock.patch("struct.unpack", mock.Mock(return_value=(333, 444))) def test_get_terminal_size_columns_stdout_is_used(self): columns = get_terminal_size_columns() self.assertEqual(columns, 444) - @mock.patch('struct.unpack', mock.Mock(side_effect=Exception('a'))) - @mock.patch('subprocess.Popen') + @mock.patch("struct.unpack", mock.Mock(side_effect=Exception("a"))) + @mock.patch("subprocess.Popen") def test_get_terminal_size_subprocess_popen_is_used(self, mock_popen): - mock_communicate = mock.Mock(return_value=['555 666']) + mock_communicate = mock.Mock(return_value=["555 666"]) mock_process = mock.Mock() mock_process.returncode = 0 @@ -64,8 +64,8 @@ def test_get_terminal_size_subprocess_popen_is_used(self, mock_popen): columns = get_terminal_size_columns() self.assertEqual(columns, 666) - @mock.patch('struct.unpack', mock.Mock(side_effect=Exception('a'))) - @mock.patch('subprocess.Popen', mock.Mock(side_effect=Exception('b'))) + @mock.patch("struct.unpack", mock.Mock(side_effect=Exception("a"))) + @mock.patch("subprocess.Popen", mock.Mock(side_effect=Exception("b"))) def test_get_terminal_size_default_values_are_used(self): columns = get_terminal_size_columns() diff --git a/st2client/tests/unit/test_workflow.py b/st2client/tests/unit/test_workflow.py index 3896a27bc21..79d580f85d1 100644 --- a/st2client/tests/unit/test_workflow.py +++ b/st2client/tests/unit/test_workflow.py @@ -31,13 +31,13 @@ LOG = logging.getLogger(__name__) MOCK_ACTION = { - 'ref': 'mock.foobar', - 'runner_type': 'mock-runner', - 'pack': 'mock', - 'name': 'foobar', - 'parameters': {}, - 'enabled': True, - 'entry_point': 'workflows/foobar.yaml' + "ref": "mock.foobar", + "runner_type": "mock-runner", + "pack": "mock", + "name": "foobar", + "parameters": {}, + "enabled": True, + "entry_point": "workflows/foobar.yaml", } MOCK_WF_DEF = """ @@ -56,73 +56,88 @@ def get_by_ref(**kwargs): class WorkflowCommandTestCase(st2cli_tests.BaseCLITestCase): - def __init__(self, *args, **kwargs): super(WorkflowCommandTestCase, self).__init__(*args, **kwargs) self.shell = shell.Shell() @mock.patch.object( - httpclient.HTTPClient, 'post_raw', - mock.MagicMock(return_value=st2cli_tests.FakeResponse(json.dumps(MOCK_RESULT), 200, 'OK'))) + httpclient.HTTPClient, + "post_raw", + mock.MagicMock( + return_value=st2cli_tests.FakeResponse(json.dumps(MOCK_RESULT), 200, "OK") + ), + ) def test_inspect_file(self): - fd, path = tempfile.mkstemp(suffix='.yaml') + fd, path = tempfile.mkstemp(suffix=".yaml") try: - with open(path, 'a') as f: + with open(path, "a") as f: f.write(MOCK_WF_DEF) - retcode = self.shell.run(['workflow', 'inspect', '--file', path]) + retcode = self.shell.run(["workflow", "inspect", "--file", path]) self.assertEqual(retcode, 0) httpclient.HTTPClient.post_raw.assert_called_with( - '/inspect', - MOCK_WF_DEF, - headers={'content-type': 'text/plain'} + "/inspect", MOCK_WF_DEF, headers={"content-type": "text/plain"} ) finally: os.close(fd) os.unlink(path) @mock.patch.object( - httpclient.HTTPClient, 'post_raw', - mock.MagicMock(return_value=st2cli_tests.FakeResponse(json.dumps(MOCK_RESULT), 200, 'OK'))) + httpclient.HTTPClient, + "post_raw", + mock.MagicMock( + return_value=st2cli_tests.FakeResponse(json.dumps(MOCK_RESULT), 200, "OK") + ), + ) def test_inspect_bad_file(self): - retcode = self.shell.run(['workflow', 'inspect', '--file', '/tmp/foobar']) + retcode = self.shell.run(["workflow", "inspect", "--file", "/tmp/foobar"]) self.assertEqual(retcode, 1) - self.assertIn('does not exist', self.stdout.getvalue()) + self.assertIn("does not exist", self.stdout.getvalue()) self.assertFalse(httpclient.HTTPClient.post_raw.called) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(side_effect=get_by_ref)) + models.ResourceManager, + "get_by_ref_or_id", + mock.MagicMock(side_effect=get_by_ref), + ) @mock.patch.object( - workflow.WorkflowInspectionCommand, 'get_file_content', - mock.MagicMock(return_value=MOCK_WF_DEF)) + workflow.WorkflowInspectionCommand, + "get_file_content", + mock.MagicMock(return_value=MOCK_WF_DEF), + ) @mock.patch.object( - httpclient.HTTPClient, 'post_raw', - mock.MagicMock(return_value=st2cli_tests.FakeResponse(json.dumps(MOCK_RESULT), 200, 'OK'))) + httpclient.HTTPClient, + "post_raw", + mock.MagicMock( + return_value=st2cli_tests.FakeResponse(json.dumps(MOCK_RESULT), 200, "OK") + ), + ) def test_inspect_action(self): - retcode = self.shell.run(['workflow', 'inspect', '--action', 'mock.foobar']) + retcode = self.shell.run(["workflow", "inspect", "--action", "mock.foobar"]) self.assertEqual(retcode, 0) httpclient.HTTPClient.post_raw.assert_called_with( - '/inspect', - MOCK_WF_DEF, - headers={'content-type': 'text/plain'} + "/inspect", MOCK_WF_DEF, headers={"content-type": "text/plain"} ) @mock.patch.object( - models.ResourceManager, 'get_by_ref_or_id', - mock.MagicMock(return_value=None)) + models.ResourceManager, "get_by_ref_or_id", mock.MagicMock(return_value=None) + ) @mock.patch.object( - httpclient.HTTPClient, 'post_raw', - mock.MagicMock(return_value=st2cli_tests.FakeResponse(json.dumps(MOCK_RESULT), 200, 'OK'))) + httpclient.HTTPClient, + "post_raw", + mock.MagicMock( + return_value=st2cli_tests.FakeResponse(json.dumps(MOCK_RESULT), 200, "OK") + ), + ) def test_inspect_bad_action(self): - retcode = self.shell.run(['workflow', 'inspect', '--action', 'mock.foobar']) + retcode = self.shell.run(["workflow", "inspect", "--action", "mock.foobar"]) self.assertEqual(retcode, 1) - self.assertIn('Unable to identify action', self.stdout.getvalue()) + self.assertIn("Unable to identify action", self.stdout.getvalue()) self.assertFalse(httpclient.HTTPClient.post_raw.called) diff --git a/st2common/bin/migrations/v1.5/st2-migrate-datastore-to-include-scope-secret.py b/st2common/bin/migrations/v1.5/st2-migrate-datastore-to-include-scope-secret.py index db5cbdcf671..b8de86a6611 100755 --- a/st2common/bin/migrations/v1.5/st2-migrate-datastore-to-include-scope-secret.py +++ b/st2common/bin/migrations/v1.5/st2-migrate-datastore-to-include-scope-secret.py @@ -35,16 +35,20 @@ def migrate_datastore(): try: for kvp in key_value_items: - kvp_id = getattr(kvp, 'id', None) - secret = getattr(kvp, 'secret', False) - scope = getattr(kvp, 'scope', SYSTEM_SCOPE) - new_kvp_db = KeyValuePairDB(id=kvp_id, name=kvp.name, - expire_timestamp=kvp.expire_timestamp, - value=kvp.value, secret=secret, - scope=scope) + kvp_id = getattr(kvp, "id", None) + secret = getattr(kvp, "secret", False) + scope = getattr(kvp, "scope", SYSTEM_SCOPE) + new_kvp_db = KeyValuePairDB( + id=kvp_id, + name=kvp.name, + expire_timestamp=kvp.expire_timestamp, + value=kvp.value, + secret=secret, + scope=scope, + ) KeyValuePair.add_or_update(new_kvp_db) except: - print('ERROR: Failed migrating datastore item with name: %s' % kvp.name) + print("ERROR: Failed migrating datastore item with name: %s" % kvp.name) tb.print_exc() raise @@ -58,10 +62,10 @@ def main(): # Migrate rules. try: migrate_datastore() - print('SUCCESS: Datastore items migrated successfully.') + print("SUCCESS: Datastore items migrated successfully.") exit_code = 0 except: - print('ABORTED: Datastore migration aborted on first failure.') + print("ABORTED: Datastore migration aborted on first failure.") exit_code = 1 # Disconnect from db. @@ -69,5 +73,5 @@ def main(): sys.exit(exit_code) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/st2common/bin/migrations/v2.1/st2-migrate-datastore-scopes.py b/st2common/bin/migrations/v2.1/st2-migrate-datastore-scopes.py index 24275f80dc5..a1a500ad967 100755 --- a/st2common/bin/migrations/v2.1/st2-migrate-datastore-scopes.py +++ b/st2common/bin/migrations/v2.1/st2-migrate-datastore-scopes.py @@ -32,9 +32,9 @@ def migrate_datastore(): try: for kvp in key_value_items: - kvp_id = getattr(kvp, 'id', None) - secret = getattr(kvp, 'secret', False) - scope = getattr(kvp, 'scope', SYSTEM_SCOPE) + kvp_id = getattr(kvp, "id", None) + secret = getattr(kvp, "secret", False) + scope = getattr(kvp, "scope", SYSTEM_SCOPE) if scope == USER_SCOPE: scope = FULL_USER_SCOPE @@ -42,13 +42,17 @@ def migrate_datastore(): if scope == SYSTEM_SCOPE: scope = FULL_SYSTEM_SCOPE - new_kvp_db = KeyValuePairDB(id=kvp_id, name=kvp.name, - expire_timestamp=kvp.expire_timestamp, - value=kvp.value, secret=secret, - scope=scope) + new_kvp_db = KeyValuePairDB( + id=kvp_id, + name=kvp.name, + expire_timestamp=kvp.expire_timestamp, + value=kvp.value, + secret=secret, + scope=scope, + ) KeyValuePair.add_or_update(new_kvp_db) except: - print('ERROR: Failed migrating datastore item with name: %s' % kvp.name) + print("ERROR: Failed migrating datastore item with name: %s" % kvp.name) tb.print_exc() raise @@ -62,10 +66,10 @@ def main(): # Migrate rules. try: migrate_datastore() - print('SUCCESS: Datastore items migrated successfully.') + print("SUCCESS: Datastore items migrated successfully.") exit_code = 0 except: - print('ABORTED: Datastore migration aborted on first failure.') + print("ABORTED: Datastore migration aborted on first failure.") exit_code = 1 # Disconnect from db. @@ -73,5 +77,5 @@ def main(): sys.exit(exit_code) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/st2common/bin/migrations/v3.1/st2-cleanup-policy-delayed.py b/st2common/bin/migrations/v3.1/st2-cleanup-policy-delayed.py index bb4ee666b98..9d097894135 100755 --- a/st2common/bin/migrations/v3.1/st2-cleanup-policy-delayed.py +++ b/st2common/bin/migrations/v3.1/st2-cleanup-policy-delayed.py @@ -39,12 +39,14 @@ def main(): try: handler = scheduler_handler.get_handler() handler._cleanup_policy_delayed() - LOG.info('SUCCESS: Completed clean up of executions with deprecated policy-delayed status.') + LOG.info( + "SUCCESS: Completed clean up of executions with deprecated policy-delayed status." + ) exit_code = 0 except Exception as e: LOG.error( - 'ABORTED: Clean up of executions with deprecated policy-delayed status aborted on ' - 'first failure. %s' % e.message + "ABORTED: Clean up of executions with deprecated policy-delayed status aborted on " + "first failure. %s" % e.message ) exit_code = 1 @@ -53,5 +55,5 @@ def main(): sys.exit(exit_code) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/st2common/bin/paramiko_ssh_evenlets_tester.py b/st2common/bin/paramiko_ssh_evenlets_tester.py index 49a42545f85..af30196de12 100755 --- a/st2common/bin/paramiko_ssh_evenlets_tester.py +++ b/st2common/bin/paramiko_ssh_evenlets_tester.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from st2common.util.monkey_patch import monkey_patch + monkey_patch() import argparse @@ -34,49 +35,54 @@ def main(user, pkey, password, hosts_str, cmd, file_path, dir_path, delete_dir): if file_path: if not os.path.exists(file_path): - raise Exception('File not found.') - results = client.put(file_path, '/home/lakshmi/test_file', mode="0660") - pp.pprint('Copy results: \n%s' % results) - results = client.run('ls -rlth') - pp.pprint('ls results: \n%s' % results) + raise Exception("File not found.") + results = client.put(file_path, "/home/lakshmi/test_file", mode="0660") + pp.pprint("Copy results: \n%s" % results) + results = client.run("ls -rlth") + pp.pprint("ls results: \n%s" % results) if dir_path: if not os.path.exists(dir_path): - raise Exception('File not found.') - results = client.put(dir_path, '/home/lakshmi/', mode="0660") - pp.pprint('Copy results: \n%s' % results) - results = client.run('ls -rlth') - pp.pprint('ls results: \n%s' % results) + raise Exception("File not found.") + results = client.put(dir_path, "/home/lakshmi/", mode="0660") + pp.pprint("Copy results: \n%s" % results) + results = client.run("ls -rlth") + pp.pprint("ls results: \n%s" % results) if cmd: results = client.run(cmd) - pp.pprint('cmd results: \n%s' % results) + pp.pprint("cmd results: \n%s" % results) if delete_dir: results = client.delete_dir(delete_dir, force=True) - pp.pprint('Delete results: \n%s' % results) + pp.pprint("Delete results: \n%s" % results) -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Parallel SSH tester.') - parser.add_argument('--hosts', required=True, - help='List of hosts to connect to') - parser.add_argument('--private-key', required=False, - help='Private key to use.') - parser.add_argument('--password', required=False, - help='Password.') - parser.add_argument('--user', required=True, - help='SSH user name.') - parser.add_argument('--cmd', required=False, - help='Command to run on host.') - parser.add_argument('--file', required=False, - help='Path of file to copy to remote host.') - parser.add_argument('--dir', required=False, - help='Path of dir to copy to remote host.') - parser.add_argument('--delete-dir', required=False, - help='Path of dir to delete on remote host.') +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Parallel SSH tester.") + parser.add_argument("--hosts", required=True, help="List of hosts to connect to") + parser.add_argument("--private-key", required=False, help="Private key to use.") + parser.add_argument("--password", required=False, help="Password.") + parser.add_argument("--user", required=True, help="SSH user name.") + parser.add_argument("--cmd", required=False, help="Command to run on host.") + parser.add_argument( + "--file", required=False, help="Path of file to copy to remote host." + ) + parser.add_argument( + "--dir", required=False, help="Path of dir to copy to remote host." + ) + parser.add_argument( + "--delete-dir", required=False, help="Path of dir to delete on remote host." + ) args = parser.parse_args() - main(user=args.user, pkey=args.private_key, password=args.password, - hosts_str=args.hosts, cmd=args.cmd, - file_path=args.file, dir_path=args.dir, delete_dir=args.delete_dir) + main( + user=args.user, + pkey=args.private_key, + password=args.password, + hosts_str=args.hosts, + cmd=args.cmd, + file_path=args.file, + dir_path=args.dir, + delete_dir=args.delete_dir, + ) diff --git a/st2common/dist_utils.py b/st2common/dist_utils.py index a6f62c8cc2a..2f2043cf29f 100644 --- a/st2common/dist_utils.py +++ b/st2common/dist_utils.py @@ -43,17 +43,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" __all__ = [ - 'check_pip_is_installed', - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_is_installed", + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] @@ -64,15 +64,15 @@ def check_pip_is_installed(): try: import pip # NOQA except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) return True -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ @@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'): import pip if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) return True @@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path): reqs = [] def _get_link(line): - vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+'] + vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"] for vcs_prefix in vcs_prefixes: - if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)): - req_name = re.findall('.*#egg=(.+)([&|@]).*$', line) + if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)): + req_name = re.findall(".*#egg=(.+)([&|@]).*$", line) if not req_name: - req_name = re.findall('.*#egg=(.+?)$', line) + req_name = re.findall(".*#egg=(.+?)$", line) else: req_name = req_name[0] if not req_name: - raise ValueError('Line "%s" is missing "#egg="' % (line)) + raise ValueError( + 'Line "%s" is missing "#egg="' % (line) + ) - link = line.replace('-e ', '').strip() + link = line.replace("-e ", "").strip() return link, req_name[0] return None, None - with open(requirements_file_path, 'r') as fp: + with open(requirements_file_path, "r") as fp: for line in fp.readlines(): line = line.strip() - if line.startswith('#') or not line: + if line.startswith("#") or not line: continue link, req_name = _get_link(line=line) @@ -131,8 +135,8 @@ def _get_link(line): else: req_name = line - if ';' in req_name: - req_name = req_name.split(';')[0].strip() + if ";" in req_name: + req_name = req_name.split(";")[0].strip() reqs.append(req_name) @@ -146,7 +150,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -155,14 +159,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/st2common/setup.py b/st2common/setup.py index f68679af8cc..908884260d4 100644 --- a/st2common/setup.py +++ b/st2common/setup.py @@ -23,10 +23,10 @@ from dist_utils import apply_vagrant_workaround from dist_utils import get_version_string -ST2_COMPONENT = 'st2common' +ST2_COMPONENT = "st2common" BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt') -INIT_FILE = os.path.join(BASE_DIR, 'st2common/__init__.py') +REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt") +INIT_FILE = os.path.join(BASE_DIR, "st2common/__init__.py") install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE) @@ -34,41 +34,43 @@ setup( name=ST2_COMPONENT, version=get_version_string(INIT_FILE), - description='{} StackStorm event-driven automation platform component'.format(ST2_COMPONENT), - author='StackStorm', - author_email='info@stackstorm.com', - license='Apache License (2.0)', - url='https://stackstorm.com/', + description="{} StackStorm event-driven automation platform component".format( + ST2_COMPONENT + ), + author="StackStorm", + author_email="info@stackstorm.com", + license="Apache License (2.0)", + url="https://stackstorm.com/", install_requires=install_reqs, dependency_links=dep_links, test_suite=ST2_COMPONENT, zip_safe=False, include_package_data=True, - packages=find_packages(exclude=['setuptools', 'tests']), + packages=find_packages(exclude=["setuptools", "tests"]), scripts=[ - 'bin/st2-bootstrap-rmq', - 'bin/st2-cleanup-db', - 'bin/st2-register-content', - 'bin/st2-purge-executions', - 'bin/st2-purge-trigger-instances', - 'bin/st2-run-pack-tests', - 'bin/st2ctl', - 'bin/st2-generate-symmetric-crypto-key', - 'bin/st2-self-check', - 'bin/st2-track-result', - 'bin/st2-validate-pack-config', - 'bin/st2-pack-install', - 'bin/st2-pack-download', - 'bin/st2-pack-setup-virtualenv' + "bin/st2-bootstrap-rmq", + "bin/st2-cleanup-db", + "bin/st2-register-content", + "bin/st2-purge-executions", + "bin/st2-purge-trigger-instances", + "bin/st2-run-pack-tests", + "bin/st2ctl", + "bin/st2-generate-symmetric-crypto-key", + "bin/st2-self-check", + "bin/st2-track-result", + "bin/st2-validate-pack-config", + "bin/st2-pack-install", + "bin/st2-pack-download", + "bin/st2-pack-setup-virtualenv", ], entry_points={ - 'st2common.metrics.driver': [ - 'statsd = st2common.metrics.drivers.statsd_driver:StatsdDriver', - 'noop = st2common.metrics.drivers.noop_driver:NoopDriver', - 'echo = st2common.metrics.drivers.echo_driver:EchoDriver' + "st2common.metrics.driver": [ + "statsd = st2common.metrics.drivers.statsd_driver:StatsdDriver", + "noop = st2common.metrics.drivers.noop_driver:NoopDriver", + "echo = st2common.metrics.drivers.echo_driver:EchoDriver", ], - 'st2common.rbac.backend': [ - 'noop = st2common.rbac.backends.noop:NoOpRBACBackend' + "st2common.rbac.backend": [ + "noop = st2common.rbac.backends.noop:NoOpRBACBackend" ], - } + }, ) diff --git a/st2common/st2common/__init__.py b/st2common/st2common/__init__.py index bbe290db9a7..e6d3f15e0bd 100644 --- a/st2common/st2common/__init__.py +++ b/st2common/st2common/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = "3.4dev" diff --git a/st2common/st2common/bootstrap/actionsregistrar.py b/st2common/st2common/bootstrap/actionsregistrar.py index c21788fb141..f5265bac482 100644 --- a/st2common/st2common/bootstrap/actionsregistrar.py +++ b/st2common/st2common/bootstrap/actionsregistrar.py @@ -30,10 +30,7 @@ import st2common.util.action_db as action_utils import st2common.validators.api.action as action_validator -__all__ = [ - 'ActionsRegistrar', - 'register_actions' -] +__all__ = ["ActionsRegistrar", "register_actions"] LOG = logging.getLogger(__name__) @@ -53,15 +50,18 @@ def register_from_packs(self, base_dirs): self.register_packs(base_dirs=base_dirs) registered_count = 0 - content = self._pack_loader.get_content(base_dirs=base_dirs, - content_type='actions') + content = self._pack_loader.get_content( + base_dirs=base_dirs, content_type="actions" + ) for pack, actions_dir in six.iteritems(content): if not actions_dir: - LOG.debug('Pack %s does not contain actions.', pack) + LOG.debug("Pack %s does not contain actions.", pack) continue try: - LOG.debug('Registering actions from pack %s:, dir: %s', pack, actions_dir) + LOG.debug( + "Registering actions from pack %s:, dir: %s", pack, actions_dir + ) actions = self._get_actions_from_pack(actions_dir) count = self._register_actions_from_pack(pack=pack, actions=actions) registered_count += count @@ -69,7 +69,9 @@ def register_from_packs(self, base_dirs): if self._fail_on_failure: raise e - LOG.exception('Failed registering all actions from pack: %s', actions_dir) + LOG.exception( + "Failed registering all actions from pack: %s", actions_dir + ) return registered_count @@ -80,10 +82,11 @@ def register_from_pack(self, pack_dir): :return: Number of actions registered. :rtype: ``int`` """ - pack_dir = pack_dir[:-1] if pack_dir.endswith('/') else pack_dir + pack_dir = pack_dir[:-1] if pack_dir.endswith("/") else pack_dir _, pack = os.path.split(pack_dir) - actions_dir = self._pack_loader.get_content_from_pack(pack_dir=pack_dir, - content_type='actions') + actions_dir = self._pack_loader.get_content_from_pack( + pack_dir=pack_dir, content_type="actions" + ) # Register pack first self.register_pack(pack_name=pack, pack_dir=pack_dir) @@ -92,16 +95,18 @@ def register_from_pack(self, pack_dir): if not actions_dir: return registered_count - LOG.debug('Registering actions from pack %s:, dir: %s', pack, actions_dir) + LOG.debug("Registering actions from pack %s:, dir: %s", pack, actions_dir) try: actions = self._get_actions_from_pack(actions_dir=actions_dir) - registered_count = self._register_actions_from_pack(pack=pack, actions=actions) + registered_count = self._register_actions_from_pack( + pack=pack, actions=actions + ) except Exception as e: if self._fail_on_failure: raise e - LOG.exception('Failed registering all actions from pack: %s', actions_dir) + LOG.exception("Failed registering all actions from pack: %s", actions_dir) return registered_count @@ -109,29 +114,33 @@ def _get_actions_from_pack(self, actions_dir): actions = self.get_resources_from_pack(resources_dir=actions_dir) # Exclude global actions configuration file - config_files = ['actions/config' + ext for ext in self.ALLOWED_EXTENSIONS] + config_files = ["actions/config" + ext for ext in self.ALLOWED_EXTENSIONS] for config_file in config_files: - actions = [file_path for file_path in actions if config_file not in file_path] + actions = [ + file_path for file_path in actions if config_file not in file_path + ] return actions def _register_action(self, pack, action): content = self._meta_loader.load(action) - pack_field = content.get('pack', None) + pack_field = content.get("pack", None) if not pack_field: - content['pack'] = pack + content["pack"] = pack pack_field = pack if pack_field != pack: - raise Exception('Model is in pack "%s" but field "pack" is different: %s' % - (pack, pack_field)) + raise Exception( + 'Model is in pack "%s" but field "pack" is different: %s' + % (pack, pack_field) + ) # Add in "metadata_file" attribute which stores path to the pack metadata file relative to # the pack directory - metadata_file = content_utils.get_relative_path_to_pack_file(pack_ref=pack, - file_path=action, - use_pack_cache=True) - content['metadata_file'] = metadata_file + metadata_file = content_utils.get_relative_path_to_pack_file( + pack_ref=pack, file_path=action, use_pack_cache=True + ) + content["metadata_file"] = metadata_file action_api = ActionAPI(**content) @@ -141,25 +150,29 @@ def _register_action(self, pack, action): # We throw a more user-friendly exception on invalid parameter name msg = six.text_type(e) - is_invalid_parameter_name = 'does not match any of the regexes: ' in msg + is_invalid_parameter_name = "does not match any of the regexes: " in msg if is_invalid_parameter_name: - match = re.search('\'(.+?)\' does not match any of the regexes', msg) + match = re.search("'(.+?)' does not match any of the regexes", msg) if match: parameter_name = match.groups()[0] else: - parameter_name = 'unknown' + parameter_name = "unknown" - new_msg = ('Parameter name "%s" is invalid. Valid characters for parameter name ' - 'are [a-zA-Z0-0_].' % (parameter_name)) - new_msg += '\n\n' + msg + new_msg = ( + 'Parameter name "%s" is invalid. Valid characters for parameter name ' + "are [a-zA-Z0-0_]." % (parameter_name) + ) + new_msg += "\n\n" + msg raise jsonschema.ValidationError(new_msg) raise e # Use in-memory cached RunnerTypeDB objects to reduce load on the database if self._use_runners_cache: - runner_type_db = self._runner_type_db_cache.get(action_api.runner_type, None) + runner_type_db = self._runner_type_db_cache.get( + action_api.runner_type, None + ) if not runner_type_db: runner_type_db = action_validator.get_runner_model(action_api) @@ -170,36 +183,47 @@ def _register_action(self, pack, action): action_validator.validate_action(action_api, runner_type_db=runner_type_db) model = ActionAPI.to_model(action_api) - action_ref = ResourceReference.to_string_reference(pack=pack, name=str(content['name'])) + action_ref = ResourceReference.to_string_reference( + pack=pack, name=str(content["name"]) + ) existing = action_utils.get_action_by_ref(action_ref) if not existing: - LOG.debug('Action %s not found. Creating new one with: %s', action_ref, content) + LOG.debug( + "Action %s not found. Creating new one with: %s", action_ref, content + ) else: - LOG.debug('Action %s found. Will be updated from: %s to: %s', - action_ref, existing, model) + LOG.debug( + "Action %s found. Will be updated from: %s to: %s", + action_ref, + existing, + model, + ) model.id = existing.id try: model = Action.add_or_update(model) - extra = {'action_db': model} - LOG.audit('Action updated. Action %s from %s.', model, action, extra=extra) + extra = {"action_db": model} + LOG.audit("Action updated. Action %s from %s.", model, action, extra=extra) except Exception: - LOG.exception('Failed to write action to db %s.', model.name) + LOG.exception("Failed to write action to db %s.", model.name) raise def _register_actions_from_pack(self, pack, actions): registered_count = 0 for action in actions: try: - LOG.debug('Loading action from %s.', action) + LOG.debug("Loading action from %s.", action) self._register_action(pack=pack, action=action) except Exception as e: if self._fail_on_failure: - msg = ('Failed to register action "%s" from pack "%s": %s' % (action, pack, - six.text_type(e))) + msg = 'Failed to register action "%s" from pack "%s": %s' % ( + action, + pack, + six.text_type(e), + ) raise ValueError(msg) - LOG.exception('Unable to register action: %s', action) + LOG.exception("Unable to register action: %s", action) continue else: registered_count += 1 @@ -207,16 +231,18 @@ def _register_actions_from_pack(self, pack, actions): return registered_count -def register_actions(packs_base_paths=None, pack_dir=None, use_pack_cache=True, - fail_on_failure=False): +def register_actions( + packs_base_paths=None, pack_dir=None, use_pack_cache=True, fail_on_failure=False +): if packs_base_paths: assert isinstance(packs_base_paths, list) if not packs_base_paths: packs_base_paths = content_utils.get_packs_base_paths() - registrar = ActionsRegistrar(use_pack_cache=use_pack_cache, - fail_on_failure=fail_on_failure) + registrar = ActionsRegistrar( + use_pack_cache=use_pack_cache, fail_on_failure=fail_on_failure + ) if pack_dir: result = registrar.register_from_pack(pack_dir=pack_dir) diff --git a/st2common/st2common/bootstrap/aliasesregistrar.py b/st2common/st2common/bootstrap/aliasesregistrar.py index dbc9c3b0fce..c9d4ef70173 100644 --- a/st2common/st2common/bootstrap/aliasesregistrar.py +++ b/st2common/st2common/bootstrap/aliasesregistrar.py @@ -27,10 +27,7 @@ from st2common.persistence.actionalias import ActionAlias from st2common.exceptions.db import StackStormDBObjectNotFoundError -__all__ = [ - 'AliasesRegistrar', - 'register_aliases' -] +__all__ = ["AliasesRegistrar", "register_aliases"] LOG = logging.getLogger(__name__) @@ -50,15 +47,18 @@ def register_from_packs(self, base_dirs): self.register_packs(base_dirs=base_dirs) registered_count = 0 - content = self._pack_loader.get_content(base_dirs=base_dirs, - content_type='aliases') + content = self._pack_loader.get_content( + base_dirs=base_dirs, content_type="aliases" + ) for pack, aliases_dir in six.iteritems(content): if not aliases_dir: - LOG.debug('Pack %s does not contain aliases.', pack) + LOG.debug("Pack %s does not contain aliases.", pack) continue try: - LOG.debug('Registering aliases from pack %s:, dir: %s', pack, aliases_dir) + LOG.debug( + "Registering aliases from pack %s:, dir: %s", pack, aliases_dir + ) aliases = self._get_aliases_from_pack(aliases_dir) count = self._register_aliases_from_pack(pack=pack, aliases=aliases) registered_count += count @@ -66,7 +66,9 @@ def register_from_packs(self, base_dirs): if self._fail_on_failure: raise e - LOG.exception('Failed registering all aliases from pack: %s', aliases_dir) + LOG.exception( + "Failed registering all aliases from pack: %s", aliases_dir + ) return registered_count @@ -77,10 +79,11 @@ def register_from_pack(self, pack_dir): :return: Number of aliases registered. :rtype: ``int`` """ - pack_dir = pack_dir[:-1] if pack_dir.endswith('/') else pack_dir + pack_dir = pack_dir[:-1] if pack_dir.endswith("/") else pack_dir _, pack = os.path.split(pack_dir) - aliases_dir = self._pack_loader.get_content_from_pack(pack_dir=pack_dir, - content_type='aliases') + aliases_dir = self._pack_loader.get_content_from_pack( + pack_dir=pack_dir, content_type="aliases" + ) # Register pack first self.register_pack(pack_name=pack, pack_dir=pack_dir) @@ -89,16 +92,18 @@ def register_from_pack(self, pack_dir): if not aliases_dir: return registered_count - LOG.debug('Registering aliases from pack %s:, dir: %s', pack, aliases_dir) + LOG.debug("Registering aliases from pack %s:, dir: %s", pack, aliases_dir) try: aliases = self._get_aliases_from_pack(aliases_dir=aliases_dir) - registered_count = self._register_aliases_from_pack(pack=pack, aliases=aliases) + registered_count = self._register_aliases_from_pack( + pack=pack, aliases=aliases + ) except Exception as e: if self._fail_on_failure: raise e - LOG.exception('Failed registering all aliases from pack: %s', aliases_dir) + LOG.exception("Failed registering all aliases from pack: %s", aliases_dir) return registered_count return registered_count @@ -106,7 +111,9 @@ def register_from_pack(self, pack_dir): def _get_aliases_from_pack(self, aliases_dir): return self.get_resources_from_pack(resources_dir=aliases_dir) - def _get_action_alias_db(self, pack, action_alias, ignore_metadata_file_error=False): + def _get_action_alias_db( + self, pack, action_alias, ignore_metadata_file_error=False + ): """ Retrieve ActionAliasDB object. @@ -115,25 +122,27 @@ def _get_action_alias_db(self, pack, action_alias, ignore_metadata_file_error=Fa :type ignore_metadata_file_error: ``bool`` """ content = self._meta_loader.load(action_alias) - pack_field = content.get('pack', None) + pack_field = content.get("pack", None) if not pack_field: - content['pack'] = pack + content["pack"] = pack pack_field = pack if pack_field != pack: - raise Exception('Model is in pack "%s" but field "pack" is different: %s' % - (pack, pack_field)) + raise Exception( + 'Model is in pack "%s" but field "pack" is different: %s' + % (pack, pack_field) + ) # Add in "metadata_file" attribute which stores path to the pack metadata file relative to # the pack directory try: - metadata_file = content_utils.get_relative_path_to_pack_file(pack_ref=pack, - file_path=action_alias, - use_pack_cache=True) + metadata_file = content_utils.get_relative_path_to_pack_file( + pack_ref=pack, file_path=action_alias, use_pack_cache=True + ) except ValueError as e: if not ignore_metadata_file_error: raise e else: - content['metadata_file'] = metadata_file + content["metadata_file"] = metadata_file action_alias_api = ActionAliasAPI(**content) action_alias_api.validate() @@ -142,28 +151,35 @@ def _get_action_alias_db(self, pack, action_alias, ignore_metadata_file_error=Fa return action_alias_db def _register_action_alias(self, pack, action_alias): - action_alias_db = self._get_action_alias_db(pack=pack, - action_alias=action_alias) + action_alias_db = self._get_action_alias_db( + pack=pack, action_alias=action_alias + ) try: action_alias_db.id = ActionAlias.get_by_name(action_alias_db.name).id except StackStormDBObjectNotFoundError: - LOG.debug('ActionAlias %s not found. Creating new one.', action_alias) + LOG.debug("ActionAlias %s not found. Creating new one.", action_alias) action_ref = action_alias_db.action_ref action_db = Action.get_by_ref(action_ref) if not action_db: - LOG.warning('Action %s not found in DB. Did you forget to register the action?', - action_ref) + LOG.warning( + "Action %s not found in DB. Did you forget to register the action?", + action_ref, + ) try: action_alias_db = ActionAlias.add_or_update(action_alias_db) - extra = {'action_alias_db': action_alias_db} - LOG.audit('Action alias updated. Action alias %s from %s.', action_alias_db, - action_alias, extra=extra) + extra = {"action_alias_db": action_alias_db} + LOG.audit( + "Action alias updated. Action alias %s from %s.", + action_alias_db, + action_alias, + extra=extra, + ) except Exception: - LOG.exception('Failed to create action alias %s.', action_alias_db.name) + LOG.exception("Failed to create action alias %s.", action_alias_db.name) raise def _register_aliases_from_pack(self, pack, aliases): @@ -171,15 +187,18 @@ def _register_aliases_from_pack(self, pack, aliases): for alias in aliases: try: - LOG.debug('Loading alias from %s.', alias) + LOG.debug("Loading alias from %s.", alias) self._register_action_alias(pack, alias) except Exception as e: if self._fail_on_failure: - msg = ('Failed to register alias "%s" from pack "%s": %s' % (alias, pack, - six.text_type(e))) + msg = 'Failed to register alias "%s" from pack "%s": %s' % ( + alias, + pack, + six.text_type(e), + ) raise ValueError(msg) - LOG.exception('Unable to register alias: %s', alias) + LOG.exception("Unable to register alias: %s", alias) continue else: registered_count += 1 @@ -187,8 +206,9 @@ def _register_aliases_from_pack(self, pack, aliases): return registered_count -def register_aliases(packs_base_paths=None, pack_dir=None, use_pack_cache=True, - fail_on_failure=False): +def register_aliases( + packs_base_paths=None, pack_dir=None, use_pack_cache=True, fail_on_failure=False +): if packs_base_paths: assert isinstance(packs_base_paths, list) @@ -196,8 +216,9 @@ def register_aliases(packs_base_paths=None, pack_dir=None, use_pack_cache=True, if not packs_base_paths: packs_base_paths = content_utils.get_packs_base_paths() - registrar = AliasesRegistrar(use_pack_cache=use_pack_cache, - fail_on_failure=fail_on_failure) + registrar = AliasesRegistrar( + use_pack_cache=use_pack_cache, fail_on_failure=fail_on_failure + ) if pack_dir: result = registrar.register_from_pack(pack_dir=pack_dir) diff --git a/st2common/st2common/bootstrap/base.py b/st2common/st2common/bootstrap/base.py index 1757a2fa8e3..1070a3af383 100644 --- a/st2common/st2common/bootstrap/base.py +++ b/st2common/st2common/bootstrap/base.py @@ -32,9 +32,7 @@ from st2common.util.pack import get_pack_ref_from_metadata from st2common.exceptions.db import StackStormDBObjectNotFoundError -__all__ = [ - 'ResourceRegistrar' -] +__all__ = ["ResourceRegistrar"] LOG = logging.getLogger(__name__) @@ -44,16 +42,15 @@ # a long running process. REGISTERED_PACKS_CACHE = {} -EXCLUDE_FILE_PATTERNS = [ - '*.pyc', - '.git/*' -] +EXCLUDE_FILE_PATTERNS = ["*.pyc", ".git/*"] class ResourceRegistrar(object): ALLOWED_EXTENSIONS = [] - def __init__(self, use_pack_cache=True, use_runners_cache=False, fail_on_failure=False): + def __init__( + self, use_pack_cache=True, use_runners_cache=False, fail_on_failure=False + ): """ :param use_pack_cache: True to cache which packs have been registered in memory and making sure packs are only registered once. @@ -81,10 +78,10 @@ def get_resources_from_pack(self, resources_dir): for ext in self.ALLOWED_EXTENSIONS: resources_glob = resources_dir - if resources_dir.endswith('/'): + if resources_dir.endswith("/"): resources_glob = resources_dir + ext else: - resources_glob = resources_dir + '/*' + ext + resources_glob = resources_dir + "/*" + ext resource_files = glob.glob(resources_glob) resources.extend(resource_files) @@ -121,7 +118,7 @@ def register_pack(self, pack_name, pack_dir): # This pack has already been registered during this register content run return - LOG.debug('Registering pack: %s' % (pack_name)) + LOG.debug("Registering pack: %s" % (pack_name)) REGISTERED_PACKS_CACHE[pack_name] = True try: @@ -148,19 +145,26 @@ def _register_pack(self, pack_name, pack_dir): # Display a warning if pack contains deprecated config.yaml file. Support for those files # will be fully removed in v2.4.0. - config_path = os.path.join(pack_dir, 'config.yaml') + config_path = os.path.join(pack_dir, "config.yaml") if os.path.isfile(config_path): - LOG.error('Pack "%s" contains a deprecated config.yaml file (%s). ' - 'Support for "config.yaml" files has been deprecated in StackStorm v1.6.0 ' - 'in favor of config.schema.yaml config schema files and config files in ' - '/opt/stackstorm/configs/ directory. Support for config.yaml files has ' - 'been removed in the release (v2.4.0) so please migrate. For more ' - 'information please refer to %s ' % (pack_db.name, config_path, - 'https://docs.stackstorm.com/reference/pack_configs.html')) + LOG.error( + 'Pack "%s" contains a deprecated config.yaml file (%s). ' + 'Support for "config.yaml" files has been deprecated in StackStorm v1.6.0 ' + "in favor of config.schema.yaml config schema files and config files in " + "/opt/stackstorm/configs/ directory. Support for config.yaml files has " + "been removed in the release (v2.4.0) so please migrate. For more " + "information please refer to %s " + % ( + pack_db.name, + config_path, + "https://docs.stackstorm.com/reference/pack_configs.html", + ) + ) # 2. Register corresponding pack config schema - config_schema_db = self._register_pack_config_schema_db(pack_name=pack_name, - pack_dir=pack_dir) + config_schema_db = self._register_pack_config_schema_db( + pack_name=pack_name, pack_dir=pack_dir + ) return pack_db, config_schema_db @@ -173,25 +177,28 @@ def _register_pack_db(self, pack_name, pack_dir): # 2hich are in sub-directories) # 2. If attribute is not available, but pack name is and pack name meets the valid name # criteria, we use that - content['ref'] = get_pack_ref_from_metadata(metadata=content, - pack_directory_name=pack_name) + content["ref"] = get_pack_ref_from_metadata( + metadata=content, pack_directory_name=pack_name + ) # Include a list of pack files - pack_file_list = get_file_list(directory=pack_dir, exclude_patterns=EXCLUDE_FILE_PATTERNS) - content['files'] = pack_file_list - content['path'] = pack_dir + pack_file_list = get_file_list( + directory=pack_dir, exclude_patterns=EXCLUDE_FILE_PATTERNS + ) + content["files"] = pack_file_list + content["path"] = pack_dir pack_api = PackAPI(**content) pack_api.validate() pack_db = PackAPI.to_model(pack_api) try: - pack_db.id = Pack.get_by_ref(content['ref']).id + pack_db.id = Pack.get_by_ref(content["ref"]).id except StackStormDBObjectNotFoundError: - LOG.debug('Pack %s not found. Creating new one.', pack_name) + LOG.debug("Pack %s not found. Creating new one.", pack_name) pack_db = Pack.add_or_update(pack_db) - LOG.debug('Pack %s registered.' % (pack_name)) + LOG.debug("Pack %s registered." % (pack_name)) return pack_db def _register_pack_config_schema_db(self, pack_name, pack_dir): @@ -204,11 +211,13 @@ def _register_pack_config_schema_db(self, pack_name, pack_dir): values = self._meta_loader.load(config_schema_path) if not values: - raise ValueError('Config schema "%s" is empty and invalid.' % (config_schema_path)) + raise ValueError( + 'Config schema "%s" is empty and invalid.' % (config_schema_path) + ) content = {} - content['pack'] = pack_name - content['attributes'] = values + content["pack"] = pack_name + content["attributes"] = values config_schema_api = ConfigSchemaAPI(**content) config_schema_api = config_schema_api.validate() @@ -217,8 +226,10 @@ def _register_pack_config_schema_db(self, pack_name, pack_dir): try: config_schema_db.id = ConfigSchema.get_by_pack(pack_name).id except StackStormDBObjectNotFoundError: - LOG.debug('Config schema for pack %s not found. Creating new one.', pack_name) + LOG.debug( + "Config schema for pack %s not found. Creating new one.", pack_name + ) config_schema_db = ConfigSchema.add_or_update(config_schema_db) - LOG.debug('Config schema for pack %s registered.' % (pack_name)) + LOG.debug("Config schema for pack %s registered." % (pack_name)) return config_schema_db diff --git a/st2common/st2common/bootstrap/configsregistrar.py b/st2common/st2common/bootstrap/configsregistrar.py index fc7e05eb986..3cbc5283fcd 100644 --- a/st2common/st2common/bootstrap/configsregistrar.py +++ b/st2common/st2common/bootstrap/configsregistrar.py @@ -28,9 +28,7 @@ from st2common.persistence.pack import Config from st2common.exceptions.db import StackStormDBObjectNotFoundError -__all__ = [ - 'ConfigsRegistrar' -] +__all__ = ["ConfigsRegistrar"] LOG = logging.getLogger(__name__) @@ -44,11 +42,18 @@ class ConfigsRegistrar(ResourceRegistrar): ALLOWED_EXTENSIONS = ALLOWED_EXTS - def __init__(self, use_pack_cache=True, use_runners_cache=False, fail_on_failure=False, - validate_configs=True): - super(ConfigsRegistrar, self).__init__(use_pack_cache=use_pack_cache, - use_runners_cache=use_runners_cache, - fail_on_failure=fail_on_failure) + def __init__( + self, + use_pack_cache=True, + use_runners_cache=False, + fail_on_failure=False, + validate_configs=True, + ): + super(ConfigsRegistrar, self).__init__( + use_pack_cache=use_pack_cache, + use_runners_cache=use_runners_cache, + fail_on_failure=fail_on_failure, + ) self._validate_configs = validate_configs @@ -68,21 +73,29 @@ def register_from_packs(self, base_dirs): if not os.path.isfile(config_path): # Config for that pack doesn't exist - LOG.debug('No config found for pack "%s" (file "%s" is not present).', pack_name, - config_path) + LOG.debug( + 'No config found for pack "%s" (file "%s" is not present).', + pack_name, + config_path, + ) continue try: self._register_config_for_pack(pack=pack_name, config_path=config_path) except Exception as e: if self._fail_on_failure: - msg = ('Failed to register config "%s" for pack "%s": %s' % (config_path, - pack_name, - six.text_type(e))) + msg = 'Failed to register config "%s" for pack "%s": %s' % ( + config_path, + pack_name, + six.text_type(e), + ) raise ValueError(msg) - LOG.exception('Failed to register config for pack "%s": %s', pack_name, - six.text_type(e)) + LOG.exception( + 'Failed to register config for pack "%s": %s', + pack_name, + six.text_type(e), + ) else: registered_count += 1 @@ -92,7 +105,7 @@ def register_from_pack(self, pack_dir): """ Register config for a provided pack. """ - pack_dir = pack_dir[:-1] if pack_dir.endswith('/') else pack_dir + pack_dir = pack_dir[:-1] if pack_dir.endswith("/") else pack_dir _, pack_name = os.path.split(pack_dir) # Register pack first @@ -106,8 +119,8 @@ def register_from_pack(self, pack_dir): return 1 def _get_config_path_for_pack(self, pack_name): - configs_path = os.path.join(cfg.CONF.system.base_path, 'configs/') - config_path = os.path.join(configs_path, '%s.yaml' % (pack_name)) + configs_path = os.path.join(cfg.CONF.system.base_path, "configs/") + config_path = os.path.join(configs_path, "%s.yaml" % (pack_name)) return config_path @@ -115,8 +128,8 @@ def _register_config_for_pack(self, pack, config_path): content = {} values = self._meta_loader.load(config_path) - content['pack'] = pack - content['values'] = values + content["pack"] = pack + content["values"] = values config_api = ConfigAPI(**content) config_api.validate(validate_against_schema=self._validate_configs) @@ -136,17 +149,22 @@ def save_model(config_api): try: config_db = Config.add_or_update(config_db) - extra = {'config_db': config_db} + extra = {"config_db": config_db} LOG.audit('Config for pack "%s" is updated.', config_db.pack, extra=extra) except Exception: - LOG.exception('Failed to save config for pack %s.', pack) + LOG.exception("Failed to save config for pack %s.", pack) raise return config_db -def register_configs(packs_base_paths=None, pack_dir=None, use_pack_cache=True, - fail_on_failure=False, validate_configs=True): +def register_configs( + packs_base_paths=None, + pack_dir=None, + use_pack_cache=True, + fail_on_failure=False, + validate_configs=True, +): if packs_base_paths: assert isinstance(packs_base_paths, list) @@ -154,9 +172,11 @@ def register_configs(packs_base_paths=None, pack_dir=None, use_pack_cache=True, if not packs_base_paths: packs_base_paths = content_utils.get_packs_base_paths() - registrar = ConfigsRegistrar(use_pack_cache=use_pack_cache, - fail_on_failure=fail_on_failure, - validate_configs=validate_configs) + registrar = ConfigsRegistrar( + use_pack_cache=use_pack_cache, + fail_on_failure=fail_on_failure, + validate_configs=validate_configs, + ) if pack_dir: result = registrar.register_from_pack(pack_dir=pack_dir) diff --git a/st2common/st2common/bootstrap/policiesregistrar.py b/st2common/st2common/bootstrap/policiesregistrar.py index b963eaf0979..4f6f2476946 100644 --- a/st2common/st2common/bootstrap/policiesregistrar.py +++ b/st2common/st2common/bootstrap/policiesregistrar.py @@ -30,11 +30,7 @@ from st2common.util import loader -__all__ = [ - 'PolicyRegistrar', - 'register_policy_types', - 'register_policies' -] +__all__ = ["PolicyRegistrar", "register_policy_types", "register_policies"] LOG = logging.getLogger(__name__) @@ -55,15 +51,18 @@ def register_from_packs(self, base_dirs): self.register_packs(base_dirs=base_dirs) registered_count = 0 - content = self._pack_loader.get_content(base_dirs=base_dirs, - content_type='policies') + content = self._pack_loader.get_content( + base_dirs=base_dirs, content_type="policies" + ) for pack, policies_dir in six.iteritems(content): if not policies_dir: - LOG.debug('Pack %s does not contain policies.', pack) + LOG.debug("Pack %s does not contain policies.", pack) continue try: - LOG.debug('Registering policies from pack %s:, dir: %s', pack, policies_dir) + LOG.debug( + "Registering policies from pack %s:, dir: %s", pack, policies_dir + ) policies = self._get_policies_from_pack(policies_dir) count = self._register_policies_from_pack(pack=pack, policies=policies) registered_count += count @@ -71,7 +70,9 @@ def register_from_packs(self, base_dirs): if self._fail_on_failure: raise e - LOG.exception('Failed registering all policies from pack: %s', policies_dir) + LOG.exception( + "Failed registering all policies from pack: %s", policies_dir + ) return registered_count @@ -82,11 +83,12 @@ def register_from_pack(self, pack_dir): :rtype: ``int`` """ - pack_dir = pack_dir[:-1] if pack_dir.endswith('/') else pack_dir + pack_dir = pack_dir[:-1] if pack_dir.endswith("/") else pack_dir _, pack = os.path.split(pack_dir) - policies_dir = self._pack_loader.get_content_from_pack(pack_dir=pack_dir, - content_type='policies') + policies_dir = self._pack_loader.get_content_from_pack( + pack_dir=pack_dir, content_type="policies" + ) # Register pack first self.register_pack(pack_name=pack, pack_dir=pack_dir) @@ -95,16 +97,18 @@ def register_from_pack(self, pack_dir): if not policies_dir: return registered_count - LOG.debug('Registering policies from pack %s, dir: %s', pack, policies_dir) + LOG.debug("Registering policies from pack %s, dir: %s", pack, policies_dir) try: policies = self._get_policies_from_pack(policies_dir=policies_dir) - registered_count = self._register_policies_from_pack(pack=pack, policies=policies) + registered_count = self._register_policies_from_pack( + pack=pack, policies=policies + ) except Exception as e: if self._fail_on_failure: raise e - LOG.exception('Failed registering all policies from pack: %s', policies_dir) + LOG.exception("Failed registering all policies from pack: %s", policies_dir) return registered_count return registered_count @@ -117,15 +121,18 @@ def _register_policies_from_pack(self, pack, policies): for policy in policies: try: - LOG.debug('Loading policy from %s.', policy) + LOG.debug("Loading policy from %s.", policy) self._register_policy(pack=pack, policy=policy) except Exception as e: if self._fail_on_failure: - msg = ('Failed to register policy "%s" from pack "%s": %s' % (policy, pack, - six.text_type(e))) + msg = 'Failed to register policy "%s" from pack "%s": %s' % ( + policy, + pack, + six.text_type(e), + ) raise ValueError(msg) - LOG.exception('Unable to register policy: %s', policy) + LOG.exception("Unable to register policy: %s", policy) continue else: registered_count += 1 @@ -134,20 +141,22 @@ def _register_policies_from_pack(self, pack, policies): def _register_policy(self, pack, policy): content = self._meta_loader.load(policy) - pack_field = content.get('pack', None) + pack_field = content.get("pack", None) if not pack_field: - content['pack'] = pack + content["pack"] = pack pack_field = pack if pack_field != pack: - raise Exception('Model is in pack "%s" but field "pack" is different: %s' % - (pack, pack_field)) + raise Exception( + 'Model is in pack "%s" but field "pack" is different: %s' + % (pack, pack_field) + ) # Add in "metadata_file" attribute which stores path to the pack metadata file relative to # the pack directory - metadata_file = content_utils.get_relative_path_to_pack_file(pack_ref=pack, - file_path=policy, - use_pack_cache=True) - content['metadata_file'] = metadata_file + metadata_file = content_utils.get_relative_path_to_pack_file( + pack_ref=pack, file_path=policy, use_pack_cache=True + ) + content["metadata_file"] = metadata_file policy_api = PolicyAPI(**content) policy_api = policy_api.validate() @@ -160,21 +169,21 @@ def _register_policy(self, pack, policy): try: policy_db = Policy.add_or_update(policy_db) - extra = {'policy_db': policy_db} + extra = {"policy_db": policy_db} LOG.audit('Policy "%s" is updated.', policy_db.ref, extra=extra) except Exception: - LOG.exception('Failed to create policy %s.', policy_api.name) + LOG.exception("Failed to create policy %s.", policy_api.name) raise def register_policy_types(module): registered_count = 0 mod_path = os.path.dirname(os.path.realpath(sys.modules[module.__name__].__file__)) - path = os.path.join(mod_path, 'policies/meta') + path = os.path.join(mod_path, "policies/meta") files = [] for ext in ALLOWED_EXTS: - exp = '%s/*%s' % (path, ext) + exp = "%s/*%s" % (path, ext) files += glob.glob(exp) for f in files: @@ -189,11 +198,13 @@ def register_policy_types(module): if existing_entry: policy_type_db.id = existing_entry.id except StackStormDBObjectNotFoundError: - LOG.debug('Policy type "%s" is not found. Creating new entry.', - policy_type_db.ref) + LOG.debug( + 'Policy type "%s" is not found. Creating new entry.', + policy_type_db.ref, + ) policy_type_db = PolicyType.add_or_update(policy_type_db) - extra = {'policy_type_db': policy_type_db} + extra = {"policy_type_db": policy_type_db} LOG.audit('Policy type "%s" is updated.', policy_type_db.ref, extra=extra) registered_count += 1 @@ -203,16 +214,18 @@ def register_policy_types(module): return registered_count -def register_policies(packs_base_paths=None, pack_dir=None, use_pack_cache=True, - fail_on_failure=False): +def register_policies( + packs_base_paths=None, pack_dir=None, use_pack_cache=True, fail_on_failure=False +): if packs_base_paths: assert isinstance(packs_base_paths, list) if not packs_base_paths: packs_base_paths = content_utils.get_packs_base_paths() - registrar = PolicyRegistrar(use_pack_cache=use_pack_cache, - fail_on_failure=fail_on_failure) + registrar = PolicyRegistrar( + use_pack_cache=use_pack_cache, fail_on_failure=fail_on_failure + ) if pack_dir: result = registrar.register_from_pack(pack_dir=pack_dir) diff --git a/st2common/st2common/bootstrap/rulesregistrar.py b/st2common/st2common/bootstrap/rulesregistrar.py index c50b0d5eae4..505f3e53374 100644 --- a/st2common/st2common/bootstrap/rulesregistrar.py +++ b/st2common/st2common/bootstrap/rulesregistrar.py @@ -25,14 +25,14 @@ from st2common.models.api.rule import RuleAPI from st2common.models.system.common import ResourceReference from st2common.persistence.rule import Rule -from st2common.services.triggers import cleanup_trigger_db_for_rule, increment_trigger_ref_count +from st2common.services.triggers import ( + cleanup_trigger_db_for_rule, + increment_trigger_ref_count, +) from st2common.exceptions.db import StackStormDBObjectNotFoundError import st2common.content.utils as content_utils -__all__ = [ - 'RulesRegistrar', - 'register_rules' -] +__all__ = ["RulesRegistrar", "register_rules"] LOG = logging.getLogger(__name__) @@ -49,14 +49,15 @@ def register_from_packs(self, base_dirs): self.register_packs(base_dirs=base_dirs) registered_count = 0 - content = self._pack_loader.get_content(base_dirs=base_dirs, - content_type='rules') + content = self._pack_loader.get_content( + base_dirs=base_dirs, content_type="rules" + ) for pack, rules_dir in six.iteritems(content): if not rules_dir: - LOG.debug('Pack %s does not contain rules.', pack) + LOG.debug("Pack %s does not contain rules.", pack) continue try: - LOG.debug('Registering rules from pack: %s', pack) + LOG.debug("Registering rules from pack: %s", pack) rules = self._get_rules_from_pack(rules_dir) count = self._register_rules_from_pack(pack, rules) registered_count += count @@ -64,7 +65,7 @@ def register_from_packs(self, base_dirs): if self._fail_on_failure: raise e - LOG.exception('Failed registering all rules from pack: %s', rules_dir) + LOG.exception("Failed registering all rules from pack: %s", rules_dir) return registered_count @@ -75,10 +76,11 @@ def register_from_pack(self, pack_dir): :return: Number of rules registered. :rtype: ``int`` """ - pack_dir = pack_dir[:-1] if pack_dir.endswith('/') else pack_dir + pack_dir = pack_dir[:-1] if pack_dir.endswith("/") else pack_dir _, pack = os.path.split(pack_dir) - rules_dir = self._pack_loader.get_content_from_pack(pack_dir=pack_dir, - content_type='rules') + rules_dir = self._pack_loader.get_content_from_pack( + pack_dir=pack_dir, content_type="rules" + ) # Register pack first self.register_pack(pack_name=pack, pack_dir=pack_dir) @@ -87,7 +89,7 @@ def register_from_pack(self, pack_dir): if not rules_dir: return registered_count - LOG.debug('Registering rules from pack %s:, dir: %s', pack, rules_dir) + LOG.debug("Registering rules from pack %s:, dir: %s", pack, rules_dir) try: rules = self._get_rules_from_pack(rules_dir=rules_dir) @@ -96,7 +98,7 @@ def register_from_pack(self, pack_dir): if self._fail_on_failure: raise e - LOG.exception('Failed registering all rules from pack: %s', rules_dir) + LOG.exception("Failed registering all rules from pack: %s", rules_dir) return registered_count @@ -108,21 +110,23 @@ def _register_rules_from_pack(self, pack, rules): # TODO: Refactor this monstrosity for rule in rules: - LOG.debug('Loading rule from %s.', rule) + LOG.debug("Loading rule from %s.", rule) try: content = self._meta_loader.load(rule) - pack_field = content.get('pack', None) + pack_field = content.get("pack", None) if not pack_field: - content['pack'] = pack + content["pack"] = pack pack_field = pack if pack_field != pack: - raise Exception('Model is in pack "%s" but field "pack" is different: %s' % - (pack, pack_field)) + raise Exception( + 'Model is in pack "%s" but field "pack" is different: %s' + % (pack, pack_field) + ) - metadata_file = content_utils.get_relative_path_to_pack_file(pack_ref=pack, - file_path=rule, - use_pack_cache=True) - content['metadata_file'] = metadata_file + metadata_file = content_utils.get_relative_path_to_pack_file( + pack_ref=pack, file_path=rule, use_pack_cache=True + ) + content["metadata_file"] = metadata_file rule_api = RuleAPI(**content) rule_api.validate() @@ -134,35 +138,48 @@ def _register_rules_from_pack(self, pack, rules): # delete so we don't have duplicates. if pack_field != DEFAULT_PACK_NAME: try: - rule_ref = ResourceReference.to_string_reference(name=content['name'], - pack=DEFAULT_PACK_NAME) - LOG.debug('Looking for rule %s in pack %s', content['name'], - DEFAULT_PACK_NAME) + rule_ref = ResourceReference.to_string_reference( + name=content["name"], pack=DEFAULT_PACK_NAME + ) + LOG.debug( + "Looking for rule %s in pack %s", + content["name"], + DEFAULT_PACK_NAME, + ) existing = Rule.get_by_ref(rule_ref) - LOG.debug('Existing = %s', existing) + LOG.debug("Existing = %s", existing) if existing: - LOG.debug('Found rule in pack default: %s; Deleting.', rule_ref) + LOG.debug( + "Found rule in pack default: %s; Deleting.", rule_ref + ) Rule.delete(existing) except: - LOG.exception('Exception deleting rule from %s pack.', DEFAULT_PACK_NAME) + LOG.exception( + "Exception deleting rule from %s pack.", DEFAULT_PACK_NAME + ) try: - rule_ref = ResourceReference.to_string_reference(name=content['name'], - pack=content['pack']) + rule_ref = ResourceReference.to_string_reference( + name=content["name"], pack=content["pack"] + ) existing = Rule.get_by_ref(rule_ref) if existing: rule_db.id = existing.id - LOG.debug('Found existing rule: %s with id: %s', rule_ref, existing.id) + LOG.debug( + "Found existing rule: %s with id: %s", rule_ref, existing.id + ) except StackStormDBObjectNotFoundError: - LOG.debug('Rule %s not found. Creating new one.', rule) + LOG.debug("Rule %s not found. Creating new one.", rule) try: rule_db = Rule.add_or_update(rule_db) increment_trigger_ref_count(rule_api=rule_api) - extra = {'rule_db': rule_db} - LOG.audit('Rule updated. Rule %s from %s.', rule_db, rule, extra=extra) + extra = {"rule_db": rule_db} + LOG.audit( + "Rule updated. Rule %s from %s.", rule_db, rule, extra=extra + ) except Exception: - LOG.exception('Failed to create rule %s.', rule_api.name) + LOG.exception("Failed to create rule %s.", rule_api.name) # If there was an existing rule then the ref count was updated in # to_model so it needs to be adjusted down here. Also, update could @@ -171,27 +188,32 @@ def _register_rules_from_pack(self, pack, rules): cleanup_trigger_db_for_rule(existing) except Exception as e: if self._fail_on_failure: - msg = ('Failed to register rule "%s" from pack "%s": %s' % (rule, pack, - six.text_type(e))) + msg = 'Failed to register rule "%s" from pack "%s": %s' % ( + rule, + pack, + six.text_type(e), + ) raise ValueError(msg) - LOG.exception('Failed registering rule from %s.', rule) + LOG.exception("Failed registering rule from %s.", rule) else: registered_count += 1 return registered_count -def register_rules(packs_base_paths=None, pack_dir=None, use_pack_cache=True, - fail_on_failure=False): +def register_rules( + packs_base_paths=None, pack_dir=None, use_pack_cache=True, fail_on_failure=False +): if packs_base_paths: assert isinstance(packs_base_paths, list) if not packs_base_paths: packs_base_paths = content_utils.get_packs_base_paths() - registrar = RulesRegistrar(use_pack_cache=use_pack_cache, - fail_on_failure=fail_on_failure) + registrar = RulesRegistrar( + use_pack_cache=use_pack_cache, fail_on_failure=fail_on_failure + ) if pack_dir: result = registrar.register_from_pack(pack_dir=pack_dir) diff --git a/st2common/st2common/bootstrap/ruletypesregistrar.py b/st2common/st2common/bootstrap/ruletypesregistrar.py index 735294cd23b..90d4018a40b 100644 --- a/st2common/st2common/bootstrap/ruletypesregistrar.py +++ b/st2common/st2common/bootstrap/ruletypesregistrar.py @@ -22,41 +22,36 @@ from st2common.persistence.rule import RuleType from st2common.exceptions.db import StackStormDBObjectNotFoundError -__all__ = [ - 'register_rule_types', - 'RULE_TYPES' -] +__all__ = ["register_rule_types", "RULE_TYPES"] LOG = logging.getLogger(__name__) RULE_TYPES = [ { - 'name': RULE_TYPE_STANDARD, - 'description': 'standard rule that is always applicable.', - 'enabled': True, - 'parameters': { - } + "name": RULE_TYPE_STANDARD, + "description": "standard rule that is always applicable.", + "enabled": True, + "parameters": {}, }, { - 'name': RULE_TYPE_BACKSTOP, - 'description': 'Rule that applies when no other rule has matched for a specific Trigger.', - 'enabled': True, - 'parameters': { - } + "name": RULE_TYPE_BACKSTOP, + "description": "Rule that applies when no other rule has matched for a specific Trigger.", + "enabled": True, + "parameters": {}, }, ] def register_rule_types(): - LOG.debug('Start : register default RuleTypes.') + LOG.debug("Start : register default RuleTypes.") registered_count = 0 for rule_type in RULE_TYPES: rule_type = copy.deepcopy(rule_type) try: - rule_type_db = RuleType.get_by_name(rule_type['name']) + rule_type_db = RuleType.get_by_name(rule_type["name"]) update = True except StackStormDBObjectNotFoundError: rule_type_db = None @@ -72,16 +67,16 @@ def register_rule_types(): try: rule_type_db = RuleType.add_or_update(rule_type_model) - extra = {'rule_type_db': rule_type_db} + extra = {"rule_type_db": rule_type_db} if update: - LOG.audit('RuleType updated. RuleType %s', rule_type_db, extra=extra) + LOG.audit("RuleType updated. RuleType %s", rule_type_db, extra=extra) else: - LOG.audit('RuleType created. RuleType %s', rule_type_db, extra=extra) + LOG.audit("RuleType created. RuleType %s", rule_type_db, extra=extra) except Exception: - LOG.exception('Unable to register RuleType %s.', rule_type['name']) + LOG.exception("Unable to register RuleType %s.", rule_type["name"]) else: registered_count += 1 - LOG.debug('End : register default RuleTypes.') + LOG.debug("End : register default RuleTypes.") return registered_count diff --git a/st2common/st2common/bootstrap/runnersregistrar.py b/st2common/st2common/bootstrap/runnersregistrar.py index 3aa93da9b1f..bb993894330 100644 --- a/st2common/st2common/bootstrap/runnersregistrar.py +++ b/st2common/st2common/bootstrap/runnersregistrar.py @@ -26,7 +26,7 @@ from st2common.util.action_db import get_runnertype_by_name __all__ = [ - 'register_runner_types', + "register_runner_types", ] @@ -37,7 +37,7 @@ def register_runners(experimental=False, fail_on_failure=True): """ Register runners """ - LOG.debug('Start : register runners') + LOG.debug("Start : register runners") runner_count = 0 manager = ExtensionManager(namespace=RUNNERS_NAMESPACE, invoke_on_load=False) @@ -46,28 +46,30 @@ def register_runners(experimental=False, fail_on_failure=True): for name in extension_names: LOG.debug('Found runner "%s"' % (name)) - manager = DriverManager(namespace=RUNNERS_NAMESPACE, invoke_on_load=False, name=name) + manager = DriverManager( + namespace=RUNNERS_NAMESPACE, invoke_on_load=False, name=name + ) runner_metadata = manager.driver.get_metadata() runner_count += register_runner(runner_metadata, experimental) - LOG.debug('End : register runners') + LOG.debug("End : register runners") return runner_count def register_runner(runner_type, experimental): # For backward compatibility reasons, we also register runners under the old names - runner_names = [runner_type['name']] + runner_type.get('aliases', []) + runner_names = [runner_type["name"]] + runner_type.get("aliases", []) for runner_name in runner_names: - runner_type['name'] = runner_name - runner_experimental = runner_type.get('experimental', False) + runner_type["name"] = runner_name + runner_experimental = runner_type.get("experimental", False) if runner_experimental and not experimental: LOG.debug('Skipping experimental runner "%s"' % (runner_name)) continue # Remove additional, non db-model attributes - non_db_attributes = ['experimental', 'aliases'] + non_db_attributes = ["experimental", "aliases"] for attribute in non_db_attributes: if attribute in runner_type: del runner_type[attribute] @@ -81,13 +83,13 @@ def register_runner(runner_type, experimental): # Note: We don't want to overwrite "enabled" attribute which is already in the database # (aka we don't want to re-enable runner which has been disabled by the user) - if runner_type_db and runner_type_db['enabled'] != runner_type['enabled']: - runner_type['enabled'] = runner_type_db['enabled'] + if runner_type_db and runner_type_db["enabled"] != runner_type["enabled"]: + runner_type["enabled"] = runner_type_db["enabled"] # If package is not provided, assume it's the same as module name for backward # compatibility reasons - if not runner_type.get('runner_package', None): - runner_type['runner_package'] = runner_type['runner_module'] + if not runner_type.get("runner_package", None): + runner_type["runner_package"] = runner_type["runner_module"] runner_type_api = RunnerTypeAPI(**runner_type) runner_type_api.validate() @@ -100,13 +102,17 @@ def register_runner(runner_type, experimental): runner_type_db = RunnerType.add_or_update(runner_type_model) - extra = {'runner_type_db': runner_type_db} + extra = {"runner_type_db": runner_type_db} if update: - LOG.audit('RunnerType updated. RunnerType %s', runner_type_db, extra=extra) + LOG.audit( + "RunnerType updated. RunnerType %s", runner_type_db, extra=extra + ) else: - LOG.audit('RunnerType created. RunnerType %s', runner_type_db, extra=extra) + LOG.audit( + "RunnerType created. RunnerType %s", runner_type_db, extra=extra + ) except Exception: - LOG.exception('Unable to register runner type %s.', runner_type['name']) + LOG.exception("Unable to register runner type %s.", runner_type["name"]) return 0 return 1 diff --git a/st2common/st2common/bootstrap/sensorsregistrar.py b/st2common/st2common/bootstrap/sensorsregistrar.py index 5181270d793..8a91e23eeaf 100644 --- a/st2common/st2common/bootstrap/sensorsregistrar.py +++ b/st2common/st2common/bootstrap/sensorsregistrar.py @@ -26,10 +26,7 @@ from st2common.models.api.sensor import SensorTypeAPI from st2common.persistence.sensor import SensorType -__all__ = [ - 'SensorsRegistrar', - 'register_sensors' -] +__all__ = ["SensorsRegistrar", "register_sensors"] LOG = logging.getLogger(__name__) @@ -51,15 +48,18 @@ def register_from_packs(self, base_dirs): self.register_packs(base_dirs=base_dirs) registered_count = 0 - content = self._pack_loader.get_content(base_dirs=base_dirs, - content_type='sensors') + content = self._pack_loader.get_content( + base_dirs=base_dirs, content_type="sensors" + ) for pack, sensors_dir in six.iteritems(content): if not sensors_dir: - LOG.debug('Pack %s does not contain sensors.', pack) + LOG.debug("Pack %s does not contain sensors.", pack) continue try: - LOG.debug('Registering sensors from pack %s:, dir: %s', pack, sensors_dir) + LOG.debug( + "Registering sensors from pack %s:, dir: %s", pack, sensors_dir + ) sensors = self._get_sensors_from_pack(sensors_dir) count = self._register_sensors_from_pack(pack=pack, sensors=sensors) registered_count += count @@ -67,8 +67,11 @@ def register_from_packs(self, base_dirs): if self._fail_on_failure: raise e - LOG.exception('Failed registering all sensors from pack "%s": %s', sensors_dir, - six.text_type(e)) + LOG.exception( + 'Failed registering all sensors from pack "%s": %s', + sensors_dir, + six.text_type(e), + ) return registered_count @@ -79,10 +82,11 @@ def register_from_pack(self, pack_dir): :return: Number of sensors registered. :rtype: ``int`` """ - pack_dir = pack_dir[:-1] if pack_dir.endswith('/') else pack_dir + pack_dir = pack_dir[:-1] if pack_dir.endswith("/") else pack_dir _, pack = os.path.split(pack_dir) - sensors_dir = self._pack_loader.get_content_from_pack(pack_dir=pack_dir, - content_type='sensors') + sensors_dir = self._pack_loader.get_content_from_pack( + pack_dir=pack_dir, content_type="sensors" + ) # Register pack first self.register_pack(pack_name=pack, pack_dir=pack_dir) @@ -91,17 +95,22 @@ def register_from_pack(self, pack_dir): if not sensors_dir: return registered_count - LOG.debug('Registering sensors from pack %s:, dir: %s', pack, sensors_dir) + LOG.debug("Registering sensors from pack %s:, dir: %s", pack, sensors_dir) try: sensors = self._get_sensors_from_pack(sensors_dir=sensors_dir) - registered_count = self._register_sensors_from_pack(pack=pack, sensors=sensors) + registered_count = self._register_sensors_from_pack( + pack=pack, sensors=sensors + ) except Exception as e: if self._fail_on_failure: raise e - LOG.exception('Failed registering all sensors from pack "%s": %s', sensors_dir, - six.text_type(e)) + LOG.exception( + 'Failed registering all sensors from pack "%s": %s', + sensors_dir, + six.text_type(e), + ) return registered_count @@ -115,11 +124,16 @@ def _register_sensors_from_pack(self, pack, sensors): self._register_sensor_from_pack(pack=pack, sensor=sensor) except Exception as e: if self._fail_on_failure: - msg = ('Failed to register sensor "%s" from pack "%s": %s' % (sensor, pack, - six.text_type(e))) + msg = 'Failed to register sensor "%s" from pack "%s": %s' % ( + sensor, + pack, + six.text_type(e), + ) raise ValueError(msg) - LOG.debug('Failed to register sensor "%s": %s', sensor, six.text_type(e)) + LOG.debug( + 'Failed to register sensor "%s": %s', sensor, six.text_type(e) + ) else: LOG.debug('Sensor "%s" successfully registered', sensor) registered_count += 1 @@ -129,33 +143,35 @@ def _register_sensors_from_pack(self, pack, sensors): def _register_sensor_from_pack(self, pack, sensor): sensor_metadata_file_path = sensor - LOG.debug('Loading sensor from %s.', sensor_metadata_file_path) + LOG.debug("Loading sensor from %s.", sensor_metadata_file_path) content = self._meta_loader.load(file_path=sensor_metadata_file_path) - pack_field = content.get('pack', None) + pack_field = content.get("pack", None) if not pack_field: - content['pack'] = pack + content["pack"] = pack pack_field = pack if pack_field != pack: - raise Exception('Model is in pack "%s" but field "pack" is different: %s' % - (pack, pack_field)) + raise Exception( + 'Model is in pack "%s" but field "pack" is different: %s' + % (pack, pack_field) + ) - entry_point = content.get('entry_point', None) + entry_point = content.get("entry_point", None) if not entry_point: - raise ValueError('Sensor definition missing entry_point') + raise ValueError("Sensor definition missing entry_point") # Add in "metadata_file" attribute which stores path to the pack metadata file relative to # the pack directory - metadata_file = content_utils.get_relative_path_to_pack_file(pack_ref=pack, - file_path=sensor, - use_pack_cache=True) - content['metadata_file'] = metadata_file + metadata_file = content_utils.get_relative_path_to_pack_file( + pack_ref=pack, file_path=sensor, use_pack_cache=True + ) + content["metadata_file"] = metadata_file sensors_dir = os.path.dirname(sensor_metadata_file_path) sensor_file_path = os.path.join(sensors_dir, entry_point) - artifact_uri = 'file://%s' % (sensor_file_path) - content['artifact_uri'] = artifact_uri - content['entry_point'] = entry_point + artifact_uri = "file://%s" % (sensor_file_path) + content["artifact_uri"] = artifact_uri + content["entry_point"] = entry_point sensor_api = SensorTypeAPI(**content) sensor_model = SensorTypeAPI.to_model(sensor_api) @@ -163,28 +179,33 @@ def _register_sensor_from_pack(self, pack, sensor): sensor_types = SensorType.query(pack=sensor_model.pack, name=sensor_model.name) if len(sensor_types) >= 1: sensor_type = sensor_types[0] - LOG.debug('Found existing sensor id:%s with name:%s. Will update it.', - sensor_type.id, sensor_type.name) + LOG.debug( + "Found existing sensor id:%s with name:%s. Will update it.", + sensor_type.id, + sensor_type.name, + ) sensor_model.id = sensor_type.id try: sensor_model = SensorType.add_or_update(sensor_model) except: - LOG.exception('Failed creating sensor model for %s', sensor) + LOG.exception("Failed creating sensor model for %s", sensor) return sensor_model -def register_sensors(packs_base_paths=None, pack_dir=None, use_pack_cache=True, - fail_on_failure=False): +def register_sensors( + packs_base_paths=None, pack_dir=None, use_pack_cache=True, fail_on_failure=False +): if packs_base_paths: assert isinstance(packs_base_paths, list) if not packs_base_paths: packs_base_paths = content_utils.get_packs_base_paths() - registrar = SensorsRegistrar(use_pack_cache=use_pack_cache, - fail_on_failure=fail_on_failure) + registrar = SensorsRegistrar( + use_pack_cache=use_pack_cache, fail_on_failure=fail_on_failure + ) if pack_dir: result = registrar.register_from_pack(pack_dir=pack_dir) diff --git a/st2common/st2common/bootstrap/triggersregistrar.py b/st2common/st2common/bootstrap/triggersregistrar.py index 4f95a6d0a3f..180c9cb8856 100644 --- a/st2common/st2common/bootstrap/triggersregistrar.py +++ b/st2common/st2common/bootstrap/triggersregistrar.py @@ -24,10 +24,7 @@ import st2common.content.utils as content_utils from st2common.models.utils import sensor_type_utils -__all__ = [ - 'TriggersRegistrar', - 'register_triggers' -] +__all__ = ["TriggersRegistrar", "register_triggers"] LOG = logging.getLogger(__name__) @@ -47,15 +44,18 @@ def register_from_packs(self, base_dirs): self.register_packs(base_dirs=base_dirs) registered_count = 0 - content = self._pack_loader.get_content(base_dirs=base_dirs, - content_type='triggers') + content = self._pack_loader.get_content( + base_dirs=base_dirs, content_type="triggers" + ) for pack, triggers_dir in six.iteritems(content): if not triggers_dir: - LOG.debug('Pack %s does not contain triggers.', pack) + LOG.debug("Pack %s does not contain triggers.", pack) continue try: - LOG.debug('Registering triggers from pack %s:, dir: %s', pack, triggers_dir) + LOG.debug( + "Registering triggers from pack %s:, dir: %s", pack, triggers_dir + ) triggers = self._get_triggers_from_pack(triggers_dir) count = self._register_triggers_from_pack(pack=pack, triggers=triggers) registered_count += count @@ -63,8 +63,11 @@ def register_from_packs(self, base_dirs): if self._fail_on_failure: raise e - LOG.exception('Failed registering all triggers from pack "%s": %s', triggers_dir, - six.text_type(e)) + LOG.exception( + 'Failed registering all triggers from pack "%s": %s', + triggers_dir, + six.text_type(e), + ) return registered_count @@ -75,10 +78,11 @@ def register_from_pack(self, pack_dir): :return: Number of triggers registered. :rtype: ``int`` """ - pack_dir = pack_dir[:-1] if pack_dir.endswith('/') else pack_dir + pack_dir = pack_dir[:-1] if pack_dir.endswith("/") else pack_dir _, pack = os.path.split(pack_dir) - triggers_dir = self._pack_loader.get_content_from_pack(pack_dir=pack_dir, - content_type='triggers') + triggers_dir = self._pack_loader.get_content_from_pack( + pack_dir=pack_dir, content_type="triggers" + ) # Register pack first self.register_pack(pack_name=pack, pack_dir=pack_dir) @@ -87,17 +91,22 @@ def register_from_pack(self, pack_dir): if not triggers_dir: return registered_count - LOG.debug('Registering triggers from pack %s:, dir: %s', pack, triggers_dir) + LOG.debug("Registering triggers from pack %s:, dir: %s", pack, triggers_dir) try: triggers = self._get_triggers_from_pack(triggers_dir=triggers_dir) - registered_count = self._register_triggers_from_pack(pack=pack, triggers=triggers) + registered_count = self._register_triggers_from_pack( + pack=pack, triggers=triggers + ) except Exception as e: if self._fail_on_failure: raise e - LOG.exception('Failed registering all triggers from pack "%s": %s', triggers_dir, - six.text_type(e)) + LOG.exception( + 'Failed registering all triggers from pack "%s": %s', + triggers_dir, + six.text_type(e), + ) return registered_count @@ -107,20 +116,27 @@ def _get_triggers_from_pack(self, triggers_dir): def _register_triggers_from_pack(self, pack, triggers): registered_count = 0 - pack_base_path = content_utils.get_pack_base_path(pack_name=pack, - include_trailing_slash=True) + pack_base_path = content_utils.get_pack_base_path( + pack_name=pack, include_trailing_slash=True + ) for trigger in triggers: try: - self._register_trigger_from_pack(pack_base_path=pack_base_path, pack=pack, - trigger=trigger) + self._register_trigger_from_pack( + pack_base_path=pack_base_path, pack=pack, trigger=trigger + ) except Exception as e: if self._fail_on_failure: - msg = ('Failed to register trigger "%s" from pack "%s": %s' % (trigger, pack, - six.text_type(e))) + msg = 'Failed to register trigger "%s" from pack "%s": %s' % ( + trigger, + pack, + six.text_type(e), + ) raise ValueError(msg) - LOG.debug('Failed to register trigger "%s": %s', trigger, six.text_type(e)) + LOG.debug( + 'Failed to register trigger "%s": %s', trigger, six.text_type(e) + ) else: LOG.debug('Trigger "%s" successfully registered', trigger) registered_count += 1 @@ -130,37 +146,41 @@ def _register_triggers_from_pack(self, pack, triggers): def _register_trigger_from_pack(self, pack_base_path, pack, trigger): trigger_metadata_file_path = trigger - LOG.debug('Loading trigger from %s.', trigger_metadata_file_path) + LOG.debug("Loading trigger from %s.", trigger_metadata_file_path) content = self._meta_loader.load(file_path=trigger_metadata_file_path) - pack_field = content.get('pack', None) + pack_field = content.get("pack", None) if not pack_field: - content['pack'] = pack + content["pack"] = pack pack_field = pack if pack_field != pack: - raise Exception('Model is in pack "%s" but field "pack" is different: %s' % - (pack, pack_field)) + raise Exception( + 'Model is in pack "%s" but field "pack" is different: %s' + % (pack, pack_field) + ) # Add in "metadata_file" attribute which stores path to the pack metadata file relative to # the pack directory - metadata_file = trigger.replace(pack_base_path, '') - content['metadata_file'] = metadata_file + metadata_file = trigger.replace(pack_base_path, "") + content["metadata_file"] = metadata_file trigger_types = [content] result = sensor_type_utils.create_trigger_types(trigger_types=trigger_types) return result[0] if result else None -def register_triggers(packs_base_paths=None, pack_dir=None, use_pack_cache=True, - fail_on_failure=False): +def register_triggers( + packs_base_paths=None, pack_dir=None, use_pack_cache=True, fail_on_failure=False +): if packs_base_paths: assert isinstance(packs_base_paths, list) if not packs_base_paths: packs_base_paths = content_utils.get_packs_base_paths() - registrar = TriggersRegistrar(use_pack_cache=use_pack_cache, - fail_on_failure=fail_on_failure) + registrar = TriggersRegistrar( + use_pack_cache=use_pack_cache, fail_on_failure=fail_on_failure + ) if pack_dir: result = registrar.register_from_pack(pack_dir=pack_dir) diff --git a/st2common/st2common/callback/base.py b/st2common/st2common/callback/base.py index ae1b55e5010..a48fcbecb96 100644 --- a/st2common/st2common/callback/base.py +++ b/st2common/st2common/callback/base.py @@ -21,7 +21,7 @@ __all__ = [ - 'AsyncActionExecutionCallbackHandler', + "AsyncActionExecutionCallbackHandler", ] @@ -30,7 +30,6 @@ @six.add_metaclass(abc.ABCMeta) class AsyncActionExecutionCallbackHandler(object): - @staticmethod @abc.abstractmethod def callback(liveaction): diff --git a/st2common/st2common/cmd/download_pack.py b/st2common/st2common/cmd/download_pack.py index b22a3f0467e..5ef0fb72b94 100644 --- a/st2common/st2common/cmd/download_pack.py +++ b/st2common/st2common/cmd/download_pack.py @@ -24,23 +24,34 @@ from st2common.util.pack_management import download_pack from st2common.util.pack_management import get_and_set_proxy_config -__all__ = [ - 'main' -] +__all__ = ["main"] LOG = logging.getLogger(__name__) def _register_cli_opts(): cli_opts = [ - cfg.MultiStrOpt('pack', default=None, required=True, positional=True, - help='Name of the pack to install (download).'), - cfg.BoolOpt('verify-ssl', default=True, - help=('Verify SSL certificate of the Git repo from which the pack is ' - 'installed.')), - cfg.BoolOpt('force', default=False, - help='True to force pack download and ignore download ' - 'lock file if it exists.'), + cfg.MultiStrOpt( + "pack", + default=None, + required=True, + positional=True, + help="Name of the pack to install (download).", + ), + cfg.BoolOpt( + "verify-ssl", + default=True, + help=( + "Verify SSL certificate of the Git repo from which the pack is " + "installed." + ), + ), + cfg.BoolOpt( + "force", + default=False, + help="True to force pack download and ignore download " + "lock file if it exists.", + ), ] do_register_cli_opts(cli_opts) @@ -49,8 +60,12 @@ def main(argv): _register_cli_opts() # Parse CLI args, set up logging - common_setup(config=config, setup_db=False, register_mq_exchanges=False, - register_internal_trigger_types=False) + common_setup( + config=config, + setup_db=False, + register_mq_exchanges=False, + register_internal_trigger_types=False, + ) packs = cfg.CONF.pack verify_ssl = cfg.CONF.verify_ssl @@ -60,8 +75,13 @@ def main(argv): for pack in packs: LOG.info('Installing pack "%s"' % (pack)) - result = download_pack(pack=pack, verify_ssl=verify_ssl, force=force, - proxy_config=proxy_config, force_permissions=True) + result = download_pack( + pack=pack, + verify_ssl=verify_ssl, + force=force, + proxy_config=proxy_config, + force_permissions=True, + ) # Raw pack name excluding the version pack_name = result[1] diff --git a/st2common/st2common/cmd/generate_api_spec.py b/st2common/st2common/cmd/generate_api_spec.py index 1b0a65ec8f3..7ff7757b71d 100644 --- a/st2common/st2common/cmd/generate_api_spec.py +++ b/st2common/st2common/cmd/generate_api_spec.py @@ -25,9 +25,7 @@ from st2common.script_setup import setup as common_setup from st2common.script_setup import teardown as common_teardown -__all__ = [ - 'main' -] +__all__ = ["main"] LOG = logging.getLogger(__name__) @@ -37,7 +35,7 @@ def setup(): def generate_spec(): - spec_string = spec_loader.generate_spec('st2common', 'openapi.yaml.j2') + spec_string = spec_loader.generate_spec("st2common", "openapi.yaml.j2") print(spec_string) @@ -52,7 +50,7 @@ def main(): generate_spec() ret = 0 except Exception: - LOG.error('Failed to generate openapi.yaml file', exc_info=True) + LOG.error("Failed to generate openapi.yaml file", exc_info=True) ret = 1 finally: teartown() diff --git a/st2common/st2common/cmd/install_pack.py b/st2common/st2common/cmd/install_pack.py index 861d0d4041b..42c2267012b 100644 --- a/st2common/st2common/cmd/install_pack.py +++ b/st2common/st2common/cmd/install_pack.py @@ -25,23 +25,34 @@ from st2common.util.pack_management import get_and_set_proxy_config from st2common.util.virtualenvs import setup_pack_virtualenv -__all__ = [ - 'main' -] +__all__ = ["main"] LOG = logging.getLogger(__name__) def _register_cli_opts(): cli_opts = [ - cfg.MultiStrOpt('pack', default=None, required=True, positional=True, - help='Name of the pack to install.'), - cfg.BoolOpt('verify-ssl', default=True, - help=('Verify SSL certificate of the Git repo from which the pack is ' - 'downloaded.')), - cfg.BoolOpt('force', default=False, - help='True to force pack installation and ignore install ' - 'lock file if it exists.'), + cfg.MultiStrOpt( + "pack", + default=None, + required=True, + positional=True, + help="Name of the pack to install.", + ), + cfg.BoolOpt( + "verify-ssl", + default=True, + help=( + "Verify SSL certificate of the Git repo from which the pack is " + "downloaded." + ), + ), + cfg.BoolOpt( + "force", + default=False, + help="True to force pack installation and ignore install " + "lock file if it exists.", + ), ] do_register_cli_opts(cli_opts) @@ -50,8 +61,12 @@ def main(argv): _register_cli_opts() # Parse CLI args, set up logging - common_setup(config=config, setup_db=False, register_mq_exchanges=False, - register_internal_trigger_types=False) + common_setup( + config=config, + setup_db=False, + register_mq_exchanges=False, + register_internal_trigger_types=False, + ) packs = cfg.CONF.pack verify_ssl = cfg.CONF.verify_ssl @@ -62,8 +77,13 @@ def main(argv): for pack in packs: # 1. Download the pack LOG.info('Installing pack "%s"' % (pack)) - result = download_pack(pack=pack, verify_ssl=verify_ssl, force=force, - proxy_config=proxy_config, force_permissions=True) + result = download_pack( + pack=pack, + verify_ssl=verify_ssl, + force=force, + proxy_config=proxy_config, + force_permissions=True, + ) # Raw pack name excluding the version pack_name = result[1] @@ -78,9 +98,13 @@ def main(argv): # 2. Setup pack virtual environment LOG.info('Setting up virtualenv for pack "%s"' % (pack_name)) - setup_pack_virtualenv(pack_name=pack_name, update=False, logger=LOG, - proxy_config=proxy_config, - no_download=True) + setup_pack_virtualenv( + pack_name=pack_name, + update=False, + logger=LOG, + proxy_config=proxy_config, + no_download=True, + ) LOG.info('Successfully set up virtualenv for pack "%s"' % (pack_name)) return 0 diff --git a/st2common/st2common/cmd/purge_executions.py b/st2common/st2common/cmd/purge_executions.py index dcf7b47b403..27225d661c9 100755 --- a/st2common/st2common/cmd/purge_executions.py +++ b/st2common/st2common/cmd/purge_executions.py @@ -38,25 +38,30 @@ from st2common.constants.exit_codes import FAILURE_EXIT_CODE from st2common.garbage_collection.executions import purge_executions -__all__ = [ - 'main' -] +__all__ = ["main"] LOG = logging.getLogger(__name__) def _register_cli_opts(): cli_opts = [ - cfg.StrOpt('timestamp', default=None, - help='Will delete execution and liveaction models older than ' + - 'this UTC timestamp. ' + - 'Example value: 2015-03-13T19:01:27.255542Z.'), - cfg.StrOpt('action-ref', default='', - help='action-ref to delete executions for.'), - cfg.BoolOpt('purge-incomplete', default=False, - help='Purge all models irrespective of their ``status``.' + - 'By default, only executions in completed states such as "succeeeded" ' + - ', "failed", "canceled" and "timed_out" are deleted.'), + cfg.StrOpt( + "timestamp", + default=None, + help="Will delete execution and liveaction models older than " + + "this UTC timestamp. " + + "Example value: 2015-03-13T19:01:27.255542Z.", + ), + cfg.StrOpt( + "action-ref", default="", help="action-ref to delete executions for." + ), + cfg.BoolOpt( + "purge-incomplete", + default=False, + help="Purge all models irrespective of their ``status``." + + 'By default, only executions in completed states such as "succeeeded" ' + + ', "failed", "canceled" and "timed_out" are deleted.', + ), ] do_register_cli_opts(cli_opts) @@ -71,15 +76,19 @@ def main(): purge_incomplete = cfg.CONF.purge_incomplete if not timestamp: - LOG.error('Please supply a timestamp for purging models. Aborting.') + LOG.error("Please supply a timestamp for purging models. Aborting.") return 1 else: - timestamp = datetime.strptime(timestamp, '%Y-%m-%dT%H:%M:%S.%fZ') + timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%fZ") timestamp = timestamp.replace(tzinfo=pytz.UTC) try: - purge_executions(logger=LOG, timestamp=timestamp, action_ref=action_ref, - purge_incomplete=purge_incomplete) + purge_executions( + logger=LOG, + timestamp=timestamp, + action_ref=action_ref, + purge_incomplete=purge_incomplete, + ) except Exception as e: LOG.exception(six.text_type(e)) return FAILURE_EXIT_CODE diff --git a/st2common/st2common/cmd/purge_trigger_instances.py b/st2common/st2common/cmd/purge_trigger_instances.py index e0908e9f8d4..529b786678e 100755 --- a/st2common/st2common/cmd/purge_trigger_instances.py +++ b/st2common/st2common/cmd/purge_trigger_instances.py @@ -38,19 +38,20 @@ from st2common.constants.exit_codes import FAILURE_EXIT_CODE from st2common.garbage_collection.trigger_instances import purge_trigger_instances -__all__ = [ - 'main' -] +__all__ = ["main"] LOG = logging.getLogger(__name__) def _register_cli_opts(): cli_opts = [ - cfg.StrOpt('timestamp', default=None, - help='Will delete trigger instances older than ' + - 'this UTC timestamp. ' + - 'Example value: 2015-03-13T19:01:27.255542Z') + cfg.StrOpt( + "timestamp", + default=None, + help="Will delete trigger instances older than " + + "this UTC timestamp. " + + "Example value: 2015-03-13T19:01:27.255542Z", + ) ] do_register_cli_opts(cli_opts) @@ -63,10 +64,10 @@ def main(): timestamp = cfg.CONF.timestamp if not timestamp: - LOG.error('Please supply a timestamp for purging models. Aborting.') + LOG.error("Please supply a timestamp for purging models. Aborting.") return 1 else: - timestamp = datetime.strptime(timestamp, '%Y-%m-%dT%H:%M:%S.%fZ') + timestamp = datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%fZ") timestamp = timestamp.replace(tzinfo=pytz.UTC) # Purge models. diff --git a/st2common/st2common/cmd/setup_pack_virtualenv.py b/st2common/st2common/cmd/setup_pack_virtualenv.py index 626bb389af0..514b1cf2e0f 100644 --- a/st2common/st2common/cmd/setup_pack_virtualenv.py +++ b/st2common/st2common/cmd/setup_pack_virtualenv.py @@ -22,23 +22,31 @@ from st2common.util.pack_management import get_and_set_proxy_config from st2common.util.virtualenvs import setup_pack_virtualenv -__all__ = [ - 'main' -] +__all__ = ["main"] LOG = logging.getLogger(__name__) def _register_cli_opts(): cli_opts = [ - cfg.MultiStrOpt('pack', default=None, required=True, positional=True, - help='Name of the pack to setup the virtual environment for.'), - cfg.BoolOpt('update', default=False, - help=('Check this option if the virtual environment already exists and if you ' - 'only want to perform an update and installation of new dependencies. If ' - 'you don\'t check this option, the virtual environment will be destroyed ' - 'then re-created. If you check this and the virtual environment doesn\'t ' - 'exist, it will create it..')), + cfg.MultiStrOpt( + "pack", + default=None, + required=True, + positional=True, + help="Name of the pack to setup the virtual environment for.", + ), + cfg.BoolOpt( + "update", + default=False, + help=( + "Check this option if the virtual environment already exists and if you " + "only want to perform an update and installation of new dependencies. If " + "you don't check this option, the virtual environment will be destroyed " + "then re-created. If you check this and the virtual environment doesn't " + "exist, it will create it.." + ), + ), ] do_register_cli_opts(cli_opts) @@ -47,8 +55,12 @@ def main(argv): _register_cli_opts() # Parse CLI args, set up logging - common_setup(config=config, setup_db=False, register_mq_exchanges=False, - register_internal_trigger_types=False) + common_setup( + config=config, + setup_db=False, + register_mq_exchanges=False, + register_internal_trigger_types=False, + ) packs = cfg.CONF.pack update = cfg.CONF.update @@ -58,9 +70,13 @@ def main(argv): for pack in packs: # Setup pack virtual environment LOG.info('Setting up virtualenv for pack "%s"' % (pack)) - setup_pack_virtualenv(pack_name=pack, update=update, logger=LOG, - proxy_config=proxy_config, - no_download=True) + setup_pack_virtualenv( + pack_name=pack, + update=update, + logger=LOG, + proxy_config=proxy_config, + no_download=True, + ) LOG.info('Successfully set up virtualenv for pack "%s"' % (pack)) return 0 diff --git a/st2common/st2common/cmd/validate_api_spec.py b/st2common/st2common/cmd/validate_api_spec.py index 743b3e467a3..4f317db4a4c 100644 --- a/st2common/st2common/cmd/validate_api_spec.py +++ b/st2common/st2common/cmd/validate_api_spec.py @@ -33,19 +33,20 @@ import six -__all__ = [ - 'main' -] +__all__ = ["main"] cfg.CONF.register_cli_opt( - cfg.StrOpt('spec-file', short='f', required=False, - default='st2common/st2common/openapi.yaml') + cfg.StrOpt( + "spec-file", + short="f", + required=False, + default="st2common/st2common/openapi.yaml", + ) ) cfg.CONF.register_cli_opt( - cfg.BoolOpt('generate', short='-c', required=False, - default=False) + cfg.BoolOpt("generate", short="-c", required=False, default=False) ) LOG = logging.getLogger(__name__) @@ -56,12 +57,12 @@ def setup(): def _validate_definitions(spec): - defs = spec.get('definitions', None) + defs = spec.get("definitions", None) error = False verbose = cfg.CONF.verbose for (model, definition) in six.iteritems(defs): - api_model = definition.get('x-api-model', None) + api_model = definition.get("x-api-model", None) if not api_model: msg = ( @@ -69,7 +70,7 @@ def _validate_definitions(spec): ) if verbose: - LOG.info('Supplied definition for model %s: \n\n%s.', model, definition) + LOG.info("Supplied definition for model %s: \n\n%s.", model, definition) error = True LOG.error(msg) @@ -82,18 +83,20 @@ def validate_spec(): generate_spec = cfg.CONF.generate if not os.path.exists(spec_file) and not generate_spec: - msg = ('No spec file found in location %s. ' % spec_file + - 'Provide a valid spec file or ' + - 'pass --generate-api-spec to genrate a spec.') + msg = ( + "No spec file found in location %s. " % spec_file + + "Provide a valid spec file or " + + "pass --generate-api-spec to genrate a spec." + ) raise Exception(msg) if generate_spec: if not spec_file: - raise Exception('Supply a path to write to spec file to.') + raise Exception("Supply a path to write to spec file to.") - spec_string = spec_loader.generate_spec('st2common', 'openapi.yaml.j2') + spec_string = spec_loader.generate_spec("st2common", "openapi.yaml.j2") - with open(spec_file, 'w') as f: + with open(spec_file, "w") as f: f.write(spec_string) f.flush() @@ -112,13 +115,15 @@ def main(): try: # 1. Validate there are no duplicates keys in the YAML file - spec_loader.load_spec('st2common', 'openapi.yaml.j2', allow_duplicate_keys=False) + spec_loader.load_spec( + "st2common", "openapi.yaml.j2", allow_duplicate_keys=False + ) # 2. Validate schema (currently broken) # validate_spec() ret = 0 except Exception: - LOG.error('Failed to validate openapi.yaml file', exc_info=True) + LOG.error("Failed to validate openapi.yaml file", exc_info=True) ret = 1 finally: teartown() diff --git a/st2common/st2common/cmd/validate_config.py b/st2common/st2common/cmd/validate_config.py index 2bd5b58d0de..6b7bedd32f1 100644 --- a/st2common/st2common/cmd/validate_config.py +++ b/st2common/st2common/cmd/validate_config.py @@ -31,9 +31,7 @@ from st2common.constants.exit_codes import FAILURE_EXIT_CODE from st2common.util.pack import validate_config_against_schema -__all__ = [ - 'main' -] +__all__ = ["main"] def _do_register_cli_opts(opts, ignore_errors=False): @@ -47,10 +45,18 @@ def _do_register_cli_opts(opts, ignore_errors=False): def _register_cli_opts(): cli_opts = [ - cfg.StrOpt('schema-path', default=None, required=True, - help='Path to the config schema to use for validation.'), - cfg.StrOpt('config-path', default=None, required=True, - help='Path to the config file to validate.'), + cfg.StrOpt( + "schema-path", + default=None, + required=True, + help="Path to the config schema to use for validation.", + ), + cfg.StrOpt( + "config-path", + default=None, + required=True, + help="Path to the config file to validate.", + ), ] do_register_cli_opts(cli_opts) @@ -65,18 +71,24 @@ def main(): print('Validating config "%s" against schema in "%s"' % (config_path, schema_path)) - with open(schema_path, 'r') as fp: + with open(schema_path, "r") as fp: config_schema = yaml.safe_load(fp.read()) - with open(config_path, 'r') as fp: + with open(config_path, "r") as fp: config_object = yaml.safe_load(fp.read()) try: - validate_config_against_schema(config_schema=config_schema, config_object=config_object, - config_path=config_path) + validate_config_against_schema( + config_schema=config_schema, + config_object=config_object, + config_path=config_path, + ) except Exception as e: - print('Failed to validate pack config.\n%s' % six.text_type(e)) + print("Failed to validate pack config.\n%s" % six.text_type(e)) return FAILURE_EXIT_CODE - print('Config "%s" successfully validated against schema in %s.' % (config_path, schema_path)) + print( + 'Config "%s" successfully validated against schema in %s.' + % (config_path, schema_path) + ) return SUCCESS_EXIT_CODE diff --git a/st2common/st2common/config.py b/st2common/st2common/config.py index 8ad77fa626d..e7b30a9a7cf 100644 --- a/st2common/st2common/config.py +++ b/st2common/st2common/config.py @@ -25,12 +25,7 @@ from st2common.constants.runners import PYTHON_RUNNER_DEFAULT_LOG_LEVEL from st2common.constants.action import LIVEACTION_COMPLETED_STATES -__all__ = [ - 'do_register_opts', - 'do_register_cli_opts', - - 'parse_args' -] +__all__ = ["do_register_opts", "do_register_cli_opts", "parse_args"] def do_register_opts(opts, group=None, ignore_errors=False): @@ -57,447 +52,550 @@ def do_register_cli_opts(opt, ignore_errors=False): def register_opts(ignore_errors=False): rbac_opts = [ + cfg.BoolOpt("enable", default=False, help="Enable RBAC."), + cfg.StrOpt("backend", default="noop", help="RBAC backend to use."), cfg.BoolOpt( - 'enable', default=False, - help='Enable RBAC.'), - cfg.StrOpt( - 'backend', default='noop', - help='RBAC backend to use.'), - cfg.BoolOpt( - 'sync_remote_groups', default=False, - help='True to synchronize remote groups returned by the auth backed for each ' - 'StackStorm user with local StackStorm roles based on the group to role ' - 'mapping definition files.'), + "sync_remote_groups", + default=False, + help="True to synchronize remote groups returned by the auth backed for each " + "StackStorm user with local StackStorm roles based on the group to role " + "mapping definition files.", + ), cfg.BoolOpt( - 'permission_isolation', default=False, - help='Isolate resources by user. For now, these resources only include rules and ' - 'executions. All resources can only be viewed or executed by the owning user ' - 'except the admin and system_user who can view or run everything.') + "permission_isolation", + default=False, + help="Isolate resources by user. For now, these resources only include rules and " + "executions. All resources can only be viewed or executed by the owning user " + "except the admin and system_user who can view or run everything.", + ), ] - do_register_opts(rbac_opts, 'rbac', ignore_errors) + do_register_opts(rbac_opts, "rbac", ignore_errors) system_user_opts = [ + cfg.StrOpt("user", default="stanley", help="Default system user."), cfg.StrOpt( - 'user', default='stanley', - help='Default system user.'), - cfg.StrOpt( - 'ssh_key_file', default='/home/stanley/.ssh/stanley_rsa', - help='SSH private key for the system user.') + "ssh_key_file", + default="/home/stanley/.ssh/stanley_rsa", + help="SSH private key for the system user.", + ), ] - do_register_opts(system_user_opts, 'system_user', ignore_errors) + do_register_opts(system_user_opts, "system_user", ignore_errors) schema_opts = [ - cfg.IntOpt( - 'version', default=4, - help='Version of JSON schema to use.'), + cfg.IntOpt("version", default=4, help="Version of JSON schema to use."), cfg.StrOpt( - 'draft', default='http://json-schema.org/draft-04/schema#', - help='URL to the JSON schema draft.') + "draft", + default="http://json-schema.org/draft-04/schema#", + help="URL to the JSON schema draft.", + ), ] - do_register_opts(schema_opts, 'schema', ignore_errors) + do_register_opts(schema_opts, "schema", ignore_errors) system_opts = [ - cfg.BoolOpt( - 'debug', default=False, - help='Enable debug mode.'), + cfg.BoolOpt("debug", default=False, help="Enable debug mode."), cfg.StrOpt( - 'base_path', default='/opt/stackstorm', - help='Base path to all st2 artifacts.'), + "base_path", + default="/opt/stackstorm", + help="Base path to all st2 artifacts.", + ), cfg.BoolOpt( - 'validate_trigger_parameters', default=True, - help='True to validate parameters for non-system trigger types when creating' - 'a rule. By default, only parameters for system triggers are validated.'), + "validate_trigger_parameters", + default=True, + help="True to validate parameters for non-system trigger types when creating" + "a rule. By default, only parameters for system triggers are validated.", + ), cfg.BoolOpt( - 'validate_trigger_payload', default=True, - help='True to validate payload for non-system trigger types when dispatching a trigger ' - 'inside the sensor. By default, only payload for system triggers is validated.'), + "validate_trigger_payload", + default=True, + help="True to validate payload for non-system trigger types when dispatching a trigger " + "inside the sensor. By default, only payload for system triggers is validated.", + ), cfg.BoolOpt( - 'validate_output_schema', default=False, - help='True to validate action and runner output against schema.') + "validate_output_schema", + default=False, + help="True to validate action and runner output against schema.", + ), ] - do_register_opts(system_opts, 'system', ignore_errors) + do_register_opts(system_opts, "system", ignore_errors) - system_packs_base_path = os.path.join(cfg.CONF.system.base_path, 'packs') - system_runners_base_path = os.path.join(cfg.CONF.system.base_path, 'runners') + system_packs_base_path = os.path.join(cfg.CONF.system.base_path, "packs") + system_runners_base_path = os.path.join(cfg.CONF.system.base_path, "runners") content_opts = [ cfg.StrOpt( - 'pack_group', default='st2packs', - help='User group that can write to packs directory.'), - cfg.StrOpt( - 'system_packs_base_path', default=system_packs_base_path, - help='Path to the directory which contains system packs.'), - cfg.StrOpt( - 'system_runners_base_path', default=system_runners_base_path, - help='Path to the directory which contains system runners. ' - 'NOTE: This option has been deprecated and it\'s unused since StackStorm v3.0.0'), - cfg.StrOpt( - 'packs_base_paths', default=None, - help='Paths which will be searched for integration packs.'), - cfg.StrOpt( - 'runners_base_paths', default=None, - help='Paths which will be searched for runners. ' - 'NOTE: This option has been deprecated and it\'s unused since StackStorm v3.0.0'), + "pack_group", + default="st2packs", + help="User group that can write to packs directory.", + ), + cfg.StrOpt( + "system_packs_base_path", + default=system_packs_base_path, + help="Path to the directory which contains system packs.", + ), + cfg.StrOpt( + "system_runners_base_path", + default=system_runners_base_path, + help="Path to the directory which contains system runners. " + "NOTE: This option has been deprecated and it's unused since StackStorm v3.0.0", + ), + cfg.StrOpt( + "packs_base_paths", + default=None, + help="Paths which will be searched for integration packs.", + ), + cfg.StrOpt( + "runners_base_paths", + default=None, + help="Paths which will be searched for runners. " + "NOTE: This option has been deprecated and it's unused since StackStorm v3.0.0", + ), cfg.ListOpt( - 'index_url', default=['https://index.stackstorm.org/v1/index.json'], - help='A URL pointing to the pack index. StackStorm Exchange is used by ' - 'default. Use a comma-separated list for multiple indexes if you ' - 'want to get other packs discovered with "st2 pack search".'), + "index_url", + default=["https://index.stackstorm.org/v1/index.json"], + help="A URL pointing to the pack index. StackStorm Exchange is used by " + "default. Use a comma-separated list for multiple indexes if you " + 'want to get other packs discovered with "st2 pack search".', + ), ] - do_register_opts(content_opts, 'content', ignore_errors) + do_register_opts(content_opts, "content", ignore_errors) webui_opts = [ cfg.StrOpt( - 'webui_base_url', default='https://%s' % socket.getfqdn(), - help='Base https URL to access st2 Web UI. This is used to construct history URLs ' - 'that are sent out when chatops is used to kick off executions.') + "webui_base_url", + default="https://%s" % socket.getfqdn(), + help="Base https URL to access st2 Web UI. This is used to construct history URLs " + "that are sent out when chatops is used to kick off executions.", + ) ] - do_register_opts(webui_opts, 'webui', ignore_errors) + do_register_opts(webui_opts, "webui", ignore_errors) db_opts = [ - cfg.StrOpt( - 'host', default='127.0.0.1', - help='host of db server'), + cfg.StrOpt("host", default="127.0.0.1", help="host of db server"), + cfg.IntOpt("port", default=27017, help="port of db server"), + cfg.StrOpt("db_name", default="st2", help="name of database"), + cfg.StrOpt("username", help="username for db login"), + cfg.StrOpt("password", help="password for db login"), cfg.IntOpt( - 'port', default=27017, - help='port of db server'), - cfg.StrOpt( - 'db_name', default='st2', - help='name of database'), - cfg.StrOpt( - 'username', - help='username for db login'), - cfg.StrOpt( - 'password', - help='password for db login'), + "connection_timeout", + default=3 * 1000, + help="Connection and server selection timeout (in ms).", + ), cfg.IntOpt( - 'connection_timeout', default=3 * 1000, - help='Connection and server selection timeout (in ms).'), + "connection_retry_max_delay_m", + default=3, + help="Connection retry total time (minutes).", + ), cfg.IntOpt( - 'connection_retry_max_delay_m', default=3, - help='Connection retry total time (minutes).'), + "connection_retry_backoff_max_s", + default=10, + help="Connection retry backoff max (seconds).", + ), cfg.IntOpt( - 'connection_retry_backoff_max_s', default=10, - help='Connection retry backoff max (seconds).'), - cfg.IntOpt( - 'connection_retry_backoff_mul', default=1, - help='Backoff multiplier (seconds).'), + "connection_retry_backoff_mul", + default=1, + help="Backoff multiplier (seconds).", + ), cfg.BoolOpt( - 'ssl', default=False, - help='Create the connection to mongodb using SSL'), - cfg.StrOpt( - 'ssl_keyfile', default=None, - help='Private keyfile used to identify the local connection against MongoDB.'), - cfg.StrOpt( - 'ssl_certfile', default=None, - help='Certificate file used to identify the localconnection'), - cfg.StrOpt( - 'ssl_cert_reqs', default=None, choices='none, optional, required', - help='Specifies whether a certificate is required from the other side of the ' - 'connection, and whether it will be validated if provided'), - cfg.StrOpt( - 'ssl_ca_certs', default=None, - help='ca_certs file contains a set of concatenated CA certificates, which are ' - 'used to validate certificates passed from MongoDB.'), + "ssl", default=False, help="Create the connection to mongodb using SSL" + ), + cfg.StrOpt( + "ssl_keyfile", + default=None, + help="Private keyfile used to identify the local connection against MongoDB.", + ), + cfg.StrOpt( + "ssl_certfile", + default=None, + help="Certificate file used to identify the localconnection", + ), + cfg.StrOpt( + "ssl_cert_reqs", + default=None, + choices="none, optional, required", + help="Specifies whether a certificate is required from the other side of the " + "connection, and whether it will be validated if provided", + ), + cfg.StrOpt( + "ssl_ca_certs", + default=None, + help="ca_certs file contains a set of concatenated CA certificates, which are " + "used to validate certificates passed from MongoDB.", + ), cfg.BoolOpt( - 'ssl_match_hostname', default=True, - help='If True and `ssl_cert_reqs` is not None, enables hostname verification'), - cfg.StrOpt( - 'authentication_mechanism', default=None, - help='Specifies database authentication mechanisms. ' - 'By default, it use SCRAM-SHA-1 with MongoDB 3.0 and later, ' - 'MONGODB-CR (MongoDB Challenge Response protocol) for older servers.') + "ssl_match_hostname", + default=True, + help="If True and `ssl_cert_reqs` is not None, enables hostname verification", + ), + cfg.StrOpt( + "authentication_mechanism", + default=None, + help="Specifies database authentication mechanisms. " + "By default, it use SCRAM-SHA-1 with MongoDB 3.0 and later, " + "MONGODB-CR (MongoDB Challenge Response protocol) for older servers.", + ), ] - do_register_opts(db_opts, 'database', ignore_errors) + do_register_opts(db_opts, "database", ignore_errors) messaging_opts = [ # It would be nice to be able to deprecate url and completely switch to using # url. However, this will be a breaking change and will have impact so allowing both. cfg.StrOpt( - 'url', default='amqp://guest:guest@127.0.0.1:5672//', - help='URL of the messaging server.'), + "url", + default="amqp://guest:guest@127.0.0.1:5672//", + help="URL of the messaging server.", + ), cfg.ListOpt( - 'cluster_urls', default=[], - help='URL of all the nodes in a messaging service cluster.'), + "cluster_urls", + default=[], + help="URL of all the nodes in a messaging service cluster.", + ), cfg.IntOpt( - 'connection_retries', default=10, - help='How many times should we retry connection before failing.'), + "connection_retries", + default=10, + help="How many times should we retry connection before failing.", + ), cfg.IntOpt( - 'connection_retry_wait', default=10000, - help='How long should we wait between connection retries.'), + "connection_retry_wait", + default=10000, + help="How long should we wait between connection retries.", + ), cfg.BoolOpt( - 'ssl', default=False, - help='Use SSL / TLS to connect to the messaging server. Same as ' - 'appending "?ssl=true" at the end of the connection URL string.'), - cfg.StrOpt( - 'ssl_keyfile', default=None, - help='Private keyfile used to identify the local connection against RabbitMQ.'), - cfg.StrOpt( - 'ssl_certfile', default=None, - help='Certificate file used to identify the local connection (client).'), - cfg.StrOpt( - 'ssl_cert_reqs', default=None, choices='none, optional, required', - help='Specifies whether a certificate is required from the other side of the ' - 'connection, and whether it will be validated if provided.'), - cfg.StrOpt( - 'ssl_ca_certs', default=None, - help='ca_certs file contains a set of concatenated CA certificates, which are ' - 'used to validate certificates passed from RabbitMQ.'), - cfg.StrOpt( - 'login_method', default=None, - help='Login method to use (AMQPLAIN, PLAIN, EXTERNAL, etc.).') + "ssl", + default=False, + help="Use SSL / TLS to connect to the messaging server. Same as " + 'appending "?ssl=true" at the end of the connection URL string.', + ), + cfg.StrOpt( + "ssl_keyfile", + default=None, + help="Private keyfile used to identify the local connection against RabbitMQ.", + ), + cfg.StrOpt( + "ssl_certfile", + default=None, + help="Certificate file used to identify the local connection (client).", + ), + cfg.StrOpt( + "ssl_cert_reqs", + default=None, + choices="none, optional, required", + help="Specifies whether a certificate is required from the other side of the " + "connection, and whether it will be validated if provided.", + ), + cfg.StrOpt( + "ssl_ca_certs", + default=None, + help="ca_certs file contains a set of concatenated CA certificates, which are " + "used to validate certificates passed from RabbitMQ.", + ), + cfg.StrOpt( + "login_method", + default=None, + help="Login method to use (AMQPLAIN, PLAIN, EXTERNAL, etc.).", + ), ] - do_register_opts(messaging_opts, 'messaging', ignore_errors) + do_register_opts(messaging_opts, "messaging", ignore_errors) syslog_opts = [ + cfg.StrOpt("host", default="127.0.0.1", help="Host for the syslog server."), + cfg.IntOpt("port", default=514, help="Port for the syslog server."), + cfg.StrOpt("facility", default="local7", help="Syslog facility level."), cfg.StrOpt( - 'host', default='127.0.0.1', - help='Host for the syslog server.'), - cfg.IntOpt( - 'port', default=514, - help='Port for the syslog server.'), - cfg.StrOpt( - 'facility', default='local7', - help='Syslog facility level.'), - cfg.StrOpt( - 'protocol', default='udp', - help='Transport protocol to use (udp / tcp).') + "protocol", default="udp", help="Transport protocol to use (udp / tcp)." + ), ] - do_register_opts(syslog_opts, 'syslog', ignore_errors) + do_register_opts(syslog_opts, "syslog", ignore_errors) log_opts = [ - cfg.ListOpt( - 'excludes', default='', - help='Exclusion list of loggers to omit.'), + cfg.ListOpt("excludes", default="", help="Exclusion list of loggers to omit."), cfg.BoolOpt( - 'redirect_stderr', default=False, - help='Controls if stderr should be redirected to the logs.'), + "redirect_stderr", + default=False, + help="Controls if stderr should be redirected to the logs.", + ), cfg.BoolOpt( - 'mask_secrets', default=True, - help='True to mask secrets in the log files.'), + "mask_secrets", default=True, help="True to mask secrets in the log files." + ), cfg.ListOpt( - 'mask_secrets_blacklist', default=[], - help='Blacklist of additional attribute names to mask in the log messages.') + "mask_secrets_blacklist", + default=[], + help="Blacklist of additional attribute names to mask in the log messages.", + ), ] - do_register_opts(log_opts, 'log', ignore_errors) + do_register_opts(log_opts, "log", ignore_errors) # Common API options api_opts = [ - cfg.StrOpt( - 'host', default='127.0.0.1', - help='StackStorm API server host'), - cfg.IntOpt( - 'port', default=9101, - help='StackStorm API server port'), + cfg.StrOpt("host", default="127.0.0.1", help="StackStorm API server host"), + cfg.IntOpt("port", default=9101, help="StackStorm API server port"), cfg.ListOpt( - 'allow_origin', default=['http://127.0.0.1:3000'], - help='List of origins allowed for api, auth and stream'), + "allow_origin", + default=["http://127.0.0.1:3000"], + help="List of origins allowed for api, auth and stream", + ), cfg.BoolOpt( - 'mask_secrets', default=True, - help='True to mask secrets in the API responses') + "mask_secrets", + default=True, + help="True to mask secrets in the API responses", + ), ] - do_register_opts(api_opts, 'api', ignore_errors) + do_register_opts(api_opts, "api", ignore_errors) # Key Value store options keyvalue_opts = [ cfg.BoolOpt( - 'enable_encryption', default=True, - help='Allow encryption of values in key value stored qualified as "secret".'), - cfg.StrOpt( - 'encryption_key_path', default='', - help='Location of the symmetric encryption key for encrypting values in kvstore. ' - 'This key should be in JSON and should\'ve been generated using ' - 'st2-generate-symmetric-crypto-key tool.') + "enable_encryption", + default=True, + help='Allow encryption of values in key value stored qualified as "secret".', + ), + cfg.StrOpt( + "encryption_key_path", + default="", + help="Location of the symmetric encryption key for encrypting values in kvstore. " + "This key should be in JSON and should've been generated using " + "st2-generate-symmetric-crypto-key tool.", + ), ] - do_register_opts(keyvalue_opts, group='keyvalue') + do_register_opts(keyvalue_opts, group="keyvalue") # Common auth options auth_opts = [ cfg.StrOpt( - 'api_url', default=None, - help='Base URL to the API endpoint excluding the version'), - cfg.BoolOpt( - 'enable', default=True, - help='Enable authentication middleware.'), + "api_url", + default=None, + help="Base URL to the API endpoint excluding the version", + ), + cfg.BoolOpt("enable", default=True, help="Enable authentication middleware."), cfg.IntOpt( - 'token_ttl', default=(24 * 60 * 60), - help='Access token ttl in seconds.'), + "token_ttl", default=(24 * 60 * 60), help="Access token ttl in seconds." + ), # This TTL is used for tokens which belong to StackStorm services cfg.IntOpt( - 'service_token_ttl', default=(24 * 60 * 60), - help='Service token ttl in seconds.') + "service_token_ttl", + default=(24 * 60 * 60), + help="Service token ttl in seconds.", + ), ] - do_register_opts(auth_opts, 'auth', ignore_errors) + do_register_opts(auth_opts, "auth", ignore_errors) # Runner options default_python_bin_path = sys.executable base_dir = os.path.dirname(os.path.realpath(default_python_bin_path)) - default_virtualenv_bin_path = os.path.join(base_dir, 'virtualenv') + default_virtualenv_bin_path = os.path.join(base_dir, "virtualenv") action_runner_opts = [ # Common runner options cfg.StrOpt( - 'logging', default='/etc/st2/logging.actionrunner.conf', - help='location of the logging.conf file'), - + "logging", + default="/etc/st2/logging.actionrunner.conf", + help="location of the logging.conf file", + ), # Python runner options cfg.StrOpt( - 'python_binary', default=default_python_bin_path, - help='Python binary which will be used by Python actions.'), - cfg.StrOpt( - 'virtualenv_binary', default=default_virtualenv_bin_path, - help='Virtualenv binary which should be used to create pack virtualenvs.'), - cfg.StrOpt( - 'python_runner_log_level', default=PYTHON_RUNNER_DEFAULT_LOG_LEVEL, - help='Default log level to use for Python runner actions. Can be overriden on ' - 'invocation basis using "log_level" runner parameter.'), + "python_binary", + default=default_python_bin_path, + help="Python binary which will be used by Python actions.", + ), + cfg.StrOpt( + "virtualenv_binary", + default=default_virtualenv_bin_path, + help="Virtualenv binary which should be used to create pack virtualenvs.", + ), + cfg.StrOpt( + "python_runner_log_level", + default=PYTHON_RUNNER_DEFAULT_LOG_LEVEL, + help="Default log level to use for Python runner actions. Can be overriden on " + 'invocation basis using "log_level" runner parameter.', + ), cfg.ListOpt( - 'virtualenv_opts', default=['--system-site-packages'], + "virtualenv_opts", + default=["--system-site-packages"], help='List of virtualenv options to be passsed to "virtualenv" command that ' - 'creates pack virtualenv.'), + "creates pack virtualenv.", + ), cfg.ListOpt( - 'pip_opts', default=[], + "pip_opts", + default=[], help='List of pip options to be passed to "pip install" command when installing pack ' - 'dependencies into pack virtual environment.'), + "dependencies into pack virtual environment.", + ), cfg.BoolOpt( - 'stream_output', default=True, - help='True to store and stream action output (stdout and stderr) in real-time.'), + "stream_output", + default=True, + help="True to store and stream action output (stdout and stderr) in real-time.", + ), cfg.IntOpt( - 'stream_output_buffer_size', default=-1, - help=('Buffer size to use for real time action output streaming. 0 means unbuffered ' - '1 means line buffered, -1 means system default, which usually means fully ' - 'buffered and any other positive value means use a buffer of (approximately) ' - 'that size')) + "stream_output_buffer_size", + default=-1, + help=( + "Buffer size to use for real time action output streaming. 0 means unbuffered " + "1 means line buffered, -1 means system default, which usually means fully " + "buffered and any other positive value means use a buffer of (approximately) " + "that size" + ), + ), ] - do_register_opts(action_runner_opts, group='actionrunner') + do_register_opts(action_runner_opts, group="actionrunner") dispatcher_pool_opts = [ cfg.IntOpt( - 'workflows_pool_size', default=40, - help='Internal pool size for dispatcher used by workflow actions.'), + "workflows_pool_size", + default=40, + help="Internal pool size for dispatcher used by workflow actions.", + ), cfg.IntOpt( - 'actions_pool_size', default=60, - help='Internal pool size for dispatcher used by regular actions.') + "actions_pool_size", + default=60, + help="Internal pool size for dispatcher used by regular actions.", + ), ] - do_register_opts(dispatcher_pool_opts, group='actionrunner') + do_register_opts(dispatcher_pool_opts, group="actionrunner") ssh_runner_opts = [ cfg.StrOpt( - 'remote_dir', default='/tmp', - help='Location of the script on the remote filesystem.'), + "remote_dir", + default="/tmp", + help="Location of the script on the remote filesystem.", + ), cfg.BoolOpt( - 'allow_partial_failure', default=False, - help='How partial success of actions run on multiple nodes should be treated.'), + "allow_partial_failure", + default=False, + help="How partial success of actions run on multiple nodes should be treated.", + ), cfg.IntOpt( - 'max_parallel_actions', default=50, - help='Max number of parallel remote SSH actions that should be run. ' - 'Works only with Paramiko SSH runner.'), + "max_parallel_actions", + default=50, + help="Max number of parallel remote SSH actions that should be run. " + "Works only with Paramiko SSH runner.", + ), cfg.BoolOpt( - 'use_ssh_config', default=False, - help='Use the .ssh/config file. Useful to override ports etc.'), - cfg.StrOpt( - 'ssh_config_file_path', default='~/.ssh/config', - help='Path to the ssh config file.'), + "use_ssh_config", + default=False, + help="Use the .ssh/config file. Useful to override ports etc.", + ), + cfg.StrOpt( + "ssh_config_file_path", + default="~/.ssh/config", + help="Path to the ssh config file.", + ), cfg.IntOpt( - 'ssh_connect_timeout', default=60, - help='Max time in seconds to establish the SSH connection.') + "ssh_connect_timeout", + default=60, + help="Max time in seconds to establish the SSH connection.", + ), ] - do_register_opts(ssh_runner_opts, group='ssh_runner') + do_register_opts(ssh_runner_opts, group="ssh_runner") # Common options (used by action runner and sensor container) action_sensor_opts = [ cfg.BoolOpt( - 'enable', default=True, - help='Whether to enable or disable the ability to post a trigger on action.'), + "enable", + default=True, + help="Whether to enable or disable the ability to post a trigger on action.", + ), cfg.ListOpt( - 'emit_when', default=LIVEACTION_COMPLETED_STATES, - help='List of execution statuses for which a trigger will be emitted. ') + "emit_when", + default=LIVEACTION_COMPLETED_STATES, + help="List of execution statuses for which a trigger will be emitted. ", + ), ] - do_register_opts(action_sensor_opts, group='action_sensor') + do_register_opts(action_sensor_opts, group="action_sensor") # Common options for content pack_lib_opts = [ cfg.BoolOpt( - 'enable_common_libs', default=False, - help='Enable/Disable support for pack common libs. ' - 'Setting this config to ``True`` would allow you to ' - 'place common library code for sensors and actions in lib/ folder ' - 'in packs and use them in python sensors and actions. ' - 'See https://docs.stackstorm.com/reference/' - 'sharing_code_sensors_actions.html ' - 'for details.') + "enable_common_libs", + default=False, + help="Enable/Disable support for pack common libs. " + "Setting this config to ``True`` would allow you to " + "place common library code for sensors and actions in lib/ folder " + "in packs and use them in python sensors and actions. " + "See https://docs.stackstorm.com/reference/" + "sharing_code_sensors_actions.html " + "for details.", + ) ] - do_register_opts(pack_lib_opts, group='packs') + do_register_opts(pack_lib_opts, group="packs") # Coordination options coord_opts = [ - cfg.StrOpt( - 'url', default=None, - help='Endpoint for the coordination server.'), + cfg.StrOpt("url", default=None, help="Endpoint for the coordination server."), cfg.IntOpt( - 'lock_timeout', default=60, - help='TTL for the lock if backend suports it.'), + "lock_timeout", default=60, help="TTL for the lock if backend suports it." + ), cfg.BoolOpt( - 'service_registry', default=False, - help='True to register StackStorm services in a service registry.'), + "service_registry", + default=False, + help="True to register StackStorm services in a service registry.", + ), ] - do_register_opts(coord_opts, 'coordination', ignore_errors) + do_register_opts(coord_opts, "coordination", ignore_errors) # XXX: This is required for us to support deprecated config group results_tracker query_opts = [ cfg.IntOpt( - 'thread_pool_size', - help='Number of threads to use to query external workflow systems.'), + "thread_pool_size", + help="Number of threads to use to query external workflow systems.", + ), cfg.FloatOpt( - 'query_interval', - help='Time interval between subsequent queries for a context ' - 'to external workflow system.') + "query_interval", + help="Time interval between subsequent queries for a context " + "to external workflow system.", + ), ] - do_register_opts(query_opts, group='results_tracker', ignore_errors=ignore_errors) + do_register_opts(query_opts, group="results_tracker", ignore_errors=ignore_errors) # Common stream options stream_opts = [ cfg.IntOpt( - 'heartbeat', default=25, - help='Send empty message every N seconds to keep connection open') + "heartbeat", + default=25, + help="Send empty message every N seconds to keep connection open", + ) ] - do_register_opts(stream_opts, group='stream', ignore_errors=ignore_errors) + do_register_opts(stream_opts, group="stream", ignore_errors=ignore_errors) # Common CLI options cli_opts = [ cfg.BoolOpt( - 'debug', default=False, - help='Enable debug mode. By default this will set all log levels to DEBUG.'), + "debug", + default=False, + help="Enable debug mode. By default this will set all log levels to DEBUG.", + ), cfg.BoolOpt( - 'profile', default=False, - help='Enable profile mode. In the profile mode all the MongoDB queries and ' - 'related profile data are logged.'), + "profile", + default=False, + help="Enable profile mode. In the profile mode all the MongoDB queries and " + "related profile data are logged.", + ), cfg.BoolOpt( - 'use-debugger', default=True, - help='Enables debugger. Note that using this option changes how the ' - 'eventlet library is used to support async IO. This could result in ' - 'failures that do not occur under normal operation.') + "use-debugger", + default=True, + help="Enables debugger. Note that using this option changes how the " + "eventlet library is used to support async IO. This could result in " + "failures that do not occur under normal operation.", + ), ] do_register_cli_opts(cli_opts, ignore_errors=ignore_errors) @@ -505,92 +603,121 @@ def register_opts(ignore_errors=False): # Metrics Options stream options metrics_opts = [ cfg.StrOpt( - 'driver', default='noop', - help='Driver type for metrics collection.'), + "driver", default="noop", help="Driver type for metrics collection." + ), cfg.StrOpt( - 'host', default='127.0.0.1', - help='Destination server to connect to if driver requires connection.'), + "host", + default="127.0.0.1", + help="Destination server to connect to if driver requires connection.", + ), cfg.IntOpt( - 'port', default=8125, - help='Destination port to connect to if driver requires connection.'), - cfg.StrOpt( - 'prefix', default=None, - help='Optional prefix which is prepended to all the metric names. Comes handy when ' - 'you want to submit metrics from various environment to the same metric ' - 'backend instance.'), + "port", + default=8125, + help="Destination port to connect to if driver requires connection.", + ), + cfg.StrOpt( + "prefix", + default=None, + help="Optional prefix which is prepended to all the metric names. Comes handy when " + "you want to submit metrics from various environment to the same metric " + "backend instance.", + ), cfg.FloatOpt( - 'sample_rate', default=1, - help='Randomly sample and only send metrics for X% of metric operations to the ' - 'backend. Default value of 1 means no sampling is done and all the metrics are ' - 'sent to the backend. E.g. 0.1 would mean 10% of operations are sampled.') - + "sample_rate", + default=1, + help="Randomly sample and only send metrics for X% of metric operations to the " + "backend. Default value of 1 means no sampling is done and all the metrics are " + "sent to the backend. E.g. 0.1 would mean 10% of operations are sampled.", + ), ] - do_register_opts(metrics_opts, group='metrics', ignore_errors=ignore_errors) + do_register_opts(metrics_opts, group="metrics", ignore_errors=ignore_errors) # Common timers engine options timer_logging_opts = [ cfg.StrOpt( - 'logging', default=None, - help='Location of the logging configuration file. ' - 'NOTE: Deprecated in favor of timersengine.logging'), + "logging", + default=None, + help="Location of the logging configuration file. " + "NOTE: Deprecated in favor of timersengine.logging", + ), ] timers_engine_logging_opts = [ cfg.StrOpt( - 'logging', default='/etc/st2/logging.timersengine.conf', - help='Location of the logging configuration file.') + "logging", + default="/etc/st2/logging.timersengine.conf", + help="Location of the logging configuration file.", + ) ] - do_register_opts(timer_logging_opts, group='timer', ignore_errors=ignore_errors) - do_register_opts(timers_engine_logging_opts, group='timersengine', ignore_errors=ignore_errors) + do_register_opts(timer_logging_opts, group="timer", ignore_errors=ignore_errors) + do_register_opts( + timers_engine_logging_opts, group="timersengine", ignore_errors=ignore_errors + ) # NOTE: We default old style deprecated "timer" options to None so our code # works correclty and "timersengine" has precedence over "timers" # NOTE: "timer" section will be removed in v3.1 timer_opts = [ cfg.StrOpt( - 'local_timezone', default=None, - help='Timezone pertaining to the location where st2 is run. ' - 'NOTE: Deprecated in favor of timersengine.local_timezone'), + "local_timezone", + default=None, + help="Timezone pertaining to the location where st2 is run. " + "NOTE: Deprecated in favor of timersengine.local_timezone", + ), cfg.BoolOpt( - 'enable', default=None, - help='Specify to enable timer service. ' - 'NOTE: Deprecated in favor of timersengine.enable'), + "enable", + default=None, + help="Specify to enable timer service. " + "NOTE: Deprecated in favor of timersengine.enable", + ), ] timers_engine_opts = [ cfg.StrOpt( - 'local_timezone', default='America/Los_Angeles', - help='Timezone pertaining to the location where st2 is run.'), - cfg.BoolOpt( - 'enable', default=True, - help='Specify to enable timer service.') + "local_timezone", + default="America/Los_Angeles", + help="Timezone pertaining to the location where st2 is run.", + ), + cfg.BoolOpt("enable", default=True, help="Specify to enable timer service."), ] - do_register_opts(timer_opts, group='timer', ignore_errors=ignore_errors) - do_register_opts(timers_engine_opts, group='timersengine', ignore_errors=ignore_errors) + do_register_opts(timer_opts, group="timer", ignore_errors=ignore_errors) + do_register_opts( + timers_engine_opts, group="timersengine", ignore_errors=ignore_errors + ) # Workflow engine options workflow_engine_opts = [ cfg.IntOpt( - 'retry_stop_max_msec', default=60000, - help='Max time to stop retrying.'), + "retry_stop_max_msec", default=60000, help="Max time to stop retrying." + ), cfg.IntOpt( - 'retry_wait_fixed_msec', default=1000, - help='Interval inbetween retries.'), + "retry_wait_fixed_msec", default=1000, help="Interval inbetween retries." + ), cfg.FloatOpt( - 'retry_max_jitter_msec', default=1000, - help='Max jitter interval to smooth out retries.'), + "retry_max_jitter_msec", + default=1000, + help="Max jitter interval to smooth out retries.", + ), cfg.IntOpt( - 'gc_max_idle_sec', default=0, - help='Max seconds to allow workflow execution be idled before it is identified as ' - 'orphaned and cancelled by the garbage collector. A value of zero means the ' - 'feature is disabled. This is disabled by default.') + "gc_max_idle_sec", + default=0, + help="Max seconds to allow workflow execution be idled before it is identified as " + "orphaned and cancelled by the garbage collector. A value of zero means the " + "feature is disabled. This is disabled by default.", + ), ] - do_register_opts(workflow_engine_opts, group='workflow_engine', ignore_errors=ignore_errors) + do_register_opts( + workflow_engine_opts, group="workflow_engine", ignore_errors=ignore_errors + ) def parse_args(args=None): register_opts() - cfg.CONF(args=args, version=VERSION_STRING, default_config_files=[DEFAULT_CONFIG_FILE_PATH]) + cfg.CONF( + args=args, + version=VERSION_STRING, + default_config_files=[DEFAULT_CONFIG_FILE_PATH], + ) diff --git a/st2common/st2common/constants/action.py b/st2common/st2common/constants/action.py index c28725f225b..5587b0be91f 100644 --- a/st2common/st2common/constants/action.py +++ b/st2common/st2common/constants/action.py @@ -14,61 +14,56 @@ # limitations under the License. __all__ = [ - 'ACTION_NAME', - 'ACTION_ID', - - 'LIBS_DIR', - - 'LIVEACTION_STATUS_REQUESTED', - 'LIVEACTION_STATUS_SCHEDULED', - 'LIVEACTION_STATUS_DELAYED', - 'LIVEACTION_STATUS_RUNNING', - 'LIVEACTION_STATUS_SUCCEEDED', - 'LIVEACTION_STATUS_FAILED', - 'LIVEACTION_STATUS_TIMED_OUT', - 'LIVEACTION_STATUS_CANCELING', - 'LIVEACTION_STATUS_CANCELED', - 'LIVEACTION_STATUS_PENDING', - 'LIVEACTION_STATUS_PAUSING', - 'LIVEACTION_STATUS_PAUSED', - 'LIVEACTION_STATUS_RESUMING', - - 'LIVEACTION_STATUSES', - 'LIVEACTION_RUNNABLE_STATES', - 'LIVEACTION_DELAYED_STATES', - 'LIVEACTION_CANCELABLE_STATES', - 'LIVEACTION_FAILED_STATES', - 'LIVEACTION_COMPLETED_STATES', - - 'ACTION_OUTPUT_RESULT_DELIMITER', - 'ACTION_CONTEXT_KV_PREFIX', - 'ACTION_PARAMETERS_KV_PREFIX', - 'ACTION_RESULTS_KV_PREFIX', - - 'WORKFLOW_RUNNER_TYPES' + "ACTION_NAME", + "ACTION_ID", + "LIBS_DIR", + "LIVEACTION_STATUS_REQUESTED", + "LIVEACTION_STATUS_SCHEDULED", + "LIVEACTION_STATUS_DELAYED", + "LIVEACTION_STATUS_RUNNING", + "LIVEACTION_STATUS_SUCCEEDED", + "LIVEACTION_STATUS_FAILED", + "LIVEACTION_STATUS_TIMED_OUT", + "LIVEACTION_STATUS_CANCELING", + "LIVEACTION_STATUS_CANCELED", + "LIVEACTION_STATUS_PENDING", + "LIVEACTION_STATUS_PAUSING", + "LIVEACTION_STATUS_PAUSED", + "LIVEACTION_STATUS_RESUMING", + "LIVEACTION_STATUSES", + "LIVEACTION_RUNNABLE_STATES", + "LIVEACTION_DELAYED_STATES", + "LIVEACTION_CANCELABLE_STATES", + "LIVEACTION_FAILED_STATES", + "LIVEACTION_COMPLETED_STATES", + "ACTION_OUTPUT_RESULT_DELIMITER", + "ACTION_CONTEXT_KV_PREFIX", + "ACTION_PARAMETERS_KV_PREFIX", + "ACTION_RESULTS_KV_PREFIX", + "WORKFLOW_RUNNER_TYPES", ] -ACTION_NAME = 'name' -ACTION_ID = 'id' -ACTION_PACK = 'pack' +ACTION_NAME = "name" +ACTION_ID = "id" +ACTION_PACK = "pack" -LIBS_DIR = 'lib' +LIBS_DIR = "lib" -LIVEACTION_STATUS_REQUESTED = 'requested' -LIVEACTION_STATUS_SCHEDULED = 'scheduled' -LIVEACTION_STATUS_DELAYED = 'delayed' -LIVEACTION_STATUS_RUNNING = 'running' -LIVEACTION_STATUS_SUCCEEDED = 'succeeded' -LIVEACTION_STATUS_FAILED = 'failed' -LIVEACTION_STATUS_TIMED_OUT = 'timeout' -LIVEACTION_STATUS_ABANDONED = 'abandoned' -LIVEACTION_STATUS_CANCELING = 'canceling' -LIVEACTION_STATUS_CANCELED = 'canceled' -LIVEACTION_STATUS_PENDING = 'pending' -LIVEACTION_STATUS_PAUSING = 'pausing' -LIVEACTION_STATUS_PAUSED = 'paused' -LIVEACTION_STATUS_RESUMING = 'resuming' +LIVEACTION_STATUS_REQUESTED = "requested" +LIVEACTION_STATUS_SCHEDULED = "scheduled" +LIVEACTION_STATUS_DELAYED = "delayed" +LIVEACTION_STATUS_RUNNING = "running" +LIVEACTION_STATUS_SUCCEEDED = "succeeded" +LIVEACTION_STATUS_FAILED = "failed" +LIVEACTION_STATUS_TIMED_OUT = "timeout" +LIVEACTION_STATUS_ABANDONED = "abandoned" +LIVEACTION_STATUS_CANCELING = "canceling" +LIVEACTION_STATUS_CANCELED = "canceled" +LIVEACTION_STATUS_PENDING = "pending" +LIVEACTION_STATUS_PAUSING = "pausing" +LIVEACTION_STATUS_PAUSED = "paused" +LIVEACTION_STATUS_RESUMING = "resuming" LIVEACTION_STATUSES = [ LIVEACTION_STATUS_REQUESTED, @@ -84,25 +79,23 @@ LIVEACTION_STATUS_PENDING, LIVEACTION_STATUS_PAUSING, LIVEACTION_STATUS_PAUSED, - LIVEACTION_STATUS_RESUMING + LIVEACTION_STATUS_RESUMING, ] -ACTION_OUTPUT_RESULT_DELIMITER = '%%%%%~=~=~=************=~=~=~%%%%' -ACTION_CONTEXT_KV_PREFIX = 'action_context' -ACTION_PARAMETERS_KV_PREFIX = 'action_parameters' -ACTION_RESULTS_KV_PREFIX = 'action_results' +ACTION_OUTPUT_RESULT_DELIMITER = "%%%%%~=~=~=************=~=~=~%%%%" +ACTION_CONTEXT_KV_PREFIX = "action_context" +ACTION_PARAMETERS_KV_PREFIX = "action_parameters" +ACTION_RESULTS_KV_PREFIX = "action_results" LIVEACTION_RUNNABLE_STATES = [ LIVEACTION_STATUS_REQUESTED, LIVEACTION_STATUS_SCHEDULED, LIVEACTION_STATUS_PAUSING, LIVEACTION_STATUS_PAUSED, - LIVEACTION_STATUS_RESUMING + LIVEACTION_STATUS_RESUMING, ] -LIVEACTION_DELAYED_STATES = [ - LIVEACTION_STATUS_DELAYED -] +LIVEACTION_DELAYED_STATES = [LIVEACTION_STATUS_DELAYED] LIVEACTION_CANCELABLE_STATES = [ LIVEACTION_STATUS_REQUESTED, @@ -111,7 +104,7 @@ LIVEACTION_STATUS_RUNNING, LIVEACTION_STATUS_PAUSING, LIVEACTION_STATUS_PAUSED, - LIVEACTION_STATUS_RESUMING + LIVEACTION_STATUS_RESUMING, ] LIVEACTION_COMPLETED_STATES = [ @@ -119,29 +112,20 @@ LIVEACTION_STATUS_FAILED, LIVEACTION_STATUS_TIMED_OUT, LIVEACTION_STATUS_CANCELED, - LIVEACTION_STATUS_ABANDONED + LIVEACTION_STATUS_ABANDONED, ] LIVEACTION_FAILED_STATES = [ LIVEACTION_STATUS_FAILED, LIVEACTION_STATUS_TIMED_OUT, - LIVEACTION_STATUS_ABANDONED + LIVEACTION_STATUS_ABANDONED, ] -LIVEACTION_PAUSE_STATES = [ - LIVEACTION_STATUS_PAUSING, - LIVEACTION_STATUS_PAUSED -] +LIVEACTION_PAUSE_STATES = [LIVEACTION_STATUS_PAUSING, LIVEACTION_STATUS_PAUSED] -LIVEACTION_CANCEL_STATES = [ - LIVEACTION_STATUS_CANCELING, - LIVEACTION_STATUS_CANCELED -] +LIVEACTION_CANCEL_STATES = [LIVEACTION_STATUS_CANCELING, LIVEACTION_STATUS_CANCELED] -WORKFLOW_RUNNER_TYPES = [ - 'action-chain', - 'orquesta' -] +WORKFLOW_RUNNER_TYPES = ["action-chain", "orquesta"] # Linux's limit for param size _LINUX_PARAM_LIMIT = 131072 diff --git a/st2common/st2common/constants/api.py b/st2common/st2common/constants/api.py index c1df81fb0d1..2690133314f 100644 --- a/st2common/st2common/constants/api.py +++ b/st2common/st2common/constants/api.py @@ -13,11 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = [ - 'DEFAULT_API_VERSION' -] +__all__ = ["DEFAULT_API_VERSION"] -DEFAULT_API_VERSION = 'v1' +DEFAULT_API_VERSION = "v1" -REQUEST_ID_HEADER = 'X-Request-ID' +REQUEST_ID_HEADER = "X-Request-ID" diff --git a/st2common/st2common/constants/auth.py b/st2common/st2common/constants/auth.py index f0664739ceb..7b4003c0ef3 100644 --- a/st2common/st2common/constants/auth.py +++ b/st2common/st2common/constants/auth.py @@ -14,26 +14,22 @@ # limitations under the License. __all__ = [ - 'VALID_MODES', - 'DEFAULT_MODE', - 'DEFAULT_BACKEND', - - 'HEADER_ATTRIBUTE_NAME', - 'QUERY_PARAM_ATTRIBUTE_NAME' + "VALID_MODES", + "DEFAULT_MODE", + "DEFAULT_BACKEND", + "HEADER_ATTRIBUTE_NAME", + "QUERY_PARAM_ATTRIBUTE_NAME", ] -VALID_MODES = [ - 'proxy', - 'standalone' -] +VALID_MODES = ["proxy", "standalone"] -HEADER_ATTRIBUTE_NAME = 'X-Auth-Token' -QUERY_PARAM_ATTRIBUTE_NAME = 'x-auth-token' +HEADER_ATTRIBUTE_NAME = "X-Auth-Token" +QUERY_PARAM_ATTRIBUTE_NAME = "x-auth-token" -HEADER_API_KEY_ATTRIBUTE_NAME = 'St2-Api-Key' -QUERY_PARAM_API_KEY_ATTRIBUTE_NAME = 'st2-api-key' +HEADER_API_KEY_ATTRIBUTE_NAME = "St2-Api-Key" +QUERY_PARAM_API_KEY_ATTRIBUTE_NAME = "st2-api-key" -DEFAULT_MODE = 'standalone' +DEFAULT_MODE = "standalone" -DEFAULT_BACKEND = 'flat_file' -DEFAULT_SSO_BACKEND = 'noop' +DEFAULT_BACKEND = "flat_file" +DEFAULT_SSO_BACKEND = "noop" diff --git a/st2common/st2common/constants/error_messages.py b/st2common/st2common/constants/error_messages.py index 7aa56c4025d..7c70377721a 100644 --- a/st2common/st2common/constants/error_messages.py +++ b/st2common/st2common/constants/error_messages.py @@ -13,21 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = [ - 'PACK_VIRTUALENV_DOESNT_EXIST', - 'PYTHON2_DEPRECATION' -] +__all__ = ["PACK_VIRTUALENV_DOESNT_EXIST", "PYTHON2_DEPRECATION"] -PACK_VIRTUALENV_DOESNT_EXIST = ''' +PACK_VIRTUALENV_DOESNT_EXIST = """ The virtual environment (%(virtualenv_path)s) for pack "%(pack)s" does not exist. Normally this is created when you install a pack using "st2 pack install". If you installed your pack by some other means, you can create a new virtual environment using the command: "st2 run packs.setup_virtualenv packs=%(pack)s" -''' +""" -PYTHON2_DEPRECATION = 'DEPRECATION WARNING. Support for python 2 will be removed in future ' \ - 'StackStorm releases. Please ensure that all packs used are python ' \ - '3 compatible. Your StackStorm installation may be upgraded from ' \ - 'python 2 to python 3 in future platform releases. It is recommended ' \ - 'to plan the manual migration to a python 3 native platform, e.g. ' \ - 'Ubuntu 18.04 LTS or CentOS/RHEL 8.' +PYTHON2_DEPRECATION = ( + "DEPRECATION WARNING. Support for python 2 will be removed in future " + "StackStorm releases. Please ensure that all packs used are python " + "3 compatible. Your StackStorm installation may be upgraded from " + "python 2 to python 3 in future platform releases. It is recommended " + "to plan the manual migration to a python 3 native platform, e.g. " + "Ubuntu 18.04 LTS or CentOS/RHEL 8." +) diff --git a/st2common/st2common/constants/exit_codes.py b/st2common/st2common/constants/exit_codes.py index 8fd1efd9a78..1b32e89e269 100644 --- a/st2common/st2common/constants/exit_codes.py +++ b/st2common/st2common/constants/exit_codes.py @@ -14,10 +14,10 @@ # limitations under the License. __all__ = [ - 'SUCCESS_EXIT_CODE', - 'FAILURE_EXIT_CODE', - 'SIGKILL_EXIT_CODE', - 'SIGTERM_EXIT_CODE' + "SUCCESS_EXIT_CODE", + "FAILURE_EXIT_CODE", + "SIGKILL_EXIT_CODE", + "SIGTERM_EXIT_CODE", ] SUCCESS_EXIT_CODE = 0 diff --git a/st2common/st2common/constants/garbage_collection.py b/st2common/st2common/constants/garbage_collection.py index dad31218968..ac8a2aac5f8 100644 --- a/st2common/st2common/constants/garbage_collection.py +++ b/st2common/st2common/constants/garbage_collection.py @@ -14,10 +14,10 @@ # limitations under the License. __all__ = [ - 'DEFAULT_COLLECTION_INTERVAL', - 'DEFAULT_SLEEP_DELAY', - 'MINIMUM_TTL_DAYS', - 'MINIMUM_TTL_DAYS_EXECUTION_OUTPUT' + "DEFAULT_COLLECTION_INTERVAL", + "DEFAULT_SLEEP_DELAY", + "MINIMUM_TTL_DAYS", + "MINIMUM_TTL_DAYS_EXECUTION_OUTPUT", ] diff --git a/st2common/st2common/constants/keyvalue.py b/st2common/st2common/constants/keyvalue.py index 2897f1e32df..7a21eab8ec4 100644 --- a/st2common/st2common/constants/keyvalue.py +++ b/st2common/st2common/constants/keyvalue.py @@ -14,46 +14,49 @@ # limitations under the License. __all__ = [ - 'ALLOWED_SCOPES', - 'SYSTEM_SCOPE', - 'FULL_SYSTEM_SCOPE', - 'SYSTEM_SCOPES', - 'USER_SCOPE', - 'FULL_USER_SCOPE', - 'USER_SCOPES', - 'USER_SEPARATOR', - - 'DATASTORE_SCOPE_SEPARATOR', - 'DATASTORE_KEY_SEPARATOR' + "ALLOWED_SCOPES", + "SYSTEM_SCOPE", + "FULL_SYSTEM_SCOPE", + "SYSTEM_SCOPES", + "USER_SCOPE", + "FULL_USER_SCOPE", + "USER_SCOPES", + "USER_SEPARATOR", + "DATASTORE_SCOPE_SEPARATOR", + "DATASTORE_KEY_SEPARATOR", ] -ALL_SCOPE = 'all' +ALL_SCOPE = "all" # Parent namespace for all items in key-value store -DATASTORE_PARENT_SCOPE = 'st2kv' -DATASTORE_SCOPE_SEPARATOR = '.' # To separate scope from datastore namespace. E.g. st2kv.system +DATASTORE_PARENT_SCOPE = "st2kv" +DATASTORE_SCOPE_SEPARATOR = ( + "." # To separate scope from datastore namespace. E.g. st2kv.system +) # Namespace to contain all system/global scoped variables in key-value store. -SYSTEM_SCOPE = 'system' -FULL_SYSTEM_SCOPE = '%s%s%s' % (DATASTORE_PARENT_SCOPE, DATASTORE_SCOPE_SEPARATOR, SYSTEM_SCOPE) +SYSTEM_SCOPE = "system" +FULL_SYSTEM_SCOPE = "%s%s%s" % ( + DATASTORE_PARENT_SCOPE, + DATASTORE_SCOPE_SEPARATOR, + SYSTEM_SCOPE, +) SYSTEM_SCOPES = [SYSTEM_SCOPE] # Namespace to contain all user scoped variables in key-value store. -USER_SCOPE = 'user' -FULL_USER_SCOPE = '%s%s%s' % (DATASTORE_PARENT_SCOPE, DATASTORE_SCOPE_SEPARATOR, USER_SCOPE) +USER_SCOPE = "user" +FULL_USER_SCOPE = "%s%s%s" % ( + DATASTORE_PARENT_SCOPE, + DATASTORE_SCOPE_SEPARATOR, + USER_SCOPE, +) USER_SCOPES = [USER_SCOPE] -USER_SEPARATOR = ':' +USER_SEPARATOR = ":" # Separator for keys in the datastore -DATASTORE_KEY_SEPARATOR = ':' - -ALLOWED_SCOPES = [ - SYSTEM_SCOPE, - USER_SCOPE, +DATASTORE_KEY_SEPARATOR = ":" - FULL_SYSTEM_SCOPE, - FULL_USER_SCOPE -] +ALLOWED_SCOPES = [SYSTEM_SCOPE, USER_SCOPE, FULL_SYSTEM_SCOPE, FULL_USER_SCOPE] diff --git a/st2common/st2common/constants/logging.py b/st2common/st2common/constants/logging.py index b62a59bd00d..0985a039473 100644 --- a/st2common/st2common/constants/logging.py +++ b/st2common/st2common/constants/logging.py @@ -16,11 +16,9 @@ from __future__ import absolute_import import os -__all__ = [ - 'DEFAULT_LOGGING_CONF_PATH' -] +__all__ = ["DEFAULT_LOGGING_CONF_PATH"] BASE_PATH = os.path.dirname(os.path.abspath(__file__)) -DEFAULT_LOGGING_CONF_PATH = os.path.join(BASE_PATH, '../conf/base.logging.conf') +DEFAULT_LOGGING_CONF_PATH = os.path.join(BASE_PATH, "../conf/base.logging.conf") DEFAULT_LOGGING_CONF_PATH = os.path.abspath(DEFAULT_LOGGING_CONF_PATH) diff --git a/st2common/st2common/constants/meta.py b/st2common/st2common/constants/meta.py index ac4859b5e1a..acd348a3555 100644 --- a/st2common/st2common/constants/meta.py +++ b/st2common/st2common/constants/meta.py @@ -16,10 +16,7 @@ from __future__ import absolute_import import yaml -__all__ = [ - 'ALLOWED_EXTS', - 'PARSER_FUNCS' -] +__all__ = ["ALLOWED_EXTS", "PARSER_FUNCS"] -ALLOWED_EXTS = ['.yaml', '.yml'] -PARSER_FUNCS = {'.yml': yaml.safe_load, '.yaml': yaml.safe_load} +ALLOWED_EXTS = [".yaml", ".yml"] +PARSER_FUNCS = {".yml": yaml.safe_load, ".yaml": yaml.safe_load} diff --git a/st2common/st2common/constants/pack.py b/st2common/st2common/constants/pack.py index 91ae5a5e2c4..f782a6920cd 100644 --- a/st2common/st2common/constants/pack.py +++ b/st2common/st2common/constants/pack.py @@ -14,81 +14,74 @@ # limitations under the License. __all__ = [ - 'PACKS_PACK_NAME', - 'PACK_REF_WHITELIST_REGEX', - 'PACK_RESERVED_CHARACTERS', - 'PACK_VERSION_SEPARATOR', - 'PACK_VERSION_REGEX', - 'ST2_VERSION_REGEX', - 'SYSTEM_PACK_NAME', - 'PACKS_PACK_NAME', - 'LINUX_PACK_NAME', - 'SYSTEM_PACK_NAMES', - 'CHATOPS_PACK_NAME', - 'USER_PACK_NAME_BLACKLIST', - 'BASE_PACK_REQUIREMENTS', - 'MANIFEST_FILE_NAME', - 'CONFIG_SCHEMA_FILE_NAME' + "PACKS_PACK_NAME", + "PACK_REF_WHITELIST_REGEX", + "PACK_RESERVED_CHARACTERS", + "PACK_VERSION_SEPARATOR", + "PACK_VERSION_REGEX", + "ST2_VERSION_REGEX", + "SYSTEM_PACK_NAME", + "PACKS_PACK_NAME", + "LINUX_PACK_NAME", + "SYSTEM_PACK_NAMES", + "CHATOPS_PACK_NAME", + "USER_PACK_NAME_BLACKLIST", + "BASE_PACK_REQUIREMENTS", + "MANIFEST_FILE_NAME", + "CONFIG_SCHEMA_FILE_NAME", ] # Prefix for render context w/ config -PACK_CONFIG_CONTEXT_KV_PREFIX = 'config_context' +PACK_CONFIG_CONTEXT_KV_PREFIX = "config_context" # A list of allowed characters for the pack name -PACK_REF_WHITELIST_REGEX = r'^[a-z0-9_]+$' +PACK_REF_WHITELIST_REGEX = r"^[a-z0-9_]+$" # Check for a valid semver string -PACK_VERSION_REGEX = r'^(?:0|[1-9]\d*)\.(?:0|[1-9]\d*)\.(?:0|[1-9]\d*)(?:-[\da-z\-]+(?:\.[\da-z\-]+)*)?(?:\+[\da-z\-]+(?:\.[\da-z\-]+)*)?$' # noqa +PACK_VERSION_REGEX = r"^(?:0|[1-9]\d*)\.(?:0|[1-9]\d*)\.(?:0|[1-9]\d*)(?:-[\da-z\-]+(?:\.[\da-z\-]+)*)?(?:\+[\da-z\-]+(?:\.[\da-z\-]+)*)?$" # noqa # Special characters which can't be used in pack names -PACK_RESERVED_CHARACTERS = [ - '.' -] +PACK_RESERVED_CHARACTERS = ["."] # Version sperator when version is supplied in pack name # Example: libcloud@1.0.1 -PACK_VERSION_SEPARATOR = '=' +PACK_VERSION_SEPARATOR = "=" # Check for st2 version in engines -ST2_VERSION_REGEX = r'^((>?>|>=|=|<=|?>|>=|=|<=|=1.9.0,<2.0' -] +BASE_PACK_REQUIREMENTS = ["six>=1.9.0,<2.0"] # Name of the pack manifest file -MANIFEST_FILE_NAME = 'pack.yaml' +MANIFEST_FILE_NAME = "pack.yaml" # File name for the config schema file -CONFIG_SCHEMA_FILE_NAME = 'config.schema.yaml' +CONFIG_SCHEMA_FILE_NAME = "config.schema.yaml" diff --git a/st2common/st2common/constants/policy.py b/st2common/st2common/constants/policy.py index e36ce8fc126..7ce7093ed5f 100644 --- a/st2common/st2common/constants/policy.py +++ b/st2common/st2common/constants/policy.py @@ -13,13 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = [ - 'POLICY_TYPES_REQUIRING_LOCK' -] +__all__ = ["POLICY_TYPES_REQUIRING_LOCK"] # Concurrency policies require scheduler to acquire a distributed lock to prevent race # in scheduling when there are multiple scheduler instances. -POLICY_TYPES_REQUIRING_LOCK = [ - 'action.concurrency', - 'action.concurrency.attr' -] +POLICY_TYPES_REQUIRING_LOCK = ["action.concurrency", "action.concurrency.attr"] diff --git a/st2common/st2common/constants/rule_enforcement.py b/st2common/st2common/constants/rule_enforcement.py index fced450304d..ceece2d6e1c 100644 --- a/st2common/st2common/constants/rule_enforcement.py +++ b/st2common/st2common/constants/rule_enforcement.py @@ -14,16 +14,15 @@ # limitations under the License. __all__ = [ - 'RULE_ENFORCEMENT_STATUS_SUCCEEDED', - 'RULE_ENFORCEMENT_STATUS_FAILED', - - 'RULE_ENFORCEMENT_STATUSES' + "RULE_ENFORCEMENT_STATUS_SUCCEEDED", + "RULE_ENFORCEMENT_STATUS_FAILED", + "RULE_ENFORCEMENT_STATUSES", ] -RULE_ENFORCEMENT_STATUS_SUCCEEDED = 'succeeded' -RULE_ENFORCEMENT_STATUS_FAILED = 'failed' +RULE_ENFORCEMENT_STATUS_SUCCEEDED = "succeeded" +RULE_ENFORCEMENT_STATUS_FAILED = "failed" RULE_ENFORCEMENT_STATUSES = [ RULE_ENFORCEMENT_STATUS_SUCCEEDED, - RULE_ENFORCEMENT_STATUS_FAILED + RULE_ENFORCEMENT_STATUS_FAILED, ] diff --git a/st2common/st2common/constants/rules.py b/st2common/st2common/constants/rules.py index 393e94aebbc..929e4b5e920 100644 --- a/st2common/st2common/constants/rules.py +++ b/st2common/st2common/constants/rules.py @@ -13,10 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -TRIGGER_PAYLOAD_PREFIX = 'trigger' -TRIGGER_ITEM_PAYLOAD_PREFIX = 'item' +TRIGGER_PAYLOAD_PREFIX = "trigger" +TRIGGER_ITEM_PAYLOAD_PREFIX = "item" -RULE_TYPE_STANDARD = 'standard' -RULE_TYPE_BACKSTOP = 'backstop' +RULE_TYPE_STANDARD = "standard" +RULE_TYPE_BACKSTOP = "backstop" -MATCH_CRITERIA = r'({{)\s*(.*)\s*(}})' +MATCH_CRITERIA = r"({{)\s*(.*)\s*(}})" diff --git a/st2common/st2common/constants/runners.py b/st2common/st2common/constants/runners.py index fe78a6497fc..52ec7383842 100644 --- a/st2common/st2common/constants/runners.py +++ b/st2common/st2common/constants/runners.py @@ -17,36 +17,28 @@ from oslo_config import cfg __all__ = [ - 'RUNNER_NAME_WHITELIST', - - 'MANIFEST_FILE_NAME', - - 'LOCAL_RUNNER_DEFAULT_ACTION_TIMEOUT', - - 'REMOTE_RUNNER_DEFAULT_ACTION_TIMEOUT', - 'REMOTE_RUNNER_DEFAULT_REMOTE_DIR', - 'REMOTE_RUNNER_PRIVATE_KEY_HEADER', - - 'PYTHON_RUNNER_DEFAULT_ACTION_TIMEOUT', - 'PYTHON_RUNNER_INVALID_ACTION_STATUS_EXIT_CODE', - - 'WINDOWS_RUNNER_DEFAULT_ACTION_TIMEOUT', - - 'COMMON_ACTION_ENV_VARIABLE_PREFIX', - 'COMMON_ACTION_ENV_VARIABLES', - - 'DEFAULT_SSH_PORT', - - 'RUNNERS_NAMESPACE' + "RUNNER_NAME_WHITELIST", + "MANIFEST_FILE_NAME", + "LOCAL_RUNNER_DEFAULT_ACTION_TIMEOUT", + "REMOTE_RUNNER_DEFAULT_ACTION_TIMEOUT", + "REMOTE_RUNNER_DEFAULT_REMOTE_DIR", + "REMOTE_RUNNER_PRIVATE_KEY_HEADER", + "PYTHON_RUNNER_DEFAULT_ACTION_TIMEOUT", + "PYTHON_RUNNER_INVALID_ACTION_STATUS_EXIT_CODE", + "WINDOWS_RUNNER_DEFAULT_ACTION_TIMEOUT", + "COMMON_ACTION_ENV_VARIABLE_PREFIX", + "COMMON_ACTION_ENV_VARIABLES", + "DEFAULT_SSH_PORT", + "RUNNERS_NAMESPACE", ] DEFAULT_SSH_PORT = 22 # A list of allowed characters for the pack name -RUNNER_NAME_WHITELIST = r'^[A-Za-z0-9_-]+' +RUNNER_NAME_WHITELIST = r"^[A-Za-z0-9_-]+" # Manifest file name for runners -MANIFEST_FILE_NAME = 'runner.yaml' +MANIFEST_FILE_NAME = "runner.yaml" # Local runner LOCAL_RUNNER_DEFAULT_ACTION_TIMEOUT = 60 @@ -57,9 +49,9 @@ try: REMOTE_RUNNER_DEFAULT_REMOTE_DIR = cfg.CONF.ssh_runner.remote_dir except: - REMOTE_RUNNER_DEFAULT_REMOTE_DIR = '/tmp' + REMOTE_RUNNER_DEFAULT_REMOTE_DIR = "/tmp" -REMOTE_RUNNER_PRIVATE_KEY_HEADER = 'PRIVATE KEY-----'.lower() +REMOTE_RUNNER_PRIVATE_KEY_HEADER = "PRIVATE KEY-----".lower() # Python runner # Default timeout (in seconds) for actions executed by Python runner @@ -69,20 +61,20 @@ # action returns invalid status from the run() method PYTHON_RUNNER_INVALID_ACTION_STATUS_EXIT_CODE = 220 -PYTHON_RUNNER_DEFAULT_LOG_LEVEL = 'DEBUG' +PYTHON_RUNNER_DEFAULT_LOG_LEVEL = "DEBUG" # Windows runner WINDOWS_RUNNER_DEFAULT_ACTION_TIMEOUT = 10 * 60 # Prefix for common st2 environment variables which are available to the actions -COMMON_ACTION_ENV_VARIABLE_PREFIX = 'ST2_ACTION_' +COMMON_ACTION_ENV_VARIABLE_PREFIX = "ST2_ACTION_" # Common st2 environment variables which are available to the actions COMMON_ACTION_ENV_VARIABLES = [ - 'ST2_ACTION_PACK_NAME', - 'ST2_ACTION_EXECUTION_ID', - 'ST2_ACTION_API_URL', - 'ST2_ACTION_AUTH_TOKEN' + "ST2_ACTION_PACK_NAME", + "ST2_ACTION_EXECUTION_ID", + "ST2_ACTION_API_URL", + "ST2_ACTION_AUTH_TOKEN", ] # Namespaces for dynamically loaded runner modules diff --git a/st2common/st2common/constants/scheduler.py b/st2common/st2common/constants/scheduler.py index d825d2aed09..fb97971a3cb 100644 --- a/st2common/st2common/constants/scheduler.py +++ b/st2common/st2common/constants/scheduler.py @@ -13,12 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = [ - 'SCHEDULER_ENABLED_LOG_LINE', - 'SCHEDULER_DISABLED_LOG_LINE' -] +__all__ = ["SCHEDULER_ENABLED_LOG_LINE", "SCHEDULER_DISABLED_LOG_LINE"] # Integration tests look for these loglines to validate scheduler enable/disable -SCHEDULER_ENABLED_LOG_LINE = 'Scheduler is enabled.' -SCHEDULER_DISABLED_LOG_LINE = 'Scheduler is disabled.' +SCHEDULER_ENABLED_LOG_LINE = "Scheduler is enabled." +SCHEDULER_DISABLED_LOG_LINE = "Scheduler is disabled." diff --git a/st2common/st2common/constants/secrets.py b/st2common/st2common/constants/secrets.py index d3f9e53b9ed..ef9a02d5ee8 100644 --- a/st2common/st2common/constants/secrets.py +++ b/st2common/st2common/constants/secrets.py @@ -13,22 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = [ - 'MASKED_ATTRIBUTES_BLACKLIST', - 'MASKED_ATTRIBUTE_VALUE' -] +__all__ = ["MASKED_ATTRIBUTES_BLACKLIST", "MASKED_ATTRIBUTE_VALUE"] # A blacklist of attributes which should be masked in the log messages by default. # Note: If an attribute is an object or a dict, we try to recursively process it and mask the # values. MASKED_ATTRIBUTES_BLACKLIST = [ - 'password', - 'auth_token', - 'token', - 'secret', - 'credentials', - 'st2_auth_token' + "password", + "auth_token", + "token", + "secret", + "credentials", + "st2_auth_token", ] # Value with which the masked attribute values are replaced -MASKED_ATTRIBUTE_VALUE = '********' +MASKED_ATTRIBUTE_VALUE = "********" diff --git a/st2common/st2common/constants/sensors.py b/st2common/st2common/constants/sensors.py index 3ba4f9487d3..a2d7903d187 100644 --- a/st2common/st2common/constants/sensors.py +++ b/st2common/st2common/constants/sensors.py @@ -17,7 +17,7 @@ MINIMUM_POLL_INTERVAL = 4 # keys for PARTITION loaders -DEFAULT_PARTITION_LOADER = 'default' -KVSTORE_PARTITION_LOADER = 'kvstore' -FILE_PARTITION_LOADER = 'file' -HASH_PARTITION_LOADER = 'hash' +DEFAULT_PARTITION_LOADER = "default" +KVSTORE_PARTITION_LOADER = "kvstore" +FILE_PARTITION_LOADER = "file" +HASH_PARTITION_LOADER = "hash" diff --git a/st2common/st2common/constants/system.py b/st2common/st2common/constants/system.py index dcb8ee699cb..9736527171c 100644 --- a/st2common/st2common/constants/system.py +++ b/st2common/st2common/constants/system.py @@ -20,15 +20,14 @@ from st2common import __version__ __all__ = [ - 'VERSION_STRING', - 'DEFAULT_CONFIG_FILE_PATH', - - 'API_URL_ENV_VARIABLE_NAME', - 'AUTH_TOKEN_ENV_VARIABLE_NAME', + "VERSION_STRING", + "DEFAULT_CONFIG_FILE_PATH", + "API_URL_ENV_VARIABLE_NAME", + "AUTH_TOKEN_ENV_VARIABLE_NAME", ] -VERSION_STRING = 'StackStorm v%s' % (__version__) -DEFAULT_CONFIG_FILE_PATH = os.environ.get('ST2_CONFIG_PATH', '/etc/st2/st2.conf') +VERSION_STRING = "StackStorm v%s" % (__version__) +DEFAULT_CONFIG_FILE_PATH = os.environ.get("ST2_CONFIG_PATH", "/etc/st2/st2.conf") -API_URL_ENV_VARIABLE_NAME = 'ST2_API_URL' -AUTH_TOKEN_ENV_VARIABLE_NAME = 'ST2_AUTH_TOKEN' +API_URL_ENV_VARIABLE_NAME = "ST2_API_URL" +AUTH_TOKEN_ENV_VARIABLE_NAME = "ST2_AUTH_TOKEN" diff --git a/st2common/st2common/constants/timer.py b/st2common/st2common/constants/timer.py index 0f191a8027e..97727437926 100644 --- a/st2common/st2common/constants/timer.py +++ b/st2common/st2common/constants/timer.py @@ -13,12 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = [ - 'TIMER_ENABLED_LOG_LINE', - 'TIMER_DISABLED_LOG_LINE' -] +__all__ = ["TIMER_ENABLED_LOG_LINE", "TIMER_DISABLED_LOG_LINE"] # Integration tests look for these loglines to validate timer enable/disable -TIMER_ENABLED_LOG_LINE = 'Timer is enabled.' -TIMER_DISABLED_LOG_LINE = 'Timer is disabled.' +TIMER_ENABLED_LOG_LINE = "Timer is enabled." +TIMER_DISABLED_LOG_LINE = "Timer is disabled." diff --git a/st2common/st2common/constants/trace.py b/st2common/st2common/constants/trace.py index d900912c608..f7e4242da12 100644 --- a/st2common/st2common/constants/trace.py +++ b/st2common/st2common/constants/trace.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = ['TRACE_CONTEXT', 'TRACE_ID'] +__all__ = ["TRACE_CONTEXT", "TRACE_ID"] -TRACE_CONTEXT = 'trace_context' -TRACE_ID = 'trace_tag' +TRACE_CONTEXT = "trace_context" +TRACE_ID = "trace_tag" diff --git a/st2common/st2common/constants/triggers.py b/st2common/st2common/constants/triggers.py index 4a0ccc8e4e0..14ab861fd53 100644 --- a/st2common/st2common/constants/triggers.py +++ b/st2common/st2common/constants/triggers.py @@ -18,244 +18,200 @@ from st2common.models.system.common import ResourceReference __all__ = [ - 'WEBHOOKS_PARAMETERS_SCHEMA', - 'WEBHOOKS_PAYLOAD_SCHEMA', - 'INTERVAL_PARAMETERS_SCHEMA', - 'DATE_PARAMETERS_SCHEMA', - 'CRON_PARAMETERS_SCHEMA', - 'TIMER_PAYLOAD_SCHEMA', - - 'ACTION_SENSOR_TRIGGER', - 'NOTIFY_TRIGGER', - 'ACTION_FILE_WRITTEN_TRIGGER', - 'INQUIRY_TRIGGER', - - 'TIMER_TRIGGER_TYPES', - 'WEBHOOK_TRIGGER_TYPES', - 'WEBHOOK_TRIGGER_TYPE', - 'INTERNAL_TRIGGER_TYPES', - 'SYSTEM_TRIGGER_TYPES', - - 'INTERVAL_TIMER_TRIGGER_REF', - 'DATE_TIMER_TRIGGER_REF', - 'CRON_TIMER_TRIGGER_REF', - - 'TRIGGER_INSTANCE_STATUSES', - 'TRIGGER_INSTANCE_PENDING', - 'TRIGGER_INSTANCE_PROCESSING', - 'TRIGGER_INSTANCE_PROCESSED', - 'TRIGGER_INSTANCE_PROCESSING_FAILED' + "WEBHOOKS_PARAMETERS_SCHEMA", + "WEBHOOKS_PAYLOAD_SCHEMA", + "INTERVAL_PARAMETERS_SCHEMA", + "DATE_PARAMETERS_SCHEMA", + "CRON_PARAMETERS_SCHEMA", + "TIMER_PAYLOAD_SCHEMA", + "ACTION_SENSOR_TRIGGER", + "NOTIFY_TRIGGER", + "ACTION_FILE_WRITTEN_TRIGGER", + "INQUIRY_TRIGGER", + "TIMER_TRIGGER_TYPES", + "WEBHOOK_TRIGGER_TYPES", + "WEBHOOK_TRIGGER_TYPE", + "INTERNAL_TRIGGER_TYPES", + "SYSTEM_TRIGGER_TYPES", + "INTERVAL_TIMER_TRIGGER_REF", + "DATE_TIMER_TRIGGER_REF", + "CRON_TIMER_TRIGGER_REF", + "TRIGGER_INSTANCE_STATUSES", + "TRIGGER_INSTANCE_PENDING", + "TRIGGER_INSTANCE_PROCESSING", + "TRIGGER_INSTANCE_PROCESSED", + "TRIGGER_INSTANCE_PROCESSING_FAILED", ] # Action resource triggers ACTION_SENSOR_TRIGGER = { - 'name': 'st2.generic.actiontrigger', - 'pack': SYSTEM_PACK_NAME, - 'description': 'Trigger encapsulating the completion of an action execution.', - 'payload_schema': { - 'type': 'object', - 'properties': { - 'execution_id': {}, - 'status': {}, - 'start_timestamp': {}, - 'action_name': {}, - 'action_ref': {}, - 'runner_ref': {}, - 'parameters': {}, - 'result': {} - } - } + "name": "st2.generic.actiontrigger", + "pack": SYSTEM_PACK_NAME, + "description": "Trigger encapsulating the completion of an action execution.", + "payload_schema": { + "type": "object", + "properties": { + "execution_id": {}, + "status": {}, + "start_timestamp": {}, + "action_name": {}, + "action_ref": {}, + "runner_ref": {}, + "parameters": {}, + "result": {}, + }, + }, } ACTION_FILE_WRITTEN_TRIGGER = { - 'name': 'st2.action.file_written', - 'pack': SYSTEM_PACK_NAME, - 'description': 'Trigger encapsulating action file being written on disk.', - 'payload_schema': { - 'type': 'object', - 'properties': { - 'ref': {}, - 'file_path': {}, - 'host_info': {} - } - } + "name": "st2.action.file_written", + "pack": SYSTEM_PACK_NAME, + "description": "Trigger encapsulating action file being written on disk.", + "payload_schema": { + "type": "object", + "properties": {"ref": {}, "file_path": {}, "host_info": {}}, + }, } NOTIFY_TRIGGER = { - 'name': 'st2.generic.notifytrigger', - 'pack': SYSTEM_PACK_NAME, - 'description': 'Notification trigger.', - 'payload_schema': { - 'type': 'object', - 'properties': { - 'execution_id': {}, - 'status': {}, - 'start_timestamp': {}, - 'end_timestamp': {}, - 'action_ref': {}, - 'runner_ref': {}, - 'channel': {}, - 'route': {}, - 'message': {}, - 'data': {} - } - } + "name": "st2.generic.notifytrigger", + "pack": SYSTEM_PACK_NAME, + "description": "Notification trigger.", + "payload_schema": { + "type": "object", + "properties": { + "execution_id": {}, + "status": {}, + "start_timestamp": {}, + "end_timestamp": {}, + "action_ref": {}, + "runner_ref": {}, + "channel": {}, + "route": {}, + "message": {}, + "data": {}, + }, + }, } INQUIRY_TRIGGER = { - 'name': 'st2.generic.inquiry', - 'pack': SYSTEM_PACK_NAME, - 'description': 'Trigger indicating a new "inquiry" has entered "pending" status', - 'payload_schema': { - 'type': 'object', - 'properties': { - 'id': { - 'type': 'string', - 'description': 'ID of the new inquiry.', - 'required': True + "name": "st2.generic.inquiry", + "pack": SYSTEM_PACK_NAME, + "description": 'Trigger indicating a new "inquiry" has entered "pending" status', + "payload_schema": { + "type": "object", + "properties": { + "id": { + "type": "string", + "description": "ID of the new inquiry.", + "required": True, + }, + "route": { + "type": "string", + "description": "An arbitrary value for allowing rules " + "to route to proper notification channel.", + "required": True, }, - 'route': { - 'type': 'string', - 'description': 'An arbitrary value for allowing rules ' - 'to route to proper notification channel.', - 'required': True - } }, - "additionalProperties": False - } + "additionalProperties": False, + }, } # Sensor spawn/exit triggers. SENSOR_SPAWN_TRIGGER = { - 'name': 'st2.sensor.process_spawn', - 'pack': SYSTEM_PACK_NAME, - 'description': 'Trigger indicating sensor process is started up.', - 'payload_schema': { - 'type': 'object', - 'properties': { - 'object': {} - } - } + "name": "st2.sensor.process_spawn", + "pack": SYSTEM_PACK_NAME, + "description": "Trigger indicating sensor process is started up.", + "payload_schema": {"type": "object", "properties": {"object": {}}}, } SENSOR_EXIT_TRIGGER = { - 'name': 'st2.sensor.process_exit', - 'pack': SYSTEM_PACK_NAME, - 'description': 'Trigger indicating sensor process is stopped.', - 'payload_schema': { - 'type': 'object', - 'properties': { - 'object': {} - } - } + "name": "st2.sensor.process_exit", + "pack": SYSTEM_PACK_NAME, + "description": "Trigger indicating sensor process is stopped.", + "payload_schema": {"type": "object", "properties": {"object": {}}}, } # KeyValuePair resource triggers KEY_VALUE_PAIR_CREATE_TRIGGER = { - 'name': 'st2.key_value_pair.create', - 'pack': SYSTEM_PACK_NAME, - 'description': 'Trigger encapsulating datastore item creation.', - 'payload_schema': { - 'type': 'object', - 'properties': { - 'object': {} - } - } + "name": "st2.key_value_pair.create", + "pack": SYSTEM_PACK_NAME, + "description": "Trigger encapsulating datastore item creation.", + "payload_schema": {"type": "object", "properties": {"object": {}}}, } KEY_VALUE_PAIR_UPDATE_TRIGGER = { - 'name': 'st2.key_value_pair.update', - 'pack': SYSTEM_PACK_NAME, - 'description': 'Trigger encapsulating datastore set action.', - 'payload_schema': { - 'type': 'object', - 'properties': { - 'object': {} - } - } + "name": "st2.key_value_pair.update", + "pack": SYSTEM_PACK_NAME, + "description": "Trigger encapsulating datastore set action.", + "payload_schema": {"type": "object", "properties": {"object": {}}}, } KEY_VALUE_PAIR_VALUE_CHANGE_TRIGGER = { - 'name': 'st2.key_value_pair.value_change', - 'pack': SYSTEM_PACK_NAME, - 'description': 'Trigger encapsulating a change of datastore item value.', - 'payload_schema': { - 'type': 'object', - 'properties': { - 'old_object': {}, - 'new_object': {} - } - } + "name": "st2.key_value_pair.value_change", + "pack": SYSTEM_PACK_NAME, + "description": "Trigger encapsulating a change of datastore item value.", + "payload_schema": { + "type": "object", + "properties": {"old_object": {}, "new_object": {}}, + }, } KEY_VALUE_PAIR_DELETE_TRIGGER = { - 'name': 'st2.key_value_pair.delete', - 'pack': SYSTEM_PACK_NAME, - 'description': 'Trigger encapsulating datastore item deletion.', - 'payload_schema': { - 'type': 'object', - 'properties': { - 'object': {} - } - } + "name": "st2.key_value_pair.delete", + "pack": SYSTEM_PACK_NAME, + "description": "Trigger encapsulating datastore item deletion.", + "payload_schema": {"type": "object", "properties": {"object": {}}}, } # Internal system triggers which are available for each resource INTERNAL_TRIGGER_TYPES = { - 'action': [ + "action": [ ACTION_SENSOR_TRIGGER, NOTIFY_TRIGGER, ACTION_FILE_WRITTEN_TRIGGER, - INQUIRY_TRIGGER - ], - 'sensor': [ - SENSOR_SPAWN_TRIGGER, - SENSOR_EXIT_TRIGGER + INQUIRY_TRIGGER, ], - 'key_value_pair': [ + "sensor": [SENSOR_SPAWN_TRIGGER, SENSOR_EXIT_TRIGGER], + "key_value_pair": [ KEY_VALUE_PAIR_CREATE_TRIGGER, KEY_VALUE_PAIR_UPDATE_TRIGGER, KEY_VALUE_PAIR_VALUE_CHANGE_TRIGGER, - KEY_VALUE_PAIR_DELETE_TRIGGER - ] + KEY_VALUE_PAIR_DELETE_TRIGGER, + ], } WEBHOOKS_PARAMETERS_SCHEMA = { - 'type': 'object', - 'properties': { - 'url': { - 'type': 'string', - 'required': True - } - }, - 'additionalProperties': False + "type": "object", + "properties": {"url": {"type": "string", "required": True}}, + "additionalProperties": False, } WEBHOOKS_PAYLOAD_SCHEMA = { - 'type': 'object', - 'properties': { - 'headers': { - 'type': 'object' - }, - 'body': { - 'anyOf': [ - {'type': 'array'}, - {'type': 'object'}, + "type": "object", + "properties": { + "headers": {"type": "object"}, + "body": { + "anyOf": [ + {"type": "array"}, + {"type": "object"}, ] - } - } + }, + }, } WEBHOOK_TRIGGER_TYPES = { - ResourceReference.to_string_reference(SYSTEM_PACK_NAME, 'st2.webhook'): { - 'name': 'st2.webhook', - 'pack': SYSTEM_PACK_NAME, - 'description': ('Trigger type for registering webhooks that can consume' - ' arbitrary payload.'), - 'parameters_schema': WEBHOOKS_PARAMETERS_SCHEMA, - 'payload_schema': WEBHOOKS_PAYLOAD_SCHEMA + ResourceReference.to_string_reference(SYSTEM_PACK_NAME, "st2.webhook"): { + "name": "st2.webhook", + "pack": SYSTEM_PACK_NAME, + "description": ( + "Trigger type for registering webhooks that can consume" + " arbitrary payload." + ), + "parameters_schema": WEBHOOKS_PARAMETERS_SCHEMA, + "payload_schema": WEBHOOKS_PAYLOAD_SCHEMA, } } WEBHOOK_TRIGGER_TYPE = list(WEBHOOK_TRIGGER_TYPES.keys())[0] @@ -265,107 +221,69 @@ INTERVAL_PARAMETERS_SCHEMA = { "type": "object", "properties": { - "timezone": { - "type": "string" - }, + "timezone": {"type": "string"}, "unit": { "enum": ["weeks", "days", "hours", "minutes", "seconds"], - "required": True + "required": True, }, - "delta": { - "type": "integer", - "required": True - - } + "delta": {"type": "integer", "required": True}, }, - "additionalProperties": False + "additionalProperties": False, } DATE_PARAMETERS_SCHEMA = { "type": "object", "properties": { - "timezone": { - "type": "string" - }, - "date": { - "type": "string", - "format": "date-time", - "required": True - } + "timezone": {"type": "string"}, + "date": {"type": "string", "format": "date-time", "required": True}, }, - "additionalProperties": False + "additionalProperties": False, } CRON_PARAMETERS_SCHEMA = { "type": "object", "properties": { - "timezone": { - "type": "string" - }, + "timezone": {"type": "string"}, "year": { - "anyOf": [ - {"type": "string"}, - {"type": "integer"} - ], + "anyOf": [{"type": "string"}, {"type": "integer"}], }, "month": { - "anyOf": [ - {"type": "string"}, - {"type": "integer"} - ], + "anyOf": [{"type": "string"}, {"type": "integer"}], "minimum": 1, - "maximum": 12 + "maximum": 12, }, "day": { - "anyOf": [ - {"type": "string"}, - {"type": "integer"} - ], + "anyOf": [{"type": "string"}, {"type": "integer"}], "minimum": 1, - "maximum": 31 + "maximum": 31, }, "week": { - "anyOf": [ - {"type": "string"}, - {"type": "integer"} - ], + "anyOf": [{"type": "string"}, {"type": "integer"}], "minimum": 1, - "maximum": 53 + "maximum": 53, }, "day_of_week": { - "anyOf": [ - {"type": "string"}, - {"type": "integer"} - ], + "anyOf": [{"type": "string"}, {"type": "integer"}], "minimum": 0, - "maximum": 6 + "maximum": 6, }, "hour": { - "anyOf": [ - {"type": "string"}, - {"type": "integer"} - ], + "anyOf": [{"type": "string"}, {"type": "integer"}], "minimum": 0, - "maximum": 23 + "maximum": 23, }, "minute": { - "anyOf": [ - {"type": "string"}, - {"type": "integer"} - ], + "anyOf": [{"type": "string"}, {"type": "integer"}], "minimum": 0, - "maximum": 59 + "maximum": 59, }, "second": { - "anyOf": [ - {"type": "string"}, - {"type": "integer"} - ], + "anyOf": [{"type": "string"}, {"type": "integer"}], "minimum": 0, - "maximum": 59 - } + "maximum": 59, + }, }, - "additionalProperties": False + "additionalProperties": False, } TIMER_PAYLOAD_SCHEMA = { @@ -374,61 +292,62 @@ "executed_at": { "type": "string", "format": "date-time", - "default": "2014-07-30 05:04:24.578325" + "default": "2014-07-30 05:04:24.578325", }, - "schedule": { - "type": "object", - "default": { - "delta": 30, - "units": "seconds" - } - } - } + "schedule": {"type": "object", "default": {"delta": 30, "units": "seconds"}}, + }, } -INTERVAL_TIMER_TRIGGER_REF = ResourceReference.to_string_reference(SYSTEM_PACK_NAME, - 'st2.IntervalTimer') -DATE_TIMER_TRIGGER_REF = ResourceReference.to_string_reference(SYSTEM_PACK_NAME, 'st2.DateTimer') -CRON_TIMER_TRIGGER_REF = ResourceReference.to_string_reference(SYSTEM_PACK_NAME, 'st2.CronTimer') +INTERVAL_TIMER_TRIGGER_REF = ResourceReference.to_string_reference( + SYSTEM_PACK_NAME, "st2.IntervalTimer" +) +DATE_TIMER_TRIGGER_REF = ResourceReference.to_string_reference( + SYSTEM_PACK_NAME, "st2.DateTimer" +) +CRON_TIMER_TRIGGER_REF = ResourceReference.to_string_reference( + SYSTEM_PACK_NAME, "st2.CronTimer" +) TIMER_TRIGGER_TYPES = { INTERVAL_TIMER_TRIGGER_REF: { - 'name': 'st2.IntervalTimer', - 'pack': SYSTEM_PACK_NAME, - 'description': 'Triggers on specified intervals. e.g. every 30s, 1week etc.', - 'payload_schema': TIMER_PAYLOAD_SCHEMA, - 'parameters_schema': INTERVAL_PARAMETERS_SCHEMA + "name": "st2.IntervalTimer", + "pack": SYSTEM_PACK_NAME, + "description": "Triggers on specified intervals. e.g. every 30s, 1week etc.", + "payload_schema": TIMER_PAYLOAD_SCHEMA, + "parameters_schema": INTERVAL_PARAMETERS_SCHEMA, }, DATE_TIMER_TRIGGER_REF: { - 'name': 'st2.DateTimer', - 'pack': SYSTEM_PACK_NAME, - 'description': 'Triggers exactly once when the current time matches the specified time. ' - 'e.g. timezone:UTC date:2014-12-31 23:59:59.', - 'payload_schema': TIMER_PAYLOAD_SCHEMA, - 'parameters_schema': DATE_PARAMETERS_SCHEMA + "name": "st2.DateTimer", + "pack": SYSTEM_PACK_NAME, + "description": "Triggers exactly once when the current time matches the specified time. " + "e.g. timezone:UTC date:2014-12-31 23:59:59.", + "payload_schema": TIMER_PAYLOAD_SCHEMA, + "parameters_schema": DATE_PARAMETERS_SCHEMA, }, CRON_TIMER_TRIGGER_REF: { - 'name': 'st2.CronTimer', - 'pack': SYSTEM_PACK_NAME, - 'description': 'Triggers whenever current time matches the specified time constaints like ' - 'a UNIX cron scheduler.', - 'payload_schema': TIMER_PAYLOAD_SCHEMA, - 'parameters_schema': CRON_PARAMETERS_SCHEMA - } + "name": "st2.CronTimer", + "pack": SYSTEM_PACK_NAME, + "description": "Triggers whenever current time matches the specified time constaints like " + "a UNIX cron scheduler.", + "payload_schema": TIMER_PAYLOAD_SCHEMA, + "parameters_schema": CRON_PARAMETERS_SCHEMA, + }, } -SYSTEM_TRIGGER_TYPES = dict(list(WEBHOOK_TRIGGER_TYPES.items()) + list(TIMER_TRIGGER_TYPES.items())) +SYSTEM_TRIGGER_TYPES = dict( + list(WEBHOOK_TRIGGER_TYPES.items()) + list(TIMER_TRIGGER_TYPES.items()) +) # various status to record lifecycle of a TriggerInstance -TRIGGER_INSTANCE_PENDING = 'pending' -TRIGGER_INSTANCE_PROCESSING = 'processing' -TRIGGER_INSTANCE_PROCESSED = 'processed' -TRIGGER_INSTANCE_PROCESSING_FAILED = 'processing_failed' +TRIGGER_INSTANCE_PENDING = "pending" +TRIGGER_INSTANCE_PROCESSING = "processing" +TRIGGER_INSTANCE_PROCESSED = "processed" +TRIGGER_INSTANCE_PROCESSING_FAILED = "processing_failed" TRIGGER_INSTANCE_STATUSES = [ TRIGGER_INSTANCE_PENDING, TRIGGER_INSTANCE_PROCESSING, TRIGGER_INSTANCE_PROCESSED, - TRIGGER_INSTANCE_PROCESSING_FAILED + TRIGGER_INSTANCE_PROCESSING_FAILED, ] diff --git a/st2common/st2common/constants/types.py b/st2common/st2common/constants/types.py index 7873d5b6650..01ec79605fe 100644 --- a/st2common/st2common/constants/types.py +++ b/st2common/st2common/constants/types.py @@ -16,9 +16,7 @@ from __future__ import absolute_import from st2common.util.enum import Enum -__all__ = [ - 'ResourceType' -] +__all__ = ["ResourceType"] class ResourceType(Enum): @@ -27,37 +25,37 @@ class ResourceType(Enum): """ # System resources - RUNNER_TYPE = 'runner_type' + RUNNER_TYPE = "runner_type" # Pack resources - PACK = 'pack' - ACTION = 'action' - ACTION_ALIAS = 'action_alias' - SENSOR_TYPE = 'sensor_type' - TRIGGER_TYPE = 'trigger_type' - TRIGGER = 'trigger' - TRIGGER_INSTANCE = 'trigger_instance' - RULE = 'rule' - RULE_ENFORCEMENT = 'rule_enforcement' + PACK = "pack" + ACTION = "action" + ACTION_ALIAS = "action_alias" + SENSOR_TYPE = "sensor_type" + TRIGGER_TYPE = "trigger_type" + TRIGGER = "trigger" + TRIGGER_INSTANCE = "trigger_instance" + RULE = "rule" + RULE_ENFORCEMENT = "rule_enforcement" # Note: Policy type is a global resource and policy belong to a pack - POLICY_TYPE = 'policy_type' - POLICY = 'policy' + POLICY_TYPE = "policy_type" + POLICY = "policy" # Other resources - EXECUTION = 'execution' - EXECUTION_REQUEST = 'execution_request' - KEY_VALUE_PAIR = 'key_value_pair' + EXECUTION = "execution" + EXECUTION_REQUEST = "execution_request" + KEY_VALUE_PAIR = "key_value_pair" - WEBHOOK = 'webhook' - TIMER = 'timer' - API_KEY = 'api_key' - TRACE = 'trace' - TIMER = 'timer' + WEBHOOK = "webhook" + TIMER = "timer" + API_KEY = "api_key" + TRACE = "trace" + TIMER = "timer" # Special resource type for stream related stuff - STREAM = 'stream' + STREAM = "stream" - INQUIRY = 'inquiry' + INQUIRY = "inquiry" - UNKNOWN = 'unknown' + UNKNOWN = "unknown" diff --git a/st2common/st2common/content/bootstrap.py b/st2common/st2common/content/bootstrap.py index 1072d350534..d38296bd5a9 100644 --- a/st2common/st2common/content/bootstrap.py +++ b/st2common/st2common/content/bootstrap.py @@ -38,46 +38,55 @@ from st2common.metrics.base import Timer from st2common.util.virtualenvs import setup_pack_virtualenv -__all__ = [ - 'main' -] +__all__ = ["main"] -LOG = logging.getLogger('st2common.content.bootstrap') +LOG = logging.getLogger("st2common.content.bootstrap") -cfg.CONF.register_cli_opt(cfg.BoolOpt('experimental', default=False)) +cfg.CONF.register_cli_opt(cfg.BoolOpt("experimental", default=False)) def register_opts(): content_opts = [ - cfg.BoolOpt('all', default=False, help='Register sensors, actions and rules.'), - cfg.BoolOpt('triggers', default=False, help='Register triggers.'), - cfg.BoolOpt('sensors', default=False, help='Register sensors.'), - cfg.BoolOpt('actions', default=False, help='Register actions.'), - cfg.BoolOpt('runners', default=False, help='Register runners.'), - cfg.BoolOpt('rules', default=False, help='Register rules.'), - cfg.BoolOpt('aliases', default=False, help='Register aliases.'), - cfg.BoolOpt('policies', default=False, help='Register policies.'), - cfg.BoolOpt('configs', default=False, help='Register and load pack configs.'), - - cfg.StrOpt('pack', default=None, help='Directory to the pack to register content from.'), - cfg.StrOpt('runner-dir', default=None, help='Directory to load runners from.'), - cfg.BoolOpt('setup-virtualenvs', default=False, help=('Setup Python virtual environments ' - 'all the Python runner actions.')), - + cfg.BoolOpt("all", default=False, help="Register sensors, actions and rules."), + cfg.BoolOpt("triggers", default=False, help="Register triggers."), + cfg.BoolOpt("sensors", default=False, help="Register sensors."), + cfg.BoolOpt("actions", default=False, help="Register actions."), + cfg.BoolOpt("runners", default=False, help="Register runners."), + cfg.BoolOpt("rules", default=False, help="Register rules."), + cfg.BoolOpt("aliases", default=False, help="Register aliases."), + cfg.BoolOpt("policies", default=False, help="Register policies."), + cfg.BoolOpt("configs", default=False, help="Register and load pack configs."), + cfg.StrOpt( + "pack", default=None, help="Directory to the pack to register content from." + ), + cfg.StrOpt("runner-dir", default=None, help="Directory to load runners from."), + cfg.BoolOpt( + "setup-virtualenvs", + default=False, + help=( + "Setup Python virtual environments " "all the Python runner actions." + ), + ), # General options # Note: This value should default to False since we want fail on failure behavior by # default. - cfg.BoolOpt('no-fail-on-failure', default=False, - help=('Don\'t exit with non-zero if some resource registration fails.')), + cfg.BoolOpt( + "no-fail-on-failure", + default=False, + help=("Don't exit with non-zero if some resource registration fails."), + ), # Note: Fail on failure is now a default behavior. This flag is only left here for backward # compatibility reasons, but it's not actually used. - cfg.BoolOpt('fail-on-failure', default=True, - help=('Exit with non-zero if some resource registration fails.')) + cfg.BoolOpt( + "fail-on-failure", + default=True, + help=("Exit with non-zero if some resource registration fails."), + ), ] try: - cfg.CONF.register_cli_opts(content_opts, group='register') + cfg.CONF.register_cli_opts(content_opts, group="register") except: - sys.stderr.write('Failed registering opts.\n') + sys.stderr.write("Failed registering opts.\n") register_opts() @@ -88,9 +97,9 @@ def setup_virtualenvs(): Setup Python virtual environments for all the registered or the provided pack. """ - LOG.info('=========================================================') - LOG.info('########### Setting up virtual environments #############') - LOG.info('=========================================================') + LOG.info("=========================================================") + LOG.info("########### Setting up virtual environments #############") + LOG.info("=========================================================") pack_dir = cfg.CONF.register.pack fail_on_failure = not cfg.CONF.register.no_fail_on_failure @@ -116,15 +125,19 @@ def setup_virtualenvs(): setup_pack_virtualenv(pack_name=pack_name, update=True, logger=LOG) except Exception as e: exc_info = not fail_on_failure - LOG.warning('Failed to setup virtualenv for pack "%s": %s', pack_name, e, - exc_info=exc_info) + LOG.warning( + 'Failed to setup virtualenv for pack "%s": %s', + pack_name, + e, + exc_info=exc_info, + ) if fail_on_failure: raise e else: setup_count += 1 - LOG.info('Setup virtualenv for %s pack(s).' % (setup_count)) + LOG.info("Setup virtualenv for %s pack(s)." % (setup_count)) def register_triggers(): @@ -134,22 +147,21 @@ def register_triggers(): registered_count = 0 try: - LOG.info('=========================================================') - LOG.info('############## Registering triggers #####################') - LOG.info('=========================================================') - with Timer(key='st2.register.triggers'): + LOG.info("=========================================================") + LOG.info("############## Registering triggers #####################") + LOG.info("=========================================================") + with Timer(key="st2.register.triggers"): registered_count = triggers_registrar.register_triggers( - pack_dir=pack_dir, - fail_on_failure=fail_on_failure + pack_dir=pack_dir, fail_on_failure=fail_on_failure ) except Exception as e: exc_info = not fail_on_failure - LOG.warning('Failed to register sensors: %s', e, exc_info=exc_info) + LOG.warning("Failed to register sensors: %s", e, exc_info=exc_info) if fail_on_failure: raise e - LOG.info('Registered %s triggers.' % (registered_count)) + LOG.info("Registered %s triggers." % (registered_count)) def register_sensors(): @@ -159,22 +171,21 @@ def register_sensors(): registered_count = 0 try: - LOG.info('=========================================================') - LOG.info('############## Registering sensors ######################') - LOG.info('=========================================================') - with Timer(key='st2.register.sensors'): + LOG.info("=========================================================") + LOG.info("############## Registering sensors ######################") + LOG.info("=========================================================") + with Timer(key="st2.register.sensors"): registered_count = sensors_registrar.register_sensors( - pack_dir=pack_dir, - fail_on_failure=fail_on_failure + pack_dir=pack_dir, fail_on_failure=fail_on_failure ) except Exception as e: exc_info = not fail_on_failure - LOG.warning('Failed to register sensors: %s', e, exc_info=exc_info) + LOG.warning("Failed to register sensors: %s", e, exc_info=exc_info) if fail_on_failure: raise e - LOG.info('Registered %s sensors.' % (registered_count)) + LOG.info("Registered %s sensors." % (registered_count)) def register_runners(): @@ -184,24 +195,23 @@ def register_runners(): # 1. Register runner types try: - LOG.info('=========================================================') - LOG.info('############## Registering runners ######################') - LOG.info('=========================================================') - with Timer(key='st2.register.runners'): + LOG.info("=========================================================") + LOG.info("############## Registering runners ######################") + LOG.info("=========================================================") + with Timer(key="st2.register.runners"): registered_count = runners_registrar.register_runners( - fail_on_failure=fail_on_failure, - experimental=False + fail_on_failure=fail_on_failure, experimental=False ) except Exception as error: exc_info = not fail_on_failure # TODO: Narrow exception window - LOG.warning('Failed to register runners: %s', error, exc_info=exc_info) + LOG.warning("Failed to register runners: %s", error, exc_info=exc_info) if fail_on_failure: raise error - LOG.info('Registered %s runners.', registered_count) + LOG.info("Registered %s runners.", registered_count) def register_actions(): @@ -213,22 +223,21 @@ def register_actions(): registered_count = 0 try: - LOG.info('=========================================================') - LOG.info('############## Registering actions ######################') - LOG.info('=========================================================') - with Timer(key='st2.register.actions'): + LOG.info("=========================================================") + LOG.info("############## Registering actions ######################") + LOG.info("=========================================================") + with Timer(key="st2.register.actions"): registered_count = actions_registrar.register_actions( - pack_dir=pack_dir, - fail_on_failure=fail_on_failure + pack_dir=pack_dir, fail_on_failure=fail_on_failure ) except Exception as e: exc_info = not fail_on_failure - LOG.warning('Failed to register actions: %s', e, exc_info=exc_info) + LOG.warning("Failed to register actions: %s", e, exc_info=exc_info) if fail_on_failure: raise e - LOG.info('Registered %s actions.' % (registered_count)) + LOG.info("Registered %s actions." % (registered_count)) def register_rules(): @@ -239,28 +248,27 @@ def register_rules(): registered_count = 0 try: - LOG.info('=========================================================') - LOG.info('############## Registering rules ########################') - LOG.info('=========================================================') + LOG.info("=========================================================") + LOG.info("############## Registering rules ########################") + LOG.info("=========================================================") rule_types_registrar.register_rule_types() except Exception as e: - LOG.warning('Failed to register rule types: %s', e, exc_info=True) + LOG.warning("Failed to register rule types: %s", e, exc_info=True) return try: - with Timer(key='st2.register.rules'): + with Timer(key="st2.register.rules"): registered_count = rules_registrar.register_rules( - pack_dir=pack_dir, - fail_on_failure=fail_on_failure + pack_dir=pack_dir, fail_on_failure=fail_on_failure ) except Exception as e: exc_info = not fail_on_failure - LOG.warning('Failed to register rules: %s', e, exc_info=exc_info) + LOG.warning("Failed to register rules: %s", e, exc_info=exc_info) if fail_on_failure: raise e - LOG.info('Registered %s rules.', registered_count) + LOG.info("Registered %s rules.", registered_count) def register_aliases(): @@ -270,21 +278,20 @@ def register_aliases(): registered_count = 0 try: - LOG.info('=========================================================') - LOG.info('############## Registering aliases ######################') - LOG.info('=========================================================') - with Timer(key='st2.register.aliases'): + LOG.info("=========================================================") + LOG.info("############## Registering aliases ######################") + LOG.info("=========================================================") + with Timer(key="st2.register.aliases"): registered_count = aliases_registrar.register_aliases( - pack_dir=pack_dir, - fail_on_failure=fail_on_failure + pack_dir=pack_dir, fail_on_failure=fail_on_failure ) except Exception as e: if fail_on_failure: raise e - LOG.warning('Failed to register aliases.', exc_info=True) + LOG.warning("Failed to register aliases.", exc_info=True) - LOG.info('Registered %s aliases.', registered_count) + LOG.info("Registered %s aliases.", registered_count) def register_policies(): @@ -295,31 +302,32 @@ def register_policies(): registered_type_count = 0 try: - LOG.info('=========================================================') - LOG.info('############## Registering policy types #################') - LOG.info('=========================================================') - with Timer(key='st2.register.policies'): + LOG.info("=========================================================") + LOG.info("############## Registering policy types #################") + LOG.info("=========================================================") + with Timer(key="st2.register.policies"): registered_type_count = policies_registrar.register_policy_types(st2common) except Exception: - LOG.warning('Failed to register policy types.', exc_info=True) + LOG.warning("Failed to register policy types.", exc_info=True) - LOG.info('Registered %s policy types.', registered_type_count) + LOG.info("Registered %s policy types.", registered_type_count) registered_count = 0 try: - LOG.info('=========================================================') - LOG.info('############## Registering policies #####################') - LOG.info('=========================================================') - registered_count = policies_registrar.register_policies(pack_dir=pack_dir, - fail_on_failure=fail_on_failure) + LOG.info("=========================================================") + LOG.info("############## Registering policies #####################") + LOG.info("=========================================================") + registered_count = policies_registrar.register_policies( + pack_dir=pack_dir, fail_on_failure=fail_on_failure + ) except Exception as e: exc_info = not fail_on_failure - LOG.warning('Failed to register policies: %s', e, exc_info=exc_info) + LOG.warning("Failed to register policies: %s", e, exc_info=exc_info) if fail_on_failure: raise e - LOG.info('Registered %s policies.', registered_count) + LOG.info("Registered %s policies.", registered_count) def register_configs(): @@ -329,23 +337,23 @@ def register_configs(): registered_count = 0 try: - LOG.info('=========================================================') - LOG.info('############## Registering configs ######################') - LOG.info('=========================================================') - with Timer(key='st2.register.configs'): + LOG.info("=========================================================") + LOG.info("############## Registering configs ######################") + LOG.info("=========================================================") + with Timer(key="st2.register.configs"): registered_count = configs_registrar.register_configs( pack_dir=pack_dir, fail_on_failure=fail_on_failure, - validate_configs=True + validate_configs=True, ) except Exception as e: exc_info = not fail_on_failure - LOG.warning('Failed to register configs: %s', e, exc_info=exc_info) + LOG.warning("Failed to register configs: %s", e, exc_info=exc_info) if fail_on_failure: raise e - LOG.info('Registered %s configs.' % (registered_count)) + LOG.info("Registered %s configs." % (registered_count)) def register_content(): @@ -395,8 +403,12 @@ def register_content(): def setup(argv): - common_setup(config=config, setup_db=True, register_mq_exchanges=True, - register_internal_trigger_types=True) + common_setup( + config=config, + setup_db=True, + register_mq_exchanges=True, + register_internal_trigger_types=True, + ) def teardown(): @@ -410,5 +422,5 @@ def main(argv): # This script registers actions and rules from content-packs. -if __name__ == '__main__': +if __name__ == "__main__": main(sys.argv[1:]) diff --git a/st2common/st2common/content/loader.py b/st2common/st2common/content/loader.py index 0dfae4c0b6a..420323fd76e 100644 --- a/st2common/st2common/content/loader.py +++ b/st2common/st2common/content/loader.py @@ -28,10 +28,7 @@ if six.PY2: from io import open -__all__ = [ - 'ContentPackLoader', - 'MetaLoader' -] +__all__ = ["ContentPackLoader", "MetaLoader"] LOG = logging.getLogger(__name__) @@ -45,12 +42,12 @@ class ContentPackLoader(object): # content - they just return a path ALLOWED_CONTENT_TYPES = [ - 'triggers', - 'sensors', - 'actions', - 'rules', - 'aliases', - 'policies' + "triggers", + "sensors", + "actions", + "rules", + "aliases", + "policies", ] def get_packs(self, base_dirs): @@ -91,7 +88,7 @@ def get_content(self, base_dirs, content_type): assert isinstance(base_dirs, list) if content_type not in self.ALLOWED_CONTENT_TYPES: - raise ValueError('Unsupported content_type: %s' % (content_type)) + raise ValueError("Unsupported content_type: %s" % (content_type)) content = {} pack_to_dir_map = {} @@ -99,14 +96,18 @@ def get_content(self, base_dirs, content_type): if not os.path.isdir(base_dir): raise ValueError('Directory "%s" doesn\'t exist' % (base_dir)) - dir_content = self._get_content_from_dir(base_dir=base_dir, content_type=content_type) + dir_content = self._get_content_from_dir( + base_dir=base_dir, content_type=content_type + ) # Check for duplicate packs for pack_name, pack_content in six.iteritems(dir_content): if pack_name in content: pack_dir = pack_to_dir_map[pack_name] - LOG.warning('Pack "%s" already found in "%s", ignoring content from "%s"' % - (pack_name, pack_dir, base_dir)) + LOG.warning( + 'Pack "%s" already found in "%s", ignoring content from "%s"' + % (pack_name, pack_dir, base_dir) + ) else: content[pack_name] = pack_content pack_to_dir_map[pack_name] = base_dir @@ -126,13 +127,14 @@ def get_content_from_pack(self, pack_dir, content_type): :rtype: ``str`` """ if content_type not in self.ALLOWED_CONTENT_TYPES: - raise ValueError('Unsupported content_type: %s' % (content_type)) + raise ValueError("Unsupported content_type: %s" % (content_type)) if not os.path.isdir(pack_dir): raise ValueError('Directory "%s" doesn\'t exist' % (pack_dir)) - content = self._get_content_from_pack_dir(pack_dir=pack_dir, - content_type=content_type) + content = self._get_content_from_pack_dir( + pack_dir=pack_dir, content_type=content_type + ) return content def _get_packs_from_dir(self, base_dir): @@ -154,8 +156,9 @@ def _get_content_from_dir(self, base_dir, content_type): # Ignore missing or non directories try: - pack_content = self._get_content_from_pack_dir(pack_dir=pack_dir, - content_type=content_type) + pack_content = self._get_content_from_pack_dir( + pack_dir=pack_dir, content_type=content_type + ) except ValueError: continue else: @@ -170,13 +173,13 @@ def _get_content_from_pack_dir(self, pack_dir, content_type): actions=self._get_actions, rules=self._get_rules, aliases=self._get_aliases, - policies=self._get_policies + policies=self._get_policies, ) get_func = content_types.get(content_type) if get_func is None: - raise ValueError('Invalid content_type: %s' % (content_type)) + raise ValueError("Invalid content_type: %s" % (content_type)) if not os.path.isdir(pack_dir): raise ValueError('Directory "%s" doesn\'t exist' % (pack_dir)) @@ -185,22 +188,22 @@ def _get_content_from_pack_dir(self, pack_dir, content_type): return pack_content def _get_triggers(self, pack_dir): - return self._get_folder(pack_dir=pack_dir, content_type='triggers') + return self._get_folder(pack_dir=pack_dir, content_type="triggers") def _get_sensors(self, pack_dir): - return self._get_folder(pack_dir=pack_dir, content_type='sensors') + return self._get_folder(pack_dir=pack_dir, content_type="sensors") def _get_actions(self, pack_dir): - return self._get_folder(pack_dir=pack_dir, content_type='actions') + return self._get_folder(pack_dir=pack_dir, content_type="actions") def _get_rules(self, pack_dir): - return self._get_folder(pack_dir=pack_dir, content_type='rules') + return self._get_folder(pack_dir=pack_dir, content_type="rules") def _get_aliases(self, pack_dir): - return self._get_folder(pack_dir=pack_dir, content_type='aliases') + return self._get_folder(pack_dir=pack_dir, content_type="aliases") def _get_policies(self, pack_dir): - return self._get_folder(pack_dir=pack_dir, content_type='policies') + return self._get_folder(pack_dir=pack_dir, content_type="policies") def _get_folder(self, pack_dir, content_type): path = os.path.join(pack_dir, content_type) @@ -233,8 +236,10 @@ def load(self, file_path, expected_type=None): file_name, file_ext = os.path.splitext(file_path) if file_ext not in ALLOWED_EXTS: - raise Exception('Unsupported meta type %s, file %s. Allowed: %s' % - (file_ext, file_path, ALLOWED_EXTS)) + raise Exception( + "Unsupported meta type %s, file %s. Allowed: %s" + % (file_ext, file_path, ALLOWED_EXTS) + ) result = self._load(PARSER_FUNCS[file_ext], file_path) @@ -246,12 +251,12 @@ def load(self, file_path, expected_type=None): return result def _load(self, parser_func, file_path): - with open(file_path, 'r', encoding='utf-8') as fd: + with open(file_path, "r", encoding="utf-8") as fd: try: return parser_func(fd) except ValueError: - LOG.exception('Failed loading content from %s.', file_path) + LOG.exception("Failed loading content from %s.", file_path) raise except ParserError: - LOG.exception('Failed loading content from %s.', file_path) + LOG.exception("Failed loading content from %s.", file_path) raise diff --git a/st2common/st2common/content/utils.py b/st2common/st2common/content/utils.py index 3bd5e2b12ca..ad9386acf6a 100644 --- a/st2common/st2common/content/utils.py +++ b/st2common/st2common/content/utils.py @@ -24,22 +24,24 @@ from st2common.util.shell import quote_unix __all__ = [ - 'get_pack_group', - 'get_system_packs_base_path', - 'get_packs_base_paths', - 'get_pack_base_path', - 'get_pack_directory', - 'get_pack_file_abs_path', - 'get_pack_resource_file_abs_path', - 'get_relative_path_to_pack_file', - 'check_pack_directory_exists', - 'check_pack_content_directory_exists' + "get_pack_group", + "get_system_packs_base_path", + "get_packs_base_paths", + "get_pack_base_path", + "get_pack_directory", + "get_pack_file_abs_path", + "get_pack_resource_file_abs_path", + "get_relative_path_to_pack_file", + "check_pack_directory_exists", + "check_pack_content_directory_exists", ] INVALID_FILE_PATH_ERROR = """ Invalid file path: "%s". File path needs to be relative to the pack%sdirectory (%s). For example "my_%s.py". -""".strip().replace('\n', ' ') +""".strip().replace( + "\n", " " +) # Cache which stores pack name -> pack base path mappings PACK_NAME_TO_BASE_PATH_CACHE = {} @@ -70,10 +72,10 @@ def get_packs_base_paths(): :rtype: ``list`` """ system_packs_base_path = get_system_packs_base_path() - packs_base_paths = cfg.CONF.content.packs_base_paths or '' + packs_base_paths = cfg.CONF.content.packs_base_paths or "" # Remove trailing colon (if present) - if packs_base_paths.endswith(':'): + if packs_base_paths.endswith(":"): packs_base_paths = packs_base_paths[:-1] result = [] @@ -81,7 +83,7 @@ def get_packs_base_paths(): if system_packs_base_path: result.append(system_packs_base_path) - packs_base_paths = packs_base_paths.split(':') + packs_base_paths = packs_base_paths.split(":") result = result + packs_base_paths result = [path for path in result if path] @@ -223,22 +225,28 @@ def get_entry_point_abs_path(pack=None, entry_point=None, use_pack_cache=False): return None if os.path.isabs(entry_point): - pack_base_path = get_pack_base_path(pack_name=pack, use_pack_cache=use_pack_cache) + pack_base_path = get_pack_base_path( + pack_name=pack, use_pack_cache=use_pack_cache + ) common_prefix = os.path.commonprefix([pack_base_path, entry_point]) if common_prefix != pack_base_path: - raise ValueError('Entry point file "%s" is located outside of the pack directory' % - (entry_point)) + raise ValueError( + 'Entry point file "%s" is located outside of the pack directory' + % (entry_point) + ) return entry_point - entry_point_abs_path = get_pack_resource_file_abs_path(pack_ref=pack, - resource_type='action', - file_path=entry_point) + entry_point_abs_path = get_pack_resource_file_abs_path( + pack_ref=pack, resource_type="action", file_path=entry_point + ) return entry_point_abs_path -def get_pack_file_abs_path(pack_ref, file_path, resource_type=None, use_pack_cache=False): +def get_pack_file_abs_path( + pack_ref, file_path, resource_type=None, use_pack_cache=False +): """ Retrieve full absolute path to the pack file. @@ -258,36 +266,46 @@ def get_pack_file_abs_path(pack_ref, file_path, resource_type=None, use_pack_cac :rtype: ``str`` """ - pack_base_path = get_pack_base_path(pack_name=pack_ref, use_pack_cache=use_pack_cache) + pack_base_path = get_pack_base_path( + pack_name=pack_ref, use_pack_cache=use_pack_cache + ) if resource_type: - resource_type_plural = ' %ss ' % (resource_type) - resource_base_path = os.path.join(pack_base_path, '%ss/' % (resource_type)) + resource_type_plural = " %ss " % (resource_type) + resource_base_path = os.path.join(pack_base_path, "%ss/" % (resource_type)) else: - resource_type_plural = ' ' + resource_type_plural = " " resource_base_path = pack_base_path path_components = [] path_components.append(pack_base_path) # Normalize the path to prevent directory traversal - normalized_file_path = os.path.normpath('/' + file_path).lstrip('/') + normalized_file_path = os.path.normpath("/" + file_path).lstrip("/") if normalized_file_path != file_path: - msg = INVALID_FILE_PATH_ERROR % (file_path, resource_type_plural, resource_base_path, - resource_type or 'action') + msg = INVALID_FILE_PATH_ERROR % ( + file_path, + resource_type_plural, + resource_base_path, + resource_type or "action", + ) raise ValueError(msg) path_components.append(normalized_file_path) - result = os.path.join(*path_components) # pylint: disable=E1120 + result = os.path.join(*path_components) # pylint: disable=E1120 assert normalized_file_path in result # Final safety check for common prefix to avoid traversal attack common_prefix = os.path.commonprefix([pack_base_path, result]) if common_prefix != pack_base_path: - msg = INVALID_FILE_PATH_ERROR % (file_path, resource_type_plural, resource_base_path, - resource_type or 'action') + msg = INVALID_FILE_PATH_ERROR % ( + file_path, + resource_type_plural, + resource_base_path, + resource_type or "action", + ) raise ValueError(msg) return result @@ -313,19 +331,20 @@ def get_pack_resource_file_abs_path(pack_ref, resource_type, file_path): :rtype: ``str`` """ path_components = [] - if resource_type == 'action': - path_components.append('actions/') - elif resource_type == 'sensor': - path_components.append('sensors/') - elif resource_type == 'rule': - path_components.append('rules/') + if resource_type == "action": + path_components.append("actions/") + elif resource_type == "sensor": + path_components.append("sensors/") + elif resource_type == "rule": + path_components.append("rules/") else: - raise ValueError('Invalid resource type: %s' % (resource_type)) + raise ValueError("Invalid resource type: %s" % (resource_type)) path_components.append(file_path) file_path = os.path.join(*path_components) # pylint: disable=E1120 - result = get_pack_file_abs_path(pack_ref=pack_ref, file_path=file_path, - resource_type=resource_type) + result = get_pack_file_abs_path( + pack_ref=pack_ref, file_path=file_path, resource_type=resource_type + ) return result @@ -341,7 +360,9 @@ def get_relative_path_to_pack_file(pack_ref, file_path, use_pack_cache=False): :rtype: ``str`` """ - pack_base_path = get_pack_base_path(pack_name=pack_ref, use_pack_cache=use_pack_cache) + pack_base_path = get_pack_base_path( + pack_name=pack_ref, use_pack_cache=use_pack_cache + ) if not os.path.isabs(file_path): return file_path @@ -350,8 +371,10 @@ def get_relative_path_to_pack_file(pack_ref, file_path, use_pack_cache=False): common_prefix = os.path.commonprefix([pack_base_path, file_path]) if common_prefix != pack_base_path: - raise ValueError('file_path (%s) is not located inside the pack directory (%s)' % - (file_path, pack_base_path)) + raise ValueError( + "file_path (%s) is not located inside the pack directory (%s)" + % (file_path, pack_base_path) + ) relative_path = os.path.relpath(file_path, common_prefix) return relative_path @@ -381,15 +404,15 @@ def get_aliases_base_paths(): :rtype: ``list`` """ - aliases_base_paths = cfg.CONF.content.aliases_base_paths or '' + aliases_base_paths = cfg.CONF.content.aliases_base_paths or "" # Remove trailing colon (if present) - if aliases_base_paths.endswith(':'): + if aliases_base_paths.endswith(":"): aliases_base_paths = aliases_base_paths[:-1] result = [] - aliases_base_paths = aliases_base_paths.split(':') + aliases_base_paths = aliases_base_paths.split(":") result = aliases_base_paths result = [path for path in result if path] diff --git a/st2common/st2common/content/validators.py b/st2common/st2common/content/validators.py index bba9c446e36..8b1ab822c0f 100644 --- a/st2common/st2common/content/validators.py +++ b/st2common/st2common/content/validators.py @@ -19,20 +19,16 @@ from st2common.constants.pack import USER_PACK_NAME_BLACKLIST -__all__ = [ - 'RequirementsValidator', - 'validate_pack_name' -] +__all__ = ["RequirementsValidator", "validate_pack_name"] class RequirementsValidator(object): - @staticmethod def validate(requirements_file): if not os.path.exists(requirements_file): - raise Exception('Requirements file %s not found.' % requirements_file) + raise Exception("Requirements file %s not found." % requirements_file) missing = [] - with open(requirements_file, 'r') as f: + with open(requirements_file, "r") as f: for line in f: rqmnt = line.strip() try: @@ -54,10 +50,9 @@ def validate_pack_name(name): :rtype: ``str`` """ if not name: - raise ValueError('Content pack name cannot be empty') + raise ValueError("Content pack name cannot be empty") if name.lower() in USER_PACK_NAME_BLACKLIST: - raise ValueError('Name "%s" is blacklisted and can\'t be used' % - (name.lower())) + raise ValueError('Name "%s" is blacklisted and can\'t be used' % (name.lower())) return name diff --git a/st2common/st2common/database_setup.py b/st2common/st2common/database_setup.py index 2678ecbf2ed..2e2e7d2a17c 100644 --- a/st2common/st2common/database_setup.py +++ b/st2common/st2common/database_setup.py @@ -23,29 +23,27 @@ from st2common.models import db from st2common.persistence import db_init -__all__ = [ - 'db_config', - 'db_setup', - 'db_teardown' -] +__all__ = ["db_config", "db_setup", "db_teardown"] def db_config(): - username = getattr(cfg.CONF.database, 'username', None) - password = getattr(cfg.CONF.database, 'password', None) - - return {'db_name': cfg.CONF.database.db_name, - 'db_host': cfg.CONF.database.host, - 'db_port': cfg.CONF.database.port, - 'username': username, - 'password': password, - 'ssl': cfg.CONF.database.ssl, - 'ssl_keyfile': cfg.CONF.database.ssl_keyfile, - 'ssl_certfile': cfg.CONF.database.ssl_certfile, - 'ssl_cert_reqs': cfg.CONF.database.ssl_cert_reqs, - 'ssl_ca_certs': cfg.CONF.database.ssl_ca_certs, - 'authentication_mechanism': cfg.CONF.database.authentication_mechanism, - 'ssl_match_hostname': cfg.CONF.database.ssl_match_hostname} + username = getattr(cfg.CONF.database, "username", None) + password = getattr(cfg.CONF.database, "password", None) + + return { + "db_name": cfg.CONF.database.db_name, + "db_host": cfg.CONF.database.host, + "db_port": cfg.CONF.database.port, + "username": username, + "password": password, + "ssl": cfg.CONF.database.ssl, + "ssl_keyfile": cfg.CONF.database.ssl_keyfile, + "ssl_certfile": cfg.CONF.database.ssl_certfile, + "ssl_cert_reqs": cfg.CONF.database.ssl_cert_reqs, + "ssl_ca_certs": cfg.CONF.database.ssl_ca_certs, + "authentication_mechanism": cfg.CONF.database.authentication_mechanism, + "ssl_match_hostname": cfg.CONF.database.ssl_match_hostname, + } def db_setup(ensure_indexes=True): @@ -53,7 +51,7 @@ def db_setup(ensure_indexes=True): Creates the database and indexes (optional). """ db_cfg = db_config() - db_cfg['ensure_indexes'] = ensure_indexes + db_cfg["ensure_indexes"] = ensure_indexes connection = db_init.db_setup_with_retry(**db_cfg) return connection diff --git a/st2common/st2common/exceptions/__init__.py b/st2common/st2common/exceptions/__init__.py index ec4e9430e93..065d3ff0fea 100644 --- a/st2common/st2common/exceptions/__init__.py +++ b/st2common/st2common/exceptions/__init__.py @@ -16,24 +16,26 @@ class StackStormBaseException(Exception): """ - The root of the exception class hierarchy for all - StackStorm server exceptions. + The root of the exception class hierarchy for all + StackStorm server exceptions. - For exceptions raised by plug-ins, see StackStormPluginException - class. + For exceptions raised by plug-ins, see StackStormPluginException + class. """ + pass class StackStormPluginException(StackStormBaseException): """ - The root of the exception class hierarchy for all - exceptions that are defined as part of a StackStorm - plug-in API. - - It is recommended that each API define a root exception - class for the API. This root exception class for the - API should inherit from the StackStormPluginException - class. + The root of the exception class hierarchy for all + exceptions that are defined as part of a StackStorm + plug-in API. + + It is recommended that each API define a root exception + class for the API. This root exception class for the + API should inherit from the StackStormPluginException + class. """ + pass diff --git a/st2common/st2common/exceptions/action.py b/st2common/st2common/exceptions/action.py index f7ed4302665..f4bba2ee752 100644 --- a/st2common/st2common/exceptions/action.py +++ b/st2common/st2common/exceptions/action.py @@ -17,9 +17,9 @@ from st2common.exceptions import StackStormBaseException __all__ = [ - 'ParameterRenderingFailedException', - 'InvalidActionReferencedException', - 'InvalidActionParameterException' + "ParameterRenderingFailedException", + "InvalidActionReferencedException", + "InvalidActionParameterException", ] diff --git a/st2common/st2common/exceptions/actionalias.py b/st2common/st2common/exceptions/actionalias.py index 1c01cd5736a..3172a72dc67 100644 --- a/st2common/st2common/exceptions/actionalias.py +++ b/st2common/st2common/exceptions/actionalias.py @@ -16,9 +16,7 @@ from __future__ import absolute_import from st2common.exceptions import StackStormBaseException -__all__ = [ - 'ActionAliasAmbiguityException' -] +__all__ = ["ActionAliasAmbiguityException"] class ActionAliasAmbiguityException(ValueError, StackStormBaseException): diff --git a/st2common/st2common/exceptions/api.py b/st2common/st2common/exceptions/api.py index f5aee1c1c05..054eb1bcf18 100644 --- a/st2common/st2common/exceptions/api.py +++ b/st2common/st2common/exceptions/api.py @@ -16,8 +16,7 @@ from __future__ import absolute_import from st2common.exceptions import StackStormBaseException -__all__ = [ -] +__all__ = [] class InternalServerErrorException(StackStormBaseException): diff --git a/st2common/st2common/exceptions/auth.py b/st2common/st2common/exceptions/auth.py index 429d597abd4..5eab1915f5f 100644 --- a/st2common/st2common/exceptions/auth.py +++ b/st2common/st2common/exceptions/auth.py @@ -18,19 +18,19 @@ from st2common.exceptions.db import StackStormDBObjectNotFoundError __all__ = [ - 'TokenNotProvidedError', - 'TokenNotFoundError', - 'TokenExpiredError', - 'TTLTooLargeException', - 'ApiKeyNotProvidedError', - 'ApiKeyNotFoundError', - 'MultipleAuthSourcesError', - 'NoAuthSourceProvidedError', - 'NoNicknameOriginProvidedError', - 'UserNotFoundError', - 'AmbiguousUserError', - 'NotServiceUserError', - 'SSOVerificationError' + "TokenNotProvidedError", + "TokenNotFoundError", + "TokenExpiredError", + "TTLTooLargeException", + "ApiKeyNotProvidedError", + "ApiKeyNotFoundError", + "MultipleAuthSourcesError", + "NoAuthSourceProvidedError", + "NoNicknameOriginProvidedError", + "UserNotFoundError", + "AmbiguousUserError", + "NotServiceUserError", + "SSOVerificationError", ] diff --git a/st2common/st2common/exceptions/connection.py b/st2common/st2common/exceptions/connection.py index 8cb9681b41f..806d6e1046b 100644 --- a/st2common/st2common/exceptions/connection.py +++ b/st2common/st2common/exceptions/connection.py @@ -16,14 +16,17 @@ class UnknownHostException(Exception): """Raised when a host is unknown (dns failure)""" + pass class ConnectionErrorException(Exception): """Raised on error connecting (connection refused/timed out)""" + pass class AuthenticationException(Exception): """Raised on authentication error (user/password/ssh key error)""" + pass diff --git a/st2common/st2common/exceptions/db.py b/st2common/st2common/exceptions/db.py index fcd607e9648..776d927e0f5 100644 --- a/st2common/st2common/exceptions/db.py +++ b/st2common/st2common/exceptions/db.py @@ -29,6 +29,7 @@ class StackStormDBObjectConflictError(StackStormBaseException): """ Exception that captures a DB object conflict error. """ + def __init__(self, message, conflict_id, model_object): super(StackStormDBObjectConflictError, self).__init__(message) self.conflict_id = conflict_id @@ -36,7 +37,9 @@ def __init__(self, message, conflict_id, model_object): class StackStormDBObjectWriteConflictError(StackStormBaseException): - def __init__(self, instance): - msg = 'Conflict saving DB object with id "%s" and rev "%s".' % (instance.id, instance.rev) + msg = 'Conflict saving DB object with id "%s" and rev "%s".' % ( + instance.id, + instance.rev, + ) super(StackStormDBObjectWriteConflictError, self).__init__(msg) diff --git a/st2common/st2common/exceptions/inquiry.py b/st2common/st2common/exceptions/inquiry.py index 0636d0f985b..b5c3f306462 100644 --- a/st2common/st2common/exceptions/inquiry.py +++ b/st2common/st2common/exceptions/inquiry.py @@ -23,32 +23,33 @@ class InvalidInquiryInstance(st2_exc.StackStormBaseException): - def __init__(self, inquiry_id): - Exception.__init__(self, 'Action execution "%s" is not an inquiry.' % inquiry_id) + Exception.__init__( + self, 'Action execution "%s" is not an inquiry.' % inquiry_id + ) class InquiryTimedOut(st2_exc.StackStormBaseException): - def __init__(self, inquiry_id): - Exception.__init__(self, 'Inquiry "%s" timed out and cannot be responded to.' % inquiry_id) + Exception.__init__( + self, 'Inquiry "%s" timed out and cannot be responded to.' % inquiry_id + ) class InquiryAlreadyResponded(st2_exc.StackStormBaseException): - def __init__(self, inquiry_id): - Exception.__init__(self, 'Inquiry "%s" has already been responded to.' % inquiry_id) + Exception.__init__( + self, 'Inquiry "%s" has already been responded to.' % inquiry_id + ) class InquiryResponseUnauthorized(st2_exc.StackStormBaseException): - def __init__(self, inquiry_id, user): msg = 'User "%s" does not have permission to respond to inquiry "%s".' Exception.__init__(self, msg % (user, inquiry_id)) class InvalidInquiryResponse(st2_exc.StackStormBaseException): - def __init__(self, inquiry_id, error): msg = 'Response for inquiry "%s" did not pass schema validation. %s' Exception.__init__(self, msg % (inquiry_id, error)) diff --git a/st2common/st2common/exceptions/keyvalue.py b/st2common/st2common/exceptions/keyvalue.py index 6ef2702fe83..7fccb8b8192 100644 --- a/st2common/st2common/exceptions/keyvalue.py +++ b/st2common/st2common/exceptions/keyvalue.py @@ -18,9 +18,9 @@ from st2common.exceptions.db import StackStormDBObjectNotFoundError __all__ = [ - 'CryptoKeyNotSetupException', - 'DataStoreKeyNotFoundError', - 'InvalidScopeException' + "CryptoKeyNotSetupException", + "DataStoreKeyNotFoundError", + "InvalidScopeException", ] diff --git a/st2common/st2common/exceptions/rbac.py b/st2common/st2common/exceptions/rbac.py index 308110c267a..957b0fe5bea 100644 --- a/st2common/st2common/exceptions/rbac.py +++ b/st2common/st2common/exceptions/rbac.py @@ -18,10 +18,10 @@ from st2common.rbac.types import GLOBAL_PERMISSION_TYPES __all__ = [ - 'AccessDeniedError', - 'ResourceTypeAccessDeniedError', - 'ResourceAccessDeniedError', - 'ResourceAccessDeniedPermissionIsolationError' + "AccessDeniedError", + "ResourceTypeAccessDeniedError", + "ResourceAccessDeniedError", + "ResourceAccessDeniedPermissionIsolationError", ] @@ -45,9 +45,13 @@ class ResourceTypeAccessDeniedError(AccessDeniedError): def __init__(self, user_db, permission_type): self.permission_type = permission_type - message = ('User "%s" doesn\'t have required permission "%s"' % (user_db.name, - permission_type)) - super(ResourceTypeAccessDeniedError, self).__init__(message=message, user_db=user_db) + message = 'User "%s" doesn\'t have required permission "%s"' % ( + user_db.name, + permission_type, + ) + super(ResourceTypeAccessDeniedError, self).__init__( + message=message, user_db=user_db + ) class ResourceAccessDeniedError(AccessDeniedError): @@ -59,15 +63,25 @@ def __init__(self, user_db, resource_api_or_db, permission_type): self.resource_api_db = resource_api_or_db self.permission_type = permission_type - resource_uid = resource_api_or_db.get_uid() if resource_api_or_db else 'unknown' + resource_uid = resource_api_or_db.get_uid() if resource_api_or_db else "unknown" if resource_api_or_db and permission_type not in GLOBAL_PERMISSION_TYPES: - message = ('User "%s" doesn\'t have required permission "%s" on resource "%s"' % - (user_db.name, permission_type, resource_uid)) + message = ( + 'User "%s" doesn\'t have required permission "%s" on resource "%s"' + % ( + user_db.name, + permission_type, + resource_uid, + ) + ) else: - message = ('User "%s" doesn\'t have required permission "%s"' % - (user_db.name, permission_type)) - super(ResourceAccessDeniedError, self).__init__(message=message, user_db=user_db) + message = 'User "%s" doesn\'t have required permission "%s"' % ( + user_db.name, + permission_type, + ) + super(ResourceAccessDeniedError, self).__init__( + message=message, user_db=user_db + ) class ResourceAccessDeniedPermissionIsolationError(AccessDeniedError): @@ -80,9 +94,12 @@ def __init__(self, user_db, resource_api_or_db, permission_type): self.resource_api_db = resource_api_or_db self.permission_type = permission_type - resource_uid = resource_api_or_db.get_uid() if resource_api_or_db else 'unknown' + resource_uid = resource_api_or_db.get_uid() if resource_api_or_db else "unknown" - message = ('User "%s" doesn\'t have access to resource "%s" due to resource permission ' - 'isolation.' % (user_db.name, resource_uid)) - super(ResourceAccessDeniedPermissionIsolationError, self).__init__(message=message, - user_db=user_db) + message = ( + 'User "%s" doesn\'t have access to resource "%s" due to resource permission ' + "isolation." % (user_db.name, resource_uid) + ) + super(ResourceAccessDeniedPermissionIsolationError, self).__init__( + message=message, user_db=user_db + ) diff --git a/st2common/st2common/exceptions/ssh.py b/st2common/st2common/exceptions/ssh.py index f720e54b8a4..7a4e1ee5165 100644 --- a/st2common/st2common/exceptions/ssh.py +++ b/st2common/st2common/exceptions/ssh.py @@ -13,9 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = [ - 'InvalidCredentialsException' -] +__all__ = ["InvalidCredentialsException"] class InvalidCredentialsException(Exception): diff --git a/st2common/st2common/exceptions/workflow.py b/st2common/st2common/exceptions/workflow.py index dd787417c27..2a346819bed 100644 --- a/st2common/st2common/exceptions/workflow.py +++ b/st2common/st2common/exceptions/workflow.py @@ -27,28 +27,25 @@ def retry_on_connection_errors(exc): - LOG.warning('Determining if exception %s should be retried.', type(exc)) + LOG.warning("Determining if exception %s should be retried.", type(exc)) - retrying = ( - isinstance(exc, tooz.coordination.ToozConnectionError) or - isinstance(exc, mongoengine.connection.MongoEngineConnectionError) + retrying = isinstance(exc, tooz.coordination.ToozConnectionError) or isinstance( + exc, mongoengine.connection.MongoEngineConnectionError ) if retrying: - LOG.warning('Retrying operation due to connection error: %s', type(exc)) + LOG.warning("Retrying operation due to connection error: %s", type(exc)) return retrying def retry_on_transient_db_errors(exc): - LOG.warning('Determining if exception %s should be retried.', type(exc)) + LOG.warning("Determining if exception %s should be retried.", type(exc)) - retrying = ( - isinstance(exc, db_exc.StackStormDBObjectWriteConflictError) - ) + retrying = isinstance(exc, db_exc.StackStormDBObjectWriteConflictError) if retrying: - LOG.warning('Retrying operation due to transient database error: %s', type(exc)) + LOG.warning("Retrying operation due to transient database error: %s", type(exc)) return retrying @@ -62,38 +59,37 @@ class WorkflowExecutionException(st2_exc.StackStormBaseException): class WorkflowExecutionNotFoundException(st2_exc.StackStormBaseException): - def __init__(self, ac_ex_id): Exception.__init__( self, - 'Unable to identify any workflow execution that is ' - 'associated to action execution "%s".' % ac_ex_id + "Unable to identify any workflow execution that is " + 'associated to action execution "%s".' % ac_ex_id, ) class AmbiguousWorkflowExecutionException(st2_exc.StackStormBaseException): - def __init__(self, ac_ex_id): Exception.__init__( self, - 'More than one workflow execution is associated ' - 'to action execution "%s".' % ac_ex_id + "More than one workflow execution is associated " + 'to action execution "%s".' % ac_ex_id, ) class WorkflowExecutionIsCompletedException(st2_exc.StackStormBaseException): - def __init__(self, wf_ex_id): - Exception.__init__(self, 'Workflow execution "%s" is already completed.' % wf_ex_id) + Exception.__init__( + self, 'Workflow execution "%s" is already completed.' % wf_ex_id + ) class WorkflowExecutionIsRunningException(st2_exc.StackStormBaseException): - def __init__(self, wf_ex_id): - Exception.__init__(self, 'Workflow execution "%s" is already active.' % wf_ex_id) + Exception.__init__( + self, 'Workflow execution "%s" is already active.' % wf_ex_id + ) class WorkflowExecutionRerunException(st2_exc.StackStormBaseException): - def __init__(self, msg): Exception.__init__(self, msg) diff --git a/st2common/st2common/expressions/functions/data.py b/st2common/st2common/expressions/functions/data.py index d3783e652e6..b240cb72386 100644 --- a/st2common/st2common/expressions/functions/data.py +++ b/st2common/st2common/expressions/functions/data.py @@ -24,13 +24,13 @@ __all__ = [ - 'from_json_string', - 'from_yaml_string', - 'json_escape', - 'jsonpath_query', - 'to_complex', - 'to_json_string', - 'to_yaml_string', + "from_json_string", + "from_yaml_string", + "json_escape", + "jsonpath_query", + "to_complex", + "to_json_string", + "to_yaml_string", ] @@ -42,19 +42,19 @@ def from_yaml_string(value): return yaml.safe_load(six.text_type(value)) -def to_json_string(value, indent=None, sort_keys=False, separators=(',', ': ')): +def to_json_string(value, indent=None, sort_keys=False, separators=(",", ": ")): value = db_util.mongodb_to_python_types(value) options = {} if indent is not None: - options['indent'] = indent + options["indent"] = indent if sort_keys is not None: - options['sort_keys'] = sort_keys + options["sort_keys"] = sort_keys if separators is not None: - options['separators'] = separators + options["separators"] = separators return json.dumps(value, **options) @@ -62,19 +62,19 @@ def to_json_string(value, indent=None, sort_keys=False, separators=(',', ': ')): def to_yaml_string(value, indent=None, allow_unicode=True): value = db_util.mongodb_to_python_types(value) - options = {'default_flow_style': False} + options = {"default_flow_style": False} if indent is not None: - options['indent'] = indent + options["indent"] = indent if allow_unicode is not None: - options['allow_unicode'] = allow_unicode + options["allow_unicode"] = allow_unicode return yaml.safe_dump(value, **options) def json_escape(value): - """ Adds escape sequences to problematic characters in the string + """Adds escape sequences to problematic characters in the string This filter simply passes the value to json.dumps as a convenient way of escaping characters in it However, before returning, we want to strip the double @@ -110,7 +110,7 @@ def to_complex(value): # Magic string to which None type is serialized when using use_none filter -NONE_MAGIC_VALUE = '%*****__%NONE%__*****%' +NONE_MAGIC_VALUE = "%*****__%NONE%__*****%" def use_none(value): diff --git a/st2common/st2common/expressions/functions/datastore.py b/st2common/st2common/expressions/functions/datastore.py index a8e903c3772..bd0e5fbb09d 100644 --- a/st2common/st2common/expressions/functions/datastore.py +++ b/st2common/st2common/expressions/functions/datastore.py @@ -22,9 +22,7 @@ from st2common.util.crypto import read_crypto_key from st2common.util.crypto import symmetric_decrypt -__all__ = [ - 'decrypt_kv' -] +__all__ = ["decrypt_kv"] def decrypt_kv(value): @@ -41,11 +39,13 @@ def decrypt_kv(value): # NOTE: If value is None this indicate key value item doesn't exist and we hrow a more # user-friendly error - if is_kv_item and value == '': + if is_kv_item and value == "": # Build original key name key_name = original_value.get_key_name() - raise ValueError('Referenced datastore item "%s" doesn\'t exist or it contains an empty ' - 'string' % (key_name)) + raise ValueError( + 'Referenced datastore item "%s" doesn\'t exist or it contains an empty ' + "string" % (key_name) + ) crypto_key_path = cfg.CONF.keyvalue.encryption_key_path crypto_key = read_crypto_key(key_path=crypto_key_path) diff --git a/st2common/st2common/expressions/functions/path.py b/st2common/st2common/expressions/functions/path.py index 6081be895cd..d21f301aa18 100644 --- a/st2common/st2common/expressions/functions/path.py +++ b/st2common/st2common/expressions/functions/path.py @@ -16,10 +16,7 @@ from __future__ import absolute_import import os -__all__ = [ - 'basename', - 'dirname' -] +__all__ = ["basename", "dirname"] def basename(path): diff --git a/st2common/st2common/expressions/functions/regex.py b/st2common/st2common/expressions/functions/regex.py index 4db7fe0f65f..4b17b7372f0 100644 --- a/st2common/st2common/expressions/functions/regex.py +++ b/st2common/st2common/expressions/functions/regex.py @@ -17,12 +17,7 @@ import re import six -__all__ = [ - 'regex_match', - 'regex_replace', - 'regex_search', - 'regex_substring' -] +__all__ = ["regex_match", "regex_replace", "regex_search", "regex_substring"] def _get_regex_flags(ignorecase=False): diff --git a/st2common/st2common/expressions/functions/time.py b/st2common/st2common/expressions/functions/time.py index 543fc80938d..d25b8acecca 100644 --- a/st2common/st2common/expressions/functions/time.py +++ b/st2common/st2common/expressions/functions/time.py @@ -19,14 +19,12 @@ import datetime -__all__ = [ - 'to_human_time_from_seconds' -] +__all__ = ["to_human_time_from_seconds"] if six.PY3: long_int = int else: - long_int = long # noqa # pylint: disable=E0602 + long_int = long # noqa # pylint: disable=E0602 def to_human_time_from_seconds(seconds): @@ -39,8 +37,11 @@ def to_human_time_from_seconds(seconds): :rtype: ``str`` """ - assert (isinstance(seconds, int) or isinstance(seconds, int) or - isinstance(seconds, float)) + assert ( + isinstance(seconds, int) + or isinstance(seconds, int) + or isinstance(seconds, float) + ) return _get_human_time(seconds) @@ -59,10 +60,10 @@ def _get_human_time(seconds): return None if seconds == 0: - return '0s' + return "0s" if seconds < 1: - return '%s\u03BCs' % seconds # Microseconds + return "%s\u03BCs" % seconds # Microseconds if isinstance(seconds, float): seconds = long_int(round(seconds)) # Let's lose microseconds. @@ -81,17 +82,17 @@ def _get_human_time(seconds): first_non_zero_pos = next((i for i, x in enumerate(time_parts) if x), None) if first_non_zero_pos is None: - return '0s' + return "0s" else: time_parts = time_parts[first_non_zero_pos:] if len(time_parts) == 1: - return '%ss' % tuple(time_parts) + return "%ss" % tuple(time_parts) elif len(time_parts) == 2: - return '%sm%ss' % tuple(time_parts) + return "%sm%ss" % tuple(time_parts) elif len(time_parts) == 3: - return '%sh%sm%ss' % tuple(time_parts) + return "%sh%sm%ss" % tuple(time_parts) elif len(time_parts) == 4: - return '%sd%sh%sm%ss' % tuple(time_parts) + return "%sd%sh%sm%ss" % tuple(time_parts) elif len(time_parts) == 5: - return '%sy%sd%sh%sm%ss' % tuple(time_parts) + return "%sy%sd%sh%sm%ss" % tuple(time_parts) diff --git a/st2common/st2common/expressions/functions/version.py b/st2common/st2common/expressions/functions/version.py index 2dc8d353f13..825d5965e3c 100644 --- a/st2common/st2common/expressions/functions/version.py +++ b/st2common/st2common/expressions/functions/version.py @@ -17,13 +17,13 @@ import semver __all__ = [ - 'version_compare', - 'version_more_than', - 'version_less_than', - 'version_equal', - 'version_match', - 'version_bump_major', - 'version_bump_minor' + "version_compare", + "version_more_than", + "version_less_than", + "version_equal", + "version_match", + "version_bump_major", + "version_bump_minor", ] diff --git a/st2common/st2common/fields.py b/st2common/st2common/fields.py index 7217365874b..b968e2fdb76 100644 --- a/st2common/st2common/fields.py +++ b/st2common/st2common/fields.py @@ -21,9 +21,7 @@ from st2common.util import date as date_utils -__all__ = [ - 'ComplexDateTimeField' -] +__all__ = ["ComplexDateTimeField"] SECOND_TO_MICROSECONDS = 1000000 @@ -60,7 +58,7 @@ def _microseconds_since_epoch_to_datetime(self, data): :type data: ``int`` """ result = datetime.datetime.utcfromtimestamp(data // SECOND_TO_MICROSECONDS) - microseconds_reminder = (data % SECOND_TO_MICROSECONDS) + microseconds_reminder = data % SECOND_TO_MICROSECONDS result = result.replace(microsecond=microseconds_reminder) result = date_utils.add_utc_tz(result) return result @@ -77,11 +75,13 @@ def _datetime_to_microseconds_since_epoch(self, value): # Verify that the value which is passed in contains UTC timezone # information. if not value.tzinfo or (value.tzinfo.utcoffset(value) != datetime.timedelta(0)): - raise ValueError('Value passed to this function needs to be in UTC timezone') + raise ValueError( + "Value passed to this function needs to be in UTC timezone" + ) seconds = calendar.timegm(value.timetuple()) microseconds_reminder = value.time().microsecond - result = (int(seconds * SECOND_TO_MICROSECONDS) + microseconds_reminder) + result = int(seconds * SECOND_TO_MICROSECONDS) + microseconds_reminder return result def __get__(self, instance, owner): @@ -99,8 +99,7 @@ def __set__(self, instance, value): def validate(self, value): value = self.to_python(value) if not isinstance(value, datetime.datetime): - self.error('Only datetime objects may used in a ' - 'ComplexDateTimeField') + self.error("Only datetime objects may used in a " "ComplexDateTimeField") def to_python(self, value): original_value = value diff --git a/st2common/st2common/garbage_collection/executions.py b/st2common/st2common/garbage_collection/executions.py index ba924e76f23..ae0f3296f40 100644 --- a/st2common/st2common/garbage_collection/executions.py +++ b/st2common/st2common/garbage_collection/executions.py @@ -32,15 +32,14 @@ from st2common.services import action as action_service from st2common.services import workflows as workflow_service -__all__ = [ - 'purge_executions', - 'purge_execution_output_objects' -] +__all__ = ["purge_executions", "purge_execution_output_objects"] -DONE_STATES = [action_constants.LIVEACTION_STATUS_SUCCEEDED, - action_constants.LIVEACTION_STATUS_FAILED, - action_constants.LIVEACTION_STATUS_TIMED_OUT, - action_constants.LIVEACTION_STATUS_CANCELED] +DONE_STATES = [ + action_constants.LIVEACTION_STATUS_SUCCEEDED, + action_constants.LIVEACTION_STATUS_FAILED, + action_constants.LIVEACTION_STATUS_TIMED_OUT, + action_constants.LIVEACTION_STATUS_CANCELED, +] def purge_executions(logger, timestamp, action_ref=None, purge_incomplete=False): @@ -57,90 +56,118 @@ def purge_executions(logger, timestamp, action_ref=None, purge_incomplete=False) :type purge_incomplete: ``bool`` """ if not timestamp: - raise ValueError('Specify a valid timestamp to purge.') + raise ValueError("Specify a valid timestamp to purge.") - logger.info('Purging executions older than timestamp: %s' % - timestamp.strftime('%Y-%m-%dT%H:%M:%S.%fZ')) + logger.info( + "Purging executions older than timestamp: %s" + % timestamp.strftime("%Y-%m-%dT%H:%M:%S.%fZ") + ) filters = {} if purge_incomplete: - filters['start_timestamp__lt'] = timestamp + filters["start_timestamp__lt"] = timestamp else: - filters['end_timestamp__lt'] = timestamp - filters['start_timestamp__lt'] = timestamp - filters['status'] = {'$in': DONE_STATES} + filters["end_timestamp__lt"] = timestamp + filters["start_timestamp__lt"] = timestamp + filters["status"] = {"$in": DONE_STATES} exec_filters = copy.copy(filters) if action_ref: - exec_filters['action__ref'] = action_ref + exec_filters["action__ref"] = action_ref liveaction_filters = copy.deepcopy(filters) if action_ref: - liveaction_filters['action'] = action_ref + liveaction_filters["action"] = action_ref to_delete_execution_dbs = [] # 1. Delete ActionExecutionDB objects try: # Note: We call list() on the query set object because it's lazyily evaluated otherwise - to_delete_execution_dbs = list(ActionExecution.query(only_fields=['id'], - no_dereference=True, - **exec_filters)) + to_delete_execution_dbs = list( + ActionExecution.query( + only_fields=["id"], no_dereference=True, **exec_filters + ) + ) deleted_count = ActionExecution.delete_by_query(**exec_filters) except InvalidQueryError as e: - msg = ('Bad query (%s) used to delete execution instances: %s' - 'Please contact support.' % (exec_filters, six.text_type(e))) + msg = ( + "Bad query (%s) used to delete execution instances: %s" + "Please contact support." + % ( + exec_filters, + six.text_type(e), + ) + ) raise InvalidQueryError(msg) except: - logger.exception('Deletion of execution models failed for query with filters: %s.', - exec_filters) + logger.exception( + "Deletion of execution models failed for query with filters: %s.", + exec_filters, + ) else: - logger.info('Deleted %s action execution objects' % (deleted_count)) + logger.info("Deleted %s action execution objects" % (deleted_count)) # 2. Delete LiveActionDB objects try: deleted_count = LiveAction.delete_by_query(**liveaction_filters) except InvalidQueryError as e: - msg = ('Bad query (%s) used to delete liveaction instances: %s' - 'Please contact support.' % (liveaction_filters, six.text_type(e))) + msg = ( + "Bad query (%s) used to delete liveaction instances: %s" + "Please contact support." + % ( + liveaction_filters, + six.text_type(e), + ) + ) raise InvalidQueryError(msg) except: - logger.exception('Deletion of liveaction models failed for query with filters: %s.', - liveaction_filters) + logger.exception( + "Deletion of liveaction models failed for query with filters: %s.", + liveaction_filters, + ) else: - logger.info('Deleted %s liveaction objects' % (deleted_count)) + logger.info("Deleted %s liveaction objects" % (deleted_count)) # 3. Delete ActionExecutionOutputDB objects - to_delete_exection_ids = [str(execution_db.id) for execution_db in to_delete_execution_dbs] + to_delete_exection_ids = [ + str(execution_db.id) for execution_db in to_delete_execution_dbs + ] output_dbs_filters = {} - output_dbs_filters['execution_id'] = {'$in': to_delete_exection_ids} + output_dbs_filters["execution_id"] = {"$in": to_delete_exection_ids} try: deleted_count = ActionExecutionOutput.delete_by_query(**output_dbs_filters) except InvalidQueryError as e: - msg = ('Bad query (%s) used to delete execution output instances: %s' - 'Please contact support.' % (output_dbs_filters, six.text_type(e))) + msg = ( + "Bad query (%s) used to delete execution output instances: %s" + "Please contact support." % (output_dbs_filters, six.text_type(e)) + ) raise InvalidQueryError(msg) except: - logger.exception('Deletion of execution output models failed for query with filters: %s.', - output_dbs_filters) + logger.exception( + "Deletion of execution output models failed for query with filters: %s.", + output_dbs_filters, + ) else: - logger.info('Deleted %s execution output objects' % (deleted_count)) + logger.info("Deleted %s execution output objects" % (deleted_count)) - zombie_execution_instances = len(ActionExecution.query(only_fields=['id'], - no_dereference=True, - **exec_filters)) - zombie_liveaction_instances = len(LiveAction.query(only_fields=['id'], - no_dereference=True, - **liveaction_filters)) + zombie_execution_instances = len( + ActionExecution.query(only_fields=["id"], no_dereference=True, **exec_filters) + ) + zombie_liveaction_instances = len( + LiveAction.query(only_fields=["id"], no_dereference=True, **liveaction_filters) + ) if (zombie_execution_instances > 0) or (zombie_liveaction_instances > 0): - logger.error('Zombie execution instances left: %d.', zombie_execution_instances) - logger.error('Zombie liveaction instances left: %s.', zombie_liveaction_instances) + logger.error("Zombie execution instances left: %d.", zombie_execution_instances) + logger.error( + "Zombie liveaction instances left: %s.", zombie_liveaction_instances + ) # Print stats - logger.info('All execution models older than timestamp %s were deleted.', timestamp) + logger.info("All execution models older than timestamp %s were deleted.", timestamp) def purge_execution_output_objects(logger, timestamp, action_ref=None): @@ -154,28 +181,34 @@ def purge_execution_output_objects(logger, timestamp, action_ref=None): :type action_ref: ``str`` """ if not timestamp: - raise ValueError('Specify a valid timestamp to purge.') + raise ValueError("Specify a valid timestamp to purge.") - logger.info('Purging action execution output objects older than timestamp: %s' % - timestamp.strftime('%Y-%m-%dT%H:%M:%S.%fZ')) + logger.info( + "Purging action execution output objects older than timestamp: %s" + % timestamp.strftime("%Y-%m-%dT%H:%M:%S.%fZ") + ) filters = {} - filters['timestamp__lt'] = timestamp + filters["timestamp__lt"] = timestamp if action_ref: - filters['action_ref'] = action_ref + filters["action_ref"] = action_ref try: deleted_count = ActionExecutionOutput.delete_by_query(**filters) except InvalidQueryError as e: - msg = ('Bad query (%s) used to delete execution output instances: %s' - 'Please contact support.' % (filters, six.text_type(e))) + msg = ( + "Bad query (%s) used to delete execution output instances: %s" + "Please contact support." % (filters, six.text_type(e)) + ) raise InvalidQueryError(msg) except: - logger.exception('Deletion of execution output models failed for query with filters: %s.', - filters) + logger.exception( + "Deletion of execution output models failed for query with filters: %s.", + filters, + ) else: - logger.info('Deleted %s execution output objects' % (deleted_count)) + logger.info("Deleted %s execution output objects" % (deleted_count)) def purge_orphaned_workflow_executions(logger): @@ -190,5 +223,5 @@ def purge_orphaned_workflow_executions(logger): # as a result of the original failure, the garbage collection routine here cancels # the workflow execution so it cannot be rerun from failed task(s). for ac_ex_db in workflow_service.identify_orphaned_workflows(): - lv_ac_db = LiveAction.get(id=ac_ex_db.liveaction['id']) + lv_ac_db = LiveAction.get(id=ac_ex_db.liveaction["id"]) action_service.request_cancellation(lv_ac_db, None) diff --git a/st2common/st2common/garbage_collection/inquiries.py b/st2common/st2common/garbage_collection/inquiries.py index 724033853f7..ad95126b216 100644 --- a/st2common/st2common/garbage_collection/inquiries.py +++ b/st2common/st2common/garbage_collection/inquiries.py @@ -27,7 +27,7 @@ from st2common.util.date import get_datetime_utc_now __all__ = [ - 'purge_inquiries', + "purge_inquiries", ] @@ -44,7 +44,10 @@ def purge_inquiries(logger): """ # Get all existing Inquiries - filters = {'runner__name': 'inquirer', 'status': action_constants.LIVEACTION_STATUS_PENDING} + filters = { + "runner__name": "inquirer", + "status": action_constants.LIVEACTION_STATUS_PENDING, + } inquiries = list(ActionExecution.query(**filters)) gc_count = 0 @@ -52,7 +55,7 @@ def purge_inquiries(logger): # Inspect each Inquiry, and determine if TTL is expired for inquiry in inquiries: - ttl = int(inquiry.result.get('ttl')) + ttl = int(inquiry.result.get("ttl")) if ttl <= 0: logger.debug("Inquiry %s has a TTL of %s. Skipping." % (inquiry.id, ttl)) continue @@ -61,17 +64,22 @@ def purge_inquiries(logger): (get_datetime_utc_now() - inquiry.start_timestamp).total_seconds() / 60 ) - logger.debug("Inquiry %s has a TTL of %s and was started %s minute(s) ago" % ( - inquiry.id, ttl, min_since_creation)) + logger.debug( + "Inquiry %s has a TTL of %s and was started %s minute(s) ago" + % (inquiry.id, ttl, min_since_creation) + ) if min_since_creation > ttl: gc_count += 1 - logger.info("TTL expired for Inquiry %s. Marking as timed out." % inquiry.id) + logger.info( + "TTL expired for Inquiry %s. Marking as timed out." % inquiry.id + ) liveaction_db = action_utils.update_liveaction_status( status=action_constants.LIVEACTION_STATUS_TIMED_OUT, result=inquiry.result, - liveaction_id=inquiry.liveaction.get('id')) + liveaction_id=inquiry.liveaction.get("id"), + ) executions.update_execution(liveaction_db) # Call Inquiry runner's post_run to trigger callback to workflow @@ -82,8 +90,7 @@ def purge_inquiries(logger): # Request that root workflow resumes root_liveaction = action_service.get_root_liveaction(liveaction_db) action_service.request_resume( - root_liveaction, - UserDB(cfg.CONF.system_user.user) + root_liveaction, UserDB(cfg.CONF.system_user.user) ) logger.info('Marked %s ttl-expired Inquiries as "timed out".' % (gc_count)) diff --git a/st2common/st2common/garbage_collection/trigger_instances.py b/st2common/st2common/garbage_collection/trigger_instances.py index 47996614dd4..0fbabb5e727 100644 --- a/st2common/st2common/garbage_collection/trigger_instances.py +++ b/st2common/st2common/garbage_collection/trigger_instances.py @@ -25,9 +25,7 @@ from st2common.persistence.trigger import TriggerInstance from st2common.util import isotime -__all__ = [ - 'purge_trigger_instances' -] +__all__ = ["purge_trigger_instances"] def purge_trigger_instances(logger, timestamp): @@ -36,23 +34,35 @@ def purge_trigger_instances(logger, timestamp): :type timestamp: ``datetime.datetime """ if not timestamp: - raise ValueError('Specify a valid timestamp to purge.') + raise ValueError("Specify a valid timestamp to purge.") - logger.info('Purging trigger instances older than timestamp: %s' % - timestamp.strftime('%Y-%m-%dT%H:%M:%S.%fZ')) + logger.info( + "Purging trigger instances older than timestamp: %s" + % timestamp.strftime("%Y-%m-%dT%H:%M:%S.%fZ") + ) - query_filters = {'occurrence_time__lt': isotime.parse(timestamp)} + query_filters = {"occurrence_time__lt": isotime.parse(timestamp)} try: deleted_count = TriggerInstance.delete_by_query(**query_filters) except InvalidQueryError as e: - msg = ('Bad query (%s) used to delete trigger instances: %s' - 'Please contact support.' % (query_filters, six.text_type(e))) + msg = ( + "Bad query (%s) used to delete trigger instances: %s" + "Please contact support." + % ( + query_filters, + six.text_type(e), + ) + ) raise InvalidQueryError(msg) except: - logger.exception('Deleting instances using query_filters %s failed.', query_filters) + logger.exception( + "Deleting instances using query_filters %s failed.", query_filters + ) else: - logger.info('Deleted %s trigger instance objects' % (deleted_count)) + logger.info("Deleted %s trigger instance objects" % (deleted_count)) # Print stats - logger.info('All trigger instance models older than timestamp %s were deleted.', timestamp) + logger.info( + "All trigger instance models older than timestamp %s were deleted.", timestamp + ) diff --git a/st2common/st2common/log.py b/st2common/st2common/log.py index 5335af5f535..fbf6205bb91 100644 --- a/st2common/st2common/log.py +++ b/st2common/st2common/log.py @@ -35,34 +35,30 @@ from st2common.util.misc import get_normalized_file_path __all__ = [ - 'getLogger', - 'setup', - - 'FormatNamedFileHandler', - 'ConfigurableSyslogHandler', - - 'LoggingStream', - - 'ignore_lib2to3_log_messages', - 'ignore_statsd_log_messages' + "getLogger", + "setup", + "FormatNamedFileHandler", + "ConfigurableSyslogHandler", + "LoggingStream", + "ignore_lib2to3_log_messages", + "ignore_statsd_log_messages", ] # NOTE: We set AUDIT to the highest log level which means AUDIT log messages will always be # included (e.g. also if log level is set to INFO). To avoid that, we need to explicitly filter # out AUDIT log level in service setup code. logging.AUDIT = logging.CRITICAL + 10 -logging.addLevelName(logging.AUDIT, 'AUDIT') +logging.addLevelName(logging.AUDIT, "AUDIT") LOGGER_KEYS = [ - 'debug', - 'info', - 'warning', - 'error', - 'critical', - 'exception', - 'log', - - 'audit' + "debug", + "info", + "warning", + "error", + "critical", + "exception", + "log", + "audit", ] # Note: This attribute is used by "find_caller" so it can correctly exclude this file when looking @@ -89,10 +85,10 @@ def find_caller(stack_info=False, stacklevel=1): on what runtine we're working in. """ if six.PY2: - rv = '(unknown file)', 0, '(unknown function)' + rv = "(unknown file)", 0, "(unknown function)" else: # python 3, has extra tuple element at the end for stack information - rv = '(unknown file)', 0, '(unknown function)', None + rv = "(unknown file)", 0, "(unknown function)", None try: f = logging.currentframe() @@ -107,7 +103,7 @@ def find_caller(stack_info=False, stacklevel=1): if not f: f = orig_f - while hasattr(f, 'f_code'): + while hasattr(f, "f_code"): co = f.f_code filename = os.path.normcase(co.co_filename) if filename in (_srcfile, logging._srcfile): # This line is modified. @@ -121,10 +117,10 @@ def find_caller(stack_info=False, stacklevel=1): sinfo = None if stack_info: sio = io.StringIO() - sio.write('Stack (most recent call last):\n') + sio.write("Stack (most recent call last):\n") traceback.print_stack(f, file=sio) sinfo = sio.getvalue() - if sinfo[-1] == '\n': + if sinfo[-1] == "\n": sinfo = sinfo[:-1] sio.close() rv = (filename, f.f_lineno, co.co_name, sinfo) @@ -139,8 +135,8 @@ def decorate_log_method(func): @wraps(func) def func_wrapper(*args, **kwargs): # Prefix extra keys with underscore - if 'extra' in kwargs: - kwargs['extra'] = prefix_dict_keys(dictionary=kwargs['extra'], prefix='_') + if "extra" in kwargs: + kwargs["extra"] = prefix_dict_keys(dictionary=kwargs["extra"], prefix="_") try: return func(*args, **kwargs) @@ -150,10 +146,11 @@ def func_wrapper(*args, **kwargs): # See: # - https://docs.python.org/release/2.7.3/library/logging.html#logging.Logger.exception # - https://docs.python.org/release/2.7.7/library/logging.html#logging.Logger.exception - if 'got an unexpected keyword argument \'extra\'' in six.text_type(e): - kwargs.pop('extra', None) + if "got an unexpected keyword argument 'extra'" in six.text_type(e): + kwargs.pop("extra", None) return func(*args, **kwargs) raise e + return func_wrapper @@ -179,11 +176,11 @@ def decorate_logger_methods(logger): def getLogger(name): # make sure that prefix isn't appended multiple times to preserve logging name hierarchy - prefix = 'st2.' + prefix = "st2." if name.startswith(prefix): logger = logging.getLogger(name) else: - logger_name = '{}{}'.format(prefix, name) + logger_name = "{}{}".format(prefix, name) logger = logging.getLogger(logger_name) logger = decorate_logger_methods(logger=logger) @@ -191,7 +188,6 @@ def getLogger(name): class LoggingStream(object): - def __init__(self, name, level=logging.ERROR): self._logger = getLogger(name) self._level = level @@ -219,11 +215,16 @@ def _add_exclusion_filters(handlers, excludes=None): def _redirect_stderr(): # It is ok to redirect stderr as none of the st2 handlers write to stderr. - sys.stderr = LoggingStream('STDERR') + sys.stderr = LoggingStream("STDERR") -def setup(config_file, redirect_stderr=True, excludes=None, disable_existing_loggers=False, - st2_conf_path=None): +def setup( + config_file, + redirect_stderr=True, + excludes=None, + disable_existing_loggers=False, + st2_conf_path=None, +): """ Configure logging from file. @@ -232,16 +233,18 @@ def setup(config_file, redirect_stderr=True, excludes=None, disable_existing_log absolute path relative to st2.conf. :type st2_conf_path: ``str`` """ - if st2_conf_path and config_file[:2] == './' and not os.path.isfile(config_file): + if st2_conf_path and config_file[:2] == "./" and not os.path.isfile(config_file): # Logging config path is relative to st2.conf, resolve it to full absolute path directory = os.path.dirname(st2_conf_path) config_file_name = os.path.basename(config_file) config_file = os.path.join(directory, config_file_name) try: - logging.config.fileConfig(config_file, - defaults=None, - disable_existing_loggers=disable_existing_loggers) + logging.config.fileConfig( + config_file, + defaults=None, + disable_existing_loggers=disable_existing_loggers, + ) handlers = logging.getLoggerClass().manager.root.handlers _add_exclusion_filters(handlers=handlers, excludes=excludes) if redirect_stderr: @@ -251,13 +254,13 @@ def setup(config_file, redirect_stderr=True, excludes=None, disable_existing_log tb_msg = traceback.format_exc() msg = str(exc) - msg += '\n\n' + tb_msg + msg += "\n\n" + tb_msg # revert stderr redirection since there is no logger in place. sys.stderr = sys.__stderr__ # No logger yet therefore write to stderr - sys.stderr.write('ERROR: %s' % (msg)) + sys.stderr.write("ERROR: %s" % (msg)) raise exc_cls(six.text_type(msg)) @@ -271,10 +274,10 @@ def ignore_lib2to3_log_messages(): class MockLoggingModule(object): def getLogger(self, *args, **kwargs): - return logging.getLogger('lib2to3') + return logging.getLogger("lib2to3") lib2to3.pgen2.driver.logging = MockLoggingModule() - logging.getLogger('lib2to3').setLevel(logging.ERROR) + logging.getLogger("lib2to3").setLevel(logging.ERROR) def ignore_statsd_log_messages(): @@ -288,8 +291,8 @@ def ignore_statsd_log_messages(): class MockLoggingModule(object): def getLogger(self, *args, **kwargs): - return logging.getLogger('statsd') + return logging.getLogger("statsd") statsd.connection.logging = MockLoggingModule() statsd.client.logging = MockLoggingModule() - logging.getLogger('statsd').setLevel(logging.ERROR) + logging.getLogger("statsd").setLevel(logging.ERROR) diff --git a/st2common/st2common/logging/filters.py b/st2common/st2common/logging/filters.py index d997589a0e7..1fef1640287 100644 --- a/st2common/st2common/logging/filters.py +++ b/st2common/st2common/logging/filters.py @@ -17,9 +17,9 @@ import logging __all__ = [ - 'LoggerNameExclusionFilter', - 'LoggerFunctionNameExclusionFilter', - 'LogLevelFilter', + "LoggerNameExclusionFilter", + "LoggerFunctionNameExclusionFilter", + "LogLevelFilter", ] @@ -36,8 +36,11 @@ def filter(self, record): if len(self._exclusions) < 1: return True - module_decomposition = record.name.split('.') - exclude = len(module_decomposition) > 0 and module_decomposition[0] in self._exclusions + module_decomposition = record.name.split(".") + exclude = ( + len(module_decomposition) > 0 + and module_decomposition[0] in self._exclusions + ) return not exclude @@ -54,7 +57,7 @@ def filter(self, record): if len(self._exclusions) < 1: return True - function_name = getattr(record, 'funcName', None) + function_name = getattr(record, "funcName", None) if function_name in self._exclusions: return False diff --git a/st2common/st2common/logging/formatters.py b/st2common/st2common/logging/formatters.py index d20b240a5ac..7c30e780a98 100644 --- a/st2common/st2common/logging/formatters.py +++ b/st2common/st2common/logging/formatters.py @@ -28,8 +28,8 @@ from st2common.constants.secrets import MASKED_ATTRIBUTE_VALUE __all__ = [ - 'ConsoleLogFormatter', - 'GelfLogFormatter', + "ConsoleLogFormatter", + "GelfLogFormatter", ] SIMPLE_TYPES = (int, float) + six.string_types @@ -37,16 +37,16 @@ # GELF logger specific constants HOSTNAME = socket.gethostname() -GELF_SPEC_VERSION = '1.1' +GELF_SPEC_VERSION = "1.1" COMMON_ATTRIBUTE_NAMES = [ - 'name', - 'process', - 'processName', - 'module', - 'filename', - 'funcName', - 'lineno' + "name", + "process", + "processName", + "module", + "filename", + "funcName", + "lineno", ] @@ -60,9 +60,9 @@ def serialize_object(obj): :rtype: ``str`` """ # Try to serialize the object - if getattr(obj, 'to_dict', None): + if getattr(obj, "to_dict", None): value = obj.to_dict() - elif getattr(obj, 'to_serializable_dict', None): + elif getattr(obj, "to_serializable_dict", None): value = obj.to_serializable_dict(mask_secrets=True) else: value = repr(obj) @@ -77,7 +77,9 @@ def process_attribute_value(key, value): if not cfg.CONF.log.mask_secrets: return value - blacklisted_attribute_names = MASKED_ATTRIBUTES_BLACKLIST + cfg.CONF.log.mask_secrets_blacklist + blacklisted_attribute_names = ( + MASKED_ATTRIBUTES_BLACKLIST + cfg.CONF.log.mask_secrets_blacklist + ) # NOTE: This can be expensive when processing large dicts or objects if isinstance(value, SIMPLE_TYPES): @@ -121,11 +123,16 @@ class BaseExtraLogFormatter(logging.Formatter): dictionary need to be prefixed with a slash ('_'). """ - PREFIX = '_' # Prefix for user provided attributes in the extra dict + PREFIX = "_" # Prefix for user provided attributes in the extra dict def _get_extra_attributes(self, record): - attributes = dict([(k, v) for k, v in six.iteritems(record.__dict__) - if k.startswith(self.PREFIX)]) + attributes = dict( + [ + (k, v) + for k, v in six.iteritems(record.__dict__) + if k.startswith(self.PREFIX) + ] + ) return attributes def _get_common_extra_attributes(self, record): @@ -182,17 +189,17 @@ def format(self, record): msg = super(ConsoleLogFormatter, self).format(record) if attributes: - msg = '%s (%s)' % (msg, attributes) + msg = "%s (%s)" % (msg, attributes) return msg def _dict_to_str(self, attributes): result = [] for key, value in six.iteritems(attributes): - item = '%s=%s' % (key[1:], repr(value)) + item = "%s=%s" % (key[1:], repr(value)) result.append(item) - result = ','.join(result) + result = ",".join(result) return result @@ -245,30 +252,32 @@ def format(self, record): exc_info = record.exc_info time_now_float = record.created time_now_sec = int(time_now_float) - level = self.PYTHON_TO_GELF_LEVEL_MAP.get(record.levelno, self.DEFAULT_LOG_LEVEL) + level = self.PYTHON_TO_GELF_LEVEL_MAP.get( + record.levelno, self.DEFAULT_LOG_LEVEL + ) common_attributes = self._get_common_extra_attributes(record=record) full_msg = super(GelfLogFormatter, self).format(record) data = { - 'version': GELF_SPEC_VERSION, - 'host': HOSTNAME, - 'short_message': msg, - 'full_message': full_msg, - 'timestamp': time_now_sec, - 'timestamp_f': time_now_float, - 'level': level + "version": GELF_SPEC_VERSION, + "host": HOSTNAME, + "short_message": msg, + "full_message": full_msg, + "timestamp": time_now_sec, + "timestamp_f": time_now_float, + "level": level, } if exc_info: # Include exception information exc_type, exc_value, exc_tb = exc_info - tb_str = ''.join(traceback.format_tb(exc_tb)) - data['_exception'] = six.text_type(exc_value) - data['_traceback'] = tb_str + tb_str = "".join(traceback.format_tb(exc_tb)) + data["_exception"] = six.text_type(exc_value) + data["_traceback"] = tb_str # Include common Python log record attributes - data['_python'] = common_attributes + data["_python"] = common_attributes # Include user extra attributes data.update(attributes) diff --git a/st2common/st2common/logging/handlers.py b/st2common/st2common/logging/handlers.py index ade4dfbb046..963ac197a02 100644 --- a/st2common/st2common/logging/handlers.py +++ b/st2common/st2common/logging/handlers.py @@ -24,26 +24,29 @@ from st2common.util import date as date_utils __all__ = [ - 'FormatNamedFileHandler', - 'ConfigurableSyslogHandler', + "FormatNamedFileHandler", + "ConfigurableSyslogHandler", ] class FormatNamedFileHandler(logging.handlers.RotatingFileHandler): - def __init__(self, filename, mode='a', maxBytes=0, backupCount=0, encoding=None, delay=False): + def __init__( + self, filename, mode="a", maxBytes=0, backupCount=0, encoding=None, delay=False + ): # We add aditional values to the context which can be used in the log filename timestamp = int(time.time()) - isotime_str = str(date_utils.get_datetime_utc_now()).replace(' ', '_') + isotime_str = str(date_utils.get_datetime_utc_now()).replace(" ", "_") pid = os.getpid() - format_values = { - 'timestamp': timestamp, - 'ts': isotime_str, - 'pid': pid - } + format_values = {"timestamp": timestamp, "ts": isotime_str, "pid": pid} filename = filename.format(**format_values) - super(FormatNamedFileHandler, self).__init__(filename, mode=mode, maxBytes=maxBytes, - backupCount=backupCount, encoding=encoding, - delay=delay) + super(FormatNamedFileHandler, self).__init__( + filename, + mode=mode, + maxBytes=maxBytes, + backupCount=backupCount, + encoding=encoding, + delay=delay, + ) class ConfigurableSyslogHandler(logging.handlers.SysLogHandler): @@ -55,12 +58,12 @@ def __init__(self, address=None, facility=None, socktype=None): if not socktype: protocol = cfg.CONF.syslog.protocol.lower() - if protocol == 'udp': + if protocol == "udp": socktype = socket.SOCK_DGRAM - elif protocol == 'tcp': + elif protocol == "tcp": socktype = socket.SOCK_STREAM else: - raise ValueError('Unsupported protocol: %s' % (protocol)) + raise ValueError("Unsupported protocol: %s" % (protocol)) if socktype: super(ConfigurableSyslogHandler, self).__init__(address, facility, socktype) diff --git a/st2common/st2common/logging/misc.py b/st2common/st2common/logging/misc.py index 36f8b179865..de7f673431e 100644 --- a/st2common/st2common/logging/misc.py +++ b/st2common/st2common/logging/misc.py @@ -23,32 +23,26 @@ from st2common.logging.filters import LoggerFunctionNameExclusionFilter __all__ = [ - 'reopen_log_files', - - 'set_log_level_for_all_handlers', - 'set_log_level_for_all_loggers', - - 'add_global_filters_for_all_loggers' + "reopen_log_files", + "set_log_level_for_all_handlers", + "set_log_level_for_all_loggers", + "add_global_filters_for_all_loggers", ] LOG = logging.getLogger(__name__) # Because some loggers are just waste of attention span -SPECIAL_LOGGERS = { - 'swagger_spec_validator.ref_validators': logging.INFO -} +SPECIAL_LOGGERS = {"swagger_spec_validator.ref_validators": logging.INFO} # Log messages for function names which are very spammy and we want to filter out when DEBUG log # level is enabled IGNORED_FUNCTION_NAMES = [ # Used by pyamqp, logs every heartbit tick every 2 ms by default - 'heartbeat_tick' + "heartbeat_tick" ] # List of global filters which apply to all the loggers -GLOBAL_FILTERS = [ - LoggerFunctionNameExclusionFilter(exclusions=IGNORED_FUNCTION_NAMES) -] +GLOBAL_FILTERS = [LoggerFunctionNameExclusionFilter(exclusions=IGNORED_FUNCTION_NAMES)] def reopen_log_files(handlers): @@ -65,8 +59,10 @@ def reopen_log_files(handlers): if not isinstance(handler, logging.FileHandler): continue - LOG.info('Re-opening log file "%s" with mode "%s"\n' % - (handler.baseFilename, handler.mode)) + LOG.info( + 'Re-opening log file "%s" with mode "%s"\n' + % (handler.baseFilename, handler.mode) + ) try: handler.acquire() @@ -76,10 +72,10 @@ def reopen_log_files(handlers): try: handler.release() except RuntimeError as e: - if 'cannot release' in six.text_type(e): + if "cannot release" in six.text_type(e): # Release failed which most likely indicates that acquire failed # and lock was never acquired - LOG.warn('Failed to release lock', exc_info=True) + LOG.warn("Failed to release lock", exc_info=True) else: raise e @@ -112,7 +108,9 @@ def set_log_level_for_all_loggers(level=logging.DEBUG): logger = add_filters_for_logger(logger=logger, filters=GLOBAL_FILTERS) if logger.name in SPECIAL_LOGGERS: - set_log_level_for_all_handlers(logger=logger, level=SPECIAL_LOGGERS.get(logger.name)) + set_log_level_for_all_handlers( + logger=logger, level=SPECIAL_LOGGERS.get(logger.name) + ) else: set_log_level_for_all_handlers(logger=logger, level=level) @@ -152,7 +150,7 @@ def add_filters_for_logger(logger, filters): if not isinstance(logger, logging.Logger): return logger - if not hasattr(logger, 'addFilter'): + if not hasattr(logger, "addFilter"): return logger for logger_filter in filters: @@ -170,7 +168,7 @@ def get_logger_name_for_module(module, exclude_module_name=False): module_file = module.__file__ base_dir = os.path.dirname(os.path.abspath(module_file)) module_name = os.path.basename(module_file) - module_name = module_name.replace('.pyc', '').replace('.py', '') + module_name = module_name.replace(".pyc", "").replace(".py", "") split = base_dir.split(os.path.sep) split = [component for component in split if component] @@ -178,15 +176,15 @@ def get_logger_name_for_module(module, exclude_module_name=False): # Find first component which starts with st2 and use that as a starting point start_index = 0 for index, component in enumerate(reversed(split)): - if component.startswith('st2'): - start_index = ((len(split) - 1) - index) + if component.startswith("st2"): + start_index = (len(split) - 1) - index break split = split[start_index:] if exclude_module_name: - name = '.'.join(split) + name = ".".join(split) else: - name = '.'.join(split) + '.' + module_name + name = ".".join(split) + "." + module_name return name diff --git a/st2common/st2common/metrics/base.py b/st2common/st2common/metrics/base.py index 18801c901db..215780b86f6 100644 --- a/st2common/st2common/metrics/base.py +++ b/st2common/st2common/metrics/base.py @@ -28,23 +28,22 @@ from st2common.exceptions.plugins import PluginLoadError __all__ = [ - 'BaseMetricsDriver', - - 'Timer', - 'Counter', - 'CounterWithTimer', - - 'metrics_initialize', - 'get_driver' + "BaseMetricsDriver", + "Timer", + "Counter", + "CounterWithTimer", + "metrics_initialize", + "get_driver", ] -if not hasattr(cfg.CONF, 'metrics'): +if not hasattr(cfg.CONF, "metrics"): from st2common.config import register_opts + register_opts() LOG = logging.getLogger(__name__) -PLUGIN_NAMESPACE = 'st2common.metrics.driver' +PLUGIN_NAMESPACE = "st2common.metrics.driver" # Stores reference to the metrics driver class instance. # NOTE: This value is populated lazily on the first get_driver() function call @@ -97,6 +96,7 @@ class Timer(object): """ Timer context manager for easily sending timer statistics. """ + def __init__(self, key, include_parameter=False): check_key(key) @@ -136,8 +136,9 @@ def __call__(self, func): def wrapper(*args, **kw): with self as metrics_timer: if self._include_parameter: - kw['metrics_timer'] = metrics_timer + kw["metrics_timer"] = metrics_timer return func(*args, **kw) + return wrapper @@ -145,6 +146,7 @@ class Counter(object): """ Counter context manager for easily sending counter statistics. """ + def __init__(self, key): check_key(key) self.key = key @@ -162,6 +164,7 @@ def __call__(self, func): def wrapper(*args, **kw): with self: return func(*args, **kw) + return wrapper @@ -209,8 +212,9 @@ def __call__(self, func): def wrapper(*args, **kw): with self as counter_with_timer: if self._include_parameter: - kw['metrics_counter_with_timer'] = counter_with_timer + kw["metrics_counter_with_timer"] = counter_with_timer return func(*args, **kw) + return wrapper @@ -223,7 +227,9 @@ def metrics_initialize(): try: METRICS = get_plugin_instance(PLUGIN_NAMESPACE, cfg.CONF.metrics.driver) except (NoMatches, MultipleMatches, NoSuchOptError) as error: - raise PluginLoadError('Error loading metrics driver. Check configuration: %s' % error) + raise PluginLoadError( + "Error loading metrics driver. Check configuration: %s" % error + ) return METRICS diff --git a/st2common/st2common/metrics/drivers/echo_driver.py b/st2common/st2common/metrics/drivers/echo_driver.py index 40b2ed3947f..7cb115aab68 100644 --- a/st2common/st2common/metrics/drivers/echo_driver.py +++ b/st2common/st2common/metrics/drivers/echo_driver.py @@ -16,9 +16,7 @@ from st2common import log as logging from st2common.metrics.base import BaseMetricsDriver -__all__ = [ - 'EchoDriver' -] +__all__ = ["EchoDriver"] LOG = logging.getLogger(__name__) @@ -29,19 +27,19 @@ class EchoDriver(BaseMetricsDriver): """ def time(self, key, time): - LOG.debug('[metrics] time(key=%s, time=%s)' % (key, time)) + LOG.debug("[metrics] time(key=%s, time=%s)" % (key, time)) def inc_counter(self, key, amount=1): - LOG.debug('[metrics] counter.incr(%s, %s)' % (key, amount)) + LOG.debug("[metrics] counter.incr(%s, %s)" % (key, amount)) def dec_counter(self, key, amount=1): - LOG.debug('[metrics] counter.decr(%s, %s)' % (key, amount)) + LOG.debug("[metrics] counter.decr(%s, %s)" % (key, amount)) def set_gauge(self, key, value): - LOG.debug('[metrics] set_gauge(%s, %s)' % (key, value)) + LOG.debug("[metrics] set_gauge(%s, %s)" % (key, value)) def inc_gauge(self, key, amount=1): - LOG.debug('[metrics] gauge.incr(%s, %s)' % (key, amount)) + LOG.debug("[metrics] gauge.incr(%s, %s)" % (key, amount)) def dec_gauge(self, key, amount=1): - LOG.debug('[metrics] gauge.decr(%s, %s)' % (key, amount)) + LOG.debug("[metrics] gauge.decr(%s, %s)" % (key, amount)) diff --git a/st2common/st2common/metrics/drivers/noop_driver.py b/st2common/st2common/metrics/drivers/noop_driver.py index 6f816f2a693..658ee10a40f 100644 --- a/st2common/st2common/metrics/drivers/noop_driver.py +++ b/st2common/st2common/metrics/drivers/noop_driver.py @@ -15,9 +15,7 @@ from st2common.metrics.base import BaseMetricsDriver -__all__ = [ - 'NoopDriver' -] +__all__ = ["NoopDriver"] class NoopDriver(BaseMetricsDriver): diff --git a/st2common/st2common/metrics/drivers/statsd_driver.py b/st2common/st2common/metrics/drivers/statsd_driver.py index c334837e9bd..efbefde6016 100644 --- a/st2common/st2common/metrics/drivers/statsd_driver.py +++ b/st2common/st2common/metrics/drivers/statsd_driver.py @@ -30,15 +30,9 @@ LOG = logging.getLogger(__name__) # Which exceptions thrown by statsd library should be considered as non-fatal -NON_FATAL_EXC_CLASSES = [ - socket.error, - IOError, - OSError -] +NON_FATAL_EXC_CLASSES = [socket.error, IOError, OSError] -__all__ = [ - 'StatsdDriver' -] +__all__ = ["StatsdDriver"] class StatsdDriver(BaseMetricsDriver): @@ -55,11 +49,15 @@ class StatsdDriver(BaseMetricsDriver): """ def __init__(self): - statsd.Connection.set_defaults(host=cfg.CONF.metrics.host, port=cfg.CONF.metrics.port, - sample_rate=cfg.CONF.metrics.sample_rate) - - @ignore_and_log_exception(exc_classes=NON_FATAL_EXC_CLASSES, logger=LOG, - level=stdlib_logging.WARNING) + statsd.Connection.set_defaults( + host=cfg.CONF.metrics.host, + port=cfg.CONF.metrics.port, + sample_rate=cfg.CONF.metrics.sample_rate, + ) + + @ignore_and_log_exception( + exc_classes=NON_FATAL_EXC_CLASSES, logger=LOG, level=stdlib_logging.WARNING + ) def time(self, key, time): """ Timer metric @@ -68,11 +66,12 @@ def time(self, key, time): assert isinstance(time, Number) key = get_full_key_name(key) - timer = statsd.Timer('') + timer = statsd.Timer("") timer.send(key, time) - @ignore_and_log_exception(exc_classes=NON_FATAL_EXC_CLASSES, logger=LOG, - level=stdlib_logging.WARNING) + @ignore_and_log_exception( + exc_classes=NON_FATAL_EXC_CLASSES, logger=LOG, level=stdlib_logging.WARNING + ) def inc_counter(self, key, amount=1): """ Increment counter @@ -84,8 +83,9 @@ def inc_counter(self, key, amount=1): counter = statsd.Counter(key) counter.increment(delta=amount) - @ignore_and_log_exception(exc_classes=NON_FATAL_EXC_CLASSES, logger=LOG, - level=stdlib_logging.WARNING) + @ignore_and_log_exception( + exc_classes=NON_FATAL_EXC_CLASSES, logger=LOG, level=stdlib_logging.WARNING + ) def dec_counter(self, key, amount=1): """ Decrement metric @@ -97,8 +97,9 @@ def dec_counter(self, key, amount=1): counter = statsd.Counter(key) counter.decrement(delta=amount) - @ignore_and_log_exception(exc_classes=NON_FATAL_EXC_CLASSES, logger=LOG, - level=stdlib_logging.WARNING) + @ignore_and_log_exception( + exc_classes=NON_FATAL_EXC_CLASSES, logger=LOG, level=stdlib_logging.WARNING + ) def set_gauge(self, key, value): """ Set gauge value. @@ -110,8 +111,9 @@ def set_gauge(self, key, value): gauge = statsd.Gauge(key) gauge.send(None, value) - @ignore_and_log_exception(exc_classes=NON_FATAL_EXC_CLASSES, logger=LOG, - level=stdlib_logging.WARNING) + @ignore_and_log_exception( + exc_classes=NON_FATAL_EXC_CLASSES, logger=LOG, level=stdlib_logging.WARNING + ) def inc_gauge(self, key, amount=1): """ Increment gauge value. @@ -123,8 +125,9 @@ def inc_gauge(self, key, amount=1): gauge = statsd.Gauge(key) gauge.increment(None, amount) - @ignore_and_log_exception(exc_classes=NON_FATAL_EXC_CLASSES, logger=LOG, - level=stdlib_logging.WARNING) + @ignore_and_log_exception( + exc_classes=NON_FATAL_EXC_CLASSES, logger=LOG, level=stdlib_logging.WARNING + ) def dec_gauge(self, key, amount=1): """ Decrement gauge value. diff --git a/st2common/st2common/metrics/utils.py b/st2common/st2common/metrics/utils.py index f741743cd23..710aceff159 100644 --- a/st2common/st2common/metrics/utils.py +++ b/st2common/st2common/metrics/utils.py @@ -16,10 +16,7 @@ import six from oslo_config import cfg -__all__ = [ - 'get_full_key_name', - 'check_key' -] +__all__ = ["get_full_key_name", "check_key"] def get_full_key_name(key): @@ -27,14 +24,14 @@ def get_full_key_name(key): Return full metric key name, taking into account optional prefix which can be specified in the config. """ - parts = ['st2'] + parts = ["st2"] if cfg.CONF.metrics.prefix: parts.append(cfg.CONF.metrics.prefix) parts.append(key) - return '.'.join(parts) + return ".".join(parts) def check_key(key): diff --git a/st2common/st2common/middleware/cors.py b/st2common/st2common/middleware/cors.py index 1388e65e632..eaeac86f074 100644 --- a/st2common/st2common/middleware/cors.py +++ b/st2common/st2common/middleware/cors.py @@ -42,18 +42,18 @@ def __call__(self, environ, start_response): def custom_start_response(status, headers, exc_info=None): headers = ResponseHeaders(headers) - origin = request.headers.get('Origin') + origin = request.headers.get("Origin") origins = OrderedSet(cfg.CONF.api.allow_origin) # Build a list of the default allowed origins public_api_url = cfg.CONF.auth.api_url # Default gulp development server WebUI URL - origins.add('http://127.0.0.1:3000') + origins.add("http://127.0.0.1:3000") # By default WebUI simple http server listens on 8080 - origins.add('http://localhost:8080') - origins.add('http://127.0.0.1:8080') + origins.add("http://localhost:8080") + origins.add("http://127.0.0.1:8080") if public_api_url: # Public API URL @@ -62,7 +62,7 @@ def custom_start_response(status, headers, exc_info=None): origins = list(origins) if origin: - if '*' in origins: + if "*" in origins: origin_allowed = origin else: # See http://www.w3.org/TR/cors/#access-control-allow-origin-response-header @@ -70,21 +70,32 @@ def custom_start_response(status, headers, exc_info=None): else: origin_allowed = list(origins)[0] - methods_allowed = ['GET', 'POST', 'PUT', 'DELETE', 'OPTIONS'] - request_headers_allowed = ['Content-Type', 'Authorization', HEADER_ATTRIBUTE_NAME, - HEADER_API_KEY_ATTRIBUTE_NAME, REQUEST_ID_HEADER] - response_headers_allowed = ['Content-Type', 'X-Limit', 'X-Total-Count', - REQUEST_ID_HEADER] - - headers['Access-Control-Allow-Origin'] = origin_allowed - headers['Access-Control-Allow-Methods'] = ','.join(methods_allowed) - headers['Access-Control-Allow-Headers'] = ','.join(request_headers_allowed) - headers['Access-Control-Allow-Credentials'] = 'true' - headers['Access-Control-Expose-Headers'] = ','.join(response_headers_allowed) + methods_allowed = ["GET", "POST", "PUT", "DELETE", "OPTIONS"] + request_headers_allowed = [ + "Content-Type", + "Authorization", + HEADER_ATTRIBUTE_NAME, + HEADER_API_KEY_ATTRIBUTE_NAME, + REQUEST_ID_HEADER, + ] + response_headers_allowed = [ + "Content-Type", + "X-Limit", + "X-Total-Count", + REQUEST_ID_HEADER, + ] + + headers["Access-Control-Allow-Origin"] = origin_allowed + headers["Access-Control-Allow-Methods"] = ",".join(methods_allowed) + headers["Access-Control-Allow-Headers"] = ",".join(request_headers_allowed) + headers["Access-Control-Allow-Credentials"] = "true" + headers["Access-Control-Expose-Headers"] = ",".join( + response_headers_allowed + ) return start_response(status, headers._items, exc_info) - if request.method == 'OPTIONS': + if request.method == "OPTIONS": return Response()(environ, custom_start_response) else: return self.app(environ, custom_start_response) diff --git a/st2common/st2common/middleware/error_handling.py b/st2common/st2common/middleware/error_handling.py index 478cf3691bd..d7ae59cde56 100644 --- a/st2common/st2common/middleware/error_handling.py +++ b/st2common/st2common/middleware/error_handling.py @@ -50,13 +50,13 @@ def __call__(self, environ, start_response): except NotFoundException: raise exc.HTTPNotFound() except Exception as e: - status = getattr(e, 'code', exc.HTTPInternalServerError.code) + status = getattr(e, "code", exc.HTTPInternalServerError.code) - if hasattr(e, 'detail') and not getattr(e, 'comment'): - setattr(e, 'comment', getattr(e, 'detail')) + if hasattr(e, "detail") and not getattr(e, "comment"): + setattr(e, "comment", getattr(e, "detail")) - if hasattr(e, 'body') and isinstance(getattr(e, 'body', None), dict): - body = getattr(e, 'body', None) + if hasattr(e, "body") and isinstance(getattr(e, "body", None), dict): + body = getattr(e, "body", None) else: body = {} @@ -69,40 +69,40 @@ def __call__(self, environ, start_response): elif isinstance(e, db_exceptions.StackStormDBObjectConflictError): status_code = exc.HTTPConflict.code message = six.text_type(e) - body['conflict-id'] = getattr(e, 'conflict_id', None) + body["conflict-id"] = getattr(e, "conflict_id", None) elif isinstance(e, rbac_exceptions.AccessDeniedError): status_code = exc.HTTPForbidden.code message = six.text_type(e) elif isinstance(e, (ValueValidationException, ValueError, ValidationError)): status_code = exc.HTTPBadRequest.code - message = getattr(e, 'message', six.text_type(e)) + message = getattr(e, "message", six.text_type(e)) else: status_code = exc.HTTPInternalServerError.code - message = 'Internal Server Error' + message = "Internal Server Error" # Log the error is_internal_server_error = status_code == exc.HTTPInternalServerError.code - error_msg = getattr(e, 'comment', six.text_type(e)) + error_msg = getattr(e, "comment", six.text_type(e)) extra = { - 'exception_class': e.__class__.__name__, - 'exception_message': six.text_type(e), - 'exception_data': e.__dict__ + "exception_class": e.__class__.__name__, + "exception_message": six.text_type(e), + "exception_data": e.__dict__, } if is_internal_server_error: - LOG.exception('API call failed: %s', error_msg, extra=extra) + LOG.exception("API call failed: %s", error_msg, extra=extra) else: - LOG.debug('API call failed: %s', error_msg, extra=extra) + LOG.debug("API call failed: %s", error_msg, extra=extra) if is_debugging_enabled(): LOG.debug(traceback.format_exc()) - body['faultstring'] = message + body["faultstring"] = message response_body = json_encode(body) headers = { - 'Content-Type': 'application/json', - 'Content-Length': str(len(response_body)) + "Content-Type": "application/json", + "Content-Length": str(len(response_body)), } resp = Response(response_body, status=status_code, headers=headers) diff --git a/st2common/st2common/middleware/instrumentation.py b/st2common/st2common/middleware/instrumentation.py index 8ff7445f75a..e5d01d2223c 100644 --- a/st2common/st2common/middleware/instrumentation.py +++ b/st2common/st2common/middleware/instrumentation.py @@ -21,10 +21,7 @@ from st2common.util.date import get_datetime_utc_now from st2common.router import NotFoundException -__all__ = [ - 'RequestInstrumentationMiddleware', - 'ResponseInstrumentationMiddleware' -] +__all__ = ["RequestInstrumentationMiddleware", "ResponseInstrumentationMiddleware"] LOG = logging.getLogger(__name__) @@ -54,10 +51,11 @@ def __call__(self, environ, start_response): # NOTE: We don't track per request and response metrics for /v1/executions/ and some # other endpoints because this would result in a lot of unique metrics which is an # anti-pattern and causes unnecessary load on the metrics server. - submit_metrics = endpoint.get('x-submit-metrics', True) - operation_id = endpoint.get('operationId', None) - is_get_one_endpoint = bool(operation_id) and (operation_id.endswith('.get') or - operation_id.endswith('.get_one')) + submit_metrics = endpoint.get("x-submit-metrics", True) + operation_id = endpoint.get("operationId", None) + is_get_one_endpoint = bool(operation_id) and ( + operation_id.endswith(".get") or operation_id.endswith(".get_one") + ) if is_get_one_endpoint: # NOTE: We don't submit metrics for any get one API endpoint since this would result @@ -65,22 +63,22 @@ def __call__(self, environ, start_response): submit_metrics = False if not submit_metrics: - LOG.debug('Not submitting request metrics for path: %s' % (request.path)) + LOG.debug("Not submitting request metrics for path: %s" % (request.path)) return self.app(environ, start_response) metrics_driver = get_driver() - key = '%s.request.total' % (self._service_name) + key = "%s.request.total" % (self._service_name) metrics_driver.inc_counter(key) - key = '%s.request.method.%s' % (self._service_name, request.method) + key = "%s.request.method.%s" % (self._service_name, request.method) metrics_driver.inc_counter(key) - path = request.path.replace('/', '_') - key = '%s.request.path.%s' % (self._service_name, path) + path = request.path.replace("/", "_") + key = "%s.request.path.%s" % (self._service_name, path) metrics_driver.inc_counter(key) - if self._service_name == 'stream': + if self._service_name == "stream": # For stream service, we also record current number of open connections. # Due to the way stream service works, we need to utilize eventlet posthook to # correctly set the counter when the connection is closed / full response is returned. @@ -88,34 +86,34 @@ def __call__(self, environ, start_response): # hooks for details # Increase request counter - key = '%s.request' % (self._service_name) + key = "%s.request" % (self._service_name) metrics_driver.inc_counter(key) # Increase "total number of connections" gauge - metrics_driver.inc_gauge('stream.connections', 1) + metrics_driver.inc_gauge("stream.connections", 1) start_time = get_datetime_utc_now() def update_metrics_hook(env): # Hook which is called at the very end after all the response has been sent and # connection closed - time_delta = (get_datetime_utc_now() - start_time) + time_delta = get_datetime_utc_now() - start_time duration = time_delta.total_seconds() # Send total request time metrics_driver.time(key, duration) # Decrease "current number of connections" gauge - metrics_driver.dec_gauge('stream.connections', 1) + metrics_driver.dec_gauge("stream.connections", 1) # NOTE: Some tests mock environ and there 'eventlet.posthooks' key is not available - if 'eventlet.posthooks' in environ: - environ['eventlet.posthooks'].append((update_metrics_hook, (), {})) + if "eventlet.posthooks" in environ: + environ["eventlet.posthooks"].append((update_metrics_hook, (), {})) return self.app(environ, start_response) else: # Track and time current number of processing requests - key = '%s.request' % (self._service_name) + key = "%s.request" % (self._service_name) with CounterWithTimer(key=key): return self.app(environ, start_response) @@ -138,11 +136,12 @@ def __init__(self, app, router, service_name): def __call__(self, environ, start_response): # Track and time current number of processing requests def custom_start_response(status, headers, exc_info=None): - status_code = int(status.split(' ')[0]) + status_code = int(status.split(" ")[0]) metrics_driver = get_driver() - metrics_driver.inc_counter('%s.response.status.%s' % (self._service_name, - status_code)) + metrics_driver.inc_counter( + "%s.response.status.%s" % (self._service_name, status_code) + ) return start_response(status, headers, exc_info) diff --git a/st2common/st2common/middleware/logging.py b/st2common/st2common/middleware/logging.py index d41622ff295..a044e2c59bd 100644 --- a/st2common/st2common/middleware/logging.py +++ b/st2common/st2common/middleware/logging.py @@ -33,7 +33,7 @@ SECRET_QUERY_PARAMS = [ QUERY_PARAM_ATTRIBUTE_NAME, - QUERY_PARAM_API_KEY_ATTRIBUTE_NAME + QUERY_PARAM_API_KEY_ATTRIBUTE_NAME, ] + MASKED_ATTRIBUTES_BLACKLIST try: @@ -68,21 +68,23 @@ def __call__(self, environ, start_response): # Log the incoming request values = { - 'method': request.method, - 'path': request.path, - 'remote_addr': request.remote_addr, - 'query': query_params, - 'request_id': request.headers.get(REQUEST_ID_HEADER, None) + "method": request.method, + "path": request.path, + "remote_addr": request.remote_addr, + "query": query_params, + "request_id": request.headers.get(REQUEST_ID_HEADER, None), } - LOG.info('%(request_id)s - %(method)s %(path)s with query=%(query)s' % - values, extra=values) + LOG.info( + "%(request_id)s - %(method)s %(path)s with query=%(query)s" % values, + extra=values, + ) def custom_start_response(status, headers, exc_info=None): - status_code.append(int(status.split(' ')[0])) + status_code.append(int(status.split(" ")[0])) for name, value in headers: - if name.lower() == 'content-length': + if name.lower() == "content-length": content_length.append(int(value)) break @@ -95,7 +97,7 @@ def custom_start_response(status, headers, exc_info=None): except NotFoundException: endpoint = {} - log_result = endpoint.get('x-log-result', True) + log_result = endpoint.get("x-log-result", True) if isinstance(retval, (types.GeneratorType, itertools.chain)): # Note: We don't log the result when return value is a generator, because this would @@ -105,22 +107,28 @@ def custom_start_response(status, headers, exc_info=None): # Log the response values = { - 'method': request.method, - 'path': request.path, - 'remote_addr': request.remote_addr, - 'status': status_code[0], - 'runtime': float("{0:.3f}".format((clock() - start_time) * 10**3)), - 'content_length': content_length[0] if content_length else len(b''.join(retval)), - 'request_id': request.headers.get(REQUEST_ID_HEADER, None) + "method": request.method, + "path": request.path, + "remote_addr": request.remote_addr, + "status": status_code[0], + "runtime": float("{0:.3f}".format((clock() - start_time) * 10 ** 3)), + "content_length": content_length[0] + if content_length + else len(b"".join(retval)), + "request_id": request.headers.get(REQUEST_ID_HEADER, None), } - log_msg = '%(request_id)s - %(status)s %(content_length)s %(runtime)sms' % (values) + log_msg = "%(request_id)s - %(status)s %(content_length)s %(runtime)sms" % ( + values + ) LOG.info(log_msg, extra=values) if log_result: - values['result'] = retval[0] - log_msg = ('%(request_id)s - %(status)s %(content_length)s %(runtime)sms\n%(result)s' % - (values)) + values["result"] = retval[0] + log_msg = ( + "%(request_id)s - %(status)s %(content_length)s %(runtime)sms\n%(result)s" + % (values) + ) LOG.debug(log_msg, extra=values) return retval diff --git a/st2common/st2common/middleware/streaming.py b/st2common/st2common/middleware/streaming.py index 8f48dedbcf5..eb09084b304 100644 --- a/st2common/st2common/middleware/streaming.py +++ b/st2common/st2common/middleware/streaming.py @@ -16,9 +16,7 @@ from __future__ import absolute_import import fnmatch -__all__ = [ - 'StreamingMiddleware' -] +__all__ = ["StreamingMiddleware"] class StreamingMiddleware(object): @@ -32,7 +30,7 @@ def __call__(self, environ, start_response): # middleware is not important since it acts as pass-through. matches = False - req_path = environ.get('PATH_INFO', None) + req_path = environ.get("PATH_INFO", None) if not self._path_whitelist: matches = True @@ -43,6 +41,6 @@ def __call__(self, environ, start_response): break if matches: - environ['eventlet.minimum_write_chunk_size'] = 0 + environ["eventlet.minimum_write_chunk_size"] = 0 return self.app(environ, start_response) diff --git a/st2common/st2common/models/api/action.py b/st2common/st2common/models/api/action.py index 70eaeddad9e..1924f544601 100644 --- a/st2common/st2common/models/api/action.py +++ b/st2common/st2common/models/api/action.py @@ -23,7 +23,10 @@ from st2common.models.api.base import BaseAPI from st2common.models.api.base import APIUIDMixin from st2common.models.api.tag import TagsHelper -from st2common.models.api.notification import (NotificationSubSchemaAPI, NotificationsHelper) +from st2common.models.api.notification import ( + NotificationSubSchemaAPI, + NotificationsHelper, +) from st2common.models.db.action import ActionDB from st2common.models.db.actionalias import ActionAliasDB from st2common.models.db.executionstate import ActionExecutionStateDB @@ -34,17 +37,16 @@ __all__ = [ - 'ActionAPI', - 'ActionCreateAPI', - 'LiveActionAPI', - 'LiveActionCreateAPI', - 'RunnerTypeAPI', - - 'AliasExecutionAPI', - 'AliasMatchAndExecuteInputAPI', - 'ActionAliasAPI', - 'ActionAliasMatchAPI', - 'ActionAliasHelpAPI' + "ActionAPI", + "ActionCreateAPI", + "LiveActionAPI", + "LiveActionCreateAPI", + "RunnerTypeAPI", + "AliasExecutionAPI", + "AliasMatchAndExecuteInputAPI", + "ActionAliasAPI", + "ActionAliasMatchAPI", + "ActionAliasHelpAPI", ] @@ -56,6 +58,7 @@ class RunnerTypeAPI(BaseAPI): The representation of an RunnerType in the system. An RunnerType has a one-to-one mapping to a particular ActionRunner implementation. """ + model = RunnerTypeDB schema = { "title": "Runner", @@ -65,42 +68,40 @@ class RunnerTypeAPI(BaseAPI): "id": { "description": "The unique identifier for the action runner.", "type": "string", - "default": None - }, - "uid": { - "type": "string" + "default": None, }, + "uid": {"type": "string"}, "name": { "description": "The name of the action runner.", "type": "string", - "required": True + "required": True, }, "description": { "description": "The description of the action runner.", - "type": "string" + "type": "string", }, "enabled": { "description": "Enable or disable the action runner.", "type": "boolean", - "default": True + "default": True, }, "runner_package": { "description": "The python package that implements the " - "action runner for this type.", + "action runner for this type.", "type": "string", - "required": False + "required": False, }, "runner_module": { "description": "The python module that implements the " - "action runner for this type.", + "action runner for this type.", "type": "string", - "required": True + "required": True, }, "query_module": { "description": "The python module that implements the " - "results tracker (querier) for the runner.", + "results tracker (querier) for the runner.", "type": "string", - "required": False + "required": False, }, "runner_parameters": { "description": "Input parameters for the action runner.", @@ -108,24 +109,22 @@ class RunnerTypeAPI(BaseAPI): "patternProperties": { r"^\w+$": util_schema.get_action_parameters_schema() }, - 'additionalProperties': False + "additionalProperties": False, }, "output_key": { "description": "Default key to expect results to be published to.", "type": "string", - "required": False + "required": False, }, "output_schema": { "description": "Schema for the runner's output.", "type": "object", - "patternProperties": { - r"^\w+$": util_schema.get_action_output_schema() - }, - 'additionalProperties': False, - "default": {} + "patternProperties": {r"^\w+$": util_schema.get_action_output_schema()}, + "additionalProperties": False, + "default": {}, }, }, - "additionalProperties": False + "additionalProperties": False, } def __init__(self, **kw): @@ -138,25 +137,34 @@ def __init__(self, **kw): # modified one for key, value in kw.items(): setattr(self, key, value) - if not hasattr(self, 'runner_parameters'): - setattr(self, 'runner_parameters', dict()) + if not hasattr(self, "runner_parameters"): + setattr(self, "runner_parameters", dict()) @classmethod def to_model(cls, runner_type): name = runner_type.name description = runner_type.description - enabled = getattr(runner_type, 'enabled', True) - runner_package = getattr(runner_type, 'runner_package', runner_type.runner_module) + enabled = getattr(runner_type, "enabled", True) + runner_package = getattr( + runner_type, "runner_package", runner_type.runner_module + ) runner_module = str(runner_type.runner_module) - runner_parameters = getattr(runner_type, 'runner_parameters', dict()) - output_key = getattr(runner_type, 'output_key', None) - output_schema = getattr(runner_type, 'output_schema', dict()) - query_module = getattr(runner_type, 'query_module', None) - - model = cls.model(name=name, description=description, enabled=enabled, - runner_package=runner_package, runner_module=runner_module, - runner_parameters=runner_parameters, output_schema=output_schema, - query_module=query_module, output_key=output_key) + runner_parameters = getattr(runner_type, "runner_parameters", dict()) + output_key = getattr(runner_type, "output_key", None) + output_schema = getattr(runner_type, "output_schema", dict()) + query_module = getattr(runner_type, "query_module", None) + + model = cls.model( + name=name, + description=description, + enabled=enabled, + runner_package=runner_package, + runner_module=runner_module, + runner_parameters=runner_parameters, + output_schema=output_schema, + query_module=query_module, + output_key=output_key, + ) return model @@ -174,44 +182,42 @@ class ActionAPI(BaseAPI, APIUIDMixin): "properties": { "id": { "description": "The unique identifier for the action.", - "type": "string" + "type": "string", }, "ref": { "description": "System computed user friendly reference for the action. \ Provided value will be overridden by computed value.", - "type": "string" - }, - "uid": { - "type": "string" + "type": "string", }, + "uid": {"type": "string"}, "name": { "description": "The name of the action.", "type": "string", - "required": True + "required": True, }, "description": { "description": "The description of the action.", - "type": "string" + "type": "string", }, "enabled": { "description": "Enable or disable the action from invocation.", "type": "boolean", - "default": True + "default": True, }, "runner_type": { "description": "The type of runner that executes the action.", "type": "string", - "required": True + "required": True, }, "entry_point": { "description": "The entry point for the action.", "type": "string", - "default": "" + "default": "", }, "pack": { "description": "The content pack this action belongs to.", "type": "string", - "default": DEFAULT_PACK_NAME + "default": DEFAULT_PACK_NAME, }, "parameters": { "description": "Input parameters for the action.", @@ -219,22 +225,20 @@ class ActionAPI(BaseAPI, APIUIDMixin): "patternProperties": { r"^\w+$": util_schema.get_action_parameters_schema() }, - 'additionalProperties': False, - "default": {} + "additionalProperties": False, + "default": {}, }, "output_schema": { "description": "Schema for the action's output.", "type": "object", - "patternProperties": { - r"^\w+$": util_schema.get_action_output_schema() - }, - 'additionalProperties': False, - "default": {} + "patternProperties": {r"^\w+$": util_schema.get_action_output_schema()}, + "additionalProperties": False, + "default": {}, }, "tags": { "description": "User associated metadata assigned to this object.", "type": "array", - "items": {"type": "object"} + "items": {"type": "object"}, }, "notify": { "description": "Notification settings for action.", @@ -242,52 +246,52 @@ class ActionAPI(BaseAPI, APIUIDMixin): "properties": { "on-complete": NotificationSubSchemaAPI, "on-failure": NotificationSubSchemaAPI, - "on-success": NotificationSubSchemaAPI + "on-success": NotificationSubSchemaAPI, }, - "additionalProperties": False + "additionalProperties": False, }, "metadata_file": { "description": "Path to the metadata file relative to the pack directory.", "type": "string", - "default": "" - } + "default": "", + }, }, - "additionalProperties": False + "additionalProperties": False, } def __init__(self, **kw): for key, value in kw.items(): setattr(self, key, value) - if not hasattr(self, 'parameters'): - setattr(self, 'parameters', dict()) - if not hasattr(self, 'entry_point'): - setattr(self, 'entry_point', '') + if not hasattr(self, "parameters"): + setattr(self, "parameters", dict()) + if not hasattr(self, "entry_point"): + setattr(self, "entry_point", "") @classmethod def from_model(cls, model, mask_secrets=False): action = cls._from_model(model) - action['runner_type'] = action.get('runner_type', {}).get('name', None) - action['tags'] = TagsHelper.from_model(model.tags) + action["runner_type"] = action.get("runner_type", {}).get("name", None) + action["tags"] = TagsHelper.from_model(model.tags) - if getattr(model, 'notify', None): - action['notify'] = NotificationsHelper.from_model(model.notify) + if getattr(model, "notify", None): + action["notify"] = NotificationsHelper.from_model(model.notify) return cls(**action) @classmethod def to_model(cls, action): - name = getattr(action, 'name', None) - description = getattr(action, 'description', None) - enabled = bool(getattr(action, 'enabled', True)) + name = getattr(action, "name", None) + description = getattr(action, "description", None) + enabled = bool(getattr(action, "enabled", True)) entry_point = str(action.entry_point) pack = str(action.pack) - runner_type = {'name': str(action.runner_type)} - parameters = getattr(action, 'parameters', dict()) - output_schema = getattr(action, 'output_schema', dict()) - tags = TagsHelper.to_model(getattr(action, 'tags', [])) + runner_type = {"name": str(action.runner_type)} + parameters = getattr(action, "parameters", dict()) + output_schema = getattr(action, "output_schema", dict()) + tags = TagsHelper.to_model(getattr(action, "tags", [])) ref = ResourceReference.to_string_reference(pack=pack, name=name) - if getattr(action, 'notify', None): + if getattr(action, "notify", None): notify = NotificationsHelper.to_model(action.notify) else: # We use embedded document model for ``notify`` in action model. If notify is @@ -296,12 +300,22 @@ def to_model(cls, action): # to use an empty document. notify = NotificationsHelper.to_model({}) - metadata_file = getattr(action, 'metadata_file', None) - - model = cls.model(name=name, description=description, enabled=enabled, - entry_point=entry_point, pack=pack, runner_type=runner_type, - tags=tags, parameters=parameters, output_schema=output_schema, - notify=notify, ref=ref, metadata_file=metadata_file) + metadata_file = getattr(action, "metadata_file", None) + + model = cls.model( + name=name, + description=description, + enabled=enabled, + entry_point=entry_point, + pack=pack, + runner_type=runner_type, + tags=tags, + parameters=parameters, + output_schema=output_schema, + notify=notify, + ref=ref, + metadata_file=metadata_file, + ) return model @@ -310,28 +324,31 @@ class ActionCreateAPI(ActionAPI, APIUIDMixin): """ API model for create action operation. """ + schema = copy.deepcopy(ActionAPI.schema) - schema['properties']['data_files'] = { - 'description': 'Optional action script and data files which are written to the filesystem.', - 'type': 'array', - 'items': { - 'type': 'object', - 'properties': { - 'file_path': { - 'type': 'string', - 'description': ('Path to the file relative to the pack actions directory ' - '(e.g. my_action.py)'), - 'required': True + schema["properties"]["data_files"] = { + "description": "Optional action script and data files which are written to the filesystem.", + "type": "array", + "items": { + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": ( + "Path to the file relative to the pack actions directory " + "(e.g. my_action.py)" + ), + "required": True, }, - 'content': { - 'type': 'string', - 'description': 'Raw file content.', - 'required': True + "content": { + "type": "string", + "description": "Raw file content.", + "required": True, }, }, - 'additionalProperties': False + "additionalProperties": False, }, - 'default': [] + "default": [], } @@ -339,8 +356,9 @@ class ActionUpdateAPI(ActionAPI, APIUIDMixin): """ API model for update action operation. """ + schema = copy.deepcopy(ActionCreateAPI.schema) - del schema['properties']['pack']['default'] + del schema["properties"]["pack"]["default"] class LiveActionAPI(BaseAPI): @@ -356,27 +374,27 @@ class LiveActionAPI(BaseAPI): "properties": { "id": { "description": "The unique identifier for the action execution.", - "type": "string" + "type": "string", }, "status": { "description": "The current status of the action execution.", "type": "string", - "enum": LIVEACTION_STATUSES + "enum": LIVEACTION_STATUSES, }, "start_timestamp": { "description": "The start time when the action is executed.", "type": "string", - "pattern": isotime.ISO8601_UTC_REGEX + "pattern": isotime.ISO8601_UTC_REGEX, }, "end_timestamp": { "description": "The timestamp when the action has finished.", "type": "string", - "pattern": isotime.ISO8601_UTC_REGEX + "pattern": isotime.ISO8601_UTC_REGEX, }, "action": { "description": "Reference to the action to be executed.", "type": "string", - "required": True + "required": True, }, "parameters": { "description": "Input parameters for the action.", @@ -390,58 +408,56 @@ class LiveActionAPI(BaseAPI): {"type": "number"}, {"type": "object"}, {"type": "string"}, - {"type": "null"} + {"type": "null"}, ] } }, - 'additionalProperties': False + "additionalProperties": False, }, "result": { - "anyOf": [{"type": "array"}, - {"type": "boolean"}, - {"type": "integer"}, - {"type": "number"}, - {"type": "object"}, - {"type": "string"}] - }, - "context": { - "type": "object" - }, - "callback": { - "type": "object" - }, - "runner_info": { - "type": "object" - }, + "anyOf": [ + {"type": "array"}, + {"type": "boolean"}, + {"type": "integer"}, + {"type": "number"}, + {"type": "object"}, + {"type": "string"}, + ] + }, + "context": {"type": "object"}, + "callback": {"type": "object"}, + "runner_info": {"type": "object"}, "notify": { "description": "Notification settings for liveaction.", "type": "object", "properties": { "on-complete": NotificationSubSchemaAPI, "on-failure": NotificationSubSchemaAPI, - "on-success": NotificationSubSchemaAPI + "on-success": NotificationSubSchemaAPI, }, - "additionalProperties": False + "additionalProperties": False, }, "delay": { - "description": ("How long (in milliseconds) to delay the execution before" - "scheduling."), + "description": ( + "How long (in milliseconds) to delay the execution before" + "scheduling." + ), "type": "integer", - } + }, }, - "additionalProperties": False + "additionalProperties": False, } @classmethod def from_model(cls, model, mask_secrets=False): doc = super(cls, cls)._from_model(model, mask_secrets=mask_secrets) if model.start_timestamp: - doc['start_timestamp'] = isotime.format(model.start_timestamp, offset=False) + doc["start_timestamp"] = isotime.format(model.start_timestamp, offset=False) if model.end_timestamp: - doc['end_timestamp'] = isotime.format(model.end_timestamp, offset=False) + doc["end_timestamp"] = isotime.format(model.end_timestamp, offset=False) - if getattr(model, 'notify', None): - doc['notify'] = NotificationsHelper.from_model(model.notify) + if getattr(model, "notify", None): + doc["notify"] = NotificationsHelper.from_model(model.notify) return cls(**doc) @@ -449,32 +465,40 @@ def from_model(cls, model, mask_secrets=False): def to_model(cls, live_action): action = live_action.action - if getattr(live_action, 'start_timestamp', None): + if getattr(live_action, "start_timestamp", None): start_timestamp = isotime.parse(live_action.start_timestamp) else: start_timestamp = None - if getattr(live_action, 'end_timestamp', None): + if getattr(live_action, "end_timestamp", None): end_timestamp = isotime.parse(live_action.end_timestamp) else: end_timestamp = None - status = getattr(live_action, 'status', None) - parameters = getattr(live_action, 'parameters', dict()) - context = getattr(live_action, 'context', dict()) - callback = getattr(live_action, 'callback', dict()) - result = getattr(live_action, 'result', None) - delay = getattr(live_action, 'delay', None) + status = getattr(live_action, "status", None) + parameters = getattr(live_action, "parameters", dict()) + context = getattr(live_action, "context", dict()) + callback = getattr(live_action, "callback", dict()) + result = getattr(live_action, "result", None) + delay = getattr(live_action, "delay", None) - if getattr(live_action, 'notify', None): + if getattr(live_action, "notify", None): notify = NotificationsHelper.to_model(live_action.notify) else: notify = None - model = cls.model(action=action, - start_timestamp=start_timestamp, end_timestamp=end_timestamp, - status=status, parameters=parameters, context=context, - callback=callback, result=result, notify=notify, delay=delay) + model = cls.model( + action=action, + start_timestamp=start_timestamp, + end_timestamp=end_timestamp, + status=status, + parameters=parameters, + context=context, + callback=callback, + result=result, + notify=notify, + delay=delay, + ) return model @@ -483,11 +507,12 @@ class LiveActionCreateAPI(LiveActionAPI): """ API model for action execution create (run action) operations. """ + schema = copy.deepcopy(LiveActionAPI.schema) - schema['properties']['user'] = { - 'description': 'User context under which action should run (admins only)', - 'type': 'string', - 'default': None + schema["properties"]["user"] = { + "description": "User context under which action should run (admins only)", + "type": "string", + "default": None, } @@ -496,6 +521,7 @@ class ActionExecutionStateAPI(BaseAPI): System entity that represents state of an action in the system. This is used only in tests for now. """ + model = ActionExecutionStateDB schema = { "title": "ActionExecutionState", @@ -504,25 +530,25 @@ class ActionExecutionStateAPI(BaseAPI): "properties": { "id": { "description": "The unique identifier for the action execution state.", - "type": "string" + "type": "string", }, "execution_id": { "type": "string", "description": "ID of the action execution.", - "required": True + "required": True, }, "query_context": { "type": "object", "description": "query context to be used by querier.", - "required": True + "required": True, }, "query_module": { "type": "string", "description": "Name of the query module.", - "required": True - } + "required": True, + }, }, - "additionalProperties": False + "additionalProperties": False, } @classmethod @@ -531,8 +557,11 @@ def to_model(cls, state): query_module = state.query_module query_context = state.query_context - model = cls.model(execution_id=execution_id, query_module=query_module, - query_context=query_context) + model = cls.model( + execution_id=execution_id, + query_module=query_module, + query_context=query_context, + ) return model @@ -540,6 +569,7 @@ class ActionAliasAPI(BaseAPI, APIUIDMixin): """ Alias for an action in the system. """ + model = ActionAliasDB schema = { "title": "ActionAlias", @@ -548,42 +578,40 @@ class ActionAliasAPI(BaseAPI, APIUIDMixin): "properties": { "id": { "description": "The unique identifier for the action alias.", - "type": "string" + "type": "string", }, "ref": { "description": ( "System computed user friendly reference for the alias. " "Provided value will be overridden by computed value." ), - "type": "string" - }, - "uid": { - "type": "string" + "type": "string", }, + "uid": {"type": "string"}, "name": { "type": "string", "description": "Name of the action alias.", - "required": True + "required": True, }, "pack": { "description": "The content pack this actionalias belongs to.", "type": "string", - "required": True + "required": True, }, "description": { "type": "string", "description": "Description of the action alias.", - "default": None + "default": None, }, "enabled": { "description": "Flag indicating of action alias is enabled.", "type": "boolean", - "default": True + "default": True, }, "action_ref": { "type": "string", "description": "Reference to the aliased action.", - "required": True + "required": True, }, "formats": { "type": "array", @@ -596,13 +624,13 @@ class ActionAliasAPI(BaseAPI, APIUIDMixin): "display": {"type": "string"}, "representation": { "type": "array", - "items": {"type": "string"} - } - } - } + "items": {"type": "string"}, + }, + }, + }, ] }, - "description": "Possible parameter format." + "description": "Possible parameter format.", }, "ack": { "type": "object", @@ -610,56 +638,65 @@ class ActionAliasAPI(BaseAPI, APIUIDMixin): "enabled": {"type": "boolean"}, "format": {"type": "string"}, "extra": {"type": "object"}, - "append_url": {"type": "boolean"} + "append_url": {"type": "boolean"}, }, - "description": "Acknowledgement message format." + "description": "Acknowledgement message format.", }, "result": { "type": "object", "properties": { "enabled": {"type": "boolean"}, "format": {"type": "string"}, - "extra": {"type": "object"} + "extra": {"type": "object"}, }, - "description": "Execution message format." + "description": "Execution message format.", }, "extra": { "type": "object", - "description": "Extra parameters, usually adapter-specific." + "description": "Extra parameters, usually adapter-specific.", }, "immutable_parameters": { "type": "object", - "description": "Parameters to be passed to the action on every execution." + "description": "Parameters to be passed to the action on every execution.", }, "metadata_file": { "description": "Path to the metadata file relative to the pack directory.", "type": "string", - "default": "" - } + "default": "", + }, }, - "additionalProperties": False + "additionalProperties": False, } @classmethod def to_model(cls, alias): name = alias.name - description = getattr(alias, 'description', None) + description = getattr(alias, "description", None) pack = alias.pack ref = ResourceReference.to_string_reference(pack=pack, name=name) - enabled = getattr(alias, 'enabled', True) + enabled = getattr(alias, "enabled", True) action_ref = alias.action_ref formats = alias.formats - ack = getattr(alias, 'ack', None) - result = getattr(alias, 'result', None) - extra = getattr(alias, 'extra', None) - immutable_parameters = getattr(alias, 'immutable_parameters', None) - metadata_file = getattr(alias, 'metadata_file', None) - - model = cls.model(name=name, description=description, pack=pack, ref=ref, - enabled=enabled, action_ref=action_ref, formats=formats, - ack=ack, result=result, extra=extra, - immutable_parameters=immutable_parameters, - metadata_file=metadata_file) + ack = getattr(alias, "ack", None) + result = getattr(alias, "result", None) + extra = getattr(alias, "extra", None) + immutable_parameters = getattr(alias, "immutable_parameters", None) + metadata_file = getattr(alias, "metadata_file", None) + + model = cls.model( + name=name, + description=description, + pack=pack, + ref=ref, + enabled=enabled, + action_ref=action_ref, + formats=formats, + ack=ack, + result=result, + extra=extra, + immutable_parameters=immutable_parameters, + metadata_file=metadata_file, + ) return model @@ -667,6 +704,7 @@ class AliasExecutionAPI(BaseAPI): """ Alias for an action in the system. """ + model = None schema = { "title": "AliasExecution", @@ -676,48 +714,48 @@ class AliasExecutionAPI(BaseAPI): "name": { "type": "string", "description": "Name of the action alias which matched.", - "required": True + "required": True, }, "format": { "type": "string", "description": "Format string which matched.", - "required": True + "required": True, }, "command": { "type": "string", "description": "Command used in chat.", - "required": True + "required": True, }, "user": { "type": "string", "description": "User that requested the execution.", - "default": "channel" # TODO: This value doesnt get set + "default": "channel", # TODO: This value doesnt get set }, "source_channel": { "type": "string", "description": "Channel from which the execution was requested. This is not the " - "channel as defined by the notification system.", - "required": True + "channel as defined by the notification system.", + "required": True, }, "source_context": { "type": "object", "description": "ALL data included with the message (also called the message " - "envelope). This is currently only used by the Microsoft Teams " - "adapter.", - "required": False + "envelope). This is currently only used by the Microsoft Teams " + "adapter.", + "required": False, }, "notification_channel": { "type": "string", "description": "StackStorm notification channel to use to respond.", - "required": False + "required": False, }, "notification_route": { "type": "string", "description": "StackStorm notification route to use to respond.", - "required": False - } + "required": False, + }, }, - "additionalProperties": False + "additionalProperties": False, } @classmethod @@ -734,6 +772,7 @@ class AliasMatchAndExecuteInputAPI(BaseAPI): """ API object used for alias execution "match and execute" API endpoint request payload. """ + model = None schema = { "title": "ActionAliasMatchAndExecuteInputAPI", @@ -743,7 +782,7 @@ class AliasMatchAndExecuteInputAPI(BaseAPI): "command": { "type": "string", "description": "Command used in chat.", - "required": True + "required": True, }, "user": { "type": "string", @@ -753,22 +792,22 @@ class AliasMatchAndExecuteInputAPI(BaseAPI): "type": "string", "description": "Channel from which the execution was requested. This is not the \ channel as defined by the notification system.", - "required": True + "required": True, }, "notification_channel": { "type": "string", "description": "StackStorm notification channel to use to respond.", "required": False, - "default": None + "default": None, }, "notification_route": { "type": "string", "description": "StackStorm notification route to use to respond.", "required": False, - "default": None - } + "default": None, + }, }, - "additionalProperties": False + "additionalProperties": False, } @@ -776,6 +815,7 @@ class ActionAliasMatchAPI(BaseAPI): """ API model used for alias match API endpoint. """ + model = None schema = { @@ -786,10 +826,10 @@ class ActionAliasMatchAPI(BaseAPI): "command": { "type": "string", "description": "Command string to try to match the aliases against.", - "required": True + "required": True, } }, - "additionalProperties": False + "additionalProperties": False, } @classmethod @@ -805,6 +845,7 @@ class ActionAliasHelpAPI(BaseAPI): """ API model used to display action-alias help API endpoint. """ + model = None schema = { @@ -816,28 +857,28 @@ class ActionAliasHelpAPI(BaseAPI): "type": "string", "description": "Find help strings containing keyword.", "required": False, - "default": "" + "default": "", }, "pack": { "type": "string", "description": "List help strings for a specific pack.", "required": False, - "default": "" + "default": "", }, "offset": { "type": "integer", "description": "List help strings from the offset position.", "required": False, - "default": 0 + "default": 0, }, "limit": { "type": "integer", "description": "Limit the number of help strings returned.", "required": False, - "default": 0 - } + "default": 0, + }, }, - "additionalProperties": False + "additionalProperties": False, } @classmethod diff --git a/st2common/st2common/models/api/actionrunner.py b/st2common/st2common/models/api/actionrunner.py index d2a2029e323..7b580e1c9bc 100644 --- a/st2common/st2common/models/api/actionrunner.py +++ b/st2common/st2common/models/api/actionrunner.py @@ -17,7 +17,7 @@ from st2common import log as logging from st2common.models.api.base import BaseAPI -__all__ = ['ActionRunnerAPI'] +__all__ = ["ActionRunnerAPI"] LOG = logging.getLogger(__name__) @@ -29,12 +29,9 @@ class ActionRunnerAPI(BaseAPI): Attribute: ... """ + schema = { - 'type': 'object', - 'parameters': { - 'id': { - 'type': 'string' - } - }, - 'additionalProperties': False + "type": "object", + "parameters": {"id": {"type": "string"}}, + "additionalProperties": False, } diff --git a/st2common/st2common/models/api/auth.py b/st2common/st2common/models/api/auth.py index 8e5ed34e347..10672e99ecd 100644 --- a/st2common/st2common/models/api/auth.py +++ b/st2common/st2common/models/api/auth.py @@ -36,13 +36,8 @@ class UserAPI(BaseAPI): schema = { "title": "User", "type": "object", - "properties": { - "name": { - "type": "string", - "required": True - } - }, - "additionalProperties": False + "properties": {"name": {"type": "string", "required": True}}, + "additionalProperties": False, } @classmethod @@ -58,34 +53,25 @@ class TokenAPI(BaseAPI): "title": "Token", "type": "object", "properties": { - "id": { - "type": "string" - }, - "user": { - "type": ["string", "null"] - }, - "token": { - "type": ["string", "null"] - }, - "ttl": { - "type": "integer", - "minimum": 1 - }, + "id": {"type": "string"}, + "user": {"type": ["string", "null"]}, + "token": {"type": ["string", "null"]}, + "ttl": {"type": "integer", "minimum": 1}, "expiry": { "type": ["string", "null"], - "pattern": isotime.ISO8601_UTC_REGEX + "pattern": isotime.ISO8601_UTC_REGEX, }, - "metadata": { - "type": ["object", "null"] - } + "metadata": {"type": ["object", "null"]}, }, - "additionalProperties": False + "additionalProperties": False, } @classmethod def from_model(cls, model, mask_secrets=False): doc = super(cls, cls)._from_model(model, mask_secrets=mask_secrets) - doc['expiry'] = isotime.format(model.expiry, offset=False) if model.expiry else None + doc["expiry"] = ( + isotime.format(model.expiry, offset=False) if model.expiry else None + ) return cls(**doc) @classmethod @@ -104,52 +90,44 @@ class ApiKeyAPI(BaseAPI, APIUIDMixin): "title": "ApiKey", "type": "object", "properties": { - "id": { - "type": "string" - }, - "uid": { - "type": "string" - }, - "user": { - "type": ["string", "null"], - "default": "" - }, - "key_hash": { - "type": ["string", "null"] - }, - "metadata": { - "type": ["object", "null"] - }, - 'created_at': { - 'description': 'The start time when the action is executed.', - 'type': 'string', - 'pattern': isotime.ISO8601_UTC_REGEX + "id": {"type": "string"}, + "uid": {"type": "string"}, + "user": {"type": ["string", "null"], "default": ""}, + "key_hash": {"type": ["string", "null"]}, + "metadata": {"type": ["object", "null"]}, + "created_at": { + "description": "The start time when the action is executed.", + "type": "string", + "pattern": isotime.ISO8601_UTC_REGEX, }, "enabled": { "description": "Enable or disable the action from invocation.", "type": "boolean", - "default": True - } + "default": True, + }, }, - "additionalProperties": False + "additionalProperties": False, } @classmethod def from_model(cls, model, mask_secrets=False): doc = super(cls, cls)._from_model(model, mask_secrets=mask_secrets) - doc['created_at'] = isotime.format(model.created_at, offset=False) if model.created_at \ - else None + doc["created_at"] = ( + isotime.format(model.created_at, offset=False) if model.created_at else None + ) return cls(**doc) @classmethod def to_model(cls, instance): # If PrimaryKey ID is provided, - we want to work with existing ST2 API key - id = getattr(instance, 'id', None) + id = getattr(instance, "id", None) user = str(instance.user) if instance.user else None - key_hash = getattr(instance, 'key_hash', None) - metadata = getattr(instance, 'metadata', {}) - enabled = bool(getattr(instance, 'enabled', True)) - model = cls.model(id=id, user=user, key_hash=key_hash, metadata=metadata, enabled=enabled) + key_hash = getattr(instance, "key_hash", None) + metadata = getattr(instance, "metadata", {}) + enabled = bool(getattr(instance, "enabled", True)) + model = cls.model( + id=id, user=user, key_hash=key_hash, metadata=metadata, enabled=enabled + ) return model @@ -158,45 +136,35 @@ class ApiKeyCreateResponseAPI(BaseAPI): "title": "APIKeyCreateResponse", "type": "object", "properties": { - "id": { - "type": "string" - }, - "uid": { - "type": "string" - }, - "user": { - "type": ["string", "null"], - "default": "" - }, - "key": { - "type": ["string", "null"] - }, - "metadata": { - "type": ["object", "null"] - }, - 'created_at': { - 'description': 'The start time when the action is executed.', - 'type': 'string', - 'pattern': isotime.ISO8601_UTC_REGEX + "id": {"type": "string"}, + "uid": {"type": "string"}, + "user": {"type": ["string", "null"], "default": ""}, + "key": {"type": ["string", "null"]}, + "metadata": {"type": ["object", "null"]}, + "created_at": { + "description": "The start time when the action is executed.", + "type": "string", + "pattern": isotime.ISO8601_UTC_REGEX, }, "enabled": { "description": "Enable or disable the action from invocation.", "type": "boolean", - "default": True - } + "default": True, + }, }, - "additionalProperties": False + "additionalProperties": False, } @classmethod def from_model(cls, model, mask_secrets=False): doc = cls._from_model(model=model, mask_secrets=mask_secrets) attrs = {attr: value for attr, value in six.iteritems(doc) if value is not None} - attrs['created_at'] = isotime.format(model.created_at, offset=False) if model.created_at \ - else None + attrs["created_at"] = ( + isotime.format(model.created_at, offset=False) if model.created_at else None + ) # key_hash is ignored. - attrs.pop('key_hash', None) + attrs.pop("key_hash", None) # key is unknown so the calling code will have to update after conversion. - attrs['key'] = None + attrs["key"] = None return cls(**attrs) diff --git a/st2common/st2common/models/api/base.py b/st2common/st2common/models/api/base.py index 3669291e9cd..6c052a43e33 100644 --- a/st2common/st2common/models/api/base.py +++ b/st2common/st2common/models/api/base.py @@ -23,10 +23,7 @@ from st2common.util import mongoescape as util_mongodb from st2common import log as logging -__all__ = [ - 'BaseAPI', - 'APIUIDMixin' -] +__all__ = ["BaseAPI", "APIUIDMixin"] LOG = logging.getLogger(__name__) @@ -43,13 +40,13 @@ def __init__(self, **kw): def __repr__(self): name = type(self).__name__ - attrs = ', '.join("'%s': %r" % item for item in six.iteritems(vars(self))) + attrs = ", ".join("'%s': %r" % item for item in six.iteritems(vars(self))) # The format here is so that eval can be applied. return "%s(**{%s})" % (name, attrs) def __str__(self): name = type(self).__name__ - attrs = ', '.join("%s=%r" % item for item in six.iteritems(vars(self))) + attrs = ", ".join("%s=%r" % item for item in six.iteritems(vars(self))) return "%s[%s]" % (name, attrs) @@ -66,12 +63,16 @@ def validate(self): """ from st2common.util import schema as util_schema - schema = getattr(self, 'schema', {}) + schema = getattr(self, "schema", {}) attributes = vars(self) - cleaned = util_schema.validate(instance=attributes, schema=schema, - cls=util_schema.CustomValidator, use_default=True, - allow_default_none=True) + cleaned = util_schema.validate( + instance=attributes, + schema=schema, + cls=util_schema.CustomValidator, + use_default=True, + allow_default_none=True, + ) # Note: We use type() instead of self.__class__ since self.__class__ confuses pylint return type(self)(**cleaned) @@ -80,8 +81,8 @@ def validate(self): def _from_model(cls, model, mask_secrets=False): doc = model.to_mongo() - if '_id' in doc: - doc['id'] = str(doc.pop('_id')) + if "_id" in doc: + doc["id"] = str(doc.pop("_id")) doc = util_mongodb.unescape_chars(doc) @@ -117,7 +118,7 @@ def to_model(cls, doc): class APIUIDMixin(object): - """" + """ " Mixin class for retrieving UID for API objects. """ @@ -142,9 +143,11 @@ def has_valid_uid(self): def cast_argument_value(value_type, value): if value_type == bool: + def cast_func(value): value = str(value) - return value.lower() in ['1', 'true'] + return value.lower() in ["1", "true"] + else: cast_func = value_type diff --git a/st2common/st2common/models/api/execution.py b/st2common/st2common/models/api/execution.py index 447a8679eb6..87a5ff52c0e 100644 --- a/st2common/st2common/models/api/execution.py +++ b/st2common/st2common/models/api/execution.py @@ -28,10 +28,7 @@ from st2common.models.api.action import RunnerTypeAPI, ActionAPI, LiveActionAPI from st2common import log as logging -__all__ = [ - 'ActionExecutionAPI', - 'ActionExecutionOutputAPI' -] +__all__ = ["ActionExecutionAPI", "ActionExecutionOutputAPI"] LOG = logging.getLogger(__name__) @@ -48,47 +45,44 @@ class ActionExecutionAPI(BaseAPI): model = ActionExecutionDB - SKIP = ['start_timestamp', 'end_timestamp'] + SKIP = ["start_timestamp", "end_timestamp"] schema = { "title": "ActionExecution", "description": "Record of the execution of an action.", "type": "object", "properties": { - "id": { - "type": "string", - "required": True - }, + "id": {"type": "string", "required": True}, "trigger": TriggerAPI.schema, "trigger_type": TriggerTypeAPI.schema, "trigger_instance": TriggerInstanceAPI.schema, "rule": RuleAPI.schema, - "action": REQUIRED_ATTR_SCHEMAS['action'], - "runner": REQUIRED_ATTR_SCHEMAS['runner'], - "liveaction": REQUIRED_ATTR_SCHEMAS['liveaction'], + "action": REQUIRED_ATTR_SCHEMAS["action"], + "runner": REQUIRED_ATTR_SCHEMAS["runner"], + "liveaction": REQUIRED_ATTR_SCHEMAS["liveaction"], "status": { "description": "The current status of the action execution.", "type": "string", - "enum": LIVEACTION_STATUSES + "enum": LIVEACTION_STATUSES, }, "start_timestamp": { "description": "The start time when the action is executed.", "type": "string", - "pattern": isotime.ISO8601_UTC_REGEX + "pattern": isotime.ISO8601_UTC_REGEX, }, "end_timestamp": { "description": "The timestamp when the action has finished.", "type": "string", - "pattern": isotime.ISO8601_UTC_REGEX + "pattern": isotime.ISO8601_UTC_REGEX, }, "elapsed_seconds": { "description": "Time duration in seconds taken for completion of this execution.", "type": "number", - "required": False + "required": False, }, "web_url": { "description": "History URL for this execution if you want to view in UI.", "type": "string", - "required": False + "required": False, }, "parameters": { "description": "Input parameters for the action.", @@ -101,28 +95,28 @@ class ActionExecutionAPI(BaseAPI): {"type": "integer"}, {"type": "number"}, {"type": "object"}, - {"type": "string"} + {"type": "string"}, ] } }, - 'additionalProperties': False - }, - "context": { - "type": "object" + "additionalProperties": False, }, + "context": {"type": "object"}, "result": { - "anyOf": [{"type": "array"}, - {"type": "boolean"}, - {"type": "integer"}, - {"type": "number"}, - {"type": "object"}, - {"type": "string"}] + "anyOf": [ + {"type": "array"}, + {"type": "boolean"}, + {"type": "integer"}, + {"type": "number"}, + {"type": "object"}, + {"type": "string"}, + ] }, "parent": {"type": "string"}, "children": { "type": "array", "items": {"type": "string"}, - "uniqueItems": True + "uniqueItems": True, }, "log": { "description": "Contains information about execution state transitions.", @@ -132,22 +126,21 @@ class ActionExecutionAPI(BaseAPI): "properties": { "timestamp": { "type": "string", - "pattern": isotime.ISO8601_UTC_REGEX + "pattern": isotime.ISO8601_UTC_REGEX, }, - "status": { - "type": "string", - "enum": LIVEACTION_STATUSES - } - } - } + "status": {"type": "string", "enum": LIVEACTION_STATUSES}, + }, + }, }, "delay": { - "description": ("How long (in milliseconds) to delay the execution before" - "scheduling."), + "description": ( + "How long (in milliseconds) to delay the execution before" + "scheduling." + ), "type": "integer", - } + }, }, - "additionalProperties": False + "additionalProperties": False, } @classmethod @@ -155,16 +148,16 @@ def from_model(cls, model, mask_secrets=False): doc = cls._from_model(model, mask_secrets=mask_secrets) start_timestamp = model.start_timestamp start_timestamp_iso = isotime.format(start_timestamp, offset=False) - doc['start_timestamp'] = start_timestamp_iso + doc["start_timestamp"] = start_timestamp_iso end_timestamp = model.end_timestamp if end_timestamp: end_timestamp_iso = isotime.format(end_timestamp, offset=False) - doc['end_timestamp'] = end_timestamp_iso - doc['elapsed_seconds'] = (end_timestamp - start_timestamp).total_seconds() + doc["end_timestamp"] = end_timestamp_iso + doc["elapsed_seconds"] = (end_timestamp - start_timestamp).total_seconds() - for entry in doc.get('log', []): - entry['timestamp'] = isotime.format(entry['timestamp'], offset=False) + for entry in doc.get("log", []): + entry["timestamp"] = isotime.format(entry["timestamp"], offset=False) attrs = {attr: value for attr, value in six.iteritems(doc) if value} return cls(**attrs) @@ -172,11 +165,11 @@ def from_model(cls, model, mask_secrets=False): @classmethod def to_model(cls, instance): values = {} - for attr, meta in six.iteritems(cls.schema.get('properties', dict())): + for attr, meta in six.iteritems(cls.schema.get("properties", dict())): if not getattr(instance, attr, None): continue - default = copy.deepcopy(meta.get('default', None)) + default = copy.deepcopy(meta.get("default", None)) value = getattr(instance, attr, default) # pylint: disable=no-member @@ -188,8 +181,8 @@ def to_model(cls, instance): if attr not in ActionExecutionAPI.SKIP: values[attr] = value - values['start_timestamp'] = isotime.parse(instance.start_timestamp) - values['end_timestamp'] = isotime.parse(instance.end_timestamp) + values["start_timestamp"] = isotime.parse(instance.start_timestamp) + values["end_timestamp"] = isotime.parse(instance.end_timestamp) model = cls.model(**values) return model @@ -198,41 +191,24 @@ def to_model(cls, instance): class ActionExecutionOutputAPI(BaseAPI): model = ActionExecutionOutputDB schema = { - 'type': 'object', - 'properties': { - 'id': { - 'type': 'string' - }, - 'execution_id': { - 'type': 'string' - }, - 'action_ref': { - 'type': 'string' - }, - 'runner_ref': { - 'type': 'string' - }, - 'timestamp': { - 'type': 'string', - 'pattern': isotime.ISO8601_UTC_REGEX - }, - 'output_type': { - 'type': 'string' - }, - 'data': { - 'type': 'string' - }, - 'delay': { - 'type': 'integer' - } + "type": "object", + "properties": { + "id": {"type": "string"}, + "execution_id": {"type": "string"}, + "action_ref": {"type": "string"}, + "runner_ref": {"type": "string"}, + "timestamp": {"type": "string", "pattern": isotime.ISO8601_UTC_REGEX}, + "output_type": {"type": "string"}, + "data": {"type": "string"}, + "delay": {"type": "integer"}, }, - 'additionalProperties': False + "additionalProperties": False, } @classmethod def from_model(cls, model, mask_secrets=True): doc = cls._from_model(model, mask_secrets=mask_secrets) - doc['timestamp'] = isotime.format(model.timestamp, offset=False) + doc["timestamp"] = isotime.format(model.timestamp, offset=False) attrs = {attr: value for attr, value in six.iteritems(doc) if value is not None} return cls(**attrs) diff --git a/st2common/st2common/models/api/inquiry.py b/st2common/st2common/models/api/inquiry.py index e3194df28c7..a45327aaa7e 100644 --- a/st2common/st2common/models/api/inquiry.py +++ b/st2common/st2common/models/api/inquiry.py @@ -54,30 +54,11 @@ class InquiryAPI(BaseAPI): "description": "Record of an Inquiry", "type": "object", "properties": { - "id": { - "type": "string", - "required": True - }, - "route": { - "type": "string", - "default": "", - "required": True - }, - "ttl": { - "type": "integer", - "default": 1440, - "required": True - }, - "users": { - "type": "array", - "default": [], - "required": True - }, - "roles": { - "type": "array", - "default": [], - "required": True - }, + "id": {"type": "string", "required": True}, + "route": {"type": "string", "default": "", "required": True}, + "ttl": {"type": "integer", "default": 1440, "required": True}, + "users": {"type": "array", "default": [], "required": True}, + "roles": {"type": "array", "default": [], "required": True}, "schema": { "type": "object", "default": { @@ -87,30 +68,32 @@ class InquiryAPI(BaseAPI): "continue": { "type": "boolean", "description": "Would you like to continue the workflow?", - "required": True + "required": True, } }, }, - "required": True + "required": True, }, - "liveaction": REQUIRED_ATTR_SCHEMAS['liveaction'], - "runner": REQUIRED_ATTR_SCHEMAS['runner'], + "liveaction": REQUIRED_ATTR_SCHEMAS["liveaction"], + "runner": REQUIRED_ATTR_SCHEMAS["runner"], "status": { "description": "The current status of the action execution.", "type": "string", - "enum": LIVEACTION_STATUSES + "enum": LIVEACTION_STATUSES, }, "parent": {"type": "string"}, "result": { - "anyOf": [{"type": "array"}, - {"type": "boolean"}, - {"type": "integer"}, - {"type": "number"}, - {"type": "object"}, - {"type": "string"}] - } + "anyOf": [ + {"type": "array"}, + {"type": "boolean"}, + {"type": "integer"}, + {"type": "number"}, + {"type": "object"}, + {"type": "string"}, + ] + }, }, - "additionalProperties": False + "additionalProperties": False, } @classmethod @@ -118,23 +101,22 @@ def from_model(cls, model, mask_secrets=False): doc = cls._from_model(model, mask_secrets=mask_secrets) newdoc = { - 'id': doc['id'], - 'runner': doc.get('runner', None), - 'status': doc.get('status', None), - 'liveaction': doc.get('liveaction', None), - 'parent': doc.get('parent', None), - 'result': doc.get('result', None) + "id": doc["id"], + "runner": doc.get("runner", None), + "status": doc.get("status", None), + "liveaction": doc.get("liveaction", None), + "parent": doc.get("parent", None), + "result": doc.get("result", None), } - for field in ['route', 'ttl', 'users', 'roles', 'schema']: - newdoc[field] = doc['result'].get(field, None) + for field in ["route", "ttl", "users", "roles", "schema"]: + newdoc[field] = doc["result"].get(field, None) return cls(**newdoc) class InquiryResponseAPI(BaseAPI): - """A more pruned Inquiry model, containing only the fields needed for an API response - """ + """A more pruned Inquiry model, containing only the fields needed for an API response""" model = ActionExecutionDB schema = { @@ -142,30 +124,11 @@ class InquiryResponseAPI(BaseAPI): "description": "Record of an Inquiry", "type": "object", "properties": { - "id": { - "type": "string", - "required": True - }, - "route": { - "type": "string", - "default": "", - "required": True - }, - "ttl": { - "type": "integer", - "default": 1440, - "required": True - }, - "users": { - "type": "array", - "default": [], - "required": True - }, - "roles": { - "type": "array", - "default": [], - "required": True - }, + "id": {"type": "string", "required": True}, + "route": {"type": "string", "default": "", "required": True}, + "ttl": {"type": "integer", "default": 1440, "required": True}, + "users": {"type": "array", "default": [], "required": True}, + "roles": {"type": "array", "default": [], "required": True}, "schema": { "type": "object", "default": { @@ -175,14 +138,14 @@ class InquiryResponseAPI(BaseAPI): "continue": { "type": "boolean", "description": "Would you like to continue the workflow?", - "required": True + "required": True, } }, }, - "required": True - } + "required": True, + }, }, - "additionalProperties": False + "additionalProperties": False, } @classmethod @@ -201,9 +164,7 @@ def from_model(cls, model, mask_secrets=False, skip_db=False): else: doc = model - newdoc = { - "id": doc["id"] - } + newdoc = {"id": doc["id"]} for field in ["route", "ttl", "users", "roles", "schema"]: newdoc[field] = doc["result"].get(field) @@ -211,16 +172,16 @@ def from_model(cls, model, mask_secrets=False, skip_db=False): @classmethod def from_inquiry_api(cls, inquiry_api, mask_secrets=False): - """ Allows translation of InquiryAPI directly to InquiryResponseAPI + """Allows translation of InquiryAPI directly to InquiryResponseAPI This bypasses the DB modeling, since there's no DB model for Inquiries yet. """ return cls( - id=getattr(inquiry_api, 'id', None), - route=getattr(inquiry_api, 'route', None), - ttl=getattr(inquiry_api, 'ttl', None), - users=getattr(inquiry_api, 'users', None), - roles=getattr(inquiry_api, 'roles', None), - schema=getattr(inquiry_api, 'schema', None) + id=getattr(inquiry_api, "id", None), + route=getattr(inquiry_api, "route", None), + ttl=getattr(inquiry_api, "ttl", None), + users=getattr(inquiry_api, "users", None), + roles=getattr(inquiry_api, "roles", None), + schema=getattr(inquiry_api, "schema", None), ) diff --git a/st2common/st2common/models/api/keyvalue.py b/st2common/st2common/models/api/keyvalue.py index 8365350ef72..a19cfcc33e1 100644 --- a/st2common/st2common/models/api/keyvalue.py +++ b/st2common/st2common/models/api/keyvalue.py @@ -21,9 +21,16 @@ from oslo_config import cfg import six -from st2common.constants.keyvalue import FULL_SYSTEM_SCOPE, FULL_USER_SCOPE, ALLOWED_SCOPES +from st2common.constants.keyvalue import ( + FULL_SYSTEM_SCOPE, + FULL_USER_SCOPE, + ALLOWED_SCOPES, +) from st2common.constants.keyvalue import SYSTEM_SCOPE, USER_SCOPE -from st2common.exceptions.keyvalue import CryptoKeyNotSetupException, InvalidScopeException +from st2common.exceptions.keyvalue import ( + CryptoKeyNotSetupException, + InvalidScopeException, +) from st2common.log import logging from st2common.util import isotime from st2common.util import date as date_utils @@ -32,10 +39,7 @@ from st2common.models.system.keyvalue import UserKeyReference from st2common.models.db.keyvalue import KeyValuePairDB -__all__ = [ - 'KeyValuePairAPI', - 'KeyValuePairSetAPI' -] +__all__ = ["KeyValuePairAPI", "KeyValuePairSetAPI"] LOG = logging.getLogger(__name__) @@ -44,50 +48,29 @@ class KeyValuePairAPI(BaseAPI): crypto_setup = False model = KeyValuePairDB schema = { - 'type': 'object', - 'properties': { - 'id': { - 'type': 'string' + "type": "object", + "properties": { + "id": {"type": "string"}, + "uid": {"type": "string"}, + "name": {"type": "string"}, + "description": {"type": "string"}, + "value": {"type": "string", "required": True}, + "secret": {"type": "boolean", "required": False, "default": False}, + "encrypted": {"type": "boolean", "required": False, "default": False}, + "scope": { + "type": "string", + "required": False, + "default": FULL_SYSTEM_SCOPE, }, - "uid": { - "type": "string" - }, - 'name': { - 'type': 'string' - }, - 'description': { - 'type': 'string' - }, - 'value': { - 'type': 'string', - 'required': True - }, - 'secret': { - 'type': 'boolean', - 'required': False, - 'default': False - }, - 'encrypted': { - 'type': 'boolean', - 'required': False, - 'default': False - }, - 'scope': { - 'type': 'string', - 'required': False, - 'default': FULL_SYSTEM_SCOPE - }, - 'expire_timestamp': { - 'type': 'string', - 'pattern': isotime.ISO8601_UTC_REGEX + "expire_timestamp": { + "type": "string", + "pattern": isotime.ISO8601_UTC_REGEX, }, # Note: Those values are only used for input # TODO: Improve - 'ttl': { - 'type': 'integer' - } + "ttl": {"type": "integer"}, }, - 'additionalProperties': False + "additionalProperties": False, } @staticmethod @@ -96,19 +79,25 @@ def _setup_crypto(): # Crypto already set up return - LOG.info('Checking if encryption is enabled for key-value store.') + LOG.info("Checking if encryption is enabled for key-value store.") KeyValuePairAPI.is_encryption_enabled = cfg.CONF.keyvalue.enable_encryption - LOG.debug('Encryption enabled? : %s', KeyValuePairAPI.is_encryption_enabled) + LOG.debug("Encryption enabled? : %s", KeyValuePairAPI.is_encryption_enabled) if KeyValuePairAPI.is_encryption_enabled: KeyValuePairAPI.crypto_key_path = cfg.CONF.keyvalue.encryption_key_path - LOG.info('Encryption enabled. Looking for key in path %s', - KeyValuePairAPI.crypto_key_path) + LOG.info( + "Encryption enabled. Looking for key in path %s", + KeyValuePairAPI.crypto_key_path, + ) if not os.path.exists(KeyValuePairAPI.crypto_key_path): - msg = ('Encryption key file does not exist in path %s.' % - KeyValuePairAPI.crypto_key_path) + msg = ( + "Encryption key file does not exist in path %s." + % KeyValuePairAPI.crypto_key_path + ) LOG.exception(msg) - LOG.info('All API requests will now send out BAD_REQUEST ' + - 'if you ask to store secrets in key value store.') + LOG.info( + "All API requests will now send out BAD_REQUEST " + + "if you ask to store secrets in key value store." + ) KeyValuePairAPI.crypto_key = None else: KeyValuePairAPI.crypto_key = read_crypto_key( @@ -123,28 +112,30 @@ def from_model(cls, model, mask_secrets=True): doc = cls._from_model(model, mask_secrets=mask_secrets) - if getattr(model, 'expire_timestamp', None) and model.expire_timestamp: - doc['expire_timestamp'] = isotime.format(model.expire_timestamp, offset=False) + if getattr(model, "expire_timestamp", None) and model.expire_timestamp: + doc["expire_timestamp"] = isotime.format( + model.expire_timestamp, offset=False + ) encrypted = False - secret = getattr(model, 'secret', False) + secret = getattr(model, "secret", False) if secret: encrypted = True if not mask_secrets and secret: - doc['value'] = symmetric_decrypt(KeyValuePairAPI.crypto_key, model.value) + doc["value"] = symmetric_decrypt(KeyValuePairAPI.crypto_key, model.value) encrypted = False - scope = getattr(model, 'scope', SYSTEM_SCOPE) + scope = getattr(model, "scope", SYSTEM_SCOPE) if scope: - doc['scope'] = scope + doc["scope"] = scope - key = doc.get('name', None) + key = doc.get("name", None) if (scope == USER_SCOPE or scope == FULL_USER_SCOPE) and key: - doc['user'] = UserKeyReference.get_user(key) - doc['name'] = UserKeyReference.get_name(key) + doc["user"] = UserKeyReference.get_user(key) + doc["name"] = UserKeyReference.get_name(key) - doc['encrypted'] = encrypted + doc["encrypted"] = encrypted attrs = {attr: value for attr, value in six.iteritems(doc) if value is not None} return cls(**attrs) @@ -153,21 +144,22 @@ def to_model(cls, kvp): if not KeyValuePairAPI.crypto_setup: KeyValuePairAPI._setup_crypto() - kvp_id = getattr(kvp, 'id', None) - name = getattr(kvp, 'name', None) - description = getattr(kvp, 'description', None) + kvp_id = getattr(kvp, "id", None) + name = getattr(kvp, "name", None) + description = getattr(kvp, "description", None) value = kvp.value original_value = value secret = False - if getattr(kvp, 'ttl', None): - expire_timestamp = (date_utils.get_datetime_utc_now() + - datetime.timedelta(seconds=kvp.ttl)) + if getattr(kvp, "ttl", None): + expire_timestamp = date_utils.get_datetime_utc_now() + datetime.timedelta( + seconds=kvp.ttl + ) else: expire_timestamp = None - encrypted = getattr(kvp, 'encrypted', False) - secret = getattr(kvp, 'secret', False) + encrypted = getattr(kvp, "encrypted", False) + secret = getattr(kvp, "secret", False) # If user transmitted the value in an pre-encrypted format, we perform the decryption here # to ensure data integrity. Besides that, we store data as-is. @@ -182,9 +174,11 @@ def to_model(cls, kvp): try: symmetric_decrypt(KeyValuePairAPI.crypto_key, value) except Exception: - msg = ('Failed to verify the integrity of the provided value for key "%s". Ensure ' - 'that the value is encrypted with the correct key and not corrupted.' % - (name)) + msg = ( + 'Failed to verify the integrity of the provided value for key "%s". Ensure ' + "that the value is encrypted with the correct key and not corrupted." + % (name) + ) raise ValueError(msg) # Additional safety check to ensure that the value hasn't been decrypted @@ -194,30 +188,39 @@ def to_model(cls, kvp): value = symmetric_encrypt(KeyValuePairAPI.crypto_key, value) - scope = getattr(kvp, 'scope', FULL_SYSTEM_SCOPE) + scope = getattr(kvp, "scope", FULL_SYSTEM_SCOPE) if scope not in ALLOWED_SCOPES: - raise InvalidScopeException('Invalid scope "%s"! Allowed scopes are %s.' % ( - scope, ALLOWED_SCOPES) + raise InvalidScopeException( + 'Invalid scope "%s"! Allowed scopes are %s.' % (scope, ALLOWED_SCOPES) ) # NOTE: For security reasons, encrypted always implies secret=True. See comment # above for explanation. if encrypted and not secret: - raise ValueError('encrypted option can only be used in combination with secret ' - 'option') + raise ValueError( + "encrypted option can only be used in combination with secret " "option" + ) - model = cls.model(id=kvp_id, name=name, description=description, value=value, - secret=secret, scope=scope, - expire_timestamp=expire_timestamp) + model = cls.model( + id=kvp_id, + name=name, + description=description, + value=value, + secret=secret, + scope=scope, + expire_timestamp=expire_timestamp, + ) return model @classmethod def _verif_key_is_set_up(cls, name): if not KeyValuePairAPI.crypto_key: - msg = ('Crypto key not found in %s. Unable to encrypt / decrypt value for key %s.' % - (KeyValuePairAPI.crypto_key_path, name)) + msg = "Crypto key not found in %s. Unable to encrypt / decrypt value for key %s." % ( + KeyValuePairAPI.crypto_key_path, + name, + ) raise CryptoKeyNotSetupException(msg) @@ -227,13 +230,12 @@ class KeyValuePairSetAPI(KeyValuePairAPI): """ schema = copy.deepcopy(KeyValuePairAPI.schema) - schema['properties']['ttl'] = { - 'description': 'Items TTL', - 'type': 'integer' - } - schema['properties']['user'] = { - 'description': ('User to which the value should be scoped to. Only applicable to ' - 'scope == user'), - 'type': 'string', - 'default': None + schema["properties"]["ttl"] = {"description": "Items TTL", "type": "integer"} + schema["properties"]["user"] = { + "description": ( + "User to which the value should be scoped to. Only applicable to " + "scope == user" + ), + "type": "string", + "default": None, } diff --git a/st2common/st2common/models/api/notification.py b/st2common/st2common/models/api/notification.py index fef0545f26f..9d80ddbf7fb 100644 --- a/st2common/st2common/models/api/notification.py +++ b/st2common/st2common/models/api/notification.py @@ -19,57 +19,60 @@ NotificationSubSchemaAPI = { "type": "object", "properties": { - "message": { - "type": "string", - "description": "Message to use for notification" - }, + "message": {"type": "string", "description": "Message to use for notification"}, "data": { "type": "object", - "description": "Data to be sent as part of notification" + "description": "Data to be sent as part of notification", }, "routes": { "type": "array", - "description": "Channels to post notifications to." + "description": "Channels to post notifications to.", }, "channels": { # Deprecated. Only here for backward compatibility. "type": "array", - "description": "Channels to post notifications to." + "description": "Channels to post notifications to.", }, }, - "additionalProperties": False + "additionalProperties": False, } class NotificationsHelper(object): - @staticmethod def to_model(notify_api_object): - if notify_api_object.get('on-success', None): - on_success = NotificationsHelper._to_model_sub_schema(notify_api_object['on-success']) + if notify_api_object.get("on-success", None): + on_success = NotificationsHelper._to_model_sub_schema( + notify_api_object["on-success"] + ) else: on_success = None - if notify_api_object.get('on-complete', None): + if notify_api_object.get("on-complete", None): on_complete = NotificationsHelper._to_model_sub_schema( - notify_api_object['on-complete']) + notify_api_object["on-complete"] + ) else: on_complete = None - if notify_api_object.get('on-failure', None): - on_failure = NotificationsHelper._to_model_sub_schema(notify_api_object['on-failure']) + if notify_api_object.get("on-failure", None): + on_failure = NotificationsHelper._to_model_sub_schema( + notify_api_object["on-failure"] + ) else: on_failure = None - model = NotificationSchema(on_success=on_success, on_failure=on_failure, - on_complete=on_complete) + model = NotificationSchema( + on_success=on_success, on_failure=on_failure, on_complete=on_complete + ) return model @staticmethod def _to_model_sub_schema(notification_settings_json): - message = notification_settings_json.get('message', None) - data = notification_settings_json.get('data', {}) - routes = (notification_settings_json.get('routes', None) or - notification_settings_json.get('channels', [])) + message = notification_settings_json.get("message", None) + data = notification_settings_json.get("data", {}) + routes = notification_settings_json.get( + "routes", None + ) or notification_settings_json.get("channels", []) model = NotificationSubSchema(message=message, data=data, routes=routes) return model @@ -77,15 +80,18 @@ def _to_model_sub_schema(notification_settings_json): @staticmethod def from_model(notify_model): notify = {} - if getattr(notify_model, 'on_complete', None): - notify['on-complete'] = NotificationsHelper._from_model_sub_schema( - notify_model.on_complete) - if getattr(notify_model, 'on_success', None): - notify['on-success'] = NotificationsHelper._from_model_sub_schema( - notify_model.on_success) - if getattr(notify_model, 'on_failure', None): - notify['on-failure'] = NotificationsHelper._from_model_sub_schema( - notify_model.on_failure) + if getattr(notify_model, "on_complete", None): + notify["on-complete"] = NotificationsHelper._from_model_sub_schema( + notify_model.on_complete + ) + if getattr(notify_model, "on_success", None): + notify["on-success"] = NotificationsHelper._from_model_sub_schema( + notify_model.on_success + ) + if getattr(notify_model, "on_failure", None): + notify["on-failure"] = NotificationsHelper._from_model_sub_schema( + notify_model.on_failure + ) return notify @@ -93,13 +99,14 @@ def from_model(notify_model): def _from_model_sub_schema(notify_sub_schema_model): notify_sub_schema = {} - if getattr(notify_sub_schema_model, 'message', None): - notify_sub_schema['message'] = notify_sub_schema_model.message - if getattr(notify_sub_schema_model, 'data', None): - notify_sub_schema['data'] = notify_sub_schema_model.data - routes = (getattr(notify_sub_schema_model, 'routes') or - getattr(notify_sub_schema_model, 'channels')) + if getattr(notify_sub_schema_model, "message", None): + notify_sub_schema["message"] = notify_sub_schema_model.message + if getattr(notify_sub_schema_model, "data", None): + notify_sub_schema["data"] = notify_sub_schema_model.data + routes = getattr(notify_sub_schema_model, "routes") or getattr( + notify_sub_schema_model, "channels" + ) if routes: - notify_sub_schema['routes'] = routes + notify_sub_schema["routes"] = routes return notify_sub_schema diff --git a/st2common/st2common/models/api/pack.py b/st2common/st2common/models/api/pack.py index 02c6d00f637..6de2893427e 100644 --- a/st2common/st2common/models/api/pack.py +++ b/st2common/st2common/models/api/pack.py @@ -37,16 +37,14 @@ from st2common.util.pack import validate_config_against_schema __all__ = [ - 'PackAPI', - 'ConfigSchemaAPI', - 'ConfigAPI', - - 'ConfigItemSetAPI', - - 'PackInstallRequestAPI', - 'PackRegisterRequestAPI', - 'PackSearchRequestAPI', - 'PackAsyncAPI' + "PackAPI", + "ConfigSchemaAPI", + "ConfigAPI", + "ConfigItemSetAPI", + "PackInstallRequestAPI", + "PackRegisterRequestAPI", + "PackSearchRequestAPI", + "PackAsyncAPI", ] LOG = logging.getLogger(__name__) @@ -55,124 +53,117 @@ class PackAPI(BaseAPI): model = PackDB schema = { - 'type': 'object', - 'description': 'Content pack schema.', - 'properties': { - 'id': { - 'type': 'string', - 'description': 'Unique identifier for the pack.', - 'default': None + "type": "object", + "description": "Content pack schema.", + "properties": { + "id": { + "type": "string", + "description": "Unique identifier for the pack.", + "default": None, }, - 'name': { - 'type': 'string', - 'description': 'Display name of the pack. If the name only contains lowercase' - 'letters, digits and underscores, the "ref" field is not required.', - 'required': True + "name": { + "type": "string", + "description": "Display name of the pack. If the name only contains lowercase" + 'letters, digits and underscores, the "ref" field is not required.', + "required": True, }, - 'ref': { - 'type': 'string', - 'description': 'Reference for the pack, used as an internal id.', - 'default': None, - 'pattern': PACK_REF_WHITELIST_REGEX + "ref": { + "type": "string", + "description": "Reference for the pack, used as an internal id.", + "default": None, + "pattern": PACK_REF_WHITELIST_REGEX, }, - 'uid': { - 'type': 'string' + "uid": {"type": "string"}, + "description": { + "type": "string", + "description": "Brief description of the pack and the service it integrates with.", + "required": True, }, - 'description': { - 'type': 'string', - 'description': 'Brief description of the pack and the service it integrates with.', - 'required': True + "keywords": { + "type": "array", + "description": "Keywords describing the pack.", + "items": {"type": "string"}, + "default": [], }, - 'keywords': { - 'type': 'array', - 'description': 'Keywords describing the pack.', - 'items': {'type': 'string'}, - 'default': [] + "version": { + "type": "string", + "description": "Pack version. Must follow the semver format " + '(for instance, "0.1.0").', + "pattern": PACK_VERSION_REGEX, + "required": True, }, - 'version': { - 'type': 'string', - 'description': 'Pack version. Must follow the semver format ' - '(for instance, "0.1.0").', - 'pattern': PACK_VERSION_REGEX, - 'required': True + "stackstorm_version": { + "type": "string", + "description": 'Required StackStorm version. Examples: ">1.6.0", ' + '">=1.8.0, <2.2.0"', + "pattern": ST2_VERSION_REGEX, }, - 'stackstorm_version': { - 'type': 'string', - 'description': 'Required StackStorm version. Examples: ">1.6.0", ' - '">=1.8.0, <2.2.0"', - 'pattern': ST2_VERSION_REGEX, + "python_versions": { + "type": "array", + "description": ( + "Major Python versions supported by this pack. E.g. " + '"2" for Python 2.7.x and "3" for Python 3.6.x' + ), + "items": {"type": "string", "enum": ["2", "3"]}, + "minItems": 1, + "maxItems": 2, + "uniqueItems": True, + "additionalItems": True, }, - 'python_versions': { - 'type': 'array', - 'description': ('Major Python versions supported by this pack. E.g. ' - '"2" for Python 2.7.x and "3" for Python 3.6.x'), - 'items': { - 'type': 'string', - 'enum': [ - '2', - '3' - ] - }, - 'minItems': 1, - 'maxItems': 2, - 'uniqueItems': True, - 'additionalItems': True + "author": { + "type": "string", + "description": "Pack author or authors.", + "required": True, }, - 'author': { - 'type': 'string', - 'description': 'Pack author or authors.', - 'required': True + "email": { + "type": "string", + "description": "E-mail of the pack author.", + "format": "email", }, - 'email': { - 'type': 'string', - 'description': 'E-mail of the pack author.', - 'format': 'email' + "contributors": { + "type": "array", + "items": {"type": "string", "maxLength": 100}, + "description": ( + "A list of people who have contributed to the pack. Format is: " + "Name e.g. Tomaz Muraus ." + ), }, - 'contributors': { - 'type': 'array', - 'items': { - 'type': 'string', - 'maxLength': 100 - }, - 'description': ('A list of people who have contributed to the pack. Format is: ' - 'Name e.g. Tomaz Muraus .') + "files": { + "type": "array", + "description": "A list of files inside the pack.", + "items": {"type": "string"}, + "default": [], }, - 'files': { - 'type': 'array', - 'description': 'A list of files inside the pack.', - 'items': {'type': 'string'}, - 'default': [] + "dependencies": { + "type": "array", + "description": "A list of other StackStorm packs this pack depends upon. " + 'The same format as in "st2 pack install" is used: ' + '"[=]".', + "items": {"type": "string"}, + "default": [], }, - 'dependencies': { - 'type': 'array', - 'description': 'A list of other StackStorm packs this pack depends upon. ' - 'The same format as in "st2 pack install" is used: ' - '"[=]".', - 'items': {'type': 'string'}, - 'default': [] + "system": { + "type": "object", + "description": "Specification for the system components and packages " + "required for the pack.", + "default": {}, }, - 'system': { - 'type': 'object', - 'description': 'Specification for the system components and packages ' - 'required for the pack.', - 'default': {} + "path": { + "type": "string", + "description": "Location of the pack on disk in st2 system.", + "required": False, }, - 'path': { - 'type': 'string', - 'description': 'Location of the pack on disk in st2 system.', - 'required': False - } }, # NOTE: We add this here explicitly so we can gracefuly add new attributs to pack.yaml # without breaking existing installations - 'additionalProperties': True + "additionalProperties": True, } def __init__(self, **values): # Note: If some version values are not explicitly surrounded by quotes they are recognized # as numbers so we cast them to string - if values.get('version', None): - values['version'] = str(values['version']) + if values.get("version", None): + values["version"] = str(values["version"]) super(PackAPI, self).__init__(**values) @@ -186,17 +177,21 @@ def validate(self): # Invalid version if "Failed validating 'pattern' in schema['properties']['version']" in msg: - new_msg = ('Pack version "%s" doesn\'t follow a valid semver format. Valid ' - 'versions and formats include: 0.1.0, 0.2.1, 1.1.0, etc.' % - (self.version)) - new_msg += '\n\n' + msg + new_msg = ( + 'Pack version "%s" doesn\'t follow a valid semver format. Valid ' + "versions and formats include: 0.1.0, 0.2.1, 1.1.0, etc." + % (self.version) + ) + new_msg += "\n\n" + msg raise jsonschema.ValidationError(new_msg) # Invalid ref / name if "Failed validating 'pattern' in schema['properties']['ref']" in msg: - new_msg = ('Pack ref / name can only contain valid word characters (a-z, 0-9 and ' - '_), dashes are not allowed.') - new_msg += '\n\n' + msg + new_msg = ( + "Pack ref / name can only contain valid word characters (a-z, 0-9 and " + "_), dashes are not allowed." + ) + new_msg += "\n\n" + msg raise jsonschema.ValidationError(new_msg) raise e @@ -206,24 +201,35 @@ def to_model(cls, pack): ref = pack.ref name = pack.name description = pack.description - keywords = getattr(pack, 'keywords', []) + keywords = getattr(pack, "keywords", []) version = str(pack.version) - stackstorm_version = getattr(pack, 'stackstorm_version', None) - python_versions = getattr(pack, 'python_versions', []) + stackstorm_version = getattr(pack, "stackstorm_version", None) + python_versions = getattr(pack, "python_versions", []) author = pack.author email = pack.email - contributors = getattr(pack, 'contributors', []) - files = getattr(pack, 'files', []) - pack_dir = getattr(pack, 'path', None) - dependencies = getattr(pack, 'dependencies', []) - system = getattr(pack, 'system', {}) - - model = cls.model(ref=ref, name=name, description=description, keywords=keywords, - version=version, author=author, email=email, contributors=contributors, - files=files, dependencies=dependencies, system=system, - stackstorm_version=stackstorm_version, path=pack_dir, - python_versions=python_versions) + contributors = getattr(pack, "contributors", []) + files = getattr(pack, "files", []) + pack_dir = getattr(pack, "path", None) + dependencies = getattr(pack, "dependencies", []) + system = getattr(pack, "system", {}) + + model = cls.model( + ref=ref, + name=name, + description=description, + keywords=keywords, + version=version, + author=author, + email=email, + contributors=contributors, + files=files, + dependencies=dependencies, + system=system, + stackstorm_version=stackstorm_version, + path=pack_dir, + python_versions=python_versions, + ) return model @@ -236,11 +242,11 @@ class ConfigSchemaAPI(BaseAPI): "properties": { "id": { "description": "The unique identifier for the config schema.", - "type": "string" + "type": "string", }, "pack": { "description": "The content pack this config schema belongs to.", - "type": "string" + "type": "string", }, "attributes": { "description": "Config schema attributes.", @@ -248,11 +254,11 @@ class ConfigSchemaAPI(BaseAPI): "patternProperties": { r"^\w+$": util_schema.get_action_parameters_schema() }, - 'additionalProperties': False, - "default": {} - } + "additionalProperties": False, + "default": {}, + }, }, - "additionalProperties": False + "additionalProperties": False, } @classmethod @@ -273,19 +279,19 @@ class ConfigAPI(BaseAPI): "properties": { "id": { "description": "The unique identifier for the config.", - "type": "string" + "type": "string", }, "pack": { "description": "The content pack this config belongs to.", - "type": "string" + "type": "string", }, "values": { "description": "Config values.", "type": "object", - "default": {} - } + "default": {}, + }, }, - "additionalProperties": False + "additionalProperties": False, } def validate(self, validate_against_schema=False): @@ -310,13 +316,15 @@ def _validate_config_values_against_schema(self): instance = self.values or {} schema = config_schema_db.attributes or {} - configs_path = os.path.join(cfg.CONF.system.base_path, 'configs/') - config_path = os.path.join(configs_path, '%s.yaml' % (self.pack)) + configs_path = os.path.join(cfg.CONF.system.base_path, "configs/") + config_path = os.path.join(configs_path, "%s.yaml" % (self.pack)) - cleaned = validate_config_against_schema(config_schema=schema, - config_object=instance, - config_path=config_path, - pack_name=self.pack) + cleaned = validate_config_against_schema( + config_schema=schema, + config_object=instance, + config_path=config_path, + pack_name=self.pack, + ) return cleaned @@ -330,15 +338,14 @@ def to_model(cls, config): class ConfigUpdateRequestAPI(BaseAPI): - schema = { - "type": "object" - } + schema = {"type": "object"} class ConfigItemSetAPI(BaseAPI): """ API class used with the config set API endpoint. """ + model = None schema = { "title": "", @@ -348,30 +355,27 @@ class ConfigItemSetAPI(BaseAPI): "name": { "description": "Config item name (key)", "type": "string", - "required": True + "required": True, }, "value": { "description": "Config item value.", "type": ["string", "number", "boolean", "array", "object"], - "required": True + "required": True, }, "scope": { "description": "Config item scope (system / user)", "type": "string", "default": SYSTEM_SCOPE, - "enum": [ - SYSTEM_SCOPE, - USER_SCOPE - ] + "enum": [SYSTEM_SCOPE, USER_SCOPE], }, "user": { "description": "User for user-scoped items (only available to admins).", "type": "string", "required": False, - "default": None - } + "default": None, + }, }, - "additionalProperties": False + "additionalProperties": False, } @@ -379,15 +383,13 @@ class PackInstallRequestAPI(BaseAPI): schema = { "type": "object", "properties": { - "packs": { - "type": "array" - }, + "packs": {"type": "array"}, "force": { "type": "boolean", "description": "Force pack installation", - "default": False - } - } + "default": False, + }, + }, } @@ -395,24 +397,14 @@ class PackRegisterRequestAPI(BaseAPI): schema = { "type": "object", "properties": { - "types": { - "type": "array", - "items": { - "type": "string" - } - }, - "packs": { - "type": "array", - "items": { - "type": "string" - } - }, + "types": {"type": "array", "items": {"type": "string"}}, + "packs": {"type": "array", "items": {"type": "string"}}, "fail_on_failure": { "type": "boolean", "description": "True to fail on failure", - "default": True - } - } + "default": True, + }, + }, } @@ -438,18 +430,13 @@ class PackSearchRequestAPI(BaseAPI): }, "additionalProperties": False, }, - ] + ], } class PackAsyncAPI(BaseAPI): schema = { "type": "object", - "properties": { - "execution_id": { - "type": "string", - "required": True - } - }, - "additionalProperties": False + "properties": {"execution_id": {"type": "string", "required": True}}, + "additionalProperties": False, } diff --git a/st2common/st2common/models/api/policy.py b/st2common/st2common/models/api/policy.py index a46dad9eda4..211560d453b 100644 --- a/st2common/st2common/models/api/policy.py +++ b/st2common/st2common/models/api/policy.py @@ -22,7 +22,7 @@ from st2common.util import schema as util_schema -__all__ = ['PolicyTypeAPI'] +__all__ = ["PolicyTypeAPI"] LOG = logging.getLogger(__name__) @@ -33,55 +33,34 @@ class PolicyTypeAPI(BaseAPI, APIUIDMixin): "title": "Policy Type", "type": "object", "properties": { - "id": { - "type": "string", - "default": None - }, - 'uid': { - 'type': 'string' - }, - "name": { - "type": "string", - "required": True - }, - "resource_type": { - "enum": ["action"], - "required": True - }, - "ref": { - "type": "string" - }, - "description": { - "type": "string" - }, - "enabled": { - "type": "boolean", - "default": True - }, - "module": { - "type": "string", - "required": True - }, + "id": {"type": "string", "default": None}, + "uid": {"type": "string"}, + "name": {"type": "string", "required": True}, + "resource_type": {"enum": ["action"], "required": True}, + "ref": {"type": "string"}, + "description": {"type": "string"}, + "enabled": {"type": "boolean", "default": True}, + "module": {"type": "string", "required": True}, "parameters": { "type": "object", - "patternProperties": { - r"^\w+$": util_schema.get_draft_schema() - }, - 'additionalProperties': False - } + "patternProperties": {r"^\w+$": util_schema.get_draft_schema()}, + "additionalProperties": False, + }, }, - "additionalProperties": False + "additionalProperties": False, } @classmethod def to_model(cls, instance): - return cls.model(name=str(instance.name), - description=getattr(instance, 'description', None), - resource_type=str(instance.resource_type), - ref=getattr(instance, 'ref', None), - enabled=getattr(instance, 'enabled', None), - module=str(instance.module), - parameters=getattr(instance, 'parameters', dict())) + return cls.model( + name=str(instance.name), + description=getattr(instance, "description", None), + resource_type=str(instance.resource_type), + ref=getattr(instance, "ref", None), + enabled=getattr(instance, "enabled", None), + module=str(instance.module), + parameters=getattr(instance, "parameters", dict()), + ) class PolicyAPI(BaseAPI, APIUIDMixin): @@ -90,38 +69,15 @@ class PolicyAPI(BaseAPI, APIUIDMixin): "title": "Policy", "type": "object", "properties": { - "id": { - "type": "string", - "default": None - }, - 'uid': { - 'type': 'string' - }, - "name": { - "type": "string", - "required": True - }, - "pack": { - "type": "string" - }, - "ref": { - "type": "string" - }, - "description": { - "type": "string" - }, - "enabled": { - "type": "boolean", - "default": True - }, - "resource_ref": { - "type": "string", - "required": True - }, - "policy_type": { - "type": "string", - "required": True - }, + "id": {"type": "string", "default": None}, + "uid": {"type": "string"}, + "name": {"type": "string", "required": True}, + "pack": {"type": "string"}, + "ref": {"type": "string"}, + "description": {"type": "string"}, + "enabled": {"type": "boolean", "default": True}, + "resource_ref": {"type": "string", "required": True}, + "policy_type": {"type": "string", "required": True}, "parameters": { "type": "object", "patternProperties": { @@ -132,20 +88,19 @@ class PolicyAPI(BaseAPI, APIUIDMixin): {"type": "integer"}, {"type": "number"}, {"type": "object"}, - {"type": "string"} + {"type": "string"}, ] } }, - 'additionalProperties': False - + "additionalProperties": False, }, "metadata_file": { "description": "Path to the metadata file relative to the pack directory.", "type": "string", - "default": "" - } + "default": "", + }, }, - "additionalProperties": False + "additionalProperties": False, } def validate(self): @@ -156,15 +111,19 @@ def validate(self): # pylint: disable=no-member policy_type_db = PolicyType.get_by_ref(cleaned.policy_type) if not policy_type_db: - raise ValueError('Referenced policy_type "%s" doesnt exist' % (cleaned.policy_type)) + raise ValueError( + 'Referenced policy_type "%s" doesnt exist' % (cleaned.policy_type) + ) parameters_schema = policy_type_db.parameters - parameters = getattr(cleaned, 'parameters', {}) + parameters = getattr(cleaned, "parameters", {}) schema = util_schema.get_schema_for_resource_parameters( - parameters_schema=parameters_schema) + parameters_schema=parameters_schema + ) validator = util_schema.get_validator() - cleaned_parameters = util_schema.validate(parameters, schema, validator, use_default=True, - allow_default_none=True) + cleaned_parameters = util_schema.validate( + parameters, schema, validator, use_default=True, allow_default_none=True + ) cleaned.parameters = cleaned_parameters @@ -172,13 +131,15 @@ def validate(self): @classmethod def to_model(cls, instance): - return cls.model(id=getattr(instance, 'id', None), - name=str(instance.name), - description=getattr(instance, 'description', None), - pack=str(instance.pack), - ref=getattr(instance, 'ref', None), - enabled=getattr(instance, 'enabled', None), - resource_ref=str(instance.resource_ref), - policy_type=str(instance.policy_type), - parameters=getattr(instance, 'parameters', dict()), - metadata_file=getattr(instance, 'metadata_file', None)) + return cls.model( + id=getattr(instance, "id", None), + name=str(instance.name), + description=getattr(instance, "description", None), + pack=str(instance.pack), + ref=getattr(instance, "ref", None), + enabled=getattr(instance, "enabled", None), + resource_ref=str(instance.resource_ref), + policy_type=str(instance.policy_type), + parameters=getattr(instance, "parameters", dict()), + metadata_file=getattr(instance, "metadata_file", None), + ) diff --git a/st2common/st2common/models/api/rbac.py b/st2common/st2common/models/api/rbac.py index 556793b7a63..bd269ce3d68 100644 --- a/st2common/st2common/models/api/rbac.py +++ b/st2common/st2common/models/api/rbac.py @@ -25,67 +25,55 @@ from st2common.util.uid import parse_uid __all__ = [ - 'RoleAPI', - 'UserRoleAssignmentAPI', - - 'RoleDefinitionFileFormatAPI', - 'UserRoleAssignmentFileFormatAPI', - - 'AuthGroupToRoleMapAssignmentFileFormatAPI' + "RoleAPI", + "UserRoleAssignmentAPI", + "RoleDefinitionFileFormatAPI", + "UserRoleAssignmentFileFormatAPI", + "AuthGroupToRoleMapAssignmentFileFormatAPI", ] class RoleAPI(BaseAPI): model = RoleDB schema = { - 'type': 'object', - 'properties': { - 'id': { - 'type': 'string', - 'default': None - }, - 'name': { - 'type': 'string', - 'required': True - }, - 'description': { - 'type': 'string' - }, - 'permission_grant_ids': { - 'type': 'array', - 'items': { - 'type': 'string' - } - }, - 'permission_grant_objects': { - 'type': 'array', - 'items': { - 'type': 'object' - } - } + "type": "object", + "properties": { + "id": {"type": "string", "default": None}, + "name": {"type": "string", "required": True}, + "description": {"type": "string"}, + "permission_grant_ids": {"type": "array", "items": {"type": "string"}}, + "permission_grant_objects": {"type": "array", "items": {"type": "object"}}, }, - 'additionalProperties': False + "additionalProperties": False, } @classmethod - def from_model(cls, model, mask_secrets=False, retrieve_permission_grant_objects=True): + def from_model( + cls, model, mask_secrets=False, retrieve_permission_grant_objects=True + ): role = cls._from_model(model, mask_secrets=mask_secrets) # Convert ObjectIDs to strings - role['permission_grant_ids'] = [str(permission_grant) for permission_grant in - model.permission_grants] + role["permission_grant_ids"] = [ + str(permission_grant) for permission_grant in model.permission_grants + ] # Retrieve and include corresponding permission grant objects if retrieve_permission_grant_objects: from st2common.persistence.rbac import PermissionGrant - permission_grant_dbs = PermissionGrant.query(id__in=role['permission_grants']) + + permission_grant_dbs = PermissionGrant.query( + id__in=role["permission_grants"] + ) permission_grant_apis = [] for permission_grant_db in permission_grant_dbs: - permission_grant_api = PermissionGrantAPI.from_model(permission_grant_db) + permission_grant_api = PermissionGrantAPI.from_model( + permission_grant_db + ) permission_grant_apis.append(permission_grant_api) - role['permission_grant_objects'] = permission_grant_apis + role["permission_grant_objects"] = permission_grant_apis return cls(**role) @@ -93,56 +81,30 @@ def from_model(cls, model, mask_secrets=False, retrieve_permission_grant_objects class UserRoleAssignmentAPI(BaseAPI): model = UserRoleAssignmentDB schema = { - 'type': 'object', - 'properties': { - 'id': { - 'type': 'string', - 'default': None - }, - 'user': { - 'type': 'string', - 'required': True - }, - 'role': { - 'type': 'string', - 'required': True - }, - 'description': { - 'type': 'string' - }, - 'is_remote': { - 'type': 'boolean' - }, - 'source': { - 'type': 'string' - } + "type": "object", + "properties": { + "id": {"type": "string", "default": None}, + "user": {"type": "string", "required": True}, + "role": {"type": "string", "required": True}, + "description": {"type": "string"}, + "is_remote": {"type": "boolean"}, + "source": {"type": "string"}, }, - 'additionalProperties': False + "additionalProperties": False, } class PermissionGrantAPI(BaseAPI): model = PermissionGrantDB schema = { - 'type': 'object', - 'properties': { - 'id': { - 'type': 'string', - 'default': None - }, - 'resource_uid': { - 'type': 'string', - 'required': True - }, - 'resource_type': { - 'type': 'string', - 'required': True - }, - 'permission_types': { - 'type': 'array' - } + "type": "object", + "properties": { + "id": {"type": "string", "default": None}, + "resource_uid": {"type": "string", "required": True}, + "resource_type": {"type": "string", "required": True}, + "permission_types": {"type": "array"}, }, - 'additionalProperties': False + "additionalProperties": False, } @@ -152,53 +114,55 @@ class RoleDefinitionFileFormatAPI(BaseAPI): """ schema = { - 'type': 'object', - 'properties': { - 'name': { - 'type': 'string', - 'description': 'Role name', - 'required': True, - 'default': None + "type": "object", + "properties": { + "name": { + "type": "string", + "description": "Role name", + "required": True, + "default": None, }, - 'description': { - 'type': 'string', - 'description': 'Role description', - 'required': False + "description": { + "type": "string", + "description": "Role description", + "required": False, }, - 'enabled': { - 'type': 'boolean', - 'description': ('Flag indicating if this role is enabled. Note: Disabled roles ' - 'are simply ignored when loading definitions from disk.'), - 'default': True + "enabled": { + "type": "boolean", + "description": ( + "Flag indicating if this role is enabled. Note: Disabled roles " + "are simply ignored when loading definitions from disk." + ), + "default": True, }, - 'permission_grants': { - 'type': 'array', - 'items': { - 'type': 'object', - 'properties': { - 'resource_uid': { - 'type': 'string', - 'description': 'UID of a resource to which this grant applies to.', - 'required': False, - 'default': None + "permission_grants": { + "type": "array", + "items": { + "type": "object", + "properties": { + "resource_uid": { + "type": "string", + "description": "UID of a resource to which this grant applies to.", + "required": False, + "default": None, }, - 'permission_types': { - 'type': 'array', - 'description': 'A list of permission types to grant', - 'uniqueItems': True, - 'items': { - 'type': 'string', + "permission_types": { + "type": "array", + "description": "A list of permission types to grant", + "uniqueItems": True, + "items": { + "type": "string", # Note: We permission aditional validation for based on the # resource type in other place - 'enum': PermissionType.get_valid_values() + "enum": PermissionType.get_valid_values(), }, - 'default': [] - } - } - } - } + "default": [], + }, + }, + }, + }, }, - 'additionalProperties': False + "additionalProperties": False, } def validate(self): @@ -208,31 +172,43 @@ def validate(self): # Custom validation # Validate that only the correct permission types are used - permission_grants = getattr(self, 'permission_grants', []) + permission_grants = getattr(self, "permission_grants", []) for permission_grant in permission_grants: - resource_uid = permission_grant.get('resource_uid', None) - permission_types = permission_grant.get('permission_types', []) + resource_uid = permission_grant.get("resource_uid", None) + permission_types = permission_grant.get("permission_types", []) if resource_uid: # Permission types which apply to a resource resource_type, _ = parse_uid(uid=resource_uid) - valid_permission_types = PermissionType.get_valid_permissions_for_resource_type( - resource_type=resource_type) + valid_permission_types = ( + PermissionType.get_valid_permissions_for_resource_type( + resource_type=resource_type + ) + ) for permission_type in permission_types: if permission_type not in valid_permission_types: - message = ('Invalid permission type "%s" for resource type "%s"' % - (permission_type, resource_type)) + message = ( + 'Invalid permission type "%s" for resource type "%s"' + % ( + permission_type, + resource_type, + ) + ) raise ValueError(message) else: # Right now we only support single permission type (list) which is global and # doesn't apply to a resource for permission_type in permission_types: if permission_type not in GLOBAL_PERMISSION_TYPES: - valid_global_permission_types = ', '.join(GLOBAL_PERMISSION_TYPES) - message = ('Invalid permission type "%s". Valid global permission types ' - 'which can be used without a resource id are: %s' % - (permission_type, valid_global_permission_types)) + valid_global_permission_types = ", ".join( + GLOBAL_PERMISSION_TYPES + ) + message = ( + 'Invalid permission type "%s". Valid global permission types ' + "which can be used without a resource id are: %s" + % (permission_type, valid_global_permission_types) + ) raise ValueError(message) return cleaned @@ -252,52 +228,53 @@ def validate(self, validate_role_exists=False): if validate_role_exists: # Validate that the referenced roles exist in the db rbac_service = get_rbac_backend().get_service_class() - rbac_service.validate_roles_exists(role_names=self.roles) # pylint: disable=no-member + rbac_service.validate_roles_exists( + role_names=self.roles + ) # pylint: disable=no-member return cleaned class UserRoleAssignmentFileFormatAPI(BaseAPI): schema = { - 'type': 'object', - 'properties': { - 'username': { - 'type': 'string', - 'description': 'Username', - 'required': True, - 'default': None + "type": "object", + "properties": { + "username": { + "type": "string", + "description": "Username", + "required": True, + "default": None, }, - 'description': { - 'type': 'string', - 'description': 'Assignment description', - 'required': False, - 'default': None + "description": { + "type": "string", + "description": "Assignment description", + "required": False, + "default": None, }, - 'enabled': { - 'type': 'boolean', - 'description': ('Flag indicating if this assignment is enabled. Note: Disabled ' - 'assignments are simply ignored when loading definitions from ' - ' disk.'), - 'default': True + "enabled": { + "type": "boolean", + "description": ( + "Flag indicating if this assignment is enabled. Note: Disabled " + "assignments are simply ignored when loading definitions from " + " disk." + ), + "default": True, }, - 'roles': { - 'type': 'array', - 'description': 'Roles assigned to this user', - 'uniqueItems': True, - 'items': { - 'type': 'string' - }, - 'required': True + "roles": { + "type": "array", + "description": "Roles assigned to this user", + "uniqueItems": True, + "items": {"type": "string"}, + "required": True, + }, + "file_path": { + "type": "string", + "description": "Path of the file of where this assignment comes from.", + "default": None, + "required": False, }, - 'file_path': { - 'type': 'string', - 'description': 'Path of the file of where this assignment comes from.', - 'default': None, - 'required': False - } - }, - 'additionalProperties': False + "additionalProperties": False, } def validate(self, validate_role_exists=False): @@ -307,44 +284,46 @@ def validate(self, validate_role_exists=False): class AuthGroupToRoleMapAssignmentFileFormatAPI(BaseAPI): schema = { - 'type': 'object', - 'properties': { - 'group': { - 'type': 'string', - 'description': 'Name of the group as returned by auth backend.', - 'required': True + "type": "object", + "properties": { + "group": { + "type": "string", + "description": "Name of the group as returned by auth backend.", + "required": True, }, - 'description': { - 'type': 'string', - 'description': 'Mapping description', - 'required': False, - 'default': None + "description": { + "type": "string", + "description": "Mapping description", + "required": False, + "default": None, }, - 'enabled': { - 'type': 'boolean', - 'description': ('Flag indicating if this mapping is enabled. Note: Disabled ' - 'assignments are simply ignored when loading definitions from ' - ' disk.'), - 'default': True + "enabled": { + "type": "boolean", + "description": ( + "Flag indicating if this mapping is enabled. Note: Disabled " + "assignments are simply ignored when loading definitions from " + " disk." + ), + "default": True, }, - 'roles': { - 'type': 'array', - 'description': ('StackStorm roles which are assigned to each user which belongs ' - 'to that group.'), - 'uniqueItems': True, - 'items': { - 'type': 'string' - }, - 'required': True + "roles": { + "type": "array", + "description": ( + "StackStorm roles which are assigned to each user which belongs " + "to that group." + ), + "uniqueItems": True, + "items": {"type": "string"}, + "required": True, + }, + "file_path": { + "type": "string", + "description": "Path of the file of where this assignment comes from.", + "default": None, + "required": False, }, - 'file_path': { - 'type': 'string', - 'description': 'Path of the file of where this assignment comes from.', - 'default': None, - 'required': False - } }, - 'additionalProperties': False + "additionalProperties": False, } def validate(self, validate_role_exists=False): diff --git a/st2common/st2common/models/api/rule.py b/st2common/st2common/models/api/rule.py index 716eeec2070..8919a2ffc9a 100644 --- a/st2common/st2common/models/api/rule.py +++ b/st2common/st2common/models/api/rule.py @@ -20,7 +20,12 @@ from st2common.models.api.base import BaseAPI from st2common.models.api.base import APIUIDMixin from st2common.models.api.tag import TagsHelper -from st2common.models.db.rule import RuleDB, RuleTypeDB, RuleTypeSpecDB, ActionExecutionSpecDB +from st2common.models.db.rule import ( + RuleDB, + RuleTypeDB, + RuleTypeSpecDB, + ActionExecutionSpecDB, +) from st2common.models.system.common import ResourceReference from st2common.persistence.trigger import Trigger import st2common.services.triggers as TriggerService @@ -30,61 +35,52 @@ class RuleTypeSpec(BaseAPI): schema = { - 'type': 'object', - 'properties': { - 'ref': { - 'type': 'string', - 'required': True - }, - 'parameters': { - 'type': 'object' - } + "type": "object", + "properties": { + "ref": {"type": "string", "required": True}, + "parameters": {"type": "object"}, }, - 'additionalProperties': False + "additionalProperties": False, } class RuleTypeAPI(BaseAPI): model = RuleTypeDB schema = { - 'title': 'RuleType', - 'description': 'A specific type of rule.', - 'type': 'object', - 'properties': { - 'id': { - 'description': 'The unique identifier for the rule type.', - 'type': 'string', - 'default': None - }, - 'name': { - 'description': 'The name for the rule type.', - 'type': 'string', - 'required': True + "title": "RuleType", + "description": "A specific type of rule.", + "type": "object", + "properties": { + "id": { + "description": "The unique identifier for the rule type.", + "type": "string", + "default": None, }, - 'description': { - 'description': 'The description of the rule type.', - 'type': 'string' + "name": { + "description": "The name for the rule type.", + "type": "string", + "required": True, }, - 'enabled': { - 'type': 'boolean', - 'default': True + "description": { + "description": "The description of the rule type.", + "type": "string", }, - 'parameters': { - 'type': 'object' - } + "enabled": {"type": "boolean", "default": True}, + "parameters": {"type": "object"}, }, - 'additionalProperties': False + "additionalProperties": False, } @classmethod def to_model(cls, rule_type): - name = getattr(rule_type, 'name', None) - description = getattr(rule_type, 'description', None) - enabled = getattr(rule_type, 'enabled', False) - parameters = getattr(rule_type, 'parameters', {}) + name = getattr(rule_type, "name", None) + description = getattr(rule_type, "description", None) + enabled = getattr(rule_type, "enabled", False) + parameters = getattr(rule_type, "parameters", {}) - return cls.model(name=name, description=description, enabled=enabled, - parameters=parameters) + return cls.model( + name=name, description=description, enabled=enabled, parameters=parameters + ) class RuleAPI(BaseAPI, APIUIDMixin): @@ -113,100 +109,60 @@ class RuleAPI(BaseAPI, APIUIDMixin): status: enabled or disabled. If disabled occurrence of the trigger does not lead to execution of a action and vice-versa. """ + model = RuleDB schema = { - 'type': 'object', - 'properties': { - 'id': { - 'type': 'string', - 'default': None - }, + "type": "object", + "properties": { + "id": {"type": "string", "default": None}, "ref": { "description": ( "System computed user friendly reference for the rule. " "Provided value will be overridden by computed value." ), - "type": "string" - }, - 'uid': { - 'type': 'string' - }, - 'name': { - 'type': 'string', - 'required': True - }, - 'pack': { - 'type': 'string', - 'default': DEFAULT_PACK_NAME - }, - 'description': { - 'type': 'string' + "type": "string", }, - 'type': RuleTypeSpec.schema, - 'trigger': { - 'type': 'object', - 'required': True, - 'properties': { - 'type': { - 'type': 'string', - 'required': True - }, - 'description': { - 'type': 'string', - 'require': False - }, - 'parameters': { - 'type': 'object', - 'default': {} - }, - 'ref': { - 'type': 'string', - 'required': False - } + "uid": {"type": "string"}, + "name": {"type": "string", "required": True}, + "pack": {"type": "string", "default": DEFAULT_PACK_NAME}, + "description": {"type": "string"}, + "type": RuleTypeSpec.schema, + "trigger": { + "type": "object", + "required": True, + "properties": { + "type": {"type": "string", "required": True}, + "description": {"type": "string", "require": False}, + "parameters": {"type": "object", "default": {}}, + "ref": {"type": "string", "required": False}, }, - 'additionalProperties': True - }, - 'criteria': { - 'type': 'object', - 'default': {} - }, - 'action': { - 'type': 'object', - 'required': True, - 'properties': { - 'ref': { - 'type': 'string', - 'required': True - }, - 'description': { - 'type': 'string', - 'require': False - }, - 'parameters': { - 'type': 'object' - } + "additionalProperties": True, + }, + "criteria": {"type": "object", "default": {}}, + "action": { + "type": "object", + "required": True, + "properties": { + "ref": {"type": "string", "required": True}, + "description": {"type": "string", "require": False}, + "parameters": {"type": "object"}, }, - 'additionalProperties': False - }, - 'enabled': { - 'type': 'boolean', - 'default': False - }, - 'context': { - 'type': 'object' + "additionalProperties": False, }, + "enabled": {"type": "boolean", "default": False}, + "context": {"type": "object"}, "tags": { "description": "User associated metadata assigned to this object.", "type": "array", - "items": {"type": "object"} + "items": {"type": "object"}, }, "metadata_file": { "description": "Path to the metadata file relative to the pack directory.", "type": "string", - "default": "" - } + "default": "", + }, }, - 'additionalProperties': False + "additionalProperties": False, } @classmethod @@ -215,58 +171,62 @@ def from_model(cls, model, mask_secrets=False, ignore_missing_trigger=False): trigger_db = reference.get_model_by_resource_ref(Trigger, model.trigger) if not ignore_missing_trigger and not trigger_db: - raise ValueError('Missing TriggerDB object for rule %s' % (rule['id'])) + raise ValueError("Missing TriggerDB object for rule %s" % (rule["id"])) if trigger_db: - rule['trigger'] = { - 'type': trigger_db.type, - 'parameters': trigger_db.parameters, - 'ref': model.trigger + rule["trigger"] = { + "type": trigger_db.type, + "parameters": trigger_db.parameters, + "ref": model.trigger, } - rule['tags'] = TagsHelper.from_model(model.tags) + rule["tags"] = TagsHelper.from_model(model.tags) return cls(**rule) @classmethod def to_model(cls, rule): kwargs = {} - kwargs['name'] = getattr(rule, 'name', None) - kwargs['description'] = getattr(rule, 'description', None) + kwargs["name"] = getattr(rule, "name", None) + kwargs["description"] = getattr(rule, "description", None) # Validate trigger parameters # Note: This must happen before we create a trigger, otherwise create trigger could fail # with a cryptic error - trigger = getattr(rule, 'trigger', {}) - trigger_type_ref = trigger.get('type', None) - parameters = trigger.get('parameters', {}) + trigger = getattr(rule, "trigger", {}) + trigger_type_ref = trigger.get("type", None) + parameters = trigger.get("parameters", {}) - validator.validate_trigger_parameters(trigger_type_ref=trigger_type_ref, - parameters=parameters) + validator.validate_trigger_parameters( + trigger_type_ref=trigger_type_ref, parameters=parameters + ) # Create a trigger for the provided rule trigger_db = TriggerService.create_trigger_db_from_rule(rule) - kwargs['trigger'] = reference.get_str_resource_ref_from_model(trigger_db) + kwargs["trigger"] = reference.get_str_resource_ref_from_model(trigger_db) - kwargs['pack'] = getattr(rule, 'pack', DEFAULT_PACK_NAME) - kwargs['ref'] = ResourceReference.to_string_reference(pack=kwargs['pack'], - name=kwargs['name']) + kwargs["pack"] = getattr(rule, "pack", DEFAULT_PACK_NAME) + kwargs["ref"] = ResourceReference.to_string_reference( + pack=kwargs["pack"], name=kwargs["name"] + ) # Validate criteria - kwargs['criteria'] = dict(getattr(rule, 'criteria', {})) - validator.validate_criteria(kwargs['criteria']) + kwargs["criteria"] = dict(getattr(rule, "criteria", {})) + validator.validate_criteria(kwargs["criteria"]) - kwargs['action'] = ActionExecutionSpecDB(ref=rule.action['ref'], - parameters=rule.action.get('parameters', {})) + kwargs["action"] = ActionExecutionSpecDB( + ref=rule.action["ref"], parameters=rule.action.get("parameters", {}) + ) - rule_type = dict(getattr(rule, 'type', {})) + rule_type = dict(getattr(rule, "type", {})) if rule_type: - kwargs['type'] = RuleTypeSpecDB(ref=rule_type['ref'], - parameters=rule_type.get('parameters', {})) + kwargs["type"] = RuleTypeSpecDB( + ref=rule_type["ref"], parameters=rule_type.get("parameters", {}) + ) - kwargs['enabled'] = getattr(rule, 'enabled', False) - kwargs['context'] = getattr(rule, 'context', dict()) - kwargs['tags'] = TagsHelper.to_model(getattr(rule, 'tags', [])) - kwargs['metadata_file'] = getattr(rule, 'metadata_file', None) + kwargs["enabled"] = getattr(rule, "enabled", False) + kwargs["context"] = getattr(rule, "context", dict()) + kwargs["tags"] = TagsHelper.to_model(getattr(rule, "tags", [])) + kwargs["metadata_file"] = getattr(rule, "metadata_file", None) model = cls.model(**kwargs) return model @@ -277,13 +237,5 @@ class RuleViewAPI(RuleAPI): # Always deep-copy to avoid breaking the original. schema = copy.deepcopy(RuleAPI.schema) # Update the schema to include the description properties - schema['properties']['action'].update({ - 'description': { - 'type': 'string' - } - }) - schema['properties']['trigger'].update({ - 'description': { - 'type': 'string' - } - }) + schema["properties"]["action"].update({"description": {"type": "string"}}) + schema["properties"]["trigger"].update({"description": {"type": "string"}}) diff --git a/st2common/st2common/models/api/rule_enforcement.py b/st2common/st2common/models/api/rule_enforcement.py index c950b59bfe0..d7aa1bc8731 100644 --- a/st2common/st2common/models/api/rule_enforcement.py +++ b/st2common/st2common/models/api/rule_enforcement.py @@ -28,95 +28,98 @@ from st2common.constants.rule_enforcement import RULE_ENFORCEMENT_STATUSES from st2common.util import isotime -__all__ = [ - 'RuleEnforcementAPI', - 'RuleEnforcementViewAPI', - - 'RuleReferenceSpecDB' -] +__all__ = ["RuleEnforcementAPI", "RuleEnforcementViewAPI", "RuleReferenceSpecDB"] class RuleReferenceSpec(BaseAPI): schema = { - 'type': 'object', - 'properties': { - 'ref': { - 'type': 'string', - 'required': True, + "type": "object", + "properties": { + "ref": { + "type": "string", + "required": True, }, - 'uid': { - 'type': 'string', - 'required': True, + "uid": { + "type": "string", + "required": True, }, - 'id': { - 'type': 'string', - 'required': False, + "id": { + "type": "string", + "required": False, }, }, - 'additionalProperties': False + "additionalProperties": False, } class RuleEnforcementAPI(BaseAPI): model = RuleEnforcementDB schema = { - 'title': 'RuleEnforcement', - 'description': 'A specific instance of rule enforcement.', - 'type': 'object', - 'properties': { - 'trigger_instance_id': { - 'description': 'The unique identifier for the trigger instance ' + - 'that flipped the rule.', - 'type': 'string', - 'required': True + "title": "RuleEnforcement", + "description": "A specific instance of rule enforcement.", + "type": "object", + "properties": { + "trigger_instance_id": { + "description": "The unique identifier for the trigger instance " + + "that flipped the rule.", + "type": "string", + "required": True, }, - 'execution_id': { - 'description': 'ID of the action execution that was invoked as a response.', - 'type': 'string' + "execution_id": { + "description": "ID of the action execution that was invoked as a response.", + "type": "string", }, - 'failure_reason': { - 'description': 'Reason for failure to execute the action specified in the rule.', - 'type': 'string' + "failure_reason": { + "description": "Reason for failure to execute the action specified in the rule.", + "type": "string", }, - 'rule': RuleReferenceSpec.schema, - 'enforced_at': { - 'description': 'Timestamp when rule enforcement happened.', - 'type': 'string', - 'required': True + "rule": RuleReferenceSpec.schema, + "enforced_at": { + "description": "Timestamp when rule enforcement happened.", + "type": "string", + "required": True, }, "status": { "description": "Rule enforcement status.", "type": "string", - "enum": RULE_ENFORCEMENT_STATUSES + "enum": RULE_ENFORCEMENT_STATUSES, }, }, - 'additionalProperties': False + "additionalProperties": False, } @classmethod def to_model(cls, rule_enforcement): - trigger_instance_id = getattr(rule_enforcement, 'trigger_instance_id', None) - execution_id = getattr(rule_enforcement, 'execution_id', None) - enforced_at = getattr(rule_enforcement, 'enforced_at', None) - failure_reason = getattr(rule_enforcement, 'failure_reason', None) - status = getattr(rule_enforcement, 'status', RULE_ENFORCEMENT_STATUS_SUCCEEDED) - - rule_ref_model = dict(getattr(rule_enforcement, 'rule', {})) - rule = RuleReferenceSpecDB(ref=rule_ref_model['ref'], id=rule_ref_model['id'], - uid=rule_ref_model['uid']) + trigger_instance_id = getattr(rule_enforcement, "trigger_instance_id", None) + execution_id = getattr(rule_enforcement, "execution_id", None) + enforced_at = getattr(rule_enforcement, "enforced_at", None) + failure_reason = getattr(rule_enforcement, "failure_reason", None) + status = getattr(rule_enforcement, "status", RULE_ENFORCEMENT_STATUS_SUCCEEDED) + + rule_ref_model = dict(getattr(rule_enforcement, "rule", {})) + rule = RuleReferenceSpecDB( + ref=rule_ref_model["ref"], + id=rule_ref_model["id"], + uid=rule_ref_model["uid"], + ) if enforced_at: enforced_at = isotime.parse(enforced_at) - return cls.model(trigger_instance_id=trigger_instance_id, execution_id=execution_id, - failure_reason=failure_reason, enforced_at=enforced_at, rule=rule, - status=status) + return cls.model( + trigger_instance_id=trigger_instance_id, + execution_id=execution_id, + failure_reason=failure_reason, + enforced_at=enforced_at, + rule=rule, + status=status, + ) @classmethod def from_model(cls, model, mask_secrets=False): doc = cls._from_model(model, mask_secrets=mask_secrets) enforced_at = isotime.format(model.enforced_at, offset=False) - doc['enforced_at'] = enforced_at + doc["enforced_at"] = enforced_at attrs = {attr: value for attr, value in six.iteritems(doc) if value} return cls(**attrs) @@ -126,7 +129,7 @@ class RuleEnforcementViewAPI(RuleEnforcementAPI): schema = copy.deepcopy(RuleEnforcementAPI.schema) # Update the schema to include additional execution properties - schema['properties']['execution'] = copy.deepcopy(ActionExecutionAPI.schema) + schema["properties"]["execution"] = copy.deepcopy(ActionExecutionAPI.schema) # Update the schema to include additional trigger instance properties - schema['properties']['trigger_instance'] = copy.deepcopy(TriggerInstanceAPI.schema) + schema["properties"]["trigger_instance"] = copy.deepcopy(TriggerInstanceAPI.schema) diff --git a/st2common/st2common/models/api/sensor.py b/st2common/st2common/models/api/sensor.py index af9c6876110..a2ba978adf3 100644 --- a/st2common/st2common/models/api/sensor.py +++ b/st2common/st2common/models/api/sensor.py @@ -22,53 +22,34 @@ class SensorTypeAPI(BaseAPI): model = SensorTypeDB schema = { - 'type': 'object', - 'properties': { - 'id': { - 'type': 'string', - 'default': None - }, - 'ref': { - 'type': 'string' - }, - 'uid': { - 'type': 'string' - }, - 'name': { - 'type': 'string', - 'required': True - }, - 'pack': { - 'type': 'string' - }, - 'description': { - 'type': 'string' - }, - 'artifact_uri': { - 'type': 'string', - }, - 'entry_point': { - 'type': 'string', - }, - 'enabled': { - 'description': 'Enable or disable the sensor.', - 'type': 'boolean', - 'default': True + "type": "object", + "properties": { + "id": {"type": "string", "default": None}, + "ref": {"type": "string"}, + "uid": {"type": "string"}, + "name": {"type": "string", "required": True}, + "pack": {"type": "string"}, + "description": {"type": "string"}, + "artifact_uri": { + "type": "string", }, - 'trigger_types': { - 'type': 'array', - 'default': [] + "entry_point": { + "type": "string", }, - 'poll_interval': { - 'type': 'number' + "enabled": { + "description": "Enable or disable the sensor.", + "type": "boolean", + "default": True, }, + "trigger_types": {"type": "array", "default": []}, + "poll_interval": {"type": "number"}, "metadata_file": { "description": "Path to the metadata file relative to the pack directory.", "type": "string", - "default": "" - } + "default": "", + }, }, - 'additionalProperties': False + "additionalProperties": False, } @classmethod diff --git a/st2common/st2common/models/api/tag.py b/st2common/st2common/models/api/tag.py index 78d92568e0e..0ed763a51f1 100644 --- a/st2common/st2common/models/api/tag.py +++ b/st2common/st2common/models/api/tag.py @@ -16,19 +16,19 @@ from __future__ import absolute_import from st2common.models.db.stormbase import TagField -__all__ = [ - 'TagsHelper' -] +__all__ = ["TagsHelper"] class TagsHelper(object): - @staticmethod def to_model(tags): tags = tags or [] - return [TagField(name=tag.get('name', ''), value=tag.get('value', '')) for tag in tags] + return [ + TagField(name=tag.get("name", ""), value=tag.get("value", "")) + for tag in tags + ] @staticmethod def from_model(tags): tags = tags or [] - return [{'name': tag.name, 'value': tag.value} for tag in tags] + return [{"name": tag.name, "value": tag.value} for tag in tags] diff --git a/st2common/st2common/models/api/trace.py b/st2common/st2common/models/api/trace.py index f09faf6d366..4ce8ec89420 100644 --- a/st2common/st2common/models/api/trace.py +++ b/st2common/st2common/models/api/trace.py @@ -21,141 +21,148 @@ TraceComponentAPISchema = { - 'type': 'object', - 'properties': { - 'object_id': { - 'type': 'string', - 'description': 'Id of the component', - 'required': True + "type": "object", + "properties": { + "object_id": { + "type": "string", + "description": "Id of the component", + "required": True, }, - 'ref': { - 'type': 'string', - 'description': 'ref of the component', - 'required': False + "ref": { + "type": "string", + "description": "ref of the component", + "required": False, }, - 'updated_at': { - 'description': 'The start time when the action is executed.', - 'type': 'string', - 'pattern': isotime.ISO8601_UTC_REGEX + "updated_at": { + "description": "The start time when the action is executed.", + "type": "string", + "pattern": isotime.ISO8601_UTC_REGEX, }, - 'caused_by': { - 'type': 'object', - 'description': 'Component that is the cause or the predecesor.', - 'properties': { - 'id': { - 'description': 'Id of the causal component.', - 'type': 'string' + "caused_by": { + "type": "object", + "description": "Component that is the cause or the predecesor.", + "properties": { + "id": {"description": "Id of the causal component.", "type": "string"}, + "type": { + "description": "Type of the causal component.", + "type": "string", }, - 'type': { - 'description': 'Type of the causal component.', - 'type': 'string' - } - } - } + }, + }, }, - 'additionalProperties': False + "additionalProperties": False, } class TraceAPI(BaseAPI, APIUIDMixin): model = TraceDB schema = { - 'title': 'Trace', - 'desciption': 'Trace is a collection of all TriggerInstances, Rules and ActionExecutions \ + "title": "Trace", + "desciption": "Trace is a collection of all TriggerInstances, Rules and ActionExecutions \ that represent an activity which begins with the introduction of a \ TriggerInstance or request of an ActionExecution and ends with the \ - completion of an ActionExecution.', - 'type': 'object', - 'properties': { - 'id': { - 'description': 'The unique identifier for a Trace.', - 'type': 'string', - 'default': None + completion of an ActionExecution.", + "type": "object", + "properties": { + "id": { + "description": "The unique identifier for a Trace.", + "type": "string", + "default": None, }, - 'trace_tag': { - 'description': 'User assigned identifier for each Trace.', - 'type': 'string', - 'required': True + "trace_tag": { + "description": "User assigned identifier for each Trace.", + "type": "string", + "required": True, }, - 'action_executions': { - 'description': 'All ActionExecutions belonging to a Trace.', - 'type': 'array', - 'items': TraceComponentAPISchema + "action_executions": { + "description": "All ActionExecutions belonging to a Trace.", + "type": "array", + "items": TraceComponentAPISchema, }, - 'rules': { - 'description': 'All rules that applied as part of a Trace.', - 'type': 'array', - 'items': TraceComponentAPISchema + "rules": { + "description": "All rules that applied as part of a Trace.", + "type": "array", + "items": TraceComponentAPISchema, }, - 'trigger_instances': { - 'description': 'All TriggerInstances fired during a Trace.', - 'type': 'array', - 'items': TraceComponentAPISchema + "trigger_instances": { + "description": "All TriggerInstances fired during a Trace.", + "type": "array", + "items": TraceComponentAPISchema, }, - 'start_timestamp': { - 'description': 'Timestamp when the Trace is started.', - 'type': 'string', - 'pattern': isotime.ISO8601_UTC_REGEX + "start_timestamp": { + "description": "Timestamp when the Trace is started.", + "type": "string", + "pattern": isotime.ISO8601_UTC_REGEX, }, }, - 'additionalProperties': False + "additionalProperties": False, } @classmethod def to_component_model(cls, component): values = { - 'object_id': component['object_id'], - 'ref': component['ref'], - 'caused_by': component.get('caused_by', {}) + "object_id": component["object_id"], + "ref": component["ref"], + "caused_by": component.get("caused_by", {}), } - updated_at = component.get('updated_at', None) + updated_at = component.get("updated_at", None) if updated_at: - values['updated_at'] = isotime.parse(updated_at) + values["updated_at"] = isotime.parse(updated_at) return TraceComponentDB(**values) @classmethod def to_model(cls, instance): - values = { - 'trace_tag': instance.trace_tag - } - action_executions = getattr(instance, 'action_executions', []) - action_executions = [TraceAPI.to_component_model(component=action_execution) - for action_execution in action_executions] - values['action_executions'] = action_executions - - rules = getattr(instance, 'rules', []) + values = {"trace_tag": instance.trace_tag} + action_executions = getattr(instance, "action_executions", []) + action_executions = [ + TraceAPI.to_component_model(component=action_execution) + for action_execution in action_executions + ] + values["action_executions"] = action_executions + + rules = getattr(instance, "rules", []) rules = [TraceAPI.to_component_model(component=rule) for rule in rules] - values['rules'] = rules + values["rules"] = rules - trigger_instances = getattr(instance, 'trigger_instances', []) - trigger_instances = [TraceAPI.to_component_model(component=trigger_instance) - for trigger_instance in trigger_instances] - values['trigger_instances'] = trigger_instances + trigger_instances = getattr(instance, "trigger_instances", []) + trigger_instances = [ + TraceAPI.to_component_model(component=trigger_instance) + for trigger_instance in trigger_instances + ] + values["trigger_instances"] = trigger_instances - start_timestamp = getattr(instance, 'start_timestamp', None) + start_timestamp = getattr(instance, "start_timestamp", None) if start_timestamp: - values['start_timestamp'] = isotime.parse(start_timestamp) + values["start_timestamp"] = isotime.parse(start_timestamp) return cls.model(**values) @classmethod def from_component_model(cls, component_model): - return {'object_id': component_model.object_id, - 'ref': component_model.ref, - 'updated_at': isotime.format(component_model.updated_at, offset=False), - 'caused_by': component_model.caused_by} + return { + "object_id": component_model.object_id, + "ref": component_model.ref, + "updated_at": isotime.format(component_model.updated_at, offset=False), + "caused_by": component_model.caused_by, + } @classmethod def from_model(cls, model, mask_secrets=False): instance = cls._from_model(model, mask_secrets=mask_secrets) - instance['start_timestamp'] = isotime.format(model.start_timestamp, offset=False) + instance["start_timestamp"] = isotime.format( + model.start_timestamp, offset=False + ) if model.action_executions: - instance['action_executions'] = [cls.from_component_model(action_execution) - for action_execution in model.action_executions] + instance["action_executions"] = [ + cls.from_component_model(action_execution) + for action_execution in model.action_executions + ] if model.rules: - instance['rules'] = [cls.from_component_model(rule) for rule in model.rules] + instance["rules"] = [cls.from_component_model(rule) for rule in model.rules] if model.trigger_instances: - instance['trigger_instances'] = [cls.from_component_model(trigger_instance) - for trigger_instance in model.trigger_instances] + instance["trigger_instances"] = [ + cls.from_component_model(trigger_instance) + for trigger_instance in model.trigger_instances + ] return cls(**instance) @@ -173,12 +180,13 @@ class TraceContext(object): Optional property. :type trace_tag: ``str`` """ + def __init__(self, id_=None, trace_tag=None): self.id_ = id_ self.trace_tag = trace_tag def __str__(self): - return '{id_: %s, trace_tag: %s}' % (self.id_, self.trace_tag) + return "{id_: %s, trace_tag: %s}" % (self.id_, self.trace_tag) def __json__(self): return vars(self) diff --git a/st2common/st2common/models/api/trigger.py b/st2common/st2common/models/api/trigger.py index af88027fe08..cdb2cd9ddd1 100644 --- a/st2common/st2common/models/api/trigger.py +++ b/st2common/st2common/models/api/trigger.py @@ -23,140 +23,113 @@ from st2common.models.db.trigger import TriggerTypeDB, TriggerDB, TriggerInstanceDB from st2common.models.system.common import ResourceReference -DATE_FORMAT = '%Y-%m-%d %H:%M:%S.%f' +DATE_FORMAT = "%Y-%m-%d %H:%M:%S.%f" class TriggerTypeAPI(BaseAPI): model = TriggerTypeDB schema = { - 'type': 'object', - 'properties': { - 'id': { - 'type': 'string', - 'default': None - }, - 'ref': { - 'type': 'string' - }, - 'uid': { - 'type': 'string' - }, - 'name': { - 'type': 'string', - 'required': True - }, - 'pack': { - 'type': 'string' - }, - 'description': { - 'type': 'string' - }, - 'payload_schema': { - 'type': 'object', - 'default': {} - }, - 'parameters_schema': { - 'type': 'object', - 'default': {} - }, - 'tags': { - 'description': 'User associated metadata assigned to this object.', - 'type': 'array', - 'items': {'type': 'object'} + "type": "object", + "properties": { + "id": {"type": "string", "default": None}, + "ref": {"type": "string"}, + "uid": {"type": "string"}, + "name": {"type": "string", "required": True}, + "pack": {"type": "string"}, + "description": {"type": "string"}, + "payload_schema": {"type": "object", "default": {}}, + "parameters_schema": {"type": "object", "default": {}}, + "tags": { + "description": "User associated metadata assigned to this object.", + "type": "array", + "items": {"type": "object"}, }, "metadata_file": { "description": "Path to the metadata file relative to the pack directory.", "type": "string", - "default": "" - } + "default": "", + }, }, - 'additionalProperties': False + "additionalProperties": False, } @classmethod def to_model(cls, trigger_type): - name = getattr(trigger_type, 'name', None) - description = getattr(trigger_type, 'description', None) - pack = getattr(trigger_type, 'pack', None) - payload_schema = getattr(trigger_type, 'payload_schema', {}) - parameters_schema = getattr(trigger_type, 'parameters_schema', {}) - tags = TagsHelper.to_model(getattr(trigger_type, 'tags', [])) - metadata_file = getattr(trigger_type, 'metadata_file', None) - - model = cls.model(name=name, description=description, pack=pack, - payload_schema=payload_schema, parameters_schema=parameters_schema, - tags=tags, metadata_file=metadata_file) + name = getattr(trigger_type, "name", None) + description = getattr(trigger_type, "description", None) + pack = getattr(trigger_type, "pack", None) + payload_schema = getattr(trigger_type, "payload_schema", {}) + parameters_schema = getattr(trigger_type, "parameters_schema", {}) + tags = TagsHelper.to_model(getattr(trigger_type, "tags", [])) + metadata_file = getattr(trigger_type, "metadata_file", None) + + model = cls.model( + name=name, + description=description, + pack=pack, + payload_schema=payload_schema, + parameters_schema=parameters_schema, + tags=tags, + metadata_file=metadata_file, + ) return model @classmethod def from_model(cls, model, mask_secrets=False): triggertype = cls._from_model(model, mask_secrets=mask_secrets) - triggertype['tags'] = TagsHelper.from_model(model.tags) + triggertype["tags"] = TagsHelper.from_model(model.tags) return cls(**triggertype) class TriggerAPI(BaseAPI): model = TriggerDB schema = { - 'type': 'object', - 'properties': { - 'id': { - 'type': 'string', - 'default': None - }, - 'ref': { - 'type': 'string' - }, - 'uid': { - 'type': 'string' - }, - 'name': { - 'type': 'string' - }, - 'pack': { - 'type': 'string' - }, - 'type': { - 'type': 'string', - 'required': True - }, - 'parameters': { - 'type': 'object' - }, - 'description': { - 'type': 'string' - } + "type": "object", + "properties": { + "id": {"type": "string", "default": None}, + "ref": {"type": "string"}, + "uid": {"type": "string"}, + "name": {"type": "string"}, + "pack": {"type": "string"}, + "type": {"type": "string", "required": True}, + "parameters": {"type": "object"}, + "description": {"type": "string"}, }, - 'additionalProperties': False + "additionalProperties": False, } @classmethod def from_model(cls, model, mask_secrets=False): trigger = cls._from_model(model, mask_secrets=mask_secrets) # Hide ref count from API. - trigger.pop('ref_count', None) + trigger.pop("ref_count", None) return cls(**trigger) @classmethod def to_model(cls, trigger): - name = getattr(trigger, 'name', None) - description = getattr(trigger, 'description', None) - pack = getattr(trigger, 'pack', None) - _type = getattr(trigger, 'type', None) - parameters = getattr(trigger, 'parameters', {}) + name = getattr(trigger, "name", None) + description = getattr(trigger, "description", None) + pack = getattr(trigger, "pack", None) + _type = getattr(trigger, "type", None) + parameters = getattr(trigger, "parameters", {}) if _type and not parameters: trigger_type_ref = ResourceReference.from_string_reference(_type) name = trigger_type_ref.name - if hasattr(trigger, 'name') and trigger.name: + if hasattr(trigger, "name") and trigger.name: name = trigger.name else: # assign a name if none is provided. name = str(uuid.uuid4()) - model = cls.model(name=name, description=description, pack=pack, type=_type, - parameters=parameters) + model = cls.model( + name=name, + description=description, + pack=pack, + type=_type, + parameters=parameters, + ) return model def to_dict(self): @@ -167,38 +140,29 @@ def to_dict(self): class TriggerInstanceAPI(BaseAPI): model = TriggerInstanceDB schema = { - 'type': 'object', - 'properties': { - 'id': { - 'type': 'string' - }, - 'occurrence_time': { - 'type': 'string', - 'pattern': isotime.ISO8601_UTC_REGEX - }, - 'payload': { - 'type': 'object' - }, - 'trigger': { - 'type': 'string', - 'default': None, - 'required': True + "type": "object", + "properties": { + "id": {"type": "string"}, + "occurrence_time": {"type": "string", "pattern": isotime.ISO8601_UTC_REGEX}, + "payload": {"type": "object"}, + "trigger": {"type": "string", "default": None, "required": True}, + "status": { + "type": "string", + "default": None, + "enum": TRIGGER_INSTANCE_STATUSES, }, - 'status': { - 'type': 'string', - 'default': None, - 'enum': TRIGGER_INSTANCE_STATUSES - } }, - 'additionalProperties': False + "additionalProperties": False, } @classmethod def from_model(cls, model, mask_secrets=False): instance = cls._from_model(model, mask_secrets=mask_secrets) - if instance.get('occurrence_time', None): - instance['occurrence_time'] = isotime.format(instance['occurrence_time'], offset=False) + if instance.get("occurrence_time", None): + instance["occurrence_time"] = isotime.format( + instance["occurrence_time"], offset=False + ) return cls(**instance) @@ -209,6 +173,10 @@ def to_model(cls, instance): occurrence_time = isotime.parse(instance.occurrence_time) status = instance.status - model = cls.model(trigger=trigger, payload=payload, occurrence_time=occurrence_time, - status=status) + model = cls.model( + trigger=trigger, + payload=payload, + occurrence_time=occurrence_time, + status=status, + ) return model diff --git a/st2common/st2common/models/api/webhook.py b/st2common/st2common/models/api/webhook.py index 9d1a37ed1d4..eb7b04a29b9 100644 --- a/st2common/st2common/models/api/webhook.py +++ b/st2common/st2common/models/api/webhook.py @@ -15,20 +15,15 @@ from st2common.models.api.base import BaseAPI -__all___ = [ - 'WebhookBodyAPI' -] +__all___ = ["WebhookBodyAPI"] class WebhookBodyAPI(BaseAPI): schema = { - 'type': 'object', - 'properties': { + "type": "object", + "properties": { # Holds actual webhook body - 'data': { - 'type': ['object', 'array'], - 'required': True - } + "data": {"type": ["object", "array"], "required": True} }, - 'additionalProperties': False + "additionalProperties": False, } diff --git a/st2common/st2common/models/base.py b/st2common/st2common/models/base.py index 342daf70280..35d5c884a7a 100644 --- a/st2common/st2common/models/base.py +++ b/st2common/st2common/models/base.py @@ -17,9 +17,7 @@ Common model related classes. """ -__all__ = [ - 'DictSerializableClassMixin' -] +__all__ = ["DictSerializableClassMixin"] class DictSerializableClassMixin(object): diff --git a/st2common/st2common/models/db/__init__.py b/st2common/st2common/models/db/__init__.py index 4fd51b4f61b..ee7261facd0 100644 --- a/st2common/st2common/models/db/__init__.py +++ b/st2common/st2common/models/db/__init__.py @@ -40,32 +40,30 @@ LOG = logging.getLogger(__name__) MODEL_MODULE_NAMES = [ - 'st2common.models.db.auth', - 'st2common.models.db.action', - 'st2common.models.db.actionalias', - 'st2common.models.db.keyvalue', - 'st2common.models.db.execution', - 'st2common.models.db.executionstate', - 'st2common.models.db.execution_queue', - 'st2common.models.db.liveaction', - 'st2common.models.db.notification', - 'st2common.models.db.pack', - 'st2common.models.db.policy', - 'st2common.models.db.rbac', - 'st2common.models.db.rule', - 'st2common.models.db.rule_enforcement', - 'st2common.models.db.runner', - 'st2common.models.db.sensor', - 'st2common.models.db.trace', - 'st2common.models.db.trigger', - 'st2common.models.db.webhook', - 'st2common.models.db.workflow' + "st2common.models.db.auth", + "st2common.models.db.action", + "st2common.models.db.actionalias", + "st2common.models.db.keyvalue", + "st2common.models.db.execution", + "st2common.models.db.executionstate", + "st2common.models.db.execution_queue", + "st2common.models.db.liveaction", + "st2common.models.db.notification", + "st2common.models.db.pack", + "st2common.models.db.policy", + "st2common.models.db.rbac", + "st2common.models.db.rule", + "st2common.models.db.rule_enforcement", + "st2common.models.db.runner", + "st2common.models.db.sensor", + "st2common.models.db.trace", + "st2common.models.db.trigger", + "st2common.models.db.webhook", + "st2common.models.db.workflow", ] # A list of model names for which we don't perform extra index cleanup -INDEX_CLEANUP_MODEL_NAMES_BLACKLIST = [ - 'PermissionGrantDB' -] +INDEX_CLEANUP_MODEL_NAMES_BLACKLIST = ["PermissionGrantDB"] # Reference to DB model classes used for db_ensure_indexes # NOTE: This variable is populated lazily inside get_model_classes() @@ -86,55 +84,78 @@ def get_model_classes(): result = [] for module_name in MODEL_MODULE_NAMES: module = importlib.import_module(module_name) - model_classes = getattr(module, 'MODELS', []) + model_classes = getattr(module, "MODELS", []) result.extend(model_classes) MODEL_CLASSES = result return MODEL_CLASSES -def _db_connect(db_name, db_host, db_port, username=None, password=None, - ssl=False, ssl_keyfile=None, ssl_certfile=None, ssl_cert_reqs=None, - ssl_ca_certs=None, authentication_mechanism=None, ssl_match_hostname=True): - - if '://' in db_host: +def _db_connect( + db_name, + db_host, + db_port, + username=None, + password=None, + ssl=False, + ssl_keyfile=None, + ssl_certfile=None, + ssl_cert_reqs=None, + ssl_ca_certs=None, + authentication_mechanism=None, + ssl_match_hostname=True, +): + + if "://" in db_host: # Hostname is provided as a URI string. Make sure we don't log the password in case one is # included as part of the URI string. uri_dict = uri_parser.parse_uri(db_host) - username_string = uri_dict.get('username', username) or username + username_string = uri_dict.get("username", username) or username - if uri_dict.get('username', None) and username: + if uri_dict.get("username", None) and username: # Username argument has precedence over connection string username username_string = username hostnames = get_host_names_for_uri_dict(uri_dict=uri_dict) - if len(uri_dict['nodelist']) > 1: - host_string = '%s (replica set)' % (hostnames) + if len(uri_dict["nodelist"]) > 1: + host_string = "%s (replica set)" % (hostnames) else: host_string = hostnames else: - host_string = '%s:%s' % (db_host, db_port) + host_string = "%s:%s" % (db_host, db_port) username_string = username - LOG.info('Connecting to database "%s" @ "%s" as user "%s".' % (db_name, host_string, - str(username_string))) - - ssl_kwargs = _get_ssl_kwargs(ssl=ssl, ssl_keyfile=ssl_keyfile, ssl_certfile=ssl_certfile, - ssl_cert_reqs=ssl_cert_reqs, ssl_ca_certs=ssl_ca_certs, - authentication_mechanism=authentication_mechanism, - ssl_match_hostname=ssl_match_hostname) + LOG.info( + 'Connecting to database "%s" @ "%s" as user "%s".' + % (db_name, host_string, str(username_string)) + ) + + ssl_kwargs = _get_ssl_kwargs( + ssl=ssl, + ssl_keyfile=ssl_keyfile, + ssl_certfile=ssl_certfile, + ssl_cert_reqs=ssl_cert_reqs, + ssl_ca_certs=ssl_ca_certs, + authentication_mechanism=authentication_mechanism, + ssl_match_hostname=ssl_match_hostname, + ) # NOTE: We intentionally set "serverSelectionTimeoutMS" to 3 seconds. By default it's set to # 30 seconds, which means it will block up to 30 seconds and fail if there are any SSL related # or other errors connection_timeout = cfg.CONF.database.connection_timeout - connection = mongoengine.connection.connect(db_name, host=db_host, - port=db_port, tz_aware=True, - username=username, password=password, - connectTimeoutMS=connection_timeout, - serverSelectionTimeoutMS=connection_timeout, - **ssl_kwargs) + connection = mongoengine.connection.connect( + db_name, + host=db_host, + port=db_port, + tz_aware=True, + username=username, + password=password, + connectTimeoutMS=connection_timeout, + serverSelectionTimeoutMS=connection_timeout, + **ssl_kwargs, + ) # NOTE: Since pymongo 3.0, connect() method is lazy and not blocking (always returns success) # so we need to issue a command / query to check if connection has been @@ -142,32 +163,55 @@ def _db_connect(db_name, db_host, db_port, username=None, password=None, # See http://api.mongodb.com/python/current/api/pymongo/mongo_client.html for details try: # The ismaster command is cheap and does not require auth - connection.admin.command('ismaster') + connection.admin.command("ismaster") except (ConnectionFailure, ServerSelectionTimeoutError) as e: # NOTE: ServerSelectionTimeoutError can also be thrown if SSLHandShake fails in the server # Sadly the client doesn't include more information about the error so in such scenarios # user needs to check MongoDB server log - LOG.error('Failed to connect to database "%s" @ "%s" as user "%s": %s' % - (db_name, host_string, str(username_string), six.text_type(e))) + LOG.error( + 'Failed to connect to database "%s" @ "%s" as user "%s": %s' + % (db_name, host_string, str(username_string), six.text_type(e)) + ) raise e - LOG.info('Successfully connected to database "%s" @ "%s" as user "%s".' % ( - db_name, host_string, str(username_string))) + LOG.info( + 'Successfully connected to database "%s" @ "%s" as user "%s".' + % (db_name, host_string, str(username_string)) + ) return connection -def db_setup(db_name, db_host, db_port, username=None, password=None, ensure_indexes=True, - ssl=False, ssl_keyfile=None, ssl_certfile=None, - ssl_cert_reqs=None, ssl_ca_certs=None, - authentication_mechanism=None, ssl_match_hostname=True): - - connection = _db_connect(db_name, db_host, db_port, username=username, - password=password, ssl=ssl, ssl_keyfile=ssl_keyfile, - ssl_certfile=ssl_certfile, - ssl_cert_reqs=ssl_cert_reqs, ssl_ca_certs=ssl_ca_certs, - authentication_mechanism=authentication_mechanism, - ssl_match_hostname=ssl_match_hostname) +def db_setup( + db_name, + db_host, + db_port, + username=None, + password=None, + ensure_indexes=True, + ssl=False, + ssl_keyfile=None, + ssl_certfile=None, + ssl_cert_reqs=None, + ssl_ca_certs=None, + authentication_mechanism=None, + ssl_match_hostname=True, +): + + connection = _db_connect( + db_name, + db_host, + db_port, + username=username, + password=password, + ssl=ssl, + ssl_keyfile=ssl_keyfile, + ssl_certfile=ssl_certfile, + ssl_cert_reqs=ssl_cert_reqs, + ssl_ca_certs=ssl_ca_certs, + authentication_mechanism=authentication_mechanism, + ssl_match_hostname=ssl_match_hostname, + ) # Create all the indexes upfront to prevent race-conditions caused by # lazy index creation @@ -192,7 +236,7 @@ def db_ensure_indexes(model_classes=None): ensured for all the models. :type model_classes: ``list`` """ - LOG.debug('Ensuring database indexes...') + LOG.debug("Ensuring database indexes...") if not model_classes: model_classes = get_model_classes() @@ -210,34 +254,44 @@ def db_ensure_indexes(model_classes=None): # Note: This condition would only be encountered when upgrading existing StackStorm # installation from MongoDB 3.2 to 3.4. msg = six.text_type(e) - if 'already exists with different options' in msg and 'uid_1' in msg: + if "already exists with different options" in msg and "uid_1" in msg: drop_obsolete_types_indexes(model_class=model_class) else: raise e except Exception as e: tb_msg = traceback.format_exc() - msg = 'Failed to ensure indexes for model "%s": %s' % (class_name, six.text_type(e)) - msg += '\n\n' + tb_msg + msg = 'Failed to ensure indexes for model "%s": %s' % ( + class_name, + six.text_type(e), + ) + msg += "\n\n" + tb_msg exc_cls = type(e) raise exc_cls(msg) if model_class.__name__ in INDEX_CLEANUP_MODEL_NAMES_BLACKLIST: - LOG.debug('Skipping index cleanup for blacklisted model "%s"...' % (class_name)) + LOG.debug( + 'Skipping index cleanup for blacklisted model "%s"...' % (class_name) + ) continue removed_count = cleanup_extra_indexes(model_class=model_class) if removed_count: - LOG.debug('Removed "%s" extra indexes for model "%s"' % (removed_count, class_name)) + LOG.debug( + 'Removed "%s" extra indexes for model "%s"' + % (removed_count, class_name) + ) - LOG.debug('Indexes are ensured for models: %s' % - ', '.join(sorted((model_class.__name__ for model_class in model_classes)))) + LOG.debug( + "Indexes are ensured for models: %s" + % ", ".join(sorted((model_class.__name__ for model_class in model_classes))) + ) def cleanup_extra_indexes(model_class): """ Finds any extra indexes and removes those from mongodb. """ - extra_indexes = model_class.compare_indexes().get('extra', None) + extra_indexes = model_class.compare_indexes().get("extra", None) if not extra_indexes: return 0 @@ -248,10 +302,14 @@ def cleanup_extra_indexes(model_class): for extra_index in extra_indexes: try: c.drop_index(extra_index) - LOG.debug('Dropped index %s for model %s.', extra_index, model_class.__name__) + LOG.debug( + "Dropped index %s for model %s.", extra_index, model_class.__name__ + ) removed_count += 1 except OperationFailure: - LOG.warning('Attempt to cleanup index %s failed.', extra_index, exc_info=True) + LOG.warning( + "Attempt to cleanup index %s failed.", extra_index, exc_info=True + ) return removed_count @@ -266,14 +324,19 @@ def drop_obsolete_types_indexes(model_class): LOG.debug('Dropping obsolete types index for model "%s"' % (class_name)) collection = model_class._get_collection() - collection.update({}, {'$unset': {'_types': 1}}, multi=True) + collection.update({}, {"$unset": {"_types": 1}}, multi=True) info = collection.index_information() - indexes_to_drop = [key for key, value in six.iteritems(info) - if '_types' in dict(value['key']) or 'types' in value] + indexes_to_drop = [ + key + for key, value in six.iteritems(info) + if "_types" in dict(value["key"]) or "types" in value + ] - LOG.debug('Will drop obsolete types indexes for model "%s": %s' % (class_name, - str(indexes_to_drop))) + LOG.debug( + 'Will drop obsolete types indexes for model "%s": %s' + % (class_name, str(indexes_to_drop)) + ) for index in indexes_to_drop: collection.drop_index(index) @@ -286,57 +349,87 @@ def db_teardown(): mongoengine.connection.disconnect() -def db_cleanup(db_name, db_host, db_port, username=None, password=None, - ssl=False, ssl_keyfile=None, ssl_certfile=None, - ssl_cert_reqs=None, ssl_ca_certs=None, - authentication_mechanism=None, ssl_match_hostname=True): - - connection = _db_connect(db_name, db_host, db_port, username=username, - password=password, ssl=ssl, ssl_keyfile=ssl_keyfile, - ssl_certfile=ssl_certfile, - ssl_cert_reqs=ssl_cert_reqs, ssl_ca_certs=ssl_ca_certs, - authentication_mechanism=authentication_mechanism, - ssl_match_hostname=ssl_match_hostname) - - LOG.info('Dropping database "%s" @ "%s:%s" as user "%s".', - db_name, db_host, db_port, str(username)) +def db_cleanup( + db_name, + db_host, + db_port, + username=None, + password=None, + ssl=False, + ssl_keyfile=None, + ssl_certfile=None, + ssl_cert_reqs=None, + ssl_ca_certs=None, + authentication_mechanism=None, + ssl_match_hostname=True, +): + + connection = _db_connect( + db_name, + db_host, + db_port, + username=username, + password=password, + ssl=ssl, + ssl_keyfile=ssl_keyfile, + ssl_certfile=ssl_certfile, + ssl_cert_reqs=ssl_cert_reqs, + ssl_ca_certs=ssl_ca_certs, + authentication_mechanism=authentication_mechanism, + ssl_match_hostname=ssl_match_hostname, + ) + + LOG.info( + 'Dropping database "%s" @ "%s:%s" as user "%s".', + db_name, + db_host, + db_port, + str(username), + ) connection.drop_database(db_name) return connection -def _get_ssl_kwargs(ssl=False, ssl_keyfile=None, ssl_certfile=None, ssl_cert_reqs=None, - ssl_ca_certs=None, authentication_mechanism=None, ssl_match_hostname=True): +def _get_ssl_kwargs( + ssl=False, + ssl_keyfile=None, + ssl_certfile=None, + ssl_cert_reqs=None, + ssl_ca_certs=None, + authentication_mechanism=None, + ssl_match_hostname=True, +): # NOTE: In pymongo 3.9.0 some of the ssl related arguments have been renamed - # https://api.mongodb.com/python/current/changelog.html#changes-in-version-3-9-0 # Old names still work, but we should eventually update to new argument names. ssl_kwargs = { - 'ssl': ssl, + "ssl": ssl, } if ssl_keyfile: - ssl_kwargs['ssl'] = True - ssl_kwargs['ssl_keyfile'] = ssl_keyfile + ssl_kwargs["ssl"] = True + ssl_kwargs["ssl_keyfile"] = ssl_keyfile if ssl_certfile: - ssl_kwargs['ssl'] = True - ssl_kwargs['ssl_certfile'] = ssl_certfile + ssl_kwargs["ssl"] = True + ssl_kwargs["ssl_certfile"] = ssl_certfile if ssl_cert_reqs: - if ssl_cert_reqs == 'none': + if ssl_cert_reqs == "none": ssl_cert_reqs = ssl_lib.CERT_NONE - elif ssl_cert_reqs == 'optional': + elif ssl_cert_reqs == "optional": ssl_cert_reqs = ssl_lib.CERT_OPTIONAL - elif ssl_cert_reqs == 'required': + elif ssl_cert_reqs == "required": ssl_cert_reqs = ssl_lib.CERT_REQUIRED - ssl_kwargs['ssl_cert_reqs'] = ssl_cert_reqs + ssl_kwargs["ssl_cert_reqs"] = ssl_cert_reqs if ssl_ca_certs: - ssl_kwargs['ssl'] = True - ssl_kwargs['ssl_ca_certs'] = ssl_ca_certs + ssl_kwargs["ssl"] = True + ssl_kwargs["ssl_ca_certs"] = ssl_ca_certs if authentication_mechanism: - ssl_kwargs['ssl'] = True - ssl_kwargs['authentication_mechanism'] = authentication_mechanism - if ssl_kwargs.get('ssl', False): + ssl_kwargs["ssl"] = True + ssl_kwargs["authentication_mechanism"] = authentication_mechanism + if ssl_kwargs.get("ssl", False): # pass in ssl_match_hostname only if ssl is True. The right default value # for ssl_match_hostname in almost all cases is True. - ssl_kwargs['ssl_match_hostname'] = ssl_match_hostname + ssl_kwargs["ssl_match_hostname"] = ssl_match_hostname return ssl_kwargs @@ -362,9 +455,9 @@ def get_by_pack(self, value): return self.get(pack=value, raise_exception=True) def get(self, *args, **kwargs): - exclude_fields = kwargs.pop('exclude_fields', None) - raise_exception = kwargs.pop('raise_exception', False) - only_fields = kwargs.pop('only_fields', None) + exclude_fields = kwargs.pop("exclude_fields", None) + raise_exception = kwargs.pop("raise_exception", False) + only_fields = kwargs.pop("only_fields", None) args = self._process_arg_filters(args) @@ -377,14 +470,17 @@ def get(self, *args, **kwargs): try: instances = instances.only(*only_fields) except (mongoengine.errors.LookUpError, AttributeError) as e: - msg = ('Invalid or unsupported include attribute specified: %s' % six.text_type(e)) + msg = ( + "Invalid or unsupported include attribute specified: %s" + % six.text_type(e) + ) raise ValueError(msg) instance = instances[0] if instances else None log_query_and_profile_data_for_queryset(queryset=instances) if not instance and raise_exception: - msg = 'Unable to find the %s instance. %s' % (self.model.__name__, kwargs) + msg = "Unable to find the %s instance. %s" % (self.model.__name__, kwargs) raise db_exc.StackStormDBObjectNotFoundError(msg) return instance @@ -404,12 +500,12 @@ def count(self, *args, **kwargs): # **filters): def query(self, *args, **filters): # Python 2: Pop keyword parameters that aren't actually filters off of the kwargs - offset = filters.pop('offset', 0) - limit = filters.pop('limit', None) - order_by = filters.pop('order_by', None) - exclude_fields = filters.pop('exclude_fields', None) - only_fields = filters.pop('only_fields', None) - no_dereference = filters.pop('no_dereference', None) + offset = filters.pop("offset", 0) + limit = filters.pop("limit", None) + order_by = filters.pop("order_by", None) + exclude_fields = filters.pop("exclude_fields", None) + only_fields = filters.pop("only_fields", None) + no_dereference = filters.pop("no_dereference", None) order_by = order_by or [] exclude_fields = exclude_fields or [] @@ -419,7 +515,9 @@ def query(self, *args, **filters): # Process the filters # Note: Both of those functions manipulate "filters" variable so the order in which they # are called matters - filters, order_by = self._process_datetime_range_filters(filters=filters, order_by=order_by) + filters, order_by = self._process_datetime_range_filters( + filters=filters, order_by=order_by + ) filters = self._process_null_filters(filters=filters) result = self.model.objects(*args, **filters) @@ -429,7 +527,7 @@ def query(self, *args, **filters): result = result.exclude(*exclude_fields) except (mongoengine.errors.LookUpError, AttributeError) as e: field = get_field_name_from_mongoengine_error(e) - msg = ('Invalid or unsupported exclude attribute specified: %s' % field) + msg = "Invalid or unsupported exclude attribute specified: %s" % field raise ValueError(msg) if only_fields: @@ -437,7 +535,7 @@ def query(self, *args, **filters): result = result.only(*only_fields) except (mongoengine.errors.LookUpError, AttributeError) as e: field = get_field_name_from_mongoengine_error(e) - msg = ('Invalid or unsupported include attribute specified: %s' % field) + msg = "Invalid or unsupported include attribute specified: %s" % field raise ValueError(msg) if no_dereference: @@ -450,7 +548,7 @@ def query(self, *args, **filters): return result def distinct(self, *args, **kwargs): - field = kwargs.pop('field') + field = kwargs.pop("field") result = self.model.objects(**kwargs).distinct(field) log_query_and_profile_data_for_queryset(queryset=result) return result @@ -513,8 +611,10 @@ def _process_arg_filters(self, args): # Create a new QCombination object with the same operation and fixed filters _args += (visitor.QCombination(arg.operation, children),) else: - raise TypeError("Unknown argument type '%s' of argument '%s'" - % (type(arg), repr(arg))) + raise TypeError( + "Unknown argument type '%s' of argument '%s'" + % (type(arg), repr(arg)) + ) return _args @@ -526,35 +626,38 @@ def _process_null_filters(self, filters): for key, value in six.iteritems(filters): if value is None: null_filters[key] = value - elif isinstance(value, (str, six.text_type)) and value.lower() == 'null': + elif isinstance(value, (str, six.text_type)) and value.lower() == "null": null_filters[key] = value else: continue for key in null_filters.keys(): - result['%s__exists' % (key)] = False + result["%s__exists" % (key)] = False del result[key] return result def _process_datetime_range_filters(self, filters, order_by=None): - ranges = {k: v for k, v in six.iteritems(filters) - if type(v) in [str, six.text_type] and '..' in v} + ranges = { + k: v + for k, v in six.iteritems(filters) + if type(v) in [str, six.text_type] and ".." in v + } order_by_list = copy.deepcopy(order_by) if order_by else [] for k, v in six.iteritems(ranges): - values = v.split('..') + values = v.split("..") dt1 = isotime.parse(values[0]) dt2 = isotime.parse(values[1]) - k__gte = '%s__gte' % k - k__lte = '%s__lte' % k + k__gte = "%s__gte" % k + k__lte = "%s__lte" % k if dt1 < dt2: query = {k__gte: dt1, k__lte: dt2} - sort_key, reverse_sort_key = k, '-' + k + sort_key, reverse_sort_key = k, "-" + k else: query = {k__gte: dt2, k__lte: dt1} - sort_key, reverse_sort_key = '-' + k, k + sort_key, reverse_sort_key = "-" + k, k del filters[k] filters.update(query) @@ -569,7 +672,6 @@ def _process_datetime_range_filters(self, filters, order_by=None): class ChangeRevisionMongoDBAccess(MongoDBAccess): - def insert(self, instance): instance = self.model.objects.insert(instance) @@ -585,11 +687,11 @@ def update(self, instance, **kwargs): return self.save(instance) def save(self, instance, validate=True): - if not hasattr(instance, 'id') or not instance.id: + if not hasattr(instance, "id") or not instance.id: return self.insert(instance) else: try: - save_condition = {'id': instance.id, 'rev': instance.rev} + save_condition = {"id": instance.id, "rev": instance.rev} instance.rev = instance.rev + 1 instance.save(save_condition=save_condition, validate=validate) except mongoengine.SaveConditionError: @@ -601,8 +703,8 @@ def save(self, instance, validate=True): def get_host_names_for_uri_dict(uri_dict): hosts = [] - for host, port in uri_dict['nodelist']: - hosts.append('%s:%s' % (host, port)) + for host, port in uri_dict["nodelist"]: + hosts.append("%s:%s" % (host, port)) - hosts = ','.join(hosts) + hosts = ",".join(hosts) return hosts diff --git a/st2common/st2common/models/db/action.py b/st2common/st2common/models/db/action.py index 1c28b207c2d..52a1ed0374c 100644 --- a/st2common/st2common/models/db/action.py +++ b/st2common/st2common/models/db/action.py @@ -29,22 +29,26 @@ from st2common.constants.types import ResourceType __all__ = [ - 'RunnerTypeDB', - 'ActionDB', - 'LiveActionDB', - 'ActionExecutionDB', - 'ActionExecutionStateDB', - 'ActionAliasDB' + "RunnerTypeDB", + "ActionDB", + "LiveActionDB", + "ActionExecutionDB", + "ActionExecutionStateDB", + "ActionAliasDB", ] LOG = logging.getLogger(__name__) -PACK_SEPARATOR = '.' +PACK_SEPARATOR = "." -class ActionDB(stormbase.StormFoundationDB, stormbase.TagsMixin, - stormbase.ContentPackResourceMixin, stormbase.UIDFieldMixin): +class ActionDB( + stormbase.StormFoundationDB, + stormbase.TagsMixin, + stormbase.ContentPackResourceMixin, + stormbase.UIDFieldMixin, +): """ The system entity that represents a Stack Action/Automation in the system. @@ -56,38 +60,46 @@ class ActionDB(stormbase.StormFoundationDB, stormbase.TagsMixin, """ RESOURCE_TYPE = ResourceType.ACTION - UID_FIELDS = ['pack', 'name'] + UID_FIELDS = ["pack", "name"] name = me.StringField(required=True) ref = me.StringField(required=True) description = me.StringField() enabled = me.BooleanField( - required=True, default=True, - help_text='A flag indicating whether the action is enabled.') - entry_point = me.StringField( required=True, - help_text='The entry point to the action.') + default=True, + help_text="A flag indicating whether the action is enabled.", + ) + entry_point = me.StringField( + required=True, help_text="The entry point to the action." + ) pack = me.StringField( - required=False, - help_text='Name of the content pack.', - unique_with='name') + required=False, help_text="Name of the content pack.", unique_with="name" + ) runner_type = me.DictField( - required=True, default={}, - help_text='The action runner to use for executing the action.') + required=True, + default={}, + help_text="The action runner to use for executing the action.", + ) parameters = stormbase.EscapedDynamicField( - help_text='The specification for parameters for the action.') + help_text="The specification for parameters for the action." + ) output_schema = stormbase.EscapedDynamicField( - help_text='The schema for output of the action.') + help_text="The schema for output of the action." + ) notify = me.EmbeddedDocumentField(NotificationSchema) meta = { - 'indexes': [ - {'fields': ['name']}, - {'fields': ['pack']}, - {'fields': ['ref']}, - ] + (stormbase.ContentPackResourceMixin.get_indexes() + - stormbase.TagsMixin.get_indexes() + - stormbase.UIDFieldMixin.get_indexes()) + "indexes": [ + {"fields": ["name"]}, + {"fields": ["pack"]}, + {"fields": ["ref"]}, + ] + + ( + stormbase.ContentPackResourceMixin.get_indexes() + + stormbase.TagsMixin.get_indexes() + + stormbase.UIDFieldMixin.get_indexes() + ) } def __init__(self, *args, **values): @@ -102,11 +114,17 @@ def is_workflow(self): :rtype: ``bool`` """ # pylint: disable=unsubscriptable-object - return self.runner_type['name'] in WORKFLOW_RUNNER_TYPES + return self.runner_type["name"] in WORKFLOW_RUNNER_TYPES # specialized access objects action_access = MongoDBAccess(ActionDB) -MODELS = [ActionDB, ActionExecutionDB, ActionExecutionStateDB, ActionAliasDB, - LiveActionDB, RunnerTypeDB] +MODELS = [ + ActionDB, + ActionExecutionDB, + ActionExecutionStateDB, + ActionAliasDB, + LiveActionDB, + RunnerTypeDB, +] diff --git a/st2common/st2common/models/db/actionalias.py b/st2common/st2common/models/db/actionalias.py index a696ff08b4b..765630d8a4c 100644 --- a/st2common/st2common/models/db/actionalias.py +++ b/st2common/st2common/models/db/actionalias.py @@ -21,18 +21,19 @@ from st2common.models.db import stormbase from st2common.constants.types import ResourceType -__all__ = [ - 'ActionAliasDB' -] +__all__ = ["ActionAliasDB"] LOG = logging.getLogger(__name__) -PACK_SEPARATOR = '.' +PACK_SEPARATOR = "." -class ActionAliasDB(stormbase.StormFoundationDB, stormbase.ContentPackResourceMixin, - stormbase.UIDFieldMixin): +class ActionAliasDB( + stormbase.StormFoundationDB, + stormbase.ContentPackResourceMixin, + stormbase.UIDFieldMixin, +): """ Database entity that represent an Alias for an action. @@ -46,42 +47,48 @@ class ActionAliasDB(stormbase.StormFoundationDB, stormbase.ContentPackResourceMi """ RESOURCE_TYPE = ResourceType.ACTION_ALIAS - UID_FIELDS = ['pack', 'name'] + UID_FIELDS = ["pack", "name"] name = me.StringField(required=True) ref = me.StringField(required=True) description = me.StringField() pack = me.StringField( - required=True, - help_text='Name of the content pack.', - unique_with='name') + required=True, help_text="Name of the content pack.", unique_with="name" + ) enabled = me.BooleanField( - required=True, default=True, - help_text='A flag indicating whether the action alias is enabled.') - action_ref = me.StringField( required=True, - help_text='Reference of the Action map this alias.') + default=True, + help_text="A flag indicating whether the action alias is enabled.", + ) + action_ref = me.StringField( + required=True, help_text="Reference of the Action map this alias." + ) formats = me.ListField( - help_text='Possible parameter formats that an alias supports.') + help_text="Possible parameter formats that an alias supports." + ) ack = me.DictField( - help_text='Parameters pertaining to the acknowledgement message.' + help_text="Parameters pertaining to the acknowledgement message." ) result = me.DictField( - help_text='Parameters pertaining to the execution result message.' + help_text="Parameters pertaining to the execution result message." ) extra = me.DictField( - help_text='Additional parameters (usually adapter-specific) not covered in the schema.' + help_text="Additional parameters (usually adapter-specific) not covered in the schema." ) immutable_parameters = me.DictField( - help_text='Parameters to be passed to the action on every execution.') + help_text="Parameters to be passed to the action on every execution." + ) meta = { - 'indexes': [ - {'fields': ['name']}, - {'fields': ['enabled']}, - {'fields': ['formats']}, - ] + (stormbase.ContentPackResourceMixin().get_indexes() + - stormbase.UIDFieldMixin.get_indexes()) + "indexes": [ + {"fields": ["name"]}, + {"fields": ["enabled"]}, + {"fields": ["formats"]}, + ] + + ( + stormbase.ContentPackResourceMixin().get_indexes() + + stormbase.UIDFieldMixin.get_indexes() + ) } def __init__(self, *args, **values): @@ -97,10 +104,12 @@ def get_format_strings(self): """ result = [] - formats = getattr(self, 'formats', []) + formats = getattr(self, "formats", []) for format_string in formats: - if isinstance(format_string, dict) and format_string.get('representation', None): - result.extend(format_string['representation']) + if isinstance(format_string, dict) and format_string.get( + "representation", None + ): + result.extend(format_string["representation"]) else: result.append(format_string) diff --git a/st2common/st2common/models/db/auth.py b/st2common/st2common/models/db/auth.py index 7ef30ee0172..2531ecb11ab 100644 --- a/st2common/st2common/models/db/auth.py +++ b/st2common/st2common/models/db/auth.py @@ -25,11 +25,7 @@ from st2common.rbac.backends import get_rbac_backend from st2common.util import date as date_utils -__all__ = [ - 'UserDB', - 'TokenDB', - 'ApiKeyDB' -] +__all__ = ["UserDB", "TokenDB", "ApiKeyDB"] class UserDB(stormbase.StormFoundationDB): @@ -42,10 +38,12 @@ class UserDB(stormbase.StormFoundationDB): is_service: True if this is a service account. nicknames: Nickname + origin pairs for ChatOps auth. """ + name = me.StringField(required=True, unique=True) is_service = me.BooleanField(required=True, default=False) - nicknames = me.DictField(required=False, - help_text='"Nickname + origin" pairs for ChatOps auth') + nicknames = me.DictField( + required=False, help_text='"Nickname + origin" pairs for ChatOps auth' + ) def get_roles(self, include_remote=True): """ @@ -57,7 +55,9 @@ def get_roles(self, include_remote=True): :rtype: ``list`` of :class:`RoleDB` """ rbac_service = get_rbac_backend().get_service_class() - result = rbac_service.get_roles_for_user(user_db=self, include_remote=include_remote) + result = rbac_service.get_roles_for_user( + user_db=self, include_remote=include_remote + ) return result def get_permission_assignments(self): @@ -75,11 +75,13 @@ class TokenDB(stormbase.StormFoundationDB): expiry: Date when this token expires. service: True if this is a service (system) token. """ + user = me.StringField(required=True) token = me.StringField(required=True, unique=True) expiry = me.DateTimeField(required=True) - metadata = me.DictField(required=False, - help_text='Arbitrary metadata associated with this token') + metadata = me.DictField( + required=False, help_text="Arbitrary metadata associated with this token" + ) service = me.BooleanField(required=True, default=False) @@ -91,23 +93,24 @@ class ApiKeyDB(stormbase.StormFoundationDB, stormbase.UIDFieldMixin): """ RESOURCE_TYPE = ResourceType.API_KEY - UID_FIELDS = ['key_hash'] + UID_FIELDS = ["key_hash"] user = me.StringField(required=True) key_hash = me.StringField(required=True, unique=True) - metadata = me.DictField(required=False, - help_text='Arbitrary metadata associated with this token') - created_at = ComplexDateTimeField(default=date_utils.get_datetime_utc_now, - help_text='The creation time of this ApiKey.') - enabled = me.BooleanField(required=True, default=True, - help_text='A flag indicating whether the ApiKey is enabled.') - - meta = { - 'indexes': [ - {'fields': ['user']}, - {'fields': ['key_hash']} - ] - } + metadata = me.DictField( + required=False, help_text="Arbitrary metadata associated with this token" + ) + created_at = ComplexDateTimeField( + default=date_utils.get_datetime_utc_now, + help_text="The creation time of this ApiKey.", + ) + enabled = me.BooleanField( + required=True, + default=True, + help_text="A flag indicating whether the ApiKey is enabled.", + ) + + meta = {"indexes": [{"fields": ["user"]}, {"fields": ["key_hash"]}]} def __init__(self, *args, **values): super(ApiKeyDB, self).__init__(*args, **values) @@ -119,8 +122,8 @@ def mask_secrets(self, value): # In theory the key_hash is safe to return as it is one way. On the other # hand given that this is actually a secret no real point in letting the hash # escape. Since uid contains key_hash masking that as well. - result['key_hash'] = MASKED_ATTRIBUTE_VALUE - result['uid'] = MASKED_ATTRIBUTE_VALUE + result["key_hash"] = MASKED_ATTRIBUTE_VALUE + result["uid"] = MASKED_ATTRIBUTE_VALUE return result diff --git a/st2common/st2common/models/db/execution.py b/st2common/st2common/models/db/execution.py index 3e8f3c7742d..a44e5072d6e 100644 --- a/st2common/st2common/models/db/execution.py +++ b/st2common/st2common/models/db/execution.py @@ -27,10 +27,7 @@ from st2common.util.secrets import mask_secret_parameters from st2common.constants.types import ResourceType -__all__ = [ - 'ActionExecutionDB', - 'ActionExecutionOutputDB' -] +__all__ = ["ActionExecutionDB", "ActionExecutionOutputDB"] LOG = logging.getLogger(__name__) @@ -38,7 +35,7 @@ class ActionExecutionDB(stormbase.StormFoundationDB): RESOURCE_TYPE = ResourceType.EXECUTION - UID_FIELDS = ['id'] + UID_FIELDS = ["id"] trigger = stormbase.EscapedDictField() trigger_type = stormbase.EscapedDictField() @@ -52,22 +49,25 @@ class ActionExecutionDB(stormbase.StormFoundationDB): workflow_execution = me.StringField() task_execution = me.StringField() status = me.StringField( - required=True, - help_text='The current status of the liveaction.') + required=True, help_text="The current status of the liveaction." + ) start_timestamp = ComplexDateTimeField( default=date_utils.get_datetime_utc_now, - help_text='The timestamp when the liveaction was created.') + help_text="The timestamp when the liveaction was created.", + ) end_timestamp = ComplexDateTimeField( - help_text='The timestamp when the liveaction has finished.') + help_text="The timestamp when the liveaction has finished." + ) parameters = stormbase.EscapedDynamicField( default={}, - help_text='The key-value pairs passed as to the action runner & action.') + help_text="The key-value pairs passed as to the action runner & action.", + ) result = stormbase.EscapedDynamicField( - default={}, - help_text='Action defined result.') + default={}, help_text="Action defined result." + ) context = me.DictField( - default={}, - help_text='Contextual information on the action execution.') + default={}, help_text="Contextual information on the action execution." + ) parent = me.StringField() children = me.ListField(field=me.StringField()) log = me.ListField(field=me.DictField()) @@ -76,49 +76,51 @@ class ActionExecutionDB(stormbase.StormFoundationDB): web_url = me.StringField(required=False) meta = { - 'indexes': [ - {'fields': ['rule.ref']}, - {'fields': ['action.ref']}, - {'fields': ['liveaction.id']}, - {'fields': ['start_timestamp']}, - {'fields': ['end_timestamp']}, - {'fields': ['status']}, - {'fields': ['parent']}, - {'fields': ['rule.name']}, - {'fields': ['runner.name']}, - {'fields': ['trigger.name']}, - {'fields': ['trigger_type.name']}, - {'fields': ['trigger_instance.id']}, - {'fields': ['context.user']}, - {'fields': ['-start_timestamp', 'action.ref', 'status']}, - {'fields': ['workflow_execution']}, - {'fields': ['task_execution']} + "indexes": [ + {"fields": ["rule.ref"]}, + {"fields": ["action.ref"]}, + {"fields": ["liveaction.id"]}, + {"fields": ["start_timestamp"]}, + {"fields": ["end_timestamp"]}, + {"fields": ["status"]}, + {"fields": ["parent"]}, + {"fields": ["rule.name"]}, + {"fields": ["runner.name"]}, + {"fields": ["trigger.name"]}, + {"fields": ["trigger_type.name"]}, + {"fields": ["trigger_instance.id"]}, + {"fields": ["context.user"]}, + {"fields": ["-start_timestamp", "action.ref", "status"]}, + {"fields": ["workflow_execution"]}, + {"fields": ["task_execution"]}, ] } def get_uid(self): # TODO Construct id from non id field: uid = [self.RESOURCE_TYPE, str(self.id)] # pylint: disable=no-member - return ':'.join(uid) + return ":".join(uid) def mask_secrets(self, value): result = copy.deepcopy(value) - liveaction = result['liveaction'] + liveaction = result["liveaction"] parameters = {} # pylint: disable=no-member - parameters.update(value.get('action', {}).get('parameters', {})) - parameters.update(value.get('runner', {}).get('runner_parameters', {})) + parameters.update(value.get("action", {}).get("parameters", {})) + parameters.update(value.get("runner", {}).get("runner_parameters", {})) secret_parameters = get_secret_parameters(parameters=parameters) - result['parameters'] = mask_secret_parameters(parameters=result.get('parameters', {}), - secret_parameters=secret_parameters) + result["parameters"] = mask_secret_parameters( + parameters=result.get("parameters", {}), secret_parameters=secret_parameters + ) - if 'parameters' in liveaction: - liveaction['parameters'] = mask_secret_parameters(parameters=liveaction['parameters'], - secret_parameters=secret_parameters) + if "parameters" in liveaction: + liveaction["parameters"] = mask_secret_parameters( + parameters=liveaction["parameters"], secret_parameters=secret_parameters + ) - if liveaction.get('action', '') == 'st2.inquiry.respond': + if liveaction.get("action", "") == "st2.inquiry.respond": # Special case to mask parameters for `st2.inquiry.respond` action # In this case, this execution is just a plain python action, not # an inquiry, so we don't natively have a handle on the response @@ -130,22 +132,24 @@ def mask_secrets(self, value): # it's just a placeholder to tell mask_secret_parameters() # that this parameter is indeed a secret parameter and to # mask it. - result['parameters']['response'] = mask_secret_parameters( - parameters=liveaction['parameters']['response'], - secret_parameters={p: 'string' for p in liveaction['parameters']['response']} + result["parameters"]["response"] = mask_secret_parameters( + parameters=liveaction["parameters"]["response"], + secret_parameters={ + p: "string" for p in liveaction["parameters"]["response"] + }, ) # TODO(mierdin): This logic should be moved to the dedicated Inquiry # data model once it exists. - if self.runner.get('name') == "inquirer": + if self.runner.get("name") == "inquirer": - schema = result['result'].get('schema', {}) - response = result['result'].get('response', {}) + schema = result["result"].get("schema", {}) + response = result["result"].get("response", {}) # We can only mask response secrets if response and schema exist and are # not empty if response and schema: - result['result']['response'] = mask_inquiry_response(response, schema) + result["result"]["response"] = mask_inquiry_response(response, schema) return result def get_masked_parameters(self): @@ -155,7 +159,7 @@ def get_masked_parameters(self): :rtype: ``dict`` """ serializable_dict = self.to_serializable_dict(mask_secrets=True) - return serializable_dict['parameters'] + return serializable_dict["parameters"] class ActionExecutionOutputDB(stormbase.StormFoundationDB): @@ -174,22 +178,25 @@ class ActionExecutionOutputDB(stormbase.StormFoundationDB): data: Actual output data. This could either be line, chunk or similar, depending on the runner. """ + execution_id = me.StringField(required=True) action_ref = me.StringField(required=True) runner_ref = me.StringField(required=True) - timestamp = ComplexDateTimeField(required=True, default=date_utils.get_datetime_utc_now) - output_type = me.StringField(required=True, default='output') + timestamp = ComplexDateTimeField( + required=True, default=date_utils.get_datetime_utc_now + ) + output_type = me.StringField(required=True, default="output") delay = me.IntField() data = me.StringField() meta = { - 'indexes': [ - {'fields': ['execution_id']}, - {'fields': ['action_ref']}, - {'fields': ['runner_ref']}, - {'fields': ['timestamp']}, - {'fields': ['output_type']} + "indexes": [ + {"fields": ["execution_id"]}, + {"fields": ["action_ref"]}, + {"fields": ["runner_ref"]}, + {"fields": ["timestamp"]}, + {"fields": ["output_type"]}, ] } diff --git a/st2common/st2common/models/db/execution_queue.py b/st2common/st2common/models/db/execution_queue.py index 31dcebbd1a3..8db09933631 100644 --- a/st2common/st2common/models/db/execution_queue.py +++ b/st2common/st2common/models/db/execution_queue.py @@ -25,15 +25,16 @@ from st2common.constants.types import ResourceType __all__ = [ - 'ActionExecutionSchedulingQueueItemDB', + "ActionExecutionSchedulingQueueItemDB", ] LOG = logging.getLogger(__name__) -class ActionExecutionSchedulingQueueItemDB(stormbase.StormFoundationDB, - stormbase.ChangeRevisionFieldMixin): +class ActionExecutionSchedulingQueueItemDB( + stormbase.StormFoundationDB, stormbase.ChangeRevisionFieldMixin +): """ A model which represents a request for execution to be scheduled. @@ -42,36 +43,45 @@ class ActionExecutionSchedulingQueueItemDB(stormbase.StormFoundationDB, """ RESOURCE_TYPE = ResourceType.EXECUTION_REQUEST - UID_FIELDS = ['id'] + UID_FIELDS = ["id"] - liveaction_id = me.StringField(required=True, - help_text='Foreign key to the LiveActionDB which is to be scheduled') + liveaction_id = me.StringField( + required=True, + help_text="Foreign key to the LiveActionDB which is to be scheduled", + ) action_execution_id = me.StringField( - help_text='Foreign key to the ActionExecutionDB which is to be scheduled') + help_text="Foreign key to the ActionExecutionDB which is to be scheduled" + ) original_start_timestamp = ComplexDateTimeField( default=date_utils.get_datetime_utc_now, - help_text='The timestamp when the liveaction was created and originally be scheduled to ' - 'run.') + help_text="The timestamp when the liveaction was created and originally be scheduled to " + "run.", + ) scheduled_start_timestamp = ComplexDateTimeField( default=date_utils.get_datetime_utc_now, - help_text='The timestamp when liveaction is scheduled to run.') + help_text="The timestamp when liveaction is scheduled to run.", + ) delay = me.IntField() - handling = me.BooleanField(default=False, - help_text='Flag indicating if this item is currently being handled / ' - 'processed by a scheduler service') + handling = me.BooleanField( + default=False, + help_text="Flag indicating if this item is currently being handled / " + "processed by a scheduler service", + ) meta = { - 'indexes': [ + "indexes": [ # NOTE: We limit index names to 65 characters total for compatibility with AWS # DocumentDB. # See https://github.com/StackStorm/st2/pull/4690 for details. - {'fields': ['action_execution_id'], 'name': 'ac_exc_id'}, - {'fields': ['liveaction_id'], 'name': 'lv_ac_id'}, - {'fields': ['original_start_timestamp'], 'name': 'orig_s_ts'}, - {'fields': ['scheduled_start_timestamp'], 'name': 'schd_s_ts'}, + {"fields": ["action_execution_id"], "name": "ac_exc_id"}, + {"fields": ["liveaction_id"], "name": "lv_ac_id"}, + {"fields": ["original_start_timestamp"], "name": "orig_s_ts"}, + {"fields": ["scheduled_start_timestamp"], "name": "schd_s_ts"}, ] } MODELS = [ActionExecutionSchedulingQueueItemDB] -EXECUTION_QUEUE_ACCESS = ChangeRevisionMongoDBAccess(ActionExecutionSchedulingQueueItemDB) +EXECUTION_QUEUE_ACCESS = ChangeRevisionMongoDBAccess( + ActionExecutionSchedulingQueueItemDB +) diff --git a/st2common/st2common/models/db/executionstate.py b/st2common/st2common/models/db/executionstate.py index db949b66581..94b883038d3 100644 --- a/st2common/st2common/models/db/executionstate.py +++ b/st2common/st2common/models/db/executionstate.py @@ -21,33 +21,32 @@ from st2common.models.db import stormbase __all__ = [ - 'ActionExecutionStateDB', + "ActionExecutionStateDB", ] LOG = logging.getLogger(__name__) -PACK_SEPARATOR = '.' +PACK_SEPARATOR = "." class ActionExecutionStateDB(stormbase.StormFoundationDB): """ - Database entity that represents the state of Action execution. + Database entity that represents the state of Action execution. """ + execution_id = me.ObjectIdField( - required=True, - unique=True, - help_text='liveaction ID.') + required=True, unique=True, help_text="liveaction ID." + ) query_module = me.StringField( - required=True, - help_text='Reference to the runner model.') + required=True, help_text="Reference to the runner model." + ) query_context = me.DictField( required=True, - help_text='Context about the action execution that is needed for results query.') + help_text="Context about the action execution that is needed for results query.", + ) - meta = { - 'indexes': ['query_module'] - } + meta = {"indexes": ["query_module"]} # specialized access objects diff --git a/st2common/st2common/models/db/keyvalue.py b/st2common/st2common/models/db/keyvalue.py index debe58ebbb0..ea7fda3b9d3 100644 --- a/st2common/st2common/models/db/keyvalue.py +++ b/st2common/st2common/models/db/keyvalue.py @@ -21,9 +21,7 @@ from st2common.models.db import MongoDBAccess from st2common.models.db import stormbase -__all__ = [ - 'KeyValuePairDB' -] +__all__ = ["KeyValuePairDB"] class KeyValuePairDB(stormbase.StormBaseDB, stormbase.UIDFieldMixin): @@ -34,22 +32,20 @@ class KeyValuePairDB(stormbase.StormBaseDB, stormbase.UIDFieldMixin): """ RESOURCE_TYPE = ResourceType.KEY_VALUE_PAIR - UID_FIELDS = ['scope', 'name'] + UID_FIELDS = ["scope", "name"] - scope = me.StringField(default=FULL_SYSTEM_SCOPE, unique_with='name') + scope = me.StringField(default=FULL_SYSTEM_SCOPE, unique_with="name") name = me.StringField(required=True) value = me.StringField() secret = me.BooleanField(default=False) expire_timestamp = me.DateTimeField() meta = { - 'indexes': [ - {'fields': ['name']}, - { - 'fields': ['expire_timestamp'], - 'expireAfterSeconds': 0 - } - ] + stormbase.UIDFieldMixin.get_indexes() + "indexes": [ + {"fields": ["name"]}, + {"fields": ["expire_timestamp"], "expireAfterSeconds": 0}, + ] + + stormbase.UIDFieldMixin.get_indexes() } def __init__(self, *args, **values): diff --git a/st2common/st2common/models/db/liveaction.py b/st2common/st2common/models/db/liveaction.py index 6bc5fd77fa0..29f5a13bfc0 100644 --- a/st2common/st2common/models/db/liveaction.py +++ b/st2common/st2common/models/db/liveaction.py @@ -28,12 +28,12 @@ from st2common.util.secrets import mask_secret_parameters __all__ = [ - 'LiveActionDB', + "LiveActionDB", ] LOG = logging.getLogger(__name__) -PACK_SEPARATOR = '.' +PACK_SEPARATOR = "." class LiveActionDB(stormbase.StormFoundationDB): @@ -41,50 +41,56 @@ class LiveActionDB(stormbase.StormFoundationDB): task_execution = me.StringField() # TODO: Can status be an enum at the Mongo layer? status = me.StringField( - required=True, - help_text='The current status of the liveaction.') + required=True, help_text="The current status of the liveaction." + ) start_timestamp = ComplexDateTimeField( default=date_utils.get_datetime_utc_now, - help_text='The timestamp when the liveaction was created.') + help_text="The timestamp when the liveaction was created.", + ) end_timestamp = ComplexDateTimeField( - help_text='The timestamp when the liveaction has finished.') + help_text="The timestamp when the liveaction has finished." + ) action = me.StringField( - required=True, - help_text='Reference to the action that has to be executed.') + required=True, help_text="Reference to the action that has to be executed." + ) action_is_workflow = me.BooleanField( default=False, - help_text='A flag indicating whether the referenced action is a workflow.') + help_text="A flag indicating whether the referenced action is a workflow.", + ) parameters = stormbase.EscapedDynamicField( default={}, - help_text='The key-value pairs passed as to the action runner & execution.') + help_text="The key-value pairs passed as to the action runner & execution.", + ) result = stormbase.EscapedDynamicField( - default={}, - help_text='Action defined result.') + default={}, help_text="Action defined result." + ) context = me.DictField( - default={}, - help_text='Contextual information on the action execution.') + default={}, help_text="Contextual information on the action execution." + ) callback = me.DictField( default={}, - help_text='Callback information for the on completion of action execution.') + help_text="Callback information for the on completion of action execution.", + ) runner_info = me.DictField( default={}, - help_text='Information about the runner which executed this live action (hostname, pid).') + help_text="Information about the runner which executed this live action (hostname, pid).", + ) notify = me.EmbeddedDocumentField(NotificationSchema) delay = me.IntField( min_value=0, - help_text='How long (in milliseconds) to delay the execution before scheduling.' + help_text="How long (in milliseconds) to delay the execution before scheduling.", ) meta = { - 'indexes': [ - {'fields': ['-start_timestamp', 'action']}, - {'fields': ['start_timestamp']}, - {'fields': ['end_timestamp']}, - {'fields': ['action']}, - {'fields': ['status']}, - {'fields': ['context.trigger_instance.id']}, - {'fields': ['workflow_execution']}, - {'fields': ['task_execution']} + "indexes": [ + {"fields": ["-start_timestamp", "action"]}, + {"fields": ["start_timestamp"]}, + {"fields": ["end_timestamp"]}, + {"fields": ["action"]}, + {"fields": ["status"]}, + {"fields": ["context.trigger_instance.id"]}, + {"fields": ["workflow_execution"]}, + {"fields": ["task_execution"]}, ] } @@ -92,7 +98,7 @@ def mask_secrets(self, value): from st2common.util import action_db result = copy.deepcopy(value) - execution_parameters = value['parameters'] + execution_parameters = value["parameters"] # TODO: This results into two DB looks, we should cache action and runner type object # for each liveaction... @@ -104,8 +110,9 @@ def mask_secrets(self, value): parameters = action_db.get_action_parameters_specs(action_ref=self.action) secret_parameters = get_secret_parameters(parameters=parameters) - result['parameters'] = mask_secret_parameters(parameters=execution_parameters, - secret_parameters=secret_parameters) + result["parameters"] = mask_secret_parameters( + parameters=execution_parameters, secret_parameters=secret_parameters + ) return result def get_masked_parameters(self): @@ -115,7 +122,7 @@ def get_masked_parameters(self): :rtype: ``dict`` """ serializable_dict = self.to_serializable_dict(mask_secrets=True) - return serializable_dict['parameters'] + return serializable_dict["parameters"] # specialized access objects diff --git a/st2common/st2common/models/db/marker.py b/st2common/st2common/models/db/marker.py index 1bddf3f6047..7a053e54904 100644 --- a/st2common/st2common/models/db/marker.py +++ b/st2common/st2common/models/db/marker.py @@ -20,10 +20,7 @@ from st2common.models.db import stormbase from st2common.util import date as date_utils -__all__ = [ - 'MarkerDB', - 'DumperMarkerDB' -] +__all__ = ["MarkerDB", "DumperMarkerDB"] class MarkerDB(stormbase.StormFoundationDB): @@ -37,20 +34,21 @@ class MarkerDB(stormbase.StormFoundationDB): :param updated_at: Timestamp when marker was updated. :type updated_at: ``datetime.datetime`` """ + marker = me.StringField(required=True) updated_at = ComplexDateTimeField( default=date_utils.get_datetime_utc_now, - help_text='The timestamp when the liveaction was created.') + help_text="The timestamp when the liveaction was created.", + ) - meta = { - 'abstract': True - } + meta = {"abstract": True} class DumperMarkerDB(MarkerDB): """ Marker model used by Dumper (in exporter). """ + pass diff --git a/st2common/st2common/models/db/notification.py b/st2common/st2common/models/db/notification.py index e311f46b75f..8ef793887ba 100644 --- a/st2common/st2common/models/db/notification.py +++ b/st2common/st2common/models/db/notification.py @@ -21,43 +21,47 @@ class NotificationSubSchema(me.EmbeddedDocument): """ - Schema for notification settings to be specified for action success/failure. + Schema for notification settings to be specified for action success/failure. """ + message = me.StringField() data = stormbase.EscapedDynamicField( - default={}, - help_text='Payload to be sent as part of notification.') + default={}, help_text="Payload to be sent as part of notification." + ) routes = me.ListField( - default=['notify.default'], - help_text='Routes to post notifications to.') - channels = me.ListField( # Deprecated. Only here for backward compatibility reasons. - default=['notify.default'], - help_text='Routes to post notifications to.') + default=["notify.default"], help_text="Routes to post notifications to." + ) + channels = ( + me.ListField( # Deprecated. Only here for backward compatibility reasons. + default=["notify.default"], help_text="Routes to post notifications to." + ) + ) def __str__(self): result = [] - result.append('NotificationSubSchema@') + result.append("NotificationSubSchema@") result.append(str(id(self))) result.append('(message="%s", ' % str(self.message)) result.append('data="%s", ' % str(self.data)) result.append('routes="%s", ' % str(self.routes)) result.append('[**deprecated**]channels="%s")' % str(self.channels)) - return ''.join(result) + return "".join(result) class NotificationSchema(me.EmbeddedDocument): """ - Schema for notification settings to be specified for actions. + Schema for notification settings to be specified for actions. """ + on_success = me.EmbeddedDocumentField(NotificationSubSchema) on_failure = me.EmbeddedDocumentField(NotificationSubSchema) on_complete = me.EmbeddedDocumentField(NotificationSubSchema) def __str__(self): result = [] - result.append('NotifySchema@') + result.append("NotifySchema@") result.append(str(id(self))) result.append('(on_complete="%s", ' % str(self.on_complete)) result.append('on_success="%s", ' % str(self.on_success)) result.append('on_failure="%s")' % str(self.on_failure)) - return ''.join(result) + return "".join(result) diff --git a/st2common/st2common/models/db/pack.py b/st2common/st2common/models/db/pack.py index cf169109874..c92b0096243 100644 --- a/st2common/st2common/models/db/pack.py +++ b/st2common/st2common/models/db/pack.py @@ -25,21 +25,16 @@ from st2common.util.secrets import get_secret_parameters from st2common.util.secrets import mask_secret_parameters -__all__ = [ - 'PackDB', - 'ConfigSchemaDB', - 'ConfigDB' -] +__all__ = ["PackDB", "ConfigSchemaDB", "ConfigDB"] -class PackDB(stormbase.StormFoundationDB, stormbase.UIDFieldMixin, - me.DynamicDocument): +class PackDB(stormbase.StormFoundationDB, stormbase.UIDFieldMixin, me.DynamicDocument): """ System entity which represents a pack. """ RESOURCE_TYPE = ResourceType.PACK - UID_FIELDS = ['ref'] + UID_FIELDS = ["ref"] ref = me.StringField(required=True, unique=True) name = me.StringField(required=True, unique=True) @@ -56,9 +51,7 @@ class PackDB(stormbase.StormFoundationDB, stormbase.UIDFieldMixin, dependencies = me.ListField(field=me.StringField()) system = me.DictField() - meta = { - 'indexes': stormbase.UIDFieldMixin.get_indexes() - } + meta = {"indexes": stormbase.UIDFieldMixin.get_indexes()} def __init__(self, *args, **values): super(PackDB, self).__init__(*args, **values) @@ -73,22 +66,24 @@ class ConfigSchemaDB(stormbase.StormFoundationDB): pack = me.StringField( required=True, unique=True, - help_text='Name of the content pack this schema belongs to.') + help_text="Name of the content pack this schema belongs to.", + ) attributes = stormbase.EscapedDynamicField( - help_text='The specification for config schema attributes.') + help_text="The specification for config schema attributes." + ) class ConfigDB(stormbase.StormFoundationDB): """ System entity representing pack config. """ + pack = me.StringField( required=True, unique=True, - help_text='Name of the content pack this config belongs to.') - values = stormbase.EscapedDynamicField( - help_text='Config values.', - default={}) + help_text="Name of the content pack this config belongs to.", + ) + values = stormbase.EscapedDynamicField(help_text="Config values.", default={}) def mask_secrets(self, value): """ @@ -101,11 +96,12 @@ def mask_secrets(self, value): """ result = copy.deepcopy(value) - config_schema = config_schema_access.get_by_pack(result['pack']) + config_schema = config_schema_access.get_by_pack(result["pack"]) secret_parameters = get_secret_parameters(parameters=config_schema.attributes) - result['values'] = mask_secret_parameters(parameters=result['values'], - secret_parameters=secret_parameters) + result["values"] = mask_secret_parameters( + parameters=result["values"], secret_parameters=secret_parameters + ) return result diff --git a/st2common/st2common/models/db/policy.py b/st2common/st2common/models/db/policy.py index 69f709093c9..8b9fcafef0a 100644 --- a/st2common/st2common/models/db/policy.py +++ b/st2common/st2common/models/db/policy.py @@ -23,9 +23,7 @@ from st2common.constants.types import ResourceType -__all__ = ['PolicyTypeReference', - 'PolicyTypeDB', - 'PolicyDB'] +__all__ = ["PolicyTypeReference", "PolicyTypeDB", "PolicyDB"] LOG = logging.getLogger(__name__) @@ -34,7 +32,8 @@ class PolicyTypeReference(object): """ Class used for referring to policy types which belong to a resource type. """ - separator = '.' + + separator = "." def __init__(self, resource_type=None, name=None): self.resource_type = self.validate_resource_type(resource_type) @@ -54,14 +53,15 @@ def is_reference(cls, ref): @classmethod def from_string_reference(cls, ref): - return cls(resource_type=cls.get_resource_type(ref), - name=cls.get_name(ref)) + return cls(resource_type=cls.get_resource_type(ref), name=cls.get_name(ref)) @classmethod def to_string_reference(cls, resource_type=None, name=None): if not resource_type or not name: - raise ValueError('Both resource_type and name are required for building ref. ' - 'resource_type=%s, name=%s' % (resource_type, name)) + raise ValueError( + "Both resource_type and name are required for building ref. " + "resource_type=%s, name=%s" % (resource_type, name) + ) resource_type = cls.validate_resource_type(resource_type) return cls.separator.join([resource_type, name]) @@ -69,7 +69,7 @@ def to_string_reference(cls, resource_type=None, name=None): @classmethod def validate_resource_type(cls, resource_type): if not resource_type: - raise ValueError('Resource type should not be empty.') + raise ValueError("Resource type should not be empty.") if cls.separator in resource_type: raise ValueError('Resource type should not contain "%s".' % cls.separator) @@ -80,7 +80,7 @@ def validate_resource_type(cls, resource_type): def get_resource_type(cls, ref): try: if not cls.is_reference(ref): - raise ValueError('%s is not a valid reference.' % ref) + raise ValueError("%s is not a valid reference." % ref) return ref.split(cls.separator, 1)[0] except (ValueError, IndexError, AttributeError): @@ -90,15 +90,19 @@ def get_resource_type(cls, ref): def get_name(cls, ref): try: if not cls.is_reference(ref): - raise ValueError('%s is not a valid reference.' % ref) + raise ValueError("%s is not a valid reference." % ref) return ref.split(cls.separator, 1)[1] except (ValueError, IndexError, AttributeError): raise common_models.InvalidReferenceError(ref=ref) def __repr__(self): - return ('<%s resource_type=%s,name=%s,ref=%s>' % - (self.__class__.__name__, self.resource_type, self.name, self.ref)) + return "<%s resource_type=%s,name=%s,ref=%s>" % ( + self.__class__.__name__, + self.resource_type, + self.name, + self.ref, + ) class PolicyTypeDB(stormbase.StormBaseDB, stormbase.UIDFieldMixin): @@ -114,29 +118,35 @@ class PolicyTypeDB(stormbase.StormBaseDB, stormbase.UIDFieldMixin): module: The python module that implements the policy for this type. parameters: The specification for parameters for the policy type. """ + RESOURCE_TYPE = ResourceType.POLICY_TYPE - UID_FIELDS = ['resource_type', 'name'] + UID_FIELDS = ["resource_type", "name"] ref = me.StringField(required=True) resource_type = me.StringField( required=True, - unique_with='name', - help_text='The type of resource that this policy type can be applied to.') + unique_with="name", + help_text="The type of resource that this policy type can be applied to.", + ) enabled = me.BooleanField( required=True, default=True, - help_text='A flag indicating whether the runner for this type is enabled.') + help_text="A flag indicating whether the runner for this type is enabled.", + ) module = me.StringField( required=True, - help_text='The python module that implements the policy for this type.') + help_text="The python module that implements the policy for this type.", + ) parameters = me.DictField( - help_text='The specification for parameters for the policy type.') + help_text="The specification for parameters for the policy type." + ) def __init__(self, *args, **kwargs): super(PolicyTypeDB, self).__init__(*args, **kwargs) self.uid = self.get_uid() - self.ref = PolicyTypeReference.to_string_reference(resource_type=self.resource_type, - name=self.name) + self.ref = PolicyTypeReference.to_string_reference( + resource_type=self.resource_type, name=self.name + ) def get_reference(self): """ @@ -147,8 +157,11 @@ def get_reference(self): return PolicyTypeReference(resource_type=self.resource_type, name=self.name) -class PolicyDB(stormbase.StormFoundationDB, stormbase.ContentPackResourceMixin, - stormbase.UIDFieldMixin): +class PolicyDB( + stormbase.StormFoundationDB, + stormbase.ContentPackResourceMixin, + stormbase.UIDFieldMixin, +): """ The representation for a policy in the system. @@ -158,43 +171,47 @@ class PolicyDB(stormbase.StormFoundationDB, stormbase.ContentPackResourceMixin, policy_type: The type of policy. parameters: The specification of input parameters for the policy. """ + RESOURCE_TYPE = ResourceType.POLICY - UID_FIELDS = ['pack', 'name'] + UID_FIELDS = ["pack", "name"] name = me.StringField(required=True) ref = me.StringField(required=True) pack = me.StringField( required=False, default=pack_constants.DEFAULT_PACK_NAME, - unique_with='name', - help_text='Name of the content pack.') + unique_with="name", + help_text="Name of the content pack.", + ) description = me.StringField() enabled = me.BooleanField( required=True, default=True, - help_text='A flag indicating whether this policy is enabled in the system.') + help_text="A flag indicating whether this policy is enabled in the system.", + ) resource_ref = me.StringField( - required=True, - help_text='The resource that this policy is applied to.') + required=True, help_text="The resource that this policy is applied to." + ) policy_type = me.StringField( - required=True, - unique_with='resource_ref', - help_text='The type of policy.') + required=True, unique_with="resource_ref", help_text="The type of policy." + ) parameters = me.DictField( - help_text='The specification of input parameters for the policy.') + help_text="The specification of input parameters for the policy." + ) meta = { - 'indexes': [ - {'fields': ['name']}, - {'fields': ['resource_ref']}, + "indexes": [ + {"fields": ["name"]}, + {"fields": ["resource_ref"]}, ] } def __init__(self, *args, **kwargs): super(PolicyDB, self).__init__(*args, **kwargs) self.uid = self.get_uid() - self.ref = common_models.ResourceReference.to_string_reference(pack=self.pack, - name=self.name) + self.ref = common_models.ResourceReference.to_string_reference( + pack=self.pack, name=self.name + ) MODELS = [PolicyTypeDB, PolicyDB] diff --git a/st2common/st2common/models/db/rbac.py b/st2common/st2common/models/db/rbac.py index 68b41ea3142..bb82ba88cb3 100644 --- a/st2common/st2common/models/db/rbac.py +++ b/st2common/st2common/models/db/rbac.py @@ -21,14 +21,13 @@ __all__ = [ - 'RoleDB', - 'UserRoleAssignmentDB', - 'PermissionGrantDB', - 'GroupToRoleMappingDB', - - 'role_access', - 'user_role_assignment_access', - 'permission_grant_access' + "RoleDB", + "UserRoleAssignmentDB", + "PermissionGrantDB", + "GroupToRoleMappingDB", + "role_access", + "user_role_assignment_access", + "permission_grant_access", ] @@ -43,15 +42,16 @@ class RoleDB(stormbase.StormFoundationDB): permission_grants: A list of IDs to the permission grant which apply to this role. """ + name = me.StringField(required=True, unique=True) description = me.StringField() system = me.BooleanField(default=False) permission_grants = me.ListField(field=me.StringField()) meta = { - 'indexes': [ - {'fields': ['name']}, - {'fields': ['system']}, + "indexes": [ + {"fields": ["name"]}, + {"fields": ["system"]}, ] } @@ -67,9 +67,10 @@ class UserRoleAssignmentDB(stormbase.StormFoundationDB): and "API" for API assignments. description: Optional assigment description. """ + user = me.StringField(required=True) - role = me.StringField(required=True, unique_with=['user', 'source']) - source = me.StringField(required=True, unique_with=['user', 'role']) + role = me.StringField(required=True, unique_with=["user", "source"]) + source = me.StringField(required=True, unique_with=["user", "role"]) description = me.StringField() # True if this is assigned created on authentication based on the remote groups provided by # the auth backends. @@ -78,12 +79,12 @@ class UserRoleAssignmentDB(stormbase.StormFoundationDB): is_remote = me.BooleanField(default=False) meta = { - 'indexes': [ - {'fields': ['user']}, - {'fields': ['role']}, - {'fields': ['source']}, - {'fields': ['is_remote']}, - {'fields': ['user', 'role']}, + "indexes": [ + {"fields": ["user"]}, + {"fields": ["role"]}, + {"fields": ["source"]}, + {"fields": ["is_remote"]}, + {"fields": ["user", "role"]}, ] } @@ -98,13 +99,14 @@ class PermissionGrantDB(stormbase.StormFoundationDB): convenience and to allow for more efficient queries. permission_types: A list of permission type granted to that resources. """ + resource_uid = me.StringField(required=False) resource_type = me.StringField(required=False) permission_types = me.ListField(field=me.StringField()) meta = { - 'indexes': [ - {'fields': ['resource_uid']}, + "indexes": [ + {"fields": ["resource_uid"]}, ] } @@ -120,12 +122,16 @@ class GroupToRoleMappingDB(stormbase.StormFoundationDB): and "API" for API assignments. description: Optional description for this mapping. """ + group = me.StringField(required=True, unique=True) roles = me.ListField(field=me.StringField()) source = me.StringField() description = me.StringField() - enabled = me.BooleanField(required=True, default=True, - help_text='A flag indicating whether the mapping is enabled.') + enabled = me.BooleanField( + required=True, + default=True, + help_text="A flag indicating whether the mapping is enabled.", + ) # Specialized access objects diff --git a/st2common/st2common/models/db/reactor.py b/st2common/st2common/models/db/reactor.py index dc9f08b58ef..8b8032654b1 100644 --- a/st2common/st2common/models/db/reactor.py +++ b/st2common/st2common/models/db/reactor.py @@ -14,18 +14,17 @@ # limitations under the License. from __future__ import absolute_import -from st2common.models.db.rule import (ActionExecutionSpecDB, RuleDB) +from st2common.models.db.rule import ActionExecutionSpecDB, RuleDB from st2common.models.db.sensor import SensorTypeDB -from st2common.models.db.trigger import (TriggerDB, TriggerTypeDB, TriggerInstanceDB) +from st2common.models.db.trigger import TriggerDB, TriggerTypeDB, TriggerInstanceDB __all__ = [ - 'ActionExecutionSpecDB', - 'RuleDB', - 'SensorTypeDB', - 'TriggerTypeDB', - 'TriggerDB', - 'TriggerInstanceDB' + "ActionExecutionSpecDB", + "RuleDB", + "SensorTypeDB", + "TriggerTypeDB", + "TriggerDB", + "TriggerInstanceDB", ] -MODELS = [RuleDB, SensorTypeDB, TriggerDB, TriggerInstanceDB, - TriggerTypeDB] +MODELS = [RuleDB, SensorTypeDB, TriggerDB, TriggerInstanceDB, TriggerTypeDB] diff --git a/st2common/st2common/models/db/rule.py b/st2common/st2common/models/db/rule.py index f056734f8cb..f4f26ec6690 100644 --- a/st2common/st2common/models/db/rule.py +++ b/st2common/st2common/models/db/rule.py @@ -28,25 +28,24 @@ class RuleTypeDB(stormbase.StormBaseDB): enabled = me.BooleanField( default=True, - help_text='A flag indicating whether the runner for this type is enabled.') + help_text="A flag indicating whether the runner for this type is enabled.", + ) parameters = me.DictField( - help_text='The specification for parameters for the action.', - default={}) + help_text="The specification for parameters for the action.", default={} + ) class RuleTypeSpecDB(me.EmbeddedDocument): - ref = me.StringField(unique=False, - help_text='Type of rule.', - default='standard') + ref = me.StringField(unique=False, help_text="Type of rule.", default="standard") parameters = me.DictField(default={}) def __str__(self): result = [] - result.append('RuleTypeSpecDB@') + result.append("RuleTypeSpecDB@") result.append(str(id(self))) result.append('(ref="%s", ' % self.ref) result.append('parameters="%s")' % self.parameters) - return ''.join(result) + return "".join(result) class ActionExecutionSpecDB(me.EmbeddedDocument): @@ -55,15 +54,19 @@ class ActionExecutionSpecDB(me.EmbeddedDocument): def __str__(self): result = [] - result.append('ActionExecutionSpecDB@') + result.append("ActionExecutionSpecDB@") result.append(str(id(self))) result.append('(ref="%s", ' % self.ref) result.append('parameters="%s")' % self.parameters) - return ''.join(result) + return "".join(result) -class RuleDB(stormbase.StormFoundationDB, stormbase.TagsMixin, - stormbase.ContentPackResourceMixin, stormbase.UIDFieldMixin): +class RuleDB( + stormbase.StormFoundationDB, + stormbase.TagsMixin, + stormbase.ContentPackResourceMixin, + stormbase.UIDFieldMixin, +): """Specifies the action to invoke on the occurrence of a Trigger. It also includes the transformation to perform to match the impedance between the payload of a TriggerInstance and input of a action. @@ -74,36 +77,39 @@ class RuleDB(stormbase.StormFoundationDB, stormbase.TagsMixin, status: enabled or disabled. If disabled occurrence of the trigger does not lead to execution of a action and vice-versa. """ + RESOURCE_TYPE = ResourceType.RULE - UID_FIELDS = ['pack', 'name'] + UID_FIELDS = ["pack", "name"] name = me.StringField(required=True) ref = me.StringField(required=True) description = me.StringField() pack = me.StringField( - required=False, - help_text='Name of the content pack.', - unique_with='name') + required=False, help_text="Name of the content pack.", unique_with="name" + ) type = me.EmbeddedDocumentField(RuleTypeSpecDB, default=RuleTypeSpecDB()) trigger = me.StringField() criteria = stormbase.EscapedDictField() action = me.EmbeddedDocumentField(ActionExecutionSpecDB) - context = me.DictField( - default={}, - help_text='Contextual info on the rule' + context = me.DictField(default={}, help_text="Contextual info on the rule") + enabled = me.BooleanField( + required=True, + default=True, + help_text="Flag indicating whether the rule is enabled.", ) - enabled = me.BooleanField(required=True, default=True, - help_text=u'Flag indicating whether the rule is enabled.') meta = { - 'indexes': [ - {'fields': ['enabled']}, - {'fields': ['action.ref']}, - {'fields': ['trigger']}, - {'fields': ['context.user']}, - ] + (stormbase.ContentPackResourceMixin.get_indexes() + - stormbase.TagsMixin.get_indexes() + - stormbase.UIDFieldMixin.get_indexes()) + "indexes": [ + {"fields": ["enabled"]}, + {"fields": ["action.ref"]}, + {"fields": ["trigger"]}, + {"fields": ["context.user"]}, + ] + + ( + stormbase.ContentPackResourceMixin.get_indexes() + + stormbase.TagsMixin.get_indexes() + + stormbase.UIDFieldMixin.get_indexes() + ) } def mask_secrets(self, value): @@ -120,7 +126,7 @@ def mask_secrets(self, value): """ result = copy.deepcopy(value) - action_ref = result.get('action', {}).get('ref', None) + action_ref = result.get("action", {}).get("ref", None) if not action_ref: return result @@ -131,9 +137,10 @@ def mask_secrets(self, value): return result secret_parameters = get_secret_parameters(parameters=action_db.parameters) - result['action']['parameters'] = mask_secret_parameters( - parameters=result['action']['parameters'], - secret_parameters=secret_parameters) + result["action"]["parameters"] = mask_secret_parameters( + parameters=result["action"]["parameters"], + secret_parameters=secret_parameters, + ) return result @@ -147,8 +154,9 @@ def _get_referenced_action_model(self, action_ref): :rtype: ``ActionDB`` """ # NOTE: We need to retrieve pack and name since that's needed for the PK - action_dbs = Action.query(only_fields=['pack', 'ref', 'name', 'parameters'], - ref=action_ref, limit=1) + action_dbs = Action.query( + only_fields=["pack", "ref", "name", "parameters"], ref=action_ref, limit=1 + ) if action_dbs: return action_dbs[0] diff --git a/st2common/st2common/models/db/rule_enforcement.py b/st2common/st2common/models/db/rule_enforcement.py index 80ea1f14fef..62d2a21faf0 100644 --- a/st2common/st2common/models/db/rule_enforcement.py +++ b/st2common/st2common/models/db/rule_enforcement.py @@ -24,34 +24,27 @@ from st2common.constants.rule_enforcement import RULE_ENFORCEMENT_STATUS_SUCCEEDED from st2common.constants.rule_enforcement import RULE_ENFORCEMENT_STATUS_FAILED -__all__ = [ - 'RuleReferenceSpecDB', - 'RuleEnforcementDB' -] +__all__ = ["RuleReferenceSpecDB", "RuleEnforcementDB"] class RuleReferenceSpecDB(me.EmbeddedDocument): - ref = me.StringField(unique=False, - help_text='Reference to rule.', - required=True) - id = me.StringField(required=False, - help_text='Rule ID.') - uid = me.StringField(required=True, - help_text='Rule UID.') + ref = me.StringField(unique=False, help_text="Reference to rule.", required=True) + id = me.StringField(required=False, help_text="Rule ID.") + uid = me.StringField(required=True, help_text="Rule UID.") def __str__(self): result = [] - result.append('RuleReferenceSpecDB@') + result.append("RuleReferenceSpecDB@") result.append(str(id(self))) result.append('(ref="%s", ' % self.ref) result.append('id="%s", ' % self.id) result.append('uid="%s")' % self.uid) - return ''.join(result) + return "".join(result) class RuleEnforcementDB(stormbase.StormFoundationDB, stormbase.TagsMixin): - UID_FIELDS = ['id'] + UID_FIELDS = ["id"] trigger_instance_id = me.StringField(required=True) execution_id = me.StringField(required=False) @@ -59,31 +52,34 @@ class RuleEnforcementDB(stormbase.StormFoundationDB, stormbase.TagsMixin): rule = me.EmbeddedDocumentField(RuleReferenceSpecDB, required=True) enforced_at = ComplexDateTimeField( default=date_utils.get_datetime_utc_now, - help_text='The timestamp when the rule enforcement happened.') + help_text="The timestamp when the rule enforcement happened.", + ) status = me.StringField( required=True, default=RULE_ENFORCEMENT_STATUS_SUCCEEDED, - help_text='Rule enforcement status.') + help_text="Rule enforcement status.", + ) meta = { - 'indexes': [ - {'fields': ['trigger_instance_id']}, - {'fields': ['execution_id']}, - {'fields': ['rule.id']}, - {'fields': ['rule.ref']}, - {'fields': ['enforced_at']}, - {'fields': ['-enforced_at']}, - {'fields': ['-enforced_at', 'rule.ref']}, - {'fields': ['status']}, - ] + stormbase.TagsMixin.get_indexes() + "indexes": [ + {"fields": ["trigger_instance_id"]}, + {"fields": ["execution_id"]}, + {"fields": ["rule.id"]}, + {"fields": ["rule.ref"]}, + {"fields": ["enforced_at"]}, + {"fields": ["-enforced_at"]}, + {"fields": ["-enforced_at", "rule.ref"]}, + {"fields": ["status"]}, + ] + + stormbase.TagsMixin.get_indexes() } def __init__(self, *args, **values): super(RuleEnforcementDB, self).__init__(*args, **values) # Set status to succeeded for old / existing RuleEnforcementDB which predate status field - status = getattr(self, 'status', None) - failure_reason = getattr(self, 'failure_reason', None) + status = getattr(self, "status", None) + failure_reason = getattr(self, "failure_reason", None) if status in [None, RULE_ENFORCEMENT_STATUS_SUCCEEDED] and failure_reason: self.status = RULE_ENFORCEMENT_STATUS_FAILED @@ -92,8 +88,8 @@ def __init__(self, *args, **values): # with a consistent get_uid interface. def get_uid(self): # TODO Construct uid from non id field: - uid = [self.RESOURCE_TYPE, str(self.id)] # pylint: disable=E1101 - return ':'.join(uid) + uid = [self.RESOURCE_TYPE, str(self.id)] # pylint: disable=E1101 + return ":".join(uid) rule_enforcement_access = MongoDBAccess(RuleEnforcementDB) diff --git a/st2common/st2common/models/db/runner.py b/st2common/st2common/models/db/runner.py index c2f290f5b45..9097d35be69 100644 --- a/st2common/st2common/models/db/runner.py +++ b/st2common/st2common/models/db/runner.py @@ -22,13 +22,13 @@ from st2common.constants.types import ResourceType __all__ = [ - 'RunnerTypeDB', + "RunnerTypeDB", ] LOG = logging.getLogger(__name__) -PACK_SEPARATOR = '.' +PACK_SEPARATOR = "." class RunnerTypeDB(stormbase.StormBaseDB, stormbase.UIDFieldMixin): @@ -46,31 +46,37 @@ class RunnerTypeDB(stormbase.StormBaseDB, stormbase.UIDFieldMixin): """ RESOURCE_TYPE = ResourceType.RUNNER_TYPE - UID_FIELDS = ['name'] + UID_FIELDS = ["name"] enabled = me.BooleanField( - required=True, default=True, - help_text='A flag indicating whether the runner for this type is enabled.') + required=True, + default=True, + help_text="A flag indicating whether the runner for this type is enabled.", + ) runner_package = me.StringField( required=False, - help_text=('The python package that implements the action runner for this type. If' - 'not provided it assumes package name equals module name.')) + help_text=( + "The python package that implements the action runner for this type. If" + "not provided it assumes package name equals module name." + ), + ) runner_module = me.StringField( required=True, - help_text='The python module that implements the action runner for this type.') + help_text="The python module that implements the action runner for this type.", + ) runner_parameters = me.DictField( - help_text='The specification for parameters for the action runner.') + help_text="The specification for parameters for the action runner." + ) output_key = me.StringField( - help_text='Default key to expect results to be published to.') - output_schema = me.DictField( - help_text='The schema for runner output.') + help_text="Default key to expect results to be published to." + ) + output_schema = me.DictField(help_text="The schema for runner output.") query_module = me.StringField( required=False, - help_text='The python module that implements the query module for this runner.') + help_text="The python module that implements the query module for this runner.", + ) - meta = { - 'indexes': stormbase.UIDFieldMixin.get_indexes() - } + meta = {"indexes": stormbase.UIDFieldMixin.get_indexes()} def __init__(self, *args, **values): super(RunnerTypeDB, self).__init__(*args, **values) diff --git a/st2common/st2common/models/db/sensor.py b/st2common/st2common/models/db/sensor.py index 6517fb3a758..31437ad3215 100644 --- a/st2common/st2common/models/db/sensor.py +++ b/st2common/st2common/models/db/sensor.py @@ -20,13 +20,12 @@ from st2common.models.db import stormbase from st2common.constants.types import ResourceType -__all__ = [ - 'SensorTypeDB' -] +__all__ = ["SensorTypeDB"] -class SensorTypeDB(stormbase.StormBaseDB, stormbase.ContentPackResourceMixin, - stormbase.UIDFieldMixin): +class SensorTypeDB( + stormbase.StormBaseDB, stormbase.ContentPackResourceMixin, stormbase.UIDFieldMixin +): """ Description of a specific type of a sensor (think of it as a sensor template). @@ -40,25 +39,29 @@ class SensorTypeDB(stormbase.StormBaseDB, stormbase.ContentPackResourceMixin, """ RESOURCE_TYPE = ResourceType.SENSOR_TYPE - UID_FIELDS = ['pack', 'name'] + UID_FIELDS = ["pack", "name"] name = me.StringField(required=True) ref = me.StringField(required=True) - pack = me.StringField(required=True, unique_with='name') + pack = me.StringField(required=True, unique_with="name") artifact_uri = me.StringField() entry_point = me.StringField() trigger_types = me.ListField(field=me.StringField()) poll_interval = me.IntField() - enabled = me.BooleanField(default=True, - help_text=u'Flag indicating whether the sensor is enabled.') + enabled = me.BooleanField( + default=True, help_text="Flag indicating whether the sensor is enabled." + ) meta = { - 'indexes': [ - {'fields': ['name']}, - {'fields': ['enabled']}, - {'fields': ['trigger_types']}, - ] + (stormbase.ContentPackResourceMixin.get_indexes() + - stormbase.UIDFieldMixin.get_indexes()) + "indexes": [ + {"fields": ["name"]}, + {"fields": ["enabled"]}, + {"fields": ["trigger_types"]}, + ] + + ( + stormbase.ContentPackResourceMixin.get_indexes() + + stormbase.UIDFieldMixin.get_indexes() + ) } def __init__(self, *args, **values): diff --git a/st2common/st2common/models/db/stormbase.py b/st2common/st2common/models/db/stormbase.py index bf312c6e4f5..50f79dde781 100644 --- a/st2common/st2common/models/db/stormbase.py +++ b/st2common/st2common/models/db/stormbase.py @@ -29,17 +29,15 @@ from st2common.constants.types import ResourceType __all__ = [ - 'StormFoundationDB', - 'StormBaseDB', - - 'EscapedDictField', - 'EscapedDynamicField', - 'TagField', - - 'RefFieldMixin', - 'UIDFieldMixin', - 'TagsMixin', - 'ContentPackResourceMixin' + "StormFoundationDB", + "StormBaseDB", + "EscapedDictField", + "EscapedDynamicField", + "TagField", + "RefFieldMixin", + "UIDFieldMixin", + "TagsMixin", + "ContentPackResourceMixin", ] JSON_UNFRIENDLY_TYPES = (datetime.datetime, bson.ObjectId) @@ -62,17 +60,19 @@ class StormFoundationDB(me.Document, DictSerializableClassMixin): # don't do that # see http://docs.mongoengine.org/guide/defining-documents.html#abstract-classes - meta = { - 'abstract': True - } + meta = {"abstract": True} def __str__(self): attrs = list() - for k in sorted(self._fields.keys()): # pylint: disable=E1101 + for k in sorted(self._fields.keys()): # pylint: disable=E1101 v = getattr(self, k) - v = '"%s"' % str(v) if type(v) in [str, six.text_type, datetime.datetime] else str(v) - attrs.append('%s=%s' % (k, v)) - return '%s(%s)' % (self.__class__.__name__, ', '.join(attrs)) + v = ( + '"%s"' % str(v) + if type(v) in [str, six.text_type, datetime.datetime] + else str(v) + ) + attrs.append("%s=%s" % (k, v)) + return "%s(%s)" % (self.__class__.__name__, ", ".join(attrs)) def get_resource_type(self): return self.RESOURCE_TYPE @@ -98,7 +98,7 @@ def to_serializable_dict(self, mask_secrets=False): :rtype: ``dict`` """ serializable_dict = {} - for k in sorted(six.iterkeys(self._fields)): # pylint: disable=E1101 + for k in sorted(six.iterkeys(self._fields)): # pylint: disable=E1101 v = getattr(self, k) if isinstance(v, JSON_UNFRIENDLY_TYPES): v = str(v) @@ -120,17 +120,15 @@ class StormBaseDB(StormFoundationDB): description = me.StringField() # see http://docs.mongoengine.org/guide/defining-documents.html#abstract-classes - meta = { - 'abstract': True - } + meta = {"abstract": True} class EscapedDictField(me.DictField): - def to_mongo(self, value, use_db_field=True, fields=None): value = mongoescape.escape_chars(value) - return super(EscapedDictField, self).to_mongo(value=value, use_db_field=use_db_field, - fields=fields) + return super(EscapedDictField, self).to_mongo( + value=value, use_db_field=use_db_field, fields=fields + ) def to_python(self, value): value = super(EscapedDictField, self).to_python(value) @@ -138,18 +136,18 @@ def to_python(self, value): def validate(self, value): if not isinstance(value, dict): - self.error('Only dictionaries may be used in a DictField') + self.error("Only dictionaries may be used in a DictField") if me.fields.key_not_string(value): self.error("Invalid dictionary key - documents must have only string keys") me.base.ComplexBaseField.validate(self, value) class EscapedDynamicField(me.DynamicField): - def to_mongo(self, value, use_db_field=True, fields=None): value = mongoescape.escape_chars(value) - return super(EscapedDynamicField, self).to_mongo(value=value, use_db_field=use_db_field, - fields=fields) + return super(EscapedDynamicField, self).to_mongo( + value=value, use_db_field=use_db_field, fields=fields + ) def to_python(self, value): value = super(EscapedDynamicField, self).to_python(value) @@ -161,6 +159,7 @@ class TagField(me.EmbeddedDocument): To be attached to a db model object for the purpose of providing supplemental information. """ + name = me.StringField(max_length=1024) value = me.StringField(max_length=1024) @@ -169,11 +168,12 @@ class TagsMixin(object): """ Mixin to include tags on an object. """ + tags = me.ListField(field=me.EmbeddedDocumentField(TagField)) @classmethod def get_indexes(cls): - return ['tags.name', 'tags.value'] + return ["tags.name", "tags.value"] class RefFieldMixin(object): @@ -192,7 +192,7 @@ class UIDFieldMixin(object): the system. """ - UID_SEPARATOR = ':' # TODO: Move to constants + UID_SEPARATOR = ":" # TODO: Move to constants RESOURCE_TYPE = abc.abstractproperty UID_FIELDS = abc.abstractproperty @@ -205,13 +205,7 @@ def get_indexes(cls): # models in the database before ensure_indexes() is called. # This field gets populated in the constructor which means it will be lazily assigned next # time the model is saved (e.g. once register-content is ran). - indexes = [ - { - 'fields': ['uid'], - 'unique': True, - 'sparse': True - } - ] + indexes = [{"fields": ["uid"], "unique": True, "sparse": True}] return indexes def get_uid(self): @@ -224,7 +218,7 @@ def get_uid(self): parts.append(self.RESOURCE_TYPE) for field in self.UID_FIELDS: - value = getattr(self, field, None) or '' + value = getattr(self, field, None) or "" parts.append(value) uid = self.UID_SEPARATOR.join(parts) @@ -257,8 +251,11 @@ class ContentPackResourceMixin(object): metadata_file = me.StringField( required=False, - help_text=('Path to the metadata file (file on disk which contains resource definition) ' - 'relative to the pack directory.')) + help_text=( + "Path to the metadata file (file on disk which contains resource definition) " + "relative to the pack directory." + ), + ) def get_pack_uid(self): """ @@ -276,7 +273,7 @@ def get_reference(self): :rtype: :class:`ResourceReference` """ - if getattr(self, 'ref', None): + if getattr(self, "ref", None): ref = ResourceReference.from_string_reference(ref=self.ref) else: ref = ResourceReference(pack=self.pack, name=self.name) @@ -287,7 +284,7 @@ def get_reference(self): def get_indexes(cls): return [ { - 'fields': ['metadata_file'], + "fields": ["metadata_file"], } ] @@ -298,9 +295,4 @@ class ChangeRevisionFieldMixin(object): @classmethod def get_indexes(cls): - return [ - { - 'fields': ['id', 'rev'], - 'unique': True - } - ] + return [{"fields": ["id", "rev"], "unique": True}] diff --git a/st2common/st2common/models/db/timer.py b/st2common/st2common/models/db/timer.py index 98bb7952e1e..652d6a056a4 100644 --- a/st2common/st2common/models/db/timer.py +++ b/st2common/st2common/models/db/timer.py @@ -30,10 +30,10 @@ class TimerDB(stormbase.StormFoundationDB, stormbase.UIDFieldMixin): """ RESOURCE_TYPE = ResourceType.TIMER - UID_FIELDS = ['pack', 'name'] + UID_FIELDS = ["pack", "name"] name = me.StringField(required=True) - pack = me.StringField(required=True, unique_with='name') + pack = me.StringField(required=True, unique_with="name") type = me.StringField() parameters = me.DictField() diff --git a/st2common/st2common/models/db/trace.py b/st2common/st2common/models/db/trace.py index 00b7010d912..fe358e90c93 100644 --- a/st2common/st2common/models/db/trace.py +++ b/st2common/st2common/models/db/trace.py @@ -25,25 +25,24 @@ from st2common.models.db import MongoDBAccess -__all__ = [ - 'TraceDB', - 'TraceComponentDB' -] +__all__ = ["TraceDB", "TraceComponentDB"] class TraceComponentDB(me.EmbeddedDocument): - """ - """ + """""" + object_id = me.StringField() - ref = me.StringField(default='') + ref = me.StringField(default="") updated_at = ComplexDateTimeField( default=date_utils.get_datetime_utc_now, - help_text='The timestamp when the TraceComponent was included.') - caused_by = me.DictField(help_text='Causal component.') + help_text="The timestamp when the TraceComponent was included.", + ) + caused_by = me.DictField(help_text="Causal component.") def __str__(self): - return 'TraceComponentDB@(object_id:{}, updated_at:{})'.format( - self.object_id, self.updated_at) + return "TraceComponentDB@(object_id:{}, updated_at:{})".format( + self.object_id, self.updated_at + ) class TraceDB(stormbase.StormFoundationDB, stormbase.UIDFieldMixin): @@ -66,28 +65,37 @@ class TraceDB(stormbase.StormFoundationDB, stormbase.UIDFieldMixin): RESOURCE_TYPE = ResourceType.TRACE - trace_tag = me.StringField(required=True, - help_text='A user specified reference to the trace.') - trigger_instances = me.ListField(field=me.EmbeddedDocumentField(TraceComponentDB), - required=False, - help_text='Associated TriggerInstances.') - rules = me.ListField(field=me.EmbeddedDocumentField(TraceComponentDB), - required=False, - help_text='Associated Rules.') - action_executions = me.ListField(field=me.EmbeddedDocumentField(TraceComponentDB), - required=False, - help_text='Associated ActionExecutions.') - start_timestamp = ComplexDateTimeField(default=date_utils.get_datetime_utc_now, - help_text='The timestamp when the Trace was created.') + trace_tag = me.StringField( + required=True, help_text="A user specified reference to the trace." + ) + trigger_instances = me.ListField( + field=me.EmbeddedDocumentField(TraceComponentDB), + required=False, + help_text="Associated TriggerInstances.", + ) + rules = me.ListField( + field=me.EmbeddedDocumentField(TraceComponentDB), + required=False, + help_text="Associated Rules.", + ) + action_executions = me.ListField( + field=me.EmbeddedDocumentField(TraceComponentDB), + required=False, + help_text="Associated ActionExecutions.", + ) + start_timestamp = ComplexDateTimeField( + default=date_utils.get_datetime_utc_now, + help_text="The timestamp when the Trace was created.", + ) meta = { - 'indexes': [ - {'fields': ['trace_tag']}, - {'fields': ['start_timestamp']}, - {'fields': ['action_executions.object_id']}, - {'fields': ['trigger_instances.object_id']}, - {'fields': ['rules.object_id']}, - {'fields': ['-start_timestamp', 'trace_tag']}, + "indexes": [ + {"fields": ["trace_tag"]}, + {"fields": ["start_timestamp"]}, + {"fields": ["action_executions.object_id"]}, + {"fields": ["trigger_instances.object_id"]}, + {"fields": ["rules.object_id"]}, + {"fields": ["-start_timestamp", "trace_tag"]}, ] } diff --git a/st2common/st2common/models/db/trigger.py b/st2common/st2common/models/db/trigger.py index 0546c3b739a..9b749c52418 100644 --- a/st2common/st2common/models/db/trigger.py +++ b/st2common/st2common/models/db/trigger.py @@ -24,16 +24,18 @@ from st2common.constants.types import ResourceType __all__ = [ - 'TriggerTypeDB', - 'TriggerDB', - 'TriggerInstanceDB', + "TriggerTypeDB", + "TriggerDB", + "TriggerInstanceDB", ] -class TriggerTypeDB(stormbase.StormBaseDB, - stormbase.ContentPackResourceMixin, - stormbase.UIDFieldMixin, - stormbase.TagsMixin): +class TriggerTypeDB( + stormbase.StormBaseDB, + stormbase.ContentPackResourceMixin, + stormbase.UIDFieldMixin, + stormbase.TagsMixin, +): """Description of a specific kind/type of a trigger. The (pack, name) tuple is expected uniquely identify a trigger in the namespace of all triggers provided by a specific trigger_source. @@ -45,18 +47,20 @@ class TriggerTypeDB(stormbase.StormBaseDB, """ RESOURCE_TYPE = ResourceType.TRIGGER_TYPE - UID_FIELDS = ['pack', 'name'] + UID_FIELDS = ["pack", "name"] ref = me.StringField(required=False) name = me.StringField(required=True) - pack = me.StringField(required=True, unique_with='name') + pack = me.StringField(required=True, unique_with="name") payload_schema = me.DictField() parameters_schema = me.DictField(default={}) meta = { - 'indexes': (stormbase.ContentPackResourceMixin.get_indexes() + - stormbase.TagsMixin.get_indexes() + - stormbase.UIDFieldMixin.get_indexes()) + "indexes": ( + stormbase.ContentPackResourceMixin.get_indexes() + + stormbase.TagsMixin.get_indexes() + + stormbase.UIDFieldMixin.get_indexes() + ) } def __init__(self, *args, **values): @@ -66,8 +70,9 @@ def __init__(self, *args, **values): self.uid = self.get_uid() -class TriggerDB(stormbase.StormBaseDB, stormbase.ContentPackResourceMixin, - stormbase.UIDFieldMixin): +class TriggerDB( + stormbase.StormBaseDB, stormbase.ContentPackResourceMixin, stormbase.UIDFieldMixin +): """ Attribute: name - Trigger name. @@ -77,21 +82,22 @@ class TriggerDB(stormbase.StormBaseDB, stormbase.ContentPackResourceMixin, """ RESOURCE_TYPE = ResourceType.TRIGGER - UID_FIELDS = ['pack', 'name'] + UID_FIELDS = ["pack", "name"] ref = me.StringField(required=False) name = me.StringField(required=True) - pack = me.StringField(required=True, unique_with='name') + pack = me.StringField(required=True, unique_with="name") type = me.StringField() parameters = me.DictField() ref_count = me.IntField(default=0) meta = { - 'indexes': [ - {'fields': ['name']}, - {'fields': ['type']}, - {'fields': ['parameters']}, - ] + stormbase.UIDFieldMixin.get_indexes() + "indexes": [ + {"fields": ["name"]}, + {"fields": ["type"]}, + {"fields": ["parameters"]}, + ] + + stormbase.UIDFieldMixin.get_indexes() } def __init__(self, *args, **values): @@ -106,7 +112,7 @@ def get_uid(self): # Note: We sort the resulting JSON object so that the same dictionary always results # in the same hash - parameters = getattr(self, 'parameters', {}) + parameters = getattr(self, "parameters", {}) parameters = json.dumps(parameters, sort_keys=True) parameters = hashlib.md5(parameters.encode()).hexdigest() @@ -126,19 +132,20 @@ class TriggerInstanceDB(stormbase.StormFoundationDB): payload (dict): payload specific to the occurrence. occurrence_time (datetime): time of occurrence of the trigger. """ + trigger = me.StringField() payload = stormbase.EscapedDictField() occurrence_time = me.DateTimeField() status = me.StringField( - required=True, - help_text='Processing status of TriggerInstance.') + required=True, help_text="Processing status of TriggerInstance." + ) meta = { - 'indexes': [ - {'fields': ['occurrence_time']}, - {'fields': ['trigger']}, - {'fields': ['-occurrence_time', 'trigger']}, - {'fields': ['status']} + "indexes": [ + {"fields": ["occurrence_time"]}, + {"fields": ["trigger"]}, + {"fields": ["-occurrence_time", "trigger"]}, + {"fields": ["status"]}, ] } diff --git a/st2common/st2common/models/db/webhook.py b/st2common/st2common/models/db/webhook.py index 0ef2906b909..b608f6c3553 100644 --- a/st2common/st2common/models/db/webhook.py +++ b/st2common/st2common/models/db/webhook.py @@ -29,7 +29,7 @@ class WebhookDB(stormbase.StormFoundationDB, stormbase.UIDFieldMixin): """ RESOURCE_TYPE = ResourceType.WEBHOOK - UID_FIELDS = ['name'] + UID_FIELDS = ["name"] name = me.StringField(required=True) @@ -40,7 +40,7 @@ def __init__(self, *args, **values): def _normalize_name(self, name): # Remove trailing slash if present - if name.endswith('/'): + if name.endswith("/"): name = name[:-1] return name diff --git a/st2common/st2common/models/db/workflow.py b/st2common/st2common/models/db/workflow.py index dc73c1c55c0..fd5cdb111eb 100644 --- a/st2common/st2common/models/db/workflow.py +++ b/st2common/st2common/models/db/workflow.py @@ -24,16 +24,15 @@ from st2common.util import date as date_utils -__all__ = [ - 'WorkflowExecutionDB', - 'TaskExecutionDB' -] +__all__ = ["WorkflowExecutionDB", "TaskExecutionDB"] LOG = logging.getLogger(__name__) -class WorkflowExecutionDB(stormbase.StormFoundationDB, stormbase.ChangeRevisionFieldMixin): +class WorkflowExecutionDB( + stormbase.StormFoundationDB, stormbase.ChangeRevisionFieldMixin +): RESOURCE_TYPE = types.ResourceType.EXECUTION action_execution = me.StringField(required=True) @@ -46,14 +45,12 @@ class WorkflowExecutionDB(stormbase.StormFoundationDB, stormbase.ChangeRevisionF status = me.StringField(required=True) output = stormbase.EscapedDictField() errors = stormbase.EscapedDynamicField() - start_timestamp = db_field_types.ComplexDateTimeField(default=date_utils.get_datetime_utc_now) + start_timestamp = db_field_types.ComplexDateTimeField( + default=date_utils.get_datetime_utc_now + ) end_timestamp = db_field_types.ComplexDateTimeField() - meta = { - 'indexes': [ - {'fields': ['action_execution']} - ] - } + meta = {"indexes": [{"fields": ["action_execution"]}]} class TaskExecutionDB(stormbase.StormFoundationDB, stormbase.ChangeRevisionFieldMixin): @@ -71,21 +68,20 @@ class TaskExecutionDB(stormbase.StormFoundationDB, stormbase.ChangeRevisionField context = stormbase.EscapedDictField() status = me.StringField(required=True) result = stormbase.EscapedDictField() - start_timestamp = db_field_types.ComplexDateTimeField(default=date_utils.get_datetime_utc_now) + start_timestamp = db_field_types.ComplexDateTimeField( + default=date_utils.get_datetime_utc_now + ) end_timestamp = db_field_types.ComplexDateTimeField() meta = { - 'indexes': [ - {'fields': ['workflow_execution']}, - {'fields': ['task_id']}, - {'fields': ['task_id', 'task_route']}, - {'fields': ['workflow_execution', 'task_id']}, - {'fields': ['workflow_execution', 'task_id', 'task_route']} + "indexes": [ + {"fields": ["workflow_execution"]}, + {"fields": ["task_id"]}, + {"fields": ["task_id", "task_route"]}, + {"fields": ["workflow_execution", "task_id"]}, + {"fields": ["workflow_execution", "task_id", "task_route"]}, ] } -MODELS = [ - WorkflowExecutionDB, - TaskExecutionDB -] +MODELS = [WorkflowExecutionDB, TaskExecutionDB] diff --git a/st2common/st2common/models/system/action.py b/st2common/st2common/models/system/action.py index b5efe124f5e..2afcbf649b5 100644 --- a/st2common/st2common/models/system/action.py +++ b/st2common/st2common/models/system/action.py @@ -35,11 +35,11 @@ from st2common.constants.secrets import MASKED_ATTRIBUTE_VALUE __all__ = [ - 'ShellCommandAction', - 'ShellScriptAction', - 'RemoteAction', - 'RemoteScriptAction', - 'ResolvedActionParameters' + "ShellCommandAction", + "ShellScriptAction", + "RemoteAction", + "RemoteScriptAction", + "ResolvedActionParameters", ] LOG = logging.getLogger(__name__) @@ -48,21 +48,31 @@ # Flags which are passed to every sudo invocation SUDO_COMMON_OPTIONS = [ - '-E' # we want to preserve the environment of the user which ran sudo -] + "-E" +] # we want to preserve the environment of the user which ran sudo # Flags which are only passed to sudo when not running as current user and when # -u flag is used SUDO_DIFFERENT_USER_OPTIONS = [ - '-H' # we want $HOME to reflect the home directory of the requested / target user + "-H" # we want $HOME to reflect the home directory of the requested / target user ] class ShellCommandAction(object): - EXPORT_CMD = 'export' - - def __init__(self, name, action_exec_id, command, user, env_vars=None, sudo=False, - timeout=None, cwd=None, sudo_password=None): + EXPORT_CMD = "export" + + def __init__( + self, + name, + action_exec_id, + command, + user, + env_vars=None, + sudo=False, + timeout=None, + cwd=None, + sudo_password=None, + ): self.name = name self.action_exec_id = action_exec_id self.command = command @@ -77,15 +87,15 @@ def get_full_command_string(self): # Note: We pass -E to sudo because we want to preserve user provided environment variables if self.sudo: command = quote_unix(self.command) - sudo_arguments = ' '.join(self._get_common_sudo_arguments()) - command = 'sudo %s -- bash -c %s' % (sudo_arguments, command) + sudo_arguments = " ".join(self._get_common_sudo_arguments()) + command = "sudo %s -- bash -c %s" % (sudo_arguments, command) else: if self.user and self.user != LOGGED_USER_USERNAME: # Need to use sudo to run as a different (requested) user user = quote_unix(self.user) - sudo_arguments = ' '.join(self._get_user_sudo_arguments(user=user)) + sudo_arguments = " ".join(self._get_user_sudo_arguments(user=user)) command = quote_unix(self.command) - command = 'sudo %s -- bash -c %s' % (sudo_arguments, command) + command = "sudo %s -- bash -c %s" % (sudo_arguments, command) else: command = self.command @@ -103,7 +113,10 @@ def get_sanitized_full_command_string(self): if self.sudo_password: # Mask sudo password - command_string = 'echo -e \'%s\n\' | %s' % (MASKED_ATTRIBUTE_VALUE, command_string) + command_string = "echo -e '%s\n' | %s" % ( + MASKED_ATTRIBUTE_VALUE, + command_string, + ) return command_string @@ -124,7 +137,7 @@ def _get_common_sudo_arguments(self): if self.sudo_password: # Note: We use subprocess.Popen in local runner so we provide password via subprocess # stdin (using echo -e won't work when using subprocess.Popen) - flags.append('-S') + flags.append("-S") flags = flags + SUDO_COMMON_OPTIONS @@ -139,7 +152,7 @@ def _get_user_sudo_arguments(self, user): """ flags = self._get_common_sudo_arguments() flags += SUDO_DIFFERENT_USER_OPTIONS - flags += ['-u', user] + flags += ["-u", user] return flags @@ -150,21 +163,21 @@ def _get_env_vars_export_string(self): # If sudo_password is provided, explicitly disable bash history to make sure password # is not logged, because password is provided via command line if self.sudo and self.sudo_password: - env_vars['HISTFILE'] = '/dev/null' - env_vars['HISTSIZE'] = '0' + env_vars["HISTFILE"] = "/dev/null" + env_vars["HISTSIZE"] = "0" # Sort the dict to guarantee consistent order env_vars = collections.OrderedDict(sorted(env_vars.items())) # Environment variables could contain spaces and open us to shell # injection attacks. Always quote the key and the value. - exports = ' '.join( - '%s=%s' % (quote_unix(k), quote_unix(v)) + exports = " ".join( + "%s=%s" % (quote_unix(k), quote_unix(v)) for k, v in six.iteritems(env_vars) ) - shell_env_str = '%s %s' % (ShellCommandAction.EXPORT_CMD, exports) + shell_env_str = "%s %s" % (ShellCommandAction.EXPORT_CMD, exports) else: - shell_env_str = '' + shell_env_str = "" return shell_env_str @@ -180,8 +193,8 @@ def _get_command_string(self, cmd, args): assert isinstance(args, (list, tuple)) args = [quote_unix(arg) for arg in args] - args = ' '.join(args) - result = '%s %s' % (cmd, args) + args = " ".join(args) + result = "%s %s" % (cmd, args) return result def _get_error_result(self): @@ -195,24 +208,42 @@ def _get_error_result(self): _, exc_value, exc_traceback = sys.exc_info() exc_value = str(exc_value) - exc_traceback = ''.join(traceback.format_tb(exc_traceback)) + exc_traceback = "".join(traceback.format_tb(exc_traceback)) result = {} - result['failed'] = True - result['succeeded'] = False - result['error'] = exc_value - result['traceback'] = exc_traceback + result["failed"] = True + result["succeeded"] = False + result["error"] = exc_value + result["traceback"] = exc_traceback return result class ShellScriptAction(ShellCommandAction): - def __init__(self, name, action_exec_id, script_local_path_abs, named_args=None, - positional_args=None, env_vars=None, user=None, sudo=False, timeout=None, - cwd=None, sudo_password=None): - super(ShellScriptAction, self).__init__(name=name, action_exec_id=action_exec_id, - command=None, user=user, env_vars=env_vars, - sudo=sudo, timeout=timeout, - cwd=cwd, sudo_password=sudo_password) + def __init__( + self, + name, + action_exec_id, + script_local_path_abs, + named_args=None, + positional_args=None, + env_vars=None, + user=None, + sudo=False, + timeout=None, + cwd=None, + sudo_password=None, + ): + super(ShellScriptAction, self).__init__( + name=name, + action_exec_id=action_exec_id, + command=None, + user=user, + env_vars=env_vars, + sudo=sudo, + timeout=timeout, + cwd=cwd, + sudo_password=sudo_password, + ) self.script_local_path_abs = script_local_path_abs self.named_args = named_args self.positional_args = positional_args @@ -221,33 +252,38 @@ def get_full_command_string(self): return self._format_command() def _format_command(self): - script_arguments = self._get_script_arguments(named_args=self.named_args, - positional_args=self.positional_args) + script_arguments = self._get_script_arguments( + named_args=self.named_args, positional_args=self.positional_args + ) if self.sudo: if script_arguments: - command = quote_unix('%s %s' % (self.script_local_path_abs, script_arguments)) + command = quote_unix( + "%s %s" % (self.script_local_path_abs, script_arguments) + ) else: command = quote_unix(self.script_local_path_abs) - sudo_arguments = ' '.join(self._get_common_sudo_arguments()) - command = 'sudo %s -- bash -c %s' % (sudo_arguments, command) + sudo_arguments = " ".join(self._get_common_sudo_arguments()) + command = "sudo %s -- bash -c %s" % (sudo_arguments, command) else: if self.user and self.user != LOGGED_USER_USERNAME: # Need to use sudo to run as a different user user = quote_unix(self.user) if script_arguments: - command = quote_unix('%s %s' % (self.script_local_path_abs, script_arguments)) + command = quote_unix( + "%s %s" % (self.script_local_path_abs, script_arguments) + ) else: command = quote_unix(self.script_local_path_abs) - sudo_arguments = ' '.join(self._get_user_sudo_arguments(user=user)) - command = 'sudo %s -- bash -c %s' % (sudo_arguments, command) + sudo_arguments = " ".join(self._get_user_sudo_arguments(user=user)) + command = "sudo %s -- bash -c %s" % (sudo_arguments, command) else: script_path = quote_unix(self.script_local_path_abs) if script_arguments: - command = '%s %s' % (script_path, script_arguments) + command = "%s %s" % (script_path, script_arguments) else: command = script_path return command @@ -270,8 +306,10 @@ def _get_script_arguments(self, named_args=None, positional_args=None): # add all named_args in the format name=value (e.g. --name=value) if named_args is not None: for (arg, value) in six.iteritems(named_args): - if value is None or (isinstance(value, (str, six.text_type)) and len(value) < 1): - LOG.debug('Ignoring arg %s as its value is %s.', arg, value) + if value is None or ( + isinstance(value, (str, six.text_type)) and len(value) < 1 + ): + LOG.debug("Ignoring arg %s as its value is %s.", arg, value) continue if isinstance(value, bool): @@ -279,24 +317,45 @@ def _get_script_arguments(self, named_args=None, positional_args=None): command_parts.append(arg) else: values = (quote_unix(arg), quote_unix(six.text_type(value))) - command_parts.append(six.text_type('%s=%s' % values)) + command_parts.append(six.text_type("%s=%s" % values)) # add the positional args if positional_args: quoted_pos_args = [quote_unix(pos_arg) for pos_arg in positional_args] - pos_args_string = ' '.join(quoted_pos_args) + pos_args_string = " ".join(quoted_pos_args) command_parts.append(pos_args_string) - return ' '.join(command_parts) + return " ".join(command_parts) class SSHCommandAction(ShellCommandAction): - def __init__(self, name, action_exec_id, command, env_vars, user, password=None, pkey=None, - hosts=None, parallel=True, sudo=False, timeout=None, cwd=None, passphrase=None, - sudo_password=None): - super(SSHCommandAction, self).__init__(name=name, action_exec_id=action_exec_id, - command=command, env_vars=env_vars, user=user, - sudo=sudo, timeout=timeout, cwd=cwd, - sudo_password=sudo_password) + def __init__( + self, + name, + action_exec_id, + command, + env_vars, + user, + password=None, + pkey=None, + hosts=None, + parallel=True, + sudo=False, + timeout=None, + cwd=None, + passphrase=None, + sudo_password=None, + ): + super(SSHCommandAction, self).__init__( + name=name, + action_exec_id=action_exec_id, + command=command, + env_vars=env_vars, + user=user, + sudo=sudo, + timeout=timeout, + cwd=cwd, + sudo_password=sudo_password, + ) self.hosts = hosts self.parallel = parallel self.pkey = pkey @@ -329,25 +388,51 @@ def get_command(self): def __str__(self): str_rep = [] - str_rep.append('%s@%s(name: %s' % (self.__class__.__name__, id(self), self.name)) - str_rep.append('id: %s' % self.action_exec_id) - str_rep.append('command: %s' % self.command) - str_rep.append('user: %s' % self.user) - str_rep.append('sudo: %s' % str(self.sudo)) - str_rep.append('parallel: %s' % str(self.parallel)) - str_rep.append('hosts: %s)' % str(self.hosts)) - return ', '.join(str_rep) + str_rep.append( + "%s@%s(name: %s" % (self.__class__.__name__, id(self), self.name) + ) + str_rep.append("id: %s" % self.action_exec_id) + str_rep.append("command: %s" % self.command) + str_rep.append("user: %s" % self.user) + str_rep.append("sudo: %s" % str(self.sudo)) + str_rep.append("parallel: %s" % str(self.parallel)) + str_rep.append("hosts: %s)" % str(self.hosts)) + return ", ".join(str_rep) class RemoteAction(SSHCommandAction): - def __init__(self, name, action_exec_id, command, env_vars=None, on_behalf_user=None, - user=None, password=None, private_key=None, hosts=None, parallel=True, sudo=False, - timeout=None, cwd=None, passphrase=None, sudo_password=None): - super(RemoteAction, self).__init__(name=name, action_exec_id=action_exec_id, - command=command, env_vars=env_vars, user=user, - hosts=hosts, parallel=parallel, sudo=sudo, - timeout=timeout, cwd=cwd, passphrase=passphrase, - sudo_password=sudo_password) + def __init__( + self, + name, + action_exec_id, + command, + env_vars=None, + on_behalf_user=None, + user=None, + password=None, + private_key=None, + hosts=None, + parallel=True, + sudo=False, + timeout=None, + cwd=None, + passphrase=None, + sudo_password=None, + ): + super(RemoteAction, self).__init__( + name=name, + action_exec_id=action_exec_id, + command=command, + env_vars=env_vars, + user=user, + hosts=hosts, + parallel=parallel, + sudo=sudo, + timeout=timeout, + cwd=cwd, + passphrase=passphrase, + sudo_password=sudo_password, + ) self.password = password self.private_key = private_key self.passphrase = passphrase @@ -359,34 +444,61 @@ def get_on_behalf_user(self): def __str__(self): str_rep = [] - str_rep.append('%s@%s(name: %s' % (self.__class__.__name__, id(self), self.name)) - str_rep.append('id: %s' % self.action_exec_id) - str_rep.append('command: %s' % self.command) - str_rep.append('user: %s' % self.user) - str_rep.append('on_behalf_user: %s' % self.on_behalf_user) - str_rep.append('sudo: %s' % str(self.sudo)) - str_rep.append('parallel: %s' % str(self.parallel)) - str_rep.append('hosts: %s)' % str(self.hosts)) - str_rep.append('timeout: %s)' % str(self.timeout)) + str_rep.append( + "%s@%s(name: %s" % (self.__class__.__name__, id(self), self.name) + ) + str_rep.append("id: %s" % self.action_exec_id) + str_rep.append("command: %s" % self.command) + str_rep.append("user: %s" % self.user) + str_rep.append("on_behalf_user: %s" % self.on_behalf_user) + str_rep.append("sudo: %s" % str(self.sudo)) + str_rep.append("parallel: %s" % str(self.parallel)) + str_rep.append("hosts: %s)" % str(self.hosts)) + str_rep.append("timeout: %s)" % str(self.timeout)) - return ', '.join(str_rep) + return ", ".join(str_rep) class RemoteScriptAction(ShellScriptAction): - def __init__(self, name, action_exec_id, script_local_path_abs, script_local_libs_path_abs, - named_args=None, positional_args=None, env_vars=None, on_behalf_user=None, - user=None, password=None, private_key=None, remote_dir=None, hosts=None, - parallel=True, sudo=False, timeout=None, cwd=None, sudo_password=None): - super(RemoteScriptAction, self).__init__(name=name, action_exec_id=action_exec_id, - script_local_path_abs=script_local_path_abs, - user=user, - named_args=named_args, - positional_args=positional_args, env_vars=env_vars, - sudo=sudo, timeout=timeout, cwd=cwd, - sudo_password=sudo_password) + def __init__( + self, + name, + action_exec_id, + script_local_path_abs, + script_local_libs_path_abs, + named_args=None, + positional_args=None, + env_vars=None, + on_behalf_user=None, + user=None, + password=None, + private_key=None, + remote_dir=None, + hosts=None, + parallel=True, + sudo=False, + timeout=None, + cwd=None, + sudo_password=None, + ): + super(RemoteScriptAction, self).__init__( + name=name, + action_exec_id=action_exec_id, + script_local_path_abs=script_local_path_abs, + user=user, + named_args=named_args, + positional_args=positional_args, + env_vars=env_vars, + sudo=sudo, + timeout=timeout, + cwd=cwd, + sudo_password=sudo_password, + ) self.script_local_libs_path_abs = script_local_libs_path_abs - self.script_local_dir, self.script_name = os.path.split(self.script_local_path_abs) - self.remote_dir = remote_dir if remote_dir is not None else '/tmp' + self.script_local_dir, self.script_name = os.path.split( + self.script_local_path_abs + ) + self.remote_dir = remote_dir if remote_dir is not None else "/tmp" self.remote_libs_path_abs = os.path.join(self.remote_dir, ACTION_LIBS_DIR) self.on_behalf_user = on_behalf_user self.password = password @@ -395,7 +507,7 @@ def __init__(self, name, action_exec_id, script_local_path_abs, script_local_lib self.hosts = hosts self.parallel = parallel self.command = self._format_command() - LOG.debug('RemoteScriptAction: command to run on remote box: %s', self.command) + LOG.debug("RemoteScriptAction: command to run on remote box: %s", self.command) def get_remote_script_abs_path(self): return self.remote_script @@ -413,11 +525,12 @@ def get_remote_base_dir(self): return self.remote_dir def _format_command(self): - script_arguments = self._get_script_arguments(named_args=self.named_args, - positional_args=self.positional_args) + script_arguments = self._get_script_arguments( + named_args=self.named_args, positional_args=self.positional_args + ) if script_arguments: - command = '%s %s' % (self.remote_script, script_arguments) + command = "%s %s" % (self.remote_script, script_arguments) else: command = self.remote_script @@ -425,21 +538,23 @@ def _format_command(self): def __str__(self): str_rep = [] - str_rep.append('%s@%s(name: %s' % (self.__class__.__name__, id(self), self.name)) - str_rep.append('id: %s' % self.action_exec_id) - str_rep.append('local_script: %s' % self.script_local_path_abs) - str_rep.append('local_libs: %s' % self.script_local_libs_path_abs) - str_rep.append('remote_dir: %s' % self.remote_dir) - str_rep.append('remote_libs: %s' % self.remote_libs_path_abs) - str_rep.append('named_args: %s' % self.named_args) - str_rep.append('positional_args: %s' % self.positional_args) - str_rep.append('user: %s' % self.user) - str_rep.append('on_behalf_user: %s' % self.on_behalf_user) - str_rep.append('sudo: %s' % self.sudo) - str_rep.append('parallel: %s' % self.parallel) - str_rep.append('hosts: %s)' % self.hosts) - - return ', '.join(str_rep) + str_rep.append( + "%s@%s(name: %s" % (self.__class__.__name__, id(self), self.name) + ) + str_rep.append("id: %s" % self.action_exec_id) + str_rep.append("local_script: %s" % self.script_local_path_abs) + str_rep.append("local_libs: %s" % self.script_local_libs_path_abs) + str_rep.append("remote_dir: %s" % self.remote_dir) + str_rep.append("remote_libs: %s" % self.remote_libs_path_abs) + str_rep.append("named_args: %s" % self.named_args) + str_rep.append("positional_args: %s" % self.positional_args) + str_rep.append("user: %s" % self.user) + str_rep.append("on_behalf_user: %s" % self.on_behalf_user) + str_rep.append("sudo: %s" % self.sudo) + str_rep.append("parallel: %s" % self.parallel) + str_rep.append("hosts: %s)" % self.hosts) + + return ", ".join(str_rep) class ResolvedActionParameters(DictSerializableClassMixin): @@ -447,7 +562,9 @@ class ResolvedActionParameters(DictSerializableClassMixin): Class which contains resolved runner and action parameters for a particular action. """ - def __init__(self, action_db, runner_type_db, runner_parameters=None, action_parameters=None): + def __init__( + self, action_db, runner_type_db, runner_parameters=None, action_parameters=None + ): self._action_db = action_db self._runner_type_db = runner_type_db self._runner_parameters = runner_parameters @@ -456,28 +573,34 @@ def __init__(self, action_db, runner_type_db, runner_parameters=None, action_par def mask_secrets(self, value): result = copy.deepcopy(value) - runner_parameters = result['runner_parameters'] - action_parameters = result['action_parameters'] + runner_parameters = result["runner_parameters"] + action_parameters = result["action_parameters"] runner_parameters_specs = self._runner_type_db.runner_parameters action_parameters_sepcs = self._action_db.parameters - secret_runner_parameters = get_secret_parameters(parameters=runner_parameters_specs) - secret_action_parameters = get_secret_parameters(parameters=action_parameters_sepcs) - - runner_parameters = mask_secret_parameters(parameters=runner_parameters, - secret_parameters=secret_runner_parameters) - action_parameters = mask_secret_parameters(parameters=action_parameters, - secret_parameters=secret_action_parameters) - result['runner_parameters'] = runner_parameters - result['action_parameters'] = action_parameters + secret_runner_parameters = get_secret_parameters( + parameters=runner_parameters_specs + ) + secret_action_parameters = get_secret_parameters( + parameters=action_parameters_sepcs + ) + + runner_parameters = mask_secret_parameters( + parameters=runner_parameters, secret_parameters=secret_runner_parameters + ) + action_parameters = mask_secret_parameters( + parameters=action_parameters, secret_parameters=secret_action_parameters + ) + result["runner_parameters"] = runner_parameters + result["action_parameters"] = action_parameters return result def to_serializable_dict(self, mask_secrets=False): result = {} - result['runner_parameters'] = self._runner_parameters - result['action_parameters'] = self._action_parameters + result["runner_parameters"] = self._runner_parameters + result["action_parameters"] = self._action_parameters if mask_secrets and cfg.CONF.log.mask_secrets: result = self.mask_secrets(value=result) diff --git a/st2common/st2common/models/system/actionchain.py b/st2common/st2common/models/system/actionchain.py index 2c5ce24c3da..24a84cc6b6b 100644 --- a/st2common/st2common/models/system/actionchain.py +++ b/st2common/st2common/models/system/actionchain.py @@ -31,45 +31,45 @@ class Node(object): "name": { "description": "The name of this node.", "type": "string", - "required": True + "required": True, }, "ref": { "type": "string", "description": "Ref of the action to be executed.", - "required": True + "required": True, }, "params": { "type": "object", - "description": ("Parameter for the execution (old name, here for backward " - "compatibility reasons)."), - "default": {} + "description": ( + "Parameter for the execution (old name, here for backward " + "compatibility reasons)." + ), + "default": {}, }, "parameters": { "type": "object", "description": "Parameter for the execution.", - "default": {} + "default": {}, }, "on-success": { "type": "string", "description": "Name of the node to invoke on successful completion of action" - " executed for this node.", - "default": "" + " executed for this node.", + "default": "", }, "on-failure": { "type": "string", "description": "Name of the node to invoke on failure of action executed for this" - " node.", - "default": "" + " node.", + "default": "", }, "publish": { "description": "The variables to publish from the result. Should be of the form" - " name.foo. o1: {{node_name.foo}} will result in creation of a" - " variable o1 which is now available for reference through" - " remainder of the chain as a global variable.", + " name.foo. o1: {{node_name.foo}} will result in creation of a" + " variable o1 which is now available for reference through" + " remainder of the chain as a global variable.", "type": "object", - "patternProperties": { - r"^\w+$": {} - } + "patternProperties": {r"^\w+$": {}}, }, "notify": { "description": "Notification settings for action.", @@ -77,43 +77,49 @@ class Node(object): "properties": { "on-complete": NotificationSubSchemaAPI, "on-failure": NotificationSubSchemaAPI, - "on-success": NotificationSubSchemaAPI + "on-success": NotificationSubSchemaAPI, }, - "additionalProperties": False - } + "additionalProperties": False, + }, }, - "additionalProperties": False + "additionalProperties": False, } def __init__(self, **kw): - for prop in six.iterkeys(self.schema.get('properties', [])): + for prop in six.iterkeys(self.schema.get("properties", [])): value = kw.get(prop, None) # having '-' in the property name lead to challenges in referencing the property. # At hindsight the schema property should've been on_success rather than on-success. - prop = prop.replace('-', '_') + prop = prop.replace("-", "_") setattr(self, prop, value) def validate(self): - params = getattr(self, 'params', {}) - parameters = getattr(self, 'parameters', {}) + params = getattr(self, "params", {}) + parameters = getattr(self, "parameters", {}) if params and parameters: - msg = ('Either "params" or "parameters" attribute needs to be provided, but not ' - 'both') + msg = ( + 'Either "params" or "parameters" attribute needs to be provided, but not ' + "both" + ) raise ValueError(msg) return self def get_parameters(self): # Note: "params" is old deprecated attribute which will be removed in a future release - params = getattr(self, 'params', {}) - parameters = getattr(self, 'parameters', {}) + params = getattr(self, "params", {}) + parameters = getattr(self, "parameters", {}) return parameters or params def __repr__(self): - return ('' % - (self.name, self.ref, self.on_success, self.on_failure)) + return "" % ( + self.name, + self.ref, + self.on_success, + self.on_failure, + ) class ActionChain(object): @@ -127,31 +133,34 @@ class ActionChain(object): "description": "The chain.", "type": "array", "items": [Node.schema], - "required": True + "required": True, }, "default": { "type": "string", - "description": "name of the action to be executed." + "description": "name of the action to be executed.", }, "vars": { "description": "", "type": "object", - "patternProperties": { - r"^\w+$": {} - } - } + "patternProperties": {r"^\w+$": {}}, + }, }, - "additionalProperties": False + "additionalProperties": False, } def __init__(self, **kw): - util_schema.validate(instance=kw, schema=self.schema, cls=util_schema.CustomValidator, - use_default=False, allow_default_none=True) - - for prop in six.iterkeys(self.schema.get('properties', [])): + util_schema.validate( + instance=kw, + schema=self.schema, + cls=util_schema.CustomValidator, + use_default=False, + allow_default_none=True, + ) + + for prop in six.iterkeys(self.schema.get("properties", [])): value = kw.get(prop, None) # special handling for chain property to create the Node object - if prop == 'chain': + if prop == "chain": nodes = [] for node in value: ac_node = Node(**node) diff --git a/st2common/st2common/models/system/common.py b/st2common/st2common/models/system/common.py index a56f6701acb..72ad6c3f84a 100644 --- a/st2common/st2common/models/system/common.py +++ b/st2common/st2common/models/system/common.py @@ -14,17 +14,17 @@ # limitations under the License. __all__ = [ - 'InvalidReferenceError', - 'InvalidResourceReferenceError', - 'ResourceReference', + "InvalidReferenceError", + "InvalidResourceReferenceError", + "ResourceReference", ] -PACK_SEPARATOR = '.' +PACK_SEPARATOR = "." class InvalidReferenceError(ValueError): def __init__(self, ref): - message = 'Invalid reference: %s' % (ref) + message = "Invalid reference: %s" % (ref) self.ref = ref self.message = message super(InvalidReferenceError, self).__init__(message) @@ -32,7 +32,7 @@ def __init__(self, ref): class InvalidResourceReferenceError(ValueError): def __init__(self, ref): - message = 'Invalid resource reference: %s' % (ref) + message = "Invalid resource reference: %s" % (ref) self.ref = ref self.message = message super(InvalidResourceReferenceError, self).__init__(message) @@ -42,6 +42,7 @@ class ResourceReference(object): """ Class used for referring to resources which belong to a content pack. """ + def __init__(self, pack=None, name=None): self.pack = self.validate_pack_name(pack=pack) self.name = name @@ -72,8 +73,10 @@ def to_string_reference(pack=None, name=None): pack = ResourceReference.validate_pack_name(pack=pack) return PACK_SEPARATOR.join([pack, name]) else: - raise ValueError('Both pack and name needed for building ref. pack=%s, name=%s' % - (pack, name)) + raise ValueError( + "Both pack and name needed for building ref. pack=%s, name=%s" + % (pack, name) + ) @staticmethod def validate_pack_name(pack): @@ -97,5 +100,8 @@ def get_name(ref): raise InvalidResourceReferenceError(ref=ref) def __repr__(self): - return ('' % - (self.pack, self.name, self.ref)) + return "" % ( + self.pack, + self.name, + self.ref, + ) diff --git a/st2common/st2common/models/system/keyvalue.py b/st2common/st2common/models/system/keyvalue.py index 0bac5949d83..018df956029 100644 --- a/st2common/st2common/models/system/keyvalue.py +++ b/st2common/st2common/models/system/keyvalue.py @@ -17,13 +17,13 @@ from st2common.constants.keyvalue import USER_SEPARATOR __all__ = [ - 'InvalidUserKeyReferenceError', + "InvalidUserKeyReferenceError", ] class InvalidUserKeyReferenceError(ValueError): def __init__(self, ref): - message = 'Invalid resource reference: %s' % (ref) + message = "Invalid resource reference: %s" % (ref) self.ref = ref self.message = message super(InvalidUserKeyReferenceError, self).__init__(message) @@ -38,7 +38,7 @@ class UserKeyReference(object): def __init__(self, user, name): self._user = user self._name = name - self.ref = ('%s%s%s' % (self._user, USER_SEPARATOR, self._name)) + self.ref = "%s%s%s" % (self._user, USER_SEPARATOR, self._name) def __str__(self): return self.ref diff --git a/st2common/st2common/models/system/paramiko_command_action.py b/st2common/st2common/models/system/paramiko_command_action.py index a96183ef9e1..685ffeb67ce 100644 --- a/st2common/st2common/models/system/paramiko_command_action.py +++ b/st2common/st2common/models/system/paramiko_command_action.py @@ -23,7 +23,7 @@ from st2common.util.shell import quote_unix __all__ = [ - 'ParamikoRemoteCommandAction', + "ParamikoRemoteCommandAction", ] LOG = logging.getLogger(__name__) @@ -32,7 +32,6 @@ class ParamikoRemoteCommandAction(RemoteAction): - def get_full_command_string(self): # Note: We pass -E to sudo because we want to preserve user provided environment variables env_str = self._get_env_vars_export_string() @@ -40,24 +39,25 @@ def get_full_command_string(self): if self.sudo: if env_str: - command = quote_unix('%s && cd %s && %s' % (env_str, cwd, self.command)) + command = quote_unix("%s && cd %s && %s" % (env_str, cwd, self.command)) else: - command = quote_unix('cd %s && %s' % (cwd, self.command)) + command = quote_unix("cd %s && %s" % (cwd, self.command)) - sudo_arguments = ' '.join(self._get_common_sudo_arguments()) - command = 'sudo %s -- bash -c %s' % (sudo_arguments, command) + sudo_arguments = " ".join(self._get_common_sudo_arguments()) + command = "sudo %s -- bash -c %s" % (sudo_arguments, command) if self.sudo_password: - command = ('set +o history ; echo -e %s | %s' % - (quote_unix('%s\n' % (self.sudo_password)), command)) + command = "set +o history ; echo -e %s | %s" % ( + quote_unix("%s\n" % (self.sudo_password)), + command, + ) else: if env_str: - command = '%s && cd %s && %s' % (env_str, cwd, - self.command) + command = "%s && cd %s && %s" % (env_str, cwd, self.command) else: - command = 'cd %s && %s' % (cwd, self.command) + command = "cd %s && %s" % (cwd, self.command) - LOG.debug('Command to run on remote host will be: %s', command) + LOG.debug("Command to run on remote host will be: %s", command) return command def _get_common_sudo_arguments(self): @@ -69,7 +69,7 @@ def _get_common_sudo_arguments(self): flags = [] if self.sudo_password: - flags.append('-S') + flags.append("-S") flags = flags + SUDO_COMMON_OPTIONS diff --git a/st2common/st2common/models/system/paramiko_script_action.py b/st2common/st2common/models/system/paramiko_script_action.py index a6ff26a751c..284e87a708c 100644 --- a/st2common/st2common/models/system/paramiko_script_action.py +++ b/st2common/st2common/models/system/paramiko_script_action.py @@ -20,7 +20,7 @@ from st2common.util.shell import quote_unix __all__ = [ - 'ParamikoRemoteScriptAction', + "ParamikoRemoteScriptAction", ] @@ -28,10 +28,10 @@ class ParamikoRemoteScriptAction(RemoteScriptAction): - def _format_command(self): - script_arguments = self._get_script_arguments(named_args=self.named_args, - positional_args=self.positional_args) + script_arguments = self._get_script_arguments( + named_args=self.named_args, positional_args=self.positional_args + ) env_str = self._get_env_vars_export_string() cwd = quote_unix(self.get_cwd()) script_path = quote_unix(self.remote_script) @@ -39,36 +39,46 @@ def _format_command(self): if self.sudo: if script_arguments: if env_str: - command = quote_unix('%s && cd %s && %s %s' % ( - env_str, cwd, script_path, script_arguments)) + command = quote_unix( + "%s && cd %s && %s %s" + % (env_str, cwd, script_path, script_arguments) + ) else: - command = quote_unix('cd %s && %s %s' % ( - cwd, script_path, script_arguments)) + command = quote_unix( + "cd %s && %s %s" % (cwd, script_path, script_arguments) + ) else: if env_str: - command = quote_unix('%s && cd %s && %s' % ( - env_str, cwd, script_path)) + command = quote_unix( + "%s && cd %s && %s" % (env_str, cwd, script_path) + ) else: - command = quote_unix('cd %s && %s' % (cwd, script_path)) + command = quote_unix("cd %s && %s" % (cwd, script_path)) - sudo_arguments = ' '.join(self._get_common_sudo_arguments()) - command = 'sudo %s -- bash -c %s' % (sudo_arguments, command) + sudo_arguments = " ".join(self._get_common_sudo_arguments()) + command = "sudo %s -- bash -c %s" % (sudo_arguments, command) if self.sudo_password: - command = ('set +o history ; echo -e %s | %s' % - (quote_unix('%s\n' % (self.sudo_password)), command)) + command = "set +o history ; echo -e %s | %s" % ( + quote_unix("%s\n" % (self.sudo_password)), + command, + ) else: if script_arguments: if env_str: - command = '%s && cd %s && %s %s' % (env_str, cwd, - script_path, script_arguments) + command = "%s && cd %s && %s %s" % ( + env_str, + cwd, + script_path, + script_arguments, + ) else: - command = 'cd %s && %s %s' % (cwd, script_path, script_arguments) + command = "cd %s && %s %s" % (cwd, script_path, script_arguments) else: if env_str: - command = '%s && cd %s && %s' % (env_str, cwd, script_path) + command = "%s && cd %s && %s" % (env_str, cwd, script_path) else: - command = 'cd %s && %s' % (cwd, script_path) + command = "cd %s && %s" % (cwd, script_path) return command @@ -81,7 +91,7 @@ def _get_common_sudo_arguments(self): flags = [] if self.sudo_password: - flags.append('-S') + flags.append("-S") flags = flags + SUDO_COMMON_OPTIONS diff --git a/st2common/st2common/models/utils/action_alias_utils.py b/st2common/st2common/models/utils/action_alias_utils.py index 06106a27943..bf6d47c8b46 100644 --- a/st2common/st2common/models/utils/action_alias_utils.py +++ b/st2common/st2common/models/utils/action_alias_utils.py @@ -18,9 +18,15 @@ import re import sys -from sre_parse import ( # pylint: disable=E0611 - parse, AT, AT_BEGINNING, AT_BEGINNING_STRING, - AT_END, AT_END_STRING, BRANCH, SUBPATTERN, +from sre_parse import ( # pylint: disable=E0611 + parse, + AT, + AT_BEGINNING, + AT_BEGINNING_STRING, + AT_END, + AT_END_STRING, + BRANCH, + SUBPATTERN, ) from st2common.util.jinja import render_values @@ -30,11 +36,10 @@ from st2common import log __all__ = [ - 'ActionAliasFormatParser', - - 'extract_parameters_for_action_alias_db', - 'extract_parameters', - 'search_regex_tokens', + "ActionAliasFormatParser", + "extract_parameters_for_action_alias_db", + "extract_parameters", + "search_regex_tokens", ] @@ -48,10 +53,9 @@ class ActionAliasFormatParser(object): - def __init__(self, alias_format=None, param_stream=None): - self._format = alias_format or '' - self._original_param_stream = param_stream or '' + self._format = alias_format or "" + self._original_param_stream = param_stream or "" self._param_stream = self._original_param_stream self._snippets = self.generate_snippets() @@ -76,26 +80,26 @@ def generate_snippets(self): # Formats for keys and values: key is a non-spaced string, # value is anything in quotes or curly braces, or a single word. - snippets['key'] = r'\s*(\S+?)\s*' - snippets['value'] = r'""|\'\'|"(.+?)"|\'(.+?)\'|({.+?})|(\S+)' + snippets["key"] = r"\s*(\S+?)\s*" + snippets["value"] = r'""|\'\'|"(.+?)"|\'(.+?)\'|({.+?})|(\S+)' # Extended value: also matches unquoted text (caution). - snippets['ext_value'] = r'""|\'\'|"(.+?)"|\'(.+?)\'|({.+?})|(.+?)' + snippets["ext_value"] = r'""|\'\'|"(.+?)"|\'(.+?)\'|({.+?})|(.+?)' # Key-value pair: - snippets['pairs'] = r'(?:^|\s+){key}=({value})'.format(**snippets) + snippets["pairs"] = r"(?:^|\s+){key}=({value})".format(**snippets) # End of string: multiple space-separated key-value pairs: - snippets['ending'] = r'.*?(({pairs}\s*)*)$'.format(**snippets) + snippets["ending"] = r".*?(({pairs}\s*)*)$".format(**snippets) # Default value in optional parameters: - snippets['default'] = r'\s*=\s*(?:{ext_value})\s*'.format(**snippets) + snippets["default"] = r"\s*=\s*(?:{ext_value})\s*".format(**snippets) # Optional parameter (has a default value): - snippets['optional'] = '{{' + snippets['key'] + snippets['default'] + '}}' + snippets["optional"] = "{{" + snippets["key"] + snippets["default"] + "}}" # Required parameter (no default value): - snippets['required'] = '{{' + snippets['key'] + '}}' + snippets["required"] = "{{" + snippets["key"] + "}}" return snippets @@ -105,11 +109,13 @@ def match_kv_pairs_at_end(self): # 1. Matching the arbitrary key-value pairs at the end of the command # to support extra parameters (not specified in the format string), # and cutting them from the command string afterwards. - ending_pairs = re.match(self._snippets['ending'], param_stream, re.DOTALL) + ending_pairs = re.match(self._snippets["ending"], param_stream, re.DOTALL) has_ending_pairs = ending_pairs and ending_pairs.group(1) if has_ending_pairs: - kv_pairs = re.findall(self._snippets['pairs'], ending_pairs.group(1), re.DOTALL) - param_stream = param_stream.replace(ending_pairs.group(1), '') + kv_pairs = re.findall( + self._snippets["pairs"], ending_pairs.group(1), re.DOTALL + ) + param_stream = param_stream.replace(ending_pairs.group(1), "") else: kv_pairs = [] param_stream = " %s " % (param_stream) @@ -118,27 +124,36 @@ def match_kv_pairs_at_end(self): def generate_optional_params_regex(self): # 2. Matching optional parameters (with default values). - return re.findall(self._snippets['optional'], self._format, re.DOTALL) + return re.findall(self._snippets["optional"], self._format, re.DOTALL) def transform_format_string_into_regex(self): # 3. Convert the mangled format string into a regex object # Transforming our format string into a regular expression, # substituting {{ ... }} with regex named groups, so that param_stream # matched against this expression yields a dict of params with values. - param_match = r'\1["\']?(?P<\2>(?:(?<=\').+?(?=\')|(?<=").+?(?=")|{.+?}|.+?))["\']?' - reg = re.sub(r'(\s*)' + self._snippets['optional'], r'(?:' + param_match + r')?', - self._format) - reg = re.sub(r'(\s*)' + self._snippets['required'], param_match, reg) + param_match = ( + r'\1["\']?(?P<\2>(?:(?<=\').+?(?=\')|(?<=").+?(?=")|{.+?}|.+?))["\']?' + ) + reg = re.sub( + r"(\s*)" + self._snippets["optional"], + r"(?:" + param_match + r")?", + self._format, + ) + reg = re.sub(r"(\s*)" + self._snippets["required"], param_match, reg) reg_tokens = parse(reg, flags=re.DOTALL) # Add a beginning anchor if none exists - if not search_regex_tokens(((AT, AT_BEGINNING), (AT, AT_BEGINNING_STRING)), reg_tokens): - reg = r'^\s*' + reg + if not search_regex_tokens( + ((AT, AT_BEGINNING), (AT, AT_BEGINNING_STRING)), reg_tokens + ): + reg = r"^\s*" + reg # Add an ending anchor if none exists - if not search_regex_tokens(((AT, AT_END), (AT, AT_END_STRING)), reg_tokens, backwards=True): - reg = reg + r'\s*$' + if not search_regex_tokens( + ((AT, AT_END), (AT, AT_END_STRING)), reg_tokens, backwards=True + ): + reg = reg + r"\s*$" return re.compile(reg, re.DOTALL) @@ -147,8 +162,10 @@ def match_params_in_stream(self, matched_stream): if not matched_stream: # If no match is found we throw since this indicates provided user string (command) # didn't match the provided format string - raise ParseException('Command "%s" doesn\'t match format string "%s"' % - (self._original_param_stream, self._format)) + raise ParseException( + 'Command "%s" doesn\'t match format string "%s"' + % (self._original_param_stream, self._format) + ) # Compiling results from the steps 1-3. if matched_stream: @@ -157,16 +174,16 @@ def match_params_in_stream(self, matched_stream): # Apply optional parameters/add the default parameters for param in self._optional: matched_value = result[param[0]] if matched_stream else None - matched_result = matched_value or ''.join(param[1:]) + matched_result = matched_value or "".join(param[1:]) if matched_result is not None: result[param[0]] = matched_result # Apply given parameters for pair in self._kv_pairs: - result[pair[0]] = ''.join(pair[2:]) + result[pair[0]] = "".join(pair[2:]) if self._format and not (self._param_stream.strip() or any(result.values())): - raise ParseException('No value supplied and no default value found.') + raise ParseException("No value supplied and no default value found.") return result @@ -196,8 +213,9 @@ def get_multiple_extracted_param_value(self): return results -def extract_parameters_for_action_alias_db(action_alias_db, format_str, param_stream, - match_multiple=False): +def extract_parameters_for_action_alias_db( + action_alias_db, format_str, param_stream, match_multiple=False +): """ Extract parameters from the user input based on the provided format string. @@ -208,13 +226,14 @@ def extract_parameters_for_action_alias_db(action_alias_db, format_str, param_st formats = action_alias_db.get_format_strings() if format_str not in formats: - raise ValueError('Format string "%s" is not available on the alias "%s"' % - (format_str, action_alias_db.name)) + raise ValueError( + 'Format string "%s" is not available on the alias "%s"' + % (format_str, action_alias_db.name) + ) result = extract_parameters( - format_str=format_str, - param_stream=param_stream, - match_multiple=match_multiple) + format_str=format_str, param_stream=param_stream, match_multiple=match_multiple + ) return result @@ -226,7 +245,9 @@ def extract_parameters(format_str, param_stream, match_multiple=False): return parser.get_extracted_param_value() -def inject_immutable_parameters(action_alias_db, multiple_execution_parameters, action_context): +def inject_immutable_parameters( + action_alias_db, multiple_execution_parameters, action_context +): """ Inject immutable parameters from the alias definiton on the execution parameters. Jinja expressions will be resolved. @@ -235,26 +256,34 @@ def inject_immutable_parameters(action_alias_db, multiple_execution_parameters, if not immutable_parameters: return multiple_execution_parameters - user = action_context.get('user', None) + user = action_context.get("user", None) context = {} - context.update({ - kv_constants.DATASTORE_PARENT_SCOPE: { - kv_constants.SYSTEM_SCOPE: kv_service.KeyValueLookup( - scope=kv_constants.FULL_SYSTEM_SCOPE), - kv_constants.USER_SCOPE: kv_service.UserKeyValueLookup( - scope=kv_constants.FULL_USER_SCOPE, user=user) + context.update( + { + kv_constants.DATASTORE_PARENT_SCOPE: { + kv_constants.SYSTEM_SCOPE: kv_service.KeyValueLookup( + scope=kv_constants.FULL_SYSTEM_SCOPE + ), + kv_constants.USER_SCOPE: kv_service.UserKeyValueLookup( + scope=kv_constants.FULL_USER_SCOPE, user=user + ), + } } - }) + ) context.update(action_context) rendered_params = render_values(immutable_parameters, context) for exec_params in multiple_execution_parameters: - overriden = [param for param in immutable_parameters.keys() if param in exec_params] + overriden = [ + param for param in immutable_parameters.keys() if param in exec_params + ] if overriden: raise ValueError( "Immutable arguments cannot be overriden: {}".format( - ','.join(overriden))) + ",".join(overriden) + ) + ) exec_params.update(rendered_params) diff --git a/st2common/st2common/models/utils/action_param_utils.py b/st2common/st2common/models/utils/action_param_utils.py index 1ecf6dbbe88..3edbeae6edb 100644 --- a/st2common/st2common/models/utils/action_param_utils.py +++ b/st2common/st2common/models/utils/action_param_utils.py @@ -33,7 +33,7 @@ def _merge_param_meta_values(action_meta=None, runner_meta=None): merged_meta = {} # ?? Runner immutable param's meta shouldn't be allowed to be modified by action whatsoever. - if runner_meta and runner_meta.get('immutable', False): + if runner_meta and runner_meta.get("immutable", False): merged_meta = runner_meta for key in all_keys: @@ -42,8 +42,10 @@ def _merge_param_meta_values(action_meta=None, runner_meta=None): elif key in runner_meta_keys and key not in action_meta_keys: merged_meta[key] = runner_meta[key] else: - if key in ['immutable']: - merged_meta[key] = runner_meta.get(key, False) or action_meta.get(key, False) + if key in ["immutable"]: + merged_meta[key] = runner_meta.get(key, False) or action_meta.get( + key, False + ) else: merged_meta[key] = action_meta.get(key) return merged_meta @@ -51,12 +53,12 @@ def _merge_param_meta_values(action_meta=None, runner_meta=None): def get_params_view(action_db=None, runner_db=None, merged_only=False): if runner_db: - runner_params = fast_deepcopy(getattr(runner_db, 'runner_parameters', {})) or {} + runner_params = fast_deepcopy(getattr(runner_db, "runner_parameters", {})) or {} else: runner_params = {} if action_db: - action_params = fast_deepcopy(getattr(action_db, 'parameters', {})) or {} + action_params = fast_deepcopy(getattr(action_db, "parameters", {})) or {} else: action_params = {} @@ -64,19 +66,22 @@ def get_params_view(action_db=None, runner_db=None, merged_only=False): merged_params = {} for param in parameters: - merged_params[param] = _merge_param_meta_values(action_meta=action_params.get(param), - runner_meta=runner_params.get(param)) + merged_params[param] = _merge_param_meta_values( + action_meta=action_params.get(param), runner_meta=runner_params.get(param) + ) if merged_only: return merged_params def is_required(param_meta): - return param_meta.get('required', False) + return param_meta.get("required", False) def is_immutable(param_meta): - return param_meta.get('immutable', False) + return param_meta.get("immutable", False) - immutable = {param for param in parameters if is_immutable(merged_params.get(param))} + immutable = { + param for param in parameters if is_immutable(merged_params.get(param)) + } required = {param for param in parameters if is_required(merged_params.get(param))} required = required - immutable optional = parameters - required - immutable @@ -89,8 +94,7 @@ def is_immutable(param_meta): def cast_params(action_ref, params, cast_overrides=None): - """ - """ + """""" params = params or {} action_db = action_db_util.get_action_by_ref(action_ref) @@ -98,7 +102,7 @@ def cast_params(action_ref, params, cast_overrides=None): raise ValueError('Action with ref "%s" doesn\'t exist' % (action_ref)) action_parameters_schema = action_db.parameters - runnertype_db = action_db_util.get_runnertype_by_name(action_db.runner_type['name']) + runnertype_db = action_db_util.get_runnertype_by_name(action_db.runner_type["name"]) runner_parameters_schema = runnertype_db.runner_parameters # combine into 1 list of parameter schemas parameters_schema = {} @@ -110,29 +114,37 @@ def cast_params(action_ref, params, cast_overrides=None): for k, v in six.iteritems(params): parameter_schema = parameters_schema.get(k, None) if not parameter_schema: - LOG.debug('Will skip cast of param[name: %s, value: %s]. No schema.', k, v) + LOG.debug("Will skip cast of param[name: %s, value: %s]. No schema.", k, v) continue - parameter_type = parameter_schema.get('type', None) + parameter_type = parameter_schema.get("type", None) if not parameter_type: - LOG.debug('Will skip cast of param[name: %s, value: %s]. No type.', k, v) + LOG.debug("Will skip cast of param[name: %s, value: %s]. No type.", k, v) continue # Pick up cast from teh override and then from the system suppied ones. cast = cast_overrides.get(parameter_type, None) if cast_overrides else None if not cast: cast = get_cast(cast_type=parameter_type) if not cast: - LOG.debug('Will skip cast of param[name: %s, value: %s]. No cast for %s.', k, v, - parameter_type) + LOG.debug( + "Will skip cast of param[name: %s, value: %s]. No cast for %s.", + k, + v, + parameter_type, + ) continue - LOG.debug('Casting param: %s of type %s to type: %s', v, type(v), parameter_type) + LOG.debug( + "Casting param: %s of type %s to type: %s", v, type(v), parameter_type + ) try: params[k] = cast(v) except Exception as e: v_type = type(v).__name__ - msg = ('Failed to cast value "%s" (type: %s) for parameter "%s" of type "%s": %s. ' - 'Perhaps the value is of an invalid type?' % - (v, v_type, k, parameter_type, six.text_type(e))) + msg = ( + 'Failed to cast value "%s" (type: %s) for parameter "%s" of type "%s": %s. ' + "Perhaps the value is of an invalid type?" + % (v, v_type, k, parameter_type, six.text_type(e)) + ) raise ValueError(msg) return params @@ -145,8 +157,13 @@ def validate_action_parameters(action_ref, inputs): parameters = action_db_util.get_action_parameters_specs(action_ref) # Check required parameters that have no default defined. - required = set([param for param, meta in six.iteritems(parameters) - if meta.get('required', False) and 'default' not in meta]) + required = set( + [ + param + for param, meta in six.iteritems(parameters) + if meta.get("required", False) and "default" not in meta + ] + ) requires = sorted(required.difference(input_set)) diff --git a/st2common/st2common/models/utils/profiling.py b/st2common/st2common/models/utils/profiling.py index c9d26636b0b..47add2adc3e 100644 --- a/st2common/st2common/models/utils/profiling.py +++ b/st2common/st2common/models/utils/profiling.py @@ -23,10 +23,10 @@ from st2common import log as logging __all__ = [ - 'enable_profiling', - 'disable_profiling', - 'is_enabled', - 'log_query_and_profile_data_for_queryset' + "enable_profiling", + "disable_profiling", + "is_enabled", + "log_query_and_profile_data_for_queryset", ] LOG = logging.getLogger(__name__) @@ -72,13 +72,13 @@ def log_query_and_profile_data_for_queryset(queryset): # Note: Some mongoengine methods don't return queryset (e.g. count) return queryset - query = getattr(queryset, '_query', None) - mongo_query = getattr(queryset, '_mongo_query', query) - ordering = getattr(queryset, '_ordering', None) - limit = getattr(queryset, '_limit', None) - collection = getattr(queryset, '_collection', None) - collection_name = getattr(collection, 'name', None) - only_fields = getattr(queryset, 'only_fields', None) + query = getattr(queryset, "_query", None) + mongo_query = getattr(queryset, "_mongo_query", query) + ordering = getattr(queryset, "_ordering", None) + limit = getattr(queryset, "_limit", None) + collection = getattr(queryset, "_collection", None) + collection_name = getattr(collection, "name", None) + only_fields = getattr(queryset, "only_fields", None) # Note: We need to clone the queryset when using explain because explain advances the cursor # internally which changes the function result @@ -86,42 +86,46 @@ def log_query_and_profile_data_for_queryset(queryset): explain_info = cloned_queryset.explain(format=True) if mongo_query is not None and collection_name is not None: - mongo_shell_query = construct_mongo_shell_query(mongo_query=mongo_query, - collection_name=collection_name, - ordering=ordering, - limit=limit, - only_fields=only_fields) - extra = {'mongo_query': mongo_query, 'mongo_shell_query': mongo_shell_query} - LOG.debug('MongoDB query: %s' % (mongo_shell_query), extra=extra) - LOG.debug('MongoDB explain data: %s' % (explain_info)) + mongo_shell_query = construct_mongo_shell_query( + mongo_query=mongo_query, + collection_name=collection_name, + ordering=ordering, + limit=limit, + only_fields=only_fields, + ) + extra = {"mongo_query": mongo_query, "mongo_shell_query": mongo_shell_query} + LOG.debug("MongoDB query: %s" % (mongo_shell_query), extra=extra) + LOG.debug("MongoDB explain data: %s" % (explain_info)) return queryset -def construct_mongo_shell_query(mongo_query, collection_name, ordering, limit, - only_fields=None): +def construct_mongo_shell_query( + mongo_query, collection_name, ordering, limit, only_fields=None +): result = [] # Select collection - part = 'db.{collection}'.format(collection=collection_name) + part = "db.{collection}".format(collection=collection_name) result.append(part) # Include filters (if any) if mongo_query: filter_predicate = mongo_query else: - filter_predicate = '' + filter_predicate = "" - part = 'find({filter_predicate})'.format(filter_predicate=filter_predicate) + part = "find({filter_predicate})".format(filter_predicate=filter_predicate) # Include only fields (projection) if only_fields: - projection_items = ['\'%s\': 1' % (field) for field in only_fields] - projection = ', '.join(projection_items) - part = 'find({filter_predicate}, {{{projection}}})'.format( - filter_predicate=filter_predicate, projection=projection) + projection_items = ["'%s': 1" % (field) for field in only_fields] + projection = ", ".join(projection_items) + part = "find({filter_predicate}, {{{projection}}})".format( + filter_predicate=filter_predicate, projection=projection + ) else: - part = 'find({filter_predicate})'.format(filter_predicate=filter_predicate) + part = "find({filter_predicate})".format(filter_predicate=filter_predicate) result.append(part) @@ -129,17 +133,18 @@ def construct_mongo_shell_query(mongo_query, collection_name, ordering, limit, if ordering: sort_predicate = [] for field_name, direction in ordering: - sort_predicate.append('{name}: {direction}'.format(name=field_name, - direction=direction)) + sort_predicate.append( + "{name}: {direction}".format(name=field_name, direction=direction) + ) - sort_predicate = ', '.join(sort_predicate) - part = 'sort({{{sort_predicate}}})'.format(sort_predicate=sort_predicate) + sort_predicate = ", ".join(sort_predicate) + part = "sort({{{sort_predicate}}})".format(sort_predicate=sort_predicate) result.append(part) # Include limit info (if any) if limit is not None: - part = 'limit({limit})'.format(limit=limit) + part = "limit({limit})".format(limit=limit) result.append(part) - result = '.'.join(result) + ';' + result = ".".join(result) + ";" return result diff --git a/st2common/st2common/models/utils/sensor_type_utils.py b/st2common/st2common/models/utils/sensor_type_utils.py index f67a65e530c..cd4b068db9b 100644 --- a/st2common/st2common/models/utils/sensor_type_utils.py +++ b/st2common/st2common/models/utils/sensor_type_utils.py @@ -21,11 +21,7 @@ from st2common.models.db.sensor import SensorTypeDB from st2common.services import triggers as trigger_service -__all__ = [ - 'to_sensor_db_model', - 'get_sensor_entry_point', - 'create_trigger_types' -] +__all__ = ["to_sensor_db_model", "get_sensor_entry_point", "create_trigger_types"] def to_sensor_db_model(sensor_api_model=None): @@ -38,37 +34,40 @@ def to_sensor_db_model(sensor_api_model=None): :rtype: :class:`SensorTypeDB` """ - class_name = getattr(sensor_api_model, 'class_name', None) - pack = getattr(sensor_api_model, 'pack', None) + class_name = getattr(sensor_api_model, "class_name", None) + pack = getattr(sensor_api_model, "pack", None) entry_point = get_sensor_entry_point(sensor_api_model) - artifact_uri = getattr(sensor_api_model, 'artifact_uri', None) - description = getattr(sensor_api_model, 'description', None) - trigger_types = getattr(sensor_api_model, 'trigger_types', []) - poll_interval = getattr(sensor_api_model, 'poll_interval', None) - enabled = getattr(sensor_api_model, 'enabled', True) - metadata_file = getattr(sensor_api_model, 'metadata_file', None) - - poll_interval = getattr(sensor_api_model, 'poll_interval', None) + artifact_uri = getattr(sensor_api_model, "artifact_uri", None) + description = getattr(sensor_api_model, "description", None) + trigger_types = getattr(sensor_api_model, "trigger_types", []) + poll_interval = getattr(sensor_api_model, "poll_interval", None) + enabled = getattr(sensor_api_model, "enabled", True) + metadata_file = getattr(sensor_api_model, "metadata_file", None) + + poll_interval = getattr(sensor_api_model, "poll_interval", None) if poll_interval and (poll_interval < MINIMUM_POLL_INTERVAL): - raise ValueError('Minimum possible poll_interval is %s seconds' % - (MINIMUM_POLL_INTERVAL)) + raise ValueError( + "Minimum possible poll_interval is %s seconds" % (MINIMUM_POLL_INTERVAL) + ) # Add pack and metadata fileto each trigger type item for trigger_type in trigger_types: - trigger_type['pack'] = pack - trigger_type['metadata_file'] = metadata_file + trigger_type["pack"] = pack + trigger_type["metadata_file"] = metadata_file trigger_type_refs = create_trigger_types(trigger_types) - return _create_sensor_type(pack=pack, - name=class_name, - description=description, - artifact_uri=artifact_uri, - entry_point=entry_point, - trigger_types=trigger_type_refs, - poll_interval=poll_interval, - enabled=enabled, - metadata_file=metadata_file) + return _create_sensor_type( + pack=pack, + name=class_name, + description=description, + artifact_uri=artifact_uri, + entry_point=entry_point, + trigger_types=trigger_type_refs, + poll_interval=poll_interval, + enabled=enabled, + metadata_file=metadata_file, + ) def create_trigger_types(trigger_types, metadata_file=None): @@ -87,29 +86,44 @@ def create_trigger_types(trigger_types, metadata_file=None): return trigger_type_refs -def _create_sensor_type(pack=None, name=None, description=None, artifact_uri=None, - entry_point=None, trigger_types=None, poll_interval=10, - enabled=True, metadata_file=None): - - sensor_type = SensorTypeDB(pack=pack, name=name, description=description, - artifact_uri=artifact_uri, entry_point=entry_point, - poll_interval=poll_interval, enabled=enabled, - trigger_types=trigger_types, metadata_file=metadata_file) +def _create_sensor_type( + pack=None, + name=None, + description=None, + artifact_uri=None, + entry_point=None, + trigger_types=None, + poll_interval=10, + enabled=True, + metadata_file=None, +): + + sensor_type = SensorTypeDB( + pack=pack, + name=name, + description=description, + artifact_uri=artifact_uri, + entry_point=entry_point, + poll_interval=poll_interval, + enabled=enabled, + trigger_types=trigger_types, + metadata_file=metadata_file, + ) return sensor_type def get_sensor_entry_point(sensor_api_model): - file_path = getattr(sensor_api_model, 'artifact_uri', None) - class_name = getattr(sensor_api_model, 'class_name', None) - pack = getattr(sensor_api_model, 'pack', None) + file_path = getattr(sensor_api_model, "artifact_uri", None) + class_name = getattr(sensor_api_model, "class_name", None) + pack = getattr(sensor_api_model, "pack", None) if pack == SYSTEM_PACK_NAME: # Special case for sensors which come included with the default installation entry_point = class_name else: - module_path = file_path.split('/%s/' % (pack))[1] - module_path = module_path.replace(os.path.sep, '.') - module_path = module_path.replace('.py', '') - entry_point = '%s.%s' % (module_path, class_name) + module_path = file_path.split("/%s/" % (pack))[1] + module_path = module_path.replace(os.path.sep, ".") + module_path = module_path.replace(".py", "") + entry_point = "%s.%s" % (module_path, class_name) return entry_point diff --git a/st2common/st2common/operators.py b/st2common/st2common/operators.py index fc38d632155..6896e876586 100644 --- a/st2common/st2common/operators.py +++ b/st2common/st2common/operators.py @@ -24,10 +24,10 @@ from st2common.util.payload import PayloadLookup __all__ = [ - 'SEARCH', - 'get_operator', - 'get_allowed_operators', - 'UnrecognizedConditionError', + "SEARCH", + "get_operator", + "get_allowed_operators", + "UnrecognizedConditionError", ] @@ -40,7 +40,7 @@ def get_operator(op): if op in operators: return operators[op] else: - raise Exception('Invalid operator: ' + op) + raise Exception("Invalid operator: " + op) class UnrecognizedConditionError(Exception): @@ -106,35 +106,57 @@ def search(value, criteria_pattern, criteria_condition, check_function): type: "equals" pattern: "Approved" """ - if criteria_condition == 'any': + if criteria_condition == "any": # Any item of the list can match all patterns - rtn = any([ - # Any payload item can match - all([ - # Match all patterns - check_function( - child_criterion_k, child_criterion_v, - PayloadLookup(child_payload, prefix=TRIGGER_ITEM_PAYLOAD_PREFIX)) - for child_criterion_k, child_criterion_v in six.iteritems(criteria_pattern) - ]) - for child_payload in value - ]) - elif criteria_condition == 'all': + rtn = any( + [ + # Any payload item can match + all( + [ + # Match all patterns + check_function( + child_criterion_k, + child_criterion_v, + PayloadLookup( + child_payload, prefix=TRIGGER_ITEM_PAYLOAD_PREFIX + ), + ) + for child_criterion_k, child_criterion_v in six.iteritems( + criteria_pattern + ) + ] + ) + for child_payload in value + ] + ) + elif criteria_condition == "all": # Every item of the list must match all patterns - rtn = all([ - # All payload items must match - all([ - # Match all patterns - check_function( - child_criterion_k, child_criterion_v, - PayloadLookup(child_payload, prefix=TRIGGER_ITEM_PAYLOAD_PREFIX)) - for child_criterion_k, child_criterion_v in six.iteritems(criteria_pattern) - ]) - for child_payload in value - ]) + rtn = all( + [ + # All payload items must match + all( + [ + # Match all patterns + check_function( + child_criterion_k, + child_criterion_v, + PayloadLookup( + child_payload, prefix=TRIGGER_ITEM_PAYLOAD_PREFIX + ), + ) + for child_criterion_k, child_criterion_v in six.iteritems( + criteria_pattern + ) + ] + ) + for child_payload in value + ] + ) else: - raise UnrecognizedConditionError("The '%s' search condition is not recognized, only 'any' " - "and 'all' are allowed" % criteria_condition) + raise UnrecognizedConditionError( + "The '%s' search condition is not recognized, only 'any' " + "and 'all' are allowed" % criteria_condition + ) return rtn @@ -298,13 +320,17 @@ def _timediff(diff_target, period_seconds, operator): def timediff_lt(value, criteria_pattern): if criteria_pattern is None: return False - return _timediff(diff_target=value, period_seconds=criteria_pattern, operator=less_than) + return _timediff( + diff_target=value, period_seconds=criteria_pattern, operator=less_than + ) def timediff_gt(value, criteria_pattern): if criteria_pattern is None: return False - return _timediff(diff_target=value, period_seconds=criteria_pattern, operator=greater_than) + return _timediff( + diff_target=value, period_seconds=criteria_pattern, operator=greater_than + ) def exists(value, criteria_pattern): @@ -344,48 +370,48 @@ def ensure_operators_are_strings(value, criteria_pattern): :return: tuple(value, criteria_pattern) """ if isinstance(value, bytes): - value = value.decode('utf-8') + value = value.decode("utf-8") if isinstance(criteria_pattern, bytes): - criteria_pattern = criteria_pattern.decode('utf-8') + criteria_pattern = criteria_pattern.decode("utf-8") return value, criteria_pattern # operator match strings -MATCH_WILDCARD = 'matchwildcard' -MATCH_REGEX = 'matchregex' -REGEX = 'regex' -IREGEX = 'iregex' -EQUALS_SHORT = 'eq' -EQUALS_LONG = 'equals' -NEQUALS_LONG = 'nequals' -NEQUALS_SHORT = 'neq' -IEQUALS_SHORT = 'ieq' -IEQUALS_LONG = 'iequals' -CONTAINS_LONG = 'contains' -ICONTAINS_LONG = 'icontains' -NCONTAINS_LONG = 'ncontains' -INCONTAINS_LONG = 'incontains' -STARTSWITH_LONG = 'startswith' -ISTARTSWITH_LONG = 'istartswith' -ENDSWITH_LONG = 'endswith' -IENDSWITH_LONG = 'iendswith' -LESS_THAN_SHORT = 'lt' -LESS_THAN_LONG = 'lessthan' -GREATER_THAN_SHORT = 'gt' -GREATER_THAN_LONG = 'greaterthan' -TIMEDIFF_LT_SHORT = 'td_lt' -TIMEDIFF_LT_LONG = 'timediff_lt' -TIMEDIFF_GT_SHORT = 'td_gt' -TIMEDIFF_GT_LONG = 'timediff_gt' -KEY_EXISTS = 'exists' -KEY_NOT_EXISTS = 'nexists' -INSIDE_LONG = 'inside' -INSIDE_SHORT = 'in' -NINSIDE_LONG = 'ninside' -NINSIDE_SHORT = 'nin' -SEARCH = 'search' +MATCH_WILDCARD = "matchwildcard" +MATCH_REGEX = "matchregex" +REGEX = "regex" +IREGEX = "iregex" +EQUALS_SHORT = "eq" +EQUALS_LONG = "equals" +NEQUALS_LONG = "nequals" +NEQUALS_SHORT = "neq" +IEQUALS_SHORT = "ieq" +IEQUALS_LONG = "iequals" +CONTAINS_LONG = "contains" +ICONTAINS_LONG = "icontains" +NCONTAINS_LONG = "ncontains" +INCONTAINS_LONG = "incontains" +STARTSWITH_LONG = "startswith" +ISTARTSWITH_LONG = "istartswith" +ENDSWITH_LONG = "endswith" +IENDSWITH_LONG = "iendswith" +LESS_THAN_SHORT = "lt" +LESS_THAN_LONG = "lessthan" +GREATER_THAN_SHORT = "gt" +GREATER_THAN_LONG = "greaterthan" +TIMEDIFF_LT_SHORT = "td_lt" +TIMEDIFF_LT_LONG = "timediff_lt" +TIMEDIFF_GT_SHORT = "td_gt" +TIMEDIFF_GT_LONG = "timediff_gt" +KEY_EXISTS = "exists" +KEY_NOT_EXISTS = "nexists" +INSIDE_LONG = "inside" +INSIDE_SHORT = "in" +NINSIDE_LONG = "ninside" +NINSIDE_SHORT = "nin" +SEARCH = "search" # operator lookups operators = { diff --git a/st2common/st2common/persistence/action.py b/st2common/st2common/persistence/action.py index 0a91fc5cefc..1f3d17ee011 100644 --- a/st2common/st2common/persistence/action.py +++ b/st2common/st2common/persistence/action.py @@ -23,12 +23,12 @@ from st2common.persistence.runner import RunnerType __all__ = [ - 'Action', - 'ActionAlias', - 'ActionExecution', - 'ActionExecutionState', - 'LiveAction', - 'RunnerType' + "Action", + "ActionAlias", + "ActionExecution", + "ActionExecutionState", + "LiveAction", + "RunnerType", ] diff --git a/st2common/st2common/persistence/auth.py b/st2common/st2common/persistence/auth.py index f03e3ab4e17..51f0a59ea13 100644 --- a/st2common/st2common/persistence/auth.py +++ b/st2common/st2common/persistence/auth.py @@ -14,9 +14,13 @@ # limitations under the License. from __future__ import absolute_import -from st2common.exceptions.auth import (TokenNotFoundError, ApiKeyNotFoundError, - UserNotFoundError, AmbiguousUserError, - NoNicknameOriginProvidedError) +from st2common.exceptions.auth import ( + TokenNotFoundError, + ApiKeyNotFoundError, + UserNotFoundError, + AmbiguousUserError, + NoNicknameOriginProvidedError, +) from st2common.models.db import MongoDBAccess from st2common.models.db.auth import UserDB, TokenDB, ApiKeyDB from st2common.persistence.base import Access @@ -35,7 +39,7 @@ def get_by_nickname(cls, nickname, origin): if not origin: raise NoNicknameOriginProvidedError() - result = cls.query(**{('nicknames__%s' % origin): nickname}) + result = cls.query(**{("nicknames__%s" % origin): nickname}) if not result.first(): raise UserNotFoundError() @@ -51,7 +55,7 @@ def _get_impl(cls): @classmethod def _get_by_object(cls, object): # For User name is unique. - name = getattr(object, 'name', '') + name = getattr(object, "name", "") return cls.get_by_name(name) @@ -64,13 +68,15 @@ def _get_impl(cls): @classmethod def add_or_update(cls, model_object, publish=True, validate=True): - if not getattr(model_object, 'user', None): - raise ValueError('User is not provided in the token.') - if not getattr(model_object, 'token', None): - raise ValueError('Token value is not set.') - if not getattr(model_object, 'expiry', None): - raise ValueError('Token expiry is not provided in the token.') - return super(Token, cls).add_or_update(model_object, publish=publish, validate=validate) + if not getattr(model_object, "user", None): + raise ValueError("User is not provided in the token.") + if not getattr(model_object, "token", None): + raise ValueError("Token value is not set.") + if not getattr(model_object, "expiry", None): + raise ValueError("Token expiry is not provided in the token.") + return super(Token, cls).add_or_update( + model_object, publish=publish, validate=validate + ) @classmethod def get(cls, value): @@ -96,7 +102,7 @@ def get(cls, value): result = cls.query(key_hash=value_hash).first() if not result: - raise ApiKeyNotFoundError('ApiKey with key_hash=%s not found.' % value_hash) + raise ApiKeyNotFoundError("ApiKey with key_hash=%s not found." % value_hash) return result @@ -109,4 +115,4 @@ def get_by_key_or_id(cls, value): try: return cls.get_by_id(value) except: - raise ApiKeyNotFoundError('ApiKey with key or id=%s not found.' % value) + raise ApiKeyNotFoundError("ApiKey with key or id=%s not found." % value) diff --git a/st2common/st2common/persistence/base.py b/st2common/st2common/persistence/base.py index ea1325762f3..a477defe494 100644 --- a/st2common/st2common/persistence/base.py +++ b/st2common/st2common/persistence/base.py @@ -23,12 +23,7 @@ from st2common.models.system.common import ResourceReference -__all__ = [ - 'Access', - - 'ContentPackResource', - 'StatusBasedResource' -] +__all__ = ["Access", "ContentPackResource", "StatusBasedResource"] LOG = logging.getLogger(__name__) @@ -123,48 +118,60 @@ def aggregate(cls, *args, **kwargs): return cls._get_impl().aggregate(*args, **kwargs) @classmethod - def insert(cls, model_object, publish=True, dispatch_trigger=True, - log_not_unique_error_as_debug=False): + def insert( + cls, + model_object, + publish=True, + dispatch_trigger=True, + log_not_unique_error_as_debug=False, + ): # Late import to avoid very expensive in-direct import (~1 second) when this function # is not called / used from mongoengine import NotUniqueError if model_object.id: - raise ValueError('id for object %s was unexpected.' % model_object) + raise ValueError("id for object %s was unexpected." % model_object) try: model_object = cls._get_impl().insert(model_object) except NotUniqueError as e: if log_not_unique_error_as_debug: - LOG.debug('Conflict while trying to save in DB: %s.', six.text_type(e)) + LOG.debug("Conflict while trying to save in DB: %s.", six.text_type(e)) else: - LOG.exception('Conflict while trying to save in DB.') + LOG.exception("Conflict while trying to save in DB.") # On a conflict determine the conflicting object and return its id in # the raised exception. conflict_object = cls._get_by_object(model_object) conflict_id = str(conflict_object.id) if conflict_object else None message = six.text_type(e) - raise StackStormDBObjectConflictError(message=message, conflict_id=conflict_id, - model_object=model_object) + raise StackStormDBObjectConflictError( + message=message, conflict_id=conflict_id, model_object=model_object + ) # Publish internal event on the message bus if publish: try: cls.publish_create(model_object) except: - LOG.exception('Publish failed.') + LOG.exception("Publish failed.") # Dispatch trigger if dispatch_trigger: try: cls.dispatch_create_trigger(model_object) except: - LOG.exception('Trigger dispatch failed.') + LOG.exception("Trigger dispatch failed.") return model_object @classmethod - def add_or_update(cls, model_object, publish=True, dispatch_trigger=True, validate=True, - log_not_unique_error_as_debug=False): + def add_or_update( + cls, + model_object, + publish=True, + dispatch_trigger=True, + validate=True, + log_not_unique_error_as_debug=False, + ): # Late import to avoid very expensive in-direct import (~1 second) when this function # is not called / used from mongoengine import NotUniqueError @@ -174,16 +181,17 @@ def add_or_update(cls, model_object, publish=True, dispatch_trigger=True, valida model_object = cls._get_impl().add_or_update(model_object, validate=True) except NotUniqueError as e: if log_not_unique_error_as_debug: - LOG.debug('Conflict while trying to save in DB: %s.', six.text_type(e)) + LOG.debug("Conflict while trying to save in DB: %s.", six.text_type(e)) else: - LOG.exception('Conflict while trying to save in DB.') + LOG.exception("Conflict while trying to save in DB.") # On a conflict determine the conflicting object and return its id in # the raised exception. conflict_object = cls._get_by_object(model_object) conflict_id = str(conflict_object.id) if conflict_object else None message = six.text_type(e) - raise StackStormDBObjectConflictError(message=message, conflict_id=conflict_id, - model_object=model_object) + raise StackStormDBObjectConflictError( + message=message, conflict_id=conflict_id, model_object=model_object + ) is_update = str(pre_persist_id) == str(model_object.id) @@ -195,7 +203,7 @@ def add_or_update(cls, model_object, publish=True, dispatch_trigger=True, valida else: cls.publish_create(model_object) except: - LOG.exception('Publish failed.') + LOG.exception("Publish failed.") # Dispatch trigger if dispatch_trigger: @@ -205,7 +213,7 @@ def add_or_update(cls, model_object, publish=True, dispatch_trigger=True, valida else: cls.dispatch_create_trigger(model_object) except: - LOG.exception('Trigger dispatch failed.') + LOG.exception("Trigger dispatch failed.") return model_object @@ -227,14 +235,14 @@ def update(cls, model_object, publish=True, dispatch_trigger=True, **kwargs): try: cls.publish_update(model_object) except: - LOG.exception('Publish failed.') + LOG.exception("Publish failed.") # Dispatch trigger if dispatch_trigger: try: cls.dispatch_update_trigger(model_object) except: - LOG.exception('Trigger dispatch failed.') + LOG.exception("Trigger dispatch failed.") return model_object @@ -247,14 +255,14 @@ def delete(cls, model_object, publish=True, dispatch_trigger=True): try: cls.publish_delete(model_object) except Exception: - LOG.exception('Publish failed.') + LOG.exception("Publish failed.") # Dispatch trigger if dispatch_trigger: try: cls.dispatch_delete_trigger(model_object) except Exception: - LOG.exception('Trigger dispatch failed.') + LOG.exception("Trigger dispatch failed.") return persisted_object @@ -289,14 +297,18 @@ def dispatch_create_trigger(cls, model_object): """ Dispatch a resource-specific trigger which indicates a new resource has been created. """ - return cls._dispatch_operation_trigger(operation='create', model_object=model_object) + return cls._dispatch_operation_trigger( + operation="create", model_object=model_object + ) @classmethod def dispatch_update_trigger(cls, model_object): """ Dispatch a resource-specific trigger which indicates an existing resource has been updated. """ - return cls._dispatch_operation_trigger(operation='update', model_object=model_object) + return cls._dispatch_operation_trigger( + operation="update", model_object=model_object + ) @classmethod def dispatch_delete_trigger(cls, model_object): @@ -304,14 +316,18 @@ def dispatch_delete_trigger(cls, model_object): Dispatch a resource-specific trigger which indicates an existing resource has been deleted. """ - return cls._dispatch_operation_trigger(operation='delete', model_object=model_object) + return cls._dispatch_operation_trigger( + operation="delete", model_object=model_object + ) @classmethod def _get_trigger_ref_for_operation(cls, operation): trigger_ref = cls.operation_to_trigger_ref_map.get(operation, None) if not trigger_ref: - raise ValueError('Trigger ref not specified for operation: %s' % (operation)) + raise ValueError( + "Trigger ref not specified for operation: %s" % (operation) + ) return trigger_ref @@ -322,11 +338,13 @@ def _dispatch_operation_trigger(cls, operation, model_object): trigger = cls._get_trigger_ref_for_operation(operation=operation) - object_payload = cls.api_model_cls.from_model(model_object, mask_secrets=True).__json__() - payload = { - 'object': object_payload - } - return cls._dispatch_trigger(operation=operation, trigger=trigger, payload=payload) + object_payload = cls.api_model_cls.from_model( + model_object, mask_secrets=True + ).__json__() + payload = {"object": object_payload} + return cls._dispatch_trigger( + operation=operation, trigger=trigger, payload=payload + ) @classmethod def _dispatch_trigger(cls, operation, trigger, payload): @@ -338,23 +356,23 @@ def _dispatch_trigger(cls, operation, trigger, payload): class ContentPackResource(Access): - @classmethod def get_by_ref(cls, ref): if not ref: return None ref_obj = ResourceReference.from_string_reference(ref=ref) - result = cls.query(name=ref_obj.name, - pack=ref_obj.pack).first() + result = cls.query(name=ref_obj.name, pack=ref_obj.pack).first() return result @classmethod def _get_by_object(cls, object): # For an object with a resourcepack pack.name is unique. - name = getattr(object, 'name', '') - pack = getattr(object, 'pack', '') - return cls.get_by_ref(ResourceReference.to_string_reference(pack=pack, name=name)) + name = getattr(object, "name", "") + pack = getattr(object, "pack", "") + return cls.get_by_ref( + ResourceReference.to_string_reference(pack=pack, name=name) + ) class StatusBasedResource(Access): @@ -372,4 +390,4 @@ def publish_status(cls, model_object): """ publisher = cls._get_publisher() if publisher: - publisher.publish_state(model_object, getattr(model_object, 'status', None)) + publisher.publish_state(model_object, getattr(model_object, "status", None)) diff --git a/st2common/st2common/persistence/cleanup.py b/st2common/st2common/persistence/cleanup.py index 5831a47cca6..06c48dec867 100644 --- a/st2common/st2common/persistence/cleanup.py +++ b/st2common/st2common/persistence/cleanup.py @@ -24,11 +24,7 @@ from st2common.script_setup import setup as common_setup from st2common.script_setup import teardown as common_teardown -__all__ = [ - 'db_cleanup', - 'db_cleanup_with_retry', - 'main' -] +__all__ = ["db_cleanup", "db_cleanup_with_retry", "main"] LOG = logging.getLogger(__name__) @@ -42,26 +38,47 @@ def db_cleanup(): return connection -def db_cleanup_with_retry(db_name, db_host, db_port, username=None, password=None, - ssl=False, ssl_keyfile=None, - ssl_certfile=None, ssl_cert_reqs=None, ssl_ca_certs=None, - authentication_mechanism=None, ssl_match_hostname=True): +def db_cleanup_with_retry( + db_name, + db_host, + db_port, + username=None, + password=None, + ssl=False, + ssl_keyfile=None, + ssl_certfile=None, + ssl_cert_reqs=None, + ssl_ca_certs=None, + authentication_mechanism=None, + ssl_match_hostname=True, +): """ This method is a retry version of db_cleanup. """ - return db_func_with_retry(db_cleanup_func, - db_name, db_host, db_port, - username=username, password=password, - ssl=ssl, ssl_keyfile=ssl_keyfile, - ssl_certfile=ssl_certfile, ssl_cert_reqs=ssl_cert_reqs, - ssl_ca_certs=ssl_ca_certs, - authentication_mechanism=authentication_mechanism, - ssl_match_hostname=ssl_match_hostname) + return db_func_with_retry( + db_cleanup_func, + db_name, + db_host, + db_port, + username=username, + password=password, + ssl=ssl, + ssl_keyfile=ssl_keyfile, + ssl_certfile=ssl_certfile, + ssl_cert_reqs=ssl_cert_reqs, + ssl_ca_certs=ssl_ca_certs, + authentication_mechanism=authentication_mechanism, + ssl_match_hostname=ssl_match_hostname, + ) def setup(argv): - common_setup(config=config, setup_db=False, register_mq_exchanges=False, - register_internal_trigger_types=False) + common_setup( + config=config, + setup_db=False, + register_mq_exchanges=False, + register_internal_trigger_types=False, + ) def teardown(): @@ -75,5 +92,5 @@ def main(argv): # This script registers actions and rules from content-packs. -if __name__ == '__main__': +if __name__ == "__main__": main(sys.argv[1:]) diff --git a/st2common/st2common/persistence/db_init.py b/st2common/st2common/persistence/db_init.py index 04a2a3a753e..678ca71ccdd 100644 --- a/st2common/st2common/persistence/db_init.py +++ b/st2common/st2common/persistence/db_init.py @@ -22,9 +22,7 @@ from st2common import log as logging from st2common.models.db import db_setup -__all__ = [ - 'db_setup_with_retry' -] +__all__ = ["db_setup_with_retry"] LOG = logging.getLogger(__name__) @@ -36,9 +34,11 @@ def _retry_if_connection_error(error): # Ideally, a special execption or atleast some exception code. # If this does become an issue look for "Cannot connect to database" at the # start of error msg. - is_connection_error = isinstance(error, mongoengine.connection.MongoEngineConnectionError) + is_connection_error = isinstance( + error, mongoengine.connection.MongoEngineConnectionError + ) if is_connection_error: - LOG.warn('Retry on ConnectionError - %s', error) + LOG.warn("Retry on ConnectionError - %s", error) return is_connection_error @@ -52,25 +52,45 @@ def db_func_with_retry(db_func, *args, **kwargs): # reading of config values however this is lesser code. retrying_obj = retrying.Retrying( retry_on_exception=_retry_if_connection_error, - wait_exponential_multiplier=cfg.CONF.database.connection_retry_backoff_mul * 1000, + wait_exponential_multiplier=cfg.CONF.database.connection_retry_backoff_mul + * 1000, wait_exponential_max=cfg.CONF.database.connection_retry_backoff_max_s * 1000, - stop_max_delay=cfg.CONF.database.connection_retry_max_delay_m * 60 * 1000 + stop_max_delay=cfg.CONF.database.connection_retry_max_delay_m * 60 * 1000, ) return retrying_obj.call(db_func, *args, **kwargs) -def db_setup_with_retry(db_name, db_host, db_port, username=None, password=None, - ensure_indexes=True, ssl=False, ssl_keyfile=None, - ssl_certfile=None, ssl_cert_reqs=None, ssl_ca_certs=None, - authentication_mechanism=None, ssl_match_hostname=True): +def db_setup_with_retry( + db_name, + db_host, + db_port, + username=None, + password=None, + ensure_indexes=True, + ssl=False, + ssl_keyfile=None, + ssl_certfile=None, + ssl_cert_reqs=None, + ssl_ca_certs=None, + authentication_mechanism=None, + ssl_match_hostname=True, +): """ This method is a retry version of db_setup. """ - return db_func_with_retry(db_setup, db_name, db_host, db_port, - username=username, password=password, - ensure_indexes=ensure_indexes, - ssl=ssl, ssl_keyfile=ssl_keyfile, - ssl_certfile=ssl_certfile, ssl_cert_reqs=ssl_cert_reqs, - ssl_ca_certs=ssl_ca_certs, - authentication_mechanism=authentication_mechanism, - ssl_match_hostname=ssl_match_hostname) + return db_func_with_retry( + db_setup, + db_name, + db_host, + db_port, + username=username, + password=password, + ensure_indexes=ensure_indexes, + ssl=ssl, + ssl_keyfile=ssl_keyfile, + ssl_certfile=ssl_certfile, + ssl_cert_reqs=ssl_cert_reqs, + ssl_ca_certs=ssl_ca_certs, + authentication_mechanism=authentication_mechanism, + ssl_match_hostname=ssl_match_hostname, + ) diff --git a/st2common/st2common/persistence/execution.py b/st2common/st2common/persistence/execution.py index 6af949786db..2073dda17b6 100644 --- a/st2common/st2common/persistence/execution.py +++ b/st2common/st2common/persistence/execution.py @@ -21,8 +21,8 @@ from st2common.persistence.base import Access __all__ = [ - 'ActionExecution', - 'ActionExecutionOutput', + "ActionExecution", + "ActionExecutionOutput", ] diff --git a/st2common/st2common/persistence/execution_queue.py b/st2common/st2common/persistence/execution_queue.py index 2ec5f05924b..eaedc22f4ce 100644 --- a/st2common/st2common/persistence/execution_queue.py +++ b/st2common/st2common/persistence/execution_queue.py @@ -18,9 +18,7 @@ from st2common.models.db.execution_queue import EXECUTION_QUEUE_ACCESS from st2common.persistence import base as persistence -__all__ = [ - 'ActionExecutionSchedulingQueue' -] +__all__ = ["ActionExecutionSchedulingQueue"] class ActionExecutionSchedulingQueue(persistence.Access): diff --git a/st2common/st2common/persistence/executionstate.py b/st2common/st2common/persistence/executionstate.py index 8e94a714aa1..7e2debd1384 100644 --- a/st2common/st2common/persistence/executionstate.py +++ b/st2common/st2common/persistence/executionstate.py @@ -19,9 +19,7 @@ from st2common.models.db.executionstate import actionexecstate_access from st2common.persistence import base as persistence -__all__ = [ - 'ActionExecutionState' -] +__all__ = ["ActionExecutionState"] class ActionExecutionState(persistence.Access): @@ -35,5 +33,7 @@ def _get_impl(cls): @classmethod def _get_publisher(cls): if not cls.publisher: - cls.publisher = transport.actionexecutionstate.ActionExecutionStatePublisher() + cls.publisher = ( + transport.actionexecutionstate.ActionExecutionStatePublisher() + ) return cls.publisher diff --git a/st2common/st2common/persistence/keyvalue.py b/st2common/st2common/persistence/keyvalue.py index 634bd723027..10676998f56 100644 --- a/st2common/st2common/persistence/keyvalue.py +++ b/st2common/st2common/persistence/keyvalue.py @@ -34,24 +34,30 @@ class KeyValuePair(Access): publisher = None api_model_cls = KeyValuePairAPI - dispatch_trigger_for_operations = ['create', 'update', 'value_change', 'delete'] + dispatch_trigger_for_operations = ["create", "update", "value_change", "delete"] operation_to_trigger_ref_map = { - 'create': ResourceReference.to_string_reference( - name=KEY_VALUE_PAIR_CREATE_TRIGGER['name'], - pack=KEY_VALUE_PAIR_CREATE_TRIGGER['pack']), - 'update': ResourceReference.to_string_reference( - name=KEY_VALUE_PAIR_UPDATE_TRIGGER['name'], - pack=KEY_VALUE_PAIR_UPDATE_TRIGGER['pack']), - 'value_change': ResourceReference.to_string_reference( - name=KEY_VALUE_PAIR_VALUE_CHANGE_TRIGGER['name'], - pack=KEY_VALUE_PAIR_VALUE_CHANGE_TRIGGER['pack']), - 'delete': ResourceReference.to_string_reference( - name=KEY_VALUE_PAIR_DELETE_TRIGGER['name'], - pack=KEY_VALUE_PAIR_DELETE_TRIGGER['pack']), + "create": ResourceReference.to_string_reference( + name=KEY_VALUE_PAIR_CREATE_TRIGGER["name"], + pack=KEY_VALUE_PAIR_CREATE_TRIGGER["pack"], + ), + "update": ResourceReference.to_string_reference( + name=KEY_VALUE_PAIR_UPDATE_TRIGGER["name"], + pack=KEY_VALUE_PAIR_UPDATE_TRIGGER["pack"], + ), + "value_change": ResourceReference.to_string_reference( + name=KEY_VALUE_PAIR_VALUE_CHANGE_TRIGGER["name"], + pack=KEY_VALUE_PAIR_VALUE_CHANGE_TRIGGER["pack"], + ), + "delete": ResourceReference.to_string_reference( + name=KEY_VALUE_PAIR_DELETE_TRIGGER["name"], + pack=KEY_VALUE_PAIR_DELETE_TRIGGER["pack"], + ), } @classmethod - def add_or_update(cls, model_object, publish=True, dispatch_trigger=True, validate=True): + def add_or_update( + cls, model_object, publish=True, dispatch_trigger=True, validate=True + ): """ Note: We override add_or_update because we also want to publish high level "value_change" event for this resource. @@ -62,32 +68,36 @@ def add_or_update(cls, model_object, publish=True, dispatch_trigger=True, valida # Not an update existing_model_object = None - model_object = super(KeyValuePair, cls).add_or_update(model_object=model_object, - publish=publish, - dispatch_trigger=dispatch_trigger) + model_object = super(KeyValuePair, cls).add_or_update( + model_object=model_object, + publish=publish, + dispatch_trigger=dispatch_trigger, + ) # Dispatch a value_change event which is specific to this resource if existing_model_object and existing_model_object.value != model_object.value: - cls.dispatch_value_change_trigger(old_model_object=existing_model_object, - new_model_object=model_object) + cls.dispatch_value_change_trigger( + old_model_object=existing_model_object, new_model_object=model_object + ) return model_object @classmethod def dispatch_value_change_trigger(cls, old_model_object, new_model_object): - operation = 'value_change' + operation = "value_change" trigger = cls._get_trigger_ref_for_operation(operation=operation) - old_object_payload = cls.api_model_cls.from_model(old_model_object, - mask_secrets=True).__json__() - new_object_payload = cls.api_model_cls.from_model(new_model_object, - mask_secrets=True).__json__() - payload = { - 'old_object': old_object_payload, - 'new_object': new_object_payload - } + old_object_payload = cls.api_model_cls.from_model( + old_model_object, mask_secrets=True + ).__json__() + new_object_payload = cls.api_model_cls.from_model( + new_model_object, mask_secrets=True + ).__json__() + payload = {"old_object": old_object_payload, "new_object": new_object_payload} - return cls._dispatch_trigger(operation=operation, trigger=trigger, payload=payload) + return cls._dispatch_trigger( + operation=operation, trigger=trigger, payload=payload + ) @classmethod def get_by_names(cls, names): @@ -124,5 +134,5 @@ def _get_impl(cls): @classmethod def _get_by_object(cls, object): # For KeyValuePair name is unique. - name = getattr(object, 'name', '') + name = getattr(object, "name", "") return cls.get_by_name(name) diff --git a/st2common/st2common/persistence/liveaction.py b/st2common/st2common/persistence/liveaction.py index 61b16b18782..aa7551592a6 100644 --- a/st2common/st2common/persistence/liveaction.py +++ b/st2common/st2common/persistence/liveaction.py @@ -19,9 +19,7 @@ from st2common.models.db.liveaction import liveaction_access from st2common.persistence import base as persistence -__all__ = [ - 'LiveAction' -] +__all__ = ["LiveAction"] class LiveAction(persistence.StatusBasedResource): diff --git a/st2common/st2common/persistence/marker.py b/st2common/st2common/persistence/marker.py index 1f35bbcdf2f..6be08a25ec1 100644 --- a/st2common/st2common/persistence/marker.py +++ b/st2common/st2common/persistence/marker.py @@ -19,9 +19,7 @@ from st2common.models.db.marker import DumperMarkerDB from st2common.persistence.base import Access -__all__ = [ - 'Marker' -] +__all__ = ["Marker"] class Marker(Access): diff --git a/st2common/st2common/persistence/pack.py b/st2common/st2common/persistence/pack.py index 5b2ff39102a..01ca6b20cb1 100644 --- a/st2common/st2common/persistence/pack.py +++ b/st2common/st2common/persistence/pack.py @@ -19,11 +19,7 @@ from st2common.models.db.pack import config_schema_access from st2common.models.db.pack import config_access -__all__ = [ - 'Pack', - 'ConfigSchema', - 'Config' -] +__all__ = ["Pack", "ConfigSchema", "Config"] class Pack(base.Access): diff --git a/st2common/st2common/persistence/policy.py b/st2common/st2common/persistence/policy.py index 468ce07f696..8b6700c194e 100644 --- a/st2common/st2common/persistence/policy.py +++ b/st2common/st2common/persistence/policy.py @@ -30,16 +30,20 @@ def _get_impl(cls): def get_by_ref(cls, ref): if ref: ref_obj = PolicyTypeReference.from_string_reference(ref=ref) - result = cls.query(name=ref_obj.name, resource_type=ref_obj.resource_type).first() + result = cls.query( + name=ref_obj.name, resource_type=ref_obj.resource_type + ).first() return result else: return None @classmethod def _get_by_object(cls, object): - name = getattr(object, 'name', '') - resource_type = getattr(object, 'resource_type', '') - ref = PolicyTypeReference.to_string_reference(resource_type=resource_type, name=name) + name = getattr(object, "name", "") + resource_type = getattr(object, "resource_type", "") + ref = PolicyTypeReference.to_string_reference( + resource_type=resource_type, name=name + ) return cls.get_by_ref(ref) diff --git a/st2common/st2common/persistence/rbac.py b/st2common/st2common/persistence/rbac.py index bdac61d8883..e14b973aeb7 100644 --- a/st2common/st2common/persistence/rbac.py +++ b/st2common/st2common/persistence/rbac.py @@ -20,12 +20,7 @@ from st2common.models.db.rbac import permission_grant_access from st2common.models.db.rbac import group_to_role_mapping_access -__all__ = [ - 'Role', - 'UserRoleAssignment', - 'PermissionGrant', - 'GroupToRoleMapping' -] +__all__ = ["Role", "UserRoleAssignment", "PermissionGrant", "GroupToRoleMapping"] class Role(base.Access): diff --git a/st2common/st2common/persistence/reactor.py b/st2common/st2common/persistence/reactor.py index c0608775136..0fa35c6bdff 100644 --- a/st2common/st2common/persistence/reactor.py +++ b/st2common/st2common/persistence/reactor.py @@ -16,12 +16,6 @@ from __future__ import absolute_import from st2common.persistence.rule import Rule from st2common.persistence.sensor import SensorType -from st2common.persistence.trigger import (Trigger, TriggerInstance, TriggerType) +from st2common.persistence.trigger import Trigger, TriggerInstance, TriggerType -__all__ = [ - 'Rule', - 'SensorType', - 'Trigger', - 'TriggerInstance', - 'TriggerType' -] +__all__ = ["Rule", "SensorType", "Trigger", "TriggerInstance", "TriggerType"] diff --git a/st2common/st2common/persistence/rule.py b/st2common/st2common/persistence/rule.py index 741b9d49675..0a64e4bb1f2 100644 --- a/st2common/st2common/persistence/rule.py +++ b/st2common/st2common/persistence/rule.py @@ -36,5 +36,5 @@ def _get_impl(cls): @classmethod def _get_by_object(cls, object): # For RuleType name is unique. - name = getattr(object, 'name', '') + name = getattr(object, "name", "") return cls.get_by_name(name) diff --git a/st2common/st2common/persistence/runner.py b/st2common/st2common/persistence/runner.py index 77440707f24..63cfa36d9e9 100644 --- a/st2common/st2common/persistence/runner.py +++ b/st2common/st2common/persistence/runner.py @@ -28,5 +28,5 @@ def _get_impl(cls): @classmethod def _get_by_object(cls, object): # For RunnerType name is unique. - name = getattr(object, 'name', '') + name = getattr(object, "name", "") return cls.get_by_name(name) diff --git a/st2common/st2common/persistence/sensor.py b/st2common/st2common/persistence/sensor.py index 67367c7fc48..1a3a3679dac 100644 --- a/st2common/st2common/persistence/sensor.py +++ b/st2common/st2common/persistence/sensor.py @@ -19,9 +19,7 @@ from st2common.models.db.sensor import sensor_type_access from st2common.persistence.base import ContentPackResource -__all__ = [ - 'SensorType' -] +__all__ = ["SensorType"] class SensorType(ContentPackResource): diff --git a/st2common/st2common/persistence/trace.py b/st2common/st2common/persistence/trace.py index 5e7276a1f0d..ce5472f2aaa 100644 --- a/st2common/st2common/persistence/trace.py +++ b/st2common/st2common/persistence/trace.py @@ -26,14 +26,16 @@ def _get_impl(cls): return cls.impl @classmethod - def push_components(cls, instance, action_executions=None, rules=None, trigger_instances=None): + def push_components( + cls, instance, action_executions=None, rules=None, trigger_instances=None + ): update_kwargs = {} if action_executions: - update_kwargs['push_all__action_executions'] = action_executions + update_kwargs["push_all__action_executions"] = action_executions if rules: - update_kwargs['push_all__rules'] = rules + update_kwargs["push_all__rules"] = rules if trigger_instances: - update_kwargs['push_all__trigger_instances'] = trigger_instances + update_kwargs["push_all__trigger_instances"] = trigger_instances if update_kwargs: return cls.update(instance, **update_kwargs) return instance diff --git a/st2common/st2common/persistence/trigger.py b/st2common/st2common/persistence/trigger.py index 1cdc4ef4ac9..3567a15829a 100644 --- a/st2common/st2common/persistence/trigger.py +++ b/st2common/st2common/persistence/trigger.py @@ -18,14 +18,14 @@ from st2common import log as logging from st2common import transport from st2common.exceptions.db import StackStormDBObjectNotFoundError -from st2common.models.db.trigger import triggertype_access, trigger_access, triggerinstance_access -from st2common.persistence.base import (Access, ContentPackResource) +from st2common.models.db.trigger import ( + triggertype_access, + trigger_access, + triggerinstance_access, +) +from st2common.persistence.base import Access, ContentPackResource -__all__ = [ - 'TriggerType', - 'Trigger', - 'TriggerInstance' -] +__all__ = ["TriggerType", "Trigger", "TriggerInstance"] LOG = logging.getLogger(__name__) @@ -57,7 +57,7 @@ def delete_if_unreferenced(cls, model_object, publish=True, dispatch_trigger=Tru # Found in the innards of mongoengine. # e.g. {'pk': ObjectId('5609e91832ed356d04a93cc0')} delete_query = model_object._object_key - delete_query['ref_count__lte'] = 0 + delete_query["ref_count__lte"] = 0 cls._get_impl().delete_by_query(**delete_query) # Since delete_by_query cannot tell if teh delete actually happened check with a get call @@ -73,14 +73,14 @@ def delete_if_unreferenced(cls, model_object, publish=True, dispatch_trigger=Tru try: cls.publish_delete(model_object) except Exception: - LOG.exception('Publish failed.') + LOG.exception("Publish failed.") # Dispatch trigger if confirmed_delete and dispatch_trigger: try: cls.dispatch_delete_trigger(model_object) except Exception: - LOG.exception('Trigger dispatch failed.') + LOG.exception("Trigger dispatch failed.") return model_object diff --git a/st2common/st2common/persistence/workflow.py b/st2common/st2common/persistence/workflow.py index aa02c320e1a..8d993ef4fe3 100644 --- a/st2common/st2common/persistence/workflow.py +++ b/st2common/st2common/persistence/workflow.py @@ -21,10 +21,7 @@ from st2common.persistence import base as persistence -__all__ = [ - 'WorkflowExecution', - 'TaskExecution' -] +__all__ = ["WorkflowExecution", "TaskExecution"] class WorkflowExecution(persistence.StatusBasedResource): diff --git a/st2common/st2common/policies/__init__.py b/st2common/st2common/policies/__init__.py index df49fa1f149..ef39e129c93 100644 --- a/st2common/st2common/policies/__init__.py +++ b/st2common/st2common/policies/__init__.py @@ -18,7 +18,4 @@ from st2common.policies.base import ResourcePolicyApplicator -__all__ = [ - 'get_driver', - 'ResourcePolicyApplicator' -] +__all__ = ["get_driver", "ResourcePolicyApplicator"] diff --git a/st2common/st2common/policies/base.py b/st2common/st2common/policies/base.py index 5bfc3fa58e6..a22fa2fb424 100644 --- a/st2common/st2common/policies/base.py +++ b/st2common/st2common/policies/base.py @@ -24,10 +24,7 @@ LOG = logging.getLogger(__name__) -__all__ = [ - 'ResourcePolicyApplicator', - 'get_driver' -] +__all__ = ["ResourcePolicyApplicator", "get_driver"] @six.add_metaclass(abc.ABCMeta) @@ -72,9 +69,9 @@ def _get_lock_name(self, values): lock_uid = [] for key, value in six.iteritems(values): - lock_uid.append('%s=%s' % (key, value)) + lock_uid.append("%s=%s" % (key, value)) - lock_uid = ','.join(lock_uid) + lock_uid = ",".join(lock_uid) return lock_uid @@ -88,5 +85,7 @@ def get_driver(policy_ref, policy_type, **parameters): # interested in continue - if (issubclass(obj, ResourcePolicyApplicator) and not obj.__name__.startswith('Base')): + if issubclass(obj, ResourcePolicyApplicator) and not obj.__name__.startswith( + "Base" + ): return obj(policy_ref, policy_type, **parameters) diff --git a/st2common/st2common/policies/concurrency.py b/st2common/st2common/policies/concurrency.py index a453214b722..fcf96467c3e 100644 --- a/st2common/st2common/policies/concurrency.py +++ b/st2common/st2common/policies/concurrency.py @@ -18,24 +18,23 @@ from st2common.policies import base from st2common.services import coordination -__all__ = [ - 'BaseConcurrencyApplicator' -] +__all__ = ["BaseConcurrencyApplicator"] class BaseConcurrencyApplicator(base.ResourcePolicyApplicator): - def __init__(self, policy_ref, policy_type, threshold=0, action='delay'): - super(BaseConcurrencyApplicator, self).__init__(policy_ref=policy_ref, - policy_type=policy_type) + def __init__(self, policy_ref, policy_type, threshold=0, action="delay"): + super(BaseConcurrencyApplicator, self).__init__( + policy_ref=policy_ref, policy_type=policy_type + ) self.threshold = threshold self.policy_action = action self.coordinator = coordination.get_coordinator(start_heart=True) def _get_status_for_policy_action(self, action): - if action == 'delay': + if action == "delay": status = action_constants.LIVEACTION_STATUS_DELAYED - elif action == 'cancel': + elif action == "cancel": status = action_constants.LIVEACTION_STATUS_CANCELING return status diff --git a/st2common/st2common/rbac/backends/__init__.py b/st2common/st2common/rbac/backends/__init__.py index cf6429c1248..bb7ad3d58f7 100644 --- a/st2common/st2common/rbac/backends/__init__.py +++ b/st2common/st2common/rbac/backends/__init__.py @@ -22,15 +22,11 @@ from st2common.util import driver_loader -__all__ = [ - 'get_available_backends', - 'get_backend_instance', - 'get_rbac_backend' -] +__all__ = ["get_available_backends", "get_backend_instance", "get_rbac_backend"] LOG = logging.getLogger(__name__) -BACKENDS_NAMESPACE = 'st2common.rbac.backend' +BACKENDS_NAMESPACE = "st2common.rbac.backend" # Cache which maps backed name -> backend class instance # NOTE: We use cache to avoid slow stevedore dynamic filesystem instrospection on every @@ -44,7 +40,9 @@ def get_available_backends(): def get_backend_instance(name, use_cache=True): if name not in BACKENDS_CACHE or not use_cache: - rbac_backend = driver_loader.get_backend_instance(namespace=BACKENDS_NAMESPACE, name=name) + rbac_backend = driver_loader.get_backend_instance( + namespace=BACKENDS_NAMESPACE, name=name + ) BACKENDS_CACHE[name] = rbac_backend rbac_backend = BACKENDS_CACHE[name] diff --git a/st2common/st2common/rbac/backends/base.py b/st2common/st2common/rbac/backends/base.py index 8e2c54c4fdb..f9661d0b4bc 100644 --- a/st2common/st2common/rbac/backends/base.py +++ b/st2common/st2common/rbac/backends/base.py @@ -23,17 +23,16 @@ from st2common.exceptions.rbac import AccessDeniedError __all__ = [ - 'BaseRBACBackend', - 'BaseRBACPermissionResolver', - 'BaseRBACService', - 'BaseRBACUtils', - 'BaseRBACRemoteGroupToRoleSyncer' + "BaseRBACBackend", + "BaseRBACPermissionResolver", + "BaseRBACService", + "BaseRBACUtils", + "BaseRBACRemoteGroupToRoleSyncer", ] @six.add_metaclass(abc.ABCMeta) class BaseRBACBackend(object): - def get_resolver_for_resource_type(self, resource_type): """ Method which returns PermissionResolver class for the provided resource type. @@ -67,7 +66,6 @@ def get_utils_class(self): @six.add_metaclass(abc.ABCMeta) class BaseRBACPermissionResolver(object): - def user_has_permission(self, user_db, permission_type): """ Method for checking user permissions which are not tied to a particular resource. @@ -177,7 +175,9 @@ def assert_user_has_rule_trigger_and_action_permission(user_db, rule_api): raise NotImplementedError() @staticmethod - def assert_user_is_admin_if_user_query_param_is_provided(user_db, user, require_rbac=False): + def assert_user_is_admin_if_user_query_param_is_provided( + user_db, user, require_rbac=False + ): """ Function which asserts that the request user is administator if "user" query parameter is provided and doesn't match the current user. @@ -273,12 +273,12 @@ def get_user_db_from_request(request): """ Retrieve UserDB object from the provided request. """ - auth_context = request.context.get('auth', {}) + auth_context = request.context.get("auth", {}) if not auth_context: return None - user_db = auth_context.get('user', None) + user_db = auth_context.get("user", None) return user_db diff --git a/st2common/st2common/rbac/backends/noop.py b/st2common/st2common/rbac/backends/noop.py index 15ca5a3a75c..4d3b8fb127f 100644 --- a/st2common/st2common/rbac/backends/noop.py +++ b/st2common/st2common/rbac/backends/noop.py @@ -25,11 +25,11 @@ from st2common.exceptions.rbac import AccessDeniedError __all__ = [ - 'NoOpRBACBackend', - 'NoOpRBACPermissionResolver', - 'NoOpRBACService', - 'NoOpRBACUtils', - 'NoOpRBACRemoteGroupToRoleSyncer' + "NoOpRBACBackend", + "NoOpRBACPermissionResolver", + "NoOpRBACService", + "NoOpRBACUtils", + "NoOpRBACRemoteGroupToRoleSyncer", ] @@ -37,6 +37,7 @@ class NoOpRBACBackend(BaseRBACBackend): """ NoOp RBAC backend. """ + def get_resolver_for_resource_type(self, resource_type): return NoOpRBACPermissionResolver() @@ -79,7 +80,6 @@ def validate_roles_exists(role_names): class NoOpRBACUtils(BaseRBACUtils): - @staticmethod def assert_user_is_admin(user_db): """ @@ -141,7 +141,9 @@ def assert_user_has_rule_trigger_and_action_permission(user_db, rule_api): return True @staticmethod - def assert_user_is_admin_if_user_query_param_is_provided(user_db, user, require_rbac=False): + def assert_user_is_admin_if_user_query_param_is_provided( + user_db, user, require_rbac=False + ): """ Function which asserts that the request user is administator if "user" query parameter is provided and doesn't match the current user. diff --git a/st2common/st2common/rbac/migrations.py b/st2common/st2common/rbac/migrations.py index 9e9fc9db18c..951bbddf194 100644 --- a/st2common/st2common/rbac/migrations.py +++ b/st2common/st2common/rbac/migrations.py @@ -23,11 +23,7 @@ LOG = logging.getLogger(__name__) -__all__ = [ - 'run_all', - - 'insert_system_roles' -] +__all__ = ["run_all", "insert_system_roles"] def run_all(): @@ -40,7 +36,7 @@ def insert_system_roles(): """ system_roles = SystemRole.get_valid_values() - LOG.debug('Inserting system roles (%s)' % (str(system_roles))) + LOG.debug("Inserting system roles (%s)" % (str(system_roles))) for role_name in system_roles: description = role_name diff --git a/st2common/st2common/rbac/types.py b/st2common/st2common/rbac/types.py index 1c6b0ea3528..cceb819d7bc 100644 --- a/st2common/st2common/rbac/types.py +++ b/st2common/st2common/rbac/types.py @@ -21,19 +21,16 @@ from st2common.constants.types import ResourceType as SystemResourceType __all__ = [ - 'SystemRole', - 'PermissionType', - 'ResourceType', - - 'RESOURCE_TYPE_TO_PERMISSION_TYPES_MAP', - 'PERMISION_TYPE_TO_DESCRIPTION_MAP', - - 'ALL_PERMISSION_TYPES', - 'GLOBAL_PERMISSION_TYPES', - 'GLOBAL_PACK_PERMISSION_TYPES', - 'LIST_PERMISSION_TYPES', - - 'get_resource_permission_types_with_descriptions' + "SystemRole", + "PermissionType", + "ResourceType", + "RESOURCE_TYPE_TO_PERMISSION_TYPES_MAP", + "PERMISION_TYPE_TO_DESCRIPTION_MAP", + "ALL_PERMISSION_TYPES", + "GLOBAL_PERMISSION_TYPES", + "GLOBAL_PACK_PERMISSION_TYPES", + "LIST_PERMISSION_TYPES", + "get_resource_permission_types_with_descriptions", ] @@ -43,120 +40,120 @@ class PermissionType(Enum): """ # Note: There is no create endpoint for runner types right now - RUNNER_LIST = 'runner_type_list' - RUNNER_VIEW = 'runner_type_view' - RUNNER_MODIFY = 'runner_type_modify' - RUNNER_ALL = 'runner_type_all' + RUNNER_LIST = "runner_type_list" + RUNNER_VIEW = "runner_type_view" + RUNNER_MODIFY = "runner_type_modify" + RUNNER_ALL = "runner_type_all" - PACK_LIST = 'pack_list' - PACK_VIEW = 'pack_view' - PACK_CREATE = 'pack_create' - PACK_MODIFY = 'pack_modify' - PACK_DELETE = 'pack_delete' + PACK_LIST = "pack_list" + PACK_VIEW = "pack_view" + PACK_CREATE = "pack_create" + PACK_MODIFY = "pack_modify" + PACK_DELETE = "pack_delete" # Pack-management specific permissions # Note: Right now those permissions are global and apply to all the packs. # In the future we plan to support globs. - PACK_INSTALL = 'pack_install' - PACK_UNINSTALL = 'pack_uninstall' - PACK_REGISTER = 'pack_register' - PACK_CONFIG = 'pack_config' - PACK_SEARCH = 'pack_search' - PACK_VIEWS_INDEX_HEALTH = 'pack_views_index_health' + PACK_INSTALL = "pack_install" + PACK_UNINSTALL = "pack_uninstall" + PACK_REGISTER = "pack_register" + PACK_CONFIG = "pack_config" + PACK_SEARCH = "pack_search" + PACK_VIEWS_INDEX_HEALTH = "pack_views_index_health" - PACK_ALL = 'pack_all' + PACK_ALL = "pack_all" # Note: Right now we only have read endpoints + update for sensors types - SENSOR_LIST = 'sensor_type_list' - SENSOR_VIEW = 'sensor_type_view' - SENSOR_MODIFY = 'sensor_type_modify' - SENSOR_ALL = 'sensor_type_all' - - ACTION_LIST = 'action_list' - ACTION_VIEW = 'action_view' - ACTION_CREATE = 'action_create' - ACTION_MODIFY = 'action_modify' - ACTION_DELETE = 'action_delete' - ACTION_EXECUTE = 'action_execute' - ACTION_ALL = 'action_all' - - ACTION_ALIAS_LIST = 'action_alias_list' - ACTION_ALIAS_VIEW = 'action_alias_view' - ACTION_ALIAS_CREATE = 'action_alias_create' - ACTION_ALIAS_MODIFY = 'action_alias_modify' - ACTION_ALIAS_MATCH = 'action_alias_match' - ACTION_ALIAS_HELP = 'action_alias_help' - ACTION_ALIAS_DELETE = 'action_alias_delete' - ACTION_ALIAS_ALL = 'action_alias_all' + SENSOR_LIST = "sensor_type_list" + SENSOR_VIEW = "sensor_type_view" + SENSOR_MODIFY = "sensor_type_modify" + SENSOR_ALL = "sensor_type_all" + + ACTION_LIST = "action_list" + ACTION_VIEW = "action_view" + ACTION_CREATE = "action_create" + ACTION_MODIFY = "action_modify" + ACTION_DELETE = "action_delete" + ACTION_EXECUTE = "action_execute" + ACTION_ALL = "action_all" + + ACTION_ALIAS_LIST = "action_alias_list" + ACTION_ALIAS_VIEW = "action_alias_view" + ACTION_ALIAS_CREATE = "action_alias_create" + ACTION_ALIAS_MODIFY = "action_alias_modify" + ACTION_ALIAS_MATCH = "action_alias_match" + ACTION_ALIAS_HELP = "action_alias_help" + ACTION_ALIAS_DELETE = "action_alias_delete" + ACTION_ALIAS_ALL = "action_alias_all" # Note: Execution create is granted with "action_execute" - EXECUTION_LIST = 'execution_list' - EXECUTION_VIEW = 'execution_view' - EXECUTION_RE_RUN = 'execution_rerun' - EXECUTION_STOP = 'execution_stop' - EXECUTION_ALL = 'execution_all' - EXECUTION_VIEWS_FILTERS_LIST = 'execution_views_filters_list' - - RULE_LIST = 'rule_list' - RULE_VIEW = 'rule_view' - RULE_CREATE = 'rule_create' - RULE_MODIFY = 'rule_modify' - RULE_DELETE = 'rule_delete' - RULE_ALL = 'rule_all' - - RULE_ENFORCEMENT_LIST = 'rule_enforcement_list' - RULE_ENFORCEMENT_VIEW = 'rule_enforcement_view' + EXECUTION_LIST = "execution_list" + EXECUTION_VIEW = "execution_view" + EXECUTION_RE_RUN = "execution_rerun" + EXECUTION_STOP = "execution_stop" + EXECUTION_ALL = "execution_all" + EXECUTION_VIEWS_FILTERS_LIST = "execution_views_filters_list" + + RULE_LIST = "rule_list" + RULE_VIEW = "rule_view" + RULE_CREATE = "rule_create" + RULE_MODIFY = "rule_modify" + RULE_DELETE = "rule_delete" + RULE_ALL = "rule_all" + + RULE_ENFORCEMENT_LIST = "rule_enforcement_list" + RULE_ENFORCEMENT_VIEW = "rule_enforcement_view" # TODO - Maybe "datastore_item" / key_value_item ? - KEY_VALUE_VIEW = 'key_value_pair_view' - KEY_VALUE_SET = 'key_value_pair_set' - KEY_VALUE_DELETE = 'key_value_pair_delete' - - WEBHOOK_LIST = 'webhook_list' - WEBHOOK_VIEW = 'webhook_view' - WEBHOOK_CREATE = 'webhook_create' - WEBHOOK_SEND = 'webhook_send' - WEBHOOK_DELETE = 'webhook_delete' - WEBHOOK_ALL = 'webhook_all' - - TIMER_LIST = 'timer_list' - TIMER_VIEW = 'timer_view' - TIMER_ALL = 'timer_all' - - API_KEY_LIST = 'api_key_list' - API_KEY_VIEW = 'api_key_view' - API_KEY_CREATE = 'api_key_create' - API_KEY_MODIFY = 'api_key_modify' - API_KEY_DELETE = 'api_key_delete' - API_KEY_ALL = 'api_key_all' - - TRACE_LIST = 'trace_list' - TRACE_VIEW = 'trace_view' - TRACE_ALL = 'trace_all' + KEY_VALUE_VIEW = "key_value_pair_view" + KEY_VALUE_SET = "key_value_pair_set" + KEY_VALUE_DELETE = "key_value_pair_delete" + + WEBHOOK_LIST = "webhook_list" + WEBHOOK_VIEW = "webhook_view" + WEBHOOK_CREATE = "webhook_create" + WEBHOOK_SEND = "webhook_send" + WEBHOOK_DELETE = "webhook_delete" + WEBHOOK_ALL = "webhook_all" + + TIMER_LIST = "timer_list" + TIMER_VIEW = "timer_view" + TIMER_ALL = "timer_all" + + API_KEY_LIST = "api_key_list" + API_KEY_VIEW = "api_key_view" + API_KEY_CREATE = "api_key_create" + API_KEY_MODIFY = "api_key_modify" + API_KEY_DELETE = "api_key_delete" + API_KEY_ALL = "api_key_all" + + TRACE_LIST = "trace_list" + TRACE_VIEW = "trace_view" + TRACE_ALL = "trace_all" # Note: Trigger permissions types are also used for Timer API endpoint since timer is just # a special type of a trigger - TRIGGER_LIST = 'trigger_list' - TRIGGER_VIEW = 'trigger_view' - TRIGGER_ALL = 'trigger_all' + TRIGGER_LIST = "trigger_list" + TRIGGER_VIEW = "trigger_view" + TRIGGER_ALL = "trigger_all" - POLICY_TYPE_LIST = 'policy_type_list' - POLICY_TYPE_VIEW = 'policy_type_view' - POLICY_TYPE_ALL = 'policy_type_all' + POLICY_TYPE_LIST = "policy_type_list" + POLICY_TYPE_VIEW = "policy_type_view" + POLICY_TYPE_ALL = "policy_type_all" - POLICY_LIST = 'policy_list' - POLICY_VIEW = 'policy_view' - POLICY_CREATE = 'policy_create' - POLICY_MODIFY = 'policy_modify' - POLICY_DELETE = 'policy_delete' - POLICY_ALL = 'policy_all' + POLICY_LIST = "policy_list" + POLICY_VIEW = "policy_view" + POLICY_CREATE = "policy_create" + POLICY_MODIFY = "policy_modify" + POLICY_DELETE = "policy_delete" + POLICY_ALL = "policy_all" - STREAM_VIEW = 'stream_view' + STREAM_VIEW = "stream_view" - INQUIRY_LIST = 'inquiry_list' - INQUIRY_VIEW = 'inquiry_view' - INQUIRY_RESPOND = 'inquiry_respond' - INQUIRY_ALL = 'inquiry_all' + INQUIRY_LIST = "inquiry_list" + INQUIRY_VIEW = "inquiry_view" + INQUIRY_RESPOND = "inquiry_respond" + INQUIRY_ALL = "inquiry_all" @classmethod def get_valid_permissions_for_resource_type(cls, resource_type): @@ -183,10 +180,10 @@ def get_resource_type(cls, permission_type): elif permission_type == PermissionType.EXECUTION_VIEWS_FILTERS_LIST: return ResourceType.EXECUTION - split = permission_type.split('_') + split = permission_type.split("_") assert len(split) >= 2 - return '_'.join(split[:-1]) + return "_".join(split[:-1]) @classmethod def get_permission_name(cls, permission_type): @@ -195,12 +192,12 @@ def get_permission_name(cls, permission_type): :rtype: ``str`` """ - split = permission_type.split('_') + split = permission_type.split("_") assert len(split) >= 2 # Special case for PACK_VIEWS_INDEX_HEALTH if permission_type == PermissionType.PACK_VIEWS_INDEX_HEALTH: - split = permission_type.split('_', 1) + split = permission_type.split("_", 1) return split[1] return split[-1] @@ -224,14 +221,16 @@ def get_permission_type(cls, resource_type, permission_name): """ # Special case for sensor type (sensor_type -> sensor) if resource_type == ResourceType.SENSOR: - resource_type = 'sensor' + resource_type = "sensor" - permission_enum = '%s_%s' % (resource_type.upper(), permission_name.upper()) + permission_enum = "%s_%s" % (resource_type.upper(), permission_name.upper()) result = getattr(cls, permission_enum, None) if not result: - raise ValueError('Unsupported permission type for type "%s" and name "%s"' % - (resource_type, permission_name)) + raise ValueError( + 'Unsupported permission type for type "%s" and name "%s"' + % (resource_type, permission_name) + ) return result @@ -240,6 +239,7 @@ class ResourceType(Enum): """ Resource types on which permissions can be granted. """ + RUNNER = SystemResourceType.RUNNER_TYPE PACK = SystemResourceType.PACK @@ -266,9 +266,10 @@ class SystemRole(Enum): """ Default system roles which can't be manipulated (modified or removed). """ - SYSTEM_ADMIN = 'system_admin' # Special role which can't be revoked. - ADMIN = 'admin' - OBSERVER = 'observer' + + SYSTEM_ADMIN = "system_admin" # Special role which can't be revoked. + ADMIN = "admin" + OBSERVER = "observer" # Maps a list of available permission types for each resource @@ -292,35 +293,31 @@ class SystemRole(Enum): PermissionType.PACK_SEARCH, PermissionType.PACK_VIEWS_INDEX_HEALTH, PermissionType.PACK_ALL, - PermissionType.SENSOR_VIEW, PermissionType.SENSOR_MODIFY, PermissionType.SENSOR_ALL, - PermissionType.ACTION_VIEW, PermissionType.ACTION_CREATE, PermissionType.ACTION_MODIFY, PermissionType.ACTION_DELETE, PermissionType.ACTION_EXECUTE, PermissionType.ACTION_ALL, - PermissionType.ACTION_ALIAS_VIEW, PermissionType.ACTION_ALIAS_CREATE, PermissionType.ACTION_ALIAS_MODIFY, PermissionType.ACTION_ALIAS_DELETE, PermissionType.ACTION_ALIAS_ALL, - PermissionType.RULE_VIEW, PermissionType.RULE_CREATE, PermissionType.RULE_MODIFY, PermissionType.RULE_DELETE, - PermissionType.RULE_ALL + PermissionType.RULE_ALL, ], ResourceType.SENSOR: [ PermissionType.SENSOR_LIST, PermissionType.SENSOR_VIEW, PermissionType.SENSOR_MODIFY, - PermissionType.SENSOR_ALL + PermissionType.SENSOR_ALL, ], ResourceType.ACTION: [ PermissionType.ACTION_LIST, @@ -329,7 +326,7 @@ class SystemRole(Enum): PermissionType.ACTION_MODIFY, PermissionType.ACTION_DELETE, PermissionType.ACTION_EXECUTE, - PermissionType.ACTION_ALL + PermissionType.ACTION_ALL, ], ResourceType.ACTION_ALIAS: [ PermissionType.ACTION_ALIAS_LIST, @@ -339,7 +336,7 @@ class SystemRole(Enum): PermissionType.ACTION_ALIAS_MATCH, PermissionType.ACTION_ALIAS_HELP, PermissionType.ACTION_ALIAS_DELETE, - PermissionType.ACTION_ALIAS_ALL + PermissionType.ACTION_ALIAS_ALL, ], ResourceType.RULE: [ PermissionType.RULE_LIST, @@ -347,7 +344,7 @@ class SystemRole(Enum): PermissionType.RULE_CREATE, PermissionType.RULE_MODIFY, PermissionType.RULE_DELETE, - PermissionType.RULE_ALL + PermissionType.RULE_ALL, ], ResourceType.RULE_ENFORCEMENT: [ PermissionType.RULE_ENFORCEMENT_LIST, @@ -364,7 +361,7 @@ class SystemRole(Enum): ResourceType.KEY_VALUE_PAIR: [ PermissionType.KEY_VALUE_VIEW, PermissionType.KEY_VALUE_SET, - PermissionType.KEY_VALUE_DELETE + PermissionType.KEY_VALUE_DELETE, ], ResourceType.WEBHOOK: [ PermissionType.WEBHOOK_LIST, @@ -372,12 +369,12 @@ class SystemRole(Enum): PermissionType.WEBHOOK_CREATE, PermissionType.WEBHOOK_SEND, PermissionType.WEBHOOK_DELETE, - PermissionType.WEBHOOK_ALL + PermissionType.WEBHOOK_ALL, ], ResourceType.TIMER: [ PermissionType.TIMER_LIST, PermissionType.TIMER_VIEW, - PermissionType.TIMER_ALL + PermissionType.TIMER_ALL, ], ResourceType.API_KEY: [ PermissionType.API_KEY_LIST, @@ -385,17 +382,17 @@ class SystemRole(Enum): PermissionType.API_KEY_CREATE, PermissionType.API_KEY_MODIFY, PermissionType.API_KEY_DELETE, - PermissionType.API_KEY_ALL + PermissionType.API_KEY_ALL, ], ResourceType.TRACE: [ PermissionType.TRACE_LIST, PermissionType.TRACE_VIEW, - PermissionType.TRACE_ALL + PermissionType.TRACE_ALL, ], ResourceType.TRIGGER: [ PermissionType.TRIGGER_LIST, PermissionType.TRIGGER_VIEW, - PermissionType.TRIGGER_ALL + PermissionType.TRIGGER_ALL, ], ResourceType.POLICY_TYPE: [ PermissionType.POLICY_TYPE_LIST, @@ -415,13 +412,16 @@ class SystemRole(Enum): PermissionType.INQUIRY_VIEW, PermissionType.INQUIRY_RESPOND, PermissionType.INQUIRY_ALL, - ] + ], } ALL_PERMISSION_TYPES = list(RESOURCE_TYPE_TO_PERMISSION_TYPES_MAP.values()) ALL_PERMISSION_TYPES = list(itertools.chain(*ALL_PERMISSION_TYPES)) -LIST_PERMISSION_TYPES = [permission_type for permission_type in ALL_PERMISSION_TYPES if - permission_type.endswith('_list')] +LIST_PERMISSION_TYPES = [ + permission_type + for permission_type in ALL_PERMISSION_TYPES + if permission_type.endswith("_list") +] # List of global permissions (ones which don't apply to a specific resource) GLOBAL_PERMISSION_TYPES = [ @@ -433,169 +433,198 @@ class SystemRole(Enum): PermissionType.PACK_CONFIG, PermissionType.PACK_SEARCH, PermissionType.PACK_VIEWS_INDEX_HEALTH, - # Action alias global permission types PermissionType.ACTION_ALIAS_MATCH, PermissionType.ACTION_ALIAS_HELP, - # API key global permission types PermissionType.API_KEY_CREATE, - # Policy global permission types PermissionType.POLICY_CREATE, - # Execution PermissionType.EXECUTION_VIEWS_FILTERS_LIST, - # Stream PermissionType.STREAM_VIEW, - # Inquiry PermissionType.INQUIRY_LIST, PermissionType.INQUIRY_RESPOND, - PermissionType.INQUIRY_VIEW - + PermissionType.INQUIRY_VIEW, ] + LIST_PERMISSION_TYPES -GLOBAL_PACK_PERMISSION_TYPES = [permission_type for permission_type in GLOBAL_PERMISSION_TYPES if - permission_type.startswith('pack_')] +GLOBAL_PACK_PERMISSION_TYPES = [ + permission_type + for permission_type in GLOBAL_PERMISSION_TYPES + if permission_type.startswith("pack_") +] # Maps a permission type to the corresponding description PERMISION_TYPE_TO_DESCRIPTION_MAP = { - PermissionType.PACK_LIST: 'Ability to list (view all) packs.', - PermissionType.PACK_VIEW: 'Ability to view a pack.', - PermissionType.PACK_CREATE: 'Ability to create a new pack.', - PermissionType.PACK_MODIFY: 'Ability to modify (update) an existing pack.', - PermissionType.PACK_DELETE: 'Ability to delete an existing pack.', - PermissionType.PACK_INSTALL: 'Ability to install packs.', - PermissionType.PACK_UNINSTALL: 'Ability to uninstall packs.', - PermissionType.PACK_REGISTER: 'Ability to register packs and corresponding resources.', - PermissionType.PACK_CONFIG: 'Ability to configure a pack.', - PermissionType.PACK_SEARCH: 'Ability to query registry and search packs.', - PermissionType.PACK_VIEWS_INDEX_HEALTH: 'Ability to query health of pack registries.', - PermissionType.PACK_ALL: ('Ability to perform all the supported operations on a particular ' - 'pack.'), - - PermissionType.SENSOR_LIST: 'Ability to list (view all) sensors.', - PermissionType.SENSOR_VIEW: 'Ability to view a sensor', - PermissionType.SENSOR_MODIFY: ('Ability to modify (update) an existing sensor. Also implies ' - '"sensor_type_view" permission.'), - PermissionType.SENSOR_ALL: ('Ability to perform all the supported operations on a particular ' - 'sensor.'), - - PermissionType.ACTION_LIST: 'Ability to list (view all) actions.', - PermissionType.ACTION_VIEW: 'Ability to view an action.', - PermissionType.ACTION_CREATE: ('Ability to create a new action. Also implies "action_view" ' - 'permission.'), - PermissionType.ACTION_MODIFY: ('Ability to modify (update) an existing action. Also implies ' - '"action_view" permission.'), - PermissionType.ACTION_DELETE: ('Ability to delete an existing action. Also implies ' - '"action_view" permission.'), - PermissionType.ACTION_EXECUTE: ('Ability to execute (run) an action. Also implies ' - '"action_view" permission.'), - PermissionType.ACTION_ALL: ('Ability to perform all the supported operations on a particular ' - 'action.'), - - PermissionType.ACTION_ALIAS_LIST: 'Ability to list (view all) action aliases.', - PermissionType.ACTION_ALIAS_VIEW: 'Ability to view an action alias.', - PermissionType.ACTION_ALIAS_CREATE: ('Ability to create a new action alias. Also implies' - ' "action_alias_view" permission.'), - PermissionType.ACTION_ALIAS_MODIFY: ('Ability to modify (update) an existing action alias. ' - 'Also implies "action_alias_view" permission.'), - PermissionType.ACTION_ALIAS_MATCH: ('Ability to use action alias match API endpoint.'), - PermissionType.ACTION_ALIAS_HELP: ('Ability to use action alias help API endpoint.'), - PermissionType.ACTION_ALIAS_DELETE: ('Ability to delete an existing action alias. Also ' - 'implies "action_alias_view" permission.'), - PermissionType.ACTION_ALIAS_ALL: ('Ability to perform all the supported operations on a ' - 'particular action alias.'), - - PermissionType.EXECUTION_LIST: 'Ability to list (view all) executions.', - PermissionType.EXECUTION_VIEW: 'Ability to view an execution.', - PermissionType.EXECUTION_RE_RUN: 'Ability to create a new action.', - PermissionType.EXECUTION_STOP: 'Ability to stop (cancel) a running execution.', - PermissionType.EXECUTION_ALL: ('Ability to perform all the supported operations on a ' - 'particular execution.'), - PermissionType.EXECUTION_VIEWS_FILTERS_LIST: ('Ability view all the distinct execution ' - 'filters.'), - - PermissionType.RULE_LIST: 'Ability to list (view all) rules.', - PermissionType.RULE_VIEW: 'Ability to view a rule.', - PermissionType.RULE_CREATE: ('Ability to create a new rule. Also implies "rule_view" ' - 'permission'), - PermissionType.RULE_MODIFY: ('Ability to modify (update) an existing rule. Also implies ' - '"rule_view" permission.'), - PermissionType.RULE_DELETE: ('Ability to delete an existing rule. Also implies "rule_view" ' - 'permission.'), - PermissionType.RULE_ALL: ('Ability to perform all the supported operations on a particular ' - 'rule.'), - - PermissionType.RULE_ENFORCEMENT_LIST: 'Ability to list (view all) rule enforcements.', - PermissionType.RULE_ENFORCEMENT_VIEW: 'Ability to view a rule enforcement.', - - PermissionType.RUNNER_LIST: 'Ability to list (view all) runners.', - PermissionType.RUNNER_VIEW: 'Ability to view a runner.', - PermissionType.RUNNER_MODIFY: ('Ability to modify (update) an existing runner. Also implies ' - '"runner_type_view" permission.'), - PermissionType.RUNNER_ALL: ('Ability to perform all the supported operations on a particular ' - 'runner.'), - - PermissionType.WEBHOOK_LIST: 'Ability to list (view all) webhooks.', - PermissionType.WEBHOOK_VIEW: ('Ability to view a webhook.'), - PermissionType.WEBHOOK_CREATE: ('Ability to create a new webhook.'), - PermissionType.WEBHOOK_SEND: ('Ability to send / POST data to an existing webhook.'), - PermissionType.WEBHOOK_DELETE: ('Ability to delete an existing webhook.'), - PermissionType.WEBHOOK_ALL: ('Ability to perform all the supported operations on a particular ' - 'webhook.'), - - PermissionType.TIMER_LIST: 'Ability to list (view all) timers.', - PermissionType.TIMER_VIEW: ('Ability to view a timer.'), - PermissionType.TIMER_ALL: ('Ability to perform all the supported operations on timers'), - - PermissionType.API_KEY_LIST: 'Ability to list (view all) API keys.', - PermissionType.API_KEY_VIEW: ('Ability to view an API Key.'), - PermissionType.API_KEY_CREATE: ('Ability to create a new API Key.'), - PermissionType.API_KEY_MODIFY: ('Ability to modify (update) an existing API key. Also implies ' - '"api_key_view" permission.'), - PermissionType.API_KEY_DELETE: ('Ability to delete an existing API Keys.'), - PermissionType.API_KEY_ALL: ('Ability to perform all the supported operations on an API Key.'), - - PermissionType.KEY_VALUE_VIEW: ('Ability to view Key-Value Pairs.'), - PermissionType.KEY_VALUE_SET: ('Ability to set a Key-Value Pair.'), - PermissionType.KEY_VALUE_DELETE: ('Ability to delete an existing Key-Value Pair.'), - - PermissionType.TRACE_LIST: ('Ability to list (view all) traces.'), - PermissionType.TRACE_VIEW: ('Ability to view a trace.'), - PermissionType.TRACE_ALL: ('Ability to perform all the supported operations on traces.'), - - PermissionType.TRIGGER_LIST: ('Ability to list (view all) triggers.'), - PermissionType.TRIGGER_VIEW: ('Ability to view a trigger.'), - PermissionType.TRIGGER_ALL: ('Ability to perform all the supported operations on triggers.'), - - PermissionType.POLICY_TYPE_LIST: ('Ability to list (view all) policy types.'), - PermissionType.POLICY_TYPE_VIEW: ('Ability to view a policy types.'), - PermissionType.POLICY_TYPE_ALL: ('Ability to perform all the supported operations on policy' - ' types.'), - - PermissionType.POLICY_LIST: 'Ability to list (view all) policies.', - PermissionType.POLICY_VIEW: ('Ability to view a policy.'), - PermissionType.POLICY_CREATE: ('Ability to create a new policy.'), - PermissionType.POLICY_MODIFY: ('Ability to modify an existing policy.'), - PermissionType.POLICY_DELETE: ('Ability to delete an existing policy.'), - PermissionType.POLICY_ALL: ('Ability to perform all the supported operations on a particular ' - 'policy.'), - - PermissionType.STREAM_VIEW: ('Ability to view / listen to the events on the stream API ' - 'endpoint.'), - - PermissionType.INQUIRY_LIST: 'Ability to list existing Inquiries', - PermissionType.INQUIRY_VIEW: 'Ability to view an existing Inquiry. Also implies ' - '"inquiry_respond" permission.', - PermissionType.INQUIRY_RESPOND: 'Ability to respond to an existing Inquiry (in general - user ' - 'still needs access per specific inquiry parameters). Also ' - 'implies "inquiry_view" permission.', - PermissionType.INQUIRY_ALL: ('Ability to perform all supported operations on a particular ' - 'Inquiry.') + PermissionType.PACK_LIST: "Ability to list (view all) packs.", + PermissionType.PACK_VIEW: "Ability to view a pack.", + PermissionType.PACK_CREATE: "Ability to create a new pack.", + PermissionType.PACK_MODIFY: "Ability to modify (update) an existing pack.", + PermissionType.PACK_DELETE: "Ability to delete an existing pack.", + PermissionType.PACK_INSTALL: "Ability to install packs.", + PermissionType.PACK_UNINSTALL: "Ability to uninstall packs.", + PermissionType.PACK_REGISTER: "Ability to register packs and corresponding resources.", + PermissionType.PACK_CONFIG: "Ability to configure a pack.", + PermissionType.PACK_SEARCH: "Ability to query registry and search packs.", + PermissionType.PACK_VIEWS_INDEX_HEALTH: "Ability to query health of pack registries.", + PermissionType.PACK_ALL: ( + "Ability to perform all the supported operations on a particular " "pack." + ), + PermissionType.SENSOR_LIST: "Ability to list (view all) sensors.", + PermissionType.SENSOR_VIEW: "Ability to view a sensor", + PermissionType.SENSOR_MODIFY: ( + "Ability to modify (update) an existing sensor. Also implies " + '"sensor_type_view" permission.' + ), + PermissionType.SENSOR_ALL: ( + "Ability to perform all the supported operations on a particular " "sensor." + ), + PermissionType.ACTION_LIST: "Ability to list (view all) actions.", + PermissionType.ACTION_VIEW: "Ability to view an action.", + PermissionType.ACTION_CREATE: ( + 'Ability to create a new action. Also implies "action_view" ' "permission." + ), + PermissionType.ACTION_MODIFY: ( + "Ability to modify (update) an existing action. Also implies " + '"action_view" permission.' + ), + PermissionType.ACTION_DELETE: ( + "Ability to delete an existing action. Also implies " + '"action_view" permission.' + ), + PermissionType.ACTION_EXECUTE: ( + "Ability to execute (run) an action. Also implies " '"action_view" permission.' + ), + PermissionType.ACTION_ALL: ( + "Ability to perform all the supported operations on a particular " "action." + ), + PermissionType.ACTION_ALIAS_LIST: "Ability to list (view all) action aliases.", + PermissionType.ACTION_ALIAS_VIEW: "Ability to view an action alias.", + PermissionType.ACTION_ALIAS_CREATE: ( + "Ability to create a new action alias. Also implies" + ' "action_alias_view" permission.' + ), + PermissionType.ACTION_ALIAS_MODIFY: ( + "Ability to modify (update) an existing action alias. " + 'Also implies "action_alias_view" permission.' + ), + PermissionType.ACTION_ALIAS_MATCH: ( + "Ability to use action alias match API endpoint." + ), + PermissionType.ACTION_ALIAS_HELP: ( + "Ability to use action alias help API endpoint." + ), + PermissionType.ACTION_ALIAS_DELETE: ( + "Ability to delete an existing action alias. Also " + 'implies "action_alias_view" permission.' + ), + PermissionType.ACTION_ALIAS_ALL: ( + "Ability to perform all the supported operations on a " + "particular action alias." + ), + PermissionType.EXECUTION_LIST: "Ability to list (view all) executions.", + PermissionType.EXECUTION_VIEW: "Ability to view an execution.", + PermissionType.EXECUTION_RE_RUN: "Ability to create a new action.", + PermissionType.EXECUTION_STOP: "Ability to stop (cancel) a running execution.", + PermissionType.EXECUTION_ALL: ( + "Ability to perform all the supported operations on a " "particular execution." + ), + PermissionType.EXECUTION_VIEWS_FILTERS_LIST: ( + "Ability view all the distinct execution " "filters." + ), + PermissionType.RULE_LIST: "Ability to list (view all) rules.", + PermissionType.RULE_VIEW: "Ability to view a rule.", + PermissionType.RULE_CREATE: ( + 'Ability to create a new rule. Also implies "rule_view" ' "permission" + ), + PermissionType.RULE_MODIFY: ( + "Ability to modify (update) an existing rule. Also implies " + '"rule_view" permission.' + ), + PermissionType.RULE_DELETE: ( + 'Ability to delete an existing rule. Also implies "rule_view" ' "permission." + ), + PermissionType.RULE_ALL: ( + "Ability to perform all the supported operations on a particular " "rule." + ), + PermissionType.RULE_ENFORCEMENT_LIST: "Ability to list (view all) rule enforcements.", + PermissionType.RULE_ENFORCEMENT_VIEW: "Ability to view a rule enforcement.", + PermissionType.RUNNER_LIST: "Ability to list (view all) runners.", + PermissionType.RUNNER_VIEW: "Ability to view a runner.", + PermissionType.RUNNER_MODIFY: ( + "Ability to modify (update) an existing runner. Also implies " + '"runner_type_view" permission.' + ), + PermissionType.RUNNER_ALL: ( + "Ability to perform all the supported operations on a particular " "runner." + ), + PermissionType.WEBHOOK_LIST: "Ability to list (view all) webhooks.", + PermissionType.WEBHOOK_VIEW: ("Ability to view a webhook."), + PermissionType.WEBHOOK_CREATE: ("Ability to create a new webhook."), + PermissionType.WEBHOOK_SEND: ( + "Ability to send / POST data to an existing webhook." + ), + PermissionType.WEBHOOK_DELETE: ("Ability to delete an existing webhook."), + PermissionType.WEBHOOK_ALL: ( + "Ability to perform all the supported operations on a particular " "webhook." + ), + PermissionType.TIMER_LIST: "Ability to list (view all) timers.", + PermissionType.TIMER_VIEW: ("Ability to view a timer."), + PermissionType.TIMER_ALL: ( + "Ability to perform all the supported operations on timers" + ), + PermissionType.API_KEY_LIST: "Ability to list (view all) API keys.", + PermissionType.API_KEY_VIEW: ("Ability to view an API Key."), + PermissionType.API_KEY_CREATE: ("Ability to create a new API Key."), + PermissionType.API_KEY_MODIFY: ( + "Ability to modify (update) an existing API key. Also implies " + '"api_key_view" permission.' + ), + PermissionType.API_KEY_DELETE: ("Ability to delete an existing API Keys."), + PermissionType.API_KEY_ALL: ( + "Ability to perform all the supported operations on an API Key." + ), + PermissionType.KEY_VALUE_VIEW: ("Ability to view Key-Value Pairs."), + PermissionType.KEY_VALUE_SET: ("Ability to set a Key-Value Pair."), + PermissionType.KEY_VALUE_DELETE: ("Ability to delete an existing Key-Value Pair."), + PermissionType.TRACE_LIST: ("Ability to list (view all) traces."), + PermissionType.TRACE_VIEW: ("Ability to view a trace."), + PermissionType.TRACE_ALL: ( + "Ability to perform all the supported operations on traces." + ), + PermissionType.TRIGGER_LIST: ("Ability to list (view all) triggers."), + PermissionType.TRIGGER_VIEW: ("Ability to view a trigger."), + PermissionType.TRIGGER_ALL: ( + "Ability to perform all the supported operations on triggers." + ), + PermissionType.POLICY_TYPE_LIST: ("Ability to list (view all) policy types."), + PermissionType.POLICY_TYPE_VIEW: ("Ability to view a policy types."), + PermissionType.POLICY_TYPE_ALL: ( + "Ability to perform all the supported operations on policy" " types." + ), + PermissionType.POLICY_LIST: "Ability to list (view all) policies.", + PermissionType.POLICY_VIEW: ("Ability to view a policy."), + PermissionType.POLICY_CREATE: ("Ability to create a new policy."), + PermissionType.POLICY_MODIFY: ("Ability to modify an existing policy."), + PermissionType.POLICY_DELETE: ("Ability to delete an existing policy."), + PermissionType.POLICY_ALL: ( + "Ability to perform all the supported operations on a particular " "policy." + ), + PermissionType.STREAM_VIEW: ( + "Ability to view / listen to the events on the stream API " "endpoint." + ), + PermissionType.INQUIRY_LIST: "Ability to list existing Inquiries", + PermissionType.INQUIRY_VIEW: "Ability to view an existing Inquiry. Also implies " + '"inquiry_respond" permission.', + PermissionType.INQUIRY_RESPOND: "Ability to respond to an existing Inquiry (in general - user " + "still needs access per specific inquiry parameters). Also " + 'implies "inquiry_view" permission.', + PermissionType.INQUIRY_ALL: ( + "Ability to perform all supported operations on a particular " "Inquiry." + ), } @@ -607,10 +636,13 @@ def get_resource_permission_types_with_descriptions(): """ result = {} - for resource_type, permission_types in six.iteritems(RESOURCE_TYPE_TO_PERMISSION_TYPES_MAP): + for resource_type, permission_types in six.iteritems( + RESOURCE_TYPE_TO_PERMISSION_TYPES_MAP + ): result[resource_type] = {} for permission_type in permission_types: - result[resource_type][permission_type] = \ - PERMISION_TYPE_TO_DESCRIPTION_MAP[permission_type] + result[resource_type][permission_type] = PERMISION_TYPE_TO_DESCRIPTION_MAP[ + permission_type + ] return result diff --git a/st2common/st2common/router.py b/st2common/st2common/router.py index 29b34031b40..47ef009b987 100644 --- a/st2common/st2common/router.py +++ b/st2common/st2common/router.py @@ -43,15 +43,12 @@ from st2common.util.http import parse_content_type_header __all__ = [ - 'Router', - - 'Response', - - 'NotFoundException', - - 'abort', - 'abort_unauthorized', - 'exc' + "Router", + "Response", + "NotFoundException", + "abort", + "abort_unauthorized", + "exc", ] LOG = logging.getLogger(__name__) @@ -63,24 +60,24 @@ def op_resolver(op_id): :rtype: ``tuple`` """ - module_name, func_name = op_id.split(':', 1) - controller_name = func_name.split('.')[0] + module_name, func_name = op_id.split(":", 1) + controller_name = func_name.split(".")[0] __import__(module_name) module = sys.modules[module_name] controller_instance = getattr(module, controller_name) - method_callable = functools.reduce(getattr, func_name.split('.'), module) + method_callable = functools.reduce(getattr, func_name.split("."), module) return controller_instance, method_callable -def abort(status_code=exc.HTTPInternalServerError.code, message='Unhandled exception'): +def abort(status_code=exc.HTTPInternalServerError.code, message="Unhandled exception"): raise exc.status_map[status_code](message) def abort_unauthorized(msg=None): - raise exc.HTTPUnauthorized('Unauthorized - %s' % msg if msg else 'Unauthorized') + raise exc.HTTPUnauthorized("Unauthorized - %s" % msg if msg else "Unauthorized") def extend_with_default(validator_class): @@ -92,12 +89,16 @@ def set_defaults(validator, properties, instance, schema): instance.setdefault(property, subschema["default"]) for error in validate_properties( - validator, properties, instance, schema, + validator, + properties, + instance, + schema, ): yield error return jsonschema.validators.extend( - validator_class, {"properties": set_defaults}, + validator_class, + {"properties": set_defaults}, ) @@ -109,7 +110,8 @@ def set_additional_check(validator, properties, instance, schema): yield error return jsonschema.validators.extend( - validator_class, {"x-additional-check": set_additional_check}, + validator_class, + {"x-additional-check": set_additional_check}, ) @@ -126,7 +128,8 @@ def set_type_draft4(validator, types, instance, schema): yield error return jsonschema.validators.extend( - validator_class, {"type": set_type_draft4}, + validator_class, + {"type": set_type_draft4}, ) @@ -141,27 +144,40 @@ class NotFoundException(Exception): class Response(webob.Response): - def __init__(self, body=None, status=None, headerlist=None, app_iter=None, content_type=None, - *args, **kwargs): + def __init__( + self, + body=None, + status=None, + headerlist=None, + app_iter=None, + content_type=None, + *args, + **kwargs, + ): # Do some sanity checking, and turn json_body into an actual body - if app_iter is None and body is None and ('json_body' in kwargs or 'json' in kwargs): - if 'json_body' in kwargs: - json_body = kwargs.pop('json_body') + if ( + app_iter is None + and body is None + and ("json_body" in kwargs or "json" in kwargs) + ): + if "json_body" in kwargs: + json_body = kwargs.pop("json_body") else: - json_body = kwargs.pop('json') - body = json_encode(json_body).encode('UTF-8') + json_body = kwargs.pop("json") + body = json_encode(json_body).encode("UTF-8") if content_type is None: - content_type = 'application/json' + content_type = "application/json" - super(Response, self).__init__(body, status, headerlist, app_iter, content_type, - *args, **kwargs) + super(Response, self).__init__( + body, status, headerlist, app_iter, content_type, *args, **kwargs + ) def _json_body__get(self): return super(Response, self)._json_body__get() def _json_body__set(self, value): - self.body = json_encode(value).encode('UTF-8') + self.body = json_encode(value).encode("UTF-8") def _json_body__del(self): return super(Response, self)._json_body__del() @@ -182,44 +198,51 @@ def __init__(self, arguments=None, debug=False, auth=True, is_gunicorn=True): self.routes = routes.Mapper() def add_spec(self, spec, transforms): - info = spec.get('info', {}) - LOG.debug('Adding API: %s %s', info.get('title', 'untitled'), info.get('version', '0.0.0')) + info = spec.get("info", {}) + LOG.debug( + "Adding API: %s %s", + info.get("title", "untitled"), + info.get("version", "0.0.0"), + ) self.spec = spec - self.spec_resolver = jsonschema.RefResolver('', self.spec) + self.spec_resolver = jsonschema.RefResolver("", self.spec) validate(copy.deepcopy(self.spec)) for filter in transforms: - for (path, methods) in six.iteritems(spec['paths']): + for (path, methods) in six.iteritems(spec["paths"]): if not re.search(filter, path): continue for (method, endpoint) in six.iteritems(methods): - conditions = { - 'method': [method.upper()] - } + conditions = {"method": [method.upper()]} connect_kw = {} - if 'x-requirements' in endpoint: - connect_kw['requirements'] = endpoint['x-requirements'] + if "x-requirements" in endpoint: + connect_kw["requirements"] = endpoint["x-requirements"] - m = self.routes.submapper(_api_path=path, _api_method=method, - conditions=conditions) + m = self.routes.submapper( + _api_path=path, _api_method=method, conditions=conditions + ) for transform in transforms[filter]: m.connect(None, re.sub(filter, transform, path), **connect_kw) - module_name = endpoint['operationId'].split(':', 1)[0] + module_name = endpoint["operationId"].split(":", 1)[0] __import__(module_name) for route in sorted(self.routes.matchlist, key=lambda r: r.routepath): - LOG.debug('Route registered: %+6s %s', route.conditions['method'][0], route.routepath) + LOG.debug( + "Route registered: %+6s %s", + route.conditions["method"][0], + route.routepath, + ) def match(self, req): path = url_unquote(req.path) LOG.debug("Match path: %s", path) - if len(path) > 1 and path.endswith('/'): + if len(path) > 1 and path.endswith("/"): path = path[:-1] match = self.routes.match(path, req.environ) @@ -235,9 +258,9 @@ def match(self, req): path_vars = dict(path_vars) - path = path_vars.pop('_api_path') - method = path_vars.pop('_api_method') - endpoint = self.spec['paths'][path][method] + path = path_vars.pop("_api_path") + method = path_vars.pop("_api_method") + endpoint = self.spec["paths"][path][method] return endpoint, path_vars @@ -256,127 +279,140 @@ def __call__(self, req): LOG.debug("Parsed endpoint: %s", endpoint) LOG.debug("Parsed path_vars: %s", path_vars) - context = copy.copy(getattr(self, 'mock_context', {})) + context = copy.copy(getattr(self, "mock_context", {})) cookie_token = None # Handle security - if 'security' in endpoint: - security = endpoint.get('security') + if "security" in endpoint: + security = endpoint.get("security") else: - security = self.spec.get('security', []) + security = self.spec.get("security", []) if self.auth and security: try: - security_definitions = self.spec.get('securityDefinitions', {}) + security_definitions = self.spec.get("securityDefinitions", {}) for statement in security: declaration, options = statement.copy().popitem() definition = security_definitions[declaration] - if definition['type'] == 'apiKey': - if definition['in'] == 'header': - token = req.headers.get(definition['name']) - elif definition['in'] == 'query': - token = req.GET.get(definition['name']) - elif definition['in'] == 'cookie': - token = req.cookies.get(definition['name']) + if definition["type"] == "apiKey": + if definition["in"] == "header": + token = req.headers.get(definition["name"]) + elif definition["in"] == "query": + token = req.GET.get(definition["name"]) + elif definition["in"] == "cookie": + token = req.cookies.get(definition["name"]) else: token = None if token: - _, auth_func = op_resolver(definition['x-operationId']) + _, auth_func = op_resolver(definition["x-operationId"]) auth_resp = auth_func(token) # Include information on how user authenticated inside the context - if 'auth-token' in definition['name'].lower(): - auth_method = 'authentication token' - elif 'api-key' in definition['name'].lower(): - auth_method = 'API key' - - context['user'] = User.get_by_name(auth_resp.user) - context['auth_info'] = { - 'method': auth_method, - 'location': definition['in'] + if "auth-token" in definition["name"].lower(): + auth_method = "authentication token" + elif "api-key" in definition["name"].lower(): + auth_method = "API key" + + context["user"] = User.get_by_name(auth_resp.user) + context["auth_info"] = { + "method": auth_method, + "location": definition["in"], } # Also include token expiration time when authenticated via auth token - if 'auth-token' in definition['name'].lower(): - context['auth_info']['token_expire'] = auth_resp.expiry - - if 'x-set-cookie' in definition: - max_age = auth_resp.expiry - date_utils.get_datetime_utc_now() - cookie_token = cookies.make_cookie(definition['x-set-cookie'], - token, - max_age=max_age, - httponly=True) + if "auth-token" in definition["name"].lower(): + context["auth_info"]["token_expire"] = auth_resp.expiry + + if "x-set-cookie" in definition: + max_age = ( + auth_resp.expiry - date_utils.get_datetime_utc_now() + ) + cookie_token = cookies.make_cookie( + definition["x-set-cookie"], + token, + max_age=max_age, + httponly=True, + ) break - if 'user' not in context: - raise auth_exc.NoAuthSourceProvidedError('One of Token or API key required.') - except (auth_exc.NoAuthSourceProvidedError, - auth_exc.MultipleAuthSourcesError) as e: + if "user" not in context: + raise auth_exc.NoAuthSourceProvidedError( + "One of Token or API key required." + ) + except ( + auth_exc.NoAuthSourceProvidedError, + auth_exc.MultipleAuthSourcesError, + ) as e: LOG.error(six.text_type(e)) return abort_unauthorized(six.text_type(e)) except auth_exc.TokenNotProvidedError as e: - LOG.exception('Token is not provided.') + LOG.exception("Token is not provided.") return abort_unauthorized(six.text_type(e)) except auth_exc.TokenNotFoundError as e: - LOG.exception('Token is not found.') + LOG.exception("Token is not found.") return abort_unauthorized(six.text_type(e)) except auth_exc.TokenExpiredError as e: - LOG.exception('Token has expired.') + LOG.exception("Token has expired.") return abort_unauthorized(six.text_type(e)) except auth_exc.ApiKeyNotProvidedError as e: - LOG.exception('API key is not provided.') + LOG.exception("API key is not provided.") return abort_unauthorized(six.text_type(e)) except auth_exc.ApiKeyNotFoundError as e: - LOG.exception('API key is not found.') + LOG.exception("API key is not found.") return abort_unauthorized(six.text_type(e)) except auth_exc.ApiKeyDisabledError as e: - LOG.exception('API key is disabled.') + LOG.exception("API key is disabled.") return abort_unauthorized(six.text_type(e)) if cfg.CONF.rbac.enable: - user_db = context['user'] + user_db = context["user"] - permission_type = endpoint.get('x-permissions', None) + permission_type = endpoint.get("x-permissions", None) if permission_type: rbac_backend = get_rbac_backend() - resolver = rbac_backend.get_resolver_for_permission_type(permission_type) - has_permission = resolver.user_has_permission(user_db, permission_type) + resolver = rbac_backend.get_resolver_for_permission_type( + permission_type + ) + has_permission = resolver.user_has_permission( + user_db, permission_type + ) if not has_permission: - raise rbac_exc.ResourceTypeAccessDeniedError(user_db, - permission_type) + raise rbac_exc.ResourceTypeAccessDeniedError( + user_db, permission_type + ) # Collect parameters kw = {} - for param in endpoint.get('parameters', []) + endpoint.get('x-parameters', []): - name = param['name'] - argument_name = param.get('x-as', None) or name - source = param['in'] - default = param.get('default', None) + for param in endpoint.get("parameters", []) + endpoint.get("x-parameters", []): + name = param["name"] + argument_name = param.get("x-as", None) or name + source = param["in"] + default = param.get("default", None) # Collecting params from different sources - if source == 'query': + if source == "query": kw[argument_name] = req.GET.get(name, default) - elif source == 'path': + elif source == "path": kw[argument_name] = path_vars[name] - elif source == 'header': + elif source == "header": kw[argument_name] = req.headers.get(name, default) - elif source == 'formData': + elif source == "formData": kw[argument_name] = req.POST.get(name, default) - elif source == 'environ': + elif source == "environ": kw[argument_name] = req.environ.get(name.upper(), default) - elif source == 'context': + elif source == "context": kw[argument_name] = context.get(name, default) - elif source == 'request': + elif source == "request": kw[argument_name] = getattr(req, name) - elif source == 'body': - content_type = req.headers.get('Content-Type', 'application/json') + elif source == "body": + content_type = req.headers.get("Content-Type", "application/json") content_type = parse_content_type_header(content_type=content_type)[0] - schema = param['schema'] + schema = param["schema"] # NOTE: HACK: Workaround for eventlet wsgi server which sets Content-Type to # text/plain if Content-Type is not provided in the request. @@ -384,65 +420,76 @@ def __call__(self, req): # expect application/json so we explicitly set it to that # if not provided (set to text/plain by the base http server) and if it's not # /v1/workflows/inspection API endpoints. - if not self.is_gunicorn and content_type == 'text/plain': - operation_id = endpoint['operationId'] + if not self.is_gunicorn and content_type == "text/plain": + operation_id = endpoint["operationId"] - if ('workflow_inspection_controller' not in operation_id): - content_type = 'application/json' + if "workflow_inspection_controller" not in operation_id: + content_type = "application/json" # Note: We also want to perform validation if no body is explicitly provided - in a # lot of POST, PUT scenarios, body is mandatory - if not req.body and content_type == 'application/json': - req.body = b'{}' + if not req.body and content_type == "application/json": + req.body = b"{}" try: - if content_type == 'application/json': + if content_type == "application/json": data = req.json - elif content_type == 'text/plain': + elif content_type == "text/plain": data = req.body - elif content_type in ['application/x-www-form-urlencoded', - 'multipart/form-data']: + elif content_type in [ + "application/x-www-form-urlencoded", + "multipart/form-data", + ]: data = urlparse.parse_qs(req.body) else: - raise ValueError('Unsupported Content-Type: "%s"' % (content_type)) + raise ValueError( + 'Unsupported Content-Type: "%s"' % (content_type) + ) except Exception as e: - detail = 'Failed to parse request body: %s' % six.text_type(e) + detail = "Failed to parse request body: %s" % six.text_type(e) raise exc.HTTPBadRequest(detail=detail) # Special case for Python 3 - if six.PY3 and content_type == 'text/plain' and isinstance(data, six.binary_type): + if ( + six.PY3 + and content_type == "text/plain" + and isinstance(data, six.binary_type) + ): # Convert bytes to text type (string / unicode) - data = data.decode('utf-8') + data = data.decode("utf-8") try: CustomValidator(schema, resolver=self.spec_resolver).validate(data) except (jsonschema.ValidationError, ValueError) as e: - raise exc.HTTPBadRequest(detail=getattr(e, 'message', six.text_type(e)), - comment=traceback.format_exc()) + raise exc.HTTPBadRequest( + detail=getattr(e, "message", six.text_type(e)), + comment=traceback.format_exc(), + ) - if content_type == 'text/plain': + if content_type == "text/plain": kw[argument_name] = data else: + class Body(object): def __init__(self, **entries): self.__dict__.update(entries) - ref = schema.get('$ref', None) + ref = schema.get("$ref", None) if ref: with self.spec_resolver.resolving(ref) as resolved: schema = resolved - if 'x-api-model' in schema: - input_type = schema.get('type', []) - _, Model = op_resolver(schema['x-api-model']) + if "x-api-model" in schema: + input_type = schema.get("type", []) + _, Model = op_resolver(schema["x-api-model"]) if input_type and not isinstance(input_type, (list, tuple)): input_type = [input_type] # root attribute is not an object, we need to use wrapper attribute to # make it work with **kwarg expansion - if input_type and 'array' in input_type: - data = {'data': data} + if input_type and "array" in input_type: + data = {"data": data} instance = self._get_model_instance(model_cls=Model, data=data) @@ -451,143 +498,178 @@ def __init__(self, **entries): try: instance = instance.validate() except (jsonschema.ValidationError, ValueError) as e: - raise exc.HTTPBadRequest(detail=getattr(e, 'message', six.text_type(e)), - comment=traceback.format_exc()) + raise exc.HTTPBadRequest( + detail=getattr(e, "message", six.text_type(e)), + comment=traceback.format_exc(), + ) else: - LOG.debug('Missing x-api-model definition for %s, using generic Body ' - 'model.' % (endpoint['operationId'])) + LOG.debug( + "Missing x-api-model definition for %s, using generic Body " + "model." % (endpoint["operationId"]) + ) model = Body instance = self._get_model_instance(model_cls=model, data=data) kw[argument_name] = instance # Making sure all required params are present - required = param.get('required', False) + required = param.get("required", False) if required and kw[argument_name] is None: detail = 'Required parameter "%s" is missing' % name raise exc.HTTPBadRequest(detail=detail) # Validating and casting param types - param_type = param.get('type', None) + param_type = param.get("type", None) if kw[argument_name] is not None: - if param_type == 'boolean': - positive = ('true', '1', 'yes', 'y') - negative = ('false', '0', 'no', 'n') + if param_type == "boolean": + positive = ("true", "1", "yes", "y") + negative = ("false", "0", "no", "n") if str(kw[argument_name]).lower() not in positive + negative: detail = 'Parameter "%s" is not of type boolean' % argument_name raise exc.HTTPBadRequest(detail=detail) kw[argument_name] = str(kw[argument_name]).lower() in positive - elif param_type == 'integer': - regex = r'^-?[0-9]+$' + elif param_type == "integer": + regex = r"^-?[0-9]+$" if not re.search(regex, str(kw[argument_name])): detail = 'Parameter "%s" is not of type integer' % argument_name raise exc.HTTPBadRequest(detail=detail) kw[argument_name] = int(kw[argument_name]) - elif param_type == 'number': - regex = r'^[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?$' + elif param_type == "number": + regex = r"^[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?$" if not re.search(regex, str(kw[argument_name])): detail = 'Parameter "%s" is not of type float' % argument_name raise exc.HTTPBadRequest(detail=detail) kw[argument_name] = float(kw[argument_name]) - elif param_type == 'array' and param.get('items', {}).get('type', None) == 'string': + elif ( + param_type == "array" + and param.get("items", {}).get("type", None) == "string" + ): if kw[argument_name] is None: kw[argument_name] = [] elif isinstance(kw[argument_name], (list, tuple)): # argument is already an array pass else: - kw[argument_name] = kw[argument_name].split(',') + kw[argument_name] = kw[argument_name].split(",") # Call the controller try: - controller_instance, func = op_resolver(endpoint['operationId']) + controller_instance, func = op_resolver(endpoint["operationId"]) except Exception as e: - LOG.exception('Failed to load controller for operation "%s": %s' % - (endpoint['operationId'], six.text_type(e))) + LOG.exception( + 'Failed to load controller for operation "%s": %s' + % (endpoint["operationId"], six.text_type(e)) + ) raise e try: resp = func(**kw) except DataStoreKeyNotFoundError as e: - LOG.warning('Failed to call controller function "%s" for operation "%s": %s' % - (func.__name__, endpoint['operationId'], six.text_type(e))) + LOG.warning( + 'Failed to call controller function "%s" for operation "%s": %s' + % (func.__name__, endpoint["operationId"], six.text_type(e)) + ) raise e except Exception as e: - LOG.exception('Failed to call controller function "%s" for operation "%s": %s' % - (func.__name__, endpoint['operationId'], six.text_type(e))) + LOG.exception( + 'Failed to call controller function "%s" for operation "%s": %s' + % (func.__name__, endpoint["operationId"], six.text_type(e)) + ) raise e # Handle response if resp is None: resp = Response() - if not hasattr(resp, '__call__'): + if not hasattr(resp, "__call__"): resp = Response(json=resp) - operation_id = endpoint['operationId'] + operation_id = endpoint["operationId"] # Process the response removing attributes based on the exclude_attribute and # include_attributes query param filter values (if specified) - include_attributes = kw.get('include_attributes', None) - exclude_attributes = kw.get('exclude_attributes', None) - has_include_or_exclude_attributes = bool(include_attributes) or bool(exclude_attributes) + include_attributes = kw.get("include_attributes", None) + exclude_attributes = kw.get("exclude_attributes", None) + has_include_or_exclude_attributes = bool(include_attributes) or bool( + exclude_attributes + ) # NOTE: We do NOT want to process stream controller response - is_streamming_controller = endpoint.get('x-is-streaming-endpoint', - bool('st2stream' in operation_id)) - - if not is_streamming_controller and resp.body and has_include_or_exclude_attributes: + is_streamming_controller = endpoint.get( + "x-is-streaming-endpoint", bool("st2stream" in operation_id) + ) + + if ( + not is_streamming_controller + and resp.body + and has_include_or_exclude_attributes + ): # NOTE: We need to check for response.body attribute since resp.json throws if JSON # response is not available - mandatory_include_fields = getattr(controller_instance, - 'mandatory_include_fields_response', []) - data = self._process_response(data=resp.json, - mandatory_include_fields=mandatory_include_fields, - include_attributes=include_attributes, - exclude_attributes=exclude_attributes) + mandatory_include_fields = getattr( + controller_instance, "mandatory_include_fields_response", [] + ) + data = self._process_response( + data=resp.json, + mandatory_include_fields=mandatory_include_fields, + include_attributes=include_attributes, + exclude_attributes=exclude_attributes, + ) resp.json = data - responses = endpoint.get('responses', {}) + responses = endpoint.get("responses", {}) response_spec = responses.get(str(resp.status_code), None) - default_response_spec = responses.get('default', None) + default_response_spec = responses.get("default", None) if not response_spec and default_response_spec: - LOG.debug('No custom response spec found for endpoint "%s", using a default one' % - (endpoint['operationId'])) - response_spec_name = 'default' + LOG.debug( + 'No custom response spec found for endpoint "%s", using a default one' + % (endpoint["operationId"]) + ) + response_spec_name = "default" else: response_spec_name = str(resp.status_code) response_spec = response_spec or default_response_spec - if response_spec and 'schema' in response_spec and not has_include_or_exclude_attributes: + if ( + response_spec + and "schema" in response_spec + and not has_include_or_exclude_attributes + ): # NOTE: We don't perform response validation when include or exclude attributes are # provided because this means partial response which likely won't pass the validation - LOG.debug('Using response spec "%s" for endpoint %s and status code %s' % - (response_spec_name, endpoint['operationId'], resp.status_code)) + LOG.debug( + 'Using response spec "%s" for endpoint %s and status code %s' + % (response_spec_name, endpoint["operationId"], resp.status_code) + ) try: - validator = CustomValidator(response_spec['schema'], resolver=self.spec_resolver) + validator = CustomValidator( + response_spec["schema"], resolver=self.spec_resolver + ) - response_type = response_spec['schema'].get('type', 'json') - if response_type == 'string': + response_type = response_spec["schema"].get("type", "json") + if response_type == "string": validator.validate(resp.text) else: validator.validate(resp.json) except (jsonschema.ValidationError, ValueError): - LOG.exception('Response validation failed.') - resp.headers.add('Warning', '199 OpenAPI "Response validation failed"') + LOG.exception("Response validation failed.") + resp.headers.add("Warning", '199 OpenAPI "Response validation failed"') else: - LOG.debug('No response spec found for endpoint "%s"' % (endpoint['operationId'])) + LOG.debug( + 'No response spec found for endpoint "%s"' % (endpoint["operationId"]) + ) if cookie_token: - resp.headerlist.append(('Set-Cookie', cookie_token)) + resp.headerlist.append(("Set-Cookie", cookie_token)) return resp @@ -604,17 +686,24 @@ def _get_model_instance(self, model_cls, data): instance = model_cls(**data) except TypeError as e: # Throw a more user-friendly exception when input data is not an object - if 'type object argument after ** must be a mapping, not' in six.text_type(e): + if "type object argument after ** must be a mapping, not" in six.text_type( + e + ): type_string = get_json_type_for_python_value(data) - msg = ('Input body needs to be an object, got: %s' % (type_string)) + msg = "Input body needs to be an object, got: %s" % (type_string) raise ValueError(msg) raise e return instance - def _process_response(self, data, mandatory_include_fields=None, include_attributes=None, - exclude_attributes=None): + def _process_response( + self, + data, + mandatory_include_fields=None, + include_attributes=None, + exclude_attributes=None, + ): """ Process controller response data such as removing attributes based on the values of exclude_attributes and include_attributes query param filters and similar. @@ -628,8 +717,10 @@ def _process_response(self, data, mandatory_include_fields=None, include_attribu # NOTE: include_attributes and exclude_attributes are mutually exclusive if include_attributes and exclude_attributes: - msg = ('exclude_attributes and include_attributes arguments are mutually exclusive. ' - 'You need to provide either one or another, but not both.') + msg = ( + "exclude_attributes and include_attributes arguments are mutually exclusive. " + "You need to provide either one or another, but not both." + ) raise ValueError(msg) # Common case - filters are not provided @@ -637,16 +728,20 @@ def _process_response(self, data, mandatory_include_fields=None, include_attribu return data # Skip processing of error responses - if isinstance(data, dict) and data.get('faultstring', None): + if isinstance(data, dict) and data.get("faultstring", None): return data # We only care about the first part of the field name since deep filtering happens inside # MongoDB. Deep filtering here would also be quite expensive and waste of CPU cycles. - cleaned_include_attributes = [attribute.split('.')[0] for attribute in include_attributes] + cleaned_include_attributes = [ + attribute.split(".")[0] for attribute in include_attributes + ] # Add in mandatory fields which always need to be present in the response (primary keys) cleaned_include_attributes += mandatory_include_fields - cleaned_exclude_attributes = [attribute.split('.')[0] for attribute in exclude_attributes] + cleaned_exclude_attributes = [ + attribute.split(".")[0] for attribute in exclude_attributes + ] # NOTE: Since those parameters are mutually exclusive we could perform more efficient # filtering when just exclude_attributes is provided. Instead of creating a new dict, we @@ -675,6 +770,6 @@ def process_item(item): # get_one response result = process_item(data) else: - raise ValueError('Unsupported type: %s' % (type(data))) + raise ValueError("Unsupported type: %s" % (type(data))) return result diff --git a/st2common/st2common/runners/__init__.py b/st2common/st2common/runners/__init__.py index bcccaaf48d8..d6468f78e2e 100644 --- a/st2common/st2common/runners/__init__.py +++ b/st2common/st2common/runners/__init__.py @@ -19,14 +19,9 @@ from st2common.util import driver_loader -__all__ = [ - 'BACKENDS_NAMESPACE', +__all__ = ["BACKENDS_NAMESPACE", "get_available_backends", "get_backend_driver"] - 'get_available_backends', - 'get_backend_driver' -] - -BACKENDS_NAMESPACE = 'st2common.runners.runner' +BACKENDS_NAMESPACE = "st2common.runners.runner" def get_available_backends(): diff --git a/st2common/st2common/runners/base.py b/st2common/st2common/runners/base.py index 6a9656b9a12..a5692b66e63 100644 --- a/st2common/st2common/runners/base.py +++ b/st2common/st2common/runners/base.py @@ -42,45 +42,43 @@ subprocess = concurrency.get_subprocess_module() __all__ = [ - 'ActionRunner', - 'AsyncActionRunner', - 'PollingAsyncActionRunner', - 'GitWorktreeActionRunner', - 'PollingAsyncActionRunner', - 'ShellRunnerMixin', - - 'get_runner_module', - - 'get_runner', - 'get_metadata', + "ActionRunner", + "AsyncActionRunner", + "PollingAsyncActionRunner", + "GitWorktreeActionRunner", + "PollingAsyncActionRunner", + "ShellRunnerMixin", + "get_runner_module", + "get_runner", + "get_metadata", ] LOG = logging.getLogger(__name__) # constants to lookup in runner_parameters -RUNNER_COMMAND = 'cmd' -RUNNER_CONTENT_VERSION = 'content_version' -RUNNER_DEBUG = 'debug' +RUNNER_COMMAND = "cmd" +RUNNER_CONTENT_VERSION = "content_version" +RUNNER_DEBUG = "debug" def get_runner(name, config=None): """ Load the module and return an instance of the runner. """ - LOG.debug('Runner loading Python module: %s', name) + LOG.debug("Runner loading Python module: %s", name) module = get_runner_module(name=name) - LOG.debug('Instance of runner module: %s', module) + LOG.debug("Instance of runner module: %s", module) if config: - runner_kwargs = {'config': config} + runner_kwargs = {"config": config} else: runner_kwargs = {} runner = module.get_runner(**runner_kwargs) - LOG.debug('Instance of runner: %s', runner) + LOG.debug("Instance of runner: %s", runner) return runner @@ -95,19 +93,21 @@ def get_runner_module(name): try: module = get_plugin_instance(RUNNERS_NAMESPACE, name, invoke_on_load=False) except NoMatches: - name = name.replace('_', '-') + name = name.replace("_", "-") try: module = get_plugin_instance(RUNNERS_NAMESPACE, name, invoke_on_load=False) except Exception as e: available_runners = get_available_plugins(namespace=RUNNERS_NAMESPACE) - available_runners = ', '.join(available_runners) - msg = ('Failed to find runner %s. Make sure that the runner is available and installed ' - 'in StackStorm virtual environment. Available runners are: %s' % - (name, available_runners)) + available_runners = ", ".join(available_runners) + msg = ( + "Failed to find runner %s. Make sure that the runner is available and installed " + "in StackStorm virtual environment. Available runners are: %s" + % (name, available_runners) + ) LOG.exception(msg) - raise exc.ActionRunnerCreateError('%s\n\n%s' % (msg, six.text_type(e))) + raise exc.ActionRunnerCreateError("%s\n\n%s" % (msg, six.text_type(e))) return module @@ -120,9 +120,9 @@ def get_metadata(package_name): """ import pkg_resources - file_path = pkg_resources.resource_filename(package_name, 'runner.yaml') + file_path = pkg_resources.resource_filename(package_name, "runner.yaml") - with open(file_path, 'r') as fp: + with open(file_path, "r") as fp: content = fp.read() metadata = yaml.safe_load(content) @@ -158,14 +158,14 @@ def __init__(self, runner_id): def pre_run(self): # Handle runner "enabled" attribute - runner_enabled = getattr(self.runner_type, 'enabled', True) - runner_name = getattr(self.runner_type, 'name', 'unknown') + runner_enabled = getattr(self.runner_type, "enabled", True) + runner_name = getattr(self.runner_type, "name", "unknown") if not runner_enabled: msg = 'Runner "%s" has been disabled by the administrator.' % runner_name raise ValueError(msg) - runner_parameters = getattr(self, 'runner_parameters', {}) or {} + runner_parameters = getattr(self, "runner_parameters", {}) or {} self._debug = runner_parameters.get(RUNNER_DEBUG, False) # Run will need to take an action argument @@ -175,18 +175,20 @@ def run(self, action_parameters): raise NotImplementedError() def pause(self): - runner_name = getattr(self.runner_type, 'name', 'unknown') - raise NotImplementedError('Pause is not supported for runner %s.' % runner_name) + runner_name = getattr(self.runner_type, "name", "unknown") + raise NotImplementedError("Pause is not supported for runner %s." % runner_name) def resume(self): - runner_name = getattr(self.runner_type, 'name', 'unknown') - raise NotImplementedError('Resume is not supported for runner %s.' % runner_name) + runner_name = getattr(self.runner_type, "name", "unknown") + raise NotImplementedError( + "Resume is not supported for runner %s." % runner_name + ) def cancel(self): return ( action_constants.LIVEACTION_STATUS_CANCELED, self.liveaction.result, - self.liveaction.context + self.liveaction.context, ) def post_run(self, status, result): @@ -213,8 +215,8 @@ def get_user(self): :rtype: ``str`` """ - context = getattr(self, 'context', {}) or {} - user = context.get('user', cfg.CONF.system_user.user) + context = getattr(self, "context", {}) or {} + user = context.get("user", cfg.CONF.system_user.user) return user @@ -228,18 +230,18 @@ def _get_common_action_env_variables(self): :rtype: ``dict`` """ result = {} - result['ST2_ACTION_PACK_NAME'] = self.get_pack_ref() - result['ST2_ACTION_EXECUTION_ID'] = str(self.execution_id) - result['ST2_ACTION_API_URL'] = get_full_public_api_url() + result["ST2_ACTION_PACK_NAME"] = self.get_pack_ref() + result["ST2_ACTION_EXECUTION_ID"] = str(self.execution_id) + result["ST2_ACTION_API_URL"] = get_full_public_api_url() if self.auth_token: - result['ST2_ACTION_AUTH_TOKEN'] = self.auth_token.token + result["ST2_ACTION_AUTH_TOKEN"] = self.auth_token.token return result def __str__(self): - attrs = ', '.join(['%s=%s' % (k, v) for k, v in six.iteritems(self.__dict__)]) - return '%s@%s(%s)' % (self.__class__.__name__, str(id(self)), attrs) + attrs = ", ".join(["%s=%s" % (k, v) for k, v in six.iteritems(self.__dict__)]) + return "%s@%s(%s)" % (self.__class__.__name__, str(id(self)), attrs) @six.add_metaclass(abc.ABCMeta) @@ -248,7 +250,6 @@ class AsyncActionRunner(ActionRunner): class PollingAsyncActionRunner(AsyncActionRunner): - @classmethod def is_polling_enabled(cls): return True @@ -264,7 +265,7 @@ class GitWorktreeActionRunner(ActionRunner): This revision is specified using "content_version" runner parameter. """ - WORKTREE_DIRECTORY_PREFIX = 'st2-git-worktree-' + WORKTREE_DIRECTORY_PREFIX = "st2-git-worktree-" def __init__(self, runner_id): super(GitWorktreeActionRunner, self).__init__(runner_id=runner_id) @@ -284,11 +285,13 @@ def pre_run(self): # Override entry_point so it points to git worktree directory pack_name = self.get_pack_name() - entry_point = self._get_entry_point_for_worktree_path(pack_name=pack_name, - entry_point=self.entry_point, - worktree_path=self.git_worktree_path) + entry_point = self._get_entry_point_for_worktree_path( + pack_name=pack_name, + entry_point=self.entry_point, + worktree_path=self.git_worktree_path, + ) - assert(entry_point.startswith(self.git_worktree_path)) + assert entry_point.startswith(self.git_worktree_path) self.entry_point = entry_point @@ -298,9 +301,11 @@ def post_run(self, status, result): # Remove git worktree directories (if used and available) if self.git_worktree_path and self.git_worktree_revision: pack_name = self.get_pack_name() - self.cleanup_git_worktree(worktree_path=self.git_worktree_path, - content_version=self.git_worktree_revision, - pack_name=pack_name) + self.cleanup_git_worktree( + worktree_path=self.git_worktree_path, + content_version=self.git_worktree_revision, + pack_name=pack_name, + ) def create_git_worktree(self, content_version): """ @@ -318,51 +323,59 @@ def create_git_worktree(self, content_version): self.git_worktree_path = worktree_path extra = { - 'pack_name': pack_name, - 'pack_directory': pack_directory, - 'content_version': content_version, - 'worktree_path': worktree_path + "pack_name": pack_name, + "pack_directory": pack_directory, + "content_version": content_version, + "worktree_path": worktree_path, } if not os.path.isdir(pack_directory): - msg = ('Failed to create git worktree for pack "%s". Pack directory "%s" doesn\'t ' - 'exist.' % (pack_name, pack_directory)) + msg = ( + 'Failed to create git worktree for pack "%s". Pack directory "%s" doesn\'t ' + "exist." % (pack_name, pack_directory) + ) raise ValueError(msg) args = [ - 'git', - '-C', + "git", + "-C", pack_directory, - 'worktree', - 'add', + "worktree", + "add", worktree_path, - content_version + content_version, ] cmd = list2cmdline(args) - LOG.debug('Creating git worktree for pack "%s", content version "%s" and execution ' - 'id "%s" in "%s"' % (pack_name, content_version, self.execution_id, - worktree_path), extra=extra) - LOG.debug('Command: %s' % (cmd)) - exit_code, stdout, stderr, timed_out = run_command(cmd=cmd, - cwd=pack_directory, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - shell=True) + LOG.debug( + 'Creating git worktree for pack "%s", content version "%s" and execution ' + 'id "%s" in "%s"' + % (pack_name, content_version, self.execution_id, worktree_path), + extra=extra, + ) + LOG.debug("Command: %s" % (cmd)) + exit_code, stdout, stderr, timed_out = run_command( + cmd=cmd, + cwd=pack_directory, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=True, + ) if exit_code != 0: - self._handle_git_worktree_error(pack_name=pack_name, pack_directory=pack_directory, - content_version=content_version, - exit_code=exit_code, stdout=stdout, stderr=stderr) + self._handle_git_worktree_error( + pack_name=pack_name, + pack_directory=pack_directory, + content_version=content_version, + exit_code=exit_code, + stdout=stdout, + stderr=stderr, + ) else: LOG.debug('Git worktree created in "%s"' % (worktree_path), extra=extra) # Make sure system / action runner user can access that directory - args = [ - 'chmod', - '777', - worktree_path - ] + args = ["chmod", "777", worktree_path] cmd = list2cmdline(args) run_command(cmd=cmd, shell=True) @@ -375,15 +388,19 @@ def cleanup_git_worktree(self, worktree_path, pack_name, content_version): :rtype: ``bool`` """ # Safety check to make sure we don't remove something outside /tmp - assert(worktree_path.startswith('/tmp')) - assert(worktree_path.startswith('/tmp/%s' % (self.WORKTREE_DIRECTORY_PREFIX))) + assert worktree_path.startswith("/tmp") + assert worktree_path.startswith("/tmp/%s" % (self.WORKTREE_DIRECTORY_PREFIX)) if self._debug: - LOG.debug('Not removing git worktree "%s" because debug mode is enabled' % - (worktree_path)) + LOG.debug( + 'Not removing git worktree "%s" because debug mode is enabled' + % (worktree_path) + ) else: - LOG.debug('Removing git worktree "%s" for pack "%s" and content version "%s"' % - (worktree_path, pack_name, content_version)) + LOG.debug( + 'Removing git worktree "%s" for pack "%s" and content version "%s"' + % (worktree_path, pack_name, content_version) + ) try: shutil.rmtree(worktree_path, ignore_errors=True) @@ -392,36 +409,43 @@ def cleanup_git_worktree(self, worktree_path, pack_name, content_version): return True - def _handle_git_worktree_error(self, pack_name, pack_directory, content_version, exit_code, - stdout, stderr): + def _handle_git_worktree_error( + self, pack_name, pack_directory, content_version, exit_code, stdout, stderr + ): """ Handle "git worktree" related errors and throw a more user-friendly exception. """ error_prefix = 'Failed to create git worktree for pack "%s": ' % (pack_name) if isinstance(stdout, six.binary_type): - stdout = stdout.decode('utf-8') + stdout = stdout.decode("utf-8") if isinstance(stderr, six.binary_type): - stderr = stderr.decode('utf-8') + stderr = stderr.decode("utf-8") # 1. Installed version of git which doesn't support worktree command if "git: 'worktree' is not a git command." in stderr: - msg = ('Installed git version doesn\'t support git worktree command. ' - 'To be able to utilize this functionality you need to use git ' - '>= 2.5.0.') + msg = ( + "Installed git version doesn't support git worktree command. " + "To be able to utilize this functionality you need to use git " + ">= 2.5.0." + ) raise ValueError(error_prefix + msg) # 2. Provided pack directory is not a git repository if "Not a git repository" in stderr: - msg = ('Pack directory "%s" is not a git repository. To utilize this functionality, ' - 'pack directory needs to be a git repository.' % (pack_directory)) + msg = ( + 'Pack directory "%s" is not a git repository. To utilize this functionality, ' + "pack directory needs to be a git repository." % (pack_directory) + ) raise ValueError(error_prefix + msg) # 3. Invalid revision provided if "invalid reference" in stderr: - msg = ('Invalid content_version "%s" provided. Make sure that git repository is up ' - 'to date and contains that revision.' % (content_version)) + msg = ( + 'Invalid content_version "%s" provided. Make sure that git repository is up ' + "to date and contains that revision." % (content_version) + ) raise ValueError(error_prefix + msg) def _get_entry_point_for_worktree_path(self, pack_name, entry_point, worktree_path): @@ -433,10 +457,10 @@ def _get_entry_point_for_worktree_path(self, pack_name, entry_point, worktree_pa """ pack_base_path = get_pack_base_path(pack_name=pack_name) - new_entry_point = entry_point.replace(pack_base_path, '') + new_entry_point = entry_point.replace(pack_base_path, "") # Remove leading slash (if any) - if new_entry_point.startswith('/'): + if new_entry_point.startswith("/"): new_entry_point = new_entry_point[1:] new_entry_point = os.path.join(worktree_path, new_entry_point) @@ -444,7 +468,7 @@ def _get_entry_point_for_worktree_path(self, pack_name, entry_point, worktree_pa # Check to prevent directory traversal common_prefix = os.path.commonprefix([worktree_path, new_entry_point]) if common_prefix != worktree_path: - raise ValueError('entry_point is not located inside the pack directory') + raise ValueError("entry_point is not located inside the pack directory") return new_entry_point @@ -483,11 +507,11 @@ def _get_script_args(self, action_parameters): is_script_run_as_cmd = self.runner_parameters.get(RUNNER_COMMAND, None) - pos_args = '' + pos_args = "" named_args = {} if is_script_run_as_cmd: - pos_args = self.runner_parameters.get(RUNNER_COMMAND, '') + pos_args = self.runner_parameters.get(RUNNER_COMMAND, "") named_args = action_parameters else: pos_args, named_args = action_utils.get_args(action_parameters, self.action) diff --git a/st2common/st2common/runners/base_action.py b/st2common/st2common/runners/base_action.py index bc915d2b4ff..244a4235c9c 100644 --- a/st2common/st2common/runners/base_action.py +++ b/st2common/st2common/runners/base_action.py @@ -21,9 +21,7 @@ from st2common.runners.utils import get_logger_for_python_runner_action from st2common.runners.utils import PackConfigDict -__all__ = [ - 'Action' -] +__all__ = ["Action"] @six.add_metaclass(abc.ABCMeta) @@ -45,16 +43,17 @@ def __init__(self, config=None, action_service=None): self.config = config or {} self.action_service = action_service - if action_service and getattr(action_service, '_action_wrapper', None): - log_level = getattr(action_service._action_wrapper, '_log_level', 'debug') - pack_name = getattr(action_service._action_wrapper, '_pack', 'unknown') + if action_service and getattr(action_service, "_action_wrapper", None): + log_level = getattr(action_service._action_wrapper, "_log_level", "debug") + pack_name = getattr(action_service._action_wrapper, "_pack", "unknown") else: - log_level = 'debug' - pack_name = 'unknown' + log_level = "debug" + pack_name = "unknown" self.config = PackConfigDict(pack_name, self.config) - self.logger = get_logger_for_python_runner_action(action_name=self.__class__.__name__, - log_level=log_level) + self.logger = get_logger_for_python_runner_action( + action_name=self.__class__.__name__, log_level=log_level + ) @abc.abstractmethod def run(self, **kwargs): diff --git a/st2common/st2common/runners/parallel_ssh.py b/st2common/st2common/runners/parallel_ssh.py index 28f87564153..c41175c02c9 100644 --- a/st2common/st2common/runners/parallel_ssh.py +++ b/st2common/st2common/runners/parallel_ssh.py @@ -35,13 +35,26 @@ class ParallelSSHClient(object): - KEYS_TO_TRANSFORM = ['stdout', 'stderr'] - CONNECT_ERROR = 'Cannot connect to host.' - - def __init__(self, hosts, user=None, password=None, pkey_file=None, pkey_material=None, port=22, - bastion_host=None, concurrency=10, raise_on_any_error=False, connect=True, - passphrase=None, handle_stdout_line_func=None, handle_stderr_line_func=None, - sudo_password=False): + KEYS_TO_TRANSFORM = ["stdout", "stderr"] + CONNECT_ERROR = "Cannot connect to host." + + def __init__( + self, + hosts, + user=None, + password=None, + pkey_file=None, + pkey_material=None, + port=22, + bastion_host=None, + concurrency=10, + raise_on_any_error=False, + connect=True, + passphrase=None, + handle_stdout_line_func=None, + handle_stderr_line_func=None, + sudo_password=False, + ): """ :param handle_stdout_line_func: Callback function which is called dynamically each time a new stdout line is received. @@ -65,7 +78,7 @@ def __init__(self, hosts, user=None, password=None, pkey_file=None, pkey_materia self._sudo_password = sudo_password if not hosts: - raise Exception('Need an non-empty list of hosts to talk to.') + raise Exception("Need an non-empty list of hosts to talk to.") self._pool = concurrency_lib.get_green_pool_class()(concurrency) self._hosts_client = {} @@ -74,8 +87,8 @@ def __init__(self, hosts, user=None, password=None, pkey_file=None, pkey_materia if connect: connect_results = self.connect(raise_on_any_error=raise_on_any_error) - extra = {'_connect_results': connect_results} - LOG.debug('Connect to hosts complete.', extra=extra) + extra = {"_connect_results": connect_results} + LOG.debug("Connect to hosts complete.", extra=extra) def connect(self, raise_on_any_error=False): """ @@ -92,17 +105,28 @@ def connect(self, raise_on_any_error=False): for host in self._hosts: while not concurrency_lib.is_green_pool_free(self._pool): concurrency_lib.sleep(self._scan_interval) - self._pool.spawn(self._connect, host=host, results=results, - raise_on_any_error=raise_on_any_error) + self._pool.spawn( + self._connect, + host=host, + results=results, + raise_on_any_error=raise_on_any_error, + ) concurrency_lib.green_pool_wait_all(self._pool) if self._successful_connects < 1: # We definitely have to raise an exception in this case. - LOG.error('Unable to connect to any of the hosts.', - extra={'connect_results': results}) - msg = ('Unable to connect to any one of the hosts: %s.\n\n connect_errors=%s' % - (self._hosts, json.dumps(results, indent=2))) + LOG.error( + "Unable to connect to any of the hosts.", + extra={"connect_results": results}, + ) + msg = ( + "Unable to connect to any one of the hosts: %s.\n\n connect_errors=%s" + % ( + self._hosts, + json.dumps(results, indent=2), + ) + ) raise NoHostsConnectedToException(msg) return results @@ -124,10 +148,7 @@ def run(self, cmd, timeout=None): :rtype: ``dict`` of ``str`` to ``dict`` """ - options = { - 'cmd': cmd, - 'timeout': timeout - } + options = {"cmd": cmd, "timeout": timeout} results = self._execute_in_pool(self._run_command, **options) return results @@ -152,13 +173,13 @@ def put(self, local_path, remote_path, mode=None, mirror_local_mode=False): """ if not os.path.exists(local_path): - raise Exception('Local path %s does not exist.' % local_path) + raise Exception("Local path %s does not exist." % local_path) options = { - 'local_path': local_path, - 'remote_path': remote_path, - 'mode': mode, - 'mirror_local_mode': mirror_local_mode + "local_path": local_path, + "remote_path": remote_path, + "mode": mode, + "mirror_local_mode": mirror_local_mode, } return self._execute_in_pool(self._put_files, **options) @@ -173,9 +194,7 @@ def mkdir(self, path): :rtype path: ``dict`` of ``str`` to ``dict`` """ - options = { - 'path': path - } + options = {"path": path} return self._execute_in_pool(self._mkdir, **options) def delete_file(self, path): @@ -188,9 +207,7 @@ def delete_file(self, path): :rtype path: ``dict`` of ``str`` to ``dict`` """ - options = { - 'path': path - } + options = {"path": path} return self._execute_in_pool(self._delete_file, **options) def delete_dir(self, path, force=False, timeout=None): @@ -203,10 +220,7 @@ def delete_dir(self, path, force=False, timeout=None): :rtype path: ``dict`` of ``str`` to ``dict`` """ - options = { - 'path': path, - 'force': force - } + options = {"path": path, "force": force} return self._execute_in_pool(self._delete_dir, **options) def close(self): @@ -218,7 +232,7 @@ def close(self): try: self._hosts_client[host].close() except: - LOG.exception('Failed shutting down SSH connection to host: %s', host) + LOG.exception("Failed shutting down SSH connection to host: %s", host) def _execute_in_pool(self, execute_method, **kwargs): results = {} @@ -237,36 +251,41 @@ def _execute_in_pool(self, execute_method, **kwargs): def _connect(self, host, results, raise_on_any_error=False): (hostname, port) = self._get_host_port_info(host) - extra = {'host': host, 'port': port, 'user': self._ssh_user} + extra = {"host": host, "port": port, "user": self._ssh_user} if self._ssh_password: - extra['password'] = '' + extra["password"] = "" elif self._ssh_key_file: - extra['key_file_path'] = self._ssh_key_file + extra["key_file_path"] = self._ssh_key_file else: - extra['private_key'] = '' - - LOG.debug('Connecting to host.', extra=extra) - - client = ParamikoSSHClient(hostname=hostname, port=port, - username=self._ssh_user, - password=self._ssh_password, - bastion_host=self._bastion_host, - key_files=self._ssh_key_file, - key_material=self._ssh_key_material, - passphrase=self._passphrase, - handle_stdout_line_func=self._handle_stdout_line_func, - handle_stderr_line_func=self._handle_stderr_line_func) + extra["private_key"] = "" + + LOG.debug("Connecting to host.", extra=extra) + + client = ParamikoSSHClient( + hostname=hostname, + port=port, + username=self._ssh_user, + password=self._ssh_password, + bastion_host=self._bastion_host, + key_files=self._ssh_key_file, + key_material=self._ssh_key_material, + passphrase=self._passphrase, + handle_stdout_line_func=self._handle_stdout_line_func, + handle_stderr_line_func=self._handle_stderr_line_func, + ) try: client.connect() except SSHException as ex: LOG.exception(ex) if raise_on_any_error: raise - error_dict = self._generate_error_result(exc=ex, message='Connection error.') + error_dict = self._generate_error_result( + exc=ex, message="Connection error." + ) self._bad_hosts[hostname] = error_dict results[hostname] = error_dict except Exception as ex: - error = 'Failed connecting to host %s.' % hostname + error = "Failed connecting to host %s." % hostname LOG.exception(error) if raise_on_any_error: raise @@ -276,16 +295,19 @@ def _connect(self, host, results, raise_on_any_error=False): else: self._successful_connects += 1 self._hosts_client[hostname] = client - results[hostname] = {'message': 'Connected to host.'} + results[hostname] = {"message": "Connected to host."} def _run_command(self, host, cmd, results, timeout=None): try: - LOG.debug('Running command: %s on host: %s.', cmd, host) + LOG.debug("Running command: %s on host: %s.", cmd, host) client = self._hosts_client[host] - (stdout, stderr, exit_code) = client.run(cmd, timeout=timeout, - call_line_handler_func=True) + (stdout, stderr, exit_code) = client.run( + cmd, timeout=timeout, call_line_handler_func=True + ) - result = self._handle_command_result(stdout=stdout, stderr=stderr, exit_code=exit_code) + result = self._handle_command_result( + stdout=stdout, stderr=stderr, exit_code=exit_code + ) results[host] = result except Exception as ex: cmd = self._sanitize_command_string(cmd=cmd) @@ -293,20 +315,24 @@ def _run_command(self, host, cmd, results, timeout=None): LOG.exception(error) results[host] = self._generate_error_result(exc=ex, message=error) - def _put_files(self, local_path, remote_path, host, results, mode=None, - mirror_local_mode=False): + def _put_files( + self, local_path, remote_path, host, results, mode=None, mirror_local_mode=False + ): try: - LOG.debug('Copying file to host: %s' % host) + LOG.debug("Copying file to host: %s" % host) if os.path.isdir(local_path): result = self._hosts_client[host].put_dir(local_path, remote_path) else: - result = self._hosts_client[host].put(local_path, remote_path, - mirror_local_mode=mirror_local_mode, - mode=mode) - LOG.debug('Result of copy: %s' % result) + result = self._hosts_client[host].put( + local_path, + remote_path, + mirror_local_mode=mirror_local_mode, + mode=mode, + ) + LOG.debug("Result of copy: %s" % result) results[host] = result except Exception as ex: - error = 'Failed sending file(s) in path %s to host %s' % (local_path, host) + error = "Failed sending file(s) in path %s to host %s" % (local_path, host) LOG.exception(error) results[host] = self._generate_error_result(exc=ex, message=error) @@ -324,16 +350,18 @@ def _delete_file(self, host, path, results): result = self._hosts_client[host].delete_file(path) results[host] = result except Exception as ex: - error = 'Failed deleting file %s on host %s.' % (path, host) + error = "Failed deleting file %s on host %s." % (path, host) LOG.exception(error) results[host] = self._generate_error_result(exc=ex, message=error) def _delete_dir(self, host, path, results, force=False, timeout=None): try: - result = self._hosts_client[host].delete_dir(path, force=force, timeout=timeout) + result = self._hosts_client[host].delete_dir( + path, force=force, timeout=timeout + ) results[host] = result except Exception as ex: - error = 'Failed deleting dir %s on host %s.' % (path, host) + error = "Failed deleting dir %s on host %s." % (path, host) LOG.exception(error) results[host] = self._generate_error_result(exc=ex, message=error) @@ -347,20 +375,27 @@ def _get_host_port_info(self, host_str): def _handle_command_result(self, stdout, stderr, exit_code): # Detect if user provided an invalid sudo password or sudo is not configured for that user if self._sudo_password: - if re.search(r'sudo: \d+ incorrect password attempts', stderr): - match = re.search(r'\[sudo\] password for (.+?)\:', stderr) + if re.search(r"sudo: \d+ incorrect password attempts", stderr): + match = re.search(r"\[sudo\] password for (.+?)\:", stderr) if match: username = match.groups()[0] else: - username = 'unknown' + username = "unknown" - error = ('Invalid sudo password provided or sudo is not configured for this user ' - '(%s)' % (username)) + error = ( + "Invalid sudo password provided or sudo is not configured for this user " + "(%s)" % (username) + ) raise ValueError(error) - is_succeeded = (exit_code == 0) - result_dict = {'stdout': stdout, 'stderr': stderr, 'return_code': exit_code, - 'succeeded': is_succeeded, 'failed': not is_succeeded} + is_succeeded = exit_code == 0 + result_dict = { + "stdout": stdout, + "stderr": stderr, + "return_code": exit_code, + "succeeded": is_succeeded, + "failed": not is_succeeded, + } result = jsonify.json_loads(result_dict, ParallelSSHClient.KEYS_TO_TRANSFORM) return result @@ -375,8 +410,11 @@ def _sanitize_command_string(cmd): if not cmd: return cmd - result = re.sub(r'ST2_ACTION_AUTH_TOKEN=(.+?)\s+?', 'ST2_ACTION_AUTH_TOKEN=%s ' % - (MASKED_ATTRIBUTE_VALUE), cmd) + result = re.sub( + r"ST2_ACTION_AUTH_TOKEN=(.+?)\s+?", + "ST2_ACTION_AUTH_TOKEN=%s " % (MASKED_ATTRIBUTE_VALUE), + cmd, + ) return result @staticmethod @@ -388,8 +426,8 @@ def _generate_error_result(exc, message): :param message: Error message which will be prefixed to the exception exception message. :type message: ``str`` """ - exc_message = getattr(exc, 'message', str(exc)) - error_message = '%s %s' % (message, exc_message) + exc_message = getattr(exc, "message", str(exc)) + error_message = "%s %s" % (message, exc_message) traceback_message = traceback.format_exc() if isinstance(exc, SSHCommandTimeoutError): @@ -399,21 +437,24 @@ def _generate_error_result(exc, message): timeout = False return_code = 255 - stdout = getattr(exc, 'stdout', None) or '' - stderr = getattr(exc, 'stderr', None) or '' + stdout = getattr(exc, "stdout", None) or "" + stderr = getattr(exc, "stderr", None) or "" error_dict = { - 'failed': True, - 'succeeded': False, - 'timeout': timeout, - 'return_code': return_code, - 'stdout': stdout, - 'stderr': stderr, - 'error': error_message, - 'traceback': traceback_message, + "failed": True, + "succeeded": False, + "timeout": timeout, + "return_code": return_code, + "stdout": stdout, + "stderr": stderr, + "error": error_message, + "traceback": traceback_message, } return error_dict def __repr__(self): - return ('' % - (repr(self._hosts), self._ssh_user, id(self))) + return "" % ( + repr(self._hosts), + self._ssh_user, + id(self), + ) diff --git a/st2common/st2common/runners/paramiko_ssh.py b/st2common/st2common/runners/paramiko_ssh.py index c42c4eb89f3..7530a532d95 100644 --- a/st2common/st2common/runners/paramiko_ssh.py +++ b/st2common/st2common/runners/paramiko_ssh.py @@ -35,14 +35,13 @@ from st2common.util.misc import strip_shell_chars from st2common.util.misc import sanitize_output from st2common.util.shell import quote_unix -from st2common.constants.runners import DEFAULT_SSH_PORT, REMOTE_RUNNER_PRIVATE_KEY_HEADER +from st2common.constants.runners import ( + DEFAULT_SSH_PORT, + REMOTE_RUNNER_PRIVATE_KEY_HEADER, +) from st2common.util import concurrency -__all__ = [ - 'ParamikoSSHClient', - - 'SSHCommandTimeoutError' -] +__all__ = ["ParamikoSSHClient", "SSHCommandTimeoutError"] class SSHCommandTimeoutError(Exception): @@ -63,13 +62,21 @@ def __init__(self, cmd, timeout, ssh_connect_timeout, stdout=None, stderr=None): self.ssh_connect_timeout = ssh_connect_timeout self.stdout = stdout self.stderr = stderr - self.message = ('Command didn\'t finish in %s seconds or the SSH connection ' - 'did not succeed in %s seconds' % (timeout, ssh_connect_timeout)) + self.message = ( + "Command didn't finish in %s seconds or the SSH connection " + "did not succeed in %s seconds" % (timeout, ssh_connect_timeout) + ) super(SSHCommandTimeoutError, self).__init__(self.message) def __repr__(self): - return ('' % - (self.cmd, self.timeout, self.ssh_connect_timeout)) + return ( + '' + % ( + self.cmd, + self.timeout, + self.ssh_connect_timeout, + ) + ) def __str__(self): return self.message @@ -86,9 +93,20 @@ class ParamikoSSHClient(object): # How long to sleep while waiting for command to finish to prevent busy waiting SLEEP_DELAY = 0.2 - def __init__(self, hostname, port=DEFAULT_SSH_PORT, username=None, password=None, - bastion_host=None, key_files=None, key_material=None, timeout=None, - passphrase=None, handle_stdout_line_func=None, handle_stderr_line_func=None): + def __init__( + self, + hostname, + port=DEFAULT_SSH_PORT, + username=None, + password=None, + bastion_host=None, + key_files=None, + key_material=None, + timeout=None, + passphrase=None, + handle_stdout_line_func=None, + handle_stderr_line_func=None, + ): """ Authentication is always attempted in the following order: @@ -114,8 +132,7 @@ def __init__(self, hostname, port=DEFAULT_SSH_PORT, username=None, password=None self._handle_stderr_line_func = handle_stderr_line_func self.ssh_config_file = os.path.expanduser( - cfg.CONF.ssh_runner.ssh_config_file_path or - '~/.ssh/config' + cfg.CONF.ssh_runner.ssh_config_file_path or "~/.ssh/config" ) if self.timeout and int(self.ssh_connect_timeout) > int(self.timeout) - 2: @@ -140,14 +157,16 @@ def connect(self): :rtype: ``bool`` """ if self.bastion_host: - self.logger.debug('Bastion host specified, connecting') + self.logger.debug("Bastion host specified, connecting") self.bastion_client = self._connect(host=self.bastion_host) transport = self.bastion_client.get_transport() real_addr = (self.hostname, self.port) # fabric uses ('', 0) for direct-tcpip, this duplicates that behaviour # see https://github.com/fabric/fabric/commit/c2a9bbfd50f560df6c6f9675603fb405c4071cad - local_addr = ('', 0) - self.bastion_socket = transport.open_channel('direct-tcpip', real_addr, local_addr) + local_addr = ("", 0) + self.bastion_socket = transport.open_channel( + "direct-tcpip", real_addr, local_addr + ) self.client = self._connect(host=self.hostname, socket=self.bastion_socket) return True @@ -173,17 +192,24 @@ def put(self, local_path, remote_path, mode=None, mirror_local_mode=False): """ if not local_path or not remote_path: - raise Exception('Need both local_path and remote_path. local: %s, remote: %s' % - local_path, remote_path) + raise Exception( + "Need both local_path and remote_path. local: %s, remote: %s" + % local_path, + remote_path, + ) local_path = quote_unix(local_path) remote_path = quote_unix(remote_path) - extra = {'_local_path': local_path, '_remote_path': remote_path, '_mode': mode, - '_mirror_local_mode': mirror_local_mode} - self.logger.debug('Uploading file', extra=extra) + extra = { + "_local_path": local_path, + "_remote_path": remote_path, + "_mode": mode, + "_mirror_local_mode": mirror_local_mode, + } + self.logger.debug("Uploading file", extra=extra) if not os.path.exists(local_path): - raise Exception('Path %s does not exist locally.' % local_path) + raise Exception("Path %s does not exist locally." % local_path) rattrs = self.sftp.put(local_path, remote_path) @@ -199,7 +225,7 @@ def put(self, local_path, remote_path, mode=None, mirror_local_mode=False): remote_mode = rattrs.st_mode # Only bitshift if we actually got an remote_mode if remote_mode is not None: - remote_mode = (remote_mode & 0o7777) + remote_mode = remote_mode & 0o7777 if local_mode != remote_mode: self.sftp.chmod(remote_path, local_mode) @@ -225,9 +251,13 @@ def put_dir(self, local_path, remote_path, mode=None, mirror_local_mode=False): :rtype: ``list`` of ``str`` """ - extra = {'_local_path': local_path, '_remote_path': remote_path, '_mode': mode, - '_mirror_local_mode': mirror_local_mode} - self.logger.debug('Uploading dir', extra=extra) + extra = { + "_local_path": local_path, + "_remote_path": remote_path, + "_mode": mode, + "_mirror_local_mode": mirror_local_mode, + } + self.logger.debug("Uploading dir", extra=extra) if os.path.basename(local_path): strip = os.path.dirname(local_path) @@ -237,10 +267,10 @@ def put_dir(self, local_path, remote_path, mode=None, mirror_local_mode=False): remote_paths = [] for context, dirs, files in os.walk(local_path): - rcontext = context.replace(strip, '', 1) + rcontext = context.replace(strip, "", 1) # normalize pathname separators with POSIX separator - rcontext = rcontext.replace(os.sep, '/') - rcontext = rcontext.lstrip('/') + rcontext = rcontext.replace(os.sep, "/") + rcontext = rcontext.lstrip("/") rcontext = posixpath.join(remote_path, rcontext) if not self.exists(rcontext): @@ -255,8 +285,12 @@ def put_dir(self, local_path, remote_path, mode=None, mirror_local_mode=False): local_path = os.path.join(context, f) n = posixpath.join(rcontext, f) # Note that quote_unix is done by put anyways. - p = self.put(local_path=local_path, remote_path=n, - mirror_local_mode=mirror_local_mode, mode=mode) + p = self.put( + local_path=local_path, + remote_path=n, + mirror_local_mode=mirror_local_mode, + mode=mode, + ) remote_paths.append(p) return remote_paths @@ -290,8 +324,8 @@ def mkdir(self, dir_path): """ dir_path = quote_unix(dir_path) - extra = {'_dir_path': dir_path} - self.logger.debug('mkdir', extra=extra) + extra = {"_dir_path": dir_path} + self.logger.debug("mkdir", extra=extra) return self.sftp.mkdir(dir_path) def delete_file(self, path): @@ -307,8 +341,8 @@ def delete_file(self, path): """ path = quote_unix(path) - extra = {'_path': path} - self.logger.debug('Deleting file', extra=extra) + extra = {"_path": path} + self.logger.debug("Deleting file", extra=extra) self.sftp.unlink(path) return True @@ -331,15 +365,15 @@ def delete_dir(self, path, force=False, timeout=None): """ path = quote_unix(path) - extra = {'_path': path} + extra = {"_path": path} if force: - command = 'rm -rf %s' % path - extra['_command'] = command - extra['_force'] = force - self.logger.debug('Deleting dir', extra=extra) + command = "rm -rf %s" % path + extra["_command"] = command + extra["_force"] = force + self.logger.debug("Deleting dir", extra=extra) return self.run(command, timeout=timeout) - self.logger.debug('Deleting dir', extra=extra) + self.logger.debug("Deleting dir", extra=extra) return self.sftp.rmdir(path) def run(self, cmd, timeout=None, quote=False, call_line_handler_func=False): @@ -359,8 +393,8 @@ def run(self, cmd, timeout=None, quote=False, call_line_handler_func=False): if quote: cmd = quote_unix(cmd) - extra = {'_cmd': cmd} - self.logger.info('Executing command', extra=extra) + extra = {"_cmd": cmd} + self.logger.info("Executing command", extra=extra) # Use the system default buffer size bufsize = -1 @@ -369,7 +403,7 @@ def run(self, cmd, timeout=None, quote=False, call_line_handler_func=False): chan = transport.open_session() start_time = time.time() - if cmd.startswith('sudo'): + if cmd.startswith("sudo"): # Note that fabric does this as well. If you set pty, stdout and stderr # streams will be combined into one. # NOTE: If pty is used, every new line character \n will be converted to \r\n which @@ -386,7 +420,7 @@ def run(self, cmd, timeout=None, quote=False, call_line_handler_func=False): # Create a stdin file and immediately close it to prevent any # interactive script from hanging the process. - stdin = chan.makefile('wb', bufsize) + stdin = chan.makefile("wb", bufsize) stdin.close() # Receive all the output @@ -400,12 +434,14 @@ def run(self, cmd, timeout=None, quote=False, call_line_handler_func=False): exit_status_ready = chan.exit_status_ready() if exit_status_ready: - stdout_data = self._consume_stdout(chan=chan, - call_line_handler_func=call_line_handler_func) + stdout_data = self._consume_stdout( + chan=chan, call_line_handler_func=call_line_handler_func + ) stdout_data = stdout_data.getvalue() - stderr_data = self._consume_stderr(chan=chan, - call_line_handler_func=call_line_handler_func) + stderr_data = self._consume_stderr( + chan=chan, call_line_handler_func=call_line_handler_func + ) stderr_data = stderr_data.getvalue() stdout.write(stdout_data) @@ -413,7 +449,7 @@ def run(self, cmd, timeout=None, quote=False, call_line_handler_func=False): while not exit_status_ready: current_time = time.time() - elapsed_time = (current_time - start_time) + elapsed_time = current_time - start_time if timeout and (elapsed_time > timeout): # TODO: Is this the right way to clean up? @@ -421,16 +457,22 @@ def run(self, cmd, timeout=None, quote=False, call_line_handler_func=False): stdout = sanitize_output(stdout.getvalue(), uses_pty=uses_pty) stderr = sanitize_output(stderr.getvalue(), uses_pty=uses_pty) - raise SSHCommandTimeoutError(cmd=cmd, timeout=timeout, - ssh_connect_timeout=self.ssh_connect_timeout, - stdout=stdout, stderr=stderr) - - stdout_data = self._consume_stdout(chan=chan, - call_line_handler_func=call_line_handler_func) + raise SSHCommandTimeoutError( + cmd=cmd, + timeout=timeout, + ssh_connect_timeout=self.ssh_connect_timeout, + stdout=stdout, + stderr=stderr, + ) + + stdout_data = self._consume_stdout( + chan=chan, call_line_handler_func=call_line_handler_func + ) stdout_data = stdout_data.getvalue() - stderr_data = self._consume_stderr(chan=chan, - call_line_handler_func=call_line_handler_func) + stderr_data = self._consume_stderr( + chan=chan, call_line_handler_func=call_line_handler_func + ) stderr_data = stderr_data.getvalue() stdout.write(stdout_data) @@ -453,8 +495,8 @@ def run(self, cmd, timeout=None, quote=False, call_line_handler_func=False): stdout = sanitize_output(stdout.getvalue(), uses_pty=uses_pty) stderr = sanitize_output(stderr.getvalue(), uses_pty=uses_pty) - extra = {'_status': status, '_stdout': stdout, '_stderr': stderr} - self.logger.debug('Command finished', extra=extra) + extra = {"_status": status, "_stdout": stdout, "_stderr": stderr} + self.logger.debug("Command finished", extra=extra) return [stdout, stderr, status] @@ -499,7 +541,7 @@ def _consume_stdout(self, chan, call_line_handler_func=False): data = chan.recv(self.CHUNK_SIZE) if six.PY3 and isinstance(data, six.text_type): - data = data.encode('utf-8') + data = data.encode("utf-8") out += data @@ -512,7 +554,7 @@ def _consume_stdout(self, chan, call_line_handler_func=False): data = chan.recv(self.CHUNK_SIZE) if six.PY3 and isinstance(data, six.text_type): - data = data.encode('utf-8') + data = data.encode("utf-8") out += data @@ -520,14 +562,14 @@ def _consume_stdout(self, chan, call_line_handler_func=False): if self._handle_stdout_line_func and call_line_handler_func: data = strip_shell_chars(stdout.getvalue()) - lines = data.split('\n') + lines = data.split("\n") lines = [line for line in lines if line] for line in lines: # Note: If this function performs network operating no sleep is # needed, otherwise if a long blocking operating is performed, # sleep is recommended to yield and prevent from busy looping - self._handle_stdout_line_func(line=line + '\n') + self._handle_stdout_line_func(line=line + "\n") stdout.seek(0) @@ -545,7 +587,7 @@ def _consume_stderr(self, chan, call_line_handler_func=False): data = chan.recv_stderr(self.CHUNK_SIZE) if six.PY3 and isinstance(data, six.text_type): - data = data.encode('utf-8') + data = data.encode("utf-8") out += data @@ -558,7 +600,7 @@ def _consume_stderr(self, chan, call_line_handler_func=False): data = chan.recv_stderr(self.CHUNK_SIZE) if six.PY3 and isinstance(data, six.text_type): - data = data.encode('utf-8') + data = data.encode("utf-8") out += data @@ -566,14 +608,14 @@ def _consume_stderr(self, chan, call_line_handler_func=False): if self._handle_stderr_line_func and call_line_handler_func: data = strip_shell_chars(stderr.getvalue()) - lines = data.split('\n') + lines = data.split("\n") lines = [line for line in lines if line] for line in lines: # Note: If this function performs network operating no sleep is # needed, otherwise if a long blocking operating is performed, # sleep is recommended to yield and prevent from busy looping - self._handle_stderr_line_func(line=line + '\n') + self._handle_stderr_line_func(line=line + "\n") stderr.seek(0) @@ -581,9 +623,9 @@ def _consume_stderr(self, chan, call_line_handler_func=False): def _get_decoded_data(self, data): try: - return data.decode('utf-8') + return data.decode("utf-8") except: - self.logger.exception('Non UTF-8 character found in data: %s', data) + self.logger.exception("Non UTF-8 character found in data: %s", data) raise def _get_pkey_object(self, key_material, passphrase): @@ -604,13 +646,17 @@ def _get_pkey_object(self, key_material, passphrase): # exception letting the user know we expect the contents a not a path. # Note: We do it here and not up the stack to avoid false positives. contains_header = REMOTE_RUNNER_PRIVATE_KEY_HEADER in key_material.lower() - if not contains_header and (key_material.count('/') >= 1 or key_material.count('\\') >= 1): - msg = ('"private_key" parameter needs to contain private key data / content and not ' - 'a path') + if not contains_header and ( + key_material.count("/") >= 1 or key_material.count("\\") >= 1 + ): + msg = ( + '"private_key" parameter needs to contain private key data / content and not ' + "a path" + ) elif passphrase: - msg = 'Invalid passphrase or invalid/unsupported key type' + msg = "Invalid passphrase or invalid/unsupported key type" else: - msg = 'Invalid or unsupported key type' + msg = "Invalid or unsupported key type" raise paramiko.ssh_exception.SSHException(msg) @@ -636,19 +682,23 @@ def _connect(self, host, socket=None): :rtype: :class:`paramiko.SSHClient` """ - conninfo = {'hostname': host, - 'allow_agent': False, - 'look_for_keys': False, - 'timeout': self.ssh_connect_timeout} + conninfo = { + "hostname": host, + "allow_agent": False, + "look_for_keys": False, + "timeout": self.ssh_connect_timeout, + } ssh_config_file_info = {} if cfg.CONF.ssh_runner.use_ssh_config: ssh_config_file_info = self._get_ssh_config_for_host(host) - ssh_config_username = ssh_config_file_info.get('user', None) - ssh_config_port = ssh_config_file_info.get('port', None) + ssh_config_username = ssh_config_file_info.get("user", None) + ssh_config_port = ssh_config_file_info.get("port", None) - self.username = (self.username or ssh_config_username or cfg.CONF.system_user.user) + self.username = ( + self.username or ssh_config_username or cfg.CONF.system_user.user + ) # If a custom non-default port is provided in the SSH config file we use that over the # default port value provided via runner parameter @@ -660,78 +710,92 @@ def _connect(self, host, socket=None): # If both key file and key material are provided as action parameters, # throw an error informing user only one is required. if self.key_files and self.key_material: - msg = ('key_files and key_material arguments are mutually exclusive. Supply only one.') + msg = "key_files and key_material arguments are mutually exclusive. Supply only one." raise ValueError(msg) # If neither key material nor password is provided, only then we look at key file and decide # if we want to use the user supplied one or the one in SSH config. if not self.key_material and not self.password: - self.key_files = (self.key_files or ssh_config_file_info.get('identityfile', None) or - cfg.CONF.system_user.ssh_key_file) + self.key_files = ( + self.key_files + or ssh_config_file_info.get("identityfile", None) + or cfg.CONF.system_user.ssh_key_file + ) if self.passphrase and not (self.key_files or self.key_material): - raise ValueError('passphrase should accompany private key material') + raise ValueError("passphrase should accompany private key material") credentials_provided = self.password or self.key_files or self.key_material if not credentials_provided: - msg = ('Either password or key file location or key material should be supplied ' + - 'for action. You can also add an entry for host %s in SSH config file %s.' % - (host, self.ssh_config_file)) + msg = ( + "Either password or key file location or key material should be supplied " + + "for action. You can also add an entry for host %s in SSH config file %s." + % (host, self.ssh_config_file) + ) raise ValueError(msg) - conninfo['username'] = self.username - conninfo['port'] = self.port + conninfo["username"] = self.username + conninfo["port"] = self.port if self.password: - conninfo['password'] = self.password + conninfo["password"] = self.password if self.key_files: - conninfo['key_filename'] = self.key_files + conninfo["key_filename"] = self.key_files passphrase_reqd = self._is_key_file_needs_passphrase(self.key_files) if passphrase_reqd and not self.passphrase: - msg = ('Private key file %s is passphrase protected. Supply a passphrase.' % - self.key_files) + msg = ( + "Private key file %s is passphrase protected. Supply a passphrase." + % self.key_files + ) raise paramiko.ssh_exception.PasswordRequiredException(msg) if self.passphrase: # Optional passphrase for unlocking the private key - conninfo['password'] = self.passphrase + conninfo["password"] = self.passphrase if self.key_material: - conninfo['pkey'] = self._get_pkey_object(key_material=self.key_material, - passphrase=self.passphrase) + conninfo["pkey"] = self._get_pkey_object( + key_material=self.key_material, passphrase=self.passphrase + ) if not self.password and not (self.key_files or self.key_material): - conninfo['allow_agent'] = True - conninfo['look_for_keys'] = True - - extra = {'_hostname': host, '_port': self.port, - '_username': self.username, '_timeout': self.ssh_connect_timeout} - self.logger.debug('Connecting to server', extra=extra) - - self.socket = socket or ssh_config_file_info.get('sock', None) + conninfo["allow_agent"] = True + conninfo["look_for_keys"] = True + + extra = { + "_hostname": host, + "_port": self.port, + "_username": self.username, + "_timeout": self.ssh_connect_timeout, + } + self.logger.debug("Connecting to server", extra=extra) + + self.socket = socket or ssh_config_file_info.get("sock", None) if self.socket: - conninfo['sock'] = socket + conninfo["sock"] = socket client = paramiko.SSHClient() client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - extra = {'_conninfo': conninfo} - self.logger.debug('Connection info', extra=extra) + extra = {"_conninfo": conninfo} + self.logger.debug("Connection info", extra=extra) try: client.connect(**conninfo) except SSHException as e: paramiko_msg = six.text_type(e) - if conninfo.get('password', None): - conninfo['password'] = '' + if conninfo.get("password", None): + conninfo["password"] = "" - msg = ('Error connecting to host %s ' % host + - 'with connection parameters %s.' % conninfo + - 'Paramiko error: %s.' % paramiko_msg) + msg = ( + "Error connecting to host %s " % host + + "with connection parameters %s." % conninfo + + "Paramiko error: %s." % paramiko_msg + ) raise SSHException(msg) return client @@ -744,25 +808,29 @@ def _get_ssh_config_for_host(self, host): with open(self.ssh_config_file) as f: ssh_config_parser.parse(f) except IOError as e: - raise Exception('Error accessing ssh config file %s. Code: %s Reason %s' % - (self.ssh_config_file, e.errno, e.strerror)) + raise Exception( + "Error accessing ssh config file %s. Code: %s Reason %s" + % (self.ssh_config_file, e.errno, e.strerror) + ) ssh_config = ssh_config_parser.lookup(host) - self.logger.info('Parsed SSH config file contents: %s', ssh_config) + self.logger.info("Parsed SSH config file contents: %s", ssh_config) if ssh_config: - for k in ('hostname', 'user', 'port'): + for k in ("hostname", "user", "port"): if k in ssh_config: ssh_config_info[k] = ssh_config[k] - if 'identityfile' in ssh_config: - key_file = ssh_config['identityfile'] + if "identityfile" in ssh_config: + key_file = ssh_config["identityfile"] if type(key_file) is list: key_file = key_file[0] - ssh_config_info['identityfile'] = key_file + ssh_config_info["identityfile"] = key_file - if 'proxycommand' in ssh_config: - ssh_config_info['sock'] = paramiko.ProxyCommand(ssh_config['proxycommand']) + if "proxycommand" in ssh_config: + ssh_config_info["sock"] = paramiko.ProxyCommand( + ssh_config["proxycommand"] + ) return ssh_config_info @@ -779,5 +847,9 @@ def _is_key_file_needs_passphrase(file): return False def __repr__(self): - return ('' % - (self.hostname, self.port, self.username, id(self))) + return "" % ( + self.hostname, + self.port, + self.username, + id(self), + ) diff --git a/st2common/st2common/runners/paramiko_ssh_runner.py b/st2common/st2common/runners/paramiko_ssh_runner.py index 6c1ab053e91..f41882935f5 100644 --- a/st2common/st2common/runners/paramiko_ssh_runner.py +++ b/st2common/st2common/runners/paramiko_ssh_runner.py @@ -29,34 +29,31 @@ from st2common.exceptions.actionrunner import ActionRunnerPreRunError from st2common.services.action import store_execution_output_data -__all__ = [ - 'BaseParallelSSHRunner' -] +__all__ = ["BaseParallelSSHRunner"] LOG = logging.getLogger(__name__) # constants to lookup in runner_parameters. -RUNNER_HOSTS = 'hosts' -RUNNER_USERNAME = 'username' -RUNNER_PASSWORD = 'password' -RUNNER_PRIVATE_KEY = 'private_key' -RUNNER_PARALLEL = 'parallel' -RUNNER_SUDO = 'sudo' -RUNNER_SUDO_PASSWORD = 'sudo_password' -RUNNER_ON_BEHALF_USER = 'user' -RUNNER_REMOTE_DIR = 'dir' -RUNNER_COMMAND = 'cmd' -RUNNER_CWD = 'cwd' -RUNNER_ENV = 'env' -RUNNER_KWARG_OP = 'kwarg_op' -RUNNER_TIMEOUT = 'timeout' -RUNNER_SSH_PORT = 'port' -RUNNER_BASTION_HOST = 'bastion_host' -RUNNER_PASSPHRASE = 'passphrase' +RUNNER_HOSTS = "hosts" +RUNNER_USERNAME = "username" +RUNNER_PASSWORD = "password" +RUNNER_PRIVATE_KEY = "private_key" +RUNNER_PARALLEL = "parallel" +RUNNER_SUDO = "sudo" +RUNNER_SUDO_PASSWORD = "sudo_password" +RUNNER_ON_BEHALF_USER = "user" +RUNNER_REMOTE_DIR = "dir" +RUNNER_COMMAND = "cmd" +RUNNER_CWD = "cwd" +RUNNER_ENV = "env" +RUNNER_KWARG_OP = "kwarg_op" +RUNNER_TIMEOUT = "timeout" +RUNNER_SSH_PORT = "port" +RUNNER_BASTION_HOST = "bastion_host" +RUNNER_PASSPHRASE = "passphrase" class BaseParallelSSHRunner(ActionRunner, ShellRunnerMixin): - def __init__(self, runner_id): super(BaseParallelSSHRunner, self).__init__(runner_id=runner_id) self._hosts = None @@ -68,7 +65,7 @@ def __init__(self, runner_id): self._password = None self._private_key = None self._passphrase = None - self._kwarg_op = '--' + self._kwarg_op = "--" self._cwd = None self._env = None self._ssh_port = None @@ -83,13 +80,16 @@ def __init__(self, runner_id): def pre_run(self): super(BaseParallelSSHRunner, self).pre_run() - LOG.debug('Entering BaseParallelSSHRunner.pre_run() for liveaction_id="%s"', - self.liveaction_id) - hosts = self.runner_parameters.get(RUNNER_HOSTS, '').split(',') + LOG.debug( + 'Entering BaseParallelSSHRunner.pre_run() for liveaction_id="%s"', + self.liveaction_id, + ) + hosts = self.runner_parameters.get(RUNNER_HOSTS, "").split(",") self._hosts = [h.strip() for h in hosts if len(h) > 0] if len(self._hosts) < 1: - raise ActionRunnerPreRunError('No hosts specified to run action for action %s.' - % self.liveaction_id) + raise ActionRunnerPreRunError( + "No hosts specified to run action for action %s." % self.liveaction_id + ) self._username = self.runner_parameters.get(RUNNER_USERNAME, None) self._password = self.runner_parameters.get(RUNNER_PASSWORD, None) self._private_key = self.runner_parameters.get(RUNNER_PRIVATE_KEY, None) @@ -103,85 +103,105 @@ def pre_run(self): self._sudo_password = self.runner_parameters.get(RUNNER_SUDO_PASSWORD, None) if self.context: - self._on_behalf_user = self.context.get(RUNNER_ON_BEHALF_USER, self._on_behalf_user) + self._on_behalf_user = self.context.get( + RUNNER_ON_BEHALF_USER, self._on_behalf_user + ) self._cwd = self.runner_parameters.get(RUNNER_CWD, None) self._env = self.runner_parameters.get(RUNNER_ENV, {}) - self._kwarg_op = self.runner_parameters.get(RUNNER_KWARG_OP, '--') - self._timeout = self.runner_parameters.get(RUNNER_TIMEOUT, - REMOTE_RUNNER_DEFAULT_ACTION_TIMEOUT) + self._kwarg_op = self.runner_parameters.get(RUNNER_KWARG_OP, "--") + self._timeout = self.runner_parameters.get( + RUNNER_TIMEOUT, REMOTE_RUNNER_DEFAULT_ACTION_TIMEOUT + ) self._bastion_host = self.runner_parameters.get(RUNNER_BASTION_HOST, None) - LOG.info('[BaseParallelSSHRunner="%s", liveaction_id="%s"] Finished pre_run.', - self.runner_id, self.liveaction_id) + LOG.info( + '[BaseParallelSSHRunner="%s", liveaction_id="%s"] Finished pre_run.', + self.runner_id, + self.liveaction_id, + ) concurrency = int(len(self._hosts) / 3) + 1 if self._parallel else 1 if concurrency > self._max_concurrency: - LOG.debug('Limiting parallel SSH concurrency to %d.', concurrency) + LOG.debug("Limiting parallel SSH concurrency to %d.", concurrency) concurrency = self._max_concurrency client_kwargs = { - 'hosts': self._hosts, - 'user': self._username, - 'port': self._ssh_port, - 'concurrency': concurrency, - 'bastion_host': self._bastion_host, - 'raise_on_any_error': False, - 'connect': True + "hosts": self._hosts, + "user": self._username, + "port": self._ssh_port, + "concurrency": concurrency, + "bastion_host": self._bastion_host, + "raise_on_any_error": False, + "connect": True, } def make_store_stdout_line_func(execution_db, action_db): def store_stdout_line(line): if cfg.CONF.actionrunner.stream_output: - store_execution_output_data(execution_db=execution_db, action_db=action_db, - data=line, output_type='stdout') + store_execution_output_data( + execution_db=execution_db, + action_db=action_db, + data=line, + output_type="stdout", + ) return store_stdout_line def make_store_stderr_line_func(execution_db, action_db): def store_stderr_line(line): if cfg.CONF.actionrunner.stream_output: - store_execution_output_data(execution_db=execution_db, action_db=action_db, - data=line, output_type='stderr') + store_execution_output_data( + execution_db=execution_db, + action_db=action_db, + data=line, + output_type="stderr", + ) return store_stderr_line - handle_stdout_line_func = make_store_stdout_line_func(execution_db=self.execution, - action_db=self.action) - handle_stderr_line_func = make_store_stderr_line_func(execution_db=self.execution, - action_db=self.action) + handle_stdout_line_func = make_store_stdout_line_func( + execution_db=self.execution, action_db=self.action + ) + handle_stderr_line_func = make_store_stderr_line_func( + execution_db=self.execution, action_db=self.action + ) if len(self._hosts) == 1: # We only support streaming output when running action on one host. That is because # the action output is tied to a particulat execution. User can still achieve output # streaming for multiple hosts by running one execution per host. - client_kwargs['handle_stdout_line_func'] = handle_stdout_line_func - client_kwargs['handle_stderr_line_func'] = handle_stderr_line_func + client_kwargs["handle_stdout_line_func"] = handle_stdout_line_func + client_kwargs["handle_stderr_line_func"] = handle_stderr_line_func else: - LOG.debug('Real-time action output streaming is disabled, because action is running ' - 'on more than one host') + LOG.debug( + "Real-time action output streaming is disabled, because action is running " + "on more than one host" + ) if self._password: - client_kwargs['password'] = self._password + client_kwargs["password"] = self._password elif self._private_key: # Determine if the private_key is a path to the key file or the raw key material - is_key_material = self._is_private_key_material(private_key=self._private_key) + is_key_material = self._is_private_key_material( + private_key=self._private_key + ) if is_key_material: # Raw key material - client_kwargs['pkey_material'] = self._private_key + client_kwargs["pkey_material"] = self._private_key else: # Assume it's a path to the key file, verify the file exists - client_kwargs['pkey_file'] = self._private_key + client_kwargs["pkey_file"] = self._private_key if self._passphrase: - client_kwargs['passphrase'] = self._passphrase + client_kwargs["passphrase"] = self._passphrase else: # Default to stanley key file specified in the config - client_kwargs['pkey_file'] = self._ssh_key_file + client_kwargs["pkey_file"] = self._ssh_key_file if self._sudo_password: - client_kwargs['sudo_password'] = True + client_kwargs["sudo_password"] = True self._parallel_ssh_client = ParallelSSHClient(**client_kwargs) @@ -213,21 +233,22 @@ def _get_env_vars(self): @staticmethod def _get_result_status(result, allow_partial_failure): - if 'error' in result and 'traceback' in result: + if "error" in result and "traceback" in result: # Assume this is a global failure where the result dictionary doesn't contain entry # per host timeout = False - success = result.get('succeeded', False) - status = BaseParallelSSHRunner._get_status_for_success_and_timeout(success=success, - timeout=timeout) + success = result.get("succeeded", False) + status = BaseParallelSSHRunner._get_status_for_success_and_timeout( + success=success, timeout=timeout + ) return status success = not allow_partial_failure timeout = True for r in six.itervalues(result): - r_succeess = r.get('succeeded', False) if r else False - r_timeout = r.get('timeout', False) if r else False + r_succeess = r.get("succeeded", False) if r else False + r_timeout = r.get("timeout", False) if r else False timeout &= r_timeout @@ -240,8 +261,9 @@ def _get_result_status(result, allow_partial_failure): if not success: break - status = BaseParallelSSHRunner._get_status_for_success_and_timeout(success=success, - timeout=timeout) + status = BaseParallelSSHRunner._get_status_for_success_and_timeout( + success=success, timeout=timeout + ) return status diff --git a/st2common/st2common/runners/utils.py b/st2common/st2common/runners/utils.py index 82f1a3477cf..70f7139f3fa 100644 --- a/st2common/st2common/runners/utils.py +++ b/st2common/st2common/runners/utils.py @@ -27,14 +27,11 @@ __all__ = [ - 'PackConfigDict', - - 'get_logger_for_python_runner_action', - 'get_action_class_instance', - - 'make_read_and_store_stream_func', - - 'invoke_post_run', + "PackConfigDict", + "get_logger_for_python_runner_action", + "get_action_class_instance", + "make_read_and_store_stream_func", + "invoke_post_run", ] LOG = logging.getLogger(__name__) @@ -61,6 +58,7 @@ class PackConfigDict(dict): This class throws a user-friendly exception in case user tries to access config item which doesn't exist in the dict. """ + def __init__(self, pack_name, *args): super(PackConfigDict, self).__init__(*args) self._pack_name = pack_name @@ -72,8 +70,8 @@ def __getitem__(self, key): # Note: We use late import to avoid performance overhead from oslo_config import cfg - configs_path = os.path.join(cfg.CONF.system.base_path, 'configs/') - config_path = os.path.join(configs_path, self._pack_name + '.yaml') + configs_path = os.path.join(cfg.CONF.system.base_path, "configs/") + config_path = os.path.join(configs_path, self._pack_name + ".yaml") msg = CONFIG_MISSING_ITEM_ERROR % (self._pack_name, key, config_path) raise ValueError(msg) @@ -83,11 +81,11 @@ def __setitem__(self, key, value): super(PackConfigDict, self).__setitem__(key, value) -def get_logger_for_python_runner_action(action_name, log_level='debug'): +def get_logger_for_python_runner_action(action_name, log_level="debug"): """ Set up a logger which logs all the messages with level DEBUG and above to stderr. """ - logger_name = 'actions.python.%s' % (action_name) + logger_name = "actions.python.%s" % (action_name) if logger_name not in LOGGERS: level_name = log_level.upper() @@ -97,7 +95,7 @@ def get_logger_for_python_runner_action(action_name, log_level='debug'): console = stdlib_logging.StreamHandler() console.setLevel(log_level_constant) - formatter = stdlib_logging.Formatter('%(name)-12s: %(levelname)-8s %(message)s') + formatter = stdlib_logging.Formatter("%(name)-12s: %(levelname)-8s %(message)s") console.setFormatter(formatter) logger.addHandler(console) logger.setLevel(log_level_constant) @@ -123,8 +121,8 @@ def get_action_class_instance(action_cls, config=None, action_service=None): :type action_service: :class:`ActionService` """ kwargs = {} - kwargs['config'] = config - kwargs['action_service'] = action_service + kwargs["config"] = config + kwargs["action_service"] = action_service # Note: This is done for backward compatibility reasons. We first try to pass # "action_service" argument to the action class constructor, but if that doesn't work (e.g. old @@ -133,13 +131,15 @@ def get_action_class_instance(action_cls, config=None, action_service=None): try: action_instance = action_cls(**kwargs) except TypeError as e: - if 'unexpected keyword argument \'action_service\'' not in six.text_type(e): + if "unexpected keyword argument 'action_service'" not in six.text_type(e): raise e - LOG.debug('Action class (%s) constructor doesn\'t take "action_service" argument, ' - 'falling back to late assignment...' % (action_cls.__class__.__name__)) + LOG.debug( + 'Action class (%s) constructor doesn\'t take "action_service" argument, ' + "falling back to late assignment..." % (action_cls.__class__.__name__) + ) - action_service = kwargs.pop('action_service', None) + action_service = kwargs.pop("action_service", None) action_instance = action_cls(**kwargs) action_instance.action_service = action_service @@ -166,7 +166,7 @@ def read_and_store_stream(stream, buff): break if isinstance(line, six.binary_type): - line = line.decode('utf-8') + line = line.decode("utf-8") buff.write(line) @@ -175,7 +175,9 @@ def read_and_store_stream(stream, buff): continue if cfg.CONF.actionrunner.stream_output: - store_data_func(execution_db=execution_db, action_db=action_db, data=line) + store_data_func( + execution_db=execution_db, action_db=action_db, data=line + ) except RuntimeError: # process was terminated abruptly pass @@ -193,31 +195,40 @@ def invoke_post_run(liveaction_db, action_db=None): from st2common.util import action_db as action_db_utils from st2common.content import utils as content_utils - LOG.info('Invoking post run for action execution %s.', liveaction_db.id) + LOG.info("Invoking post run for action execution %s.", liveaction_db.id) # Identify action and runner. if not action_db: action_db = action_db_utils.get_action_by_ref(liveaction_db.action) if not action_db: - LOG.error('Unable to invoke post run. Action %s no longer exists.', liveaction_db.action) + LOG.error( + "Unable to invoke post run. Action %s no longer exists.", + liveaction_db.action, + ) return - LOG.info('Action execution %s runs %s of runner type %s.', - liveaction_db.id, action_db.name, action_db.runner_type['name']) + LOG.info( + "Action execution %s runs %s of runner type %s.", + liveaction_db.id, + action_db.name, + action_db.runner_type["name"], + ) # Get instance of the action runner and related configuration. - runner_type_db = action_db_utils.get_runnertype_by_name(action_db.runner_type['name']) + runner_type_db = action_db_utils.get_runnertype_by_name( + action_db.runner_type["name"] + ) runner = runners.get_runner(name=runner_type_db.name) entry_point = content_utils.get_entry_point_abs_path( - pack=action_db.pack, - entry_point=action_db.entry_point) + pack=action_db.pack, entry_point=action_db.entry_point + ) libs_dir_path = content_utils.get_action_libs_abs_path( - pack=action_db.pack, - entry_point=action_db.entry_point) + pack=action_db.pack, entry_point=action_db.entry_point + ) # Configure the action runner. runner.runner_type_db = runner_type_db @@ -226,8 +237,8 @@ def invoke_post_run(liveaction_db, action_db=None): runner.liveaction = liveaction_db runner.liveaction_id = str(liveaction_db.id) runner.entry_point = entry_point - runner.context = getattr(liveaction_db, 'context', dict()) - runner.callback = getattr(liveaction_db, 'callback', dict()) + runner.context = getattr(liveaction_db, "context", dict()) + runner.callback = getattr(liveaction_db, "callback", dict()) runner.libs_dir_path = libs_dir_path # Invoke the post_run method. diff --git a/st2common/st2common/script_setup.py b/st2common/st2common/script_setup.py index 03be4b4427d..0abb7e8269c 100644 --- a/st2common/st2common/script_setup.py +++ b/st2common/st2common/script_setup.py @@ -32,13 +32,7 @@ from st2common.logging.filters import LogLevelFilter from st2common.transport.bootstrap_utils import register_exchanges_with_retry -__all__ = [ - 'setup', - 'teardown', - - 'db_setup', - 'db_teardown' -] +__all__ = ["setup", "teardown", "db_setup", "db_teardown"] LOG = logging.getLogger(__name__) @@ -47,11 +41,15 @@ def register_common_cli_options(): """ Register common CLI options. """ - cfg.CONF.register_cli_opt(cfg.BoolOpt('verbose', short='v', default=False)) + cfg.CONF.register_cli_opt(cfg.BoolOpt("verbose", short="v", default=False)) -def setup(config, setup_db=True, register_mq_exchanges=True, - register_internal_trigger_types=False): +def setup( + config, + setup_db=True, + register_mq_exchanges=True, + register_internal_trigger_types=False, +): """ Common setup function. @@ -76,7 +74,9 @@ def setup(config, setup_db=True, register_mq_exchanges=True, # Set up logging log_level = stdlib_logging.DEBUG - stdlib_logging.basicConfig(format='%(asctime)s %(levelname)s [-] %(message)s', level=log_level) + stdlib_logging.basicConfig( + format="%(asctime)s %(levelname)s [-] %(message)s", level=log_level + ) if not cfg.CONF.verbose: # Note: We still want to print things at the following log levels: INFO, ERROR, CRITICAL diff --git a/st2common/st2common/service_setup.py b/st2common/st2common/service_setup.py index 14cd708cca2..bd01d205e7a 100644 --- a/st2common/st2common/service_setup.py +++ b/st2common/st2common/service_setup.py @@ -53,22 +53,29 @@ __all__ = [ - 'setup', - 'teardown', - - 'db_setup', - 'db_teardown', - - 'register_service_in_service_registry' + "setup", + "teardown", + "db_setup", + "db_teardown", + "register_service_in_service_registry", ] LOG = logging.getLogger(__name__) -def setup(service, config, setup_db=True, register_mq_exchanges=True, - register_signal_handlers=True, register_internal_trigger_types=False, - run_migrations=True, register_runners=True, service_registry=False, - capabilities=None, config_args=None): +def setup( + service, + config, + setup_db=True, + register_mq_exchanges=True, + register_signal_handlers=True, + register_internal_trigger_types=False, + run_migrations=True, + register_runners=True, + service_registry=False, + capabilities=None, + config_args=None, +): """ Common setup function. @@ -99,29 +106,38 @@ def setup(service, config, setup_db=True, register_mq_exchanges=True, else: config.parse_args() - version = '%s.%s.%s' % (sys.version_info[0], sys.version_info[1], sys.version_info[2]) - LOG.debug('Using Python: %s (%s)' % (version, sys.executable)) + version = "%s.%s.%s" % ( + sys.version_info[0], + sys.version_info[1], + sys.version_info[2], + ) + LOG.debug("Using Python: %s (%s)" % (version, sys.executable)) config_file_paths = cfg.CONF.config_file config_file_paths = [os.path.abspath(path) for path in config_file_paths] - LOG.debug('Using config files: %s', ','.join(config_file_paths)) + LOG.debug("Using config files: %s", ",".join(config_file_paths)) # Setup logging. logging_config_path = config.get_logging_config_path() logging_config_path = os.path.abspath(logging_config_path) - LOG.debug('Using logging config: %s', logging_config_path) + LOG.debug("Using logging config: %s", logging_config_path) - is_debug_enabled = (cfg.CONF.debug or cfg.CONF.system.debug) + is_debug_enabled = cfg.CONF.debug or cfg.CONF.system.debug try: - logging.setup(logging_config_path, redirect_stderr=cfg.CONF.log.redirect_stderr, - excludes=cfg.CONF.log.excludes) + logging.setup( + logging_config_path, + redirect_stderr=cfg.CONF.log.redirect_stderr, + excludes=cfg.CONF.log.excludes, + ) except KeyError as e: tb_msg = traceback.format_exc() - if 'log.setLevel' in tb_msg: - msg = 'Invalid log level selected. Log level names need to be all uppercase.' - msg += '\n\n' + getattr(e, 'message', six.text_type(e)) + if "log.setLevel" in tb_msg: + msg = ( + "Invalid log level selected. Log level names need to be all uppercase." + ) + msg += "\n\n" + getattr(e, "message", six.text_type(e)) raise KeyError(msg) else: raise e @@ -134,10 +150,14 @@ def setup(service, config, setup_db=True, register_mq_exchanges=True, # duplicate "AUDIT" messages in production deployments where default service log level is # set to "INFO" and we already log messages with level AUDIT to a special dedicated log # file. - ignore_audit_log_messages = (handler.level >= stdlib_logging.INFO and - handler.level < stdlib_logging.AUDIT) + ignore_audit_log_messages = ( + handler.level >= stdlib_logging.INFO + and handler.level < stdlib_logging.AUDIT + ) if not is_debug_enabled and ignore_audit_log_messages: - LOG.debug('Excluding log messages with level "AUDIT" for handler "%s"' % (handler)) + LOG.debug( + 'Excluding log messages with level "AUDIT" for handler "%s"' % (handler) + ) handler.addFilter(LogLevelFilter(log_levels=exclude_log_levels)) if not is_debug_enabled: @@ -184,8 +204,9 @@ def setup(service, config, setup_db=True, register_mq_exchanges=True, # Register service in the service registry if cfg.CONF.coordination.service_registry and service_registry: # NOTE: It's important that we pass start_heart=True to start the hearbeat process - register_service_in_service_registry(service=service, capabilities=capabilities, - start_heart=True) + register_service_in_service_registry( + service=service, capabilities=capabilities, start_heart=True + ) if sys.version_info[0] == 2: LOG.warning(PYTHON2_DEPRECATION) @@ -220,7 +241,7 @@ def register_service_in_service_registry(service, capabilities=None, start_heart # 1. Create a group with the name of the service if not isinstance(service, six.binary_type): - group_id = service.encode('utf-8') + group_id = service.encode("utf-8") else: group_id = service @@ -231,10 +252,12 @@ def register_service_in_service_registry(service, capabilities=None, start_heart # Include common capabilities such as hostname and process ID proc_info = system_info.get_process_info() - capabilities['hostname'] = proc_info['hostname'] - capabilities['pid'] = proc_info['pid'] + capabilities["hostname"] = proc_info["hostname"] + capabilities["pid"] = proc_info["pid"] # 1. Join the group as a member - LOG.debug('Joining service registry group "%s" as member_id "%s" with capabilities "%s"' % - (group_id, member_id, capabilities)) + LOG.debug( + 'Joining service registry group "%s" as member_id "%s" with capabilities "%s"' + % (group_id, member_id, capabilities) + ) return coordinator.join_group(group_id, capabilities=capabilities).get() diff --git a/st2common/st2common/services/access.py b/st2common/st2common/services/access.py index 72f7f192bb7..9d88c39c424 100644 --- a/st2common/st2common/services/access.py +++ b/st2common/st2common/services/access.py @@ -27,15 +27,14 @@ from st2common.persistence.auth import Token, User from st2common import log as logging -__all__ = [ - 'create_token', - 'delete_token' -] +__all__ = ["create_token", "delete_token"] LOG = logging.getLogger(__name__) -def create_token(username, ttl=None, metadata=None, add_missing_user=True, service=False): +def create_token( + username, ttl=None, metadata=None, add_missing_user=True, service=False +): """ :param username: Username of the user to create the token for. If the account for this user doesn't exist yet it will be created. @@ -57,8 +56,10 @@ def create_token(username, ttl=None, metadata=None, add_missing_user=True, servi if ttl: # Note: We allow arbitrary large TTLs for service tokens. if not service and ttl > cfg.CONF.auth.token_ttl: - msg = ('TTL specified %s is greater than max allowed %s.' % (ttl, - cfg.CONF.auth.token_ttl)) + msg = "TTL specified %s is greater than max allowed %s." % ( + ttl, + cfg.CONF.auth.token_ttl, + ) raise TTLTooLargeException(msg) else: ttl = cfg.CONF.auth.token_ttl @@ -71,22 +72,27 @@ def create_token(username, ttl=None, metadata=None, add_missing_user=True, servi user_db = UserDB(name=username) User.add_or_update(user_db) - extra = {'username': username, 'user': user_db} + extra = {"username": username, "user": user_db} LOG.audit('Registered new user "%s".' % (username), extra=extra) else: raise UserNotFoundError() token = uuid.uuid4().hex expiry = date_utils.get_datetime_utc_now() + datetime.timedelta(seconds=ttl) - token = TokenDB(user=username, token=token, expiry=expiry, metadata=metadata, service=service) + token = TokenDB( + user=username, token=token, expiry=expiry, metadata=metadata, service=service + ) Token.add_or_update(token) - username_string = username if username else 'an anonymous user' + username_string = username if username else "an anonymous user" token_expire_string = isotime.format(expiry, offset=False) - extra = {'username': username, 'token_expiration': token_expire_string} + extra = {"username": username, "token_expiration": token_expire_string} - LOG.audit('Access granted to "%s" with the token set to expire at "%s".' % - (username_string, token_expire_string), extra=extra) + LOG.audit( + 'Access granted to "%s" with the token set to expire at "%s".' + % (username_string, token_expire_string), + extra=extra, + ) return token diff --git a/st2common/st2common/services/action.py b/st2common/st2common/services/action.py index c7e7495d692..46e44800cc3 100644 --- a/st2common/st2common/services/action.py +++ b/st2common/st2common/services/action.py @@ -34,15 +34,13 @@ __all__ = [ - 'request', - 'create_request', - 'publish_request', - 'is_action_canceled_or_canceling', - - 'request_pause', - 'request_resume', - - 'store_execution_output_data', + "request", + "create_request", + "publish_request", + "is_action_canceled_or_canceling", + "request_pause", + "request_resume", + "store_execution_output_data", ] LOG = logging.getLogger(__name__) @@ -51,7 +49,7 @@ def _get_immutable_params(parameters): if not parameters: return [] - return [k for k, v in six.iteritems(parameters) if v.get('immutable', False)] + return [k for k, v in six.iteritems(parameters) if v.get("immutable", False)] def create_request(liveaction, action_db=None, runnertype_db=None): @@ -77,10 +75,10 @@ def create_request(liveaction, action_db=None, runnertype_db=None): # action can be invoked by a system user and so we want to use the user context # from the original workflow action. parent_context = executions.get_parent_context(liveaction) or {} - parent_user = parent_context.get('user', None) + parent_user = parent_context.get("user", None) if parent_user: - liveaction.context['user'] = parent_user + liveaction.context["user"] = parent_user # Validate action if not action_db: @@ -89,31 +87,44 @@ def create_request(liveaction, action_db=None, runnertype_db=None): if not action_db: raise ValueError('Action "%s" cannot be found.' % liveaction.action) if not action_db.enabled: - raise ValueError('Unable to execute. Action "%s" is disabled.' % liveaction.action) + raise ValueError( + 'Unable to execute. Action "%s" is disabled.' % liveaction.action + ) if not runnertype_db: - runnertype_db = action_utils.get_runnertype_by_name(action_db.runner_type['name']) + runnertype_db = action_utils.get_runnertype_by_name( + action_db.runner_type["name"] + ) - if not hasattr(liveaction, 'parameters'): + if not hasattr(liveaction, "parameters"): liveaction.parameters = dict() # For consistency add pack to the context here in addition to RunnerContainer.dispatch() method - liveaction.context['pack'] = action_db.pack + liveaction.context["pack"] = action_db.pack # Validate action parameters. schema = util_schema.get_schema_for_action_parameters(action_db, runnertype_db) validator = util_schema.get_validator() - util_schema.validate(liveaction.parameters, schema, validator, use_default=True, - allow_default_none=True) + util_schema.validate( + liveaction.parameters, + schema, + validator, + use_default=True, + allow_default_none=True, + ) # validate that no immutable params are being overriden. Although possible to # ignore the override it is safer to inform the user to avoid surprises. immutables = _get_immutable_params(action_db.parameters) immutables.extend(_get_immutable_params(runnertype_db.runner_parameters)) - overridden_immutables = [p for p in six.iterkeys(liveaction.parameters) if p in immutables] + overridden_immutables = [ + p for p in six.iterkeys(liveaction.parameters) if p in immutables + ] if len(overridden_immutables) > 0: - raise ValueError('Override of immutable parameter(s) %s is unsupported.' - % str(overridden_immutables)) + raise ValueError( + "Override of immutable parameter(s) %s is unsupported." + % str(overridden_immutables) + ) # Set notification settings for action. # XXX: There are cases when we don't want notifications to be sent for a particular @@ -140,17 +151,24 @@ def create_request(liveaction, action_db=None, runnertype_db=None): _cleanup_liveaction(liveaction) raise trace_exc.TraceNotFoundException(six.text_type(e)) - execution = executions.create_execution_object(liveaction=liveaction, action_db=action_db, - runnertype_db=runnertype_db, publish=False) + execution = executions.create_execution_object( + liveaction=liveaction, + action_db=action_db, + runnertype_db=runnertype_db, + publish=False, + ) if trace_db: trace_service.add_or_update_given_trace_db( trace_db=trace_db, action_executions=[ - trace_service.get_trace_component_for_action_execution(execution, liveaction) - ]) + trace_service.get_trace_component_for_action_execution( + execution, liveaction + ) + ], + ) - get_driver().inc_counter('action.executions.%s' % (liveaction.status)) + get_driver().inc_counter("action.executions.%s" % (liveaction.status)) return liveaction, execution @@ -170,8 +188,11 @@ def publish_request(liveaction, execution): # TODO: This results in two queries, optimize it # extra = {'liveaction_db': liveaction, 'execution_db': execution} extra = {} - LOG.audit('Action execution requested. LiveAction.id=%s, ActionExecution.id=%s' % - (liveaction.id, execution.id), extra=extra) + LOG.audit( + "Action execution requested. LiveAction.id=%s, ActionExecution.id=%s" + % (liveaction.id, execution.id), + extra=extra, + ) return liveaction, execution @@ -190,33 +211,34 @@ def update_status(liveaction, new_status, result=None, publish=True): old_status = liveaction.status updates = { - 'liveaction_id': liveaction.id, - 'status': new_status, - 'result': result, - 'publish': False + "liveaction_id": liveaction.id, + "status": new_status, + "result": result, + "publish": False, } if new_status in action_constants.LIVEACTION_COMPLETED_STATES: - updates['end_timestamp'] = date_utils.get_datetime_utc_now() + updates["end_timestamp"] = date_utils.get_datetime_utc_now() liveaction = action_utils.update_liveaction_status(**updates) action_execution = executions.update_execution(liveaction) - msg = ('The status of action execution is changed from %s to %s. ' - '' % (old_status, - new_status, liveaction.id, action_execution.id)) + msg = ( + "The status of action execution is changed from %s to %s. " + "" + % (old_status, new_status, liveaction.id, action_execution.id) + ) - extra = { - 'action_execution_db': action_execution, - 'liveaction_db': liveaction - } + extra = {"action_execution_db": action_execution, "liveaction_db": liveaction} LOG.audit(msg, extra=extra) LOG.info(msg) # Invoke post run if liveaction status is completed or paused. - if (new_status in action_constants.LIVEACTION_COMPLETED_STATES or - new_status == action_constants.LIVEACTION_STATUS_PAUSED): + if ( + new_status in action_constants.LIVEACTION_COMPLETED_STATES + or new_status == action_constants.LIVEACTION_STATUS_PAUSED + ): runners_utils.invoke_post_run(liveaction) if publish: @@ -227,14 +249,18 @@ def update_status(liveaction, new_status, result=None, publish=True): def is_action_canceled_or_canceling(liveaction_id): liveaction_db = action_utils.get_liveaction_by_id(liveaction_id) - return liveaction_db.status in [action_constants.LIVEACTION_STATUS_CANCELED, - action_constants.LIVEACTION_STATUS_CANCELING] + return liveaction_db.status in [ + action_constants.LIVEACTION_STATUS_CANCELED, + action_constants.LIVEACTION_STATUS_CANCELING, + ] def is_action_paused_or_pausing(liveaction_id): liveaction_db = action_utils.get_liveaction_by_id(liveaction_id) - return liveaction_db.status in [action_constants.LIVEACTION_STATUS_PAUSED, - action_constants.LIVEACTION_STATUS_PAUSING] + return liveaction_db.status in [ + action_constants.LIVEACTION_STATUS_PAUSED, + action_constants.LIVEACTION_STATUS_PAUSING, + ] def request_cancellation(liveaction, requester): @@ -250,18 +276,17 @@ def request_cancellation(liveaction, requester): if liveaction.status not in action_constants.LIVEACTION_CANCELABLE_STATES: raise Exception( 'Unable to cancel liveaction "%s" because it is already in a ' - 'completed state.' % liveaction.id + "completed state." % liveaction.id ) - result = { - 'message': 'Action canceled by user.', - 'user': requester - } + result = {"message": "Action canceled by user.", "user": requester} # Run cancelation sequence for liveaction that is in running state or # if the liveaction is operating under a workflow. - if ('parent' in liveaction.context or - liveaction.status in action_constants.LIVEACTION_STATUS_RUNNING): + if ( + "parent" in liveaction.context + or liveaction.status in action_constants.LIVEACTION_STATUS_RUNNING + ): status = action_constants.LIVEACTION_STATUS_CANCELING else: status = action_constants.LIVEACTION_STATUS_CANCELED @@ -286,17 +311,19 @@ def request_pause(liveaction, requester): if not action_db: raise ValueError( 'Unable to pause liveaction "%s" because the action "%s" ' - 'is not found.' % (liveaction.id, liveaction.action) + "is not found." % (liveaction.id, liveaction.action) ) - if action_db.runner_type['name'] not in action_constants.WORKFLOW_RUNNER_TYPES: + if action_db.runner_type["name"] not in action_constants.WORKFLOW_RUNNER_TYPES: raise runner_exc.InvalidActionRunnerOperationError( 'Unable to pause liveaction "%s" because it is not supported by the ' - '"%s" runner.' % (liveaction.id, action_db.runner_type['name']) + '"%s" runner.' % (liveaction.id, action_db.runner_type["name"]) ) - if (liveaction.status == action_constants.LIVEACTION_STATUS_PAUSING or - liveaction.status == action_constants.LIVEACTION_STATUS_PAUSED): + if ( + liveaction.status == action_constants.LIVEACTION_STATUS_PAUSING + or liveaction.status == action_constants.LIVEACTION_STATUS_PAUSED + ): execution = ActionExecution.get(liveaction__id=str(liveaction.id)) return (liveaction, execution) @@ -326,18 +353,18 @@ def request_resume(liveaction, requester): if not action_db: raise ValueError( 'Unable to resume liveaction "%s" because the action "%s" ' - 'is not found.' % (liveaction.id, liveaction.action) + "is not found." % (liveaction.id, liveaction.action) ) - if action_db.runner_type['name'] not in action_constants.WORKFLOW_RUNNER_TYPES: + if action_db.runner_type["name"] not in action_constants.WORKFLOW_RUNNER_TYPES: raise runner_exc.InvalidActionRunnerOperationError( 'Unable to resume liveaction "%s" because it is not supported by the ' - '"%s" runner.' % (liveaction.id, action_db.runner_type['name']) + '"%s" runner.' % (liveaction.id, action_db.runner_type["name"]) ) running_states = [ action_constants.LIVEACTION_STATUS_RUNNING, - action_constants.LIVEACTION_STATUS_RESUMING + action_constants.LIVEACTION_STATUS_RESUMING, ] if liveaction.status in running_states: @@ -367,13 +394,13 @@ def get_parent_liveaction(liveaction_db): :rtype: LiveActionDB """ - parent = liveaction_db.context.get('parent') + parent = liveaction_db.context.get("parent") if not parent: return None - parent_execution_db = ActionExecution.get(id=parent['execution_id']) - parent_liveaction_db = LiveAction.get(id=parent_execution_db.liveaction['id']) + parent_execution_db = ActionExecution.get(id=parent["execution_id"]) + parent_liveaction_db = LiveAction.get(id=parent_execution_db.liveaction["id"]) return parent_liveaction_db @@ -409,7 +436,11 @@ def get_root_liveaction(liveaction_db): parent_liveaction_db = get_parent_liveaction(liveaction_db) - return get_root_liveaction(parent_liveaction_db) if parent_liveaction_db else liveaction_db + return ( + get_root_liveaction(parent_liveaction_db) + if parent_liveaction_db + else liveaction_db + ) def get_root_execution(execution_db): @@ -425,36 +456,48 @@ def get_root_execution(execution_db): parent_execution_db = get_parent_execution(execution_db) - return get_root_execution(parent_execution_db) if parent_execution_db else execution_db + return ( + get_root_execution(parent_execution_db) if parent_execution_db else execution_db + ) -def store_execution_output_data(execution_db, action_db, data, output_type='output', - timestamp=None): +def store_execution_output_data( + execution_db, action_db, data, output_type="output", timestamp=None +): """ Store output from an execution as a new document in the collection. """ execution_id = str(execution_db.id) if action_db is None: - action_ref = execution_db.action.get('ref', 'unknown') - runner_ref = execution_db.action.get('runner_type', 'unknown') + action_ref = execution_db.action.get("ref", "unknown") + runner_ref = execution_db.action.get("runner_type", "unknown") else: action_ref = action_db.ref - runner_ref = getattr(action_db, 'runner_type', {}).get('name', 'unknown') + runner_ref = getattr(action_db, "runner_type", {}).get("name", "unknown") return store_execution_output_data_ex( - execution_id, action_ref, runner_ref, data, - output_type=output_type, timestamp=timestamp + execution_id, + action_ref, + runner_ref, + data, + output_type=output_type, + timestamp=timestamp, ) -def store_execution_output_data_ex(execution_id, action_ref, runner_ref, data, output_type='output', - timestamp=None): +def store_execution_output_data_ex( + execution_id, action_ref, runner_ref, data, output_type="output", timestamp=None +): timestamp = timestamp or date_utils.get_datetime_utc_now() output_db = ActionExecutionOutputDB( - execution_id=execution_id, action_ref=action_ref, runner_ref=runner_ref, - timestamp=timestamp, output_type=output_type, data=data + execution_id=execution_id, + action_ref=action_ref, + runner_ref=runner_ref, + timestamp=timestamp, + output_type=output_type, + data=data, ) output_db = ActionExecutionOutput.add_or_update( @@ -467,29 +510,29 @@ def store_execution_output_data_ex(execution_id, action_ref, runner_ref, data, o def is_children_active(liveaction_id): execution_db = ActionExecution.get(liveaction__id=str(liveaction_id)) - if execution_db.runner['name'] not in action_constants.WORKFLOW_RUNNER_TYPES: + if execution_db.runner["name"] not in action_constants.WORKFLOW_RUNNER_TYPES: return False children_execution_dbs = ActionExecution.query(parent=str(execution_db.id)) - inactive_statuses = ( - action_constants.LIVEACTION_COMPLETED_STATES + - [action_constants.LIVEACTION_STATUS_PAUSED, action_constants.LIVEACTION_STATUS_PENDING] - ) + inactive_statuses = action_constants.LIVEACTION_COMPLETED_STATES + [ + action_constants.LIVEACTION_STATUS_PAUSED, + action_constants.LIVEACTION_STATUS_PENDING, + ] completed = [ child_exec_db.status in inactive_statuses for child_exec_db in children_execution_dbs ] - return (not all(completed)) + return not all(completed) def _cleanup_liveaction(liveaction): try: LiveAction.delete(liveaction) except: - LOG.exception('Failed cleaning up LiveAction: %s.', liveaction) + LOG.exception("Failed cleaning up LiveAction: %s.", liveaction) pass diff --git a/st2common/st2common/services/config.py b/st2common/st2common/services/config.py index f23d91ee9c2..bef8f483dd8 100644 --- a/st2common/st2common/services/config.py +++ b/st2common/st2common/services/config.py @@ -28,13 +28,15 @@ from st2common.exceptions.db import StackStormDBObjectNotFoundError __all__ = [ - 'set_datastore_value_for_config_key', + "set_datastore_value_for_config_key", ] LOG = logging.getLogger(__name__) -def set_datastore_value_for_config_key(pack_name, key_name, value, secret=False, user=None): +def set_datastore_value_for_config_key( + pack_name, key_name, value, secret=False, user=None +): """ Set config value in the datastore. diff --git a/st2common/st2common/services/coordination.py b/st2common/st2common/services/coordination.py index 42556ea0842..1a068632abb 100644 --- a/st2common/st2common/services/coordination.py +++ b/st2common/st2common/services/coordination.py @@ -31,19 +31,17 @@ COORDINATOR = None __all__ = [ - 'configured', - - 'get_coordinator', - 'get_coordinator_if_set', - 'get_member_id', - - 'coordinator_setup', - 'coordinator_teardown' + "configured", + "get_coordinator", + "get_coordinator_if_set", + "get_member_id", + "coordinator_setup", + "coordinator_teardown", ] class NoOpLock(locking.Lock): - def __init__(self, name='noop'): + def __init__(self, name="noop"): super(NoOpLock, self).__init__(name=name) def acquire(self, blocking=True): @@ -61,6 +59,7 @@ class NoOpAsyncResult(object): In most scenarios, tooz library returns an async result, a future and this class wrapper is here to correctly mimic tooz API and behavior. """ + def __init__(self, result=None): self._result = result @@ -108,7 +107,7 @@ def stand_down_group_leader(group_id): @classmethod def create_group(cls, group_id): - cls.groups[group_id] = {'members': {}} + cls.groups[group_id] = {"members": {}} return NoOpAsyncResult() @classmethod @@ -116,17 +115,17 @@ def get_groups(cls): return NoOpAsyncResult(result=cls.groups.keys()) @classmethod - def join_group(cls, group_id, capabilities=''): + def join_group(cls, group_id, capabilities=""): member_id = get_member_id() - cls.groups[group_id]['members'][member_id] = {'capabilities': capabilities} + cls.groups[group_id]["members"][member_id] = {"capabilities": capabilities} return NoOpAsyncResult() @classmethod def leave_group(cls, group_id): member_id = get_member_id() - del cls.groups[group_id]['members'][member_id] + del cls.groups[group_id]["members"][member_id] return NoOpAsyncResult() @classmethod @@ -137,15 +136,15 @@ def delete_group(cls, group_id): @classmethod def get_members(cls, group_id): try: - member_ids = cls.groups[group_id]['members'].keys() + member_ids = cls.groups[group_id]["members"].keys() except KeyError: - raise GroupNotCreated('Group doesnt exist') + raise GroupNotCreated("Group doesnt exist") return NoOpAsyncResult(result=member_ids) @classmethod def get_member_capabilities(cls, group_id, member_id): - member_capabiliteis = cls.groups[group_id]['members'][member_id]['capabilities'] + member_capabiliteis = cls.groups[group_id]["members"][member_id]["capabilities"] return NoOpAsyncResult(result=member_capabiliteis) @staticmethod @@ -158,7 +157,7 @@ def get_leader(group_id): @staticmethod def get_lock(name): - return NoOpLock(name='noop') + return NoOpLock(name="noop") def configured(): @@ -168,8 +167,10 @@ def configured(): :rtype: ``bool`` """ backend_configured = cfg.CONF.coordination.url is not None - mock_backend = backend_configured and (cfg.CONF.coordination.url.startswith('zake') or - cfg.CONF.coordination.url.startswith('file')) + mock_backend = backend_configured and ( + cfg.CONF.coordination.url.startswith("zake") + or cfg.CONF.coordination.url.startswith("file") + ) return backend_configured and not mock_backend @@ -189,7 +190,9 @@ def coordinator_setup(start_heart=True): member_id = get_member_id() if url: - coordinator = coordination.get_coordinator(url, member_id, lock_timeout=lock_timeout) + coordinator = coordination.get_coordinator( + url, member_id, lock_timeout=lock_timeout + ) else: # Use a no-op backend # Note: We don't use tooz to obtain a reference since for this to work we would need to @@ -217,17 +220,21 @@ def get_coordinator(start_heart=True, use_cache=True): global COORDINATOR if not configured(): - LOG.warn('Coordination backend is not configured. Code paths which use coordination ' - 'service will use best effort approach and race conditions are possible.') + LOG.warn( + "Coordination backend is not configured. Code paths which use coordination " + "service will use best effort approach and race conditions are possible." + ) if not use_cache: return coordinator_setup(start_heart=start_heart) if not COORDINATOR: COORDINATOR = coordinator_setup(start_heart=start_heart) - LOG.debug('Initializing and caching new coordinator instance: %s' % (str(COORDINATOR))) + LOG.debug( + "Initializing and caching new coordinator instance: %s" % (str(COORDINATOR)) + ) else: - LOG.debug('Using cached coordinator instance: %s' % (str(COORDINATOR))) + LOG.debug("Using cached coordinator instance: %s" % (str(COORDINATOR))) return COORDINATOR @@ -247,5 +254,5 @@ def get_member_id(): :rtype: ``bytes`` """ proc_info = system_info.get_process_info() - member_id = six.b('%s_%d' % (proc_info['hostname'], proc_info['pid'])) + member_id = six.b("%s_%d" % (proc_info["hostname"], proc_info["pid"])) return member_id diff --git a/st2common/st2common/services/datastore.py b/st2common/st2common/services/datastore.py index 986ffd0d03a..9655499e49a 100644 --- a/st2common/st2common/services/datastore.py +++ b/st2common/st2common/services/datastore.py @@ -24,11 +24,7 @@ from st2common.util.date import get_datetime_utc_now from st2common.constants.keyvalue import DATASTORE_KEY_SEPARATOR, SYSTEM_SCOPE -__all__ = [ - 'BaseDatastoreService', - 'ActionDatastoreService', - 'SensorDatastoreService' -] +__all__ = ["BaseDatastoreService", "ActionDatastoreService", "SensorDatastoreService"] class BaseDatastoreService(object): @@ -63,7 +59,7 @@ def get_user_info(self): """ client = self.get_api_client() - self._logger.debug('Retrieving user information') + self._logger.debug("Retrieving user information") result = client.get_user_info() return result @@ -85,7 +81,7 @@ def list_values(self, local=True, prefix=None): :rtype: ``list`` of :class:`KeyValuePair` """ client = self.get_api_client() - self._logger.debug('Retrieving all the values from the datastore') + self._logger.debug("Retrieving all the values from the datastore") key_prefix = self._get_full_key_prefix(local=local, prefix=prefix) kvps = client.keys.get_all(prefix=key_prefix) @@ -113,21 +109,19 @@ def get_value(self, name, local=True, scope=SYSTEM_SCOPE, decrypt=False): :rtype: ``str`` or ``None`` """ if scope != SYSTEM_SCOPE: - raise ValueError('Scope %s is unsupported.' % scope) + raise ValueError("Scope %s is unsupported." % scope) name = self._get_full_key_name(name=name, local=local) client = self.get_api_client() - self._logger.debug('Retrieving value from the datastore (name=%s)', name) + self._logger.debug("Retrieving value from the datastore (name=%s)", name) try: - params = {'decrypt': str(decrypt).lower(), 'scope': scope} + params = {"decrypt": str(decrypt).lower(), "scope": scope} kvp = client.keys.get_by_id(id=name, params=params) except Exception as e: self._logger.exception( - 'Exception retrieving value from datastore (name=%s): %s', - name, - e + "Exception retrieving value from datastore (name=%s): %s", name, e ) return None @@ -136,7 +130,9 @@ def get_value(self, name, local=True, scope=SYSTEM_SCOPE, decrypt=False): return None - def set_value(self, name, value, ttl=None, local=True, scope=SYSTEM_SCOPE, encrypt=False): + def set_value( + self, name, value, ttl=None, local=True, scope=SYSTEM_SCOPE, encrypt=False + ): """ Set a value for the provided key. @@ -165,14 +161,14 @@ def set_value(self, name, value, ttl=None, local=True, scope=SYSTEM_SCOPE, encry :rtype: ``bool`` """ if scope != SYSTEM_SCOPE: - raise ValueError('Scope %s is unsupported.' % scope) + raise ValueError("Scope %s is unsupported." % scope) name = self._get_full_key_name(name=name, local=local) value = str(value) client = self.get_api_client() - self._logger.debug('Setting value in the datastore (name=%s)', name) + self._logger.debug("Setting value in the datastore (name=%s)", name) instance = KeyValuePair() instance.id = name @@ -208,7 +204,7 @@ def delete_value(self, name, local=True, scope=SYSTEM_SCOPE): :rtype: ``bool`` """ if scope != SYSTEM_SCOPE: - raise ValueError('Scope %s is unsupported.' % scope) + raise ValueError("Scope %s is unsupported." % scope) name = self._get_full_key_name(name=name, local=local) @@ -218,16 +214,14 @@ def delete_value(self, name, local=True, scope=SYSTEM_SCOPE): instance.id = name instance.name = name - self._logger.debug('Deleting value from the datastore (name=%s)', name) + self._logger.debug("Deleting value from the datastore (name=%s)", name) try: - params = {'scope': scope} + params = {"scope": scope} client.keys.delete(instance=instance, params=params) except Exception as e: self._logger.exception( - 'Exception deleting value from datastore (name=%s): %s', - name, - e + "Exception deleting value from datastore (name=%s): %s", name, e ) return False @@ -237,7 +231,7 @@ def get_api_client(self): """ Retrieve API client instance. """ - raise NotImplementedError('get_api_client() not implemented') + raise NotImplementedError("get_api_client() not implemented") def _get_full_key_name(self, name, local): """ @@ -282,7 +276,7 @@ def _get_key_name_with_prefix(self, name): return full_name def _get_datastore_key_prefix(self): - prefix = '%s.%s' % (self._pack_name, self._class_name) + prefix = "%s.%s" % (self._pack_name, self._class_name) return prefix @@ -299,8 +293,9 @@ def __init__(self, logger, pack_name, class_name, auth_token): :param auth_token: Auth token used to authenticate with StackStorm API. :type auth_token: ``str`` """ - super(ActionDatastoreService, self).__init__(logger=logger, pack_name=pack_name, - class_name=class_name) + super(ActionDatastoreService, self).__init__( + logger=logger, pack_name=pack_name, class_name=class_name + ) self._auth_token = auth_token self._client = None @@ -310,7 +305,7 @@ def get_api_client(self): Retrieve API client instance. """ if not self._client: - self._logger.debug('Creating new Client object.') + self._logger.debug("Creating new Client object.") api_url = get_full_public_api_url() client = Client(api_url=api_url, token=self._auth_token) @@ -330,8 +325,9 @@ class SensorDatastoreService(BaseDatastoreService): """ def __init__(self, logger, pack_name, class_name, api_username): - super(SensorDatastoreService, self).__init__(logger=logger, pack_name=pack_name, - class_name=class_name) + super(SensorDatastoreService, self).__init__( + logger=logger, pack_name=pack_name, class_name=class_name + ) self._api_username = api_username self._token_expire = get_datetime_utc_now() @@ -344,12 +340,15 @@ def get_api_client(self): if not self._client or token_expire: # Note: Late import to avoid high import cost (time wise) from st2common.services.access import create_token - self._logger.debug('Creating new Client object.') + + self._logger.debug("Creating new Client object.") ttl = cfg.CONF.auth.service_token_ttl api_url = get_full_public_api_url() - temporary_token = create_token(username=self._api_username, ttl=ttl, service=True) + temporary_token = create_token( + username=self._api_username, ttl=ttl, service=True + ) self._client = Client(api_url=api_url, token=temporary_token.token) self._token_expire = get_datetime_utc_now() + timedelta(seconds=ttl) diff --git a/st2common/st2common/services/executions.py b/st2common/st2common/services/executions.py index e259977bdc1..51447796b0e 100644 --- a/st2common/st2common/services/executions.py +++ b/st2common/st2common/services/executions.py @@ -51,13 +51,13 @@ __all__ = [ - 'create_execution_object', - 'update_execution', - 'abandon_execution_if_incomplete', - 'is_execution_canceled', - 'AscendingSortedDescendantView', - 'DFSDescendantView', - 'get_descendants' + "create_execution_object", + "update_execution", + "abandon_execution_if_incomplete", + "is_execution_canceled", + "AscendingSortedDescendantView", + "DFSDescendantView", + "get_descendants", ] LOG = logging.getLogger(__name__) @@ -66,13 +66,13 @@ # into a ActionExecution compatible dictionary. # Those attributes are LiveAction specific and are therefore stored in a "liveaction" key LIVEACTION_ATTRIBUTES = [ - 'id', - 'callback', - 'action', - 'action_is_workflow', - 'runner_info', - 'parameters', - 'notify' + "id", + "callback", + "action", + "action_is_workflow", + "runner_info", + "parameters", + "notify", ] @@ -80,11 +80,11 @@ def _decompose_liveaction(liveaction_db): """ Splits the liveaction into an ActionExecution compatible dict. """ - decomposed = {'liveaction': {}} + decomposed = {"liveaction": {}} liveaction_api = vars(LiveActionAPI.from_model(liveaction_db)) for k in liveaction_api.keys(): if k in LIVEACTION_ATTRIBUTES: - decomposed['liveaction'][k] = liveaction_api[k] + decomposed["liveaction"][k] = liveaction_api[k] else: decomposed[k] = getattr(liveaction_db, k) return decomposed @@ -94,49 +94,53 @@ def _create_execution_log_entry(status): """ Create execution log entry object for the provided execution status. """ - return { - 'timestamp': date_utils.get_datetime_utc_now(), - 'status': status - } + return {"timestamp": date_utils.get_datetime_utc_now(), "status": status} -def create_execution_object(liveaction, action_db=None, runnertype_db=None, publish=True): +def create_execution_object( + liveaction, action_db=None, runnertype_db=None, publish=True +): if not action_db: action_db = action_utils.get_action_by_ref(liveaction.action) if not runnertype_db: - runnertype_db = RunnerType.get_by_name(action_db.runner_type['name']) + runnertype_db = RunnerType.get_by_name(action_db.runner_type["name"]) attrs = { - 'action': vars(ActionAPI.from_model(action_db)), - 'parameters': liveaction['parameters'], - 'runner': vars(RunnerTypeAPI.from_model(runnertype_db)) + "action": vars(ActionAPI.from_model(action_db)), + "parameters": liveaction["parameters"], + "runner": vars(RunnerTypeAPI.from_model(runnertype_db)), } attrs.update(_decompose_liveaction(liveaction)) - if 'rule' in liveaction.context: - rule = reference.get_model_from_ref(Rule, liveaction.context.get('rule', {})) - attrs['rule'] = vars(RuleAPI.from_model(rule)) + if "rule" in liveaction.context: + rule = reference.get_model_from_ref(Rule, liveaction.context.get("rule", {})) + attrs["rule"] = vars(RuleAPI.from_model(rule)) - if 'trigger_instance' in liveaction.context: - trigger_instance_id = liveaction.context.get('trigger_instance', {}) - trigger_instance_id = trigger_instance_id.get('id', None) + if "trigger_instance" in liveaction.context: + trigger_instance_id = liveaction.context.get("trigger_instance", {}) + trigger_instance_id = trigger_instance_id.get("id", None) trigger_instance = TriggerInstance.get_by_id(trigger_instance_id) - trigger = reference.get_model_by_resource_ref(db_api=Trigger, - ref=trigger_instance.trigger) - trigger_type = reference.get_model_by_resource_ref(db_api=TriggerType, - ref=trigger.type) + trigger = reference.get_model_by_resource_ref( + db_api=Trigger, ref=trigger_instance.trigger + ) + trigger_type = reference.get_model_by_resource_ref( + db_api=TriggerType, ref=trigger.type + ) trigger_instance = reference.get_model_from_ref( - TriggerInstance, liveaction.context.get('trigger_instance', {})) - attrs['trigger_instance'] = vars(TriggerInstanceAPI.from_model(trigger_instance)) - attrs['trigger'] = vars(TriggerAPI.from_model(trigger)) - attrs['trigger_type'] = vars(TriggerTypeAPI.from_model(trigger_type)) + TriggerInstance, liveaction.context.get("trigger_instance", {}) + ) + attrs["trigger_instance"] = vars( + TriggerInstanceAPI.from_model(trigger_instance) + ) + attrs["trigger"] = vars(TriggerAPI.from_model(trigger)) + attrs["trigger_type"] = vars(TriggerTypeAPI.from_model(trigger_type)) parent = _get_parent_execution(liveaction) if parent: - attrs['parent'] = str(parent.id) + attrs["parent"] = str(parent.id) - attrs['log'] = [_create_execution_log_entry(liveaction['status'])] + attrs["log"] = [_create_execution_log_entry(liveaction["status"])] # TODO: This object initialization takes 20-30or so ms execution = ActionExecutionDB(**attrs) @@ -146,24 +150,30 @@ def create_execution_object(liveaction, action_db=None, runnertype_db=None, publ # NOTE: User input data is already validate as part of the API request, # other data is set by us. Skipping validation here makes operation 10%-30% faster - execution = ActionExecution.add_or_update(execution, publish=publish, validate=False) + execution = ActionExecution.add_or_update( + execution, publish=publish, validate=False + ) if parent and str(execution.id) not in parent.children: values = {} - values['push__children'] = str(execution.id) + values["push__children"] = str(execution.id) ActionExecution.update(parent, **values) return execution def _get_parent_execution(child_liveaction_db): - parent_execution_id = child_liveaction_db.context.get('parent', {}).get('execution_id', None) + parent_execution_id = child_liveaction_db.context.get("parent", {}).get( + "execution_id", None + ) if parent_execution_id: try: return ActionExecution.get_by_id(parent_execution_id) except: - LOG.exception('No valid execution object found in db for id: %s' % parent_execution_id) + LOG.exception( + "No valid execution object found in db for id: %s" % parent_execution_id + ) return None return None @@ -180,12 +190,12 @@ def update_execution(liveaction_db, publish=True): kw = {} for k, v in six.iteritems(decomposed): - kw['set__' + k] = v + kw["set__" + k] = v if liveaction_db.status != execution.status: # Note: If the status changes we store this transition in the "log" attribute of action # execution - kw['push__log'] = _create_execution_log_entry(liveaction_db.status) + kw["push__log"] = _create_execution_log_entry(liveaction_db.status) execution = ActionExecution.update(execution, publish=publish, **kw) return execution @@ -201,19 +211,25 @@ def abandon_execution_if_incomplete(liveaction_id, publish=True): # No need to abandon and already complete action if liveaction_db.status in action_constants.LIVEACTION_COMPLETED_STATES: - raise ValueError('LiveAction %s already in a completed state %s.' % - (liveaction_id, liveaction_db.status)) + raise ValueError( + "LiveAction %s already in a completed state %s." + % (liveaction_id, liveaction_db.status) + ) # Update status to reflect execution being abandoned. liveaction_db = action_utils.update_liveaction_status( status=action_constants.LIVEACTION_STATUS_ABANDONED, liveaction_db=liveaction_db, - result={}) + result={}, + ) execution_db = update_execution(liveaction_db, publish=publish) - LOG.info('Marked execution %s as %s.', execution_db.id, - action_constants.LIVEACTION_STATUS_ABANDONED) + LOG.info( + "Marked execution %s as %s.", + execution_db.id, + action_constants.LIVEACTION_STATUS_ABANDONED, + ) # Invoke post run on the action to execute post run operations such as callback. runners_utils.invoke_post_run(liveaction_db) @@ -236,10 +252,10 @@ def get_parent_context(liveaction_db): :return: If found the parent context else None. :rtype: dict """ - context = getattr(liveaction_db, 'context', None) + context = getattr(liveaction_db, "context", None) if not context: return None - return context.get('parent', None) + return context.get("parent", None) class AscendingSortedDescendantView(object): @@ -267,8 +283,8 @@ def result(self): DESCENDANT_VIEWS = { - 'sorted': AscendingSortedDescendantView, - 'default': DFSDescendantView + "sorted": AscendingSortedDescendantView, + "default": DFSDescendantView, } @@ -278,9 +294,10 @@ def get_descendants(actionexecution_id, descendant_depth=-1, result_fmt=None): the supplied actionexecution_id. """ descendants = DESCENDANT_VIEWS.get(result_fmt, DFSDescendantView)() - children = ActionExecution.query(parent=actionexecution_id, - **{'order_by': ['start_timestamp']}) - LOG.debug('Found %s children for id %s.', len(children), actionexecution_id) + children = ActionExecution.query( + parent=actionexecution_id, **{"order_by": ["start_timestamp"]} + ) + LOG.debug("Found %s children for id %s.", len(children), actionexecution_id) current_level = [(child, 1) for child in children] while current_level: @@ -291,8 +308,10 @@ def get_descendants(actionexecution_id, descendant_depth=-1, result_fmt=None): continue if level != -1 and level == descendant_depth: continue - children = ActionExecution.query(parent=parent_id, **{'order_by': ['start_timestamp']}) - LOG.debug('Found %s children for id %s.', len(children), parent_id) + children = ActionExecution.query( + parent=parent_id, **{"order_by": ["start_timestamp"]} + ) + LOG.debug("Found %s children for id %s.", len(children), parent_id) # prepend for DFS for idx in range(len(children)): current_level.insert(idx, (children[idx], level + 1)) diff --git a/st2common/st2common/services/inquiry.py b/st2common/st2common/services/inquiry.py index 5b511b3a97e..09be3cc8f11 100644 --- a/st2common/st2common/services/inquiry.py +++ b/st2common/st2common/services/inquiry.py @@ -40,9 +40,11 @@ def check_inquiry(inquiry): - LOG.debug('Checking action execution "%s" to see if is an inquiry.' % str(inquiry.id)) + LOG.debug( + 'Checking action execution "%s" to see if is an inquiry.' % str(inquiry.id) + ) - if inquiry.runner.get('name') != 'inquirer': + if inquiry.runner.get("name") != "inquirer": raise inquiry_exceptions.InvalidInquiryInstance(str(inquiry.id)) LOG.debug('Checking if the inquiry "%s" has timed out.' % str(inquiry.id)) @@ -69,7 +71,7 @@ def check_permission(inquiry, requester): users_passed = False # Determine role-level permissions - roles = getattr(inquiry, 'roles', []) + roles = getattr(inquiry, "roles", []) if not roles: # No roles definition so we treat it as a pass @@ -79,14 +81,16 @@ def check_permission(inquiry, requester): rbac_utils = get_rbac_backend().get_utils_class() user_has_role = rbac_utils.user_has_role(user_db, role) - LOG.debug('Checking user %s is in role %s - %s' % (user_db, role, user_has_role)) + LOG.debug( + "Checking user %s is in role %s - %s" % (user_db, role, user_has_role) + ) if user_has_role: roles_passed = True break # Determine user-level permissions - users = getattr(inquiry, 'users', []) + users = getattr(inquiry, "users", []) if not users or user_db.name in users: users_passed = True @@ -98,7 +102,7 @@ def check_permission(inquiry, requester): def validate_response(inquiry, response): schema = inquiry.schema - LOG.debug('Validating inquiry response: %s against schema: %s' % (response, schema)) + LOG.debug("Validating inquiry response: %s against schema: %s" % (response, schema)) try: schema_utils.validate( @@ -106,12 +110,14 @@ def validate_response(inquiry, response): schema=schema, cls=schema_utils.CustomValidator, use_default=True, - allow_default_none=True + allow_default_none=True, ) except Exception as e: msg = 'Response for inquiry "%s" did not pass schema validation.' LOG.exception(msg % str(inquiry.id)) - raise inquiry_exceptions.InvalidInquiryResponse(str(inquiry.id), six.text_type(e)) + raise inquiry_exceptions.InvalidInquiryResponse( + str(inquiry.id), six.text_type(e) + ) def respond(inquiry, response, requester=None): @@ -120,14 +126,14 @@ def respond(inquiry, response, requester=None): requester = cfg.CONF.system_user.user # Retrieve the liveaction from the database. - liveaction_db = lv_db_access.LiveAction.get_by_id(inquiry.liveaction.get('id')) + liveaction_db = lv_db_access.LiveAction.get_by_id(inquiry.liveaction.get("id")) # Resume the parent workflow first. If the action execution for the inquiry is updated first, # it triggers handling of the action execution completion which will interact with the paused # parent workflow. The resuming logic that is executed here will then race with the completion # of the inquiry action execution, which will randomly result in the parent workflow stuck in # paused state. - if liveaction_db.context.get('parent'): + if liveaction_db.context.get("parent"): LOG.debug('Resuming workflow parent(s) for inquiry "%s".' % str(inquiry.id)) # For action execution under Action Chain workflows, request the entire @@ -136,7 +142,9 @@ def respond(inquiry, response, requester=None): # there is no other paused branches, the conductor will resume the rest of the workflow. resume_target = ( action_service.get_parent_liveaction(liveaction_db) - if workflow_service.is_action_execution_under_workflow_context(liveaction_db) + if workflow_service.is_action_execution_under_workflow_context( + liveaction_db + ) else action_service.get_root_liveaction(liveaction_db) ) @@ -147,14 +155,14 @@ def respond(inquiry, response, requester=None): LOG.debug('Updating response for inquiry "%s".' % str(inquiry.id)) result = copy.deepcopy(inquiry.result) - result['response'] = response + result["response"] = response liveaction_db = action_utils.update_liveaction_status( status=action_constants.LIVEACTION_STATUS_SUCCEEDED, end_timestamp=date_utils.get_datetime_utc_now(), runner_info=sys_info_utils.get_process_info(), result=result, - liveaction_id=str(liveaction_db.id) + liveaction_id=str(liveaction_db.id), ) # Sync the liveaction with the corresponding action execution. @@ -164,7 +172,7 @@ def respond(inquiry, response, requester=None): LOG.debug('Invoking post run for inquiry "%s".' % str(inquiry.id)) runner_container = container.get_runner_container() action_db = action_utils.get_action_by_ref(liveaction_db.action) - runnertype_db = action_utils.get_runnertype_by_name(action_db.runner_type['name']) + runnertype_db = action_utils.get_runnertype_by_name(action_db.runner_type["name"]) runner = runner_container._get_runner(runnertype_db, action_db, liveaction_db) runner.post_run(status=action_constants.LIVEACTION_STATUS_SUCCEEDED, result=result) diff --git a/st2common/st2common/services/keyvalues.py b/st2common/st2common/services/keyvalues.py index 722603eee5f..d38f28ca939 100644 --- a/st2common/st2common/services/keyvalues.py +++ b/st2common/st2common/services/keyvalues.py @@ -28,11 +28,10 @@ from st2common.persistence.keyvalue import KeyValuePair __all__ = [ - 'get_kvp_for_name', - 'get_values_for_names', - - 'KeyValueLookup', - 'UserKeyValueLookup' + "get_kvp_for_name", + "get_values_for_names", + "KeyValueLookup", + "UserKeyValueLookup", ] LOG = logging.getLogger(__name__) @@ -81,17 +80,17 @@ def get_key_name(self): :rtype: ``str`` """ key_name_parts = [DATASTORE_PARENT_SCOPE, self.scope] - key_name = self._key_prefix.split(':', 1) + key_name = self._key_prefix.split(":", 1) if len(key_name) == 1: key_name = key_name[0] elif len(key_name) >= 2: key_name = key_name[1] else: - key_name = '' + key_name = "" key_name_parts.append(key_name) - key_name = '.'.join(key_name_parts) + key_name = ".".join(key_name_parts) return key_name @@ -99,7 +98,9 @@ class KeyValueLookup(BaseKeyValueLookup): scope = SYSTEM_SCOPE - def __init__(self, prefix=None, key_prefix=None, cache=None, scope=FULL_SYSTEM_SCOPE): + def __init__( + self, prefix=None, key_prefix=None, cache=None, scope=FULL_SYSTEM_SCOPE + ): if not scope: scope = FULL_SYSTEM_SCOPE @@ -107,7 +108,7 @@ def __init__(self, prefix=None, key_prefix=None, cache=None, scope=FULL_SYSTEM_S scope = FULL_SYSTEM_SCOPE self._prefix = prefix - self._key_prefix = key_prefix or '' + self._key_prefix = key_prefix or "" self._value_cache = cache or {} self._scope = scope @@ -129,7 +130,7 @@ def __getattr__(self, name): def _get(self, name): # get the value for this key and save in value_cache if self._key_prefix: - key = '%s.%s' % (self._key_prefix, name) + key = "%s.%s" % (self._key_prefix, name) else: key = name @@ -144,12 +145,16 @@ def _get(self, name): # the lookup is for 'key_base.key_value' it is likely that the calling code, e.g. Jinja, # will expect to do a dictionary style lookup for key_base and key_value as subsequent # calls. Saving the value in cache avoids extra DB calls. - return KeyValueLookup(prefix=self._prefix, key_prefix=key, cache=self._value_cache, - scope=self._scope) + return KeyValueLookup( + prefix=self._prefix, + key_prefix=key, + cache=self._value_cache, + scope=self._scope, + ) def _get_kv(self, key): scope = self._scope - LOG.debug('Lookup system kv: scope: %s and key: %s', scope, key) + LOG.debug("Lookup system kv: scope: %s and key: %s", scope, key) try: kvp = KeyValuePair.get_by_scope_and_name(scope=scope, name=key) @@ -157,15 +162,17 @@ def _get_kv(self, key): kvp = None if kvp: - LOG.debug('Got value %s from datastore.', kvp.value) - return kvp.value if kvp else '' + LOG.debug("Got value %s from datastore.", kvp.value) + return kvp.value if kvp else "" class UserKeyValueLookup(BaseKeyValueLookup): scope = USER_SCOPE - def __init__(self, user, prefix=None, key_prefix=None, cache=None, scope=FULL_USER_SCOPE): + def __init__( + self, user, prefix=None, key_prefix=None, cache=None, scope=FULL_USER_SCOPE + ): if not scope: scope = FULL_USER_SCOPE @@ -173,7 +180,7 @@ def __init__(self, user, prefix=None, key_prefix=None, cache=None, scope=FULL_US scope = FULL_USER_SCOPE self._prefix = prefix - self._key_prefix = key_prefix or '' + self._key_prefix = key_prefix or "" self._value_cache = cache or {} self._user = user self._scope = scope @@ -190,7 +197,7 @@ def __getattr__(self, name): def _get(self, name): # get the value for this key and save in value_cache if self._key_prefix: - key = '%s.%s' % (self._key_prefix, name) + key = "%s.%s" % (self._key_prefix, name) else: key = UserKeyReference(name=name, user=self._user).ref @@ -205,8 +212,13 @@ def _get(self, name): # the lookup is for 'key_base.key_value' it is likely that the calling code, e.g. Jinja, # will expect to do a dictionary style lookup for key_base and key_value as subsequent # calls. Saving the value in cache avoids extra DB calls. - return UserKeyValueLookup(prefix=self._prefix, user=self._user, key_prefix=key, - cache=self._value_cache, scope=self._scope) + return UserKeyValueLookup( + prefix=self._prefix, + user=self._user, + key_prefix=key, + cache=self._value_cache, + scope=self._scope, + ) def _get_kv(self, key): scope = self._scope @@ -216,7 +228,7 @@ def _get_kv(self, key): except StackStormDBObjectNotFoundError: kvp = None - return kvp.value if kvp else '' + return kvp.value if kvp else "" def get_key_reference(scope, name, user=None): @@ -232,12 +244,15 @@ def get_key_reference(scope, name, user=None): :rtype: ``str`` """ - if (scope == SYSTEM_SCOPE or scope == FULL_SYSTEM_SCOPE): + if scope == SYSTEM_SCOPE or scope == FULL_SYSTEM_SCOPE: return name - elif (scope == USER_SCOPE or scope == FULL_USER_SCOPE): + elif scope == USER_SCOPE or scope == FULL_USER_SCOPE: if not user: - raise InvalidUserException('A valid user must be specified for user key ref.') + raise InvalidUserException( + "A valid user must be specified for user key ref." + ) return UserKeyReference(name=name, user=user).ref else: - raise InvalidScopeException('Scope "%s" is not valid. Allowed scopes are %s.' % - (scope, ALLOWED_SCOPES)) + raise InvalidScopeException( + 'Scope "%s" is not valid. Allowed scopes are %s.' % (scope, ALLOWED_SCOPES) + ) diff --git a/st2common/st2common/services/packs.py b/st2common/st2common/services/packs.py index 7088b5f3682..9f2794ed78a 100644 --- a/st2common/st2common/services/packs.py +++ b/st2common/st2common/services/packs.py @@ -27,21 +27,15 @@ from six.moves import range __all__ = [ - 'get_pack_by_ref', - 'fetch_pack_index', - 'get_pack_from_index', - 'search_pack_index' + "get_pack_by_ref", + "fetch_pack_index", + "get_pack_from_index", + "search_pack_index", ] -EXCLUDE_FIELDS = [ - "repo_url", - "email" -] +EXCLUDE_FIELDS = ["repo_url", "email"] -SEARCH_PRIORITY = [ - "name", - "keywords" -] +SEARCH_PRIORITY = ["name", "keywords"] LOG = logging.getLogger(__name__) @@ -55,7 +49,7 @@ def _build_index_list(index_url): index_urls = cfg.CONF.content.index_url[::-1] elif isinstance(index_url, str): index_urls = [index_url] - elif hasattr(index_url, '__iter__'): + elif hasattr(index_url, "__iter__"): index_urls = index_url else: raise TypeError('"index_url" should either be a string or an iterable object.') @@ -73,23 +67,23 @@ def _fetch_and_compile_index(index_urls, logger=None, proxy_config=None): verify = True if proxy_config: - https_proxy = proxy_config.get('https_proxy', None) - http_proxy = proxy_config.get('http_proxy', None) - ca_bundle_path = proxy_config.get('proxy_ca_bundle_path', None) + https_proxy = proxy_config.get("https_proxy", None) + http_proxy = proxy_config.get("http_proxy", None) + ca_bundle_path = proxy_config.get("proxy_ca_bundle_path", None) if https_proxy: - proxies_dict['https'] = https_proxy + proxies_dict["https"] = https_proxy verify = ca_bundle_path or True if http_proxy: - proxies_dict['http'] = http_proxy + proxies_dict["http"] = http_proxy for index_url in index_urls: index_status = { - 'url': index_url, - 'packs': 0, - 'message': None, - 'error': None, + "url": index_url, + "packs": 0, + "message": None, + "error": None, } index_json = None @@ -98,32 +92,32 @@ def _fetch_and_compile_index(index_urls, logger=None, proxy_config=None): request.raise_for_status() index_json = request.json() except ValueError as e: - index_status['error'] = 'malformed' - index_status['message'] = repr(e) + index_status["error"] = "malformed" + index_status["message"] = repr(e) except requests.exceptions.RequestException as e: - index_status['error'] = 'unresponsive' - index_status['message'] = repr(e) + index_status["error"] = "unresponsive" + index_status["message"] = repr(e) except Exception as e: - index_status['error'] = 'other errors' - index_status['message'] = repr(e) + index_status["error"] = "other errors" + index_status["message"] = repr(e) if index_json == {}: - index_status['error'] = 'empty' - index_status['message'] = 'The index URL returned an empty object.' + index_status["error"] = "empty" + index_status["message"] = "The index URL returned an empty object." elif type(index_json) is list: - index_status['error'] = 'malformed' - index_status['message'] = 'Expected an index object, got a list instead.' - elif index_json and 'packs' not in index_json: - index_status['error'] = 'malformed' - index_status['message'] = 'Index object is missing "packs" attribute.' + index_status["error"] = "malformed" + index_status["message"] = "Expected an index object, got a list instead." + elif index_json and "packs" not in index_json: + index_status["error"] = "malformed" + index_status["message"] = 'Index object is missing "packs" attribute.' - if index_status['error']: + if index_status["error"]: logger.error("Index parsing error: %s" % json.dumps(index_status, indent=4)) else: # TODO: Notify on a duplicate pack aka pack being overwritten from a different index - packs_data = index_json['packs'] - index_status['message'] = 'Success.' - index_status['packs'] = len(packs_data) + packs_data = index_json["packs"] + index_status["message"] = "Success." + index_status["packs"] = len(packs_data) index.update(packs_data) status.append(index_status) @@ -147,8 +141,9 @@ def fetch_pack_index(index_url=None, logger=None, allow_empty=False, proxy_confi logger = logger or LOG index_urls = _build_index_list(index_url) - index, status = _fetch_and_compile_index(index_urls=index_urls, logger=logger, - proxy_config=proxy_config) + index, status = _fetch_and_compile_index( + index_urls=index_urls, logger=logger, proxy_config=proxy_config + ) # If one of the indexes on the list is unresponsive, we do not throw # immediately. The only case where an exception is raised is when no @@ -156,11 +151,14 @@ def fetch_pack_index(index_url=None, logger=None, allow_empty=False, proxy_confi # This behavior allows for mirrors / backups and handling connection # or network issues in one of the indexes. if not index and not allow_empty: - raise ValueError("No results from the %s: tried %s.\nStatus: %s" % ( - ("index" if len(index_urls) == 1 else "indexes"), - ", ".join(index_urls), - json.dumps(status, indent=4) - )) + raise ValueError( + "No results from the %s: tried %s.\nStatus: %s" + % ( + ("index" if len(index_urls) == 1 else "indexes"), + ", ".join(index_urls), + json.dumps(status, indent=4), + ) + ) return (index, status) @@ -177,13 +175,15 @@ def get_pack_from_index(pack, proxy_config=None): return index.get(pack) -def search_pack_index(query, exclude=None, priority=None, case_sensitive=True, proxy_config=None): +def search_pack_index( + query, exclude=None, priority=None, case_sensitive=True, proxy_config=None +): """ Search the pack index by query. Returns a list of matches for a query. """ if not query: - raise ValueError('Query must be specified.') + raise ValueError("Query must be specified.") if not exclude: exclude = EXCLUDE_FIELDS @@ -198,7 +198,7 @@ def search_pack_index(query, exclude=None, priority=None, case_sensitive=True, p matches = [[] for i in range(len(priority) + 1)] for pack in six.itervalues(index): for key, value in six.iteritems(pack): - if not hasattr(value, '__contains__'): + if not hasattr(value, "__contains__"): value = str(value) if not case_sensitive: diff --git a/st2common/st2common/services/policies.py b/st2common/st2common/services/policies.py index 50ba28f3040..46e24ce2900 100644 --- a/st2common/st2common/services/policies.py +++ b/st2common/st2common/services/policies.py @@ -25,13 +25,10 @@ def has_policies(lv_ac_db, policy_types=None): - query_params = { - 'resource_ref': lv_ac_db.action, - 'enabled': True - } + query_params = {"resource_ref": lv_ac_db.action, "enabled": True} if policy_types: - query_params['policy_type__in'] = policy_types + query_params["policy_type__in"] = policy_types policy_dbs = pc_db_access.Policy.query(**query_params) @@ -42,11 +39,19 @@ def apply_pre_run_policies(lv_ac_db): LOG.debug('Applying pre-run policies for liveaction "%s".' % str(lv_ac_db.id)) policy_dbs = pc_db_access.Policy.query(resource_ref=lv_ac_db.action, enabled=True) - LOG.debug('Identified %s policies for the action "%s".' % (len(policy_dbs), lv_ac_db.action)) + LOG.debug( + 'Identified %s policies for the action "%s".' + % (len(policy_dbs), lv_ac_db.action) + ) for policy_db in policy_dbs: - LOG.debug('Getting driver for policy "%s" (%s).' % (policy_db.ref, policy_db.policy_type)) - driver = engine.get_driver(policy_db.ref, policy_db.policy_type, **policy_db.parameters) + LOG.debug( + 'Getting driver for policy "%s" (%s).' + % (policy_db.ref, policy_db.policy_type) + ) + driver = engine.get_driver( + policy_db.ref, policy_db.policy_type, **policy_db.parameters + ) try: message = 'Applying policy "%s" (%s) for liveaction "%s".' @@ -54,7 +59,9 @@ def apply_pre_run_policies(lv_ac_db): lv_ac_db = driver.apply_before(lv_ac_db) except: message = 'An exception occurred while applying policy "%s" (%s) for liveaction "%s".' - LOG.exception(message % (policy_db.ref, policy_db.policy_type, str(lv_ac_db.id))) + LOG.exception( + message % (policy_db.ref, policy_db.policy_type, str(lv_ac_db.id)) + ) if lv_ac_db.status == ac_const.LIVEACTION_STATUS_DELAYED: break @@ -66,11 +73,19 @@ def apply_post_run_policies(lv_ac_db): LOG.debug('Applying post run policies for liveaction "%s".' % str(lv_ac_db.id)) policy_dbs = pc_db_access.Policy.query(resource_ref=lv_ac_db.action, enabled=True) - LOG.debug('Identified %s policies for the action "%s".' % (len(policy_dbs), lv_ac_db.action)) + LOG.debug( + 'Identified %s policies for the action "%s".' + % (len(policy_dbs), lv_ac_db.action) + ) for policy_db in policy_dbs: - LOG.debug('Getting driver for policy "%s" (%s).' % (policy_db.ref, policy_db.policy_type)) - driver = engine.get_driver(policy_db.ref, policy_db.policy_type, **policy_db.parameters) + LOG.debug( + 'Getting driver for policy "%s" (%s).' + % (policy_db.ref, policy_db.policy_type) + ) + driver = engine.get_driver( + policy_db.ref, policy_db.policy_type, **policy_db.parameters + ) try: message = 'Applying policy "%s" (%s) for liveaction "%s".' @@ -78,6 +93,8 @@ def apply_post_run_policies(lv_ac_db): lv_ac_db = driver.apply_after(lv_ac_db) except: message = 'An exception occurred while applying policy "%s" (%s) for liveaction "%s".' - LOG.exception(message % (policy_db.ref, policy_db.policy_type, str(lv_ac_db.id))) + LOG.exception( + message % (policy_db.ref, policy_db.policy_type, str(lv_ac_db.id)) + ) return lv_ac_db diff --git a/st2common/st2common/services/queries.py b/st2common/st2common/services/queries.py index e6d769e3652..20c7a0c9906 100644 --- a/st2common/st2common/services/queries.py +++ b/st2common/st2common/services/queries.py @@ -25,13 +25,15 @@ def setup_query(liveaction_id, runnertype_db, query_context): - if not getattr(runnertype_db, 'query_module', None): - raise Exception('The runner "%s" does not have a query module.' % runnertype_db.name) + if not getattr(runnertype_db, "query_module", None): + raise Exception( + 'The runner "%s" does not have a query module.' % runnertype_db.name + ) state_db = ActionExecutionStateDB( execution_id=liveaction_id, query_module=runnertype_db.query_module, - query_context=query_context + query_context=query_context, ) ActionExecutionState.add_or_update(state_db) diff --git a/st2common/st2common/services/rules.py b/st2common/st2common/services/rules.py index d9be718e274..ebb80834333 100644 --- a/st2common/st2common/services/rules.py +++ b/st2common/st2common/services/rules.py @@ -22,10 +22,7 @@ LOG = logging.getLogger(__name__) -__all__ = [ - 'get_rules_given_trigger', - 'get_rules_with_trigger_ref' -] +__all__ = ["get_rules_given_trigger", "get_rules_with_trigger_ref"] def get_rules_given_trigger(trigger): @@ -34,13 +31,15 @@ def get_rules_given_trigger(trigger): return get_rules_with_trigger_ref(trigger_ref=trigger) if isinstance(trigger, dict): - trigger_ref = trigger.get('ref', None) + trigger_ref = trigger.get("ref", None) if trigger_ref: return get_rules_with_trigger_ref(trigger_ref=trigger_ref) else: - raise ValueError('Trigger dict %s is missing ``ref``.' % trigger) + raise ValueError("Trigger dict %s is missing ``ref``." % trigger) - raise ValueError('Unknown type %s for trigger. Cannot do rule lookups.' % type(trigger)) + raise ValueError( + "Unknown type %s for trigger. Cannot do rule lookups." % type(trigger) + ) def get_rules_with_trigger_ref(trigger_ref=None, enabled=True): @@ -56,5 +55,5 @@ def get_rules_with_trigger_ref(trigger_ref=None, enabled=True): if not trigger_ref: return None - LOG.debug('Querying rules with trigger %s', trigger_ref) + LOG.debug("Querying rules with trigger %s", trigger_ref) return Rule.query(trigger=trigger_ref, enabled=enabled) diff --git a/st2common/st2common/services/sensor_watcher.py b/st2common/st2common/services/sensor_watcher.py index 0105ba46d6d..1c54881663a 100644 --- a/st2common/st2common/services/sensor_watcher.py +++ b/st2common/st2common/services/sensor_watcher.py @@ -32,9 +32,9 @@ class SensorWatcher(ConsumerMixin): - - def __init__(self, create_handler, update_handler, delete_handler, - queue_suffix=None): + def __init__( + self, create_handler, update_handler, delete_handler, queue_suffix=None + ): """ :param create_handler: Function which is called on SensorDB create event. :type create_handler: ``callable`` @@ -57,34 +57,41 @@ def __init__(self, create_handler, update_handler, delete_handler, self._handlers = { publishers.CREATE_RK: create_handler, publishers.UPDATE_RK: update_handler, - publishers.DELETE_RK: delete_handler + publishers.DELETE_RK: delete_handler, } def get_consumers(self, Consumer, channel): - consumers = [Consumer(queues=[self._sensor_watcher_q], - accept=['pickle'], - callbacks=[self.process_task])] + consumers = [ + Consumer( + queues=[self._sensor_watcher_q], + accept=["pickle"], + callbacks=[self.process_task], + ) + ] return consumers def process_task(self, body, message): - LOG.debug('process_task') - LOG.debug(' body: %s', body) - LOG.debug(' message.properties: %s', message.properties) - LOG.debug(' message.delivery_info: %s', message.delivery_info) + LOG.debug("process_task") + LOG.debug(" body: %s", body) + LOG.debug(" message.properties: %s", message.properties) + LOG.debug(" message.delivery_info: %s", message.delivery_info) - routing_key = message.delivery_info.get('routing_key', '') + routing_key = message.delivery_info.get("routing_key", "") handler = self._handlers.get(routing_key, None) try: if not handler: - LOG.info('Skipping message %s as no handler was found.', message) + LOG.info("Skipping message %s as no handler was found.", message) return try: handler(body) except Exception as e: - LOG.exception('Handling failed. Message body: %s. Exception: %s', - body, six.text_type(e)) + LOG.exception( + "Handling failed. Message body: %s. Exception: %s", + body, + six.text_type(e), + ) finally: message.ack() @@ -93,11 +100,11 @@ def start(self): self.connection = transport_utils.get_connection() self._updates_thread = concurrency.spawn(self.run) except: - LOG.exception('Failed to start sensor_watcher.') + LOG.exception("Failed to start sensor_watcher.") self.connection.release() def stop(self): - LOG.debug('Shutting down sensor watcher.') + LOG.debug("Shutting down sensor watcher.") try: if self._updates_thread: self._updates_thread = concurrency.kill(self._updates_thread) @@ -108,15 +115,19 @@ def stop(self): try: bound_sensor_watch_q.delete() except: - LOG.error('Unable to delete sensor watcher queue: %s', self._sensor_watcher_q) + LOG.error( + "Unable to delete sensor watcher queue: %s", + self._sensor_watcher_q, + ) finally: if self.connection: self.connection.release() @staticmethod def _get_queue(queue_suffix): - queue_name = queue_utils.get_queue_name(queue_name_base='st2.sensor.watch', - queue_name_suffix=queue_suffix, - add_random_uuid_to_suffix=True - ) - return reactor.get_sensor_cud_queue(queue_name, routing_key='#') + queue_name = queue_utils.get_queue_name( + queue_name_base="st2.sensor.watch", + queue_name_suffix=queue_suffix, + add_random_uuid_to_suffix=True, + ) + return reactor.get_sensor_cud_queue(queue_name, routing_key="#") diff --git a/st2common/st2common/services/trace.py b/st2common/st2common/services/trace.py index 3eb92bd2f10..4dadef09643 100644 --- a/st2common/st2common/services/trace.py +++ b/st2common/st2common/services/trace.py @@ -32,22 +32,24 @@ LOG = logging.getLogger(__name__) __all__ = [ - 'get_trace_db_by_action_execution', - 'get_trace_db_by_rule', - 'get_trace_db_by_trigger_instance', - 'get_trace', - 'add_or_update_given_trace_context', - 'add_or_update_given_trace_db', - 'get_trace_component_for_action_execution', - 'get_trace_component_for_rule', - 'get_trace_component_for_trigger_instance' + "get_trace_db_by_action_execution", + "get_trace_db_by_rule", + "get_trace_db_by_trigger_instance", + "get_trace", + "add_or_update_given_trace_context", + "add_or_update_given_trace_db", + "get_trace_component_for_action_execution", + "get_trace_component_for_rule", + "get_trace_component_for_trigger_instance", ] ACTION_SENSOR_TRIGGER_REF = ResourceReference.to_string_reference( - pack=ACTION_SENSOR_TRIGGER['pack'], name=ACTION_SENSOR_TRIGGER['name']) + pack=ACTION_SENSOR_TRIGGER["pack"], name=ACTION_SENSOR_TRIGGER["name"] +) NOTIFY_TRIGGER_REF = ResourceReference.to_string_reference( - pack=NOTIFY_TRIGGER['pack'], name=NOTIFY_TRIGGER['name']) + pack=NOTIFY_TRIGGER["pack"], name=NOTIFY_TRIGGER["name"] +) def _get_valid_trace_context(trace_context): @@ -74,14 +76,17 @@ def _get_single_trace_by_component(**component_filter): return None elif len(traces) > 1: raise UniqueTraceNotFoundException( - 'More than 1 trace matching %s found.' % component_filter) + "More than 1 trace matching %s found." % component_filter + ) return traces[0] def get_trace_db_by_action_execution(action_execution=None, action_execution_id=None): if action_execution: action_execution_id = str(action_execution.id) - return _get_single_trace_by_component(action_executions__object_id=action_execution_id) + return _get_single_trace_by_component( + action_executions__object_id=action_execution_id + ) def get_trace_db_by_rule(rule=None, rule_id=None): @@ -94,7 +99,9 @@ def get_trace_db_by_rule(rule=None, rule_id=None): def get_trace_db_by_trigger_instance(trigger_instance=None, trigger_instance_id=None): if trigger_instance: trigger_instance_id = str(trigger_instance.id) - return _get_single_trace_by_component(trigger_instances__object_id=trigger_instance_id) + return _get_single_trace_by_component( + trigger_instances__object_id=trigger_instance_id + ) def get_trace(trace_context, ignore_trace_tag=False): @@ -111,16 +118,20 @@ def get_trace(trace_context, ignore_trace_tag=False): trace_context = _get_valid_trace_context(trace_context) if not trace_context.id_ and not trace_context.trace_tag: - raise ValueError('Atleast one of id_ or trace_tag should be specified.') + raise ValueError("Atleast one of id_ or trace_tag should be specified.") if trace_context.id_: try: return Trace.get_by_id(trace_context.id_) except (ValidationError, ValueError): - LOG.warning('Database lookup for Trace with id="%s" failed.', - trace_context.id_, exc_info=True) + LOG.warning( + 'Database lookup for Trace with id="%s" failed.', + trace_context.id_, + exc_info=True, + ) raise StackStormDBObjectNotFoundError( - 'Unable to find Trace with id="%s"' % trace_context.id_) + 'Unable to find Trace with id="%s"' % trace_context.id_ + ) if ignore_trace_tag: return None @@ -130,7 +141,8 @@ def get_trace(trace_context, ignore_trace_tag=False): # Assume this method only handles 1 trace. if len(traces) > 1: raise UniqueTraceNotFoundException( - 'More than 1 Trace matching %s found.' % trace_context.trace_tag) + "More than 1 Trace matching %s found." % trace_context.trace_tag + ) return traces[0] @@ -168,14 +180,17 @@ def get_trace_db_by_live_action(liveaction): # This cover case for child execution of a workflow. parent_context = executions.get_parent_context(liveaction_db=liveaction) if not trace_context and parent_context: - parent_execution_id = parent_context.get('execution_id', None) + parent_execution_id = parent_context.get("execution_id", None) if parent_execution_id: # go straight to a trace_db. If there is a parent execution then that must # be associated with a Trace. - trace_db = get_trace_db_by_action_execution(action_execution_id=parent_execution_id) + trace_db = get_trace_db_by_action_execution( + action_execution_id=parent_execution_id + ) if not trace_db: - raise StackStormDBObjectNotFoundError('No trace found for execution %s' % - parent_execution_id) + raise StackStormDBObjectNotFoundError( + "No trace found for execution %s" % parent_execution_id + ) return (created, trace_db) # 3. Check if the action_execution associated with liveaction leads to a trace_db execution = ActionExecution.get(liveaction__id=str(liveaction.id)) @@ -184,13 +199,14 @@ def get_trace_db_by_live_action(liveaction): # 4. No trace_db found, therefore create one. This typically happens # when execution is run by hand. if not trace_db: - trace_db = TraceDB(trace_tag='execution-%s' % str(liveaction.id)) + trace_db = TraceDB(trace_tag="execution-%s" % str(liveaction.id)) created = True return (created, trace_db) -def add_or_update_given_trace_context(trace_context, action_executions=None, rules=None, - trigger_instances=None): +def add_or_update_given_trace_context( + trace_context, action_executions=None, rules=None, trigger_instances=None +): """ Will update an existing Trace or add a new Trace. This method will only look for exact Trace as identified by the trace_context. Even if the trace_context contain a trace_tag @@ -222,14 +238,17 @@ def add_or_update_given_trace_context(trace_context, action_executions=None, rul # since trace_db is None need to end up with a valid trace_context trace_context = _get_valid_trace_context(trace_context) trace_db = TraceDB(trace_tag=trace_context.trace_tag) - return add_or_update_given_trace_db(trace_db=trace_db, - action_executions=action_executions, - rules=rules, - trigger_instances=trigger_instances) + return add_or_update_given_trace_db( + trace_db=trace_db, + action_executions=action_executions, + rules=rules, + trigger_instances=trigger_instances, + ) -def add_or_update_given_trace_db(trace_db, action_executions=None, rules=None, - trigger_instances=None): +def add_or_update_given_trace_db( + trace_db, action_executions=None, rules=None, trigger_instances=None +): """ Will update an existing Trace. @@ -251,12 +270,14 @@ def add_or_update_given_trace_db(trace_db, action_executions=None, rules=None, :rtype: ``TraceDB`` """ if trace_db is None: - raise ValueError('trace_db should be non-None.') + raise ValueError("trace_db should be non-None.") if not action_executions: action_executions = [] - action_executions = [_to_trace_component_db(component=action_execution) - for action_execution in action_executions] + action_executions = [ + _to_trace_component_db(component=action_execution) + for action_execution in action_executions + ] if not rules: rules = [] @@ -264,16 +285,20 @@ def add_or_update_given_trace_db(trace_db, action_executions=None, rules=None, if not trigger_instances: trigger_instances = [] - trigger_instances = [_to_trace_component_db(component=trigger_instance) - for trigger_instance in trigger_instances] + trigger_instances = [ + _to_trace_component_db(component=trigger_instance) + for trigger_instance in trigger_instances + ] # If an id exists then this is an update and we do not want to perform # an upsert so use push_components which will use the push operator. if trace_db.id: - return Trace.push_components(trace_db, - action_executions=action_executions, - rules=rules, - trigger_instances=trigger_instances) + return Trace.push_components( + trace_db, + action_executions=action_executions, + rules=rules, + trigger_instances=trigger_instances, + ) trace_db.action_executions = action_executions trace_db.rules = rules @@ -295,23 +320,25 @@ def get_trace_component_for_action_execution(action_execution_db, liveaction_db) :rtype: ``dict`` """ if not action_execution_db: - raise ValueError('action_execution_db expected.') + raise ValueError("action_execution_db expected.") trace_component = { - 'id': str(action_execution_db.id), - 'ref': str(action_execution_db.action.get('ref', '')) + "id": str(action_execution_db.id), + "ref": str(action_execution_db.action.get("ref", "")), } caused_by = {} parent_context = executions.get_parent_context(liveaction_db=liveaction_db) if liveaction_db and parent_context: - caused_by['type'] = 'action_execution' - caused_by['id'] = liveaction_db.context['parent'].get('execution_id', None) + caused_by["type"] = "action_execution" + caused_by["id"] = liveaction_db.context["parent"].get("execution_id", None) elif action_execution_db.rule and action_execution_db.trigger_instance: # Once RuleEnforcement is available that can be used instead. - caused_by['type'] = 'rule' - caused_by['id'] = '%s:%s' % (action_execution_db.rule['id'], - action_execution_db.trigger_instance['id']) + caused_by["type"] = "rule" + caused_by["id"] = "%s:%s" % ( + action_execution_db.rule["id"], + action_execution_db.trigger_instance["id"], + ) - trace_component['caused_by'] = caused_by + trace_component["caused_by"] = caused_by return trace_component @@ -328,13 +355,13 @@ def get_trace_component_for_rule(rule_db, trigger_instance_db): :rtype: ``dict`` """ trace_component = {} - trace_component = {'id': str(rule_db.id), 'ref': rule_db.ref} + trace_component = {"id": str(rule_db.id), "ref": rule_db.ref} caused_by = {} if trigger_instance_db: # Once RuleEnforcement is available that can be used instead. - caused_by['type'] = 'trigger_instance' - caused_by['id'] = str(trigger_instance_db.id) - trace_component['caused_by'] = caused_by + caused_by["type"] = "trigger_instance" + caused_by["id"] = str(trigger_instance_db.id) + trace_component["caused_by"] = caused_by return trace_component @@ -349,18 +376,20 @@ def get_trace_component_for_trigger_instance(trigger_instance_db): """ trace_component = {} trace_component = { - 'id': str(trigger_instance_db.id), - 'ref': trigger_instance_db.trigger + "id": str(trigger_instance_db.id), + "ref": trigger_instance_db.trigger, } caused_by = {} # Special handling for ACTION_SENSOR_TRIGGER and NOTIFY_TRIGGER where we # know how to maintain the links. - if trigger_instance_db.trigger == ACTION_SENSOR_TRIGGER_REF or \ - trigger_instance_db.trigger == NOTIFY_TRIGGER_REF: - caused_by['type'] = 'action_execution' + if ( + trigger_instance_db.trigger == ACTION_SENSOR_TRIGGER_REF + or trigger_instance_db.trigger == NOTIFY_TRIGGER_REF + ): + caused_by["type"] = "action_execution" # For both action trigger and notidy trigger execution_id is stored in the payload. - caused_by['id'] = trigger_instance_db.payload['execution_id'] - trace_component['caused_by'] = caused_by + caused_by["id"] = trigger_instance_db.payload["execution_id"] + trace_component["caused_by"] = caused_by return trace_component @@ -376,10 +405,12 @@ def _to_trace_component_db(component): """ if not isinstance(component, (six.string_types, dict)): print(type(component)) - raise ValueError('Expected component to be str or dict') + raise ValueError("Expected component to be str or dict") - object_id = component if isinstance(component, six.string_types) else component['id'] - ref = component.get('ref', '') if isinstance(component, dict) else '' - caused_by = component.get('caused_by', {}) if isinstance(component, dict) else {} + object_id = ( + component if isinstance(component, six.string_types) else component["id"] + ) + ref = component.get("ref", "") if isinstance(component, dict) else "" + caused_by = component.get("caused_by", {}) if isinstance(component, dict) else {} return TraceComponentDB(object_id=object_id, ref=ref, caused_by=caused_by) diff --git a/st2common/st2common/services/trigger_dispatcher.py b/st2common/st2common/services/trigger_dispatcher.py index 6843a1eb74c..6343a555b9c 100644 --- a/st2common/st2common/services/trigger_dispatcher.py +++ b/st2common/st2common/services/trigger_dispatcher.py @@ -23,9 +23,7 @@ from st2common.transport.reactor import TriggerDispatcher from st2common.validators.api.reactor import validate_trigger_payload -__all__ = [ - 'TriggerDispatcherService' -] +__all__ = ["TriggerDispatcherService"] class TriggerDispatcherService(object): @@ -37,7 +35,9 @@ def __init__(self, logger): self._logger = logger self._dispatcher = TriggerDispatcher(self._logger) - def dispatch(self, trigger, payload=None, trace_tag=None, throw_on_validation_error=False): + def dispatch( + self, trigger, payload=None, trace_tag=None, throw_on_validation_error=False + ): """ Method which dispatches the trigger. @@ -56,12 +56,19 @@ def dispatch(self, trigger, payload=None, trace_tag=None, throw_on_validation_er """ # empty strings trace_context = TraceContext(trace_tag=trace_tag) if trace_tag else None - self._logger.debug('Added trace_context %s to trigger %s.', trace_context, trigger) - return self.dispatch_with_context(trigger, payload=payload, trace_context=trace_context, - throw_on_validation_error=throw_on_validation_error) - - def dispatch_with_context(self, trigger, payload=None, trace_context=None, - throw_on_validation_error=False): + self._logger.debug( + "Added trace_context %s to trigger %s.", trace_context, trigger + ) + return self.dispatch_with_context( + trigger, + payload=payload, + trace_context=trace_context, + throw_on_validation_error=throw_on_validation_error, + ) + + def dispatch_with_context( + self, trigger, payload=None, trace_context=None, throw_on_validation_error=False + ): """ Method which dispatches the trigger. @@ -81,18 +88,25 @@ def dispatch_with_context(self, trigger, payload=None, trace_context=None, # Note: We perform validation even if it's disabled in the config so we can at least warn # the user if validation fals (but not throw if it's disabled) try: - validate_trigger_payload(trigger_type_ref=trigger, payload=payload, - throw_on_inexistent_trigger=True) + validate_trigger_payload( + trigger_type_ref=trigger, + payload=payload, + throw_on_inexistent_trigger=True, + ) except (ValidationError, ValueError, Exception) as e: - self._logger.warn('Failed to validate payload (%s) for trigger "%s": %s' % - (str(payload), trigger, six.text_type(e))) + self._logger.warn( + 'Failed to validate payload (%s) for trigger "%s": %s' + % (str(payload), trigger, six.text_type(e)) + ) # If validation is disabled, still dispatch a trigger even if it failed validation # This condition prevents unexpected restriction. if cfg.CONF.system.validate_trigger_payload: - msg = ('Trigger payload validation failed and validation is enabled, not ' - 'dispatching a trigger "%s" (%s): %s' % (trigger, str(payload), - six.text_type(e))) + msg = ( + "Trigger payload validation failed and validation is enabled, not " + 'dispatching a trigger "%s" (%s): %s' + % (trigger, str(payload), six.text_type(e)) + ) if throw_on_validation_error: raise ValueError(msg) @@ -100,5 +114,7 @@ def dispatch_with_context(self, trigger, payload=None, trace_context=None, self._logger.warn(msg) return None - self._logger.debug('Dispatching trigger %s with payload %s.', trigger, payload) - return self._dispatcher.dispatch(trigger, payload=payload, trace_context=trace_context) + self._logger.debug("Dispatching trigger %s with payload %s.", trigger, payload) + return self._dispatcher.dispatch( + trigger, payload=payload, trace_context=trace_context + ) diff --git a/st2common/st2common/services/triggers.py b/st2common/st2common/services/triggers.py index 6448aa25334..bbdce26b815 100644 --- a/st2common/st2common/services/triggers.py +++ b/st2common/st2common/services/triggers.py @@ -23,25 +23,22 @@ from st2common.exceptions.triggers import TriggerDoesNotExistException from st2common.exceptions.db import StackStormDBObjectNotFoundError from st2common.exceptions.db import StackStormDBObjectConflictError -from st2common.models.api.trigger import (TriggerAPI, TriggerTypeAPI) +from st2common.models.api.trigger import TriggerAPI, TriggerTypeAPI from st2common.models.system.common import ResourceReference -from st2common.persistence.trigger import (Trigger, TriggerType) +from st2common.persistence.trigger import Trigger, TriggerType __all__ = [ - 'add_trigger_models', - - 'get_trigger_db_by_ref', - 'get_trigger_db_by_id', - 'get_trigger_db_by_uid', - 'get_trigger_db_by_ref_or_dict', - 'get_trigger_db_given_type_and_params', - 'get_trigger_type_db', - - 'create_trigger_db', - 'create_trigger_type_db', - - 'create_or_update_trigger_db', - 'create_or_update_trigger_type_db' + "add_trigger_models", + "get_trigger_db_by_ref", + "get_trigger_db_by_id", + "get_trigger_db_by_uid", + "get_trigger_db_by_ref_or_dict", + "get_trigger_db_given_type_and_params", + "get_trigger_type_db", + "create_trigger_db", + "create_trigger_type_db", + "create_or_update_trigger_db", + "create_or_update_trigger_type_db", ] LOG = logging.getLogger(__name__) @@ -50,8 +47,7 @@ def get_trigger_db_given_type_and_params(type=None, parameters=None): try: parameters = parameters or {} - trigger_dbs = Trigger.query(type=type, - parameters=parameters) + trigger_dbs = Trigger.query(type=type, parameters=parameters) trigger_db = trigger_dbs[0] if len(trigger_dbs) > 0 else None @@ -59,23 +55,24 @@ def get_trigger_db_given_type_and_params(type=None, parameters=None): # pymongo and mongoengine # Work around for cron-timer when in some scenarios finding an object fails when Python # value types are unicode :/ - is_cron_trigger = (type == CRON_TIMER_TRIGGER_REF) + is_cron_trigger = type == CRON_TIMER_TRIGGER_REF has_parameters = bool(parameters) if not trigger_db and six.PY2 and is_cron_trigger and has_parameters: non_unicode_literal_parameters = {} for key, value in six.iteritems(parameters): - key = key.encode('utf-8') + key = key.encode("utf-8") if isinstance(value, six.text_type): # We only encode unicode to str - value = value.encode('utf-8') + value = value.encode("utf-8") non_unicode_literal_parameters[key] = value parameters = non_unicode_literal_parameters - trigger_dbs = Trigger.query(type=type, - parameters=non_unicode_literal_parameters).no_cache() + trigger_dbs = Trigger.query( + type=type, parameters=non_unicode_literal_parameters + ).no_cache() # Note: We need to directly access the object, using len or accessing the query set # twice won't work - there seems to bug a bug with cursor where accessing it twice @@ -93,8 +90,14 @@ def get_trigger_db_given_type_and_params(type=None, parameters=None): return trigger_db except StackStormDBObjectNotFoundError as e: - LOG.debug('Database lookup for type="%s" parameters="%s" resulted ' + - 'in exception : %s.', type, parameters, e, exc_info=True) + LOG.debug( + 'Database lookup for type="%s" parameters="%s" resulted ' + + "in exception : %s.", + type, + parameters, + e, + exc_info=True, + ) return None @@ -109,26 +112,30 @@ def get_trigger_db_by_ref_or_dict(trigger): else: # If id / uid is available we try to look up Trigger by id. This way we can avoid bug in # pymongo / mongoengine related to "parameters" dictionary lookups - trigger_id = trigger.get('id', None) - trigger_uid = trigger.get('uid', None) + trigger_id = trigger.get("id", None) + trigger_uid = trigger.get("uid", None) # TODO: Remove parameters dictionary look up when we can confirm each trigger dictionary # passed to this method always contains id or uid if trigger_id: - LOG.debug('Looking up TriggerDB by id: %s', trigger_id) + LOG.debug("Looking up TriggerDB by id: %s", trigger_id) trigger_db = get_trigger_db_by_id(id=trigger_id) elif trigger_uid: - LOG.debug('Looking up TriggerDB by uid: %s', trigger_uid) + LOG.debug("Looking up TriggerDB by uid: %s", trigger_uid) trigger_db = get_trigger_db_by_uid(uid=trigger_uid) else: # Last resort - look it up by parameters - trigger_type = trigger.get('type', None) - parameters = trigger.get('parameters', {}) - - LOG.debug('Looking up TriggerDB by type and parameters: type=%s, parameters=%s', - trigger_type, parameters) - trigger_db = get_trigger_db_given_type_and_params(type=trigger_type, - parameters=parameters) + trigger_type = trigger.get("type", None) + parameters = trigger.get("parameters", {}) + + LOG.debug( + "Looking up TriggerDB by type and parameters: type=%s, parameters=%s", + trigger_type, + parameters, + ) + trigger_db = get_trigger_db_given_type_and_params( + type=trigger_type, parameters=parameters + ) return trigger_db @@ -145,8 +152,12 @@ def get_trigger_db_by_id(id): try: return Trigger.get_by_id(id) except StackStormDBObjectNotFoundError as e: - LOG.debug('Database lookup for id="%s" resulted in exception : %s.', - id, e, exc_info=True) + LOG.debug( + 'Database lookup for id="%s" resulted in exception : %s.', + id, + e, + exc_info=True, + ) return None @@ -163,8 +174,12 @@ def get_trigger_db_by_uid(uid): try: return Trigger.get_by_uid(uid) except StackStormDBObjectNotFoundError as e: - LOG.debug('Database lookup for uid="%s" resulted in exception : %s.', - uid, e, exc_info=True) + LOG.debug( + 'Database lookup for uid="%s" resulted in exception : %s.', + uid, + e, + exc_info=True, + ) return None @@ -181,8 +196,12 @@ def get_trigger_db_by_ref(ref): try: return Trigger.get_by_ref(ref) except StackStormDBObjectNotFoundError as e: - LOG.debug('Database lookup for ref="%s" resulted ' + - 'in exception : %s.', ref, e, exc_info=True) + LOG.debug( + 'Database lookup for ref="%s" resulted ' + "in exception : %s.", + ref, + e, + exc_info=True, + ) return None @@ -192,16 +211,17 @@ def _get_trigger_db(trigger): # XXX: Do not make this method public. if isinstance(trigger, dict): - name = trigger.get('name', None) - pack = trigger.get('pack', None) + name = trigger.get("name", None) + pack = trigger.get("pack", None) if name and pack: ref = ResourceReference.to_string_reference(name=name, pack=pack) return get_trigger_db_by_ref(ref) - return get_trigger_db_given_type_and_params(type=trigger['type'], - parameters=trigger.get('parameters', {})) + return get_trigger_db_given_type_and_params( + type=trigger["type"], parameters=trigger.get("parameters", {}) + ) else: - raise Exception('Unrecognized object') + raise Exception("Unrecognized object") def get_trigger_type_db(ref): @@ -216,8 +236,12 @@ def get_trigger_type_db(ref): try: return TriggerType.get_by_ref(ref) except StackStormDBObjectNotFoundError as e: - LOG.debug('Database lookup for ref="%s" resulted ' + - 'in exception : %s.', ref, e, exc_info=True) + LOG.debug( + 'Database lookup for ref="%s" resulted ' + "in exception : %s.", + ref, + e, + exc_info=True, + ) return None @@ -225,22 +249,23 @@ def get_trigger_type_db(ref): def _get_trigger_dict_given_rule(rule): trigger = rule.trigger trigger_dict = {} - triggertype_ref = ResourceReference.from_string_reference(trigger.get('type')) - trigger_dict['pack'] = trigger_dict.get('pack', triggertype_ref.pack) - trigger_dict['type'] = triggertype_ref.ref - trigger_dict['parameters'] = rule.trigger.get('parameters', {}) + triggertype_ref = ResourceReference.from_string_reference(trigger.get("type")) + trigger_dict["pack"] = trigger_dict.get("pack", triggertype_ref.pack) + trigger_dict["type"] = triggertype_ref.ref + trigger_dict["parameters"] = rule.trigger.get("parameters", {}) return trigger_dict def create_trigger_db(trigger_api): # TODO: This is used only in trigger API controller. We should get rid of this. - trigger_ref = ResourceReference.to_string_reference(name=trigger_api.name, - pack=trigger_api.pack) + trigger_ref = ResourceReference.to_string_reference( + name=trigger_api.name, pack=trigger_api.pack + ) trigger_db = get_trigger_db_by_ref(trigger_ref) if not trigger_db: trigger_db = TriggerAPI.to_model(trigger_api) - LOG.debug('Verified trigger and formulated TriggerDB=%s', trigger_db) + LOG.debug("Verified trigger and formulated TriggerDB=%s", trigger_db) trigger_db = Trigger.add_or_update(trigger_db) return trigger_db @@ -269,15 +294,16 @@ def create_or_update_trigger_db(trigger, log_not_unique_error_as_debug=False): if is_update: trigger_db.id = existing_trigger_db.id - trigger_db = Trigger.add_or_update(trigger_db, - log_not_unique_error_as_debug=log_not_unique_error_as_debug) + trigger_db = Trigger.add_or_update( + trigger_db, log_not_unique_error_as_debug=log_not_unique_error_as_debug + ) - extra = {'trigger_db': trigger_db} + extra = {"trigger_db": trigger_db} if is_update: - LOG.audit('Trigger updated. Trigger.id=%s' % (trigger_db.id), extra=extra) + LOG.audit("Trigger updated. Trigger.id=%s" % (trigger_db.id), extra=extra) else: - LOG.audit('Trigger created. Trigger.id=%s' % (trigger_db.id), extra=extra) + LOG.audit("Trigger created. Trigger.id=%s" % (trigger_db.id), extra=extra) return trigger_db @@ -288,10 +314,11 @@ def create_trigger_db_from_rule(rule): # For simple triggertypes (triggertype with no parameters), we create a trigger when # registering triggertype. So if we hit the case that there is no trigger in db but # parameters is empty, then this case is a run time error. - if not trigger_dict.get('parameters', {}) and not existing_trigger_db: + if not trigger_dict.get("parameters", {}) and not existing_trigger_db: raise TriggerDoesNotExistException( - 'A simple trigger should have been created when registering ' - 'triggertype. Cannot create trigger: %s.' % (trigger_dict)) + "A simple trigger should have been created when registering " + "triggertype. Cannot create trigger: %s." % (trigger_dict) + ) if not existing_trigger_db: trigger_db = create_or_update_trigger_db(trigger_dict) @@ -316,7 +343,7 @@ def increment_trigger_ref_count(rule_api): trigger_dict = _get_trigger_dict_given_rule(rule_api) # Special reference counting for trigger with parameters. - if trigger_dict.get('parameters', None): + if trigger_dict.get("parameters", None): trigger_db = _get_trigger_db(trigger_dict) Trigger.update(trigger_db, inc__ref_count=1) @@ -326,7 +353,7 @@ def cleanup_trigger_db_for_rule(rule_db): existing_trigger_db = get_trigger_db_by_ref(rule_db.trigger) if not existing_trigger_db or not existing_trigger_db.parameters: # nothing to be done here so moving on. - LOG.debug('ref_count decrement for %s not required.', existing_trigger_db) + LOG.debug("ref_count decrement for %s not required.", existing_trigger_db) return Trigger.update(existing_trigger_db, dec__ref_count=1) Trigger.delete_if_unreferenced(existing_trigger_db) @@ -350,15 +377,17 @@ def create_trigger_type_db(trigger_type, log_not_unique_error_as_debug=False): """ trigger_type_api = TriggerTypeAPI(**trigger_type) trigger_type_api.validate() - ref = ResourceReference.to_string_reference(name=trigger_type_api.name, - pack=trigger_type_api.pack) + ref = ResourceReference.to_string_reference( + name=trigger_type_api.name, pack=trigger_type_api.pack + ) trigger_type_db = get_trigger_type_db(ref) if not trigger_type_db: trigger_type_db = TriggerTypeAPI.to_model(trigger_type_api) - LOG.debug('verified trigger and formulated TriggerDB=%s', trigger_type_db) - trigger_type_db = TriggerType.add_or_update(trigger_type_db, - log_not_unique_error_as_debug=log_not_unique_error_as_debug) + LOG.debug("verified trigger and formulated TriggerDB=%s", trigger_type_db) + trigger_type_db = TriggerType.add_or_update( + trigger_type_db, log_not_unique_error_as_debug=log_not_unique_error_as_debug + ) return trigger_type_db @@ -378,16 +407,21 @@ def create_shadow_trigger(trigger_type_db, log_not_unique_error_as_debug=False): trigger_type_ref = trigger_type_db.get_reference().ref if trigger_type_db.parameters_schema: - LOG.debug('Skip shadow trigger for TriggerType with parameters %s.', trigger_type_ref) + LOG.debug( + "Skip shadow trigger for TriggerType with parameters %s.", trigger_type_ref + ) return None - trigger = {'name': trigger_type_db.name, - 'pack': trigger_type_db.pack, - 'type': trigger_type_ref, - 'parameters': {}} + trigger = { + "name": trigger_type_db.name, + "pack": trigger_type_db.pack, + "type": trigger_type_ref, + "parameters": {}, + } - return create_or_update_trigger_db(trigger, - log_not_unique_error_as_debug=log_not_unique_error_as_debug) + return create_or_update_trigger_db( + trigger, log_not_unique_error_as_debug=log_not_unique_error_as_debug + ) def create_or_update_trigger_type_db(trigger_type, log_not_unique_error_as_debug=False): @@ -412,8 +446,9 @@ def create_or_update_trigger_type_db(trigger_type, log_not_unique_error_as_debug trigger_type_api.validate() trigger_type_api = TriggerTypeAPI.to_model(trigger_type_api) - ref = ResourceReference.to_string_reference(name=trigger_type_api.name, - pack=trigger_type_api.pack) + ref = ResourceReference.to_string_reference( + name=trigger_type_api.name, pack=trigger_type_api.pack + ) existing_trigger_type_db = get_trigger_type_db(ref) if existing_trigger_type_db: @@ -425,8 +460,10 @@ def create_or_update_trigger_type_db(trigger_type, log_not_unique_error_as_debug trigger_type_api.id = existing_trigger_type_db.id try: - trigger_type_db = TriggerType.add_or_update(trigger_type_api, - log_not_unique_error_as_debug=log_not_unique_error_as_debug) + trigger_type_db = TriggerType.add_or_update( + trigger_type_api, + log_not_unique_error_as_debug=log_not_unique_error_as_debug, + ) except StackStormDBObjectConflictError: # Operation is idempotent and trigger could have already been created by # another process. Ignore object already exists because it simply means @@ -434,26 +471,37 @@ def create_or_update_trigger_type_db(trigger_type, log_not_unique_error_as_debug trigger_type_db = get_trigger_type_db(ref) is_update = True - extra = {'trigger_type_db': trigger_type_db} + extra = {"trigger_type_db": trigger_type_db} if is_update: - LOG.audit('TriggerType updated. TriggerType.id=%s' % (trigger_type_db.id), extra=extra) + LOG.audit( + "TriggerType updated. TriggerType.id=%s" % (trigger_type_db.id), extra=extra + ) else: - LOG.audit('TriggerType created. TriggerType.id=%s' % (trigger_type_db.id), extra=extra) + LOG.audit( + "TriggerType created. TriggerType.id=%s" % (trigger_type_db.id), extra=extra + ) return trigger_type_db -def _create_trigger_type(pack, name, description=None, payload_schema=None, - parameters_schema=None, tags=None, metadata_file=None): +def _create_trigger_type( + pack, + name, + description=None, + payload_schema=None, + parameters_schema=None, + tags=None, + metadata_file=None, +): trigger_type = { - 'name': name, - 'pack': pack, - 'description': description, - 'payload_schema': payload_schema, - 'parameters_schema': parameters_schema, - 'tags': tags, - 'metadata_file': metadata_file + "name": name, + "pack": pack, + "description": description, + "payload_schema": payload_schema, + "parameters_schema": parameters_schema, + "tags": tags, + "metadata_file": metadata_file, } return create_or_update_trigger_type_db(trigger_type=trigger_type) @@ -464,11 +512,12 @@ def _validate_trigger_type(trigger_type): XXX: We need validator objects that define the required and optional fields. For now, manually check them. """ - required_fields = ['name'] + required_fields = ["name"] for field in required_fields: if field not in trigger_type: - raise TriggerTypeRegistrationException('Invalid trigger type. Missing field "%s"' % - (field)) + raise TriggerTypeRegistrationException( + 'Invalid trigger type. Missing field "%s"' % (field) + ) def _create_trigger(trigger_type): @@ -476,37 +525,46 @@ def _create_trigger(trigger_type): :param trigger_type: TriggerType db object. :type trigger_type: :class:`TriggerTypeDB` """ - if hasattr(trigger_type, 'parameters_schema') and not trigger_type['parameters_schema']: + if ( + hasattr(trigger_type, "parameters_schema") + and not trigger_type["parameters_schema"] + ): trigger_dict = { - 'name': trigger_type.name, - 'pack': trigger_type.pack, - 'type': trigger_type.get_reference().ref + "name": trigger_type.name, + "pack": trigger_type.pack, + "type": trigger_type.get_reference().ref, } try: return create_or_update_trigger_db(trigger=trigger_dict) except: - LOG.exception('Validation failed for Trigger=%s.', trigger_dict) + LOG.exception("Validation failed for Trigger=%s.", trigger_dict) raise TriggerTypeRegistrationException( - 'Unable to create Trigger for TriggerType=%s.' % trigger_type.name) + "Unable to create Trigger for TriggerType=%s." % trigger_type.name + ) else: - LOG.debug('Won\'t create Trigger object as TriggerType %s expects ' + - 'parameters.', trigger_type) + LOG.debug( + "Won't create Trigger object as TriggerType %s expects " + "parameters.", + trigger_type, + ) return None def _add_trigger_models(trigger_type): - pack = trigger_type['pack'] - description = trigger_type['description'] if 'description' in trigger_type else '' - payload_schema = trigger_type['payload_schema'] if 'payload_schema' in trigger_type else {} - parameters_schema = trigger_type['parameters_schema'] \ - if 'parameters_schema' in trigger_type else {} - tags = trigger_type.get('tags', []) - metadata_file = trigger_type.get('metadata_file', None) + pack = trigger_type["pack"] + description = trigger_type["description"] if "description" in trigger_type else "" + payload_schema = ( + trigger_type["payload_schema"] if "payload_schema" in trigger_type else {} + ) + parameters_schema = ( + trigger_type["parameters_schema"] if "parameters_schema" in trigger_type else {} + ) + tags = trigger_type.get("tags", []) + metadata_file = trigger_type.get("metadata_file", None) trigger_type = _create_trigger_type( pack=pack, - name=trigger_type['name'], + name=trigger_type["name"], description=description, payload_schema=payload_schema, parameters_schema=parameters_schema, @@ -526,8 +584,13 @@ def add_trigger_models(trigger_types): :rtype: ``list`` of ``tuple`` (trigger_type, trigger) """ - [r for r in (_validate_trigger_type(trigger_type) - for trigger_type in trigger_types) if r is not None] + [ + r + for r in ( + _validate_trigger_type(trigger_type) for trigger_type in trigger_types + ) + if r is not None + ] result = [] for trigger_type in trigger_types: diff --git a/st2common/st2common/services/triggerwatcher.py b/st2common/st2common/services/triggerwatcher.py index 4830c349f4f..b82a46043ad 100644 --- a/st2common/st2common/services/triggerwatcher.py +++ b/st2common/st2common/services/triggerwatcher.py @@ -33,8 +33,15 @@ class TriggerWatcher(ConsumerMixin): sleep_interval = 0 # sleep to co-operatively yield after processing each message - def __init__(self, create_handler, update_handler, delete_handler, - trigger_types=None, queue_suffix=None, exclusive=False): + def __init__( + self, + create_handler, + update_handler, + delete_handler, + trigger_types=None, + queue_suffix=None, + exclusive=False, + ): """ :param create_handler: Function which is called on TriggerDB create event. :type create_handler: ``callable`` @@ -69,39 +76,49 @@ def __init__(self, create_handler, update_handler, delete_handler, self._handlers = { publishers.CREATE_RK: create_handler, publishers.UPDATE_RK: update_handler, - publishers.DELETE_RK: delete_handler + publishers.DELETE_RK: delete_handler, } def get_consumers(self, Consumer, channel): - return [Consumer(queues=[self._trigger_watch_q], - accept=['pickle'], - callbacks=[self.process_task])] + return [ + Consumer( + queues=[self._trigger_watch_q], + accept=["pickle"], + callbacks=[self.process_task], + ) + ] def process_task(self, body, message): - LOG.debug('process_task') - LOG.debug(' body: %s', body) - LOG.debug(' message.properties: %s', message.properties) - LOG.debug(' message.delivery_info: %s', message.delivery_info) + LOG.debug("process_task") + LOG.debug(" body: %s", body) + LOG.debug(" message.properties: %s", message.properties) + LOG.debug(" message.delivery_info: %s", message.delivery_info) - routing_key = message.delivery_info.get('routing_key', '') + routing_key = message.delivery_info.get("routing_key", "") handler = self._handlers.get(routing_key, None) try: if not handler: - LOG.debug('Skipping message %s as no handler was found.', message) + LOG.debug("Skipping message %s as no handler was found.", message) return - trigger_type = getattr(body, 'type', None) + trigger_type = getattr(body, "type", None) if self._trigger_types and trigger_type not in self._trigger_types: - LOG.debug('Skipping message %s since trigger_type doesn\'t match (type=%s)', - message, trigger_type) + LOG.debug( + "Skipping message %s since trigger_type doesn't match (type=%s)", + message, + trigger_type, + ) return try: handler(body) except Exception as e: - LOG.exception('Handling failed. Message body: %s. Exception: %s', - body, six.text_type(e)) + LOG.exception( + "Handling failed. Message body: %s. Exception: %s", + body, + six.text_type(e), + ) finally: message.ack() @@ -113,7 +130,7 @@ def start(self): self._updates_thread = concurrency.spawn(self.run) self._load_thread = concurrency.spawn(self._load_triggers_from_db) except: - LOG.exception('Failed to start watcher.') + LOG.exception("Failed to start watcher.") self.connection.release() def stop(self): @@ -128,8 +145,9 @@ def stop(self): # waiting for a message on the queue. def on_consume_end(self, connection, channel): - super(TriggerWatcher, self).on_consume_end(connection=connection, - channel=channel) + super(TriggerWatcher, self).on_consume_end( + connection=connection, channel=channel + ) concurrency.sleep(seconds=self.sleep_interval) def on_iteration(self): @@ -139,13 +157,16 @@ def on_iteration(self): def _load_triggers_from_db(self): for trigger_type in self._trigger_types: for trigger in Trigger.query(type=trigger_type): - LOG.debug('Found existing trigger: %s in db.' % trigger) + LOG.debug("Found existing trigger: %s in db." % trigger) self._handlers[publishers.CREATE_RK](trigger) @staticmethod def _get_queue(queue_suffix, exclusive): - queue_name = queue_utils.get_queue_name(queue_name_base='st2.trigger.watch', - queue_name_suffix=queue_suffix, - add_random_uuid_to_suffix=True - ) - return reactor.get_trigger_cud_queue(queue_name, routing_key='#', exclusive=exclusive) + queue_name = queue_utils.get_queue_name( + queue_name_base="st2.trigger.watch", + queue_name_suffix=queue_suffix, + add_random_uuid_to_suffix=True, + ) + return reactor.get_trigger_cud_queue( + queue_name, routing_key="#", exclusive=exclusive + ) diff --git a/st2common/st2common/services/workflows.py b/st2common/st2common/services/workflows.py index db64c681b19..16ddcb5a024 100644 --- a/st2common/st2common/services/workflows.py +++ b/st2common/st2common/services/workflows.py @@ -54,59 +54,61 @@ LOG = logging.getLogger(__name__) LOG_FUNCTIONS = { - 'audit': LOG.audit, - 'debug': LOG.debug, - 'info': LOG.info, - 'warning': LOG.warning, - 'error': LOG.error, - 'critical': LOG.critical, + "audit": LOG.audit, + "debug": LOG.debug, + "info": LOG.info, + "warning": LOG.warning, + "error": LOG.error, + "critical": LOG.critical, } -def update_progress(wf_ex_db, message, severity='info', log=True, stream=True): +def update_progress(wf_ex_db, message, severity="info", log=True, stream=True): if not wf_ex_db: return if log and severity in LOG_FUNCTIONS: - LOG_FUNCTIONS[severity]('[%s] %s', wf_ex_db.context['st2']['action_execution_id'], message) + LOG_FUNCTIONS[severity]( + "[%s] %s", wf_ex_db.context["st2"]["action_execution_id"], message + ) if stream: ac_svc.store_execution_output_data_ex( - wf_ex_db.context['st2']['action_execution_id'], - wf_ex_db.context['st2']['action'], - wf_ex_db.context['st2']['runner'], - '%s\n' % message, + wf_ex_db.context["st2"]["action_execution_id"], + wf_ex_db.context["st2"]["action"], + wf_ex_db.context["st2"]["runner"], + "%s\n" % message, ) def is_action_execution_under_workflow_context(ac_ex_db): # The action execution is executed under the context of a workflow # if it contains the orquesta key in its context dictionary. - return ac_ex_db.context and 'orquesta' in ac_ex_db.context + return ac_ex_db.context and "orquesta" in ac_ex_db.context def format_inspection_result(result): errors = [] categories = { - 'contents': 'content', - 'context': 'context', - 'expressions': 'expression', - 'semantics': 'semantic', - 'syntax': 'syntax' + "contents": "content", + "context": "context", + "expressions": "expression", + "semantics": "semantic", + "syntax": "syntax", } # For context and expression errors, rename the attribute from type to language. - for category in ['context', 'expressions']: + for category in ["context", "expressions"]: for entry in result.get(category, []): - if 'language' not in entry: - entry['language'] = entry['type'] - del entry['type'] + if "language" not in entry: + entry["language"] = entry["type"] + del entry["type"] # For all categories, put the category value in the type attribute. for category, entries in six.iteritems(result): for entry in entries: - entry['type'] = categories[category] + entry["type"] = categories[category] errors.append(entry) return errors @@ -121,7 +123,7 @@ def inspect(wf_spec, st2_ctx, raise_exception=True): errors += inspect_task_contents(wf_spec) # Sort the list of errors by type and path. - errors = sorted(errors, key=lambda e: (e['type'], e['schema_path'])) + errors = sorted(errors, key=lambda e: (e["type"], e["schema_path"])) if errors and raise_exception: raise orquesta_exc.WorkflowInspectionError(errors) @@ -131,10 +133,10 @@ def inspect(wf_spec, st2_ctx, raise_exception=True): def inspect_task_contents(wf_spec): result = [] - spec_path = 'tasks' - schema_path = 'properties.tasks.patternProperties.^\\w+$' - action_schema_path = schema_path + '.properties.action' - action_input_schema_path = schema_path + '.properties.input' + spec_path = "tasks" + schema_path = "properties.tasks.patternProperties.^\\w+$" + action_schema_path = schema_path + ".properties.action" + action_input_schema_path = schema_path + ".properties.input" def is_action_an_expression(action): if isinstance(action, six.string_types): @@ -143,9 +145,9 @@ def is_action_an_expression(action): return True for task_name, task_spec in six.iteritems(wf_spec.tasks): - action_ref = getattr(task_spec, 'action', None) - action_spec_path = spec_path + '.' + task_name + '.action' - action_input_spec_path = spec_path + '.' + task_name + '.input' + action_ref = getattr(task_spec, "action", None) + action_spec_path = spec_path + "." + task_name + ".action" + action_input_spec_path = spec_path + "." + task_name + ".input" # Move on if action is empty or an expression. if not action_ref or is_action_an_expression(action_ref): @@ -154,10 +156,11 @@ def is_action_an_expression(action): # Check that the format of the action is a valid resource reference. if not sys_models.ResourceReference.is_resource_reference(action_ref): entry = { - 'type': 'content', - 'message': 'The action reference "%s" is not formatted correctly.' % action_ref, - 'spec_path': action_spec_path, - 'schema_path': action_schema_path + "type": "content", + "message": 'The action reference "%s" is not formatted correctly.' + % action_ref, + "spec_path": action_spec_path, + "schema_path": action_schema_path, } result.append(entry) @@ -166,31 +169,37 @@ def is_action_an_expression(action): # Check that the action is registered in the database. if not action_utils.get_action_by_ref(ref=action_ref): entry = { - 'type': 'content', - 'message': 'The action "%s" is not registered in the database.' % action_ref, - 'spec_path': action_spec_path, - 'schema_path': action_schema_path + "type": "content", + "message": 'The action "%s" is not registered in the database.' + % action_ref, + "spec_path": action_spec_path, + "schema_path": action_schema_path, } result.append(entry) continue # Check the action parameters. - params = getattr(task_spec, 'input', None) or {} + params = getattr(task_spec, "input", None) or {} if params and not isinstance(params, dict): continue - requires, unexpected = action_param_utils.validate_action_parameters(action_ref, params) + requires, unexpected = action_param_utils.validate_action_parameters( + action_ref, params + ) for param in requires: - message = 'Action "%s" is missing required input "%s".' % (action_ref, param) + message = 'Action "%s" is missing required input "%s".' % ( + action_ref, + param, + ) entry = { - 'type': 'content', - 'message': message, - 'spec_path': action_input_spec_path, - 'schema_path': action_input_schema_path + "type": "content", + "message": message, + "spec_path": action_input_spec_path, + "schema_path": action_input_schema_path, } result.append(entry) @@ -199,10 +208,10 @@ def is_action_an_expression(action): message = 'Action "%s" has unexpected input "%s".' % (action_ref, param) entry = { - 'type': 'content', - 'message': message, - 'spec_path': action_input_spec_path + '.' + param, - 'schema_path': action_input_schema_path + '.patternProperties.^\\w+$' + "type": "content", + "message": message, + "spec_path": action_input_spec_path + "." + param, + "schema_path": action_input_schema_path + ".patternProperties.^\\w+$", } result.append(entry) @@ -211,35 +220,35 @@ def is_action_an_expression(action): def request(wf_def, ac_ex_db, st2_ctx, notify_cfg=None): - LOG.info('[%s] Processing action execution request for workflow.', str(ac_ex_db.id)) + LOG.info("[%s] Processing action execution request for workflow.", str(ac_ex_db.id)) # Load workflow definition into workflow spec model. - spec_module = specs_loader.get_spec_module('native') + spec_module = specs_loader.get_spec_module("native") wf_spec = spec_module.instantiate(wf_def) # Inspect the workflow spec. inspect(wf_spec, st2_ctx, raise_exception=True) # Identify the action to execute. - action_db = action_utils.get_action_by_ref(ref=ac_ex_db.action['ref']) + action_db = action_utils.get_action_by_ref(ref=ac_ex_db.action["ref"]) if not action_db: - error = 'Unable to find action "%s".' % ac_ex_db.action['ref'] + error = 'Unable to find action "%s".' % ac_ex_db.action["ref"] raise ac_exc.InvalidActionReferencedException(error) # Identify the runner for the action. - runner_type_db = action_utils.get_runnertype_by_name(action_db.runner_type['name']) + runner_type_db = action_utils.get_runnertype_by_name(action_db.runner_type["name"]) # Render action execution parameters. runner_params, action_params = param_utils.render_final_params( runner_type_db.runner_parameters, action_db.parameters, ac_ex_db.parameters, - ac_ex_db.context + ac_ex_db.context, ) # Instantiate the workflow conductor. - conductor_params = {'inputs': action_params, 'context': st2_ctx} + conductor_params = {"inputs": action_params, "context": st2_ctx} conductor = conducting.WorkflowConductor(wf_spec, **conductor_params) # Serialize the conductor which initializes some internal values. @@ -248,33 +257,32 @@ def request(wf_def, ac_ex_db, st2_ctx, notify_cfg=None): # Create a record for workflow execution. wf_ex_db = wf_db_models.WorkflowExecutionDB( action_execution=str(ac_ex_db.id), - spec=data['spec'], - graph=data['graph'], - input=data['input'], - context=data['context'], - state=data['state'], - status=data['state']['status'], - output=data['output'], - errors=data['errors'] + spec=data["spec"], + graph=data["graph"], + input=data["input"], + context=data["context"], + state=data["state"], + status=data["state"]["status"], + output=data["output"], + errors=data["errors"], ) # Inspect that the list of tasks in the notify parameter exist in the workflow spec. - if runner_params.get('notify'): - invalid_tasks = list(set(runner_params.get('notify')) - set(wf_spec.tasks.keys())) + if runner_params.get("notify"): + invalid_tasks = list( + set(runner_params.get("notify")) - set(wf_spec.tasks.keys()) + ) if invalid_tasks: raise wf_exc.WorkflowExecutionException( - 'The following tasks in the notify parameter do not exist ' - 'in the workflow definition: %s.' % ', '.join(invalid_tasks) + "The following tasks in the notify parameter do not exist " + "in the workflow definition: %s." % ", ".join(invalid_tasks) ) # Write notify instruction to record. if notify_cfg: # Set up the notify instruction in the workflow execution record. - wf_ex_db.notify = { - 'config': notify_cfg, - 'tasks': runner_params.get('notify') - } + wf_ex_db.notify = {"config": notify_cfg, "tasks": runner_params.get("notify")} # Insert new record into the database and do not publish to the message bus yet. wf_ex_db = wf_db_access.WorkflowExecution.insert(wf_ex_db, publish=False) @@ -286,12 +294,12 @@ def request(wf_def, ac_ex_db, st2_ctx, notify_cfg=None): # Set the initial workflow status to requested. conductor.request_workflow_status(statuses.REQUESTED) data = conductor.serialize() - wf_ex_db.state = data['state'] - wf_ex_db.status = data['state']['status'] + wf_ex_db.state = data["state"] + wf_ex_db.status = data["state"]["status"] # Put the ID of the workflow execution record in the context. - wf_ex_db.context['st2']['workflow_execution_id'] = str(wf_ex_db.id) - wf_ex_db.state['contexts'][0]['st2']['workflow_execution_id'] = str(wf_ex_db.id) + wf_ex_db.context["st2"]["workflow_execution_id"] = str(wf_ex_db.id) + wf_ex_db.state["contexts"][0]["st2"]["workflow_execution_id"] = str(wf_ex_db.id) # Update the workflow execution record. wf_ex_db = wf_db_access.WorkflowExecution.update(wf_ex_db, publish=False) @@ -308,15 +316,17 @@ def request(wf_def, ac_ex_db, st2_ctx, notify_cfg=None): @retrying.retry( retry_on_exception=wf_exc.retry_on_transient_db_errors, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) @retrying.retry( retry_on_exception=wf_exc.retry_on_connection_errors, stop_max_delay=cfg.CONF.workflow_engine.retry_stop_max_msec, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) def request_pause(ac_ex_db): wf_ac_ex_id = str(ac_ex_db.id) - LOG.info('[%s] Processing pause request for workflow.', wf_ac_ex_id) + LOG.info("[%s] Processing pause request for workflow.", wf_ac_ex_id) wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id)) @@ -343,7 +353,7 @@ def request_pause(ac_ex_db): wf_ex_db.state = conductor.workflow_state.serialize() wf_ex_db = wf_db_access.WorkflowExecution.update(wf_ex_db, publish=False) - LOG.info('[%s] Completed processing pause request for workflow.', wf_ac_ex_id) + LOG.info("[%s] Completed processing pause request for workflow.", wf_ac_ex_id) return wf_ex_db @@ -351,15 +361,17 @@ def request_pause(ac_ex_db): @retrying.retry( retry_on_exception=wf_exc.retry_on_transient_db_errors, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) @retrying.retry( retry_on_exception=wf_exc.retry_on_connection_errors, stop_max_delay=cfg.CONF.workflow_engine.retry_stop_max_msec, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) def request_resume(ac_ex_db): wf_ac_ex_id = str(ac_ex_db.id) - LOG.info('[%s] Processing resume request for workflow.', wf_ac_ex_id) + LOG.info("[%s] Processing resume request for workflow.", wf_ac_ex_id) wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id)) @@ -375,7 +387,9 @@ def request_resume(ac_ex_db): raise wf_exc.WorkflowExecutionIsCompletedException(str(wf_ex_db.id)) if wf_ex_db.status in statuses.RUNNING_STATUSES: - msg = '[%s] Workflow execution "%s" is not resumed because it is already active.' + msg = ( + '[%s] Workflow execution "%s" is not resumed because it is already active.' + ) LOG.info(msg, wf_ac_ex_id, str(wf_ex_db.id)) return @@ -385,7 +399,9 @@ def request_resume(ac_ex_db): raise wf_exc.WorkflowExecutionIsCompletedException(str(wf_ex_db.id)) if conductor.get_workflow_status() in statuses.RUNNING_STATUSES: - msg = '[%s] Workflow execution "%s" is not resumed because it is already active.' + msg = ( + '[%s] Workflow execution "%s" is not resumed because it is already active.' + ) LOG.info(msg, wf_ac_ex_id, str(wf_ex_db.id)) return @@ -400,7 +416,7 @@ def request_resume(ac_ex_db): # Publish status change. wf_db_access.WorkflowExecution.publish_status(wf_ex_db) - LOG.info('[%s] Completed processing resume request for workflow.', wf_ac_ex_id) + LOG.info("[%s] Completed processing resume request for workflow.", wf_ac_ex_id) return wf_ex_db @@ -408,15 +424,17 @@ def request_resume(ac_ex_db): @retrying.retry( retry_on_exception=wf_exc.retry_on_transient_db_errors, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) @retrying.retry( retry_on_exception=wf_exc.retry_on_connection_errors, stop_max_delay=cfg.CONF.workflow_engine.retry_stop_max_msec, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) def request_cancellation(ac_ex_db): wf_ac_ex_id = str(ac_ex_db.id) - LOG.info('[%s] Processing cancelation request for workflow.', wf_ac_ex_id) + LOG.info("[%s] Processing cancelation request for workflow.", wf_ac_ex_id) wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id)) @@ -446,13 +464,16 @@ def request_cancellation(ac_ex_db): # Cascade the cancellation up to the root of the workflow. root_ac_ex_db = ac_svc.get_root_execution(ac_ex_db) - if root_ac_ex_db != ac_ex_db and root_ac_ex_db.status not in ac_const.LIVEACTION_CANCEL_STATES: - LOG.info('[%s] Cascading cancelation request to parent workflow.', wf_ac_ex_id) - root_lv_ac_db = lv_db_access.LiveAction.get(id=root_ac_ex_db.liveaction['id']) + if ( + root_ac_ex_db != ac_ex_db + and root_ac_ex_db.status not in ac_const.LIVEACTION_CANCEL_STATES + ): + LOG.info("[%s] Cascading cancelation request to parent workflow.", wf_ac_ex_id) + root_lv_ac_db = lv_db_access.LiveAction.get(id=root_ac_ex_db.liveaction["id"]) ac_svc.request_cancellation(root_lv_ac_db, None) - LOG.debug('[%s] %s', wf_ac_ex_id, conductor.serialize()) - LOG.info('[%s] Completed processing cancelation request for workflow.', wf_ac_ex_id) + LOG.debug("[%s] %s", wf_ac_ex_id, conductor.serialize()) + LOG.info("[%s] Completed processing cancelation request for workflow.", wf_ac_ex_id) return wf_ex_db @@ -460,20 +481,22 @@ def request_cancellation(ac_ex_db): @retrying.retry( retry_on_exception=wf_exc.retry_on_transient_db_errors, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) @retrying.retry( retry_on_exception=wf_exc.retry_on_connection_errors, stop_max_delay=cfg.CONF.workflow_engine.retry_stop_max_msec, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) def request_rerun(ac_ex_db, st2_ctx, options=None): wf_ac_ex_id = str(ac_ex_db.id) - LOG.info('[%s] Processing rerun request for workflow.', wf_ac_ex_id) + LOG.info("[%s] Processing rerun request for workflow.", wf_ac_ex_id) - wf_ex_id = st2_ctx.get('workflow_execution_id') + wf_ex_id = st2_ctx.get("workflow_execution_id") if not wf_ex_id: - msg = 'Unable to rerun workflow execution because workflow_execution_id is not provided.' + msg = "Unable to rerun workflow execution because workflow_execution_id is not provided." raise wf_exc.WorkflowExecutionRerunException(msg) try: @@ -487,8 +510,8 @@ def request_rerun(ac_ex_db, st2_ctx, options=None): raise wf_exc.WorkflowExecutionRerunException(msg % wf_ex_id) wf_ex_db.action_execution = wf_ac_ex_id - wf_ex_db.context['st2'] = st2_ctx['st2'] - wf_ex_db.context['parent'] = st2_ctx['parent'] + wf_ex_db.context["st2"] = st2_ctx["st2"] + wf_ex_db.context["parent"] = st2_ctx["parent"] conductor = deserialize_conductor(wf_ex_db) try: @@ -497,26 +520,29 @@ def request_rerun(ac_ex_db, st2_ctx, options=None): if options: task_requests = [] - task_names = options.get('tasks', []) - task_resets = options.get('reset', []) + task_names = options.get("tasks", []) + task_resets = options.get("reset", []) for task_name in task_names: reset_items = task_name in task_resets - task_state_entries = conductor.workflow_state.get_tasks(task_id=task_name) + task_state_entries = conductor.workflow_state.get_tasks( + task_id=task_name + ) if not task_state_entries: problems.append(task_name) continue for _, task_state_entry in task_state_entries: - route = task_state_entry['route'] + route = task_state_entry["route"] req = orquesta_reqs.TaskRerunRequest.new( - task_name, route, reset_items=reset_items) + task_name, route, reset_items=reset_items + ) task_requests.append(req) if problems: - msg = 'Unable to rerun workflow because one or more tasks is not found: %s' - raise Exception(msg % ','.join(problems)) + msg = "Unable to rerun workflow because one or more tasks is not found: %s" + raise Exception(msg % ",".join(problems)) conductor.request_workflow_rerun(task_requests=task_requests) except Exception as e: @@ -527,10 +553,10 @@ def request_rerun(ac_ex_db, st2_ctx, options=None): raise wf_exc.WorkflowExecutionRerunException(msg % wf_ex_id) data = conductor.serialize() - wf_ex_db.status = data['state']['status'] - wf_ex_db.spec = data['spec'] - wf_ex_db.graph = data['graph'] - wf_ex_db.state = data['state'] + wf_ex_db.status = data["state"]["status"] + wf_ex_db.spec = data["spec"] + wf_ex_db.graph = data["graph"] + wf_ex_db.state = data["state"] wf_db_access.WorkflowExecution.update(wf_ex_db, publish=False) wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(str(wf_ex_db.id)) @@ -542,12 +568,12 @@ def request_rerun(ac_ex_db, st2_ctx, options=None): def request_task_execution(wf_ex_db, st2_ctx, task_ex_req): - task_id = task_ex_req['id'] - task_route = task_ex_req['route'] - task_spec = task_ex_req['spec'] - task_ctx = task_ex_req['ctx'] - task_actions = task_ex_req['actions'] - task_delay = task_ex_req.get('delay') + task_id = task_ex_req["id"] + task_route = task_ex_req["route"] + task_spec = task_ex_req["spec"] + task_ctx = task_ex_req["ctx"] + task_actions = task_ex_req["actions"] + task_delay = task_ex_req.get("delay") msg = 'Processing task execution request for task "%s", route "%s".' update_progress(wf_ex_db, msg % (task_id, str(task_route)), stream=False) @@ -557,11 +583,14 @@ def request_task_execution(wf_ex_db, st2_ctx, task_ex_req): workflow_execution=str(wf_ex_db.id), task_id=task_id, task_route=task_route, - order_by=['-start_timestamp'] + order_by=["-start_timestamp"], ) - if (len(task_ex_dbs) > 0 and task_ex_dbs[0].itemized and - task_ex_dbs[0].status == ac_const.LIVEACTION_STATUS_RUNNING): + if ( + len(task_ex_dbs) > 0 + and task_ex_dbs[0].itemized + and task_ex_dbs[0].status == ac_const.LIVEACTION_STATUS_RUNNING + ): task_ex_db = task_ex_dbs[0] task_ex_id = str(task_ex_db.id) msg = 'Task execution "%s" retrieved for task "%s", route "%s".' @@ -576,15 +605,15 @@ def request_task_execution(wf_ex_db, st2_ctx, task_ex_req): task_spec=task_spec.serialize(), delay=task_delay, itemized=task_spec.has_items(), - items_count=task_ex_req.get('items_count'), - items_concurrency=task_ex_req.get('concurrency'), + items_count=task_ex_req.get("items_count"), + items_concurrency=task_ex_req.get("concurrency"), context=task_ctx, - status=statuses.REQUESTED + status=statuses.REQUESTED, ) # Prepare the result format for itemized task execution. if task_ex_db.itemized: - task_ex_db.result = {'items': [None] * task_ex_db.items_count} + task_ex_db.result = {"items": [None] * task_ex_db.items_count} # Insert new record into the database. task_ex_db = wf_db_access.TaskExecution.insert(task_ex_db, publish=False) @@ -627,26 +656,35 @@ def request_task_execution(wf_ex_db, st2_ctx, task_ex_req): # Request action execution for each actions in the task request. for ac_ex_req in task_actions: - ac_ex_delay = eval_action_execution_delay(task_ex_req, ac_ex_req, task_ex_db.itemized) - request_action_execution(wf_ex_db, task_ex_db, st2_ctx, ac_ex_req, delay=ac_ex_delay) + ac_ex_delay = eval_action_execution_delay( + task_ex_req, ac_ex_req, task_ex_db.itemized + ) + request_action_execution( + wf_ex_db, task_ex_db, st2_ctx, ac_ex_req, delay=ac_ex_delay + ) task_ex_db = wf_db_access.TaskExecution.get_by_id(str(task_ex_db.id)) except Exception as e: msg = 'Failed action execution(s) for task "%s", route "%s".' msg = msg % (task_id, str(task_route)) LOG.exception(msg) - msg = '%s %s: %s' % (msg, type(e).__name__, six.text_type(e)) - update_progress(wf_ex_db, msg, severity='error', log=False) - msg = '%s: %s' % (type(e).__name__, six.text_type(e)) - error = {'type': 'error', 'message': msg, 'task_id': task_id, 'route': task_route} - update_task_execution(str(task_ex_db.id), statuses.FAILED, {'errors': [error]}) + msg = "%s %s: %s" % (msg, type(e).__name__, six.text_type(e)) + update_progress(wf_ex_db, msg, severity="error", log=False) + msg = "%s: %s" % (type(e).__name__, six.text_type(e)) + error = { + "type": "error", + "message": msg, + "task_id": task_id, + "route": task_route, + } + update_task_execution(str(task_ex_db.id), statuses.FAILED, {"errors": [error]}) raise e return task_ex_db def eval_action_execution_delay(task_ex_req, ac_ex_req, itemized=False): - task_ex_delay = task_ex_req.get('delay') - items_concurrency = task_ex_req.get('concurrency') + task_ex_delay = task_ex_req.get("delay") + items_concurrency = task_ex_req.get("concurrency") # If there is a task delay and not with items, return the delay value. if task_ex_delay and not itemized: @@ -658,7 +696,7 @@ def eval_action_execution_delay(task_ex_req, ac_ex_req, itemized=False): # If there is a task delay and task has items with concurrency, # return the delay value up if item id is less than the concurrency value. - if task_ex_delay and itemized and ac_ex_req['item_id'] < items_concurrency: + if task_ex_delay and itemized and ac_ex_req["item_id"] < items_concurrency: return task_ex_delay return None @@ -667,20 +705,22 @@ def eval_action_execution_delay(task_ex_req, ac_ex_req, itemized=False): @retrying.retry( retry_on_exception=wf_exc.retry_on_transient_db_errors, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) @retrying.retry( retry_on_exception=wf_exc.retry_on_connection_errors, stop_max_delay=cfg.CONF.workflow_engine.retry_stop_max_msec, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) def request_action_execution(wf_ex_db, task_ex_db, st2_ctx, ac_ex_req, delay=None): - action_ref = ac_ex_req['action'] - action_input = ac_ex_req['input'] - item_id = ac_ex_req.get('item_id') + action_ref = ac_ex_req["action"] + action_input = ac_ex_req["input"] + item_id = ac_ex_req.get("item_id") # If the task is with items and item_id is not provided, raise exception. if task_ex_db.itemized and item_id is None: - msg = 'Unable to request action execution. Identifier for the item is not provided.' + msg = "Unable to request action execution. Identifier for the item is not provided." raise Exception(msg) # Identify the action to execute. @@ -691,40 +731,40 @@ def request_action_execution(wf_ex_db, task_ex_db, st2_ctx, ac_ex_req, delay=Non raise ac_exc.InvalidActionReferencedException(error) # Identify the runner for the action. - runner_type_db = action_utils.get_runnertype_by_name(action_db.runner_type['name']) + runner_type_db = action_utils.get_runnertype_by_name(action_db.runner_type["name"]) # Identify action pack name - pack_name = action_ref.split('.')[0] if action_ref else st2_ctx.get('pack') + pack_name = action_ref.split(".")[0] if action_ref else st2_ctx.get("pack") # Set context for the action execution. ac_ex_ctx = { - 'pack': pack_name, - 'user': st2_ctx.get('user'), - 'parent': st2_ctx, - 'orquesta': { - 'workflow_execution_id': str(wf_ex_db.id), - 'task_execution_id': str(task_ex_db.id), - 'task_name': task_ex_db.task_name, - 'task_id': task_ex_db.task_id, - 'task_route': task_ex_db.task_route - } + "pack": pack_name, + "user": st2_ctx.get("user"), + "parent": st2_ctx, + "orquesta": { + "workflow_execution_id": str(wf_ex_db.id), + "task_execution_id": str(task_ex_db.id), + "task_name": task_ex_db.task_name, + "task_id": task_ex_db.task_id, + "task_route": task_ex_db.task_route, + }, } - if st2_ctx.get('api_user'): - ac_ex_ctx['api_user'] = st2_ctx.get('api_user') + if st2_ctx.get("api_user"): + ac_ex_ctx["api_user"] = st2_ctx.get("api_user") - if st2_ctx.get('source_channel'): - ac_ex_ctx['source_channel'] = st2_ctx.get('source_channel') + if st2_ctx.get("source_channel"): + ac_ex_ctx["source_channel"] = st2_ctx.get("source_channel") if item_id is not None: - ac_ex_ctx['orquesta']['item_id'] = item_id + ac_ex_ctx["orquesta"]["item_id"] = item_id # Render action execution parameters and setup action execution object. ac_ex_params = param_utils.render_live_params( runner_type_db.runner_parameters or {}, action_db.parameters or {}, action_input or {}, - ac_ex_ctx + ac_ex_ctx, ) # The delay spec is in seconds and scheduler expects milliseconds. @@ -738,13 +778,19 @@ def request_action_execution(wf_ex_db, task_ex_db, st2_ctx, ac_ex_req, delay=Non task_execution=str(task_ex_db.id), delay=delay, context=ac_ex_ctx, - parameters=ac_ex_params + parameters=ac_ex_params, ) # Set notification if instructed. - if (wf_ex_db.notify and wf_ex_db.notify.get('config') and - wf_ex_db.notify.get('tasks') and task_ex_db.task_name in wf_ex_db.notify['tasks']): - lv_ac_db.notify = notify_api_models.NotificationsHelper.to_model(wf_ex_db.notify['config']) + if ( + wf_ex_db.notify + and wf_ex_db.notify.get("config") + and wf_ex_db.notify.get("tasks") + and task_ex_db.task_name in wf_ex_db.notify["tasks"] + ): + lv_ac_db.notify = notify_api_models.NotificationsHelper.to_model( + wf_ex_db.notify["config"] + ) # Set the task execution to running first otherwise a race can occur # where the action execution finishes first and the completion handler @@ -765,13 +811,13 @@ def handle_action_execution_pending(ac_ex_db): # Check that the action execution is paused. if ac_ex_db.status != ac_const.LIVEACTION_STATUS_PENDING: raise Exception( - 'Unable to handle pending of action execution. The action execution ' + "Unable to handle pending of action execution. The action execution " '"%s" is in "%s" status.' % (str(ac_ex_db.id), ac_ex_db.status) ) # Get related record identifiers. - wf_ex_id = ac_ex_db.context['orquesta']['workflow_execution_id'] - task_ex_id = ac_ex_db.context['orquesta']['task_execution_id'] + wf_ex_id = ac_ex_db.context["orquesta"]["workflow_execution_id"] + task_ex_id = ac_ex_db.context["orquesta"]["task_execution_id"] # Get execution records for logging purposes. wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_id) @@ -780,14 +826,14 @@ def handle_action_execution_pending(ac_ex_db): msg = 'Handling pending of action execution "%s" for task "%s", route "%s".' update_progress( wf_ex_db, - msg % (str(ac_ex_db.id), task_ex_db.task_id, str(task_ex_db.task_route)) + msg % (str(ac_ex_db.id), task_ex_db.task_id, str(task_ex_db.task_route)), ) # Updat task execution update_task_execution(task_ex_id, ac_ex_db.status, ac_ex_ctx=ac_ex_db.context) # Update task flow in the workflow execution. - ac_ex_ctx = ac_ex_db.context.get('orquesta') + ac_ex_ctx = ac_ex_db.context.get("orquesta") update_task_state(task_ex_id, ac_ex_db.status, ac_ex_ctx=ac_ex_ctx, publish=True) @@ -795,13 +841,13 @@ def handle_action_execution_pause(ac_ex_db): # Check that the action execution is paused. if ac_ex_db.status != ac_const.LIVEACTION_STATUS_PAUSED: raise Exception( - 'Unable to handle pause of action execution. The action execution ' + "Unable to handle pause of action execution. The action execution " '"%s" is in "%s" status.' % (str(ac_ex_db.id), ac_ex_db.status) ) # Get related record identifiers. - wf_ex_id = ac_ex_db.context['orquesta']['workflow_execution_id'] - task_ex_id = ac_ex_db.context['orquesta']['task_execution_id'] + wf_ex_id = ac_ex_db.context["orquesta"]["workflow_execution_id"] + task_ex_id = ac_ex_db.context["orquesta"]["task_execution_id"] # Get execution records for logging purposes. wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_id) @@ -814,27 +860,27 @@ def handle_action_execution_pause(ac_ex_db): msg = 'Handling pause of action execution "%s" for task "%s", route "%s".' update_progress( wf_ex_db, - msg % (str(ac_ex_db.id), task_ex_db.task_id, str(task_ex_db.task_route)) + msg % (str(ac_ex_db.id), task_ex_db.task_id, str(task_ex_db.task_route)), ) # Updat task execution update_task_execution(task_ex_id, ac_ex_db.status, ac_ex_ctx=ac_ex_db.context) # Update task flow in the workflow execution. - ac_ex_ctx = ac_ex_db.context.get('orquesta') + ac_ex_ctx = ac_ex_db.context.get("orquesta") update_task_state(task_ex_id, ac_ex_db.status, ac_ex_ctx=ac_ex_ctx, publish=True) def handle_action_execution_resume(ac_ex_db): - if 'orquesta' not in ac_ex_db.context: + if "orquesta" not in ac_ex_db.context: raise Exception( - 'Unable to handle resume of action execution. The action execution ' - '%s is not an orquesta workflow task.' % str(ac_ex_db.id) + "Unable to handle resume of action execution. The action execution " + "%s is not an orquesta workflow task." % str(ac_ex_db.id) ) # Get related record identifiers. - wf_ex_id = ac_ex_db.context['orquesta']['workflow_execution_id'] - task_ex_id = ac_ex_db.context['orquesta']['task_execution_id'] + wf_ex_id = ac_ex_db.context["orquesta"]["workflow_execution_id"] + task_ex_id = ac_ex_db.context["orquesta"]["task_execution_id"] # Get execution records for logging purposes. wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_id) @@ -843,7 +889,7 @@ def handle_action_execution_resume(ac_ex_db): msg = 'Handling resume of action execution "%s" for task "%s", route "%s".' update_progress( wf_ex_db, - msg % (str(ac_ex_db.id), task_ex_db.task_id, str(task_ex_db.task_route)) + msg % (str(ac_ex_db.id), task_ex_db.task_id, str(task_ex_db.task_route)), ) # Updat task execution to running. @@ -854,18 +900,22 @@ def handle_action_execution_resume(ac_ex_db): # If action execution has a parent, cascade status change upstream and do not publish # the status change because we do not want to trigger resume of other peer subworkflows. - if 'parent' in ac_ex_db.context: - parent_ac_ex_id = ac_ex_db.context['parent']['execution_id'] + if "parent" in ac_ex_db.context: + parent_ac_ex_id = ac_ex_db.context["parent"]["execution_id"] parent_ac_ex_db = ex_db_access.ActionExecution.get_by_id(parent_ac_ex_id) if parent_ac_ex_db.status == ac_const.LIVEACTION_STATUS_PAUSED: action_utils.update_liveaction_status( - liveaction_id=parent_ac_ex_db.liveaction['id'], + liveaction_id=parent_ac_ex_db.liveaction["id"], status=ac_const.LIVEACTION_STATUS_RUNNING, - publish=False) + publish=False, + ) # If there are grand parents, handle the resume of the parent action execution. - if 'orquesta' in parent_ac_ex_db.context and 'parent' in parent_ac_ex_db.context: + if ( + "orquesta" in parent_ac_ex_db.context + and "parent" in parent_ac_ex_db.context + ): handle_action_execution_resume(parent_ac_ex_db) @@ -873,18 +923,19 @@ def handle_action_execution_resume(ac_ex_db): retry_on_exception=wf_exc.retry_on_connection_errors, stop_max_delay=cfg.CONF.workflow_engine.retry_stop_max_msec, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) def handle_action_execution_completion(ac_ex_db): # Check that the action execution is completed. if ac_ex_db.status not in ac_const.LIVEACTION_COMPLETED_STATES: raise Exception( - 'Unable to handle completion of action execution. The action execution ' + "Unable to handle completion of action execution. The action execution " '"%s" is in "%s" status.' % (str(ac_ex_db.id), ac_ex_db.status) ) # Get related record identifiers. - wf_ex_id = ac_ex_db.context['orquesta']['workflow_execution_id'] - task_ex_id = ac_ex_db.context['orquesta']['task_execution_id'] + wf_ex_id = ac_ex_db.context["orquesta"]["workflow_execution_id"] + task_ex_id = ac_ex_db.context["orquesta"]["task_execution_id"] # Acquire lock before write operations. with coord_svc.get_coordinator(start_heart=True).get_lock(wf_ex_id): @@ -894,9 +945,12 @@ def handle_action_execution_completion(ac_ex_db): msg = ( 'Handling completion of action execution "%s" ' - 'in status "%s" for task "%s", route "%s".' % ( - str(ac_ex_db.id), ac_ex_db.status, task_ex_db.task_id, - str(task_ex_db.task_route) + 'in status "%s" for task "%s", route "%s".' + % ( + str(ac_ex_db.id), + ac_ex_db.status, + task_ex_db.task_id, + str(task_ex_db.task_route), ) ) update_progress(wf_ex_db, msg) @@ -907,14 +961,16 @@ def handle_action_execution_completion(ac_ex_db): resume_task_execution(task_ex_id) # Update task execution if completed. - update_task_execution(task_ex_id, ac_ex_db.status, ac_ex_db.result, ac_ex_db.context) + update_task_execution( + task_ex_id, ac_ex_db.status, ac_ex_db.result, ac_ex_db.context + ) # Update task flow in the workflow execution. update_task_state( task_ex_id, ac_ex_db.status, ac_ex_result=ac_ex_db.result, - ac_ex_ctx=ac_ex_db.context.get('orquesta') + ac_ex_ctx=ac_ex_db.context.get("orquesta"), ) # Request the next set of tasks if workflow execution is not complete. @@ -926,13 +982,13 @@ def handle_action_execution_completion(ac_ex_db): def deserialize_conductor(wf_ex_db): data = { - 'spec': wf_ex_db.spec, - 'graph': wf_ex_db.graph, - 'input': wf_ex_db.input, - 'context': wf_ex_db.context, - 'state': wf_ex_db.state, - 'output': wf_ex_db.output, - 'errors': wf_ex_db.errors + "spec": wf_ex_db.spec, + "graph": wf_ex_db.graph, + "input": wf_ex_db.input, + "context": wf_ex_db.context, + "state": wf_ex_db.state, + "output": wf_ex_db.output, + "errors": wf_ex_db.errors, } return conducting.WorkflowConductor.deserialize(data) @@ -948,18 +1004,22 @@ def refresh_conductor(wf_ex_id): @retrying.retry( retry_on_exception=wf_exc.retry_on_transient_db_errors, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) @retrying.retry( retry_on_exception=wf_exc.retry_on_connection_errors, stop_max_delay=cfg.CONF.workflow_engine.retry_stop_max_msec, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) -def update_task_state(task_ex_id, ac_ex_status, ac_ex_result=None, ac_ex_ctx=None, publish=True): + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) +def update_task_state( + task_ex_id, ac_ex_status, ac_ex_result=None, ac_ex_ctx=None, publish=True +): # Return if action execution status is not in the list of statuses to process. - statuses_to_process = ( - copy.copy(ac_const.LIVEACTION_COMPLETED_STATES) + - [ac_const.LIVEACTION_STATUS_PAUSED, ac_const.LIVEACTION_STATUS_PENDING] - ) + statuses_to_process = copy.copy(ac_const.LIVEACTION_COMPLETED_STATES) + [ + ac_const.LIVEACTION_STATUS_PAUSED, + ac_const.LIVEACTION_STATUS_PENDING, + ] if ac_ex_status not in statuses_to_process: return @@ -973,22 +1033,21 @@ def update_task_state(task_ex_id, ac_ex_status, ac_ex_result=None, ac_ex_ctx=Non msg = msg % (task_ex_db.task_id, str(task_ex_db.task_route), task_ex_db.status) update_progress(wf_ex_db, msg, stream=False) - if not ac_ex_ctx or 'item_id' not in ac_ex_ctx or ac_ex_ctx['item_id'] < 0: + if not ac_ex_ctx or "item_id" not in ac_ex_ctx or ac_ex_ctx["item_id"] < 0: ac_ex_event = events.ActionExecutionEvent(ac_ex_status, result=ac_ex_result) else: accumulated_result = [ - item.get('result') if item else None - for item in task_ex_db.result['items'] + item.get("result") if item else None for item in task_ex_db.result["items"] ] ac_ex_event = events.TaskItemActionExecutionEvent( - ac_ex_ctx['item_id'], + ac_ex_ctx["item_id"], ac_ex_status, result=ac_ex_result, - accumulated_result=accumulated_result + accumulated_result=accumulated_result, ) - update_progress(wf_ex_db, conductor.serialize(), severity='debug', stream=False) + update_progress(wf_ex_db, conductor.serialize(), severity="debug", stream=False) conductor.update_task_state(task_ex_db.task_id, task_ex_db.task_route, ac_ex_event) # Update workflow execution and related liveaction and action execution. @@ -997,19 +1056,21 @@ def update_task_state(task_ex_id, ac_ex_status, ac_ex_result=None, ac_ex_ctx=Non conductor, update_lv_ac_on_statuses=statuses_to_process, pub_lv_ac=publish, - pub_ac_ex=publish + pub_ac_ex=publish, ) @retrying.retry( retry_on_exception=wf_exc.retry_on_transient_db_errors, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) @retrying.retry( retry_on_exception=wf_exc.retry_on_connection_errors, stop_max_delay=cfg.CONF.workflow_engine.retry_stop_max_msec, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) def request_next_tasks(wf_ex_db, task_ex_id=None): iteration = 0 @@ -1018,7 +1079,9 @@ def request_next_tasks(wf_ex_db, task_ex_id=None): # If workflow is in requested status, set it to running. if conductor.get_workflow_status() in [statuses.REQUESTED, statuses.SCHEDULED]: - update_progress(wf_ex_db, 'Requesting conductor to start running workflow execution.') + update_progress( + wf_ex_db, "Requesting conductor to start running workflow execution." + ) conductor.request_workflow_status(statuses.RUNNING) # Identify the list of next set of tasks. Don't pass the task id to the conductor @@ -1028,93 +1091,104 @@ def request_next_tasks(wf_ex_db, task_ex_id=None): msg = 'Identifying next set (iter %s) of tasks after completion of task "%s", route "%s".' msg = msg % (str(iteration), task_ex_db.task_id, str(task_ex_db.task_route)) update_progress(wf_ex_db, msg) - update_progress(wf_ex_db, conductor.serialize(), severity='debug', stream=False) + update_progress(wf_ex_db, conductor.serialize(), severity="debug", stream=False) next_tasks = conductor.get_next_tasks() else: msg = 'Identifying next set (iter %s) of tasks for workflow execution in status "%s".' msg = msg % (str(iteration), conductor.get_workflow_status()) update_progress(wf_ex_db, msg) - update_progress(wf_ex_db, conductor.serialize(), severity='debug', stream=False) + update_progress(wf_ex_db, conductor.serialize(), severity="debug", stream=False) next_tasks = conductor.get_next_tasks() # If there is no new tasks, update execution records to handle possible completion. if not next_tasks: # Update workflow execution and related liveaction and action execution. - update_progress(wf_ex_db, 'No tasks identified to execute next.') - update_progress(wf_ex_db, '\n', log=False) + update_progress(wf_ex_db, "No tasks identified to execute next.") + update_progress(wf_ex_db, "\n", log=False) update_execution_records(wf_ex_db, conductor) if conductor.get_workflow_status() in statuses.COMPLETED_STATUSES: msg = 'The workflow execution is completed with status "%s".' update_progress(wf_ex_db, msg % conductor.get_workflow_status()) - update_progress(wf_ex_db, '\n', log=False) + update_progress(wf_ex_db, "\n", log=False) # Iterate while there are next tasks identified for processing. In the case for # task with no action execution defined, the task execution will complete # immediately with a new set of tasks available. while next_tasks: - msg = 'Identified the following set of tasks to execute next: %s' - tasks_list = ', '.join(["%s (route %s)" % (t['id'], str(t['route'])) for t in next_tasks]) + msg = "Identified the following set of tasks to execute next: %s" + tasks_list = ", ".join( + ["%s (route %s)" % (t["id"], str(t["route"])) for t in next_tasks] + ) update_progress(wf_ex_db, msg % tasks_list) # Mark the tasks as running in the task flow before actual task execution. for task in next_tasks: msg = 'Mark task "%s", route "%s", in conductor as running.' - update_progress(wf_ex_db, msg % (task['id'], str(task['route'])), stream=False) + update_progress( + wf_ex_db, msg % (task["id"], str(task["route"])), stream=False + ) # If task has items and items list is empty, then actions under the next task is empty # and will not be processed in the for loop below. Handle this use case separately and # mark it as running in the conductor. The task will be completed automatically when # it is requested for task execution. - if task['spec'].has_items() and 'items_count' in task and task['items_count'] == 0: + if ( + task["spec"].has_items() + and "items_count" in task + and task["items_count"] == 0 + ): ac_ex_event = events.ActionExecutionEvent(statuses.RUNNING) - conductor.update_task_state(task['id'], task['route'], ac_ex_event) + conductor.update_task_state(task["id"], task["route"], ac_ex_event) # If task contains multiple action execution (i.e. with items), # then mark each item individually. - for action in task['actions']: - if 'item_id' not in action or action['item_id'] is None: + for action in task["actions"]: + if "item_id" not in action or action["item_id"] is None: ac_ex_event = events.ActionExecutionEvent(statuses.RUNNING) else: - msg = 'Mark task "%s", route "%s", item "%s" in conductor as running.' - msg = msg % (task['id'], str(task['route']), action['item_id']) + msg = ( + 'Mark task "%s", route "%s", item "%s" in conductor as running.' + ) + msg = msg % (task["id"], str(task["route"]), action["item_id"]) update_progress(wf_ex_db, msg) ac_ex_event = events.TaskItemActionExecutionEvent( - action['item_id'], - statuses.RUNNING + action["item_id"], statuses.RUNNING ) - conductor.update_task_state(task['id'], task['route'], ac_ex_event) + conductor.update_task_state(task["id"], task["route"], ac_ex_event) # Update workflow execution and related liveaction and action execution. - update_progress(wf_ex_db, conductor.serialize(), severity='debug', stream=False) + update_progress(wf_ex_db, conductor.serialize(), severity="debug", stream=False) update_execution_records(wf_ex_db, conductor) # Request task execution for the tasks. for task in next_tasks: try: msg = 'Requesting execution for task "%s", route "%s".' - update_progress(wf_ex_db, msg % (task['id'], str(task['route']))) + update_progress(wf_ex_db, msg % (task["id"], str(task["route"]))) # Pass down appropriate st2 context to the task and action execution(s). - root_st2_ctx = wf_ex_db.context.get('st2', {}) + root_st2_ctx = wf_ex_db.context.get("st2", {}) st2_ctx = { - 'execution_id': wf_ex_db.action_execution, - 'user': root_st2_ctx.get('user'), - 'pack': root_st2_ctx.get('pack') + "execution_id": wf_ex_db.action_execution, + "user": root_st2_ctx.get("user"), + "pack": root_st2_ctx.get("pack"), } - if root_st2_ctx.get('api_user'): - st2_ctx['api_user'] = root_st2_ctx.get('api_user') + if root_st2_ctx.get("api_user"): + st2_ctx["api_user"] = root_st2_ctx.get("api_user") - if root_st2_ctx.get('source_channel'): - st2_ctx['source_channel'] = root_st2_ctx.get('source_channel') + if root_st2_ctx.get("source_channel"): + st2_ctx["source_channel"] = root_st2_ctx.get("source_channel") # Request the task execution. request_task_execution(wf_ex_db, st2_ctx, task) except Exception as e: msg = 'Failed task execution for task "%s", route "%s".' - msg = msg % (task['id'], str(task['route'])) - update_progress(wf_ex_db, '%s %s' % (msg, str(e)), severity='error', log=False) + msg = msg % (task["id"], str(task["route"])) + update_progress( + wf_ex_db, "%s %s" % (msg, str(e)), severity="error", log=False + ) LOG.exception(msg) fail_workflow_execution(str(wf_ex_db.id), e, task=task) return @@ -1125,25 +1199,30 @@ def request_next_tasks(wf_ex_db, task_ex_id=None): msg = 'Identifying next set (iter %s) of tasks for workflow execution in status "%s".' msg = msg % (str(iteration), conductor.get_workflow_status()) update_progress(wf_ex_db, msg) - update_progress(wf_ex_db, conductor.serialize(), severity='debug', stream=False) + update_progress(wf_ex_db, conductor.serialize(), severity="debug", stream=False) next_tasks = conductor.get_next_tasks() if not next_tasks: - update_progress(wf_ex_db, 'No tasks identified to execute next.') - update_progress(wf_ex_db, '\n', log=False) + update_progress(wf_ex_db, "No tasks identified to execute next.") + update_progress(wf_ex_db, "\n", log=False) @retrying.retry( retry_on_exception=wf_exc.retry_on_transient_db_errors, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) @retrying.retry( retry_on_exception=wf_exc.retry_on_connection_errors, stop_max_delay=cfg.CONF.workflow_engine.retry_stop_max_msec, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) def update_task_execution(task_ex_id, ac_ex_status, ac_ex_result=None, ac_ex_ctx=None): - if ac_ex_status not in statuses.COMPLETED_STATUSES + [statuses.PAUSED, statuses.PENDING]: + if ac_ex_status not in statuses.COMPLETED_STATUSES + [ + statuses.PAUSED, + statuses.PENDING, + ]: return task_ex_db = wf_db_access.TaskExecution.get_by_id(task_ex_id) @@ -1153,31 +1232,43 @@ def update_task_execution(task_ex_id, ac_ex_status, ac_ex_result=None, ac_ex_ctx if not task_ex_db.itemized or (task_ex_db.itemized and task_ex_db.items_count == 0): if ac_ex_status != task_ex_db.status: msg = 'Updating task execution "%s" for task "%s" from status "%s" to "%s".' - msg = msg % (task_ex_id, task_ex_db.task_id, task_ex_db.status, ac_ex_status) + msg = msg % ( + task_ex_id, + task_ex_db.task_id, + task_ex_db.status, + ac_ex_status, + ) update_progress(wf_ex_db, msg) task_ex_db.status = ac_ex_status task_ex_db.result = ac_ex_result if ac_ex_result else task_ex_db.result elif task_ex_db.itemized and ac_ex_ctx: - if 'orquesta' not in ac_ex_ctx or 'item_id' not in ac_ex_ctx['orquesta']: - msg = 'Context information for the item is not provided. %s' % str(ac_ex_ctx) - update_progress(wf_ex_db, msg, severity='error', log=False) + if "orquesta" not in ac_ex_ctx or "item_id" not in ac_ex_ctx["orquesta"]: + msg = "Context information for the item is not provided. %s" % str( + ac_ex_ctx + ) + update_progress(wf_ex_db, msg, severity="error", log=False) raise Exception(msg) - item_id = ac_ex_ctx['orquesta']['item_id'] + item_id = ac_ex_ctx["orquesta"]["item_id"] msg = 'Processing action execution for task "%s", route "%s", item "%s".' msg = msg % (task_ex_db.task_id, str(task_ex_db.task_route), item_id) - update_progress(wf_ex_db, msg, severity='debug') + update_progress(wf_ex_db, msg, severity="debug") - task_ex_db.result['items'][item_id] = {'status': ac_ex_status, 'result': ac_ex_result} + task_ex_db.result["items"][item_id] = { + "status": ac_ex_status, + "result": ac_ex_result, + } item_statuses = [ - item.get('status', statuses.UNSET) if item else statuses.UNSET - for item in task_ex_db.result['items'] + item.get("status", statuses.UNSET) if item else statuses.UNSET + for item in task_ex_db.result["items"] ] - task_completed = all([status in statuses.COMPLETED_STATUSES for status in item_statuses]) + task_completed = all( + [status in statuses.COMPLETED_STATUSES for status in item_statuses] + ) if task_completed: new_task_status = ( @@ -1187,11 +1278,15 @@ def update_task_execution(task_ex_id, ac_ex_status, ac_ex_result=None, ac_ex_ctx ) msg = 'Updating task execution from status "%s" to "%s".' - update_progress(wf_ex_db, msg % (task_ex_db.status, new_task_status), severity='debug') + update_progress( + wf_ex_db, msg % (task_ex_db.status, new_task_status), severity="debug" + ) task_ex_db.status = new_task_status else: - msg = 'Task execution is not complete because not all items are complete: %s' - update_progress(wf_ex_db, msg % ', '.join(item_statuses), severity='debug') + msg = ( + "Task execution is not complete because not all items are complete: %s" + ) + update_progress(wf_ex_db, msg % ", ".join(item_statuses), severity="debug") if task_ex_db.status in statuses.COMPLETED_STATUSES: task_ex_db.end_timestamp = date_utils.get_datetime_utc_now() @@ -1202,19 +1297,23 @@ def update_task_execution(task_ex_id, ac_ex_status, ac_ex_result=None, ac_ex_ctx @retrying.retry( retry_on_exception=wf_exc.retry_on_transient_db_errors, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) @retrying.retry( retry_on_exception=wf_exc.retry_on_connection_errors, stop_max_delay=cfg.CONF.workflow_engine.retry_stop_max_msec, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) def resume_task_execution(task_ex_id): # Update task execution to running. task_ex_db = wf_db_access.TaskExecution.get_by_id(task_ex_id) wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(task_ex_db.workflow_execution) msg = 'Updating task execution from status "%s" to "%s".' - update_progress(wf_ex_db, msg % (task_ex_db.status, statuses.RUNNING), severity='debug') + update_progress( + wf_ex_db, msg % (task_ex_db.status, statuses.RUNNING), severity="debug" + ) task_ex_db.status = statuses.RUNNING # Write update to the database. @@ -1224,17 +1323,21 @@ def resume_task_execution(task_ex_id): @retrying.retry( retry_on_exception=wf_exc.retry_on_transient_db_errors, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) @retrying.retry( retry_on_exception=wf_exc.retry_on_connection_errors, stop_max_delay=cfg.CONF.workflow_engine.retry_stop_max_msec, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) def update_workflow_execution(wf_ex_id): conductor, wf_ex_db = refresh_conductor(wf_ex_id) # There is nothing to update if workflow execution is not completed or paused. - if conductor.get_workflow_status() in statuses.COMPLETED_STATUSES + [statuses.PAUSED]: + if conductor.get_workflow_status() in statuses.COMPLETED_STATUSES + [ + statuses.PAUSED + ]: # Update workflow execution and related liveaction and action execution. update_execution_records(wf_ex_db, conductor) @@ -1242,12 +1345,14 @@ def update_workflow_execution(wf_ex_id): @retrying.retry( retry_on_exception=wf_exc.retry_on_transient_db_errors, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) @retrying.retry( retry_on_exception=wf_exc.retry_on_connection_errors, stop_max_delay=cfg.CONF.workflow_engine.retry_stop_max_msec, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) def resume_workflow_execution(wf_ex_id, task_ex_id): # Update workflow execution to running. conductor, wf_ex_db = refresh_conductor(wf_ex_id) @@ -1265,12 +1370,14 @@ def resume_workflow_execution(wf_ex_id, task_ex_id): @retrying.retry( retry_on_exception=wf_exc.retry_on_transient_db_errors, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) @retrying.retry( retry_on_exception=wf_exc.retry_on_connection_errors, stop_max_delay=cfg.CONF.workflow_engine.retry_stop_max_msec, wait_fixed=cfg.CONF.workflow_engine.retry_wait_fixed_msec, - wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec) + wait_jitter_max=cfg.CONF.workflow_engine.retry_max_jitter_msec, +) def fail_workflow_execution(wf_ex_id, exception, task=None): conductor, wf_ex_db = refresh_conductor(wf_ex_id) @@ -1278,7 +1385,7 @@ def fail_workflow_execution(wf_ex_id, exception, task=None): conductor.request_workflow_status(statuses.FAILED) if task is not None and isinstance(task, dict): - conductor.log_error(exception, task_id=task.get('id'), route=task.get('route')) + conductor.log_error(exception, task_id=task.get("id"), route=task.get("route")) else: conductor.log_error(exception) @@ -1286,8 +1393,14 @@ def fail_workflow_execution(wf_ex_id, exception, task=None): update_execution_records(wf_ex_db, conductor) -def update_execution_records(wf_ex_db, conductor, update_lv_ac_on_statuses=None, - pub_wf_ex=False, pub_lv_ac=True, pub_ac_ex=True): +def update_execution_records( + wf_ex_db, + conductor, + update_lv_ac_on_statuses=None, + pub_wf_ex=False, + pub_lv_ac=True, + pub_ac_ex=True, +): # If the workflow execution is completed, then render the workflow output. if conductor.get_workflow_status() in statuses.COMPLETED_STATUSES: conductor.render_workflow_output() @@ -1295,7 +1408,7 @@ def update_execution_records(wf_ex_db, conductor, update_lv_ac_on_statuses=None, # Determine if workflow status has changed. wf_old_status = wf_ex_db.status wf_ex_db.status = conductor.get_workflow_status() - status_changed = (wf_old_status != wf_ex_db.status) + status_changed = wf_old_status != wf_ex_db.status if status_changed: msg = 'Updating workflow execution from status "%s" to "%s".' @@ -1314,53 +1427,58 @@ def update_execution_records(wf_ex_db, conductor, update_lv_ac_on_statuses=None, wf_ex_db = wf_db_access.WorkflowExecution.update(wf_ex_db, publish=pub_wf_ex) # Return if workflow execution status is not specified in update_lv_ac_on_statuses. - if (isinstance(update_lv_ac_on_statuses, list) and - wf_ex_db.status not in update_lv_ac_on_statuses): + if ( + isinstance(update_lv_ac_on_statuses, list) + and wf_ex_db.status not in update_lv_ac_on_statuses + ): return # Update the corresponding liveaction and action execution for the workflow. wf_ac_ex_db = ex_db_access.ActionExecution.get_by_id(wf_ex_db.action_execution) - wf_lv_ac_db = action_utils.get_liveaction_by_id(wf_ac_ex_db.liveaction['id']) + wf_lv_ac_db = action_utils.get_liveaction_by_id(wf_ac_ex_db.liveaction["id"]) # Gather result for liveaction and action execution. - result = {'output': wf_ex_db.output or None} + result = {"output": wf_ex_db.output or None} if wf_ex_db.status in statuses.ABENDED_STATUSES: - result['errors'] = wf_ex_db.errors + result["errors"] = wf_ex_db.errors if wf_ex_db.errors: - msg = 'Workflow execution completed with errors.' - update_progress(wf_ex_db, msg, severity='error') + msg = "Workflow execution completed with errors." + update_progress(wf_ex_db, msg, severity="error") for wf_ex_error in wf_ex_db.errors: - update_progress(wf_ex_db, wf_ex_error, severity='error') + update_progress(wf_ex_db, wf_ex_error, severity="error") # Sync update with corresponding liveaction and action execution. if pub_lv_ac or pub_ac_ex: - pub_lv_ac = (wf_lv_ac_db.status != wf_ex_db.status) + pub_lv_ac = wf_lv_ac_db.status != wf_ex_db.status pub_ac_ex = pub_lv_ac if wf_lv_ac_db.status != wf_ex_db.status: - kwargs = {'severity': 'debug', 'stream': False} + kwargs = {"severity": "debug", "stream": False} msg = 'Updating workflow liveaction from status "%s" to "%s".' update_progress(wf_ex_db, msg % (wf_lv_ac_db.status, wf_ex_db.status), **kwargs) - msg = 'Workflow liveaction status change %s be published.' - update_progress(wf_ex_db, msg % 'will' if pub_lv_ac else 'will not', **kwargs) - msg = 'Workflow action execution status change %s be published.' - update_progress(wf_ex_db, msg % 'will' if pub_ac_ex else 'will not', **kwargs) + msg = "Workflow liveaction status change %s be published." + update_progress(wf_ex_db, msg % "will" if pub_lv_ac else "will not", **kwargs) + msg = "Workflow action execution status change %s be published." + update_progress(wf_ex_db, msg % "will" if pub_ac_ex else "will not", **kwargs) wf_lv_ac_db = action_utils.update_liveaction_status( status=wf_ex_db.status, result=result, end_timestamp=wf_ex_db.end_timestamp, liveaction_db=wf_lv_ac_db, - publish=pub_lv_ac) + publish=pub_lv_ac, + ) ex_svc.update_execution(wf_lv_ac_db, publish=pub_ac_ex) # Invoke post run on the liveaction for the workflow execution. if status_changed and wf_lv_ac_db.status in ac_const.LIVEACTION_COMPLETED_STATES: - update_progress(wf_ex_db, 'Workflow action execution is completed and invoking post run.') + update_progress( + wf_ex_db, "Workflow action execution is completed and invoking post run." + ) runners_utils.invoke_post_run(wf_lv_ac_db) @@ -1376,36 +1494,40 @@ def identify_orphaned_workflows(): # does not necessary means it is the max idle time. The use of workflow_executions_idled_ttl # to filter is to reduce the number of action executions that need to be evaluated. query_filters = { - 'runner__name': 'orquesta', - 'status': ac_const.LIVEACTION_STATUS_RUNNING, - 'start_timestamp__lte': expiry_dt + "runner__name": "orquesta", + "status": ac_const.LIVEACTION_STATUS_RUNNING, + "start_timestamp__lte": expiry_dt, } ac_ex_dbs = ex_db_access.ActionExecution.query(**query_filters) for ac_ex_db in ac_ex_dbs: # Figure out the runtime for the action execution. status_change_logs = sorted( - [log for log in ac_ex_db.log if log['status'] == ac_const.LIVEACTION_STATUS_RUNNING], - key=lambda x: x['timestamp'], - reverse=True + [ + log + for log in ac_ex_db.log + if log["status"] == ac_const.LIVEACTION_STATUS_RUNNING + ], + key=lambda x: x["timestamp"], + reverse=True, ) if len(status_change_logs) <= 0: continue - runtime = (utc_now_dt - status_change_logs[0]['timestamp']).total_seconds() + runtime = (utc_now_dt - status_change_logs[0]["timestamp"]).total_seconds() # Fetch the task executions for the workflow execution. # Ensure that the root action execution is not being selected. - wf_ex_id = ac_ex_db.context['workflow_execution'] + wf_ex_id = ac_ex_db.context["workflow_execution"] wf_ex_db = wf_db_access.WorkflowExecution.get_by_id(wf_ex_id) - query_filters = {'workflow_execution': wf_ex_id, 'id__ne': ac_ex_db.id} + query_filters = {"workflow_execution": wf_ex_id, "id__ne": ac_ex_db.id} tk_ac_ex_dbs = ex_db_access.ActionExecution.query(**query_filters) # The workflow execution is orphaned if there are # no task executions and runtime passed expiry. if len(tk_ac_ex_dbs) <= 0 and runtime > gc_max_idle: - msg = 'The action execution is orphaned and will be canceled by the garbage collector.' + msg = "The action execution is orphaned and will be canceled by the garbage collector." update_progress(wf_ex_db, msg) orphaned.append(ac_ex_db) continue @@ -1415,7 +1537,8 @@ def identify_orphaned_workflows(): has_active_tasks = len([t for t in tk_ac_ex_dbs if t.end_timestamp is None]) > 0 completed_tasks = [ - t for t in tk_ac_ex_dbs + t + for t in tk_ac_ex_dbs if t.end_timestamp is not None and t.end_timestamp <= expiry_dt ] @@ -1423,11 +1546,16 @@ def identify_orphaned_workflows(): most_recent_completed_task_expired = ( completed_tasks[-1].end_timestamp <= expiry_dt - if len(completed_tasks) > 0 else False + if len(completed_tasks) > 0 + else False ) - if len(tk_ac_ex_dbs) > 0 and not has_active_tasks and most_recent_completed_task_expired: - msg = 'The action execution is orphaned and will be canceled by the garbage collector.' + if ( + len(tk_ac_ex_dbs) > 0 + and not has_active_tasks + and most_recent_completed_task_expired + ): + msg = "The action execution is orphaned and will be canceled by the garbage collector." update_progress(wf_ex_db, msg) orphaned.append(ac_ex_db) continue diff --git a/st2common/st2common/signal_handlers.py b/st2common/st2common/signal_handlers.py index 0fc2766175f..bd785403f4e 100644 --- a/st2common/st2common/signal_handlers.py +++ b/st2common/st2common/signal_handlers.py @@ -26,7 +26,7 @@ from st2common.logging.misc import reopen_log_files __all__ = [ - 'register_common_signal_handlers', + "register_common_signal_handlers", ] diff --git a/st2common/st2common/stream/listener.py b/st2common/st2common/stream/listener.py index 6edbef1750f..347c4cfc754 100644 --- a/st2common/st2common/stream/listener.py +++ b/st2common/st2common/stream/listener.py @@ -33,11 +33,10 @@ from st2common import log as logging __all__ = [ - 'StreamListener', - 'ExecutionOutputListener', - - 'get_listener', - 'get_listener_if_set' + "StreamListener", + "ExecutionOutputListener", + "get_listener", + "get_listener_if_set", ] LOG = logging.getLogger(__name__) @@ -49,23 +48,24 @@ class BaseListener(ConsumerMixin): - def __init__(self, connection): self.connection = connection self.queues = [] self._stopped = False def get_consumers(self, consumer, channel): - raise NotImplementedError('get_consumers() is not implemented') + raise NotImplementedError("get_consumers() is not implemented") def processor(self, model=None): def process(body, message): meta = message.delivery_info - event_name = '%s__%s' % (meta.get('exchange'), meta.get('routing_key')) + event_name = "%s__%s" % (meta.get("exchange"), meta.get("routing_key")) try: if model: - body = model.from_model(body, mask_secrets=cfg.CONF.api.mask_secrets) + body = model.from_model( + body, mask_secrets=cfg.CONF.api.mask_secrets + ) self.emit(event_name, body) finally: @@ -78,10 +78,17 @@ def emit(self, event, body): for queue in self.queues: queue.put(pack) - def generator(self, events=None, action_refs=None, execution_ids=None, - end_event=None, end_statuses=None, end_execution_id=None): + def generator( + self, + events=None, + action_refs=None, + execution_ids=None, + end_event=None, + end_statuses=None, + end_execution_id=None, + ): queue = eventlet.Queue() - queue.put('') + queue.put("") self.queues.append(queue) try: stop = False @@ -95,16 +102,19 @@ def generator(self, events=None, action_refs=None, execution_ids=None, event_name, body = message # check to see if this is the last message to send. if event_name == end_event: - if body is not None and \ - body.status in end_statuses and \ - end_execution_id is not None and \ - body.id == end_execution_id: + if ( + body is not None + and body.status in end_statuses + and end_execution_id is not None + and body.id == end_execution_id + ): stop = True # TODO: We now do late filtering, but this could also be performed on the # message bus level if we modified our exchange layout and utilize routing keys # Filter on event name - include_event = self._should_include_event(event_names_whitelist=events, - event_name=event_name) + include_event = self._should_include_event( + event_names_whitelist=events, event_name=event_name + ) if not include_event: LOG.debug('Skipping event "%s"' % (event_name)) continue @@ -112,14 +122,18 @@ def generator(self, events=None, action_refs=None, execution_ids=None, # Filter on action ref action_ref = self._get_action_ref_for_body(body=body) if action_refs and action_ref not in action_refs: - LOG.debug('Skipping event "%s" with action_ref "%s"' % (event_name, - action_ref)) + LOG.debug( + 'Skipping event "%s" with action_ref "%s"' + % (event_name, action_ref) + ) continue # Filter on execution id execution_id = self._get_execution_id_for_body(body=body) if execution_ids and execution_id not in execution_ids: - LOG.debug('Skipping event "%s" with execution_id "%s"' % (event_name, - execution_id)) + LOG.debug( + 'Skipping event "%s" with execution_id "%s"' + % (event_name, execution_id) + ) continue yield message @@ -154,7 +168,7 @@ def _get_action_ref_for_body(self, body): action_ref = None if isinstance(body, ActionExecutionAPI): - action_ref = body.action.get('ref', None) if body.action else None + action_ref = body.action.get("ref", None) if body.action else None elif isinstance(body, LiveActionAPI): action_ref = body.action elif isinstance(body, (ActionExecutionOutputAPI)): @@ -187,21 +201,26 @@ class StreamListener(BaseListener): def get_consumers(self, consumer, channel): return [ - consumer(queues=[STREAM_ANNOUNCEMENT_WORK_QUEUE], - accept=['pickle'], - callbacks=[self.processor()]), - - consumer(queues=[STREAM_EXECUTION_ALL_WORK_QUEUE], - accept=['pickle'], - callbacks=[self.processor(ActionExecutionAPI)]), - - consumer(queues=[STREAM_LIVEACTION_WORK_QUEUE], - accept=['pickle'], - callbacks=[self.processor(LiveActionAPI)]), - - consumer(queues=[STREAM_EXECUTION_OUTPUT_QUEUE], - accept=['pickle'], - callbacks=[self.processor(ActionExecutionOutputAPI)]) + consumer( + queues=[STREAM_ANNOUNCEMENT_WORK_QUEUE], + accept=["pickle"], + callbacks=[self.processor()], + ), + consumer( + queues=[STREAM_EXECUTION_ALL_WORK_QUEUE], + accept=["pickle"], + callbacks=[self.processor(ActionExecutionAPI)], + ), + consumer( + queues=[STREAM_LIVEACTION_WORK_QUEUE], + accept=["pickle"], + callbacks=[self.processor(LiveActionAPI)], + ), + consumer( + queues=[STREAM_EXECUTION_OUTPUT_QUEUE], + accept=["pickle"], + callbacks=[self.processor(ActionExecutionOutputAPI)], + ), ] @@ -214,13 +233,16 @@ class ExecutionOutputListener(BaseListener): def get_consumers(self, consumer, channel): return [ - consumer(queues=[STREAM_EXECUTION_UPDATE_WORK_QUEUE], - accept=['pickle'], - callbacks=[self.processor(ActionExecutionAPI)]), - - consumer(queues=[STREAM_EXECUTION_OUTPUT_QUEUE], - accept=['pickle'], - callbacks=[self.processor(ActionExecutionOutputAPI)]) + consumer( + queues=[STREAM_EXECUTION_UPDATE_WORK_QUEUE], + accept=["pickle"], + callbacks=[self.processor(ActionExecutionAPI)], + ), + consumer( + queues=[STREAM_EXECUTION_OUTPUT_QUEUE], + accept=["pickle"], + callbacks=[self.processor(ActionExecutionOutputAPI)], + ), ] @@ -235,29 +257,29 @@ def get_listener(name): global _stream_listener global _execution_output_listener - if name == 'stream': + if name == "stream": if not _stream_listener: with transport_utils.get_connection() as conn: _stream_listener = StreamListener(conn) eventlet.spawn_n(listen, _stream_listener) return _stream_listener - elif name == 'execution_output': + elif name == "execution_output": if not _execution_output_listener: with transport_utils.get_connection() as conn: _execution_output_listener = ExecutionOutputListener(conn) eventlet.spawn_n(listen, _execution_output_listener) return _execution_output_listener else: - raise ValueError('Invalid listener name: %s' % (name)) + raise ValueError("Invalid listener name: %s" % (name)) def get_listener_if_set(name): global _stream_listener global _execution_output_listener - if name == 'stream': + if name == "stream": return _stream_listener - elif name == 'execution_output': + elif name == "execution_output": return _execution_output_listener else: - raise ValueError('Invalid listener name: %s' % (name)) + raise ValueError("Invalid listener name: %s" % (name)) diff --git a/st2common/st2common/transport/__init__.py b/st2common/st2common/transport/__init__.py index cc384c878ef..632c08dc0ec 100644 --- a/st2common/st2common/transport/__init__.py +++ b/st2common/st2common/transport/__init__.py @@ -21,12 +21,12 @@ # TODO(manas) : Exchanges, Queues and RoutingKey design discussion pending. __all__ = [ - 'liveaction', - 'actionexecutionstate', - 'execution', - 'workflow', - 'publishers', - 'reactor', - 'utils', - 'connection_retry_wrapper' + "liveaction", + "actionexecutionstate", + "execution", + "workflow", + "publishers", + "reactor", + "utils", + "connection_retry_wrapper", ] diff --git a/st2common/st2common/transport/actionexecutionstate.py b/st2common/st2common/transport/actionexecutionstate.py index 268bffe0fc7..46fe095fbf9 100644 --- a/st2common/st2common/transport/actionexecutionstate.py +++ b/st2common/st2common/transport/actionexecutionstate.py @@ -21,18 +21,16 @@ from st2common.transport import publishers -__all__ = [ - 'ActionExecutionStatePublisher' -] +__all__ = ["ActionExecutionStatePublisher"] -ACTIONEXECUTIONSTATE_XCHG = Exchange('st2.actionexecutionstate', - type='topic') +ACTIONEXECUTIONSTATE_XCHG = Exchange("st2.actionexecutionstate", type="topic") class ActionExecutionStatePublisher(publishers.CUDPublisher): - def __init__(self): - super(ActionExecutionStatePublisher, self).__init__(exchange=ACTIONEXECUTIONSTATE_XCHG) + super(ActionExecutionStatePublisher, self).__init__( + exchange=ACTIONEXECUTIONSTATE_XCHG + ) def get_queue(name, routing_key): diff --git a/st2common/st2common/transport/announcement.py b/st2common/st2common/transport/announcement.py index 84c8bf27a75..e79506c608d 100644 --- a/st2common/st2common/transport/announcement.py +++ b/st2common/st2common/transport/announcement.py @@ -22,17 +22,12 @@ from st2common.models.api.trace import TraceContext from st2common.transport import publishers -__all__ = [ - 'AnnouncementPublisher', - 'AnnouncementDispatcher', - - 'get_queue' -] +__all__ = ["AnnouncementPublisher", "AnnouncementDispatcher", "get_queue"] LOG = logging.getLogger(__name__) # Exchange for Announcements -ANNOUNCEMENT_XCHG = Exchange('st2.announcement', type='topic') +ANNOUNCEMENT_XCHG = Exchange("st2.announcement", type="topic") class AnnouncementPublisher(object): @@ -68,16 +63,19 @@ def dispatch(self, routing_key, payload, trace_context=None): assert isinstance(payload, (type(None), dict)) assert isinstance(trace_context, (type(None), dict, TraceContext)) - payload = { - 'payload': payload, - TRACE_CONTEXT: trace_context - } + payload = {"payload": payload, TRACE_CONTEXT: trace_context} - self._logger.debug('Dispatching announcement (routing_key=%s,payload=%s)', - routing_key, payload) + self._logger.debug( + "Dispatching announcement (routing_key=%s,payload=%s)", routing_key, payload + ) self._publisher.publish(payload=payload, routing_key=routing_key) -def get_queue(name=None, routing_key='#', exclusive=False, auto_delete=False): - return Queue(name, ANNOUNCEMENT_XCHG, routing_key=routing_key, exclusive=exclusive, - auto_delete=auto_delete) +def get_queue(name=None, routing_key="#", exclusive=False, auto_delete=False): + return Queue( + name, + ANNOUNCEMENT_XCHG, + routing_key=routing_key, + exclusive=exclusive, + auto_delete=auto_delete, + ) diff --git a/st2common/st2common/transport/bootstrap.py b/st2common/st2common/transport/bootstrap.py index 4c75072fe9f..20d9277fae0 100644 --- a/st2common/st2common/transport/bootstrap.py +++ b/st2common/st2common/transport/bootstrap.py @@ -24,8 +24,9 @@ def _setup(): config.parse_args() # 2. setup logging. - logging.basicConfig(format='%(asctime)s %(levelname)s [-] %(message)s', - level=logging.DEBUG) + logging.basicConfig( + format="%(asctime)s %(levelname)s [-] %(message)s", level=logging.DEBUG + ) def main(): @@ -34,5 +35,5 @@ def main(): # The scripts sets up Exchanges in RabbitMQ. -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/st2common/st2common/transport/bootstrap_utils.py b/st2common/st2common/transport/bootstrap_utils.py index d787adc4938..2eea9ad64b1 100644 --- a/st2common/st2common/transport/bootstrap_utils.py +++ b/st2common/st2common/transport/bootstrap_utils.py @@ -50,15 +50,14 @@ from st2common.transport.queues import WORKFLOW_EXECUTION_WORK_QUEUE from st2common.transport.queues import WORKFLOW_EXECUTION_RESUME_QUEUE -LOG = logging.getLogger('st2common.transport.bootstrap') +LOG = logging.getLogger("st2common.transport.bootstrap") __all__ = [ - 'register_exchanges', - 'register_exchanges_with_retry', - 'register_kombu_serializers', - - 'EXCHANGES', - 'QUEUES' + "register_exchanges", + "register_exchanges_with_retry", + "register_kombu_serializers", + "EXCHANGES", + "QUEUES", ] # List of exchanges which are pre-declared on service set up. @@ -72,7 +71,7 @@ TRIGGER_INSTANCE_XCHG, SENSOR_CUD_XCHG, WORKFLOW_EXECUTION_XCHG, - WORKFLOW_EXECUTION_STATUS_MGMT_XCHG + WORKFLOW_EXECUTION_STATUS_MGMT_XCHG, ] # List of queues which are pre-declared on service startup. @@ -85,41 +84,40 @@ NOTIFIER_ACTIONUPDATE_WORK_QUEUE, RESULTSTRACKER_ACTIONSTATE_WORK_QUEUE, RULESENGINE_WORK_QUEUE, - STREAM_ANNOUNCEMENT_WORK_QUEUE, STREAM_EXECUTION_ALL_WORK_QUEUE, STREAM_LIVEACTION_WORK_QUEUE, STREAM_EXECUTION_OUTPUT_QUEUE, - WORKFLOW_EXECUTION_WORK_QUEUE, WORKFLOW_EXECUTION_RESUME_QUEUE, - # Those queues are dynamically / late created on some class init but we still need to # pre-declare them for redis Kombu backend to work. - reactor.get_trigger_cud_queue(name='st2.preinit', routing_key='init'), - reactor.get_sensor_cud_queue(name='st2.preinit', routing_key='init') + reactor.get_trigger_cud_queue(name="st2.preinit", routing_key="init"), + reactor.get_sensor_cud_queue(name="st2.preinit", routing_key="init"), ] def _do_register_exchange(exchange, connection, channel, retry_wrapper): try: kwargs = { - 'exchange': exchange.name, - 'type': exchange.type, - 'durable': exchange.durable, - 'auto_delete': exchange.auto_delete, - 'arguments': exchange.arguments, - 'nowait': False, - 'passive': False + "exchange": exchange.name, + "type": exchange.type, + "durable": exchange.durable, + "auto_delete": exchange.auto_delete, + "arguments": exchange.arguments, + "nowait": False, + "passive": False, } # Use the retry wrapper to increase resiliency in recoverable errors. - retry_wrapper.ensured(connection=connection, - obj=channel, - to_ensure_func=channel.exchange_declare, - **kwargs) - LOG.debug('Registered exchange %s (%s).' % (exchange.name, str(kwargs))) + retry_wrapper.ensured( + connection=connection, + obj=channel, + to_ensure_func=channel.exchange_declare, + **kwargs, + ) + LOG.debug("Registered exchange %s (%s)." % (exchange.name, str(kwargs))) except Exception: - LOG.exception('Failed to register exchange: %s.', exchange.name) + LOG.exception("Failed to register exchange: %s.", exchange.name) def _do_predeclare_queue(channel, queue): @@ -132,23 +130,31 @@ def _do_predeclare_queue(channel, queue): bound_queue.declare(nowait=False) LOG.debug('Predeclared queue for exchange "%s"' % (queue.exchange.name)) except Exception: - LOG.exception('Failed to predeclare queue for exchange "%s"' % (queue.exchange.name)) + LOG.exception( + 'Failed to predeclare queue for exchange "%s"' % (queue.exchange.name) + ) return bound_queue def register_exchanges(): - LOG.debug('Registering exchanges...') + LOG.debug("Registering exchanges...") connection_urls = transport_utils.get_messaging_urls() with transport_utils.get_connection() as conn: # Use ConnectionRetryWrapper to deal with rmq clustering etc. - retry_wrapper = ConnectionRetryWrapper(cluster_size=len(connection_urls), logger=LOG) + retry_wrapper = ConnectionRetryWrapper( + cluster_size=len(connection_urls), logger=LOG + ) def wrapped_register_exchanges(connection, channel): for exchange in EXCHANGES: - _do_register_exchange(exchange=exchange, connection=connection, channel=channel, - retry_wrapper=retry_wrapper) + _do_register_exchange( + exchange=exchange, + connection=connection, + channel=channel, + retry_wrapper=retry_wrapper, + ) retry_wrapper.run(connection=conn, wrapped_callback=wrapped_register_exchanges) @@ -166,7 +172,7 @@ def retry_if_io_error(exception): retrying_obj = retrying.Retrying( retry_on_exception=retry_if_io_error, wait_fixed=cfg.CONF.messaging.connection_retry_wait, - stop_max_attempt_number=cfg.CONF.messaging.connection_retries + stop_max_attempt_number=cfg.CONF.messaging.connection_retries, ) return retrying_obj.call(register_exchanges) @@ -181,24 +187,33 @@ def register_kombu_serializers(): https://github.com/celery/kombu/blob/3.0/kombu/utils/encoding.py#L47 """ + def pickle_dumps(obj, dumper=pickle.dumps): return dumper(obj, protocol=pickle_protocol) if six.PY3: + def str_to_bytes(s): if isinstance(s, str): - return s.encode('utf-8') + return s.encode("utf-8") return s def unpickle(s): return pickle_loads(str_to_bytes(s)) + else: - def str_to_bytes(s): # noqa - if isinstance(s, unicode): # noqa # pylint: disable=E0602 - return s.encode('utf-8') + + def str_to_bytes(s): # noqa + if isinstance(s, unicode): # noqa # pylint: disable=E0602 + return s.encode("utf-8") return s + unpickle = pickle_loads # noqa - register('pickle', pickle_dumps, unpickle, - content_type='application/x-python-serialize', - content_encoding='binary') + register( + "pickle", + pickle_dumps, + unpickle, + content_type="application/x-python-serialize", + content_encoding="binary", + ) diff --git a/st2common/st2common/transport/connection_retry_wrapper.py b/st2common/st2common/transport/connection_retry_wrapper.py index d0c906fff67..492aa24f32e 100644 --- a/st2common/st2common/transport/connection_retry_wrapper.py +++ b/st2common/st2common/transport/connection_retry_wrapper.py @@ -19,7 +19,7 @@ from st2common.util import concurrency -__all__ = ['ConnectionRetryWrapper', 'ClusterRetryContext'] +__all__ = ["ConnectionRetryWrapper", "ClusterRetryContext"] class ClusterRetryContext(object): @@ -27,6 +27,7 @@ class ClusterRetryContext(object): Stores retry context for cluster retries. It makes certain assumptions on how cluster_size and retry should be determined. """ + def __init__(self, cluster_size): # No of nodes in a cluster self.cluster_size = cluster_size @@ -101,6 +102,7 @@ def wrapped_callback(connection, channel): retry_wrapper.run(connection=connection, wrapped_callback=wrapped_callback) """ + def __init__(self, cluster_size, logger, ensure_max_retries=3): self._retry_context = ClusterRetryContext(cluster_size=cluster_size) self._logger = logger @@ -109,7 +111,7 @@ def __init__(self, cluster_size, logger, ensure_max_retries=3): self._ensure_max_retries = ensure_max_retries def errback(self, exc, interval): - self._logger.error('Rabbitmq connection error: %s', exc.message) + self._logger.error("Rabbitmq connection error: %s", exc.message) def run(self, connection, wrapped_callback): """ @@ -141,8 +143,10 @@ def run(self, connection, wrapped_callback): raise # -1, 0 and 1+ are handled properly by eventlet.sleep - self._logger.debug('Received RabbitMQ server error, sleeping for %s seconds ' - 'before retrying: %s' % (wait, six.text_type(e))) + self._logger.debug( + "Received RabbitMQ server error, sleeping for %s seconds " + "before retrying: %s" % (wait, six.text_type(e)) + ) concurrency.sleep(wait) connection.close() @@ -154,22 +158,28 @@ def run(self, connection, wrapped_callback): def log_error_on_conn_failure(exc, interval): self._logger.debug( - 'Failed to re-establish connection to RabbitMQ server, ' - 'retrying in %s seconds: %s' % (interval, six.text_type(exc)) + "Failed to re-establish connection to RabbitMQ server, " + "retrying in %s seconds: %s" % (interval, six.text_type(exc)) ) try: # NOTE: This function blocks and tries to restablish a connection for # indefinetly if "max_retries" argument is not specified - connection.ensure_connection(max_retries=self._ensure_max_retries, - errback=log_error_on_conn_failure) + connection.ensure_connection( + max_retries=self._ensure_max_retries, + errback=log_error_on_conn_failure, + ) except Exception: - self._logger.exception('Connections to RabbitMQ cannot be re-established: %s', - six.text_type(e)) + self._logger.exception( + "Connections to RabbitMQ cannot be re-established: %s", + six.text_type(e), + ) raise except Exception as e: - self._logger.exception('Connections to RabbitMQ cannot be re-established: %s', - six.text_type(e)) + self._logger.exception( + "Connections to RabbitMQ cannot be re-established: %s", + six.text_type(e), + ) # Not being able to publish a message could be a significant issue for an app. raise finally: @@ -177,7 +187,7 @@ def log_error_on_conn_failure(exc, interval): try: channel.close() except Exception: - self._logger.warning('Error closing channel.', exc_info=True) + self._logger.warning("Error closing channel.", exc_info=True) def ensured(self, connection, obj, to_ensure_func, **kwargs): """ @@ -191,7 +201,6 @@ def ensured(self, connection, obj, to_ensure_func, **kwargs): :type obj: Must support mixin kombu.abstract.MaybeChannelBound """ ensuring_func = connection.ensure( - obj, to_ensure_func, - errback=self.errback, - max_retries=3) + obj, to_ensure_func, errback=self.errback, max_retries=3 + ) ensuring_func(**kwargs) diff --git a/st2common/st2common/transport/consumers.py b/st2common/st2common/transport/consumers.py index 7f626f72a45..dd2f47cb55c 100644 --- a/st2common/st2common/transport/consumers.py +++ b/st2common/st2common/transport/consumers.py @@ -25,12 +25,11 @@ from st2common.util import concurrency __all__ = [ - 'QueueConsumer', - 'StagedQueueConsumer', - 'ActionsQueueConsumer', - - 'MessageHandler', - 'StagedMessageHandler' + "QueueConsumer", + "StagedQueueConsumer", + "ActionsQueueConsumer", + "MessageHandler", + "StagedMessageHandler", ] LOG = logging.getLogger(__name__) @@ -47,7 +46,9 @@ def shutdown(self): self._dispatcher.shutdown() def get_consumers(self, Consumer, channel): - consumer = Consumer(queues=self._queues, accept=['pickle'], callbacks=[self.process]) + consumer = Consumer( + queues=self._queues, accept=["pickle"], callbacks=[self.process] + ) # use prefetch_count=1 for fair dispatch. This way workers that finish an item get the next # task and the work does not get queued behind any single large item. @@ -58,11 +59,15 @@ def get_consumers(self, Consumer, channel): def process(self, body, message): try: if not isinstance(body, self._handler.message_type): - raise TypeError('Received an unexpected type "%s" for payload.' % type(body)) + raise TypeError( + 'Received an unexpected type "%s" for payload.' % type(body) + ) self._dispatcher.dispatch(self._process_message, body) except: - LOG.exception('%s failed to process message: %s', self.__class__.__name__, body) + LOG.exception( + "%s failed to process message: %s", self.__class__.__name__, body + ) finally: # At this point we will always ack a message. message.ack() @@ -71,7 +76,9 @@ def _process_message(self, body): try: self._handler.process(body) except: - LOG.exception('%s failed to process message: %s', self.__class__.__name__, body) + LOG.exception( + "%s failed to process message: %s", self.__class__.__name__, body + ) class StagedQueueConsumer(QueueConsumer): @@ -82,11 +89,15 @@ class StagedQueueConsumer(QueueConsumer): def process(self, body, message): try: if not isinstance(body, self._handler.message_type): - raise TypeError('Received an unexpected type "%s" for payload.' % type(body)) + raise TypeError( + 'Received an unexpected type "%s" for payload.' % type(body) + ) response = self._handler.pre_ack_process(body) self._dispatcher.dispatch(self._process_message, response) except: - LOG.exception('%s failed to process message: %s', self.__class__.__name__, body) + LOG.exception( + "%s failed to process message: %s", self.__class__.__name__, body + ) finally: # At this point we will always ack a message. message.ack() @@ -110,17 +121,21 @@ def __init__(self, connection, queues, handler): workflows_pool_size = cfg.CONF.actionrunner.workflows_pool_size actions_pool_size = cfg.CONF.actionrunner.actions_pool_size - self._workflows_dispatcher = BufferedDispatcher(dispatch_pool_size=workflows_pool_size, - name='workflows-dispatcher') - self._actions_dispatcher = BufferedDispatcher(dispatch_pool_size=actions_pool_size, - name='actions-dispatcher') + self._workflows_dispatcher = BufferedDispatcher( + dispatch_pool_size=workflows_pool_size, name="workflows-dispatcher" + ) + self._actions_dispatcher = BufferedDispatcher( + dispatch_pool_size=actions_pool_size, name="actions-dispatcher" + ) def process(self, body, message): try: if not isinstance(body, self._handler.message_type): - raise TypeError('Received an unexpected type "%s" for payload.' % type(body)) + raise TypeError( + 'Received an unexpected type "%s" for payload.' % type(body) + ) - action_is_workflow = getattr(body, 'action_is_workflow', False) + action_is_workflow = getattr(body, "action_is_workflow", False) if action_is_workflow: # Use workflow dispatcher queue dispatcher = self._workflows_dispatcher @@ -131,7 +146,9 @@ def process(self, body, message): LOG.debug('Using BufferedDispatcher pool: "%s"', str(dispatcher)) dispatcher.dispatch(self._process_message, body) except: - LOG.exception('%s failed to process message: %s', self.__class__.__name__, body) + LOG.exception( + "%s failed to process message: %s", self.__class__.__name__, body + ) finally: # At this point we will always ack a message. message.ack() @@ -149,11 +166,15 @@ class VariableMessageQueueConsumer(QueueConsumer): def process(self, body, message): try: if not self._handler.message_types.get(type(body)): - raise TypeError('Received an unexpected type "%s" for payload.' % type(body)) + raise TypeError( + 'Received an unexpected type "%s" for payload.' % type(body) + ) self._dispatcher.dispatch(self._process_message, body) except: - LOG.exception('%s failed to process message: %s', self.__class__.__name__, body) + LOG.exception( + "%s failed to process message: %s", self.__class__.__name__, body + ) finally: # At this point we will always ack a message. message.ack() @@ -164,12 +185,13 @@ class MessageHandler(object): message_type = None def __init__(self, connection, queues): - self._queue_consumer = self.get_queue_consumer(connection=connection, - queues=queues) + self._queue_consumer = self.get_queue_consumer( + connection=connection, queues=queues + ) self._consumer_thread = None def start(self, wait=False): - LOG.info('Starting %s...', self.__class__.__name__) + LOG.info("Starting %s...", self.__class__.__name__) self._consumer_thread = concurrency.spawn(self._queue_consumer.run) if wait: @@ -179,7 +201,7 @@ def wait(self): self._consumer_thread.wait() def shutdown(self): - LOG.info('Shutting down %s...', self.__class__.__name__) + LOG.info("Shutting down %s...", self.__class__.__name__) self._queue_consumer.shutdown() @abc.abstractmethod @@ -224,4 +246,6 @@ class VariableMessageHandler(MessageHandler): """ def get_queue_consumer(self, connection, queues): - return VariableMessageQueueConsumer(connection=connection, queues=queues, handler=self) + return VariableMessageQueueConsumer( + connection=connection, queues=queues, handler=self + ) diff --git a/st2common/st2common/transport/execution.py b/st2common/st2common/transport/execution.py index e35279ac71f..5d2880fd6f6 100644 --- a/st2common/st2common/transport/execution.py +++ b/st2common/st2common/transport/execution.py @@ -20,15 +20,14 @@ from st2common.transport import publishers __all__ = [ - 'ActionExecutionPublisher', - 'ActionExecutionOutputPublisher', - - 'get_queue', - 'get_output_queue' + "ActionExecutionPublisher", + "ActionExecutionOutputPublisher", + "get_queue", + "get_output_queue", ] -EXECUTION_XCHG = Exchange('st2.execution', type='topic') -EXECUTION_OUTPUT_XCHG = Exchange('st2.execution.output', type='topic') +EXECUTION_XCHG = Exchange("st2.execution", type="topic") +EXECUTION_OUTPUT_XCHG = Exchange("st2.execution.output", type="topic") class ActionExecutionPublisher(publishers.CUDPublisher): @@ -38,14 +37,26 @@ def __init__(self): class ActionExecutionOutputPublisher(publishers.CUDPublisher): def __init__(self): - super(ActionExecutionOutputPublisher, self).__init__(exchange=EXECUTION_OUTPUT_XCHG) + super(ActionExecutionOutputPublisher, self).__init__( + exchange=EXECUTION_OUTPUT_XCHG + ) def get_queue(name=None, routing_key=None, exclusive=False, auto_delete=False): - return Queue(name, EXECUTION_XCHG, routing_key=routing_key, exclusive=exclusive, - auto_delete=auto_delete) + return Queue( + name, + EXECUTION_XCHG, + routing_key=routing_key, + exclusive=exclusive, + auto_delete=auto_delete, + ) def get_output_queue(name=None, routing_key=None, exclusive=False, auto_delete=False): - return Queue(name, EXECUTION_OUTPUT_XCHG, routing_key=routing_key, exclusive=exclusive, - auto_delete=auto_delete) + return Queue( + name, + EXECUTION_OUTPUT_XCHG, + routing_key=routing_key, + exclusive=exclusive, + auto_delete=auto_delete, + ) diff --git a/st2common/st2common/transport/liveaction.py b/st2common/st2common/transport/liveaction.py index 97dd08400b3..670c5ebb2eb 100644 --- a/st2common/st2common/transport/liveaction.py +++ b/st2common/st2common/transport/liveaction.py @@ -21,23 +21,19 @@ from st2common.transport import publishers -__all__ = [ - 'LiveActionPublisher', +__all__ = ["LiveActionPublisher", "get_queue", "get_status_management_queue"] - 'get_queue', - 'get_status_management_queue' -] - -LIVEACTION_XCHG = Exchange('st2.liveaction', type='topic') -LIVEACTION_STATUS_MGMT_XCHG = Exchange('st2.liveaction.status', type='topic') +LIVEACTION_XCHG = Exchange("st2.liveaction", type="topic") +LIVEACTION_STATUS_MGMT_XCHG = Exchange("st2.liveaction.status", type="topic") class LiveActionPublisher(publishers.CUDPublisher, publishers.StatePublisherMixin): - def __init__(self): publishers.CUDPublisher.__init__(self, exchange=LIVEACTION_XCHG) - publishers.StatePublisherMixin.__init__(self, exchange=LIVEACTION_STATUS_MGMT_XCHG) + publishers.StatePublisherMixin.__init__( + self, exchange=LIVEACTION_STATUS_MGMT_XCHG + ) def get_queue(name, routing_key): diff --git a/st2common/st2common/transport/publishers.py b/st2common/st2common/transport/publishers.py index 7942fdfffed..202220acb13 100644 --- a/st2common/st2common/transport/publishers.py +++ b/st2common/st2common/transport/publishers.py @@ -25,16 +25,16 @@ from st2common.transport.connection_retry_wrapper import ConnectionRetryWrapper __all__ = [ - 'PoolPublisher', - 'SharedPoolPublishers', - 'CUDPublisher', - 'StatePublisherMixin' + "PoolPublisher", + "SharedPoolPublishers", + "CUDPublisher", + "StatePublisherMixin", ] -ANY_RK = '*' -CREATE_RK = 'create' -UPDATE_RK = 'update' -DELETE_RK = 'delete' +ANY_RK = "*" +CREATE_RK = "create" +UPDATE_RK = "update" +DELETE_RK = "delete" LOG = logging.getLogger(__name__) @@ -47,19 +47,21 @@ def __init__(self, urls=None): :type urls: ``list`` """ urls = urls or transport_utils.get_messaging_urls() - connection = transport_utils.get_connection(urls=urls, - connection_kwargs={'failover_strategy': - 'round-robin'}) + connection = transport_utils.get_connection( + urls=urls, connection_kwargs={"failover_strategy": "round-robin"} + ) self.pool = connection.Pool(limit=10) self.cluster_size = len(urls) def errback(self, exc, interval): - LOG.error('Rabbitmq connection error: %s', exc.message, exc_info=False) + LOG.error("Rabbitmq connection error: %s", exc.message, exc_info=False) - def publish(self, payload, exchange, routing_key=''): - with Timer(key='amqp.pool_publisher.publish_with_retries.' + exchange.name): + def publish(self, payload, exchange, routing_key=""): + with Timer(key="amqp.pool_publisher.publish_with_retries." + exchange.name): with self.pool.acquire(block=True) as connection: - retry_wrapper = ConnectionRetryWrapper(cluster_size=self.cluster_size, logger=LOG) + retry_wrapper = ConnectionRetryWrapper( + cluster_size=self.cluster_size, logger=LOG + ) def do_publish(connection, channel): # ProducerPool ends up creating it own ConnectionPool which ends up @@ -68,18 +70,18 @@ def do_publish(connection, channel): # Producer for each publish. producer = Producer(channel) kwargs = { - 'body': payload, - 'exchange': exchange, - 'routing_key': routing_key, - 'serializer': 'pickle', - 'content_encoding': 'utf-8' + "body": payload, + "exchange": exchange, + "routing_key": routing_key, + "serializer": "pickle", + "content_encoding": "utf-8", } retry_wrapper.ensured( connection=connection, obj=producer, to_ensure_func=producer.publish, - **kwargs + **kwargs, ) retry_wrapper.run(connection=connection, wrapped_callback=do_publish) @@ -91,6 +93,7 @@ class SharedPoolPublishers(object): server is usually the same. This sharing allows from the same PoolPublisher to be reused for publishing purposes. Sharing publishers leads to shared connections. """ + shared_publishers = {} def get_publisher(self, urls): @@ -99,7 +102,7 @@ def get_publisher(self, urls): # ordering in supplied list. urls_copy = copy.copy(urls) urls_copy.sort() - publisher_key = ''.join(urls_copy) + publisher_key = "".join(urls_copy) publisher = self.shared_publishers.get(publisher_key, None) if not publisher: # Use original urls here to preserve order. @@ -115,15 +118,15 @@ def __init__(self, exchange): self._exchange = exchange def publish_create(self, payload): - with Timer(key='amqp.publish.create'): + with Timer(key="amqp.publish.create"): self._publisher.publish(payload, self._exchange, CREATE_RK) def publish_update(self, payload): - with Timer(key='amqp.publish.update'): + with Timer(key="amqp.publish.update"): self._publisher.publish(payload, self._exchange, UPDATE_RK) def publish_delete(self, payload): - with Timer(key='amqp.publish.delete'): + with Timer(key="amqp.publish.delete"): self._publisher.publish(payload, self._exchange, DELETE_RK) @@ -135,6 +138,6 @@ def __init__(self, exchange): def publish_state(self, payload, state): if not state: - raise Exception('Unable to publish unassigned state.') - with Timer(key='amqp.publish.state'): + raise Exception("Unable to publish unassigned state.") + with Timer(key="amqp.publish.state"): self._state_publisher.publish(payload, self._state_exchange, state) diff --git a/st2common/st2common/transport/queues.py b/st2common/st2common/transport/queues.py index faf6d27fbf6..f6f9bcb4ef8 100644 --- a/st2common/st2common/transport/queues.py +++ b/st2common/st2common/transport/queues.py @@ -34,120 +34,109 @@ from st2common.transport import workflow __all__ = [ - 'ACTIONSCHEDULER_REQUEST_QUEUE', - - 'ACTIONRUNNER_WORK_QUEUE', - 'ACTIONRUNNER_CANCEL_QUEUE', - 'ACTIONRUNNER_PAUSE_QUEUE', - 'ACTIONRUNNER_RESUME_QUEUE', - - 'EXPORTER_WORK_QUEUE', - - 'NOTIFIER_ACTIONUPDATE_WORK_QUEUE', - - 'RESULTSTRACKER_ACTIONSTATE_WORK_QUEUE', - - 'RULESENGINE_WORK_QUEUE', - - 'STREAM_ANNOUNCEMENT_WORK_QUEUE', - 'STREAM_EXECUTION_ALL_WORK_QUEUE', - 'STREAM_EXECUTION_UPDATE_WORK_QUEUE', - 'STREAM_LIVEACTION_WORK_QUEUE', - - 'WORKFLOW_EXECUTION_WORK_QUEUE', - 'WORKFLOW_EXECUTION_RESUME_QUEUE' + "ACTIONSCHEDULER_REQUEST_QUEUE", + "ACTIONRUNNER_WORK_QUEUE", + "ACTIONRUNNER_CANCEL_QUEUE", + "ACTIONRUNNER_PAUSE_QUEUE", + "ACTIONRUNNER_RESUME_QUEUE", + "EXPORTER_WORK_QUEUE", + "NOTIFIER_ACTIONUPDATE_WORK_QUEUE", + "RESULTSTRACKER_ACTIONSTATE_WORK_QUEUE", + "RULESENGINE_WORK_QUEUE", + "STREAM_ANNOUNCEMENT_WORK_QUEUE", + "STREAM_EXECUTION_ALL_WORK_QUEUE", + "STREAM_EXECUTION_UPDATE_WORK_QUEUE", + "STREAM_LIVEACTION_WORK_QUEUE", + "WORKFLOW_EXECUTION_WORK_QUEUE", + "WORKFLOW_EXECUTION_RESUME_QUEUE", ] # Used by the action scheduler service ACTIONSCHEDULER_REQUEST_QUEUE = liveaction.get_status_management_queue( - 'st2.actionrunner.req', - routing_key=action_constants.LIVEACTION_STATUS_REQUESTED) + "st2.actionrunner.req", routing_key=action_constants.LIVEACTION_STATUS_REQUESTED +) # Used by the action runner service ACTIONRUNNER_WORK_QUEUE = liveaction.get_status_management_queue( - 'st2.actionrunner.work', - routing_key=action_constants.LIVEACTION_STATUS_SCHEDULED) + "st2.actionrunner.work", routing_key=action_constants.LIVEACTION_STATUS_SCHEDULED +) ACTIONRUNNER_CANCEL_QUEUE = liveaction.get_status_management_queue( - 'st2.actionrunner.cancel', - routing_key=action_constants.LIVEACTION_STATUS_CANCELING) + "st2.actionrunner.cancel", routing_key=action_constants.LIVEACTION_STATUS_CANCELING +) ACTIONRUNNER_PAUSE_QUEUE = liveaction.get_status_management_queue( - 'st2.actionrunner.pause', - routing_key=action_constants.LIVEACTION_STATUS_PAUSING) + "st2.actionrunner.pause", routing_key=action_constants.LIVEACTION_STATUS_PAUSING +) ACTIONRUNNER_RESUME_QUEUE = liveaction.get_status_management_queue( - 'st2.actionrunner.resume', - routing_key=action_constants.LIVEACTION_STATUS_RESUMING) + "st2.actionrunner.resume", routing_key=action_constants.LIVEACTION_STATUS_RESUMING +) # Used by the exporter service EXPORTER_WORK_QUEUE = execution.get_queue( - 'st2.exporter.work', - routing_key=publishers.UPDATE_RK) + "st2.exporter.work", routing_key=publishers.UPDATE_RK +) # Used by the notifier service NOTIFIER_ACTIONUPDATE_WORK_QUEUE = execution.get_queue( - 'st2.notifiers.execution.work', - routing_key=publishers.UPDATE_RK) + "st2.notifiers.execution.work", routing_key=publishers.UPDATE_RK +) # Used by the results tracker service RESULTSTRACKER_ACTIONSTATE_WORK_QUEUE = actionexecutionstate.get_queue( - 'st2.resultstracker.work', - routing_key=publishers.CREATE_RK) + "st2.resultstracker.work", routing_key=publishers.CREATE_RK +) # Used by the rules engine service RULESENGINE_WORK_QUEUE = reactor.get_trigger_instances_queue( - name='st2.trigger_instances_dispatch.rules_engine', - routing_key='#') + name="st2.trigger_instances_dispatch.rules_engine", routing_key="#" +) # Used by the stream service STREAM_ANNOUNCEMENT_WORK_QUEUE = announcement.get_queue( - routing_key=publishers.ANY_RK, - exclusive=True, - auto_delete=True) + routing_key=publishers.ANY_RK, exclusive=True, auto_delete=True +) STREAM_EXECUTION_ALL_WORK_QUEUE = execution.get_queue( - routing_key=publishers.ANY_RK, - exclusive=True, - auto_delete=True) + routing_key=publishers.ANY_RK, exclusive=True, auto_delete=True +) STREAM_EXECUTION_UPDATE_WORK_QUEUE = execution.get_queue( - routing_key=publishers.UPDATE_RK, - exclusive=True, - auto_delete=True) + routing_key=publishers.UPDATE_RK, exclusive=True, auto_delete=True +) STREAM_LIVEACTION_WORK_QUEUE = Queue( None, liveaction.LIVEACTION_XCHG, routing_key=publishers.ANY_RK, exclusive=True, - auto_delete=True) + auto_delete=True, +) # TODO: Perhaps we should use pack.action name as routing key # so we can do more efficient filtering later, if needed STREAM_EXECUTION_OUTPUT_QUEUE = execution.get_output_queue( - name=None, - routing_key=publishers.CREATE_RK, - exclusive=True, - auto_delete=True) + name=None, routing_key=publishers.CREATE_RK, exclusive=True, auto_delete=True +) # Used by the workflow engine service WORKFLOW_EXECUTION_WORK_QUEUE = workflow.get_status_management_queue( - name='st2.workflow.work', - routing_key=action_constants.LIVEACTION_STATUS_REQUESTED) + name="st2.workflow.work", routing_key=action_constants.LIVEACTION_STATUS_REQUESTED +) WORKFLOW_EXECUTION_RESUME_QUEUE = workflow.get_status_management_queue( - name='st2.workflow.resume', - routing_key=action_constants.LIVEACTION_STATUS_RESUMING) + name="st2.workflow.resume", routing_key=action_constants.LIVEACTION_STATUS_RESUMING +) WORKFLOW_ACTION_EXECUTION_UPDATE_QUEUE = execution.get_queue( - 'st2.workflow.action.update', - routing_key=publishers.UPDATE_RK) + "st2.workflow.action.update", routing_key=publishers.UPDATE_RK +) diff --git a/st2common/st2common/transport/reactor.py b/st2common/st2common/transport/reactor.py index 613a1d08ed8..c9dc84725c0 100644 --- a/st2common/st2common/transport/reactor.py +++ b/st2common/st2common/transport/reactor.py @@ -22,26 +22,24 @@ from st2common.transport import publishers __all__ = [ - 'TriggerCUDPublisher', - 'TriggerInstancePublisher', - - 'TriggerDispatcher', - - 'get_sensor_cud_queue', - 'get_trigger_cud_queue', - 'get_trigger_instances_queue' + "TriggerCUDPublisher", + "TriggerInstancePublisher", + "TriggerDispatcher", + "get_sensor_cud_queue", + "get_trigger_cud_queue", + "get_trigger_instances_queue", ] LOG = logging.getLogger(__name__) # Exchange for Trigger CUD events -TRIGGER_CUD_XCHG = Exchange('st2.trigger', type='topic') +TRIGGER_CUD_XCHG = Exchange("st2.trigger", type="topic") # Exchange for TriggerInstance events -TRIGGER_INSTANCE_XCHG = Exchange('st2.trigger_instances_dispatch', type='topic') +TRIGGER_INSTANCE_XCHG = Exchange("st2.trigger_instances_dispatch", type="topic") # Exchane for Sensor CUD events -SENSOR_CUD_XCHG = Exchange('st2.sensor', type='topic') +SENSOR_CUD_XCHG = Exchange("st2.sensor", type="topic") class SensorCUDPublisher(publishers.CUDPublisher): @@ -96,14 +94,12 @@ def dispatch(self, trigger, payload=None, trace_context=None): assert isinstance(payload, (type(None), dict)) assert isinstance(trace_context, (type(None), TraceContext)) - payload = { - 'trigger': trigger, - 'payload': payload, - TRACE_CONTEXT: trace_context - } - routing_key = 'trigger_instance' + payload = {"trigger": trigger, "payload": payload, TRACE_CONTEXT: trace_context} + routing_key = "trigger_instance" - self._logger.debug('Dispatching trigger (trigger=%s,payload=%s)', trigger, payload) + self._logger.debug( + "Dispatching trigger (trigger=%s,payload=%s)", trigger, payload + ) self._publisher.publish_trigger(payload=payload, routing_key=routing_key) diff --git a/st2common/st2common/transport/utils.py b/st2common/st2common/transport/utils.py index bea2df1e573..e479713ddcd 100644 --- a/st2common/st2common/transport/utils.py +++ b/st2common/st2common/transport/utils.py @@ -22,22 +22,18 @@ from st2common import log as logging -__all__ = [ - 'get_connection', - - 'get_messaging_urls' -] +__all__ = ["get_connection", "get_messaging_urls"] LOG = logging.getLogger(__name__) def get_messaging_urls(): - ''' + """ Determines the right messaging urls to supply. In case the `cluster_urls` config is specified then that is used. Else the single `url` property is used. :rtype: ``list`` - ''' + """ if cfg.CONF.messaging.cluster_urls: return cfg.CONF.messaging.cluster_urls return [cfg.CONF.messaging.url] @@ -57,33 +53,41 @@ def get_connection(urls=None, connection_kwargs=None): kwargs = {} - ssl_kwargs = _get_ssl_kwargs(ssl=cfg.CONF.messaging.ssl, - ssl_keyfile=cfg.CONF.messaging.ssl_keyfile, - ssl_certfile=cfg.CONF.messaging.ssl_certfile, - ssl_cert_reqs=cfg.CONF.messaging.ssl_cert_reqs, - ssl_ca_certs=cfg.CONF.messaging.ssl_ca_certs, - login_method=cfg.CONF.messaging.login_method) + ssl_kwargs = _get_ssl_kwargs( + ssl=cfg.CONF.messaging.ssl, + ssl_keyfile=cfg.CONF.messaging.ssl_keyfile, + ssl_certfile=cfg.CONF.messaging.ssl_certfile, + ssl_cert_reqs=cfg.CONF.messaging.ssl_cert_reqs, + ssl_ca_certs=cfg.CONF.messaging.ssl_ca_certs, + login_method=cfg.CONF.messaging.login_method, + ) # NOTE: "connection_kwargs" argument passed to this function has precedence over config values - if len(ssl_kwargs) == 1 and ssl_kwargs['ssl'] is True: - kwargs.update({'ssl': True}) + if len(ssl_kwargs) == 1 and ssl_kwargs["ssl"] is True: + kwargs.update({"ssl": True}) elif len(ssl_kwargs) >= 2: - ssl_kwargs.pop('ssl') - kwargs.update({'ssl': ssl_kwargs}) + ssl_kwargs.pop("ssl") + kwargs.update({"ssl": ssl_kwargs}) - kwargs['login_method'] = cfg.CONF.messaging.login_method + kwargs["login_method"] = cfg.CONF.messaging.login_method kwargs.update(connection_kwargs) # NOTE: This line contains no secret values so it's OK to log it - LOG.debug('Using SSL context for RabbitMQ connection: %s' % (ssl_kwargs)) + LOG.debug("Using SSL context for RabbitMQ connection: %s" % (ssl_kwargs)) connection = Connection(urls, **kwargs) return connection -def _get_ssl_kwargs(ssl=False, ssl_keyfile=None, ssl_certfile=None, ssl_cert_reqs=None, - ssl_ca_certs=None, login_method=None): +def _get_ssl_kwargs( + ssl=False, + ssl_keyfile=None, + ssl_certfile=None, + ssl_cert_reqs=None, + ssl_ca_certs=None, + login_method=None, +): """ Return SSL keyword arguments to be used with the kombu.Connection class. """ @@ -93,27 +97,27 @@ def _get_ssl_kwargs(ssl=False, ssl_keyfile=None, ssl_certfile=None, ssl_cert_req # because user could still specify to use SSL by including "?ssl=true" query param at the # end of the connection URL string if ssl is True: - ssl_kwargs['ssl'] = True + ssl_kwargs["ssl"] = True if ssl_keyfile: - ssl_kwargs['ssl'] = True - ssl_kwargs['keyfile'] = ssl_keyfile + ssl_kwargs["ssl"] = True + ssl_kwargs["keyfile"] = ssl_keyfile if ssl_certfile: - ssl_kwargs['ssl'] = True - ssl_kwargs['certfile'] = ssl_certfile + ssl_kwargs["ssl"] = True + ssl_kwargs["certfile"] = ssl_certfile if ssl_cert_reqs: - if ssl_cert_reqs == 'none': + if ssl_cert_reqs == "none": ssl_cert_reqs = ssl_lib.CERT_NONE - elif ssl_cert_reqs == 'optional': + elif ssl_cert_reqs == "optional": ssl_cert_reqs = ssl_lib.CERT_OPTIONAL - elif ssl_cert_reqs == 'required': + elif ssl_cert_reqs == "required": ssl_cert_reqs = ssl_lib.CERT_REQUIRED - ssl_kwargs['cert_reqs'] = ssl_cert_reqs + ssl_kwargs["cert_reqs"] = ssl_cert_reqs if ssl_ca_certs: - ssl_kwargs['ssl'] = True - ssl_kwargs['ca_certs'] = ssl_ca_certs + ssl_kwargs["ssl"] = True + ssl_kwargs["ca_certs"] = ssl_ca_certs return ssl_kwargs diff --git a/st2common/st2common/transport/workflow.py b/st2common/st2common/transport/workflow.py index 2b9815fcb7c..0302611a366 100644 --- a/st2common/st2common/transport/workflow.py +++ b/st2common/st2common/transport/workflow.py @@ -21,22 +21,22 @@ from st2common.transport import publishers -__all__ = [ - 'WorkflowExecutionPublisher', +__all__ = ["WorkflowExecutionPublisher", "get_queue", "get_status_management_queue"] - 'get_queue', - 'get_status_management_queue' -] +WORKFLOW_EXECUTION_XCHG = kombu.Exchange("st2.workflow", type="topic") +WORKFLOW_EXECUTION_STATUS_MGMT_XCHG = kombu.Exchange( + "st2.workflow.status", type="topic" +) -WORKFLOW_EXECUTION_XCHG = kombu.Exchange('st2.workflow', type='topic') -WORKFLOW_EXECUTION_STATUS_MGMT_XCHG = kombu.Exchange('st2.workflow.status', type='topic') - - -class WorkflowExecutionPublisher(publishers.CUDPublisher, publishers.StatePublisherMixin): +class WorkflowExecutionPublisher( + publishers.CUDPublisher, publishers.StatePublisherMixin +): def __init__(self): publishers.CUDPublisher.__init__(self, exchange=WORKFLOW_EXECUTION_XCHG) - publishers.StatePublisherMixin.__init__(self, exchange=WORKFLOW_EXECUTION_STATUS_MGMT_XCHG) + publishers.StatePublisherMixin.__init__( + self, exchange=WORKFLOW_EXECUTION_STATUS_MGMT_XCHG + ) def get_queue(name, routing_key): @@ -44,4 +44,6 @@ def get_queue(name, routing_key): def get_status_management_queue(name, routing_key): - return kombu.Queue(name, WORKFLOW_EXECUTION_STATUS_MGMT_XCHG, routing_key=routing_key) + return kombu.Queue( + name, WORKFLOW_EXECUTION_STATUS_MGMT_XCHG, routing_key=routing_key + ) diff --git a/st2common/st2common/triggers.py b/st2common/st2common/triggers.py index a18dadedb90..ec0dba378e8 100644 --- a/st2common/st2common/triggers.py +++ b/st2common/st2common/triggers.py @@ -22,52 +22,63 @@ from oslo_config import cfg from st2common import log as logging -from st2common.constants.triggers import (INTERNAL_TRIGGER_TYPES, ACTION_SENSOR_TRIGGER) +from st2common.constants.triggers import INTERNAL_TRIGGER_TYPES, ACTION_SENSOR_TRIGGER from st2common.exceptions.db import StackStormDBObjectConflictError from st2common.services.triggers import create_trigger_type_db from st2common.services.triggers import create_shadow_trigger from st2common.services.triggers import get_trigger_type_db from st2common.models.system.common import ResourceReference -__all__ = [ - 'register_internal_trigger_types' -] +__all__ = ["register_internal_trigger_types"] LOG = logging.getLogger(__name__) def _register_internal_trigger_type(trigger_definition): try: - trigger_type_db = create_trigger_type_db(trigger_type=trigger_definition, - log_not_unique_error_as_debug=True) + trigger_type_db = create_trigger_type_db( + trigger_type=trigger_definition, log_not_unique_error_as_debug=True + ) except (NotUniqueError, StackStormDBObjectConflictError): # We ignore conflict error since this operation is idempotent and race is not an issue - LOG.debug('Internal trigger type "%s" already exists, ignoring error...' % - (trigger_definition['name'])) - - ref = ResourceReference.to_string_reference(name=trigger_definition['name'], - pack=trigger_definition['pack']) + LOG.debug( + 'Internal trigger type "%s" already exists, ignoring error...' + % (trigger_definition["name"]) + ) + + ref = ResourceReference.to_string_reference( + name=trigger_definition["name"], pack=trigger_definition["pack"] + ) trigger_type_db = get_trigger_type_db(ref) if trigger_type_db: - LOG.debug('Registered internal trigger: %s.', trigger_definition['name']) + LOG.debug("Registered internal trigger: %s.", trigger_definition["name"]) # trigger types with parameters do no require a shadow trigger. if trigger_type_db and not trigger_type_db.parameters_schema: try: - trigger_db = create_shadow_trigger(trigger_type_db, - log_not_unique_error_as_debug=True) - - extra = {'trigger_db': trigger_db} - LOG.audit('Trigger created for parameter-less internal TriggerType. Trigger.id=%s' % - (trigger_db.id), extra=extra) + trigger_db = create_shadow_trigger( + trigger_type_db, log_not_unique_error_as_debug=True + ) + + extra = {"trigger_db": trigger_db} + LOG.audit( + "Trigger created for parameter-less internal TriggerType. Trigger.id=%s" + % (trigger_db.id), + extra=extra, + ) except (NotUniqueError, StackStormDBObjectConflictError): - LOG.debug('Shadow trigger "%s" already exists. Ignoring.', - trigger_type_db.get_reference().ref, exc_info=True) + LOG.debug( + 'Shadow trigger "%s" already exists. Ignoring.', + trigger_type_db.get_reference().ref, + exc_info=True, + ) except (ValidationError, ValueError): - LOG.exception('Validation failed in shadow trigger. TriggerType=%s.', - trigger_type_db.get_reference().ref) + LOG.exception( + "Validation failed in shadow trigger. TriggerType=%s.", + trigger_type_db.get_reference().ref, + ) raise return trigger_type_db @@ -89,16 +100,21 @@ def register_internal_trigger_types(): for _, trigger_definitions in six.iteritems(INTERNAL_TRIGGER_TYPES): for trigger_definition in trigger_definitions: - LOG.debug('Registering internal trigger: %s', trigger_definition['name']) + LOG.debug("Registering internal trigger: %s", trigger_definition["name"]) - is_action_trigger = trigger_definition['name'] == ACTION_SENSOR_TRIGGER['name'] + is_action_trigger = ( + trigger_definition["name"] == ACTION_SENSOR_TRIGGER["name"] + ) if is_action_trigger and not action_sensor_enabled: continue try: trigger_type_db = _register_internal_trigger_type( - trigger_definition=trigger_definition) + trigger_definition=trigger_definition + ) except Exception: - LOG.exception('Failed registering internal trigger: %s.', trigger_definition) + LOG.exception( + "Failed registering internal trigger: %s.", trigger_definition + ) raise else: registered_trigger_types_db.append(trigger_type_db) diff --git a/st2common/st2common/util/action_db.py b/st2common/st2common/util/action_db.py index 610b698c18c..48806933483 100644 --- a/st2common/st2common/util/action_db.py +++ b/st2common/st2common/util/action_db.py @@ -14,6 +14,7 @@ # limitations under the License. from __future__ import absolute_import + try: import simplejson as json except ImportError: @@ -42,15 +43,15 @@ __all__ = [ - 'get_action_parameters_specs', - 'get_runnertype_by_id', - 'get_runnertype_by_name', - 'get_action_by_id', - 'get_action_by_ref', - 'get_liveaction_by_id', - 'update_liveaction_status', - 'serialize_positional_argument', - 'get_args' + "get_action_parameters_specs", + "get_runnertype_by_id", + "get_runnertype_by_name", + "get_action_by_id", + "get_action_by_ref", + "get_liveaction_by_id", + "update_liveaction_status", + "serialize_positional_argument", + "get_args", ] @@ -71,11 +72,11 @@ def get_action_parameters_specs(action_ref): if not action_db: return parameters - runner_type_name = action_db.runner_type['name'] + runner_type_name = action_db.runner_type["name"] runner_type_db = get_runnertype_by_name(runnertype_name=runner_type_name) # Runner type parameters should be added first before the action parameters. - parameters.update(runner_type_db['runner_parameters']) + parameters.update(runner_type_db["runner_parameters"]) parameters.update(action_db.parameters) return parameters @@ -83,60 +84,76 @@ def get_action_parameters_specs(action_ref): def get_runnertype_by_id(runnertype_id): """ - Get RunnerType by id. + Get RunnerType by id. - On error, raise StackStormDBObjectNotFoundError + On error, raise StackStormDBObjectNotFoundError """ try: runnertype = RunnerType.get_by_id(runnertype_id) except (ValueError, ValidationError) as e: - LOG.warning('Database lookup for runnertype with id="%s" resulted in ' - 'exception: %s', runnertype_id, e) - raise StackStormDBObjectNotFoundError('Unable to find runnertype with ' - 'id="%s"' % runnertype_id) + LOG.warning( + 'Database lookup for runnertype with id="%s" resulted in ' "exception: %s", + runnertype_id, + e, + ) + raise StackStormDBObjectNotFoundError( + "Unable to find runnertype with " 'id="%s"' % runnertype_id + ) return runnertype def get_runnertype_by_name(runnertype_name): """ - Get an runnertype by name. - On error, raise ST2ObjectNotFoundError. + Get an runnertype by name. + On error, raise ST2ObjectNotFoundError. """ try: runnertypes = RunnerType.query(name=runnertype_name) except (ValueError, ValidationError) as e: - LOG.error('Database lookup for name="%s" resulted in exception: %s', - runnertype_name, e) - raise StackStormDBObjectNotFoundError('Unable to find runnertype with name="%s"' - % runnertype_name) + LOG.error( + 'Database lookup for name="%s" resulted in exception: %s', + runnertype_name, + e, + ) + raise StackStormDBObjectNotFoundError( + 'Unable to find runnertype with name="%s"' % runnertype_name + ) if not runnertypes: - raise StackStormDBObjectNotFoundError('Unable to find RunnerType with name="%s"' - % runnertype_name) + raise StackStormDBObjectNotFoundError( + 'Unable to find RunnerType with name="%s"' % runnertype_name + ) if len(runnertypes) > 1: - LOG.warning('More than one RunnerType returned from DB lookup by name. ' - 'Result list is: %s', runnertypes) + LOG.warning( + "More than one RunnerType returned from DB lookup by name. " + "Result list is: %s", + runnertypes, + ) return runnertypes[0] def get_action_by_id(action_id): """ - Get Action by id. + Get Action by id. - On error, raise StackStormDBObjectNotFoundError + On error, raise StackStormDBObjectNotFoundError """ action = None try: action = Action.get_by_id(action_id) except (ValueError, ValidationError) as e: - LOG.warning('Database lookup for action with id="%s" resulted in ' - 'exception: %s', action_id, e) - raise StackStormDBObjectNotFoundError('Unable to find action with ' - 'id="%s"' % action_id) + LOG.warning( + 'Database lookup for action with id="%s" resulted in ' "exception: %s", + action_id, + e, + ) + raise StackStormDBObjectNotFoundError( + "Unable to find action with " 'id="%s"' % action_id + ) return action @@ -153,56 +170,78 @@ def get_action_by_ref(ref): try: return Action.get_by_ref(ref) except ValueError as e: - LOG.debug('Database lookup for ref="%s" resulted ' + - 'in exception : %s.', ref, e, exc_info=True) + LOG.debug( + 'Database lookup for ref="%s" resulted ' + "in exception : %s.", + ref, + e, + exc_info=True, + ) return None def get_liveaction_by_id(liveaction_id): """ - Get LiveAction by id. + Get LiveAction by id. - On error, raise ST2DBObjectNotFoundError. + On error, raise ST2DBObjectNotFoundError. """ liveaction = None try: liveaction = LiveAction.get_by_id(liveaction_id) except (ValidationError, ValueError) as e: - LOG.error('Database lookup for LiveAction with id="%s" resulted in ' - 'exception: %s', liveaction_id, e) - raise StackStormDBObjectNotFoundError('Unable to find LiveAction with ' - 'id="%s"' % liveaction_id) + LOG.error( + 'Database lookup for LiveAction with id="%s" resulted in ' "exception: %s", + liveaction_id, + e, + ) + raise StackStormDBObjectNotFoundError( + "Unable to find LiveAction with " 'id="%s"' % liveaction_id + ) return liveaction -def update_liveaction_status(status=None, result=None, context=None, end_timestamp=None, - liveaction_id=None, runner_info=None, liveaction_db=None, - publish=True): +def update_liveaction_status( + status=None, + result=None, + context=None, + end_timestamp=None, + liveaction_id=None, + runner_info=None, + liveaction_db=None, + publish=True, +): """ - Update the status of the specified LiveAction to the value provided in - new_status. + Update the status of the specified LiveAction to the value provided in + new_status. - The LiveAction may be specified using either liveaction_id, or as an - liveaction_db instance. + The LiveAction may be specified using either liveaction_id, or as an + liveaction_db instance. """ if (liveaction_id is None) and (liveaction_db is None): - raise ValueError('Must specify an liveaction_id or an liveaction_db when ' - 'calling update_LiveAction_status') + raise ValueError( + "Must specify an liveaction_id or an liveaction_db when " + "calling update_LiveAction_status" + ) if liveaction_db is None: liveaction_db = get_liveaction_by_id(liveaction_id) if status not in LIVEACTION_STATUSES: - raise ValueError('Attempting to set status for LiveAction "%s" ' - 'to unknown status string. Unknown status is "%s"' - % (liveaction_db, status)) + raise ValueError( + 'Attempting to set status for LiveAction "%s" ' + 'to unknown status string. Unknown status is "%s"' % (liveaction_db, status) + ) - if result and cfg.CONF.system.validate_output_schema and status == LIVEACTION_STATUS_SUCCEEDED: + if ( + result + and cfg.CONF.system.validate_output_schema + and status == LIVEACTION_STATUS_SUCCEEDED + ): action_db = get_action_by_ref(liveaction_db.action) - runner_db = get_runnertype_by_name(action_db.runner_type['name']) + runner_db = get_runnertype_by_name(action_db.runner_type["name"]) result, status = output_schema.validate_output( runner_db.output_schema, action_db.output_schema, @@ -214,21 +253,33 @@ def update_liveaction_status(status=None, result=None, context=None, end_timesta # If liveaction_db status is set then we need to decrement the counter # because it is transitioning to a new state if liveaction_db.status: - get_driver().dec_counter('action.executions.%s' % (liveaction_db.status)) + get_driver().dec_counter("action.executions.%s" % (liveaction_db.status)) # If status is provided then we need to increment the timer because the action # is transitioning into this new state if status: - get_driver().inc_counter('action.executions.%s' % (status)) + get_driver().inc_counter("action.executions.%s" % (status)) - extra = {'liveaction_db': liveaction_db} - LOG.debug('Updating ActionExection: "%s" with status="%s"', liveaction_db.id, status, - extra=extra) + extra = {"liveaction_db": liveaction_db} + LOG.debug( + 'Updating ActionExection: "%s" with status="%s"', + liveaction_db.id, + status, + extra=extra, + ) # If liveaction is already canceled, then do not allow status to be updated. - if liveaction_db.status == LIVEACTION_STATUS_CANCELED and status != LIVEACTION_STATUS_CANCELED: - LOG.info('Unable to update ActionExecution "%s" with status="%s". ' - 'ActionExecution is already canceled.', liveaction_db.id, status, extra=extra) + if ( + liveaction_db.status == LIVEACTION_STATUS_CANCELED + and status != LIVEACTION_STATUS_CANCELED + ): + LOG.info( + 'Unable to update ActionExecution "%s" with status="%s". ' + "ActionExecution is already canceled.", + liveaction_db.id, + status, + extra=extra, + ) return liveaction_db old_status = liveaction_db.status @@ -250,11 +301,11 @@ def update_liveaction_status(status=None, result=None, context=None, end_timesta # manipulated fields liveaction_db = LiveAction.add_or_update(liveaction_db) - LOG.debug('Updated status for LiveAction object.', extra=extra) + LOG.debug("Updated status for LiveAction object.", extra=extra) if publish and status != old_status: LiveAction.publish_status(liveaction_db) - LOG.debug('Published status for LiveAction object.', extra=extra) + LOG.debug("Published status for LiveAction object.", extra=extra) return liveaction_db @@ -267,9 +318,9 @@ def serialize_positional_argument(argument_type, argument_value): sense for shell script actions (only the outter / top level value is serialized). """ - if argument_type in ['string', 'number', 'float']: + if argument_type in ["string", "number", "float"]: if argument_value is None: - argument_value = six.text_type('') + argument_value = six.text_type("") return argument_value if isinstance(argument_value, (int, float)): @@ -277,25 +328,25 @@ def serialize_positional_argument(argument_type, argument_value): if not isinstance(argument_value, six.text_type): # cast string non-unicode values to unicode - argument_value = argument_value.decode('utf-8') - elif argument_type == 'boolean': + argument_value = argument_value.decode("utf-8") + elif argument_type == "boolean": # Booleans are serialized as string "1" and "0" if argument_value is not None: - argument_value = '1' if bool(argument_value) else '0' + argument_value = "1" if bool(argument_value) else "0" else: - argument_value = '' - elif argument_type in ['array', 'list']: + argument_value = "" + elif argument_type in ["array", "list"]: # Lists are serialized a comma delimited string (foo,bar,baz) - argument_value = ','.join(map(str, argument_value)) if argument_value else '' - elif argument_type == 'object': + argument_value = ",".join(map(str, argument_value)) if argument_value else "" + elif argument_type == "object": # Objects are serialized as JSON - argument_value = json.dumps(argument_value) if argument_value else '' - elif argument_type == 'null': + argument_value = json.dumps(argument_value) if argument_value else "" + elif argument_type == "null": # None / null is serialized as en empty string - argument_value = '' + argument_value = "" else: # Other values are simply cast to unicode string - argument_value = six.text_type(argument_value) if argument_value else '' + argument_value = six.text_type(argument_value) if argument_value else "" return argument_value @@ -315,12 +366,13 @@ def get_args(action_parameters, action_db): positional_args = [] positional_args_keys = set() for _, arg in six.iteritems(position_args_dict): - arg_type = action_db_parameters.get(arg, {}).get('type', None) + arg_type = action_db_parameters.get(arg, {}).get("type", None) # Perform serialization for positional arguments arg_value = action_parameters.get(arg, None) - arg_value = serialize_positional_argument(argument_type=arg_type, - argument_value=arg_value) + arg_value = serialize_positional_argument( + argument_type=arg_type, argument_value=arg_value + ) positional_args.append(arg_value) positional_args_keys.add(arg) @@ -340,7 +392,7 @@ def _get_position_arg_dict(action_parameters, action_db): for param in action_db_params: param_meta = action_db_params.get(param, None) if param_meta is not None: - pos = param_meta.get('position') + pos = param_meta.get("position") if pos is not None: args_dict[pos] = param args_dict = OrderedDict(sorted(args_dict.items())) diff --git a/st2common/st2common/util/actionalias_helpstring.py b/st2common/st2common/util/actionalias_helpstring.py index ddee088c8cd..109328f9264 100644 --- a/st2common/st2common/util/actionalias_helpstring.py +++ b/st2common/st2common/util/actionalias_helpstring.py @@ -18,9 +18,7 @@ from st2common.util.actionalias_matching import normalise_alias_format_string -__all__ = [ - 'generate_helpstring_result' -] +__all__ = ["generate_helpstring_result"] def generate_helpstring_result(aliases, filter=None, pack=None, limit=0, offset=0): @@ -44,7 +42,7 @@ def generate_helpstring_result(aliases, filter=None, pack=None, limit=0, offset= matches = [] count = 0 if not (isinstance(limit, int) and isinstance(offset, int)): - raise TypeError('limit or offset argument is not an integer') + raise TypeError("limit or offset argument is not an integer") for alias in aliases: # Skip disable aliases. if not alias.enabled: @@ -56,7 +54,7 @@ def generate_helpstring_result(aliases, filter=None, pack=None, limit=0, offset= display, _, _ = normalise_alias_format_string(format_) if display: # Skip help strings not containing keyword. - if not re.search(filter or '', display, flags=re.IGNORECASE): + if not re.search(filter or "", display, flags=re.IGNORECASE): continue # Skip over help strings not within the requested offset/limit range. if (offset == 0 and limit > 0) and count >= limit: @@ -65,14 +63,18 @@ def generate_helpstring_result(aliases, filter=None, pack=None, limit=0, offset= elif (offset > 0 and limit == 0) and count < offset: count += 1 continue - elif (offset > 0 and limit > 0) and (count < offset or count >= offset + limit): + elif (offset > 0 and limit > 0) and ( + count < offset or count >= offset + limit + ): count += 1 continue - matches.append({ - "pack": alias.pack, - "display": display, - "description": alias.description - }) + matches.append( + { + "pack": alias.pack, + "display": display, + "description": alias.description, + } + ) count += 1 return {"available": count, "helpstrings": matches} diff --git a/st2common/st2common/util/actionalias_matching.py b/st2common/st2common/util/actionalias_matching.py index 3827b12d934..1b20fad4144 100644 --- a/st2common/st2common/util/actionalias_matching.py +++ b/st2common/st2common/util/actionalias_matching.py @@ -24,15 +24,15 @@ from st2common.models.utils.action_alias_utils import extract_parameters __all__ = [ - 'list_format_strings_from_aliases', - 'normalise_alias_format_string', - 'match_command_to_alias', - 'get_matching_alias', + "list_format_strings_from_aliases", + "normalise_alias_format_string", + "match_command_to_alias", + "get_matching_alias", ] def list_format_strings_from_aliases(aliases, match_multiple=False): - ''' + """ List patterns from a collection of alias objects :param aliases: The list of aliases @@ -40,34 +40,40 @@ def list_format_strings_from_aliases(aliases, match_multiple=False): :return: A description of potential execution patterns in a list of aliases. :rtype: ``list`` of ``list`` - ''' + """ patterns = [] for alias in aliases: for format_ in alias.formats: - display, representations, _match_multiple = normalise_alias_format_string(format_) + display, representations, _match_multiple = normalise_alias_format_string( + format_ + ) if display and len(representations) == 0: - patterns.append({ - 'alias': alias, - 'format': format_, - 'display': display, - 'representation': '', - }) - else: - patterns.extend([ + patterns.append( { - 'alias': alias, - 'format': format_, - 'display': display, - 'representation': representation, - 'match_multiple': _match_multiple, + "alias": alias, + "format": format_, + "display": display, + "representation": "", } - for representation in representations - ]) + ) + else: + patterns.extend( + [ + { + "alias": alias, + "format": format_, + "display": display, + "representation": representation, + "match_multiple": _match_multiple, + } + for representation in representations + ] + ) return patterns def normalise_alias_format_string(alias_format): - ''' + """ StackStorm action aliases come in two forms; 1. A string holding the format, which is also used as the help string. 2. A dictionary containing "display" and/or "representation" keys. @@ -80,7 +86,7 @@ def normalise_alias_format_string(alias_format): :return: The representation of the alias :rtype: ``tuple`` of (``str``, ``str``) - ''' + """ display = None representation = [] match_multiple = False @@ -89,14 +95,16 @@ def normalise_alias_format_string(alias_format): display = alias_format representation.append(alias_format) elif isinstance(alias_format, dict): - display = alias_format.get('display') - representation = alias_format.get('representation') or [] + display = alias_format.get("display") + representation = alias_format.get("representation") or [] if isinstance(representation, six.string_types): representation = [representation] - match_multiple = alias_format.get('match_multiple', match_multiple) + match_multiple = alias_format.get("match_multiple", match_multiple) else: - raise TypeError("alias_format '%s' is neither a dictionary or string type." - % repr(alias_format)) + raise TypeError( + "alias_format '%s' is neither a dictionary or string type." + % repr(alias_format) + ) return (display, representation, match_multiple) @@ -110,8 +118,9 @@ def match_command_to_alias(command, aliases, match_multiple=False): formats = list_format_strings_from_aliases([alias], match_multiple) for format_ in formats: try: - extract_parameters(format_str=format_['representation'], - param_stream=command) + extract_parameters( + format_str=format_["representation"], param_stream=command + ) except ParseException: continue @@ -125,35 +134,41 @@ def get_matching_alias(command): """ # 1. Get aliases action_alias_dbs = ActionAlias.query( - Q(formats__match_multiple=None) | Q(formats__match_multiple=False), - enabled=True) + Q(formats__match_multiple=None) | Q(formats__match_multiple=False), enabled=True + ) # 2. Match alias(es) to command matches = match_command_to_alias(command=command, aliases=action_alias_dbs) if len(matches) > 1: - raise ActionAliasAmbiguityException("Command '%s' matched more than 1 pattern" % - command, - matches=matches, - command=command) + raise ActionAliasAmbiguityException( + "Command '%s' matched more than 1 pattern" % command, + matches=matches, + command=command, + ) elif len(matches) == 0: match_multiple_action_alias_dbs = ActionAlias.query( - formats__match_multiple=True, - enabled=True) + formats__match_multiple=True, enabled=True + ) - matches = match_command_to_alias(command=command, aliases=match_multiple_action_alias_dbs, - match_multiple=True) + matches = match_command_to_alias( + command=command, + aliases=match_multiple_action_alias_dbs, + match_multiple=True, + ) if len(matches) > 1: - raise ActionAliasAmbiguityException("Command '%s' matched more than 1 (multi) pattern" % - command, - matches=matches, - command=command) + raise ActionAliasAmbiguityException( + "Command '%s' matched more than 1 (multi) pattern" % command, + matches=matches, + command=command, + ) if len(matches) == 0: - raise ActionAliasAmbiguityException("Command '%s' matched no patterns" % - command, - matches=[], - command=command) + raise ActionAliasAmbiguityException( + "Command '%s' matched no patterns" % command, + matches=[], + command=command, + ) return matches[0] diff --git a/st2common/st2common/util/api.py b/st2common/st2common/util/api.py index 4e0e3f49380..2c378ad7269 100644 --- a/st2common/st2common/util/api.py +++ b/st2common/st2common/util/api.py @@ -21,8 +21,8 @@ from st2common.util.url import get_url_without_trailing_slash __all__ = [ - 'get_base_public_api_url', - 'get_full_public_api_url', + "get_base_public_api_url", + "get_full_public_api_url", ] LOG = logging.getLogger(__name__) @@ -40,7 +40,7 @@ def get_base_public_api_url(): api_url = get_url_without_trailing_slash(cfg.CONF.auth.api_url) else: LOG.warn('"auth.api_url" configuration option is not configured') - api_url = 'http://%s:%s' % (cfg.CONF.api.host, cfg.CONF.api.port) + api_url = "http://%s:%s" % (cfg.CONF.api.host, cfg.CONF.api.port) return api_url @@ -52,5 +52,5 @@ def get_full_public_api_url(api_version=DEFAULT_API_VERSION): :rtype: ``str`` """ api_url = get_base_public_api_url() - api_url = '%s/%s' % (api_url, api_version) + api_url = "%s/%s" % (api_url, api_version) return api_url diff --git a/st2common/st2common/util/argument_parser.py b/st2common/st2common/util/argument_parser.py index 28645ad15f2..757f171661d 100644 --- a/st2common/st2common/util/argument_parser.py +++ b/st2common/st2common/util/argument_parser.py @@ -16,9 +16,7 @@ from __future__ import absolute_import import argparse -__all__ = [ - 'generate_argument_parser_for_metadata' -] +__all__ = ["generate_argument_parser_for_metadata"] def generate_argument_parser_for_metadata(metadata): @@ -32,37 +30,37 @@ def generate_argument_parser_for_metadata(metadata): :return: Generated argument parser instance. :rtype: :class:`argparse.ArgumentParser` """ - parameters = metadata['parameters'] + parameters = metadata["parameters"] - parser = argparse.ArgumentParser(description=metadata['description']) + parser = argparse.ArgumentParser(description=metadata["description"]) for parameter_name, parameter_options in parameters.items(): - name = parameter_name.replace('_', '-') - description = parameter_options['description'] - _type = parameter_options['type'] - required = parameter_options.get('required', False) - default_value = parameter_options.get('default', None) - immutable = parameter_options.get('immutable', False) + name = parameter_name.replace("_", "-") + description = parameter_options["description"] + _type = parameter_options["type"] + required = parameter_options.get("required", False) + default_value = parameter_options.get("default", None) + immutable = parameter_options.get("immutable", False) # Immutable arguments can't be controlled by the user if immutable: continue - args = ['--%s' % (name)] - kwargs = {'help': description, 'required': required} + args = ["--%s" % (name)] + kwargs = {"help": description, "required": required} if default_value is not None: - kwargs['default'] = default_value + kwargs["default"] = default_value - if _type == 'string': - kwargs['type'] = str - elif _type == 'integer': - kwargs['type'] = int - elif _type == 'boolean': + if _type == "string": + kwargs["type"] = str + elif _type == "integer": + kwargs["type"] = int + elif _type == "boolean": if default_value is False: - kwargs['action'] = 'store_false' + kwargs["action"] = "store_false" else: - kwargs['action'] = 'store_true' + kwargs["action"] = "store_true" parser.add_argument(*args, **kwargs) diff --git a/st2common/st2common/util/auth.py b/st2common/st2common/util/auth.py index 38294c92a70..90e81d938ed 100644 --- a/st2common/st2common/util/auth.py +++ b/st2common/st2common/util/auth.py @@ -28,11 +28,11 @@ from st2common.util import hash as hash_utils __all__ = [ - 'validate_token', - 'validate_token_and_source', - 'generate_api_key', - 'validate_api_key', - 'validate_api_key_and_source' + "validate_token", + "validate_token_and_source", + "generate_api_key", + "validate_api_key", + "validate_api_key_and_source", ] LOG = logging.getLogger(__name__) @@ -53,7 +53,7 @@ def validate_token(token_string): if token.expiry <= date_utils.get_datetime_utc_now(): # TODO: purge expired tokens LOG.audit('Token with id "%s" has expired.' % (token.id)) - raise exceptions.TokenExpiredError('Token has expired.') + raise exceptions.TokenExpiredError("Token has expired.") LOG.audit('Token with id "%s" is validated.' % (token.id)) @@ -74,14 +74,14 @@ def validate_token_and_source(token_in_headers, token_in_query_params): :rtype: :class:`.TokenDB` """ if not token_in_headers and not token_in_query_params: - LOG.audit('Token is not found in header or query parameters.') - raise exceptions.TokenNotProvidedError('Token is not provided.') + LOG.audit("Token is not found in header or query parameters.") + raise exceptions.TokenNotProvidedError("Token is not provided.") if token_in_headers: - LOG.audit('Token provided in headers') + LOG.audit("Token provided in headers") if token_in_query_params: - LOG.audit('Token provided in query parameters') + LOG.audit("Token provided in query parameters") return validate_token(token_in_headers or token_in_query_params) @@ -103,7 +103,8 @@ def generate_api_key(): base64_encoded = base64.b64encode( six.b(hashed_seed), - six.b(random.choice(['rA', 'aZ', 'gQ', 'hH', 'hG', 'aR', 'DD']))).rstrip(b'==') + six.b(random.choice(["rA", "aZ", "gQ", "hH", "hG", "aR", "DD"])), + ).rstrip(b"==") base64_encoded = base64_encoded.decode() return base64_encoded @@ -127,7 +128,7 @@ def validate_api_key(api_key): api_key_db = ApiKey.get(api_key) if not api_key_db.enabled: - raise exceptions.ApiKeyDisabledError('API key is disabled.') + raise exceptions.ApiKeyDisabledError("API key is disabled.") LOG.audit('API key with id "%s" is validated.' % (api_key_db.id)) @@ -148,13 +149,13 @@ def validate_api_key_and_source(api_key_in_headers, api_key_query_params): :rtype: :class:`.ApiKeyDB` """ if not api_key_in_headers and not api_key_query_params: - LOG.audit('API key is not found in header or query parameters.') - raise exceptions.ApiKeyNotProvidedError('API key is not provided.') + LOG.audit("API key is not found in header or query parameters.") + raise exceptions.ApiKeyNotProvidedError("API key is not provided.") if api_key_in_headers: - LOG.audit('API key provided in headers') + LOG.audit("API key provided in headers") if api_key_query_params: - LOG.audit('API key provided in query parameters') + LOG.audit("API key provided in query parameters") return validate_api_key(api_key_in_headers or api_key_query_params) diff --git a/st2common/st2common/util/casts.py b/st2common/st2common/util/casts.py index fa94272e47e..aadad8a4a14 100644 --- a/st2common/st2common/util/casts.py +++ b/st2common/st2common/util/casts.py @@ -89,12 +89,12 @@ def _cast_none(x): # These types as they appear in json schema. CASTS = { - 'array': _cast_object, - 'boolean': _cast_boolean, - 'integer': _cast_integer, - 'number': _cast_number, - 'object': _cast_object, - 'string': _cast_string + "array": _cast_object, + "boolean": _cast_boolean, + "integer": _cast_integer, + "number": _cast_number, + "object": _cast_object, + "string": _cast_string, } diff --git a/st2common/st2common/util/compat.py b/st2common/st2common/util/compat.py index 9288f5f3a0d..1926f97dbac 100644 --- a/st2common/st2common/util/compat.py +++ b/st2common/st2common/util/compat.py @@ -24,16 +24,15 @@ __all__ = [ - 'mock_open_name', - - 'to_unicode', - 'to_ascii', + "mock_open_name", + "to_unicode", + "to_ascii", ] if six.PY3: - mock_open_name = 'builtins.open' + mock_open_name = "builtins.open" else: - mock_open_name = '__builtin__.open' + mock_open_name = "__builtin__.open" def to_unicode(value): @@ -63,4 +62,4 @@ def to_ascii(value): if six.PY3: value = value.encode() - return value.decode('ascii', errors='ignore') + return value.decode("ascii", errors="ignore") diff --git a/st2common/st2common/util/concurrency.py b/st2common/st2common/util/concurrency.py index 50312fa78f9..239407ade05 100644 --- a/st2common/st2common/util/concurrency.py +++ b/st2common/st2common/util/concurrency.py @@ -31,34 +31,30 @@ except ImportError: gevent = None -CONCURRENCY_LIBRARY = 'eventlet' +CONCURRENCY_LIBRARY = "eventlet" __all__ = [ - 'set_concurrency_library', - 'get_concurrency_library', - - 'get_subprocess_module', - 'subprocess_popen', - - 'spawn', - 'wait', - 'cancel', - 'kill', - 'sleep', - - 'get_greenlet_exit_exception_class', - - 'get_green_pool_class', - 'is_green_pool_free', - 'green_pool_wait_all' + "set_concurrency_library", + "get_concurrency_library", + "get_subprocess_module", + "subprocess_popen", + "spawn", + "wait", + "cancel", + "kill", + "sleep", + "get_greenlet_exit_exception_class", + "get_green_pool_class", + "is_green_pool_free", + "green_pool_wait_all", ] def set_concurrency_library(library): global CONCURRENCY_LIBRARY - if library not in ['eventlet', 'gevent']: - raise ValueError('Unsupported concurrency library: %s' % (library)) + if library not in ["eventlet", "gevent"]: + raise ValueError("Unsupported concurrency library: %s" % (library)) CONCURRENCY_LIBRARY = library @@ -69,107 +65,111 @@ def get_concurrency_library(): def get_subprocess_module(): - if CONCURRENCY_LIBRARY == 'eventlet': + if CONCURRENCY_LIBRARY == "eventlet": from eventlet.green import subprocess # pylint: disable=import-error + return subprocess - elif CONCURRENCY_LIBRARY == 'gevent': + elif CONCURRENCY_LIBRARY == "gevent": from gevent import subprocess # pylint: disable=import-error + return subprocess def subprocess_popen(*args, **kwargs): - if CONCURRENCY_LIBRARY == 'eventlet': + if CONCURRENCY_LIBRARY == "eventlet": from eventlet.green import subprocess # pylint: disable=import-error + return subprocess.Popen(*args, **kwargs) - elif CONCURRENCY_LIBRARY == 'gevent': + elif CONCURRENCY_LIBRARY == "gevent": from gevent import subprocess # pylint: disable=import-error + return subprocess.Popen(*args, **kwargs) def spawn(func, *args, **kwargs): - if CONCURRENCY_LIBRARY == 'eventlet': + if CONCURRENCY_LIBRARY == "eventlet": return eventlet.spawn(func, *args, **kwargs) - elif CONCURRENCY_LIBRARY == 'gevent': + elif CONCURRENCY_LIBRARY == "gevent": return gevent.spawn(func, *args, **kwargs) else: - raise ValueError('Unsupported concurrency library') + raise ValueError("Unsupported concurrency library") def wait(green_thread, *args, **kwargs): - if CONCURRENCY_LIBRARY == 'eventlet': + if CONCURRENCY_LIBRARY == "eventlet": return green_thread.wait(*args, **kwargs) - elif CONCURRENCY_LIBRARY == 'gevent': + elif CONCURRENCY_LIBRARY == "gevent": return green_thread.join(*args, **kwargs) else: - raise ValueError('Unsupported concurrency library') + raise ValueError("Unsupported concurrency library") def cancel(green_thread, *args, **kwargs): - if CONCURRENCY_LIBRARY == 'eventlet': + if CONCURRENCY_LIBRARY == "eventlet": return green_thread.cancel(*args, **kwargs) - elif CONCURRENCY_LIBRARY == 'gevent': + elif CONCURRENCY_LIBRARY == "gevent": return green_thread.kill(*args, **kwargs) else: - raise ValueError('Unsupported concurrency library') + raise ValueError("Unsupported concurrency library") def kill(green_thread, *args, **kwargs): - if CONCURRENCY_LIBRARY == 'eventlet': + if CONCURRENCY_LIBRARY == "eventlet": return green_thread.kill(*args, **kwargs) - elif CONCURRENCY_LIBRARY == 'gevent': + elif CONCURRENCY_LIBRARY == "gevent": return green_thread.kill(*args, **kwargs) else: - raise ValueError('Unsupported concurrency library') + raise ValueError("Unsupported concurrency library") def sleep(*args, **kwargs): - if CONCURRENCY_LIBRARY == 'eventlet': + if CONCURRENCY_LIBRARY == "eventlet": return eventlet.sleep(*args, **kwargs) - elif CONCURRENCY_LIBRARY == 'gevent': + elif CONCURRENCY_LIBRARY == "gevent": return gevent.sleep(*args, **kwargs) else: - raise ValueError('Unsupported concurrency library') + raise ValueError("Unsupported concurrency library") def get_greenlet_exit_exception_class(): - if CONCURRENCY_LIBRARY == 'eventlet': + if CONCURRENCY_LIBRARY == "eventlet": return eventlet.support.greenlets.GreenletExit - elif CONCURRENCY_LIBRARY == 'gevent': + elif CONCURRENCY_LIBRARY == "gevent": return gevent.GreenletExit else: - raise ValueError('Unsupported concurrency library') + raise ValueError("Unsupported concurrency library") def get_green_pool_class(): - if CONCURRENCY_LIBRARY == 'eventlet': + if CONCURRENCY_LIBRARY == "eventlet": return eventlet.GreenPool - elif CONCURRENCY_LIBRARY == 'gevent': + elif CONCURRENCY_LIBRARY == "gevent": return gevent.pool.Pool else: - raise ValueError('Unsupported concurrency library') + raise ValueError("Unsupported concurrency library") def is_green_pool_free(pool): """ Return True if the provided green pool is free, False otherwise. """ - if CONCURRENCY_LIBRARY == 'eventlet': + if CONCURRENCY_LIBRARY == "eventlet": return pool.free() - elif CONCURRENCY_LIBRARY == 'gevent': + elif CONCURRENCY_LIBRARY == "gevent": return not pool.full() else: - raise ValueError('Unsupported concurrency library') + raise ValueError("Unsupported concurrency library") def green_pool_wait_all(pool): """ Wait for all the green threads in the pool to finish. """ - if CONCURRENCY_LIBRARY == 'eventlet': + if CONCURRENCY_LIBRARY == "eventlet": return pool.waitall() - elif CONCURRENCY_LIBRARY == 'gevent': + elif CONCURRENCY_LIBRARY == "gevent": # NOTE: This mimicks eventlet.waitall() functionallity better than # pool.join() return all(gl.ready() for gl in pool.greenlets) else: - raise ValueError('Unsupported concurrency library') + raise ValueError("Unsupported concurrency library") diff --git a/st2common/st2common/util/config_loader.py b/st2common/st2common/util/config_loader.py index 620707e6439..30db039bdca 100644 --- a/st2common/st2common/util/config_loader.py +++ b/st2common/st2common/util/config_loader.py @@ -30,9 +30,7 @@ from st2common.util.config_parser import ContentPackConfigParser from st2common.exceptions.db import StackStormDBObjectNotFoundError -__all__ = [ - 'ContentPackConfigLoader' -] +__all__ = ["ContentPackConfigLoader"] LOG = logging.getLogger(__name__) @@ -79,15 +77,16 @@ def get_config(self): # 2. Retrieve values from "global" pack config file (if available) and resolve them if # necessary - config = self._get_values_for_config(config_schema_db=config_schema_db, - config_db=config_db) + config = self._get_values_for_config( + config_schema_db=config_schema_db, config_db=config_db + ) result.update(config) return result def _get_values_for_config(self, config_schema_db, config_db): - schema_values = getattr(config_schema_db, 'attributes', {}) - config_values = getattr(config_db, 'values', {}) + schema_values = getattr(config_schema_db, "attributes", {}) + config_values = getattr(config_db, "values", {}) config = copy.deepcopy(config_values or {}) @@ -131,24 +130,34 @@ def _assign_dynamic_config_values(self, schema, config, parent_keys=None): # Inspect nested object properties if is_dictionary: parent_keys += [str(config_item_key)] - self._assign_dynamic_config_values(schema=schema_item.get('properties', {}), - config=config[config_item_key], - parent_keys=parent_keys) + self._assign_dynamic_config_values( + schema=schema_item.get("properties", {}), + config=config[config_item_key], + parent_keys=parent_keys, + ) # Inspect nested list items elif is_list: parent_keys += [str(config_item_key)] - self._assign_dynamic_config_values(schema=schema_item.get('items', {}), - config=config[config_item_key], - parent_keys=parent_keys) + self._assign_dynamic_config_values( + schema=schema_item.get("items", {}), + config=config[config_item_key], + parent_keys=parent_keys, + ) else: - is_jinja_expression = jinja_utils.is_jinja_expression(value=config_item_value) + is_jinja_expression = jinja_utils.is_jinja_expression( + value=config_item_value + ) if is_jinja_expression: # Resolve / render the Jinja template expression - full_config_item_key = '.'.join(parent_keys + [str(config_item_key)]) - value = self._get_datastore_value_for_expression(key=full_config_item_key, + full_config_item_key = ".".join( + parent_keys + [str(config_item_key)] + ) + value = self._get_datastore_value_for_expression( + key=full_config_item_key, value=config_item_value, - config_schema_item=schema_item) + config_schema_item=schema_item, + ) config[config_item_key] = value else: @@ -167,12 +176,12 @@ def _assign_default_values(self, schema, config): :rtype: ``dict`` """ for schema_item_key, schema_item in six.iteritems(schema): - has_default_value = 'default' in schema_item + has_default_value = "default" in schema_item has_config_value = schema_item_key in config - default_value = schema_item.get('default', None) - is_object = schema_item.get('type', None) == 'object' - has_properties = schema_item.get('properties', None) + default_value = schema_item.get("default", None) + is_object = schema_item.get("type", None) == "object" + has_properties = schema_item.get("properties", None) if has_default_value and not has_config_value: # Config value is not provided, but default value is, use a default value @@ -183,8 +192,9 @@ def _assign_default_values(self, schema, config): if not config.get(schema_item_key, None): config[schema_item_key] = {} - self._assign_default_values(schema=schema_item['properties'], - config=config[schema_item_key]) + self._assign_default_values( + schema=schema_item["properties"], config=config[schema_item_key] + ) return config @@ -198,18 +208,21 @@ def _get_datastore_value_for_expression(self, key, value, config_schema_item=Non from st2common.services.config import deserialize_key_value config_schema_item = config_schema_item or {} - secret = config_schema_item.get('secret', False) + secret = config_schema_item.get("secret", False) try: - value = render_template_with_system_and_user_context(value=value, - user=self.user) + value = render_template_with_system_and_user_context( + value=value, user=self.user + ) except Exception as e: # Throw a more user-friendly exception on failed render exc_class = type(e) original_msg = six.text_type(e) - msg = ('Failed to render dynamic configuration value for key "%s" with value ' - '"%s" for pack "%s" config: %s %s ' % (key, value, self.pack_name, - exc_class, original_msg)) + msg = ( + 'Failed to render dynamic configuration value for key "%s" with value ' + '"%s" for pack "%s" config: %s %s ' + % (key, value, self.pack_name, exc_class, original_msg) + ) raise RuntimeError(msg) if value: @@ -222,21 +235,17 @@ def _get_datastore_value_for_expression(self, key, value, config_schema_item=Non def get_config(pack, user): - """Returns config for given pack and user. - """ + """Returns config for given pack and user.""" LOG.debug('Attempting to get config for pack "%s" and user "%s"' % (pack, user)) if pack and user: - LOG.debug('Pack and user found. Loading config.') - config_loader = ContentPackConfigLoader( - pack_name=pack, - user=user - ) + LOG.debug("Pack and user found. Loading config.") + config_loader = ContentPackConfigLoader(pack_name=pack, user=user) config = config_loader.get_config() else: config = {} - LOG.debug('Config: %s', config) + LOG.debug("Config: %s", config) return config diff --git a/st2common/st2common/util/config_parser.py b/st2common/st2common/util/config_parser.py index 247dca88fa0..40c9e303135 100644 --- a/st2common/st2common/util/config_parser.py +++ b/st2common/st2common/util/config_parser.py @@ -21,10 +21,7 @@ from st2common.content import utils -__all__ = [ - 'ContentPackConfigParser', - 'ContentPackConfig' -] +__all__ = ["ContentPackConfigParser", "ContentPackConfig"] class ContentPackConfigParser(object): @@ -32,8 +29,8 @@ class ContentPackConfigParser(object): Class responsible for obtaining and parsing content pack configs. """ - GLOBAL_CONFIG_NAME = 'config.yaml' - LOCAL_CONFIG_SUFFIX = '_config.yaml' + GLOBAL_CONFIG_NAME = "config.yaml" + LOCAL_CONFIG_SUFFIX = "_config.yaml" def __init__(self, pack_name): self.pack_name = pack_name @@ -85,8 +82,7 @@ def get_global_config_path(self): if not self.pack_path: return None - global_config_path = os.path.join(self.pack_path, - self.GLOBAL_CONFIG_NAME) + global_config_path = os.path.join(self.pack_path, self.GLOBAL_CONFIG_NAME) return global_config_path @classmethod @@ -95,7 +91,7 @@ def get_and_parse_config(cls, config_path): return None if os.path.exists(config_path) and os.path.isfile(config_path): - with io.open(config_path, 'r', encoding='utf8') as fp: + with io.open(config_path, "r", encoding="utf8") as fp: config = yaml.safe_load(fp.read()) return ContentPackConfig(file_path=config_path, config=config) diff --git a/st2common/st2common/util/crypto.py b/st2common/st2common/util/crypto.py index 230c4ada8ef..d01e20557bd 100644 --- a/st2common/st2common/util/crypto.py +++ b/st2common/st2common/util/crypto.py @@ -51,23 +51,18 @@ from cryptography.hazmat.backends import default_backend __all__ = [ - 'KEYCZAR_HEADER_SIZE', - 'KEYCZAR_AES_BLOCK_SIZE', - 'KEYCZAR_HLEN', - - 'read_crypto_key', - - 'symmetric_encrypt', - 'symmetric_decrypt', - - 'cryptography_symmetric_encrypt', - 'cryptography_symmetric_decrypt', - + "KEYCZAR_HEADER_SIZE", + "KEYCZAR_AES_BLOCK_SIZE", + "KEYCZAR_HLEN", + "read_crypto_key", + "symmetric_encrypt", + "symmetric_decrypt", + "cryptography_symmetric_encrypt", + "cryptography_symmetric_decrypt", # NOTE: Keyczar functions are here for testing reasons - they are only used by tests - 'keyczar_symmetric_encrypt', - 'keyczar_symmetric_decrypt', - - 'AESKey' + "keyczar_symmetric_encrypt", + "keyczar_symmetric_decrypt", + "AESKey", ] # Keyczar related constants @@ -94,13 +89,19 @@ class AESKey(object): mode = None size = None - def __init__(self, aes_key_string, hmac_key_string, hmac_key_size, mode='CBC', - size=DEFAULT_AES_KEY_SIZE): - if mode not in ['CBC']: - raise ValueError('Unsupported mode: %s' % (mode)) + def __init__( + self, + aes_key_string, + hmac_key_string, + hmac_key_size, + mode="CBC", + size=DEFAULT_AES_KEY_SIZE, + ): + if mode not in ["CBC"]: + raise ValueError("Unsupported mode: %s" % (mode)) if size < MINIMUM_AES_KEY_SIZE: - raise ValueError('Unsafe key size: %s' % (size)) + raise ValueError("Unsafe key size: %s" % (size)) self.aes_key_string = aes_key_string self.hmac_key_string = hmac_key_string @@ -121,7 +122,7 @@ def generate(self, key_size=DEFAULT_AES_KEY_SIZE): :rtype: :class:`AESKey` """ if key_size < MINIMUM_AES_KEY_SIZE: - raise ValueError('Unsafe key size: %s' % (key_size)) + raise ValueError("Unsafe key size: %s" % (key_size)) aes_key_bytes = os.urandom(int(key_size / 8)) aes_key_string = Base64WSEncode(aes_key_bytes) @@ -129,8 +130,13 @@ def generate(self, key_size=DEFAULT_AES_KEY_SIZE): hmac_key_bytes = os.urandom(int(key_size / 8)) hmac_key_string = Base64WSEncode(hmac_key_bytes) - return AESKey(aes_key_string=aes_key_string, hmac_key_string=hmac_key_string, - hmac_key_size=key_size, mode='CBC', size=key_size) + return AESKey( + aes_key_string=aes_key_string, + hmac_key_string=hmac_key_string, + hmac_key_size=key_size, + mode="CBC", + size=key_size, + ) def to_json(self): """ @@ -140,19 +146,22 @@ def to_json(self): :rtype: ``str`` """ data = { - 'hmacKey': { - 'hmacKeyString': self.hmac_key_string, - 'size': self.hmac_key_size + "hmacKey": { + "hmacKeyString": self.hmac_key_string, + "size": self.hmac_key_size, }, - 'aesKeyString': self.aes_key_string, - 'mode': self.mode.upper(), - 'size': int(self.size) + "aesKeyString": self.aes_key_string, + "mode": self.mode.upper(), + "size": int(self.size), } return json.dumps(data) def __repr__(self): - return ('' % (self.hmac_key_size, self.mode, - self.size)) + return "" % ( + self.hmac_key_size, + self.mode, + self.size, + ) def read_crypto_key(key_path): @@ -164,17 +173,19 @@ def read_crypto_key(key_path): :rtype: :class:`AESKey` """ - with open(key_path, 'r') as fp: + with open(key_path, "r") as fp: content = fp.read() content = json.loads(content) try: - aes_key = AESKey(aes_key_string=content['aesKeyString'], - hmac_key_string=content['hmacKey']['hmacKeyString'], - hmac_key_size=content['hmacKey']['size'], - mode=content['mode'].upper(), - size=content['size']) + aes_key = AESKey( + aes_key_string=content["aesKeyString"], + hmac_key_string=content["hmacKey"]["hmacKeyString"], + hmac_key_size=content["hmacKey"]["size"], + mode=content["mode"].upper(), + size=content["size"], + ) except KeyError as e: msg = 'Invalid or malformed key file "%s": %s' % (key_path, six.text_type(e)) raise KeyError(msg) @@ -187,7 +198,9 @@ def symmetric_encrypt(encrypt_key, plaintext): def symmetric_decrypt(decrypt_key, ciphertext): - return cryptography_symmetric_decrypt(decrypt_key=decrypt_key, ciphertext=ciphertext) + return cryptography_symmetric_decrypt( + decrypt_key=decrypt_key, ciphertext=ciphertext + ) def cryptography_symmetric_encrypt(encrypt_key, plaintext): @@ -206,9 +219,12 @@ def cryptography_symmetric_encrypt(encrypt_key, plaintext): NOTE: Header itself is unused, but it's added so the format is compatible with keyczar format. """ - assert isinstance(encrypt_key, AESKey), 'encrypt_key needs to be AESKey class instance' - assert isinstance(plaintext, (six.text_type, six.string_types, six.binary_type)), \ - 'plaintext needs to either be a string/unicode or bytes' + assert isinstance( + encrypt_key, AESKey + ), "encrypt_key needs to be AESKey class instance" + assert isinstance( + plaintext, (six.text_type, six.string_types, six.binary_type) + ), "plaintext needs to either be a string/unicode or bytes" aes_key_bytes = encrypt_key.aes_key_bytes hmac_key_bytes = encrypt_key.hmac_key_bytes @@ -218,7 +234,7 @@ def cryptography_symmetric_encrypt(encrypt_key, plaintext): if isinstance(plaintext, (six.text_type, six.string_types)): # Convert data to bytes - data = plaintext.encode('utf-8') + data = plaintext.encode("utf-8") else: data = plaintext @@ -234,7 +250,7 @@ def cryptography_symmetric_encrypt(encrypt_key, plaintext): # NOTE: We don't care about actual Keyczar header value, we only care about the length (5 # bytes) so we simply add 5 0's - header_bytes = b'00000' + header_bytes = b"00000" ciphertext_bytes = encryptor.update(data) + encryptor.finalize() msg_bytes = header_bytes + iv_bytes + ciphertext_bytes @@ -263,9 +279,12 @@ def cryptography_symmetric_decrypt(decrypt_key, ciphertext): NOTE 2: This function is loosely based on keyczar AESKey.Decrypt() (Apache 2.0 license). """ - assert isinstance(decrypt_key, AESKey), 'decrypt_key needs to be AESKey class instance' - assert isinstance(ciphertext, (six.text_type, six.string_types, six.binary_type)), \ - 'ciphertext needs to either be a string/unicode or bytes' + assert isinstance( + decrypt_key, AESKey + ), "decrypt_key needs to be AESKey class instance" + assert isinstance( + ciphertext, (six.text_type, six.string_types, six.binary_type) + ), "ciphertext needs to either be a string/unicode or bytes" aes_key_bytes = decrypt_key.aes_key_bytes hmac_key_bytes = decrypt_key.hmac_key_bytes @@ -280,10 +299,12 @@ def cryptography_symmetric_decrypt(decrypt_key, ciphertext): # Verify ciphertext contains IV + HMAC signature if len(data_bytes) < (KEYCZAR_AES_BLOCK_SIZE + KEYCZAR_HLEN): - raise ValueError('Invalid or malformed ciphertext (too short)') + raise ValueError("Invalid or malformed ciphertext (too short)") iv_bytes = data_bytes[:KEYCZAR_AES_BLOCK_SIZE] # first block is IV - ciphertext_bytes = data_bytes[KEYCZAR_AES_BLOCK_SIZE:-KEYCZAR_HLEN] # strip IV and signature + ciphertext_bytes = data_bytes[ + KEYCZAR_AES_BLOCK_SIZE:-KEYCZAR_HLEN + ] # strip IV and signature signature_bytes = data_bytes[-KEYCZAR_HLEN:] # last 20 bytes are signature # Verify HMAC signature @@ -302,6 +323,7 @@ def cryptography_symmetric_decrypt(decrypt_key, ciphertext): decrypted = pkcs5_unpad(decrypted) return decrypted + ### # NOTE: Those methods below are deprecated and only used for testing purposes ## @@ -329,11 +351,12 @@ def keyczar_symmetric_encrypt(encrypt_key, plaintext): from keyczar.keys import HmacKey as KeyczarHmacKey # pylint: disable=import-error from keyczar.keyinfo import GetMode # pylint: disable=import-error - encrypt_key = KeyczarAesKey(encrypt_key.aes_key_string, - KeyczarHmacKey(encrypt_key.hmac_key_string, - encrypt_key.hmac_key_size), - encrypt_key.size, - GetMode(encrypt_key.mode)) + encrypt_key = KeyczarAesKey( + encrypt_key.aes_key_string, + KeyczarHmacKey(encrypt_key.hmac_key_string, encrypt_key.hmac_key_size), + encrypt_key.size, + GetMode(encrypt_key.mode), + ) return binascii.hexlify(encrypt_key.Encrypt(plaintext)).upper() @@ -356,11 +379,12 @@ def keyczar_symmetric_decrypt(decrypt_key, ciphertext): from keyczar.keys import HmacKey as KeyczarHmacKey # pylint: disable=import-error from keyczar.keyinfo import GetMode # pylint: disable=import-error - decrypt_key = KeyczarAesKey(decrypt_key.aes_key_string, - KeyczarHmacKey(decrypt_key.hmac_key_string, - decrypt_key.hmac_key_size), - decrypt_key.size, - GetMode(decrypt_key.mode)) + decrypt_key = KeyczarAesKey( + decrypt_key.aes_key_string, + KeyczarHmacKey(decrypt_key.hmac_key_string, decrypt_key.hmac_key_size), + decrypt_key.size, + GetMode(decrypt_key.mode), + ) return decrypt_key.Decrypt(binascii.unhexlify(ciphertext)) @@ -370,7 +394,7 @@ def pkcs5_pad(data): Pad data using PKCS5 """ pad = KEYCZAR_AES_BLOCK_SIZE - len(data) % KEYCZAR_AES_BLOCK_SIZE - data = data + pad * chr(pad).encode('utf-8') + data = data + pad * chr(pad).encode("utf-8") return data @@ -380,7 +404,7 @@ def pkcs5_unpad(data): """ if isinstance(data, six.binary_type): # Make sure we are operating with a string type - data = data.decode('utf-8') + data = data.decode("utf-8") pad = ord(data[-1]) data = data[:-pad] @@ -404,9 +428,9 @@ def Base64WSEncode(s): """ if isinstance(s, six.text_type): # Make sure input string is always converted to bytes (if not already) - s = s.encode('utf-8') + s = s.encode("utf-8") - return base64.urlsafe_b64encode(s).decode('utf-8').replace("=", "") + return base64.urlsafe_b64encode(s).decode("utf-8").replace("=", "") def Base64WSDecode(s): @@ -427,12 +451,12 @@ def Base64WSDecode(s): NOTE: Taken from keyczar (Apache 2.0 license) """ - s = ''.join(s.splitlines()) + s = "".join(s.splitlines()) s = str(s.replace(" ", "")) # kill whitespace, make string (not unicode) d = len(s) % 4 if d == 1: - raise ValueError('Base64 decoding errors') + raise ValueError("Base64 decoding errors") elif d == 2: s += "==" elif d == 3: @@ -442,4 +466,4 @@ def Base64WSDecode(s): return base64.urlsafe_b64decode(s) except TypeError as e: # Decoding raises TypeError if s contains invalid characters. - raise ValueError('Base64 decoding error: %s' % (six.text_type(e))) + raise ValueError("Base64 decoding error: %s" % (six.text_type(e))) diff --git a/st2common/st2common/util/date.py b/st2common/st2common/util/date.py index 979c3e8eb3f..8df0e4659f6 100644 --- a/st2common/st2common/util/date.py +++ b/st2common/st2common/util/date.py @@ -24,12 +24,7 @@ import dateutil.parser -__all__ = [ - 'get_datetime_utc_now', - 'add_utc_tz', - 'convert_to_utc', - 'parse' -] +__all__ = ["get_datetime_utc_now", "add_utc_tz", "convert_to_utc", "parse"] def get_datetime_utc_now(): @@ -45,14 +40,14 @@ def get_datetime_utc_now(): def append_milliseconds_to_time(date, millis): """ - Return time UTC datetime object offset by provided milliseconds. + Return time UTC datetime object offset by provided milliseconds. """ return convert_to_utc(date + datetime.timedelta(milliseconds=millis)) def add_utc_tz(dt): if dt.tzinfo and dt.tzinfo.utcoffset(dt) != datetime.timedelta(0): - raise ValueError('datetime already contains a non UTC timezone') + raise ValueError("datetime already contains a non UTC timezone") return dt.replace(tzinfo=dateutil.tz.tzutc()) diff --git a/st2common/st2common/util/debugging.py b/st2common/st2common/util/debugging.py index dd5d74d2a2a..66abbbe1ada 100644 --- a/st2common/st2common/util/debugging.py +++ b/st2common/st2common/util/debugging.py @@ -25,11 +25,7 @@ from st2common.logging.misc import set_log_level_for_all_loggers -__all__ = [ - 'enable_debugging', - 'disable_debugging', - 'is_enabled' -] +__all__ = ["enable_debugging", "disable_debugging", "is_enabled"] ENABLE_DEBUGGING = False diff --git a/st2common/st2common/util/deprecation.py b/st2common/st2common/util/deprecation.py index 160423a5e28..a178a9473d6 100644 --- a/st2common/st2common/util/deprecation.py +++ b/st2common/st2common/util/deprecation.py @@ -23,10 +23,14 @@ def deprecated(func): as deprecated. It will result in a warning being emitted when the function is used. """ + def new_func(*args, **kwargs): - warnings.warn("Call to deprecated function {}.".format(func.__name__), - category=DeprecationWarning) + warnings.warn( + "Call to deprecated function {}.".format(func.__name__), + category=DeprecationWarning, + ) return func(*args, **kwargs) + new_func.__name__ = func.__name__ new_func.__doc__ = func.__doc__ new_func.__dict__.update(func.__dict__) diff --git a/st2common/st2common/util/driver_loader.py b/st2common/st2common/util/driver_loader.py index 285c22ed796..50f5044c41b 100644 --- a/st2common/st2common/util/driver_loader.py +++ b/st2common/st2common/util/driver_loader.py @@ -21,15 +21,11 @@ from st2common import log as logging -__all__ = [ - 'get_available_backends', - 'get_backend_driver', - 'get_backend_instance' -] +__all__ = ["get_available_backends", "get_backend_driver", "get_backend_instance"] LOG = logging.getLogger(__name__) -BACKENDS_NAMESPACE = 'st2common.rbac.backend' +BACKENDS_NAMESPACE = "st2common.rbac.backend" def get_available_backends(namespace, invoke_on_load=False): @@ -62,8 +58,9 @@ def get_backend_driver(namespace, name, invoke_on_load=False): LOG.debug('Retrieving driver for backend "%s"' % (name)) try: - manager = DriverManager(namespace=namespace, name=name, - invoke_on_load=invoke_on_load) + manager = DriverManager( + namespace=namespace, name=name, invoke_on_load=invoke_on_load + ) except RuntimeError: message = 'Invalid "%s" backend specified: %s' % (namespace, name) LOG.exception(message) @@ -79,7 +76,9 @@ def get_backend_instance(namespace, name, invoke_on_load=False): :param name: Backend name. :type name: ``str`` """ - cls = get_backend_driver(namespace=namespace, name=name, invoke_on_load=invoke_on_load) + cls = get_backend_driver( + namespace=namespace, name=name, invoke_on_load=invoke_on_load + ) cls_instance = cls() return cls_instance diff --git a/st2common/st2common/util/enum.py b/st2common/st2common/util/enum.py index ddcc138ea52..84a6e968f53 100644 --- a/st2common/st2common/util/enum.py +++ b/st2common/st2common/util/enum.py @@ -16,15 +16,16 @@ from __future__ import absolute_import import inspect -__all__ = [ - 'Enum' -] +__all__ = ["Enum"] class Enum(object): @classmethod def get_valid_values(cls): keys = list(cls.__dict__.keys()) - values = [getattr(cls, key) for key in keys if (not key.startswith('_') and - not inspect.ismethod(getattr(cls, key)))] + values = [ + getattr(cls, key) + for key in keys + if (not key.startswith("_") and not inspect.ismethod(getattr(cls, key))) + ] return values diff --git a/st2common/st2common/util/file_system.py b/st2common/st2common/util/file_system.py index d6d2458aec9..e26adaedfdf 100644 --- a/st2common/st2common/util/file_system.py +++ b/st2common/st2common/util/file_system.py @@ -26,10 +26,7 @@ import six -__all__ = [ - 'get_file_list', - 'recursive_chown' -] +__all__ = ["get_file_list", "recursive_chown"] def get_file_list(directory, exclude_patterns=None): @@ -48,9 +45,9 @@ def get_file_list(directory, exclude_patterns=None): :rtype: ``list`` """ result = [] - if not directory.endswith('/'): + if not directory.endswith("/"): # Make sure trailing slash is present - directory = directory + '/' + directory = directory + "/" def include_file(file_path): if not exclude_patterns: @@ -63,7 +60,7 @@ def include_file(file_path): return True for (dirpath, dirnames, filenames) in os.walk(directory): - base_path = dirpath.replace(directory, '') + base_path = dirpath.replace(directory, "") for filename in filenames: if base_path: diff --git a/st2common/st2common/util/green/shell.py b/st2common/st2common/util/green/shell.py index 4fd71ef7cf8..4b6d79935b1 100644 --- a/st2common/st2common/util/green/shell.py +++ b/st2common/st2common/util/green/shell.py @@ -27,20 +27,31 @@ from st2common import log as logging from st2common.util import concurrency -__all__ = [ - 'run_command' -] +__all__ = ["run_command"] TIMEOUT_EXIT_CODE = -9 LOG = logging.getLogger(__name__) -def run_command(cmd, stdin=None, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False, - cwd=None, env=None, timeout=60, preexec_func=None, kill_func=None, - read_stdout_func=None, read_stderr_func=None, - read_stdout_buffer=None, read_stderr_buffer=None, stdin_value=None, - bufsize=0): +def run_command( + cmd, + stdin=None, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=False, + cwd=None, + env=None, + timeout=60, + preexec_func=None, + kill_func=None, + read_stdout_func=None, + read_stderr_func=None, + read_stdout_buffer=None, + read_stderr_buffer=None, + stdin_value=None, + bufsize=0, +): """ Run the provided command in a subprocess and wait until it completes. @@ -89,59 +100,77 @@ def run_command(cmd, stdin=None, stdout=subprocess.PIPE, stderr=subprocess.PIPE, :rtype: ``tuple`` (exit_code, stdout, stderr, timed_out) """ - LOG.debug('Entering st2common.util.green.run_command.') + LOG.debug("Entering st2common.util.green.run_command.") assert isinstance(cmd, (list, tuple) + six.string_types) - if (read_stdout_func and not read_stderr_func) or (read_stderr_func and not read_stdout_func): - raise ValueError('Both read_stdout_func and read_stderr_func arguments need ' - 'to be provided.') + if (read_stdout_func and not read_stderr_func) or ( + read_stderr_func and not read_stdout_func + ): + raise ValueError( + "Both read_stdout_func and read_stderr_func arguments need " + "to be provided." + ) if read_stdout_func and not (read_stdout_buffer or read_stderr_buffer): - raise ValueError('read_stdout_buffer and read_stderr_buffer arguments need to be provided ' - 'when read_stdout_func is provided') + raise ValueError( + "read_stdout_buffer and read_stderr_buffer arguments need to be provided " + "when read_stdout_func is provided" + ) if not env: - LOG.debug('env argument not provided. using process env (os.environ).') + LOG.debug("env argument not provided. using process env (os.environ).") env = os.environ.copy() subprocess = concurrency.get_subprocess_module() # Note: We are using eventlet / gevent friendly implementation of subprocess which uses # GreenPipe so it doesn't block - LOG.debug('Creating subprocess.') - process = concurrency.subprocess_popen(args=cmd, stdin=stdin, stdout=stdout, stderr=stderr, - env=env, cwd=cwd, shell=shell, preexec_fn=preexec_func, - bufsize=bufsize) + LOG.debug("Creating subprocess.") + process = concurrency.subprocess_popen( + args=cmd, + stdin=stdin, + stdout=stdout, + stderr=stderr, + env=env, + cwd=cwd, + shell=shell, + preexec_fn=preexec_func, + bufsize=bufsize, + ) if read_stdout_func: - LOG.debug('Spawning read_stdout_func function') - read_stdout_thread = concurrency.spawn(read_stdout_func, process.stdout, read_stdout_buffer) + LOG.debug("Spawning read_stdout_func function") + read_stdout_thread = concurrency.spawn( + read_stdout_func, process.stdout, read_stdout_buffer + ) if read_stderr_func: - LOG.debug('Spawning read_stderr_func function') - read_stderr_thread = concurrency.spawn(read_stderr_func, process.stderr, read_stderr_buffer) + LOG.debug("Spawning read_stderr_func function") + read_stderr_thread = concurrency.spawn( + read_stderr_func, process.stderr, read_stderr_buffer + ) def on_timeout_expired(timeout): global timed_out try: - LOG.debug('Starting process wait inside timeout handler.') + LOG.debug("Starting process wait inside timeout handler.") process.wait(timeout=timeout) except subprocess.TimeoutExpired: # Command has timed out, kill the process and propagate the error. # Note: We explicitly set the returncode to indicate the timeout. - LOG.debug('Command execution timeout reached.') + LOG.debug("Command execution timeout reached.") # NOTE: It's important we set returncode twice - here and below to avoid race in this # function because "kill_func()" is async and "process.kill()" is not. process.returncode = TIMEOUT_EXIT_CODE if kill_func: - LOG.debug('Calling kill_func.') + LOG.debug("Calling kill_func.") kill_func(process=process) else: - LOG.debug('Killing process.') + LOG.debug("Killing process.") process.kill() # NOTE: It's imporant to set returncode here as well, since call to process.kill() sets @@ -149,25 +178,27 @@ def on_timeout_expired(timeout): process.returncode = TIMEOUT_EXIT_CODE if read_stdout_func and read_stderr_func: - LOG.debug('Killing read_stdout_thread and read_stderr_thread') + LOG.debug("Killing read_stdout_thread and read_stderr_thread") concurrency.kill(read_stdout_thread) concurrency.kill(read_stderr_thread) - LOG.debug('Spawning timeout handler thread.') + LOG.debug("Spawning timeout handler thread.") timeout_thread = concurrency.spawn(on_timeout_expired, timeout) - LOG.debug('Attaching to process.') + LOG.debug("Attaching to process.") if stdin_value: if six.PY3: - stdin_value = stdin_value.encode('utf-8') + stdin_value = stdin_value.encode("utf-8") process.stdin.write(stdin_value) if read_stdout_func and read_stderr_func: - LOG.debug('Using real-time stdout and stderr read mode, calling process.wait()') + LOG.debug("Using real-time stdout and stderr read mode, calling process.wait()") process.wait() else: - LOG.debug('Using delayed stdout and stderr read mode, calling process.communicate()') + LOG.debug( + "Using delayed stdout and stderr read mode, calling process.communicate()" + ) stdout, stderr = process.communicate() concurrency.cancel(timeout_thread) @@ -182,11 +213,11 @@ def on_timeout_expired(timeout): stderr = read_stderr_buffer.getvalue() if exit_code == TIMEOUT_EXIT_CODE: - LOG.debug('Timeout.') + LOG.debug("Timeout.") timed_out = True else: - LOG.debug('No timeout.') + LOG.debug("No timeout.") timed_out = False - LOG.debug('Returning.') + LOG.debug("Returning.") return (exit_code, stdout, stderr, timed_out) diff --git a/st2common/st2common/util/greenpooldispatch.py b/st2common/st2common/util/greenpooldispatch.py index d85ebfbf5de..156d530116d 100644 --- a/st2common/st2common/util/greenpooldispatch.py +++ b/st2common/st2common/util/greenpooldispatch.py @@ -21,9 +21,7 @@ from st2common import log as logging -__all__ = [ - 'BufferedDispatcher' -] +__all__ = ["BufferedDispatcher"] # If the thread pool has been occupied with no empty threads for more than this number of seconds # a message will be logged @@ -38,14 +36,20 @@ class BufferedDispatcher(object): - - def __init__(self, dispatch_pool_size=50, monitor_thread_empty_q_sleep_time=5, - monitor_thread_no_workers_sleep_time=1, name=None): + def __init__( + self, + dispatch_pool_size=50, + monitor_thread_empty_q_sleep_time=5, + monitor_thread_no_workers_sleep_time=1, + name=None, + ): self._pool_limit = dispatch_pool_size self._dispatcher_pool = eventlet.GreenPool(dispatch_pool_size) self._dispatch_monitor_thread = eventlet.greenthread.spawn(self._flush) self._monitor_thread_empty_q_sleep_time = monitor_thread_empty_q_sleep_time - self._monitor_thread_no_workers_sleep_time = monitor_thread_no_workers_sleep_time + self._monitor_thread_no_workers_sleep_time = ( + monitor_thread_no_workers_sleep_time + ) self._name = name self._work_buffer = six.moves.queue.Queue() @@ -77,7 +81,9 @@ def _flush_now(self): now = time.time() if (now - self._pool_last_free_ts) >= POOL_BUSY_THRESHOLD_SECONDS: - LOG.info(POOL_BUSY_LOG_MESSAGE % (self.name, POOL_BUSY_THRESHOLD_SECONDS)) + LOG.info( + POOL_BUSY_LOG_MESSAGE % (self.name, POOL_BUSY_THRESHOLD_SECONDS) + ) return @@ -90,8 +96,15 @@ def _flush_now(self): def __repr__(self): free_count = self._dispatcher_pool.free() - values = (self.name, self._pool_limit, free_count, self._monitor_thread_empty_q_sleep_time, - self._monitor_thread_no_workers_sleep_time) - return ('' % - values) + values = ( + self.name, + self._pool_limit, + free_count, + self._monitor_thread_empty_q_sleep_time, + self._monitor_thread_no_workers_sleep_time, + ) + return ( + "" + % values + ) diff --git a/st2common/st2common/util/gunicorn_workers.py b/st2common/st2common/util/gunicorn_workers.py index 61eebe84e4e..69942ac309f 100644 --- a/st2common/st2common/util/gunicorn_workers.py +++ b/st2common/st2common/util/gunicorn_workers.py @@ -20,9 +20,7 @@ import six from gunicorn.workers.sync import SyncWorker -__all__ = [ - 'EventletSyncWorker' -] +__all__ = ["EventletSyncWorker"] class EventletSyncWorker(SyncWorker): @@ -44,7 +42,7 @@ def handle_quit(self, sig, frame): except AssertionError as e: msg = six.text_type(e) - if 'do not call blocking functions from the mainloop' in msg: + if "do not call blocking functions from the mainloop" in msg: # Workaround for "do not call blocking functions from the mainloop" issue sys.exit(0) diff --git a/st2common/st2common/util/hash.py b/st2common/st2common/util/hash.py index f0a55963798..3d7c83328c5 100644 --- a/st2common/st2common/util/hash.py +++ b/st2common/st2common/util/hash.py @@ -19,12 +19,10 @@ import hashlib -__all__ = [ - 'hash' -] +__all__ = ["hash"] -FIXED_SALT = 'saltnpepper' +FIXED_SALT = "saltnpepper" def hash(value, salt=FIXED_SALT): diff --git a/st2common/st2common/util/http.py b/st2common/st2common/util/http.py index e11a277be65..26aa6d445dd 100644 --- a/st2common/st2common/util/http.py +++ b/st2common/st2common/util/http.py @@ -18,17 +18,20 @@ http_client = six.moves.http_client -__all__ = [ - 'HTTP_SUCCESS', - 'parse_content_type_header' +__all__ = ["HTTP_SUCCESS", "parse_content_type_header"] + +HTTP_SUCCESS = [ + http_client.OK, + http_client.CREATED, + http_client.ACCEPTED, + http_client.NON_AUTHORITATIVE_INFORMATION, + http_client.NO_CONTENT, + http_client.RESET_CONTENT, + http_client.PARTIAL_CONTENT, + http_client.MULTI_STATUS, + http_client.IM_USED, ] -HTTP_SUCCESS = [http_client.OK, http_client.CREATED, http_client.ACCEPTED, - http_client.NON_AUTHORITATIVE_INFORMATION, http_client.NO_CONTENT, - http_client.RESET_CONTENT, http_client.PARTIAL_CONTENT, - http_client.MULTI_STATUS, http_client.IM_USED, - ] - def parse_content_type_header(content_type): """ @@ -37,13 +40,13 @@ def parse_content_type_header(content_type): :rype: ``tuple`` """ - if ';' in content_type: - split = content_type.split(';') + if ";" in content_type: + split = content_type.split(";") media = split[0] options = {} for pair in split[1:]: - split_pair = pair.split('=', 1) + split_pair = pair.split("=", 1) if len(split_pair) != 2: continue diff --git a/st2common/st2common/util/ip_utils.py b/st2common/st2common/util/ip_utils.py index 4e2a00357a2..53253432d86 100644 --- a/st2common/st2common/util/ip_utils.py +++ b/st2common/st2common/util/ip_utils.py @@ -21,11 +21,7 @@ LOG = logging.getLogger(__name__) -__all__ = [ - 'is_ipv4', - 'is_ipv6', - 'split_host_port' -] +__all__ = ["is_ipv4", "is_ipv6", "split_host_port"] BRACKET_PATTERN = r"^\[.*\]" # IPv6 bracket pattern to specify port COMPILED_BRACKET_PATTERN = re.compile(BRACKET_PATTERN) @@ -91,30 +87,32 @@ def split_host_port(host_str): # Check if it's square bracket style. match = COMPILED_BRACKET_PATTERN.match(host_str) if match: - LOG.debug('Square bracket style.') + LOG.debug("Square bracket style.") # Check if square bracket style no port. match = COMPILED_HOST_ONLY_IN_BRACKET_PATTERN.match(host_str) if match: - hostname = match.group().strip('[]') + hostname = match.group().strip("[]") return (hostname, port) - hostname, separator, port = hostname.rpartition(':') + hostname, separator, port = hostname.rpartition(":") try: - LOG.debug('host_str: %s, hostname: %s port: %s' % (host_str, hostname, port)) + LOG.debug( + "host_str: %s, hostname: %s port: %s" % (host_str, hostname, port) + ) port = int(port) - hostname = hostname.strip('[]') + hostname = hostname.strip("[]") return (hostname, port) except: - raise Exception('Invalid port %s specified.' % port) + raise Exception("Invalid port %s specified." % port) else: - LOG.debug('Non-bracket address. host_str: %s' % host_str) - if ':' in host_str: - LOG.debug('Non-bracket with port.') - hostname, separator, port = hostname.rpartition(':') + LOG.debug("Non-bracket address. host_str: %s" % host_str) + if ":" in host_str: + LOG.debug("Non-bracket with port.") + hostname, separator, port = hostname.rpartition(":") try: port = int(port) return (hostname, port) except: - raise Exception('Invalid port %s specified.' % port) + raise Exception("Invalid port %s specified." % port) return (hostname, port) diff --git a/st2common/st2common/util/isotime.py b/st2common/st2common/util/isotime.py index 0830393bf86..0c6ca1c4d43 100644 --- a/st2common/st2common/util/isotime.py +++ b/st2common/st2common/util/isotime.py @@ -25,17 +25,14 @@ from st2common.util import date as date_utils import six -__all__ = [ - 'format', - 'validate', - 'parse' -] +__all__ = ["format", "validate", "parse"] -ISO8601_FORMAT = '%Y-%m-%dT%H:%M:%S' -ISO8601_FORMAT_MICROSECOND = '%Y-%m-%dT%H:%M:%S.%f' -ISO8601_UTC_REGEX = \ - r'^\d{4}\-\d{2}\-\d{2}(\s|T)\d{2}:\d{2}:\d{2}(\.\d{3,6})?(Z|\+00|\+0000|\+00:00)$' +ISO8601_FORMAT = "%Y-%m-%dT%H:%M:%S" +ISO8601_FORMAT_MICROSECOND = "%Y-%m-%dT%H:%M:%S.%f" +ISO8601_UTC_REGEX = ( + r"^\d{4}\-\d{2}\-\d{2}(\s|T)\d{2}:\d{2}:\d{2}(\.\d{3,6})?(Z|\+00|\+0000|\+00:00)$" +) def format(dt, usec=True, offset=True): @@ -53,20 +50,21 @@ def format(dt, usec=True, offset=True): fmt = ISO8601_FORMAT_MICROSECOND if usec else ISO8601_FORMAT if offset: - ost = dt.strftime('%z') - ost = (ost[:3] + ':' + ost[3:]) if ost else '+00:00' + ost = dt.strftime("%z") + ost = (ost[:3] + ":" + ost[3:]) if ost else "+00:00" else: - tz = dt.tzinfo.tzname(dt) if dt.tzinfo else 'UTC' - ost = 'Z' if tz == 'UTC' else tz + tz = dt.tzinfo.tzname(dt) if dt.tzinfo else "UTC" + ost = "Z" if tz == "UTC" else tz return dt.strftime(fmt) + ost def validate(value, raise_exception=True): - if (isinstance(value, datetime.datetime) or - (type(value) in [str, six.text_type] and re.match(ISO8601_UTC_REGEX, value))): + if isinstance(value, datetime.datetime) or ( + type(value) in [str, six.text_type] and re.match(ISO8601_UTC_REGEX, value) + ): return True if raise_exception: - raise ValueError('Datetime value does not match expected format.') + raise ValueError("Datetime value does not match expected format.") return False diff --git a/st2common/st2common/util/jinja.py b/st2common/st2common/util/jinja.py index 9234986f9f4..44722469082 100644 --- a/st2common/st2common/util/jinja.py +++ b/st2common/st2common/util/jinja.py @@ -22,21 +22,14 @@ from st2common.util.compat import to_unicode -__all__ = [ - 'get_jinja_environment', - 'render_values', - 'is_jinja_expression' -] +__all__ = ["get_jinja_environment", "render_values", "is_jinja_expression"] -JINJA_EXPRESSIONS_START_MARKERS = [ - '{{', - '{%' -] +JINJA_EXPRESSIONS_START_MARKERS = ["{{", "{%"] -JINJA_REGEX = '({{(.*)}})' +JINJA_REGEX = "({{(.*)}})" JINJA_REGEX_PTRN = re.compile(JINJA_REGEX) -JINJA_BLOCK_REGEX = '({%(.*)%})' +JINJA_BLOCK_REGEX = "({%(.*)%})" JINJA_BLOCK_REGEX_PTRN = re.compile(JINJA_BLOCK_REGEX) @@ -53,59 +46,52 @@ def get_filters(): from st2common.expressions.functions import path return { - 'decrypt_kv': datastore.decrypt_kv, - - 'from_json_string': data.from_json_string, - 'from_yaml_string': data.from_yaml_string, - 'json_escape': data.json_escape, - 'jsonpath_query': data.jsonpath_query, - 'to_complex': data.to_complex, - 'to_json_string': data.to_json_string, - 'to_yaml_string': data.to_yaml_string, - - 'regex_match': regex.regex_match, - 'regex_replace': regex.regex_replace, - 'regex_search': regex.regex_search, - 'regex_substring': regex.regex_substring, - - 'to_human_time_from_seconds': time.to_human_time_from_seconds, - - 'version_compare': version.version_compare, - 'version_more_than': version.version_more_than, - 'version_less_than': version.version_less_than, - 'version_equal': version.version_equal, - 'version_match': version.version_match, - 'version_bump_major': version.version_bump_major, - 'version_bump_minor': version.version_bump_minor, - 'version_bump_patch': version.version_bump_patch, - 'version_strip_patch': version.version_strip_patch, - 'use_none': data.use_none, - - 'basename': path.basename, - 'dirname': path.dirname + "decrypt_kv": datastore.decrypt_kv, + "from_json_string": data.from_json_string, + "from_yaml_string": data.from_yaml_string, + "json_escape": data.json_escape, + "jsonpath_query": data.jsonpath_query, + "to_complex": data.to_complex, + "to_json_string": data.to_json_string, + "to_yaml_string": data.to_yaml_string, + "regex_match": regex.regex_match, + "regex_replace": regex.regex_replace, + "regex_search": regex.regex_search, + "regex_substring": regex.regex_substring, + "to_human_time_from_seconds": time.to_human_time_from_seconds, + "version_compare": version.version_compare, + "version_more_than": version.version_more_than, + "version_less_than": version.version_less_than, + "version_equal": version.version_equal, + "version_match": version.version_match, + "version_bump_major": version.version_bump_major, + "version_bump_minor": version.version_bump_minor, + "version_bump_patch": version.version_bump_patch, + "version_strip_patch": version.version_strip_patch, + "use_none": data.use_none, + "basename": path.basename, + "dirname": path.dirname, } def get_jinja_environment(allow_undefined=False, trim_blocks=True, lstrip_blocks=True): - ''' + """ jinja2.Environment object that is setup with right behaviors and custom filters. :param strict_undefined: If should allow undefined variables in templates :type strict_undefined: ``bool`` - ''' + """ # Late import to avoid very expensive in-direct import (~1 second) when this function # is not called / used import jinja2 undefined = jinja2.Undefined if allow_undefined else jinja2.StrictUndefined env = jinja2.Environment( # nosec - undefined=undefined, - trim_blocks=trim_blocks, - lstrip_blocks=lstrip_blocks + undefined=undefined, trim_blocks=trim_blocks, lstrip_blocks=lstrip_blocks ) env.filters.update(get_filters()) - env.tests['in'] = lambda item, list: item in list + env.tests["in"] = lambda item, list: item in list return env @@ -130,7 +116,7 @@ def render_values(mapping=None, context=None, allow_undefined=False): # This mean __context is a reserve key word although backwards compat is preserved by making # sure that real context is updated later and therefore will override the __context value. super_context = {} - super_context['__context'] = context + super_context["__context"] = context super_context.update(context) env = get_jinja_environment(allow_undefined=allow_undefined) @@ -150,7 +136,7 @@ def render_values(mapping=None, context=None, allow_undefined=False): v = str(v) try: - LOG.info('Rendering string %s. Super context=%s', v, super_context) + LOG.info("Rendering string %s. Super context=%s", v, super_context) rendered_v = env.from_string(v).render(super_context) except Exception as e: # Attach key and value which failed the rendering @@ -166,7 +152,12 @@ def render_values(mapping=None, context=None, allow_undefined=False): if reverse_json_dumps: rendered_v = json.loads(rendered_v) rendered_mapping[k] = rendered_v - LOG.info('Mapping: %s, rendered_mapping: %s, context: %s', mapping, rendered_mapping, context) + LOG.info( + "Mapping: %s, rendered_mapping: %s, context: %s", + mapping, + rendered_mapping, + context, + ) return rendered_mapping @@ -194,6 +185,6 @@ def convert_jinja_to_raw_block(value): if isinstance(value, six.string_types): if JINJA_REGEX_PTRN.findall(value) or JINJA_BLOCK_REGEX_PTRN.findall(value): - return '{% raw %}' + value + '{% endraw %}' + return "{% raw %}" + value + "{% endraw %}" return value diff --git a/st2common/st2common/util/jsonify.py b/st2common/st2common/util/jsonify.py index 1f47cec1b04..16a95dde99c 100644 --- a/st2common/st2common/util/jsonify.py +++ b/st2common/st2common/util/jsonify.py @@ -25,18 +25,12 @@ import six -__all__ = [ - 'json_encode', - 'json_loads', - 'try_loads', - - 'get_json_type_for_python_value' -] +__all__ = ["json_encode", "json_loads", "try_loads", "get_json_type_for_python_value"] class GenericJSON(JSONEncoder): def default(self, obj): # pylint: disable=method-hidden - if hasattr(obj, '__json__') and six.callable(obj.__json__): + if hasattr(obj, "__json__") and six.callable(obj.__json__): return obj.__json__() else: return JSONEncoder.default(self, obj) @@ -47,7 +41,7 @@ def json_encode(obj, indent=4): def load_file(path): - with open(path, 'r') as fd: + with open(path, "r") as fd: return json.load(fd) @@ -92,16 +86,16 @@ def get_json_type_for_python_value(value): :rtype: ``str`` """ if isinstance(value, six.text_type): - return 'string' + return "string" elif isinstance(value, (int, float)): - return 'number' + return "number" elif isinstance(value, dict): - return 'object' + return "object" elif isinstance(value, (list, tuple)): - return 'array' + return "array" elif isinstance(value, bool): - return 'boolean' + return "boolean" elif value is None: - return 'null' + return "null" else: - return 'unknown' + return "unknown" diff --git a/st2common/st2common/util/keyvalue.py b/st2common/st2common/util/keyvalue.py index 05246d2d320..cad32250a83 100644 --- a/st2common/st2common/util/keyvalue.py +++ b/st2common/st2common/util/keyvalue.py @@ -24,22 +24,23 @@ from st2common.rbac.backends import get_rbac_backend from st2common.persistence.keyvalue import KeyValuePair from st2common.services.config import deserialize_key_value -from st2common.constants.keyvalue import (FULL_SYSTEM_SCOPE, FULL_USER_SCOPE, USER_SCOPE, - ALLOWED_SCOPES) +from st2common.constants.keyvalue import ( + FULL_SYSTEM_SCOPE, + FULL_USER_SCOPE, + USER_SCOPE, + ALLOWED_SCOPES, +) from st2common.models.db.auth import UserDB from st2common.exceptions.rbac import AccessDeniedError -__all__ = [ - 'get_datastore_full_scope', - 'get_key' -] +__all__ = ["get_datastore_full_scope", "get_key"] LOG = logging.getLogger(__name__) def _validate_scope(scope): if scope not in ALLOWED_SCOPES: - msg = 'Scope %s is not in allowed scopes list: %s.' % (scope, ALLOWED_SCOPES) + msg = "Scope %s is not in allowed scopes list: %s." % (scope, ALLOWED_SCOPES) raise ValueError(msg) @@ -48,9 +49,9 @@ def _validate_decrypt_query_parameter(decrypt, scope, is_admin, user_db): Validate that the provider user is either admin or requesting to decrypt value for themselves. """ - is_user_scope = (scope == USER_SCOPE or scope == FULL_USER_SCOPE) + is_user_scope = scope == USER_SCOPE or scope == FULL_USER_SCOPE if decrypt and (not is_user_scope and not is_admin): - msg = 'Decrypt option requires administrator access' + msg = "Decrypt option requires administrator access" raise AccessDeniedError(message=msg, user_db=user_db) @@ -61,7 +62,7 @@ def get_datastore_full_scope(scope): if DATASTORE_PARENT_SCOPE in scope: return scope - return '%s%s%s' % (DATASTORE_PARENT_SCOPE, DATASTORE_SCOPE_SEPARATOR, scope) + return "%s%s%s" % (DATASTORE_PARENT_SCOPE, DATASTORE_SCOPE_SEPARATOR, scope) def _derive_scope_and_key(key, user, scope=None): @@ -75,10 +76,10 @@ def _derive_scope_and_key(key, user, scope=None): if scope is not None: return scope, key - if key.startswith('system.'): - return FULL_SYSTEM_SCOPE, key[key.index('.') + 1:] + if key.startswith("system."): + return FULL_SYSTEM_SCOPE, key[key.index(".") + 1 :] - return FULL_USER_SCOPE, '%s:%s' % (user, key) + return FULL_USER_SCOPE, "%s:%s" % (user, key) def get_key(key=None, user_db=None, scope=None, decrypt=False): @@ -86,10 +87,10 @@ def get_key(key=None, user_db=None, scope=None, decrypt=False): Retrieve key from KVP store """ if not isinstance(key, six.string_types): - raise TypeError('Given key is not typeof string.') + raise TypeError("Given key is not typeof string.") if not isinstance(decrypt, bool): - raise TypeError('Decrypt parameter is not typeof bool.') + raise TypeError("Decrypt parameter is not typeof bool.") if not user_db: # Use system user @@ -98,9 +99,10 @@ def get_key(key=None, user_db=None, scope=None, decrypt=False): scope, key_id = _derive_scope_and_key(key=key, user=user_db.name, scope=scope) scope = get_datastore_full_scope(scope) - LOG.debug('get_key key_id: %s, scope: %s, user: %s, decrypt: %s' % (key_id, scope, - str(user_db.name), - decrypt)) + LOG.debug( + "get_key key_id: %s, scope: %s, user: %s, decrypt: %s" + % (key_id, scope, str(user_db.name), decrypt) + ) _validate_scope(scope=scope) @@ -108,8 +110,9 @@ def get_key(key=None, user_db=None, scope=None, decrypt=False): is_admin = rbac_utils.user_is_admin(user_db=user_db) # User needs to be either admin or requesting item for itself - _validate_decrypt_query_parameter(decrypt=decrypt, scope=scope, is_admin=is_admin, - user_db=user_db) + _validate_decrypt_query_parameter( + decrypt=decrypt, scope=scope, is_admin=is_admin, user_db=user_db + ) # Get the key value pair by scope and name. kvp = KeyValuePair.get_by_scope_and_name(scope, key_id) diff --git a/st2common/st2common/util/loader.py b/st2common/st2common/util/loader.py index 0e27a0da32f..1c5a5a4b542 100644 --- a/st2common/st2common/util/loader.py +++ b/st2common/st2common/util/loader.py @@ -28,19 +28,14 @@ from st2common.exceptions.plugins import IncompatiblePluginException from st2common import log as logging -__all__ = [ - 'register_plugin', - 'register_plugin_class', - - 'load_meta_file' -] +__all__ = ["register_plugin", "register_plugin_class", "load_meta_file"] LOG = logging.getLogger(__name__) -PYTHON_EXTENSION = '.py' -ALLOWED_EXTS = ['.json', '.yaml', '.yml'] -PARSER_FUNCS = {'.json': json.load, '.yml': yaml.safe_load, '.yaml': yaml.safe_load} +PYTHON_EXTENSION = ".py" +ALLOWED_EXTS = [".json", ".yaml", ".yml"] +PARSER_FUNCS = {".json": json.load, ".yml": yaml.safe_load, ".yaml": yaml.safe_load} # Cache for dynamically loaded runner modules RUNNER_MODULES_CACHE = defaultdict(dict) @@ -48,7 +43,9 @@ def _register_plugin_path(plugin_dir_abs_path): if not os.path.isdir(plugin_dir_abs_path): - raise Exception('Directory "%s" with plugins doesn\'t exist' % (plugin_dir_abs_path)) + raise Exception( + 'Directory "%s" with plugins doesn\'t exist' % (plugin_dir_abs_path) + ) for x in sys.path: if plugin_dir_abs_path in (x, x + os.sep): @@ -59,15 +56,21 @@ def _register_plugin_path(plugin_dir_abs_path): def _get_plugin_module(plugin_file_path): plugin_module = os.path.basename(plugin_file_path) if plugin_module.endswith(PYTHON_EXTENSION): - plugin_module = plugin_module[:plugin_module.rfind('.py')] + plugin_module = plugin_module[: plugin_module.rfind(".py")] else: plugin_module = None return plugin_module def _get_classes_in_module(module): - return [kls for name, kls in inspect.getmembers(module, - lambda member: inspect.isclass(member) and member.__module__ == module.__name__)] + return [ + kls + for name, kls in inspect.getmembers( + module, + lambda member: inspect.isclass(member) + and member.__module__ == module.__name__, + ) + ] def _get_plugin_classes(module_name): @@ -92,7 +95,7 @@ def _get_plugin_methods(plugin_klass): method_names = [] for name, method in methods: method_properties = method.__dict__ - is_abstract = method_properties.get('__isabstractmethod__', False) + is_abstract = method_properties.get("__isabstractmethod__", False) if is_abstract: continue @@ -102,16 +105,18 @@ def _get_plugin_methods(plugin_klass): def _validate_methods(plugin_base_class, plugin_klass): - ''' + """ XXX: This is hacky but we'd like to validate the methods in plugin_impl at least has all the *abstract* methods in plugin_base_class. - ''' + """ expected_methods = plugin_base_class.__abstractmethods__ plugin_methods = _get_plugin_methods(plugin_klass) for method in expected_methods: if method not in plugin_methods: - message = 'Class "%s" doesn\'t implement required "%s" method from the base class' + message = ( + 'Class "%s" doesn\'t implement required "%s" method from the base class' + ) raise IncompatiblePluginException(message % (plugin_klass.__name__, method)) @@ -147,8 +152,10 @@ def register_plugin_class(base_class, file_path, class_name): klass = getattr(module, class_name, None) if not klass: - raise Exception('Plugin file "%s" doesn\'t expose class named "%s"' % - (file_path, class_name)) + raise Exception( + 'Plugin file "%s" doesn\'t expose class named "%s"' + % (file_path, class_name) + ) _register_plugin(base_class, klass) return klass @@ -173,12 +180,14 @@ def register_plugin(plugin_base_class, plugin_abs_file_path): registered_plugins.append(klass) except Exception as e: LOG.exception(e) - LOG.debug('Skipping class %s as it doesn\'t match specs.', klass) + LOG.debug("Skipping class %s as it doesn't match specs.", klass) continue if len(registered_plugins) == 0: - raise Exception('Found no classes in plugin file "%s" matching requirements.' % - (plugin_abs_file_path)) + raise Exception( + 'Found no classes in plugin file "%s" matching requirements.' + % (plugin_abs_file_path) + ) return registered_plugins @@ -189,16 +198,17 @@ def load_meta_file(file_path): file_name, file_ext = os.path.splitext(file_path) if file_ext not in ALLOWED_EXTS: - raise Exception('Unsupported meta type %s, file %s. Allowed: %s' % - (file_ext, file_path, ALLOWED_EXTS)) + raise Exception( + "Unsupported meta type %s, file %s. Allowed: %s" + % (file_ext, file_path, ALLOWED_EXTS) + ) - with open(file_path, 'r') as f: + with open(file_path, "r") as f: return PARSER_FUNCS[file_ext](f) def get_available_plugins(namespace): - """Return names of the available / installed plugins for a given namespace. - """ + """Return names of the available / installed plugins for a given namespace.""" from stevedore.extension import ExtensionManager manager = ExtensionManager(namespace=namespace, invoke_on_load=False) @@ -206,9 +216,10 @@ def get_available_plugins(namespace): def get_plugin_instance(namespace, name, invoke_on_load=True): - """Return class instance for the provided plugin name and namespace. - """ + """Return class instance for the provided plugin name and namespace.""" from stevedore.driver import DriverManager - manager = DriverManager(namespace=namespace, name=name, invoke_on_load=invoke_on_load) + manager = DriverManager( + namespace=namespace, name=name, invoke_on_load=invoke_on_load + ) return manager.driver diff --git a/st2common/st2common/util/misc.py b/st2common/st2common/util/misc.py index 28773abedb9..6a1027e9fec 100644 --- a/st2common/st2common/util/misc.py +++ b/st2common/st2common/util/misc.py @@ -26,18 +26,17 @@ import six __all__ = [ - 'prefix_dict_keys', - 'compare_path_file_name', - 'get_field_name_from_mongoengine_error', - - 'sanitize_output', - 'strip_shell_chars', - 'rstrip_last_char', - 'lowercase_value' + "prefix_dict_keys", + "compare_path_file_name", + "get_field_name_from_mongoengine_error", + "sanitize_output", + "strip_shell_chars", + "rstrip_last_char", + "lowercase_value", ] -def prefix_dict_keys(dictionary, prefix='_'): +def prefix_dict_keys(dictionary, prefix="_"): """ Prefix dictionary keys with a provided prefix. @@ -52,7 +51,7 @@ def prefix_dict_keys(dictionary, prefix='_'): result = {} for key, value in six.iteritems(dictionary): - result['%s%s' % (prefix, key)] = value + result["%s%s" % (prefix, key)] = value return result @@ -89,7 +88,7 @@ def sanitize_output(input_str, uses_pty=False): output = strip_shell_chars(input_str) if uses_pty: - output = output.replace('\r\n', '\n') + output = output.replace("\r\n", "\n") return output @@ -105,8 +104,8 @@ def strip_shell_chars(input_str): :rtype: ``str`` """ - stripped_str = rstrip_last_char(input_str, '\n') - stripped_str = rstrip_last_char(stripped_str, '\r') + stripped_str = rstrip_last_char(input_str, "\n") + stripped_str = rstrip_last_char(stripped_str, "\r") return stripped_str @@ -127,7 +126,7 @@ def rstrip_last_char(input_str, char_to_strip): return input_str if input_str.endswith(char_to_strip): - return input_str[:-len(char_to_strip)] + return input_str[: -len(char_to_strip)] return input_str @@ -153,10 +152,10 @@ def get_normalized_file_path(file_path): :rtype: ``str`` """ - if hasattr(sys, 'frozen'): # support for py2exe - file_path = 'logging%s__init__%s' % (os.sep, file_path[-4:]) - elif file_path[-4:].lower() in ['.pyc', '.pyo']: - file_path = file_path[:-4] + '.py' + if hasattr(sys, "frozen"): # support for py2exe + file_path = "logging%s__init__%s" % (os.sep, file_path[-4:]) + elif file_path[-4:].lower() in [".pyc", ".pyo"]: + file_path = file_path[:-4] + ".py" else: file_path = file_path @@ -193,7 +192,7 @@ def get_field_name_from_mongoengine_error(exc): """ msg = str(exc) - match = re.match("Cannot resolve field \"(.+?)\"", msg) + match = re.match('Cannot resolve field "(.+?)"', msg) if match: return match.groups()[0] @@ -201,7 +200,9 @@ def get_field_name_from_mongoengine_error(exc): return msg -def ignore_and_log_exception(exc_classes=(Exception,), logger=None, level=logging.WARNING): +def ignore_and_log_exception( + exc_classes=(Exception,), logger=None, level=logging.WARNING +): """ Decorator which catches the provided exception classes and logs them instead of letting them bubble all the way up. @@ -214,13 +215,14 @@ def wrapper(*args, **kwargs): try: return func(*args, **kwargs) except exc_classes as e: - if len(args) >= 1 and getattr(args[0], '__class__', None): - func_name = '%s.%s' % (args[0].__class__.__name__, func.__name__) + if len(args) >= 1 and getattr(args[0], "__class__", None): + func_name = "%s.%s" % (args[0].__class__.__name__, func.__name__) else: func_name = func.__name__ - message = ('Exception in fuction "%s": %s' % (func_name, str(e))) + message = 'Exception in fuction "%s": %s' % (func_name, str(e)) logger.log(level, message) return wrapper + return decorator diff --git a/st2common/st2common/util/mongoescape.py b/st2common/st2common/util/mongoescape.py index 6d42b4972cc..d75d9502f48 100644 --- a/st2common/st2common/util/mongoescape.py +++ b/st2common/st2common/util/mongoescape.py @@ -21,17 +21,22 @@ from st2common.util.ujson import fast_deepcopy # Note: Because of old rule escaping code, two different characters can be translated back to dot -RULE_CRITERIA_UNESCAPED = ['.'] -RULE_CRITERIA_ESCAPED = [u'\u2024'] -RULE_CRITERIA_ESCAPE_TRANSLATION = dict(list(zip(RULE_CRITERIA_UNESCAPED, RULE_CRITERIA_ESCAPED))) -RULE_CRITERIA_UNESCAPE_TRANSLATION = dict(list(zip(RULE_CRITERIA_ESCAPED, RULE_CRITERIA_UNESCAPED))) +RULE_CRITERIA_UNESCAPED = ["."] +RULE_CRITERIA_ESCAPED = ["\u2024"] +RULE_CRITERIA_ESCAPE_TRANSLATION = dict( + list(zip(RULE_CRITERIA_UNESCAPED, RULE_CRITERIA_ESCAPED)) +) +RULE_CRITERIA_UNESCAPE_TRANSLATION = dict( + list(zip(RULE_CRITERIA_ESCAPED, RULE_CRITERIA_UNESCAPED)) +) # http://docs.mongodb.org/manual/faq/developers/#faq-dollar-sign-escaping -UNESCAPED = ['.', '$'] -ESCAPED = [u'\uFF0E', u'\uFF04'] +UNESCAPED = [".", "$"] +ESCAPED = ["\uFF0E", "\uFF04"] ESCAPE_TRANSLATION = dict(list(zip(UNESCAPED, ESCAPED))) UNESCAPE_TRANSLATION = dict( - list(zip(ESCAPED, UNESCAPED)) + list(zip(RULE_CRITERIA_ESCAPED, RULE_CRITERIA_UNESCAPED)) + list(zip(ESCAPED, UNESCAPED)) + + list(zip(RULE_CRITERIA_ESCAPED, RULE_CRITERIA_UNESCAPED)) ) diff --git a/st2common/st2common/util/monkey_patch.py b/st2common/st2common/util/monkey_patch.py index 5a042fd656b..76b4a191dee 100644 --- a/st2common/st2common/util/monkey_patch.py +++ b/st2common/st2common/util/monkey_patch.py @@ -22,13 +22,13 @@ import sys __all__ = [ - 'monkey_patch', - 'use_select_poll_workaround', - 'is_use_debugger_flag_provided' + "monkey_patch", + "use_select_poll_workaround", + "is_use_debugger_flag_provided", ] -USE_DEBUGGER_FLAG = '--use-debugger' -PARENT_ARGS_FLAG = '--parent-args=' +USE_DEBUGGER_FLAG = "--use-debugger" +PARENT_ARGS_FLAG = "--parent-args=" def monkey_patch(patch_thread=None): @@ -48,7 +48,9 @@ def monkey_patch(patch_thread=None): if patch_thread is None: patch_thread = not is_use_debugger_flag_provided() - eventlet.monkey_patch(os=True, select=True, socket=True, thread=patch_thread, time=True) + eventlet.monkey_patch( + os=True, select=True, socket=True, thread=patch_thread, time=True + ) def use_select_poll_workaround(nose_only=True): @@ -80,20 +82,20 @@ def use_select_poll_workaround(nose_only=True): import eventlet # Work around to get tests to pass with eventlet >= 0.20.0 - if not nose_only or (nose_only and 'nose' in sys.modules.keys()): + if not nose_only or (nose_only and "nose" in sys.modules.keys()): # Add back blocking poll() to eventlet monkeypatched select - original_poll = eventlet.patcher.original('select').poll + original_poll = eventlet.patcher.original("select").poll select.poll = original_poll - sys.modules['select'] = select + sys.modules["select"] = select subprocess.select = select if sys.version_info >= (3, 6, 5): # If we also don't patch selectors.select, it will fail with Python >= 3.6.5 import selectors # pylint: disable=import-error - sys.modules['selectors'] = selectors - selectors.select = sys.modules['select'] + sys.modules["selectors"] = selectors + selectors.select = sys.modules["select"] def is_use_debugger_flag_provided(): diff --git a/st2common/st2common/util/output_schema.py b/st2common/st2common/util/output_schema.py index 607f1af0bbb..2bde19c3c02 100644 --- a/st2common/st2common/util/output_schema.py +++ b/st2common/st2common/util/output_schema.py @@ -26,37 +26,36 @@ def _validate_runner(runner_schema, result): - LOG.debug('Validating runner output: %s', runner_schema) + LOG.debug("Validating runner output: %s", runner_schema) runner_schema = { "type": "object", "properties": runner_schema, - "additionalProperties": False + "additionalProperties": False, } - schema.validate(result, runner_schema, cls=schema.get_validator('custom')) + schema.validate(result, runner_schema, cls=schema.get_validator("custom")) def _validate_action(action_schema, result, output_key): - LOG.debug('Validating action output: %s', action_schema) + LOG.debug("Validating action output: %s", action_schema) final_result = result[output_key] action_schema = { "type": "object", "properties": action_schema, - "additionalProperties": False + "additionalProperties": False, } - schema.validate(final_result, action_schema, cls=schema.get_validator('custom')) + schema.validate(final_result, action_schema, cls=schema.get_validator("custom")) def validate_output(runner_schema, action_schema, result, status, output_key): - """ Validate output of action with runner and action schema. - """ + """Validate output of action with runner and action schema.""" try: - LOG.debug('Validating action output: %s', result) - LOG.debug('Output Key: %s', output_key) + LOG.debug("Validating action output: %s", result) + LOG.debug("Output Key: %s", output_key) if runner_schema: _validate_runner(runner_schema, result) @@ -64,26 +63,26 @@ def validate_output(runner_schema, action_schema, result, status, output_key): _validate_action(action_schema, result, output_key) except jsonschema.ValidationError: - LOG.exception('Failed to validate output.') + LOG.exception("Failed to validate output.") _, ex, _ = sys.exc_info() # mark execution as failed. status = action_constants.LIVEACTION_STATUS_FAILED # include the error message and traceback to try and provide some hints. result = { - 'error': str(ex), - 'message': 'Error validating output. See error output for more details.', + "error": str(ex), + "message": "Error validating output. See error output for more details.", } return (result, status) except: - LOG.exception('Failed to validate output.') + LOG.exception("Failed to validate output.") _, ex, tb = sys.exc_info() # mark execution as failed. status = action_constants.LIVEACTION_STATUS_FAILED # include the error message and traceback to try and provide some hints. result = { - 'traceback': ''.join(traceback.format_tb(tb, 20)), - 'error': str(ex), - 'message': 'Error validating output. See error output for more details.', + "traceback": "".join(traceback.format_tb(tb, 20)), + "error": str(ex), + "message": "Error validating output. See error output for more details.", } return (result, status) diff --git a/st2common/st2common/util/pack.py b/st2common/st2common/util/pack.py index 6ac4e4fc48c..43dde600518 100644 --- a/st2common/st2common/util/pack.py +++ b/st2common/st2common/util/pack.py @@ -30,27 +30,28 @@ from st2common.util import jinja as jinja_utils __all__ = [ - 'get_pack_ref_from_metadata', - 'get_pack_metadata', - 'get_pack_warnings', - - 'get_pack_common_libs_path_for_pack_ref', - 'get_pack_common_libs_path_for_pack_db', - - 'validate_config_against_schema', - - 'normalize_pack_version' + "get_pack_ref_from_metadata", + "get_pack_metadata", + "get_pack_warnings", + "get_pack_common_libs_path_for_pack_ref", + "get_pack_common_libs_path_for_pack_db", + "validate_config_against_schema", + "normalize_pack_version", ] # Common format for python 2.7 warning if six.PY2: - PACK_PYTHON2_WARNING = "DEPRECATION WARNING: Pack %s only supports Python 2.x. " \ - "Python 2 support will be dropped in future releases. " \ - "Please consider updating your packs to work with Python 3.x" + PACK_PYTHON2_WARNING = ( + "DEPRECATION WARNING: Pack %s only supports Python 2.x. " + "Python 2 support will be dropped in future releases. " + "Please consider updating your packs to work with Python 3.x" + ) else: - PACK_PYTHON2_WARNING = "DEPRECATION WARNING: Pack %s only supports Python 2.x. " \ - "Python 2 support has been removed since st2 v3.4.0. " \ - "Please update your packs to work with Python 3.x" + PACK_PYTHON2_WARNING = ( + "DEPRECATION WARNING: Pack %s only supports Python 2.x. " + "Python 2 support has been removed since st2 v3.4.0. " + "Please update your packs to work with Python 3.x" + ) def get_pack_ref_from_metadata(metadata, pack_directory_name=None): @@ -69,19 +70,23 @@ def get_pack_ref_from_metadata(metadata, pack_directory_name=None): # which are in sub-directories) # 2. If attribute is not available, but pack name is and pack name meets the valid name # criteria, we use that - if metadata.get('ref', None): - pack_ref = metadata['ref'] - elif pack_directory_name and re.match(PACK_REF_WHITELIST_REGEX, pack_directory_name): + if metadata.get("ref", None): + pack_ref = metadata["ref"] + elif pack_directory_name and re.match( + PACK_REF_WHITELIST_REGEX, pack_directory_name + ): pack_ref = pack_directory_name else: - if re.match(PACK_REF_WHITELIST_REGEX, metadata['name']): - pack_ref = metadata['name'] + if re.match(PACK_REF_WHITELIST_REGEX, metadata["name"]): + pack_ref = metadata["name"] else: - msg = ('Pack name "%s" contains invalid characters and "ref" attribute is not ' - 'available. You either need to add "ref" attribute which contains only word ' - 'characters to the pack metadata file or update name attribute to contain only' - 'word characters.') - raise ValueError(msg % (metadata['name'])) + msg = ( + 'Pack name "%s" contains invalid characters and "ref" attribute is not ' + 'available. You either need to add "ref" attribute which contains only word ' + "characters to the pack metadata file or update name attribute to contain only" + "word characters." + ) + raise ValueError(msg % (metadata["name"])) return pack_ref @@ -95,7 +100,9 @@ def get_pack_metadata(pack_dir): manifest_path = os.path.join(pack_dir, MANIFEST_FILE_NAME) if not os.path.isfile(manifest_path): - raise ValueError('Pack "%s" is missing %s file' % (pack_dir, MANIFEST_FILE_NAME)) + raise ValueError( + 'Pack "%s" is missing %s file' % (pack_dir, MANIFEST_FILE_NAME) + ) meta_loader = MetaLoader() content = meta_loader.load(manifest_path) @@ -112,15 +119,16 @@ def get_pack_warnings(pack_metadata): :rtype: ``str`` """ warning = None - versions = pack_metadata.get('python_versions', None) - pack_name = pack_metadata.get('name', None) - if versions and set(versions) == set(['2']): + versions = pack_metadata.get("python_versions", None) + pack_name = pack_metadata.get("name", None) + if versions and set(versions) == set(["2"]): warning = PACK_PYTHON2_WARNING % pack_name return warning -def validate_config_against_schema(config_schema, config_object, config_path, - pack_name=None): +def validate_config_against_schema( + config_schema, config_object, config_path, pack_name=None +): """ Validate provided config dictionary against the provided config schema dictionary. @@ -128,35 +136,49 @@ def validate_config_against_schema(config_schema, config_object, config_path, # NOTE: Lazy improt to avoid performance overhead of importing this module when it's not used import jsonschema - pack_name = pack_name or 'unknown' + pack_name = pack_name or "unknown" - schema = util_schema.get_schema_for_resource_parameters(parameters_schema=config_schema, - allow_additional_properties=True) + schema = util_schema.get_schema_for_resource_parameters( + parameters_schema=config_schema, allow_additional_properties=True + ) instance = config_object try: - cleaned = util_schema.validate(instance=instance, schema=schema, - cls=util_schema.CustomValidator, use_default=True, - allow_default_none=True) + cleaned = util_schema.validate( + instance=instance, + schema=schema, + cls=util_schema.CustomValidator, + use_default=True, + allow_default_none=True, + ) for key in cleaned: - if (jinja_utils.is_jinja_expression(value=cleaned.get(key)) and - "decrypt_kv" in cleaned.get(key) and config_schema.get(key).get('secret')): - raise ValueValidationException('Values specified as "secret: True" in config ' - 'schema are automatically decrypted by default. Use ' - 'of "decrypt_kv" jinja filter is not allowed for ' - 'such values. Please check the specified values in ' - 'the config or the default values in the schema.') + if ( + jinja_utils.is_jinja_expression(value=cleaned.get(key)) + and "decrypt_kv" in cleaned.get(key) + and config_schema.get(key).get("secret") + ): + raise ValueValidationException( + 'Values specified as "secret: True" in config ' + "schema are automatically decrypted by default. Use " + 'of "decrypt_kv" jinja filter is not allowed for ' + "such values. Please check the specified values in " + "the config or the default values in the schema." + ) except jsonschema.ValidationError as e: - attribute = getattr(e, 'path', []) + attribute = getattr(e, "path", []) if isinstance(attribute, (tuple, list, collections.Iterable)): attribute = [str(item) for item in attribute] - attribute = '.'.join(attribute) + attribute = ".".join(attribute) else: attribute = str(attribute) - msg = ('Failed validating attribute "%s" in config for pack "%s" (%s): %s' % - (attribute, pack_name, config_path, six.text_type(e))) + msg = 'Failed validating attribute "%s" in config for pack "%s" (%s): %s' % ( + attribute, + pack_name, + config_path, + six.text_type(e), + ) raise jsonschema.ValidationError(msg) return cleaned @@ -183,12 +205,12 @@ def get_pack_common_libs_path_for_pack_db(pack_db): :rtype: ``str`` """ - pack_dir = getattr(pack_db, 'path', None) + pack_dir = getattr(pack_db, "path", None) if not pack_dir: return None - libs_path = os.path.join(pack_dir, 'lib') + libs_path = os.path.join(pack_dir, "lib") return libs_path @@ -202,8 +224,8 @@ def normalize_pack_version(version): """ version = str(version) - version_seperator_count = version.count('.') + version_seperator_count = version.count(".") if version_seperator_count == 1: - version = version + '.0' + version = version + ".0" return version diff --git a/st2common/st2common/util/pack_management.py b/st2common/st2common/util/pack_management.py index 48b94572037..0fde5b1d86b 100644 --- a/st2common/st2common/util/pack_management.py +++ b/st2common/st2common/util/pack_management.py @@ -48,29 +48,33 @@ from st2common.util.versioning import get_python_version __all__ = [ - 'download_pack', - - 'get_repo_url', - 'eval_repo_url', - - 'apply_pack_owner_group', - 'apply_pack_permissions', - - 'get_and_set_proxy_config' + "download_pack", + "get_repo_url", + "eval_repo_url", + "apply_pack_owner_group", + "apply_pack_permissions", + "get_and_set_proxy_config", ] LOG = logging.getLogger(__name__) -CONFIG_FILE = 'config.yaml' +CONFIG_FILE = "config.yaml" CURRENT_STACKSTORM_VERSION = get_stackstorm_version() CURRENT_PYTHON_VERSION = get_python_version() -SUDO_BINARY = find_executable('sudo') +SUDO_BINARY = find_executable("sudo") -def download_pack(pack, abs_repo_base='/opt/stackstorm/packs', verify_ssl=True, force=False, - proxy_config=None, force_owner_group=True, force_permissions=True, - logger=LOG): +def download_pack( + pack, + abs_repo_base="/opt/stackstorm/packs", + verify_ssl=True, + force=False, + proxy_config=None, + force_owner_group=True, + force_permissions=True, + logger=LOG, +): """ Download the pack and move it to /opt/stackstorm/packs. @@ -105,11 +109,11 @@ def download_pack(pack, abs_repo_base='/opt/stackstorm/packs', verify_ssl=True, result = [pack_url, None, None] temp_dir_name = hashlib.md5(pack_url.encode()).hexdigest() - lock_file = LockFile('/tmp/%s' % (temp_dir_name)) + lock_file = LockFile("/tmp/%s" % (temp_dir_name)) lock_file_path = lock_file.lock_file if force: - logger.debug('Force mode is enabled, deleting lock file...') + logger.debug("Force mode is enabled, deleting lock file...") try: os.unlink(lock_file_path) @@ -119,31 +123,42 @@ def download_pack(pack, abs_repo_base='/opt/stackstorm/packs', verify_ssl=True, with lock_file: try: - user_home = os.path.expanduser('~') + user_home = os.path.expanduser("~") abs_local_path = os.path.join(user_home, temp_dir_name) - if pack_url.startswith('file://'): + if pack_url.startswith("file://"): # Local pack - local_pack_directory = os.path.abspath(os.path.join(pack_url.split('file://')[1])) + local_pack_directory = os.path.abspath( + os.path.join(pack_url.split("file://")[1]) + ) else: local_pack_directory = None # If it's a local pack which is not a git repository, just copy the directory content # over if local_pack_directory and not os.path.isdir( - os.path.join(local_pack_directory, '.git')): + os.path.join(local_pack_directory, ".git") + ): if not os.path.isdir(local_pack_directory): - raise ValueError('Local pack directory "%s" doesn\'t exist' % - (local_pack_directory)) + raise ValueError( + 'Local pack directory "%s" doesn\'t exist' + % (local_pack_directory) + ) - logger.debug('Detected local pack directory which is not a git repository, just ' - 'copying files over...') + logger.debug( + "Detected local pack directory which is not a git repository, just " + "copying files over..." + ) shutil.copytree(local_pack_directory, abs_local_path) else: # 1. Clone / download the repo - clone_repo(temp_dir=abs_local_path, repo_url=pack_url, verify_ssl=verify_ssl, - ref=pack_version) + clone_repo( + temp_dir=abs_local_path, + repo_url=pack_url, + verify_ssl=verify_ssl, + ref=pack_version, + ) pack_metadata = get_pack_metadata(pack_dir=abs_local_path) pack_ref = get_pack_ref(pack_dir=abs_local_path) @@ -154,12 +169,15 @@ def download_pack(pack, abs_repo_base='/opt/stackstorm/packs', verify_ssl=True, verify_pack_version(pack_metadata=pack_metadata) # 3. Move pack to the final location - move_result = move_pack(abs_repo_base=abs_repo_base, pack_name=pack_ref, - abs_local_path=abs_local_path, - pack_metadata=pack_metadata, - force_owner_group=force_owner_group, - force_permissions=force_permissions, - logger=logger) + move_result = move_pack( + abs_repo_base=abs_repo_base, + pack_name=pack_ref, + abs_local_path=abs_local_path, + pack_metadata=pack_metadata, + force_owner_group=force_owner_group, + force_permissions=force_permissions, + logger=logger, + ) result[2] = move_result finally: cleanup_repo(abs_local_path=abs_local_path) @@ -167,21 +185,21 @@ def download_pack(pack, abs_repo_base='/opt/stackstorm/packs', verify_ssl=True, return tuple(result) -def clone_repo(temp_dir, repo_url, verify_ssl=True, ref='master'): +def clone_repo(temp_dir, repo_url, verify_ssl=True, ref="master"): # Switch to non-interactive mode - os.environ['GIT_TERMINAL_PROMPT'] = '0' - os.environ['GIT_ASKPASS'] = '/bin/echo' + os.environ["GIT_TERMINAL_PROMPT"] = "0" + os.environ["GIT_ASKPASS"] = "/bin/echo" # Disable SSL cert checking if explictly asked if not verify_ssl: - os.environ['GIT_SSL_NO_VERIFY'] = 'true' + os.environ["GIT_SSL_NO_VERIFY"] = "true" # Clone the repo from git; we don't use shallow copying # because we want the user to work with the repo in the # future. repo = Repo.clone_from(repo_url, temp_dir) - is_local_repo = repo_url.startswith('file://') + is_local_repo = repo_url.startswith("file://") try: active_branch = repo.active_branch @@ -194,18 +212,20 @@ def clone_repo(temp_dir, repo_url, verify_ssl=True, ref='master'): # Special case for local git repos - we allow users to install from repos which are checked out # at a specific commit (aka detached HEAD) if is_local_repo and not active_branch and not ref: - LOG.debug('Installing pack from git repo on disk, skipping branch checkout') + LOG.debug("Installing pack from git repo on disk, skipping branch checkout") return temp_dir use_branch = False # Special case when a default repo branch is not "master" # No ref provided so we just use a default active branch - if (not ref or ref == active_branch.name) and repo.active_branch.object == repo.head.commit: + if ( + not ref or ref == active_branch.name + ) and repo.active_branch.object == repo.head.commit: gitref = repo.active_branch.object else: # Try to match the reference to a branch name (i.e. "master") - gitref = get_gitref(repo, 'origin/%s' % ref) + gitref = get_gitref(repo, "origin/%s" % ref) if gitref: use_branch = True @@ -215,7 +235,7 @@ def clone_repo(temp_dir, repo_url, verify_ssl=True, ref='master'): # Try to match the reference to a "vX.Y.Z" tag if not gitref and re.match(PACK_VERSION_REGEX, ref): - gitref = get_gitref(repo, 'v%s' % ref) + gitref = get_gitref(repo, "v%s" % ref) # Giving up ¯\_(ツ)_/¯ if not gitref: @@ -224,43 +244,52 @@ def clone_repo(temp_dir, repo_url, verify_ssl=True, ref='master'): valid_versions = get_valid_versions_for_repo(repo=repo) if len(valid_versions) >= 1: - valid_versions_string = ', '.join(valid_versions) + valid_versions_string = ", ".join(valid_versions) - msg += ' Available versions are: %s.' + msg += " Available versions are: %s." format_values.append(valid_versions_string) raise ValueError(msg % tuple(format_values)) # We're trying to figure out which branch the ref is actually on, # since there's no direct way to check for this in git-python. - branches = repo.git.branch('-a', '--contains', gitref.hexsha) # pylint: disable=no-member + branches = repo.git.branch( + "-a", "--contains", gitref.hexsha + ) # pylint: disable=no-member # Git tags aren't necessarily on a branch. # If this is the case, gitref will be the tag name, but branches will be # empty. # We also need to checkout tags slightly differently than branches. if branches: - branches = branches.replace('*', '').split() + branches = branches.replace("*", "").split() if active_branch.name not in branches or use_branch: - branch = 'origin/%s' % ref if use_branch else branches[0] - short_branch = ref if use_branch else branches[0].split('/')[-1] - repo.git.checkout('-b', short_branch, branch) + branch = "origin/%s" % ref if use_branch else branches[0] + short_branch = ref if use_branch else branches[0].split("/")[-1] + repo.git.checkout("-b", short_branch, branch) branch = repo.head.reference else: branch = repo.active_branch.name repo.git.checkout(gitref.hexsha) # pylint: disable=no-member - repo.git.branch('-f', branch, gitref.hexsha) # pylint: disable=no-member + repo.git.branch("-f", branch, gitref.hexsha) # pylint: disable=no-member repo.git.checkout(branch) else: - repo.git.checkout('v%s' % ref) # pylint: disable=no-member + repo.git.checkout("v%s" % ref) # pylint: disable=no-member return temp_dir -def move_pack(abs_repo_base, pack_name, abs_local_path, pack_metadata, force_owner_group=True, - force_permissions=True, logger=LOG): +def move_pack( + abs_repo_base, + pack_name, + abs_local_path, + pack_metadata, + force_owner_group=True, + force_permissions=True, + logger=LOG, +): """ Move pack directory into the final location. """ @@ -270,8 +299,9 @@ def move_pack(abs_repo_base, pack_name, abs_local_path, pack_metadata, force_own to = abs_repo_base dest_pack_path = os.path.join(abs_repo_base, pack_name) if os.path.exists(dest_pack_path): - logger.debug('Removing existing pack %s in %s to replace.', pack_name, - dest_pack_path) + logger.debug( + "Removing existing pack %s in %s to replace.", pack_name, dest_pack_path + ) # Ensure to preserve any existing configuration old_config_file = os.path.join(dest_pack_path, CONFIG_FILE) @@ -282,7 +312,7 @@ def move_pack(abs_repo_base, pack_name, abs_local_path, pack_metadata, force_own shutil.rmtree(dest_pack_path) - logger.debug('Moving pack from %s to %s.', abs_local_path, to) + logger.debug("Moving pack from %s to %s.", abs_local_path, to) shutil.move(abs_local_path, dest_pack_path) # post move fix all permissions @@ -299,9 +329,9 @@ def move_pack(abs_repo_base, pack_name, abs_local_path, pack_metadata, force_own if warning: logger.warning(warning) - message = 'Success.' + message = "Success." elif message: - message = 'Failure : %s' % message + message = "Failure : %s" % message return (desired, message) @@ -316,20 +346,25 @@ def apply_pack_owner_group(pack_path): pack_group = utils.get_pack_group() if pack_group: - LOG.debug('Changing owner group of "{}" directory to {}'.format(pack_path, pack_group)) + LOG.debug( + 'Changing owner group of "{}" directory to {}'.format(pack_path, pack_group) + ) if SUDO_BINARY: - args = ['sudo', 'chgrp', '-R', pack_group, pack_path] + args = ["sudo", "chgrp", "-R", pack_group, pack_path] else: # Environments where sudo is not available (e.g. docker) - args = ['chgrp', '-R', pack_group, pack_path] + args = ["chgrp", "-R", pack_group, pack_path] exit_code, _, stderr, _ = shell.run_command(args) if exit_code != 0: # Non fatal, but we still log it - LOG.debug('Failed to change owner group on directory "{}" to "{}": {}' - .format(pack_path, pack_group, stderr)) + LOG.debug( + 'Failed to change owner group on directory "{}" to "{}": {}'.format( + pack_path, pack_group, stderr + ) + ) return True @@ -370,13 +405,13 @@ def get_repo_url(pack, proxy_config=None): name_or_url = pack_and_version[0] version = pack_and_version[1] if len(pack_and_version) > 1 else None - if len(name_or_url.split('/')) == 1: + if len(name_or_url.split("/")) == 1: pack = get_pack_from_index(name_or_url, proxy_config=proxy_config) if not pack: raise Exception('No record of the "%s" pack in the index.' % (name_or_url)) - return (pack['repo_url'], version or pack['version']) + return (pack["repo_url"], version or pack["version"]) else: return (eval_repo_url(name_or_url), version) @@ -386,12 +421,12 @@ def eval_repo_url(repo_url): Allow passing short GitHub or GitLab SSH style URLs. """ if not repo_url: - raise Exception('No valid repo_url provided or could be inferred.') + raise Exception("No valid repo_url provided or could be inferred.") if repo_url.startswith("gitlab@") or repo_url.startswith("file://"): return repo_url else: - if len(repo_url.split('/')) == 2 and 'git@' not in repo_url: - url = 'https://github.com/{}'.format(repo_url) + if len(repo_url.split("/")) == 2 and "git@" not in repo_url: + url = "https://github.com/{}".format(repo_url) else: url = repo_url return url @@ -400,50 +435,65 @@ def eval_repo_url(repo_url): def is_desired_pack(abs_pack_path, pack_name): # path has to exist. if not os.path.exists(abs_pack_path): - return (False, 'Pack "%s" not found or it\'s missing a "pack.yaml" file.' % - (pack_name)) + return ( + False, + 'Pack "%s" not found or it\'s missing a "pack.yaml" file.' % (pack_name), + ) # should not include reserved characters for character in PACK_RESERVED_CHARACTERS: if character in pack_name: - return (False, 'Pack name "%s" contains reserved character "%s"' % - (pack_name, character)) + return ( + False, + 'Pack name "%s" contains reserved character "%s"' + % (pack_name, character), + ) # must contain a manifest file. Empty file is ok for now. if not os.path.isfile(os.path.join(abs_pack_path, MANIFEST_FILE_NAME)): - return (False, 'Pack is missing a manifest file (%s).' % (MANIFEST_FILE_NAME)) + return (False, "Pack is missing a manifest file (%s)." % (MANIFEST_FILE_NAME)) - return (True, '') + return (True, "") def verify_pack_version(pack_metadata): """ Verify that the pack works with the currently running StackStorm version. """ - pack_name = pack_metadata.get('name', None) - required_stackstorm_version = pack_metadata.get('stackstorm_version', None) - supported_python_versions = pack_metadata.get('python_versions', None) + pack_name = pack_metadata.get("name", None) + required_stackstorm_version = pack_metadata.get("stackstorm_version", None) + supported_python_versions = pack_metadata.get("python_versions", None) # If stackstorm_version attribute is specified, verify that the pack works with currently # running version of StackStorm if required_stackstorm_version: - if not complex_semver_match(CURRENT_STACKSTORM_VERSION, required_stackstorm_version): - msg = ('Pack "%s" requires StackStorm "%s", but current version is "%s". ' - 'You can override this restriction by providing the "force" flag, but ' - 'the pack is not guaranteed to work.' % - (pack_name, required_stackstorm_version, CURRENT_STACKSTORM_VERSION)) + if not complex_semver_match( + CURRENT_STACKSTORM_VERSION, required_stackstorm_version + ): + msg = ( + 'Pack "%s" requires StackStorm "%s", but current version is "%s". ' + 'You can override this restriction by providing the "force" flag, but ' + "the pack is not guaranteed to work." + % (pack_name, required_stackstorm_version, CURRENT_STACKSTORM_VERSION) + ) raise ValueError(msg) if supported_python_versions: - if set(supported_python_versions) == set(['2']) and (not six.PY2): - msg = ('Pack "%s" requires Python 2.x, but current Python version is "%s". ' - 'You can override this restriction by providing the "force" flag, but ' - 'the pack is not guaranteed to work.' % (pack_name, CURRENT_PYTHON_VERSION)) + if set(supported_python_versions) == set(["2"]) and (not six.PY2): + msg = ( + 'Pack "%s" requires Python 2.x, but current Python version is "%s". ' + 'You can override this restriction by providing the "force" flag, but ' + "the pack is not guaranteed to work." + % (pack_name, CURRENT_PYTHON_VERSION) + ) raise ValueError(msg) - elif set(supported_python_versions) == set(['3']) and (not six.PY3): - msg = ('Pack "%s" requires Python 3.x, but current Python version is "%s". ' - 'You can override this restriction by providing the "force" flag, but ' - 'the pack is not guaranteed to work.' % (pack_name, CURRENT_PYTHON_VERSION)) + elif set(supported_python_versions) == set(["3"]) and (not six.PY3): + msg = ( + 'Pack "%s" requires Python 3.x, but current Python version is "%s". ' + 'You can override this restriction by providing the "force" flag, but ' + "the pack is not guaranteed to work." + % (pack_name, CURRENT_PYTHON_VERSION) + ) raise ValueError(msg) else: # Pack support Python 2.x and 3.x so no check is needed, or @@ -474,7 +524,7 @@ def get_valid_versions_for_repo(repo): valid_versions = [] for tag in repo.tags: - if tag.name.startswith('v') and re.match(PACK_VERSION_REGEX, tag.name[1:]): + if tag.name.startswith("v") and re.match(PACK_VERSION_REGEX, tag.name[1:]): # Note: We strip leading "v" from the version number valid_versions.append(tag.name[1:]) @@ -486,39 +536,38 @@ def get_pack_ref(pack_dir): Read pack reference from the metadata file and sanitize it. """ metadata = get_pack_metadata(pack_dir=pack_dir) - pack_ref = get_pack_ref_from_metadata(metadata=metadata, - pack_directory_name=None) + pack_ref = get_pack_ref_from_metadata(metadata=metadata, pack_directory_name=None) return pack_ref def get_and_set_proxy_config(): - https_proxy = os.environ.get('https_proxy', None) - http_proxy = os.environ.get('http_proxy', None) - proxy_ca_bundle_path = os.environ.get('proxy_ca_bundle_path', None) - no_proxy = os.environ.get('no_proxy', None) + https_proxy = os.environ.get("https_proxy", None) + http_proxy = os.environ.get("http_proxy", None) + proxy_ca_bundle_path = os.environ.get("proxy_ca_bundle_path", None) + no_proxy = os.environ.get("no_proxy", None) proxy_config = {} if http_proxy or https_proxy: - LOG.debug('Using proxy %s', http_proxy if http_proxy else https_proxy) + LOG.debug("Using proxy %s", http_proxy if http_proxy else https_proxy) proxy_config = { - 'https_proxy': https_proxy, - 'http_proxy': http_proxy, - 'proxy_ca_bundle_path': proxy_ca_bundle_path, - 'no_proxy': no_proxy + "https_proxy": https_proxy, + "http_proxy": http_proxy, + "proxy_ca_bundle_path": proxy_ca_bundle_path, + "no_proxy": no_proxy, } - if https_proxy and not os.environ.get('https_proxy', None): - os.environ['https_proxy'] = https_proxy + if https_proxy and not os.environ.get("https_proxy", None): + os.environ["https_proxy"] = https_proxy - if http_proxy and not os.environ.get('http_proxy', None): - os.environ['http_proxy'] = http_proxy + if http_proxy and not os.environ.get("http_proxy", None): + os.environ["http_proxy"] = http_proxy - if no_proxy and not os.environ.get('no_proxy', None): - os.environ['no_proxy'] = no_proxy + if no_proxy and not os.environ.get("no_proxy", None): + os.environ["no_proxy"] = no_proxy - if proxy_ca_bundle_path and not os.environ.get('proxy_ca_bundle_path', None): - os.environ['no_proxy'] = no_proxy + if proxy_ca_bundle_path and not os.environ.get("proxy_ca_bundle_path", None): + os.environ["no_proxy"] = no_proxy return proxy_config diff --git a/st2common/st2common/util/param.py b/st2common/st2common/util/param.py index 93507fcd875..270c90e4241 100644 --- a/st2common/st2common/util/param.py +++ b/st2common/st2common/util/param.py @@ -26,7 +26,11 @@ from st2common.util.jinja import is_jinja_expression from st2common.constants.action import ACTION_CONTEXT_KV_PREFIX from st2common.constants.pack import PACK_CONFIG_CONTEXT_KV_PREFIX -from st2common.constants.keyvalue import DATASTORE_PARENT_SCOPE, SYSTEM_SCOPE, FULL_SYSTEM_SCOPE +from st2common.constants.keyvalue import ( + DATASTORE_PARENT_SCOPE, + SYSTEM_SCOPE, + FULL_SYSTEM_SCOPE, +) from st2common.constants.keyvalue import USER_SCOPE, FULL_USER_SCOPE from st2common.exceptions.param import ParamException from st2common.services.keyvalues import KeyValueLookup, UserKeyValueLookup @@ -39,23 +43,27 @@ ENV = jinja_utils.get_jinja_environment() __all__ = [ - 'render_live_params', - 'render_final_params', + "render_live_params", + "render_final_params", ] def _split_params(runner_parameters, action_parameters, mixed_params): def pf(params, skips): - result = {k: v for k, v in six.iteritems(mixed_params) - if k in params and k not in skips} + result = { + k: v + for k, v in six.iteritems(mixed_params) + if k in params and k not in skips + } return result + return (pf(runner_parameters, {}), pf(action_parameters, runner_parameters)) def _cast_params(rendered, parameter_schemas): - ''' + """ It's just here to make tests happy - ''' + """ casted_params = {} for k, v in six.iteritems(rendered): casted_params[k] = _cast(v, parameter_schemas[k] or {}) @@ -66,7 +74,7 @@ def _cast(v, parameter_schema): if v is None or not parameter_schema: return v - parameter_type = parameter_schema.get('type', None) + parameter_type = parameter_schema.get("type", None) if not parameter_type: return v @@ -78,23 +86,27 @@ def _cast(v, parameter_schema): def _create_graph(action_context, config): - ''' + """ Creates a generic directed graph for depencency tree and fills it with basic context variables - ''' + """ G = nx.DiGraph() system_keyvalue_context = {SYSTEM_SCOPE: KeyValueLookup(scope=FULL_SYSTEM_SCOPE)} # If both 'user' and 'api_user' are specified, this prioritize 'api_user' - user = action_context['user'] if 'user' in action_context else None - user = action_context['api_user'] if 'api_user' in action_context else user + user = action_context["user"] if "user" in action_context else None + user = action_context["api_user"] if "api_user" in action_context else user if not user: # When no user is not specified, this selects system-user's scope by default. user = cfg.CONF.system_user.user - LOG.info('Unable to retrieve user / api_user value from action_context. Falling back ' - 'to and using system_user (%s).' % (user)) + LOG.info( + "Unable to retrieve user / api_user value from action_context. Falling back " + "to and using system_user (%s)." % (user) + ) - system_keyvalue_context[USER_SCOPE] = UserKeyValueLookup(scope=FULL_USER_SCOPE, user=user) + system_keyvalue_context[USER_SCOPE] = UserKeyValueLookup( + scope=FULL_USER_SCOPE, user=user + ) G.add_node(DATASTORE_PARENT_SCOPE, value=system_keyvalue_context) G.add_node(ACTION_CONTEXT_KV_PREFIX, value=action_context) G.add_node(PACK_CONFIG_CONTEXT_KV_PREFIX, value=config) @@ -102,9 +114,9 @@ def _create_graph(action_context, config): def _process(G, name, value): - ''' + """ Determines whether parameter is a template or a value. Adds graph nodes and edges accordingly. - ''' + """ # Jinja defaults to ascii parser in python 2.x unless you set utf-8 support on per module level # Instead we're just assuming every string to be a unicode string if isinstance(value, str): @@ -114,23 +126,21 @@ def _process(G, name, value): if isinstance(value, list) or isinstance(value, dict): complex_value_str = str(value) - is_jinja_expr = ( - jinja_utils.is_jinja_expression(value) or jinja_utils.is_jinja_expression( - complex_value_str - ) - ) + is_jinja_expr = jinja_utils.is_jinja_expression( + value + ) or jinja_utils.is_jinja_expression(complex_value_str) if is_jinja_expr: G.add_node(name, template=value) template_ast = ENV.parse(value) - LOG.debug('Template ast: %s', template_ast) + LOG.debug("Template ast: %s", template_ast) # Dependencies of the node represent jinja variables used in the template # We're connecting nodes with an edge for every depencency to traverse them # in the right order and also make sure that we don't have missing or cyclic # dependencies upfront. dependencies = meta.find_undeclared_variables(template_ast) - LOG.debug('Dependencies: %s', dependencies) + LOG.debug("Dependencies: %s", dependencies) if dependencies: for dependency in dependencies: G.add_edge(dependency, name) @@ -139,24 +149,24 @@ def _process(G, name, value): def _process_defaults(G, schemas): - ''' + """ Process dependencies for parameters default values in the order schemas are defined. - ''' + """ for schema in schemas: for name, value in six.iteritems(schema): absent = name not in G.node - is_none = G.node.get(name, {}).get('value') is None - immutable = value.get('immutable', False) + is_none = G.node.get(name, {}).get("value") is None + immutable = value.get("immutable", False) if absent or is_none or immutable: - _process(G, name, value.get('default')) + _process(G, name, value.get("default")) def _validate(G): - ''' + """ Validates dependency graph to ensure it has no missing or cyclic dependencies - ''' + """ for name in G.nodes(): - if 'value' not in G.node[name] and 'template' not in G.node[name]: + if "value" not in G.node[name] and "template" not in G.node[name]: msg = 'Dependency unsatisfied in variable "%s"' % name raise ParamException(msg) @@ -172,51 +182,52 @@ def _validate(G): variable_names.append(variable_name) - variable_names = ', '.join(sorted(variable_names)) - msg = ('Cyclic dependency found in the following variables: %s. Likely the variable is ' - 'referencing itself' % (variable_names)) + variable_names = ", ".join(sorted(variable_names)) + msg = ( + "Cyclic dependency found in the following variables: %s. Likely the variable is " + "referencing itself" % (variable_names) + ) raise ParamException(msg) def _render(node, render_context): - ''' + """ Render the node depending on its type - ''' - if 'template' in node: + """ + if "template" in node: complex_type = False - if isinstance(node['template'], list) or isinstance(node['template'], dict): - node['template'] = json.dumps(node['template']) + if isinstance(node["template"], list) or isinstance(node["template"], dict): + node["template"] = json.dumps(node["template"]) # Finds occurrences of "{{variable}}" and adds `to_complex` filter # so types are honored. If it doesn't follow that syntax then it's # rendered as a string. - node['template'] = re.sub( - r'"{{([A-z0-9_-]+)}}"', r'{{\1 | to_complex}}', - node['template'] + node["template"] = re.sub( + r'"{{([A-z0-9_-]+)}}"', r"{{\1 | to_complex}}", node["template"] ) - LOG.debug('Rendering complex type: %s', node['template']) + LOG.debug("Rendering complex type: %s", node["template"]) complex_type = True - LOG.debug('Rendering node: %s with context: %s', node, render_context) + LOG.debug("Rendering node: %s with context: %s", node, render_context) - result = ENV.from_string(str(node['template'])).render(render_context) + result = ENV.from_string(str(node["template"])).render(render_context) - LOG.debug('Render complete: %s', result) + LOG.debug("Render complete: %s", result) if complex_type: result = json.loads(result) - LOG.debug('Complex Type Rendered: %s', result) + LOG.debug("Complex Type Rendered: %s", result) return result - if 'value' in node: - return node['value'] + if "value" in node: + return node["value"] def _resolve_dependencies(G): - ''' + """ Traverse the dependency graph starting from resolved nodes - ''' + """ context = {} for name in nx.topological_sort(G): node = G.node[name] @@ -224,7 +235,7 @@ def _resolve_dependencies(G): context[name] = _render(node, context) except Exception as e: - LOG.debug('Failed to render %s: %s', name, e, exc_info=True) + LOG.debug("Failed to render %s: %s", name, e, exc_info=True) msg = 'Failed to render parameter "%s": %s' % (name, six.text_type(e)) raise ParamException(msg) @@ -232,9 +243,9 @@ def _resolve_dependencies(G): def _cast_params_from(params, context, schemas): - ''' + """ Pick a list of parameters from context and cast each of them according to the schemas provided - ''' + """ result = {} # First, cast only explicitly provided live parameters @@ -258,17 +269,19 @@ def _cast_params_from(params, context, schemas): for param_name, param_details in schema.items(): # Skip if the parameter have immutable set to true in schema - if param_details.get('immutable'): + if param_details.get("immutable"): continue # Skip if the parameter doesn't have a default, or if the # value in the context is identical to the default - if 'default' not in param_details or \ - param_details.get('default') == context[param_name]: + if ( + "default" not in param_details + or param_details.get("default") == context[param_name] + ): continue # Skip if the default value isn't a Jinja expression - if not is_jinja_expression(param_details.get('default')): + if not is_jinja_expression(param_details.get("default")): continue # Skip if the parameter is being overridden @@ -280,22 +293,29 @@ def _cast_params_from(params, context, schemas): return result -def render_live_params(runner_parameters, action_parameters, params, action_context, - additional_contexts=None): - ''' +def render_live_params( + runner_parameters, + action_parameters, + params, + action_context, + additional_contexts=None, +): + """ Renders list of parameters. Ensures that there's no cyclic or missing dependencies. Returns a dict of plain rendered parameters. - ''' + """ additional_contexts = additional_contexts or {} - pack = action_context.get('pack') - user = action_context.get('user') + pack = action_context.get("pack") + user = action_context.get("user") try: config = get_config(pack, user) except Exception as e: - LOG.info('Failed to retrieve config for pack %s and user %s: %s' % (pack, user, - six.text_type(e))) + LOG.info( + "Failed to retrieve config for pack %s and user %s: %s" + % (pack, user, six.text_type(e)) + ) config = {} G = _create_graph(action_context, config) @@ -310,18 +330,20 @@ def render_live_params(runner_parameters, action_parameters, params, action_cont _validate(G) context = _resolve_dependencies(G) - live_params = _cast_params_from(params, context, [action_parameters, runner_parameters]) + live_params = _cast_params_from( + params, context, [action_parameters, runner_parameters] + ) return live_params def render_final_params(runner_parameters, action_parameters, params, action_context): - ''' + """ Renders missing parameters required for action to execute. Treats parameters from the dict as plain values instead of trying to render them again. Returns dicts for action and runner parameters. - ''' - config = get_config(action_context.get('pack'), action_context.get('user')) + """ + config = get_config(action_context.get("pack"), action_context.get("user")) G = _create_graph(action_context, config) @@ -331,18 +353,29 @@ def render_final_params(runner_parameters, action_parameters, params, action_con _validate(G) context = _resolve_dependencies(G) - context = _cast_params_from(context, context, [action_parameters, runner_parameters]) + context = _cast_params_from( + context, context, [action_parameters, runner_parameters] + ) return _split_params(runner_parameters, action_parameters, context) -def get_finalized_params(runnertype_parameter_info, action_parameter_info, liveaction_parameters, - action_context): - ''' +def get_finalized_params( + runnertype_parameter_info, + action_parameter_info, + liveaction_parameters, + action_context, +): + """ Left here to keep tests running. Later we would need to split tests so they start testing each function separately. - ''' - params = render_live_params(runnertype_parameter_info, action_parameter_info, - liveaction_parameters, action_context) - return render_final_params(runnertype_parameter_info, action_parameter_info, params, - action_context) + """ + params = render_live_params( + runnertype_parameter_info, + action_parameter_info, + liveaction_parameters, + action_context, + ) + return render_final_params( + runnertype_parameter_info, action_parameter_info, params, action_context + ) diff --git a/st2common/st2common/util/payload.py b/st2common/st2common/util/payload.py index 92b36d55c02..b2dc2a74afb 100644 --- a/st2common/st2common/util/payload.py +++ b/st2common/st2common/util/payload.py @@ -22,11 +22,8 @@ class PayloadLookup(object): - def __init__(self, payload, prefix=TRIGGER_PAYLOAD_PREFIX): - self.context = { - prefix: payload - } + self.context = {prefix: payload} for system_scope in SYSTEM_SCOPES: self.context[system_scope] = KeyValueLookup(scope=system_scope) diff --git a/st2common/st2common/util/queues.py b/st2common/st2common/util/queues.py index 526692155f1..9fce3b20a70 100644 --- a/st2common/st2common/util/queues.py +++ b/st2common/st2common/util/queues.py @@ -36,7 +36,7 @@ def get_queue_name(queue_name_base, queue_name_suffix, add_random_uuid_to_suffix :rtype: ``str`` """ if not queue_name_base: - raise ValueError('Queue name base cannot be empty.') + raise ValueError("Queue name base cannot be empty.") if not queue_name_suffix: return queue_name_base @@ -46,8 +46,8 @@ def get_queue_name(queue_name_base, queue_name_suffix, add_random_uuid_to_suffix # Pick last 10 digits of uuid. Arbitrary but unique enough. Long queue names # might cause issues in RabbitMQ. u_hex = uuid.uuid4().hex - uuid_suffix = uuid.uuid4().hex[len(u_hex) - 10:] - queue_suffix = '%s-%s' % (queue_name_suffix, uuid_suffix) + uuid_suffix = uuid.uuid4().hex[len(u_hex) - 10 :] + queue_suffix = "%s-%s" % (queue_name_suffix, uuid_suffix) - queue_name = '%s.%s' % (queue_name_base, queue_suffix) + queue_name = "%s.%s" % (queue_name_base, queue_suffix) return queue_name diff --git a/st2common/st2common/util/reference.py b/st2common/st2common/util/reference.py index 3262eb603f3..137a014d731 100644 --- a/st2common/st2common/util/reference.py +++ b/st2common/st2common/util/reference.py @@ -20,24 +20,25 @@ def get_ref_from_model(model): if model is None: - raise ValueError('Model has None value.') - model_id = getattr(model, 'id', None) + raise ValueError("Model has None value.") + model_id = getattr(model, "id", None) if model_id is None: - raise db.StackStormDBObjectMalformedError('model %s must contain id.' % str(model)) - reference = {'id': str(model_id), - 'name': getattr(model, 'name', None)} + raise db.StackStormDBObjectMalformedError( + "model %s must contain id." % str(model) + ) + reference = {"id": str(model_id), "name": getattr(model, "name", None)} return reference def get_model_from_ref(db_api, reference): if reference is None: - raise db.StackStormDBObjectNotFoundError('No reference supplied.') - model_id = reference.get('id', None) + raise db.StackStormDBObjectNotFoundError("No reference supplied.") + model_id = reference.get("id", None) if model_id is not None: return db_api.get_by_id(model_id) - model_name = reference.get('name', None) + model_name = reference.get("name", None) if model_name is None: - raise db.StackStormDBObjectNotFoundError('Both name and id are None.') + raise db.StackStormDBObjectNotFoundError("Both name and id are None.") return db_api.get_by_name(model_name) @@ -71,8 +72,10 @@ def get_resource_ref_from_model(model): name = model.name pack = model.pack except AttributeError: - raise Exception('Cannot build ResourceReference for model: %s. Name or pack missing.' - % model) + raise Exception( + "Cannot build ResourceReference for model: %s. Name or pack missing." + % model + ) return ResourceReference(name=name, pack=pack) diff --git a/st2common/st2common/util/sandboxing.py b/st2common/st2common/util/sandboxing.py index 02871f74721..9801f7d112e 100644 --- a/st2common/st2common/util/sandboxing.py +++ b/st2common/st2common/util/sandboxing.py @@ -31,11 +31,11 @@ from st2common.content.utils import get_pack_base_path __all__ = [ - 'get_sandbox_python_binary_path', - 'get_sandbox_python_path', - 'get_sandbox_python_path_for_python_action', - 'get_sandbox_path', - 'get_sandbox_virtualenv_path', + "get_sandbox_python_binary_path", + "get_sandbox_python_path", + "get_sandbox_python_path_for_python_action", + "get_sandbox_path", + "get_sandbox_virtualenv_path", ] @@ -47,13 +47,13 @@ def get_sandbox_python_binary_path(pack=None): :type pack: ``str`` """ system_base_path = cfg.CONF.system.base_path - virtualenv_path = os.path.join(system_base_path, 'virtualenvs', pack) + virtualenv_path = os.path.join(system_base_path, "virtualenvs", pack) if pack in SYSTEM_PACK_NAMES: # Use system python for "packs" and "core" actions python_path = sys.executable else: - python_path = os.path.join(virtualenv_path, 'bin/python') + python_path = os.path.join(virtualenv_path, "bin/python") return python_path @@ -70,19 +70,19 @@ def get_sandbox_path(virtualenv_path): """ sandbox_path = [] - parent_path = os.environ.get('PATH', '') + parent_path = os.environ.get("PATH", "") if not virtualenv_path: return parent_path - parent_path = parent_path.split(':') + parent_path = parent_path.split(":") parent_path = [path for path in parent_path if path] # Add virtualenv bin directory - virtualenv_bin_path = os.path.join(virtualenv_path, 'bin/') + virtualenv_bin_path = os.path.join(virtualenv_path, "bin/") sandbox_path.append(virtualenv_bin_path) sandbox_path.extend(parent_path) - sandbox_path = ':'.join(sandbox_path) + sandbox_path = ":".join(sandbox_path) return sandbox_path @@ -104,9 +104,9 @@ def get_sandbox_python_path(inherit_from_parent=True, inherit_parent_virtualenv= :type inherit_parent_virtualenv: ``str`` """ sandbox_python_path = [] - parent_python_path = os.environ.get('PYTHONPATH', '') + parent_python_path = os.environ.get("PYTHONPATH", "") - parent_python_path = parent_python_path.split(':') + parent_python_path = parent_python_path.split(":") parent_python_path = [path for path in parent_python_path if path] if inherit_from_parent: @@ -121,13 +121,14 @@ def get_sandbox_python_path(inherit_from_parent=True, inherit_parent_virtualenv= sandbox_python_path.append(site_packages_dir) - sandbox_python_path = ':'.join(sandbox_python_path) - sandbox_python_path = ':' + sandbox_python_path + sandbox_python_path = ":".join(sandbox_python_path) + sandbox_python_path = ":" + sandbox_python_path return sandbox_python_path -def get_sandbox_python_path_for_python_action(pack, inherit_from_parent=True, - inherit_parent_virtualenv=True): +def get_sandbox_python_path_for_python_action( + pack, inherit_from_parent=True, inherit_parent_virtualenv=True +): """ Return sandbox PYTHONPATH for a particular Python runner action. @@ -136,30 +137,36 @@ def get_sandbox_python_path_for_python_action(pack, inherit_from_parent=True, """ sandbox_python_path = get_sandbox_python_path( inherit_from_parent=inherit_from_parent, - inherit_parent_virtualenv=inherit_parent_virtualenv) + inherit_parent_virtualenv=inherit_parent_virtualenv, + ) pack_base_path = get_pack_base_path(pack_name=pack) virtualenv_path = get_sandbox_virtualenv_path(pack=pack) if virtualenv_path and os.path.isdir(virtualenv_path): - pack_virtualenv_lib_path = os.path.join(virtualenv_path, 'lib') + pack_virtualenv_lib_path = os.path.join(virtualenv_path, "lib") virtualenv_directories = os.listdir(pack_virtualenv_lib_path) - virtualenv_directories = [dir_name for dir_name in virtualenv_directories if - fnmatch.fnmatch(dir_name, 'python*')] + virtualenv_directories = [ + dir_name + for dir_name in virtualenv_directories + if fnmatch.fnmatch(dir_name, "python*") + ] # Add the pack's lib directory (lib/python3.x) in front of the PYTHONPATH. - pack_actions_lib_paths = os.path.join(pack_base_path, 'actions', 'lib') - pack_virtualenv_lib_path = os.path.join(virtualenv_path, 'lib') - pack_venv_lib_directory = os.path.join(pack_virtualenv_lib_path, virtualenv_directories[0]) + pack_actions_lib_paths = os.path.join(pack_base_path, "actions", "lib") + pack_virtualenv_lib_path = os.path.join(virtualenv_path, "lib") + pack_venv_lib_directory = os.path.join( + pack_virtualenv_lib_path, virtualenv_directories[0] + ) # Add the pack's site-packages directory (lib/python3.x/site-packages) # in front of the Python system site-packages This is important because # we want Python 3 compatible libraries to be used from the pack virtual # environment and not system ones. - pack_venv_site_packages_directory = os.path.join(pack_virtualenv_lib_path, - virtualenv_directories[0], - 'site-packages') + pack_venv_site_packages_directory = os.path.join( + pack_virtualenv_lib_path, virtualenv_directories[0], "site-packages" + ) full_sandbox_python_path = [ # NOTE: Order here is very important for imports to function correctly @@ -169,7 +176,7 @@ def get_sandbox_python_path_for_python_action(pack, inherit_from_parent=True, sandbox_python_path, ] - sandbox_python_path = ':'.join(full_sandbox_python_path) + sandbox_python_path = ":".join(full_sandbox_python_path) return sandbox_python_path @@ -183,7 +190,7 @@ def get_sandbox_virtualenv_path(pack): virtualenv_path = None else: system_base_path = cfg.CONF.system.base_path - virtualenv_path = os.path.join(system_base_path, 'virtualenvs', pack) + virtualenv_path = os.path.join(system_base_path, "virtualenvs", pack) return virtualenv_path @@ -195,8 +202,9 @@ def is_in_virtualenv(): """ # sys.real_prefix is for virtualenv # sys.base_prefix != sys.prefix is for venv - return (hasattr(sys, 'real_prefix') or - (hasattr(sys, 'base_prefix') and sys.base_prefix != sys.prefix)) + return hasattr(sys, "real_prefix") or ( + hasattr(sys, "base_prefix") and sys.base_prefix != sys.prefix + ) def get_virtualenv_prefix(): @@ -205,10 +213,10 @@ def get_virtualenv_prefix(): where we retrieved the virtualenv prefix from. The second element is the virtualenv prefix. """ - if hasattr(sys, 'real_prefix'): - return ('sys.real_prefix', sys.real_prefix) - elif hasattr(sys, 'base_prefix'): - return ('sys.base_prefix', sys.base_prefix) + if hasattr(sys, "real_prefix"): + return ("sys.real_prefix", sys.real_prefix) + elif hasattr(sys, "base_prefix"): + return ("sys.base_prefix", sys.base_prefix) return (None, None) @@ -216,9 +224,9 @@ def set_virtualenv_prefix(prefix_tuple): """ :return: Sets the virtualenv prefix given a tuple returned from get_virtualenv_prefix() """ - if prefix_tuple[0] == 'sys.real_prefix' and hasattr(sys, 'real_prefix'): + if prefix_tuple[0] == "sys.real_prefix" and hasattr(sys, "real_prefix"): sys.real_prefix = prefix_tuple[1] - elif prefix_tuple[0] == 'sys.base_prefix' and hasattr(sys, 'base_prefix'): + elif prefix_tuple[0] == "sys.base_prefix" and hasattr(sys, "base_prefix"): sys.base_prefix = prefix_tuple[1] @@ -226,7 +234,7 @@ def clear_virtualenv_prefix(): """ :return: Unsets / removes / resets the virtualenv prefix """ - if hasattr(sys, 'real_prefix'): + if hasattr(sys, "real_prefix"): del sys.real_prefix - if hasattr(sys, 'base_prefix'): + if hasattr(sys, "base_prefix"): sys.base_prefix = sys.prefix diff --git a/st2common/st2common/util/schema/__init__.py b/st2common/st2common/util/schema/__init__.py index a49f733a5a7..8e18509cd10 100644 --- a/st2common/st2common/util/schema/__init__.py +++ b/st2common/st2common/util/schema/__init__.py @@ -27,19 +27,19 @@ from st2common.util.misc import deep_update __all__ = [ - 'get_validator', - 'get_draft_schema', - 'get_action_parameters_schema', - 'get_schema_for_action_parameters', - 'get_schema_for_resource_parameters', - 'is_property_type_single', - 'is_property_type_list', - 'is_property_type_anyof', - 'is_property_type_oneof', - 'is_property_nullable', - 'is_attribute_type_array', - 'is_attribute_type_object', - 'validate' + "get_validator", + "get_draft_schema", + "get_action_parameters_schema", + "get_schema_for_action_parameters", + "get_schema_for_resource_parameters", + "is_property_type_single", + "is_property_type_list", + "is_property_type_anyof", + "is_property_type_oneof", + "is_property_nullable", + "is_attribute_type_array", + "is_attribute_type_object", + "validate", ] # https://github.com/json-schema/json-schema/blob/master/draft-04/schema @@ -49,12 +49,13 @@ # and draft 3 version of required. PATH = os.path.join(os.path.dirname(os.path.realpath(__file__))) SCHEMAS = { - 'draft4': jsonify.load_file(os.path.join(PATH, 'draft4.json')), - 'custom': jsonify.load_file(os.path.join(PATH, 'custom.json')), - + "draft4": jsonify.load_file(os.path.join(PATH, "draft4.json")), + "custom": jsonify.load_file(os.path.join(PATH, "custom.json")), # Custom schema for action params which doesn't allow parameter "type" attribute to be array - 'action_params': jsonify.load_file(os.path.join(PATH, 'action_params.json')), - 'action_output_schema': jsonify.load_file(os.path.join(PATH, 'action_output_schema.json')) + "action_params": jsonify.load_file(os.path.join(PATH, "action_params.json")), + "action_output_schema": jsonify.load_file( + os.path.join(PATH, "action_output_schema.json") + ), } SCHEMA_ANY_TYPE = { @@ -64,23 +65,23 @@ {"type": "integer"}, {"type": "number"}, {"type": "object"}, - {"type": "string"} + {"type": "string"}, ] } RUNNER_PARAM_OVERRIDABLE_ATTRS = [ - 'default', - 'description', - 'enum', - 'immutable', - 'required' + "default", + "description", + "enum", + "immutable", + "required", ] -def get_draft_schema(version='custom', additional_properties=False): +def get_draft_schema(version="custom", additional_properties=False): schema = copy.deepcopy(SCHEMAS[version]) - if additional_properties and 'additionalProperties' in schema: - del schema['additionalProperties'] + if additional_properties and "additionalProperties" in schema: + del schema["additionalProperties"] return schema @@ -89,8 +90,7 @@ def get_action_output_schema(additional_properties=True): Return a generic schema which is used for validating action output. """ return get_draft_schema( - version='action_output_schema', - additional_properties=additional_properties + version="action_output_schema", additional_properties=additional_properties ) @@ -98,81 +98,100 @@ def get_action_parameters_schema(additional_properties=False): """ Return a generic schema which is used for validating action parameters definition. """ - return get_draft_schema(version='action_params', additional_properties=additional_properties) + return get_draft_schema( + version="action_params", additional_properties=additional_properties + ) CustomValidator = create( - meta_schema=get_draft_schema(version='custom', additional_properties=True), + meta_schema=get_draft_schema(version="custom", additional_properties=True), validators={ - u"$ref": _validators.ref, - u"additionalItems": _validators.additionalItems, - u"additionalProperties": _validators.additionalProperties, - u"allOf": _validators.allOf_draft4, - u"anyOf": _validators.anyOf_draft4, - u"dependencies": _validators.dependencies, - u"enum": _validators.enum, - u"format": _validators.format, - u"items": _validators.items, - u"maxItems": _validators.maxItems, - u"maxLength": _validators.maxLength, - u"maxProperties": _validators.maxProperties_draft4, - u"maximum": _validators.maximum, - u"minItems": _validators.minItems, - u"minLength": _validators.minLength, - u"minProperties": _validators.minProperties_draft4, - u"minimum": _validators.minimum, - u"multipleOf": _validators.multipleOf, - u"not": _validators.not_draft4, - u"oneOf": _validators.oneOf_draft4, - u"pattern": _validators.pattern, - u"patternProperties": _validators.patternProperties, - u"properties": _validators.properties_draft3, - u"type": _validators.type_draft4, - u"uniqueItems": _validators.uniqueItems, + "$ref": _validators.ref, + "additionalItems": _validators.additionalItems, + "additionalProperties": _validators.additionalProperties, + "allOf": _validators.allOf_draft4, + "anyOf": _validators.anyOf_draft4, + "dependencies": _validators.dependencies, + "enum": _validators.enum, + "format": _validators.format, + "items": _validators.items, + "maxItems": _validators.maxItems, + "maxLength": _validators.maxLength, + "maxProperties": _validators.maxProperties_draft4, + "maximum": _validators.maximum, + "minItems": _validators.minItems, + "minLength": _validators.minLength, + "minProperties": _validators.minProperties_draft4, + "minimum": _validators.minimum, + "multipleOf": _validators.multipleOf, + "not": _validators.not_draft4, + "oneOf": _validators.oneOf_draft4, + "pattern": _validators.pattern, + "patternProperties": _validators.patternProperties, + "properties": _validators.properties_draft3, + "type": _validators.type_draft4, + "uniqueItems": _validators.uniqueItems, }, version="custom_validator", ) def is_property_type_single(property_schema): - return (isinstance(property_schema, dict) and - 'anyOf' not in list(property_schema.keys()) and - 'oneOf' not in list(property_schema.keys()) and - not isinstance(property_schema.get('type', 'string'), list)) + return ( + isinstance(property_schema, dict) + and "anyOf" not in list(property_schema.keys()) + and "oneOf" not in list(property_schema.keys()) + and not isinstance(property_schema.get("type", "string"), list) + ) def is_property_type_list(property_schema): - return (isinstance(property_schema, dict) and - isinstance(property_schema.get('type', 'string'), list)) + return isinstance(property_schema, dict) and isinstance( + property_schema.get("type", "string"), list + ) def is_property_type_anyof(property_schema): - return isinstance(property_schema, dict) and 'anyOf' in list(property_schema.keys()) + return isinstance(property_schema, dict) and "anyOf" in list(property_schema.keys()) def is_property_type_oneof(property_schema): - return isinstance(property_schema, dict) and 'oneOf' in list(property_schema.keys()) + return isinstance(property_schema, dict) and "oneOf" in list(property_schema.keys()) def is_property_nullable(property_type_schema): # For anyOf and oneOf, the property_schema is a list of types. if isinstance(property_type_schema, list): - return len([t for t in property_type_schema - if ((isinstance(t, six.string_types) and t == 'null') or - (isinstance(t, dict) and t.get('type', 'string') == 'null'))]) > 0 - - return (isinstance(property_type_schema, dict) and - property_type_schema.get('type', 'string') == 'null') + return ( + len( + [ + t + for t in property_type_schema + if ( + (isinstance(t, six.string_types) and t == "null") + or (isinstance(t, dict) and t.get("type", "string") == "null") + ) + ] + ) + > 0 + ) + + return ( + isinstance(property_type_schema, dict) + and property_type_schema.get("type", "string") == "null" + ) def is_attribute_type_array(attribute_type): - return (attribute_type == 'array' or - (isinstance(attribute_type, list) and 'array' in attribute_type)) + return attribute_type == "array" or ( + isinstance(attribute_type, list) and "array" in attribute_type + ) def is_attribute_type_object(attribute_type): - return (attribute_type == 'object' or - (isinstance(attribute_type, list) and 'object' in attribute_type)) + return attribute_type == "object" or ( + isinstance(attribute_type, list) and "object" in attribute_type + ) def assign_default_values(instance, schema): @@ -186,11 +205,11 @@ def assign_default_values(instance, schema): if not instance_is_dict and not instance_is_array: return instance - properties = schema.get('properties', {}) + properties = schema.get("properties", {}) for property_name, property_data in six.iteritems(properties): - has_default_value = 'default' in property_data - default_value = property_data.get('default', None) + has_default_value = "default" in property_data + default_value = property_data.get("default", None) # Assign default value on the instance so the validation doesn't fail if requires is true # but the value is not provided @@ -203,29 +222,36 @@ def assign_default_values(instance, schema): instance[index][property_name] = default_value # Support for nested properties (array and object) - attribute_type = property_data.get('type', None) - schema_items = property_data.get('items', {}) + attribute_type = property_data.get("type", None) + schema_items = property_data.get("items", {}) # Array - if (is_attribute_type_array(attribute_type) and - schema_items and schema_items.get('properties', {})): + if ( + is_attribute_type_array(attribute_type) + and schema_items + and schema_items.get("properties", {}) + ): array_instance = instance.get(property_name, None) - array_schema = schema['properties'][property_name]['items'] + array_schema = schema["properties"][property_name]["items"] if array_instance is not None: # Note: We don't perform subschema assignment if no value is provided - instance[property_name] = assign_default_values(instance=array_instance, - schema=array_schema) + instance[property_name] = assign_default_values( + instance=array_instance, schema=array_schema + ) # Object - if is_attribute_type_object(attribute_type) and property_data.get('properties', {}): + if is_attribute_type_object(attribute_type) and property_data.get( + "properties", {} + ): object_instance = instance.get(property_name, None) - object_schema = schema['properties'][property_name] + object_schema = schema["properties"][property_name] if object_instance is not None: # Note: We don't perform subschema assignment if no value is provided - instance[property_name] = assign_default_values(instance=object_instance, - schema=object_schema) + instance[property_name] = assign_default_values( + instance=object_instance, schema=object_schema + ) return instance @@ -236,51 +262,70 @@ def modify_schema_allow_default_none(schema): defines a default value of None. """ schema = copy.deepcopy(schema) - properties = schema.get('properties', {}) + properties = schema.get("properties", {}) for property_name, property_data in six.iteritems(properties): - is_optional = not property_data.get('required', False) - has_default_value = 'default' in property_data - default_value = property_data.get('default', None) - property_schema = schema['properties'][property_name] + is_optional = not property_data.get("required", False) + has_default_value = "default" in property_data + default_value = property_data.get("default", None) + property_schema = schema["properties"][property_name] if (has_default_value or is_optional) and default_value is None: # If property is anyOf and oneOf then it has to be process differently. - if (is_property_type_anyof(property_schema) and - not is_property_nullable(property_schema['anyOf'])): - property_schema['anyOf'].append({'type': 'null'}) - elif (is_property_type_oneof(property_schema) and - not is_property_nullable(property_schema['oneOf'])): - property_schema['oneOf'].append({'type': 'null'}) - elif (is_property_type_list(property_schema) and - not is_property_nullable(property_schema.get('type'))): - property_schema['type'].append('null') - elif (is_property_type_single(property_schema) and - not is_property_nullable(property_schema.get('type'))): - property_schema['type'] = [property_schema.get('type', 'string'), 'null'] + if is_property_type_anyof(property_schema) and not is_property_nullable( + property_schema["anyOf"] + ): + property_schema["anyOf"].append({"type": "null"}) + elif is_property_type_oneof(property_schema) and not is_property_nullable( + property_schema["oneOf"] + ): + property_schema["oneOf"].append({"type": "null"}) + elif is_property_type_list(property_schema) and not is_property_nullable( + property_schema.get("type") + ): + property_schema["type"].append("null") + elif is_property_type_single(property_schema) and not is_property_nullable( + property_schema.get("type") + ): + property_schema["type"] = [ + property_schema.get("type", "string"), + "null", + ] # Support for nested properties (array and object) - attribute_type = property_data.get('type', None) - schema_items = property_data.get('items', {}) + attribute_type = property_data.get("type", None) + schema_items = property_data.get("items", {}) # Array - if (is_attribute_type_array(attribute_type) and - schema_items and schema_items.get('properties', {})): + if ( + is_attribute_type_array(attribute_type) + and schema_items + and schema_items.get("properties", {}) + ): array_schema = schema_items array_schema = modify_schema_allow_default_none(schema=array_schema) - schema['properties'][property_name]['items'] = array_schema + schema["properties"][property_name]["items"] = array_schema # Object - if is_attribute_type_object(attribute_type) and property_data.get('properties', {}): + if is_attribute_type_object(attribute_type) and property_data.get( + "properties", {} + ): object_schema = property_data object_schema = modify_schema_allow_default_none(schema=object_schema) - schema['properties'][property_name] = object_schema + schema["properties"][property_name] = object_schema return schema -def validate(instance, schema, cls=None, use_default=True, allow_default_none=False, *args, - **kwargs): +def validate( + instance, + schema, + cls=None, + use_default=True, + allow_default_none=False, + *args, + **kwargs, +): """ Custom validate function which supports default arguments combined with the "required" property. @@ -292,13 +337,13 @@ def validate(instance, schema, cls=None, use_default=True, allow_default_none=Fa """ instance = copy.deepcopy(instance) - schema_type = schema.get('type', None) + schema_type = schema.get("type", None) instance_is_dict = isinstance(instance, dict) if use_default and allow_default_none: schema = modify_schema_allow_default_none(schema=schema) - if use_default and schema_type == 'object' and instance_is_dict: + if use_default and schema_type == "object" and instance_is_dict: instance = assign_default_values(instance=instance, schema=schema) # pylint: disable=assignment-from-no-return @@ -307,28 +352,30 @@ def validate(instance, schema, cls=None, use_default=True, allow_default_none=Fa return instance -VALIDATORS = { - 'draft4': jsonschema.Draft4Validator, - 'custom': CustomValidator -} +VALIDATORS = {"draft4": jsonschema.Draft4Validator, "custom": CustomValidator} -def get_validator(version='custom'): +def get_validator(version="custom"): validator = VALIDATORS[version] return validator -def validate_runner_parameter_attribute_override(action_ref, param_name, attr_name, - runner_param_attr_value, action_param_attr_value): +def validate_runner_parameter_attribute_override( + action_ref, param_name, attr_name, runner_param_attr_value, action_param_attr_value +): """ Validate that the provided parameter from the action schema can override the runner parameter. """ param_values_are_the_same = action_param_attr_value == runner_param_attr_value - if (attr_name not in RUNNER_PARAM_OVERRIDABLE_ATTRS and not param_values_are_the_same): + if ( + attr_name not in RUNNER_PARAM_OVERRIDABLE_ATTRS + and not param_values_are_the_same + ): raise InvalidActionParameterException( 'The attribute "%s" for the runner parameter "%s" in action "%s" ' - 'cannot be overridden.' % (attr_name, param_name, action_ref)) + "cannot be overridden." % (attr_name, param_name, action_ref) + ) return True @@ -341,7 +388,8 @@ def get_schema_for_action_parameters(action_db, runnertype_db=None): """ if not runnertype_db: from st2common.util.action_db import get_runnertype_by_name - runnertype_db = get_runnertype_by_name(action_db.runner_type['name']) + + runnertype_db = get_runnertype_by_name(action_db.runner_type["name"]) # Note: We need to perform a deep merge because user can only specify a single parameter # attribute when overriding it in an action metadata. @@ -359,26 +407,31 @@ def get_schema_for_action_parameters(action_db, runnertype_db=None): for attribute, value in six.iteritems(schema): runner_param_value = runnertype_db.runner_parameters[name].get(attribute) - validate_runner_parameter_attribute_override(action_ref=action_db.ref, - param_name=name, - attr_name=attribute, - runner_param_attr_value=runner_param_value, - action_param_attr_value=value) + validate_runner_parameter_attribute_override( + action_ref=action_db.ref, + param_name=name, + attr_name=attribute, + runner_param_attr_value=runner_param_value, + action_param_attr_value=value, + ) schema = get_schema_for_resource_parameters(parameters_schema=parameters_schema) if parameters_schema: - schema['title'] = action_db.name + schema["title"] = action_db.name if action_db.description: - schema['description'] = action_db.description + schema["description"] = action_db.description return schema -def get_schema_for_resource_parameters(parameters_schema, allow_additional_properties=False): +def get_schema_for_resource_parameters( + parameters_schema, allow_additional_properties=False +): """ Dynamically construct JSON schema for the provided resource from the parameters metadata. """ + def normalize(x): return {k: v if v else SCHEMA_ANY_TYPE for k, v in six.iteritems(x)} @@ -386,8 +439,8 @@ def normalize(x): properties = {} properties.update(normalize(parameters_schema)) if properties: - schema['type'] = 'object' - schema['properties'] = properties - schema['additionalProperties'] = allow_additional_properties + schema["type"] = "object" + schema["properties"] = properties + schema["additionalProperties"] = allow_additional_properties return schema diff --git a/st2common/st2common/util/secrets.py b/st2common/st2common/util/secrets.py index 2945ef0594e..b863a93a61d 100644 --- a/st2common/st2common/util/secrets.py +++ b/st2common/st2common/util/secrets.py @@ -65,7 +65,7 @@ def get_secret_parameters(parameters): """ secret_parameters = {} - parameters_type = parameters.get('type') + parameters_type = parameters.get("type") # If the parameter itself is secret, then skip all processing below it # and return the type of this parameter. # @@ -74,22 +74,22 @@ def get_secret_parameters(parameters): # **Important** that we do this check first, so in case this parameter # is an `object` or `array`, and the user wants the full thing # to be secret, that it is marked as secret. - if parameters.get('secret', False): + if parameters.get("secret", False): return parameters_type iterator = None - if parameters_type == 'object': + if parameters_type == "object": # if this is an object, then iterate over the properties within # the object # result = dict - iterator = six.iteritems(parameters.get('properties', {})) - elif parameters_type == 'array': + iterator = six.iteritems(parameters.get("properties", {})) + elif parameters_type == "array": # if this is an array, then iterate over the items definition as a single # property # result = list - iterator = enumerate([parameters.get('items', {})]) + iterator = enumerate([parameters.get("items", {})]) secret_parameters = [] - elif parameters_type in ['integer', 'number', 'boolean', 'null', 'string']: + elif parameters_type in ["integer", "number", "boolean", "null", "string"]: # if this a "plain old datatype", then iterate over the properties set # of the data type # result = string (property type) @@ -105,8 +105,8 @@ def get_secret_parameters(parameters): if not isinstance(options, dict): continue - parameter_type = options.get('type') - if options.get('secret', False): + parameter_type = options.get("type") + if options.get("secret", False): # If this parameter is secret, then add it our secret parameters # # **This causes the _full_ object / array tree to be secret @@ -121,7 +121,7 @@ def get_secret_parameters(parameters): secret_parameters[parameter] = parameter_type else: return parameter_type - elif parameter_type in ['object', 'array']: + elif parameter_type in ["object", "array"]: # otherwise recursively dive into the `object`/`array` and # find individual parameters marked as secret sub_params = get_secret_parameters(options) @@ -176,15 +176,17 @@ def mask_secret_parameters(parameters, secret_parameters, result=None): for secret_param, secret_sub_params in iterator: if is_dict: if secret_param in result: - result[secret_param] = mask_secret_parameters(parameters[secret_param], - secret_sub_params, - result=result[secret_param]) + result[secret_param] = mask_secret_parameters( + parameters[secret_param], + secret_sub_params, + result=result[secret_param], + ) elif is_list: # we're assuming lists contain the same data type for every element for idx, value in enumerate(result): - result[idx] = mask_secret_parameters(parameters[idx], - secret_sub_params, - result=result[idx]) + result[idx] = mask_secret_parameters( + parameters[idx], secret_sub_params, result=result[idx] + ) else: result[secret_param] = MASKED_ATTRIBUTE_VALUE @@ -204,8 +206,8 @@ def mask_inquiry_response(response, schema): """ result = fast_deepcopy(response) - for prop_name, prop_attrs in schema['properties'].items(): - if prop_attrs.get('secret') is True: + for prop_name, prop_attrs in schema["properties"].items(): + if prop_attrs.get("secret") is True: if prop_name in response: result[prop_name] = MASKED_ATTRIBUTE_VALUE diff --git a/st2common/st2common/util/service.py b/st2common/st2common/util/service.py index 6691e502682..e3c2dcb9f9e 100644 --- a/st2common/st2common/util/service.py +++ b/st2common/st2common/util/service.py @@ -24,13 +24,13 @@ def retry_on_exceptions(exc): - LOG.warning('Evaluating retry on exception %s. %s', type(exc), str(exc)) + LOG.warning("Evaluating retry on exception %s. %s", type(exc), str(exc)) is_mongo_connection_error = isinstance(exc, pymongo.errors.ConnectionFailure) retrying = is_mongo_connection_error if retrying: - LOG.warning('Retrying on exception %s.', type(exc)) + LOG.warning("Retrying on exception %s.", type(exc)) return retrying diff --git a/st2common/st2common/util/shell.py b/st2common/st2common/util/shell.py index 5c4217594a3..945ec39a5a0 100644 --- a/st2common/st2common/util/shell.py +++ b/st2common/st2common/util/shell.py @@ -30,13 +30,7 @@ # subprocess functionality and run_command subprocess = concurrency.get_subprocess_module() -__all__ = [ - 'run_command', - 'kill_process', - - 'quote_unix', - 'quote_windows' -] +__all__ = ["run_command", "kill_process", "quote_unix", "quote_windows"] LOG = logging.getLogger(__name__) @@ -45,8 +39,15 @@ # pylint: disable=too-many-function-args -def run_command(cmd, stdin=None, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False, - cwd=None, env=None): +def run_command( + cmd, + stdin=None, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=False, + cwd=None, + env=None, +): """ Run the provided command in a subprocess and wait until it completes. @@ -79,8 +80,15 @@ def run_command(cmd, stdin=None, stdout=subprocess.PIPE, stderr=subprocess.PIPE, if not env: env = os.environ.copy() - process = concurrency.subprocess_popen(args=cmd, stdin=stdin, stdout=stdout, stderr=stderr, - env=env, cwd=cwd, shell=shell) + process = concurrency.subprocess_popen( + args=cmd, + stdin=stdin, + stdout=stdout, + stderr=stderr, + env=env, + cwd=cwd, + shell=shell, + ) stdout, stderr = process.communicate() exit_code = process.returncode @@ -100,15 +108,17 @@ def kill_process(process): :param process: Process object as returned by subprocess.Popen. :type process: ``object`` """ - kill_command = shlex.split('sudo pkill -TERM -s %s' % (process.pid)) + kill_command = shlex.split("sudo pkill -TERM -s %s" % (process.pid)) try: if six.PY3: - status = subprocess.call(kill_command, timeout=100) # pylint: disable=not-callable + status = subprocess.call( + kill_command, timeout=100 + ) # pylint: disable=not-callable else: status = subprocess.call(kill_command) # pylint: disable=not-callable except Exception: - LOG.exception('Unable to pkill process.') + LOG.exception("Unable to pkill process.") return status @@ -151,11 +161,12 @@ def on_parent_exit(signame): Based on https://gist.github.com/evansd/2346614 """ + def noop(): pass try: - libc = cdll['libc.so.6'] + libc = cdll["libc.so.6"] except OSError: # libc, can't be found (e.g. running on non-Unix system), we cant ensure signal will be # triggered @@ -173,5 +184,6 @@ def set_parent_exit_signal(): # http://linux.die.net/man/2/prctl result = prctl(PR_SET_PDEATHSIG, signum) if result != 0: - raise Exception('prctl failed with error code %s' % result) + raise Exception("prctl failed with error code %s" % result) + return set_parent_exit_signal diff --git a/st2common/st2common/util/spec_loader.py b/st2common/st2common/util/spec_loader.py index 8ab926330fa..07889fa2d22 100644 --- a/st2common/st2common/util/spec_loader.py +++ b/st2common/st2common/util/spec_loader.py @@ -33,16 +33,13 @@ from st2common.rbac.types import PermissionType from st2common.util import isotime -__all__ = [ - 'load_spec', - 'generate_spec' -] +__all__ = ["load_spec", "generate_spec"] ARGUMENTS = { - 'DEFAULT_PACK_NAME': st2common.constants.pack.DEFAULT_PACK_NAME, - 'LIVEACTION_STATUSES': st2common.constants.action.LIVEACTION_STATUSES, - 'PERMISSION_TYPE': PermissionType, - 'ISO8601_UTC_REGEX': isotime.ISO8601_UTC_REGEX + "DEFAULT_PACK_NAME": st2common.constants.pack.DEFAULT_PACK_NAME, + "LIVEACTION_STATUSES": st2common.constants.action.LIVEACTION_STATUSES, + "PERMISSION_TYPE": PermissionType, + "ISO8601_UTC_REGEX": isotime.ISO8601_UTC_REGEX, } @@ -50,23 +47,35 @@ class UniqueKeyLoader(Loader): """ YAML loader which throws on a duplicate key. """ + def construct_mapping(self, node, deep=False): if not isinstance(node, MappingNode): - raise ConstructorError(None, None, - "expected a mapping node, but found %s" % node.id, - node.start_mark) + raise ConstructorError( + None, + None, + "expected a mapping node, but found %s" % node.id, + node.start_mark, + ) mapping = {} for key_node, value_node in node.value: key = self.construct_object(key_node, deep=deep) try: hash(key) except TypeError as exc: - raise ConstructorError("while constructing a mapping", node.start_mark, - "found unacceptable key (%s)" % exc, key_node.start_mark) + raise ConstructorError( + "while constructing a mapping", + node.start_mark, + "found unacceptable key (%s)" % exc, + key_node.start_mark, + ) # check for duplicate keys if key in mapping: - raise ConstructorError("while constructing a mapping", node.start_mark, - "found duplicate key", key_node.start_mark) + raise ConstructorError( + "while constructing a mapping", + node.start_mark, + "found duplicate key", + key_node.start_mark, + ) value = self.construct_object(value_node, deep=deep) mapping[key] = value return mapping diff --git a/st2common/st2common/util/system_info.py b/st2common/st2common/util/system_info.py index a83bf5169fd..b81d205907f 100644 --- a/st2common/st2common/util/system_info.py +++ b/st2common/st2common/util/system_info.py @@ -17,22 +17,14 @@ import os import socket -__all__ = [ - 'get_host_info', - 'get_process_info' -] +__all__ = ["get_host_info", "get_process_info"] def get_host_info(): - host_info = { - 'hostname': socket.gethostname() - } + host_info = {"hostname": socket.gethostname()} return host_info def get_process_info(): - process_info = { - 'hostname': socket.gethostname(), - 'pid': os.getpid() - } + process_info = {"hostname": socket.gethostname(), "pid": os.getpid()} return process_info diff --git a/st2common/st2common/util/templating.py b/st2common/st2common/util/templating.py index 9dc25d917ca..82e8e1c2461 100644 --- a/st2common/st2common/util/templating.py +++ b/st2common/st2common/util/templating.py @@ -24,9 +24,9 @@ from st2common.services.keyvalues import UserKeyValueLookup __all__ = [ - 'render_template', - 'render_template_with_system_context', - 'render_template_with_system_and_user_context' + "render_template", + "render_template_with_system_context", + "render_template_with_system_and_user_context", ] @@ -74,7 +74,9 @@ def render_template_with_system_context(value, context=None, prefix=None): return rendered -def render_template_with_system_and_user_context(value, user, context=None, prefix=None): +def render_template_with_system_and_user_context( + value, user, context=None, prefix=None +): """ Render provided template with a default system context and user context for the provided user. @@ -95,7 +97,7 @@ def render_template_with_system_and_user_context(value, user, context=None, pref context = context or {} context[DATASTORE_PARENT_SCOPE] = { SYSTEM_SCOPE: KeyValueLookup(prefix=prefix, scope=FULL_SYSTEM_SCOPE), - USER_SCOPE: UserKeyValueLookup(prefix=prefix, user=user, scope=FULL_USER_SCOPE) + USER_SCOPE: UserKeyValueLookup(prefix=prefix, user=user, scope=FULL_USER_SCOPE), } rendered = render_template(value=value, context=context) diff --git a/st2common/st2common/util/types.py b/st2common/st2common/util/types.py index 5c25990a6ed..ad70f078b9b 100644 --- a/st2common/st2common/util/types.py +++ b/st2common/st2common/util/types.py @@ -20,17 +20,14 @@ from __future__ import absolute_import import collections -__all__ = [ - 'OrderedSet' -] +__all__ = ["OrderedSet"] class OrderedSet(collections.MutableSet): - def __init__(self, iterable=None): self.end = end = [] - end += [None, end, end] # sentinel node for doubly linked list - self.map = {} # key --> [key, prev, next] + end += [None, end, end] # sentinel node for doubly linked list + self.map = {} # key --> [key, prev, next] if iterable is not None: self |= iterable @@ -68,15 +65,15 @@ def __reversed__(self): def pop(self, last=True): if not self: - raise KeyError('set is empty') + raise KeyError("set is empty") key = self.end[1][0] if last else self.end[2][0] self.discard(key) return key def __repr__(self): if not self: - return '%s()' % (self.__class__.__name__,) - return '%s(%r)' % (self.__class__.__name__, list(self)) + return "%s()" % (self.__class__.__name__,) + return "%s(%r)" % (self.__class__.__name__, list(self)) def __eq__(self, other): if isinstance(other, OrderedSet): diff --git a/st2common/st2common/util/uid.py b/st2common/st2common/util/uid.py index 07d04d7511e..289184d59e9 100644 --- a/st2common/st2common/util/uid.py +++ b/st2common/st2common/util/uid.py @@ -20,9 +20,7 @@ from __future__ import absolute_import from st2common.models.db.stormbase import UIDFieldMixin -__all__ = [ - 'parse_uid' -] +__all__ = ["parse_uid"] def parse_uid(uid): @@ -33,12 +31,12 @@ def parse_uid(uid): :rtype: ``tuple`` """ if UIDFieldMixin.UID_SEPARATOR not in uid: - raise ValueError('Invalid uid: %s' % (uid)) + raise ValueError("Invalid uid: %s" % (uid)) parsed = uid.split(UIDFieldMixin.UID_SEPARATOR) if len(parsed) < 2: - raise ValueError('Invalid or malformed uid: %s' % (uid)) + raise ValueError("Invalid or malformed uid: %s" % (uid)) resource_type = parsed[0] uid_remainder = parsed[1:] diff --git a/st2common/st2common/util/ujson.py b/st2common/st2common/util/ujson.py index cace2434486..6c533fb30ac 100644 --- a/st2common/st2common/util/ujson.py +++ b/st2common/st2common/util/ujson.py @@ -19,9 +19,7 @@ import ujson -__all__ = [ - 'fast_deepcopy' -] +__all__ = ["fast_deepcopy"] def fast_deepcopy(value, fall_back_to_deepcopy=True): diff --git a/st2common/st2common/util/url.py b/st2common/st2common/util/url.py index 9c3196f8355..b4dd8fc1370 100644 --- a/st2common/st2common/util/url.py +++ b/st2common/st2common/util/url.py @@ -13,9 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = [ - 'get_url_without_trailing_slash' -] +__all__ = ["get_url_without_trailing_slash"] def get_url_without_trailing_slash(value): @@ -27,5 +25,5 @@ def get_url_without_trailing_slash(value): :rtype: ``str`` """ - result = value[:-1] if value.endswith('/') else value + result = value[:-1] if value.endswith("/") else value return result diff --git a/st2common/st2common/util/versioning.py b/st2common/st2common/util/versioning.py index 121a93312ab..89da24f1744 100644 --- a/st2common/st2common/util/versioning.py +++ b/st2common/st2common/util/versioning.py @@ -25,12 +25,7 @@ from st2common import __version__ as stackstorm_version -__all__ = [ - 'get_stackstorm_version', - 'get_python_version', - - 'complex_semver_match' -] +__all__ = ["get_stackstorm_version", "get_python_version", "complex_semver_match"] def get_stackstorm_version(): @@ -38,8 +33,8 @@ def get_stackstorm_version(): Return a valid semver version string for the currently running StackStorm version. """ # Special handling for dev versions which are not valid semver identifiers - if 'dev' in stackstorm_version and stackstorm_version.count('.') == 1: - version = stackstorm_version.replace('dev', '.0') + if "dev" in stackstorm_version and stackstorm_version.count(".") == 1: + version = stackstorm_version.replace("dev", ".0") return version return stackstorm_version @@ -50,7 +45,7 @@ def get_python_version(): Return Python version used by this installation. """ version_info = sys.version_info - return '%s.%s.%s' % (version_info.major, version_info.minor, version_info.micro) + return "%s.%s.%s" % (version_info.major, version_info.minor, version_info.micro) def complex_semver_match(version, version_specifier): @@ -63,10 +58,10 @@ def complex_semver_match(version, version_specifier): :rtype: ``bool`` """ - if version_specifier == 'all': + if version_specifier == "all": return True - split_version_specifier = version_specifier.split(',') + split_version_specifier = version_specifier.split(",") if len(split_version_specifier) == 1: # No comma, we can do a simple comparision diff --git a/st2common/st2common/util/virtualenvs.py b/st2common/st2common/util/virtualenvs.py index db56e6fb20a..7f408c9da32 100644 --- a/st2common/st2common/util/virtualenvs.py +++ b/st2common/st2common/util/virtualenvs.py @@ -36,16 +36,22 @@ from st2common.content.utils import get_packs_base_paths from st2common.content.utils import get_pack_directory -__all__ = [ - 'setup_pack_virtualenv' -] +__all__ = ["setup_pack_virtualenv"] LOG = logging.getLogger(__name__) -def setup_pack_virtualenv(pack_name, update=False, logger=None, include_pip=True, - include_setuptools=True, include_wheel=True, proxy_config=None, - no_download=True, force_owner_group=True): +def setup_pack_virtualenv( + pack_name, + update=False, + logger=None, + include_pip=True, + include_setuptools=True, + include_wheel=True, + proxy_config=None, + no_download=True, + force_owner_group=True, +): """ Setup virtual environment for the provided pack. @@ -68,7 +74,7 @@ def setup_pack_virtualenv(pack_name, update=False, logger=None, include_pip=True if not re.match(PACK_REF_WHITELIST_REGEX, pack_name): raise ValueError('Invalid pack name "%s"' % (pack_name)) - base_virtualenvs_path = os.path.join(cfg.CONF.system.base_path, 'virtualenvs/') + base_virtualenvs_path = os.path.join(cfg.CONF.system.base_path, "virtualenvs/") virtualenv_path = os.path.join(base_virtualenvs_path, quote_unix(pack_name)) # Ensure pack directory exists in one of the search paths @@ -78,7 +84,7 @@ def setup_pack_virtualenv(pack_name, update=False, logger=None, include_pip=True if not pack_path: packs_base_paths = get_packs_base_paths() - search_paths = ', '.join(packs_base_paths) + search_paths = ", ".join(packs_base_paths) msg = 'Pack "%s" is not installed. Looked in: %s' % (pack_name, search_paths) raise Exception(msg) @@ -88,42 +94,64 @@ def setup_pack_virtualenv(pack_name, update=False, logger=None, include_pip=True remove_virtualenv(virtualenv_path=virtualenv_path, logger=logger) # 1. Create virtual environment - logger.debug('Creating virtualenv for pack "%s" in "%s"' % (pack_name, virtualenv_path)) - create_virtualenv(virtualenv_path=virtualenv_path, logger=logger, include_pip=include_pip, - include_setuptools=include_setuptools, include_wheel=include_wheel, - no_download=no_download) + logger.debug( + 'Creating virtualenv for pack "%s" in "%s"' % (pack_name, virtualenv_path) + ) + create_virtualenv( + virtualenv_path=virtualenv_path, + logger=logger, + include_pip=include_pip, + include_setuptools=include_setuptools, + include_wheel=include_wheel, + no_download=no_download, + ) # 2. Install base requirements which are common to all the packs - logger.debug('Installing base requirements') + logger.debug("Installing base requirements") for requirement in BASE_PACK_REQUIREMENTS: - install_requirement(virtualenv_path=virtualenv_path, requirement=requirement, - proxy_config=proxy_config, logger=logger) + install_requirement( + virtualenv_path=virtualenv_path, + requirement=requirement, + proxy_config=proxy_config, + logger=logger, + ) # 3. Install pack-specific requirements - requirements_file_path = os.path.join(pack_path, 'requirements.txt') + requirements_file_path = os.path.join(pack_path, "requirements.txt") has_requirements = os.path.isfile(requirements_file_path) if has_requirements: - logger.debug('Installing pack specific requirements from "%s"' % - (requirements_file_path)) - install_requirements(virtualenv_path=virtualenv_path, - requirements_file_path=requirements_file_path, - proxy_config=proxy_config, - logger=logger) + logger.debug( + 'Installing pack specific requirements from "%s"' % (requirements_file_path) + ) + install_requirements( + virtualenv_path=virtualenv_path, + requirements_file_path=requirements_file_path, + proxy_config=proxy_config, + logger=logger, + ) else: - logger.debug('No pack specific requirements found') + logger.debug("No pack specific requirements found") # 4. Set the owner group if force_owner_group: apply_pack_owner_group(pack_path=virtualenv_path) - action = 'updated' if update else 'created' - logger.debug('Virtualenv for pack "%s" successfully %s in "%s"' % - (pack_name, action, virtualenv_path)) - - -def create_virtualenv(virtualenv_path, logger=None, include_pip=True, include_setuptools=True, - include_wheel=True, no_download=True): + action = "updated" if update else "created" + logger.debug( + 'Virtualenv for pack "%s" successfully %s in "%s"' + % (pack_name, action, virtualenv_path) + ) + + +def create_virtualenv( + virtualenv_path, + logger=None, + include_pip=True, + include_setuptools=True, + include_wheel=True, + no_download=True, +): """ :param include_pip: Include pip binary and package in the newely created virtual environment. :type include_pip: ``bool`` @@ -145,7 +173,7 @@ def create_virtualenv(virtualenv_path, logger=None, include_pip=True, include_se python_binary = cfg.CONF.actionrunner.python_binary virtualenv_binary = cfg.CONF.actionrunner.virtualenv_binary virtualenv_opts = cfg.CONF.actionrunner.virtualenv_opts or [] - virtualenv_opts += ['--verbose'] + virtualenv_opts += ["--verbose"] if not os.path.isfile(python_binary): raise Exception('Python binary "%s" doesn\'t exist' % (python_binary)) @@ -153,39 +181,44 @@ def create_virtualenv(virtualenv_path, logger=None, include_pip=True, include_se if not os.path.isfile(virtualenv_binary): raise Exception('Virtualenv binary "%s" doesn\'t exist.' % (virtualenv_binary)) - logger.debug('Creating virtualenv in "%s" using Python binary "%s"' % - (virtualenv_path, python_binary)) + logger.debug( + 'Creating virtualenv in "%s" using Python binary "%s"' + % (virtualenv_path, python_binary) + ) cmd = [virtualenv_binary] - cmd.extend(['-p', python_binary]) + cmd.extend(["-p", python_binary]) cmd.extend(virtualenv_opts) if not include_pip: - cmd.append('--no-pip') + cmd.append("--no-pip") if not include_setuptools: - cmd.append('--no-setuptools') + cmd.append("--no-setuptools") if not include_wheel: - cmd.append('--no-wheel') + cmd.append("--no-wheel") if no_download: - cmd.append('--no-download') + cmd.append("--no-download") cmd.extend([virtualenv_path]) - logger.debug('Running command "%s" to create virtualenv.', ' '.join(cmd)) + logger.debug('Running command "%s" to create virtualenv.', " ".join(cmd)) try: exit_code, stdout, stderr = run_command(cmd=cmd) except OSError as e: - raise Exception('Error executing command %s. %s.' % (' '.join(cmd), - six.text_type(e))) + raise Exception( + "Error executing command %s. %s." % (" ".join(cmd), six.text_type(e)) + ) if exit_code != 0: - raise Exception('Failed to create virtualenv in "%s":\n stdout=%s\n stderr=%s' % - (virtualenv_path, stdout, stderr)) + raise Exception( + 'Failed to create virtualenv in "%s":\n stdout=%s\n stderr=%s' + % (virtualenv_path, stdout, stderr) + ) return True @@ -204,51 +237,60 @@ def remove_virtualenv(virtualenv_path, logger=None): try: shutil.rmtree(virtualenv_path) except Exception as e: - logger.error('Error while removing virtualenv at "%s": "%s"' % (virtualenv_path, e)) + logger.error( + 'Error while removing virtualenv at "%s": "%s"' % (virtualenv_path, e) + ) raise e return True -def install_requirements(virtualenv_path, requirements_file_path, proxy_config=None, logger=None): +def install_requirements( + virtualenv_path, requirements_file_path, proxy_config=None, logger=None +): """ Install requirements from a file. """ logger = logger or LOG - pip_path = os.path.join(virtualenv_path, 'bin/pip') + pip_path = os.path.join(virtualenv_path, "bin/pip") pip_opts = cfg.CONF.actionrunner.pip_opts or [] cmd = [pip_path] if proxy_config: - cert = proxy_config.get('proxy_ca_bundle_path', None) - https_proxy = proxy_config.get('https_proxy', None) - http_proxy = proxy_config.get('http_proxy', None) + cert = proxy_config.get("proxy_ca_bundle_path", None) + https_proxy = proxy_config.get("https_proxy", None) + http_proxy = proxy_config.get("http_proxy", None) if http_proxy: - cmd.extend(['--proxy', http_proxy]) + cmd.extend(["--proxy", http_proxy]) if https_proxy: - cmd.extend(['--proxy', https_proxy]) + cmd.extend(["--proxy", https_proxy]) if cert: - cmd.extend(['--cert', cert]) + cmd.extend(["--cert", cert]) - cmd.append('install') + cmd.append("install") cmd.extend(pip_opts) - cmd.extend(['-U', '-r', requirements_file_path]) + cmd.extend(["-U", "-r", requirements_file_path]) env = get_env_for_subprocess_command() - logger.debug('Installing requirements from file %s with command %s.', - requirements_file_path, ' '.join(cmd)) + logger.debug( + "Installing requirements from file %s with command %s.", + requirements_file_path, + " ".join(cmd), + ) exit_code, stdout, stderr = run_command(cmd=cmd, env=env) if exit_code != 0: stdout = to_ascii(stdout) stderr = to_ascii(stderr) - raise Exception('Failed to install requirements from "%s": %s (stderr: %s)' % - (requirements_file_path, stdout, stderr)) + raise Exception( + 'Failed to install requirements from "%s": %s (stderr: %s)' + % (requirements_file_path, stdout, stderr) + ) return True @@ -260,35 +302,37 @@ def install_requirement(virtualenv_path, requirement, proxy_config=None, logger= :param requirement: Requirement specifier. """ logger = logger or LOG - pip_path = os.path.join(virtualenv_path, 'bin/pip') + pip_path = os.path.join(virtualenv_path, "bin/pip") pip_opts = cfg.CONF.actionrunner.pip_opts or [] cmd = [pip_path] if proxy_config: - cert = proxy_config.get('proxy_ca_bundle_path', None) - https_proxy = proxy_config.get('https_proxy', None) - http_proxy = proxy_config.get('http_proxy', None) + cert = proxy_config.get("proxy_ca_bundle_path", None) + https_proxy = proxy_config.get("https_proxy", None) + http_proxy = proxy_config.get("http_proxy", None) if http_proxy: - cmd.extend(['--proxy', http_proxy]) + cmd.extend(["--proxy", http_proxy]) if https_proxy: - cmd.extend(['--proxy', https_proxy]) + cmd.extend(["--proxy", https_proxy]) if cert: - cmd.extend(['--cert', cert]) + cmd.extend(["--cert", cert]) - cmd.append('install') + cmd.append("install") cmd.extend(pip_opts) cmd.extend([requirement]) env = get_env_for_subprocess_command() - logger.debug('Installing requirement %s with command %s.', - requirement, ' '.join(cmd)) + logger.debug( + "Installing requirement %s with command %s.", requirement, " ".join(cmd) + ) exit_code, stdout, stderr = run_command(cmd=cmd, env=env) if exit_code != 0: - raise Exception('Failed to install requirement "%s": %s' % - (requirement, stdout)) + raise Exception( + 'Failed to install requirement "%s": %s' % (requirement, stdout) + ) return True @@ -302,7 +346,7 @@ def get_env_for_subprocess_command(): """ env = os.environ.copy() - if 'PYTHONPATH' in env: - del env['PYTHONPATH'] + if "PYTHONPATH" in env: + del env["PYTHONPATH"] return env diff --git a/st2common/st2common/util/wsgi.py b/st2common/st2common/util/wsgi.py index a3441e4bdac..63ec6c62532 100644 --- a/st2common/st2common/util/wsgi.py +++ b/st2common/st2common/util/wsgi.py @@ -24,9 +24,7 @@ LOG = logging.getLogger(__name__) -__all__ = [ - 'shutdown_server_kill_pending_requests' -] +__all__ = ["shutdown_server_kill_pending_requests"] def shutdown_server_kill_pending_requests(sock, worker_pool, wait_time=2): @@ -46,7 +44,7 @@ def shutdown_server_kill_pending_requests(sock, worker_pool, wait_time=2): sock.close() active_requests = worker_pool.running() - LOG.info('Shutting down. Requests left: %s', active_requests) + LOG.info("Shutting down. Requests left: %s", active_requests) # Give active requests some time to finish if active_requests > 0: @@ -57,5 +55,5 @@ def shutdown_server_kill_pending_requests(sock, worker_pool, wait_time=2): for coro in running_corutines: eventlet.greenthread.kill(coro) - LOG.info('Exiting...') + LOG.info("Exiting...") raise SystemExit() diff --git a/st2common/st2common/validators/api/action.py b/st2common/st2common/validators/api/action.py index 1eb5dbfeb9d..973e999fa6a 100644 --- a/st2common/st2common/validators/api/action.py +++ b/st2common/st2common/validators/api/action.py @@ -26,10 +26,7 @@ from st2common.models.system.common import ResourceReference from six.moves import range -__all__ = [ - 'validate_action', - 'get_runner_model' -] +__all__ = ["validate_action", "get_runner_model"] LOG = logging.getLogger(__name__) @@ -49,14 +46,17 @@ def validate_action(action_api, runner_type_db=None): # Check if pack is valid. if not _is_valid_pack(action_api.pack): packs_base_paths = get_packs_base_paths() - packs_base_paths = ','.join(packs_base_paths) - msg = ('Content pack "%s" is not found or doesn\'t contain actions directory. ' - 'Searched in: %s' % - (action_api.pack, packs_base_paths)) + packs_base_paths = ",".join(packs_base_paths) + msg = ( + 'Content pack "%s" is not found or doesn\'t contain actions directory. ' + "Searched in: %s" % (action_api.pack, packs_base_paths) + ) raise ValueValidationException(msg) # Check if parameters defined are valid. - action_ref = ResourceReference.to_string_reference(pack=action_api.pack, name=action_api.name) + action_ref = ResourceReference.to_string_reference( + pack=action_api.pack, name=action_api.name + ) _validate_parameters(action_ref, action_api.parameters, runner_db.runner_parameters) @@ -66,15 +66,18 @@ def get_runner_model(action_api): try: runner_db = get_runnertype_by_name(action_api.runner_type) except StackStormDBObjectNotFoundError: - msg = ('RunnerType %s is not found. If you are using old and deprecated runner name, you ' - 'need to switch to a new one. For more information, please see ' - 'https://docs.stackstorm.com/upgrade_notes.html#st2-v0-9' % (action_api.runner_type)) + msg = ( + "RunnerType %s is not found. If you are using old and deprecated runner name, you " + "need to switch to a new one. For more information, please see " + "https://docs.stackstorm.com/upgrade_notes.html#st2-v0-9" + % (action_api.runner_type) + ) raise ValueValidationException(msg) return runner_db def _is_valid_pack(pack): - return check_pack_content_directory_exists(pack=pack, content_type='actions') + return check_pack_content_directory_exists(pack=pack, content_type="actions") def _validate_parameters(action_ref, action_params=None, runner_params=None): @@ -84,32 +87,44 @@ def _validate_parameters(action_ref, action_params=None, runner_params=None): if action_param in runner_params: for action_param_attr, value in six.iteritems(action_param_meta): util_schema.validate_runner_parameter_attribute_override( - action_ref, action_param, action_param_attr, - value, runner_params[action_param].get(action_param_attr)) - - if 'position' in action_param_meta: - pos = action_param_meta['position'] + action_ref, + action_param, + action_param_attr, + value, + runner_params[action_param].get(action_param_attr), + ) + + if "position" in action_param_meta: + pos = action_param_meta["position"] param = position_params.get(pos, None) if param: - msg = ('Parameters %s and %s have same position %d.' % (action_param, param, pos) + - ' Position values have to be unique.') + msg = ( + "Parameters %s and %s have same position %d." + % (action_param, param, pos) + + " Position values have to be unique." + ) raise ValueValidationException(msg) else: position_params[pos] = action_param - if 'immutable' in action_param_meta: + if "immutable" in action_param_meta: if action_param in runner_params: runner_param_meta = runner_params[action_param] - if 'immutable' in runner_param_meta: - msg = 'Param %s is declared immutable in runner. ' % action_param + \ - 'Cannot override in action.' + if "immutable" in runner_param_meta: + msg = ( + "Param %s is declared immutable in runner. " % action_param + + "Cannot override in action." + ) raise ValueValidationException(msg) - if 'default' not in action_param_meta and 'default' not in runner_param_meta: - msg = 'Immutable param %s requires a default value.' % action_param + if ( + "default" not in action_param_meta + and "default" not in runner_param_meta + ): + msg = "Immutable param %s requires a default value." % action_param raise ValueValidationException(msg) else: - if 'default' not in action_param_meta: - msg = 'Immutable param %s requires a default value.' % action_param + if "default" not in action_param_meta: + msg = "Immutable param %s requires a default value." % action_param raise ValueValidationException(msg) return _validate_position_values_contiguous(position_params) @@ -120,10 +135,10 @@ def _validate_position_values_contiguous(position_params): return True positions = sorted(position_params.keys()) - contiguous = (positions == list(range(min(positions), max(positions) + 1))) + contiguous = positions == list(range(min(positions), max(positions) + 1)) if not contiguous: - msg = 'Positions supplied %s for parameters are not contiguous.' % positions + msg = "Positions supplied %s for parameters are not contiguous." % positions raise ValueValidationException(msg) return True diff --git a/st2common/st2common/validators/api/misc.py b/st2common/st2common/validators/api/misc.py index b18ff05d21a..215afc55016 100644 --- a/st2common/st2common/validators/api/misc.py +++ b/st2common/st2common/validators/api/misc.py @@ -17,9 +17,7 @@ from st2common.constants.pack import SYSTEM_PACK_NAME from st2common.exceptions.apivalidation import ValueValidationException -__all__ = [ - 'validate_not_part_of_system_pack' -] +__all__ = ["validate_not_part_of_system_pack"] def validate_not_part_of_system_pack(resource_db): @@ -32,10 +30,10 @@ def validate_not_part_of_system_pack(resource_db): :param resource_db: Resource database object to check. :type resource_db: ``object`` """ - pack = getattr(resource_db, 'pack', None) + pack = getattr(resource_db, "pack", None) if pack == SYSTEM_PACK_NAME: - msg = 'Resources belonging to system level packs can\'t be manipulated' + msg = "Resources belonging to system level packs can't be manipulated" raise ValueValidationException(msg) return resource_db diff --git a/st2common/st2common/validators/api/reactor.py b/st2common/st2common/validators/api/reactor.py index eb2cf1c8147..0d84a66a996 100644 --- a/st2common/st2common/validators/api/reactor.py +++ b/st2common/st2common/validators/api/reactor.py @@ -29,10 +29,9 @@ from st2common.services import triggers __all__ = [ - 'validate_criteria', - - 'validate_trigger_parameters', - 'validate_trigger_payload' + "validate_criteria", + "validate_trigger_parameters", + "validate_trigger_payload", ] @@ -43,20 +42,30 @@ def validate_criteria(criteria): if not isinstance(criteria, dict): - raise ValueValidationException('Criteria should be a dict.') + raise ValueValidationException("Criteria should be a dict.") for key, value in six.iteritems(criteria): - operator = value.get('type', None) + operator = value.get("type", None) if operator is None: - raise ValueValidationException('Operator not specified for field: ' + key) + raise ValueValidationException("Operator not specified for field: " + key) if operator not in allowed_operators: - raise ValueValidationException('For field: ' + key + ', operator ' + operator + - ' not in list of allowed operators: ' + - str(list(allowed_operators.keys()))) - pattern = value.get('pattern', None) + raise ValueValidationException( + "For field: " + + key + + ", operator " + + operator + + " not in list of allowed operators: " + + str(list(allowed_operators.keys())) + ) + pattern = value.get("pattern", None) if pattern is None: - raise ValueValidationException('For field: ' + key + ', no pattern specified ' + - 'for operator ' + operator) + raise ValueValidationException( + "For field: " + + key + + ", no pattern specified " + + "for operator " + + operator + ) def validate_trigger_parameters(trigger_type_ref, parameters): @@ -77,27 +86,33 @@ def validate_trigger_parameters(trigger_type_ref, parameters): is_system_trigger = trigger_type_ref in SYSTEM_TRIGGER_TYPES if is_system_trigger: # System trigger - parameters_schema = SYSTEM_TRIGGER_TYPES[trigger_type_ref]['parameters_schema'] + parameters_schema = SYSTEM_TRIGGER_TYPES[trigger_type_ref]["parameters_schema"] else: trigger_type_db = triggers.get_trigger_type_db(trigger_type_ref) if not trigger_type_db: # Trigger doesn't exist in the database return None - parameters_schema = getattr(trigger_type_db, 'parameters_schema', {}) + parameters_schema = getattr(trigger_type_db, "parameters_schema", {}) if not parameters_schema: # Parameters schema not defined for the this trigger return None # We only validate non-system triggers if config option is set (enabled) if not is_system_trigger and not cfg.CONF.system.validate_trigger_parameters: - LOG.debug('Got non-system trigger "%s", but trigger parameter validation for non-system' - 'triggers is disabled, skipping validation.' % (trigger_type_ref)) + LOG.debug( + 'Got non-system trigger "%s", but trigger parameter validation for non-system' + "triggers is disabled, skipping validation." % (trigger_type_ref) + ) return None - cleaned = util_schema.validate(instance=parameters, schema=parameters_schema, - cls=util_schema.CustomValidator, use_default=True, - allow_default_none=True) + cleaned = util_schema.validate( + instance=parameters, + schema=parameters_schema, + cls=util_schema.CustomValidator, + use_default=True, + allow_default_none=True, + ) # Additional validation for CronTimer trigger # TODO: If we need to add more checks like this we should consider abstracting this out. @@ -110,7 +125,9 @@ def validate_trigger_parameters(trigger_type_ref, parameters): return cleaned -def validate_trigger_payload(trigger_type_ref, payload, throw_on_inexistent_trigger=False): +def validate_trigger_payload( + trigger_type_ref, payload, throw_on_inexistent_trigger=False +): """ This function validates trigger payload parameters for system and user-defined triggers. @@ -128,8 +145,8 @@ def validate_trigger_payload(trigger_type_ref, payload, throw_on_inexistent_trig # NOTE: Due to the awful code in some other places we also need to support a scenario where # this variable is a dictionary and contains various TriggerDB object attributes. if isinstance(trigger_type_ref, dict): - if trigger_type_ref.get('type', None): - trigger_type_ref = trigger_type_ref['type'] + if trigger_type_ref.get("type", None): + trigger_type_ref = trigger_type_ref["type"] else: trigger_db = triggers.get_trigger_db_by_ref_or_dict(trigger_type_ref) @@ -143,16 +160,16 @@ def validate_trigger_payload(trigger_type_ref, payload, throw_on_inexistent_trig is_system_trigger = trigger_type_ref in SYSTEM_TRIGGER_TYPES if is_system_trigger: # System trigger - payload_schema = SYSTEM_TRIGGER_TYPES[trigger_type_ref]['payload_schema'] + payload_schema = SYSTEM_TRIGGER_TYPES[trigger_type_ref]["payload_schema"] else: # We assume Trigger ref and not TriggerType ref is passed in if second # part (trigger name) is a valid UUID version 4 try: - trigger_uuid = uuid.UUID(trigger_type_ref.split('.')[-1]) + trigger_uuid = uuid.UUID(trigger_type_ref.split(".")[-1]) except ValueError: is_trigger_db = False else: - is_trigger_db = (trigger_uuid.version == 4) + is_trigger_db = trigger_uuid.version == 4 if is_trigger_db: trigger_db = triggers.get_trigger_db_by_ref(trigger_type_ref) @@ -165,25 +182,33 @@ def validate_trigger_payload(trigger_type_ref, payload, throw_on_inexistent_trig if not trigger_type_db: # Trigger doesn't exist in the database if throw_on_inexistent_trigger: - msg = ('Trigger type with reference "%s" doesn\'t exist in the database' % - (trigger_type_ref)) + msg = ( + 'Trigger type with reference "%s" doesn\'t exist in the database' + % (trigger_type_ref) + ) raise ValueError(msg) return None - payload_schema = getattr(trigger_type_db, 'payload_schema', {}) + payload_schema = getattr(trigger_type_db, "payload_schema", {}) if not payload_schema: # Payload schema not defined for the this trigger return None # We only validate non-system triggers if config option is set (enabled) if not is_system_trigger and not cfg.CONF.system.validate_trigger_payload: - LOG.debug('Got non-system trigger "%s", but trigger payload validation for non-system' - 'triggers is disabled, skipping validation.' % (trigger_type_ref)) + LOG.debug( + 'Got non-system trigger "%s", but trigger payload validation for non-system' + "triggers is disabled, skipping validation." % (trigger_type_ref) + ) return None - cleaned = util_schema.validate(instance=payload, schema=payload_schema, - cls=util_schema.CustomValidator, use_default=True, - allow_default_none=True) + cleaned = util_schema.validate( + instance=payload, + schema=payload_schema, + cls=util_schema.CustomValidator, + use_default=True, + allow_default_none=True, + ) return cleaned diff --git a/st2common/st2common/validators/workflow/base.py b/st2common/st2common/validators/workflow/base.py index 226a4668fbd..3bf8e9fbd5a 100644 --- a/st2common/st2common/validators/workflow/base.py +++ b/st2common/st2common/validators/workflow/base.py @@ -20,7 +20,6 @@ @six.add_metaclass(abc.ABCMeta) class WorkflowValidator(object): - @abc.abstractmethod def validate(self, definition): raise NotImplementedError diff --git a/st2common/tests/fixtures/mock_runner/mock_runner.py b/st2common/tests/fixtures/mock_runner/mock_runner.py index 9110e740f4a..66295e84210 100644 --- a/st2common/tests/fixtures/mock_runner/mock_runner.py +++ b/st2common/tests/fixtures/mock_runner/mock_runner.py @@ -23,9 +23,7 @@ LOG = logging.getLogger(__name__) -__all__ = [ - 'get_runner' -] +__all__ = ["get_runner"] def get_runner(): @@ -36,6 +34,7 @@ class MockRunner(ActionRunner): """ Runner which does absolutely nothing. """ + KEYS_TO_TRANSFORM = [] def __init__(self, runner_id): @@ -47,9 +46,9 @@ def pre_run(self): def run(self, action_parameters): result = { - 'failed': False, - 'succeeded': True, - 'return_code': 0, + "failed": False, + "succeeded": True, + "return_code": 0, } status = LIVEACTION_STATUS_SUCCEEDED diff --git a/st2common/tests/fixtures/version_file.py b/st2common/tests/fixtures/version_file.py index 882f420538f..b52f01d75c3 100644 --- a/st2common/tests/fixtures/version_file.py +++ b/st2common/tests/fixtures/version_file.py @@ -14,4 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '1.2.3' +__version__ = "1.2.3" diff --git a/st2common/tests/integration/test_rabbitmq_ssl_listener.py b/st2common/tests/integration/test_rabbitmq_ssl_listener.py index 9c1ddeef06f..e64a22995da 100644 --- a/st2common/tests/integration/test_rabbitmq_ssl_listener.py +++ b/st2common/tests/integration/test_rabbitmq_ssl_listener.py @@ -27,12 +27,10 @@ from st2tests.fixturesloader import get_fixtures_base_path -__all__ = [ - 'RabbitMQTLSListenerTestCase' -] +__all__ = ["RabbitMQTLSListenerTestCase"] -CERTS_FIXTURES_PATH = os.path.join(get_fixtures_base_path(), 'ssl_certs/') -ST2_CI = (os.environ.get('ST2_CI', 'false').lower() == 'true') +CERTS_FIXTURES_PATH = os.path.join(get_fixtures_base_path(), "ssl_certs/") +ST2_CI = os.environ.get("ST2_CI", "false").lower() == "true" NON_SSL_LISTENER_PORT = 5672 SSL_LISTENER_PORT = 5671 @@ -40,42 +38,49 @@ # NOTE: We only run those tests on the CI provider because at the moment, local # vagrant dev VM doesn't expose RabbitMQ SSL listener by default -@unittest2.skipIf(not ST2_CI, - 'Skipping tests because ST2_CI environment variable is not set to "true"') +@unittest2.skipIf( + not ST2_CI, + 'Skipping tests because ST2_CI environment variable is not set to "true"', +) class RabbitMQTLSListenerTestCase(unittest2.TestCase): - def setUp(self): # Set default values - cfg.CONF.set_override(name='ssl', override=False, group='messaging') - cfg.CONF.set_override(name='ssl_keyfile', override=None, group='messaging') - cfg.CONF.set_override(name='ssl_certfile', override=None, group='messaging') - cfg.CONF.set_override(name='ssl_ca_certs', override=None, group='messaging') - cfg.CONF.set_override(name='ssl_cert_reqs', override=None, group='messaging') + cfg.CONF.set_override(name="ssl", override=False, group="messaging") + cfg.CONF.set_override(name="ssl_keyfile", override=None, group="messaging") + cfg.CONF.set_override(name="ssl_certfile", override=None, group="messaging") + cfg.CONF.set_override(name="ssl_ca_certs", override=None, group="messaging") + cfg.CONF.set_override(name="ssl_cert_reqs", override=None, group="messaging") def test_non_ssl_connection_on_ssl_listener_port_failure(self): - connection = transport_utils.get_connection(urls='amqp://guest:guest@127.0.0.1:5671/') + connection = transport_utils.get_connection( + urls="amqp://guest:guest@127.0.0.1:5671/" + ) - expected_msg_1 = '[Errno 104]' # followed by: ' Connection reset by peer' or ' ECONNRESET' - expected_msg_2 = 'Socket closed' - expected_msg_3 = 'Server unexpectedly closed connection' + expected_msg_1 = ( + "[Errno 104]" # followed by: ' Connection reset by peer' or ' ECONNRESET' + ) + expected_msg_2 = "Socket closed" + expected_msg_3 = "Server unexpectedly closed connection" try: connection.connect() except Exception as e: self.assertFalse(connection.connected) self.assertIsInstance(e, (IOError, socket.error)) - self.assertTrue(expected_msg_1 in six.text_type(e) or - expected_msg_2 in six.text_type(e) or - expected_msg_3 in six.text_type(e)) + self.assertTrue( + expected_msg_1 in six.text_type(e) + or expected_msg_2 in six.text_type(e) + or expected_msg_3 in six.text_type(e) + ) else: - self.fail('Exception was not thrown') + self.fail("Exception was not thrown") if connection: connection.release() def test_ssl_connection_on_ssl_listener_success(self): # Using query param notation - urls = 'amqp://guest:guest@127.0.0.1:5671/?ssl=true' + urls = "amqp://guest:guest@127.0.0.1:5671/?ssl=true" connection = transport_utils.get_connection(urls=urls) try: @@ -86,9 +91,11 @@ def test_ssl_connection_on_ssl_listener_success(self): connection.release() # Using messaging.ssl config option - cfg.CONF.set_override(name='ssl', override=True, group='messaging') + cfg.CONF.set_override(name="ssl", override=True, group="messaging") - connection = transport_utils.get_connection(urls='amqp://guest:guest@127.0.0.1:5671/') + connection = transport_utils.get_connection( + urls="amqp://guest:guest@127.0.0.1:5671/" + ) try: self.assertTrue(connection.connect()) @@ -98,15 +105,21 @@ def test_ssl_connection_on_ssl_listener_success(self): connection.release() def test_ssl_connection_ca_certs_provided(self): - ca_cert_path = os.path.join(CERTS_FIXTURES_PATH, 'ca/ca_certificate_bundle.pem') + ca_cert_path = os.path.join(CERTS_FIXTURES_PATH, "ca/ca_certificate_bundle.pem") - cfg.CONF.set_override(name='ssl', override=True, group='messaging') - cfg.CONF.set_override(name='ssl_ca_certs', override=ca_cert_path, group='messaging') + cfg.CONF.set_override(name="ssl", override=True, group="messaging") + cfg.CONF.set_override( + name="ssl_ca_certs", override=ca_cert_path, group="messaging" + ) # 1. Validate server cert against a valid CA bundle (success) - cert required - cfg.CONF.set_override(name='ssl_cert_reqs', override='required', group='messaging') + cfg.CONF.set_override( + name="ssl_cert_reqs", override="required", group="messaging" + ) - connection = transport_utils.get_connection(urls='amqp://guest:guest@127.0.0.1:5671/') + connection = transport_utils.get_connection( + urls="amqp://guest:guest@127.0.0.1:5671/" + ) try: self.assertTrue(connection.connect()) @@ -117,35 +130,51 @@ def test_ssl_connection_ca_certs_provided(self): # 2. Validate server cert against other CA bundle (failure) # CA bundle which was not used to sign the server cert - ca_cert_path = os.path.join('/etc/ssl/certs/thawte_Primary_Root_CA.pem') + ca_cert_path = os.path.join("/etc/ssl/certs/thawte_Primary_Root_CA.pem") - cfg.CONF.set_override(name='ssl_cert_reqs', override='required', group='messaging') - cfg.CONF.set_override(name='ssl_ca_certs', override=ca_cert_path, group='messaging') + cfg.CONF.set_override( + name="ssl_cert_reqs", override="required", group="messaging" + ) + cfg.CONF.set_override( + name="ssl_ca_certs", override=ca_cert_path, group="messaging" + ) - connection = transport_utils.get_connection(urls='amqp://guest:guest@127.0.0.1:5671/') + connection = transport_utils.get_connection( + urls="amqp://guest:guest@127.0.0.1:5671/" + ) - expected_msg = r'\[SSL: CERTIFICATE_VERIFY_FAILED\] certificate verify failed' + expected_msg = r"\[SSL: CERTIFICATE_VERIFY_FAILED\] certificate verify failed" self.assertRaisesRegexp(ssl.SSLError, expected_msg, connection.connect) # 3. Validate server cert against other CA bundle (failure) - ca_cert_path = os.path.join('/etc/ssl/certs/thawte_Primary_Root_CA.pem') + ca_cert_path = os.path.join("/etc/ssl/certs/thawte_Primary_Root_CA.pem") - cfg.CONF.set_override(name='ssl_cert_reqs', override='optional', group='messaging') - cfg.CONF.set_override(name='ssl_ca_certs', override=ca_cert_path, group='messaging') + cfg.CONF.set_override( + name="ssl_cert_reqs", override="optional", group="messaging" + ) + cfg.CONF.set_override( + name="ssl_ca_certs", override=ca_cert_path, group="messaging" + ) - connection = transport_utils.get_connection(urls='amqp://guest:guest@127.0.0.1:5671/') + connection = transport_utils.get_connection( + urls="amqp://guest:guest@127.0.0.1:5671/" + ) - expected_msg = r'\[SSL: CERTIFICATE_VERIFY_FAILED\] certificate verify failed' + expected_msg = r"\[SSL: CERTIFICATE_VERIFY_FAILED\] certificate verify failed" self.assertRaisesRegexp(ssl.SSLError, expected_msg, connection.connect) # 4. Validate server cert against other CA bundle (failure) # We use invalid bundle but cert_reqs is none - ca_cert_path = os.path.join('/etc/ssl/certs/thawte_Primary_Root_CA.pem') + ca_cert_path = os.path.join("/etc/ssl/certs/thawte_Primary_Root_CA.pem") - cfg.CONF.set_override(name='ssl_cert_reqs', override='none', group='messaging') - cfg.CONF.set_override(name='ssl_ca_certs', override=ca_cert_path, group='messaging') + cfg.CONF.set_override(name="ssl_cert_reqs", override="none", group="messaging") + cfg.CONF.set_override( + name="ssl_ca_certs", override=ca_cert_path, group="messaging" + ) - connection = transport_utils.get_connection(urls='amqp://guest:guest@127.0.0.1:5671/') + connection = transport_utils.get_connection( + urls="amqp://guest:guest@127.0.0.1:5671/" + ) try: self.assertTrue(connection.connect()) @@ -156,16 +185,28 @@ def test_ssl_connection_ca_certs_provided(self): def test_ssl_connect_client_side_cert_authentication(self): # 1. Success, valid client side cert provided - ssl_keyfile = os.path.join(CERTS_FIXTURES_PATH, 'client/private_key.pem') - ssl_certfile = os.path.join(CERTS_FIXTURES_PATH, 'client/client_certificate.pem') - ca_cert_path = os.path.join(CERTS_FIXTURES_PATH, 'ca/ca_certificate_bundle.pem') - - cfg.CONF.set_override(name='ssl_keyfile', override=ssl_keyfile, group='messaging') - cfg.CONF.set_override(name='ssl_certfile', override=ssl_certfile, group='messaging') - cfg.CONF.set_override(name='ssl_cert_reqs', override='required', group='messaging') - cfg.CONF.set_override(name='ssl_ca_certs', override=ca_cert_path, group='messaging') - - connection = transport_utils.get_connection(urls='amqp://guest:guest@127.0.0.1:5671/') + ssl_keyfile = os.path.join(CERTS_FIXTURES_PATH, "client/private_key.pem") + ssl_certfile = os.path.join( + CERTS_FIXTURES_PATH, "client/client_certificate.pem" + ) + ca_cert_path = os.path.join(CERTS_FIXTURES_PATH, "ca/ca_certificate_bundle.pem") + + cfg.CONF.set_override( + name="ssl_keyfile", override=ssl_keyfile, group="messaging" + ) + cfg.CONF.set_override( + name="ssl_certfile", override=ssl_certfile, group="messaging" + ) + cfg.CONF.set_override( + name="ssl_cert_reqs", override="required", group="messaging" + ) + cfg.CONF.set_override( + name="ssl_ca_certs", override=ca_cert_path, group="messaging" + ) + + connection = transport_utils.get_connection( + urls="amqp://guest:guest@127.0.0.1:5671/" + ) try: self.assertTrue(connection.connect()) @@ -175,16 +216,28 @@ def test_ssl_connect_client_side_cert_authentication(self): connection.release() # 2. Invalid client side cert provided - failure - ssl_keyfile = os.path.join(CERTS_FIXTURES_PATH, 'client/private_key.pem') - ssl_certfile = os.path.join(CERTS_FIXTURES_PATH, 'server/server_certificate.pem') - ca_cert_path = os.path.join(CERTS_FIXTURES_PATH, 'ca/ca_certificate_bundle.pem') - - cfg.CONF.set_override(name='ssl_keyfile', override=ssl_keyfile, group='messaging') - cfg.CONF.set_override(name='ssl_certfile', override=ssl_certfile, group='messaging') - cfg.CONF.set_override(name='ssl_cert_reqs', override='required', group='messaging') - cfg.CONF.set_override(name='ssl_ca_certs', override=ca_cert_path, group='messaging') - - connection = transport_utils.get_connection(urls='amqp://guest:guest@127.0.0.1:5671/') - - expected_msg = r'\[X509: KEY_VALUES_MISMATCH\] key values mismatch' + ssl_keyfile = os.path.join(CERTS_FIXTURES_PATH, "client/private_key.pem") + ssl_certfile = os.path.join( + CERTS_FIXTURES_PATH, "server/server_certificate.pem" + ) + ca_cert_path = os.path.join(CERTS_FIXTURES_PATH, "ca/ca_certificate_bundle.pem") + + cfg.CONF.set_override( + name="ssl_keyfile", override=ssl_keyfile, group="messaging" + ) + cfg.CONF.set_override( + name="ssl_certfile", override=ssl_certfile, group="messaging" + ) + cfg.CONF.set_override( + name="ssl_cert_reqs", override="required", group="messaging" + ) + cfg.CONF.set_override( + name="ssl_ca_certs", override=ca_cert_path, group="messaging" + ) + + connection = transport_utils.get_connection( + urls="amqp://guest:guest@127.0.0.1:5671/" + ) + + expected_msg = r"\[X509: KEY_VALUES_MISMATCH\] key values mismatch" self.assertRaisesRegexp(ssl.SSLError, expected_msg, connection.connect) diff --git a/st2common/tests/integration/test_register_content_script.py b/st2common/tests/integration/test_register_content_script.py index 1d7ca955f98..8082a853711 100644 --- a/st2common/tests/integration/test_register_content_script.py +++ b/st2common/tests/integration/test_register_content_script.py @@ -26,15 +26,15 @@ BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -SCRIPT_PATH = os.path.join(BASE_DIR, '../../bin/st2-register-content') +SCRIPT_PATH = os.path.join(BASE_DIR, "../../bin/st2-register-content") SCRIPT_PATH = os.path.abspath(SCRIPT_PATH) -BASE_CMD_ARGS = [sys.executable, SCRIPT_PATH, '--config-file=conf/st2.tests.conf', '-v'] -BASE_REGISTER_ACTIONS_CMD_ARGS = BASE_CMD_ARGS + ['--register-actions'] +BASE_CMD_ARGS = [sys.executable, SCRIPT_PATH, "--config-file=conf/st2.tests.conf", "-v"] +BASE_REGISTER_ACTIONS_CMD_ARGS = BASE_CMD_ARGS + ["--register-actions"] PACKS_PATH = get_fixtures_packs_base_path() -PACKS_COUNT = len(glob.glob('%s/*/pack.yaml' % (PACKS_PATH))) -assert(PACKS_COUNT >= 2) +PACKS_COUNT = len(glob.glob("%s/*/pack.yaml" % (PACKS_PATH))) +assert PACKS_COUNT >= 2 class ContentRegisterScriptTestCase(IntegrationTestCase): @@ -43,27 +43,27 @@ def setUp(self): test_config.parse_args() def test_register_from_pack_success(self): - pack_dir = os.path.join(get_fixtures_packs_base_path(), 'dummy_pack_1') - runner_dirs = os.path.join(get_fixtures_packs_base_path(), 'runners') + pack_dir = os.path.join(get_fixtures_packs_base_path(), "dummy_pack_1") + runner_dirs = os.path.join(get_fixtures_packs_base_path(), "runners") opts = [ - '--register-pack=%s' % (pack_dir), - '--register-runner-dir=%s' % (runner_dirs), + "--register-pack=%s" % (pack_dir), + "--register-runner-dir=%s" % (runner_dirs), ] cmd = BASE_REGISTER_ACTIONS_CMD_ARGS + opts exit_code, _, stderr = run_command(cmd=cmd) - self.assertIn('Registered 1 actions.', stderr) + self.assertIn("Registered 1 actions.", stderr) self.assertEqual(exit_code, 0) def test_register_from_pack_fail_on_failure_pack_dir_doesnt_exist(self): # No fail on failure flag, should succeed - pack_dir = 'doesntexistblah' - runner_dirs = os.path.join(get_fixtures_packs_base_path(), 'runners') + pack_dir = "doesntexistblah" + runner_dirs = os.path.join(get_fixtures_packs_base_path(), "runners") opts = [ - '--register-pack=%s' % (pack_dir), - '--register-runner-dir=%s' % (runner_dirs), - '--register-no-fail-on-failure' + "--register-pack=%s" % (pack_dir), + "--register-runner-dir=%s" % (runner_dirs), + "--register-no-fail-on-failure", ] cmd = BASE_REGISTER_ACTIONS_CMD_ARGS + opts exit_code, _, _ = run_command(cmd=cmd) @@ -71,9 +71,9 @@ def test_register_from_pack_fail_on_failure_pack_dir_doesnt_exist(self): # Fail on failure, should fail opts = [ - '--register-pack=%s' % (pack_dir), - '--register-runner-dir=%s' % (runner_dirs), - '--register-fail-on-failure' + "--register-pack=%s" % (pack_dir), + "--register-runner-dir=%s" % (runner_dirs), + "--register-fail-on-failure", ] cmd = BASE_REGISTER_ACTIONS_CMD_ARGS + opts exit_code, _, stderr = run_command(cmd=cmd) @@ -82,30 +82,30 @@ def test_register_from_pack_fail_on_failure_pack_dir_doesnt_exist(self): def test_register_from_pack_action_metadata_fails_validation(self): # No fail on failure flag, should succeed - pack_dir = os.path.join(get_fixtures_packs_base_path(), 'dummy_pack_4') - runner_dirs = os.path.join(get_fixtures_packs_base_path(), 'runners') + pack_dir = os.path.join(get_fixtures_packs_base_path(), "dummy_pack_4") + runner_dirs = os.path.join(get_fixtures_packs_base_path(), "runners") opts = [ - '--register-pack=%s' % (pack_dir), - '--register-no-fail-on-failure', - '--register-runner-dir=%s' % (runner_dirs), + "--register-pack=%s" % (pack_dir), + "--register-no-fail-on-failure", + "--register-runner-dir=%s" % (runner_dirs), ] cmd = BASE_REGISTER_ACTIONS_CMD_ARGS + opts exit_code, _, stderr = run_command(cmd=cmd) - self.assertIn('Registered 0 actions.', stderr) + self.assertIn("Registered 0 actions.", stderr) self.assertEqual(exit_code, 0) # Fail on failure, should fail - pack_dir = os.path.join(get_fixtures_packs_base_path(), 'dummy_pack_4') + pack_dir = os.path.join(get_fixtures_packs_base_path(), "dummy_pack_4") opts = [ - '--register-pack=%s' % (pack_dir), - '--register-fail-on-failure', - '--register-runner-dir=%s' % (runner_dirs), + "--register-pack=%s" % (pack_dir), + "--register-fail-on-failure", + "--register-runner-dir=%s" % (runner_dirs), ] cmd = BASE_REGISTER_ACTIONS_CMD_ARGS + opts exit_code, _, stderr = run_command(cmd=cmd) - self.assertIn('object has no attribute \'get\'', stderr) + self.assertIn("object has no attribute 'get'", stderr) self.assertEqual(exit_code, 1) def test_register_from_packs_doesnt_throw_on_missing_pack_resource_folder(self): @@ -114,44 +114,58 @@ def test_register_from_packs_doesnt_throw_on_missing_pack_resource_folder(self): # Note: We want to use a different config which sets fixtures/packs_1/ # dir as packs_base_paths - cmd = [sys.executable, SCRIPT_PATH, '--config-file=conf/st2.tests1.conf', '-v', - '--register-sensors'] + cmd = [ + sys.executable, + SCRIPT_PATH, + "--config-file=conf/st2.tests1.conf", + "-v", + "--register-sensors", + ] exit_code, _, stderr = run_command(cmd=cmd) - self.assertIn('Registered 0 sensors.', stderr, 'Actual stderr: %s' % (stderr)) + self.assertIn("Registered 0 sensors.", stderr, "Actual stderr: %s" % (stderr)) self.assertEqual(exit_code, 0) - cmd = [sys.executable, SCRIPT_PATH, '--config-file=conf/st2.tests1.conf', '-v', - '--register-all', '--register-no-fail-on-failure'] + cmd = [ + sys.executable, + SCRIPT_PATH, + "--config-file=conf/st2.tests1.conf", + "-v", + "--register-all", + "--register-no-fail-on-failure", + ] exit_code, _, stderr = run_command(cmd=cmd) - self.assertIn('Registered 0 actions.', stderr) - self.assertIn('Registered 0 sensors.', stderr) - self.assertIn('Registered 0 rules.', stderr) + self.assertIn("Registered 0 actions.", stderr) + self.assertIn("Registered 0 sensors.", stderr) + self.assertIn("Registered 0 rules.", stderr) self.assertEqual(exit_code, 0) def test_register_all_and_register_setup_virtualenvs(self): # Verify that --register-all works in combinations with --register-setup-virtualenvs # Single pack - pack_dir = os.path.join(get_fixtures_packs_base_path(), 'dummy_pack_1') + pack_dir = os.path.join(get_fixtures_packs_base_path(), "dummy_pack_1") cmd = BASE_CMD_ARGS + [ - '--register-pack=%s' % (pack_dir), - '--register-all', - '--register-setup-virtualenvs', - '--register-no-fail-on-failure' + "--register-pack=%s" % (pack_dir), + "--register-all", + "--register-setup-virtualenvs", + "--register-no-fail-on-failure", ] exit_code, stdout, stderr = run_command(cmd=cmd) - self.assertIn('Registering actions', stderr, 'Actual stderr: %s' % (stderr)) - self.assertIn('Registering rules', stderr) - self.assertIn('Setup virtualenv for %s pack(s)' % ('1'), stderr) + self.assertIn("Registering actions", stderr, "Actual stderr: %s" % (stderr)) + self.assertIn("Registering rules", stderr) + self.assertIn("Setup virtualenv for %s pack(s)" % ("1"), stderr) self.assertEqual(exit_code, 0) def test_register_setup_virtualenvs(self): # Single pack - pack_dir = os.path.join(get_fixtures_packs_base_path(), 'dummy_pack_1') + pack_dir = os.path.join(get_fixtures_packs_base_path(), "dummy_pack_1") - cmd = BASE_CMD_ARGS + ['--register-pack=%s' % (pack_dir), '--register-setup-virtualenvs', - '--register-no-fail-on-failure'] + cmd = BASE_CMD_ARGS + [ + "--register-pack=%s" % (pack_dir), + "--register-setup-virtualenvs", + "--register-no-fail-on-failure", + ] exit_code, stdout, stderr = run_command(cmd=cmd) self.assertIn('Setting up virtualenv for pack "dummy_pack_1"', stderr) - self.assertIn('Setup virtualenv for 1 pack(s)', stderr) + self.assertIn("Setup virtualenv for 1 pack(s)", stderr) self.assertEqual(exit_code, 0) diff --git a/st2common/tests/integration/test_service_setup_log_level_filtering.py b/st2common/tests/integration/test_service_setup_log_level_filtering.py index ac3f90deaf1..a03e90688a5 100644 --- a/st2common/tests/integration/test_service_setup_log_level_filtering.py +++ b/st2common/tests/integration/test_service_setup_log_level_filtering.py @@ -25,36 +25,42 @@ from st2tests.base import IntegrationTestCase from st2tests.fixturesloader import get_fixtures_base_path -__all__ = [ - 'ServiceSetupLogLevelFilteringTestCase' -] +__all__ = ["ServiceSetupLogLevelFilteringTestCase"] BASE_DIR = os.path.dirname(os.path.abspath(__file__)) FIXTURES_DIR = get_fixtures_base_path() -ST2_CONFIG_INFO_LL_PATH = os.path.join(FIXTURES_DIR, 'conf/st2.tests.api.info_log_level.conf') +ST2_CONFIG_INFO_LL_PATH = os.path.join( + FIXTURES_DIR, "conf/st2.tests.api.info_log_level.conf" +) ST2_CONFIG_INFO_LL_PATH = os.path.abspath(ST2_CONFIG_INFO_LL_PATH) -ST2_CONFIG_DEBUG_LL_PATH = os.path.join(FIXTURES_DIR, 'conf/st2.tests.api.debug_log_level.conf') +ST2_CONFIG_DEBUG_LL_PATH = os.path.join( + FIXTURES_DIR, "conf/st2.tests.api.debug_log_level.conf" +) ST2_CONFIG_DEBUG_LL_PATH = os.path.abspath(ST2_CONFIG_DEBUG_LL_PATH) -ST2_CONFIG_AUDIT_LL_PATH = os.path.join(FIXTURES_DIR, 'conf/st2.tests.api.audit_log_level.conf') +ST2_CONFIG_AUDIT_LL_PATH = os.path.join( + FIXTURES_DIR, "conf/st2.tests.api.audit_log_level.conf" +) ST2_CONFIG_AUDIT_LL_PATH = os.path.abspath(ST2_CONFIG_AUDIT_LL_PATH) -ST2_CONFIG_SYSTEM_DEBUG_PATH = os.path.join(FIXTURES_DIR, - 'conf/st2.tests.api.system_debug_true.conf') +ST2_CONFIG_SYSTEM_DEBUG_PATH = os.path.join( + FIXTURES_DIR, "conf/st2.tests.api.system_debug_true.conf" +) ST2_CONFIG_SYSTEM_DEBUG_PATH = os.path.abspath(ST2_CONFIG_SYSTEM_DEBUG_PATH) -ST2_CONFIG_SYSTEM_LL_DEBUG_PATH = os.path.join(FIXTURES_DIR, - 'conf/st2.tests.api.system_debug_true_logging_debug.conf') +ST2_CONFIG_SYSTEM_LL_DEBUG_PATH = os.path.join( + FIXTURES_DIR, "conf/st2.tests.api.system_debug_true_logging_debug.conf" +) PYTHON_BINARY = sys.executable -ST2API_BINARY = os.path.join(BASE_DIR, '../../../st2api/bin/st2api') +ST2API_BINARY = os.path.join(BASE_DIR, "../../../st2api/bin/st2api") ST2API_BINARY = os.path.abspath(ST2API_BINARY) -CMD = [PYTHON_BINARY, ST2API_BINARY, '--config-file'] +CMD = [PYTHON_BINARY, ST2API_BINARY, "--config-file"] class ServiceSetupLogLevelFilteringTestCase(IntegrationTestCase): @@ -68,11 +74,11 @@ def test_audit_log_level_is_filtered_if_log_level_is_not_debug_or_audit(self): process.send_signal(signal.SIGKILL) # First 3 log lines are debug messages about the environment which are always logged - stdout = '\n'.join(process.stdout.read().decode('utf-8').split('\n')[3:]) + stdout = "\n".join(process.stdout.read().decode("utf-8").split("\n")[3:]) - self.assertIn('INFO [-]', stdout) - self.assertNotIn('DEBUG [-]', stdout) - self.assertNotIn('AUDIT [-]', stdout) + self.assertIn("INFO [-]", stdout) + self.assertNotIn("DEBUG [-]", stdout) + self.assertNotIn("AUDIT [-]", stdout) # 2. DEBUG log level - audit messages should be included process = self._start_process(config_path=ST2_CONFIG_DEBUG_LL_PATH) @@ -83,11 +89,11 @@ def test_audit_log_level_is_filtered_if_log_level_is_not_debug_or_audit(self): process.send_signal(signal.SIGKILL) # First 3 log lines are debug messages about the environment which are always logged - stdout = '\n'.join(process.stdout.read().decode('utf-8').split('\n')[3:]) + stdout = "\n".join(process.stdout.read().decode("utf-8").split("\n")[3:]) - self.assertIn('INFO [-]', stdout) - self.assertIn('DEBUG [-]', stdout) - self.assertIn('AUDIT [-]', stdout) + self.assertIn("INFO [-]", stdout) + self.assertIn("DEBUG [-]", stdout) + self.assertIn("AUDIT [-]", stdout) # 3. AUDIT log level - audit messages should be included process = self._start_process(config_path=ST2_CONFIG_AUDIT_LL_PATH) @@ -98,11 +104,11 @@ def test_audit_log_level_is_filtered_if_log_level_is_not_debug_or_audit(self): process.send_signal(signal.SIGKILL) # First 3 log lines are debug messages about the environment which are always logged - stdout = '\n'.join(process.stdout.read().decode('utf-8').split('\n')[3:]) + stdout = "\n".join(process.stdout.read().decode("utf-8").split("\n")[3:]) - self.assertNotIn('INFO [-]', stdout) - self.assertNotIn('DEBUG [-]', stdout) - self.assertIn('AUDIT [-]', stdout) + self.assertNotIn("INFO [-]", stdout) + self.assertNotIn("DEBUG [-]", stdout) + self.assertIn("AUDIT [-]", stdout) # 2. INFO log level but system.debug set to True process = self._start_process(config_path=ST2_CONFIG_SYSTEM_DEBUG_PATH) @@ -113,11 +119,11 @@ def test_audit_log_level_is_filtered_if_log_level_is_not_debug_or_audit(self): process.send_signal(signal.SIGKILL) # First 3 log lines are debug messages about the environment which are always logged - stdout = '\n'.join(process.stdout.read().decode('utf-8').split('\n')[3:]) + stdout = "\n".join(process.stdout.read().decode("utf-8").split("\n")[3:]) - self.assertIn('INFO [-]', stdout) - self.assertIn('DEBUG [-]', stdout) - self.assertIn('AUDIT [-]', stdout) + self.assertIn("INFO [-]", stdout) + self.assertIn("DEBUG [-]", stdout) + self.assertIn("AUDIT [-]", stdout) def test_kombu_heartbeat_tick_log_messages_are_excluded(self): # 1. system.debug = True config option is set, verify heartbeat_tick message is not logged @@ -128,8 +134,8 @@ def test_kombu_heartbeat_tick_log_messages_are_excluded(self): eventlet.sleep(5) process.send_signal(signal.SIGKILL) - stdout = '\n'.join(process.stdout.read().decode('utf-8').split('\n')) - self.assertNotIn('heartbeat_tick', stdout) + stdout = "\n".join(process.stdout.read().decode("utf-8").split("\n")) + self.assertNotIn("heartbeat_tick", stdout) # 2. system.debug = False, log level is set to debug process = self._start_process(config_path=ST2_CONFIG_DEBUG_LL_PATH) @@ -139,14 +145,19 @@ def test_kombu_heartbeat_tick_log_messages_are_excluded(self): eventlet.sleep(5) process.send_signal(signal.SIGKILL) - stdout = '\n'.join(process.stdout.read().decode('utf-8').split('\n')) - self.assertNotIn('heartbeat_tick', stdout) + stdout = "\n".join(process.stdout.read().decode("utf-8").split("\n")) + self.assertNotIn("heartbeat_tick", stdout) def _start_process(self, config_path): cmd = CMD + [config_path] - cwd = os.path.abspath(os.path.join(BASE_DIR, '../../../')) + cwd = os.path.abspath(os.path.join(BASE_DIR, "../../../")) cwd = os.path.abspath(cwd) - process = subprocess.Popen(cmd, cwd=cwd, - stdout=subprocess.PIPE, stderr=subprocess.PIPE, - shell=False, preexec_fn=os.setsid) + process = subprocess.Popen( + cmd, + cwd=cwd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=False, + preexec_fn=os.setsid, + ) return process diff --git a/st2common/tests/unit/base.py b/st2common/tests/unit/base.py index 6a22b139dbe..65948d1d112 100644 --- a/st2common/tests/unit/base.py +++ b/st2common/tests/unit/base.py @@ -24,13 +24,11 @@ from st2common.exceptions.db import StackStormDBObjectNotFoundError __all__ = [ - 'BaseDBModelCRUDTestCase', - - 'FakeModel', - 'FakeModelDB', - - 'ChangeRevFakeModel', - 'ChangeRevFakeModelDB' + "BaseDBModelCRUDTestCase", + "FakeModel", + "FakeModelDB", + "ChangeRevFakeModel", + "ChangeRevFakeModelDB", ] @@ -57,19 +55,26 @@ def test_crud_operations(self): self.assertEqual(getattr(retrieved_db, attribute_name), attribute_value) # 2. Test update - updated_attribute_value = 'updated-%s' % (str(time.time())) + updated_attribute_value = "updated-%s" % (str(time.time())) setattr(model_db, self.update_attribute_name, updated_attribute_value) saved_db = self.persistance_class.add_or_update(model_db) - self.assertEqual(getattr(saved_db, self.update_attribute_name), updated_attribute_value) + self.assertEqual( + getattr(saved_db, self.update_attribute_name), updated_attribute_value + ) retrieved_db = self.persistance_class.get_by_id(saved_db.id) self.assertEqual(saved_db.id, retrieved_db.id) - self.assertEqual(getattr(retrieved_db, self.update_attribute_name), updated_attribute_value) + self.assertEqual( + getattr(retrieved_db, self.update_attribute_name), updated_attribute_value + ) # 3. Test delete self.persistance_class.delete(model_db) - self.assertRaises(StackStormDBObjectNotFoundError, self.persistance_class.get_by_id, - model_db.id) + self.assertRaises( + StackStormDBObjectNotFoundError, + self.persistance_class.get_by_id, + model_db.id, + ) class FakeModelDB(stormbase.StormBaseDB): @@ -79,11 +84,11 @@ class FakeModelDB(stormbase.StormBaseDB): timestamp = mongoengine.DateTimeField() meta = { - 'indexes': [ - {'fields': ['index']}, - {'fields': ['category']}, - {'fields': ['timestamp']}, - {'fields': ['context.user']}, + "indexes": [ + {"fields": ["index"]}, + {"fields": ["category"]}, + {"fields": ["timestamp"]}, + {"fields": ["context.user"]}, ] } diff --git a/st2common/tests/unit/services/test_access.py b/st2common/tests/unit/services/test_access.py index 79e680b30df..4f7d8169b4c 100644 --- a/st2common/tests/unit/services/test_access.py +++ b/st2common/tests/unit/services/test_access.py @@ -28,11 +28,10 @@ import st2tests.config as tests_config -USERNAME = 'manas' +USERNAME = "manas" class AccessServiceTest(DbTestCase): - @classmethod def setUpClass(cls): super(AccessServiceTest, cls).setUpClass() @@ -47,7 +46,7 @@ def test_create_token(self): def test_create_token_fail(self): try: access.create_token(None) - self.assertTrue(False, 'Create succeeded was expected to fail.') + self.assertTrue(False, "Create succeeded was expected to fail.") except ValueError: self.assertTrue(True) @@ -56,7 +55,7 @@ def test_delete_token(self): access.delete_token(token.token) try: token = Token.get(token.token) - self.assertTrue(False, 'Delete failed was expected to pass.') + self.assertTrue(False, "Delete failed was expected to pass.") except TokenNotFoundError: self.assertTrue(True) @@ -71,13 +70,17 @@ def test_create_token_ttl_ok(self): self.assertIsNotNone(token) self.assertIsNotNone(token.token) self.assertEqual(token.user, USERNAME) - expected_expiry = date_utils.get_datetime_utc_now() + datetime.timedelta(seconds=ttl) + expected_expiry = date_utils.get_datetime_utc_now() + datetime.timedelta( + seconds=ttl + ) expected_expiry = date_utils.add_utc_tz(expected_expiry) self.assertLess(isotime.parse(token.expiry), expected_expiry) def test_create_token_ttl_capped(self): ttl = cfg.CONF.auth.token_ttl + 10 - expected_expiry = date_utils.get_datetime_utc_now() + datetime.timedelta(seconds=ttl) + expected_expiry = date_utils.get_datetime_utc_now() + datetime.timedelta( + seconds=ttl + ) expected_expiry = date_utils.add_utc_tz(expected_expiry) token = access.create_token(USERNAME, 10) self.assertIsNotNone(token) @@ -86,11 +89,13 @@ def test_create_token_ttl_capped(self): self.assertLess(isotime.parse(token.expiry), expected_expiry) def test_create_token_service_token_can_use_arbitrary_ttl(self): - ttl = (10000 * 24 * 24) + ttl = 10000 * 24 * 24 # Service token should support arbitrary TTL token = access.create_token(USERNAME, ttl=ttl, service=True) - expected_expiry = date_utils.get_datetime_utc_now() + datetime.timedelta(seconds=ttl) + expected_expiry = date_utils.get_datetime_utc_now() + datetime.timedelta( + seconds=ttl + ) expected_expiry = date_utils.add_utc_tz(expected_expiry) self.assertIsNotNone(token) @@ -98,5 +103,6 @@ def test_create_token_service_token_can_use_arbitrary_ttl(self): self.assertLess(isotime.parse(token.expiry), expected_expiry) # Non service token should throw on TTL which is too large - self.assertRaises(TTLTooLargeException, access.create_token, USERNAME, ttl=ttl, - service=False) + self.assertRaises( + TTLTooLargeException, access.create_token, USERNAME, ttl=ttl, service=False + ) diff --git a/st2common/tests/unit/services/test_action.py b/st2common/tests/unit/services/test_action.py index 7bda929cc0c..ab8db723298 100644 --- a/st2common/tests/unit/services/test_action.py +++ b/st2common/tests/unit/services/test_action.py @@ -39,145 +39,126 @@ RUNNER = { - 'name': 'local-shell-script', - 'description': 'A runner to execute local command.', - 'enabled': True, - 'runner_parameters': { - 'hosts': {'type': 'string'}, - 'cmd': {'type': 'string'}, - 'sudo': {'type': 'boolean', 'default': False} + "name": "local-shell-script", + "description": "A runner to execute local command.", + "enabled": True, + "runner_parameters": { + "hosts": {"type": "string"}, + "cmd": {"type": "string"}, + "sudo": {"type": "boolean", "default": False}, }, - 'runner_module': 'remoterunner' + "runner_module": "remoterunner", } RUNNER_ACTION_CHAIN = { - 'name': 'action-chain', - 'description': 'AC runner.', - 'enabled': True, - 'runner_parameters': { - }, - 'runner_module': 'remoterunner' + "name": "action-chain", + "description": "AC runner.", + "enabled": True, + "runner_parameters": {}, + "runner_module": "remoterunner", } ACTION = { - 'name': 'my.action', - 'description': 'my test', - 'enabled': True, - 'entry_point': '/tmp/test/action.sh', - 'pack': 'default', - 'runner_type': 'local-shell-script', - 'parameters': { - 'arg_default_value': { - 'type': 'string', - 'default': 'abc' - }, - 'arg_default_type': { - } + "name": "my.action", + "description": "my test", + "enabled": True, + "entry_point": "/tmp/test/action.sh", + "pack": "default", + "runner_type": "local-shell-script", + "parameters": { + "arg_default_value": {"type": "string", "default": "abc"}, + "arg_default_type": {}, }, - 'notify': { - 'on-complete': { - 'message': 'My awesome action is complete. Party time!!!', - 'routes': ['notify.slack'] + "notify": { + "on-complete": { + "message": "My awesome action is complete. Party time!!!", + "routes": ["notify.slack"], } - } + }, } ACTION_WORKFLOW = { - 'name': 'my.wf_action', - 'description': 'my test', - 'enabled': True, - 'entry_point': '/tmp/test/action.sh', - 'pack': 'default', - 'runner_type': 'action-chain' + "name": "my.wf_action", + "description": "my test", + "enabled": True, + "entry_point": "/tmp/test/action.sh", + "pack": "default", + "runner_type": "action-chain", } ACTION_OVR_PARAM = { - 'name': 'my.sudo.default.action', - 'description': 'my test', - 'enabled': True, - 'entry_point': '/tmp/test/action.sh', - 'pack': 'default', - 'runner_type': 'local-shell-script', - 'parameters': { - 'sudo': { - 'default': True - } - } + "name": "my.sudo.default.action", + "description": "my test", + "enabled": True, + "entry_point": "/tmp/test/action.sh", + "pack": "default", + "runner_type": "local-shell-script", + "parameters": {"sudo": {"default": True}}, } ACTION_OVR_PARAM_MUTABLE = { - 'name': 'my.sudo.mutable.action', - 'description': 'my test', - 'enabled': True, - 'entry_point': '/tmp/test/action.sh', - 'pack': 'default', - 'runner_type': 'local-shell-script', - 'parameters': { - 'sudo': { - 'immutable': False - } - } + "name": "my.sudo.mutable.action", + "description": "my test", + "enabled": True, + "entry_point": "/tmp/test/action.sh", + "pack": "default", + "runner_type": "local-shell-script", + "parameters": {"sudo": {"immutable": False}}, } ACTION_OVR_PARAM_IMMUTABLE = { - 'name': 'my.sudo.immutable.action', - 'description': 'my test', - 'enabled': True, - 'entry_point': '/tmp/test/action.sh', - 'pack': 'default', - 'runner_type': 'local-shell-script', - 'parameters': { - 'sudo': { - 'immutable': True - } - } + "name": "my.sudo.immutable.action", + "description": "my test", + "enabled": True, + "entry_point": "/tmp/test/action.sh", + "pack": "default", + "runner_type": "local-shell-script", + "parameters": {"sudo": {"immutable": True}}, } ACTION_OVR_PARAM_BAD_ATTR = { - 'name': 'my.sudo.invalid.action', - 'description': 'my test', - 'enabled': True, - 'entry_point': '/tmp/test/action.sh', - 'pack': 'default', - 'runner_type': 'local-shell-script', - 'parameters': { - 'sudo': { - 'type': 'number' - } - } + "name": "my.sudo.invalid.action", + "description": "my test", + "enabled": True, + "entry_point": "/tmp/test/action.sh", + "pack": "default", + "runner_type": "local-shell-script", + "parameters": {"sudo": {"type": "number"}}, } ACTION_OVR_PARAM_BAD_ATTR_NOOP = { - 'name': 'my.sudo.invalid.noop.action', - 'description': 'my test', - 'enabled': True, - 'entry_point': '/tmp/test/action.sh', - 'pack': 'default', - 'runner_type': 'local-shell-script', - 'parameters': { - 'sudo': { - 'type': 'boolean' - } - } + "name": "my.sudo.invalid.noop.action", + "description": "my test", + "enabled": True, + "entry_point": "/tmp/test/action.sh", + "pack": "default", + "runner_type": "local-shell-script", + "parameters": {"sudo": {"type": "boolean"}}, } -PACK = 'default' -ACTION_REF = ResourceReference(name='my.action', pack=PACK).ref -ACTION_WORKFLOW_REF = ResourceReference(name='my.wf_action', pack=PACK).ref -ACTION_OVR_PARAM_REF = ResourceReference(name='my.sudo.default.action', pack=PACK).ref -ACTION_OVR_PARAM_MUTABLE_REF = ResourceReference(name='my.sudo.mutable.action', pack=PACK).ref -ACTION_OVR_PARAM_IMMUTABLE_REF = ResourceReference(name='my.sudo.immutable.action', pack=PACK).ref -ACTION_OVR_PARAM_BAD_ATTR_REF = ResourceReference(name='my.sudo.invalid.action', pack=PACK).ref +PACK = "default" +ACTION_REF = ResourceReference(name="my.action", pack=PACK).ref +ACTION_WORKFLOW_REF = ResourceReference(name="my.wf_action", pack=PACK).ref +ACTION_OVR_PARAM_REF = ResourceReference(name="my.sudo.default.action", pack=PACK).ref +ACTION_OVR_PARAM_MUTABLE_REF = ResourceReference( + name="my.sudo.mutable.action", pack=PACK +).ref +ACTION_OVR_PARAM_IMMUTABLE_REF = ResourceReference( + name="my.sudo.immutable.action", pack=PACK +).ref +ACTION_OVR_PARAM_BAD_ATTR_REF = ResourceReference( + name="my.sudo.invalid.action", pack=PACK +).ref ACTION_OVR_PARAM_BAD_ATTR_NOOP_REF = ResourceReference( - name='my.sudo.invalid.noop.action', pack=PACK).ref + name="my.sudo.invalid.noop.action", pack=PACK +).ref -USERNAME = 'stanley' +USERNAME = "stanley" -@mock.patch.object(runners_utils, 'invoke_post_run', mock.MagicMock(return_value=None)) -@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) +@mock.patch.object(runners_utils, "invoke_post_run", mock.MagicMock(return_value=None)) +@mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) class TestActionExecutionService(DbTestCase): - @classmethod def setUpClass(cls): super(TestActionExecutionService, cls).setUpClass() @@ -188,17 +169,21 @@ def setUpClass(cls): RunnerType.add_or_update(RunnerTypeAPI.to_model(runner_api)) cls.actions = { - ACTION['name']: ActionAPI(**ACTION), - ACTION_WORKFLOW['name']: ActionAPI(**ACTION_WORKFLOW), - ACTION_OVR_PARAM['name']: ActionAPI(**ACTION_OVR_PARAM), - ACTION_OVR_PARAM_MUTABLE['name']: ActionAPI(**ACTION_OVR_PARAM_MUTABLE), - ACTION_OVR_PARAM_IMMUTABLE['name']: ActionAPI(**ACTION_OVR_PARAM_IMMUTABLE), - ACTION_OVR_PARAM_BAD_ATTR['name']: ActionAPI(**ACTION_OVR_PARAM_BAD_ATTR), - ACTION_OVR_PARAM_BAD_ATTR_NOOP['name']: ActionAPI(**ACTION_OVR_PARAM_BAD_ATTR_NOOP) + ACTION["name"]: ActionAPI(**ACTION), + ACTION_WORKFLOW["name"]: ActionAPI(**ACTION_WORKFLOW), + ACTION_OVR_PARAM["name"]: ActionAPI(**ACTION_OVR_PARAM), + ACTION_OVR_PARAM_MUTABLE["name"]: ActionAPI(**ACTION_OVR_PARAM_MUTABLE), + ACTION_OVR_PARAM_IMMUTABLE["name"]: ActionAPI(**ACTION_OVR_PARAM_IMMUTABLE), + ACTION_OVR_PARAM_BAD_ATTR["name"]: ActionAPI(**ACTION_OVR_PARAM_BAD_ATTR), + ACTION_OVR_PARAM_BAD_ATTR_NOOP["name"]: ActionAPI( + **ACTION_OVR_PARAM_BAD_ATTR_NOOP + ), } - cls.actiondbs = {name: Action.add_or_update(ActionAPI.to_model(action)) - for name, action in six.iteritems(cls.actions)} + cls.actiondbs = { + name: Action.add_or_update(ActionAPI.to_model(action)) + for name, action in six.iteritems(cls.actions) + } cls.container = RunnerContainer() @@ -212,8 +197,8 @@ def tearDownClass(cls): super(TestActionExecutionService, cls).tearDownClass() def _submit_request(self, action_ref=ACTION_REF): - context = {'user': USERNAME} - parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a'} + context = {"user": USERNAME} + parameters = {"hosts": "127.0.0.1", "cmd": "uname -a"} req = LiveActionDB(action=action_ref, context=context, parameters=parameters) req, _ = action_service.request(req) ex = action_db.get_liveaction_by_id(str(req.id)) @@ -249,7 +234,7 @@ def _create_nested_executions(self, depth=2): root_liveaction_db = LiveAction.add_or_update(root_liveaction_db) root_ex = executions.create_execution_object(root_liveaction_db) - last_id = root_ex['id'] + last_id = root_ex["id"] # Create children to the specified depth for i in range(depth): @@ -264,11 +249,7 @@ def _create_nested_executions(self, depth=2): child_liveaction_db = LiveActionDB() child_liveaction_db.status = action_constants.LIVEACTION_STATUS_PAUSED child_liveaction_db.action = action - child_liveaction_db.context = { - "parent": { - "execution_id": last_id - } - } + child_liveaction_db.context = {"parent": {"execution_id": last_id}} child_liveaction_db = LiveAction.add_or_update(child_liveaction_db) parent_ex = executions.create_execution_object(child_liveaction_db) last_id = parent_ex.id @@ -277,104 +258,116 @@ def _create_nested_executions(self, depth=2): return (child_liveaction_db, root_liveaction_db) def test_req_non_workflow_action(self): - actiondb = self.actiondbs[ACTION['name']] + actiondb = self.actiondbs[ACTION["name"]] req, ex = self._submit_request(action_ref=ACTION_REF) self.assertIsNotNone(ex) self.assertEqual(ex.action_is_workflow, False) self.assertEqual(ex.id, req.id) - self.assertEqual(ex.action, '.'.join([actiondb.pack, actiondb.name])) - self.assertEqual(ex.context['user'], req.context['user']) + self.assertEqual(ex.action, ".".join([actiondb.pack, actiondb.name])) + self.assertEqual(ex.context["user"], req.context["user"]) self.assertDictEqual(ex.parameters, req.parameters) self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_REQUESTED) self.assertIsNotNone(ex.notify) # mongoengine DateTimeField stores datetime only up to milliseconds - self.assertEqual(isotime.format(ex.start_timestamp, usec=False), - isotime.format(req.start_timestamp, usec=False)) + self.assertEqual( + isotime.format(ex.start_timestamp, usec=False), + isotime.format(req.start_timestamp, usec=False), + ) def test_req_workflow_action(self): - actiondb = self.actiondbs[ACTION_WORKFLOW['name']] + actiondb = self.actiondbs[ACTION_WORKFLOW["name"]] req, ex = self._submit_request(action_ref=ACTION_WORKFLOW_REF) self.assertIsNotNone(ex) self.assertEqual(ex.action_is_workflow, True) self.assertEqual(ex.id, req.id) - self.assertEqual(ex.action, '.'.join([actiondb.pack, actiondb.name])) - self.assertEqual(ex.context['user'], req.context['user']) + self.assertEqual(ex.action, ".".join([actiondb.pack, actiondb.name])) + self.assertEqual(ex.context["user"], req.context["user"]) self.assertDictEqual(ex.parameters, req.parameters) self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_REQUESTED) def test_req_invalid_parameters(self): - parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a', 'arg_default_value': 123} + parameters = {"hosts": "127.0.0.1", "cmd": "uname -a", "arg_default_value": 123} liveaction = LiveActionDB(action=ACTION_REF, parameters=parameters) - self.assertRaises(jsonschema.ValidationError, action_service.request, liveaction) + self.assertRaises( + jsonschema.ValidationError, action_service.request, liveaction + ) def test_req_optional_parameter_none_value(self): - parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a', 'arg_default_value': None} + parameters = { + "hosts": "127.0.0.1", + "cmd": "uname -a", + "arg_default_value": None, + } req = LiveActionDB(action=ACTION_REF, parameters=parameters) req, _ = action_service.request(req) def test_req_optional_parameter_none_value_no_default(self): - parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a', 'arg_default_type': None} + parameters = {"hosts": "127.0.0.1", "cmd": "uname -a", "arg_default_type": None} req = LiveActionDB(action=ACTION_REF, parameters=parameters) req, _ = action_service.request(req) def test_req_override_runner_parameter(self): - parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a'} + parameters = {"hosts": "127.0.0.1", "cmd": "uname -a"} req = LiveActionDB(action=ACTION_OVR_PARAM_REF, parameters=parameters) req, _ = action_service.request(req) - parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a', 'sudo': False} + parameters = {"hosts": "127.0.0.1", "cmd": "uname -a", "sudo": False} req = LiveActionDB(action=ACTION_OVR_PARAM_REF, parameters=parameters) req, _ = action_service.request(req) def test_req_override_runner_parameter_type_attribute_value_changed(self): - parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a'} + parameters = {"hosts": "127.0.0.1", "cmd": "uname -a"} req = LiveActionDB(action=ACTION_OVR_PARAM_BAD_ATTR_REF, parameters=parameters) with self.assertRaises(action_exc.InvalidActionParameterException) as ex_ctx: req, _ = action_service.request(req) - expected = ('The attribute "type" for the runner parameter "sudo" in ' - 'action "default.my.sudo.invalid.action" cannot be overridden.') + expected = ( + 'The attribute "type" for the runner parameter "sudo" in ' + 'action "default.my.sudo.invalid.action" cannot be overridden.' + ) self.assertEqual(str(ex_ctx.exception), expected) def test_req_override_runner_parameter_type_attribute_no_value_changed(self): - parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a'} - req = LiveActionDB(action=ACTION_OVR_PARAM_BAD_ATTR_NOOP_REF, parameters=parameters) + parameters = {"hosts": "127.0.0.1", "cmd": "uname -a"} + req = LiveActionDB( + action=ACTION_OVR_PARAM_BAD_ATTR_NOOP_REF, parameters=parameters + ) req, _ = action_service.request(req) def test_req_override_runner_parameter_mutable(self): - parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a'} + parameters = {"hosts": "127.0.0.1", "cmd": "uname -a"} req = LiveActionDB(action=ACTION_OVR_PARAM_MUTABLE_REF, parameters=parameters) req, _ = action_service.request(req) - parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a', 'sudo': True} + parameters = {"hosts": "127.0.0.1", "cmd": "uname -a", "sudo": True} req = LiveActionDB(action=ACTION_OVR_PARAM_MUTABLE_REF, parameters=parameters) req, _ = action_service.request(req) def test_req_override_runner_parameter_immutable(self): - parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a'} + parameters = {"hosts": "127.0.0.1", "cmd": "uname -a"} req = LiveActionDB(action=ACTION_OVR_PARAM_IMMUTABLE_REF, parameters=parameters) req, _ = action_service.request(req) - parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a', 'sudo': True} + parameters = {"hosts": "127.0.0.1", "cmd": "uname -a", "sudo": True} req = LiveActionDB(action=ACTION_OVR_PARAM_IMMUTABLE_REF, parameters=parameters) self.assertRaises(ValueError, action_service.request, req) def test_req_nonexistent_action(self): - parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a'} - action_ref = ResourceReference(name='i.action', pack='default').ref + parameters = {"hosts": "127.0.0.1", "cmd": "uname -a"} + action_ref = ResourceReference(name="i.action", pack="default").ref ex = LiveActionDB(action=action_ref, parameters=parameters) self.assertRaises(ValueError, action_service.request, ex) def test_req_disabled_action(self): - actiondb = self.actiondbs[ACTION['name']] + actiondb = self.actiondbs[ACTION["name"]] actiondb.enabled = False Action.add_or_update(actiondb) try: - parameters = {'hosts': '127.0.0.1', 'cmd': 'uname -a'} + parameters = {"hosts": "127.0.0.1", "cmd": "uname -a"} ex = LiveActionDB(action=ACTION_REF, parameters=parameters) self.assertRaises(ValueError, action_service.request, ex) except Exception as e: @@ -390,7 +383,9 @@ def test_req_cancellation(self): self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_REQUESTED) # Update ex status to RUNNING. - action_service.update_status(ex, action_constants.LIVEACTION_STATUS_RUNNING, False) + action_service.update_status( + ex, action_constants.LIVEACTION_STATUS_RUNNING, False + ) ex = action_db.get_liveaction_by_id(ex.id) self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_RUNNING) @@ -405,7 +400,9 @@ def test_req_cancellation_uncancelable_state(self): self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_REQUESTED) # Update ex status to FAILED. - action_service.update_status(ex, action_constants.LIVEACTION_STATUS_FAILED, False) + action_service.update_status( + ex, action_constants.LIVEACTION_STATUS_FAILED, False + ) ex = action_db.get_liveaction_by_id(ex.id) self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_FAILED) @@ -429,20 +426,20 @@ def test_req_pause_unsupported(self): self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_REQUESTED) # Update ex status to RUNNING. - action_service.update_status(ex, action_constants.LIVEACTION_STATUS_RUNNING, False) + action_service.update_status( + ex, action_constants.LIVEACTION_STATUS_RUNNING, False + ) ex = action_db.get_liveaction_by_id(ex.id) self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_RUNNING) # Request pause. self.assertRaises( - runner_exc.InvalidActionRunnerOperationError, - self._submit_pause, - ex + runner_exc.InvalidActionRunnerOperationError, self._submit_pause, ex ) def test_req_pause(self): # Add the runner type to the list of runners that support pause and resume. - action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION["runner_type"]) try: req, ex = self._submit_request() @@ -451,7 +448,9 @@ def test_req_pause(self): self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_REQUESTED) # Update ex status to RUNNING. - action_service.update_status(ex, action_constants.LIVEACTION_STATUS_RUNNING, False) + action_service.update_status( + ex, action_constants.LIVEACTION_STATUS_RUNNING, False + ) ex = action_db.get_liveaction_by_id(ex.id) self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_RUNNING) @@ -459,11 +458,11 @@ def test_req_pause(self): ex = self._submit_pause(ex) self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_PAUSING) finally: - action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION["runner_type"]) def test_req_pause_not_running(self): # Add the runner type to the list of runners that support pause and resume. - action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION["runner_type"]) try: req, ex = self._submit_request() @@ -473,16 +472,14 @@ def test_req_pause_not_running(self): # Request pause. self.assertRaises( - runner_exc.UnexpectedActionExecutionStatusError, - self._submit_pause, - ex + runner_exc.UnexpectedActionExecutionStatusError, self._submit_pause, ex ) finally: - action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION["runner_type"]) def test_req_pause_already_pausing(self): # Add the runner type to the list of runners that support pause and resume. - action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION["runner_type"]) try: req, ex = self._submit_request() @@ -491,7 +488,9 @@ def test_req_pause_already_pausing(self): self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_REQUESTED) # Update ex status to RUNNING. - action_service.update_status(ex, action_constants.LIVEACTION_STATUS_RUNNING, False) + action_service.update_status( + ex, action_constants.LIVEACTION_STATUS_RUNNING, False + ) ex = action_db.get_liveaction_by_id(ex.id) self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_RUNNING) @@ -500,12 +499,14 @@ def test_req_pause_already_pausing(self): self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_PAUSING) # Request pause again. - with mock.patch.object(action_service, 'update_status', return_value=None) as mocked: + with mock.patch.object( + action_service, "update_status", return_value=None + ) as mocked: ex = self._submit_pause(ex) self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_PAUSING) mocked.assert_not_called() finally: - action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION["runner_type"]) def test_req_resume_unsupported(self): req, ex = self._submit_request() @@ -514,20 +515,20 @@ def test_req_resume_unsupported(self): self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_REQUESTED) # Update ex status to RUNNING. - action_service.update_status(ex, action_constants.LIVEACTION_STATUS_RUNNING, False) + action_service.update_status( + ex, action_constants.LIVEACTION_STATUS_RUNNING, False + ) ex = action_db.get_liveaction_by_id(ex.id) self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_RUNNING) # Request resume. self.assertRaises( - runner_exc.InvalidActionRunnerOperationError, - self._submit_resume, - ex + runner_exc.InvalidActionRunnerOperationError, self._submit_resume, ex ) def test_req_resume(self): # Add the runner type to the list of runners that support pause and resume. - action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION["runner_type"]) try: req, ex = self._submit_request() @@ -536,7 +537,9 @@ def test_req_resume(self): self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_REQUESTED) # Update ex status to RUNNING. - action_service.update_status(ex, action_constants.LIVEACTION_STATUS_RUNNING, False) + action_service.update_status( + ex, action_constants.LIVEACTION_STATUS_RUNNING, False + ) ex = action_db.get_liveaction_by_id(ex.id) self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_RUNNING) @@ -545,7 +548,9 @@ def test_req_resume(self): self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_PAUSING) # Update ex status to PAUSED. - action_service.update_status(ex, action_constants.LIVEACTION_STATUS_PAUSED, False) + action_service.update_status( + ex, action_constants.LIVEACTION_STATUS_PAUSED, False + ) ex = action_db.get_liveaction_by_id(ex.id) self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_PAUSED) @@ -553,11 +558,11 @@ def test_req_resume(self): ex = self._submit_resume(ex) self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_RESUMING) finally: - action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION["runner_type"]) def test_req_resume_not_paused(self): # Add the runner type to the list of runners that support pause and resume. - action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION["runner_type"]) try: req, ex = self._submit_request() @@ -566,7 +571,9 @@ def test_req_resume_not_paused(self): self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_REQUESTED) # Update ex status to RUNNING. - action_service.update_status(ex, action_constants.LIVEACTION_STATUS_RUNNING, False) + action_service.update_status( + ex, action_constants.LIVEACTION_STATUS_RUNNING, False + ) ex = action_db.get_liveaction_by_id(ex.id) self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_RUNNING) @@ -576,16 +583,14 @@ def test_req_resume_not_paused(self): # Request resume. self.assertRaises( - runner_exc.UnexpectedActionExecutionStatusError, - self._submit_resume, - ex + runner_exc.UnexpectedActionExecutionStatusError, self._submit_resume, ex ) finally: - action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION["runner_type"]) def test_req_resume_already_running(self): # Add the runner type to the list of runners that support pause and resume. - action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.append(ACTION["runner_type"]) try: req, ex = self._submit_request() @@ -594,25 +599,28 @@ def test_req_resume_already_running(self): self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_REQUESTED) # Update ex status to RUNNING. - action_service.update_status(ex, action_constants.LIVEACTION_STATUS_RUNNING, False) + action_service.update_status( + ex, action_constants.LIVEACTION_STATUS_RUNNING, False + ) ex = action_db.get_liveaction_by_id(ex.id) self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_RUNNING) # Request resume. - with mock.patch.object(action_service, 'update_status', return_value=None) as mocked: + with mock.patch.object( + action_service, "update_status", return_value=None + ) as mocked: ex = self._submit_resume(ex) self.assertEqual(ex.status, action_constants.LIVEACTION_STATUS_RUNNING) mocked.assert_not_called() finally: - action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION['runner_type']) + action_constants.WORKFLOW_RUNNER_TYPES.remove(ACTION["runner_type"]) def test_root_liveaction(self): - """Test that get_root_liveaction correctly retrieves the root liveaction - """ + """Test that get_root_liveaction correctly retrieves the root liveaction""" # Test a variety of depths for i in range(1, 7): child, expected_root = self._create_nested_executions(depth=i) actual_root = action_service.get_root_liveaction(child) - self.assertEqual(expected_root['id'], actual_root['id']) + self.assertEqual(expected_root["id"], actual_root["id"]) diff --git a/st2common/tests/unit/services/test_keyvalue.py b/st2common/tests/unit/services/test_keyvalue.py index a11a3bb11b0..bd080719bb4 100644 --- a/st2common/tests/unit/services/test_keyvalue.py +++ b/st2common/tests/unit/services/test_keyvalue.py @@ -22,17 +22,22 @@ class KeyValueServicesTest(unittest2.TestCase): - def test_get_key_reference_system_scope(self): - ref = get_key_reference(scope=SYSTEM_SCOPE, name='foo') - self.assertEqual(ref, 'foo') + ref = get_key_reference(scope=SYSTEM_SCOPE, name="foo") + self.assertEqual(ref, "foo") def test_get_key_reference_user_scope(self): - ref = get_key_reference(scope=USER_SCOPE, name='foo', user='stanley') - self.assertEqual(ref, 'stanley:foo') - self.assertRaises(InvalidUserException, get_key_reference, - scope=USER_SCOPE, name='foo', user='') + ref = get_key_reference(scope=USER_SCOPE, name="foo", user="stanley") + self.assertEqual(ref, "stanley:foo") + self.assertRaises( + InvalidUserException, + get_key_reference, + scope=USER_SCOPE, + name="foo", + user="", + ) def test_get_key_reference_invalid_scope_raises_exception(self): - self.assertRaises(InvalidScopeException, get_key_reference, - scope='sketchy', name='foo') + self.assertRaises( + InvalidScopeException, get_key_reference, scope="sketchy", name="foo" + ) diff --git a/st2common/tests/unit/services/test_policy.py b/st2common/tests/unit/services/test_policy.py index 69fb0624e6f..128ce1defec 100644 --- a/st2common/tests/unit/services/test_policy.py +++ b/st2common/tests/unit/services/test_policy.py @@ -16,6 +16,7 @@ from __future__ import absolute_import import st2tests.config as tests_config + tests_config.parse_args() import st2common @@ -32,23 +33,22 @@ from st2tests import fixturesloader as fixtures -PACK = 'generic' +PACK = "generic" TEST_FIXTURES = { - 'actions': [ - 'action1.yaml', # wolfpack.action-1 - 'action2.yaml', # wolfpack.action-2 - 'local.yaml' # core.local + "actions": [ + "action1.yaml", # wolfpack.action-1 + "action2.yaml", # wolfpack.action-2 + "local.yaml", # core.local + ], + "policies": [ + "policy_2.yaml", # mock policy on wolfpack.action-1 + "policy_5.yaml", # concurrency policy on wolfpack.action-2 ], - 'policies': [ - 'policy_2.yaml', # mock policy on wolfpack.action-1 - 'policy_5.yaml' # concurrency policy on wolfpack.action-2 - ] } class PolicyServiceTestCase(st2tests.DbTestCase): - @classmethod def setUpClass(cls): super(PolicyServiceTestCase, cls).setUpClass() @@ -60,28 +60,39 @@ def setUpClass(cls): policies_registrar.register_policy_types(st2common) loader = fixtures.FixturesLoader() - loader.save_fixtures_to_db(fixtures_pack=PACK, - fixtures_dict=TEST_FIXTURES) + loader.save_fixtures_to_db(fixtures_pack=PACK, fixtures_dict=TEST_FIXTURES) def setUp(self): super(PolicyServiceTestCase, self).setUp() - params = {'action': 'wolfpack.action-1', 'parameters': {'actionstr': 'foo-last'}} + params = { + "action": "wolfpack.action-1", + "parameters": {"actionstr": "foo-last"}, + } self.lv_ac_db_1 = action_db_models.LiveActionDB(**params) self.lv_ac_db_1, _ = action_service.request(self.lv_ac_db_1) - params = {'action': 'wolfpack.action-2', 'parameters': {'actionstr': 'foo-last'}} + params = { + "action": "wolfpack.action-2", + "parameters": {"actionstr": "foo-last"}, + } self.lv_ac_db_2 = action_db_models.LiveActionDB(**params) self.lv_ac_db_2, _ = action_service.request(self.lv_ac_db_2) - params = {'action': 'core.local', 'parameters': {'cmd': 'date'}} + params = {"action": "core.local", "parameters": {"cmd": "date"}} self.lv_ac_db_3 = action_db_models.LiveActionDB(**params) self.lv_ac_db_3, _ = action_service.request(self.lv_ac_db_3) def tearDown(self): - action_service.update_status(self.lv_ac_db_1, action_constants.LIVEACTION_STATUS_CANCELED) - action_service.update_status(self.lv_ac_db_2, action_constants.LIVEACTION_STATUS_CANCELED) - action_service.update_status(self.lv_ac_db_3, action_constants.LIVEACTION_STATUS_CANCELED) + action_service.update_status( + self.lv_ac_db_1, action_constants.LIVEACTION_STATUS_CANCELED + ) + action_service.update_status( + self.lv_ac_db_2, action_constants.LIVEACTION_STATUS_CANCELED + ) + action_service.update_status( + self.lv_ac_db_3, action_constants.LIVEACTION_STATUS_CANCELED + ) def test_action_has_policies(self): self.assertTrue(policy_service.has_policies(self.lv_ac_db_1)) @@ -93,7 +104,7 @@ def test_action_has_specific_policies(self): self.assertTrue( policy_service.has_policies( self.lv_ac_db_2, - policy_types=policy_constants.POLICY_TYPES_REQUIRING_LOCK + policy_types=policy_constants.POLICY_TYPES_REQUIRING_LOCK, ) ) @@ -101,6 +112,6 @@ def test_action_does_not_have_specific_policies(self): self.assertFalse( policy_service.has_policies( self.lv_ac_db_1, - policy_types=policy_constants.POLICY_TYPES_REQUIRING_LOCK + policy_types=policy_constants.POLICY_TYPES_REQUIRING_LOCK, ) ) diff --git a/st2common/tests/unit/services/test_synchronization.py b/st2common/tests/unit/services/test_synchronization.py index 86cf36042fe..991e6b9036b 100644 --- a/st2common/tests/unit/services/test_synchronization.py +++ b/st2common/tests/unit/services/test_synchronization.py @@ -39,13 +39,15 @@ def tearDownClass(cls): super(SynchronizationTest, cls).tearDownClass() def test_service_configured(self): - cfg.CONF.set_override(name='url', override='kazoo://127.0.0.1:2181', group='coordination') + cfg.CONF.set_override( + name="url", override="kazoo://127.0.0.1:2181", group="coordination" + ) self.assertTrue(coordination.configured()) - cfg.CONF.set_override(name='url', override='file:///tmp', group='coordination') + cfg.CONF.set_override(name="url", override="file:///tmp", group="coordination") self.assertFalse(coordination.configured()) - cfg.CONF.set_override(name='url', override='zake://', group='coordination') + cfg.CONF.set_override(name="url", override="zake://", group="coordination") self.assertFalse(coordination.configured()) def test_lock(self): diff --git a/st2common/tests/unit/services/test_trace.py b/st2common/tests/unit/services/test_trace.py index 06c9260586b..807dc4251dc 100644 --- a/st2common/tests/unit/services/test_trace.py +++ b/st2common/tests/unit/services/test_trace.py @@ -30,33 +30,37 @@ from st2tests import DbTestCase -FIXTURES_PACK = 'traces' - -TEST_MODELS = OrderedDict(( - ('executions', [ - 'traceable_execution.yaml', - 'rule_fired_execution.yaml', - 'execution_with_parent.yaml' - ]), - ('liveactions', [ - 'traceable_liveaction.yaml', - 'liveaction_with_parent.yaml' - ]), - ('traces', [ - 'trace_empty.yaml', - 'trace_multiple_components.yaml', - 'trace_one_each.yaml', - 'trace_one_each_dup.yaml', - 'trace_execution.yaml' - ]), - ('triggers', ['trigger1.yaml']), - ('triggerinstances', [ - 'action_trigger.yaml', - 'notify_trigger.yaml', - 'non_internal_trigger.yaml' - ]), - ('rules', ['rule1.yaml']), -)) +FIXTURES_PACK = "traces" + +TEST_MODELS = OrderedDict( + ( + ( + "executions", + [ + "traceable_execution.yaml", + "rule_fired_execution.yaml", + "execution_with_parent.yaml", + ], + ), + ("liveactions", ["traceable_liveaction.yaml", "liveaction_with_parent.yaml"]), + ( + "traces", + [ + "trace_empty.yaml", + "trace_multiple_components.yaml", + "trace_one_each.yaml", + "trace_one_each_dup.yaml", + "trace_execution.yaml", + ], + ), + ("triggers", ["trigger1.yaml"]), + ( + "triggerinstances", + ["action_trigger.yaml", "notify_trigger.yaml", "non_internal_trigger.yaml"], + ), + ("rules", ["rule1.yaml"]), + ) +) class DummyComponent(object): @@ -78,139 +82,184 @@ class TestTraceService(DbTestCase): @classmethod def setUpClass(cls): super(TestTraceService, cls).setUpClass() - cls.models = FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_MODELS) - cls.trace1 = cls.models['traces']['trace_multiple_components.yaml'] - cls.trace2 = cls.models['traces']['trace_one_each.yaml'] - cls.trace3 = cls.models['traces']['trace_one_each_dup.yaml'] - cls.trace_empty = cls.models['traces']['trace_empty.yaml'] - cls.trace_execution = cls.models['traces']['trace_execution.yaml'] - - cls.action_trigger = cls.models['triggerinstances']['action_trigger.yaml'] - cls.notify_trigger = cls.models['triggerinstances']['notify_trigger.yaml'] - cls.non_internal_trigger = cls.models['triggerinstances']['non_internal_trigger.yaml'] - - cls.rule1 = cls.models['rules']['rule1.yaml'] - - cls.traceable_liveaction = cls.models['liveactions']['traceable_liveaction.yaml'] - cls.liveaction_with_parent = cls.models['liveactions']['liveaction_with_parent.yaml'] - cls.traceable_execution = cls.models['executions']['traceable_execution.yaml'] - cls.rule_fired_execution = cls.models['executions']['rule_fired_execution.yaml'] - cls.execution_with_parent = cls.models['executions']['execution_with_parent.yaml'] + cls.models = FixturesLoader().save_fixtures_to_db( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS + ) + cls.trace1 = cls.models["traces"]["trace_multiple_components.yaml"] + cls.trace2 = cls.models["traces"]["trace_one_each.yaml"] + cls.trace3 = cls.models["traces"]["trace_one_each_dup.yaml"] + cls.trace_empty = cls.models["traces"]["trace_empty.yaml"] + cls.trace_execution = cls.models["traces"]["trace_execution.yaml"] + + cls.action_trigger = cls.models["triggerinstances"]["action_trigger.yaml"] + cls.notify_trigger = cls.models["triggerinstances"]["notify_trigger.yaml"] + cls.non_internal_trigger = cls.models["triggerinstances"][ + "non_internal_trigger.yaml" + ] + + cls.rule1 = cls.models["rules"]["rule1.yaml"] + + cls.traceable_liveaction = cls.models["liveactions"][ + "traceable_liveaction.yaml" + ] + cls.liveaction_with_parent = cls.models["liveactions"][ + "liveaction_with_parent.yaml" + ] + cls.traceable_execution = cls.models["executions"]["traceable_execution.yaml"] + cls.rule_fired_execution = cls.models["executions"]["rule_fired_execution.yaml"] + cls.execution_with_parent = cls.models["executions"][ + "execution_with_parent.yaml" + ] def test_get_trace_db_by_action_execution(self): - action_execution = DummyComponent(id_=self.trace1.action_executions[0].object_id) - trace_db = trace_service.get_trace_db_by_action_execution(action_execution=action_execution) - self.assertEqual(trace_db.id, self.trace1.id, 'Incorrect trace_db returned.') + action_execution = DummyComponent( + id_=self.trace1.action_executions[0].object_id + ) + trace_db = trace_service.get_trace_db_by_action_execution( + action_execution=action_execution + ) + self.assertEqual(trace_db.id, self.trace1.id, "Incorrect trace_db returned.") def test_get_trace_db_by_action_execution_fail(self): - action_execution = DummyComponent(id_=self.trace2.action_executions[0].object_id) - self.assertRaises(UniqueTraceNotFoundException, - trace_service.get_trace_db_by_action_execution, - **{'action_execution': action_execution}) + action_execution = DummyComponent( + id_=self.trace2.action_executions[0].object_id + ) + self.assertRaises( + UniqueTraceNotFoundException, + trace_service.get_trace_db_by_action_execution, + **{"action_execution": action_execution}, + ) def test_get_trace_db_by_rule(self): rule = DummyComponent(id_=self.trace1.rules[0].object_id) trace_dbs = trace_service.get_trace_db_by_rule(rule=rule) - self.assertEqual(len(trace_dbs), 1, 'Expected 1 trace_db.') - self.assertEqual(trace_dbs[0].id, self.trace1.id, 'Incorrect trace_db returned.') + self.assertEqual(len(trace_dbs), 1, "Expected 1 trace_db.") + self.assertEqual( + trace_dbs[0].id, self.trace1.id, "Incorrect trace_db returned." + ) def test_get_multiple_trace_db_by_rule(self): rule = DummyComponent(id_=self.trace2.rules[0].object_id) trace_dbs = trace_service.get_trace_db_by_rule(rule=rule) - self.assertEqual(len(trace_dbs), 2, 'Expected 2 trace_db.') + self.assertEqual(len(trace_dbs), 2, "Expected 2 trace_db.") result = [trace_db.id for trace_db in trace_dbs] - self.assertEqual(result, [self.trace2.id, self.trace3.id], 'Incorrect trace_dbs returned.') + self.assertEqual( + result, [self.trace2.id, self.trace3.id], "Incorrect trace_dbs returned." + ) def test_get_trace_db_by_trigger_instance(self): - trigger_instance = DummyComponent(id_=self.trace1.trigger_instances[0].object_id) - trace_db = trace_service.get_trace_db_by_trigger_instance(trigger_instance=trigger_instance) - self.assertEqual(trace_db.id, self.trace1.id, 'Incorrect trace_db returned.') + trigger_instance = DummyComponent( + id_=self.trace1.trigger_instances[0].object_id + ) + trace_db = trace_service.get_trace_db_by_trigger_instance( + trigger_instance=trigger_instance + ) + self.assertEqual(trace_db.id, self.trace1.id, "Incorrect trace_db returned.") def test_get_trace_db_by_trigger_instance_fail(self): - trigger_instance = DummyComponent(id_=self.trace2.trigger_instances[0].object_id) - self.assertRaises(UniqueTraceNotFoundException, - trace_service.get_trace_db_by_trigger_instance, - **{'trigger_instance': trigger_instance}) + trigger_instance = DummyComponent( + id_=self.trace2.trigger_instances[0].object_id + ) + self.assertRaises( + UniqueTraceNotFoundException, + trace_service.get_trace_db_by_trigger_instance, + **{"trigger_instance": trigger_instance}, + ) def test_get_trace_by_dict(self): - trace_context = {'id_': str(self.trace1.id)} + trace_context = {"id_": str(self.trace1.id)} trace_db = trace_service.get_trace(trace_context) - self.assertEqual(trace_db.id, self.trace1.id, 'Incorrect trace_db returned.') + self.assertEqual(trace_db.id, self.trace1.id, "Incorrect trace_db returned.") - trace_context = {'id_': str(bson.ObjectId())} - self.assertRaises(StackStormDBObjectNotFoundError, trace_service.get_trace, trace_context) + trace_context = {"id_": str(bson.ObjectId())} + self.assertRaises( + StackStormDBObjectNotFoundError, trace_service.get_trace, trace_context + ) - trace_context = {'trace_tag': self.trace1.trace_tag} + trace_context = {"trace_tag": self.trace1.trace_tag} trace_db = trace_service.get_trace(trace_context) - self.assertEqual(trace_db.id, self.trace1.id, 'Incorrect trace_db returned.') + self.assertEqual(trace_db.id, self.trace1.id, "Incorrect trace_db returned.") def test_get_trace_by_trace_context(self): - trace_context = TraceContext(**{'id_': str(self.trace1.id)}) + trace_context = TraceContext(**{"id_": str(self.trace1.id)}) trace_db = trace_service.get_trace(trace_context) - self.assertEqual(trace_db.id, self.trace1.id, 'Incorrect trace_db returned.') + self.assertEqual(trace_db.id, self.trace1.id, "Incorrect trace_db returned.") - trace_context = TraceContext(**{'trace_tag': self.trace1.trace_tag}) + trace_context = TraceContext(**{"trace_tag": self.trace1.trace_tag}) trace_db = trace_service.get_trace(trace_context) - self.assertEqual(trace_db.id, self.trace1.id, 'Incorrect trace_db returned.') + self.assertEqual(trace_db.id, self.trace1.id, "Incorrect trace_db returned.") def test_get_trace_ignore_trace_tag(self): - trace_context = {'trace_tag': self.trace1.trace_tag} + trace_context = {"trace_tag": self.trace1.trace_tag} trace_db = trace_service.get_trace(trace_context) - self.assertEqual(trace_db.id, self.trace1.id, 'Incorrect trace_db returned.') + self.assertEqual(trace_db.id, self.trace1.id, "Incorrect trace_db returned.") - trace_context = {'trace_tag': self.trace1.trace_tag} + trace_context = {"trace_tag": self.trace1.trace_tag} trace_db = trace_service.get_trace(trace_context, ignore_trace_tag=True) - self.assertEqual(trace_db, None, 'Should be None.') + self.assertEqual(trace_db, None, "Should be None.") def test_get_trace_fail_empty_context(self): trace_context = {} self.assertRaises(ValueError, trace_service.get_trace, trace_context) def test_get_trace_fail_multi_match(self): - trace_context = {'trace_tag': self.trace2.trace_tag} - self.assertRaises(UniqueTraceNotFoundException, trace_service.get_trace, trace_context) + trace_context = {"trace_tag": self.trace2.trace_tag} + self.assertRaises( + UniqueTraceNotFoundException, trace_service.get_trace, trace_context + ) def test_get_trace_db_by_live_action_valid_id_context(self): traceable_liveaction = copy.copy(self.traceable_liveaction) - traceable_liveaction.context['trace_context'] = {'id_': str(self.trace_execution.id)} - created, trace_db = trace_service.get_trace_db_by_live_action(traceable_liveaction) + traceable_liveaction.context["trace_context"] = { + "id_": str(self.trace_execution.id) + } + created, trace_db = trace_service.get_trace_db_by_live_action( + traceable_liveaction + ) self.assertFalse(created) self.assertEqual(trace_db.id, self.trace_execution.id) def test_get_trace_db_by_live_action_trace_tag_context(self): traceable_liveaction = copy.copy(self.traceable_liveaction) - traceable_liveaction.context['trace_context'] = { - 'trace_tag': str(self.trace_execution.trace_tag) + traceable_liveaction.context["trace_context"] = { + "trace_tag": str(self.trace_execution.trace_tag) } - created, trace_db = trace_service.get_trace_db_by_live_action(traceable_liveaction) + created, trace_db = trace_service.get_trace_db_by_live_action( + traceable_liveaction + ) self.assertTrue(created) - self.assertEqual(trace_db.id, None, 'Expected to be None') + self.assertEqual(trace_db.id, None, "Expected to be None") self.assertEqual(trace_db.trace_tag, str(self.trace_execution.trace_tag)) def test_get_trace_db_by_live_action_parent(self): traceable_liveaction = copy.copy(self.traceable_liveaction) - traceable_liveaction.context['parent'] = { - 'execution_id': str(self.trace1.action_executions[0].object_id) + traceable_liveaction.context["parent"] = { + "execution_id": str(self.trace1.action_executions[0].object_id) } - created, trace_db = trace_service.get_trace_db_by_live_action(traceable_liveaction) + created, trace_db = trace_service.get_trace_db_by_live_action( + traceable_liveaction + ) self.assertFalse(created) self.assertEqual(trace_db.id, self.trace1.id) def test_get_trace_db_by_live_action_parent_fail(self): traceable_liveaction = copy.copy(self.traceable_liveaction) - traceable_liveaction.context['parent'] = { - 'execution_id': str(bson.ObjectId()) - } - self.assertRaises(StackStormDBObjectNotFoundError, - trace_service.get_trace_db_by_live_action, - traceable_liveaction) + traceable_liveaction.context["parent"] = {"execution_id": str(bson.ObjectId())} + self.assertRaises( + StackStormDBObjectNotFoundError, + trace_service.get_trace_db_by_live_action, + traceable_liveaction, + ) def test_get_trace_db_by_live_action_from_execution(self): traceable_liveaction = copy.copy(self.traceable_liveaction) # fixtures id value in liveaction is not persisted in DB. - traceable_liveaction.id = bson.ObjectId(self.traceable_execution.liveaction['id']) - created, trace_db = trace_service.get_trace_db_by_live_action(traceable_liveaction) + traceable_liveaction.id = bson.ObjectId( + self.traceable_execution.liveaction["id"] + ) + created, trace_db = trace_service.get_trace_db_by_live_action( + traceable_liveaction + ) self.assertFalse(created) self.assertEqual(trace_db.id, self.trace_execution.id) @@ -218,76 +267,119 @@ def test_get_trace_db_by_live_action_new_trace(self): traceable_liveaction = copy.copy(self.traceable_liveaction) # a liveaction without any associated ActionExecution traceable_liveaction.id = bson.ObjectId() - created, trace_db = trace_service.get_trace_db_by_live_action(traceable_liveaction) + created, trace_db = trace_service.get_trace_db_by_live_action( + traceable_liveaction + ) self.assertTrue(created) - self.assertEqual(trace_db.id, None, 'Should be None.') + self.assertEqual(trace_db.id, None, "Should be None.") def test_add_or_update_given_trace_context(self): - trace_context = {'id_': str(self.trace_empty.id)} - action_execution_id = 'action_execution_1' - rule_id = 'rule_1' - trigger_instance_id = 'trigger_instance_1' + trace_context = {"id_": str(self.trace_empty.id)} + action_execution_id = "action_execution_1" + rule_id = "rule_1" + trigger_instance_id = "trigger_instance_1" trace_service.add_or_update_given_trace_context( trace_context, action_executions=[action_execution_id], rules=[rule_id], - trigger_instances=[trigger_instance_id]) + trigger_instances=[trigger_instance_id], + ) retrieved_trace_db = Trace.get_by_id(self.trace_empty.id) - self.assertEqual(len(retrieved_trace_db.action_executions), 1, - 'Expected updated action_executions.') - self.assertEqual(retrieved_trace_db.action_executions[0].object_id, action_execution_id, - 'Expected updated action_executions.') - - self.assertEqual(len(retrieved_trace_db.rules), 1, 'Expected updated rules.') - self.assertEqual(retrieved_trace_db.rules[0].object_id, rule_id, 'Expected updated rules.') - - self.assertEqual(len(retrieved_trace_db.trigger_instances), 1, - 'Expected updated trigger_instances.') - self.assertEqual(retrieved_trace_db.trigger_instances[0].object_id, trigger_instance_id, - 'Expected updated trigger_instances.') + self.assertEqual( + len(retrieved_trace_db.action_executions), + 1, + "Expected updated action_executions.", + ) + self.assertEqual( + retrieved_trace_db.action_executions[0].object_id, + action_execution_id, + "Expected updated action_executions.", + ) + + self.assertEqual(len(retrieved_trace_db.rules), 1, "Expected updated rules.") + self.assertEqual( + retrieved_trace_db.rules[0].object_id, rule_id, "Expected updated rules." + ) + + self.assertEqual( + len(retrieved_trace_db.trigger_instances), + 1, + "Expected updated trigger_instances.", + ) + self.assertEqual( + retrieved_trace_db.trigger_instances[0].object_id, + trigger_instance_id, + "Expected updated trigger_instances.", + ) Trace.delete(retrieved_trace_db) Trace.add_or_update(self.trace_empty) def test_add_or_update_given_trace_db(self): - action_execution_id = 'action_execution_1' - rule_id = 'rule_1' - trigger_instance_id = 'trigger_instance_1' + action_execution_id = "action_execution_1" + rule_id = "rule_1" + trigger_instance_id = "trigger_instance_1" to_save = copy.copy(self.trace_empty) to_save.id = None saved = trace_service.add_or_update_given_trace_db( to_save, action_executions=[action_execution_id], rules=[rule_id], - trigger_instances=[trigger_instance_id]) + trigger_instances=[trigger_instance_id], + ) retrieved_trace_db = Trace.get_by_id(saved.id) - self.assertEqual(len(retrieved_trace_db.action_executions), 1, - 'Expected updated action_executions.') - self.assertEqual(retrieved_trace_db.action_executions[0].object_id, action_execution_id, - 'Expected updated action_executions.') - - self.assertEqual(len(retrieved_trace_db.rules), 1, 'Expected updated rules.') - self.assertEqual(retrieved_trace_db.rules[0].object_id, rule_id, 'Expected updated rules.') - - self.assertEqual(len(retrieved_trace_db.trigger_instances), 1, - 'Expected updated trigger_instances.') - self.assertEqual(retrieved_trace_db.trigger_instances[0].object_id, trigger_instance_id, - 'Expected updated trigger_instances.') + self.assertEqual( + len(retrieved_trace_db.action_executions), + 1, + "Expected updated action_executions.", + ) + self.assertEqual( + retrieved_trace_db.action_executions[0].object_id, + action_execution_id, + "Expected updated action_executions.", + ) + + self.assertEqual(len(retrieved_trace_db.rules), 1, "Expected updated rules.") + self.assertEqual( + retrieved_trace_db.rules[0].object_id, rule_id, "Expected updated rules." + ) + + self.assertEqual( + len(retrieved_trace_db.trigger_instances), + 1, + "Expected updated trigger_instances.", + ) + self.assertEqual( + retrieved_trace_db.trigger_instances[0].object_id, + trigger_instance_id, + "Expected updated trigger_instances.", + ) # Now add more TraceComponents and validated that they are added properly. saved = trace_service.add_or_update_given_trace_db( retrieved_trace_db, action_executions=[str(bson.ObjectId()), str(bson.ObjectId())], rules=[str(bson.ObjectId())], - trigger_instances=[str(bson.ObjectId()), str(bson.ObjectId()), str(bson.ObjectId())]) + trigger_instances=[ + str(bson.ObjectId()), + str(bson.ObjectId()), + str(bson.ObjectId()), + ], + ) retrieved_trace_db = Trace.get_by_id(saved.id) - self.assertEqual(len(retrieved_trace_db.action_executions), 3, - 'Expected updated action_executions.') - self.assertEqual(len(retrieved_trace_db.rules), 2, 'Expected updated rules.') - self.assertEqual(len(retrieved_trace_db.trigger_instances), 4, - 'Expected updated trigger_instances.') + self.assertEqual( + len(retrieved_trace_db.action_executions), + 3, + "Expected updated action_executions.", + ) + self.assertEqual(len(retrieved_trace_db.rules), 2, "Expected updated rules.") + self.assertEqual( + len(retrieved_trace_db.trigger_instances), + 4, + "Expected updated trigger_instances.", + ) Trace.delete(retrieved_trace_db) @@ -295,179 +387,238 @@ def test_add_or_update_given_trace_db_fail(self): self.assertRaises(ValueError, trace_service.add_or_update_given_trace_db, None) def test_add_or_update_given_trace_context_new(self): - trace_context = {'trace_tag': 'awesome_test_trace'} - action_execution_id = 'action_execution_1' - rule_id = 'rule_1' - trigger_instance_id = 'trigger_instance_1' + trace_context = {"trace_tag": "awesome_test_trace"} + action_execution_id = "action_execution_1" + rule_id = "rule_1" + trigger_instance_id = "trigger_instance_1" pre_add_or_update_traces = len(Trace.get_all()) trace_db = trace_service.add_or_update_given_trace_context( trace_context, action_executions=[action_execution_id], rules=[rule_id], - trigger_instances=[trigger_instance_id]) + trigger_instances=[trigger_instance_id], + ) post_add_or_update_traces = len(Trace.get_all()) - self.assertTrue(post_add_or_update_traces > pre_add_or_update_traces, - 'Expected new Trace to be created.') + self.assertTrue( + post_add_or_update_traces > pre_add_or_update_traces, + "Expected new Trace to be created.", + ) retrieved_trace_db = Trace.get_by_id(trace_db.id) - self.assertEqual(len(retrieved_trace_db.action_executions), 1, - 'Expected updated action_executions.') - self.assertEqual(retrieved_trace_db.action_executions[0].object_id, action_execution_id, - 'Expected updated action_executions.') - - self.assertEqual(len(retrieved_trace_db.rules), 1, 'Expected updated rules.') - self.assertEqual(retrieved_trace_db.rules[0].object_id, rule_id, 'Expected updated rules.') - - self.assertEqual(len(retrieved_trace_db.trigger_instances), 1, - 'Expected updated trigger_instances.') - self.assertEqual(retrieved_trace_db.trigger_instances[0].object_id, trigger_instance_id, - 'Expected updated trigger_instances.') + self.assertEqual( + len(retrieved_trace_db.action_executions), + 1, + "Expected updated action_executions.", + ) + self.assertEqual( + retrieved_trace_db.action_executions[0].object_id, + action_execution_id, + "Expected updated action_executions.", + ) + + self.assertEqual(len(retrieved_trace_db.rules), 1, "Expected updated rules.") + self.assertEqual( + retrieved_trace_db.rules[0].object_id, rule_id, "Expected updated rules." + ) + + self.assertEqual( + len(retrieved_trace_db.trigger_instances), + 1, + "Expected updated trigger_instances.", + ) + self.assertEqual( + retrieved_trace_db.trigger_instances[0].object_id, + trigger_instance_id, + "Expected updated trigger_instances.", + ) Trace.delete(retrieved_trace_db) def test_add_or_update_given_trace_context_new_with_causals(self): - trace_context = {'trace_tag': 'causal_test_trace'} - action_execution_id = 'action_execution_1' - rule_id = 'rule_1' - trigger_instance_id = 'trigger_instance_1' + trace_context = {"trace_tag": "causal_test_trace"} + action_execution_id = "action_execution_1" + rule_id = "rule_1" + trigger_instance_id = "trigger_instance_1" pre_add_or_update_traces = len(Trace.get_all()) trace_db = trace_service.add_or_update_given_trace_context( trace_context, - action_executions=[{'id': action_execution_id, - 'caused_by': {'id': '%s:%s' % (rule_id, trigger_instance_id), - 'type': 'rule'}}], - rules=[{'id': rule_id, - 'caused_by': {'id': trigger_instance_id, 'type': 'trigger-instance'}}], - trigger_instances=[trigger_instance_id]) + action_executions=[ + { + "id": action_execution_id, + "caused_by": { + "id": "%s:%s" % (rule_id, trigger_instance_id), + "type": "rule", + }, + } + ], + rules=[ + { + "id": rule_id, + "caused_by": { + "id": trigger_instance_id, + "type": "trigger-instance", + }, + } + ], + trigger_instances=[trigger_instance_id], + ) post_add_or_update_traces = len(Trace.get_all()) - self.assertTrue(post_add_or_update_traces > pre_add_or_update_traces, - 'Expected new Trace to be created.') + self.assertTrue( + post_add_or_update_traces > pre_add_or_update_traces, + "Expected new Trace to be created.", + ) retrieved_trace_db = Trace.get_by_id(trace_db.id) - self.assertEqual(len(retrieved_trace_db.action_executions), 1, - 'Expected updated action_executions.') - self.assertEqual(retrieved_trace_db.action_executions[0].object_id, action_execution_id, - 'Expected updated action_executions.') - self.assertEqual(retrieved_trace_db.action_executions[0].caused_by, - {'id': '%s:%s' % (rule_id, trigger_instance_id), - 'type': 'rule'}, - 'Expected updated action_executions.') - - self.assertEqual(len(retrieved_trace_db.rules), 1, 'Expected updated rules.') - self.assertEqual(retrieved_trace_db.rules[0].object_id, rule_id, 'Expected updated rules.') - self.assertEqual(retrieved_trace_db.rules[0].caused_by, - {'id': trigger_instance_id, 'type': 'trigger-instance'}, - 'Expected updated rules.') - - self.assertEqual(len(retrieved_trace_db.trigger_instances), 1, - 'Expected updated trigger_instances.') - self.assertEqual(retrieved_trace_db.trigger_instances[0].object_id, trigger_instance_id, - 'Expected updated trigger_instances.') - self.assertEqual(retrieved_trace_db.trigger_instances[0].caused_by, {}, - 'Expected updated rules.') + self.assertEqual( + len(retrieved_trace_db.action_executions), + 1, + "Expected updated action_executions.", + ) + self.assertEqual( + retrieved_trace_db.action_executions[0].object_id, + action_execution_id, + "Expected updated action_executions.", + ) + self.assertEqual( + retrieved_trace_db.action_executions[0].caused_by, + {"id": "%s:%s" % (rule_id, trigger_instance_id), "type": "rule"}, + "Expected updated action_executions.", + ) + + self.assertEqual(len(retrieved_trace_db.rules), 1, "Expected updated rules.") + self.assertEqual( + retrieved_trace_db.rules[0].object_id, rule_id, "Expected updated rules." + ) + self.assertEqual( + retrieved_trace_db.rules[0].caused_by, + {"id": trigger_instance_id, "type": "trigger-instance"}, + "Expected updated rules.", + ) + + self.assertEqual( + len(retrieved_trace_db.trigger_instances), + 1, + "Expected updated trigger_instances.", + ) + self.assertEqual( + retrieved_trace_db.trigger_instances[0].object_id, + trigger_instance_id, + "Expected updated trigger_instances.", + ) + self.assertEqual( + retrieved_trace_db.trigger_instances[0].caused_by, + {}, + "Expected updated rules.", + ) Trace.delete(retrieved_trace_db) def test_trace_component_for_trigger_instance(self): # action_trigger trace_component = trace_service.get_trace_component_for_trigger_instance( - self.action_trigger) + self.action_trigger + ) expected = { - 'id': str(self.action_trigger.id), - 'ref': self.action_trigger.trigger, - 'caused_by': { - 'type': 'action_execution', - 'id': self.action_trigger.payload['execution_id'] - } + "id": str(self.action_trigger.id), + "ref": self.action_trigger.trigger, + "caused_by": { + "type": "action_execution", + "id": self.action_trigger.payload["execution_id"], + }, } self.assertEqual(trace_component, expected) # notify_trigger trace_component = trace_service.get_trace_component_for_trigger_instance( - self.notify_trigger) + self.notify_trigger + ) expected = { - 'id': str(self.notify_trigger.id), - 'ref': self.notify_trigger.trigger, - 'caused_by': { - 'type': 'action_execution', - 'id': self.notify_trigger.payload['execution_id'] - } + "id": str(self.notify_trigger.id), + "ref": self.notify_trigger.trigger, + "caused_by": { + "type": "action_execution", + "id": self.notify_trigger.payload["execution_id"], + }, } self.assertEqual(trace_component, expected) # non_internal_trigger trace_component = trace_service.get_trace_component_for_trigger_instance( - self.non_internal_trigger) + self.non_internal_trigger + ) expected = { - 'id': str(self.non_internal_trigger.id), - 'ref': self.non_internal_trigger.trigger, - 'caused_by': {} + "id": str(self.non_internal_trigger.id), + "ref": self.non_internal_trigger.trigger, + "caused_by": {}, } self.assertEqual(trace_component, expected) def test_trace_component_for_rule(self): - trace_component = trace_service.get_trace_component_for_rule(self.rule1, - self.non_internal_trigger) + trace_component = trace_service.get_trace_component_for_rule( + self.rule1, self.non_internal_trigger + ) expected = { - 'id': str(self.rule1.id), - 'ref': self.rule1.ref, - 'caused_by': { - 'type': 'trigger_instance', - 'id': str(self.non_internal_trigger.id) - } + "id": str(self.rule1.id), + "ref": self.rule1.ref, + "caused_by": { + "type": "trigger_instance", + "id": str(self.non_internal_trigger.id), + }, } self.assertEqual(trace_component, expected) def test_trace_component_for_action_execution(self): # no cause trace_component = trace_service.get_trace_component_for_action_execution( - self.traceable_execution, - self.traceable_liveaction) + self.traceable_execution, self.traceable_liveaction + ) expected = { - 'id': str(self.traceable_execution.id), - 'ref': self.traceable_execution.action['ref'], - 'caused_by': {} + "id": str(self.traceable_execution.id), + "ref": self.traceable_execution.action["ref"], + "caused_by": {}, } self.assertEqual(trace_component, expected) # rule_fired_execution trace_component = trace_service.get_trace_component_for_action_execution( - self.rule_fired_execution, - self.traceable_liveaction) + self.rule_fired_execution, self.traceable_liveaction + ) expected = { - 'id': str(self.rule_fired_execution.id), - 'ref': self.rule_fired_execution.action['ref'], - 'caused_by': { - 'type': 'rule', - 'id': '%s:%s' % (self.rule_fired_execution.rule['id'], - self.rule_fired_execution.trigger_instance['id']) - } + "id": str(self.rule_fired_execution.id), + "ref": self.rule_fired_execution.action["ref"], + "caused_by": { + "type": "rule", + "id": "%s:%s" + % ( + self.rule_fired_execution.rule["id"], + self.rule_fired_execution.trigger_instance["id"], + ), + }, } self.assertEqual(trace_component, expected) # execution_with_parent trace_component = trace_service.get_trace_component_for_action_execution( - self.execution_with_parent, - self.liveaction_with_parent) + self.execution_with_parent, self.liveaction_with_parent + ) expected = { - 'id': str(self.execution_with_parent.id), - 'ref': self.execution_with_parent.action['ref'], - 'caused_by': { - 'type': 'action_execution', - 'id': self.liveaction_with_parent.context['parent']['execution_id'] - } + "id": str(self.execution_with_parent.id), + "ref": self.execution_with_parent.action["ref"], + "caused_by": { + "type": "action_execution", + "id": self.liveaction_with_parent.context["parent"]["execution_id"], + }, } self.assertEqual(trace_component, expected) class TestTraceContext(TestCase): - def test_str_method(self): - trace_context = TraceContext(id_='id', trace_tag='tag') + trace_context = TraceContext(id_="id", trace_tag="tag") self.assertTrue(str(trace_context)) - trace_context = TraceContext(trace_tag='tag') + trace_context = TraceContext(trace_tag="tag") self.assertTrue(str(trace_context)) - trace_context = TraceContext(id_='id') + trace_context = TraceContext(id_="id") self.assertTrue(str(trace_context)) diff --git a/st2common/tests/unit/services/test_trace_injection_action_services.py b/st2common/tests/unit/services/test_trace_injection_action_services.py index 8f9570d0e28..4b4fe0d1775 100644 --- a/st2common/tests/unit/services/test_trace_injection_action_services.py +++ b/st2common/tests/unit/services/test_trace_injection_action_services.py @@ -21,13 +21,13 @@ from st2tests.fixturesloader import FixturesLoader from st2tests import DbTestCase -FIXTURES_PACK = 'traces' +FIXTURES_PACK = "traces" TEST_MODELS = { - 'executions': ['traceable_execution.yaml'], - 'liveactions': ['traceable_liveaction.yaml'], - 'actions': ['chain1.yaml'], - 'runners': ['actionchain.yaml'] + "executions": ["traceable_execution.yaml"], + "liveactions": ["traceable_liveaction.yaml"], + "actions": ["chain1.yaml"], + "runners": ["actionchain.yaml"], } @@ -41,44 +41,52 @@ class TraceInjectionTests(DbTestCase): @classmethod def setUpClass(cls): super(TraceInjectionTests, cls).setUpClass() - cls.models = FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_MODELS) + cls.models = FixturesLoader().save_fixtures_to_db( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS + ) - cls.traceable_liveaction = cls.models['liveactions']['traceable_liveaction.yaml'] - cls.traceable_execution = cls.models['executions']['traceable_execution.yaml'] - cls.action = cls.models['actions']['chain1.yaml'] + cls.traceable_liveaction = cls.models["liveactions"][ + "traceable_liveaction.yaml" + ] + cls.traceable_execution = cls.models["executions"]["traceable_execution.yaml"] + cls.action = cls.models["actions"]["chain1.yaml"] def test_trace_provided(self): - self.traceable_liveaction['context']['trace_context'] = {'trace_tag': 'OohLaLaLa'} + self.traceable_liveaction["context"]["trace_context"] = { + "trace_tag": "OohLaLaLa" + } action_services.request(self.traceable_liveaction) traces = Trace.get_all() self.assertEqual(len(traces), 1) - self.assertEqual(len(traces[0]['action_executions']), 1) + self.assertEqual(len(traces[0]["action_executions"]), 1) # Let's use existing trace id in trace context. # We shouldn't create new trace object. trace_id = str(traces[0].id) - self.traceable_liveaction['context']['trace_context'] = {'id_': trace_id} + self.traceable_liveaction["context"]["trace_context"] = {"id_": trace_id} action_services.request(self.traceable_liveaction) traces = Trace.get_all() self.assertEqual(len(traces), 1) - self.assertEqual(len(traces[0]['action_executions']), 2) + self.assertEqual(len(traces[0]["action_executions"]), 2) def test_trace_tag_resuse(self): - self.traceable_liveaction['context']['trace_context'] = {'trace_tag': 'blank space'} + self.traceable_liveaction["context"]["trace_context"] = { + "trace_tag": "blank space" + } action_services.request(self.traceable_liveaction) # Let's use same trace tag again and we should see two trace objects in db. action_services.request(self.traceable_liveaction) - traces = Trace.query(**{'trace_tag': 'blank space'}) + traces = Trace.query(**{"trace_tag": "blank space"}) self.assertEqual(len(traces), 2) def test_invalid_trace_id_provided(self): liveactions = LiveAction.get_all() self.assertEqual(len(liveactions), 1) # fixtures loads it. - self.traceable_liveaction['context']['trace_context'] = {'id_': 'balleilaka'} + self.traceable_liveaction["context"]["trace_context"] = {"id_": "balleilaka"} - self.assertRaises(TraceNotFoundException, action_services.request, - self.traceable_liveaction) + self.assertRaises( + TraceNotFoundException, action_services.request, self.traceable_liveaction + ) # Make sure no liveactions are left behind liveactions = LiveAction.get_all() diff --git a/st2common/tests/unit/services/test_workflow.py b/st2common/tests/unit/services/test_workflow.py index 71cae679ba1..23bd4aca606 100644 --- a/st2common/tests/unit/services/test_workflow.py +++ b/st2common/tests/unit/services/test_workflow.py @@ -25,6 +25,7 @@ import st2tests import st2tests.config as tests_config + tests_config.parse_args() from st2common.bootstrap import actionsregistrar @@ -43,33 +44,35 @@ from st2tests.mocks import liveaction as mock_lv_ac_xport -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) -PACK_7 = 'dummy_pack_7' -PACK_7_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + PACK_7 +PACK_7 = "dummy_pack_7" +PACK_7_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + PACK_7 PACKS = [ TEST_PACK_PATH, PACK_7_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] @mock.patch.object( - publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) + publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None) +) @mock.patch.object( publishers.CUDPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) class WorkflowExecutionServiceTest(st2tests.WorkflowTestCase): - @classmethod def setUpClass(cls): super(WorkflowExecutionServiceTest, cls).setUpClass() @@ -79,18 +82,17 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: actions_registrar.register_from_pack(pack) def test_request(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") # Manually create the liveaction and action execution objects without publishing. - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.create_request(lv_ac_db) # Request the workflow execution. @@ -99,7 +101,9 @@ def test_request(self): wf_ex_db = workflow_service.request(wf_def, ac_ex_db, st2_ctx) # Check workflow execution is saved to the database. - wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id)) + wf_ex_dbs = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + ) self.assertEqual(len(wf_ex_dbs), 1) # Check required attributes. @@ -110,10 +114,12 @@ def test_request(self): self.assertEqual(wf_ex_db.status, wf_statuses.REQUESTED) def test_request_with_input(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") # Manually create the liveaction and action execution objects without publishing. - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name'], parameters={'who': 'stan'}) + lv_ac_db = lv_db_models.LiveActionDB( + action=wf_meta["name"], parameters={"who": "stan"} + ) lv_ac_db, ac_ex_db = action_service.create_request(lv_ac_db) # Request the workflow execution. @@ -122,7 +128,9 @@ def test_request_with_input(self): wf_ex_db = workflow_service.request(wf_def, ac_ex_db, st2_ctx) # Check workflow execution is saved to the database. - wf_ex_dbs = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id)) + wf_ex_dbs = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + ) self.assertEqual(len(wf_ex_dbs), 1) # Check required attributes. @@ -133,18 +141,16 @@ def test_request_with_input(self): self.assertEqual(wf_ex_db.status, wf_statuses.REQUESTED) # Check input and context. - expected_input = { - 'who': 'stan' - } + expected_input = {"who": "stan"} self.assertDictEqual(wf_ex_db.input, expected_input) def test_request_bad_action(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") # Manually create the action execution object with the bad action. ac_ex_db = ex_db_models.ActionExecutionDB( - action={'ref': 'mock.foobar'}, runner={'name': 'foobar'} + action={"ref": "mock.foobar"}, runner={"name": "foobar"} ) # Request the workflow execution. @@ -153,14 +159,16 @@ def test_request_bad_action(self): workflow_service.request, self.get_wf_def(TEST_PACK_PATH, wf_meta), ac_ex_db, - self.mock_st2_context(ac_ex_db) + self.mock_st2_context(ac_ex_db), ) def test_request_wf_def_with_bad_action_ref(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'fail-inspection-action-ref.yaml') + wf_meta = self.get_wf_fixture_meta_data( + TEST_PACK_PATH, "fail-inspection-action-ref.yaml" + ) # Manually create the liveaction and action execution objects without publishing. - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.create_request(lv_ac_db) # Exception is expected on request of workflow execution. @@ -169,14 +177,16 @@ def test_request_wf_def_with_bad_action_ref(self): workflow_service.request, self.get_wf_def(TEST_PACK_PATH, wf_meta), ac_ex_db, - self.mock_st2_context(ac_ex_db) + self.mock_st2_context(ac_ex_db), ) def test_request_wf_def_with_unregistered_action(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'fail-inspection-action-db.yaml') + wf_meta = self.get_wf_fixture_meta_data( + TEST_PACK_PATH, "fail-inspection-action-db.yaml" + ) # Manually create the liveaction and action execution objects without publishing. - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.create_request(lv_ac_db) # Exception is expected on request of workflow execution. @@ -185,15 +195,15 @@ def test_request_wf_def_with_unregistered_action(self): workflow_service.request, self.get_wf_def(TEST_PACK_PATH, wf_meta), ac_ex_db, - self.mock_st2_context(ac_ex_db) + self.mock_st2_context(ac_ex_db), ) def test_request_wf_def_with_missing_required_action_param(self): - wf_name = 'fail-inspection-missing-required-action-param' - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + '.yaml') + wf_name = "fail-inspection-missing-required-action-param" + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + ".yaml") # Manually create the liveaction and action execution objects without publishing. - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.create_request(lv_ac_db) # Exception is expected on request of workflow execution. @@ -202,15 +212,15 @@ def test_request_wf_def_with_missing_required_action_param(self): workflow_service.request, self.get_wf_def(TEST_PACK_PATH, wf_meta), ac_ex_db, - self.mock_st2_context(ac_ex_db) + self.mock_st2_context(ac_ex_db), ) def test_request_wf_def_with_unexpected_action_param(self): - wf_name = 'fail-inspection-unexpected-action-param' - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + '.yaml') + wf_name = "fail-inspection-unexpected-action-param" + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, wf_name + ".yaml") # Manually create the liveaction and action execution objects without publishing. - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.create_request(lv_ac_db) # Exception is expected on request of workflow execution. @@ -219,44 +229,46 @@ def test_request_wf_def_with_unexpected_action_param(self): workflow_service.request, self.get_wf_def(TEST_PACK_PATH, wf_meta), ac_ex_db, - self.mock_st2_context(ac_ex_db) + self.mock_st2_context(ac_ex_db), ) def test_request_task_execution(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") # Manually create the liveaction and action execution objects without publishing. - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.create_request(lv_ac_db) # Request the workflow execution. wf_def = self.get_wf_def(TEST_PACK_PATH, wf_meta) st2_ctx = self.mock_st2_context(ac_ex_db) wf_ex_db = workflow_service.request(wf_def, ac_ex_db, st2_ctx) - spec_module = specs_loader.get_spec_module(wf_ex_db.spec['catalog']) + spec_module = specs_loader.get_spec_module(wf_ex_db.spec["catalog"]) wf_spec = spec_module.WorkflowSpec.deserialize(wf_ex_db.spec) # Manually request task execution. task_route = 0 - task_id = 'task1' + task_id = "task1" task_spec = wf_spec.tasks.get_task(task_id) - task_ctx = {'foo': 'bar'} - st2_ctx = {'execution_id': wf_ex_db.action_execution} + task_ctx = {"foo": "bar"} + st2_ctx = {"execution_id": wf_ex_db.action_execution} task_ex_req = { - 'id': task_id, - 'route': task_route, - 'spec': task_spec, - 'ctx': task_ctx, - 'actions': [ - {'action': 'core.echo', 'input': {'message': 'Veni, vidi, vici.'}} - ] + "id": task_id, + "route": task_route, + "spec": task_spec, + "ctx": task_ctx, + "actions": [ + {"action": "core.echo", "input": {"message": "Veni, vidi, vici."}} + ], } workflow_service.request_task_execution(wf_ex_db, st2_ctx, task_ex_req) # Check task execution is saved to the database. - task_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id)) + task_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(task_ex_dbs), 1) # Check required attributes. @@ -267,42 +279,46 @@ def test_request_task_execution(self): self.assertEqual(task_ex_db.status, wf_statuses.RUNNING) # Check action execution for the task query with task execution ID. - ac_ex_dbs = ex_db_access.ActionExecution.query(task_execution=str(task_ex_db.id)) + ac_ex_dbs = ex_db_access.ActionExecution.query( + task_execution=str(task_ex_db.id) + ) self.assertEqual(len(ac_ex_dbs), 1) # Check action execution for the task query with workflow execution ID. - ac_ex_dbs = ex_db_access.ActionExecution.query(workflow_execution=str(wf_ex_db.id)) + ac_ex_dbs = ex_db_access.ActionExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(ac_ex_dbs), 1) def test_request_task_execution_bad_action(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") # Manually create the liveaction and action execution objects without publishing. - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.create_request(lv_ac_db) # Request the workflow execution. wf_def = self.get_wf_def(TEST_PACK_PATH, wf_meta) st2_ctx = self.mock_st2_context(ac_ex_db) wf_ex_db = workflow_service.request(wf_def, ac_ex_db, st2_ctx) - spec_module = specs_loader.get_spec_module(wf_ex_db.spec['catalog']) + spec_module = specs_loader.get_spec_module(wf_ex_db.spec["catalog"]) wf_spec = spec_module.WorkflowSpec.deserialize(wf_ex_db.spec) # Manually request task execution. task_route = 0 - task_id = 'task1' + task_id = "task1" task_spec = wf_spec.tasks.get_task(task_id) - task_ctx = {'foo': 'bar'} - st2_ctx = {'execution_id': wf_ex_db.action_execution} + task_ctx = {"foo": "bar"} + st2_ctx = {"execution_id": wf_ex_db.action_execution} task_ex_req = { - 'id': task_id, - 'route': task_route, - 'spec': task_spec, - 'ctx': task_ctx, - 'actions': [ - {'action': 'mock.echo', 'input': {'message': 'Veni, vidi, vici.'}} - ] + "id": task_id, + "route": task_route, + "spec": task_spec, + "ctx": task_ctx, + "actions": [ + {"action": "mock.echo", "input": {"message": "Veni, vidi, vici."}} + ], } self.assertRaises( @@ -310,14 +326,14 @@ def test_request_task_execution_bad_action(self): workflow_service.request_task_execution, wf_ex_db, st2_ctx, - task_ex_req + task_ex_req, ) def test_handle_action_execution_completion(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") # Manually create the liveaction and action execution objects without publishing. - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.create_request(lv_ac_db) # Request and pre-process the workflow execution. @@ -327,111 +343,124 @@ def test_handle_action_execution_completion(self): wf_ex_db = self.prep_wf_ex(wf_ex_db) # Manually request task execution. - self.run_workflow_step(wf_ex_db, 'task1', 0, ctx={'foo': 'bar'}) + self.run_workflow_step(wf_ex_db, "task1", 0, ctx={"foo": "bar"}) # Check that a new task is executed. - self.assert_task_running('task2', 0) + self.assert_task_running("task2", 0) def test_evaluate_action_execution_delay(self): - base_task_ex_req = {'task_id': 'task1', 'task_name': 'task1', 'route': 0} + base_task_ex_req = {"task_id": "task1", "task_name": "task1", "route": 0} # No task delay. task_ex_req = copy.deepcopy(base_task_ex_req) - ac_ex_req = {'action': 'core.noop', 'input': None} - actual_delay = workflow_service.eval_action_execution_delay(task_ex_req, ac_ex_req) + ac_ex_req = {"action": "core.noop", "input": None} + actual_delay = workflow_service.eval_action_execution_delay( + task_ex_req, ac_ex_req + ) self.assertIsNone(actual_delay) # Simple task delay. task_ex_req = copy.deepcopy(base_task_ex_req) - task_ex_req['delay'] = 180 - ac_ex_req = {'action': 'core.noop', 'input': None} - actual_delay = workflow_service.eval_action_execution_delay(task_ex_req, ac_ex_req) + task_ex_req["delay"] = 180 + ac_ex_req = {"action": "core.noop", "input": None} + actual_delay = workflow_service.eval_action_execution_delay( + task_ex_req, ac_ex_req + ) self.assertEqual(actual_delay, 180) # Task delay for with items task and with no concurrency. task_ex_req = copy.deepcopy(base_task_ex_req) - task_ex_req['delay'] = 180 - task_ex_req['concurrency'] = None - ac_ex_req = {'action': 'core.noop', 'input': None, 'items_id': 0} - actual_delay = workflow_service.eval_action_execution_delay(task_ex_req, ac_ex_req, True) + task_ex_req["delay"] = 180 + task_ex_req["concurrency"] = None + ac_ex_req = {"action": "core.noop", "input": None, "items_id": 0} + actual_delay = workflow_service.eval_action_execution_delay( + task_ex_req, ac_ex_req, True + ) self.assertEqual(actual_delay, 180) # Task delay for with items task, with concurrency, and evaluate first item. task_ex_req = copy.deepcopy(base_task_ex_req) - task_ex_req['delay'] = 180 - task_ex_req['concurrency'] = 1 - ac_ex_req = {'action': 'core.noop', 'input': None, 'item_id': 0} - actual_delay = workflow_service.eval_action_execution_delay(task_ex_req, ac_ex_req, True) + task_ex_req["delay"] = 180 + task_ex_req["concurrency"] = 1 + ac_ex_req = {"action": "core.noop", "input": None, "item_id": 0} + actual_delay = workflow_service.eval_action_execution_delay( + task_ex_req, ac_ex_req, True + ) self.assertEqual(actual_delay, 180) # Task delay for with items task, with concurrency, and evaluate later items. task_ex_req = copy.deepcopy(base_task_ex_req) - task_ex_req['delay'] = 180 - task_ex_req['concurrency'] = 1 - ac_ex_req = {'action': 'core.noop', 'input': None, 'item_id': 1} - actual_delay = workflow_service.eval_action_execution_delay(task_ex_req, ac_ex_req, True) + task_ex_req["delay"] = 180 + task_ex_req["concurrency"] = 1 + ac_ex_req = {"action": "core.noop", "input": None, "item_id": 1} + actual_delay = workflow_service.eval_action_execution_delay( + task_ex_req, ac_ex_req, True + ) self.assertIsNone(actual_delay) def test_request_action_execution_render(self): # Manually create ConfigDB - output = 'Testing' - value = { - "config_item_one": output - } + output = "Testing" + value = {"config_item_one": output} config_db = pk_db_models.ConfigDB(pack=PACK_7, values=value) config = pk_db_access.Config.add_or_update(config_db) self.assertEqual(len(config), 3) - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'render_config_context.yaml') + wf_meta = self.get_wf_fixture_meta_data( + TEST_PACK_PATH, "render_config_context.yaml" + ) # Manually create the liveaction and action execution objects without publishing. - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = action_service.create_request(lv_ac_db) # Request the workflow execution. wf_def = self.get_wf_def(TEST_PACK_PATH, wf_meta) st2_ctx = self.mock_st2_context(ac_ex_db) wf_ex_db = workflow_service.request(wf_def, ac_ex_db, st2_ctx) - spec_module = specs_loader.get_spec_module(wf_ex_db.spec['catalog']) + spec_module = specs_loader.get_spec_module(wf_ex_db.spec["catalog"]) wf_spec = spec_module.WorkflowSpec.deserialize(wf_ex_db.spec) # Pass down appropriate st2 context to the task and action execution(s). - root_st2_ctx = wf_ex_db.context.get('st2', {}) + root_st2_ctx = wf_ex_db.context.get("st2", {}) st2_ctx = { - 'execution_id': wf_ex_db.action_execution, - 'user': root_st2_ctx.get('user'), - 'pack': root_st2_ctx.get('pack') + "execution_id": wf_ex_db.action_execution, + "user": root_st2_ctx.get("user"), + "pack": root_st2_ctx.get("pack"), } # Manually request task execution. task_route = 0 - task_id = 'task1' + task_id = "task1" task_spec = wf_spec.tasks.get_task(task_id) - task_ctx = {'foo': 'bar'} + task_ctx = {"foo": "bar"} task_ex_req = { - 'id': task_id, - 'route': task_route, - 'spec': task_spec, - 'ctx': task_ctx, - 'actions': [ - {'action': 'dummy_pack_7.render_config_context', 'input': None} - ] + "id": task_id, + "route": task_route, + "spec": task_spec, + "ctx": task_ctx, + "actions": [ + {"action": "dummy_pack_7.render_config_context", "input": None} + ], } workflow_service.request_task_execution(wf_ex_db, st2_ctx, task_ex_req) # Check task execution is saved to the database. - task_ex_dbs = wf_db_access.TaskExecution.query(workflow_execution=str(wf_ex_db.id)) + task_ex_dbs = wf_db_access.TaskExecution.query( + workflow_execution=str(wf_ex_db.id) + ) self.assertEqual(len(task_ex_dbs), 1) workflow_service.request_task_execution(wf_ex_db, st2_ctx, task_ex_req) # Manually request action execution task_ex_db = task_ex_dbs[0] - action_ex_db = workflow_service.request_action_execution(wf_ex_db, task_ex_db, st2_ctx, - task_ex_req['actions'][0]) + action_ex_db = workflow_service.request_action_execution( + wf_ex_db, task_ex_db, st2_ctx, task_ex_req["actions"][0] + ) # Check required attributes. self.assertIsNotNone(str(action_ex_db.id)) self.assertEqual(task_ex_db.workflow_execution, str(wf_ex_db.id)) - expected_parameters = {'value1': output} + expected_parameters = {"value1": output} self.assertEqual(expected_parameters, action_ex_db.parameters) diff --git a/st2common/tests/unit/services/test_workflow_cancellation.py b/st2common/tests/unit/services/test_workflow_cancellation.py index 26455971f09..22694924a31 100644 --- a/st2common/tests/unit/services/test_workflow_cancellation.py +++ b/st2common/tests/unit/services/test_workflow_cancellation.py @@ -22,6 +22,7 @@ import st2tests import st2tests.config as tests_config + tests_config.parse_args() from st2common.bootstrap import actionsregistrar @@ -35,39 +36,35 @@ TEST_FIXTURES = { - 'workflows': [ - 'sequential.yaml', - 'join.yaml' - ], - 'actions': [ - 'sequential.yaml', - 'join.yaml' - ] + "workflows": ["sequential.yaml", "join.yaml"], + "actions": ["sequential.yaml", "join.yaml"], } -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) PACKS = [ TEST_PACK_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] @mock.patch.object( - publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) + publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None) +) @mock.patch.object( publishers.CUDPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) class WorkflowExecutionCancellationTest(st2tests.WorkflowTestCase): - @classmethod def setUpClass(cls): super(WorkflowExecutionCancellationTest, cls).setUpClass() @@ -77,8 +74,7 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: @@ -86,8 +82,10 @@ def setUpClass(cls): def test_cancellation(self): # Manually create the liveaction and action execution objects without publishing. - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, TEST_FIXTURES['workflows'][0]) - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = self.get_wf_fixture_meta_data( + TEST_PACK_PATH, TEST_FIXTURES["workflows"][0] + ) + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.create_request(lv_ac_db) # Request and pre-process the workflow execution. @@ -98,8 +96,8 @@ def test_cancellation(self): # Manually request task executions. task_route = 0 - self.run_workflow_step(wf_ex_db, 'task1', task_route) - self.assert_task_running('task2', task_route) + self.run_workflow_step(wf_ex_db, "task1", task_route) + self.assert_task_running("task2", task_route) # Cancel the workflow when there is still active task(s). wf_ex_db = wf_svc.request_cancellation(ac_ex_db) @@ -108,8 +106,8 @@ def test_cancellation(self): self.assertEqual(wf_ex_db.status, wf_statuses.CANCELING) # Manually complete the task and ensure workflow is canceled. - self.run_workflow_step(wf_ex_db, 'task2', task_route) - self.assert_task_not_started('task3', task_route) + self.run_workflow_step(wf_ex_db, "task2", task_route) + self.assert_task_not_started("task3", task_route) conductor, wf_ex_db = wf_svc.refresh_conductor(str(wf_ex_db.id)) self.assertEqual(conductor.get_workflow_status(), wf_statuses.CANCELED) self.assertEqual(wf_ex_db.status, wf_statuses.CANCELED) diff --git a/st2common/tests/unit/services/test_workflow_identify_orphans.py b/st2common/tests/unit/services/test_workflow_identify_orphans.py index d45ba1527f0..306e22badd3 100644 --- a/st2common/tests/unit/services/test_workflow_identify_orphans.py +++ b/st2common/tests/unit/services/test_workflow_identify_orphans.py @@ -24,6 +24,7 @@ # XXX: actionsensor import depends on config being setup. import st2tests.config as tests_config + tests_config.parse_args() from st2common.bootstrap import actionsregistrar @@ -48,42 +49,51 @@ LOG = logging.getLogger(__name__) -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) PACKS = [ TEST_PACK_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] @mock.patch.object( - publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) + publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None) +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create)) + "publish_create", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create + ), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state)) + "publish_state", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state + ), +) class WorkflowServiceIdentifyOrphansTest(st2tests.WorkflowTestCase): ensure_indexes = True ensure_indexes_models = [ ex_db_models.ActionExecutionDB, lv_db_models.LiveActionDB, wf_db_models.WorkflowExecutionDB, - wf_db_models.TaskExecutionDB + wf_db_models.TaskExecutionDB, ] @classmethod @@ -95,8 +105,7 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: @@ -119,8 +128,9 @@ def tearDown(self): def mock_workflow_records(self, completed=False, expired=True, log=True): status = ( - ac_const.LIVEACTION_STATUS_SUCCEEDED if completed else - ac_const.LIVEACTION_STATUS_RUNNING + ac_const.LIVEACTION_STATUS_SUCCEEDED + if completed + else ac_const.LIVEACTION_STATUS_RUNNING ) # Identify start and end timestamp @@ -131,18 +141,24 @@ def mock_workflow_records(self, completed=False, expired=True, log=True): end_timestamp = utc_now_dt if completed else None # Assign metadata. - action_ref = 'orquesta_tests.sequential' - runner = 'orquesta' - user = 'stanley' + action_ref = "orquesta_tests.sequential" + runner = "orquesta" + user = "stanley" # Create the WorkflowExecutionDB record first since the ID needs to be # included in the LiveActionDB and ActionExecutionDB records. - st2_ctx = {'st2': {'action_execution_id': '123', 'action': 'foobar', 'runner': 'orquesta'}} + st2_ctx = { + "st2": { + "action_execution_id": "123", + "action": "foobar", + "runner": "orquesta", + } + } wf_ex_db = wf_db_models.WorkflowExecutionDB( context=st2_ctx, status=status, start_timestamp=start_timestamp, - end_timestamp=end_timestamp + end_timestamp=end_timestamp, ) wf_ex_db = wf_db_access.WorkflowExecution.insert(wf_ex_db, publish=False) @@ -152,13 +168,10 @@ def mock_workflow_records(self, completed=False, expired=True, log=True): workflow_execution=str(wf_ex_db.id), action=action_ref, action_is_workflow=True, - context={ - 'user': user, - 'workflow_execution': str(wf_ex_db.id) - }, + context={"user": user, "workflow_execution": str(wf_ex_db.id)}, status=status, start_timestamp=start_timestamp, - end_timestamp=end_timestamp + end_timestamp=end_timestamp, ) lv_ac_db = lv_db_access.LiveAction.insert(lv_ac_db, publish=False) @@ -166,30 +179,20 @@ def mock_workflow_records(self, completed=False, expired=True, log=True): # Create the ActionExecutionDB record. ac_ex_db = ex_db_models.ActionExecutionDB( workflow_execution=str(wf_ex_db.id), - action={ - 'runner_type': runner, - 'ref': action_ref - }, - runner={ - 'name': runner - }, - liveaction={ - 'id': str(lv_ac_db.id) - }, - context={ - 'user': user, - 'workflow_execution': str(wf_ex_db.id) - }, + action={"runner_type": runner, "ref": action_ref}, + runner={"name": runner}, + liveaction={"id": str(lv_ac_db.id)}, + context={"user": user, "workflow_execution": str(wf_ex_db.id)}, status=status, start_timestamp=start_timestamp, - end_timestamp=end_timestamp + end_timestamp=end_timestamp, ) if log: - ac_ex_db.log = [{'status': 'running', 'timestamp': start_timestamp}] + ac_ex_db.log = [{"status": "running", "timestamp": start_timestamp}] if log and status in ac_const.LIVEACTION_COMPLETED_STATES: - ac_ex_db.log.append({'status': status, 'timestamp': end_timestamp}) + ac_ex_db.log.append({"status": status, "timestamp": end_timestamp}) ac_ex_db = ex_db_access.ActionExecution.insert(ac_ex_db, publish=False) @@ -199,14 +202,16 @@ def mock_workflow_records(self, completed=False, expired=True, log=True): return wf_ex_db, lv_ac_db, ac_ex_db - def mock_task_records(self, parent, task_id, task_route=0, - completed=True, expired=False, log=True): + def mock_task_records( + self, parent, task_id, task_route=0, completed=True, expired=False, log=True + ): if not completed and expired: - raise ValueError('Task must be set completed=True if expired=True.') + raise ValueError("Task must be set completed=True if expired=True.") status = ( - ac_const.LIVEACTION_STATUS_SUCCEEDED if completed else - ac_const.LIVEACTION_STATUS_RUNNING + ac_const.LIVEACTION_STATUS_SUCCEEDED + if completed + else ac_const.LIVEACTION_STATUS_RUNNING ) parent_wf_ex_db, parent_ac_ex_db = parent[0], parent[2] @@ -218,9 +223,9 @@ def mock_task_records(self, parent, task_id, task_route=0, end_timestamp = expiry_dt if expired else utc_now_dt # Assign metadata. - action_ref = 'core.local' - runner = 'local-shell-cmd' - user = 'stanley' + action_ref = "core.local" + runner = "local-shell-cmd" + user = "stanley" # Create the TaskExecutionDB record first since the ID needs to be # included in the LiveActionDB and ActionExecutionDB records. @@ -229,7 +234,7 @@ def mock_task_records(self, parent, task_id, task_route=0, task_id=task_id, task_route=0, status=status, - start_timestamp=parent_wf_ex_db.start_timestamp + start_timestamp=parent_wf_ex_db.start_timestamp, ) if status in ac_const.LIVEACTION_COMPLETED_STATES: @@ -239,18 +244,15 @@ def mock_task_records(self, parent, task_id, task_route=0, # Build context for LiveActionDB and ActionExecutionDB. context = { - 'user': user, - 'orquesta': { - 'task_id': tk_ex_db.task_id, - 'task_name': tk_ex_db.task_id, - 'workflow_execution_id': str(parent_wf_ex_db.id), - 'task_execution_id': str(tk_ex_db.id), - 'task_route': tk_ex_db.task_route + "user": user, + "orquesta": { + "task_id": tk_ex_db.task_id, + "task_name": tk_ex_db.task_id, + "workflow_execution_id": str(parent_wf_ex_db.id), + "task_execution_id": str(tk_ex_db.id), + "task_route": tk_ex_db.task_route, }, - 'parent': { - 'user': user, - 'execution_id': str(parent_ac_ex_db.id) - } + "parent": {"user": user, "execution_id": str(parent_ac_ex_db.id)}, } # Create the LiveActionDB record. @@ -262,7 +264,7 @@ def mock_task_records(self, parent, task_id, task_route=0, context=context, status=status, start_timestamp=tk_ex_db.start_timestamp, - end_timestamp=tk_ex_db.end_timestamp + end_timestamp=tk_ex_db.end_timestamp, ) lv_ac_db = lv_db_access.LiveAction.insert(lv_ac_db, publish=False) @@ -271,27 +273,22 @@ def mock_task_records(self, parent, task_id, task_route=0, ac_ex_db = ex_db_models.ActionExecutionDB( workflow_execution=str(parent_wf_ex_db.id), task_execution=str(tk_ex_db.id), - action={ - 'runner_type': runner, - 'ref': action_ref - }, - runner={ - 'name': runner - }, - liveaction={ - 'id': str(lv_ac_db.id) - }, + action={"runner_type": runner, "ref": action_ref}, + runner={"name": runner}, + liveaction={"id": str(lv_ac_db.id)}, context=context, status=status, start_timestamp=tk_ex_db.start_timestamp, - end_timestamp=tk_ex_db.end_timestamp + end_timestamp=tk_ex_db.end_timestamp, ) if log: - ac_ex_db.log = [{'status': 'running', 'timestamp': tk_ex_db.start_timestamp}] + ac_ex_db.log = [ + {"status": "running", "timestamp": tk_ex_db.start_timestamp} + ] if log and status in ac_const.LIVEACTION_COMPLETED_STATES: - ac_ex_db.log.append({'status': status, 'timestamp': tk_ex_db.end_timestamp}) + ac_ex_db.log.append({"status": status, "timestamp": tk_ex_db.end_timestamp}) ac_ex_db = ex_db_access.ActionExecution.insert(ac_ex_db, publish=False) @@ -303,18 +300,18 @@ def test_no_orphans(self): # Workflow that is still running with task completed and not expired. wf_ex_set_2 = self.mock_workflow_records(completed=False, expired=False) - self.mock_task_records(wf_ex_set_2, 'task1', completed=True, expired=False) + self.mock_task_records(wf_ex_set_2, "task1", completed=True, expired=False) # Workflow that is still running with task running and not expired. wf_ex_set_3 = self.mock_workflow_records(completed=False, expired=False) - self.mock_task_records(wf_ex_set_3, 'task1', completed=False, expired=False) + self.mock_task_records(wf_ex_set_3, "task1", completed=False, expired=False) # Workflow that is completed and not expired. self.mock_workflow_records(completed=True, expired=False) # Workflow that is completed with task completed and not expired. wf_ex_set_5 = self.mock_workflow_records(completed=True, expired=False) - self.mock_task_records(wf_ex_set_5, 'task1', completed=True, expired=False) + self.mock_task_records(wf_ex_set_5, "task1", completed=True, expired=False) orphaned_ac_ex_dbs = wf_svc.identify_orphaned_workflows() self.assertEqual(len(orphaned_ac_ex_dbs), 0) @@ -339,33 +336,33 @@ def test_identify_orphans_with_no_task_executions(self): def test_identify_orphans_with_task_executions(self): # Workflow that is still running with task completed and expired. wf_ex_set_1 = self.mock_workflow_records(completed=False, expired=True) - self.mock_task_records(wf_ex_set_1, 'task1', completed=True, expired=True) + self.mock_task_records(wf_ex_set_1, "task1", completed=True, expired=True) # Workflow that is still running with task completed and not expired. wf_ex_set_2 = self.mock_workflow_records(completed=False, expired=False) - self.mock_task_records(wf_ex_set_2, 'task1', completed=True, expired=False) + self.mock_task_records(wf_ex_set_2, "task1", completed=True, expired=False) # Workflow that is still running with task running and not expired. wf_ex_set_3 = self.mock_workflow_records(completed=False, expired=False) - self.mock_task_records(wf_ex_set_3, 'task1', completed=False, expired=False) + self.mock_task_records(wf_ex_set_3, "task1", completed=False, expired=False) # Workflow that is still running with multiple tasks and not expired. # One of the task completed passed expiry date but another task is still running. wf_ex_set_4 = self.mock_workflow_records(completed=False, expired=False) - self.mock_task_records(wf_ex_set_4, 'task1', completed=True, expired=True) - self.mock_task_records(wf_ex_set_4, 'task2', completed=False, expired=False) + self.mock_task_records(wf_ex_set_4, "task1", completed=True, expired=True) + self.mock_task_records(wf_ex_set_4, "task2", completed=False, expired=False) # Workflow that is still running with multiple tasks and not expired. # Both of the tasks are completed with one completed only recently. wf_ex_set_5 = self.mock_workflow_records(completed=False, expired=False) - self.mock_task_records(wf_ex_set_5, 'task1', completed=True, expired=True) - self.mock_task_records(wf_ex_set_5, 'task2', completed=True, expired=False) + self.mock_task_records(wf_ex_set_5, "task1", completed=True, expired=True) + self.mock_task_records(wf_ex_set_5, "task2", completed=True, expired=False) # Workflow that is still running with multiple tasks and not expired. # One of the task completed recently and another task is still running. wf_ex_set_6 = self.mock_workflow_records(completed=False, expired=False) - self.mock_task_records(wf_ex_set_6, 'task1', completed=True, expired=False) - self.mock_task_records(wf_ex_set_6, 'task2', completed=False, expired=False) + self.mock_task_records(wf_ex_set_6, "task1", completed=True, expired=False) + self.mock_task_records(wf_ex_set_6, "task2", completed=False, expired=False) orphaned_ac_ex_dbs = wf_svc.identify_orphaned_workflows() self.assertEqual(len(orphaned_ac_ex_dbs), 1) @@ -373,8 +370,10 @@ def test_identify_orphans_with_task_executions(self): def test_action_execution_with_missing_log_entries(self): # Workflow that is still running and expired. However the state change logs are missing. - wf_ex_set_1 = self.mock_workflow_records(completed=False, expired=True, log=False) - self.mock_task_records(wf_ex_set_1, 'task1', completed=True, expired=True) + wf_ex_set_1 = self.mock_workflow_records( + completed=False, expired=True, log=False + ) + self.mock_task_records(wf_ex_set_1, "task1", completed=True, expired=True) orphaned_ac_ex_dbs = wf_svc.identify_orphaned_workflows() self.assertEqual(len(orphaned_ac_ex_dbs), 0) @@ -385,7 +384,7 @@ def test_garbage_collection(self): # Workflow that is still running with task completed and expired. wf_ex_set_2 = self.mock_workflow_records(completed=False, expired=True) - self.mock_task_records(wf_ex_set_2, 'task1', completed=True, expired=True) + self.mock_task_records(wf_ex_set_2, "task1", completed=True, expired=True) # Ensure these workflows are identified as orphans. orphaned_ac_ex_dbs = wf_svc.identify_orphaned_workflows() diff --git a/st2common/tests/unit/services/test_workflow_rerun.py b/st2common/tests/unit/services/test_workflow_rerun.py index 6808991595b..f5ff2bc487a 100644 --- a/st2common/tests/unit/services/test_workflow_rerun.py +++ b/st2common/tests/unit/services/test_workflow_rerun.py @@ -24,6 +24,7 @@ import st2tests import st2tests.config as tests_config + tests_config.parse_args() from local_runner import local_shell_command_runner @@ -42,32 +43,38 @@ from st2tests.mocks import liveaction as mock_lv_ac_xport -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) PACKS = [ TEST_PACK_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] RUNNER_RESULT_FAILED = (action_constants.LIVEACTION_STATUS_FAILED, {}, {}) -RUNNER_RESULT_SUCCEEDED = (action_constants.LIVEACTION_STATUS_SUCCEEDED, {'stdout': 'foobar'}, {}) +RUNNER_RESULT_SUCCEEDED = ( + action_constants.LIVEACTION_STATUS_SUCCEEDED, + {"stdout": "foobar"}, + {}, +) @mock.patch.object( - publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) + publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None) +) @mock.patch.object( publishers.CUDPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) class WorkflowExecutionRerunTest(st2tests.WorkflowTestCase): - @classmethod def setUpClass(cls): super(WorkflowExecutionRerunTest, cls).setUpClass() @@ -77,18 +84,17 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: actions_registrar.register_from_pack(pack) def prep_wf_ex_for_rerun(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") # Manually create the liveaction and action execution objects without publishing. - lv_ac_db1 = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db1 = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db1, ac_ex_db1 = action_service.create_request(lv_ac_db1) # Request the workflow execution. @@ -99,9 +105,12 @@ def prep_wf_ex_for_rerun(self): # Fail workflow execution. self.run_workflow_step( - wf_ex_db, 'task1', 0, + wf_ex_db, + "task1", + 0, expected_ac_ex_db_status=action_constants.LIVEACTION_STATUS_FAILED, - expected_tk_ex_db_status=wf_statuses.FAILED) + expected_tk_ex_db_status=wf_statuses.FAILED, + ) # Check workflow status. conductor, wf_ex_db = workflow_service.refresh_conductor(str(wf_ex_db.id)) @@ -115,20 +124,22 @@ def prep_wf_ex_for_rerun(self): return wf_meta, lv_ac_db1, ac_ex_db1, wf_ex_db @mock.patch.object( - local_shell_command_runner.LocalShellCommandRunner, 'run', - mock.MagicMock(side_effect=[RUNNER_RESULT_FAILED, RUNNER_RESULT_SUCCEEDED])) + local_shell_command_runner.LocalShellCommandRunner, + "run", + mock.MagicMock(side_effect=[RUNNER_RESULT_FAILED, RUNNER_RESULT_SUCCEEDED]), + ) def test_request_rerun(self): # Create and return a failed workflow execution. wf_meta, lv_ac_db1, ac_ex_db1, wf_ex_db = self.prep_wf_ex_for_rerun() # Manually create the liveaction and action execution objects for the rerun. - lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db2, ac_ex_db2 = action_service.create_request(lv_ac_db2) # Request workflow execution rerun. st2_ctx = self.mock_st2_context(ac_ex_db2, ac_ex_db1.context) - st2_ctx['workflow_execution_id'] = str(wf_ex_db.id) - rerun_options = {'ref': str(ac_ex_db1.id), 'tasks': ['task1']} + st2_ctx["workflow_execution_id"] = str(wf_ex_db.id) + rerun_options = {"ref": str(ac_ex_db1.id), "tasks": ["task1"]} wf_ex_db = workflow_service.request_rerun(ac_ex_db2, st2_ctx, rerun_options) wf_ex_db = self.prep_wf_ex(wf_ex_db) @@ -138,7 +149,7 @@ def test_request_rerun(self): self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Complete task1. - self.run_workflow_step(wf_ex_db, 'task1', 0) + self.run_workflow_step(wf_ex_db, "task1", 0) # Check workflow status and make sure it is still running. conductor, wf_ex_db = workflow_service.refresh_conductor(str(wf_ex_db.id)) @@ -150,10 +161,10 @@ def test_request_rerun(self): self.assertEqual(ac_ex_db2.status, action_constants.LIVEACTION_STATUS_RUNNING) def test_request_rerun_while_original_is_still_running(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") # Manually create the liveaction and action execution objects without publishing. - lv_ac_db1 = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db1 = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db1, ac_ex_db1 = action_service.create_request(lv_ac_db1) # Request the workflow execution. @@ -168,16 +179,16 @@ def test_request_rerun_while_original_is_still_running(self): self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Manually create the liveaction and action execution objects for the rerun. - lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db2, ac_ex_db2 = action_service.create_request(lv_ac_db2) # Request workflow execution rerun. st2_ctx = self.mock_st2_context(ac_ex_db2, ac_ex_db1.context) - st2_ctx['workflow_execution_id'] = str(wf_ex_db.id) - rerun_options = {'ref': str(ac_ex_db1.id), 'tasks': ['task1']} + st2_ctx["workflow_execution_id"] = str(wf_ex_db.id) + rerun_options = {"ref": str(ac_ex_db1.id), "tasks": ["task1"]} expected_error = ( - '^Unable to rerun workflow execution \".*\" ' - 'because it is not in a completed state.$' + '^Unable to rerun workflow execution ".*" ' + "because it is not in a completed state.$" ) self.assertRaisesRegexp( @@ -186,24 +197,26 @@ def test_request_rerun_while_original_is_still_running(self): workflow_service.request_rerun, ac_ex_db2, st2_ctx, - rerun_options + rerun_options, ) @mock.patch.object( - local_shell_command_runner.LocalShellCommandRunner, 'run', - mock.MagicMock(side_effect=[RUNNER_RESULT_FAILED, RUNNER_RESULT_SUCCEEDED])) + local_shell_command_runner.LocalShellCommandRunner, + "run", + mock.MagicMock(side_effect=[RUNNER_RESULT_FAILED, RUNNER_RESULT_SUCCEEDED]), + ) def test_request_rerun_again_while_prev_rerun_is_still_running(self): # Create and return a failed workflow execution. wf_meta, lv_ac_db1, ac_ex_db1, wf_ex_db = self.prep_wf_ex_for_rerun() # Manually create the liveaction and action execution objects for the rerun. - lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db2, ac_ex_db2 = action_service.create_request(lv_ac_db2) # Request workflow execution rerun. st2_ctx = self.mock_st2_context(ac_ex_db2, ac_ex_db1.context) - st2_ctx['workflow_execution_id'] = str(wf_ex_db.id) - rerun_options = {'ref': str(ac_ex_db1.id), 'tasks': ['task1']} + st2_ctx["workflow_execution_id"] = str(wf_ex_db.id) + rerun_options = {"ref": str(ac_ex_db1.id), "tasks": ["task1"]} wf_ex_db = workflow_service.request_rerun(ac_ex_db2, st2_ctx, rerun_options) wf_ex_db = self.prep_wf_ex(wf_ex_db) @@ -213,7 +226,7 @@ def test_request_rerun_again_while_prev_rerun_is_still_running(self): self.assertEqual(wf_ex_db.status, wf_statuses.RUNNING) # Complete task1. - self.run_workflow_step(wf_ex_db, 'task1', 0) + self.run_workflow_step(wf_ex_db, "task1", 0) # Check workflow status and make sure it is still running. conductor, wf_ex_db = workflow_service.refresh_conductor(str(wf_ex_db.id)) @@ -225,16 +238,16 @@ def test_request_rerun_again_while_prev_rerun_is_still_running(self): self.assertEqual(ac_ex_db2.status, action_constants.LIVEACTION_STATUS_RUNNING) # Manually create the liveaction and action execution objects for the rerun. - lv_ac_db3 = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db3 = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db3, ac_ex_db3 = action_service.create_request(lv_ac_db3) # Request workflow execution rerun again. st2_ctx = self.mock_st2_context(ac_ex_db3, ac_ex_db1.context) - st2_ctx['workflow_execution_id'] = str(wf_ex_db.id) - rerun_options = {'ref': str(ac_ex_db1.id), 'tasks': ['task1']} + st2_ctx["workflow_execution_id"] = str(wf_ex_db.id) + rerun_options = {"ref": str(ac_ex_db1.id), "tasks": ["task1"]} expected_error = ( - '^Unable to rerun workflow execution \".*\" ' - 'because it is not in a completed state.$' + '^Unable to rerun workflow execution ".*" ' + "because it is not in a completed state.$" ) self.assertRaisesRegexp( @@ -243,26 +256,28 @@ def test_request_rerun_again_while_prev_rerun_is_still_running(self): workflow_service.request_rerun, ac_ex_db3, st2_ctx, - rerun_options + rerun_options, ) @mock.patch.object( - local_shell_command_runner.LocalShellCommandRunner, 'run', - mock.MagicMock(return_value=RUNNER_RESULT_FAILED)) + local_shell_command_runner.LocalShellCommandRunner, + "run", + mock.MagicMock(return_value=RUNNER_RESULT_FAILED), + ) def test_request_rerun_with_missing_workflow_execution_id(self): # Create and return a failed workflow execution. wf_meta, lv_ac_db1, ac_ex_db1, wf_ex_db = self.prep_wf_ex_for_rerun() # Manually create the liveaction and action execution objects for the rerun. - lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db2, ac_ex_db2 = action_service.create_request(lv_ac_db2) # Request workflow execution rerun without workflow_execution_id. st2_ctx = self.mock_st2_context(ac_ex_db2, ac_ex_db1.context) - rerun_options = {'ref': str(ac_ex_db1.id), 'tasks': ['task1']} + rerun_options = {"ref": str(ac_ex_db1.id), "tasks": ["task1"]} expected_error = ( - 'Unable to rerun workflow execution because ' - 'workflow_execution_id is not provided.' + "Unable to rerun workflow execution because " + "workflow_execution_id is not provided." ) self.assertRaisesRegexp( @@ -271,27 +286,28 @@ def test_request_rerun_with_missing_workflow_execution_id(self): workflow_service.request_rerun, ac_ex_db2, st2_ctx, - rerun_options + rerun_options, ) @mock.patch.object( - local_shell_command_runner.LocalShellCommandRunner, 'run', - mock.MagicMock(return_value=RUNNER_RESULT_FAILED)) + local_shell_command_runner.LocalShellCommandRunner, + "run", + mock.MagicMock(return_value=RUNNER_RESULT_FAILED), + ) def test_request_rerun_with_nonexistent_workflow_execution(self): # Create and return a failed workflow execution. wf_meta, lv_ac_db1, ac_ex_db1, wf_ex_db = self.prep_wf_ex_for_rerun() # Manually create the liveaction and action execution objects for the rerun. - lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db2, ac_ex_db2 = action_service.create_request(lv_ac_db2) # Request workflow execution rerun with bogus workflow_execution_id. st2_ctx = self.mock_st2_context(ac_ex_db2, ac_ex_db1.context) - st2_ctx['workflow_execution_id'] = uuid.uuid4().hex[0:24] - rerun_options = {'ref': str(ac_ex_db1.id), 'tasks': ['task1']} + st2_ctx["workflow_execution_id"] = uuid.uuid4().hex[0:24] + rerun_options = {"ref": str(ac_ex_db1.id), "tasks": ["task1"]} expected_error = ( - '^Unable to rerun workflow execution \".*\" ' - 'because it does not exist.$' + '^Unable to rerun workflow execution ".*" ' "because it does not exist.$" ) self.assertRaisesRegexp( @@ -300,12 +316,14 @@ def test_request_rerun_with_nonexistent_workflow_execution(self): workflow_service.request_rerun, ac_ex_db2, st2_ctx, - rerun_options + rerun_options, ) @mock.patch.object( - local_shell_command_runner.LocalShellCommandRunner, 'run', - mock.MagicMock(return_value=RUNNER_RESULT_FAILED)) + local_shell_command_runner.LocalShellCommandRunner, + "run", + mock.MagicMock(return_value=RUNNER_RESULT_FAILED), + ) def test_request_rerun_with_workflow_execution_not_abended(self): # Create and return a failed workflow execution. wf_meta, lv_ac_db1, ac_ex_db1, wf_ex_db = self.prep_wf_ex_for_rerun() @@ -315,16 +333,16 @@ def test_request_rerun_with_workflow_execution_not_abended(self): wf_ex_db = wf_db_access.WorkflowExecution.add_or_update(wf_ex_db) # Manually create the liveaction and action execution objects for the rerun. - lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db2, ac_ex_db2 = action_service.create_request(lv_ac_db2) # Request workflow execution rerun with bogus workflow_execution_id. st2_ctx = self.mock_st2_context(ac_ex_db2, ac_ex_db1.context) - st2_ctx['workflow_execution_id'] = str(wf_ex_db.id) - rerun_options = {'ref': str(ac_ex_db1.id), 'tasks': ['task1']} + st2_ctx["workflow_execution_id"] = str(wf_ex_db.id) + rerun_options = {"ref": str(ac_ex_db1.id), "tasks": ["task1"]} expected_error = ( - '^Unable to rerun workflow execution \".*\" ' - 'because it is not in a completed state.$' + '^Unable to rerun workflow execution ".*" ' + "because it is not in a completed state.$" ) self.assertRaisesRegexp( @@ -333,29 +351,33 @@ def test_request_rerun_with_workflow_execution_not_abended(self): workflow_service.request_rerun, ac_ex_db2, st2_ctx, - rerun_options + rerun_options, ) @mock.patch.object( - local_shell_command_runner.LocalShellCommandRunner, 'run', - mock.MagicMock(return_value=RUNNER_RESULT_FAILED)) + local_shell_command_runner.LocalShellCommandRunner, + "run", + mock.MagicMock(return_value=RUNNER_RESULT_FAILED), + ) def test_request_rerun_with_conductor_status_not_abended(self): # Create and return a failed workflow execution. wf_meta, lv_ac_db1, ac_ex_db1, wf_ex_db = self.prep_wf_ex_for_rerun() # Manually set workflow conductor state to paused. - wf_ex_db.state['status'] = wf_statuses.PAUSED + wf_ex_db.state["status"] = wf_statuses.PAUSED wf_ex_db = wf_db_access.WorkflowExecution.add_or_update(wf_ex_db) # Manually create the liveaction and action execution objects for the rerun. - lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db2, ac_ex_db2 = action_service.create_request(lv_ac_db2) # Request workflow execution rerun with bogus workflow_execution_id. st2_ctx = self.mock_st2_context(ac_ex_db2, ac_ex_db1.context) - st2_ctx['workflow_execution_id'] = str(wf_ex_db.id) - rerun_options = {'ref': str(ac_ex_db1.id), 'tasks': ['task1']} - expected_error = 'Unable to rerun workflow because it is not in a completed state.' + st2_ctx["workflow_execution_id"] = str(wf_ex_db.id) + rerun_options = {"ref": str(ac_ex_db1.id), "tasks": ["task1"]} + expected_error = ( + "Unable to rerun workflow because it is not in a completed state." + ) self.assertRaisesRegexp( wf_exc.WorkflowExecutionRerunException, @@ -363,25 +385,29 @@ def test_request_rerun_with_conductor_status_not_abended(self): workflow_service.request_rerun, ac_ex_db2, st2_ctx, - rerun_options + rerun_options, ) @mock.patch.object( - local_shell_command_runner.LocalShellCommandRunner, 'run', - mock.MagicMock(return_value=RUNNER_RESULT_FAILED)) + local_shell_command_runner.LocalShellCommandRunner, + "run", + mock.MagicMock(return_value=RUNNER_RESULT_FAILED), + ) def test_request_rerun_with_bad_task_name(self): # Create and return a failed workflow execution. wf_meta, lv_ac_db1, ac_ex_db1, wf_ex_db = self.prep_wf_ex_for_rerun() # Manually create the liveaction and action execution objects for the rerun. - lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db2, ac_ex_db2 = action_service.create_request(lv_ac_db2) # Request workflow execution. st2_ctx = self.mock_st2_context(ac_ex_db2, ac_ex_db1.context) - st2_ctx['workflow_execution_id'] = str(wf_ex_db.id) - rerun_options = {'ref': str(ac_ex_db1.id), 'tasks': ['task5354']} - expected_error = '^Unable to rerun workflow because one or more tasks is not found: .*$' + st2_ctx["workflow_execution_id"] = str(wf_ex_db.id) + rerun_options = {"ref": str(ac_ex_db1.id), "tasks": ["task5354"]} + expected_error = ( + "^Unable to rerun workflow because one or more tasks is not found: .*$" + ) self.assertRaisesRegexp( wf_exc.WorkflowExecutionRerunException, @@ -389,36 +415,40 @@ def test_request_rerun_with_bad_task_name(self): workflow_service.request_rerun, ac_ex_db2, st2_ctx, - rerun_options + rerun_options, ) @mock.patch.object( - local_shell_command_runner.LocalShellCommandRunner, 'run', - mock.MagicMock(return_value=RUNNER_RESULT_FAILED)) + local_shell_command_runner.LocalShellCommandRunner, + "run", + mock.MagicMock(return_value=RUNNER_RESULT_FAILED), + ) def test_request_rerun_with_conductor_status_not_resuming(self): # Create and return a failed workflow execution. wf_meta, lv_ac_db1, ac_ex_db1, wf_ex_db = self.prep_wf_ex_for_rerun() # Manually create the liveaction and action execution objects for the rerun. - lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta['name']) + lv_ac_db2 = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db2, ac_ex_db2 = action_service.create_request(lv_ac_db2) # Request workflow execution rerun with bogus workflow_execution_id. st2_ctx = self.mock_st2_context(ac_ex_db2, ac_ex_db1.context) - st2_ctx['workflow_execution_id'] = str(wf_ex_db.id) - rerun_options = {'ref': str(ac_ex_db1.id), 'tasks': ['task1']} + st2_ctx["workflow_execution_id"] = str(wf_ex_db.id) + rerun_options = {"ref": str(ac_ex_db1.id), "tasks": ["task1"]} expected_error = ( - '^Unable to rerun workflow execution \".*\" ' - 'due to an unknown cause.' + '^Unable to rerun workflow execution ".*" ' "due to an unknown cause." ) - with mock.patch.object(conducting.WorkflowConductor, 'get_workflow_status', - mock.MagicMock(return_value=wf_statuses.FAILED)): + with mock.patch.object( + conducting.WorkflowConductor, + "get_workflow_status", + mock.MagicMock(return_value=wf_statuses.FAILED), + ): self.assertRaisesRegexp( wf_exc.WorkflowExecutionRerunException, expected_error, workflow_service.request_rerun, ac_ex_db2, st2_ctx, - rerun_options + rerun_options, ) diff --git a/st2common/tests/unit/services/test_workflow_service_retries.py b/st2common/tests/unit/services/test_workflow_service_retries.py index 35fafc12131..baa79c69547 100644 --- a/st2common/tests/unit/services/test_workflow_service_retries.py +++ b/st2common/tests/unit/services/test_workflow_service_retries.py @@ -27,6 +27,7 @@ # XXX: actionsensor import depends on config being setup. import st2tests.config as tests_config + tests_config.parse_args() from st2common.bootstrap import actionsregistrar @@ -49,12 +50,14 @@ from st2tests.mocks import workflow as mock_wf_ex_xport -TEST_PACK = 'orquesta_tests' -TEST_PACK_PATH = st2tests.fixturesloader.get_fixtures_packs_base_path() + '/' + TEST_PACK +TEST_PACK = "orquesta_tests" +TEST_PACK_PATH = ( + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/" + TEST_PACK +) PACKS = [ TEST_PACK_PATH, - st2tests.fixturesloader.get_fixtures_packs_base_path() + '/core' + st2tests.fixturesloader.get_fixtures_packs_base_path() + "/core", ] @@ -63,11 +66,11 @@ def mock_wf_db_update_conflict(wf_ex_db, publish=True, dispatch_trigger=True, **kwargs): - seq_len = len(wf_ex_db.state['sequence']) + seq_len = len(wf_ex_db.state["sequence"]) if seq_len > 0: - current_task_id = wf_ex_db.state['sequence'][seq_len - 1:][0]['id'] - temp_file_path = TEMP_DIR_PATH + '/' + current_task_id + current_task_id = wf_ex_db.state["sequence"][seq_len - 1 :][0]["id"] + temp_file_path = TEMP_DIR_PATH + "/" + current_task_id if os.path.exists(temp_file_path): os.remove(temp_file_path) @@ -77,31 +80,38 @@ def mock_wf_db_update_conflict(wf_ex_db, publish=True, dispatch_trigger=True, ** @mock.patch.object( - publishers.CUDPublisher, - 'publish_update', - mock.MagicMock(return_value=None)) + publishers.CUDPublisher, "publish_update", mock.MagicMock(return_value=None) +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create)) + "publish_create", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_create), +) @mock.patch.object( lv_ac_xport.LiveActionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state)) + "publish_state", + mock.MagicMock(side_effect=mock_lv_ac_xport.MockLiveActionPublisher.publish_state), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_create', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create)) + "publish_create", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_create + ), +) @mock.patch.object( wf_ex_xport.WorkflowExecutionPublisher, - 'publish_state', - mock.MagicMock(side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state)) + "publish_state", + mock.MagicMock( + side_effect=mock_wf_ex_xport.MockWorkflowExecutionPublisher.publish_state + ), +) class OrquestaServiceRetryTest(st2tests.WorkflowTestCase): ensure_indexes = True ensure_indexes_models = [ wf_db_models.WorkflowExecutionDB, wf_db_models.TaskExecutionDB, - ex_q_db_models.ActionExecutionSchedulingQueueItemDB + ex_q_db_models.ActionExecutionSchedulingQueueItemDB, ] @classmethod @@ -113,30 +123,38 @@ def setUpClass(cls): # Register test pack(s). actions_registrar = actionsregistrar.ActionsRegistrar( - use_pack_cache=False, - fail_on_failure=True + use_pack_cache=False, fail_on_failure=True ) for pack in PACKS: actions_registrar.register_from_pack(pack) @mock.patch.object( - coord_svc.NoOpDriver, 'get_lock', - mock.MagicMock(side_effect=[ - coordination.ToozConnectionError('foobar'), - coordination.ToozConnectionError('fubar'), - coord_svc.NoOpLock(name='noop')])) + coord_svc.NoOpDriver, + "get_lock", + mock.MagicMock( + side_effect=[ + coordination.ToozConnectionError("foobar"), + coordination.ToozConnectionError("fubar"), + coord_svc.NoOpLock(name="noop"), + ] + ), + ) def test_recover_from_coordinator_connection_error(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] # Process task1 and expect acquiring lock returns a few connection errors. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) wf_svc.handle_action_execution_completion(tk1_ac_ex_db) @@ -145,45 +163,60 @@ def test_recover_from_coordinator_connection_error(self): self.assertEqual(tk1_ex_db.status, wf_statuses.SUCCEEDED) @mock.patch.object( - coord_svc.NoOpDriver, 'get_lock', - mock.MagicMock(side_effect=coordination.ToozConnectionError('foobar'))) + coord_svc.NoOpDriver, + "get_lock", + mock.MagicMock(side_effect=coordination.ToozConnectionError("foobar")), + ) def test_retries_exhausted_from_coordinator_connection_error(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] # Process task1 but retries exhaused with connection errors. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) # The connection error should raise if retries are exhaused. self.assertRaises( coordination.ToozConnectionError, wf_svc.handle_action_execution_completion, - tk1_ac_ex_db + tk1_ac_ex_db, ) @mock.patch.object( - wf_svc, 'update_task_state', - mock.MagicMock(side_effect=[ - mongoengine.connection.MongoEngineConnectionError(), - mongoengine.connection.MongoEngineConnectionError(), - None])) + wf_svc, + "update_task_state", + mock.MagicMock( + side_effect=[ + mongoengine.connection.MongoEngineConnectionError(), + mongoengine.connection.MongoEngineConnectionError(), + None, + ] + ), + ) def test_recover_from_database_connection_error(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] # Process task1 and expect acquiring lock returns a few connection errors. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) wf_svc.handle_action_execution_completion(tk1_ac_ex_db) @@ -192,61 +225,71 @@ def test_recover_from_database_connection_error(self): self.assertEqual(tk1_ex_db.status, wf_statuses.SUCCEEDED) @mock.patch.object( - wf_svc, 'update_task_state', - mock.MagicMock(side_effect=mongoengine.connection.MongoEngineConnectionError())) + wf_svc, + "update_task_state", + mock.MagicMock(side_effect=mongoengine.connection.MongoEngineConnectionError()), + ) def test_retries_exhausted_from_database_connection_error(self): - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'sequential.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "sequential.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] # Process task1 but retries exhaused with connection errors. - query_filters = {'workflow_execution': str(wf_ex_db.id), 'task_id': 'task1'} + query_filters = {"workflow_execution": str(wf_ex_db.id), "task_id": "task1"} tk1_ex_db = wf_db_access.TaskExecution.query(**query_filters)[0] - tk1_ac_ex_db = ex_db_access.ActionExecution.query(task_execution=str(tk1_ex_db.id))[0] - tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction['id']) + tk1_ac_ex_db = ex_db_access.ActionExecution.query( + task_execution=str(tk1_ex_db.id) + )[0] + tk1_lv_ac_db = lv_db_access.LiveAction.get_by_id(tk1_ac_ex_db.liveaction["id"]) self.assertEqual(tk1_lv_ac_db.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) # The connection error should raise if retries are exhaused. self.assertRaises( mongoengine.connection.MongoEngineConnectionError, wf_svc.handle_action_execution_completion, - tk1_ac_ex_db + tk1_ac_ex_db, ) @mock.patch.object( - wf_db_access.WorkflowExecution, 'update', - mock.MagicMock(side_effect=mock_wf_db_update_conflict)) + wf_db_access.WorkflowExecution, + "update", + mock.MagicMock(side_effect=mock_wf_db_update_conflict), + ) def test_recover_from_database_write_conflicts(self): # Create a temporary file which will be used to signal # which task(s) to mock the DB write conflict. - temp_file_path = TEMP_DIR_PATH + '/task4' + temp_file_path = TEMP_DIR_PATH + "/task4" if not os.path.exists(temp_file_path): - with open(temp_file_path, 'w'): + with open(temp_file_path, "w"): pass # Manually create the liveaction and action execution objects without publishing. - wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, 'join.yaml') - lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta['name']) + wf_meta = self.get_wf_fixture_meta_data(TEST_PACK_PATH, "join.yaml") + lv_ac_db = lv_db_models.LiveActionDB(action=wf_meta["name"]) lv_ac_db, ac_ex_db = ac_svc.request(lv_ac_db) - wf_ex_db = wf_db_access.WorkflowExecution.query(action_execution=str(ac_ex_db.id))[0] + wf_ex_db = wf_db_access.WorkflowExecution.query( + action_execution=str(ac_ex_db.id) + )[0] # Manually request task executions. task_route = 0 - self.run_workflow_step(wf_ex_db, 'task1', task_route) - self.assert_task_running('task2', task_route) - self.assert_task_running('task4', task_route) - self.run_workflow_step(wf_ex_db, 'task2', task_route) - self.assert_task_running('task3', task_route) - self.run_workflow_step(wf_ex_db, 'task4', task_route) - self.assert_task_running('task5', task_route) - self.run_workflow_step(wf_ex_db, 'task3', task_route) - self.assert_task_not_started('task6', task_route) - self.run_workflow_step(wf_ex_db, 'task5', task_route) - self.assert_task_running('task6', task_route) - self.run_workflow_step(wf_ex_db, 'task6', task_route) - self.assert_task_running('task7', task_route) - self.run_workflow_step(wf_ex_db, 'task7', task_route) + self.run_workflow_step(wf_ex_db, "task1", task_route) + self.assert_task_running("task2", task_route) + self.assert_task_running("task4", task_route) + self.run_workflow_step(wf_ex_db, "task2", task_route) + self.assert_task_running("task3", task_route) + self.run_workflow_step(wf_ex_db, "task4", task_route) + self.assert_task_running("task5", task_route) + self.run_workflow_step(wf_ex_db, "task3", task_route) + self.assert_task_not_started("task6", task_route) + self.run_workflow_step(wf_ex_db, "task5", task_route) + self.assert_task_running("task6", task_route) + self.run_workflow_step(wf_ex_db, "task6", task_route) + self.assert_task_running("task7", task_route) + self.run_workflow_step(wf_ex_db, "task7", task_route) self.assert_workflow_completed(str(wf_ex_db.id), status=wf_statuses.SUCCEEDED) # Ensure retry happened. diff --git a/st2common/tests/unit/test_action_alias_utils.py b/st2common/tests/unit/test_action_alias_utils.py index 33b78981a5c..daad0fbe1eb 100644 --- a/st2common/tests/unit/test_action_alias_utils.py +++ b/st2common/tests/unit/test_action_alias_utils.py @@ -14,281 +14,312 @@ # limitations under the License. from __future__ import absolute_import -from sre_parse import (parse, AT, AT_BEGINNING, AT_BEGINNING_STRING, AT_END, AT_END_STRING) +from sre_parse import ( + parse, + AT, + AT_BEGINNING, + AT_BEGINNING_STRING, + AT_END, + AT_END_STRING, +) from mock import Mock from unittest2 import TestCase from st2common.exceptions.content import ParseException from st2common.models.utils.action_alias_utils import ( - ActionAliasFormatParser, search_regex_tokens, - inject_immutable_parameters + ActionAliasFormatParser, + search_regex_tokens, + inject_immutable_parameters, ) class TestActionAliasParser(TestCase): def test_empty_string(self): - alias_format = '' - param_stream = '' + alias_format = "" + param_stream = "" parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() self.assertEqual(extracted_values, {}) def test_arbitrary_pairs(self): # single-word param - alias_format = '' - param_stream = 'a=foobar1' + alias_format = "" + param_stream = "a=foobar1" parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'a': 'foobar1'}) + self.assertEqual(extracted_values, {"a": "foobar1"}) # multi-word double-quoted param - alias_format = 'foo' + alias_format = "foo" param_stream = 'foo a="foobar2 poonies bar"' parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'a': 'foobar2 poonies bar'}) + self.assertEqual(extracted_values, {"a": "foobar2 poonies bar"}) # multi-word single-quoted param - alias_format = 'foo' - param_stream = 'foo a=\'foobar2 poonies bar\'' + alias_format = "foo" + param_stream = "foo a='foobar2 poonies bar'" parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'a': 'foobar2 poonies bar'}) + self.assertEqual(extracted_values, {"a": "foobar2 poonies bar"}) # JSON param - alias_format = 'foo' + alias_format = "foo" param_stream = 'foo a={"foobar2": "poonies"}' parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'a': '{"foobar2": "poonies"}'}) + self.assertEqual(extracted_values, {"a": '{"foobar2": "poonies"}'}) # Multiple mixed params - alias_format = '' + alias_format = "" param_stream = 'a=foobar1 b="boobar2 3 4" c=\'coobar3 4\' d={"a": "b"}' parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'a': 'foobar1', - 'b': 'boobar2 3 4', - 'c': 'coobar3 4', - 'd': '{"a": "b"}'}) + self.assertEqual( + extracted_values, + {"a": "foobar1", "b": "boobar2 3 4", "c": "coobar3 4", "d": '{"a": "b"}'}, + ) # Params along with a "normal" alias format - alias_format = '{{ captain }} is my captain' + alias_format = "{{ captain }} is my captain" param_stream = 'Malcolm Reynolds is my captain weirdo="River Tam"' parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'captain': 'Malcolm Reynolds', - 'weirdo': 'River Tam'}) + self.assertEqual( + extracted_values, {"captain": "Malcolm Reynolds", "weirdo": "River Tam"} + ) def test_simple_parsing(self): - alias_format = 'skip {{a}} more skip {{b}} and skip more.' - param_stream = 'skip a1 more skip b1 and skip more.' + alias_format = "skip {{a}} more skip {{b}} and skip more." + param_stream = "skip a1 more skip b1 and skip more." parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'a': 'a1', 'b': 'b1'}) + self.assertEqual(extracted_values, {"a": "a1", "b": "b1"}) def test_end_string_parsing(self): - alias_format = 'skip {{a}} more skip {{b}}' - param_stream = 'skip a1 more skip b1' + alias_format = "skip {{a}} more skip {{b}}" + param_stream = "skip a1 more skip b1" parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'a': 'a1', 'b': 'b1'}) + self.assertEqual(extracted_values, {"a": "a1", "b": "b1"}) def test_spaced_parsing(self): - alias_format = 'skip {{a}} more skip {{b}} and skip more.' + alias_format = "skip {{a}} more skip {{b}} and skip more." param_stream = 'skip "a1 a2" more skip b1 and skip more.' parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'a': 'a1 a2', 'b': 'b1'}) + self.assertEqual(extracted_values, {"a": "a1 a2", "b": "b1"}) def test_default_values(self): - alias_format = 'acl {{a}} {{b}} {{c}} {{d=1}}' + alias_format = "acl {{a}} {{b}} {{c}} {{d=1}}" param_stream = 'acl "a1 a2" "b1" "c1"' parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'a': 'a1 a2', 'b': 'b1', - 'c': 'c1', 'd': '1'}) + self.assertEqual( + extracted_values, {"a": "a1 a2", "b": "b1", "c": "c1", "d": "1"} + ) def test_spacing(self): - alias_format = 'acl {{a=test}}' - param_stream = 'acl' + alias_format = "acl {{a=test}}" + param_stream = "acl" parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'a': 'test'}) + self.assertEqual(extracted_values, {"a": "test"}) def test_json_parsing(self): - alias_format = 'skip {{a}} more skip.' + alias_format = "skip {{a}} more skip." param_stream = 'skip {"a": "b", "c": "d"} more skip.' parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'a': '{"a": "b", "c": "d"}'}) + self.assertEqual(extracted_values, {"a": '{"a": "b", "c": "d"}'}) def test_mixed_parsing(self): - alias_format = 'skip {{a}} more skip {{b}}.' + alias_format = "skip {{a}} more skip {{b}}." param_stream = 'skip {"a": "b", "c": "d"} more skip x.' parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'a': '{"a": "b", "c": "d"}', - 'b': 'x'}) + self.assertEqual(extracted_values, {"a": '{"a": "b", "c": "d"}', "b": "x"}) def test_param_spaces(self): - alias_format = 's {{a}} more {{ b }} more {{ c=99 }} more {{ d = 99 }}' - param_stream = 's one more two more three more' + alias_format = "s {{a}} more {{ b }} more {{ c=99 }} more {{ d = 99 }}" + param_stream = "s one more two more three more" parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'a': 'one', 'b': 'two', - 'c': 'three', 'd': '99'}) + self.assertEqual( + extracted_values, {"a": "one", "b": "two", "c": "three", "d": "99"} + ) def test_enclosed_defaults(self): - alias_format = 'skip {{ a = value }} more' - param_stream = 'skip one more' + alias_format = "skip {{ a = value }} more" + param_stream = "skip one more" parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'a': 'one'}) + self.assertEqual(extracted_values, {"a": "one"}) - alias_format = 'skip {{ a = value }} more' - param_stream = 'skip more' + alias_format = "skip {{ a = value }} more" + param_stream = "skip more" parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'a': 'value'}) + self.assertEqual(extracted_values, {"a": "value"}) def test_template_defaults(self): - alias_format = 'two by two hands of {{ color = {{ colors.default_color }} }}' - param_stream = 'two by two hands of' + alias_format = "two by two hands of {{ color = {{ colors.default_color }} }}" + param_stream = "two by two hands of" parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'color': '{{ colors.default_color }}'}) + self.assertEqual(extracted_values, {"color": "{{ colors.default_color }}"}) def test_key_value_combinations(self): # one-word value, single extra pair - alias_format = 'testing {{ a }}' - param_stream = 'testing value b=value2' + alias_format = "testing {{ a }}" + param_stream = "testing value b=value2" parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'a': 'value', - 'b': 'value2'}) + self.assertEqual(extracted_values, {"a": "value", "b": "value2"}) # default value, single extra pair with quotes - alias_format = 'testing {{ a=new }}' + alias_format = "testing {{ a=new }}" param_stream = 'testing b="another value"' parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'a': 'new', - 'b': 'another value'}) + self.assertEqual(extracted_values, {"a": "new", "b": "another value"}) # multiple values and multiple extra pairs - alias_format = 'testing {{ b=abc }} {{ c=xyz }}' + alias_format = "testing {{ b=abc }} {{ c=xyz }}" param_stream = 'testing newvalue d={"1": "2"} e="long value"' parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'b': 'newvalue', - 'c': 'xyz', - 'd': '{"1": "2"}', - 'e': 'long value'}) + self.assertEqual( + extracted_values, + {"b": "newvalue", "c": "xyz", "d": '{"1": "2"}', "e": "long value"}, + ) def test_stream_is_none_with_all_default_values(self): - alias_format = 'skip {{d=test1}} more skip {{e=test1}}.' - param_stream = 'skip more skip' + alias_format = "skip {{d=test1}} more skip {{e=test1}}." + param_stream = "skip more skip" parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'d': 'test1', 'e': 'test1'}) + self.assertEqual(extracted_values, {"d": "test1", "e": "test1"}) def test_stream_is_not_none_some_default_values(self): - alias_format = 'skip {{d=test}} more skip {{e=test}}' - param_stream = 'skip ponies more skip' + alias_format = "skip {{d=test}} more skip {{e=test}}" + param_stream = "skip ponies more skip" parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'d': 'ponies', 'e': 'test'}) + self.assertEqual(extracted_values, {"d": "ponies", "e": "test"}) def test_stream_is_none_no_default_values(self): - alias_format = 'skip {{d}} more skip {{e}}.' + alias_format = "skip {{d}} more skip {{e}}." param_stream = None parser = ActionAliasFormatParser(alias_format, param_stream) - expected_msg = 'Command "" doesn\'t match format string "skip {{d}} more skip {{e}}."' - self.assertRaisesRegexp(ParseException, expected_msg, - parser.get_extracted_param_value) + expected_msg = ( + 'Command "" doesn\'t match format string "skip {{d}} more skip {{e}}."' + ) + self.assertRaisesRegexp( + ParseException, expected_msg, parser.get_extracted_param_value + ) def test_all_the_things(self): # this is the most insane example I could come up with - alias_format = "{{ p0='http' }} g {{ p1=p }} a " + \ - "{{ url }} {{ p2={'a':'b'} }} {{ p3={{ e.i }} }}" - param_stream = "g a http://google.com {{ execution.id }} p4='testing' p5={'a':'c'}" + alias_format = ( + "{{ p0='http' }} g {{ p1=p }} a " + + "{{ url }} {{ p2={'a':'b'} }} {{ p3={{ e.i }} }}" + ) + param_stream = ( + "g a http://google.com {{ execution.id }} p4='testing' p5={'a':'c'}" + ) parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'p0': 'http', 'p1': 'p', - 'url': 'http://google.com', - 'p2': '{{ execution.id }}', - 'p3': '{{ e.i }}', - 'p4': 'testing', 'p5': "{'a':'c'}"}) + self.assertEqual( + extracted_values, + { + "p0": "http", + "p1": "p", + "url": "http://google.com", + "p2": "{{ execution.id }}", + "p3": "{{ e.i }}", + "p4": "testing", + "p5": "{'a':'c'}", + }, + ) def test_command_doesnt_match_format_string(self): - alias_format = 'foo bar ponies' - param_stream = 'foo lulz ponies' + alias_format = "foo bar ponies" + param_stream = "foo lulz ponies" parser = ActionAliasFormatParser(alias_format, param_stream) - expected_msg = 'Command "foo lulz ponies" doesn\'t match format string "foo bar ponies"' - self.assertRaisesRegexp(ParseException, expected_msg, - parser.get_extracted_param_value) + expected_msg = ( + 'Command "foo lulz ponies" doesn\'t match format string "foo bar ponies"' + ) + self.assertRaisesRegexp( + ParseException, expected_msg, parser.get_extracted_param_value + ) def test_ending_parameters_matching(self): - alias_format = 'foo bar' - param_stream = 'foo bar pony1=foo pony2=bar' + alias_format = "foo bar" + param_stream = "foo bar pony1=foo pony2=bar" parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'pony1': 'foo', 'pony2': 'bar'}) + self.assertEqual(extracted_values, {"pony1": "foo", "pony2": "bar"}) def test_regex_beginning_anchors(self): - alias_format = r'^\s*foo (?P[A-Z][A-Z0-9]+-[0-9]+)' - param_stream = 'foo ASDF-1234' + alias_format = r"^\s*foo (?P[A-Z][A-Z0-9]+-[0-9]+)" + param_stream = "foo ASDF-1234" parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'issue_key': 'ASDF-1234'}) + self.assertEqual(extracted_values, {"issue_key": "ASDF-1234"}) def test_regex_beginning_anchors_dont_match(self): - alias_format = r'^\s*foo (?P[A-Z][A-Z0-9]+-[0-9]+)' - param_stream = 'bar foo ASDF-1234' + alias_format = r"^\s*foo (?P[A-Z][A-Z0-9]+-[0-9]+)" + param_stream = "bar foo ASDF-1234" parser = ActionAliasFormatParser(alias_format, param_stream) - expected_msg = r'''Command "bar foo ASDF-1234" doesn't match format string '''\ - r'''"^\s*foo (?P[A-Z][A-Z0-9]+-[0-9]+)"''' + expected_msg = ( + r"""Command "bar foo ASDF-1234" doesn't match format string """ + r'''"^\s*foo (?P[A-Z][A-Z0-9]+-[0-9]+)"''' + ) with self.assertRaises(ParseException) as e: parser.get_extracted_param_value() self.assertEqual(e.msg, expected_msg) def test_regex_ending_anchors(self): - alias_format = r'foo (?P[A-Z][A-Z0-9]+-[0-9]+)\s*$' - param_stream = 'foo ASDF-1234' + alias_format = r"foo (?P[A-Z][A-Z0-9]+-[0-9]+)\s*$" + param_stream = "foo ASDF-1234" parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'issue_key': 'ASDF-1234'}) + self.assertEqual(extracted_values, {"issue_key": "ASDF-1234"}) def test_regex_ending_anchors_dont_match(self): - alias_format = r'foo (?P[A-Z][A-Z0-9]+-[0-9]+)\s*$' - param_stream = 'foo ASDF-1234 bar' + alias_format = r"foo (?P[A-Z][A-Z0-9]+-[0-9]+)\s*$" + param_stream = "foo ASDF-1234 bar" parser = ActionAliasFormatParser(alias_format, param_stream) - expected_msg = r'''Command "foo ASDF-1234 bar" doesn't match format string '''\ - r'''"foo (?P[A-Z][A-Z0-9]+-[0-9]+)\s*$"''' + expected_msg = ( + r"""Command "foo ASDF-1234 bar" doesn't match format string """ + r'''"foo (?P[A-Z][A-Z0-9]+-[0-9]+)\s*$"''' + ) with self.assertRaises(ParseException) as e: parser.get_extracted_param_value() self.assertEqual(e.msg, expected_msg) def test_regex_beginning_and_ending_anchors(self): - alias_format = r'^\s*foo (?P[A-Z][A-Z0-9]+-[0-9]+) bar\s*$' - param_stream = 'foo ASDF-1234 bar' + alias_format = r"^\s*foo (?P[A-Z][A-Z0-9]+-[0-9]+) bar\s*$" + param_stream = "foo ASDF-1234 bar" parser = ActionAliasFormatParser(alias_format, param_stream) extracted_values = parser.get_extracted_param_value() - self.assertEqual(extracted_values, {'issue_key': 'ASDF-1234'}) + self.assertEqual(extracted_values, {"issue_key": "ASDF-1234"}) def test_regex_beginning_and_ending_anchors_dont_match(self): - alias_format = r'^\s*foo (?P[A-Z][A-Z0-9]+-[0-9]+)\s*$' - param_stream = 'bar ASDF-1234' + alias_format = r"^\s*foo (?P[A-Z][A-Z0-9]+-[0-9]+)\s*$" + param_stream = "bar ASDF-1234" parser = ActionAliasFormatParser(alias_format, param_stream) - expected_msg = r'''Command "bar ASDF-1234" doesn't match format string '''\ - r'''"^\s*foo (?P[A-Z][A-Z0-9]+-[0-9]+)\s*$"''' + expected_msg = ( + r"""Command "bar ASDF-1234" doesn't match format string """ + r'''"^\s*foo (?P[A-Z][A-Z0-9]+-[0-9]+)\s*$"''' + ) with self.assertRaises(ParseException) as e: parser.get_extracted_param_value() @@ -332,8 +363,8 @@ def test_immutable_parameters_are_injected(self): exec_params = [{"param1": "value1", "param2": "value2"}] inject_immutable_parameters(action_alias_db, exec_params, {}) self.assertEqual( - exec_params, - [{"param1": "value1", "param2": "value2", "env": "dev"}]) + exec_params, [{"param1": "value1", "param2": "value2", "env": "dev"}] + ) def test_immutable_parameters_with_jinja(self): action_alias_db = Mock() @@ -341,8 +372,8 @@ def test_immutable_parameters_with_jinja(self): exec_params = [{"param1": "value1", "param2": "value2"}] inject_immutable_parameters(action_alias_db, exec_params, {}) self.assertEqual( - exec_params, - [{"param1": "value1", "param2": "value2", "env": "dev1"}]) + exec_params, [{"param1": "value1", "param2": "value2", "env": "dev1"}] + ) def test_override_raises_error(self): action_alias_db = Mock() diff --git a/st2common/tests/unit/test_action_api_validator.py b/st2common/tests/unit/test_action_api_validator.py index 1cf16d3f14b..5be1ca13bac 100644 --- a/st2common/tests/unit/test_action_api_validator.py +++ b/st2common/tests/unit/test_action_api_validator.py @@ -14,6 +14,7 @@ # limitations under the License. from __future__ import absolute_import + try: import simplejson as json except ImportError: @@ -29,66 +30,83 @@ from st2tests import DbTestCase from st2tests.fixtures.packs import executions as fixture -__all__ = [ - 'TestActionAPIValidator' -] +__all__ = ["TestActionAPIValidator"] class TestActionAPIValidator(DbTestCase): - @classmethod def setUpClass(cls): super(TestActionAPIValidator, cls).setUpClass() runners_registrar.register_runners() - @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "_is_valid_pack", mock.MagicMock(return_value=True) + ) def test_validate_runner_type_happy_case(self): - action_api_dict = fixture.ARTIFACTS['actions']['local'] + action_api_dict = fixture.ARTIFACTS["actions"]["local"] action_api = ActionAPI(**action_api_dict) try: action_validator.validate_action(action_api) except: - self.fail('Exception validating action: %s' % json.dumps(action_api_dict)) + self.fail("Exception validating action: %s" % json.dumps(action_api_dict)) - @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "_is_valid_pack", mock.MagicMock(return_value=True) + ) def test_validate_runner_type_invalid_runner(self): - action_api_dict = fixture.ARTIFACTS['actions']['action-with-invalid-runner'] + action_api_dict = fixture.ARTIFACTS["actions"]["action-with-invalid-runner"] action_api = ActionAPI(**action_api_dict) try: action_validator.validate_action(action_api) - self.fail('Action validation should not have passed. %s' % json.dumps(action_api_dict)) + self.fail( + "Action validation should not have passed. %s" + % json.dumps(action_api_dict) + ) except ValueValidationException: pass - @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "_is_valid_pack", mock.MagicMock(return_value=True) + ) def test_validate_override_immutable_runner_param(self): - action_api_dict = fixture.ARTIFACTS['actions']['remote-override-runner-immutable'] + action_api_dict = fixture.ARTIFACTS["actions"][ + "remote-override-runner-immutable" + ] action_api = ActionAPI(**action_api_dict) try: action_validator.validate_action(action_api) - self.fail('Action validation should not have passed. %s' % json.dumps(action_api_dict)) + self.fail( + "Action validation should not have passed. %s" + % json.dumps(action_api_dict) + ) except ValueValidationException as e: - self.assertIn('Cannot override in action.', six.text_type(e)) + self.assertIn("Cannot override in action.", six.text_type(e)) - @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "_is_valid_pack", mock.MagicMock(return_value=True) + ) def test_validate_action_param_immutable(self): - action_api_dict = fixture.ARTIFACTS['actions']['action-immutable-param-no-default'] + action_api_dict = fixture.ARTIFACTS["actions"][ + "action-immutable-param-no-default" + ] action_api = ActionAPI(**action_api_dict) try: action_validator.validate_action(action_api) - self.fail('Action validation should not have passed. %s' % json.dumps(action_api_dict)) + self.fail( + "Action validation should not have passed. %s" + % json.dumps(action_api_dict) + ) except ValueValidationException as e: - self.assertIn('requires a default value.', six.text_type(e)) + self.assertIn("requires a default value.", six.text_type(e)) - @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "_is_valid_pack", mock.MagicMock(return_value=True) + ) def test_validate_action_param_immutable_no_default(self): - action_api_dict = fixture.ARTIFACTS['actions']['action-immutable-runner-param-no-default'] + action_api_dict = fixture.ARTIFACTS["actions"][ + "action-immutable-runner-param-no-default" + ] action_api = ActionAPI(**action_api_dict) # Runner param sudo is decalred immutable in action but no defualt value @@ -97,30 +115,44 @@ def test_validate_action_param_immutable_no_default(self): action_validator.validate_action(action_api) except ValueValidationException as e: print(e) - self.fail('Action validation should have passed. %s' % json.dumps(action_api_dict)) + self.fail( + "Action validation should have passed. %s" % json.dumps(action_api_dict) + ) - @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "_is_valid_pack", mock.MagicMock(return_value=True) + ) def test_validate_action_param_position_values_unique(self): - action_api_dict = fixture.ARTIFACTS['actions']['action-with-non-unique-positions'] + action_api_dict = fixture.ARTIFACTS["actions"][ + "action-with-non-unique-positions" + ] action_api = ActionAPI(**action_api_dict) try: action_validator.validate_action(action_api) - self.fail('Action validation should have failed ' + - 'because position values are not unique.' % json.dumps(action_api_dict)) + self.fail( + "Action validation should have failed " + + "because position values are not unique." + % json.dumps(action_api_dict) + ) except ValueValidationException as e: - self.assertIn('have same position', six.text_type(e)) + self.assertIn("have same position", six.text_type(e)) - @mock.patch.object(action_validator, '_is_valid_pack', mock.MagicMock( - return_value=True)) + @mock.patch.object( + action_validator, "_is_valid_pack", mock.MagicMock(return_value=True) + ) def test_validate_action_param_position_values_contiguous(self): - action_api_dict = fixture.ARTIFACTS['actions']['action-with-non-contiguous-positions'] + action_api_dict = fixture.ARTIFACTS["actions"][ + "action-with-non-contiguous-positions" + ] action_api = ActionAPI(**action_api_dict) try: action_validator.validate_action(action_api) - self.fail('Action validation should have failed ' + - 'because position values are not contiguous.' % json.dumps(action_api_dict)) + self.fail( + "Action validation should have failed " + + "because position values are not contiguous." + % json.dumps(action_api_dict) + ) except ValueValidationException as e: - self.assertIn('are not contiguous', six.text_type(e)) + self.assertIn("are not contiguous", six.text_type(e)) diff --git a/st2common/tests/unit/test_action_db_utils.py b/st2common/tests/unit/test_action_db_utils.py index ba2dcef0180..f7a114b85b9 100644 --- a/st2common/tests/unit/test_action_db_utils.py +++ b/st2common/tests/unit/test_action_db_utils.py @@ -35,7 +35,7 @@ from st2tests.base import DbTestCase -@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) +@mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) class ActionDBUtilsTestCase(DbTestCase): runnertype_db = None action_db = None @@ -48,26 +48,39 @@ def setUpClass(cls): def test_get_runnertype_nonexisting(self): # By id. - self.assertRaises(StackStormDBObjectNotFoundError, action_db_utils.get_runnertype_by_id, - 'somedummyrunnerid') + self.assertRaises( + StackStormDBObjectNotFoundError, + action_db_utils.get_runnertype_by_id, + "somedummyrunnerid", + ) # By name. - self.assertRaises(StackStormDBObjectNotFoundError, action_db_utils.get_runnertype_by_name, - 'somedummyrunnername') + self.assertRaises( + StackStormDBObjectNotFoundError, + action_db_utils.get_runnertype_by_name, + "somedummyrunnername", + ) def test_get_runnertype_existing(self): # Lookup by id and verify name equals. - runner = action_db_utils.get_runnertype_by_id(ActionDBUtilsTestCase.runnertype_db.id) + runner = action_db_utils.get_runnertype_by_id( + ActionDBUtilsTestCase.runnertype_db.id + ) self.assertEqual(runner.name, ActionDBUtilsTestCase.runnertype_db.name) # Lookup by name and verify id equals. - runner = action_db_utils.get_runnertype_by_name(ActionDBUtilsTestCase.runnertype_db.name) + runner = action_db_utils.get_runnertype_by_name( + ActionDBUtilsTestCase.runnertype_db.name + ) self.assertEqual(runner.id, ActionDBUtilsTestCase.runnertype_db.id) def test_get_action_nonexisting(self): # By id. - self.assertRaises(StackStormDBObjectNotFoundError, action_db_utils.get_action_by_id, - 'somedummyactionid') + self.assertRaises( + StackStormDBObjectNotFoundError, + action_db_utils.get_action_by_id, + "somedummyactionid", + ) # By ref. - action = action_db_utils.get_action_by_ref('packaintexist.somedummyactionname') + action = action_db_utils.get_action_by_ref("packaintexist.somedummyactionname") self.assertIsNone(action) def test_get_action_existing(self): @@ -77,50 +90,57 @@ def test_get_action_existing(self): # Lookup by reference as string. action_ref = ResourceReference.to_string_reference( pack=ActionDBUtilsTestCase.action_db.pack, - name=ActionDBUtilsTestCase.action_db.name) + name=ActionDBUtilsTestCase.action_db.name, + ) action = action_db_utils.get_action_by_ref(action_ref) self.assertEqual(action.id, ActionDBUtilsTestCase.action_db.id) def test_get_actionexec_nonexisting(self): # By id. - self.assertRaises(StackStormDBObjectNotFoundError, action_db_utils.get_liveaction_by_id, - 'somedummyactionexecid') + self.assertRaises( + StackStormDBObjectNotFoundError, + action_db_utils.get_liveaction_by_id, + "somedummyactionexecid", + ) def test_get_actionexec_existing(self): - liveaction = action_db_utils.get_liveaction_by_id(ActionDBUtilsTestCase.liveaction_db.id) + liveaction = action_db_utils.get_liveaction_by_id( + ActionDBUtilsTestCase.liveaction_db.id + ) self.assertEqual(liveaction, ActionDBUtilsTestCase.liveaction_db) - @mock.patch.object(LiveActionPublisher, 'publish_state', mock.MagicMock()) + @mock.patch.object(LiveActionPublisher, "publish_state", mock.MagicMock()) def test_update_liveaction_with_incorrect_output_schema(self): liveaction_db = LiveActionDB() - liveaction_db.status = 'initializing' + liveaction_db.status = "initializing" liveaction_db.start_timestamp = get_datetime_utc_now() liveaction_db.action = ResourceReference( name=ActionDBUtilsTestCase.action_db.name, - pack=ActionDBUtilsTestCase.action_db.pack).ref + pack=ActionDBUtilsTestCase.action_db.pack, + ).ref params = { - 'actionstr': 'foo', - 'some_key_that_aint_exist_in_action_or_runner': 'bar', - 'runnerint': 555 + "actionstr": "foo", + "some_key_that_aint_exist_in_action_or_runner": "bar", + "runnerint": 555, } liveaction_db.parameters = params runner = mock.MagicMock() - runner.output_schema = { - "notaparam": { - "type": "boolean" - } - } + runner.output_schema = {"notaparam": {"type": "boolean"}} liveaction_db.runner = runner liveaction_db = LiveAction.add_or_update(liveaction_db) origliveaction_db = copy.copy(liveaction_db) now = get_datetime_utc_now() - status = 'succeeded' - result = 'Work is done.' - context = {'third_party_id': uuid.uuid4().hex} + status = "succeeded" + result = "Work is done." + context = {"third_party_id": uuid.uuid4().hex} newliveaction_db = action_db_utils.update_liveaction_status( - status=status, result=result, context=context, end_timestamp=now, - liveaction_id=liveaction_db.id) + status=status, + result=result, + context=context, + end_timestamp=now, + liveaction_id=liveaction_db.id, + ) self.assertEqual(origliveaction_db.id, newliveaction_db.id) self.assertEqual(newliveaction_db.status, status) @@ -128,18 +148,19 @@ def test_update_liveaction_with_incorrect_output_schema(self): self.assertDictEqual(newliveaction_db.context, context) self.assertEqual(newliveaction_db.end_timestamp, now) - @mock.patch.object(LiveActionPublisher, 'publish_state', mock.MagicMock()) + @mock.patch.object(LiveActionPublisher, "publish_state", mock.MagicMock()) def test_update_liveaction_status(self): liveaction_db = LiveActionDB() - liveaction_db.status = 'initializing' + liveaction_db.status = "initializing" liveaction_db.start_timestamp = get_datetime_utc_now() liveaction_db.action = ResourceReference( name=ActionDBUtilsTestCase.action_db.name, - pack=ActionDBUtilsTestCase.action_db.pack).ref + pack=ActionDBUtilsTestCase.action_db.pack, + ).ref params = { - 'actionstr': 'foo', - 'some_key_that_aint_exist_in_action_or_runner': 'bar', - 'runnerint': 555 + "actionstr": "foo", + "some_key_that_aint_exist_in_action_or_runner": "bar", + "runnerint": 555, } liveaction_db.parameters = params liveaction_db = LiveAction.add_or_update(liveaction_db) @@ -147,24 +168,31 @@ def test_update_liveaction_status(self): # Update by id. newliveaction_db = action_db_utils.update_liveaction_status( - status='running', liveaction_id=liveaction_db.id) + status="running", liveaction_id=liveaction_db.id + ) # Verify id didn't change. self.assertEqual(origliveaction_db.id, newliveaction_db.id) - self.assertEqual(newliveaction_db.status, 'running') + self.assertEqual(newliveaction_db.status, "running") # Verify that state is published. self.assertTrue(LiveActionPublisher.publish_state.called) - LiveActionPublisher.publish_state.assert_called_once_with(newliveaction_db, 'running') + LiveActionPublisher.publish_state.assert_called_once_with( + newliveaction_db, "running" + ) # Update status, result, context, and end timestamp. now = get_datetime_utc_now() - status = 'succeeded' - result = 'Work is done.' - context = {'third_party_id': uuid.uuid4().hex} + status = "succeeded" + result = "Work is done." + context = {"third_party_id": uuid.uuid4().hex} newliveaction_db = action_db_utils.update_liveaction_status( - status=status, result=result, context=context, end_timestamp=now, - liveaction_id=liveaction_db.id) + status=status, + result=result, + context=context, + end_timestamp=now, + liveaction_id=liveaction_db.id, + ) self.assertEqual(origliveaction_db.id, newliveaction_db.id) self.assertEqual(newliveaction_db.status, status) @@ -172,18 +200,19 @@ def test_update_liveaction_status(self): self.assertDictEqual(newliveaction_db.context, context) self.assertEqual(newliveaction_db.end_timestamp, now) - @mock.patch.object(LiveActionPublisher, 'publish_state', mock.MagicMock()) + @mock.patch.object(LiveActionPublisher, "publish_state", mock.MagicMock()) def test_update_canceled_liveaction(self): liveaction_db = LiveActionDB() - liveaction_db.status = 'initializing' + liveaction_db.status = "initializing" liveaction_db.start_timestamp = get_datetime_utc_now() liveaction_db.action = ResourceReference( name=ActionDBUtilsTestCase.action_db.name, - pack=ActionDBUtilsTestCase.action_db.pack).ref + pack=ActionDBUtilsTestCase.action_db.pack, + ).ref params = { - 'actionstr': 'foo', - 'some_key_that_aint_exist_in_action_or_runner': 'bar', - 'runnerint': 555 + "actionstr": "foo", + "some_key_that_aint_exist_in_action_or_runner": "bar", + "runnerint": 555, } liveaction_db.parameters = params liveaction_db = LiveAction.add_or_update(liveaction_db) @@ -191,21 +220,25 @@ def test_update_canceled_liveaction(self): # Update by id. newliveaction_db = action_db_utils.update_liveaction_status( - status='running', liveaction_id=liveaction_db.id) + status="running", liveaction_id=liveaction_db.id + ) # Verify id didn't change. self.assertEqual(origliveaction_db.id, newliveaction_db.id) - self.assertEqual(newliveaction_db.status, 'running') + self.assertEqual(newliveaction_db.status, "running") # Verify that state is published. self.assertTrue(LiveActionPublisher.publish_state.called) - LiveActionPublisher.publish_state.assert_called_once_with(newliveaction_db, 'running') + LiveActionPublisher.publish_state.assert_called_once_with( + newliveaction_db, "running" + ) # Cancel liveaction. now = get_datetime_utc_now() - status = 'canceled' + status = "canceled" newliveaction_db = action_db_utils.update_liveaction_status( - status=status, end_timestamp=now, liveaction_id=liveaction_db.id) + status=status, end_timestamp=now, liveaction_id=liveaction_db.id + ) self.assertEqual(origliveaction_db.id, newliveaction_db.id) self.assertEqual(newliveaction_db.status, status) self.assertEqual(newliveaction_db.end_timestamp, now) @@ -213,31 +246,36 @@ def test_update_canceled_liveaction(self): # Since liveaction has already been canceled, check that anymore update of # status, result, context, and end timestamp are not processed. now = get_datetime_utc_now() - status = 'succeeded' - result = 'Work is done.' - context = {'third_party_id': uuid.uuid4().hex} + status = "succeeded" + result = "Work is done." + context = {"third_party_id": uuid.uuid4().hex} newliveaction_db = action_db_utils.update_liveaction_status( - status=status, result=result, context=context, end_timestamp=now, - liveaction_id=liveaction_db.id) + status=status, + result=result, + context=context, + end_timestamp=now, + liveaction_id=liveaction_db.id, + ) self.assertEqual(origliveaction_db.id, newliveaction_db.id) - self.assertEqual(newliveaction_db.status, 'canceled') + self.assertEqual(newliveaction_db.status, "canceled") self.assertNotEqual(newliveaction_db.result, result) self.assertNotEqual(newliveaction_db.context, context) self.assertNotEqual(newliveaction_db.end_timestamp, now) - @mock.patch.object(LiveActionPublisher, 'publish_state', mock.MagicMock()) + @mock.patch.object(LiveActionPublisher, "publish_state", mock.MagicMock()) def test_update_liveaction_result_with_dotted_key(self): liveaction_db = LiveActionDB() - liveaction_db.status = 'initializing' + liveaction_db.status = "initializing" liveaction_db.start_timestamp = get_datetime_utc_now() liveaction_db.action = ResourceReference( name=ActionDBUtilsTestCase.action_db.name, - pack=ActionDBUtilsTestCase.action_db.pack).ref + pack=ActionDBUtilsTestCase.action_db.pack, + ).ref params = { - 'actionstr': 'foo', - 'some_key_that_aint_exist_in_action_or_runner': 'bar', - 'runnerint': 555 + "actionstr": "foo", + "some_key_that_aint_exist_in_action_or_runner": "bar", + "runnerint": 555, } liveaction_db.parameters = params liveaction_db = LiveAction.add_or_update(liveaction_db) @@ -245,66 +283,79 @@ def test_update_liveaction_result_with_dotted_key(self): # Update by id. newliveaction_db = action_db_utils.update_liveaction_status( - status='running', liveaction_id=liveaction_db.id) + status="running", liveaction_id=liveaction_db.id + ) # Verify id didn't change. self.assertEqual(origliveaction_db.id, newliveaction_db.id) - self.assertEqual(newliveaction_db.status, 'running') + self.assertEqual(newliveaction_db.status, "running") # Verify that state is published. self.assertTrue(LiveActionPublisher.publish_state.called) - LiveActionPublisher.publish_state.assert_called_once_with(newliveaction_db, 'running') + LiveActionPublisher.publish_state.assert_called_once_with( + newliveaction_db, "running" + ) now = get_datetime_utc_now() - status = 'succeeded' - result = {'a': 1, 'b': True, 'a.b.c': 'abc'} - context = {'third_party_id': uuid.uuid4().hex} + status = "succeeded" + result = {"a": 1, "b": True, "a.b.c": "abc"} + context = {"third_party_id": uuid.uuid4().hex} newliveaction_db = action_db_utils.update_liveaction_status( - status=status, result=result, context=context, end_timestamp=now, - liveaction_id=liveaction_db.id) + status=status, + result=result, + context=context, + end_timestamp=now, + liveaction_id=liveaction_db.id, + ) self.assertEqual(origliveaction_db.id, newliveaction_db.id) self.assertEqual(newliveaction_db.status, status) - self.assertIn('a.b.c', list(result.keys())) + self.assertIn("a.b.c", list(result.keys())) self.assertDictEqual(newliveaction_db.result, result) self.assertDictEqual(newliveaction_db.context, context) self.assertEqual(newliveaction_db.end_timestamp, now) - @mock.patch.object(LiveActionPublisher, 'publish_state', mock.MagicMock()) + @mock.patch.object(LiveActionPublisher, "publish_state", mock.MagicMock()) def test_update_LiveAction_status_invalid(self): liveaction_db = LiveActionDB() - liveaction_db.status = 'initializing' + liveaction_db.status = "initializing" liveaction_db.start_timestamp = get_datetime_utc_now() liveaction_db.action = ResourceReference( name=ActionDBUtilsTestCase.action_db.name, - pack=ActionDBUtilsTestCase.action_db.pack).ref + pack=ActionDBUtilsTestCase.action_db.pack, + ).ref params = { - 'actionstr': 'foo', - 'some_key_that_aint_exist_in_action_or_runner': 'bar', - 'runnerint': 555 + "actionstr": "foo", + "some_key_that_aint_exist_in_action_or_runner": "bar", + "runnerint": 555, } liveaction_db.parameters = params liveaction_db = LiveAction.add_or_update(liveaction_db) # Update by id. - self.assertRaises(ValueError, action_db_utils.update_liveaction_status, - status='mea culpa', liveaction_id=liveaction_db.id) + self.assertRaises( + ValueError, + action_db_utils.update_liveaction_status, + status="mea culpa", + liveaction_id=liveaction_db.id, + ) # Verify that state is not published. self.assertFalse(LiveActionPublisher.publish_state.called) - @mock.patch.object(LiveActionPublisher, 'publish_state', mock.MagicMock()) + @mock.patch.object(LiveActionPublisher, "publish_state", mock.MagicMock()) def test_update_same_liveaction_status(self): liveaction_db = LiveActionDB() - liveaction_db.status = 'requested' + liveaction_db.status = "requested" liveaction_db.start_timestamp = get_datetime_utc_now() liveaction_db.action = ResourceReference( name=ActionDBUtilsTestCase.action_db.name, - pack=ActionDBUtilsTestCase.action_db.pack).ref + pack=ActionDBUtilsTestCase.action_db.pack, + ).ref params = { - 'actionstr': 'foo', - 'some_key_that_aint_exist_in_action_or_runner': 'bar', - 'runnerint': 555 + "actionstr": "foo", + "some_key_that_aint_exist_in_action_or_runner": "bar", + "runnerint": 555, } liveaction_db.parameters = params liveaction_db = LiveAction.add_or_update(liveaction_db) @@ -312,141 +363,150 @@ def test_update_same_liveaction_status(self): # Update by id. newliveaction_db = action_db_utils.update_liveaction_status( - status='requested', liveaction_id=liveaction_db.id) + status="requested", liveaction_id=liveaction_db.id + ) # Verify id didn't change. self.assertEqual(origliveaction_db.id, newliveaction_db.id) - self.assertEqual(newliveaction_db.status, 'requested') + self.assertEqual(newliveaction_db.status, "requested") # Verify that state is not published. self.assertFalse(LiveActionPublisher.publish_state.called) def test_get_args(self): - params = { - 'actionstr': 'foo', - 'actionint': 20, - 'runnerint': 555 - } - pos_args, named_args = action_db_utils.get_args(params, ActionDBUtilsTestCase.action_db) - self.assertListEqual(pos_args, ['20', '', 'foo', '', '', '', '', ''], - 'Positional args not parsed correctly.') - self.assertNotIn('actionint', named_args) - self.assertNotIn('actionstr', named_args) - self.assertEqual(named_args.get('runnerint'), 555) + params = {"actionstr": "foo", "actionint": 20, "runnerint": 555} + pos_args, named_args = action_db_utils.get_args( + params, ActionDBUtilsTestCase.action_db + ) + self.assertListEqual( + pos_args, + ["20", "", "foo", "", "", "", "", ""], + "Positional args not parsed correctly.", + ) + self.assertNotIn("actionint", named_args) + self.assertNotIn("actionstr", named_args) + self.assertEqual(named_args.get("runnerint"), 555) # Test serialization for different positional argument types and values # Test all the values provided params = { - 'actionint': 1, - 'actionfloat': 1.5, - 'actionstr': 'string value', - 'actionbool': True, - 'actionarray': ['foo', 'bar', 'baz', 'qux'], - 'actionlist': ['foo', 'bar', 'baz'], - 'actionobject': {'a': 1, 'b': '2'}, + "actionint": 1, + "actionfloat": 1.5, + "actionstr": "string value", + "actionbool": True, + "actionarray": ["foo", "bar", "baz", "qux"], + "actionlist": ["foo", "bar", "baz"], + "actionobject": {"a": 1, "b": "2"}, } expected_pos_args = [ - '1', - '1.5', - 'string value', - '1', - 'foo,bar,baz,qux', - 'foo,bar,baz', + "1", + "1.5", + "string value", + "1", + "foo,bar,baz,qux", + "foo,bar,baz", '{"a": 1, "b": "2"}', - '' + "", ] pos_args, _ = action_db_utils.get_args(params, ActionDBUtilsTestCase.action_db) - self.assertListEqual(pos_args, expected_pos_args, - 'Positional args not parsed / serialized correctly.') + self.assertListEqual( + pos_args, + expected_pos_args, + "Positional args not parsed / serialized correctly.", + ) params = { - 'actionint': 1, - 'actionfloat': 1.5, - 'actionstr': 'string value', - 'actionbool': False, - 'actionarray': [], - 'actionlist': [], - 'actionobject': {'a': 1, 'b': '2'}, + "actionint": 1, + "actionfloat": 1.5, + "actionstr": "string value", + "actionbool": False, + "actionarray": [], + "actionlist": [], + "actionobject": {"a": 1, "b": "2"}, } expected_pos_args = [ - '1', - '1.5', - 'string value', - '0', - '', - '', + "1", + "1.5", + "string value", + "0", + "", + "", '{"a": 1, "b": "2"}', - '' + "", ] pos_args, _ = action_db_utils.get_args(params, ActionDBUtilsTestCase.action_db) - self.assertListEqual(pos_args, expected_pos_args, - 'Positional args not parsed / serialized correctly.') + self.assertListEqual( + pos_args, + expected_pos_args, + "Positional args not parsed / serialized correctly.", + ) # Test none values params = { - 'actionint': None, - 'actionfloat': None, - 'actionstr': None, - 'actionbool': None, - 'actionarray': None, - 'actionlist': None, - 'actionobject': None, + "actionint": None, + "actionfloat": None, + "actionstr": None, + "actionbool": None, + "actionarray": None, + "actionlist": None, + "actionobject": None, } - expected_pos_args = [ - '', - '', - '', - '', - '', - '', - '', - '' - ] + expected_pos_args = ["", "", "", "", "", "", "", ""] pos_args, _ = action_db_utils.get_args(params, ActionDBUtilsTestCase.action_db) - self.assertListEqual(pos_args, expected_pos_args, - 'Positional args not parsed / serialized correctly.') + self.assertListEqual( + pos_args, + expected_pos_args, + "Positional args not parsed / serialized correctly.", + ) # Test unicode values params = { - 'actionstr': 'bar č š hello đ č p ž Ž a 💩😁', - 'actionint': 20, - 'runnerint': 555 + "actionstr": "bar č š hello đ č p ž Ž a 💩😁", + "actionint": 20, + "runnerint": 555, } expected_pos_args = [ - '20', - '', - u'bar č š hello đ č p ž Ž a 💩😁', - '', - '', - '', - '', - '' + "20", + "", + "bar č š hello đ č p ž Ž a 💩😁", + "", + "", + "", + "", + "", ] - pos_args, named_args = action_db_utils.get_args(params, ActionDBUtilsTestCase.action_db) - self.assertListEqual(pos_args, expected_pos_args, 'Positional args not parsed correctly.') + pos_args, named_args = action_db_utils.get_args( + params, ActionDBUtilsTestCase.action_db + ) + self.assertListEqual( + pos_args, expected_pos_args, "Positional args not parsed correctly." + ) # Test arrays and lists with values of different types params = { - 'actionarray': [None, False, 1, 4.2e1, '1e3', 'foo'], - 'actionlist': [None, False, 1, 73e-2, '1e2', 'bar'] + "actionarray": [None, False, 1, 4.2e1, "1e3", "foo"], + "actionlist": [None, False, 1, 73e-2, "1e2", "bar"], } expected_pos_args = [ - '', - '', - '', - '', - 'None,False,1,42.0,1e3,foo', - 'None,False,1,0.73,1e2,bar', - '', - '' + "", + "", + "", + "", + "None,False,1,42.0,1e3,foo", + "None,False,1,0.73,1e2,bar", + "", + "", ] pos_args, _ = action_db_utils.get_args(params, ActionDBUtilsTestCase.action_db) - self.assertListEqual(pos_args, expected_pos_args, - 'Positional args not parsed / serialized correctly.') + self.assertListEqual( + pos_args, + expected_pos_args, + "Positional args not parsed / serialized correctly.", + ) - self.assertNotIn('actionint', named_args) - self.assertNotIn('actionstr', named_args) - self.assertEqual(named_args.get('runnerint'), 555) + self.assertNotIn("actionint", named_args) + self.assertNotIn("actionstr", named_args) + self.assertEqual(named_args.get("runnerint"), 555) @classmethod def _setup_test_models(cls): @@ -456,63 +516,65 @@ def _setup_test_models(cls): @classmethod def setup_runner(cls): test_runner = { - 'name': 'test-runner', - 'description': 'A test runner.', - 'enabled': True, - 'runner_parameters': { - 'runnerstr': { - 'description': 'Foo str param.', - 'type': 'string', - 'default': 'defaultfoo' + "name": "test-runner", + "description": "A test runner.", + "enabled": True, + "runner_parameters": { + "runnerstr": { + "description": "Foo str param.", + "type": "string", + "default": "defaultfoo", }, - 'runnerint': { - 'description': 'Foo int param.', - 'type': 'number' + "runnerint": {"description": "Foo int param.", "type": "number"}, + "runnerdummy": { + "description": "Dummy param.", + "type": "string", + "default": "runnerdummy", }, - 'runnerdummy': { - 'description': 'Dummy param.', - 'type': 'string', - 'default': 'runnerdummy' - } }, - 'runner_module': 'tests.test_runner' + "runner_module": "tests.test_runner", } runnertype_api = RunnerTypeAPI(**test_runner) ActionDBUtilsTestCase.runnertype_db = RunnerType.add_or_update( - RunnerTypeAPI.to_model(runnertype_api)) + RunnerTypeAPI.to_model(runnertype_api) + ) @classmethod - @mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) + @mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) def setup_action_models(cls): - pack = 'wolfpack' - name = 'action-1' + pack = "wolfpack" + name = "action-1" parameters = { - 'actionint': {'type': 'number', 'default': 10, 'position': 0}, - 'actionfloat': {'type': 'float', 'required': False, 'position': 1}, - 'actionstr': {'type': 'string', 'required': True, 'position': 2}, - 'actionbool': {'type': 'boolean', 'required': False, 'position': 3}, - 'actionarray': {'type': 'array', 'required': False, 'position': 4}, - 'actionlist': {'type': 'list', 'required': False, 'position': 5}, - 'actionobject': {'type': 'object', 'required': False, 'position': 6}, - 'actionnull': {'type': 'null', 'required': False, 'position': 7}, - - 'runnerdummy': {'type': 'string', 'default': 'actiondummy'} + "actionint": {"type": "number", "default": 10, "position": 0}, + "actionfloat": {"type": "float", "required": False, "position": 1}, + "actionstr": {"type": "string", "required": True, "position": 2}, + "actionbool": {"type": "boolean", "required": False, "position": 3}, + "actionarray": {"type": "array", "required": False, "position": 4}, + "actionlist": {"type": "list", "required": False, "position": 5}, + "actionobject": {"type": "object", "required": False, "position": 6}, + "actionnull": {"type": "null", "required": False, "position": 7}, + "runnerdummy": {"type": "string", "default": "actiondummy"}, } - action_db = ActionDB(pack=pack, name=name, description='awesomeness', - enabled=True, - ref=ResourceReference(name=name, pack=pack).ref, - entry_point='', runner_type={'name': 'test-runner'}, - parameters=parameters) + action_db = ActionDB( + pack=pack, + name=name, + description="awesomeness", + enabled=True, + ref=ResourceReference(name=name, pack=pack).ref, + entry_point="", + runner_type={"name": "test-runner"}, + parameters=parameters, + ) ActionDBUtilsTestCase.action_db = Action.add_or_update(action_db) liveaction_db = LiveActionDB() - liveaction_db.status = 'initializing' + liveaction_db.status = "initializing" liveaction_db.start_timestamp = get_datetime_utc_now() liveaction_db.action = ActionDBUtilsTestCase.action_db.ref params = { - 'actionstr': 'foo', - 'some_key_that_aint_exist_in_action_or_runner': 'bar', - 'runnerint': 555 + "actionstr": "foo", + "some_key_that_aint_exist_in_action_or_runner": "bar", + "runnerint": 555, } liveaction_db.parameters = params ActionDBUtilsTestCase.liveaction_db = LiveAction.add_or_update(liveaction_db) diff --git a/st2common/tests/unit/test_action_param_utils.py b/st2common/tests/unit/test_action_param_utils.py index 08a6654f218..5eecf018dcd 100644 --- a/st2common/tests/unit/test_action_param_utils.py +++ b/st2common/tests/unit/test_action_param_utils.py @@ -28,23 +28,16 @@ TEST_FIXTURES = { - 'actions': [ - 'action1.yaml', - 'action3.yaml' - ], - 'runners': [ - 'testrunner1.yaml', - 'testrunner3.yaml' - ] + "actions": ["action1.yaml", "action3.yaml"], + "runners": ["testrunner1.yaml", "testrunner3.yaml"], } -PACK = 'generic' +PACK = "generic" LOADER = FixturesLoader() FIXTURES = LOADER.load_fixtures(fixtures_pack=PACK, fixtures_dict=TEST_FIXTURES) class ActionParamsUtilsTest(DbTestCase): - @classmethod def setUpClass(cls): super(ActionParamsUtilsTest, cls).setUpClass() @@ -54,86 +47,105 @@ def setUpClass(cls): cls.runnertype_dbs = {} cls.action_dbs = {} - for _, fixture in six.iteritems(FIXTURES['runners']): + for _, fixture in six.iteritems(FIXTURES["runners"]): instance = RunnerTypeAPI(**fixture) runnertype_db = RunnerType.add_or_update(RunnerTypeAPI.to_model(instance)) cls.runnertype_dbs[runnertype_db.name] = runnertype_db - for _, fixture in six.iteritems(FIXTURES['actions']): + for _, fixture in six.iteritems(FIXTURES["actions"]): instance = ActionAPI(**fixture) action_db = Action.add_or_update(ActionAPI.to_model(instance)) cls.action_dbs[action_db.name] = action_db def test_merge_action_runner_params_meta(self): required, optional, immutable = action_param_utils.get_params_view( - action_db=self.action_dbs['action-1'], - runner_db=self.runnertype_dbs['test-runner-1']) + action_db=self.action_dbs["action-1"], + runner_db=self.runnertype_dbs["test-runner-1"], + ) merged = {} merged.update(required) merged.update(optional) merged.update(immutable) consolidated = action_param_utils.get_params_view( - action_db=self.action_dbs['action-1'], - runner_db=self.runnertype_dbs['test-runner-1'], - merged_only=True) + action_db=self.action_dbs["action-1"], + runner_db=self.runnertype_dbs["test-runner-1"], + merged_only=True, + ) # Validate that merged_only view works. self.assertEqual(merged, consolidated) # Validate required params. - self.assertEqual(len(required), 1, 'Required should contain only one param.') - self.assertIn('actionstr', required, 'actionstr param is a required param.') - self.assertNotIn('actionstr', optional, 'actionstr should not be in optional parameters') - self.assertNotIn('actionstr', immutable, 'actionstr should not be in immutable parameters') - self.assertIn('actionstr', merged, 'actionstr should be in action parameters') + self.assertEqual(len(required), 1, "Required should contain only one param.") + self.assertIn("actionstr", required, "actionstr param is a required param.") + self.assertNotIn( + "actionstr", optional, "actionstr should not be in optional parameters" + ) + self.assertNotIn( + "actionstr", immutable, "actionstr should not be in immutable parameters" + ) + self.assertIn("actionstr", merged, "actionstr should be in action parameters") # Validate immutable params. - self.assertIn('runnerimmutable', immutable, 'runnerimmutable should be in immutable.') - self.assertIn('actionimmutable', immutable, 'actionimmutable should be in immutable.') + self.assertIn( + "runnerimmutable", immutable, "runnerimmutable should be in immutable." + ) + self.assertIn( + "actionimmutable", immutable, "actionimmutable should be in immutable." + ) # Validate optional params. for opt in optional: - self.assertIn(opt, merged, 'Optional %s should be in action parameters' % opt) - self.assertNotIn(opt, required, 'Optional %s should not be in required params' % opt) - self.assertNotIn(opt, immutable, 'Optional %s should not be in immutable params' % opt) + self.assertIn( + opt, merged, "Optional %s should be in action parameters" % opt + ) + self.assertNotIn( + opt, required, "Optional %s should not be in required params" % opt + ) + self.assertNotIn( + opt, immutable, "Optional %s should not be in immutable params" % opt + ) def test_merge_param_meta_values(self): runner_meta = copy.deepcopy( - self.runnertype_dbs['test-runner-1'].runner_parameters['runnerdummy']) - action_meta = copy.deepcopy(self.action_dbs['action-1'].parameters['runnerdummy']) - merged_meta = action_param_utils._merge_param_meta_values(action_meta=action_meta, - runner_meta=runner_meta) + self.runnertype_dbs["test-runner-1"].runner_parameters["runnerdummy"] + ) + action_meta = copy.deepcopy( + self.action_dbs["action-1"].parameters["runnerdummy"] + ) + merged_meta = action_param_utils._merge_param_meta_values( + action_meta=action_meta, runner_meta=runner_meta + ) # Description is in runner meta but not in action meta. - self.assertEqual(merged_meta['description'], runner_meta['description']) + self.assertEqual(merged_meta["description"], runner_meta["description"]) # Default value is overridden in action. - self.assertEqual(merged_meta['default'], action_meta['default']) + self.assertEqual(merged_meta["default"], action_meta["default"]) # Immutability is set in action. - self.assertEqual(merged_meta['immutable'], action_meta['immutable']) + self.assertEqual(merged_meta["immutable"], action_meta["immutable"]) def test_merge_param_meta_require_override(self): - action_meta = { - 'required': False - } - runner_meta = { - 'required': True - } - merged_meta = action_param_utils._merge_param_meta_values(action_meta=action_meta, - runner_meta=runner_meta) + action_meta = {"required": False} + runner_meta = {"required": True} + merged_meta = action_param_utils._merge_param_meta_values( + action_meta=action_meta, runner_meta=runner_meta + ) - self.assertEqual(merged_meta['required'], action_meta['required']) + self.assertEqual(merged_meta["required"], action_meta["required"]) def test_validate_action_inputs(self): requires, unexpected = action_param_utils.validate_action_parameters( - self.action_dbs['action-1'].ref, {'foo': 'bar'}) + self.action_dbs["action-1"].ref, {"foo": "bar"} + ) - self.assertListEqual(requires, ['actionstr']) - self.assertListEqual(unexpected, ['foo']) + self.assertListEqual(requires, ["actionstr"]) + self.assertListEqual(unexpected, ["foo"]) def test_validate_overridden_action_inputs(self): requires, unexpected = action_param_utils.validate_action_parameters( - self.action_dbs['action-3'].ref, {'k1': 'foo'}) + self.action_dbs["action-3"].ref, {"k1": "foo"} + ) self.assertListEqual(requires, []) self.assertListEqual(unexpected, []) diff --git a/st2common/tests/unit/test_action_system_models.py b/st2common/tests/unit/test_action_system_models.py index 8098759b564..c8812acf38e 100644 --- a/st2common/tests/unit/test_action_system_models.py +++ b/st2common/tests/unit/test_action_system_models.py @@ -19,24 +19,30 @@ from st2common.models.system.action import RemoteAction from st2common.models.system.action import RemoteScriptAction -__all__ = [ - 'RemoteActionTestCase', - 'RemoteScriptActionTestCase' -] +__all__ = ["RemoteActionTestCase", "RemoteScriptActionTestCase"] class RemoteActionTestCase(unittest2.TestCase): def test_instantiation(self): - action = RemoteAction(name='name', action_exec_id='aeid', command='ls -la', - env_vars={'a': 1}, on_behalf_user='onbehalf', user='user', - hosts=['127.0.0.1'], parallel=False, sudo=True, timeout=10) - self.assertEqual(action.name, 'name') - self.assertEqual(action.action_exec_id, 'aeid') - self.assertEqual(action.command, 'ls -la') - self.assertEqual(action.env_vars, {'a': 1}) - self.assertEqual(action.on_behalf_user, 'onbehalf') - self.assertEqual(action.user, 'user') - self.assertEqual(action.hosts, ['127.0.0.1']) + action = RemoteAction( + name="name", + action_exec_id="aeid", + command="ls -la", + env_vars={"a": 1}, + on_behalf_user="onbehalf", + user="user", + hosts=["127.0.0.1"], + parallel=False, + sudo=True, + timeout=10, + ) + self.assertEqual(action.name, "name") + self.assertEqual(action.action_exec_id, "aeid") + self.assertEqual(action.command, "ls -la") + self.assertEqual(action.env_vars, {"a": 1}) + self.assertEqual(action.on_behalf_user, "onbehalf") + self.assertEqual(action.user, "user") + self.assertEqual(action.hosts, ["127.0.0.1"]) self.assertEqual(action.parallel, False) self.assertEqual(action.sudo, True) self.assertEqual(action.timeout, 10) @@ -44,26 +50,35 @@ def test_instantiation(self): class RemoteScriptActionTestCase(unittest2.TestCase): def test_instantiation(self): - action = RemoteScriptAction(name='name', action_exec_id='aeid', - script_local_path_abs='/tmp/sc/ma_script.sh', - script_local_libs_path_abs='/tmp/sc/libs', named_args=None, - positional_args=None, env_vars={'a': 1}, - on_behalf_user='onbehalf', user='user', - remote_dir='/home/mauser', hosts=['127.0.0.1'], - parallel=False, sudo=True, timeout=10) - self.assertEqual(action.name, 'name') - self.assertEqual(action.action_exec_id, 'aeid') - self.assertEqual(action.script_local_libs_path_abs, '/tmp/sc/libs') - self.assertEqual(action.env_vars, {'a': 1}) - self.assertEqual(action.on_behalf_user, 'onbehalf') - self.assertEqual(action.user, 'user') - self.assertEqual(action.remote_dir, '/home/mauser') - self.assertEqual(action.hosts, ['127.0.0.1']) + action = RemoteScriptAction( + name="name", + action_exec_id="aeid", + script_local_path_abs="/tmp/sc/ma_script.sh", + script_local_libs_path_abs="/tmp/sc/libs", + named_args=None, + positional_args=None, + env_vars={"a": 1}, + on_behalf_user="onbehalf", + user="user", + remote_dir="/home/mauser", + hosts=["127.0.0.1"], + parallel=False, + sudo=True, + timeout=10, + ) + self.assertEqual(action.name, "name") + self.assertEqual(action.action_exec_id, "aeid") + self.assertEqual(action.script_local_libs_path_abs, "/tmp/sc/libs") + self.assertEqual(action.env_vars, {"a": 1}) + self.assertEqual(action.on_behalf_user, "onbehalf") + self.assertEqual(action.user, "user") + self.assertEqual(action.remote_dir, "/home/mauser") + self.assertEqual(action.hosts, ["127.0.0.1"]) self.assertEqual(action.parallel, False) self.assertEqual(action.sudo, True) self.assertEqual(action.timeout, 10) - self.assertEqual(action.script_local_dir, '/tmp/sc') - self.assertEqual(action.script_name, 'ma_script.sh') - self.assertEqual(action.remote_script, '/home/mauser/ma_script.sh') - self.assertEqual(action.command, '/home/mauser/ma_script.sh') + self.assertEqual(action.script_local_dir, "/tmp/sc") + self.assertEqual(action.script_name, "ma_script.sh") + self.assertEqual(action.remote_script, "/home/mauser/ma_script.sh") + self.assertEqual(action.command, "/home/mauser/ma_script.sh") diff --git a/st2common/tests/unit/test_actionchain_schema.py b/st2common/tests/unit/test_actionchain_schema.py index 5c968c9a115..e5bba6c0e20 100644 --- a/st2common/tests/unit/test_actionchain_schema.py +++ b/st2common/tests/unit/test_actionchain_schema.py @@ -20,42 +20,48 @@ from st2common.models.system import actionchain from st2tests.fixturesloader import FixturesLoader -FIXTURES_PACK = 'generic' +FIXTURES_PACK = "generic" TEST_FIXTURES = { - 'actionchains': ['chain1.yaml', 'malformedchain.yaml', 'no_default_chain.yaml', - 'chain_with_vars.yaml', 'chain_with_publish.yaml'] + "actionchains": [ + "chain1.yaml", + "malformedchain.yaml", + "no_default_chain.yaml", + "chain_with_vars.yaml", + "chain_with_publish.yaml", + ] } -FIXTURES = FixturesLoader().load_fixtures(fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_FIXTURES) -CHAIN_1 = FIXTURES['actionchains']['chain1.yaml'] -MALFORMED_CHAIN = FIXTURES['actionchains']['malformedchain.yaml'] -NO_DEFAULT_CHAIN = FIXTURES['actionchains']['no_default_chain.yaml'] -CHAIN_WITH_VARS = FIXTURES['actionchains']['chain_with_vars.yaml'] -CHAIN_WITH_PUBLISH = FIXTURES['actionchains']['chain_with_publish.yaml'] +FIXTURES = FixturesLoader().load_fixtures( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES +) +CHAIN_1 = FIXTURES["actionchains"]["chain1.yaml"] +MALFORMED_CHAIN = FIXTURES["actionchains"]["malformedchain.yaml"] +NO_DEFAULT_CHAIN = FIXTURES["actionchains"]["no_default_chain.yaml"] +CHAIN_WITH_VARS = FIXTURES["actionchains"]["chain_with_vars.yaml"] +CHAIN_WITH_PUBLISH = FIXTURES["actionchains"]["chain_with_publish.yaml"] class ActionChainSchemaTest(unittest2.TestCase): - def test_actionchain_schema_valid(self): chain = actionchain.ActionChain(**CHAIN_1) - self.assertEqual(len(chain.chain), len(CHAIN_1['chain'])) - self.assertEqual(chain.default, CHAIN_1['default']) + self.assertEqual(len(chain.chain), len(CHAIN_1["chain"])) + self.assertEqual(chain.default, CHAIN_1["default"]) def test_actionchain_no_default(self): chain = actionchain.ActionChain(**NO_DEFAULT_CHAIN) - self.assertEqual(len(chain.chain), len(NO_DEFAULT_CHAIN['chain'])) + self.assertEqual(len(chain.chain), len(NO_DEFAULT_CHAIN["chain"])) self.assertEqual(chain.default, None) def test_actionchain_with_vars(self): chain = actionchain.ActionChain(**CHAIN_WITH_VARS) - self.assertEqual(len(chain.chain), len(CHAIN_WITH_VARS['chain'])) - self.assertEqual(len(chain.vars), len(CHAIN_WITH_VARS['vars'])) + self.assertEqual(len(chain.chain), len(CHAIN_WITH_VARS["chain"])) + self.assertEqual(len(chain.vars), len(CHAIN_WITH_VARS["vars"])) def test_actionchain_with_publish(self): chain = actionchain.ActionChain(**CHAIN_WITH_PUBLISH) - self.assertEqual(len(chain.chain), len(CHAIN_WITH_PUBLISH['chain'])) - self.assertEqual(len(chain.chain[0].publish), - len(CHAIN_WITH_PUBLISH['chain'][0]['publish'])) + self.assertEqual(len(chain.chain), len(CHAIN_WITH_PUBLISH["chain"])) + self.assertEqual( + len(chain.chain[0].publish), len(CHAIN_WITH_PUBLISH["chain"][0]["publish"]) + ) def test_actionchain_schema_invalid(self): with self.assertRaises(ValidationError): diff --git a/st2common/tests/unit/test_aliasesregistrar.py b/st2common/tests/unit/test_aliasesregistrar.py index b8278305943..4f17246dcfd 100644 --- a/st2common/tests/unit/test_aliasesregistrar.py +++ b/st2common/tests/unit/test_aliasesregistrar.py @@ -22,22 +22,20 @@ from st2tests import DbTestCase from st2tests import fixturesloader -__all__ = [ - 'TestAliasRegistrar' -] +__all__ = ["TestAliasRegistrar"] -ALIASES_FIXTURE_PACK_PATH = os.path.join(fixturesloader.get_fixtures_packs_base_path(), - 'dummy_pack_1') -ALIASES_FIXTURE_PATH = os.path.join(ALIASES_FIXTURE_PACK_PATH, 'aliases') +ALIASES_FIXTURE_PACK_PATH = os.path.join( + fixturesloader.get_fixtures_packs_base_path(), "dummy_pack_1" +) +ALIASES_FIXTURE_PATH = os.path.join(ALIASES_FIXTURE_PACK_PATH, "aliases") class TestAliasRegistrar(DbTestCase): - def test_alias_registration(self): count = aliasesregistrar.register_aliases(pack_dir=ALIASES_FIXTURE_PACK_PATH) # expect all files to contain be aliases self.assertEqual(count, len(os.listdir(ALIASES_FIXTURE_PATH))) action_alias_dbs = ActionAlias.get_all() - self.assertEqual(action_alias_dbs[0].metadata_file, 'aliases/alias1.yaml') + self.assertEqual(action_alias_dbs[0].metadata_file, "aliases/alias1.yaml") diff --git a/st2common/tests/unit/test_api_model_validation.py b/st2common/tests/unit/test_api_model_validation.py index d5f250482b9..20eb98ce6c3 100644 --- a/st2common/tests/unit/test_api_model_validation.py +++ b/st2common/tests/unit/test_api_model_validation.py @@ -18,196 +18,197 @@ from st2common.models.api.base import BaseAPI -__all__ = [ - 'APIModelValidationTestCase' -] +__all__ = ["APIModelValidationTestCase"] class MockAPIModel1(BaseAPI): model = None schema = { - 'title': 'MockAPIModel', - 'description': 'Test', - 'type': 'object', - 'properties': { - 'id': { - 'description': 'The unique identifier for the action runner.', - 'type': ['string', 'null'], - 'default': None + "title": "MockAPIModel", + "description": "Test", + "type": "object", + "properties": { + "id": { + "description": "The unique identifier for the action runner.", + "type": ["string", "null"], + "default": None, }, - 'name': { - 'description': 'The name of the action runner.', - 'type': 'string', - 'required': True + "name": { + "description": "The name of the action runner.", + "type": "string", + "required": True, }, - 'description': { - 'description': 'The description of the action runner.', - 'type': 'string' + "description": { + "description": "The description of the action runner.", + "type": "string", }, - 'enabled': { - 'type': 'boolean', - 'default': True - }, - 'parameters': { - 'type': 'object' - }, - 'permission_grants': { - 'type': 'array', - 'items': { - 'type': 'object', - 'properties': { - 'resource_uid': { - 'type': 'string', - 'description': 'UID of a resource to which this grant applies to.', - 'required': False, - 'default': 'unknown' + "enabled": {"type": "boolean", "default": True}, + "parameters": {"type": "object"}, + "permission_grants": { + "type": "array", + "items": { + "type": "object", + "properties": { + "resource_uid": { + "type": "string", + "description": "UID of a resource to which this grant applies to.", + "required": False, + "default": "unknown", }, - 'enabled': { - 'type': 'boolean', - 'default': True + "enabled": {"type": "boolean", "default": True}, + "description": { + "type": "string", + "description": "Description", + "required": False, }, - 'description': { - 'type': 'string', - 'description': 'Description', - 'required': False - } - } + }, }, - 'default': [] - } + "default": [], + }, }, - 'additionalProperties': False + "additionalProperties": False, } class MockAPIModel2(BaseAPI): model = None schema = { - 'title': 'MockAPIModel2', - 'description': 'Test', - 'type': 'object', - 'properties': { - 'id': { - 'description': 'The unique identifier for the action runner.', - 'type': 'string', - 'default': None + "title": "MockAPIModel2", + "description": "Test", + "type": "object", + "properties": { + "id": { + "description": "The unique identifier for the action runner.", + "type": "string", + "default": None, }, - 'permission_grants': { - 'type': 'array', - 'items': { - 'type': 'object', - 'properties': { - 'resource_uid': { - 'type': 'string', - 'description': 'UID of a resource to which this grant applies to.', - 'required': False, - 'default': None + "permission_grants": { + "type": "array", + "items": { + "type": "object", + "properties": { + "resource_uid": { + "type": "string", + "description": "UID of a resource to which this grant applies to.", + "required": False, + "default": None, }, - 'description': { - 'type': 'string', - 'required': True - } - } + "description": {"type": "string", "required": True}, + }, }, - 'default': [] + "default": [], }, - 'parameters': { - 'type': 'object', - 'properties': { - 'id': { - 'type': 'string', - 'default': None - }, - 'name': { - 'type': 'string', - 'required': True - } + "parameters": { + "type": "object", + "properties": { + "id": {"type": "string", "default": None}, + "name": {"type": "string", "required": True}, }, - 'additionalProperties': False, - } + "additionalProperties": False, + }, }, - 'additionalProperties': False + "additionalProperties": False, } class APIModelValidationTestCase(unittest2.TestCase): def test_validate_default_values_are_set(self): # no "permission_grants" attribute - mock_model_api = MockAPIModel1(name='name') - self.assertEqual(getattr(mock_model_api, 'id', 'notset'), 'notset') - self.assertEqual(mock_model_api.name, 'name') - self.assertEqual(getattr(mock_model_api, 'enabled', None), None) - self.assertEqual(getattr(mock_model_api, 'permission_grants', None), None) + mock_model_api = MockAPIModel1(name="name") + self.assertEqual(getattr(mock_model_api, "id", "notset"), "notset") + self.assertEqual(mock_model_api.name, "name") + self.assertEqual(getattr(mock_model_api, "enabled", None), None) + self.assertEqual(getattr(mock_model_api, "permission_grants", None), None) mock_model_api_validated = mock_model_api.validate() # Validate it doesn't modify object in place - self.assertEqual(getattr(mock_model_api, 'id', 'notset'), 'notset') - self.assertEqual(mock_model_api.name, 'name') - self.assertEqual(getattr(mock_model_api, 'enabled', None), None) + self.assertEqual(getattr(mock_model_api, "id", "notset"), "notset") + self.assertEqual(mock_model_api.name, "name") + self.assertEqual(getattr(mock_model_api, "enabled", None), None) # Verify cleaned object self.assertEqual(mock_model_api_validated.id, None) - self.assertEqual(mock_model_api_validated.name, 'name') + self.assertEqual(mock_model_api_validated.name, "name") self.assertEqual(mock_model_api_validated.enabled, True) self.assertEqual(mock_model_api_validated.permission_grants, []) # "permission_grants" attribute present, but child missing - mock_model_api = MockAPIModel1(name='name', enabled=False, - permission_grants=[{}, {'description': 'test'}]) - self.assertEqual(mock_model_api.name, 'name') + mock_model_api = MockAPIModel1( + name="name", enabled=False, permission_grants=[{}, {"description": "test"}] + ) + self.assertEqual(mock_model_api.name, "name") self.assertEqual(mock_model_api.enabled, False) - self.assertEqual(mock_model_api.permission_grants, [{}, {'description': 'test'}]) + self.assertEqual( + mock_model_api.permission_grants, [{}, {"description": "test"}] + ) mock_model_api_validated = mock_model_api.validate() # Validate it doesn't modify object in place - self.assertEqual(mock_model_api.name, 'name') + self.assertEqual(mock_model_api.name, "name") self.assertEqual(mock_model_api.enabled, False) - self.assertEqual(mock_model_api.permission_grants, [{}, {'description': 'test'}]) + self.assertEqual( + mock_model_api.permission_grants, [{}, {"description": "test"}] + ) # Verify cleaned object self.assertEqual(mock_model_api_validated.id, None) - self.assertEqual(mock_model_api_validated.name, 'name') + self.assertEqual(mock_model_api_validated.name, "name") self.assertEqual(mock_model_api_validated.enabled, False) - self.assertEqual(mock_model_api_validated.permission_grants, - [{'resource_uid': 'unknown', 'enabled': True}, - {'resource_uid': 'unknown', 'enabled': True, 'description': 'test'}]) + self.assertEqual( + mock_model_api_validated.permission_grants, + [ + {"resource_uid": "unknown", "enabled": True}, + {"resource_uid": "unknown", "enabled": True, "description": "test"}, + ], + ) def test_validate_nested_attribute_with_default_not_provided(self): mock_model_api = MockAPIModel2() - self.assertEqual(getattr(mock_model_api, 'id', 'notset'), 'notset') - self.assertEqual(getattr(mock_model_api, 'permission_grants', 'notset'), 'notset') - self.assertEqual(getattr(mock_model_api, 'parameters', 'notset'), 'notset') + self.assertEqual(getattr(mock_model_api, "id", "notset"), "notset") + self.assertEqual( + getattr(mock_model_api, "permission_grants", "notset"), "notset" + ) + self.assertEqual(getattr(mock_model_api, "parameters", "notset"), "notset") mock_model_api_validated = mock_model_api.validate() # Validate it doesn't modify object in place - self.assertEqual(getattr(mock_model_api, 'id', 'notset'), 'notset') - self.assertEqual(getattr(mock_model_api, 'permission_grants', 'notset'), 'notset') - self.assertEqual(getattr(mock_model_api, 'parameters', 'notset'), 'notset') + self.assertEqual(getattr(mock_model_api, "id", "notset"), "notset") + self.assertEqual( + getattr(mock_model_api, "permission_grants", "notset"), "notset" + ) + self.assertEqual(getattr(mock_model_api, "parameters", "notset"), "notset") # Verify cleaned object self.assertEqual(mock_model_api_validated.id, None) self.assertEqual(mock_model_api_validated.permission_grants, []) - self.assertEqual(getattr(mock_model_api_validated, 'parameters', 'notset'), 'notset') + self.assertEqual( + getattr(mock_model_api_validated, "parameters", "notset"), "notset" + ) def test_validate_allow_default_none_for_any_type(self): - mock_model_api = MockAPIModel2(permission_grants=[{'description': 'test'}], - parameters={'name': 'test'}) - self.assertEqual(getattr(mock_model_api, 'id', 'notset'), 'notset') - self.assertEqual(mock_model_api.permission_grants, [{'description': 'test'}]) - self.assertEqual(mock_model_api.parameters, {'name': 'test'}) + mock_model_api = MockAPIModel2( + permission_grants=[{"description": "test"}], parameters={"name": "test"} + ) + self.assertEqual(getattr(mock_model_api, "id", "notset"), "notset") + self.assertEqual(mock_model_api.permission_grants, [{"description": "test"}]) + self.assertEqual(mock_model_api.parameters, {"name": "test"}) mock_model_api_validated = mock_model_api.validate() # Validate it doesn't modify object in place - self.assertEqual(getattr(mock_model_api, 'id', 'notset'), 'notset') - self.assertEqual(mock_model_api.permission_grants, [{'description': 'test'}]) - self.assertEqual(mock_model_api.parameters, {'name': 'test'}) + self.assertEqual(getattr(mock_model_api, "id", "notset"), "notset") + self.assertEqual(mock_model_api.permission_grants, [{"description": "test"}]) + self.assertEqual(mock_model_api.parameters, {"name": "test"}) # Verify cleaned object self.assertEqual(mock_model_api_validated.id, None) - self.assertEqual(mock_model_api_validated.permission_grants, - [{'description': 'test', 'resource_uid': None}]) - self.assertEqual(mock_model_api_validated.parameters, {'id': None, 'name': 'test'}) + self.assertEqual( + mock_model_api_validated.permission_grants, + [{"description": "test", "resource_uid": None}], + ) + self.assertEqual( + mock_model_api_validated.parameters, {"id": None, "name": "test"} + ) diff --git a/st2common/tests/unit/test_casts.py b/st2common/tests/unit/test_casts.py index 55e95ca781f..62bf0ac4e82 100644 --- a/st2common/tests/unit/test_casts.py +++ b/st2common/tests/unit/test_casts.py @@ -23,19 +23,19 @@ class CastsTestCase(unittest2.TestCase): def test_cast_string(self): - cast_func = get_cast('string') + cast_func = get_cast("string") - value = 'test1' + value = "test1" result = cast_func(value) - self.assertEqual(result, 'test1') + self.assertEqual(result, "test1") - value = u'test2' + value = "test2" result = cast_func(value) - self.assertEqual(result, u'test2') + self.assertEqual(result, "test2") - value = '' + value = "" result = cast_func(value) - self.assertEqual(result, '') + self.assertEqual(result, "") # None should be preserved value = None @@ -48,7 +48,7 @@ def test_cast_string(self): self.assertRaisesRegexp(ValueError, expected_msg, cast_func, value) def test_cast_array(self): - cast_func = get_cast('array') + cast_func = get_cast("array") # Python literal value = str([1, 2, 3]) diff --git a/st2common/tests/unit/test_config_loader.py b/st2common/tests/unit/test_config_loader.py index f59e3efe4f2..e1849d7868e 100644 --- a/st2common/tests/unit/test_config_loader.py +++ b/st2common/tests/unit/test_config_loader.py @@ -24,9 +24,7 @@ from st2tests.base import CleanDbTestCase -__all__ = [ - 'ContentPackConfigLoaderTestCase' -] +__all__ = ["ContentPackConfigLoaderTestCase"] class ContentPackConfigLoaderTestCase(CleanDbTestCase): @@ -37,7 +35,7 @@ def test_ensure_local_pack_config_feature_removed(self): # Test a scenario where all the values are loaded from pack local # config and pack global config (pack name.yaml) doesn't exist. # Test a scenario where no values are overridden in the datastore - loader = ContentPackConfigLoader(pack_name='dummy_pack_4') + loader = ContentPackConfigLoader(pack_name="dummy_pack_4") config = loader.get_config() expected_config = {} @@ -46,35 +44,39 @@ def test_ensure_local_pack_config_feature_removed(self): def test_get_config_some_values_overriden_in_datastore(self): # Test a scenario where some values are overriden in datastore via pack # flobal config - kvp_db = set_datastore_value_for_config_key(pack_name='dummy_pack_5', - key_name='api_secret', - value='some_api_secret', - secret=True, - user='joe') + kvp_db = set_datastore_value_for_config_key( + pack_name="dummy_pack_5", + key_name="api_secret", + value="some_api_secret", + secret=True, + user="joe", + ) # This is a secret so a value should be encrypted - self.assertTrue(kvp_db.value != 'some_api_secret') - self.assertTrue(len(kvp_db.value) > len('some_api_secret') * 2) + self.assertTrue(kvp_db.value != "some_api_secret") + self.assertTrue(len(kvp_db.value) > len("some_api_secret") * 2) self.assertTrue(kvp_db.secret) - kvp_db = set_datastore_value_for_config_key(pack_name='dummy_pack_5', - key_name='private_key_path', - value='some_private_key') - self.assertEqual(kvp_db.value, 'some_private_key') + kvp_db = set_datastore_value_for_config_key( + pack_name="dummy_pack_5", + key_name="private_key_path", + value="some_private_key", + ) + self.assertEqual(kvp_db.value, "some_private_key") self.assertFalse(kvp_db.secret) - loader = ContentPackConfigLoader(pack_name='dummy_pack_5', user='joe') + loader = ContentPackConfigLoader(pack_name="dummy_pack_5", user="joe") config = loader.get_config() # regions is provided in the pack global config # api_secret is dynamically loaded from the datastore for a particular user expected_config = { - 'api_key': 'some_api_key', - 'api_secret': 'some_api_secret', - 'regions': ['us-west-1'], - 'region': 'default-region-value', - 'private_key_path': 'some_private_key', - 'non_required_with_default_value': 'config value' + "api_key": "some_api_key", + "api_secret": "some_api_secret", + "regions": ["us-west-1"], + "region": "default-region-value", + "private_key_path": "some_private_key", + "non_required_with_default_value": "config value", } self.assertEqual(config, expected_config) @@ -82,26 +84,26 @@ def test_get_config_some_values_overriden_in_datastore(self): def test_get_config_default_value_from_config_schema_is_used(self): # No value is provided for "region" in the config, default value from config schema # should be used - loader = ContentPackConfigLoader(pack_name='dummy_pack_5') + loader = ContentPackConfigLoader(pack_name="dummy_pack_5") config = loader.get_config() - self.assertEqual(config['region'], 'default-region-value') + self.assertEqual(config["region"], "default-region-value") # Here a default value is specified in schema but an explicit value is provided in the # config - loader = ContentPackConfigLoader(pack_name='dummy_pack_1') + loader = ContentPackConfigLoader(pack_name="dummy_pack_1") config = loader.get_config() - self.assertEqual(config['region'], 'us-west-1') + self.assertEqual(config["region"], "us-west-1") # Config item attribute has required: false # Value is provided in the config - it should be used as provided - pack_name = 'dummy_pack_5' + pack_name = "dummy_pack_5" loader = ContentPackConfigLoader(pack_name=pack_name) config = loader.get_config() - self.assertEqual(config['non_required_with_default_value'], 'config value') + self.assertEqual(config["non_required_with_default_value"], "config value") config_db = Config.get_by_pack(pack_name) - del config_db['values']['non_required_with_default_value'] + del config_db["values"]["non_required_with_default_value"] Config.add_or_update(config_db) # No value in the config - default value should be used @@ -111,10 +113,12 @@ def test_get_config_default_value_from_config_schema_is_used(self): # No config exists for that pack - default value should be used loader = ContentPackConfigLoader(pack_name=pack_name) config = loader.get_config() - self.assertEqual(config['non_required_with_default_value'], 'some default value') + self.assertEqual( + config["non_required_with_default_value"], "some default value" + ) def test_default_values_from_schema_are_used_when_no_config_exists(self): - pack_name = 'dummy_pack_5' + pack_name = "dummy_pack_5" config_db = Config.get_by_pack(pack_name) # Delete the existing config loaded in setUp @@ -122,37 +126,37 @@ def test_default_values_from_schema_are_used_when_no_config_exists(self): config_db.delete() # Verify config has been deleted from the database - self.assertRaises(StackStormDBObjectNotFoundError, Config.get_by_pack, pack_name) + self.assertRaises( + StackStormDBObjectNotFoundError, Config.get_by_pack, pack_name + ) loader = ContentPackConfigLoader(pack_name=pack_name) config = loader.get_config() - self.assertEqual(config['region'], 'default-region-value') + self.assertEqual(config["region"], "default-region-value") def test_default_values_are_used_when_default_values_are_falsey(self): - pack_name = 'dummy_pack_17' + pack_name = "dummy_pack_17" loader = ContentPackConfigLoader(pack_name=pack_name) config = loader.get_config() # 1. Default values are used - self.assertEqual(config['key_with_default_falsy_value_1'], False) - self.assertEqual(config['key_with_default_falsy_value_2'], None) - self.assertEqual(config['key_with_default_falsy_value_3'], {}) - self.assertEqual(config['key_with_default_falsy_value_4'], '') - self.assertEqual(config['key_with_default_falsy_value_5'], 0) - self.assertEqual(config['key_with_default_falsy_value_6']['key_1'], False) - self.assertEqual(config['key_with_default_falsy_value_6']['key_2'], 0) + self.assertEqual(config["key_with_default_falsy_value_1"], False) + self.assertEqual(config["key_with_default_falsy_value_2"], None) + self.assertEqual(config["key_with_default_falsy_value_3"], {}) + self.assertEqual(config["key_with_default_falsy_value_4"], "") + self.assertEqual(config["key_with_default_falsy_value_5"], 0) + self.assertEqual(config["key_with_default_falsy_value_6"]["key_1"], False) + self.assertEqual(config["key_with_default_falsy_value_6"]["key_2"], 0) # 2. Default values are overwrriten with config values which are also falsey values = { - 'key_with_default_falsy_value_1': 0, - 'key_with_default_falsy_value_2': '', - 'key_with_default_falsy_value_3': False, - 'key_with_default_falsy_value_4': None, - 'key_with_default_falsy_value_5': {}, - 'key_with_default_falsy_value_6': { - 'key_2': False - } + "key_with_default_falsy_value_1": 0, + "key_with_default_falsy_value_2": "", + "key_with_default_falsy_value_3": False, + "key_with_default_falsy_value_4": None, + "key_with_default_falsy_value_5": {}, + "key_with_default_falsy_value_6": {"key_2": False}, } config_db = ConfigDB(pack=pack_name, values=values) config_db = Config.add_or_update(config_db) @@ -160,301 +164,296 @@ def test_default_values_are_used_when_default_values_are_falsey(self): loader = ContentPackConfigLoader(pack_name=pack_name) config = loader.get_config() - self.assertEqual(config['key_with_default_falsy_value_1'], 0) - self.assertEqual(config['key_with_default_falsy_value_2'], '') - self.assertEqual(config['key_with_default_falsy_value_3'], False) - self.assertEqual(config['key_with_default_falsy_value_4'], None) - self.assertEqual(config['key_with_default_falsy_value_5'], {}) - self.assertEqual(config['key_with_default_falsy_value_6']['key_1'], False) - self.assertEqual(config['key_with_default_falsy_value_6']['key_2'], False) + self.assertEqual(config["key_with_default_falsy_value_1"], 0) + self.assertEqual(config["key_with_default_falsy_value_2"], "") + self.assertEqual(config["key_with_default_falsy_value_3"], False) + self.assertEqual(config["key_with_default_falsy_value_4"], None) + self.assertEqual(config["key_with_default_falsy_value_5"], {}) + self.assertEqual(config["key_with_default_falsy_value_6"]["key_1"], False) + self.assertEqual(config["key_with_default_falsy_value_6"]["key_2"], False) def test_get_config_nested_schema_default_values_from_config_schema_are_used(self): # Special case for more complex config schemas with attributes ntesting. # Validate that the default values are also used for one level nested object properties. - pack_name = 'dummy_pack_schema_with_nested_object_1' + pack_name = "dummy_pack_schema_with_nested_object_1" # 1. None of the nested object values are provided loader = ContentPackConfigLoader(pack_name=pack_name) config = loader.get_config() expected_config = { - 'api_key': '', - 'api_secret': '', - 'regions': ['us-west-1', 'us-east-1'], - 'auth_settings': { - 'host': '127.0.0.3', - 'port': 8080, - 'device_uids': ['a', 'b', 'c'] - } + "api_key": "", + "api_secret": "", + "regions": ["us-west-1", "us-east-1"], + "auth_settings": { + "host": "127.0.0.3", + "port": 8080, + "device_uids": ["a", "b", "c"], + }, } self.assertEqual(config, expected_config) # 2. Some of the nested object values are provided (host, port) - pack_name = 'dummy_pack_schema_with_nested_object_2' + pack_name = "dummy_pack_schema_with_nested_object_2" loader = ContentPackConfigLoader(pack_name=pack_name) config = loader.get_config() expected_config = { - 'api_key': '', - 'api_secret': '', - 'regions': ['us-west-1', 'us-east-1'], - 'auth_settings': { - 'host': '127.0.0.6', - 'port': 9090, - 'device_uids': ['a', 'b', 'c'] - } + "api_key": "", + "api_secret": "", + "regions": ["us-west-1", "us-east-1"], + "auth_settings": { + "host": "127.0.0.6", + "port": 9090, + "device_uids": ["a", "b", "c"], + }, } self.assertEqual(config, expected_config) # 3. Nested attribute (auth_settings.token) references a non-secret datastore value - pack_name = 'dummy_pack_schema_with_nested_object_3' - - kvp_db = set_datastore_value_for_config_key(pack_name=pack_name, - key_name='auth_settings_token', - value='some_auth_settings_token') - self.assertEqual(kvp_db.value, 'some_auth_settings_token') + pack_name = "dummy_pack_schema_with_nested_object_3" + + kvp_db = set_datastore_value_for_config_key( + pack_name=pack_name, + key_name="auth_settings_token", + value="some_auth_settings_token", + ) + self.assertEqual(kvp_db.value, "some_auth_settings_token") self.assertFalse(kvp_db.secret) loader = ContentPackConfigLoader(pack_name=pack_name) config = loader.get_config() expected_config = { - 'api_key': '', - 'api_secret': '', - 'regions': ['us-west-1', 'us-east-1'], - 'auth_settings': { - 'host': '127.0.0.10', - 'port': 8080, - 'device_uids': ['a', 'b', 'c'], - 'token': 'some_auth_settings_token' - } + "api_key": "", + "api_secret": "", + "regions": ["us-west-1", "us-east-1"], + "auth_settings": { + "host": "127.0.0.10", + "port": 8080, + "device_uids": ["a", "b", "c"], + "token": "some_auth_settings_token", + }, } self.assertEqual(config, expected_config) # 4. Nested attribute (auth_settings.token) references a secret datastore value - pack_name = 'dummy_pack_schema_with_nested_object_4' - - kvp_db = set_datastore_value_for_config_key(pack_name=pack_name, - key_name='auth_settings_token', - value='joe_token_secret', - secret=True, - user='joe') - self.assertTrue(kvp_db.value != 'joe_token_secret') - self.assertTrue(len(kvp_db.value) > len('joe_token_secret') * 2) + pack_name = "dummy_pack_schema_with_nested_object_4" + + kvp_db = set_datastore_value_for_config_key( + pack_name=pack_name, + key_name="auth_settings_token", + value="joe_token_secret", + secret=True, + user="joe", + ) + self.assertTrue(kvp_db.value != "joe_token_secret") + self.assertTrue(len(kvp_db.value) > len("joe_token_secret") * 2) self.assertTrue(kvp_db.secret) - kvp_db = set_datastore_value_for_config_key(pack_name=pack_name, - key_name='auth_settings_token', - value='alice_token_secret', - secret=True, - user='alice') - self.assertTrue(kvp_db.value != 'alice_token_secret') - self.assertTrue(len(kvp_db.value) > len('alice_token_secret') * 2) + kvp_db = set_datastore_value_for_config_key( + pack_name=pack_name, + key_name="auth_settings_token", + value="alice_token_secret", + secret=True, + user="alice", + ) + self.assertTrue(kvp_db.value != "alice_token_secret") + self.assertTrue(len(kvp_db.value) > len("alice_token_secret") * 2) self.assertTrue(kvp_db.secret) - loader = ContentPackConfigLoader(pack_name=pack_name, user='joe') + loader = ContentPackConfigLoader(pack_name=pack_name, user="joe") config = loader.get_config() expected_config = { - 'api_key': '', - 'api_secret': '', - 'regions': ['us-west-1', 'us-east-1'], - 'auth_settings': { - 'host': '127.0.0.11', - 'port': 8080, - 'device_uids': ['a', 'b', 'c'], - 'token': 'joe_token_secret' - } + "api_key": "", + "api_secret": "", + "regions": ["us-west-1", "us-east-1"], + "auth_settings": { + "host": "127.0.0.11", + "port": 8080, + "device_uids": ["a", "b", "c"], + "token": "joe_token_secret", + }, } self.assertEqual(config, expected_config) - loader = ContentPackConfigLoader(pack_name=pack_name, user='alice') + loader = ContentPackConfigLoader(pack_name=pack_name, user="alice") config = loader.get_config() expected_config = { - 'api_key': '', - 'api_secret': '', - 'regions': ['us-west-1', 'us-east-1'], - 'auth_settings': { - 'host': '127.0.0.11', - 'port': 8080, - 'device_uids': ['a', 'b', 'c'], - 'token': 'alice_token_secret' - } + "api_key": "", + "api_secret": "", + "regions": ["us-west-1", "us-east-1"], + "auth_settings": { + "host": "127.0.0.11", + "port": 8080, + "device_uids": ["a", "b", "c"], + "token": "alice_token_secret", + }, } self.assertEqual(config, expected_config) - def test_get_config_dynamic_config_item_render_fails_user_friendly_exception_is_thrown(self): - pack_name = 'dummy_pack_schema_with_nested_object_5' + def test_get_config_dynamic_config_item_render_fails_user_friendly_exception_is_thrown( + self, + ): + pack_name = "dummy_pack_schema_with_nested_object_5" loader = ContentPackConfigLoader(pack_name=pack_name) # Render fails on top-level item - values = { - 'level0_key': '{{st2kvXX.invalid}}' - } + values = {"level0_key": "{{st2kvXX.invalid}}"} config_db = ConfigDB(pack=pack_name, values=values) config_db = Config.add_or_update(config_db) - expected_msg = ('Failed to render dynamic configuration value for key "level0_key" with ' - 'value "{{st2kvXX.invalid}}" for pack ".*?" config: ' - ' ' - '\'st2kvXX\' is undefined') + expected_msg = ( + 'Failed to render dynamic configuration value for key "level0_key" with ' + 'value "{{st2kvXX.invalid}}" for pack ".*?" config: ' + " " + "'st2kvXX' is undefined" + ) self.assertRaisesRegexp(RuntimeError, expected_msg, loader.get_config) config_db.delete() # Renders fails on fist level item - values = { - 'level0_object': { - 'level1_key': '{{st2kvXX.invalid}}' - } - } + values = {"level0_object": {"level1_key": "{{st2kvXX.invalid}}"}} config_db = ConfigDB(pack=pack_name, values=values) Config.add_or_update(config_db) - expected_msg = ('Failed to render dynamic configuration value for key ' - '"level0_object.level1_key" with value "{{st2kvXX.invalid}}"' - ' for pack ".*?" config: ' - ' \'st2kvXX\' is undefined') + expected_msg = ( + "Failed to render dynamic configuration value for key " + '"level0_object.level1_key" with value "{{st2kvXX.invalid}}"' + " for pack \".*?\" config: " + " 'st2kvXX' is undefined" + ) self.assertRaisesRegexp(RuntimeError, expected_msg, loader.get_config) config_db.delete() # Renders fails on second level item values = { - 'level0_object': { - 'level1_object': { - 'level2_key': '{{st2kvXX.invalid}}' - } - } + "level0_object": {"level1_object": {"level2_key": "{{st2kvXX.invalid}}"}} } config_db = ConfigDB(pack=pack_name, values=values) Config.add_or_update(config_db) - expected_msg = ('Failed to render dynamic configuration value for key ' - '"level0_object.level1_object.level2_key" with value "{{st2kvXX.invalid}}"' - ' for pack ".*?" config: ' - ' \'st2kvXX\' is undefined') + expected_msg = ( + "Failed to render dynamic configuration value for key " + '"level0_object.level1_object.level2_key" with value "{{st2kvXX.invalid}}"' + " for pack \".*?\" config: " + " 'st2kvXX' is undefined" + ) self.assertRaisesRegexp(RuntimeError, expected_msg, loader.get_config) config_db.delete() # Renders fails on list item - values = { - 'level0_object': [ - 'abc', - '{{st2kvXX.invalid}}' - ] - } + values = {"level0_object": ["abc", "{{st2kvXX.invalid}}"]} config_db = ConfigDB(pack=pack_name, values=values) Config.add_or_update(config_db) - expected_msg = ('Failed to render dynamic configuration value for key ' - '"level0_object.1" with value "{{st2kvXX.invalid}}"' - ' for pack ".*?" config: ' - ' \'st2kvXX\' is undefined') + expected_msg = ( + "Failed to render dynamic configuration value for key " + '"level0_object.1" with value "{{st2kvXX.invalid}}"' + " for pack \".*?\" config: " + " 'st2kvXX' is undefined" + ) self.assertRaisesRegexp(RuntimeError, expected_msg, loader.get_config) config_db.delete() # Renders fails on nested object in list item - values = { - 'level0_object': [ - {'level2_key': '{{st2kvXX.invalid}}'} - ] - } + values = {"level0_object": [{"level2_key": "{{st2kvXX.invalid}}"}]} config_db = ConfigDB(pack=pack_name, values=values) Config.add_or_update(config_db) - expected_msg = ('Failed to render dynamic configuration value for key ' - '"level0_object.0.level2_key" with value "{{st2kvXX.invalid}}"' - ' for pack ".*?" config: ' - ' \'st2kvXX\' is undefined') + expected_msg = ( + "Failed to render dynamic configuration value for key " + '"level0_object.0.level2_key" with value "{{st2kvXX.invalid}}"' + " for pack \".*?\" config: " + " 'st2kvXX' is undefined" + ) self.assertRaisesRegexp(RuntimeError, expected_msg, loader.get_config) config_db.delete() # Renders fails on invalid syntax - values = { - 'level0_key': '{{ this is some invalid Jinja }}' - } + values = {"level0_key": "{{ this is some invalid Jinja }}"} config_db = ConfigDB(pack=pack_name, values=values) Config.add_or_update(config_db) - expected_msg = ('Failed to render dynamic configuration value for key ' - '"level0_key" with value "{{ this is some invalid Jinja }}"' - ' for pack ".*?" config: ' - ' expected token \'end of print statement\', got \'Jinja\'') + expected_msg = ( + "Failed to render dynamic configuration value for key " + '"level0_key" with value "{{ this is some invalid Jinja }}"' + " for pack \".*?\" config: " + " expected token 'end of print statement', got 'Jinja'" + ) self.assertRaisesRegexp(RuntimeError, expected_msg, loader.get_config) config_db.delete() def test_get_config_dynamic_config_item(self): - pack_name = 'dummy_pack_schema_with_nested_object_6' + pack_name = "dummy_pack_schema_with_nested_object_6" loader = ContentPackConfigLoader(pack_name=pack_name) #################### # value in top level item - KeyValuePair.add_or_update(KeyValuePairDB(name='k1', value='v1')) - values = { - 'level0_key': '{{st2kv.system.k1}}' - } + KeyValuePair.add_or_update(KeyValuePairDB(name="k1", value="v1")) + values = {"level0_key": "{{st2kv.system.k1}}"} config_db = ConfigDB(pack=pack_name, values=values) config_db = Config.add_or_update(config_db) config_rendered = loader.get_config() - self.assertEqual(config_rendered, {'level0_key': 'v1'}) + self.assertEqual(config_rendered, {"level0_key": "v1"}) config_db.delete() def test_get_config_dynamic_config_item_nested_dict(self): - pack_name = 'dummy_pack_schema_with_nested_object_7' + pack_name = "dummy_pack_schema_with_nested_object_7" loader = ContentPackConfigLoader(pack_name=pack_name) - KeyValuePair.add_or_update(KeyValuePairDB(name='k0', value='v0')) - KeyValuePair.add_or_update(KeyValuePairDB(name='k1', value='v1')) - KeyValuePair.add_or_update(KeyValuePairDB(name='k2', value='v2')) + KeyValuePair.add_or_update(KeyValuePairDB(name="k0", value="v0")) + KeyValuePair.add_or_update(KeyValuePairDB(name="k1", value="v1")) + KeyValuePair.add_or_update(KeyValuePairDB(name="k2", value="v2")) #################### # values nested dictionaries values = { - 'level0_key': '{{st2kv.system.k0}}', - 'level0_object': { - 'level1_key': '{{st2kv.system.k1}}', - 'level1_object': { - 'level2_key': '{{st2kv.system.k2}}' - } - } + "level0_key": "{{st2kv.system.k0}}", + "level0_object": { + "level1_key": "{{st2kv.system.k1}}", + "level1_object": {"level2_key": "{{st2kv.system.k2}}"}, + }, } config_db = ConfigDB(pack=pack_name, values=values) config_db = Config.add_or_update(config_db) config_rendered = loader.get_config() - self.assertEqual(config_rendered, - { - 'level0_key': 'v0', - 'level0_object': { - 'level1_key': 'v1', - 'level1_object': { - 'level2_key': 'v2' - } - } - }) + self.assertEqual( + config_rendered, + { + "level0_key": "v0", + "level0_object": { + "level1_key": "v1", + "level1_object": {"level2_key": "v2"}, + }, + }, + ) config_db.delete() def test_get_config_dynamic_config_item_list(self): - pack_name = 'dummy_pack_schema_with_nested_object_7' + pack_name = "dummy_pack_schema_with_nested_object_7" loader = ContentPackConfigLoader(pack_name=pack_name) - KeyValuePair.add_or_update(KeyValuePairDB(name='k0', value='v0')) - KeyValuePair.add_or_update(KeyValuePairDB(name='k1', value='v1')) + KeyValuePair.add_or_update(KeyValuePairDB(name="k0", value="v0")) + KeyValuePair.add_or_update(KeyValuePairDB(name="k1", value="v1")) #################### # values in list values = { - 'level0_key': [ - 'a', - '{{st2kv.system.k0}}', - 'b', - '{{st2kv.system.k1}}', + "level0_key": [ + "a", + "{{st2kv.system.k0}}", + "b", + "{{st2kv.system.k1}}", ] } config_db = ConfigDB(pack=pack_name, values=values) @@ -462,44 +461,34 @@ def test_get_config_dynamic_config_item_list(self): config_rendered = loader.get_config() - self.assertEqual(config_rendered, - { - 'level0_key': [ - 'a', - 'v0', - 'b', - 'v1' - ] - }) + self.assertEqual(config_rendered, {"level0_key": ["a", "v0", "b", "v1"]}) config_db.delete() def test_get_config_dynamic_config_item_nested_list(self): - pack_name = 'dummy_pack_schema_with_nested_object_8' + pack_name = "dummy_pack_schema_with_nested_object_8" loader = ContentPackConfigLoader(pack_name=pack_name) - KeyValuePair.add_or_update(KeyValuePairDB(name='k0', value='v0')) - KeyValuePair.add_or_update(KeyValuePairDB(name='k1', value='v1')) - KeyValuePair.add_or_update(KeyValuePairDB(name='k2', value='v2')) + KeyValuePair.add_or_update(KeyValuePairDB(name="k0", value="v0")) + KeyValuePair.add_or_update(KeyValuePairDB(name="k1", value="v1")) + KeyValuePair.add_or_update(KeyValuePairDB(name="k2", value="v2")) #################### # values in objects embedded in lists and nested lists values = { - 'level0_key': [ - { - 'level1_key0': '{{st2kv.system.k0}}' - }, - '{{st2kv.system.k1}}', + "level0_key": [ + {"level1_key0": "{{st2kv.system.k0}}"}, + "{{st2kv.system.k1}}", [ - '{{st2kv.system.k0}}', - '{{st2kv.system.k1}}', - '{{st2kv.system.k2}}', + "{{st2kv.system.k0}}", + "{{st2kv.system.k1}}", + "{{st2kv.system.k2}}", ], { - 'level1_key2': [ - '{{st2kv.system.k2}}', + "level1_key2": [ + "{{st2kv.system.k2}}", ] - } + }, ] } config_db = ConfigDB(pack=pack_name, values=values) @@ -507,30 +496,30 @@ def test_get_config_dynamic_config_item_nested_list(self): config_rendered = loader.get_config() - self.assertEqual(config_rendered, - { - 'level0_key': [ - { - 'level1_key0': 'v0' - }, - 'v1', - [ - 'v0', - 'v1', - 'v2', - ], - { - 'level1_key2': [ - 'v2', - ] - } - ] - }) + self.assertEqual( + config_rendered, + { + "level0_key": [ + {"level1_key0": "v0"}, + "v1", + [ + "v0", + "v1", + "v2", + ], + { + "level1_key2": [ + "v2", + ] + }, + ] + }, + ) config_db.delete() def test_empty_config_object_in_the_database(self): - pack_name = 'dummy_pack_empty_config' + pack_name = "dummy_pack_empty_config" config_db = ConfigDB(pack=pack_name) config_db = Config.add_or_update(config_db) diff --git a/st2common/tests/unit/test_config_parser.py b/st2common/tests/unit/test_config_parser.py index 6dc690b746d..fde03853696 100644 --- a/st2common/tests/unit/test_config_parser.py +++ b/st2common/tests/unit/test_config_parser.py @@ -27,27 +27,27 @@ def setUp(self): tests_config.parse_args() def test_get_config_inexistent_pack(self): - parser = ContentPackConfigParser(pack_name='inexistent') + parser = ContentPackConfigParser(pack_name="inexistent") config = parser.get_config() self.assertEqual(config, None) def test_get_config_no_config(self): - pack_name = 'dummy_pack_1' + pack_name = "dummy_pack_1" parser = ContentPackConfigParser(pack_name=pack_name) config = parser.get_config() self.assertEqual(config, None) def test_get_config_existing_config(self): - pack_name = 'dummy_pack_2' + pack_name = "dummy_pack_2" parser = ContentPackConfigParser(pack_name=pack_name) config = parser.get_config() - self.assertEqual(config.config['section1']['key1'], 'value1') - self.assertEqual(config.config['section2']['key10'], 'value10') + self.assertEqual(config.config["section1"]["key1"], "value1") + self.assertEqual(config.config["section2"]["key10"], "value10") def test_get_config_for_unicode_char(self): - pack_name = 'dummy_pack_18' + pack_name = "dummy_pack_18" parser = ContentPackConfigParser(pack_name=pack_name) config = parser.get_config() - self.assertEqual(config.config['section1']['key1'], u'测试') + self.assertEqual(config.config["section1"]["key1"], "测试") diff --git a/st2common/tests/unit/test_configs_registrar.py b/st2common/tests/unit/test_configs_registrar.py index 09d002eb6a9..821cec75fa9 100644 --- a/st2common/tests/unit/test_configs_registrar.py +++ b/st2common/tests/unit/test_configs_registrar.py @@ -30,15 +30,23 @@ from st2tests import fixturesloader -__all__ = [ - 'ConfigsRegistrarTestCase' -] - -PACK_1_PATH = os.path.join(fixturesloader.get_fixtures_packs_base_path(), 'dummy_pack_1') -PACK_6_PATH = os.path.join(fixturesloader.get_fixtures_packs_base_path(), 'dummy_pack_6') -PACK_19_PATH = os.path.join(fixturesloader.get_fixtures_packs_base_path(), 'dummy_pack_19') -PACK_11_PATH = os.path.join(fixturesloader.get_fixtures_packs_base_path(), 'dummy_pack_11') -PACK_22_PATH = os.path.join(fixturesloader.get_fixtures_packs_base_path(), 'dummy_pack_22') +__all__ = ["ConfigsRegistrarTestCase"] + +PACK_1_PATH = os.path.join( + fixturesloader.get_fixtures_packs_base_path(), "dummy_pack_1" +) +PACK_6_PATH = os.path.join( + fixturesloader.get_fixtures_packs_base_path(), "dummy_pack_6" +) +PACK_19_PATH = os.path.join( + fixturesloader.get_fixtures_packs_base_path(), "dummy_pack_19" +) +PACK_11_PATH = os.path.join( + fixturesloader.get_fixtures_packs_base_path(), "dummy_pack_11" +) +PACK_22_PATH = os.path.join( + fixturesloader.get_fixtures_packs_base_path(), "dummy_pack_22" +) class ConfigsRegistrarTestCase(CleanDbTestCase): @@ -52,7 +60,7 @@ def test_register_configs_for_all_packs(self): registrar = ConfigsRegistrar(use_pack_cache=False) registrar._pack_loader.get_packs = mock.Mock() - registrar._pack_loader.get_packs.return_value = {'dummy_pack_1': PACK_1_PATH} + registrar._pack_loader.get_packs.return_value = {"dummy_pack_1": PACK_1_PATH} packs_base_paths = content_utils.get_packs_base_paths() registrar.register_from_packs(base_dirs=packs_base_paths) @@ -64,9 +72,9 @@ def test_register_configs_for_all_packs(self): self.assertEqual(len(config_dbs), 1) config_db = config_dbs[0] - self.assertEqual(config_db.values['api_key'], '{{st2kv.user.api_key}}') - self.assertEqual(config_db.values['api_secret'], SUPER_SECRET_PARAMETER) - self.assertEqual(config_db.values['region'], 'us-west-1') + self.assertEqual(config_db.values["api_key"], "{{st2kv.user.api_key}}") + self.assertEqual(config_db.values["api_secret"], SUPER_SECRET_PARAMETER) + self.assertEqual(config_db.values["region"], "us-west-1") def test_register_all_configs_invalid_config_no_config_schema(self): # verify_ configs is on, but ConfigSchema for the pack doesn't exist so @@ -81,7 +89,7 @@ def test_register_all_configs_invalid_config_no_config_schema(self): registrar = ConfigsRegistrar(use_pack_cache=False, validate_configs=False) registrar._pack_loader.get_packs = mock.Mock() - registrar._pack_loader.get_packs.return_value = {'dummy_pack_6': PACK_6_PATH} + registrar._pack_loader.get_packs.return_value = {"dummy_pack_6": PACK_6_PATH} packs_base_paths = content_utils.get_packs_base_paths() registrar.register_from_packs(base_dirs=packs_base_paths) @@ -92,7 +100,9 @@ def test_register_all_configs_invalid_config_no_config_schema(self): self.assertEqual(len(pack_dbs), 1) self.assertEqual(len(config_dbs), 1) - def test_register_all_configs_with_config_schema_validation_validation_failure_1(self): + def test_register_all_configs_with_config_schema_validation_validation_failure_1( + self, + ): # Verify DB is empty pack_dbs = Pack.get_all() config_dbs = Config.get_all() @@ -100,28 +110,38 @@ def test_register_all_configs_with_config_schema_validation_validation_failure_1 self.assertEqual(len(pack_dbs), 0) self.assertEqual(len(config_dbs), 0) - registrar = ConfigsRegistrar(use_pack_cache=False, fail_on_failure=True, - validate_configs=True) + registrar = ConfigsRegistrar( + use_pack_cache=False, fail_on_failure=True, validate_configs=True + ) registrar._pack_loader.get_packs = mock.Mock() - registrar._pack_loader.get_packs.return_value = {'dummy_pack_6': PACK_6_PATH} + registrar._pack_loader.get_packs.return_value = {"dummy_pack_6": PACK_6_PATH} # Register ConfigSchema for pack registrar._register_pack_db = mock.Mock() - registrar._register_pack(pack_name='dummy_pack_5', pack_dir=PACK_6_PATH) + registrar._register_pack(pack_name="dummy_pack_5", pack_dir=PACK_6_PATH) packs_base_paths = content_utils.get_packs_base_paths() if six.PY3: - expected_msg = ('Failed validating attribute "regions" in config for pack ' - '"dummy_pack_6" (.*?): 1000 is not of type \'array\'') + expected_msg = ( + 'Failed validating attribute "regions" in config for pack ' + "\"dummy_pack_6\" (.*?): 1000 is not of type 'array'" + ) else: - expected_msg = ('Failed validating attribute "regions" in config for pack ' - '"dummy_pack_6" (.*?): 1000 is not of type u\'array\'') - - self.assertRaisesRegexp(ValueError, expected_msg, - registrar.register_from_packs, - base_dirs=packs_base_paths) - - def test_register_all_configs_with_config_schema_validation_validation_failure_2(self): + expected_msg = ( + 'Failed validating attribute "regions" in config for pack ' + "\"dummy_pack_6\" (.*?): 1000 is not of type u'array'" + ) + + self.assertRaisesRegexp( + ValueError, + expected_msg, + registrar.register_from_packs, + base_dirs=packs_base_paths, + ) + + def test_register_all_configs_with_config_schema_validation_validation_failure_2( + self, + ): # Verify DB is empty pack_dbs = Pack.get_all() config_dbs = Config.get_all() @@ -129,30 +149,40 @@ def test_register_all_configs_with_config_schema_validation_validation_failure_2 self.assertEqual(len(pack_dbs), 0) self.assertEqual(len(config_dbs), 0) - registrar = ConfigsRegistrar(use_pack_cache=False, fail_on_failure=True, - validate_configs=True) + registrar = ConfigsRegistrar( + use_pack_cache=False, fail_on_failure=True, validate_configs=True + ) registrar._pack_loader.get_packs = mock.Mock() - registrar._pack_loader.get_packs.return_value = {'dummy_pack_19': PACK_19_PATH} + registrar._pack_loader.get_packs.return_value = {"dummy_pack_19": PACK_19_PATH} # Register ConfigSchema for pack registrar._register_pack_db = mock.Mock() - registrar._register_pack(pack_name='dummy_pack_19', pack_dir=PACK_19_PATH) + registrar._register_pack(pack_name="dummy_pack_19", pack_dir=PACK_19_PATH) packs_base_paths = content_utils.get_packs_base_paths() if six.PY3: - expected_msg = ('Failed validating attribute "instances.0.alias" in config for pack ' - '"dummy_pack_19" (.*?): {\'not\': \'string\'} is not of type ' - '\'string\'') + expected_msg = ( + 'Failed validating attribute "instances.0.alias" in config for pack ' + "\"dummy_pack_19\" (.*?): {'not': 'string'} is not of type " + "'string'" + ) else: - expected_msg = ('Failed validating attribute "instances.0.alias" in config for pack ' - '"dummy_pack_19" (.*?): {\'not\': \'string\'} is not of type ' - 'u\'string\'') - - self.assertRaisesRegexp(ValueError, expected_msg, - registrar.register_from_packs, - base_dirs=packs_base_paths) - - def test_register_all_configs_with_config_schema_validation_validation_failure_3(self): + expected_msg = ( + 'Failed validating attribute "instances.0.alias" in config for pack ' + "\"dummy_pack_19\" (.*?): {'not': 'string'} is not of type " + "u'string'" + ) + + self.assertRaisesRegexp( + ValueError, + expected_msg, + registrar.register_from_packs, + base_dirs=packs_base_paths, + ) + + def test_register_all_configs_with_config_schema_validation_validation_failure_3( + self, + ): # This test checks for values containing "decrypt_kv" jinja filter in the config # object where keys have "secret: True" set in the schema. @@ -163,26 +193,34 @@ def test_register_all_configs_with_config_schema_validation_validation_failure_3 self.assertEqual(len(pack_dbs), 0) self.assertEqual(len(config_dbs), 0) - registrar = ConfigsRegistrar(use_pack_cache=False, fail_on_failure=True, - validate_configs=True) + registrar = ConfigsRegistrar( + use_pack_cache=False, fail_on_failure=True, validate_configs=True + ) registrar._pack_loader.get_packs = mock.Mock() - registrar._pack_loader.get_packs.return_value = {'dummy_pack_11': PACK_11_PATH} + registrar._pack_loader.get_packs.return_value = {"dummy_pack_11": PACK_11_PATH} # Register ConfigSchema for pack registrar._register_pack_db = mock.Mock() - registrar._register_pack(pack_name='dummy_pack_11', pack_dir=PACK_11_PATH) + registrar._register_pack(pack_name="dummy_pack_11", pack_dir=PACK_11_PATH) packs_base_paths = content_utils.get_packs_base_paths() - expected_msg = ('Values specified as "secret: True" in config schema are automatically ' - 'decrypted by default. Use of "decrypt_kv" jinja filter is not allowed ' - 'for such values. Please check the specified values in the config or ' - 'the default values in the schema.') - - self.assertRaisesRegexp(ValueError, expected_msg, - registrar.register_from_packs, - base_dirs=packs_base_paths) - - def test_register_all_configs_with_config_schema_validation_validation_failure_4(self): + expected_msg = ( + 'Values specified as "secret: True" in config schema are automatically ' + 'decrypted by default. Use of "decrypt_kv" jinja filter is not allowed ' + "for such values. Please check the specified values in the config or " + "the default values in the schema." + ) + + self.assertRaisesRegexp( + ValueError, + expected_msg, + registrar.register_from_packs, + base_dirs=packs_base_paths, + ) + + def test_register_all_configs_with_config_schema_validation_validation_failure_4( + self, + ): # This test checks for default values containing "decrypt_kv" jinja filter for # keys which have "secret: True" set. @@ -193,21 +231,27 @@ def test_register_all_configs_with_config_schema_validation_validation_failure_4 self.assertEqual(len(pack_dbs), 0) self.assertEqual(len(config_dbs), 0) - registrar = ConfigsRegistrar(use_pack_cache=False, fail_on_failure=True, - validate_configs=True) + registrar = ConfigsRegistrar( + use_pack_cache=False, fail_on_failure=True, validate_configs=True + ) registrar._pack_loader.get_packs = mock.Mock() - registrar._pack_loader.get_packs.return_value = {'dummy_pack_22': PACK_22_PATH} + registrar._pack_loader.get_packs.return_value = {"dummy_pack_22": PACK_22_PATH} # Register ConfigSchema for pack registrar._register_pack_db = mock.Mock() - registrar._register_pack(pack_name='dummy_pack_22', pack_dir=PACK_22_PATH) + registrar._register_pack(pack_name="dummy_pack_22", pack_dir=PACK_22_PATH) packs_base_paths = content_utils.get_packs_base_paths() - expected_msg = ('Values specified as "secret: True" in config schema are automatically ' - 'decrypted by default. Use of "decrypt_kv" jinja filter is not allowed ' - 'for such values. Please check the specified values in the config or ' - 'the default values in the schema.') - - self.assertRaisesRegexp(ValueError, expected_msg, - registrar.register_from_packs, - base_dirs=packs_base_paths) + expected_msg = ( + 'Values specified as "secret: True" in config schema are automatically ' + 'decrypted by default. Use of "decrypt_kv" jinja filter is not allowed ' + "for such values. Please check the specified values in the config or " + "the default values in the schema." + ) + + self.assertRaisesRegexp( + ValueError, + expected_msg, + registrar.register_from_packs, + base_dirs=packs_base_paths, + ) diff --git a/st2common/tests/unit/test_connection_retry_wrapper.py b/st2common/tests/unit/test_connection_retry_wrapper.py index 8c75ff4955d..831ac8c22e7 100644 --- a/st2common/tests/unit/test_connection_retry_wrapper.py +++ b/st2common/tests/unit/test_connection_retry_wrapper.py @@ -21,19 +21,18 @@ class TestClusterRetryContext(unittest.TestCase): - def test_single_node_cluster_retry(self): retry_context = ClusterRetryContext(cluster_size=1) should_stop, wait = retry_context.test_should_stop() - self.assertFalse(should_stop, 'Not done trying.') + self.assertFalse(should_stop, "Not done trying.") self.assertEqual(wait, 10) should_stop, wait = retry_context.test_should_stop() - self.assertFalse(should_stop, 'Not done trying.') + self.assertFalse(should_stop, "Not done trying.") self.assertEqual(wait, 10) should_stop, wait = retry_context.test_should_stop() - self.assertTrue(should_stop, 'Done trying.') + self.assertTrue(should_stop, "Done trying.") self.assertEqual(wait, -1) def test_should_stop_second_channel_open_error_should_be_non_fatal(self): @@ -58,10 +57,10 @@ def test_multiple_node_cluster_retry(self): for i in range(last_index + 1): should_stop, wait = retry_context.test_should_stop() if i == last_index: - self.assertTrue(should_stop, 'Done trying.') + self.assertTrue(should_stop, "Done trying.") self.assertEqual(wait, -1) else: - self.assertFalse(should_stop, 'Not done trying.') + self.assertFalse(should_stop, "Not done trying.") # on cluster boundaries the wait is longer. Short wait when switching # to a different server within a cluster. if (i + 1) % cluster_size == 0: @@ -72,5 +71,5 @@ def test_multiple_node_cluster_retry(self): def test_zero_node_cluster_retry(self): retry_context = ClusterRetryContext(cluster_size=0) should_stop, wait = retry_context.test_should_stop() - self.assertTrue(should_stop, 'Done trying.') + self.assertTrue(should_stop, "Done trying.") self.assertEqual(wait, -1) diff --git a/st2common/tests/unit/test_content_loader.py b/st2common/tests/unit/test_content_loader.py index c20afda87a7..8b8e650afb7 100644 --- a/st2common/tests/unit/test_content_loader.py +++ b/st2common/tests/unit/test_content_loader.py @@ -23,64 +23,81 @@ from st2common.content.loader import LOG CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) -RESOURCES_DIR = os.path.abspath(os.path.join(CURRENT_DIR, '../resources')) +RESOURCES_DIR = os.path.abspath(os.path.join(CURRENT_DIR, "../resources")) class ContentLoaderTest(unittest2.TestCase): def test_get_sensors(self): - packs_base_path = os.path.join(RESOURCES_DIR, 'packs/') + packs_base_path = os.path.join(RESOURCES_DIR, "packs/") loader = ContentPackLoader() - pack_sensors = loader.get_content(base_dirs=[packs_base_path], content_type='sensors') - self.assertIsNotNone(pack_sensors.get('pack1', None)) + pack_sensors = loader.get_content( + base_dirs=[packs_base_path], content_type="sensors" + ) + self.assertIsNotNone(pack_sensors.get("pack1", None)) def test_get_sensors_pack_missing_sensors(self): loader = ContentPackLoader() - fail_pack_path = os.path.join(RESOURCES_DIR, 'packs/pack2') + fail_pack_path = os.path.join(RESOURCES_DIR, "packs/pack2") self.assertTrue(os.path.exists(fail_pack_path)) self.assertEqual(loader._get_sensors(fail_pack_path), None) def test_invalid_content_type(self): - packs_base_path = os.path.join(RESOURCES_DIR, 'packs/') + packs_base_path = os.path.join(RESOURCES_DIR, "packs/") loader = ContentPackLoader() - self.assertRaises(ValueError, loader.get_content, base_dirs=[packs_base_path], - content_type='stuff') + self.assertRaises( + ValueError, + loader.get_content, + base_dirs=[packs_base_path], + content_type="stuff", + ) def test_get_content_multiple_directories(self): - packs_base_path_1 = os.path.join(RESOURCES_DIR, 'packs/') - packs_base_path_2 = os.path.join(RESOURCES_DIR, 'packs2/') + packs_base_path_1 = os.path.join(RESOURCES_DIR, "packs/") + packs_base_path_2 = os.path.join(RESOURCES_DIR, "packs2/") base_dirs = [packs_base_path_1, packs_base_path_2] LOG.warning = Mock() loader = ContentPackLoader() - sensors = loader.get_content(base_dirs=base_dirs, content_type='sensors') - self.assertIn('pack1', sensors) # from packs/ - self.assertIn('pack3', sensors) # from packs2/ + sensors = loader.get_content(base_dirs=base_dirs, content_type="sensors") + self.assertIn("pack1", sensors) # from packs/ + self.assertIn("pack3", sensors) # from packs2/ # Assert that a warning is emitted when a duplicated pack is found - expected_msg = ('Pack "pack1" already found in ' - '"%s/packs/", ignoring content from ' - '"%s/packs2/"' % (RESOURCES_DIR, RESOURCES_DIR)) + expected_msg = ( + 'Pack "pack1" already found in ' + '"%s/packs/", ignoring content from ' + '"%s/packs2/"' % (RESOURCES_DIR, RESOURCES_DIR) + ) LOG.warning.assert_called_once_with(expected_msg) def test_get_content_from_pack_success(self): loader = ContentPackLoader() - pack_path = os.path.join(RESOURCES_DIR, 'packs/pack1') + pack_path = os.path.join(RESOURCES_DIR, "packs/pack1") - sensors = loader.get_content_from_pack(pack_dir=pack_path, content_type='sensors') - self.assertTrue(sensors.endswith('packs/pack1/sensors')) + sensors = loader.get_content_from_pack( + pack_dir=pack_path, content_type="sensors" + ) + self.assertTrue(sensors.endswith("packs/pack1/sensors")) def test_get_content_from_pack_directory_doesnt_exist(self): loader = ContentPackLoader() - pack_path = os.path.join(RESOURCES_DIR, 'packs/pack100') + pack_path = os.path.join(RESOURCES_DIR, "packs/pack100") - message_regex = 'Directory .*? doesn\'t exist' - self.assertRaisesRegexp(ValueError, message_regex, loader.get_content_from_pack, - pack_dir=pack_path, content_type='sensors') + message_regex = "Directory .*? doesn't exist" + self.assertRaisesRegexp( + ValueError, + message_regex, + loader.get_content_from_pack, + pack_dir=pack_path, + content_type="sensors", + ) def test_get_content_from_pack_no_sensors(self): loader = ContentPackLoader() - pack_path = os.path.join(RESOURCES_DIR, 'packs/pack2') + pack_path = os.path.join(RESOURCES_DIR, "packs/pack2") - result = loader.get_content_from_pack(pack_dir=pack_path, content_type='sensors') + result = loader.get_content_from_pack( + pack_dir=pack_path, content_type="sensors" + ) self.assertEqual(result, None) diff --git a/st2common/tests/unit/test_content_utils.py b/st2common/tests/unit/test_content_utils.py index 703c75aa706..523114a613a 100644 --- a/st2common/tests/unit/test_content_utils.py +++ b/st2common/tests/unit/test_content_utils.py @@ -39,205 +39,260 @@ def setUpClass(cls): tests_config.parse_args() def test_get_pack_base_paths(self): - cfg.CONF.content.system_packs_base_path = '' - cfg.CONF.content.packs_base_paths = '/opt/path1' + cfg.CONF.content.system_packs_base_path = "" + cfg.CONF.content.packs_base_paths = "/opt/path1" result = get_packs_base_paths() - self.assertEqual(result, ['/opt/path1']) + self.assertEqual(result, ["/opt/path1"]) # Multiple paths, no trailing colon - cfg.CONF.content.packs_base_paths = '/opt/path1:/opt/path2' + cfg.CONF.content.packs_base_paths = "/opt/path1:/opt/path2" result = get_packs_base_paths() - self.assertEqual(result, ['/opt/path1', '/opt/path2']) + self.assertEqual(result, ["/opt/path1", "/opt/path2"]) # Multiple paths, trailing colon - cfg.CONF.content.packs_base_paths = '/opt/path1:/opt/path2:' + cfg.CONF.content.packs_base_paths = "/opt/path1:/opt/path2:" result = get_packs_base_paths() - self.assertEqual(result, ['/opt/path1', '/opt/path2']) + self.assertEqual(result, ["/opt/path1", "/opt/path2"]) # Multiple same paths - cfg.CONF.content.packs_base_paths = '/opt/path1:/opt/path2:/opt/path1:/opt/path2' + cfg.CONF.content.packs_base_paths = ( + "/opt/path1:/opt/path2:/opt/path1:/opt/path2" + ) result = get_packs_base_paths() - self.assertEqual(result, ['/opt/path1', '/opt/path2']) + self.assertEqual(result, ["/opt/path1", "/opt/path2"]) # Assert system path is always first - cfg.CONF.content.system_packs_base_path = '/opt/system' - cfg.CONF.content.packs_base_paths = '/opt/path2:/opt/path1' + cfg.CONF.content.system_packs_base_path = "/opt/system" + cfg.CONF.content.packs_base_paths = "/opt/path2:/opt/path1" result = get_packs_base_paths() - self.assertEqual(result, ['/opt/system', '/opt/path2', '/opt/path1']) + self.assertEqual(result, ["/opt/system", "/opt/path2", "/opt/path1"]) # More scenarios orig_path = cfg.CONF.content.system_packs_base_path - cfg.CONF.content.system_packs_base_path = '/tests/packs' + cfg.CONF.content.system_packs_base_path = "/tests/packs" - names = [ - 'test_pack_1', - 'test_pack_2', - 'ma_pack' - ] + names = ["test_pack_1", "test_pack_2", "ma_pack"] for name in names: actual = get_pack_base_path(pack_name=name) - expected = os.path.join(cfg.CONF.content.system_packs_base_path, - name) + expected = os.path.join(cfg.CONF.content.system_packs_base_path, name) self.assertEqual(actual, expected) cfg.CONF.content.system_packs_base_path = orig_path def test_get_aliases_base_paths(self): - cfg.CONF.content.aliases_base_paths = '/opt/path1' + cfg.CONF.content.aliases_base_paths = "/opt/path1" result = get_aliases_base_paths() - self.assertEqual(result, ['/opt/path1']) + self.assertEqual(result, ["/opt/path1"]) # Multiple paths, no trailing colon - cfg.CONF.content.aliases_base_paths = '/opt/path1:/opt/path2' + cfg.CONF.content.aliases_base_paths = "/opt/path1:/opt/path2" result = get_aliases_base_paths() - self.assertEqual(result, ['/opt/path1', '/opt/path2']) + self.assertEqual(result, ["/opt/path1", "/opt/path2"]) # Multiple paths, trailing colon - cfg.CONF.content.aliases_base_paths = '/opt/path1:/opt/path2:' + cfg.CONF.content.aliases_base_paths = "/opt/path1:/opt/path2:" result = get_aliases_base_paths() - self.assertEqual(result, ['/opt/path1', '/opt/path2']) + self.assertEqual(result, ["/opt/path1", "/opt/path2"]) # Multiple same paths - cfg.CONF.content.aliases_base_paths = '/opt/path1:/opt/path2:/opt/path1:/opt/path2' + cfg.CONF.content.aliases_base_paths = ( + "/opt/path1:/opt/path2:/opt/path1:/opt/path2" + ) result = get_aliases_base_paths() - self.assertEqual(result, ['/opt/path1', '/opt/path2']) + self.assertEqual(result, ["/opt/path1", "/opt/path2"]) def test_get_pack_resource_file_abs_path(self): # Mock the packs path to point to the fixtures directory cfg.CONF.content.packs_base_paths = get_fixtures_packs_base_path() # Invalid resource type - expected_msg = 'Invalid resource type: fooo' - self.assertRaisesRegexp(ValueError, expected_msg, get_pack_resource_file_abs_path, - pack_ref='dummy_pack_1', - resource_type='fooo', - file_path='test.py') + expected_msg = "Invalid resource type: fooo" + self.assertRaisesRegexp( + ValueError, + expected_msg, + get_pack_resource_file_abs_path, + pack_ref="dummy_pack_1", + resource_type="fooo", + file_path="test.py", + ) # Invalid paths (directory traversal and absolute paths) - file_paths = ['/tmp/foo.py', '../foo.py', '/etc/passwd', '../../foo.py', - '/opt/stackstorm/packs/invalid_pack/actions/my_action.py', - '../../foo.py'] + file_paths = [ + "/tmp/foo.py", + "../foo.py", + "/etc/passwd", + "../../foo.py", + "/opt/stackstorm/packs/invalid_pack/actions/my_action.py", + "../../foo.py", + ] for file_path in file_paths: # action resource_type - expected_msg = (r'Invalid file path: ".*%s"\. File path needs to be relative to the ' - r'pack actions directory (.*). For example "my_action.py"\.' % - (file_path)) - self.assertRaisesRegexp(ValueError, expected_msg, get_pack_resource_file_abs_path, - pack_ref='dummy_pack_1', - resource_type='action', - file_path=file_path) + expected_msg = ( + r'Invalid file path: ".*%s"\. File path needs to be relative to the ' + r'pack actions directory (.*). For example "my_action.py"\.' + % (file_path) + ) + self.assertRaisesRegexp( + ValueError, + expected_msg, + get_pack_resource_file_abs_path, + pack_ref="dummy_pack_1", + resource_type="action", + file_path=file_path, + ) # sensor resource_type - expected_msg = (r'Invalid file path: ".*%s"\. File path needs to be relative to the ' - r'pack sensors directory (.*). For example "my_sensor.py"\.' % - (file_path)) - self.assertRaisesRegexp(ValueError, expected_msg, get_pack_resource_file_abs_path, - pack_ref='dummy_pack_1', - resource_type='sensor', - file_path=file_path) + expected_msg = ( + r'Invalid file path: ".*%s"\. File path needs to be relative to the ' + r'pack sensors directory (.*). For example "my_sensor.py"\.' + % (file_path) + ) + self.assertRaisesRegexp( + ValueError, + expected_msg, + get_pack_resource_file_abs_path, + pack_ref="dummy_pack_1", + resource_type="sensor", + file_path=file_path, + ) # no resource type - expected_msg = (r'Invalid file path: ".*%s"\. File path needs to be relative to the ' - r'pack directory (.*). For example "my_action.py"\.' % - (file_path)) - self.assertRaisesRegexp(ValueError, expected_msg, get_pack_file_abs_path, - pack_ref='dummy_pack_1', - file_path=file_path) + expected_msg = ( + r'Invalid file path: ".*%s"\. File path needs to be relative to the ' + r'pack directory (.*). For example "my_action.py"\.' % (file_path) + ) + self.assertRaisesRegexp( + ValueError, + expected_msg, + get_pack_file_abs_path, + pack_ref="dummy_pack_1", + file_path=file_path, + ) # Valid paths - file_paths = ['foo.py', 'a/foo.py', 'a/b/foo.py'] + file_paths = ["foo.py", "a/foo.py", "a/b/foo.py"] for file_path in file_paths: - expected = os.path.join(get_fixtures_packs_base_path(), - 'dummy_pack_1/actions', file_path) - result = get_pack_resource_file_abs_path(pack_ref='dummy_pack_1', - resource_type='action', - file_path=file_path) + expected = os.path.join( + get_fixtures_packs_base_path(), "dummy_pack_1/actions", file_path + ) + result = get_pack_resource_file_abs_path( + pack_ref="dummy_pack_1", resource_type="action", file_path=file_path + ) self.assertEqual(result, expected) def test_get_entry_point_absolute_path(self): orig_path = cfg.CONF.content.system_packs_base_path - cfg.CONF.content.system_packs_base_path = '/tests/packs' + cfg.CONF.content.system_packs_base_path = "/tests/packs" acutal_path = get_entry_point_abs_path( - pack='foo', - entry_point='/tests/packs/foo/bar.py') - self.assertEqual(acutal_path, '/tests/packs/foo/bar.py', 'Entry point path doesn\'t match.') + pack="foo", entry_point="/tests/packs/foo/bar.py" + ) + self.assertEqual( + acutal_path, "/tests/packs/foo/bar.py", "Entry point path doesn't match." + ) cfg.CONF.content.system_packs_base_path = orig_path def test_get_entry_point_absolute_path_empty(self): orig_path = cfg.CONF.content.system_packs_base_path - cfg.CONF.content.system_packs_base_path = '/tests/packs' - acutal_path = get_entry_point_abs_path(pack='foo', entry_point=None) - self.assertEqual(acutal_path, None, 'Entry point path doesn\'t match.') - acutal_path = get_entry_point_abs_path(pack='foo', entry_point='') - self.assertEqual(acutal_path, None, 'Entry point path doesn\'t match.') + cfg.CONF.content.system_packs_base_path = "/tests/packs" + acutal_path = get_entry_point_abs_path(pack="foo", entry_point=None) + self.assertEqual(acutal_path, None, "Entry point path doesn't match.") + acutal_path = get_entry_point_abs_path(pack="foo", entry_point="") + self.assertEqual(acutal_path, None, "Entry point path doesn't match.") cfg.CONF.content.system_packs_base_path = orig_path def test_get_entry_point_relative_path(self): orig_path = cfg.CONF.content.system_packs_base_path - cfg.CONF.content.system_packs_base_path = '/tests/packs' - acutal_path = get_entry_point_abs_path(pack='foo', entry_point='foo/bar.py') - expected_path = os.path.join(cfg.CONF.content.system_packs_base_path, 'foo', 'actions', - 'foo/bar.py') - self.assertEqual(acutal_path, expected_path, 'Entry point path doesn\'t match.') + cfg.CONF.content.system_packs_base_path = "/tests/packs" + acutal_path = get_entry_point_abs_path(pack="foo", entry_point="foo/bar.py") + expected_path = os.path.join( + cfg.CONF.content.system_packs_base_path, "foo", "actions", "foo/bar.py" + ) + self.assertEqual(acutal_path, expected_path, "Entry point path doesn't match.") cfg.CONF.content.system_packs_base_path = orig_path def test_get_action_libs_abs_path(self): orig_path = cfg.CONF.content.system_packs_base_path - cfg.CONF.content.system_packs_base_path = '/tests/packs' + cfg.CONF.content.system_packs_base_path = "/tests/packs" # entry point relative. - acutal_path = get_action_libs_abs_path(pack='foo', entry_point='foo/bar.py') - expected_path = os.path.join(cfg.CONF.content.system_packs_base_path, 'foo', 'actions', - os.path.join('foo', ACTION_LIBS_DIR)) - self.assertEqual(acutal_path, expected_path, 'Action libs path doesn\'t match.') + acutal_path = get_action_libs_abs_path(pack="foo", entry_point="foo/bar.py") + expected_path = os.path.join( + cfg.CONF.content.system_packs_base_path, + "foo", + "actions", + os.path.join("foo", ACTION_LIBS_DIR), + ) + self.assertEqual(acutal_path, expected_path, "Action libs path doesn't match.") # entry point absolute. acutal_path = get_action_libs_abs_path( - pack='foo', - entry_point='/tests/packs/foo/tmp/foo.py') - expected_path = os.path.join('/tests/packs/foo/tmp', ACTION_LIBS_DIR) - self.assertEqual(acutal_path, expected_path, 'Action libs path doesn\'t match.') + pack="foo", entry_point="/tests/packs/foo/tmp/foo.py" + ) + expected_path = os.path.join("/tests/packs/foo/tmp", ACTION_LIBS_DIR) + self.assertEqual(acutal_path, expected_path, "Action libs path doesn't match.") cfg.CONF.content.system_packs_base_path = orig_path def test_get_relative_path_to_pack_file(self): packs_base_paths = get_fixtures_packs_base_path() - pack_ref = 'dummy_pack_1' + pack_ref = "dummy_pack_1" # 1. Valid paths - file_path = os.path.join(packs_base_paths, 'dummy_pack_1/pack.yaml') + file_path = os.path.join(packs_base_paths, "dummy_pack_1/pack.yaml") result = get_relative_path_to_pack_file(pack_ref=pack_ref, file_path=file_path) - self.assertEqual(result, 'pack.yaml') + self.assertEqual(result, "pack.yaml") - file_path = os.path.join(packs_base_paths, 'dummy_pack_1/actions/action.meta.yaml') + file_path = os.path.join( + packs_base_paths, "dummy_pack_1/actions/action.meta.yaml" + ) result = get_relative_path_to_pack_file(pack_ref=pack_ref, file_path=file_path) - self.assertEqual(result, 'actions/action.meta.yaml') + self.assertEqual(result, "actions/action.meta.yaml") - file_path = os.path.join(packs_base_paths, 'dummy_pack_1/actions/lib/foo.py') + file_path = os.path.join(packs_base_paths, "dummy_pack_1/actions/lib/foo.py") result = get_relative_path_to_pack_file(pack_ref=pack_ref, file_path=file_path) - self.assertEqual(result, 'actions/lib/foo.py') + self.assertEqual(result, "actions/lib/foo.py") # Already relative - file_path = 'actions/lib/foo2.py' + file_path = "actions/lib/foo2.py" result = get_relative_path_to_pack_file(pack_ref=pack_ref, file_path=file_path) - self.assertEqual(result, 'actions/lib/foo2.py') + self.assertEqual(result, "actions/lib/foo2.py") # 2. Invalid path - outside pack directory - expected_msg = r'file_path (.*?) is not located inside the pack directory (.*?)' - - file_path = os.path.join(packs_base_paths, 'dummy_pack_2/actions/lib/foo.py') - self.assertRaisesRegexp(ValueError, expected_msg, get_relative_path_to_pack_file, - pack_ref=pack_ref, file_path=file_path) - - file_path = '/tmp/foo/bar.py' - self.assertRaisesRegexp(ValueError, expected_msg, get_relative_path_to_pack_file, - pack_ref=pack_ref, file_path=file_path) - - file_path = os.path.join(packs_base_paths, '../dummy_pack_1/pack.yaml') - self.assertRaisesRegexp(ValueError, expected_msg, get_relative_path_to_pack_file, - pack_ref=pack_ref, file_path=file_path) - - file_path = os.path.join(packs_base_paths, '../../dummy_pack_1/pack.yaml') - self.assertRaisesRegexp(ValueError, expected_msg, get_relative_path_to_pack_file, - pack_ref=pack_ref, file_path=file_path) + expected_msg = r"file_path (.*?) is not located inside the pack directory (.*?)" + + file_path = os.path.join(packs_base_paths, "dummy_pack_2/actions/lib/foo.py") + self.assertRaisesRegexp( + ValueError, + expected_msg, + get_relative_path_to_pack_file, + pack_ref=pack_ref, + file_path=file_path, + ) + + file_path = "/tmp/foo/bar.py" + self.assertRaisesRegexp( + ValueError, + expected_msg, + get_relative_path_to_pack_file, + pack_ref=pack_ref, + file_path=file_path, + ) + + file_path = os.path.join(packs_base_paths, "../dummy_pack_1/pack.yaml") + self.assertRaisesRegexp( + ValueError, + expected_msg, + get_relative_path_to_pack_file, + pack_ref=pack_ref, + file_path=file_path, + ) + + file_path = os.path.join(packs_base_paths, "../../dummy_pack_1/pack.yaml") + self.assertRaisesRegexp( + ValueError, + expected_msg, + get_relative_path_to_pack_file, + pack_ref=pack_ref, + file_path=file_path, + ) diff --git a/st2common/tests/unit/test_crypto_utils.py b/st2common/tests/unit/test_crypto_utils.py index 3bd63ecefe8..5f8f07fa692 100644 --- a/st2common/tests/unit/test_crypto_utils.py +++ b/st2common/tests/unit/test_crypto_utils.py @@ -40,37 +40,32 @@ from st2tests.fixturesloader import get_fixtures_base_path -__all__ = [ - 'CryptoUtilsTestCase', - 'CryptoUtilsKeyczarCompatibilityTestCase' -] +__all__ = ["CryptoUtilsTestCase", "CryptoUtilsKeyczarCompatibilityTestCase"] -KEY_FIXTURES_PATH = os.path.join(get_fixtures_base_path(), 'keyczar_keys/') +KEY_FIXTURES_PATH = os.path.join(get_fixtures_base_path(), "keyczar_keys/") class CryptoUtilsTestCase(TestCase): - @classmethod def setUpClass(cls): super(CryptoUtilsTestCase, cls).setUpClass() CryptoUtilsTestCase.test_crypto_key = AESKey.generate() def test_symmetric_encrypt_decrypt_short_string_needs_to_be_padded(self): - original = u'a' + original = "a" crypto = symmetric_encrypt(CryptoUtilsTestCase.test_crypto_key, original) plain = symmetric_decrypt(CryptoUtilsTestCase.test_crypto_key, crypto) self.assertEqual(plain, original) def test_symmetric_encrypt_decrypt_utf8_character(self): values = [ - u'£', - u'£££', - u'££££££', - u'č š hello đ č p ž Ž', - u'hello 💩', - u'💩💩💩💩💩' - u'💩💩💩', - u'💩😁' + "£", + "£££", + "££££££", + "č š hello đ č p ž Ž", + "hello 💩", + "💩💩💩💩💩" "💩💩💩", + "💩😁", ] for index, original in enumerate(values): @@ -81,13 +76,13 @@ def test_symmetric_encrypt_decrypt_utf8_character(self): self.assertEqual(index, (len(values) - 1)) def test_symmetric_encrypt_decrypt(self): - original = 'secret' + original = "secret" crypto = symmetric_encrypt(CryptoUtilsTestCase.test_crypto_key, original) plain = symmetric_decrypt(CryptoUtilsTestCase.test_crypto_key, crypto) self.assertEqual(plain, original) def test_encrypt_output_is_diff_due_to_diff_IV(self): - original = 'Kami is a little boy.' + original = "Kami is a little boy." cryptos = set() for _ in range(0, 10000): @@ -97,7 +92,7 @@ def test_encrypt_output_is_diff_due_to_diff_IV(self): def test_decrypt_ciphertext_is_too_short(self): aes_key = AESKey.generate() - plaintext = 'hello world ponies 1' + plaintext = "hello world ponies 1" encrypted = cryptography_symmetric_encrypt(aes_key, plaintext) # Verify original non manipulated value can be decrypted @@ -117,13 +112,18 @@ def test_decrypt_ciphertext_is_too_short(self): encrypted_malformed = binascii.hexlify(encrypted_malformed) # Verify corrupted value results in an excpetion - expected_msg = 'Invalid or malformed ciphertext' - self.assertRaisesRegexp(ValueError, expected_msg, cryptography_symmetric_decrypt, - aes_key, encrypted_malformed) + expected_msg = "Invalid or malformed ciphertext" + self.assertRaisesRegexp( + ValueError, + expected_msg, + cryptography_symmetric_decrypt, + aes_key, + encrypted_malformed, + ) def test_exception_is_thrown_on_invalid_hmac_signature(self): aes_key = AESKey.generate() - plaintext = 'hello world ponies 2' + plaintext = "hello world ponies 2" encrypted = cryptography_symmetric_encrypt(aes_key, plaintext) # Verify original non manipulated value can be decrypted @@ -133,13 +133,18 @@ def test_exception_is_thrown_on_invalid_hmac_signature(self): # Corrupt the HMAC signature (last part is the HMAC signature) encrypted_malformed = binascii.unhexlify(encrypted) encrypted_malformed = encrypted_malformed[:-3] - encrypted_malformed += b'abc' + encrypted_malformed += b"abc" encrypted_malformed = binascii.hexlify(encrypted_malformed) # Verify corrupted value results in an excpetion - expected_msg = 'Signature did not match digest' - self.assertRaisesRegexp(InvalidSignature, expected_msg, cryptography_symmetric_decrypt, - aes_key, encrypted_malformed) + expected_msg = "Signature did not match digest" + self.assertRaisesRegexp( + InvalidSignature, + expected_msg, + cryptography_symmetric_decrypt, + aes_key, + encrypted_malformed, + ) class CryptoUtilsKeyczarCompatibilityTestCase(TestCase): @@ -150,44 +155,69 @@ class CryptoUtilsKeyczarCompatibilityTestCase(TestCase): def test_aes_key_class(self): # 1. Unsupported mode - expected_msg = 'Unsupported mode: EBC' - self.assertRaisesRegexp(ValueError, expected_msg, AESKey, aes_key_string='a', - hmac_key_string='b', hmac_key_size=128, mode='EBC') + expected_msg = "Unsupported mode: EBC" + self.assertRaisesRegexp( + ValueError, + expected_msg, + AESKey, + aes_key_string="a", + hmac_key_string="b", + hmac_key_size=128, + mode="EBC", + ) # 2. AES key is too small - expected_msg = 'Unsafe key size: 64' - self.assertRaisesRegexp(ValueError, expected_msg, AESKey, aes_key_string='a', - hmac_key_string='b', hmac_key_size=128, mode='CBC', size=64) + expected_msg = "Unsafe key size: 64" + self.assertRaisesRegexp( + ValueError, + expected_msg, + AESKey, + aes_key_string="a", + hmac_key_string="b", + hmac_key_size=128, + mode="CBC", + size=64, + ) def test_loading_keys_from_keyczar_formatted_key_files(self): - key_path = os.path.join(KEY_FIXTURES_PATH, 'one.json') + key_path = os.path.join(KEY_FIXTURES_PATH, "one.json") aes_key = read_crypto_key(key_path=key_path) - self.assertEqual(aes_key.hmac_key_string, 'lgI9YdOKlIOtPQFdgB0B6zr0AZ6L2QJuFQg4gTu2dxc') + self.assertEqual( + aes_key.hmac_key_string, "lgI9YdOKlIOtPQFdgB0B6zr0AZ6L2QJuFQg4gTu2dxc" + ) self.assertEqual(aes_key.hmac_key_size, 256) - self.assertEqual(aes_key.aes_key_string, 'vKmBE2YeQ9ATyovel7NDjdnbvOMcoU5uPtUVxWxWm58') - self.assertEqual(aes_key.mode, 'CBC') + self.assertEqual( + aes_key.aes_key_string, "vKmBE2YeQ9ATyovel7NDjdnbvOMcoU5uPtUVxWxWm58" + ) + self.assertEqual(aes_key.mode, "CBC") self.assertEqual(aes_key.size, 256) - key_path = os.path.join(KEY_FIXTURES_PATH, 'two.json') + key_path = os.path.join(KEY_FIXTURES_PATH, "two.json") aes_key = read_crypto_key(key_path=key_path) - self.assertEqual(aes_key.hmac_key_string, '92ok9S5extxphADmUhObPSD5wugey8eTffoJ2CEg_2s') + self.assertEqual( + aes_key.hmac_key_string, "92ok9S5extxphADmUhObPSD5wugey8eTffoJ2CEg_2s" + ) self.assertEqual(aes_key.hmac_key_size, 256) - self.assertEqual(aes_key.aes_key_string, 'fU9hT9pm-b9hu3VyQACLXe2Z7xnaJMZrXiTltyLUzgs') - self.assertEqual(aes_key.mode, 'CBC') + self.assertEqual( + aes_key.aes_key_string, "fU9hT9pm-b9hu3VyQACLXe2Z7xnaJMZrXiTltyLUzgs" + ) + self.assertEqual(aes_key.mode, "CBC") self.assertEqual(aes_key.size, 256) - key_path = os.path.join(KEY_FIXTURES_PATH, 'five.json') + key_path = os.path.join(KEY_FIXTURES_PATH, "five.json") aes_key = read_crypto_key(key_path=key_path) - self.assertEqual(aes_key.hmac_key_string, 'GCX2uMfOzp1JXYgqH8piEE4_mJOPXydH_fRHPDw9bkM') + self.assertEqual( + aes_key.hmac_key_string, "GCX2uMfOzp1JXYgqH8piEE4_mJOPXydH_fRHPDw9bkM" + ) self.assertEqual(aes_key.hmac_key_size, 256) - self.assertEqual(aes_key.aes_key_string, 'EeBcUcbH14tL0w_fF5siEw') - self.assertEqual(aes_key.mode, 'CBC') + self.assertEqual(aes_key.aes_key_string, "EeBcUcbH14tL0w_fF5siEw") + self.assertEqual(aes_key.mode, "CBC") self.assertEqual(aes_key.size, 128) def test_key_generation_file_format_is_fully_keyczar_compatible(self): @@ -197,13 +227,13 @@ def test_key_generation_file_format_is_fully_keyczar_compatible(self): json_parsed = json.loads(key_json) expected = { - 'hmacKey': { - 'hmacKeyString': aes_key.hmac_key_string, - 'size': aes_key.hmac_key_size + "hmacKey": { + "hmacKeyString": aes_key.hmac_key_string, + "size": aes_key.hmac_key_size, }, - 'aesKeyString': aes_key.aes_key_string, - 'mode': aes_key.mode, - 'size': aes_key.size + "aesKeyString": aes_key.aes_key_string, + "mode": aes_key.mode, + "size": aes_key.size, } self.assertEqual(json_parsed, expected) @@ -211,15 +241,14 @@ def test_key_generation_file_format_is_fully_keyczar_compatible(self): def test_symmetric_encrypt_decrypt_cryptography(self): key = AESKey.generate() plaintexts = [ - 'a b c', - 'ab', - 'hello foo', - 'hell', - 'bar5' - 'hello hello bar bar hello', - 'a', - '', - 'c' + "a b c", + "ab", + "hello foo", + "hell", + "bar5" "hello hello bar bar hello", + "a", + "", + "c", ] for plaintext in plaintexts: @@ -228,13 +257,13 @@ def test_symmetric_encrypt_decrypt_cryptography(self): self.assertEqual(decrypted, plaintext) - @unittest2.skipIf(six.PY3, 'keyczar doesn\'t work under Python 3') + @unittest2.skipIf(six.PY3, "keyczar doesn't work under Python 3") def test_symmetric_encrypt_decrypt_roundtrips_1(self): encrypt_keys = [ AESKey.generate(), AESKey.generate(), AESKey.generate(), - AESKey.generate() + AESKey.generate(), ] # Verify all keys are unique @@ -248,7 +277,7 @@ def test_symmetric_encrypt_decrypt_roundtrips_1(self): self.assertEqual(len(aes_key_strings), 4) self.assertEqual(len(hmac_key_strings), 4) - plaintext = 'hello world test dummy 8 9 5 1 bar2' + plaintext = "hello world test dummy 8 9 5 1 bar2" # Verify that round trips work and that cryptography based primitives are fully compatible # with keyczar format @@ -261,14 +290,19 @@ def test_symmetric_encrypt_decrypt_roundtrips_1(self): self.assertNotEqual(data_enc_keyczar, data_enc_cryptography) data_dec_keyczar_keyczar = keyczar_symmetric_decrypt(key, data_enc_keyczar) - data_dec_keyczar_cryptography = keyczar_symmetric_decrypt(key, data_enc_cryptography) + data_dec_keyczar_cryptography = keyczar_symmetric_decrypt( + key, data_enc_cryptography + ) self.assertEqual(data_dec_keyczar_keyczar, plaintext) self.assertEqual(data_dec_keyczar_cryptography, plaintext) - data_dec_cryptography_cryptography = cryptography_symmetric_decrypt(key, - data_enc_cryptography) - data_dec_cryptography_keyczar = cryptography_symmetric_decrypt(key, data_enc_keyczar) + data_dec_cryptography_cryptography = cryptography_symmetric_decrypt( + key, data_enc_cryptography + ) + data_dec_cryptography_keyczar = cryptography_symmetric_decrypt( + key, data_enc_keyczar + ) self.assertEqual(data_dec_cryptography_cryptography, plaintext) self.assertEqual(data_dec_cryptography_keyczar, plaintext) diff --git a/st2common/tests/unit/test_datastore.py b/st2common/tests/unit/test_datastore.py index 30d3c7dc76c..1e3dc86d30c 100644 --- a/st2common/tests/unit/test_datastore.py +++ b/st2common/tests/unit/test_datastore.py @@ -28,12 +28,10 @@ from st2tests import DbTestCase from st2tests import config -__all__ = [ - 'DatastoreServiceTestCase' -] +__all__ = ["DatastoreServiceTestCase"] CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) -RESOURCES_DIR = os.path.abspath(os.path.join(CURRENT_DIR, '../resources')) +RESOURCES_DIR = os.path.abspath(os.path.join(CURRENT_DIR, "../resources")) class DatastoreServiceTestCase(DbTestCase): @@ -41,9 +39,9 @@ def setUp(self): super(DatastoreServiceTestCase, self).setUp() config.parse_args() - self._datastore_service = BaseDatastoreService(logger=mock.Mock(), - pack_name='core', - class_name='TestSensor') + self._datastore_service = BaseDatastoreService( + logger=mock.Mock(), pack_name="core", class_name="TestSensor" + ) self._datastore_service.get_api_client = mock.Mock() def test_datastore_operations_list_values(self): @@ -53,14 +51,14 @@ def test_datastore_operations_list_values(self): self._set_mock_api_client(mock_api_client) self._datastore_service.list_values(local=True, prefix=None) - mock_api_client.keys.get_all.assert_called_with(prefix='core.TestSensor:') - self._datastore_service.list_values(local=True, prefix='ponies') - mock_api_client.keys.get_all.assert_called_with(prefix='core.TestSensor:ponies') + mock_api_client.keys.get_all.assert_called_with(prefix="core.TestSensor:") + self._datastore_service.list_values(local=True, prefix="ponies") + mock_api_client.keys.get_all.assert_called_with(prefix="core.TestSensor:ponies") self._datastore_service.list_values(local=False, prefix=None) mock_api_client.keys.get_all.assert_called_with(prefix=None) - self._datastore_service.list_values(local=False, prefix='ponies') - mock_api_client.keys.get_all.assert_called_with(prefix='ponies') + self._datastore_service.list_values(local=False, prefix="ponies") + mock_api_client.keys.get_all.assert_called_with(prefix="ponies") # No values in the datastore mock_api_client = mock.Mock() @@ -74,11 +72,11 @@ def test_datastore_operations_list_values(self): # Values in the datastore kvp1 = KeyValuePair() - kvp1.name = 'test1' - kvp1.value = 'bar' + kvp1.name = "test1" + kvp1.value = "bar" kvp2 = KeyValuePair() - kvp2.name = 'test2' - kvp2.value = 'bar' + kvp2.name = "test2" + kvp2.value = "bar" mock_return_value = [kvp1, kvp2] mock_api_client.keys.get_all.return_value = mock_return_value self._set_mock_api_client(mock_api_client) @@ -90,12 +88,12 @@ def test_datastore_operations_list_values(self): def test_datastore_operations_get_value(self): mock_api_client = mock.Mock() kvp1 = KeyValuePair() - kvp1.name = 'test1' - kvp1.value = 'bar' + kvp1.name = "test1" + kvp1.value = "bar" mock_api_client.keys.get_by_id.return_value = kvp1 self._set_mock_api_client(mock_api_client) - value = self._datastore_service.get_value(name='test1', local=False) + value = self._datastore_service.get_value(name="test1", local=False) self.assertEqual(value, kvp1.value) def test_datastore_operations_set_value(self): @@ -103,10 +101,12 @@ def test_datastore_operations_set_value(self): mock_api_client.keys.update.return_value = True self._set_mock_api_client(mock_api_client) - value = self._datastore_service.set_value(name='test1', value='foo', local=False) + value = self._datastore_service.set_value( + name="test1", value="foo", local=False + ) self.assertTrue(value) - kvp = mock_api_client.keys.update.call_args[1]['instance'] - self.assertEqual(kvp.value, 'foo') + kvp = mock_api_client.keys.update.call_args[1]["instance"] + self.assertEqual(kvp.value, "foo") self.assertEqual(kvp.scope, SYSTEM_SCOPE) def test_datastore_operations_delete_value(self): @@ -114,53 +114,69 @@ def test_datastore_operations_delete_value(self): mock_api_client.keys.delete.return_value = True self._set_mock_api_client(mock_api_client) - value = self._datastore_service.delete_value(name='test', local=False) + value = self._datastore_service.delete_value(name="test", local=False) self.assertTrue(value) def test_datastore_operations_set_encrypted_value(self): mock_api_client = mock.Mock() mock_api_client.keys.update.return_value = True self._set_mock_api_client(mock_api_client) - value = self._datastore_service.set_value(name='test1', value='foo', local=False, - encrypt=True) + value = self._datastore_service.set_value( + name="test1", value="foo", local=False, encrypt=True + ) self.assertTrue(value) - kvp = mock_api_client.keys.update.call_args[1]['instance'] - self.assertEqual(kvp.value, 'foo') + kvp = mock_api_client.keys.update.call_args[1]["instance"] + self.assertEqual(kvp.value, "foo") self.assertTrue(kvp.secret) self.assertEqual(kvp.scope, SYSTEM_SCOPE) def test_datastore_unsupported_scope(self): - self.assertRaises(ValueError, self._datastore_service.get_value, name='test1', - scope='NOT_SYSTEM') - self.assertRaises(ValueError, self._datastore_service.set_value, name='test1', - value='foo', scope='NOT_SYSTEM') - self.assertRaises(ValueError, self._datastore_service.delete_value, name='test1', - scope='NOT_SYSTEM') + self.assertRaises( + ValueError, + self._datastore_service.get_value, + name="test1", + scope="NOT_SYSTEM", + ) + self.assertRaises( + ValueError, + self._datastore_service.set_value, + name="test1", + value="foo", + scope="NOT_SYSTEM", + ) + self.assertRaises( + ValueError, + self._datastore_service.delete_value, + name="test1", + scope="NOT_SYSTEM", + ) def test_datastore_get_exception(self): mock_api_client = mock.Mock() mock_api_client.keys.get_by_id.side_effect = ValueError("Exception test") self._set_mock_api_client(mock_api_client) - value = self._datastore_service.get_value(name='test1') + value = self._datastore_service.get_value(name="test1") self.assertEqual(value, None) def test_datastore_delete_exception(self): mock_api_client = mock.Mock() mock_api_client.keys.delete.side_effect = ValueError("Exception test") self._set_mock_api_client(mock_api_client) - delete_success = self._datastore_service.delete_value(name='test1') + delete_success = self._datastore_service.delete_value(name="test1") self.assertEqual(delete_success, False) def test_datastore_token_timeout(self): - datastore_service = SensorDatastoreService(logger=mock.Mock(), - pack_name='core', - class_name='TestSensor', - api_username='sensor_service') + datastore_service = SensorDatastoreService( + logger=mock.Mock(), + pack_name="core", + class_name="TestSensor", + api_username="sensor_service", + ) mock_api_client = mock.Mock() kvp1 = KeyValuePair() - kvp1.name = 'test1' - kvp1.value = 'bar' + kvp1.name = "test1" + kvp1.value = "bar" mock_api_client.keys.get_by_id.return_value = kvp1 token_expire_time = get_datetime_utc_now() - timedelta(seconds=5) @@ -170,10 +186,9 @@ def test_datastore_token_timeout(self): self._set_mock_api_client(mock_api_client) with mock.patch( - 'st2common.services.datastore.Client', - return_value=mock_api_client + "st2common.services.datastore.Client", return_value=mock_api_client ) as datastore_client: - value = datastore_service.get_value(name='test1', local=False) + value = datastore_service.get_value(name="test1", local=False) self.assertTrue(datastore_client.called) self.assertEqual(value, kvp1.value) self.assertGreater(datastore_service._token_expire, token_expire_time) diff --git a/st2common/tests/unit/test_date_utils.py b/st2common/tests/unit/test_date_utils.py index 1b1d3b465c9..d453edb8f7e 100644 --- a/st2common/tests/unit/test_date_utils.py +++ b/st2common/tests/unit/test_date_utils.py @@ -25,44 +25,44 @@ class DateUtilsTestCase(unittest2.TestCase): def test_get_datetime_utc_now(self): date = date_utils.get_datetime_utc_now() - self.assertEqual(date.tzinfo.tzname(None), 'UTC') + self.assertEqual(date.tzinfo.tzname(None), "UTC") def test_add_utc_tz(self): dt = datetime.datetime.utcnow() self.assertIsNone(dt.tzinfo) dt = date_utils.add_utc_tz(dt) self.assertIsNotNone(dt.tzinfo) - self.assertEqual(dt.tzinfo.tzname(None), 'UTC') + self.assertEqual(dt.tzinfo.tzname(None), "UTC") def test_convert_to_utc(self): date_without_tz = datetime.datetime.utcnow() self.assertEqual(date_without_tz.tzinfo, None) result = date_utils.convert_to_utc(date_without_tz) - self.assertEqual(result.tzinfo.tzname(None), 'UTC') + self.assertEqual(result.tzinfo.tzname(None), "UTC") date_with_pdt_tz = datetime.datetime(2015, 10, 28, 10, 0, 0, 0) - date_with_pdt_tz = date_with_pdt_tz.replace(tzinfo=pytz.timezone('US/Pacific')) - self.assertEqual(date_with_pdt_tz.tzinfo.tzname(None), 'US/Pacific') + date_with_pdt_tz = date_with_pdt_tz.replace(tzinfo=pytz.timezone("US/Pacific")) + self.assertEqual(date_with_pdt_tz.tzinfo.tzname(None), "US/Pacific") result = date_utils.convert_to_utc(date_with_pdt_tz) - self.assertEqual(str(result), '2015-10-28 17:53:00+00:00') - self.assertEqual(result.tzinfo.tzname(None), 'UTC') + self.assertEqual(str(result), "2015-10-28 17:53:00+00:00") + self.assertEqual(result.tzinfo.tzname(None), "UTC") def test_parse(self): - date_str_without_tz = 'January 1st, 2014 10:00:00' + date_str_without_tz = "January 1st, 2014 10:00:00" result = date_utils.parse(value=date_str_without_tz) - self.assertEqual(str(result), '2014-01-01 10:00:00+00:00') - self.assertEqual(result.tzinfo.tzname(None), 'UTC') + self.assertEqual(str(result), "2014-01-01 10:00:00+00:00") + self.assertEqual(result.tzinfo.tzname(None), "UTC") # preserve original tz - date_str_with_tz = 'January 1st, 2014 10:00:00 +07:00' + date_str_with_tz = "January 1st, 2014 10:00:00 +07:00" result = date_utils.parse(value=date_str_with_tz, preserve_original_tz=True) - self.assertEqual(str(result), '2014-01-01 10:00:00+07:00') + self.assertEqual(str(result), "2014-01-01 10:00:00+07:00") self.assertEqual(result.tzinfo.utcoffset(result), datetime.timedelta(hours=7)) # convert to utc - date_str_with_tz = 'January 1st, 2014 10:00:00 +07:00' + date_str_with_tz = "January 1st, 2014 10:00:00 +07:00" result = date_utils.parse(value=date_str_with_tz, preserve_original_tz=False) - self.assertEqual(str(result), '2014-01-01 03:00:00+00:00') + self.assertEqual(str(result), "2014-01-01 03:00:00+00:00") self.assertEqual(result.tzinfo.utcoffset(result), datetime.timedelta(hours=0)) - self.assertEqual(result.tzinfo.tzname(None), 'UTC') + self.assertEqual(result.tzinfo.tzname(None), "UTC") diff --git a/st2common/tests/unit/test_db.py b/st2common/tests/unit/test_db.py index 756c0a105e6..da0157127e7 100644 --- a/st2common/tests/unit/test_db.py +++ b/st2common/tests/unit/test_db.py @@ -18,6 +18,7 @@ # NOTE: We need to perform monkeypatch before importing ssl module otherwise tests will fail. # See https://github.com/StackStorm/st2/pull/4834 for details from st2common.util.monkey_patch import monkey_patch + monkey_patch() import ssl @@ -52,47 +53,50 @@ __all__ = [ - 'DbConnectionTestCase', - 'DbConnectionTestCase', - 'ReactorModelTestCase', - 'ActionModelTestCase', - 'KeyValuePairModelTestCase' + "DbConnectionTestCase", + "DbConnectionTestCase", + "ReactorModelTestCase", + "ActionModelTestCase", + "KeyValuePairModelTestCase", ] SKIP_DELETE = False -DUMMY_DESCRIPTION = 'Sample Description.' +DUMMY_DESCRIPTION = "Sample Description." class DbIndexNameTestCase(TestCase): """ Test which verifies that model index name are not longer than the specified limit. """ + LIMIT = 65 def test_index_name_length(self): - db_name = 'st2' + db_name = "st2" for model in ALL_MODELS: collection_name = model._get_collection_name() - model_indexes = model._meta['index_specs'] + model_indexes = model._meta["index_specs"] for index_specs in model_indexes: - index_name = index_specs.get('name', None) + index_name = index_specs.get("name", None) if index_name: # Custom index name defined by the developer index_field_name = index_name else: # No explicit index name specified, one is auto-generated using # .. schema - index_fields = dict(index_specs['fields']).keys() - index_field_name = '.'.join(index_fields) + index_fields = dict(index_specs["fields"]).keys() + index_field_name = ".".join(index_fields) - index_name = '%s.%s.%s' % (db_name, collection_name, index_field_name) + index_name = "%s.%s.%s" % (db_name, collection_name, index_field_name) if len(index_name) > self.LIMIT: - self.fail('Index name "%s" for model "%s" is longer than %s characters. ' - 'Please manually define name for this index so it\'s shorter than ' - 'that' % (index_name, model.__name__, self.LIMIT)) + self.fail( + 'Index name "%s" for model "%s" is longer than %s characters. ' + "Please manually define name for this index so it's shorter than " + "that" % (index_name, model.__name__, self.LIMIT) + ) class DbConnectionTestCase(DbTestCase): @@ -111,210 +115,293 @@ def test_check_connect(self): """ client = mongoengine.connection.get_connection() - expected_str = "host=['%s:%s']" % (cfg.CONF.database.host, cfg.CONF.database.port) - self.assertIn(expected_str, str(client), 'Not connected to desired host.') + expected_str = "host=['%s:%s']" % ( + cfg.CONF.database.host, + cfg.CONF.database.port, + ) + self.assertIn(expected_str, str(client), "Not connected to desired host.") def test_get_ssl_kwargs(self): # 1. No SSL kwargs provided ssl_kwargs = _get_ssl_kwargs() - self.assertEqual(ssl_kwargs, {'ssl': False}) + self.assertEqual(ssl_kwargs, {"ssl": False}) # 2. ssl kwarg provided ssl_kwargs = _get_ssl_kwargs(ssl=True) - self.assertEqual(ssl_kwargs, {'ssl': True, 'ssl_match_hostname': True}) + self.assertEqual(ssl_kwargs, {"ssl": True, "ssl_match_hostname": True}) # 2. authentication_mechanism kwarg provided - ssl_kwargs = _get_ssl_kwargs(authentication_mechanism='MONGODB-X509') - self.assertEqual(ssl_kwargs, { - 'ssl': True, - 'ssl_match_hostname': True, - 'authentication_mechanism': 'MONGODB-X509' - }) + ssl_kwargs = _get_ssl_kwargs(authentication_mechanism="MONGODB-X509") + self.assertEqual( + ssl_kwargs, + { + "ssl": True, + "ssl_match_hostname": True, + "authentication_mechanism": "MONGODB-X509", + }, + ) # 3. ssl_keyfile provided - ssl_kwargs = _get_ssl_kwargs(ssl_keyfile='/tmp/keyfile') - self.assertEqual(ssl_kwargs, { - 'ssl': True, - 'ssl_keyfile': '/tmp/keyfile', - 'ssl_match_hostname': True - }) + ssl_kwargs = _get_ssl_kwargs(ssl_keyfile="/tmp/keyfile") + self.assertEqual( + ssl_kwargs, + {"ssl": True, "ssl_keyfile": "/tmp/keyfile", "ssl_match_hostname": True}, + ) # 4. ssl_certfile provided - ssl_kwargs = _get_ssl_kwargs(ssl_certfile='/tmp/certfile') - self.assertEqual(ssl_kwargs, { - 'ssl': True, - 'ssl_certfile': '/tmp/certfile', - 'ssl_match_hostname': True - }) + ssl_kwargs = _get_ssl_kwargs(ssl_certfile="/tmp/certfile") + self.assertEqual( + ssl_kwargs, + {"ssl": True, "ssl_certfile": "/tmp/certfile", "ssl_match_hostname": True}, + ) # 5. ssl_ca_certs provided - ssl_kwargs = _get_ssl_kwargs(ssl_ca_certs='/tmp/ca_certs') - self.assertEqual(ssl_kwargs, { - 'ssl': True, - 'ssl_ca_certs': '/tmp/ca_certs', - 'ssl_match_hostname': True - }) + ssl_kwargs = _get_ssl_kwargs(ssl_ca_certs="/tmp/ca_certs") + self.assertEqual( + ssl_kwargs, + {"ssl": True, "ssl_ca_certs": "/tmp/ca_certs", "ssl_match_hostname": True}, + ) # 6. ssl_ca_certs and ssl_cert_reqs combinations - ssl_kwargs = _get_ssl_kwargs(ssl_ca_certs='/tmp/ca_certs', ssl_cert_reqs='none') - self.assertEqual(ssl_kwargs, { - 'ssl': True, - 'ssl_ca_certs': '/tmp/ca_certs', - 'ssl_cert_reqs': ssl.CERT_NONE, - 'ssl_match_hostname': True - }) - - ssl_kwargs = _get_ssl_kwargs(ssl_ca_certs='/tmp/ca_certs', ssl_cert_reqs='optional') - self.assertEqual(ssl_kwargs, { - 'ssl': True, - 'ssl_ca_certs': '/tmp/ca_certs', - 'ssl_cert_reqs': ssl.CERT_OPTIONAL, - 'ssl_match_hostname': True - }) - - ssl_kwargs = _get_ssl_kwargs(ssl_ca_certs='/tmp/ca_certs', ssl_cert_reqs='required') - self.assertEqual(ssl_kwargs, { - 'ssl': True, - 'ssl_ca_certs': '/tmp/ca_certs', - 'ssl_cert_reqs': ssl.CERT_REQUIRED, - 'ssl_match_hostname': True - }) - - @mock.patch('st2common.models.db.mongoengine') + ssl_kwargs = _get_ssl_kwargs(ssl_ca_certs="/tmp/ca_certs", ssl_cert_reqs="none") + self.assertEqual( + ssl_kwargs, + { + "ssl": True, + "ssl_ca_certs": "/tmp/ca_certs", + "ssl_cert_reqs": ssl.CERT_NONE, + "ssl_match_hostname": True, + }, + ) + + ssl_kwargs = _get_ssl_kwargs( + ssl_ca_certs="/tmp/ca_certs", ssl_cert_reqs="optional" + ) + self.assertEqual( + ssl_kwargs, + { + "ssl": True, + "ssl_ca_certs": "/tmp/ca_certs", + "ssl_cert_reqs": ssl.CERT_OPTIONAL, + "ssl_match_hostname": True, + }, + ) + + ssl_kwargs = _get_ssl_kwargs( + ssl_ca_certs="/tmp/ca_certs", ssl_cert_reqs="required" + ) + self.assertEqual( + ssl_kwargs, + { + "ssl": True, + "ssl_ca_certs": "/tmp/ca_certs", + "ssl_cert_reqs": ssl.CERT_REQUIRED, + "ssl_match_hostname": True, + }, + ) + + @mock.patch("st2common.models.db.mongoengine") def test_db_setup(self, mock_mongoengine): - db_setup(db_name='name', db_host='host', db_port=12345, username='username', - password='password', authentication_mechanism='MONGODB-X509') + db_setup( + db_name="name", + db_host="host", + db_port=12345, + username="username", + password="password", + authentication_mechanism="MONGODB-X509", + ) call_args = mock_mongoengine.connection.connect.call_args_list[0][0] call_kwargs = mock_mongoengine.connection.connect.call_args_list[0][1] - self.assertEqual(call_args, ('name',)) - self.assertEqual(call_kwargs, { - 'host': 'host', - 'port': 12345, - 'username': 'username', - 'password': 'password', - 'tz_aware': True, - 'authentication_mechanism': 'MONGODB-X509', - 'ssl': True, - 'ssl_match_hostname': True, - 'connectTimeoutMS': 3000, - 'serverSelectionTimeoutMS': 3000 - }) - - @mock.patch('st2common.models.db.mongoengine') - @mock.patch('st2common.models.db.LOG') + self.assertEqual(call_args, ("name",)) + self.assertEqual( + call_kwargs, + { + "host": "host", + "port": 12345, + "username": "username", + "password": "password", + "tz_aware": True, + "authentication_mechanism": "MONGODB-X509", + "ssl": True, + "ssl_match_hostname": True, + "connectTimeoutMS": 3000, + "serverSelectionTimeoutMS": 3000, + }, + ) + + @mock.patch("st2common.models.db.mongoengine") + @mock.patch("st2common.models.db.LOG") def test_db_setup_connecting_info_logging(self, mock_log, mock_mongoengine): # Verify that password is not included in the log message - db_name = 'st2' - db_port = '27017' - username = 'user_st2' - password = 'pass_st2' + db_name = "st2" + db_port = "27017" + username = "user_st2" + password = "pass_st2" # 1. Password provided as separate argument - db_host = 'localhost' - username = 'user_st2' - password = 'pass_st2' - db_setup(db_name=db_name, db_host=db_host, db_port=db_port, username=username, - password=password) - - expected_message = 'Connecting to database "st2" @ "localhost:27017" as user "user_st2".' + db_host = "localhost" + username = "user_st2" + password = "pass_st2" + db_setup( + db_name=db_name, + db_host=db_host, + db_port=db_port, + username=username, + password=password, + ) + + expected_message = ( + 'Connecting to database "st2" @ "localhost:27017" as user "user_st2".' + ) actual_message = mock_log.info.call_args_list[0][0][0] self.assertEqual(expected_message, actual_message) # Check for helpful error messages if the connection is successful - expected_log_message = ('Successfully connected to database "st2" @ "localhost:27017" as ' - 'user "user_st2".') + expected_log_message = ( + 'Successfully connected to database "st2" @ "localhost:27017" as ' + 'user "user_st2".' + ) actual_log_message = mock_log.info.call_args_list[1][0][0] self.assertEqual(expected_log_message, actual_log_message) # 2. Password provided as part of uri string (single host) - db_host = 'mongodb://user_st22:pass_st22@127.0.0.2:5555' + db_host = "mongodb://user_st22:pass_st22@127.0.0.2:5555" username = None password = None - db_setup(db_name=db_name, db_host=db_host, db_port=db_port, username=username, - password=password) - - expected_message = 'Connecting to database "st2" @ "127.0.0.2:5555" as user "user_st22".' + db_setup( + db_name=db_name, + db_host=db_host, + db_port=db_port, + username=username, + password=password, + ) + + expected_message = ( + 'Connecting to database "st2" @ "127.0.0.2:5555" as user "user_st22".' + ) actual_message = mock_log.info.call_args_list[2][0][0] self.assertEqual(expected_message, actual_message) - expected_log_message = ('Successfully connected to database "st2" @ "127.0.0.2:5555" as ' - 'user "user_st22".') + expected_log_message = ( + 'Successfully connected to database "st2" @ "127.0.0.2:5555" as ' + 'user "user_st22".' + ) actual_log_message = mock_log.info.call_args_list[3][0][0] self.assertEqual(expected_log_message, actual_log_message) # 3. Password provided as part of uri string (single host) - username # provided as argument has precedence - db_host = 'mongodb://user_st210:pass_st23@127.0.0.2:5555' - username = 'user_st23' + db_host = "mongodb://user_st210:pass_st23@127.0.0.2:5555" + username = "user_st23" password = None - db_setup(db_name=db_name, db_host=db_host, db_port=db_port, username=username, - password=password) - - expected_message = 'Connecting to database "st2" @ "127.0.0.2:5555" as user "user_st23".' + db_setup( + db_name=db_name, + db_host=db_host, + db_port=db_port, + username=username, + password=password, + ) + + expected_message = ( + 'Connecting to database "st2" @ "127.0.0.2:5555" as user "user_st23".' + ) actual_message = mock_log.info.call_args_list[4][0][0] self.assertEqual(expected_message, actual_message) - expected_log_message = ('Successfully connected to database "st2" @ "127.0.0.2:5555" as ' - 'user "user_st23".') + expected_log_message = ( + 'Successfully connected to database "st2" @ "127.0.0.2:5555" as ' + 'user "user_st23".' + ) actual_log_message = mock_log.info.call_args_list[5][0][0] self.assertEqual(expected_log_message, actual_log_message) # 4. Just host provided in the url string - db_host = 'mongodb://127.0.0.2:5555' - username = 'user_st24' - password = 'foobar' - db_setup(db_name=db_name, db_host=db_host, db_port=db_port, username=username, - password=password) - - expected_message = 'Connecting to database "st2" @ "127.0.0.2:5555" as user "user_st24".' + db_host = "mongodb://127.0.0.2:5555" + username = "user_st24" + password = "foobar" + db_setup( + db_name=db_name, + db_host=db_host, + db_port=db_port, + username=username, + password=password, + ) + + expected_message = ( + 'Connecting to database "st2" @ "127.0.0.2:5555" as user "user_st24".' + ) actual_message = mock_log.info.call_args_list[6][0][0] self.assertEqual(expected_message, actual_message) - expected_log_message = ('Successfully connected to database "st2" @ "127.0.0.2:5555" as ' - 'user "user_st24".') + expected_log_message = ( + 'Successfully connected to database "st2" @ "127.0.0.2:5555" as ' + 'user "user_st24".' + ) actual_log_message = mock_log.info.call_args_list[7][0][0] self.assertEqual(expected_log_message, actual_log_message) # 5. Multiple hosts specified as part of connection uri - db_host = 'mongodb://user6:pass6@host1,host2,host3' + db_host = "mongodb://user6:pass6@host1,host2,host3" username = None - password = 'foobar' - db_setup(db_name=db_name, db_host=db_host, db_port=db_port, username=username, - password=password) - - expected_message = ('Connecting to database "st2" @ "host1:27017,host2:27017,host3:27017 ' - '(replica set)" as user "user6".') + password = "foobar" + db_setup( + db_name=db_name, + db_host=db_host, + db_port=db_port, + username=username, + password=password, + ) + + expected_message = ( + 'Connecting to database "st2" @ "host1:27017,host2:27017,host3:27017 ' + '(replica set)" as user "user6".' + ) actual_message = mock_log.info.call_args_list[8][0][0] self.assertEqual(expected_message, actual_message) - expected_log_message = ('Successfully connected to database "st2" @ ' - '"host1:27017,host2:27017,host3:27017 ' - '(replica set)" as user "user6".') + expected_log_message = ( + 'Successfully connected to database "st2" @ ' + '"host1:27017,host2:27017,host3:27017 ' + '(replica set)" as user "user6".' + ) actual_log_message = mock_log.info.call_args_list[9][0][0] self.assertEqual(expected_log_message, actual_log_message) # 6. Check for error message when failing to establish a connection mock_connect = mock.Mock() - mock_connect.admin.command = mock.Mock(side_effect=ConnectionFailure('Failed to connect')) + mock_connect.admin.command = mock.Mock( + side_effect=ConnectionFailure("Failed to connect") + ) mock_mongoengine.connection.connect.return_value = mock_connect - db_host = 'mongodb://localhost:9797' - username = 'user_st2' - password = 'pass_st2' - - expected_msg = 'Failed to connect' - self.assertRaisesRegexp(ConnectionFailure, expected_msg, db_setup, - db_name=db_name, db_host=db_host, db_port=db_port, - username=username, password=password) - - expected_message = 'Connecting to database "st2" @ "localhost:9797" as user "user_st2".' + db_host = "mongodb://localhost:9797" + username = "user_st2" + password = "pass_st2" + + expected_msg = "Failed to connect" + self.assertRaisesRegexp( + ConnectionFailure, + expected_msg, + db_setup, + db_name=db_name, + db_host=db_host, + db_port=db_port, + username=username, + password=password, + ) + + expected_message = ( + 'Connecting to database "st2" @ "localhost:9797" as user "user_st2".' + ) actual_message = mock_log.info.call_args_list[10][0][0] self.assertEqual(expected_message, actual_message) - expected_message = ('Failed to connect to database "st2" @ "localhost:9797" as user ' - '"user_st2": Failed to connect') + expected_message = ( + 'Failed to connect to database "st2" @ "localhost:9797" as user ' + '"user_st2": Failed to connect' + ) actual_message = mock_log.error.call_args_list[0][0][0] self.assertEqual(expected_message, actual_message) @@ -323,29 +410,43 @@ def test_db_connect_server_selection_timeout_ssl_on_non_ssl_listener(self): # and propagating the error disconnect() - db_name = 'st2' - db_host = 'localhost' + db_name = "st2" + db_host = "localhost" db_port = 27017 - cfg.CONF.set_override(name='connection_timeout', group='database', override=1000) + cfg.CONF.set_override( + name="connection_timeout", group="database", override=1000 + ) start = time.time() - self.assertRaises(ServerSelectionTimeoutError, db_setup, db_name=db_name, db_host=db_host, - db_port=db_port, ssl=True) + self.assertRaises( + ServerSelectionTimeoutError, + db_setup, + db_name=db_name, + db_host=db_host, + db_port=db_port, + ssl=True, + ) end = time.time() - diff = (end - start) + diff = end - start self.assertTrue(diff >= 1) disconnect() - cfg.CONF.set_override(name='connection_timeout', group='database', override=400) + cfg.CONF.set_override(name="connection_timeout", group="database", override=400) start = time.time() - self.assertRaises(ServerSelectionTimeoutError, db_setup, db_name=db_name, db_host=db_host, - db_port=db_port, ssl=True) + self.assertRaises( + ServerSelectionTimeoutError, + db_setup, + db_name=db_name, + db_host=db_host, + db_port=db_port, + ssl=True, + ) end = time.time() - diff = (end - start) + diff = end - start self.assertTrue(diff >= 0.4) @@ -364,60 +465,63 @@ def test_cleanup(self): self.assertNotIn(cfg.CONF.database.db_name, connection.database_names()) -@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) +@mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) class ReactorModelTestCase(DbTestCase): - def test_triggertype_crud(self): saved = ReactorModelTestCase._create_save_triggertype() retrieved = TriggerType.get_by_id(saved.id) - self.assertEqual(saved.name, retrieved.name, - 'Same triggertype was not returned.') + self.assertEqual( + saved.name, retrieved.name, "Same triggertype was not returned." + ) # test update - self.assertEqual(retrieved.description, '') + self.assertEqual(retrieved.description, "") retrieved.description = DUMMY_DESCRIPTION saved = TriggerType.add_or_update(retrieved) retrieved = TriggerType.get_by_id(saved.id) - self.assertEqual(retrieved.description, DUMMY_DESCRIPTION, 'Update to trigger failed.') + self.assertEqual( + retrieved.description, DUMMY_DESCRIPTION, "Update to trigger failed." + ) # cleanup ReactorModelTestCase._delete([retrieved]) try: retrieved = TriggerType.get_by_id(saved.id) except StackStormDBObjectNotFoundError: retrieved = None - self.assertIsNone(retrieved, 'managed to retrieve after failure.') + self.assertIsNone(retrieved, "managed to retrieve after failure.") def test_trigger_crud(self): triggertype = ReactorModelTestCase._create_save_triggertype() saved = ReactorModelTestCase._create_save_trigger(triggertype) retrieved = Trigger.get_by_id(saved.id) - self.assertEqual(saved.name, retrieved.name, - 'Same trigger was not returned.') + self.assertEqual(saved.name, retrieved.name, "Same trigger was not returned.") # test update - self.assertEqual(retrieved.description, '') + self.assertEqual(retrieved.description, "") retrieved.description = DUMMY_DESCRIPTION saved = Trigger.add_or_update(retrieved) retrieved = Trigger.get_by_id(saved.id) - self.assertEqual(retrieved.description, DUMMY_DESCRIPTION, 'Update to trigger failed.') + self.assertEqual( + retrieved.description, DUMMY_DESCRIPTION, "Update to trigger failed." + ) # cleanup ReactorModelTestCase._delete([retrieved, triggertype]) try: retrieved = Trigger.get_by_id(saved.id) except StackStormDBObjectNotFoundError: retrieved = None - self.assertIsNone(retrieved, 'managed to retrieve after failure.') + self.assertIsNone(retrieved, "managed to retrieve after failure.") def test_triggerinstance_crud(self): triggertype = ReactorModelTestCase._create_save_triggertype() trigger = ReactorModelTestCase._create_save_trigger(triggertype) saved = ReactorModelTestCase._create_save_triggerinstance(trigger) retrieved = TriggerInstance.get_by_id(saved.id) - self.assertIsNotNone(retrieved, 'No triggerinstance created.') + self.assertIsNotNone(retrieved, "No triggerinstance created.") ReactorModelTestCase._delete([retrieved, trigger, triggertype]) try: retrieved = TriggerInstance.get_by_id(saved.id) except StackStormDBObjectNotFoundError: retrieved = None - self.assertIsNone(retrieved, 'managed to retrieve after failure.') + self.assertIsNone(retrieved, "managed to retrieve after failure.") def test_rule_crud(self): triggertype = ReactorModelTestCase._create_save_triggertype() @@ -426,20 +530,22 @@ def test_rule_crud(self): action = ActionModelTestCase._create_save_action(runnertype) saved = ReactorModelTestCase._create_save_rule(trigger, action) retrieved = Rule.get_by_id(saved.id) - self.assertEqual(saved.name, retrieved.name, 'Same rule was not returned.') + self.assertEqual(saved.name, retrieved.name, "Same rule was not returned.") # test update self.assertEqual(retrieved.enabled, True) retrieved.enabled = False saved = Rule.add_or_update(retrieved) retrieved = Rule.get_by_id(saved.id) - self.assertEqual(retrieved.enabled, False, 'Update to rule failed.') + self.assertEqual(retrieved.enabled, False, "Update to rule failed.") # cleanup - ReactorModelTestCase._delete([retrieved, trigger, action, runnertype, triggertype]) + ReactorModelTestCase._delete( + [retrieved, trigger, action, runnertype, triggertype] + ) try: retrieved = Rule.get_by_id(saved.id) except StackStormDBObjectNotFoundError: retrieved = None - self.assertIsNone(retrieved, 'managed to retrieve after failure.') + self.assertIsNone(retrieved, "managed to retrieve after failure.") def test_rule_lookup(self): triggertype = ReactorModelTestCase._create_save_triggertype() @@ -447,10 +553,12 @@ def test_rule_lookup(self): runnertype = ActionModelTestCase._create_save_runnertype() action = ActionModelTestCase._create_save_action(runnertype) saved = ReactorModelTestCase._create_save_rule(trigger, action) - retrievedrules = Rule.query(trigger=reference.get_str_resource_ref_from_model(trigger)) - self.assertEqual(1, len(retrievedrules), 'No rules found.') + retrievedrules = Rule.query( + trigger=reference.get_str_resource_ref_from_model(trigger) + ) + self.assertEqual(1, len(retrievedrules), "No rules found.") for retrievedrule in retrievedrules: - self.assertEqual(saved.id, retrievedrule.id, 'Incorrect rule returned.') + self.assertEqual(saved.id, retrievedrule.id, "Incorrect rule returned.") ReactorModelTestCase._delete([saved, trigger, action, runnertype, triggertype]) def test_rule_lookup_enabled(self): @@ -459,12 +567,12 @@ def test_rule_lookup_enabled(self): runnertype = ActionModelTestCase._create_save_runnertype() action = ActionModelTestCase._create_save_action(runnertype) saved = ReactorModelTestCase._create_save_rule(trigger, action) - retrievedrules = Rule.query(trigger=reference.get_str_resource_ref_from_model(trigger), - enabled=True) - self.assertEqual(1, len(retrievedrules), 'Error looking up enabled rules.') + retrievedrules = Rule.query( + trigger=reference.get_str_resource_ref_from_model(trigger), enabled=True + ) + self.assertEqual(1, len(retrievedrules), "Error looking up enabled rules.") for retrievedrule in retrievedrules: - self.assertEqual(saved.id, retrievedrule.id, - 'Incorrect rule returned.') + self.assertEqual(saved.id, retrievedrule.id, "Incorrect rule returned.") ReactorModelTestCase._delete([saved, trigger, action, runnertype, triggertype]) def test_rule_lookup_disabled(self): @@ -473,49 +581,64 @@ def test_rule_lookup_disabled(self): runnertype = ActionModelTestCase._create_save_runnertype() action = ActionModelTestCase._create_save_action(runnertype) saved = ReactorModelTestCase._create_save_rule(trigger, action, False) - retrievedrules = Rule.query(trigger=reference.get_str_resource_ref_from_model(trigger), - enabled=False) - self.assertEqual(1, len(retrievedrules), 'Error looking up enabled rules.') + retrievedrules = Rule.query( + trigger=reference.get_str_resource_ref_from_model(trigger), enabled=False + ) + self.assertEqual(1, len(retrievedrules), "Error looking up enabled rules.") for retrievedrule in retrievedrules: - self.assertEqual(saved.id, retrievedrule.id, 'Incorrect rule returned.') + self.assertEqual(saved.id, retrievedrule.id, "Incorrect rule returned.") ReactorModelTestCase._delete([saved, trigger, action, runnertype, triggertype]) def test_trigger_lookup(self): triggertype = ReactorModelTestCase._create_save_triggertype() saved = ReactorModelTestCase._create_save_trigger(triggertype) retrievedtriggers = Trigger.query(name=saved.name) - self.assertEqual(1, len(retrievedtriggers), 'No triggers found.') + self.assertEqual(1, len(retrievedtriggers), "No triggers found.") for retrievedtrigger in retrievedtriggers: - self.assertEqual(saved.id, retrievedtrigger.id, - 'Incorrect trigger returned.') + self.assertEqual( + saved.id, retrievedtrigger.id, "Incorrect trigger returned." + ) ReactorModelTestCase._delete([saved, triggertype]) @staticmethod def _create_save_triggertype(): - created = TriggerTypeDB(pack='dummy_pack_1', name='triggertype-1', description='', - payload_schema={}, parameters_schema={}) + created = TriggerTypeDB( + pack="dummy_pack_1", + name="triggertype-1", + description="", + payload_schema={}, + parameters_schema={}, + ) return Trigger.add_or_update(created) @staticmethod def _create_save_trigger(triggertype): - created = TriggerDB(pack='dummy_pack_1', name='trigger-1', description='', - type=triggertype.get_reference().ref, parameters={}) + created = TriggerDB( + pack="dummy_pack_1", + name="trigger-1", + description="", + type=triggertype.get_reference().ref, + parameters={}, + ) return Trigger.add_or_update(created) @staticmethod def _create_save_triggerinstance(trigger): - created = TriggerInstanceDB(trigger=trigger.get_reference().ref, payload={}, - occurrence_time=date_utils.get_datetime_utc_now(), - status=TRIGGER_INSTANCE_PROCESSED) + created = TriggerInstanceDB( + trigger=trigger.get_reference().ref, + payload={}, + occurrence_time=date_utils.get_datetime_utc_now(), + status=TRIGGER_INSTANCE_PROCESSED, + ) return TriggerInstance.add_or_update(created) @staticmethod def _create_save_rule(trigger, action=None, enabled=True): - name = 'rule-1' - pack = 'default' + name = "rule-1" + pack = "default" ref = ResourceReference.to_string_reference(name=name, pack=pack) created = RuleDB(name=name, pack=pack, ref=ref) - created.description = '' + created.description = "" created.enabled = enabled created.trigger = reference.get_str_resource_ref_from_model(trigger) created.criteria = {} @@ -547,44 +670,21 @@ def _delete(model_objects): "description": "awesomeness", "type": "object", "properties": { - "r1": { - "type": "object", - "properties": { - "r1a": { - "type": "string" - } - } - }, - "r2": { - "type": "string", - "required": True - }, - "p1": { - "type": "string", - "required": True - }, - "p2": { - "type": "number", - "default": 2868 - }, - "p3": { - "type": "boolean", - "default": False - }, - "p4": { - "type": "string", - "secret": True - } + "r1": {"type": "object", "properties": {"r1a": {"type": "string"}}}, + "r2": {"type": "string", "required": True}, + "p1": {"type": "string", "required": True}, + "p2": {"type": "number", "default": 2868}, + "p3": {"type": "boolean", "default": False}, + "p4": {"type": "string", "secret": True}, }, - "additionalProperties": False + "additionalProperties": False, } -@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) +@mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) class ActionModelTestCase(DbTestCase): - def tearDown(self): - runnertype = RunnerType.get_by_name('python') + runnertype = RunnerType.get_by_name("python") self._delete([runnertype]) super(ActionModelTestCase, self).tearDown() @@ -592,15 +692,16 @@ def test_action_crud(self): runnertype = self._create_save_runnertype(metadata=False) saved = self._create_save_action(runnertype, metadata=False) retrieved = Action.get_by_id(saved.id) - self.assertEqual(saved.name, retrieved.name, - 'Same Action was not returned.') + self.assertEqual(saved.name, retrieved.name, "Same Action was not returned.") # test update - self.assertEqual(retrieved.description, 'awesomeness') + self.assertEqual(retrieved.description, "awesomeness") retrieved.description = DUMMY_DESCRIPTION saved = Action.add_or_update(retrieved) retrieved = Action.get_by_id(saved.id) - self.assertEqual(retrieved.description, DUMMY_DESCRIPTION, 'Update to action failed.') + self.assertEqual( + retrieved.description, DUMMY_DESCRIPTION, "Update to action failed." + ) # cleanup self._delete([retrieved]) @@ -608,14 +709,14 @@ def test_action_crud(self): retrieved = Action.get_by_id(saved.id) except StackStormDBObjectNotFoundError: retrieved = None - self.assertIsNone(retrieved, 'managed to retrieve after failure.') + self.assertIsNone(retrieved, "managed to retrieve after failure.") def test_action_with_notify_crud(self): runnertype = self._create_save_runnertype(metadata=False) saved = self._create_save_action(runnertype, metadata=False) # Update action with notification settings - on_complete = NotificationSubSchema(message='Action complete.') + on_complete = NotificationSubSchema(message="Action complete.") saved.notify = NotificationSchema(on_complete=on_complete) saved = Action.add_or_update(saved) @@ -635,7 +736,7 @@ def test_action_with_notify_crud(self): retrieved = Action.get_by_id(saved.id) except StackStormDBObjectNotFoundError: retrieved = None - self.assertIsNone(retrieved, 'managed to retrieve after failure.') + self.assertIsNone(retrieved, "managed to retrieve after failure.") def test_parameter_schema(self): runnertype = self._create_save_runnertype(metadata=True) @@ -650,13 +751,30 @@ def test_parameter_schema(self): # use schema to validate parameters jsonschema.validate({"r2": "abc", "p1": "def"}, schema, validator) - jsonschema.validate({"r2": "abc", "p1": "def", "r1": {"r1a": "ghi"}}, schema, validator) - self.assertRaises(jsonschema.ValidationError, jsonschema.validate, - '{"r2": "abc", "p1": "def"}', schema, validator) - self.assertRaises(jsonschema.ValidationError, jsonschema.validate, - {"r2": "abc"}, schema, validator) - self.assertRaises(jsonschema.ValidationError, jsonschema.validate, - {"r2": "abc", "p1": "def", "r1": 123}, schema, validator) + jsonschema.validate( + {"r2": "abc", "p1": "def", "r1": {"r1a": "ghi"}}, schema, validator + ) + self.assertRaises( + jsonschema.ValidationError, + jsonschema.validate, + '{"r2": "abc", "p1": "def"}', + schema, + validator, + ) + self.assertRaises( + jsonschema.ValidationError, + jsonschema.validate, + {"r2": "abc"}, + schema, + validator, + ) + self.assertRaises( + jsonschema.ValidationError, + jsonschema.validate, + {"r2": "abc", "p1": "def", "r1": 123}, + schema, + validator, + ) # cleanup self._delete([retrieved]) @@ -664,7 +782,7 @@ def test_parameter_schema(self): retrieved = Action.get_by_id(saved.id) except StackStormDBObjectNotFoundError: retrieved = None - self.assertIsNone(retrieved, 'managed to retrieve after failure.') + self.assertIsNone(retrieved, "managed to retrieve after failure.") def test_parameters_schema_runner_and_action_parameters_are_correctly_merged(self): # Test that the runner and action parameters are correctly deep merged when building @@ -673,54 +791,55 @@ def test_parameters_schema_runner_and_action_parameters_are_correctly_merged(sel self._create_save_runnertype(metadata=True) action_db = mock.Mock() - action_db.runner_type = {'name': 'python'} - action_db.parameters = {'r1': {'immutable': True}} + action_db.runner_type = {"name": "python"} + action_db.parameters = {"r1": {"immutable": True}} schema = util_schema.get_schema_for_action_parameters(action_db=action_db) expected = { - u'type': u'object', - u'properties': { - u'r1a': { - u'type': u'string' - } - }, - 'immutable': True + "type": "object", + "properties": {"r1a": {"type": "string"}}, + "immutable": True, } - self.assertEqual(schema['properties']['r1'], expected) + self.assertEqual(schema["properties"]["r1"], expected) @staticmethod def _create_save_runnertype(metadata=False): - created = RunnerTypeDB(name='python') - created.description = '' + created = RunnerTypeDB(name="python") + created.description = "" created.enabled = True if not metadata: - created.runner_parameters = {'r1': None, 'r2': None} + created.runner_parameters = {"r1": None, "r2": None} else: created.runner_parameters = { - 'r1': {'type': 'object', 'properties': {'r1a': {'type': 'string'}}}, - 'r2': {'type': 'string', 'required': True} + "r1": {"type": "object", "properties": {"r1a": {"type": "string"}}}, + "r2": {"type": "string", "required": True}, } - created.runner_module = 'nomodule' + created.runner_module = "nomodule" return RunnerType.add_or_update(created) @staticmethod def _create_save_action(runnertype, metadata=False): - name = 'action-1' - pack = 'wolfpack' + name = "action-1" + pack = "wolfpack" ref = ResourceReference(pack=pack, name=name).ref - created = ActionDB(name=name, description='awesomeness', enabled=True, - entry_point='/tmp/action.py', pack=pack, - ref=ref, - runner_type={'name': runnertype.name}) + created = ActionDB( + name=name, + description="awesomeness", + enabled=True, + entry_point="/tmp/action.py", + pack=pack, + ref=ref, + runner_type={"name": runnertype.name}, + ) if not metadata: - created.parameters = {'p1': None, 'p2': None, 'p3': None, 'p4': None} + created.parameters = {"p1": None, "p2": None, "p3": None, "p4": None} else: created.parameters = { - 'p1': {'type': 'string', 'required': True}, - 'p2': {'type': 'number', 'default': 2868}, - 'p3': {'type': 'boolean', 'default': False}, - 'p4': {'type': 'string', 'secret': True} + "p1": {"type": "string", "required": True}, + "p2": {"type": "number", "default": 2868}, + "p3": {"type": "boolean", "default": False}, + "p4": {"type": "string", "secret": True}, } return Action.add_or_update(created) @@ -738,20 +857,19 @@ def _delete(model_objects): class KeyValuePairModelTestCase(DbTestCase): - def test_kvp_crud(self): saved = KeyValuePairModelTestCase._create_save_kvp() retrieved = KeyValuePair.get_by_name(saved.name) - self.assertEqual(saved.id, retrieved.id, - 'Same KeyValuePair was not returned.') + self.assertEqual(saved.id, retrieved.id, "Same KeyValuePair was not returned.") # test update - self.assertEqual(retrieved.value, '0123456789ABCDEF') - retrieved.value = 'ABCDEF0123456789' + self.assertEqual(retrieved.value, "0123456789ABCDEF") + retrieved.value = "ABCDEF0123456789" saved = KeyValuePair.add_or_update(retrieved) retrieved = KeyValuePair.get_by_name(saved.name) - self.assertEqual(retrieved.value, 'ABCDEF0123456789', - 'Update of key value failed') + self.assertEqual( + retrieved.value, "ABCDEF0123456789", "Update of key value failed" + ) # cleanup KeyValuePairModelTestCase._delete([retrieved]) @@ -759,11 +877,11 @@ def test_kvp_crud(self): retrieved = KeyValuePair.get_by_name(saved.name) except StackStormDBObjectNotFoundError: retrieved = None - self.assertIsNone(retrieved, 'managed to retrieve after failure.') + self.assertIsNone(retrieved, "managed to retrieve after failure.") @staticmethod def _create_save_kvp(): - created = KeyValuePairDB(name='token', value='0123456789ABCDEF') + created = KeyValuePairDB(name="token", value="0123456789ABCDEF") return KeyValuePair.add_or_update(created) @staticmethod diff --git a/st2common/tests/unit/test_db_action_state.py b/st2common/tests/unit/test_db_action_state.py index 3251898e298..47b9d170bdc 100644 --- a/st2common/tests/unit/test_db_action_state.py +++ b/st2common/tests/unit/test_db_action_state.py @@ -34,13 +34,13 @@ def test_state_crud(self): retrieved = ActionExecutionState.get_by_id(saved.id) except StackStormDBObjectNotFoundError: retrieved = None - self.assertIsNone(retrieved, 'managed to retrieve after failure.') + self.assertIsNone(retrieved, "managed to retrieve after failure.") @staticmethod def _create_save_actionstate(): created = ActionExecutionStateDB() - created.query_context = {'id': 'some_external_service_id'} - created.query_module = 'dummy.modules.query1' + created.query_context = {"id": "some_external_service_id"} + created.query_module = "dummy.modules.query1" created.execution_id = bson.ObjectId() return ActionExecutionState.add_or_update(created) diff --git a/st2common/tests/unit/test_db_auth.py b/st2common/tests/unit/test_db_auth.py index 9cf35bc737d..b1595805051 100644 --- a/st2common/tests/unit/test_db_auth.py +++ b/st2common/tests/unit/test_db_auth.py @@ -26,44 +26,35 @@ from tests.unit.base import BaseDBModelCRUDTestCase -__all__ = [ - 'UserDBModelCRUDTestCase' -] +__all__ = ["UserDBModelCRUDTestCase"] class UserDBModelCRUDTestCase(BaseDBModelCRUDTestCase, DbTestCase): model_class = UserDB persistance_class = User model_class_kwargs = { - 'name': 'pony', - 'is_service': False, - 'nicknames': { - 'pony1': 'ponyA' - } + "name": "pony", + "is_service": False, + "nicknames": {"pony1": "ponyA"}, } - update_attribute_name = 'name' + update_attribute_name = "name" class TokenDBModelCRUDTestCase(BaseDBModelCRUDTestCase, DbTestCase): model_class = TokenDB persistance_class = Token model_class_kwargs = { - 'user': 'pony', - 'token': 'token-token-token-token', - 'expiry': get_datetime_utc_now(), - 'metadata': { - 'service': 'action-runner' - } + "user": "pony", + "token": "token-token-token-token", + "expiry": get_datetime_utc_now(), + "metadata": {"service": "action-runner"}, } - skip_check_attribute_names = ['expiry'] - update_attribute_name = 'user' + skip_check_attribute_names = ["expiry"] + update_attribute_name = "user" class ApiKeyDBModelCRUDTestCase(BaseDBModelCRUDTestCase, DbTestCase): model_class = ApiKeyDB persistance_class = ApiKey - model_class_kwargs = { - 'user': 'pony', - 'key_hash': 'token-token-token-token' - } - update_attribute_name = 'user' + model_class_kwargs = {"user": "pony", "key_hash": "token-token-token-token"} + update_attribute_name = "user" diff --git a/st2common/tests/unit/test_db_base.py b/st2common/tests/unit/test_db_base.py index 0c77c336bf1..68496432434 100644 --- a/st2common/tests/unit/test_db_base.py +++ b/st2common/tests/unit/test_db_base.py @@ -27,11 +27,11 @@ class FakeRuleSpecDB(mongoengine.EmbeddedDocument): def __str__(self): result = [] - result.append('ActionExecutionSpecDB@') - result.append('test') + result.append("ActionExecutionSpecDB@") + result.append("test") result.append('(ref="%s", ' % self.ref) result.append('parameters="%s")' % self.parameters) - return ''.join(result) + return "".join(result) class FakeModel(stormbase.StormBaseDB): @@ -52,30 +52,43 @@ class FakeRuleModel(stormbase.StormBaseDB): class TestBaseModel(DbTestCase): - def test_print(self): - instance = FakeModel(name='seesaw', boolean_field=True, - datetime_field=date_utils.get_datetime_utc_now(), - description=u'fun!', dict_field={'a': 1}, - integer_field=68, list_field=['abc']) - - expected = ('FakeModel(boolean_field=True, datetime_field="%s", description="fun!", ' - 'dict_field={\'a\': 1}, id=None, integer_field=68, list_field=[\'abc\'], ' - 'name="seesaw")' % str(instance.datetime_field)) + instance = FakeModel( + name="seesaw", + boolean_field=True, + datetime_field=date_utils.get_datetime_utc_now(), + description="fun!", + dict_field={"a": 1}, + integer_field=68, + list_field=["abc"], + ) + + expected = ( + 'FakeModel(boolean_field=True, datetime_field="%s", description="fun!", ' + "dict_field={'a': 1}, id=None, integer_field=68, list_field=['abc'], " + 'name="seesaw")' % str(instance.datetime_field) + ) self.assertEqual(str(instance), expected) def test_rule_print(self): - instance = FakeRuleModel(name='seesaw', boolean_field=True, - datetime_field=date_utils.get_datetime_utc_now(), - description=u'fun!', dict_field={'a': 1}, - integer_field=68, list_field=['abc'], - embedded_doc_field={'ref': '1234', 'parameters': {'b': 2}}) - - expected = ('FakeRuleModel(boolean_field=True, datetime_field="%s", description="fun!", ' - 'dict_field={\'a\': 1}, embedded_doc_field=ActionExecutionSpecDB@test(' - 'ref="1234", parameters="{\'b\': 2}"), id=None, integer_field=68, ' - 'list_field=[\'abc\'], ' - 'name="seesaw")' % str(instance.datetime_field)) + instance = FakeRuleModel( + name="seesaw", + boolean_field=True, + datetime_field=date_utils.get_datetime_utc_now(), + description="fun!", + dict_field={"a": 1}, + integer_field=68, + list_field=["abc"], + embedded_doc_field={"ref": "1234", "parameters": {"b": 2}}, + ) + + expected = ( + 'FakeRuleModel(boolean_field=True, datetime_field="%s", description="fun!", ' + "dict_field={'a': 1}, embedded_doc_field=ActionExecutionSpecDB@test(" + 'ref="1234", parameters="{\'b\': 2}"), id=None, integer_field=68, ' + "list_field=['abc'], " + 'name="seesaw")' % str(instance.datetime_field) + ) self.assertEqual(str(instance), expected) diff --git a/st2common/tests/unit/test_db_execution.py b/st2common/tests/unit/test_db_execution.py index 62478ee13e5..e94ccb3d94f 100644 --- a/st2common/tests/unit/test_db_execution.py +++ b/st2common/tests/unit/test_db_execution.py @@ -27,79 +27,71 @@ INQUIRY_RESULT = { - 'users': [], - 'roles': [], - 'route': 'developers', - 'ttl': 1440, - 'response': { - 'secondfactor': 'supersecretvalue' - }, - 'schema': { - 'type': 'object', - 'properties': { - 'secondfactor': { - 'secret': True, - 'required': True, - 'type': 'string', - 'description': 'Please enter second factor for authenticating to "foo" service' + "users": [], + "roles": [], + "route": "developers", + "ttl": 1440, + "response": {"secondfactor": "supersecretvalue"}, + "schema": { + "type": "object", + "properties": { + "secondfactor": { + "secret": True, + "required": True, + "type": "string", + "description": 'Please enter second factor for authenticating to "foo" service', } - } - } + }, + }, } INQUIRY_LIVEACTION = { - 'parameters': { - 'route': 'developers', - 'schema': { - 'type': 'object', - 'properties': { - 'secondfactor': { - 'secret': True, - 'required': True, - 'type': u'string', - 'description': 'Please enter second factor for authenticating to "foo" service' + "parameters": { + "route": "developers", + "schema": { + "type": "object", + "properties": { + "secondfactor": { + "secret": True, + "required": True, + "type": "string", + "description": 'Please enter second factor for authenticating to "foo" service', } - } - } + }, + }, }, - 'action': 'core.ask' + "action": "core.ask", } RESPOND_LIVEACTION = { - 'parameters': { - 'response': { - 'secondfactor': 'omgsupersecret', + "parameters": { + "response": { + "secondfactor": "omgsupersecret", } }, - 'action': 'st2.inquiry.respond' + "action": "st2.inquiry.respond", } ACTIONEXECUTIONS = { "execution_1": { - 'action': {'uid': 'action:core:ask'}, - 'status': 'succeeded', - 'runner': {'name': 'inquirer'}, - 'liveaction': INQUIRY_LIVEACTION, - 'result': INQUIRY_RESULT + "action": {"uid": "action:core:ask"}, + "status": "succeeded", + "runner": {"name": "inquirer"}, + "liveaction": INQUIRY_LIVEACTION, + "result": INQUIRY_RESULT, }, "execution_2": { - 'action': {'uid': 'action:st2:inquiry.respond'}, - 'status': 'succeeded', - 'runner': {'name': 'python-script'}, - 'liveaction': RESPOND_LIVEACTION, - 'result': { - 'exit_code': 0, - 'result': None, - 'stderr': '', - 'stdout': '' - } - } + "action": {"uid": "action:st2:inquiry.respond"}, + "status": "succeeded", + "runner": {"name": "python-script"}, + "liveaction": RESPOND_LIVEACTION, + "result": {"exit_code": 0, "result": None, "stderr": "", "stdout": ""}, + }, } -@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) +@mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) class ActionExecutionModelTest(DbTestCase): - def setUp(self): self.executions = {} @@ -107,16 +99,17 @@ def setUp(self): for name, execution in ACTIONEXECUTIONS.items(): created = ActionExecutionDB() - created.action = execution['action'] - created.status = execution['status'] - created.runner = execution['runner'] - created.liveaction = execution['liveaction'] - created.result = execution['result'] + created.action = execution["action"] + created.status = execution["status"] + created.runner = execution["runner"] + created.liveaction = execution["liveaction"] + created.result = execution["result"] saved = ActionExecutionModelTest._save_execution(created) retrieved = ActionExecution.get_by_id(saved.id) - self.assertEqual(saved.action, retrieved.action, - 'Same action was not returned.') + self.assertEqual( + saved.action, retrieved.action, "Same action was not returned." + ) self.executions[name] = retrieved @@ -128,15 +121,16 @@ def tearDown(self): retrieved = ActionExecution.get_by_id(execution.id) except StackStormDBObjectNotFoundError: retrieved = None - self.assertIsNone(retrieved, 'managed to retrieve after failure.') + self.assertIsNone(retrieved, "managed to retrieve after failure.") def test_update_execution(self): - """Test ActionExecutionDb update - """ - self.assertIsNone(self.executions['execution_1'].end_timestamp) - self.executions['execution_1'].end_timestamp = date_utils.get_datetime_utc_now() - updated = ActionExecution.add_or_update(self.executions['execution_1']) - self.assertTrue(updated.end_timestamp == self.executions['execution_1'].end_timestamp) + """Test ActionExecutionDb update""" + self.assertIsNone(self.executions["execution_1"].end_timestamp) + self.executions["execution_1"].end_timestamp = date_utils.get_datetime_utc_now() + updated = ActionExecution.add_or_update(self.executions["execution_1"]) + self.assertTrue( + updated.end_timestamp == self.executions["execution_1"].end_timestamp + ) def test_execution_inquiry_secrets(self): """Corner case test for Inquiry responses that contain secrets. @@ -148,13 +142,15 @@ def test_execution_inquiry_secrets(self): """ # Test Inquiry response masking is done properly within this model - masked = self.executions['execution_1'].mask_secrets( - self.executions['execution_1'].to_serializable_dict() + masked = self.executions["execution_1"].mask_secrets( + self.executions["execution_1"].to_serializable_dict() + ) + self.assertEqual( + masked["result"]["response"]["secondfactor"], MASKED_ATTRIBUTE_VALUE ) - self.assertEqual(masked['result']['response']['secondfactor'], MASKED_ATTRIBUTE_VALUE) self.assertEqual( - self.executions['execution_1'].result['response']['secondfactor'], - "supersecretvalue" + self.executions["execution_1"].result["response"]["secondfactor"], + "supersecretvalue", ) def test_execution_inquiry_response_action(self): @@ -164,10 +160,10 @@ def test_execution_inquiry_response_action(self): so we mask all response values. This test ensures this happens. """ - masked = self.executions['execution_2'].mask_secrets( - self.executions['execution_2'].to_serializable_dict() + masked = self.executions["execution_2"].mask_secrets( + self.executions["execution_2"].to_serializable_dict() ) - for value in masked['parameters']['response'].values(): + for value in masked["parameters"]["response"].values(): self.assertEqual(value, MASKED_ATTRIBUTE_VALUE) @staticmethod diff --git a/st2common/tests/unit/test_db_fields.py b/st2common/tests/unit/test_db_fields.py index eceb70d4c01..86fd3bc6fbf 100644 --- a/st2common/tests/unit/test_db_fields.py +++ b/st2common/tests/unit/test_db_fields.py @@ -37,12 +37,12 @@ def test_round_trip_conversion(self): datetime_values = [ datetime.datetime(2015, 1, 1, 15, 0, 0).replace(microsecond=500), datetime.datetime(2015, 1, 1, 15, 0, 0).replace(microsecond=0), - datetime.datetime(2015, 1, 1, 15, 0, 0).replace(microsecond=999999) + datetime.datetime(2015, 1, 1, 15, 0, 0).replace(microsecond=999999), ] datetime_values = [ date_utils.add_utc_tz(datetime_values[0]), date_utils.add_utc_tz(datetime_values[1]), - date_utils.add_utc_tz(datetime_values[2]) + date_utils.add_utc_tz(datetime_values[2]), ] microsecond_values = [] @@ -69,7 +69,7 @@ def test_round_trip_conversion(self): expected_value = datetime_values[index] self.assertEqual(actual_value, expected_value) - @mock.patch('st2common.fields.LongField.__get__') + @mock.patch("st2common.fields.LongField.__get__") def test_get_(self, mock_get): field = ComplexDateTimeField() @@ -79,7 +79,9 @@ def test_get_(self, mock_get): # Already a datetime mock_get.return_value = date_utils.get_datetime_utc_now() - self.assertEqual(field.__get__(instance=None, owner=None), mock_get.return_value) + self.assertEqual( + field.__get__(instance=None, owner=None), mock_get.return_value + ) # Microseconds dt = datetime.datetime(2015, 1, 1, 15, 0, 0).replace(microsecond=500) diff --git a/st2common/tests/unit/test_db_liveaction.py b/st2common/tests/unit/test_db_liveaction.py index 7c8b6aa35f1..605aa759f64 100644 --- a/st2common/tests/unit/test_db_liveaction.py +++ b/st2common/tests/unit/test_db_liveaction.py @@ -26,19 +26,19 @@ from st2tests import DbTestCase -@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) +@mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) class LiveActionModelTest(DbTestCase): - def test_liveaction_crud_no_notify(self): created = LiveActionDB() - created.action = 'core.local' - created.description = '' - created.status = 'running' + created.action = "core.local" + created.description = "" + created.status = "running" created.parameters = {} saved = LiveActionModelTest._save_liveaction(created) retrieved = LiveAction.get_by_id(saved.id) - self.assertEqual(saved.action, retrieved.action, - 'Same triggertype was not returned.') + self.assertEqual( + saved.action, retrieved.action, "Same triggertype was not returned." + ) self.assertEqual(retrieved.notify, None) # Test update @@ -52,80 +52,81 @@ def test_liveaction_crud_no_notify(self): retrieved = LiveAction.get_by_id(saved.id) except StackStormDBObjectNotFoundError: retrieved = None - self.assertIsNone(retrieved, 'managed to retrieve after failure.') + self.assertIsNone(retrieved, "managed to retrieve after failure.") def test_liveaction_create_with_notify_on_complete_only(self): created = LiveActionDB() - created.action = 'core.local' - created.description = '' - created.status = 'running' + created.action = "core.local" + created.description = "" + created.status = "running" created.parameters = {} notify_db = NotificationSchema() notify_sub_schema = NotificationSubSchema() - notify_sub_schema.message = 'Action complete.' - notify_sub_schema.data = { - 'foo': 'bar', - 'bar': 1, - 'baz': {'k1': 'v1'} - } + notify_sub_schema.message = "Action complete." + notify_sub_schema.data = {"foo": "bar", "bar": 1, "baz": {"k1": "v1"}} notify_db.on_complete = notify_sub_schema created.notify = notify_db saved = LiveActionModelTest._save_liveaction(created) retrieved = LiveAction.get_by_id(saved.id) - self.assertEqual(saved.action, retrieved.action, - 'Same triggertype was not returned.') + self.assertEqual( + saved.action, retrieved.action, "Same triggertype was not returned." + ) # Assert notify settings saved are right. - self.assertEqual(notify_sub_schema.message, retrieved.notify.on_complete.message) + self.assertEqual( + notify_sub_schema.message, retrieved.notify.on_complete.message + ) self.assertDictEqual(notify_sub_schema.data, retrieved.notify.on_complete.data) - self.assertListEqual(notify_sub_schema.routes, retrieved.notify.on_complete.routes) + self.assertListEqual( + notify_sub_schema.routes, retrieved.notify.on_complete.routes + ) self.assertEqual(retrieved.notify.on_success, None) self.assertEqual(retrieved.notify.on_failure, None) def test_liveaction_create_with_notify_on_success_only(self): created = LiveActionDB() - created.action = 'core.local' - created.description = '' - created.status = 'running' + created.action = "core.local" + created.description = "" + created.status = "running" created.parameters = {} notify_db = NotificationSchema() notify_sub_schema = NotificationSubSchema() - notify_sub_schema.message = 'Action succeeded.' - notify_sub_schema.data = { - 'foo': 'bar', - 'bar': 1, - 'baz': {'k1': 'v1'} - } + notify_sub_schema.message = "Action succeeded." + notify_sub_schema.data = {"foo": "bar", "bar": 1, "baz": {"k1": "v1"}} notify_db.on_success = notify_sub_schema created.notify = notify_db saved = LiveActionModelTest._save_liveaction(created) retrieved = LiveAction.get_by_id(saved.id) - self.assertEqual(saved.action, retrieved.action, - 'Same triggertype was not returned.') + self.assertEqual( + saved.action, retrieved.action, "Same triggertype was not returned." + ) # Assert notify settings saved are right. - self.assertEqual(notify_sub_schema.message, - retrieved.notify.on_success.message) + self.assertEqual(notify_sub_schema.message, retrieved.notify.on_success.message) self.assertDictEqual(notify_sub_schema.data, retrieved.notify.on_success.data) - self.assertListEqual(notify_sub_schema.routes, retrieved.notify.on_success.routes) + self.assertListEqual( + notify_sub_schema.routes, retrieved.notify.on_success.routes + ) self.assertEqual(retrieved.notify.on_failure, None) self.assertEqual(retrieved.notify.on_complete, None) def test_liveaction_create_with_notify_both_on_success_and_on_error(self): created = LiveActionDB() - created.action = 'core.local' - created.description = '' - created.status = 'running' + created.action = "core.local" + created.description = "" + created.status = "running" created.parameters = {} - on_success = NotificationSubSchema(message='Action succeeded.') - on_failure = NotificationSubSchema(message='Action failed.') - created.notify = NotificationSchema(on_success=on_success, - on_failure=on_failure) + on_success = NotificationSubSchema(message="Action succeeded.") + on_failure = NotificationSubSchema(message="Action failed.") + created.notify = NotificationSchema( + on_success=on_success, on_failure=on_failure + ) saved = LiveActionModelTest._save_liveaction(created) retrieved = LiveAction.get_by_id(saved.id) - self.assertEqual(saved.action, retrieved.action, - 'Same triggertype was not returned.') + self.assertEqual( + saved.action, retrieved.action, "Same triggertype was not returned." + ) # Assert notify settings saved are right. self.assertEqual(on_success.message, retrieved.notify.on_success.message) self.assertEqual(on_failure.message, retrieved.notify.on_failure.message) diff --git a/st2common/tests/unit/test_db_marker.py b/st2common/tests/unit/test_db_marker.py index 72dc8796978..b9cd879ea3e 100644 --- a/st2common/tests/unit/test_db_marker.py +++ b/st2common/tests/unit/test_db_marker.py @@ -26,26 +26,27 @@ class DumperMarkerModelTest(DbTestCase): def test_dumper_marker_crud(self): saved = DumperMarkerModelTest._create_save_dumper_marker() retrieved = DumperMarker.get_by_id(saved.id) - self.assertEqual(saved.marker, retrieved.marker, - 'Same marker was not returned.') + self.assertEqual( + saved.marker, retrieved.marker, "Same marker was not returned." + ) # test update time_now = date_utils.get_datetime_utc_now() retrieved.updated_at = time_now saved = DumperMarker.add_or_update(retrieved) retrieved = DumperMarker.get_by_id(saved.id) - self.assertEqual(retrieved.updated_at, time_now, 'Update to marker failed.') + self.assertEqual(retrieved.updated_at, time_now, "Update to marker failed.") # cleanup DumperMarkerModelTest._delete([retrieved]) try: retrieved = DumperMarker.get_by_id(saved.id) except StackStormDBObjectNotFoundError: retrieved = None - self.assertIsNone(retrieved, 'managed to retrieve after failure.') + self.assertIsNone(retrieved, "managed to retrieve after failure.") @staticmethod def _create_save_dumper_marker(): created = DumperMarkerDB() - created.marker = '2015-06-11T00:35:15.260439Z' + created.marker = "2015-06-11T00:35:15.260439Z" created.updated_at = date_utils.get_datetime_utc_now() return DumperMarker.add_or_update(created) diff --git a/st2common/tests/unit/test_db_model_uids.py b/st2common/tests/unit/test_db_model_uids.py index 3f5ec1ca6cf..2dd3bfb87de 100644 --- a/st2common/tests/unit/test_db_model_uids.py +++ b/st2common/tests/unit/test_db_model_uids.py @@ -30,72 +30,80 @@ from st2common.models.db.policy import PolicyDB from st2common.models.db.auth import ApiKeyDB -__all__ = [ - 'DBModelUIDFieldTestCase' -] +__all__ = ["DBModelUIDFieldTestCase"] class DBModelUIDFieldTestCase(unittest2.TestCase): def test_get_uid(self): - pack_db = PackDB(ref='ma_pack') - self.assertEqual(pack_db.get_uid(), 'pack:ma_pack') + pack_db = PackDB(ref="ma_pack") + self.assertEqual(pack_db.get_uid(), "pack:ma_pack") self.assertTrue(pack_db.has_valid_uid()) - sensor_type_db = SensorTypeDB(name='sname', pack='spack') - self.assertEqual(sensor_type_db.get_uid(), 'sensor_type:spack:sname') + sensor_type_db = SensorTypeDB(name="sname", pack="spack") + self.assertEqual(sensor_type_db.get_uid(), "sensor_type:spack:sname") self.assertTrue(sensor_type_db.has_valid_uid()) - action_db = ActionDB(name='aname', pack='apack', runner_type={}) - self.assertEqual(action_db.get_uid(), 'action:apack:aname') + action_db = ActionDB(name="aname", pack="apack", runner_type={}) + self.assertEqual(action_db.get_uid(), "action:apack:aname") self.assertTrue(action_db.has_valid_uid()) - rule_db = RuleDB(name='rname', pack='rpack') - self.assertEqual(rule_db.get_uid(), 'rule:rpack:rname') + rule_db = RuleDB(name="rname", pack="rpack") + self.assertEqual(rule_db.get_uid(), "rule:rpack:rname") self.assertTrue(rule_db.has_valid_uid()) - trigger_type_db = TriggerTypeDB(name='ttname', pack='ttpack') - self.assertEqual(trigger_type_db.get_uid(), 'trigger_type:ttpack:ttname') + trigger_type_db = TriggerTypeDB(name="ttname", pack="ttpack") + self.assertEqual(trigger_type_db.get_uid(), "trigger_type:ttpack:ttname") self.assertTrue(trigger_type_db.has_valid_uid()) - trigger_db = TriggerDB(name='tname', pack='tpack') - self.assertTrue(trigger_db.get_uid().startswith('trigger:tpack:tname:')) + trigger_db = TriggerDB(name="tname", pack="tpack") + self.assertTrue(trigger_db.get_uid().startswith("trigger:tpack:tname:")) # Verify that same set of parameters always results in the same hash - parameters = {'a': 1, 'b': 'unicode', 'c': [1, 2, 3], 'd': {'g': 1, 'h': 2}} + parameters = {"a": 1, "b": "unicode", "c": [1, 2, 3], "d": {"g": 1, "h": 2}} paramers_hash = json.dumps(parameters, sort_keys=True) paramers_hash = hashlib.md5(paramers_hash.encode()).hexdigest() - parameters = {'a': 1, 'b': 'unicode', 'c': [1, 2, 3], 'd': {'g': 1, 'h': 2}} - trigger_db = TriggerDB(name='tname', pack='tpack', parameters=parameters) - self.assertEqual(trigger_db.get_uid(), 'trigger:tpack:tname:%s' % (paramers_hash)) + parameters = {"a": 1, "b": "unicode", "c": [1, 2, 3], "d": {"g": 1, "h": 2}} + trigger_db = TriggerDB(name="tname", pack="tpack", parameters=parameters) + self.assertEqual( + trigger_db.get_uid(), "trigger:tpack:tname:%s" % (paramers_hash) + ) self.assertTrue(trigger_db.has_valid_uid()) - parameters = {'c': [1, 2, 3], 'b': u'unicode', 'd': {'h': 2, 'g': 1}, 'a': 1} - trigger_db = TriggerDB(name='tname', pack='tpack', parameters=parameters) - self.assertEqual(trigger_db.get_uid(), 'trigger:tpack:tname:%s' % (paramers_hash)) + parameters = {"c": [1, 2, 3], "b": "unicode", "d": {"h": 2, "g": 1}, "a": 1} + trigger_db = TriggerDB(name="tname", pack="tpack", parameters=parameters) + self.assertEqual( + trigger_db.get_uid(), "trigger:tpack:tname:%s" % (paramers_hash) + ) self.assertTrue(trigger_db.has_valid_uid()) - parameters = {'b': u'unicode', 'c': [1, 2, 3], 'd': {'h': 2, 'g': 1}, 'a': 1} - trigger_db = TriggerDB(name='tname', pack='tpack', parameters=parameters) - self.assertEqual(trigger_db.get_uid(), 'trigger:tpack:tname:%s' % (paramers_hash)) + parameters = {"b": "unicode", "c": [1, 2, 3], "d": {"h": 2, "g": 1}, "a": 1} + trigger_db = TriggerDB(name="tname", pack="tpack", parameters=parameters) + self.assertEqual( + trigger_db.get_uid(), "trigger:tpack:tname:%s" % (paramers_hash) + ) self.assertTrue(trigger_db.has_valid_uid()) - parameters = OrderedDict({'c': [1, 2, 3], 'b': u'unicode', 'd': {'h': 2, 'g': 1}, 'a': 1}) - trigger_db = TriggerDB(name='tname', pack='tpack', parameters=parameters) - self.assertEqual(trigger_db.get_uid(), 'trigger:tpack:tname:%s' % (paramers_hash)) + parameters = OrderedDict( + {"c": [1, 2, 3], "b": "unicode", "d": {"h": 2, "g": 1}, "a": 1} + ) + trigger_db = TriggerDB(name="tname", pack="tpack", parameters=parameters) + self.assertEqual( + trigger_db.get_uid(), "trigger:tpack:tname:%s" % (paramers_hash) + ) self.assertTrue(trigger_db.has_valid_uid()) - policy_type_db = PolicyTypeDB(resource_type='action', name='concurrency') - self.assertEqual(policy_type_db.get_uid(), 'policy_type:action:concurrency') + policy_type_db = PolicyTypeDB(resource_type="action", name="concurrency") + self.assertEqual(policy_type_db.get_uid(), "policy_type:action:concurrency") self.assertTrue(policy_type_db.has_valid_uid()) - policy_db = PolicyDB(pack='dummy', name='policy1') - self.assertEqual(policy_db.get_uid(), 'policy:dummy:policy1') + policy_db = PolicyDB(pack="dummy", name="policy1") + self.assertEqual(policy_db.get_uid(), "policy:dummy:policy1") - api_key_db = ApiKeyDB(key_hash='valid') - self.assertEqual(api_key_db.get_uid(), 'api_key:valid') + api_key_db = ApiKeyDB(key_hash="valid") + self.assertEqual(api_key_db.get_uid(), "api_key:valid") self.assertTrue(api_key_db.has_valid_uid()) api_key_db = ApiKeyDB() - self.assertEqual(api_key_db.get_uid(), 'api_key:') + self.assertEqual(api_key_db.get_uid(), "api_key:") self.assertFalse(api_key_db.has_valid_uid()) diff --git a/st2common/tests/unit/test_db_pack.py b/st2common/tests/unit/test_db_pack.py index c8df8b5a28c..d5b5af00f44 100644 --- a/st2common/tests/unit/test_db_pack.py +++ b/st2common/tests/unit/test_db_pack.py @@ -26,21 +26,21 @@ class PackDBModelCRUDTestCase(BaseDBModelCRUDTestCase, DbTestCase): model_class = PackDB persistance_class = Pack model_class_kwargs = { - 'name': 'Yolo CI', - 'ref': 'yolo_ci', - 'description': 'YOLO CI pack', - 'version': '0.1.0', - 'author': 'Volkswagen', - 'path': '/opt/stackstorm/packs/yolo_ci/' + "name": "Yolo CI", + "ref": "yolo_ci", + "description": "YOLO CI pack", + "version": "0.1.0", + "author": "Volkswagen", + "path": "/opt/stackstorm/packs/yolo_ci/", } - update_attribute_name = 'author' + update_attribute_name = "author" def test_path_none(self): PackDBModelCRUDTestCase.model_class_kwargs = { - 'name': 'Yolo CI', - 'ref': 'yolo_ci', - 'description': 'YOLO CI pack', - 'version': '0.1.0', - 'author': 'Volkswagen' + "name": "Yolo CI", + "ref": "yolo_ci", + "description": "YOLO CI pack", + "version": "0.1.0", + "author": "Volkswagen", } super(PackDBModelCRUDTestCase, self).test_crud_operations() diff --git a/st2common/tests/unit/test_db_policy.py b/st2common/tests/unit/test_db_policy.py index 9364c61074b..95b682e4a42 100644 --- a/st2common/tests/unit/test_db_policy.py +++ b/st2common/tests/unit/test_db_policy.py @@ -24,64 +24,113 @@ class PolicyTypeReferenceTest(unittest2.TestCase): - def test_is_reference(self): - self.assertTrue(PolicyTypeReference.is_reference('action.concurrency')) - self.assertFalse(PolicyTypeReference.is_reference('concurrency')) - self.assertFalse(PolicyTypeReference.is_reference('')) + self.assertTrue(PolicyTypeReference.is_reference("action.concurrency")) + self.assertFalse(PolicyTypeReference.is_reference("concurrency")) + self.assertFalse(PolicyTypeReference.is_reference("")) self.assertFalse(PolicyTypeReference.is_reference(None)) def test_validate_resource_type(self): - self.assertEqual(PolicyTypeReference.validate_resource_type('action'), 'action') - self.assertRaises(ValueError, PolicyTypeReference.validate_resource_type, 'action.test') + self.assertEqual(PolicyTypeReference.validate_resource_type("action"), "action") + self.assertRaises( + ValueError, PolicyTypeReference.validate_resource_type, "action.test" + ) def test_get_resource_type(self): - self.assertEqual(PolicyTypeReference.get_resource_type('action.concurrency'), 'action') - self.assertRaises(InvalidReferenceError, PolicyTypeReference.get_resource_type, '.abc') - self.assertRaises(InvalidReferenceError, PolicyTypeReference.get_resource_type, 'abc') - self.assertRaises(InvalidReferenceError, PolicyTypeReference.get_resource_type, '') - self.assertRaises(InvalidReferenceError, PolicyTypeReference.get_resource_type, None) + self.assertEqual( + PolicyTypeReference.get_resource_type("action.concurrency"), "action" + ) + self.assertRaises( + InvalidReferenceError, PolicyTypeReference.get_resource_type, ".abc" + ) + self.assertRaises( + InvalidReferenceError, PolicyTypeReference.get_resource_type, "abc" + ) + self.assertRaises( + InvalidReferenceError, PolicyTypeReference.get_resource_type, "" + ) + self.assertRaises( + InvalidReferenceError, PolicyTypeReference.get_resource_type, None + ) def test_get_name(self): - self.assertEqual(PolicyTypeReference.get_name('action.concurrency'), 'concurrency') - self.assertRaises(InvalidReferenceError, PolicyTypeReference.get_name, '.abc') - self.assertRaises(InvalidReferenceError, PolicyTypeReference.get_name, 'abc') - self.assertRaises(InvalidReferenceError, PolicyTypeReference.get_name, '') + self.assertEqual( + PolicyTypeReference.get_name("action.concurrency"), "concurrency" + ) + self.assertRaises(InvalidReferenceError, PolicyTypeReference.get_name, ".abc") + self.assertRaises(InvalidReferenceError, PolicyTypeReference.get_name, "abc") + self.assertRaises(InvalidReferenceError, PolicyTypeReference.get_name, "") self.assertRaises(InvalidReferenceError, PolicyTypeReference.get_name, None) def test_to_string_reference(self): - ref = PolicyTypeReference.to_string_reference(resource_type='action', name='concurrency') - self.assertEqual(ref, 'action.concurrency') - - self.assertRaises(ValueError, PolicyTypeReference.to_string_reference, - resource_type='action.test', name='concurrency') - self.assertRaises(ValueError, PolicyTypeReference.to_string_reference, - resource_type=None, name='concurrency') - self.assertRaises(ValueError, PolicyTypeReference.to_string_reference, - resource_type='', name='concurrency') - self.assertRaises(ValueError, PolicyTypeReference.to_string_reference, - resource_type='action', name=None) - self.assertRaises(ValueError, PolicyTypeReference.to_string_reference, - resource_type='action', name='') - self.assertRaises(ValueError, PolicyTypeReference.to_string_reference, - resource_type=None, name=None) - self.assertRaises(ValueError, PolicyTypeReference.to_string_reference, - resource_type='', name='') + ref = PolicyTypeReference.to_string_reference( + resource_type="action", name="concurrency" + ) + self.assertEqual(ref, "action.concurrency") + + self.assertRaises( + ValueError, + PolicyTypeReference.to_string_reference, + resource_type="action.test", + name="concurrency", + ) + self.assertRaises( + ValueError, + PolicyTypeReference.to_string_reference, + resource_type=None, + name="concurrency", + ) + self.assertRaises( + ValueError, + PolicyTypeReference.to_string_reference, + resource_type="", + name="concurrency", + ) + self.assertRaises( + ValueError, + PolicyTypeReference.to_string_reference, + resource_type="action", + name=None, + ) + self.assertRaises( + ValueError, + PolicyTypeReference.to_string_reference, + resource_type="action", + name="", + ) + self.assertRaises( + ValueError, + PolicyTypeReference.to_string_reference, + resource_type=None, + name=None, + ) + self.assertRaises( + ValueError, + PolicyTypeReference.to_string_reference, + resource_type="", + name="", + ) def test_from_string_reference(self): - ref = PolicyTypeReference.from_string_reference('action.concurrency') - self.assertEqual(ref.resource_type, 'action') - self.assertEqual(ref.name, 'concurrency') - self.assertEqual(ref.ref, 'action.concurrency') - - ref = PolicyTypeReference.from_string_reference('action.concurrency.targeted') - self.assertEqual(ref.resource_type, 'action') - self.assertEqual(ref.name, 'concurrency.targeted') - self.assertEqual(ref.ref, 'action.concurrency.targeted') - - self.assertRaises(InvalidReferenceError, PolicyTypeReference.from_string_reference, '.test') - self.assertRaises(InvalidReferenceError, PolicyTypeReference.from_string_reference, '') - self.assertRaises(InvalidReferenceError, PolicyTypeReference.from_string_reference, None) + ref = PolicyTypeReference.from_string_reference("action.concurrency") + self.assertEqual(ref.resource_type, "action") + self.assertEqual(ref.name, "concurrency") + self.assertEqual(ref.ref, "action.concurrency") + + ref = PolicyTypeReference.from_string_reference("action.concurrency.targeted") + self.assertEqual(ref.resource_type, "action") + self.assertEqual(ref.name, "concurrency.targeted") + self.assertEqual(ref.ref, "action.concurrency.targeted") + + self.assertRaises( + InvalidReferenceError, PolicyTypeReference.from_string_reference, ".test" + ) + self.assertRaises( + InvalidReferenceError, PolicyTypeReference.from_string_reference, "" + ) + self.assertRaises( + InvalidReferenceError, PolicyTypeReference.from_string_reference, None + ) class PolicyTypeTest(DbModelTestCase): @@ -89,34 +138,26 @@ class PolicyTypeTest(DbModelTestCase): @staticmethod def _create_instance(): - parameters = { - 'threshold': { - 'type': 'integer', - 'required': True - } - } - - instance = PolicyTypeDB(name='concurrency', - description='TBD', - enabled=None, - ref=None, - resource_type='action', - module='st2action.policies.concurrency', - parameters=parameters) + parameters = {"threshold": {"type": "integer", "required": True}} + + instance = PolicyTypeDB( + name="concurrency", + description="TBD", + enabled=None, + ref=None, + resource_type="action", + module="st2action.policies.concurrency", + parameters=parameters, + ) return instance def test_crud(self): instance = self._create_instance() - defaults = { - 'ref': 'action.concurrency', - 'enabled': True - } + defaults = {"ref": "action.concurrency", "enabled": True} - updates = { - 'description': 'Limits the concurrent executions for the action.' - } + updates = {"description": "Limits the concurrent executions for the action."} self._assert_crud(instance, defaults=defaults, updates=updates) @@ -130,16 +171,16 @@ class PolicyTest(DbModelTestCase): @staticmethod def _create_instance(): - instance = PolicyDB(pack=None, - name='local.concurrency', - description='TBD', - enabled=None, - ref=None, - resource_ref='core.local', - policy_type='action.concurrency', - parameters={ - 'threshold': 25 - }) + instance = PolicyDB( + pack=None, + name="local.concurrency", + description="TBD", + enabled=None, + ref=None, + resource_ref="core.local", + policy_type="action.concurrency", + parameters={"threshold": 25}, + ) return instance @@ -147,13 +188,13 @@ def test_crud(self): instance = self._create_instance() defaults = { - 'pack': pack_constants.DEFAULT_PACK_NAME, - 'ref': '%s.local.concurrency' % pack_constants.DEFAULT_PACK_NAME, - 'enabled': True + "pack": pack_constants.DEFAULT_PACK_NAME, + "ref": "%s.local.concurrency" % pack_constants.DEFAULT_PACK_NAME, + "enabled": True, } updates = { - 'description': 'Limits the concurrent executions for the action "core.local".' + "description": 'Limits the concurrent executions for the action "core.local".' } self._assert_crud(instance, defaults=defaults, updates=updates) @@ -164,7 +205,7 @@ def test_ref(self): self.assertIsNotNone(ref) self.assertEqual(ref.pack, instance.pack) self.assertEqual(ref.name, instance.name) - self.assertEqual(ref.ref, instance.pack + '.' + instance.name) + self.assertEqual(ref.ref, instance.pack + "." + instance.name) self.assertEqual(ref.ref, instance.ref) def test_unique_key(self): diff --git a/st2common/tests/unit/test_db_rbac.py b/st2common/tests/unit/test_db_rbac.py index 62b97632723..d9c3fcc958a 100644 --- a/st2common/tests/unit/test_db_rbac.py +++ b/st2common/tests/unit/test_db_rbac.py @@ -28,10 +28,10 @@ __all__ = [ - 'RoleDBModelCRUDTestCase', - 'UserRoleAssignmentDBModelCRUDTestCase', - 'PermissionGrantDBModelCRUDTestCase', - 'GroupToRoleMappingDBModelCRUDTestCase' + "RoleDBModelCRUDTestCase", + "UserRoleAssignmentDBModelCRUDTestCase", + "PermissionGrantDBModelCRUDTestCase", + "GroupToRoleMappingDBModelCRUDTestCase", ] @@ -39,44 +39,44 @@ class RoleDBModelCRUDTestCase(BaseDBModelCRUDTestCase, DbTestCase): model_class = RoleDB persistance_class = Role model_class_kwargs = { - 'name': 'role_one', - 'description': None, - 'system': False, - 'permission_grants': [] + "name": "role_one", + "description": None, + "system": False, + "permission_grants": [], } - update_attribute_name = 'name' + update_attribute_name = "name" class UserRoleAssignmentDBModelCRUDTestCase(BaseDBModelCRUDTestCase, DbTestCase): model_class = UserRoleAssignmentDB persistance_class = UserRoleAssignment model_class_kwargs = { - 'user': 'user_one', - 'role': 'role_one', - 'source': 'source_one', - 'is_remote': True + "user": "user_one", + "role": "role_one", + "source": "source_one", + "is_remote": True, } - update_attribute_name = 'role' + update_attribute_name = "role" class PermissionGrantDBModelCRUDTestCase(BaseDBModelCRUDTestCase, DbTestCase): model_class = PermissionGrantDB persistance_class = PermissionGrant model_class_kwargs = { - 'resource_uid': 'pack:core', - 'resource_type': 'pack', - 'permission_types': [] + "resource_uid": "pack:core", + "resource_type": "pack", + "permission_types": [], } - update_attribute_name = 'resource_uid' + update_attribute_name = "resource_uid" class GroupToRoleMappingDBModelCRUDTestCase(BaseDBModelCRUDTestCase, DbTestCase): model_class = GroupToRoleMappingDB persistance_class = GroupToRoleMapping model_class_kwargs = { - 'group': 'some group', - 'roles': ['role_one', 'role_two'], - 'description': 'desc', - 'enabled': True + "group": "some group", + "roles": ["role_one", "role_two"], + "description": "desc", + "enabled": True, } - update_attribute_name = 'group' + update_attribute_name = "group" diff --git a/st2common/tests/unit/test_db_rule_enforcement.py b/st2common/tests/unit/test_db_rule_enforcement.py index 734a34ffc36..5cececffa0b 100644 --- a/st2common/tests/unit/test_db_rule_enforcement.py +++ b/st2common/tests/unit/test_db_rule_enforcement.py @@ -28,19 +28,19 @@ SKIP_DELETE = False -__all__ = [ - 'RuleEnforcementModelTest' -] +__all__ = ["RuleEnforcementModelTest"] -@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) +@mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) class RuleEnforcementModelTest(DbTestCase): - def test_ruleenforcment_crud(self): saved = RuleEnforcementModelTest._create_save_rule_enforcement() retrieved = RuleEnforcement.get_by_id(saved.id) - self.assertEqual(saved.rule.ref, retrieved.rule.ref, - 'Same rule enforcement was not returned.') + self.assertEqual( + saved.rule.ref, + retrieved.rule.ref, + "Same rule enforcement was not returned.", + ) self.assertIsNotNone(retrieved.enforced_at) # test update RULE_ID = str(bson.ObjectId()) @@ -48,73 +48,82 @@ def test_ruleenforcment_crud(self): retrieved.rule.id = RULE_ID saved = RuleEnforcement.add_or_update(retrieved) retrieved = RuleEnforcement.get_by_id(saved.id) - self.assertEqual(retrieved.rule.id, RULE_ID, - 'Update to rule enforcement failed.') + self.assertEqual( + retrieved.rule.id, RULE_ID, "Update to rule enforcement failed." + ) # cleanup RuleEnforcementModelTest._delete([retrieved]) try: retrieved = RuleEnforcement.get_by_id(saved.id) except StackStormDBObjectNotFoundError: retrieved = None - self.assertIsNone(retrieved, 'managed to retrieve after delete.') + self.assertIsNone(retrieved, "managed to retrieve after delete.") def test_status_set_to_failed_for_objects_which_predate_status_field(self): - rule = { - 'ref': 'foo_pack.foo_rule', - 'uid': 'rule:foo_pack:foo_rule' - } + rule = {"ref": "foo_pack.foo_rule", "uid": "rule:foo_pack:foo_rule"} # 1. No status field explicitly set and no failure reason - enforcement_db = RuleEnforcementDB(trigger_instance_id=str(bson.ObjectId()), - rule=rule, - execution_id=str(bson.ObjectId())) + enforcement_db = RuleEnforcementDB( + trigger_instance_id=str(bson.ObjectId()), + rule=rule, + execution_id=str(bson.ObjectId()), + ) enforcement_db = RuleEnforcement.add_or_update(enforcement_db) self.assertEqual(enforcement_db.status, RULE_ENFORCEMENT_STATUS_SUCCEEDED) # 2. No status field, with failure reason, status should be set to failed - enforcement_db = RuleEnforcementDB(trigger_instance_id=str(bson.ObjectId()), - rule=rule, - execution_id=str(bson.ObjectId()), - failure_reason='so much fail') + enforcement_db = RuleEnforcementDB( + trigger_instance_id=str(bson.ObjectId()), + rule=rule, + execution_id=str(bson.ObjectId()), + failure_reason="so much fail", + ) enforcement_db = RuleEnforcement.add_or_update(enforcement_db) self.assertEqual(enforcement_db.status, RULE_ENFORCEMENT_STATUS_FAILED) # 3. Explcit status field - succeeded + failure reasun - enforcement_db = RuleEnforcementDB(trigger_instance_id=str(bson.ObjectId()), - rule=rule, - execution_id=str(bson.ObjectId()), - status=RULE_ENFORCEMENT_STATUS_SUCCEEDED, - failure_reason='so much fail') + enforcement_db = RuleEnforcementDB( + trigger_instance_id=str(bson.ObjectId()), + rule=rule, + execution_id=str(bson.ObjectId()), + status=RULE_ENFORCEMENT_STATUS_SUCCEEDED, + failure_reason="so much fail", + ) enforcement_db = RuleEnforcement.add_or_update(enforcement_db) self.assertEqual(enforcement_db.status, RULE_ENFORCEMENT_STATUS_FAILED) # 4. Explcit status field - succeeded + no failure reasun - enforcement_db = RuleEnforcementDB(trigger_instance_id=str(bson.ObjectId()), - rule=rule, - execution_id=str(bson.ObjectId()), - status=RULE_ENFORCEMENT_STATUS_SUCCEEDED) + enforcement_db = RuleEnforcementDB( + trigger_instance_id=str(bson.ObjectId()), + rule=rule, + execution_id=str(bson.ObjectId()), + status=RULE_ENFORCEMENT_STATUS_SUCCEEDED, + ) enforcement_db = RuleEnforcement.add_or_update(enforcement_db) self.assertEqual(enforcement_db.status, RULE_ENFORCEMENT_STATUS_SUCCEEDED) # 5. Explcit status field - failed + no failure reasun - enforcement_db = RuleEnforcementDB(trigger_instance_id=str(bson.ObjectId()), - rule=rule, - execution_id=str(bson.ObjectId()), - status=RULE_ENFORCEMENT_STATUS_FAILED) + enforcement_db = RuleEnforcementDB( + trigger_instance_id=str(bson.ObjectId()), + rule=rule, + execution_id=str(bson.ObjectId()), + status=RULE_ENFORCEMENT_STATUS_FAILED, + ) enforcement_db = RuleEnforcement.add_or_update(enforcement_db) self.assertEqual(enforcement_db.status, RULE_ENFORCEMENT_STATUS_FAILED) @staticmethod def _create_save_rule_enforcement(): - created = RuleEnforcementDB(trigger_instance_id=str(bson.ObjectId()), - rule={'ref': 'foo_pack.foo_rule', - 'uid': 'rule:foo_pack:foo_rule'}, - execution_id=str(bson.ObjectId())) + created = RuleEnforcementDB( + trigger_instance_id=str(bson.ObjectId()), + rule={"ref": "foo_pack.foo_rule", "uid": "rule:foo_pack:foo_rule"}, + execution_id=str(bson.ObjectId()), + ) return RuleEnforcement.add_or_update(created) @staticmethod diff --git a/st2common/tests/unit/test_db_task.py b/st2common/tests/unit/test_db_task.py index 60285f13666..bc0d3e23824 100644 --- a/st2common/tests/unit/test_db_task.py +++ b/st2common/tests/unit/test_db_task.py @@ -27,19 +27,18 @@ from st2common.util import date as date_utils -@mock.patch.object(publishers.PoolPublisher, 'publish', mock.MagicMock()) +@mock.patch.object(publishers.PoolPublisher, "publish", mock.MagicMock()) class TaskExecutionModelTest(st2tests.DbTestCase): - def test_task_execution_crud(self): initial = wf_db_models.TaskExecutionDB() initial.workflow_execution = uuid.uuid4().hex - initial.task_name = 't1' - initial.task_id = 't1' + initial.task_name = "t1" + initial.task_id = "t1" initial.task_route = 0 - initial.task_spec = {'tasks': {'t1': 'some task'}} + initial.task_spec = {"tasks": {"t1": "some task"}} initial.delay = 180 - initial.status = 'requested' - initial.context = {'var1': 'foobar'} + initial.status = "requested" + initial.context = {"var1": "foobar"} # Test create created = wf_db_access.TaskExecution.add_or_update(initial) @@ -61,7 +60,7 @@ def test_task_execution_crud(self): self.assertDictEqual(created.context, retrieved.context) # Test update - status = 'running' + status = "running" retrieved = wf_db_access.TaskExecution.update(retrieved, status=status) updated = wf_db_access.TaskExecution.get_by_id(doc_id) self.assertNotEqual(created.rev, updated.rev) @@ -79,8 +78,8 @@ def test_task_execution_crud(self): self.assertDictEqual(updated.context, retrieved.context) # Test add or update - retrieved.result = {'output': 'fubar'} - retrieved.status = 'succeeded' + retrieved.result = {"output": "fubar"} + retrieved.status = "succeeded" retrieved.end_timestamp = date_utils.get_datetime_utc_now() retrieved = wf_db_access.TaskExecution.add_or_update(retrieved) updated = wf_db_access.TaskExecution.get_by_id(doc_id) @@ -105,20 +104,20 @@ def test_task_execution_crud(self): self.assertRaises( db_exc.StackStormDBObjectNotFoundError, wf_db_access.TaskExecution.get_by_id, - doc_id + doc_id, ) def test_task_execution_crud_set_itemized_true(self): initial = wf_db_models.TaskExecutionDB() initial.workflow_execution = uuid.uuid4().hex - initial.task_name = 't1' - initial.task_id = 't1' + initial.task_name = "t1" + initial.task_id = "t1" initial.task_route = 0 - initial.task_spec = {'tasks': {'t1': 'some task'}} + initial.task_spec = {"tasks": {"t1": "some task"}} initial.delay = 180 initial.itemized = True - initial.status = 'requested' - initial.context = {'var1': 'foobar'} + initial.status = "requested" + initial.context = {"var1": "foobar"} # Test create created = wf_db_access.TaskExecution.add_or_update(initial) @@ -140,7 +139,7 @@ def test_task_execution_crud_set_itemized_true(self): self.assertDictEqual(created.context, retrieved.context) # Test update - status = 'running' + status = "running" retrieved = wf_db_access.TaskExecution.update(retrieved, status=status) updated = wf_db_access.TaskExecution.get_by_id(doc_id) self.assertNotEqual(created.rev, updated.rev) @@ -158,8 +157,8 @@ def test_task_execution_crud_set_itemized_true(self): self.assertDictEqual(updated.context, retrieved.context) # Test add or update - retrieved.result = {'output': 'fubar'} - retrieved.status = 'succeeded' + retrieved.result = {"output": "fubar"} + retrieved.status = "succeeded" retrieved.end_timestamp = date_utils.get_datetime_utc_now() retrieved = wf_db_access.TaskExecution.add_or_update(retrieved) updated = wf_db_access.TaskExecution.get_by_id(doc_id) @@ -184,19 +183,19 @@ def test_task_execution_crud_set_itemized_true(self): self.assertRaises( db_exc.StackStormDBObjectNotFoundError, wf_db_access.TaskExecution.get_by_id, - doc_id + doc_id, ) def test_task_execution_write_conflict(self): initial = wf_db_models.TaskExecutionDB() initial.workflow_execution = uuid.uuid4().hex - initial.task_name = 't1' - initial.task_id = 't1' + initial.task_name = "t1" + initial.task_id = "t1" initial.task_route = 0 - initial.task_spec = {'tasks': {'t1': 'some task'}} + initial.task_spec = {"tasks": {"t1": "some task"}} initial.delay = 180 - initial.status = 'requested' - initial.context = {'var1': 'foobar'} + initial.status = "requested" + initial.context = {"var1": "foobar"} # Prep record created = wf_db_access.TaskExecution.add_or_update(initial) @@ -208,7 +207,7 @@ def test_task_execution_write_conflict(self): retrieved2 = wf_db_access.TaskExecution.get_by_id(doc_id) # Test update on instance 1, expect success - status = 'running' + status = "running" retrieved1 = wf_db_access.TaskExecution.update(retrieved1, status=status) updated = wf_db_access.TaskExecution.get_by_id(doc_id) self.assertNotEqual(created.rev, updated.rev) @@ -230,7 +229,7 @@ def test_task_execution_write_conflict(self): db_exc.StackStormDBObjectWriteConflictError, wf_db_access.TaskExecution.update, retrieved2, - status='pausing' + status="pausing", ) # Test delete @@ -239,5 +238,5 @@ def test_task_execution_write_conflict(self): self.assertRaises( db_exc.StackStormDBObjectNotFoundError, wf_db_access.TaskExecution.get_by_id, - doc_id + doc_id, ) diff --git a/st2common/tests/unit/test_db_trace.py b/st2common/tests/unit/test_db_trace.py index b9e2ec9c8a4..1e0f884472c 100644 --- a/st2common/tests/unit/test_db_trace.py +++ b/st2common/tests/unit/test_db_trace.py @@ -24,85 +24,103 @@ class TraceDBTest(CleanDbTestCase): - def test_get(self): saved = TraceDBTest._create_save_trace( - trace_tag='test_trace', + trace_tag="test_trace", action_executions=[str(bson.ObjectId()) for _ in range(4)], rules=[str(bson.ObjectId()) for _ in range(4)], - trigger_instances=[str(bson.ObjectId()) for _ in range(5)]) + trigger_instances=[str(bson.ObjectId()) for _ in range(5)], + ) retrieved = Trace.get(id=saved.id) - self.assertEqual(retrieved.id, saved.id, 'Incorrect trace retrieved.') + self.assertEqual(retrieved.id, saved.id, "Incorrect trace retrieved.") def test_query(self): saved = TraceDBTest._create_save_trace( - trace_tag='test_trace', + trace_tag="test_trace", action_executions=[str(bson.ObjectId()) for _ in range(4)], rules=[str(bson.ObjectId()) for _ in range(4)], - trigger_instances=[str(bson.ObjectId()) for _ in range(5)]) + trigger_instances=[str(bson.ObjectId()) for _ in range(5)], + ) retrieved = Trace.query(trace_tag=saved.trace_tag) - self.assertEqual(len(retrieved), 1, 'Should have 1 trace.') - self.assertEqual(retrieved[0].id, saved.id, 'Incorrect trace retrieved.') + self.assertEqual(len(retrieved), 1, "Should have 1 trace.") + self.assertEqual(retrieved[0].id, saved.id, "Incorrect trace retrieved.") # Add another trace with same trace_tag and confirm that we support. # This is most likley an anti-pattern for the trace_tag but it is an unknown. saved = TraceDBTest._create_save_trace( - trace_tag='test_trace', + trace_tag="test_trace", action_executions=[str(bson.ObjectId()) for _ in range(2)], rules=[str(bson.ObjectId()) for _ in range(4)], - trigger_instances=[str(bson.ObjectId()) for _ in range(3)]) + trigger_instances=[str(bson.ObjectId()) for _ in range(3)], + ) retrieved = Trace.query(trace_tag=saved.trace_tag) - self.assertEqual(len(retrieved), 2, 'Should have 2 traces.') + self.assertEqual(len(retrieved), 2, "Should have 2 traces.") def test_update(self): saved = TraceDBTest._create_save_trace( - trace_tag='test_trace', - action_executions=[], - rules=[], - trigger_instances=[]) + trace_tag="test_trace", action_executions=[], rules=[], trigger_instances=[] + ) retrieved = Trace.query(trace_tag=saved.trace_tag) - self.assertEqual(len(retrieved), 1, 'Should have 1 trace.') - self.assertEqual(retrieved[0].id, saved.id, 'Incorrect trace retrieved.') + self.assertEqual(len(retrieved), 1, "Should have 1 trace.") + self.assertEqual(retrieved[0].id, saved.id, "Incorrect trace retrieved.") no_action_executions = 4 no_rules = 4 no_trigger_instances = 5 saved = TraceDBTest._create_save_trace( - trace_tag='test_trace', + trace_tag="test_trace", id_=retrieved[0].id, - action_executions=[str(bson.ObjectId()) for _ in range(no_action_executions)], + action_executions=[ + str(bson.ObjectId()) for _ in range(no_action_executions) + ], rules=[str(bson.ObjectId()) for _ in range(no_rules)], - trigger_instances=[str(bson.ObjectId()) for _ in range(no_trigger_instances)]) + trigger_instances=[ + str(bson.ObjectId()) for _ in range(no_trigger_instances) + ], + ) retrieved = Trace.query(trace_tag=saved.trace_tag) - self.assertEqual(len(retrieved), 1, 'Should have 1 trace.') - self.assertEqual(retrieved[0].id, saved.id, 'Incorrect trace retrieved.') + self.assertEqual(len(retrieved), 1, "Should have 1 trace.") + self.assertEqual(retrieved[0].id, saved.id, "Incorrect trace retrieved.") # validate update - self.assertEqual(len(retrieved[0].action_executions), no_action_executions, - 'Failed to update action_executions.') - self.assertEqual(len(retrieved[0].rules), no_rules, 'Failed to update rules.') - self.assertEqual(len(retrieved[0].trigger_instances), no_trigger_instances, - 'Failed to update trigger_instances.') + self.assertEqual( + len(retrieved[0].action_executions), + no_action_executions, + "Failed to update action_executions.", + ) + self.assertEqual(len(retrieved[0].rules), no_rules, "Failed to update rules.") + self.assertEqual( + len(retrieved[0].trigger_instances), + no_trigger_instances, + "Failed to update trigger_instances.", + ) def test_update_via_list_push(self): no_action_executions = 4 no_rules = 4 no_trigger_instances = 5 saved = TraceDBTest._create_save_trace( - trace_tag='test_trace', - action_executions=[str(bson.ObjectId()) for _ in range(no_action_executions)], + trace_tag="test_trace", + action_executions=[ + str(bson.ObjectId()) for _ in range(no_action_executions) + ], rules=[str(bson.ObjectId()) for _ in range(no_rules)], - trigger_instances=[str(bson.ObjectId()) for _ in range(no_trigger_instances)]) + trigger_instances=[ + str(bson.ObjectId()) for _ in range(no_trigger_instances) + ], + ) # push updates Trace.push_action_execution( - saved, action_execution=TraceComponentDB(object_id=str(bson.ObjectId()))) + saved, action_execution=TraceComponentDB(object_id=str(bson.ObjectId())) + ) Trace.push_rule(saved, rule=TraceComponentDB(object_id=str(bson.ObjectId()))) Trace.push_trigger_instance( - saved, trigger_instance=TraceComponentDB(object_id=str(bson.ObjectId()))) + saved, trigger_instance=TraceComponentDB(object_id=str(bson.ObjectId())) + ) retrieved = Trace.get(id=saved.id) - self.assertEqual(retrieved.id, saved.id, 'Incorrect trace retrieved.') + self.assertEqual(retrieved.id, saved.id, "Incorrect trace retrieved.") self.assertEqual(len(retrieved.action_executions), no_action_executions + 1) self.assertEqual(len(retrieved.rules), no_rules + 1) self.assertEqual(len(retrieved.trigger_instances), no_trigger_instances + 1) @@ -112,33 +130,48 @@ def test_update_via_list_push_components(self): no_rules = 4 no_trigger_instances = 5 saved = TraceDBTest._create_save_trace( - trace_tag='test_trace', - action_executions=[str(bson.ObjectId()) for _ in range(no_action_executions)], + trace_tag="test_trace", + action_executions=[ + str(bson.ObjectId()) for _ in range(no_action_executions) + ], rules=[str(bson.ObjectId()) for _ in range(no_rules)], - trigger_instances=[str(bson.ObjectId()) for _ in range(no_trigger_instances)]) + trigger_instances=[ + str(bson.ObjectId()) for _ in range(no_trigger_instances) + ], + ) retrieved = Trace.push_components( saved, - action_executions=[TraceComponentDB(object_id=str(bson.ObjectId())) - for _ in range(no_action_executions)], - rules=[TraceComponentDB(object_id=str(bson.ObjectId())) - for _ in range(no_rules)], - trigger_instances=[TraceComponentDB(object_id=str(bson.ObjectId())) - for _ in range(no_trigger_instances)]) - - self.assertEqual(retrieved.id, saved.id, 'Incorrect trace retrieved.') + action_executions=[ + TraceComponentDB(object_id=str(bson.ObjectId())) + for _ in range(no_action_executions) + ], + rules=[ + TraceComponentDB(object_id=str(bson.ObjectId())) + for _ in range(no_rules) + ], + trigger_instances=[ + TraceComponentDB(object_id=str(bson.ObjectId())) + for _ in range(no_trigger_instances) + ], + ) + + self.assertEqual(retrieved.id, saved.id, "Incorrect trace retrieved.") self.assertEqual(len(retrieved.action_executions), no_action_executions * 2) self.assertEqual(len(retrieved.rules), no_rules * 2) self.assertEqual(len(retrieved.trigger_instances), no_trigger_instances * 2) @staticmethod - def _create_save_trace(trace_tag, id_=None, action_executions=None, rules=None, - trigger_instances=None): + def _create_save_trace( + trace_tag, id_=None, action_executions=None, rules=None, trigger_instances=None + ): if action_executions is None: action_executions = [] - action_executions = [TraceComponentDB(object_id=action_execution) - for action_execution in action_executions] + action_executions = [ + TraceComponentDB(object_id=action_execution) + for action_execution in action_executions + ] if rules is None: rules = [] @@ -146,12 +179,16 @@ def _create_save_trace(trace_tag, id_=None, action_executions=None, rules=None, if trigger_instances is None: trigger_instances = [] - trigger_instances = [TraceComponentDB(object_id=trigger_instance) - for trigger_instance in trigger_instances] - - created = TraceDB(id=id_, - trace_tag=trace_tag, - trigger_instances=trigger_instances, - rules=rules, - action_executions=action_executions) + trigger_instances = [ + TraceComponentDB(object_id=trigger_instance) + for trigger_instance in trigger_instances + ] + + created = TraceDB( + id=id_, + trace_tag=trace_tag, + trigger_instances=trigger_instances, + rules=rules, + action_executions=action_executions, + ) return Trace.add_or_update(created) diff --git a/st2common/tests/unit/test_db_uid_mixin.py b/st2common/tests/unit/test_db_uid_mixin.py index e3283e6f91a..b7a6a251089 100644 --- a/st2common/tests/unit/test_db_uid_mixin.py +++ b/st2common/tests/unit/test_db_uid_mixin.py @@ -23,28 +23,41 @@ class UIDMixinTestCase(CleanDbTestCase): def test_get_uid(self): - pack_1_db = PackDB(ref='test_pack') - pack_2_db = PackDB(ref='examples') + pack_1_db = PackDB(ref="test_pack") + pack_2_db = PackDB(ref="examples") - self.assertEqual(pack_1_db.get_uid(), 'pack:test_pack') - self.assertEqual(pack_2_db.get_uid(), 'pack:examples') + self.assertEqual(pack_1_db.get_uid(), "pack:test_pack") + self.assertEqual(pack_2_db.get_uid(), "pack:examples") - action_1_db = ActionDB(pack='examples', name='my_action', ref='examples.my_action') - action_2_db = ActionDB(pack='core', name='local', ref='core.local') - self.assertEqual(action_1_db.get_uid(), 'action:examples:my_action') - self.assertEqual(action_2_db.get_uid(), 'action:core:local') + action_1_db = ActionDB( + pack="examples", name="my_action", ref="examples.my_action" + ) + action_2_db = ActionDB(pack="core", name="local", ref="core.local") + self.assertEqual(action_1_db.get_uid(), "action:examples:my_action") + self.assertEqual(action_2_db.get_uid(), "action:core:local") def test_uid_is_populated_on_save(self): - pack_1_db = PackDB(ref='test_pack', name='test', description='foo', version='1.0.0', - author='dev', email='test@example.com') + pack_1_db = PackDB( + ref="test_pack", + name="test", + description="foo", + version="1.0.0", + author="dev", + email="test@example.com", + ) pack_1_db = Pack.add_or_update(pack_1_db) pack_1_db.reload() - self.assertEqual(pack_1_db.uid, 'pack:test_pack') + self.assertEqual(pack_1_db.uid, "pack:test_pack") - action_1_db = ActionDB(name='local', pack='core', ref='core.local', entry_point='', - runner_type={'name': 'local-shell-cmd'}) + action_1_db = ActionDB( + name="local", + pack="core", + ref="core.local", + entry_point="", + runner_type={"name": "local-shell-cmd"}, + ) action_1_db = Action.add_or_update(action_1_db) action_1_db.reload() - self.assertEqual(action_1_db.uid, 'action:core:local') + self.assertEqual(action_1_db.uid, "action:core:local") diff --git a/st2common/tests/unit/test_db_workflow.py b/st2common/tests/unit/test_db_workflow.py index 1f7ce38a4ac..e434d0f9d60 100644 --- a/st2common/tests/unit/test_db_workflow.py +++ b/st2common/tests/unit/test_db_workflow.py @@ -26,14 +26,13 @@ from st2common.exceptions import db as db_exc -@mock.patch.object(publishers.PoolPublisher, 'publish', mock.MagicMock()) +@mock.patch.object(publishers.PoolPublisher, "publish", mock.MagicMock()) class WorkflowExecutionModelTest(st2tests.DbTestCase): - def test_workflow_execution_crud(self): initial = wf_db_models.WorkflowExecutionDB() initial.action_execution = uuid.uuid4().hex - initial.graph = {'var1': 'foobar'} - initial.status = 'requested' + initial.graph = {"var1": "foobar"} + initial.status = "requested" # Test create created = wf_db_access.WorkflowExecution.add_or_update(initial) @@ -47,9 +46,11 @@ def test_workflow_execution_crud(self): self.assertEqual(created.status, retrieved.status) # Test update - graph = {'var1': 'fubar'} - status = 'running' - retrieved = wf_db_access.WorkflowExecution.update(retrieved, graph=graph, status=status) + graph = {"var1": "fubar"} + status = "running" + retrieved = wf_db_access.WorkflowExecution.update( + retrieved, graph=graph, status=status + ) updated = wf_db_access.WorkflowExecution.get_by_id(doc_id) self.assertNotEqual(created.rev, updated.rev) self.assertEqual(retrieved.rev, updated.rev) @@ -58,7 +59,7 @@ def test_workflow_execution_crud(self): self.assertEqual(retrieved.status, updated.status) # Test add or update - retrieved.graph = {'var2': 'fubar'} + retrieved.graph = {"var2": "fubar"} retrieved = wf_db_access.WorkflowExecution.add_or_update(retrieved) updated = wf_db_access.WorkflowExecution.get_by_id(doc_id) self.assertNotEqual(created.rev, updated.rev) @@ -73,14 +74,14 @@ def test_workflow_execution_crud(self): self.assertRaises( db_exc.StackStormDBObjectNotFoundError, wf_db_access.WorkflowExecution.get_by_id, - doc_id + doc_id, ) def test_workflow_execution_write_conflict(self): initial = wf_db_models.WorkflowExecutionDB() initial.action_execution = uuid.uuid4().hex - initial.graph = {'var1': 'foobar'} - initial.status = 'requested' + initial.graph = {"var1": "foobar"} + initial.status = "requested" # Prep record created = wf_db_access.WorkflowExecution.add_or_update(initial) @@ -92,9 +93,11 @@ def test_workflow_execution_write_conflict(self): retrieved2 = wf_db_access.WorkflowExecution.get_by_id(doc_id) # Test update on instance 1, expect success - graph = {'var1': 'fubar'} - status = 'running' - retrieved1 = wf_db_access.WorkflowExecution.update(retrieved1, graph=graph, status=status) + graph = {"var1": "fubar"} + status = "running" + retrieved1 = wf_db_access.WorkflowExecution.update( + retrieved1, graph=graph, status=status + ) updated = wf_db_access.WorkflowExecution.get_by_id(doc_id) self.assertNotEqual(created.rev, updated.rev) self.assertEqual(retrieved1.rev, updated.rev) @@ -107,7 +110,7 @@ def test_workflow_execution_write_conflict(self): db_exc.StackStormDBObjectWriteConflictError, wf_db_access.WorkflowExecution.update, retrieved2, - graph={'var2': 'fubar'} + graph={"var2": "fubar"}, ) # Test delete @@ -116,5 +119,5 @@ def test_workflow_execution_write_conflict(self): self.assertRaises( db_exc.StackStormDBObjectNotFoundError, wf_db_access.WorkflowExecution.get_by_id, - doc_id + doc_id, ) diff --git a/st2common/tests/unit/test_dist_utils.py b/st2common/tests/unit/test_dist_utils.py index 901f8abd44e..1b01d4ff48b 100644 --- a/st2common/tests/unit/test_dist_utils.py +++ b/st2common/tests/unit/test_dist_utils.py @@ -21,7 +21,7 @@ import unittest2 BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -SCRIPTS_PATH = os.path.join(BASE_DIR, '../../../scripts/') +SCRIPTS_PATH = os.path.join(BASE_DIR, "../../../scripts/") # Add scripts/ which contain main dist_utils.py to PYTHONPATH sys.path.insert(0, SCRIPTS_PATH) @@ -32,21 +32,21 @@ from dist_utils import apply_vagrant_workaround from dist_utils import get_version_string -__all__ = [ - 'DistUtilsTestCase' -] +__all__ = ["DistUtilsTestCase"] -REQUIREMENTS_PATH_1 = os.path.join(BASE_DIR, '../fixtures/requirements-used-for-tests.txt') -REQUIREMENTS_PATH_2 = os.path.join(BASE_DIR, '../../../requirements.txt') -VERSION_FILE_PATH = os.path.join(BASE_DIR, '../fixtures/version_file.py') +REQUIREMENTS_PATH_1 = os.path.join( + BASE_DIR, "../fixtures/requirements-used-for-tests.txt" +) +REQUIREMENTS_PATH_2 = os.path.join(BASE_DIR, "../../../requirements.txt") +VERSION_FILE_PATH = os.path.join(BASE_DIR, "../fixtures/version_file.py") class DistUtilsTestCase(unittest2.TestCase): def setUp(self): super(DistUtilsTestCase, self).setUp() - if 'pip' in sys.modules: - del sys.modules['pip'] + if "pip" in sys.modules: + del sys.modules["pip"] def tearDown(self): super(DistUtilsTestCase, self).tearDown() @@ -54,15 +54,15 @@ def tearDown(self): def test_check_pip_is_installed_success(self): self.assertTrue(check_pip_is_installed()) - @mock.patch('sys.exit') + @mock.patch("sys.exit") def test_check_pip_is_installed_failure(self, mock_sys_exit): if six.PY3: - module_name = 'builtins.__import__' + module_name = "builtins.__import__" else: - module_name = '__builtin__.__import__' + module_name = "__builtin__.__import__" with mock.patch(module_name) as mock_import: - mock_import.side_effect = ImportError('not found') + mock_import.side_effect = ImportError("not found") self.assertEqual(mock_sys_exit.call_count, 0) check_pip_is_installed() @@ -72,12 +72,12 @@ def test_check_pip_is_installed_failure(self, mock_sys_exit): def test_check_pip_version_success(self): self.assertTrue(check_pip_version()) - @mock.patch('sys.exit') + @mock.patch("sys.exit") def test_check_pip_version_failure(self, mock_sys_exit): mock_pip = mock.Mock() - mock_pip.__version__ = '0.0.0' - sys.modules['pip'] = mock_pip + mock_pip.__version__ = "0.0.0" + sys.modules["pip"] = mock_pip self.assertEqual(mock_sys_exit.call_count, 0) check_pip_version() @@ -86,50 +86,50 @@ def test_check_pip_version_failure(self, mock_sys_exit): def test_get_version_string(self): version = get_version_string(VERSION_FILE_PATH) - self.assertEqual(version, '1.2.3') + self.assertEqual(version, "1.2.3") def test_apply_vagrant_workaround(self): - with mock.patch('os.link') as _: - os.environ['USER'] = 'stanley' + with mock.patch("os.link") as _: + os.environ["USER"] = "stanley" apply_vagrant_workaround() self.assertTrue(os.link) - with mock.patch('os.link') as _: - os.environ['USER'] = 'vagrant' + with mock.patch("os.link") as _: + os.environ["USER"] = "vagrant" apply_vagrant_workaround() - self.assertFalse(getattr(os, 'link', None)) + self.assertFalse(getattr(os, "link", None)) def test_fetch_requirements(self): expected_reqs = [ - 'RandomWords', - 'amqp==2.5.1', - 'argcomplete', - 'bcrypt==3.1.6', - 'flex==6.14.0', - 'logshipper', - 'orquesta', - 'st2-auth-backend-flat-file', - 'logshipper-editable', - 'python_runner', - 'SomePackageHq', - 'SomePackageSvn', - 'gitpython==2.1.11', - 'ose-timer==0.7.5', - 'oslo.config<1.13,>=1.12.1', - 'requests[security]<2.22.0,>=2.21.0', - 'retrying==1.3.3', - 'zake==0.2.2' + "RandomWords", + "amqp==2.5.1", + "argcomplete", + "bcrypt==3.1.6", + "flex==6.14.0", + "logshipper", + "orquesta", + "st2-auth-backend-flat-file", + "logshipper-editable", + "python_runner", + "SomePackageHq", + "SomePackageSvn", + "gitpython==2.1.11", + "ose-timer==0.7.5", + "oslo.config<1.13,>=1.12.1", + "requests[security]<2.22.0,>=2.21.0", + "retrying==1.3.3", + "zake==0.2.2", ] expected_links = [ - 'git+https://github.com/Kami/logshipper.git@stackstorm_patched#egg=logshipper', - 'git+https://github.com/StackStorm/orquesta.git@224c1a589a6007eb0598a62ee99d674e7836d369#egg=orquesta', # NOQA - 'git+https://github.com/StackStorm/st2-auth-backend-flat-file.git@master#egg=st2-auth-backend-flat-file', # NOQA - 'git+https://github.com/Kami/logshipper.git@stackstorm_patched#egg=logshipper-editable', - 'git+https://github.com/StackStorm/st2.git#egg=python_runner&subdirectory=contrib/runners/python_runner', # NOQA - 'hg+https://hg.repo/some_pkg.git#egg=SomePackageHq', - 'svn+svn://svn.repo/some_pkg/trunk/@ma-branch#egg=SomePackageSvn' + "git+https://github.com/Kami/logshipper.git@stackstorm_patched#egg=logshipper", + "git+https://github.com/StackStorm/orquesta.git@224c1a589a6007eb0598a62ee99d674e7836d369#egg=orquesta", # NOQA + "git+https://github.com/StackStorm/st2-auth-backend-flat-file.git@master#egg=st2-auth-backend-flat-file", # NOQA + "git+https://github.com/Kami/logshipper.git@stackstorm_patched#egg=logshipper-editable", + "git+https://github.com/StackStorm/st2.git#egg=python_runner&subdirectory=contrib/runners/python_runner", # NOQA + "hg+https://hg.repo/some_pkg.git#egg=SomePackageHq", + "svn+svn://svn.repo/some_pkg/trunk/@ma-branch#egg=SomePackageSvn", ] reqs, links = fetch_requirements(REQUIREMENTS_PATH_1) diff --git a/st2common/tests/unit/test_exceptions_workflow.py b/st2common/tests/unit/test_exceptions_workflow.py index 9e37f6c5d9d..a9fbcc549f0 100644 --- a/st2common/tests/unit/test_exceptions_workflow.py +++ b/st2common/tests/unit/test_exceptions_workflow.py @@ -26,7 +26,6 @@ class WorkflowExceptionTest(unittest2.TestCase): - def test_retry_on_transient_db_errors(self): instance = wf_db_models.WorkflowExecutionDB() exc = db_exc.StackStormDBObjectWriteConflictError(instance) @@ -34,13 +33,13 @@ def test_retry_on_transient_db_errors(self): def test_do_not_retry_on_transient_db_errors(self): instance = wf_db_models.WorkflowExecutionDB() - exc = db_exc.StackStormDBObjectConflictError('foobar', '1234', instance) + exc = db_exc.StackStormDBObjectConflictError("foobar", "1234", instance) self.assertFalse(wf_exc.retry_on_transient_db_errors(exc)) self.assertFalse(wf_exc.retry_on_transient_db_errors(NotImplementedError())) self.assertFalse(wf_exc.retry_on_transient_db_errors(Exception())) def test_retry_on_connection_errors(self): - exc = coordination.ToozConnectionError('foobar') + exc = coordination.ToozConnectionError("foobar") self.assertTrue(wf_exc.retry_on_connection_errors(exc)) exc = mongoengine.connection.MongoEngineConnectionError() diff --git a/st2common/tests/unit/test_executions.py b/st2common/tests/unit/test_executions.py index 59353379ac8..0be1ca7c9da 100644 --- a/st2common/tests/unit/test_executions.py +++ b/st2common/tests/unit/test_executions.py @@ -29,94 +29,117 @@ class TestActionExecutionHistoryModel(DbTestCase): - def setUp(self): super(TestActionExecutionHistoryModel, self).setUp() # Fake execution record for action liveactions triggered by workflow runner. self.fake_history_subtasks = [ { - 'id': str(bson.ObjectId()), - 'action': copy.deepcopy(fixture.ARTIFACTS['actions']['local']), - 'runner': copy.deepcopy(fixture.ARTIFACTS['runners']['run-local']), - 'liveaction': copy.deepcopy(fixture.ARTIFACTS['liveactions']['task1']), - 'status': fixture.ARTIFACTS['liveactions']['task1']['status'], - 'start_timestamp': fixture.ARTIFACTS['liveactions']['task1']['start_timestamp'], - 'end_timestamp': fixture.ARTIFACTS['liveactions']['task1']['end_timestamp'] + "id": str(bson.ObjectId()), + "action": copy.deepcopy(fixture.ARTIFACTS["actions"]["local"]), + "runner": copy.deepcopy(fixture.ARTIFACTS["runners"]["run-local"]), + "liveaction": copy.deepcopy(fixture.ARTIFACTS["liveactions"]["task1"]), + "status": fixture.ARTIFACTS["liveactions"]["task1"]["status"], + "start_timestamp": fixture.ARTIFACTS["liveactions"]["task1"][ + "start_timestamp" + ], + "end_timestamp": fixture.ARTIFACTS["liveactions"]["task1"][ + "end_timestamp" + ], }, { - 'id': str(bson.ObjectId()), - 'action': copy.deepcopy(fixture.ARTIFACTS['actions']['local']), - 'runner': copy.deepcopy(fixture.ARTIFACTS['runners']['run-local']), - 'liveaction': copy.deepcopy(fixture.ARTIFACTS['liveactions']['task2']), - 'status': fixture.ARTIFACTS['liveactions']['task2']['status'], - 'start_timestamp': fixture.ARTIFACTS['liveactions']['task2']['start_timestamp'], - 'end_timestamp': fixture.ARTIFACTS['liveactions']['task2']['end_timestamp'] - } + "id": str(bson.ObjectId()), + "action": copy.deepcopy(fixture.ARTIFACTS["actions"]["local"]), + "runner": copy.deepcopy(fixture.ARTIFACTS["runners"]["run-local"]), + "liveaction": copy.deepcopy(fixture.ARTIFACTS["liveactions"]["task2"]), + "status": fixture.ARTIFACTS["liveactions"]["task2"]["status"], + "start_timestamp": fixture.ARTIFACTS["liveactions"]["task2"][ + "start_timestamp" + ], + "end_timestamp": fixture.ARTIFACTS["liveactions"]["task2"][ + "end_timestamp" + ], + }, ] # Fake execution record for a workflow action execution triggered by rule. self.fake_history_workflow = { - 'id': str(bson.ObjectId()), - 'trigger': copy.deepcopy(fixture.ARTIFACTS['trigger']), - 'trigger_type': copy.deepcopy(fixture.ARTIFACTS['trigger_type']), - 'trigger_instance': copy.deepcopy(fixture.ARTIFACTS['trigger_instance']), - 'rule': copy.deepcopy(fixture.ARTIFACTS['rule']), - 'action': copy.deepcopy(fixture.ARTIFACTS['actions']['chain']), - 'runner': copy.deepcopy(fixture.ARTIFACTS['runners']['action-chain']), - 'liveaction': copy.deepcopy(fixture.ARTIFACTS['liveactions']['workflow']), - 'children': [task['id'] for task in self.fake_history_subtasks], - 'status': fixture.ARTIFACTS['liveactions']['workflow']['status'], - 'start_timestamp': fixture.ARTIFACTS['liveactions']['workflow']['start_timestamp'], - 'end_timestamp': fixture.ARTIFACTS['liveactions']['workflow']['end_timestamp'] + "id": str(bson.ObjectId()), + "trigger": copy.deepcopy(fixture.ARTIFACTS["trigger"]), + "trigger_type": copy.deepcopy(fixture.ARTIFACTS["trigger_type"]), + "trigger_instance": copy.deepcopy(fixture.ARTIFACTS["trigger_instance"]), + "rule": copy.deepcopy(fixture.ARTIFACTS["rule"]), + "action": copy.deepcopy(fixture.ARTIFACTS["actions"]["chain"]), + "runner": copy.deepcopy(fixture.ARTIFACTS["runners"]["action-chain"]), + "liveaction": copy.deepcopy(fixture.ARTIFACTS["liveactions"]["workflow"]), + "children": [task["id"] for task in self.fake_history_subtasks], + "status": fixture.ARTIFACTS["liveactions"]["workflow"]["status"], + "start_timestamp": fixture.ARTIFACTS["liveactions"]["workflow"][ + "start_timestamp" + ], + "end_timestamp": fixture.ARTIFACTS["liveactions"]["workflow"][ + "end_timestamp" + ], } # Assign parent to the execution records for the subtasks. for task in self.fake_history_subtasks: - task['parent'] = self.fake_history_workflow['id'] + task["parent"] = self.fake_history_workflow["id"] def test_model_complete(self): # Create API object. obj = ActionExecutionAPI(**copy.deepcopy(self.fake_history_workflow)) - self.assertDictEqual(obj.trigger, self.fake_history_workflow['trigger']) - self.assertDictEqual(obj.trigger_type, self.fake_history_workflow['trigger_type']) - self.assertDictEqual(obj.trigger_instance, self.fake_history_workflow['trigger_instance']) - self.assertDictEqual(obj.rule, self.fake_history_workflow['rule']) - self.assertDictEqual(obj.action, self.fake_history_workflow['action']) - self.assertDictEqual(obj.runner, self.fake_history_workflow['runner']) - self.assertEqual(obj.liveaction, self.fake_history_workflow['liveaction']) - self.assertIsNone(getattr(obj, 'parent', None)) - self.assertListEqual(obj.children, self.fake_history_workflow['children']) + self.assertDictEqual(obj.trigger, self.fake_history_workflow["trigger"]) + self.assertDictEqual( + obj.trigger_type, self.fake_history_workflow["trigger_type"] + ) + self.assertDictEqual( + obj.trigger_instance, self.fake_history_workflow["trigger_instance"] + ) + self.assertDictEqual(obj.rule, self.fake_history_workflow["rule"]) + self.assertDictEqual(obj.action, self.fake_history_workflow["action"]) + self.assertDictEqual(obj.runner, self.fake_history_workflow["runner"]) + self.assertEqual(obj.liveaction, self.fake_history_workflow["liveaction"]) + self.assertIsNone(getattr(obj, "parent", None)) + self.assertListEqual(obj.children, self.fake_history_workflow["children"]) # Convert API object to DB model. model = ActionExecutionAPI.to_model(obj) self.assertEqual(str(model.id), obj.id) - self.assertDictEqual(model.trigger, self.fake_history_workflow['trigger']) - self.assertDictEqual(model.trigger_type, self.fake_history_workflow['trigger_type']) - self.assertDictEqual(model.trigger_instance, self.fake_history_workflow['trigger_instance']) - self.assertDictEqual(model.rule, self.fake_history_workflow['rule']) - self.assertDictEqual(model.action, self.fake_history_workflow['action']) - self.assertDictEqual(model.runner, self.fake_history_workflow['runner']) - doc = copy.deepcopy(self.fake_history_workflow['liveaction']) - doc['start_timestamp'] = doc['start_timestamp'] - doc['end_timestamp'] = doc['end_timestamp'] + self.assertDictEqual(model.trigger, self.fake_history_workflow["trigger"]) + self.assertDictEqual( + model.trigger_type, self.fake_history_workflow["trigger_type"] + ) + self.assertDictEqual( + model.trigger_instance, self.fake_history_workflow["trigger_instance"] + ) + self.assertDictEqual(model.rule, self.fake_history_workflow["rule"]) + self.assertDictEqual(model.action, self.fake_history_workflow["action"]) + self.assertDictEqual(model.runner, self.fake_history_workflow["runner"]) + doc = copy.deepcopy(self.fake_history_workflow["liveaction"]) + doc["start_timestamp"] = doc["start_timestamp"] + doc["end_timestamp"] = doc["end_timestamp"] self.assertDictEqual(model.liveaction, doc) - self.assertIsNone(getattr(model, 'parent', None)) - self.assertListEqual(model.children, self.fake_history_workflow['children']) + self.assertIsNone(getattr(model, "parent", None)) + self.assertListEqual(model.children, self.fake_history_workflow["children"]) # Convert DB model to API object. obj = ActionExecutionAPI.from_model(model) self.assertEqual(str(model.id), obj.id) - self.assertDictEqual(obj.trigger, self.fake_history_workflow['trigger']) - self.assertDictEqual(obj.trigger_type, self.fake_history_workflow['trigger_type']) - self.assertDictEqual(obj.trigger_instance, self.fake_history_workflow['trigger_instance']) - self.assertDictEqual(obj.rule, self.fake_history_workflow['rule']) - self.assertDictEqual(obj.action, self.fake_history_workflow['action']) - self.assertDictEqual(obj.runner, self.fake_history_workflow['runner']) - self.assertDictEqual(obj.liveaction, self.fake_history_workflow['liveaction']) - self.assertIsNone(getattr(obj, 'parent', None)) - self.assertListEqual(obj.children, self.fake_history_workflow['children']) + self.assertDictEqual(obj.trigger, self.fake_history_workflow["trigger"]) + self.assertDictEqual( + obj.trigger_type, self.fake_history_workflow["trigger_type"] + ) + self.assertDictEqual( + obj.trigger_instance, self.fake_history_workflow["trigger_instance"] + ) + self.assertDictEqual(obj.rule, self.fake_history_workflow["rule"]) + self.assertDictEqual(obj.action, self.fake_history_workflow["action"]) + self.assertDictEqual(obj.runner, self.fake_history_workflow["runner"]) + self.assertDictEqual(obj.liveaction, self.fake_history_workflow["liveaction"]) + self.assertIsNone(getattr(obj, "parent", None)) + self.assertListEqual(obj.children, self.fake_history_workflow["children"]) def test_crud_complete(self): # Create the DB record. @@ -124,18 +147,22 @@ def test_crud_complete(self): ActionExecution.add_or_update(ActionExecutionAPI.to_model(obj)) model = ActionExecution.get_by_id(obj.id) self.assertEqual(str(model.id), obj.id) - self.assertDictEqual(model.trigger, self.fake_history_workflow['trigger']) - self.assertDictEqual(model.trigger_type, self.fake_history_workflow['trigger_type']) - self.assertDictEqual(model.trigger_instance, self.fake_history_workflow['trigger_instance']) - self.assertDictEqual(model.rule, self.fake_history_workflow['rule']) - self.assertDictEqual(model.action, self.fake_history_workflow['action']) - self.assertDictEqual(model.runner, self.fake_history_workflow['runner']) - doc = copy.deepcopy(self.fake_history_workflow['liveaction']) - doc['start_timestamp'] = doc['start_timestamp'] - doc['end_timestamp'] = doc['end_timestamp'] + self.assertDictEqual(model.trigger, self.fake_history_workflow["trigger"]) + self.assertDictEqual( + model.trigger_type, self.fake_history_workflow["trigger_type"] + ) + self.assertDictEqual( + model.trigger_instance, self.fake_history_workflow["trigger_instance"] + ) + self.assertDictEqual(model.rule, self.fake_history_workflow["rule"]) + self.assertDictEqual(model.action, self.fake_history_workflow["action"]) + self.assertDictEqual(model.runner, self.fake_history_workflow["runner"]) + doc = copy.deepcopy(self.fake_history_workflow["liveaction"]) + doc["start_timestamp"] = doc["start_timestamp"] + doc["end_timestamp"] = doc["end_timestamp"] self.assertDictEqual(model.liveaction, doc) - self.assertIsNone(getattr(model, 'parent', None)) - self.assertListEqual(model.children, self.fake_history_workflow['children']) + self.assertIsNone(getattr(model, "parent", None)) + self.assertListEqual(model.children, self.fake_history_workflow["children"]) # Update the DB record. children = [str(bson.ObjectId()), str(bson.ObjectId())] @@ -146,20 +173,24 @@ def test_crud_complete(self): # Delete the DB record. ActionExecution.delete(model) - self.assertRaises(StackStormDBObjectNotFoundError, ActionExecution.get_by_id, obj.id) + self.assertRaises( + StackStormDBObjectNotFoundError, ActionExecution.get_by_id, obj.id + ) def test_model_partial(self): # Create API object. obj = ActionExecutionAPI(**copy.deepcopy(self.fake_history_subtasks[0])) - self.assertIsNone(getattr(obj, 'trigger', None)) - self.assertIsNone(getattr(obj, 'trigger_type', None)) - self.assertIsNone(getattr(obj, 'trigger_instance', None)) - self.assertIsNone(getattr(obj, 'rule', None)) - self.assertDictEqual(obj.action, self.fake_history_subtasks[0]['action']) - self.assertDictEqual(obj.runner, self.fake_history_subtasks[0]['runner']) - self.assertDictEqual(obj.liveaction, self.fake_history_subtasks[0]['liveaction']) - self.assertEqual(obj.parent, self.fake_history_subtasks[0]['parent']) - self.assertIsNone(getattr(obj, 'children', None)) + self.assertIsNone(getattr(obj, "trigger", None)) + self.assertIsNone(getattr(obj, "trigger_type", None)) + self.assertIsNone(getattr(obj, "trigger_instance", None)) + self.assertIsNone(getattr(obj, "rule", None)) + self.assertDictEqual(obj.action, self.fake_history_subtasks[0]["action"]) + self.assertDictEqual(obj.runner, self.fake_history_subtasks[0]["runner"]) + self.assertDictEqual( + obj.liveaction, self.fake_history_subtasks[0]["liveaction"] + ) + self.assertEqual(obj.parent, self.fake_history_subtasks[0]["parent"]) + self.assertIsNone(getattr(obj, "children", None)) # Convert API object to DB model. model = ActionExecutionAPI.to_model(obj) @@ -168,28 +199,30 @@ def test_model_partial(self): self.assertDictEqual(model.trigger_type, {}) self.assertDictEqual(model.trigger_instance, {}) self.assertDictEqual(model.rule, {}) - self.assertDictEqual(model.action, self.fake_history_subtasks[0]['action']) - self.assertDictEqual(model.runner, self.fake_history_subtasks[0]['runner']) - doc = copy.deepcopy(self.fake_history_subtasks[0]['liveaction']) - doc['start_timestamp'] = doc['start_timestamp'] - doc['end_timestamp'] = doc['end_timestamp'] + self.assertDictEqual(model.action, self.fake_history_subtasks[0]["action"]) + self.assertDictEqual(model.runner, self.fake_history_subtasks[0]["runner"]) + doc = copy.deepcopy(self.fake_history_subtasks[0]["liveaction"]) + doc["start_timestamp"] = doc["start_timestamp"] + doc["end_timestamp"] = doc["end_timestamp"] self.assertDictEqual(model.liveaction, doc) - self.assertEqual(model.parent, self.fake_history_subtasks[0]['parent']) + self.assertEqual(model.parent, self.fake_history_subtasks[0]["parent"]) self.assertListEqual(model.children, []) # Convert DB model to API object. obj = ActionExecutionAPI.from_model(model) self.assertEqual(str(model.id), obj.id) - self.assertIsNone(getattr(obj, 'trigger', None)) - self.assertIsNone(getattr(obj, 'trigger_type', None)) - self.assertIsNone(getattr(obj, 'trigger_instance', None)) - self.assertIsNone(getattr(obj, 'rule', None)) - self.assertDictEqual(obj.action, self.fake_history_subtasks[0]['action']) - self.assertDictEqual(obj.runner, self.fake_history_subtasks[0]['runner']) - self.assertDictEqual(obj.liveaction, self.fake_history_subtasks[0]['liveaction']) - self.assertEqual(obj.parent, self.fake_history_subtasks[0]['parent']) - self.assertIsNone(getattr(obj, 'children', None)) + self.assertIsNone(getattr(obj, "trigger", None)) + self.assertIsNone(getattr(obj, "trigger_type", None)) + self.assertIsNone(getattr(obj, "trigger_instance", None)) + self.assertIsNone(getattr(obj, "rule", None)) + self.assertDictEqual(obj.action, self.fake_history_subtasks[0]["action"]) + self.assertDictEqual(obj.runner, self.fake_history_subtasks[0]["runner"]) + self.assertDictEqual( + obj.liveaction, self.fake_history_subtasks[0]["liveaction"] + ) + self.assertEqual(obj.parent, self.fake_history_subtasks[0]["parent"]) + self.assertIsNone(getattr(obj, "children", None)) def test_crud_partial(self): # Create the DB record. @@ -201,13 +234,13 @@ def test_crud_partial(self): self.assertDictEqual(model.trigger_type, {}) self.assertDictEqual(model.trigger_instance, {}) self.assertDictEqual(model.rule, {}) - self.assertDictEqual(model.action, self.fake_history_subtasks[0]['action']) - self.assertDictEqual(model.runner, self.fake_history_subtasks[0]['runner']) - doc = copy.deepcopy(self.fake_history_subtasks[0]['liveaction']) - doc['start_timestamp'] = doc['start_timestamp'] - doc['end_timestamp'] = doc['end_timestamp'] + self.assertDictEqual(model.action, self.fake_history_subtasks[0]["action"]) + self.assertDictEqual(model.runner, self.fake_history_subtasks[0]["runner"]) + doc = copy.deepcopy(self.fake_history_subtasks[0]["liveaction"]) + doc["start_timestamp"] = doc["start_timestamp"] + doc["end_timestamp"] = doc["end_timestamp"] self.assertDictEqual(model.liveaction, doc) - self.assertEqual(model.parent, self.fake_history_subtasks[0]['parent']) + self.assertEqual(model.parent, self.fake_history_subtasks[0]["parent"]) self.assertListEqual(model.children, []) # Update the DB record. @@ -219,23 +252,25 @@ def test_crud_partial(self): # Delete the DB record. ActionExecution.delete(model) - self.assertRaises(StackStormDBObjectNotFoundError, ActionExecution.get_by_id, obj.id) + self.assertRaises( + StackStormDBObjectNotFoundError, ActionExecution.get_by_id, obj.id + ) def test_datetime_range(self): base = date_utils.add_utc_tz(datetime.datetime(2014, 12, 25, 0, 0, 0)) for i in range(60): timestamp = base + datetime.timedelta(seconds=i) doc = copy.deepcopy(self.fake_history_subtasks[0]) - doc['id'] = str(bson.ObjectId()) - doc['start_timestamp'] = isotime.format(timestamp) + doc["id"] = str(bson.ObjectId()) + doc["start_timestamp"] = isotime.format(timestamp) obj = ActionExecutionAPI(**doc) ActionExecution.add_or_update(ActionExecutionAPI.to_model(obj)) - dt_range = '2014-12-25T00:00:10Z..2014-12-25T00:00:19Z' + dt_range = "2014-12-25T00:00:10Z..2014-12-25T00:00:19Z" objs = ActionExecution.query(start_timestamp=dt_range) self.assertEqual(len(objs), 10) - dt_range = '2014-12-25T00:00:19Z..2014-12-25T00:00:10Z' + dt_range = "2014-12-25T00:00:19Z..2014-12-25T00:00:10Z" objs = ActionExecution.query(start_timestamp=dt_range) self.assertEqual(len(objs), 10) @@ -244,19 +279,19 @@ def test_sort_by_start_timestamp(self): for i in range(60): timestamp = base + datetime.timedelta(seconds=i) doc = copy.deepcopy(self.fake_history_subtasks[0]) - doc['id'] = str(bson.ObjectId()) - doc['start_timestamp'] = isotime.format(timestamp) + doc["id"] = str(bson.ObjectId()) + doc["start_timestamp"] = isotime.format(timestamp) obj = ActionExecutionAPI(**doc) ActionExecution.add_or_update(ActionExecutionAPI.to_model(obj)) - dt_range = '2014-12-25T00:00:10Z..2014-12-25T00:00:19Z' - objs = ActionExecution.query(start_timestamp=dt_range, - order_by=['start_timestamp']) - self.assertLess(objs[0]['start_timestamp'], - objs[9]['start_timestamp']) + dt_range = "2014-12-25T00:00:10Z..2014-12-25T00:00:19Z" + objs = ActionExecution.query( + start_timestamp=dt_range, order_by=["start_timestamp"] + ) + self.assertLess(objs[0]["start_timestamp"], objs[9]["start_timestamp"]) - dt_range = '2014-12-25T00:00:19Z..2014-12-25T00:00:10Z' - objs = ActionExecution.query(start_timestamp=dt_range, - order_by=['-start_timestamp']) - self.assertLess(objs[9]['start_timestamp'], - objs[0]['start_timestamp']) + dt_range = "2014-12-25T00:00:19Z..2014-12-25T00:00:10Z" + objs = ActionExecution.query( + start_timestamp=dt_range, order_by=["-start_timestamp"] + ) + self.assertLess(objs[9]["start_timestamp"], objs[0]["start_timestamp"]) diff --git a/st2common/tests/unit/test_executions_util.py b/st2common/tests/unit/test_executions_util.py index 9177493188d..f7702614d32 100644 --- a/st2common/tests/unit/test_executions_util.py +++ b/st2common/tests/unit/test_executions_util.py @@ -35,25 +35,28 @@ import st2tests.config as tests_config from six.moves import range + tests_config.parse_args() -FIXTURES_PACK = 'generic' +FIXTURES_PACK = "generic" TEST_FIXTURES = { - 'liveactions': ['liveaction1.yaml', 'parentliveaction.yaml', 'childliveaction.yaml', - 'successful_liveaction.yaml'], - 'actions': ['local.yaml'], - 'executions': ['execution1.yaml'], - 'runners': ['run-local.yaml'], - 'triggertypes': ['triggertype2.yaml'], - 'rules': ['rule3.yaml'], - 'triggers': ['trigger2.yaml'], - 'triggerinstances': ['trigger_instance_1.yaml'] + "liveactions": [ + "liveaction1.yaml", + "parentliveaction.yaml", + "childliveaction.yaml", + "successful_liveaction.yaml", + ], + "actions": ["local.yaml"], + "executions": ["execution1.yaml"], + "runners": ["run-local.yaml"], + "triggertypes": ["triggertype2.yaml"], + "rules": ["rule3.yaml"], + "triggers": ["trigger2.yaml"], + "triggerinstances": ["trigger_instance_1.yaml"], } -DYNAMIC_FIXTURES = { - 'liveactions': ['liveaction3.yaml'] -} +DYNAMIC_FIXTURES = {"liveactions": ["liveaction3.yaml"]} class ExecutionsUtilTestCase(CleanDbTestCase): @@ -63,118 +66,144 @@ def __init__(self, *args, **kwargs): def setUp(self): super(ExecutionsUtilTestCase, self).setUp() - self.MODELS = FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_FIXTURES) - self.FIXTURES = FixturesLoader().load_fixtures(fixtures_pack=FIXTURES_PACK, - fixtures_dict=DYNAMIC_FIXTURES) + self.MODELS = FixturesLoader().save_fixtures_to_db( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES + ) + self.FIXTURES = FixturesLoader().load_fixtures( + fixtures_pack=FIXTURES_PACK, fixtures_dict=DYNAMIC_FIXTURES + ) def test_execution_creation_manual_action_run(self): - liveaction = self.MODELS['liveactions']['liveaction1.yaml'] + liveaction = self.MODELS["liveactions"]["liveaction1.yaml"] pre_creation_timestamp = date_utils.get_datetime_utc_now() executions_util.create_execution_object(liveaction) post_creation_timestamp = date_utils.get_datetime_utc_now() - execution = self._get_action_execution(liveaction__id=str(liveaction.id), - raise_exception=True) + execution = self._get_action_execution( + liveaction__id=str(liveaction.id), raise_exception=True + ) self.assertDictEqual(execution.trigger, {}) self.assertDictEqual(execution.trigger_type, {}) self.assertDictEqual(execution.trigger_instance, {}) self.assertDictEqual(execution.rule, {}) - action = action_utils.get_action_by_ref('core.local') + action = action_utils.get_action_by_ref("core.local") self.assertDictEqual(execution.action, vars(ActionAPI.from_model(action))) - runner = RunnerType.get_by_name(action.runner_type['name']) + runner = RunnerType.get_by_name(action.runner_type["name"]) self.assertDictEqual(execution.runner, vars(RunnerTypeAPI.from_model(runner))) liveaction = LiveAction.get_by_id(str(liveaction.id)) - self.assertEqual(execution.liveaction['id'], str(liveaction.id)) + self.assertEqual(execution.liveaction["id"], str(liveaction.id)) self.assertEqual(len(execution.log), 1) - self.assertEqual(execution.log[0]['status'], liveaction.status) - self.assertGreater(execution.log[0]['timestamp'], pre_creation_timestamp) - self.assertLess(execution.log[0]['timestamp'], post_creation_timestamp) + self.assertEqual(execution.log[0]["status"], liveaction.status) + self.assertGreater(execution.log[0]["timestamp"], pre_creation_timestamp) + self.assertLess(execution.log[0]["timestamp"], post_creation_timestamp) def test_execution_creation_action_triggered_by_rule(self): # Wait for the action execution to complete and then confirm outcome. - trigger_type = self.MODELS['triggertypes']['triggertype2.yaml'] - trigger = self.MODELS['triggers']['trigger2.yaml'] - trigger_instance = self.MODELS['triggerinstances']['trigger_instance_1.yaml'] - test_liveaction = self.FIXTURES['liveactions']['liveaction3.yaml'] - rule = self.MODELS['rules']['rule3.yaml'] + trigger_type = self.MODELS["triggertypes"]["triggertype2.yaml"] + trigger = self.MODELS["triggers"]["trigger2.yaml"] + trigger_instance = self.MODELS["triggerinstances"]["trigger_instance_1.yaml"] + test_liveaction = self.FIXTURES["liveactions"]["liveaction3.yaml"] + rule = self.MODELS["rules"]["rule3.yaml"] # Setup LiveAction to point to right rule and trigger_instance. # XXX: We need support for dynamic fixtures. - test_liveaction['context']['rule']['id'] = str(rule.id) - test_liveaction['context']['trigger_instance']['id'] = str(trigger_instance.id) + test_liveaction["context"]["rule"]["id"] = str(rule.id) + test_liveaction["context"]["trigger_instance"]["id"] = str(trigger_instance.id) test_liveaction_api = LiveActionAPI(**test_liveaction) - test_liveaction = LiveAction.add_or_update(LiveActionAPI.to_model(test_liveaction_api)) - liveaction = LiveAction.get(context__trigger_instance__id=str(trigger_instance.id)) + test_liveaction = LiveAction.add_or_update( + LiveActionAPI.to_model(test_liveaction_api) + ) + liveaction = LiveAction.get( + context__trigger_instance__id=str(trigger_instance.id) + ) self.assertIsNotNone(liveaction) - self.assertEqual(liveaction.status, action_constants.LIVEACTION_STATUS_REQUESTED) + self.assertEqual( + liveaction.status, action_constants.LIVEACTION_STATUS_REQUESTED + ) executions_util.create_execution_object(liveaction) - execution = self._get_action_execution(liveaction__id=str(liveaction.id), - raise_exception=True) + execution = self._get_action_execution( + liveaction__id=str(liveaction.id), raise_exception=True + ) self.assertDictEqual(execution.trigger, vars(TriggerAPI.from_model(trigger))) - self.assertDictEqual(execution.trigger_type, vars(TriggerTypeAPI.from_model(trigger_type))) - self.assertDictEqual(execution.trigger_instance, - vars(TriggerInstanceAPI.from_model(trigger_instance))) + self.assertDictEqual( + execution.trigger_type, vars(TriggerTypeAPI.from_model(trigger_type)) + ) + self.assertDictEqual( + execution.trigger_instance, + vars(TriggerInstanceAPI.from_model(trigger_instance)), + ) self.assertDictEqual(execution.rule, vars(RuleAPI.from_model(rule))) action = action_utils.get_action_by_ref(liveaction.action) self.assertDictEqual(execution.action, vars(ActionAPI.from_model(action))) - runner = RunnerType.get_by_name(action.runner_type['name']) + runner = RunnerType.get_by_name(action.runner_type["name"]) self.assertDictEqual(execution.runner, vars(RunnerTypeAPI.from_model(runner))) liveaction = LiveAction.get_by_id(str(liveaction.id)) - self.assertEqual(execution.liveaction['id'], str(liveaction.id)) + self.assertEqual(execution.liveaction["id"], str(liveaction.id)) def test_execution_creation_with_web_url(self): - liveaction = self.MODELS['liveactions']['liveaction1.yaml'] + liveaction = self.MODELS["liveactions"]["liveaction1.yaml"] executions_util.create_execution_object(liveaction) - execution = self._get_action_execution(liveaction__id=str(liveaction.id), - raise_exception=True) + execution = self._get_action_execution( + liveaction__id=str(liveaction.id), raise_exception=True + ) self.assertIsNotNone(execution.web_url) execution_id = str(execution.id) - self.assertIn(('history/%s/general' % execution_id), execution.web_url) + self.assertIn(("history/%s/general" % execution_id), execution.web_url) def test_execution_creation_chains(self): - childliveaction = self.MODELS['liveactions']['childliveaction.yaml'] + childliveaction = self.MODELS["liveactions"]["childliveaction.yaml"] child_exec = executions_util.create_execution_object(childliveaction) - parent_execution_id = childliveaction.context['parent']['execution_id'] + parent_execution_id = childliveaction.context["parent"]["execution_id"] parent_execution = ActionExecution.get_by_id(parent_execution_id) child_execs = parent_execution.children self.assertIn(str(child_exec.id), child_execs) def test_execution_update(self): - liveaction = self.MODELS['liveactions']['liveaction1.yaml'] + liveaction = self.MODELS["liveactions"]["liveaction1.yaml"] executions_util.create_execution_object(liveaction) - liveaction.status = 'running' + liveaction.status = "running" pre_update_timestamp = date_utils.get_datetime_utc_now() executions_util.update_execution(liveaction) post_update_timestamp = date_utils.get_datetime_utc_now() - execution = self._get_action_execution(liveaction__id=str(liveaction.id), - raise_exception=True) + execution = self._get_action_execution( + liveaction__id=str(liveaction.id), raise_exception=True + ) self.assertEqual(len(execution.log), 2) - self.assertEqual(execution.log[1]['status'], liveaction.status) - self.assertGreater(execution.log[1]['timestamp'], pre_update_timestamp) - self.assertLess(execution.log[1]['timestamp'], post_update_timestamp) - - @mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) - @mock.patch.object(runners_utils, 'invoke_post_run', mock.MagicMock(return_value=None)) + self.assertEqual(execution.log[1]["status"], liveaction.status) + self.assertGreater(execution.log[1]["timestamp"], pre_update_timestamp) + self.assertLess(execution.log[1]["timestamp"], post_update_timestamp) + + @mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) + @mock.patch.object( + runners_utils, "invoke_post_run", mock.MagicMock(return_value=None) + ) def test_abandon_executions(self): - liveaction_db = self.MODELS['liveactions']['liveaction1.yaml'] + liveaction_db = self.MODELS["liveactions"]["liveaction1.yaml"] executions_util.create_execution_object(liveaction_db) execution_db = executions_util.abandon_execution_if_incomplete( - liveaction_id=str(liveaction_db.id)) + liveaction_id=str(liveaction_db.id) + ) - self.assertEqual(execution_db.status, 'abandoned') + self.assertEqual(execution_db.status, "abandoned") runners_utils.invoke_post_run.assert_called_once() - @mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) - @mock.patch.object(runners_utils, 'invoke_post_run', mock.MagicMock(return_value=None)) + @mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) + @mock.patch.object( + runners_utils, "invoke_post_run", mock.MagicMock(return_value=None) + ) def test_abandon_executions_on_complete(self): - liveaction_db = self.MODELS['liveactions']['successful_liveaction.yaml'] + liveaction_db = self.MODELS["liveactions"]["successful_liveaction.yaml"] executions_util.create_execution_object(liveaction_db) - expected_msg = r'LiveAction %s already in a completed state %s\.' % \ - (str(liveaction_db.id), liveaction_db.status) - - self.assertRaisesRegexp(ValueError, expected_msg, - executions_util.abandon_execution_if_incomplete, - liveaction_id=str(liveaction_db.id)) + expected_msg = r"LiveAction %s already in a completed state %s\." % ( + str(liveaction_db.id), + liveaction_db.status, + ) + + self.assertRaisesRegexp( + ValueError, + expected_msg, + executions_util.abandon_execution_if_incomplete, + liveaction_id=str(liveaction_db.id), + ) runners_utils.invoke_post_run.assert_not_called() @@ -184,12 +213,20 @@ def _get_action_execution(self, **kwargs): # descendants test section -DESCENDANTS_PACK = 'descendants' +DESCENDANTS_PACK = "descendants" DESCENDANTS_FIXTURES = { - 'executions': ['root_execution.yaml', 'child1_level1.yaml', 'child2_level1.yaml', - 'child1_level2.yaml', 'child2_level2.yaml', 'child3_level2.yaml', - 'child1_level3.yaml', 'child2_level3.yaml', 'child3_level3.yaml'] + "executions": [ + "root_execution.yaml", + "child1_level1.yaml", + "child2_level1.yaml", + "child1_level2.yaml", + "child2_level2.yaml", + "child3_level2.yaml", + "child1_level3.yaml", + "child2_level3.yaml", + "child3_level3.yaml", + ] } @@ -200,75 +237,90 @@ def __init__(self, *args, **kwargs): def setUp(self): super(ExecutionsUtilDescendantsTestCase, self).setUp() - self.MODELS = FixturesLoader().save_fixtures_to_db(fixtures_pack=DESCENDANTS_PACK, - fixtures_dict=DESCENDANTS_FIXTURES) + self.MODELS = FixturesLoader().save_fixtures_to_db( + fixtures_pack=DESCENDANTS_PACK, fixtures_dict=DESCENDANTS_FIXTURES + ) def test_get_all_descendants_sorted(self): - root_execution = self.MODELS['executions']['root_execution.yaml'] - all_descendants = executions_util.get_descendants(str(root_execution.id), - result_fmt='sorted') + root_execution = self.MODELS["executions"]["root_execution.yaml"] + all_descendants = executions_util.get_descendants( + str(root_execution.id), result_fmt="sorted" + ) all_descendants_ids = [str(descendant.id) for descendant in all_descendants] all_descendants_ids.sort() # everything except the root_execution - expected_ids = [str(v.id) for _, v in six.iteritems(self.MODELS['executions']) - if v.id != root_execution.id] + expected_ids = [ + str(v.id) + for _, v in six.iteritems(self.MODELS["executions"]) + if v.id != root_execution.id + ] expected_ids.sort() self.assertListEqual(all_descendants_ids, expected_ids) # verify sort order for idx in range(len(all_descendants) - 1): - self.assertLess(all_descendants[idx].start_timestamp, - all_descendants[idx + 1].start_timestamp) + self.assertLess( + all_descendants[idx].start_timestamp, + all_descendants[idx + 1].start_timestamp, + ) def test_get_all_descendants(self): - root_execution = self.MODELS['executions']['root_execution.yaml'] + root_execution = self.MODELS["executions"]["root_execution.yaml"] all_descendants = executions_util.get_descendants(str(root_execution.id)) all_descendants_ids = [str(descendant.id) for descendant in all_descendants] all_descendants_ids.sort() # everything except the root_execution - expected_ids = [str(v.id) for _, v in six.iteritems(self.MODELS['executions']) - if v.id != root_execution.id] + expected_ids = [ + str(v.id) + for _, v in six.iteritems(self.MODELS["executions"]) + if v.id != root_execution.id + ] expected_ids.sort() self.assertListEqual(all_descendants_ids, expected_ids) def test_get_1_level_descendants_sorted(self): - root_execution = self.MODELS['executions']['root_execution.yaml'] - all_descendants = executions_util.get_descendants(str(root_execution.id), - descendant_depth=1, - result_fmt='sorted') + root_execution = self.MODELS["executions"]["root_execution.yaml"] + all_descendants = executions_util.get_descendants( + str(root_execution.id), descendant_depth=1, result_fmt="sorted" + ) all_descendants_ids = [str(descendant.id) for descendant in all_descendants] all_descendants_ids.sort() # All children of root_execution - expected_ids = [str(v.id) for _, v in six.iteritems(self.MODELS['executions']) - if v.parent == str(root_execution.id)] + expected_ids = [ + str(v.id) + for _, v in six.iteritems(self.MODELS["executions"]) + if v.parent == str(root_execution.id) + ] expected_ids.sort() self.assertListEqual(all_descendants_ids, expected_ids) # verify sort order for idx in range(len(all_descendants) - 1): - self.assertLess(all_descendants[idx].start_timestamp, - all_descendants[idx + 1].start_timestamp) + self.assertLess( + all_descendants[idx].start_timestamp, + all_descendants[idx + 1].start_timestamp, + ) def test_get_2_level_descendants_sorted(self): - root_execution = self.MODELS['executions']['root_execution.yaml'] - all_descendants = executions_util.get_descendants(str(root_execution.id), - descendant_depth=2, - result_fmt='sorted') + root_execution = self.MODELS["executions"]["root_execution.yaml"] + all_descendants = executions_util.get_descendants( + str(root_execution.id), descendant_depth=2, result_fmt="sorted" + ) all_descendants_ids = [str(descendant.id) for descendant in all_descendants] all_descendants_ids.sort() # All children of root_execution - root_execution = self.MODELS['executions']['root_execution.yaml'] + root_execution = self.MODELS["executions"]["root_execution.yaml"] expected_ids = [] traverse = [(child_id, 1) for child_id in root_execution.children] while traverse: @@ -282,7 +334,7 @@ def test_get_2_level_descendants_sorted(self): self.assertListEqual(all_descendants_ids, expected_ids) def _get_action_execution(self, ae_id): - for _, execution in six.iteritems(self.MODELS['executions']): + for _, execution in six.iteritems(self.MODELS["executions"]): if str(execution.id) == ae_id: return execution return None diff --git a/st2common/tests/unit/test_greenpooldispatch.py b/st2common/tests/unit/test_greenpooldispatch.py index 84c411d1405..45cc568759a 100644 --- a/st2common/tests/unit/test_greenpooldispatch.py +++ b/st2common/tests/unit/test_greenpooldispatch.py @@ -23,7 +23,6 @@ class TestGreenPoolDispatch(TestCase): - def test_dispatch_simple(self): dispatcher = BufferedDispatcher(dispatch_pool_size=10) mock_handler = mock.MagicMock() @@ -34,13 +33,17 @@ def test_dispatch_simple(self): while mock_handler.call_count < 10: eventlet.sleep(0.01) dispatcher.shutdown() - call_args_list = [(args[0][0], args[0][1]) for args in mock_handler.call_args_list] + call_args_list = [ + (args[0][0], args[0][1]) for args in mock_handler.call_args_list + ] self.assertItemsEqual(expected, call_args_list) def test_dispatch_starved(self): - dispatcher = BufferedDispatcher(dispatch_pool_size=2, - monitor_thread_empty_q_sleep_time=0.01, - monitor_thread_no_workers_sleep_time=0.01) + dispatcher = BufferedDispatcher( + dispatch_pool_size=2, + monitor_thread_empty_q_sleep_time=0.01, + monitor_thread_no_workers_sleep_time=0.01, + ) mock_handler = mock.MagicMock() expected = [] for i in range(10): @@ -49,5 +52,7 @@ def test_dispatch_starved(self): while mock_handler.call_count < 10: eventlet.sleep(0.01) dispatcher.shutdown() - call_args_list = [(args[0][0], args[0][1]) for args in mock_handler.call_args_list] + call_args_list = [ + (args[0][0], args[0][1]) for args in mock_handler.call_args_list + ] self.assertItemsEqual(expected, call_args_list) diff --git a/st2common/tests/unit/test_hash.py b/st2common/tests/unit/test_hash.py index 7211879ff63..234d4969dae 100644 --- a/st2common/tests/unit/test_hash.py +++ b/st2common/tests/unit/test_hash.py @@ -22,15 +22,14 @@ class TestHashWithApiKeys(unittest2.TestCase): - def test_hash_repeatability(self): api_key = auth_utils.generate_api_key() hash1 = hash_utils.hash(api_key) hash2 = hash_utils.hash(api_key) - self.assertEqual(hash1, hash2, 'Expected a repeated hash.') + self.assertEqual(hash1, hash2, "Expected a repeated hash.") def test_hash_uniqueness(self): count = 10000 api_keys = [auth_utils.generate_api_key() for _ in range(count)] hashes = set([hash_utils.hash(api_key) for api_key in api_keys]) - self.assertEqual(len(hashes), count, 'Expected all unique hashes.') + self.assertEqual(len(hashes), count, "Expected all unique hashes.") diff --git a/st2common/tests/unit/test_ip_utils.py b/st2common/tests/unit/test_ip_utils.py index a33c220d71c..cd1339be737 100644 --- a/st2common/tests/unit/test_ip_utils.py +++ b/st2common/tests/unit/test_ip_utils.py @@ -20,73 +20,72 @@ class IPUtilsTests(unittest2.TestCase): - def test_host_port_split(self): # Simple IPv4 - host_str = '1.2.3.4' + host_str = "1.2.3.4" host, port = split_host_port(host_str) self.assertEqual(host, host_str) self.assertEqual(port, None) # Simple IPv4 with port - host_str = '1.2.3.4:55' + host_str = "1.2.3.4:55" host, port = split_host_port(host_str) - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 55) # Simple IPv6 - host_str = 'fec2::10' + host_str = "fec2::10" host, port = split_host_port(host_str) - self.assertEqual(host, 'fec2::10') + self.assertEqual(host, "fec2::10") self.assertEqual(port, None) # IPv6 with square brackets no port - host_str = '[fec2::10]' + host_str = "[fec2::10]" host, port = split_host_port(host_str) - self.assertEqual(host, 'fec2::10') + self.assertEqual(host, "fec2::10") self.assertEqual(port, None) # IPv6 with square brackets with port - host_str = '[fec2::10]:55' + host_str = "[fec2::10]:55" host, port = split_host_port(host_str) - self.assertEqual(host, 'fec2::10') + self.assertEqual(host, "fec2::10") self.assertEqual(port, 55) # IPv4 inside bracket - host_str = '[1.2.3.4]' + host_str = "[1.2.3.4]" host, port = split_host_port(host_str) - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, None) # IPv4 inside bracket and port - host_str = '[1.2.3.4]:55' + host_str = "[1.2.3.4]:55" host, port = split_host_port(host_str) - self.assertEqual(host, '1.2.3.4') + self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 55) # Hostname inside bracket - host_str = '[st2build001]:55' + host_str = "[st2build001]:55" host, port = split_host_port(host_str) - self.assertEqual(host, 'st2build001') + self.assertEqual(host, "st2build001") self.assertEqual(port, 55) # Simple hostname - host_str = 'st2build001' + host_str = "st2build001" host, port = split_host_port(host_str) - self.assertEqual(host, 'st2build001') + self.assertEqual(host, "st2build001") self.assertEqual(port, None) # Simple hostname with port - host_str = 'st2build001:55' + host_str = "st2build001:55" host, port = split_host_port(host_str) - self.assertEqual(host, 'st2build001') + self.assertEqual(host, "st2build001") self.assertEqual(port, 55) # No-bracket invalid port - host_str = 'st2build001:abc' + host_str = "st2build001:abc" self.assertRaises(Exception, split_host_port, host_str) # Bracket invalid port - host_str = '[fec2::10]:abc' + host_str = "[fec2::10]:abc" self.assertRaises(Exception, split_host_port, host_str) diff --git a/st2common/tests/unit/test_isotime_utils.py b/st2common/tests/unit/test_isotime_utils.py index 5ec5495ca91..34d785031b5 100644 --- a/st2common/tests/unit/test_isotime_utils.py +++ b/st2common/tests/unit/test_isotime_utils.py @@ -24,50 +24,54 @@ class IsoTimeUtilsTestCase(unittest.TestCase): def test_validate(self): - self.assertTrue(isotime.validate('2000-01-01 12:00:00Z')) - self.assertTrue(isotime.validate('2000-01-01 12:00:00+00')) - self.assertTrue(isotime.validate('2000-01-01 12:00:00+0000')) - self.assertTrue(isotime.validate('2000-01-01 12:00:00+00:00')) - self.assertTrue(isotime.validate('2000-01-01 12:00:00.000000Z')) - self.assertTrue(isotime.validate('2000-01-01 12:00:00.000000+00')) - self.assertTrue(isotime.validate('2000-01-01 12:00:00.000000+0000')) - self.assertTrue(isotime.validate('2000-01-01 12:00:00.000000+00:00')) - self.assertTrue(isotime.validate('2000-01-01T12:00:00Z')) - self.assertTrue(isotime.validate('2000-01-01T12:00:00.000000Z')) - self.assertTrue(isotime.validate('2000-01-01T12:00:00+00:00')) - self.assertTrue(isotime.validate('2000-01-01T12:00:00.000000+00:00')) - self.assertTrue(isotime.validate('2015-02-10T21:21:53.399Z')) - self.assertFalse(isotime.validate('2000-01-01', raise_exception=False)) - self.assertFalse(isotime.validate('2000-01-01T12:00:00', raise_exception=False)) - self.assertFalse(isotime.validate('2000-01-01T12:00:00+00:00Z', raise_exception=False)) - self.assertFalse(isotime.validate('2000-01-01T12:00:00.000000', raise_exception=False)) - self.assertFalse(isotime.validate('Epic!', raise_exception=False)) + self.assertTrue(isotime.validate("2000-01-01 12:00:00Z")) + self.assertTrue(isotime.validate("2000-01-01 12:00:00+00")) + self.assertTrue(isotime.validate("2000-01-01 12:00:00+0000")) + self.assertTrue(isotime.validate("2000-01-01 12:00:00+00:00")) + self.assertTrue(isotime.validate("2000-01-01 12:00:00.000000Z")) + self.assertTrue(isotime.validate("2000-01-01 12:00:00.000000+00")) + self.assertTrue(isotime.validate("2000-01-01 12:00:00.000000+0000")) + self.assertTrue(isotime.validate("2000-01-01 12:00:00.000000+00:00")) + self.assertTrue(isotime.validate("2000-01-01T12:00:00Z")) + self.assertTrue(isotime.validate("2000-01-01T12:00:00.000000Z")) + self.assertTrue(isotime.validate("2000-01-01T12:00:00+00:00")) + self.assertTrue(isotime.validate("2000-01-01T12:00:00.000000+00:00")) + self.assertTrue(isotime.validate("2015-02-10T21:21:53.399Z")) + self.assertFalse(isotime.validate("2000-01-01", raise_exception=False)) + self.assertFalse(isotime.validate("2000-01-01T12:00:00", raise_exception=False)) + self.assertFalse( + isotime.validate("2000-01-01T12:00:00+00:00Z", raise_exception=False) + ) + self.assertFalse( + isotime.validate("2000-01-01T12:00:00.000000", raise_exception=False) + ) + self.assertFalse(isotime.validate("Epic!", raise_exception=False)) self.assertFalse(isotime.validate(object(), raise_exception=False)) - self.assertRaises(ValueError, isotime.validate, 'Epic!', True) + self.assertRaises(ValueError, isotime.validate, "Epic!", True) def test_parse(self): dt = date.add_utc_tz(datetime.datetime(2000, 1, 1, 12)) - self.assertEqual(isotime.parse('2000-01-01 12:00:00Z'), dt) - self.assertEqual(isotime.parse('2000-01-01 12:00:00+00'), dt) - self.assertEqual(isotime.parse('2000-01-01 12:00:00+0000'), dt) - self.assertEqual(isotime.parse('2000-01-01 12:00:00+00:00'), dt) - self.assertEqual(isotime.parse('2000-01-01 12:00:00.000000Z'), dt) - self.assertEqual(isotime.parse('2000-01-01 12:00:00.000000+00'), dt) - self.assertEqual(isotime.parse('2000-01-01 12:00:00.000000+0000'), dt) - self.assertEqual(isotime.parse('2000-01-01 12:00:00.000000+00:00'), dt) - self.assertEqual(isotime.parse('2000-01-01T12:00:00Z'), dt) - self.assertEqual(isotime.parse('2000-01-01T12:00:00+00:00'), dt) - self.assertEqual(isotime.parse('2000-01-01T12:00:00.000000Z'), dt) - self.assertEqual(isotime.parse('2000-01-01T12:00:00.000000+00:00'), dt) - self.assertEqual(isotime.parse('2000-01-01T12:00:00.000Z'), dt) + self.assertEqual(isotime.parse("2000-01-01 12:00:00Z"), dt) + self.assertEqual(isotime.parse("2000-01-01 12:00:00+00"), dt) + self.assertEqual(isotime.parse("2000-01-01 12:00:00+0000"), dt) + self.assertEqual(isotime.parse("2000-01-01 12:00:00+00:00"), dt) + self.assertEqual(isotime.parse("2000-01-01 12:00:00.000000Z"), dt) + self.assertEqual(isotime.parse("2000-01-01 12:00:00.000000+00"), dt) + self.assertEqual(isotime.parse("2000-01-01 12:00:00.000000+0000"), dt) + self.assertEqual(isotime.parse("2000-01-01 12:00:00.000000+00:00"), dt) + self.assertEqual(isotime.parse("2000-01-01T12:00:00Z"), dt) + self.assertEqual(isotime.parse("2000-01-01T12:00:00+00:00"), dt) + self.assertEqual(isotime.parse("2000-01-01T12:00:00.000000Z"), dt) + self.assertEqual(isotime.parse("2000-01-01T12:00:00.000000+00:00"), dt) + self.assertEqual(isotime.parse("2000-01-01T12:00:00.000Z"), dt) def test_format(self): dt = date.add_utc_tz(datetime.datetime(2000, 1, 1, 12)) - dt_str_usec_offset = '2000-01-01T12:00:00.000000+00:00' - dt_str_usec = '2000-01-01T12:00:00.000000Z' - dt_str_offset = '2000-01-01T12:00:00+00:00' - dt_str = '2000-01-01T12:00:00Z' - dt_unicode = u'2000-01-01T12:00:00Z' + dt_str_usec_offset = "2000-01-01T12:00:00.000000+00:00" + dt_str_usec = "2000-01-01T12:00:00.000000Z" + dt_str_offset = "2000-01-01T12:00:00+00:00" + dt_str = "2000-01-01T12:00:00Z" + dt_unicode = "2000-01-01T12:00:00Z" # datetime object self.assertEqual(isotime.format(dt, usec=True, offset=True), dt_str_usec_offset) @@ -75,16 +79,22 @@ def test_format(self): self.assertEqual(isotime.format(dt, usec=False, offset=True), dt_str_offset) self.assertEqual(isotime.format(dt, usec=False, offset=False), dt_str) self.assertEqual(isotime.format(dt_str, usec=False, offset=False), dt_str) - self.assertEqual(isotime.format(dt_unicode, usec=False, offset=False), dt_unicode) + self.assertEqual( + isotime.format(dt_unicode, usec=False, offset=False), dt_unicode + ) # unix timestamp (epoch) dt = 1557390483 - self.assertEqual(isotime.format(dt, usec=True, offset=True), - '2019-05-09T08:28:03.000000+00:00') - self.assertEqual(isotime.format(dt, usec=False, offset=False), - '2019-05-09T08:28:03Z') - self.assertEqual(isotime.format(dt, usec=False, offset=True), - '2019-05-09T08:28:03+00:00') + self.assertEqual( + isotime.format(dt, usec=True, offset=True), + "2019-05-09T08:28:03.000000+00:00", + ) + self.assertEqual( + isotime.format(dt, usec=False, offset=False), "2019-05-09T08:28:03Z" + ) + self.assertEqual( + isotime.format(dt, usec=False, offset=True), "2019-05-09T08:28:03+00:00" + ) def test_format_tz_naive(self): dt1 = datetime.datetime.utcnow() @@ -99,6 +109,8 @@ def test_format_tz_aware(self): def test_format_sec_truncated(self): dt1 = date.add_utc_tz(datetime.datetime.utcnow()) dt2 = isotime.parse(isotime.format(dt1, usec=False)) - dt3 = datetime.datetime(dt1.year, dt1.month, dt1.day, dt1.hour, dt1.minute, dt1.second) + dt3 = datetime.datetime( + dt1.year, dt1.month, dt1.day, dt1.hour, dt1.minute, dt1.second + ) self.assertLess(dt2, dt1) self.assertEqual(dt2, date.add_utc_tz(dt3)) diff --git a/st2common/tests/unit/test_jinja_render_crypto_filters.py b/st2common/tests/unit/test_jinja_render_crypto_filters.py index 1a026e83ed2..f58edb13093 100644 --- a/st2common/tests/unit/test_jinja_render_crypto_filters.py +++ b/st2common/tests/unit/test_jinja_render_crypto_filters.py @@ -38,72 +38,101 @@ def setUp(self): crypto_key_path = cfg.CONF.keyvalue.encryption_key_path crypto_key = read_crypto_key(key_path=crypto_key_path) - self.secret = 'Build a wall' - self.secret_value = symmetric_encrypt(encrypt_key=crypto_key, plaintext=self.secret) + self.secret = "Build a wall" + self.secret_value = symmetric_encrypt( + encrypt_key=crypto_key, plaintext=self.secret + ) self.env = jinja_utils.get_jinja_environment() def test_filter_decrypt_kv(self): - KeyValuePair.add_or_update(KeyValuePairDB(name='k8', value=self.secret_value, - scope=FULL_SYSTEM_SCOPE, - secret=True)) + KeyValuePair.add_or_update( + KeyValuePairDB( + name="k8", value=self.secret_value, scope=FULL_SYSTEM_SCOPE, secret=True + ) + ) context = {} context.update({SYSTEM_SCOPE: KeyValueLookup(scope=SYSTEM_SCOPE)}) - context.update({ - DATASTORE_PARENT_SCOPE: { - SYSTEM_SCOPE: KeyValueLookup(scope=FULL_SYSTEM_SCOPE) + context.update( + { + DATASTORE_PARENT_SCOPE: { + SYSTEM_SCOPE: KeyValueLookup(scope=FULL_SYSTEM_SCOPE) + } } - }) + ) - template = '{{st2kv.system.k8 | decrypt_kv}}' + template = "{{st2kv.system.k8 | decrypt_kv}}" actual = self.env.from_string(template).render(context) self.assertEqual(actual, self.secret) def test_filter_decrypt_kv_datastore_value_doesnt_exist(self): context = {} context.update({SYSTEM_SCOPE: KeyValueLookup(scope=SYSTEM_SCOPE)}) - context.update({ - DATASTORE_PARENT_SCOPE: { - SYSTEM_SCOPE: KeyValueLookup(scope=FULL_SYSTEM_SCOPE) + context.update( + { + DATASTORE_PARENT_SCOPE: { + SYSTEM_SCOPE: KeyValueLookup(scope=FULL_SYSTEM_SCOPE) + } } - }) + ) - template = '{{st2kv.system.doesnt_exist | decrypt_kv}}' + template = "{{st2kv.system.doesnt_exist | decrypt_kv}}" - expected_msg = ('Referenced datastore item "st2kv.system.doesnt_exist" doesn\'t exist or ' - 'it contains an empty string') - self.assertRaisesRegexp(ValueError, expected_msg, self.env.from_string(template).render, - context) + expected_msg = ( + 'Referenced datastore item "st2kv.system.doesnt_exist" doesn\'t exist or ' + "it contains an empty string" + ) + self.assertRaisesRegexp( + ValueError, expected_msg, self.env.from_string(template).render, context + ) def test_filter_decrypt_kv_with_user_scope_value(self): - KeyValuePair.add_or_update(KeyValuePairDB(name='stanley:k8', value=self.secret_value, - scope=FULL_USER_SCOPE, - secret=True)) + KeyValuePair.add_or_update( + KeyValuePairDB( + name="stanley:k8", + value=self.secret_value, + scope=FULL_USER_SCOPE, + secret=True, + ) + ) context = {} - context.update({USER_SCOPE: UserKeyValueLookup(user='stanley', scope=USER_SCOPE)}) - context.update({ - DATASTORE_PARENT_SCOPE: { - USER_SCOPE: UserKeyValueLookup(user='stanley', scope=FULL_USER_SCOPE) + context.update( + {USER_SCOPE: UserKeyValueLookup(user="stanley", scope=USER_SCOPE)} + ) + context.update( + { + DATASTORE_PARENT_SCOPE: { + USER_SCOPE: UserKeyValueLookup( + user="stanley", scope=FULL_USER_SCOPE + ) + } } - }) + ) - template = '{{st2kv.user.k8 | decrypt_kv}}' + template = "{{st2kv.user.k8 | decrypt_kv}}" actual = self.env.from_string(template).render(context) self.assertEqual(actual, self.secret) def test_filter_decrypt_kv_with_user_scope_value_datastore_value_doesnt_exist(self): context = {} context.update({SYSTEM_SCOPE: KeyValueLookup(scope=SYSTEM_SCOPE)}) - context.update({ - DATASTORE_PARENT_SCOPE: { - USER_SCOPE: UserKeyValueLookup(user='stanley', scope=FULL_USER_SCOPE) + context.update( + { + DATASTORE_PARENT_SCOPE: { + USER_SCOPE: UserKeyValueLookup( + user="stanley", scope=FULL_USER_SCOPE + ) + } } - }) + ) - template = '{{st2kv.user.doesnt_exist | decrypt_kv}}' + template = "{{st2kv.user.doesnt_exist | decrypt_kv}}" - expected_msg = ('Referenced datastore item "st2kv.user.doesnt_exist" doesn\'t exist or ' - 'it contains an empty string') - self.assertRaisesRegexp(ValueError, expected_msg, self.env.from_string(template).render, - context) + expected_msg = ( + 'Referenced datastore item "st2kv.user.doesnt_exist" doesn\'t exist or ' + "it contains an empty string" + ) + self.assertRaisesRegexp( + ValueError, expected_msg, self.env.from_string(template).render, context + ) diff --git a/st2common/tests/unit/test_jinja_render_data_filters.py b/st2common/tests/unit/test_jinja_render_data_filters.py index fd923e870f2..44d2f296f92 100644 --- a/st2common/tests/unit/test_jinja_render_data_filters.py +++ b/st2common/tests/unit/test_jinja_render_data_filters.py @@ -24,77 +24,68 @@ class JinjaUtilsDataFilterTestCase(unittest2.TestCase): - def test_filter_from_json_string(self): env = jinja_utils.get_jinja_environment() - expected_obj = {'a': 'b', 'c': {'d': 'e', 'f': 1, 'g': True}} + expected_obj = {"a": "b", "c": {"d": "e", "f": 1, "g": True}} obj_json_str = '{"a": "b", "c": {"d": "e", "f": 1, "g": true}}' - template = '{{k1 | from_json_string}}' + template = "{{k1 | from_json_string}}" - obj_str = env.from_string(template).render({'k1': obj_json_str}) + obj_str = env.from_string(template).render({"k1": obj_json_str}) obj = eval(obj_str) self.assertDictEqual(obj, expected_obj) # With KeyValueLookup object env = jinja_utils.get_jinja_environment() obj_json_str = '["a", "b", "c"]' - expected_obj = ['a', 'b', 'c'] + expected_obj = ["a", "b", "c"] - template = '{{ k1 | from_json_string}}' + template = "{{ k1 | from_json_string}}" - lookup = KeyValueLookup(scope=FULL_SYSTEM_SCOPE, key_prefix='a') - lookup._value_cache['a'] = obj_json_str - obj_str = env.from_string(template).render({'k1': lookup}) + lookup = KeyValueLookup(scope=FULL_SYSTEM_SCOPE, key_prefix="a") + lookup._value_cache["a"] = obj_json_str + obj_str = env.from_string(template).render({"k1": lookup}) obj = eval(obj_str) self.assertEqual(obj, expected_obj) def test_filter_from_yaml_string(self): env = jinja_utils.get_jinja_environment() - expected_obj = {'a': 'b', 'c': {'d': 'e', 'f': 1, 'g': True}} - obj_yaml_str = ("---\n" - "a: b\n" - "c:\n" - " d: e\n" - " f: 1\n" - " g: true\n") - - template = '{{k1 | from_yaml_string}}' - obj_str = env.from_string(template).render({'k1': obj_yaml_str}) + expected_obj = {"a": "b", "c": {"d": "e", "f": 1, "g": True}} + obj_yaml_str = "---\n" "a: b\n" "c:\n" " d: e\n" " f: 1\n" " g: true\n" + + template = "{{k1 | from_yaml_string}}" + obj_str = env.from_string(template).render({"k1": obj_yaml_str}) obj = eval(obj_str) self.assertDictEqual(obj, expected_obj) # With KeyValueLookup object env = jinja_utils.get_jinja_environment() - obj_yaml_str = ("---\n" - "- a\n" - "- b\n" - "- c\n") - expected_obj = ['a', 'b', 'c'] + obj_yaml_str = "---\n" "- a\n" "- b\n" "- c\n" + expected_obj = ["a", "b", "c"] - template = '{{ k1 | from_yaml_string }}' + template = "{{ k1 | from_yaml_string }}" - lookup = KeyValueLookup(scope=FULL_SYSTEM_SCOPE, key_prefix='b') - lookup._value_cache['b'] = obj_yaml_str - obj_str = env.from_string(template).render({'k1': lookup}) + lookup = KeyValueLookup(scope=FULL_SYSTEM_SCOPE, key_prefix="b") + lookup._value_cache["b"] = obj_yaml_str + obj_str = env.from_string(template).render({"k1": lookup}) obj = eval(obj_str) self.assertEqual(obj, expected_obj) def test_filter_to_json_string(self): env = jinja_utils.get_jinja_environment() - obj = {'a': 'b', 'c': {'d': 'e', 'f': 1, 'g': True}} + obj = {"a": "b", "c": {"d": "e", "f": 1, "g": True}} - template = '{{k1 | to_json_string}}' + template = "{{k1 | to_json_string}}" - obj_json_str = env.from_string(template).render({'k1': obj}) + obj_json_str = env.from_string(template).render({"k1": obj}) actual_obj = json.loads(obj_json_str) self.assertDictEqual(obj, actual_obj) def test_filter_to_yaml_string(self): env = jinja_utils.get_jinja_environment() - obj = {'a': 'b', 'c': {'d': 'e', 'f': 1, 'g': True}} + obj = {"a": "b", "c": {"d": "e", "f": 1, "g": True}} - template = '{{k1 | to_yaml_string}}' - obj_yaml_str = env.from_string(template).render({'k1': obj}) + template = "{{k1 | to_yaml_string}}" + obj_yaml_str = env.from_string(template).render({"k1": obj}) actual_obj = yaml.safe_load(obj_yaml_str) self.assertDictEqual(obj, actual_obj) diff --git a/st2common/tests/unit/test_jinja_render_json_escape_filters.py b/st2common/tests/unit/test_jinja_render_json_escape_filters.py index 82534100c5a..48fef776c1b 100644 --- a/st2common/tests/unit/test_jinja_render_json_escape_filters.py +++ b/st2common/tests/unit/test_jinja_render_json_escape_filters.py @@ -21,52 +21,51 @@ class JinjaUtilsJsonEscapeTestCase(unittest2.TestCase): - def test_doublequotes(self): env = jinja_utils.get_jinja_environment() - template = '{{ test_str | json_escape }}' - actual = env.from_string(template).render({'test_str': 'foo """ bar'}) + template = "{{ test_str | json_escape }}" + actual = env.from_string(template).render({"test_str": 'foo """ bar'}) expected = 'foo \\"\\"\\" bar' self.assertEqual(actual, expected) def test_backslashes(self): env = jinja_utils.get_jinja_environment() - template = '{{ test_str | json_escape }}' - actual = env.from_string(template).render({'test_str': r'foo \ bar'}) - expected = 'foo \\\\ bar' + template = "{{ test_str | json_escape }}" + actual = env.from_string(template).render({"test_str": r"foo \ bar"}) + expected = "foo \\\\ bar" self.assertEqual(actual, expected) def test_backspace(self): env = jinja_utils.get_jinja_environment() - template = '{{ test_str | json_escape }}' - actual = env.from_string(template).render({'test_str': 'foo \b bar'}) - expected = 'foo \\b bar' + template = "{{ test_str | json_escape }}" + actual = env.from_string(template).render({"test_str": "foo \b bar"}) + expected = "foo \\b bar" self.assertEqual(actual, expected) def test_formfeed(self): env = jinja_utils.get_jinja_environment() - template = '{{ test_str | json_escape }}' - actual = env.from_string(template).render({'test_str': 'foo \f bar'}) - expected = 'foo \\f bar' + template = "{{ test_str | json_escape }}" + actual = env.from_string(template).render({"test_str": "foo \f bar"}) + expected = "foo \\f bar" self.assertEqual(actual, expected) def test_newline(self): env = jinja_utils.get_jinja_environment() - template = '{{ test_str | json_escape }}' - actual = env.from_string(template).render({'test_str': 'foo \n bar'}) - expected = 'foo \\n bar' + template = "{{ test_str | json_escape }}" + actual = env.from_string(template).render({"test_str": "foo \n bar"}) + expected = "foo \\n bar" self.assertEqual(actual, expected) def test_carriagereturn(self): env = jinja_utils.get_jinja_environment() - template = '{{ test_str | json_escape }}' - actual = env.from_string(template).render({'test_str': 'foo \r bar'}) - expected = 'foo \\r bar' + template = "{{ test_str | json_escape }}" + actual = env.from_string(template).render({"test_str": "foo \r bar"}) + expected = "foo \\r bar" self.assertEqual(actual, expected) def test_tab(self): env = jinja_utils.get_jinja_environment() - template = '{{ test_str | json_escape }}' - actual = env.from_string(template).render({'test_str': 'foo \t bar'}) - expected = 'foo \\t bar' + template = "{{ test_str | json_escape }}" + actual = env.from_string(template).render({"test_str": "foo \t bar"}) + expected = "foo \\t bar" self.assertEqual(actual, expected) diff --git a/st2common/tests/unit/test_jinja_render_jsonpath_query_filters.py b/st2common/tests/unit/test_jinja_render_jsonpath_query_filters.py index fd199ebf64d..934aa04de8b 100644 --- a/st2common/tests/unit/test_jinja_render_jsonpath_query_filters.py +++ b/st2common/tests/unit/test_jinja_render_jsonpath_query_filters.py @@ -21,49 +21,58 @@ class JinjaUtilsJsonpathQueryTestCase(unittest2.TestCase): - def test_jsonpath_query_static(self): env = jinja_utils.get_jinja_environment() - obj = {'people': [{'first': 'James', 'last': 'd'}, - {'first': 'Jacob', 'last': 'e'}, - {'first': 'Jayden', 'last': 'f'}, - {'missing': 'different'}], - 'foo': {'bar': 'baz'}} + obj = { + "people": [ + {"first": "James", "last": "d"}, + {"first": "Jacob", "last": "e"}, + {"first": "Jayden", "last": "f"}, + {"missing": "different"}, + ], + "foo": {"bar": "baz"}, + } template = '{{ obj | jsonpath_query("people[*].first") }}' - actual_str = env.from_string(template).render({'obj': obj}) + actual_str = env.from_string(template).render({"obj": obj}) actual = eval(actual_str) - expected = ['James', 'Jacob', 'Jayden'] + expected = ["James", "Jacob", "Jayden"] self.assertEqual(actual, expected) def test_jsonpath_query_dynamic(self): env = jinja_utils.get_jinja_environment() - obj = {'people': [{'first': 'James', 'last': 'd'}, - {'first': 'Jacob', 'last': 'e'}, - {'first': 'Jayden', 'last': 'f'}, - {'missing': 'different'}], - 'foo': {'bar': 'baz'}} + obj = { + "people": [ + {"first": "James", "last": "d"}, + {"first": "Jacob", "last": "e"}, + {"first": "Jayden", "last": "f"}, + {"missing": "different"}, + ], + "foo": {"bar": "baz"}, + } query = "people[*].last" - template = '{{ obj | jsonpath_query(query) }}' - actual_str = env.from_string(template).render({'obj': obj, - 'query': query}) + template = "{{ obj | jsonpath_query(query) }}" + actual_str = env.from_string(template).render({"obj": obj, "query": query}) actual = eval(actual_str) - expected = ['d', 'e', 'f'] + expected = ["d", "e", "f"] self.assertEqual(actual, expected) def test_jsonpath_query_no_results(self): env = jinja_utils.get_jinja_environment() - obj = {'people': [{'first': 'James', 'last': 'd'}, - {'first': 'Jacob', 'last': 'e'}, - {'first': 'Jayden', 'last': 'f'}, - {'missing': 'different'}], - 'foo': {'bar': 'baz'}} + obj = { + "people": [ + {"first": "James", "last": "d"}, + {"first": "Jacob", "last": "e"}, + {"first": "Jayden", "last": "f"}, + {"missing": "different"}, + ], + "foo": {"bar": "baz"}, + } query = "query_returns_no_results" - template = '{{ obj | jsonpath_query(query) }}' - actual_str = env.from_string(template).render({'obj': obj, - 'query': query}) + template = "{{ obj | jsonpath_query(query) }}" + actual_str = env.from_string(template).render({"obj": obj, "query": query}) actual = eval(actual_str) expected = None self.assertEqual(actual, expected) diff --git a/st2common/tests/unit/test_jinja_render_path_filters.py b/st2common/tests/unit/test_jinja_render_path_filters.py index 504b6454bbb..23507bbbc19 100644 --- a/st2common/tests/unit/test_jinja_render_path_filters.py +++ b/st2common/tests/unit/test_jinja_render_path_filters.py @@ -21,29 +21,28 @@ class JinjaUtilsPathFilterTestCase(unittest2.TestCase): - def test_basename(self): env = jinja_utils.get_jinja_environment() - template = '{{k1 | basename}}' - actual = env.from_string(template).render({'k1': '/some/path/to/file.txt'}) - self.assertEqual(actual, 'file.txt') + template = "{{k1 | basename}}" + actual = env.from_string(template).render({"k1": "/some/path/to/file.txt"}) + self.assertEqual(actual, "file.txt") - actual = env.from_string(template).render({'k1': '/some/path/to/dir'}) - self.assertEqual(actual, 'dir') + actual = env.from_string(template).render({"k1": "/some/path/to/dir"}) + self.assertEqual(actual, "dir") - actual = env.from_string(template).render({'k1': '/some/path/to/dir/'}) - self.assertEqual(actual, '') + actual = env.from_string(template).render({"k1": "/some/path/to/dir/"}) + self.assertEqual(actual, "") def test_dirname(self): env = jinja_utils.get_jinja_environment() - template = '{{k1 | dirname}}' - actual = env.from_string(template).render({'k1': '/some/path/to/file.txt'}) - self.assertEqual(actual, '/some/path/to') + template = "{{k1 | dirname}}" + actual = env.from_string(template).render({"k1": "/some/path/to/file.txt"}) + self.assertEqual(actual, "/some/path/to") - actual = env.from_string(template).render({'k1': '/some/path/to/dir'}) - self.assertEqual(actual, '/some/path/to') + actual = env.from_string(template).render({"k1": "/some/path/to/dir"}) + self.assertEqual(actual, "/some/path/to") - actual = env.from_string(template).render({'k1': '/some/path/to/dir/'}) - self.assertEqual(actual, '/some/path/to/dir') + actual = env.from_string(template).render({"k1": "/some/path/to/dir/"}) + self.assertEqual(actual, "/some/path/to/dir") diff --git a/st2common/tests/unit/test_jinja_render_regex_filters.py b/st2common/tests/unit/test_jinja_render_regex_filters.py index 081d0686823..df2e3477799 100644 --- a/st2common/tests/unit/test_jinja_render_regex_filters.py +++ b/st2common/tests/unit/test_jinja_render_regex_filters.py @@ -20,54 +20,53 @@ class JinjaUtilsRegexFilterTestCase(unittest2.TestCase): - def test_filters_regex_match(self): env = jinja_utils.get_jinja_environment() template = '{{k1 | regex_match("x")}}' - actual = env.from_string(template).render({'k1': 'xyz'}) - expected = 'True' + actual = env.from_string(template).render({"k1": "xyz"}) + expected = "True" self.assertEqual(actual, expected) template = '{{k1 | regex_match("y")}}' - actual = env.from_string(template).render({'k1': 'xyz'}) - expected = 'False' + actual = env.from_string(template).render({"k1": "xyz"}) + expected = "False" self.assertEqual(actual, expected) template = '{{k1 | regex_match("^v(\\d+\\.)?(\\d+\\.)?(\\*|\\d+)$")}}' - actual = env.from_string(template).render({'k1': 'v0.10.1'}) - expected = 'True' + actual = env.from_string(template).render({"k1": "v0.10.1"}) + expected = "True" self.assertEqual(actual, expected) def test_filters_regex_replace(self): env = jinja_utils.get_jinja_environment() template = '{{k1 | regex_replace("x", "y")}}' - actual = env.from_string(template).render({'k1': 'xyz'}) - expected = 'yyz' + actual = env.from_string(template).render({"k1": "xyz"}) + expected = "yyz" self.assertEqual(actual, expected) template = '{{k1 | regex_replace("(blue|white|red)", "color")}}' - actual = env.from_string(template).render({'k1': 'blue socks and red shoes'}) - expected = 'color socks and color shoes' + actual = env.from_string(template).render({"k1": "blue socks and red shoes"}) + expected = "color socks and color shoes" self.assertEqual(actual, expected) def test_filters_regex_search(self): env = jinja_utils.get_jinja_environment() template = '{{k1 | regex_search("x")}}' - actual = env.from_string(template).render({'k1': 'xyz'}) - expected = 'True' + actual = env.from_string(template).render({"k1": "xyz"}) + expected = "True" self.assertEqual(actual, expected) template = '{{k1 | regex_search("y")}}' - actual = env.from_string(template).render({'k1': 'xyz'}) - expected = 'True' + actual = env.from_string(template).render({"k1": "xyz"}) + expected = "True" self.assertEqual(actual, expected) template = '{{k1 | regex_search("^v(\\d+\\.)?(\\d+\\.)?(\\*|\\d+)$")}}' - actual = env.from_string(template).render({'k1': 'v0.10.1'}) - expected = 'True' + actual = env.from_string(template).render({"k1": "v0.10.1"}) + expected = "True" self.assertEqual(actual, expected) def test_filters_regex_substring(self): @@ -76,29 +75,31 @@ def test_filters_regex_substring(self): # Normal (match) template = r'{{input_str | regex_substring("([0-9]{3} \w+ (?:Ave|St|Dr))")}}' actual = env.from_string(template).render( - {'input_str': 'My address is 123 Somewhere Ave. See you soon!'} + {"input_str": "My address is 123 Somewhere Ave. See you soon!"} ) - expected = '123 Somewhere Ave' + expected = "123 Somewhere Ave" self.assertEqual(actual, expected) # Selecting second match explicitly template = r'{{input_str | regex_substring("([0-9]{3} \w+ (?:Ave|St|Dr))", 1)}}' actual = env.from_string(template).render( - {'input_str': 'Your address is 567 Elsewhere Dr. My address is 123 Somewhere Ave.'} + { + "input_str": "Your address is 567 Elsewhere Dr. My address is 123 Somewhere Ave." + } ) - expected = '123 Somewhere Ave' + expected = "123 Somewhere Ave" self.assertEqual(actual, expected) # Selecting second match explicitly, but doesn't exist template = r'{{input_str | regex_substring("([0-9]{3} \w+ (?:Ave|St|Dr))", 1)}}' with self.assertRaises(IndexError): actual = env.from_string(template).render( - {'input_str': 'Your address is 567 Elsewhere Dr.'} + {"input_str": "Your address is 567 Elsewhere Dr."} ) # No match template = r'{{input_str | regex_substring("([0-3]{3} \w+ (?:Ave|St|Dr))")}}' with self.assertRaises(IndexError): actual = env.from_string(template).render( - {'input_str': 'My address is 986 Somewhere Ave. See you soon!'} + {"input_str": "My address is 986 Somewhere Ave. See you soon!"} ) diff --git a/st2common/tests/unit/test_jinja_render_time_filters.py b/st2common/tests/unit/test_jinja_render_time_filters.py index 5151cec6951..2cf002a0e3a 100644 --- a/st2common/tests/unit/test_jinja_render_time_filters.py +++ b/st2common/tests/unit/test_jinja_render_time_filters.py @@ -20,16 +20,16 @@ class JinjaUtilsTimeFilterTestCase(unittest2.TestCase): - def test_to_human_time_filter(self): env = jinja_utils.get_jinja_environment() - template = '{{k1 | to_human_time_from_seconds}}' - actual = env.from_string(template).render({'k1': 12345}) - self.assertEqual(actual, '3h25m45s') + template = "{{k1 | to_human_time_from_seconds}}" + actual = env.from_string(template).render({"k1": 12345}) + self.assertEqual(actual, "3h25m45s") - actual = env.from_string(template).render({'k1': 0}) - self.assertEqual(actual, '0s') + actual = env.from_string(template).render({"k1": 0}) + self.assertEqual(actual, "0s") - self.assertRaises(AssertionError, env.from_string(template).render, - {'k1': 'stuff'}) + self.assertRaises( + AssertionError, env.from_string(template).render, {"k1": "stuff"} + ) diff --git a/st2common/tests/unit/test_jinja_render_version_filters.py b/st2common/tests/unit/test_jinja_render_version_filters.py index 9cbacd7dcbe..41b2b236709 100644 --- a/st2common/tests/unit/test_jinja_render_version_filters.py +++ b/st2common/tests/unit/test_jinja_render_version_filters.py @@ -21,134 +21,133 @@ class JinjaUtilsVersionsFilterTestCase(unittest2.TestCase): - def test_version_compare(self): env = jinja_utils.get_jinja_environment() template = '{{version | version_compare("0.10.0")}}' - actual = env.from_string(template).render({'version': '0.9.0'}) - expected = '-1' + actual = env.from_string(template).render({"version": "0.9.0"}) + expected = "-1" self.assertEqual(actual, expected) template = '{{version | version_compare("0.10.0")}}' - actual = env.from_string(template).render({'version': '0.10.1'}) - expected = '1' + actual = env.from_string(template).render({"version": "0.10.1"}) + expected = "1" self.assertEqual(actual, expected) template = '{{version | version_compare("0.10.0")}}' - actual = env.from_string(template).render({'version': '0.10.0'}) - expected = '0' + actual = env.from_string(template).render({"version": "0.10.0"}) + expected = "0" self.assertEqual(actual, expected) def test_version_more_than(self): env = jinja_utils.get_jinja_environment() template = '{{version | version_more_than("0.10.0")}}' - actual = env.from_string(template).render({'version': '0.9.0'}) - expected = 'False' + actual = env.from_string(template).render({"version": "0.9.0"}) + expected = "False" self.assertEqual(actual, expected) template = '{{version | version_more_than("0.10.0")}}' - actual = env.from_string(template).render({'version': '0.10.1'}) - expected = 'True' + actual = env.from_string(template).render({"version": "0.10.1"}) + expected = "True" self.assertEqual(actual, expected) template = '{{version | version_more_than("0.10.0")}}' - actual = env.from_string(template).render({'version': '0.10.0'}) - expected = 'False' + actual = env.from_string(template).render({"version": "0.10.0"}) + expected = "False" self.assertEqual(actual, expected) def test_version_less_than(self): env = jinja_utils.get_jinja_environment() template = '{{version | version_less_than("0.10.0")}}' - actual = env.from_string(template).render({'version': '0.9.0'}) - expected = 'True' + actual = env.from_string(template).render({"version": "0.9.0"}) + expected = "True" self.assertEqual(actual, expected) template = '{{version | version_less_than("0.10.0")}}' - actual = env.from_string(template).render({'version': '0.10.1'}) - expected = 'False' + actual = env.from_string(template).render({"version": "0.10.1"}) + expected = "False" self.assertEqual(actual, expected) template = '{{version | version_less_than("0.10.0")}}' - actual = env.from_string(template).render({'version': '0.10.0'}) - expected = 'False' + actual = env.from_string(template).render({"version": "0.10.0"}) + expected = "False" self.assertEqual(actual, expected) def test_version_equal(self): env = jinja_utils.get_jinja_environment() template = '{{version | version_equal("0.10.0")}}' - actual = env.from_string(template).render({'version': '0.9.0'}) - expected = 'False' + actual = env.from_string(template).render({"version": "0.9.0"}) + expected = "False" self.assertEqual(actual, expected) template = '{{version | version_equal("0.10.0")}}' - actual = env.from_string(template).render({'version': '0.10.1'}) - expected = 'False' + actual = env.from_string(template).render({"version": "0.10.1"}) + expected = "False" self.assertEqual(actual, expected) template = '{{version | version_equal("0.10.0")}}' - actual = env.from_string(template).render({'version': '0.10.0'}) - expected = 'True' + actual = env.from_string(template).render({"version": "0.10.0"}) + expected = "True" self.assertEqual(actual, expected) def test_version_match(self): env = jinja_utils.get_jinja_environment() template = '{{version | version_match(">0.10.0")}}' - actual = env.from_string(template).render({'version': '0.10.1'}) - expected = 'True' + actual = env.from_string(template).render({"version": "0.10.1"}) + expected = "True" self.assertEqual(actual, expected) - actual = env.from_string(template).render({'version': '0.1.1'}) - expected = 'False' + actual = env.from_string(template).render({"version": "0.1.1"}) + expected = "False" self.assertEqual(actual, expected) template = '{{version | version_match("<0.10.0")}}' - actual = env.from_string(template).render({'version': '0.1.0'}) - expected = 'True' + actual = env.from_string(template).render({"version": "0.1.0"}) + expected = "True" self.assertEqual(actual, expected) - actual = env.from_string(template).render({'version': '1.1.0'}) - expected = 'False' + actual = env.from_string(template).render({"version": "1.1.0"}) + expected = "False" self.assertEqual(actual, expected) template = '{{version | version_match("==0.10.0")}}' - actual = env.from_string(template).render({'version': '0.10.0'}) - expected = 'True' + actual = env.from_string(template).render({"version": "0.10.0"}) + expected = "True" self.assertEqual(actual, expected) - actual = env.from_string(template).render({'version': '0.10.1'}) - expected = 'False' + actual = env.from_string(template).render({"version": "0.10.1"}) + expected = "False" self.assertEqual(actual, expected) def test_version_bump_major(self): env = jinja_utils.get_jinja_environment() - template = '{{version | version_bump_major}}' - actual = env.from_string(template).render({'version': '0.10.1'}) - expected = '1.0.0' + template = "{{version | version_bump_major}}" + actual = env.from_string(template).render({"version": "0.10.1"}) + expected = "1.0.0" self.assertEqual(actual, expected) def test_version_bump_minor(self): env = jinja_utils.get_jinja_environment() - template = '{{version | version_bump_minor}}' - actual = env.from_string(template).render({'version': '0.10.1'}) - expected = '0.11.0' + template = "{{version | version_bump_minor}}" + actual = env.from_string(template).render({"version": "0.10.1"}) + expected = "0.11.0" self.assertEqual(actual, expected) def test_version_bump_patch(self): env = jinja_utils.get_jinja_environment() - template = '{{version | version_bump_patch}}' - actual = env.from_string(template).render({'version': '0.10.1'}) - expected = '0.10.2' + template = "{{version | version_bump_patch}}" + actual = env.from_string(template).render({"version": "0.10.1"}) + expected = "0.10.2" self.assertEqual(actual, expected) def test_version_strip_patch(self): env = jinja_utils.get_jinja_environment() - template = '{{version | version_strip_patch}}' - actual = env.from_string(template).render({'version': '0.10.1'}) - expected = '0.10' + template = "{{version | version_strip_patch}}" + actual = env.from_string(template).render({"version": "0.10.1"}) + expected = "0.10" self.assertEqual(actual, expected) diff --git a/st2common/tests/unit/test_json_schema.py b/st2common/tests/unit/test_json_schema.py index 42b94efb703..892e2604ed6 100644 --- a/st2common/tests/unit/test_json_schema.py +++ b/st2common/tests/unit/test_json_schema.py @@ -20,158 +20,127 @@ from st2common.util import schema as util_schema TEST_SCHEMA_1 = { - 'additionalProperties': False, - 'title': 'foo', - 'description': 'Foo.', - 'type': 'object', - 'properties': { - 'arg_required_no_default': { - 'description': 'Foo', - 'required': True, - 'type': 'string' + "additionalProperties": False, + "title": "foo", + "description": "Foo.", + "type": "object", + "properties": { + "arg_required_no_default": { + "description": "Foo", + "required": True, + "type": "string", }, - 'arg_optional_no_type': { - 'description': 'Bar' + "arg_optional_no_type": {"description": "Bar"}, + "arg_optional_multi_type": { + "description": "Mirror mirror", + "type": ["string", "boolean", "number"], }, - 'arg_optional_multi_type': { - 'description': 'Mirror mirror', - 'type': ['string', 'boolean', 'number'] + "arg_optional_multi_type_none": { + "description": "Mirror mirror on the wall", + "type": ["string", "boolean", "number", "null"], }, - 'arg_optional_multi_type_none': { - 'description': 'Mirror mirror on the wall', - 'type': ['string', 'boolean', 'number', 'null'] + "arg_optional_type_array": { + "description": "Who" "s the fairest?", + "type": "array", }, - 'arg_optional_type_array': { - 'description': 'Who''s the fairest?', - 'type': 'array' + "arg_optional_type_object": { + "description": "Who" "s the fairest of them?", + "type": "object", }, - 'arg_optional_type_object': { - 'description': 'Who''s the fairest of them?', - 'type': 'object' + "arg_optional_multi_collection_type": { + "description": "Who" "s the fairest of them all?", + "type": ["array", "object"], }, - 'arg_optional_multi_collection_type': { - 'description': 'Who''s the fairest of them all?', - 'type': ['array', 'object'] - } - } + }, } TEST_SCHEMA_2 = { - 'additionalProperties': False, - 'title': 'foo', - 'description': 'Foo.', - 'type': 'object', - 'properties': { - 'arg_required_default': { - 'default': 'date', - 'description': 'Foo', - 'required': True, - 'type': 'string' + "additionalProperties": False, + "title": "foo", + "description": "Foo.", + "type": "object", + "properties": { + "arg_required_default": { + "default": "date", + "description": "Foo", + "required": True, + "type": "string", } - } + }, } TEST_SCHEMA_3 = { - 'additionalProperties': False, - 'title': 'foo', - 'description': 'Foo.', - 'type': 'object', - 'properties': { - 'arg_optional_default': { - 'default': 'bar', - 'description': 'Foo', - 'type': 'string' + "additionalProperties": False, + "title": "foo", + "description": "Foo.", + "type": "object", + "properties": { + "arg_optional_default": { + "default": "bar", + "description": "Foo", + "type": "string", }, - 'arg_optional_default_none': { - 'default': None, - 'description': 'Foo', - 'type': 'string' + "arg_optional_default_none": { + "default": None, + "description": "Foo", + "type": "string", }, - 'arg_optional_no_default': { - 'description': 'Foo', - 'type': 'string' - } - } + "arg_optional_no_default": {"description": "Foo", "type": "string"}, + }, } TEST_SCHEMA_4 = { - 'additionalProperties': False, - 'title': 'foo', - 'description': 'Foo.', - 'type': 'object', - 'properties': { - 'arg_optional_default': { - 'default': 'bar', - 'description': 'Foo', - 'anyOf': [ - {'type': 'string'}, - {'type': 'boolean'} - ] + "additionalProperties": False, + "title": "foo", + "description": "Foo.", + "type": "object", + "properties": { + "arg_optional_default": { + "default": "bar", + "description": "Foo", + "anyOf": [{"type": "string"}, {"type": "boolean"}], }, - 'arg_optional_default_none': { - 'default': None, - 'description': 'Foo', - 'anyOf': [ - {'type': 'string'}, - {'type': 'boolean'} - ] + "arg_optional_default_none": { + "default": None, + "description": "Foo", + "anyOf": [{"type": "string"}, {"type": "boolean"}], }, - 'arg_optional_no_default': { - 'description': 'Foo', - 'anyOf': [ - {'type': 'string'}, - {'type': 'boolean'} - ] + "arg_optional_no_default": { + "description": "Foo", + "anyOf": [{"type": "string"}, {"type": "boolean"}], }, - 'arg_optional_no_default_anyof_none': { - 'description': 'Foo', - 'anyOf': [ - {'type': 'string'}, - {'type': 'boolean'}, - {'type': 'null'} - ] - } - } + "arg_optional_no_default_anyof_none": { + "description": "Foo", + "anyOf": [{"type": "string"}, {"type": "boolean"}, {"type": "null"}], + }, + }, } TEST_SCHEMA_5 = { - 'additionalProperties': False, - 'title': 'foo', - 'description': 'Foo.', - 'type': 'object', - 'properties': { - 'arg_optional_default': { - 'default': 'bar', - 'description': 'Foo', - 'oneOf': [ - {'type': 'string'}, - {'type': 'boolean'} - ] + "additionalProperties": False, + "title": "foo", + "description": "Foo.", + "type": "object", + "properties": { + "arg_optional_default": { + "default": "bar", + "description": "Foo", + "oneOf": [{"type": "string"}, {"type": "boolean"}], }, - 'arg_optional_default_none': { - 'default': None, - 'description': 'Foo', - 'oneOf': [ - {'type': 'string'}, - {'type': 'boolean'} - ] + "arg_optional_default_none": { + "default": None, + "description": "Foo", + "oneOf": [{"type": "string"}, {"type": "boolean"}], }, - 'arg_optional_no_default': { - 'description': 'Foo', - 'oneOf': [ - {'type': 'string'}, - {'type': 'boolean'} - ] + "arg_optional_no_default": { + "description": "Foo", + "oneOf": [{"type": "string"}, {"type": "boolean"}], }, - 'arg_optional_no_default_oneof_none': { - 'description': 'Foo', - 'oneOf': [ - {'type': 'string'}, - {'type': 'boolean'}, - {'type': 'null'} - ] - } - } + "arg_optional_no_default_oneof_none": { + "description": "Foo", + "oneOf": [{"type": "string"}, {"type": "boolean"}, {"type": "null"}], + }, + }, } @@ -181,192 +150,265 @@ def test_use_default_value(self): instance = {} validator = util_schema.get_validator() - expected_msg = '\'arg_required_no_default\' is a required property' - self.assertRaisesRegexp(ValidationError, expected_msg, util_schema.validate, - instance=instance, schema=TEST_SCHEMA_1, cls=validator, - use_default=True) + expected_msg = "'arg_required_no_default' is a required property" + self.assertRaisesRegexp( + ValidationError, + expected_msg, + util_schema.validate, + instance=instance, + schema=TEST_SCHEMA_1, + cls=validator, + use_default=True, + ) # No default, value provided - instance = {'arg_required_no_default': 'foo'} - util_schema.validate(instance=instance, schema=TEST_SCHEMA_1, cls=validator, - use_default=True) + instance = {"arg_required_no_default": "foo"} + util_schema.validate( + instance=instance, schema=TEST_SCHEMA_1, cls=validator, use_default=True + ) # default value provided, no value, should pass instance = {} validator = util_schema.get_validator() - util_schema.validate(instance=instance, schema=TEST_SCHEMA_2, cls=validator, - use_default=True) + util_schema.validate( + instance=instance, schema=TEST_SCHEMA_2, cls=validator, use_default=True + ) # default value provided, value provided, should pass - instance = {'arg_required_default': 'foo'} + instance = {"arg_required_default": "foo"} validator = util_schema.get_validator() - util_schema.validate(instance=instance, schema=TEST_SCHEMA_2, cls=validator, - use_default=True) + util_schema.validate( + instance=instance, schema=TEST_SCHEMA_2, cls=validator, use_default=True + ) def test_allow_default_none(self): # Let validator take care of default validator = util_schema.get_validator() - util_schema.validate(instance=dict(), schema=TEST_SCHEMA_3, cls=validator, - use_default=True, allow_default_none=True) + util_schema.validate( + instance=dict(), + schema=TEST_SCHEMA_3, + cls=validator, + use_default=True, + allow_default_none=True, + ) def test_allow_default_explicit_none(self): # Explicitly pass None to arguments instance = { - 'arg_optional_default': None, - 'arg_optional_default_none': None, - 'arg_optional_no_default': None + "arg_optional_default": None, + "arg_optional_default_none": None, + "arg_optional_no_default": None, } validator = util_schema.get_validator() - util_schema.validate(instance=instance, schema=TEST_SCHEMA_3, cls=validator, - use_default=True, allow_default_none=True) + util_schema.validate( + instance=instance, + schema=TEST_SCHEMA_3, + cls=validator, + use_default=True, + allow_default_none=True, + ) def test_anyof_type_allow_default_none(self): # Let validator take care of default validator = util_schema.get_validator() - util_schema.validate(instance=dict(), schema=TEST_SCHEMA_4, cls=validator, - use_default=True, allow_default_none=True) + util_schema.validate( + instance=dict(), + schema=TEST_SCHEMA_4, + cls=validator, + use_default=True, + allow_default_none=True, + ) def test_anyof_allow_default_explicit_none(self): # Explicitly pass None to arguments instance = { - 'arg_optional_default': None, - 'arg_optional_default_none': None, - 'arg_optional_no_default': None, - 'arg_optional_no_default_anyof_none': None + "arg_optional_default": None, + "arg_optional_default_none": None, + "arg_optional_no_default": None, + "arg_optional_no_default_anyof_none": None, } validator = util_schema.get_validator() - util_schema.validate(instance=instance, schema=TEST_SCHEMA_4, cls=validator, - use_default=True, allow_default_none=True) + util_schema.validate( + instance=instance, + schema=TEST_SCHEMA_4, + cls=validator, + use_default=True, + allow_default_none=True, + ) def test_oneof_type_allow_default_none(self): # Let validator take care of default validator = util_schema.get_validator() - util_schema.validate(instance=dict(), schema=TEST_SCHEMA_5, cls=validator, - use_default=True, allow_default_none=True) + util_schema.validate( + instance=dict(), + schema=TEST_SCHEMA_5, + cls=validator, + use_default=True, + allow_default_none=True, + ) def test_oneof_allow_default_explicit_none(self): # Explicitly pass None to arguments instance = { - 'arg_optional_default': None, - 'arg_optional_default_none': None, - 'arg_optional_no_default': None, - 'arg_optional_no_default_oneof_none': None + "arg_optional_default": None, + "arg_optional_default_none": None, + "arg_optional_no_default": None, + "arg_optional_no_default_oneof_none": None, } validator = util_schema.get_validator() - util_schema.validate(instance=instance, schema=TEST_SCHEMA_5, cls=validator, - use_default=True, allow_default_none=True) + util_schema.validate( + instance=instance, + schema=TEST_SCHEMA_5, + cls=validator, + use_default=True, + allow_default_none=True, + ) def test_is_property_type_single(self): - typed_property = TEST_SCHEMA_1['properties']['arg_required_no_default'] + typed_property = TEST_SCHEMA_1["properties"]["arg_required_no_default"] self.assertTrue(util_schema.is_property_type_single(typed_property)) - untyped_property = TEST_SCHEMA_1['properties']['arg_optional_no_type'] + untyped_property = TEST_SCHEMA_1["properties"]["arg_optional_no_type"] self.assertTrue(util_schema.is_property_type_single(untyped_property)) - multi_typed_property = TEST_SCHEMA_1['properties']['arg_optional_multi_type'] + multi_typed_property = TEST_SCHEMA_1["properties"]["arg_optional_multi_type"] self.assertFalse(util_schema.is_property_type_single(multi_typed_property)) - anyof_property = TEST_SCHEMA_4['properties']['arg_optional_default'] + anyof_property = TEST_SCHEMA_4["properties"]["arg_optional_default"] self.assertFalse(util_schema.is_property_type_single(anyof_property)) - oneof_property = TEST_SCHEMA_5['properties']['arg_optional_default'] + oneof_property = TEST_SCHEMA_5["properties"]["arg_optional_default"] self.assertFalse(util_schema.is_property_type_single(oneof_property)) def test_is_property_type_anyof(self): - anyof_property = TEST_SCHEMA_4['properties']['arg_optional_default'] + anyof_property = TEST_SCHEMA_4["properties"]["arg_optional_default"] self.assertTrue(util_schema.is_property_type_anyof(anyof_property)) - typed_property = TEST_SCHEMA_1['properties']['arg_required_no_default'] + typed_property = TEST_SCHEMA_1["properties"]["arg_required_no_default"] self.assertFalse(util_schema.is_property_type_anyof(typed_property)) - untyped_property = TEST_SCHEMA_1['properties']['arg_optional_no_type'] + untyped_property = TEST_SCHEMA_1["properties"]["arg_optional_no_type"] self.assertFalse(util_schema.is_property_type_anyof(untyped_property)) - multi_typed_property = TEST_SCHEMA_1['properties']['arg_optional_multi_type'] + multi_typed_property = TEST_SCHEMA_1["properties"]["arg_optional_multi_type"] self.assertFalse(util_schema.is_property_type_anyof(multi_typed_property)) - oneof_property = TEST_SCHEMA_5['properties']['arg_optional_default'] + oneof_property = TEST_SCHEMA_5["properties"]["arg_optional_default"] self.assertFalse(util_schema.is_property_type_anyof(oneof_property)) def test_is_property_type_oneof(self): - oneof_property = TEST_SCHEMA_5['properties']['arg_optional_default'] + oneof_property = TEST_SCHEMA_5["properties"]["arg_optional_default"] self.assertTrue(util_schema.is_property_type_oneof(oneof_property)) - typed_property = TEST_SCHEMA_1['properties']['arg_required_no_default'] + typed_property = TEST_SCHEMA_1["properties"]["arg_required_no_default"] self.assertFalse(util_schema.is_property_type_oneof(typed_property)) - untyped_property = TEST_SCHEMA_1['properties']['arg_optional_no_type'] + untyped_property = TEST_SCHEMA_1["properties"]["arg_optional_no_type"] self.assertFalse(util_schema.is_property_type_oneof(untyped_property)) - multi_typed_property = TEST_SCHEMA_1['properties']['arg_optional_multi_type'] + multi_typed_property = TEST_SCHEMA_1["properties"]["arg_optional_multi_type"] self.assertFalse(util_schema.is_property_type_oneof(multi_typed_property)) - anyof_property = TEST_SCHEMA_4['properties']['arg_optional_default'] + anyof_property = TEST_SCHEMA_4["properties"]["arg_optional_default"] self.assertFalse(util_schema.is_property_type_oneof(anyof_property)) def test_is_property_type_list(self): - multi_typed_property = TEST_SCHEMA_1['properties']['arg_optional_multi_type'] + multi_typed_property = TEST_SCHEMA_1["properties"]["arg_optional_multi_type"] self.assertTrue(util_schema.is_property_type_list(multi_typed_property)) - typed_property = TEST_SCHEMA_1['properties']['arg_required_no_default'] + typed_property = TEST_SCHEMA_1["properties"]["arg_required_no_default"] self.assertFalse(util_schema.is_property_type_list(typed_property)) - untyped_property = TEST_SCHEMA_1['properties']['arg_optional_no_type'] + untyped_property = TEST_SCHEMA_1["properties"]["arg_optional_no_type"] self.assertFalse(util_schema.is_property_type_list(untyped_property)) - anyof_property = TEST_SCHEMA_4['properties']['arg_optional_default'] + anyof_property = TEST_SCHEMA_4["properties"]["arg_optional_default"] self.assertFalse(util_schema.is_property_type_list(anyof_property)) - oneof_property = TEST_SCHEMA_5['properties']['arg_optional_default'] + oneof_property = TEST_SCHEMA_5["properties"]["arg_optional_default"] self.assertFalse(util_schema.is_property_type_list(oneof_property)) def test_is_property_nullable(self): - multi_typed_prop_nullable = TEST_SCHEMA_1['properties']['arg_optional_multi_type_none'] - self.assertTrue(util_schema.is_property_nullable(multi_typed_prop_nullable.get('type'))) - - anyof_property_nullable = TEST_SCHEMA_4['properties']['arg_optional_no_default_anyof_none'] - self.assertTrue(util_schema.is_property_nullable(anyof_property_nullable.get('anyOf'))) - - oneof_property_nullable = TEST_SCHEMA_5['properties']['arg_optional_no_default_oneof_none'] - self.assertTrue(util_schema.is_property_nullable(oneof_property_nullable.get('oneOf'))) - - typed_property = TEST_SCHEMA_1['properties']['arg_required_no_default'] + multi_typed_prop_nullable = TEST_SCHEMA_1["properties"][ + "arg_optional_multi_type_none" + ] + self.assertTrue( + util_schema.is_property_nullable(multi_typed_prop_nullable.get("type")) + ) + + anyof_property_nullable = TEST_SCHEMA_4["properties"][ + "arg_optional_no_default_anyof_none" + ] + self.assertTrue( + util_schema.is_property_nullable(anyof_property_nullable.get("anyOf")) + ) + + oneof_property_nullable = TEST_SCHEMA_5["properties"][ + "arg_optional_no_default_oneof_none" + ] + self.assertTrue( + util_schema.is_property_nullable(oneof_property_nullable.get("oneOf")) + ) + + typed_property = TEST_SCHEMA_1["properties"]["arg_required_no_default"] self.assertFalse(util_schema.is_property_nullable(typed_property)) - multi_typed_property = TEST_SCHEMA_1['properties']['arg_optional_multi_type'] - self.assertFalse(util_schema.is_property_nullable(multi_typed_property.get('type'))) + multi_typed_property = TEST_SCHEMA_1["properties"]["arg_optional_multi_type"] + self.assertFalse( + util_schema.is_property_nullable(multi_typed_property.get("type")) + ) - anyof_property = TEST_SCHEMA_4['properties']['arg_optional_no_default'] - self.assertFalse(util_schema.is_property_nullable(anyof_property.get('anyOf'))) + anyof_property = TEST_SCHEMA_4["properties"]["arg_optional_no_default"] + self.assertFalse(util_schema.is_property_nullable(anyof_property.get("anyOf"))) - oneof_property = TEST_SCHEMA_5['properties']['arg_optional_no_default'] - self.assertFalse(util_schema.is_property_nullable(oneof_property.get('oneOf'))) + oneof_property = TEST_SCHEMA_5["properties"]["arg_optional_no_default"] + self.assertFalse(util_schema.is_property_nullable(oneof_property.get("oneOf"))) def test_is_attribute_type_array(self): - multi_coll_typed_prop = TEST_SCHEMA_1['properties']['arg_optional_multi_collection_type'] - self.assertTrue(util_schema.is_attribute_type_array(multi_coll_typed_prop.get('type'))) - - array_type_property = TEST_SCHEMA_1['properties']['arg_optional_type_array'] - self.assertTrue(util_schema.is_attribute_type_array(array_type_property.get('type'))) - - multi_non_coll_prop = TEST_SCHEMA_1['properties']['arg_optional_multi_type'] - self.assertFalse(util_schema.is_attribute_type_array(multi_non_coll_prop.get('type'))) - - object_type_property = TEST_SCHEMA_1['properties']['arg_optional_type_object'] - self.assertFalse(util_schema.is_attribute_type_array(object_type_property.get('type'))) + multi_coll_typed_prop = TEST_SCHEMA_1["properties"][ + "arg_optional_multi_collection_type" + ] + self.assertTrue( + util_schema.is_attribute_type_array(multi_coll_typed_prop.get("type")) + ) + + array_type_property = TEST_SCHEMA_1["properties"]["arg_optional_type_array"] + self.assertTrue( + util_schema.is_attribute_type_array(array_type_property.get("type")) + ) + + multi_non_coll_prop = TEST_SCHEMA_1["properties"]["arg_optional_multi_type"] + self.assertFalse( + util_schema.is_attribute_type_array(multi_non_coll_prop.get("type")) + ) + + object_type_property = TEST_SCHEMA_1["properties"]["arg_optional_type_object"] + self.assertFalse( + util_schema.is_attribute_type_array(object_type_property.get("type")) + ) def test_is_attribute_type_object(self): - multi_coll_typed_prop = TEST_SCHEMA_1['properties']['arg_optional_multi_collection_type'] - self.assertTrue(util_schema.is_attribute_type_object(multi_coll_typed_prop.get('type'))) - - object_type_property = TEST_SCHEMA_1['properties']['arg_optional_type_object'] - self.assertTrue(util_schema.is_attribute_type_object(object_type_property.get('type'))) - - multi_non_coll_prop = TEST_SCHEMA_1['properties']['arg_optional_multi_type'] - self.assertFalse(util_schema.is_attribute_type_object(multi_non_coll_prop.get('type'))) - - array_type_property = TEST_SCHEMA_1['properties']['arg_optional_type_array'] - self.assertFalse(util_schema.is_attribute_type_object(array_type_property.get('type'))) + multi_coll_typed_prop = TEST_SCHEMA_1["properties"][ + "arg_optional_multi_collection_type" + ] + self.assertTrue( + util_schema.is_attribute_type_object(multi_coll_typed_prop.get("type")) + ) + + object_type_property = TEST_SCHEMA_1["properties"]["arg_optional_type_object"] + self.assertTrue( + util_schema.is_attribute_type_object(object_type_property.get("type")) + ) + + multi_non_coll_prop = TEST_SCHEMA_1["properties"]["arg_optional_multi_type"] + self.assertFalse( + util_schema.is_attribute_type_object(multi_non_coll_prop.get("type")) + ) + + array_type_property = TEST_SCHEMA_1["properties"]["arg_optional_type_array"] + self.assertFalse( + util_schema.is_attribute_type_object(array_type_property.get("type")) + ) diff --git a/st2common/tests/unit/test_jsonify.py b/st2common/tests/unit/test_jsonify.py index 801d9123334..1feaac96b05 100644 --- a/st2common/tests/unit/test_jsonify.py +++ b/st2common/tests/unit/test_jsonify.py @@ -20,33 +20,32 @@ class JsonifyTests(unittest2.TestCase): - def test_none_object(self): obj = None self.assertIsNone(jsonify.json_loads(obj)) def test_no_keys(self): - obj = {'foo': '{"bar": "baz"}'} + obj = {"foo": '{"bar": "baz"}'} transformed_obj = jsonify.json_loads(obj) - self.assertTrue(transformed_obj['foo']['bar'] == 'baz') + self.assertTrue(transformed_obj["foo"]["bar"] == "baz") def test_no_json_value(self): - obj = {'foo': 'bar'} + obj = {"foo": "bar"} transformed_obj = jsonify.json_loads(obj) - self.assertTrue(transformed_obj['foo'] == 'bar') + self.assertTrue(transformed_obj["foo"] == "bar") def test_happy_case(self): - obj = {'foo': '{"bar": "baz"}', 'yo': 'bibimbao'} - transformed_obj = jsonify.json_loads(obj, ['yo']) - self.assertTrue(transformed_obj['yo'] == 'bibimbao') + obj = {"foo": '{"bar": "baz"}', "yo": "bibimbao"} + transformed_obj = jsonify.json_loads(obj, ["yo"]) + self.assertTrue(transformed_obj["yo"] == "bibimbao") def test_try_loads(self): # The function json.loads will fail and the function should return the original value. - values = ['abc', 123, True, object()] + values = ["abc", 123, True, object()] for value in values: self.assertEqual(jsonify.try_loads(value), value) # The function json.loads succeed. d = '{"a": 1, "b": true}' - expected = {'a': 1, 'b': True} + expected = {"a": 1, "b": True} self.assertDictEqual(jsonify.try_loads(d), expected) diff --git a/st2common/tests/unit/test_keyvalue_lookup.py b/st2common/tests/unit/test_keyvalue_lookup.py index f37cc04dc98..afcd76901a6 100644 --- a/st2common/tests/unit/test_keyvalue_lookup.py +++ b/st2common/tests/unit/test_keyvalue_lookup.py @@ -24,23 +24,29 @@ class TestKeyValueLookup(CleanDbTestCase): def test_lookup_with_key_prefix(self): - KeyValuePair.add_or_update(KeyValuePairDB(name='some:prefix:stanley:k5', value='v5', - scope=FULL_USER_SCOPE)) + KeyValuePair.add_or_update( + KeyValuePairDB( + name="some:prefix:stanley:k5", value="v5", scope=FULL_USER_SCOPE + ) + ) # No prefix provided, should return None - lookup = UserKeyValueLookup(user='stanley', scope=FULL_USER_SCOPE) - self.assertEqual(str(lookup.k5), '') + lookup = UserKeyValueLookup(user="stanley", scope=FULL_USER_SCOPE) + self.assertEqual(str(lookup.k5), "") # Prefix provided - lookup = UserKeyValueLookup(prefix='some:prefix', user='stanley', scope=FULL_USER_SCOPE) - self.assertEqual(str(lookup.k5), 'v5') + lookup = UserKeyValueLookup( + prefix="some:prefix", user="stanley", scope=FULL_USER_SCOPE + ) + self.assertEqual(str(lookup.k5), "v5") def test_non_hierarchical_lookup(self): - k1 = KeyValuePair.add_or_update(KeyValuePairDB(name='k1', value='v1')) - k2 = KeyValuePair.add_or_update(KeyValuePairDB(name='k2', value='v2')) - k3 = KeyValuePair.add_or_update(KeyValuePairDB(name='k3', value='v3')) - k4 = KeyValuePair.add_or_update(KeyValuePairDB(name='stanley:k4', value='v4', - scope=FULL_USER_SCOPE)) + k1 = KeyValuePair.add_or_update(KeyValuePairDB(name="k1", value="v1")) + k2 = KeyValuePair.add_or_update(KeyValuePairDB(name="k2", value="v2")) + k3 = KeyValuePair.add_or_update(KeyValuePairDB(name="k3", value="v3")) + k4 = KeyValuePair.add_or_update( + KeyValuePairDB(name="stanley:k4", value="v4", scope=FULL_USER_SCOPE) + ) lookup = KeyValueLookup() self.assertEqual(str(lookup.k1), k1.value) @@ -49,108 +55,119 @@ def test_non_hierarchical_lookup(self): # Scoped lookup lookup = KeyValueLookup(scope=FULL_SYSTEM_SCOPE) - self.assertEqual(str(lookup.k4), '') - user_lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user='stanley') + self.assertEqual(str(lookup.k4), "") + user_lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user="stanley") self.assertEqual(str(user_lookup.k4), k4.value) def test_hierarchical_lookup_dotted(self): - k1 = KeyValuePair.add_or_update(KeyValuePairDB(name='a.b', value='v1')) - k2 = KeyValuePair.add_or_update(KeyValuePairDB(name='a.b.c', value='v2')) - k3 = KeyValuePair.add_or_update(KeyValuePairDB(name='b.c', value='v3')) - k4 = KeyValuePair.add_or_update(KeyValuePairDB(name='stanley:r.i.p', value='v4', - scope=FULL_USER_SCOPE)) + k1 = KeyValuePair.add_or_update(KeyValuePairDB(name="a.b", value="v1")) + k2 = KeyValuePair.add_or_update(KeyValuePairDB(name="a.b.c", value="v2")) + k3 = KeyValuePair.add_or_update(KeyValuePairDB(name="b.c", value="v3")) + k4 = KeyValuePair.add_or_update( + KeyValuePairDB(name="stanley:r.i.p", value="v4", scope=FULL_USER_SCOPE) + ) lookup = KeyValueLookup() self.assertEqual(str(lookup.a.b), k1.value) self.assertEqual(str(lookup.a.b.c), k2.value) self.assertEqual(str(lookup.b.c), k3.value) - self.assertEqual(str(lookup.a), '') + self.assertEqual(str(lookup.a), "") # Scoped lookup lookup = KeyValueLookup(scope=FULL_SYSTEM_SCOPE) - self.assertEqual(str(lookup.r.i.p), '') - user_lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user='stanley') + self.assertEqual(str(lookup.r.i.p), "") + user_lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user="stanley") self.assertEqual(str(user_lookup.r.i.p), k4.value) def test_hierarchical_lookup_dict(self): - k1 = KeyValuePair.add_or_update(KeyValuePairDB(name='a.b', value='v1')) - k2 = KeyValuePair.add_or_update(KeyValuePairDB(name='a.b.c', value='v2')) - k3 = KeyValuePair.add_or_update(KeyValuePairDB(name='b.c', value='v3')) - k4 = KeyValuePair.add_or_update(KeyValuePairDB(name='stanley:r.i.p', value='v4', - scope=FULL_USER_SCOPE)) + k1 = KeyValuePair.add_or_update(KeyValuePairDB(name="a.b", value="v1")) + k2 = KeyValuePair.add_or_update(KeyValuePairDB(name="a.b.c", value="v2")) + k3 = KeyValuePair.add_or_update(KeyValuePairDB(name="b.c", value="v3")) + k4 = KeyValuePair.add_or_update( + KeyValuePairDB(name="stanley:r.i.p", value="v4", scope=FULL_USER_SCOPE) + ) lookup = KeyValueLookup() - self.assertEqual(str(lookup['a']['b']), k1.value) - self.assertEqual(str(lookup['a']['b']['c']), k2.value) - self.assertEqual(str(lookup['b']['c']), k3.value) - self.assertEqual(str(lookup['a']), '') + self.assertEqual(str(lookup["a"]["b"]), k1.value) + self.assertEqual(str(lookup["a"]["b"]["c"]), k2.value) + self.assertEqual(str(lookup["b"]["c"]), k3.value) + self.assertEqual(str(lookup["a"]), "") # Scoped lookup lookup = KeyValueLookup(scope=FULL_SYSTEM_SCOPE) - self.assertEqual(str(lookup['r']['i']['p']), '') - user_lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user='stanley') - self.assertEqual(str(user_lookup['r']['i']['p']), k4.value) + self.assertEqual(str(lookup["r"]["i"]["p"]), "") + user_lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user="stanley") + self.assertEqual(str(user_lookup["r"]["i"]["p"]), k4.value) def test_lookups_older_scope_names_backward_compatibility(self): - k1 = KeyValuePair.add_or_update(KeyValuePairDB(name='a.b', value='v1', - scope=FULL_SYSTEM_SCOPE)) + k1 = KeyValuePair.add_or_update( + KeyValuePairDB(name="a.b", value="v1", scope=FULL_SYSTEM_SCOPE) + ) lookup = KeyValueLookup(scope=SYSTEM_SCOPE) - self.assertEqual(str(lookup['a']['b']), k1.value) + self.assertEqual(str(lookup["a"]["b"]), k1.value) - k2 = KeyValuePair.add_or_update(KeyValuePairDB(name='stanley:r.i.p', value='v4', - scope=FULL_USER_SCOPE)) - user_lookup = UserKeyValueLookup(scope=USER_SCOPE, user='stanley') - self.assertEqual(str(user_lookup['r']['i']['p']), k2.value) + k2 = KeyValuePair.add_or_update( + KeyValuePairDB(name="stanley:r.i.p", value="v4", scope=FULL_USER_SCOPE) + ) + user_lookup = UserKeyValueLookup(scope=USER_SCOPE, user="stanley") + self.assertEqual(str(user_lookup["r"]["i"]["p"]), k2.value) def test_user_scope_lookups_dot_in_user(self): - KeyValuePair.add_or_update(KeyValuePairDB(name='first.last:r.i.p', value='v4', - scope=FULL_USER_SCOPE)) - lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user='first.last') - self.assertEqual(str(lookup.r.i.p), 'v4') - self.assertEqual(str(lookup['r']['i']['p']), 'v4') + KeyValuePair.add_or_update( + KeyValuePairDB(name="first.last:r.i.p", value="v4", scope=FULL_USER_SCOPE) + ) + lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user="first.last") + self.assertEqual(str(lookup.r.i.p), "v4") + self.assertEqual(str(lookup["r"]["i"]["p"]), "v4") def test_user_scope_lookups_user_sep_in_name(self): - KeyValuePair.add_or_update(KeyValuePairDB(name='stanley:r:i:p', value='v4', - scope=FULL_USER_SCOPE)) - lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user='stanley') + KeyValuePair.add_or_update( + KeyValuePairDB(name="stanley:r:i:p", value="v4", scope=FULL_USER_SCOPE) + ) + lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user="stanley") # This is the only way to lookup because USER_SEPARATOR (':') cannot be a part of # variable name in Python. - self.assertEqual(str(lookup['r:i:p']), 'v4') + self.assertEqual(str(lookup["r:i:p"]), "v4") def test_missing_key_lookup(self): lookup = KeyValueLookup(scope=FULL_SYSTEM_SCOPE) - self.assertEqual(str(lookup.missing_key), '') - self.assertTrue(lookup.missing_key, 'Should be not none.') + self.assertEqual(str(lookup.missing_key), "") + self.assertTrue(lookup.missing_key, "Should be not none.") - user_lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user='stanley') - self.assertEqual(str(user_lookup.missing_key), '') - self.assertTrue(user_lookup.missing_key, 'Should be not none.') + user_lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user="stanley") + self.assertEqual(str(user_lookup.missing_key), "") + self.assertTrue(user_lookup.missing_key, "Should be not none.") def test_secret_lookup(self): - secret_value = '0055A2D9A09E1071931925933744965EEA7E23DCF59A8D1D7A3' + \ - '64338294916D37E83C4796283C584751750E39844E2FD97A3727DB5D553F638' - k1 = KeyValuePair.add_or_update(KeyValuePairDB( - name='k1', value=secret_value, - secret=True) + secret_value = ( + "0055A2D9A09E1071931925933744965EEA7E23DCF59A8D1D7A3" + + "64338294916D37E83C4796283C584751750E39844E2FD97A3727DB5D553F638" + ) + k1 = KeyValuePair.add_or_update( + KeyValuePairDB(name="k1", value=secret_value, secret=True) ) - k2 = KeyValuePair.add_or_update(KeyValuePairDB(name='k2', value='v2')) - k3 = KeyValuePair.add_or_update(KeyValuePairDB( - name='stanley:k3', value=secret_value, scope=FULL_USER_SCOPE, - secret=True) + k2 = KeyValuePair.add_or_update(KeyValuePairDB(name="k2", value="v2")) + k3 = KeyValuePair.add_or_update( + KeyValuePairDB( + name="stanley:k3", + value=secret_value, + scope=FULL_USER_SCOPE, + secret=True, + ) ) lookup = KeyValueLookup() self.assertEqual(str(lookup.k1), k1.value) self.assertEqual(str(lookup.k2), k2.value) - self.assertEqual(str(lookup.k3), '') + self.assertEqual(str(lookup.k3), "") - user_lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user='stanley') + user_lookup = UserKeyValueLookup(scope=FULL_USER_SCOPE, user="stanley") self.assertEqual(str(user_lookup.k3), k3.value) def test_lookup_cast(self): - KeyValuePair.add_or_update(KeyValuePairDB(name='count', value='5.5')) + KeyValuePair.add_or_update(KeyValuePairDB(name="count", value="5.5")) lookup = KeyValueLookup(scope=FULL_SYSTEM_SCOPE) - self.assertEqual(str(lookup.count), '5.5') + self.assertEqual(str(lookup.count), "5.5") self.assertEqual(float(lookup.count), 5.5) self.assertEqual(int(lookup.count), 5) diff --git a/st2common/tests/unit/test_keyvalue_system_model.py b/st2common/tests/unit/test_keyvalue_system_model.py index ff834f2d6de..a8ea10822b3 100644 --- a/st2common/tests/unit/test_keyvalue_system_model.py +++ b/st2common/tests/unit/test_keyvalue_system_model.py @@ -21,15 +21,19 @@ class UserKeyReferenceSystemModelTest(unittest2.TestCase): - def test_to_string_reference(self): - key_ref = UserKeyReference.to_string_reference(user='stanley', name='foo') - self.assertEqual(key_ref, 'stanley:foo') - self.assertRaises(ValueError, UserKeyReference.to_string_reference, user=None, name='foo') + key_ref = UserKeyReference.to_string_reference(user="stanley", name="foo") + self.assertEqual(key_ref, "stanley:foo") + self.assertRaises( + ValueError, UserKeyReference.to_string_reference, user=None, name="foo" + ) def test_from_string_reference(self): - user, name = UserKeyReference.from_string_reference('stanley:foo') - self.assertEqual(user, 'stanley') - self.assertEqual(name, 'foo') - self.assertRaises(InvalidUserKeyReferenceError, UserKeyReference.from_string_reference, - 'this_key_has_no_sep') + user, name = UserKeyReference.from_string_reference("stanley:foo") + self.assertEqual(user, "stanley") + self.assertEqual(name, "foo") + self.assertRaises( + InvalidUserKeyReferenceError, + UserKeyReference.from_string_reference, + "this_key_has_no_sep", + ) diff --git a/st2common/tests/unit/test_logger.py b/st2common/tests/unit/test_logger.py index 30d18e9f893..79158b8b7f3 100644 --- a/st2common/tests/unit/test_logger.py +++ b/st2common/tests/unit/test_logger.py @@ -36,13 +36,13 @@ import st2tests.config as tests_config CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) -RESOURCES_DIR = os.path.abspath(os.path.join(CURRENT_DIR, '../resources')) -CONFIG_FILE_PATH = os.path.join(RESOURCES_DIR, 'logging.conf') +RESOURCES_DIR = os.path.abspath(os.path.join(CURRENT_DIR, "../resources")) +CONFIG_FILE_PATH = os.path.join(RESOURCES_DIR, "logging.conf") MOCK_MASKED_ATTRIBUTES_BLACKLIST = [ - 'blacklisted_1', - 'blacklisted_2', - 'blacklisted_3', + "blacklisted_1", + "blacklisted_2", + "blacklisted_3", ] @@ -69,9 +69,8 @@ def setUp(self): self.cfg_fd, self.cfg_path = tempfile.mkstemp() self.info_log_fd, self.info_log_path = tempfile.mkstemp() self.audit_log_fd, self.audit_log_path = tempfile.mkstemp() - with open(self.cfg_path, 'a') as f: - f.write(self.config_text.format(self.info_log_path, - self.audit_log_path)) + with open(self.cfg_path, "a") as f: + f.write(self.config_text.format(self.info_log_path, self.audit_log_path)) def tearDown(self): self._remove_tempfile(self.cfg_fd, self.cfg_path) @@ -84,7 +83,7 @@ def _remove_tempfile(self, fd, path): os.unlink(path) def test_logger_setup_failure(self): - config_file = '/tmp/abc123' + config_file = "/tmp/abc123" self.assertFalse(os.path.exists(config_file)) self.assertRaises(Exception, logging.setup, config_file) @@ -146,7 +145,7 @@ def test_format(self): formatter = ConsoleLogFormatter() # No extra attributes - mock_message = 'test message 1' + mock_message = "test message 1" record = MockRecord() record.msg = mock_message @@ -155,94 +154,109 @@ def test_format(self): self.assertEqual(message, mock_message) # Some extra attributes - mock_message = 'test message 2' + mock_message = "test message 2" record = MockRecord() record.msg = mock_message # Add "extra" attributes record._user_id = 1 - record._value = 'bar' - record.ignored = 'foo' # this one is ignored since it doesnt have a prefix + record._value = "bar" + record.ignored = "foo" # this one is ignored since it doesnt have a prefix message = formatter.format(record=record) - expected = 'test message 2 (value=\'bar\',user_id=1)' + expected = "test message 2 (value='bar',user_id=1)" self.assertEqual(sorted(message), sorted(expected)) - @mock.patch('st2common.logging.formatters.MASKED_ATTRIBUTES_BLACKLIST', - MOCK_MASKED_ATTRIBUTES_BLACKLIST) + @mock.patch( + "st2common.logging.formatters.MASKED_ATTRIBUTES_BLACKLIST", + MOCK_MASKED_ATTRIBUTES_BLACKLIST, + ) def test_format_blacklisted_attributes_are_masked(self): formatter = ConsoleLogFormatter() - mock_message = 'test message 1' + mock_message = "test message 1" record = MockRecord() record.msg = mock_message # Add "extra" attributes - record._blacklisted_1 = 'test value 1' - record._blacklisted_2 = 'test value 2' - record._blacklisted_3 = {'key1': 'val1', 'blacklisted_1': 'val2', 'key3': 'val3'} - record._foo1 = 'bar' + record._blacklisted_1 = "test value 1" + record._blacklisted_2 = "test value 2" + record._blacklisted_3 = { + "key1": "val1", + "blacklisted_1": "val2", + "key3": "val3", + } + record._foo1 = "bar" message = formatter.format(record=record) - expected = ("test message 1 (blacklisted_1='********',blacklisted_2='********'," - "blacklisted_3={'key3': 'val3', 'key1': 'val1', 'blacklisted_1': '********'}," - "foo1='bar')") + expected = ( + "test message 1 (blacklisted_1='********',blacklisted_2='********'," + "blacklisted_3={'key3': 'val3', 'key1': 'val1', 'blacklisted_1': '********'}," + "foo1='bar')" + ) self.assertEqual(sorted(message), sorted(expected)) - @mock.patch('st2common.logging.formatters.MASKED_ATTRIBUTES_BLACKLIST', - MOCK_MASKED_ATTRIBUTES_BLACKLIST) + @mock.patch( + "st2common.logging.formatters.MASKED_ATTRIBUTES_BLACKLIST", + MOCK_MASKED_ATTRIBUTES_BLACKLIST, + ) def test_format_custom_blacklist_attributes_are_masked(self): - cfg.CONF.set_override(group='log', name='mask_secrets_blacklist', - override=['blacklisted_4', 'blacklisted_5']) + cfg.CONF.set_override( + group="log", + name="mask_secrets_blacklist", + override=["blacklisted_4", "blacklisted_5"], + ) formatter = ConsoleLogFormatter() - mock_message = 'test message 1' + mock_message = "test message 1" record = MockRecord() record.msg = mock_message # Add "extra" attributes - record._blacklisted_1 = 'test value 1' - record._blacklisted_2 = 'test value 2' - record._blacklisted_3 = {'key1': 'val1', 'blacklisted_1': 'val2', 'key3': 'val3'} - record._blacklisted_4 = 'fowa' - record._blacklisted_5 = 'fiva' - record._foo1 = 'bar' + record._blacklisted_1 = "test value 1" + record._blacklisted_2 = "test value 2" + record._blacklisted_3 = { + "key1": "val1", + "blacklisted_1": "val2", + "key3": "val3", + } + record._blacklisted_4 = "fowa" + record._blacklisted_5 = "fiva" + record._foo1 = "bar" message = formatter.format(record=record) - expected = ("test message 1 (foo1='bar',blacklisted_1='********',blacklisted_2='********'," - "blacklisted_3={'key3': 'val3', 'key1': 'val1', 'blacklisted_1': '********'}," - "blacklisted_4='********',blacklisted_5='********')") + expected = ( + "test message 1 (foo1='bar',blacklisted_1='********',blacklisted_2='********'," + "blacklisted_3={'key3': 'val3', 'key1': 'val1', 'blacklisted_1': '********'}," + "blacklisted_4='********',blacklisted_5='********')" + ) self.assertEqual(sorted(message), sorted(expected)) - @mock.patch('st2common.logging.formatters.MASKED_ATTRIBUTES_BLACKLIST', - MOCK_MASKED_ATTRIBUTES_BLACKLIST) + @mock.patch( + "st2common.logging.formatters.MASKED_ATTRIBUTES_BLACKLIST", + MOCK_MASKED_ATTRIBUTES_BLACKLIST, + ) def test_format_secret_action_parameters_are_masked(self): formatter = ConsoleLogFormatter() - mock_message = 'test message 1' + mock_message = "test message 1" parameters = { - 'parameter1': { - 'type': 'string', - 'required': False - }, - 'parameter2': { - 'type': 'string', - 'required': False, - 'secret': True - } + "parameter1": {"type": "string", "required": False}, + "parameter2": {"type": "string", "required": False, "secret": True}, } - mock_action_db = ActionDB(pack='testpack', name='test.action', parameters=parameters) + mock_action_db = ActionDB( + pack="testpack", name="test.action", parameters=parameters + ) action = mock_action_db.to_serializable_dict() - parameters = { - 'parameter1': 'value1', - 'parameter2': 'value2' - } - mock_action_execution_db = ActionExecutionDB(action=action, parameters=parameters) + parameters = {"parameter1": "value1", "parameter2": "value2"} + mock_action_execution_db = ActionExecutionDB( + action=action, parameters=parameters + ) record = MockRecord() record.msg = mock_message @@ -250,97 +264,94 @@ def test_format_secret_action_parameters_are_masked(self): # Add "extra" attributes record._action_execution_db = mock_action_execution_db - expected_msg_part = (r"'parameters': {u?'parameter1': u?'value1', " - r"u?'parameter2': u?'\*\*\*\*\*\*\*\*'}") + expected_msg_part = ( + r"'parameters': {u?'parameter1': u?'value1', " + r"u?'parameter2': u?'\*\*\*\*\*\*\*\*'}" + ) message = formatter.format(record=record) - self.assertIn('test message 1', message) + self.assertIn("test message 1", message) self.assertRegexpMatches(message, expected_msg_part) - @mock.patch('st2common.logging.formatters.MASKED_ATTRIBUTES_BLACKLIST', - MOCK_MASKED_ATTRIBUTES_BLACKLIST) + @mock.patch( + "st2common.logging.formatters.MASKED_ATTRIBUTES_BLACKLIST", + MOCK_MASKED_ATTRIBUTES_BLACKLIST, + ) def test_format_rule(self): expected_result = { - 'description': 'Test description', - 'tags': [], - 'type': { - 'ref': 'standard', - 'parameters': {}}, - 'enabled': True, - 'trigger': 'test tigger', - 'metadata_file': None, - 'context': {}, - 'criteria': {}, - 'action': { - 'ref': '1234', - 'parameters': {'b': 2}}, - 'uid': 'rule:testpack:test.action', - 'pack': 'testpack', - 'ref': 'testpack.test.action', - 'id': None, - 'name': 'test.action' + "description": "Test description", + "tags": [], + "type": {"ref": "standard", "parameters": {}}, + "enabled": True, + "trigger": "test tigger", + "metadata_file": None, + "context": {}, + "criteria": {}, + "action": {"ref": "1234", "parameters": {"b": 2}}, + "uid": "rule:testpack:test.action", + "pack": "testpack", + "ref": "testpack.test.action", + "id": None, + "name": "test.action", } - mock_rule_db = RuleDB(pack='testpack', - name='test.action', - description='Test description', - trigger='test tigger', - action={'ref': '1234', 'parameters': {'b': 2}}) + mock_rule_db = RuleDB( + pack="testpack", + name="test.action", + description="Test description", + trigger="test tigger", + action={"ref": "1234", "parameters": {"b": 2}}, + ) result = mock_rule_db.to_serializable_dict() self.assertEqual(expected_result, result) - @mock.patch('st2common.logging.formatters.MASKED_ATTRIBUTES_BLACKLIST', - MOCK_MASKED_ATTRIBUTES_BLACKLIST) - @mock.patch('st2common.models.db.rule.RuleDB._get_referenced_action_model') - def test_format_secret_rule_parameters_are_masked(self, mock__get_referenced_action_model): + @mock.patch( + "st2common.logging.formatters.MASKED_ATTRIBUTES_BLACKLIST", + MOCK_MASKED_ATTRIBUTES_BLACKLIST, + ) + @mock.patch("st2common.models.db.rule.RuleDB._get_referenced_action_model") + def test_format_secret_rule_parameters_are_masked( + self, mock__get_referenced_action_model + ): expected_result = { - 'description': 'Test description', - 'tags': [], - 'type': { - 'ref': 'standard', - 'parameters': {}}, - 'enabled': True, - 'trigger': 'test tigger', - 'metadata_file': None, - 'context': {}, - 'criteria': {}, - 'action': { - 'ref': '1234', - 'parameters': { - 'parameter1': 'value1', - 'parameter2': '********' - }}, - 'uid': 'rule:testpack:test.action', - 'pack': 'testpack', - 'ref': 'testpack.test.action', - 'id': None, - 'name': 'test.action' + "description": "Test description", + "tags": [], + "type": {"ref": "standard", "parameters": {}}, + "enabled": True, + "trigger": "test tigger", + "metadata_file": None, + "context": {}, + "criteria": {}, + "action": { + "ref": "1234", + "parameters": {"parameter1": "value1", "parameter2": "********"}, + }, + "uid": "rule:testpack:test.action", + "pack": "testpack", + "ref": "testpack.test.action", + "id": None, + "name": "test.action", } parameters = { - 'parameter1': { - 'type': 'string', - 'required': False - }, - 'parameter2': { - 'type': 'string', - 'required': False, - 'secret': True - } + "parameter1": {"type": "string", "required": False}, + "parameter2": {"type": "string", "required": False, "secret": True}, } - mock_action_db = ActionDB(pack='testpack', name='test.action', parameters=parameters) + mock_action_db = ActionDB( + pack="testpack", name="test.action", parameters=parameters + ) mock__get_referenced_action_model.return_value = mock_action_db - cfg.CONF.set_override(group='log', name='mask_secrets', - override=True) - mock_rule_db = RuleDB(pack='testpack', - name='test.action', - description='Test description', - trigger='test tigger', - action={'ref': '1234', - 'parameters': { - 'parameter1': 'value1', - 'parameter2': 'value2' - }}) + cfg.CONF.set_override(group="log", name="mask_secrets", override=True) + mock_rule_db = RuleDB( + pack="testpack", + name="test.action", + description="Test description", + trigger="test tigger", + action={ + "ref": "1234", + "parameters": {"parameter1": "value1", "parameter2": "value2"}, + }, + ) result = mock_rule_db.to_serializable_dict(True) @@ -355,11 +366,18 @@ def setUpClass(cls): def test_format(self): formatter = GelfLogFormatter() - expected_keys = ['version', 'host', 'short_message', 'full_message', - 'timestamp', 'timestamp_f', 'level'] + expected_keys = [ + "version", + "host", + "short_message", + "full_message", + "timestamp", + "timestamp_f", + "level", + ] # No extra attributes - mock_message = 'test message 1' + mock_message = "test message 1" record = MockRecord() record.msg = mock_message @@ -370,19 +388,19 @@ def test_format(self): for key in expected_keys: self.assertIn(key, parsed) - self.assertEqual(parsed['short_message'], mock_message) - self.assertEqual(parsed['full_message'], mock_message) + self.assertEqual(parsed["short_message"], mock_message) + self.assertEqual(parsed["full_message"], mock_message) # Some extra attributes - mock_message = 'test message 2' + mock_message = "test message 2" record = MockRecord() record.msg = mock_message # Add "extra" attributes record._user_id = 1 - record._value = 'bar' - record.ignored = 'foo' # this one is ignored since it doesnt have a prefix + record._value = "bar" + record.ignored = "foo" # this one is ignored since it doesnt have a prefix record.created = 1234.5678 message = formatter.format(record=record) @@ -391,16 +409,16 @@ def test_format(self): for key in expected_keys: self.assertIn(key, parsed) - self.assertEqual(parsed['short_message'], mock_message) - self.assertEqual(parsed['full_message'], mock_message) - self.assertEqual(parsed['_user_id'], 1) - self.assertEqual(parsed['_value'], 'bar') - self.assertEqual(parsed['timestamp'], 1234) - self.assertEqual(parsed['timestamp_f'], 1234.5678) - self.assertNotIn('ignored', parsed) + self.assertEqual(parsed["short_message"], mock_message) + self.assertEqual(parsed["full_message"], mock_message) + self.assertEqual(parsed["_user_id"], 1) + self.assertEqual(parsed["_value"], "bar") + self.assertEqual(parsed["timestamp"], 1234) + self.assertEqual(parsed["timestamp_f"], 1234.5678) + self.assertNotIn("ignored", parsed) # Record with an exception - mock_exception = Exception('mock exception bar') + mock_exception = Exception("mock exception bar") try: raise mock_exception @@ -408,7 +426,7 @@ def test_format(self): mock_exc_info = sys.exc_info() # Some extra attributes - mock_message = 'test message 3' + mock_message = "test message 3" record = MockRecord() record.msg = mock_message @@ -420,69 +438,77 @@ def test_format(self): for key in expected_keys: self.assertIn(key, parsed) - self.assertEqual(parsed['short_message'], mock_message) - self.assertIn(mock_message, parsed['full_message']) - self.assertIn('Traceback', parsed['full_message']) - self.assertIn('_exception', parsed) - self.assertIn('_traceback', parsed) + self.assertEqual(parsed["short_message"], mock_message) + self.assertIn(mock_message, parsed["full_message"]) + self.assertIn("Traceback", parsed["full_message"]) + self.assertIn("_exception", parsed) + self.assertIn("_traceback", parsed) def test_extra_object_serialization(self): class MyClass1(object): def __repr__(self): - return 'repr' + return "repr" class MyClass2(object): def to_dict(self): - return 'to_dict' + return "to_dict" class MyClass3(object): def to_serializable_dict(self, mask_secrets=False): - return 'to_serializable_dict' + return "to_serializable_dict" formatter = GelfLogFormatter() record = MockRecord() - record.msg = 'message' + record.msg = "message" record._obj1 = MyClass1() record._obj2 = MyClass2() record._obj3 = MyClass3() message = formatter.format(record=record) parsed = json.loads(message) - self.assertEqual(parsed['_obj1'], 'repr') - self.assertEqual(parsed['_obj2'], 'to_dict') - self.assertEqual(parsed['_obj3'], 'to_serializable_dict') - - @mock.patch('st2common.logging.formatters.MASKED_ATTRIBUTES_BLACKLIST', - MOCK_MASKED_ATTRIBUTES_BLACKLIST) + self.assertEqual(parsed["_obj1"], "repr") + self.assertEqual(parsed["_obj2"], "to_dict") + self.assertEqual(parsed["_obj3"], "to_serializable_dict") + + @mock.patch( + "st2common.logging.formatters.MASKED_ATTRIBUTES_BLACKLIST", + MOCK_MASKED_ATTRIBUTES_BLACKLIST, + ) def test_format_blacklisted_attributes_are_masked(self): formatter = GelfLogFormatter() # Some extra attributes - mock_message = 'test message 1' + mock_message = "test message 1" record = MockRecord() record.msg = mock_message # Add "extra" attributes - record._blacklisted_1 = 'test value 1' - record._blacklisted_2 = 'test value 2' - record._blacklisted_3 = {'key1': 'val1', 'blacklisted_1': 'val2', 'key3': 'val3'} - record._foo1 = 'bar' + record._blacklisted_1 = "test value 1" + record._blacklisted_2 = "test value 2" + record._blacklisted_3 = { + "key1": "val1", + "blacklisted_1": "val2", + "key3": "val3", + } + record._foo1 = "bar" message = formatter.format(record=record) parsed = json.loads(message) - self.assertEqual(parsed['_blacklisted_1'], MASKED_ATTRIBUTE_VALUE) - self.assertEqual(parsed['_blacklisted_2'], MASKED_ATTRIBUTE_VALUE) - self.assertEqual(parsed['_blacklisted_3']['key1'], 'val1') - self.assertEqual(parsed['_blacklisted_3']['blacklisted_1'], MASKED_ATTRIBUTE_VALUE) - self.assertEqual(parsed['_blacklisted_3']['key3'], 'val3') - self.assertEqual(parsed['_foo1'], 'bar') + self.assertEqual(parsed["_blacklisted_1"], MASKED_ATTRIBUTE_VALUE) + self.assertEqual(parsed["_blacklisted_2"], MASKED_ATTRIBUTE_VALUE) + self.assertEqual(parsed["_blacklisted_3"]["key1"], "val1") + self.assertEqual( + parsed["_blacklisted_3"]["blacklisted_1"], MASKED_ATTRIBUTE_VALUE + ) + self.assertEqual(parsed["_blacklisted_3"]["key3"], "val3") + self.assertEqual(parsed["_foo1"], "bar") # Assert that the original dict is left unmodified - self.assertEqual(record._blacklisted_1, 'test value 1') - self.assertEqual(record._blacklisted_2, 'test value 2') - self.assertEqual(record._blacklisted_3['key1'], 'val1') - self.assertEqual(record._blacklisted_3['blacklisted_1'], 'val2') - self.assertEqual(record._blacklisted_3['key3'], 'val3') + self.assertEqual(record._blacklisted_1, "test value 1") + self.assertEqual(record._blacklisted_2, "test value 2") + self.assertEqual(record._blacklisted_3["key1"], "val1") + self.assertEqual(record._blacklisted_3["blacklisted_1"], "val2") + self.assertEqual(record._blacklisted_3["key3"], "val3") diff --git a/st2common/tests/unit/test_logging.py b/st2common/tests/unit/test_logging.py index 7dc4fc1b6d7..ebb75b6f1de 100644 --- a/st2common/tests/unit/test_logging.py +++ b/st2common/tests/unit/test_logging.py @@ -21,25 +21,29 @@ from python_runner import python_runner from st2common import runners -__all__ = [ - 'LoggingMiscUtilsTestCase' -] +__all__ = ["LoggingMiscUtilsTestCase"] class LoggingMiscUtilsTestCase(unittest2.TestCase): def test_get_logger_name_for_module(self): logger_name = get_logger_name_for_module(sensormanager) - self.assertEqual(logger_name, 'st2reactor.cmd.sensormanager') + self.assertEqual(logger_name, "st2reactor.cmd.sensormanager") logger_name = get_logger_name_for_module(python_runner) - result = logger_name.endswith('contrib.runners.python_runner.python_runner.python_runner') + result = logger_name.endswith( + "contrib.runners.python_runner.python_runner.python_runner" + ) self.assertTrue(result) - logger_name = get_logger_name_for_module(python_runner, exclude_module_name=True) - self.assertTrue(logger_name.endswith('contrib.runners.python_runner.python_runner')) + logger_name = get_logger_name_for_module( + python_runner, exclude_module_name=True + ) + self.assertTrue( + logger_name.endswith("contrib.runners.python_runner.python_runner") + ) logger_name = get_logger_name_for_module(runners) - self.assertEqual(logger_name, 'st2common.runners.__init__') + self.assertEqual(logger_name, "st2common.runners.__init__") logger_name = get_logger_name_for_module(runners, exclude_module_name=True) - self.assertEqual(logger_name, 'st2common.runners') + self.assertEqual(logger_name, "st2common.runners") diff --git a/st2common/tests/unit/test_logging_middleware.py b/st2common/tests/unit/test_logging_middleware.py index 8a59177bebb..b7d34de0bc7 100644 --- a/st2common/tests/unit/test_logging_middleware.py +++ b/st2common/tests/unit/test_logging_middleware.py @@ -21,18 +21,15 @@ from st2common.middleware.logging import LoggingMiddleware from st2common.constants.secrets import MASKED_ATTRIBUTE_VALUE -__all__ = [ - 'LoggingMiddlewareTestCase' -] +__all__ = ["LoggingMiddlewareTestCase"] class LoggingMiddlewareTestCase(unittest2.TestCase): - @mock.patch('st2common.middleware.logging.LOG') - @mock.patch('st2common.middleware.logging.Request') + @mock.patch("st2common.middleware.logging.LOG") + @mock.patch("st2common.middleware.logging.Request") def test_secret_parameters_are_masked_in_log_message(self, mock_request, mock_log): - def app(environ, custom_start_response): - custom_start_response(status='200 OK', headers=[('Content-Length', 100)]) + custom_start_response(status="200 OK", headers=[("Content-Length", 100)]) return [None] router = mock.Mock() @@ -40,35 +37,38 @@ def app(environ, custom_start_response): router.match.return_value = (endpoint, None) middleware = LoggingMiddleware(app=app, router=router) - cfg.CONF.set_override(group='log', name='mask_secrets_blacklist', - override=['blacklisted_4', 'blacklisted_5']) + cfg.CONF.set_override( + group="log", + name="mask_secrets_blacklist", + override=["blacklisted_4", "blacklisted_5"], + ) environ = {} mock_request.return_value.GET.dict_of_lists.return_value = { - 'foo': 'bar', - 'bar': 'baz', - 'x-auth-token': 'secret', - 'st2-api-key': 'secret', - 'password': 'secret', - 'st2_auth_token': 'secret', - 'token': 'secret', - 'blacklisted_4': 'super secret', - 'blacklisted_5': 'super secret', + "foo": "bar", + "bar": "baz", + "x-auth-token": "secret", + "st2-api-key": "secret", + "password": "secret", + "st2_auth_token": "secret", + "token": "secret", + "blacklisted_4": "super secret", + "blacklisted_5": "super secret", } middleware(environ=environ, start_response=mock.Mock()) expected_query = { - 'foo': 'bar', - 'bar': 'baz', - 'x-auth-token': MASKED_ATTRIBUTE_VALUE, - 'st2-api-key': MASKED_ATTRIBUTE_VALUE, - 'password': MASKED_ATTRIBUTE_VALUE, - 'token': MASKED_ATTRIBUTE_VALUE, - 'st2_auth_token': MASKED_ATTRIBUTE_VALUE, - 'blacklisted_4': MASKED_ATTRIBUTE_VALUE, - 'blacklisted_5': MASKED_ATTRIBUTE_VALUE, + "foo": "bar", + "bar": "baz", + "x-auth-token": MASKED_ATTRIBUTE_VALUE, + "st2-api-key": MASKED_ATTRIBUTE_VALUE, + "password": MASKED_ATTRIBUTE_VALUE, + "token": MASKED_ATTRIBUTE_VALUE, + "st2_auth_token": MASKED_ATTRIBUTE_VALUE, + "blacklisted_4": MASKED_ATTRIBUTE_VALUE, + "blacklisted_5": MASKED_ATTRIBUTE_VALUE, } call_kwargs = mock_log.info.call_args_list[0][1] - query = call_kwargs['extra']['query'] + query = call_kwargs["extra"]["query"] self.assertEqual(query, expected_query) diff --git a/st2common/tests/unit/test_metrics.py b/st2common/tests/unit/test_metrics.py index 084db97c626..4b0df66aa19 100644 --- a/st2common/tests/unit/test_metrics.py +++ b/st2common/tests/unit/test_metrics.py @@ -29,16 +29,16 @@ from st2common.util.date import get_datetime_utc_now __all__ = [ - 'TestBaseMetricsDriver', - 'TestStatsDMetricsDriver', - 'TestCounterContextManager', - 'TestTimerContextManager', - 'TestCounterWithTimerContextManager' + "TestBaseMetricsDriver", + "TestStatsDMetricsDriver", + "TestCounterContextManager", + "TestTimerContextManager", + "TestCounterWithTimerContextManager", ] -cfg.CONF.set_override('driver', 'noop', group='metrics') -cfg.CONF.set_override('host', '127.0.0.1', group='metrics') -cfg.CONF.set_override('port', 8080, group='metrics') +cfg.CONF.set_override("driver", "noop", group="metrics") +cfg.CONF.set_override("host", "127.0.0.1", group="metrics") +cfg.CONF.set_override("port", 8080, group="metrics") class TestBaseMetricsDriver(unittest2.TestCase): @@ -48,45 +48,43 @@ def setUp(self): self._driver = base.BaseMetricsDriver() def test_time(self): - self._driver.time('test', 10) + self._driver.time("test", 10) def test_inc_counter(self): - self._driver.inc_counter('test') + self._driver.inc_counter("test") def test_dec_timer(self): - self._driver.dec_counter('test') + self._driver.dec_counter("test") class TestStatsDMetricsDriver(unittest2.TestCase): _driver = None - @patch('st2common.metrics.drivers.statsd_driver.statsd') + @patch("st2common.metrics.drivers.statsd_driver.statsd") def setUp(self, statsd): - cfg.CONF.set_override(name='prefix', group='metrics', override=None) + cfg.CONF.set_override(name="prefix", group="metrics", override=None) self._driver = StatsdDriver() statsd.Connection.set_defaults.assert_called_once_with( - host=cfg.CONF.metrics.host, - port=cfg.CONF.metrics.port, - sample_rate=1.0 + host=cfg.CONF.metrics.host, port=cfg.CONF.metrics.port, sample_rate=1.0 ) - @patch('st2common.metrics.drivers.statsd_driver.statsd') + @patch("st2common.metrics.drivers.statsd_driver.statsd") def test_time(self, statsd): mock_timer = MagicMock() - statsd.Timer('').send.side_effect = mock_timer - params = ('test', 10) + statsd.Timer("").send.side_effect = mock_timer + params = ("test", 10) self._driver.time(*params) - statsd.Timer('').send.assert_called_with('st2.test', 10) + statsd.Timer("").send.assert_called_with("st2.test", 10) - @patch('st2common.metrics.drivers.statsd_driver.statsd') + @patch("st2common.metrics.drivers.statsd_driver.statsd") def test_time_with_float(self, statsd): mock_timer = MagicMock() - statsd.Timer('').send.side_effect = mock_timer - params = ('test', 10.5) + statsd.Timer("").send.side_effect = mock_timer + params = ("test", 10.5) self._driver.time(*params) - statsd.Timer().send.assert_called_with('st2.test', 10.5) + statsd.Timer().send.assert_called_with("st2.test", 10.5) def test_time_with_invalid_key(self): params = (2, 2) @@ -94,21 +92,21 @@ def test_time_with_invalid_key(self): self._driver.time(*params) def test_time_with_invalid_time(self): - params = ('test', '1') + params = ("test", "1") with self.assertRaises(AssertionError): self._driver.time(*params) - @patch('st2common.metrics.drivers.statsd_driver.statsd') + @patch("st2common.metrics.drivers.statsd_driver.statsd") def test_inc_counter_with_default_amount(self, statsd): - key = 'test' + key = "test" mock_counter = MagicMock() statsd.Counter(key).increment.side_effect = mock_counter self._driver.inc_counter(key) mock_counter.assert_called_once_with(delta=1) - @patch('st2common.metrics.drivers.statsd_driver.statsd') + @patch("st2common.metrics.drivers.statsd_driver.statsd") def test_inc_counter_with_amount(self, statsd): - params = ('test', 2) + params = ("test", 2) mock_counter = MagicMock() statsd.Counter(params[0]).increment.side_effect = mock_counter self._driver.inc_counter(*params) @@ -120,21 +118,21 @@ def test_inc_timer_with_invalid_key(self): self._driver.inc_counter(*params) def test_inc_timer_with_invalid_amount(self): - params = ('test', '1') + params = ("test", "1") with self.assertRaises(AssertionError): self._driver.inc_counter(*params) - @patch('st2common.metrics.drivers.statsd_driver.statsd') + @patch("st2common.metrics.drivers.statsd_driver.statsd") def test_dec_timer_with_default_amount(self, statsd): - key = 'test' + key = "test" mock_counter = MagicMock() statsd.Counter().decrement.side_effect = mock_counter self._driver.dec_counter(key) mock_counter.assert_called_once_with(delta=1) - @patch('st2common.metrics.drivers.statsd_driver.statsd') + @patch("st2common.metrics.drivers.statsd_driver.statsd") def test_dec_timer_with_amount(self, statsd): - params = ('test', 2) + params = ("test", 2) mock_counter = MagicMock() statsd.Counter().decrement.side_effect = mock_counter self._driver.dec_counter(*params) @@ -146,41 +144,41 @@ def test_dec_timer_with_invalid_key(self): self._driver.dec_counter(*params) def test_dec_timer_with_invalid_amount(self): - params = ('test', '1') + params = ("test", "1") with self.assertRaises(AssertionError): self._driver.dec_counter(*params) - @patch('st2common.metrics.drivers.statsd_driver.statsd') + @patch("st2common.metrics.drivers.statsd_driver.statsd") def test_set_gauge_success(self, statsd): - params = ('key', 100) + params = ("key", 100) mock_gauge = MagicMock() statsd.Gauge().send.side_effect = mock_gauge self._driver.set_gauge(*params) mock_gauge.assert_called_once_with(None, params[1]) - @patch('st2common.metrics.drivers.statsd_driver.statsd') + @patch("st2common.metrics.drivers.statsd_driver.statsd") def test_inc_gauge_success(self, statsd): - params = ('key1',) + params = ("key1",) mock_gauge = MagicMock() statsd.Gauge().increment.side_effect = mock_gauge self._driver.inc_gauge(*params) mock_gauge.assert_called_once_with(None, 1) - params = ('key2', 100) + params = ("key2", 100) mock_gauge = MagicMock() statsd.Gauge().increment.side_effect = mock_gauge self._driver.inc_gauge(*params) mock_gauge.assert_called_once_with(None, params[1]) - @patch('st2common.metrics.drivers.statsd_driver.statsd') + @patch("st2common.metrics.drivers.statsd_driver.statsd") def test_dec_gauge_success(self, statsd): - params = ('key1',) + params = ("key1",) mock_gauge = MagicMock() statsd.Gauge().decrement.side_effect = mock_gauge self._driver.dec_gauge(*params) mock_gauge.assert_called_once_with(None, 1) - params = ('key2', 100) + params = ("key2", 100) mock_gauge = MagicMock() statsd.Gauge().decrement.side_effect = mock_gauge self._driver.dec_gauge(*params) @@ -188,71 +186,71 @@ def test_dec_gauge_success(self, statsd): def test_get_full_key_name(self): # No prefix specified in the config - cfg.CONF.set_override(name='prefix', group='metrics', override=None) + cfg.CONF.set_override(name="prefix", group="metrics", override=None) - result = get_full_key_name('api.requests') - self.assertEqual(result, 'st2.api.requests') + result = get_full_key_name("api.requests") + self.assertEqual(result, "st2.api.requests") # Prefix is defined in the config - cfg.CONF.set_override(name='prefix', group='metrics', override='staging') + cfg.CONF.set_override(name="prefix", group="metrics", override="staging") - result = get_full_key_name('api.requests') - self.assertEqual(result, 'st2.staging.api.requests') + result = get_full_key_name("api.requests") + self.assertEqual(result, "st2.staging.api.requests") - cfg.CONF.set_override(name='prefix', group='metrics', override='prod') + cfg.CONF.set_override(name="prefix", group="metrics", override="prod") - result = get_full_key_name('api.requests') - self.assertEqual(result, 'st2.prod.api.requests') + result = get_full_key_name("api.requests") + self.assertEqual(result, "st2.prod.api.requests") - @patch('st2common.metrics.drivers.statsd_driver.LOG') - @patch('st2common.metrics.drivers.statsd_driver.statsd') + @patch("st2common.metrics.drivers.statsd_driver.LOG") + @patch("st2common.metrics.drivers.statsd_driver.statsd") def test_driver_socket_exceptions_are_not_fatal(self, statsd, mock_log): # Socket errors such as DNS resolution errors should be considered non fatal and ignored mock_logger = mock.Mock() StatsdDriver.logger = mock_logger # 1. timer - mock_timer = MagicMock(side_effect=socket.error('error 1')) - statsd.Timer('').send.side_effect = mock_timer - params = ('test', 10) + mock_timer = MagicMock(side_effect=socket.error("error 1")) + statsd.Timer("").send.side_effect = mock_timer + params = ("test", 10) self._driver.time(*params) - statsd.Timer('').send.assert_called_with('st2.test', 10) + statsd.Timer("").send.assert_called_with("st2.test", 10) # 2. counter - key = 'test' - mock_counter = MagicMock(side_effect=socket.error('error 2')) + key = "test" + mock_counter = MagicMock(side_effect=socket.error("error 2")) statsd.Counter(key).increment.side_effect = mock_counter self._driver.inc_counter(key) mock_counter.assert_called_once_with(delta=1) - key = 'test' - mock_counter = MagicMock(side_effect=socket.error('error 3')) + key = "test" + mock_counter = MagicMock(side_effect=socket.error("error 3")) statsd.Counter(key).decrement.side_effect = mock_counter self._driver.dec_counter(key) mock_counter.assert_called_once_with(delta=1) # 3. gauge - params = ('key', 100) - mock_gauge = MagicMock(side_effect=socket.error('error 4')) + params = ("key", 100) + mock_gauge = MagicMock(side_effect=socket.error("error 4")) statsd.Gauge().send.side_effect = mock_gauge self._driver.set_gauge(*params) mock_gauge.assert_called_once_with(None, params[1]) - params = ('key1',) - mock_gauge = MagicMock(side_effect=socket.error('error 5')) + params = ("key1",) + mock_gauge = MagicMock(side_effect=socket.error("error 5")) statsd.Gauge().increment.side_effect = mock_gauge self._driver.inc_gauge(*params) mock_gauge.assert_called_once_with(None, 1) - params = ('key1',) - mock_gauge = MagicMock(side_effect=socket.error('error 6')) + params = ("key1",) + mock_gauge = MagicMock(side_effect=socket.error("error 6")) statsd.Gauge().decrement.side_effect = mock_gauge self._driver.dec_gauge(*params) mock_gauge.assert_called_once_with(None, 1) class TestCounterContextManager(unittest2.TestCase): - @patch('st2common.metrics.base.METRICS') + @patch("st2common.metrics.base.METRICS") def test_counter(self, metrics_patch): test_key = "test_key" with base.Counter(test_key): @@ -261,8 +259,8 @@ def test_counter(self, metrics_patch): class TestTimerContextManager(unittest2.TestCase): - @patch('st2common.metrics.base.get_datetime_utc_now') - @patch('st2common.metrics.base.METRICS') + @patch("st2common.metrics.base.get_datetime_utc_now") + @patch("st2common.metrics.base.METRICS") def test_time(self, metrics_patch, datetime_patch): start_time = get_datetime_utc_now() middle_time = start_time + timedelta(seconds=1) @@ -272,7 +270,7 @@ def test_time(self, metrics_patch, datetime_patch): middle_time, middle_time, middle_time, - end_time + end_time, ] test_key = "test_key" with base.Timer(test_key) as timer: @@ -280,23 +278,19 @@ def test_time(self, metrics_patch, datetime_patch): metrics_patch.time.assert_not_called() timer.send_time() metrics_patch.time.assert_called_with( - test_key, - (end_time - middle_time).total_seconds() + test_key, (end_time - middle_time).total_seconds() ) second_test_key = "lakshmi_has_toes" timer.send_time(second_test_key) metrics_patch.time.assert_called_with( - second_test_key, - (end_time - middle_time).total_seconds() + second_test_key, (end_time - middle_time).total_seconds() ) time_delta = timer.get_time_delta() self.assertEqual( - time_delta.total_seconds(), - (end_time - middle_time).total_seconds() + time_delta.total_seconds(), (end_time - middle_time).total_seconds() ) metrics_patch.time.assert_called_with( - test_key, - (end_time - start_time).total_seconds() + test_key, (end_time - start_time).total_seconds() ) @@ -306,46 +300,44 @@ def setUp(self): self.middle_time = self.start_time + timedelta(seconds=1) self.end_time = self.middle_time + timedelta(seconds=1) - @patch('st2common.metrics.base.get_datetime_utc_now') - @patch('st2common.metrics.base.METRICS') + @patch("st2common.metrics.base.get_datetime_utc_now") + @patch("st2common.metrics.base.METRICS") def test_time(self, metrics_patch, datetime_patch): datetime_patch.side_effect = [ self.start_time, self.middle_time, self.middle_time, self.middle_time, - self.end_time + self.end_time, ] test_key = "test_key" with base.CounterWithTimer(test_key) as timer: self.assertIsInstance(timer._start_time, datetime) metrics_patch.time.assert_not_called() timer.send_time() - metrics_patch.time.assert_called_with(test_key, - (self.end_time - self.middle_time).total_seconds() + metrics_patch.time.assert_called_with( + test_key, (self.end_time - self.middle_time).total_seconds() ) second_test_key = "lakshmi_has_a_nose" timer.send_time(second_test_key) metrics_patch.time.assert_called_with( - second_test_key, - (self.end_time - self.middle_time).total_seconds() + second_test_key, (self.end_time - self.middle_time).total_seconds() ) time_delta = timer.get_time_delta() self.assertEqual( time_delta.total_seconds(), - (self.end_time - self.middle_time).total_seconds() + (self.end_time - self.middle_time).total_seconds(), ) metrics_patch.inc_counter.assert_called_with(test_key) metrics_patch.dec_counter.assert_not_called() metrics_patch.time.assert_called_with( - test_key, - (self.end_time - self.start_time).total_seconds() + test_key, (self.end_time - self.start_time).total_seconds() ) class TestCounterWithTimerDecorator(unittest2.TestCase): - @patch('st2common.metrics.base.get_datetime_utc_now') - @patch('st2common.metrics.base.METRICS') + @patch("st2common.metrics.base.get_datetime_utc_now") + @patch("st2common.metrics.base.METRICS") def test_time(self, metrics_patch, datetime_patch): start_time = get_datetime_utc_now() middle_time = start_time + timedelta(seconds=1) @@ -355,7 +347,7 @@ def test_time(self, metrics_patch, datetime_patch): middle_time, middle_time, middle_time, - end_time + end_time, ] test_key = "test_key" @@ -364,32 +356,30 @@ def _get_tested(metrics_counter_with_timer=None): self.assertIsInstance(metrics_counter_with_timer._start_time, datetime) metrics_patch.time.assert_not_called() metrics_counter_with_timer.send_time() - metrics_patch.time.assert_called_with(test_key, - (end_time - middle_time).total_seconds() + metrics_patch.time.assert_called_with( + test_key, (end_time - middle_time).total_seconds() ) second_test_key = "lakshmi_has_a_nose" metrics_counter_with_timer.send_time(second_test_key) metrics_patch.time.assert_called_with( - second_test_key, - (end_time - middle_time).total_seconds() + second_test_key, (end_time - middle_time).total_seconds() ) time_delta = metrics_counter_with_timer.get_time_delta() self.assertEqual( - time_delta.total_seconds(), - (end_time - middle_time).total_seconds() + time_delta.total_seconds(), (end_time - middle_time).total_seconds() ) metrics_patch.inc_counter.assert_called_with(test_key) metrics_patch.dec_counter.assert_not_called() _get_tested() - metrics_patch.time.assert_called_with(test_key, - (end_time - start_time).total_seconds() + metrics_patch.time.assert_called_with( + test_key, (end_time - start_time).total_seconds() ) class TestCounterDecorator(unittest2.TestCase): - @patch('st2common.metrics.base.METRICS') + @patch("st2common.metrics.base.METRICS") def test_counter(self, metrics_patch): test_key = "test_key" @@ -397,12 +387,13 @@ def test_counter(self, metrics_patch): def _get_tested(): metrics_patch.inc_counter.assert_called_with(test_key) metrics_patch.dec_counter.assert_not_called() + _get_tested() class TestTimerDecorator(unittest2.TestCase): - @patch('st2common.metrics.base.get_datetime_utc_now') - @patch('st2common.metrics.base.METRICS') + @patch("st2common.metrics.base.get_datetime_utc_now") + @patch("st2common.metrics.base.METRICS") def test_time(self, metrics_patch, datetime_patch): start_time = get_datetime_utc_now() middle_time = start_time + timedelta(seconds=1) @@ -412,7 +403,7 @@ def test_time(self, metrics_patch, datetime_patch): middle_time, middle_time, middle_time, - end_time + end_time, ] test_key = "test_key" @@ -422,22 +413,19 @@ def _get_tested(metrics_timer=None): metrics_patch.time.assert_not_called() metrics_timer.send_time() metrics_patch.time.assert_called_with( - test_key, - (end_time - middle_time).total_seconds() + test_key, (end_time - middle_time).total_seconds() ) second_test_key = "lakshmi_has_toes" metrics_timer.send_time(second_test_key) metrics_patch.time.assert_called_with( - second_test_key, - (end_time - middle_time).total_seconds() + second_test_key, (end_time - middle_time).total_seconds() ) time_delta = metrics_timer.get_time_delta() self.assertEqual( - time_delta.total_seconds(), - (end_time - middle_time).total_seconds() + time_delta.total_seconds(), (end_time - middle_time).total_seconds() ) + _get_tested() metrics_patch.time.assert_called_with( - test_key, - (end_time - start_time).total_seconds() + test_key, (end_time - start_time).total_seconds() ) diff --git a/st2common/tests/unit/test_misc_utils.py b/st2common/tests/unit/test_misc_utils.py index d7008e921c3..f05615573bb 100644 --- a/st2common/tests/unit/test_misc_utils.py +++ b/st2common/tests/unit/test_misc_utils.py @@ -24,71 +24,61 @@ from st2common.util.misc import sanitize_output from st2common.util.ujson import fast_deepcopy -__all__ = [ - 'MiscUtilTestCase' -] +__all__ = ["MiscUtilTestCase"] class MiscUtilTestCase(unittest2.TestCase): def test_rstrip_last_char(self): - self.assertEqual(rstrip_last_char(None, '\n'), None) - self.assertEqual(rstrip_last_char('stuff', None), 'stuff') - self.assertEqual(rstrip_last_char('', '\n'), '') - self.assertEqual(rstrip_last_char('foo', '\n'), 'foo') - self.assertEqual(rstrip_last_char('foo\n', '\n'), 'foo') - self.assertEqual(rstrip_last_char('foo\n\n', '\n'), 'foo\n') - self.assertEqual(rstrip_last_char('foo\r', '\r'), 'foo') - self.assertEqual(rstrip_last_char('foo\r\r', '\r'), 'foo\r') - self.assertEqual(rstrip_last_char('foo\r\n', '\r\n'), 'foo') - self.assertEqual(rstrip_last_char('foo\r\r\n', '\r\n'), 'foo\r') - self.assertEqual(rstrip_last_char('foo\n\r', '\r\n'), 'foo\n\r') + self.assertEqual(rstrip_last_char(None, "\n"), None) + self.assertEqual(rstrip_last_char("stuff", None), "stuff") + self.assertEqual(rstrip_last_char("", "\n"), "") + self.assertEqual(rstrip_last_char("foo", "\n"), "foo") + self.assertEqual(rstrip_last_char("foo\n", "\n"), "foo") + self.assertEqual(rstrip_last_char("foo\n\n", "\n"), "foo\n") + self.assertEqual(rstrip_last_char("foo\r", "\r"), "foo") + self.assertEqual(rstrip_last_char("foo\r\r", "\r"), "foo\r") + self.assertEqual(rstrip_last_char("foo\r\n", "\r\n"), "foo") + self.assertEqual(rstrip_last_char("foo\r\r\n", "\r\n"), "foo\r") + self.assertEqual(rstrip_last_char("foo\n\r", "\r\n"), "foo\n\r") def test_strip_shell_chars(self): self.assertEqual(strip_shell_chars(None), None) - self.assertEqual(strip_shell_chars('foo'), 'foo') - self.assertEqual(strip_shell_chars('foo\r'), 'foo') - self.assertEqual(strip_shell_chars('fo\ro\r'), 'fo\ro') - self.assertEqual(strip_shell_chars('foo\n'), 'foo') - self.assertEqual(strip_shell_chars('fo\no\n'), 'fo\no') - self.assertEqual(strip_shell_chars('foo\r\n'), 'foo') - self.assertEqual(strip_shell_chars('fo\no\r\n'), 'fo\no') - self.assertEqual(strip_shell_chars('foo\r\n\r\n'), 'foo\r\n') + self.assertEqual(strip_shell_chars("foo"), "foo") + self.assertEqual(strip_shell_chars("foo\r"), "foo") + self.assertEqual(strip_shell_chars("fo\ro\r"), "fo\ro") + self.assertEqual(strip_shell_chars("foo\n"), "foo") + self.assertEqual(strip_shell_chars("fo\no\n"), "fo\no") + self.assertEqual(strip_shell_chars("foo\r\n"), "foo") + self.assertEqual(strip_shell_chars("fo\no\r\n"), "fo\no") + self.assertEqual(strip_shell_chars("foo\r\n\r\n"), "foo\r\n") def test_lowercase_value(self): - value = 'TEST' - expected_value = 'test' + value = "TEST" + expected_value = "test" self.assertEqual(expected_value, lowercase_value(value=value)) - value = ['testA', 'TESTb', 'TESTC'] - expected_value = ['testa', 'testb', 'testc'] + value = ["testA", "TESTb", "TESTC"] + expected_value = ["testa", "testb", "testc"] self.assertEqual(expected_value, lowercase_value(value=value)) - value = { - 'testA': 'testB', - 'testC': 'TESTD', - 'TESTE': 'TESTE' - } - expected_value = { - 'testa': 'testb', - 'testc': 'testd', - 'teste': 'teste' - } + value = {"testA": "testB", "testC": "TESTD", "TESTE": "TESTE"} + expected_value = {"testa": "testb", "testc": "testd", "teste": "teste"} self.assertEqual(expected_value, lowercase_value(value=value)) def test_fast_deepcopy_success(self): values = [ - 'a', - u'٩(̾●̮̮̃̾•̃̾)۶', + "a", + "٩(̾●̮̮̃̾•̃̾)۶", 1, - [1, 2, '3', 'b'], - {'a': 1, 'b': '3333', 'c': 'd'}, + [1, 2, "3", "b"], + {"a": 1, "b": "3333", "c": "d"}, ] expected_values = [ - 'a', - u'٩(̾●̮̮̃̾•̃̾)۶', + "a", + "٩(̾●̮̮̃̾•̃̾)۶", 1, - [1, 2, '3', 'b'], - {'a': 1, 'b': '3333', 'c': 'd'}, + [1, 2, "3", "b"], + {"a": 1, "b": "3333", "c": "d"}, ] for value, expected_value in zip(values, expected_values): @@ -99,18 +89,18 @@ def test_fast_deepcopy_success(self): def test_sanitize_output_use_pyt_false(self): # pty is not used, \r\n shouldn't be replaced with \n input_strs = [ - 'foo', - 'foo\n', - 'foo\r\n', - 'foo\nbar\nbaz\n', - 'foo\r\nbar\r\nbaz\r\n', + "foo", + "foo\n", + "foo\r\n", + "foo\nbar\nbaz\n", + "foo\r\nbar\r\nbaz\r\n", ] expected = [ - 'foo', - 'foo', - 'foo', - 'foo\nbar\nbaz', - 'foo\r\nbar\r\nbaz', + "foo", + "foo", + "foo", + "foo\nbar\nbaz", + "foo\r\nbar\r\nbaz", ] for input_str, expected_output in zip(input_strs, expected): @@ -120,18 +110,18 @@ def test_sanitize_output_use_pyt_false(self): def test_sanitize_output_use_pyt_true(self): # pty is used, \r\n should be replaced with \n input_strs = [ - 'foo', - 'foo\n', - 'foo\r\n', - 'foo\nbar\nbaz\n', - 'foo\r\nbar\r\nbaz\r\n', + "foo", + "foo\n", + "foo\r\n", + "foo\nbar\nbaz\n", + "foo\r\nbar\r\nbaz\r\n", ] expected = [ - 'foo', - 'foo', - 'foo', - 'foo\nbar\nbaz', - 'foo\nbar\nbaz', + "foo", + "foo", + "foo", + "foo\nbar\nbaz", + "foo\nbar\nbaz", ] for input_str, expected_output in zip(input_strs, expected): diff --git a/st2common/tests/unit/test_model_utils_profiling.py b/st2common/tests/unit/test_model_utils_profiling.py index 2225e39a7e0..db37039c800 100644 --- a/st2common/tests/unit/test_model_utils_profiling.py +++ b/st2common/tests/unit/test_model_utils_profiling.py @@ -28,31 +28,37 @@ def setUp(self): super(MongoDBProfilingTestCase, self).setUp() disable_profiling() - @mock.patch('st2common.models.utils.profiling.LOG') + @mock.patch("st2common.models.utils.profiling.LOG") def test_logging_profiling_is_disabled(self, mock_log): disable_profiling() - queryset = User.query(name__in=['test1', 'test2'], order_by=['+aa', '-bb'], limit=1) + queryset = User.query( + name__in=["test1", "test2"], order_by=["+aa", "-bb"], limit=1 + ) result = log_query_and_profile_data_for_queryset(queryset=queryset) self.assertEqual(queryset, result) call_args_list = mock_log.debug.call_args_list self.assertItemsEqual(call_args_list, []) - @mock.patch('st2common.models.utils.profiling.LOG') + @mock.patch("st2common.models.utils.profiling.LOG") def test_logging_profiling_is_enabled(self, mock_log): enable_profiling() - queryset = User.query(name__in=['test1', 'test2'], order_by=['+aa', '-bb'], limit=1) + queryset = User.query( + name__in=["test1", "test2"], order_by=["+aa", "-bb"], limit=1 + ) result = log_query_and_profile_data_for_queryset(queryset=queryset) call_args_list = mock_log.debug.call_args_list call_args = call_args_list[0][0] call_kwargs = call_args_list[0][1] - expected_result = ("db.user_d_b.find({'name': {'$in': ['test1', 'test2']}})" - ".sort({aa: 1, bb: -1}).limit(1);") + expected_result = ( + "db.user_d_b.find({'name': {'$in': ['test1', 'test2']}})" + ".sort({aa: 1, bb: -1}).limit(1);" + ) self.assertEqual(queryset, result) self.assertIn(expected_result, call_args[0]) - self.assertIn('mongo_query', call_kwargs['extra']) - self.assertIn('mongo_shell_query', call_kwargs['extra']) + self.assertIn("mongo_query", call_kwargs["extra"]) + self.assertIn("mongo_shell_query", call_kwargs["extra"]) def test_logging_profiling_is_enabled_non_queryset_object(self): enable_profiling() diff --git a/st2common/tests/unit/test_mongoescape.py b/st2common/tests/unit/test_mongoescape.py index 05e3b7962f6..0ad12e28236 100644 --- a/st2common/tests/unit/test_mongoescape.py +++ b/st2common/tests/unit/test_mongoescape.py @@ -21,68 +21,70 @@ class TestMongoEscape(unittest.TestCase): def test_unnested(self): - field = {'k1.k1.k1': 'v1', 'k2$': 'v2', '$k3.': 'v3'} + field = {"k1.k1.k1": "v1", "k2$": "v2", "$k3.": "v3"} escaped = mongoescape.escape_chars(field) - self.assertEqual(escaped, {u'k1\uff0ek1\uff0ek1': 'v1', - u'k2\uff04': 'v2', - u'\uff04k3\uff0e': 'v3'}, 'Escaping failed.') + self.assertEqual( + escaped, + {"k1\uff0ek1\uff0ek1": "v1", "k2\uff04": "v2", "\uff04k3\uff0e": "v3"}, + "Escaping failed.", + ) unescaped = mongoescape.unescape_chars(escaped) - self.assertEqual(unescaped, field, 'Unescaping failed.') + self.assertEqual(unescaped, field, "Unescaping failed.") def test_nested(self): - nested_field = {'nk1.nk1.nk1': 'v1', 'nk2$': 'v2', '$nk3.': 'v3'} - field = {'k1.k1.k1': nested_field, 'k2$': 'v2', '$k3.': 'v3'} + nested_field = {"nk1.nk1.nk1": "v1", "nk2$": "v2", "$nk3.": "v3"} + field = {"k1.k1.k1": nested_field, "k2$": "v2", "$k3.": "v3"} escaped = mongoescape.escape_chars(field) - self.assertEqual(escaped, {u'k1\uff0ek1\uff0ek1': {u'\uff04nk3\uff0e': 'v3', - u'nk1\uff0enk1\uff0enk1': 'v1', - u'nk2\uff04': 'v2'}, - u'k2\uff04': 'v2', - u'\uff04k3\uff0e': 'v3'}, 'un-escaping failed.') + self.assertEqual( + escaped, + { + "k1\uff0ek1\uff0ek1": { + "\uff04nk3\uff0e": "v3", + "nk1\uff0enk1\uff0enk1": "v1", + "nk2\uff04": "v2", + }, + "k2\uff04": "v2", + "\uff04k3\uff0e": "v3", + }, + "un-escaping failed.", + ) unescaped = mongoescape.unescape_chars(escaped) - self.assertEqual(unescaped, field, 'Unescaping failed.') + self.assertEqual(unescaped, field, "Unescaping failed.") def test_unescaping_of_rule_criteria(self): # Verify that dot escaped in rule criteria is correctly escaped. # Note: In the past we used different character to escape dot in the # rule criteria. - escaped = { - u'k1\u2024k1\u2024k1': 'v1', - u'k2$': 'v2', - u'$k3\u2024': 'v3' - } - unescaped = { - 'k1.k1.k1': 'v1', - 'k2$': 'v2', - '$k3.': 'v3' - } + escaped = {"k1\u2024k1\u2024k1": "v1", "k2$": "v2", "$k3\u2024": "v3"} + unescaped = {"k1.k1.k1": "v1", "k2$": "v2", "$k3.": "v3"} result = mongoescape.unescape_chars(escaped) self.assertEqual(result, unescaped) def test_original_value(self): - field = {'k1.k2.k3': 'v1'} + field = {"k1.k2.k3": "v1"} escaped = mongoescape.escape_chars(field) - self.assertIn('k1.k2.k3', list(field.keys())) - self.assertIn(u'k1\uff0ek2\uff0ek3', list(escaped.keys())) + self.assertIn("k1.k2.k3", list(field.keys())) + self.assertIn("k1\uff0ek2\uff0ek3", list(escaped.keys())) unescaped = mongoescape.unescape_chars(escaped) - self.assertIn('k1.k2.k3', list(unescaped.keys())) - self.assertIn(u'k1\uff0ek2\uff0ek3', list(escaped.keys())) + self.assertIn("k1.k2.k3", list(unescaped.keys())) + self.assertIn("k1\uff0ek2\uff0ek3", list(escaped.keys())) def test_complex(self): field = { - 'k1.k2': [{'l1.l2': '123'}, {'l3.l4': '456'}], - 'k3': [{'l5.l6': '789'}], - 'k4.k5': [1, 2, 3], - 'k6': ['a', 'b'] + "k1.k2": [{"l1.l2": "123"}, {"l3.l4": "456"}], + "k3": [{"l5.l6": "789"}], + "k4.k5": [1, 2, 3], + "k6": ["a", "b"], } expected = { - u'k1\uff0ek2': [{u'l1\uff0el2': '123'}, {u'l3\uff0el4': '456'}], - 'k3': [{u'l5\uff0el6': '789'}], - u'k4\uff0ek5': [1, 2, 3], - 'k6': ['a', 'b'] + "k1\uff0ek2": [{"l1\uff0el2": "123"}, {"l3\uff0el4": "456"}], + "k3": [{"l5\uff0el6": "789"}], + "k4\uff0ek5": [1, 2, 3], + "k6": ["a", "b"], } escaped = mongoescape.escape_chars(field) @@ -93,17 +95,17 @@ def test_complex(self): def test_complex_list(self): field = [ - {'k1.k2': [{'l1.l2': '123'}, {'l3.l4': '456'}]}, - {'k3': [{'l5.l6': '789'}]}, - {'k4.k5': [1, 2, 3]}, - {'k6': ['a', 'b']} + {"k1.k2": [{"l1.l2": "123"}, {"l3.l4": "456"}]}, + {"k3": [{"l5.l6": "789"}]}, + {"k4.k5": [1, 2, 3]}, + {"k6": ["a", "b"]}, ] expected = [ - {u'k1\uff0ek2': [{u'l1\uff0el2': '123'}, {u'l3\uff0el4': '456'}]}, - {'k3': [{u'l5\uff0el6': '789'}]}, - {u'k4\uff0ek5': [1, 2, 3]}, - {'k6': ['a', 'b']} + {"k1\uff0ek2": [{"l1\uff0el2": "123"}, {"l3\uff0el4": "456"}]}, + {"k3": [{"l5\uff0el6": "789"}]}, + {"k4\uff0ek5": [1, 2, 3]}, + {"k6": ["a", "b"]}, ] escaped = mongoescape.escape_chars(field) diff --git a/st2common/tests/unit/test_notification_helper.py b/st2common/tests/unit/test_notification_helper.py index d169dd5a5f2..9c00ea4771a 100644 --- a/st2common/tests/unit/test_notification_helper.py +++ b/st2common/tests/unit/test_notification_helper.py @@ -20,7 +20,6 @@ class NotificationsHelperTestCase(unittest2.TestCase): - def test_model_transformations(self): notify = {} @@ -31,42 +30,56 @@ def test_model_transformations(self): notify_api = NotificationsHelper.from_model(notify_model) self.assertEqual(notify_api, {}) - notify['on-complete'] = { - 'message': 'Action completed.', - 'routes': [ - '66' - ], - 'data': { - 'foo': '{{foo}}', - 'bar': 1, - 'baz': [1, 2, 3] - } + notify["on-complete"] = { + "message": "Action completed.", + "routes": ["66"], + "data": {"foo": "{{foo}}", "bar": 1, "baz": [1, 2, 3]}, } - notify['on-success'] = { - 'message': 'Action succeeded.', - 'routes': [ - '100' - ], - 'data': { - 'foo': '{{foo}}', - 'bar': 1, - } + notify["on-success"] = { + "message": "Action succeeded.", + "routes": ["100"], + "data": { + "foo": "{{foo}}", + "bar": 1, + }, } notify_model = NotificationsHelper.to_model(notify) - self.assertEqual(notify['on-complete']['message'], notify_model.on_complete.message) - self.assertDictEqual(notify['on-complete']['data'], notify_model.on_complete.data) - self.assertListEqual(notify['on-complete']['routes'], notify_model.on_complete.routes) - self.assertEqual(notify['on-success']['message'], notify_model.on_success.message) - self.assertDictEqual(notify['on-success']['data'], notify_model.on_success.data) - self.assertListEqual(notify['on-success']['routes'], notify_model.on_success.routes) + self.assertEqual( + notify["on-complete"]["message"], notify_model.on_complete.message + ) + self.assertDictEqual( + notify["on-complete"]["data"], notify_model.on_complete.data + ) + self.assertListEqual( + notify["on-complete"]["routes"], notify_model.on_complete.routes + ) + self.assertEqual( + notify["on-success"]["message"], notify_model.on_success.message + ) + self.assertDictEqual(notify["on-success"]["data"], notify_model.on_success.data) + self.assertListEqual( + notify["on-success"]["routes"], notify_model.on_success.routes + ) notify_api = NotificationsHelper.from_model(notify_model) - self.assertEqual(notify['on-complete']['message'], notify_api['on-complete']['message']) - self.assertDictEqual(notify['on-complete']['data'], notify_api['on-complete']['data']) - self.assertListEqual(notify['on-complete']['routes'], notify_api['on-complete']['routes']) - self.assertEqual(notify['on-success']['message'], notify_api['on-success']['message']) - self.assertDictEqual(notify['on-success']['data'], notify_api['on-success']['data']) - self.assertListEqual(notify['on-success']['routes'], notify_api['on-success']['routes']) + self.assertEqual( + notify["on-complete"]["message"], notify_api["on-complete"]["message"] + ) + self.assertDictEqual( + notify["on-complete"]["data"], notify_api["on-complete"]["data"] + ) + self.assertListEqual( + notify["on-complete"]["routes"], notify_api["on-complete"]["routes"] + ) + self.assertEqual( + notify["on-success"]["message"], notify_api["on-success"]["message"] + ) + self.assertDictEqual( + notify["on-success"]["data"], notify_api["on-success"]["data"] + ) + self.assertListEqual( + notify["on-success"]["routes"], notify_api["on-success"]["routes"] + ) def test_model_transformations_missing_fields(self): notify = {} @@ -78,33 +91,39 @@ def test_model_transformations_missing_fields(self): notify_api = NotificationsHelper.from_model(notify_model) self.assertEqual(notify_api, {}) - notify['on-complete'] = { - 'routes': [ - '66' - ], - 'data': { - 'foo': '{{foo}}', - 'bar': 1, - 'baz': [1, 2, 3] - } + notify["on-complete"] = { + "routes": ["66"], + "data": {"foo": "{{foo}}", "bar": 1, "baz": [1, 2, 3]}, } - notify['on-success'] = { - 'routes': [ - '100' - ], - 'data': { - 'foo': '{{foo}}', - 'bar': 1, - } + notify["on-success"] = { + "routes": ["100"], + "data": { + "foo": "{{foo}}", + "bar": 1, + }, } notify_model = NotificationsHelper.to_model(notify) - self.assertDictEqual(notify['on-complete']['data'], notify_model.on_complete.data) - self.assertListEqual(notify['on-complete']['routes'], notify_model.on_complete.routes) - self.assertDictEqual(notify['on-success']['data'], notify_model.on_success.data) - self.assertListEqual(notify['on-success']['routes'], notify_model.on_success.routes) + self.assertDictEqual( + notify["on-complete"]["data"], notify_model.on_complete.data + ) + self.assertListEqual( + notify["on-complete"]["routes"], notify_model.on_complete.routes + ) + self.assertDictEqual(notify["on-success"]["data"], notify_model.on_success.data) + self.assertListEqual( + notify["on-success"]["routes"], notify_model.on_success.routes + ) notify_api = NotificationsHelper.from_model(notify_model) - self.assertDictEqual(notify['on-complete']['data'], notify_api['on-complete']['data']) - self.assertListEqual(notify['on-complete']['routes'], notify_api['on-complete']['routes']) - self.assertDictEqual(notify['on-success']['data'], notify_api['on-success']['data']) - self.assertListEqual(notify['on-success']['routes'], notify_api['on-success']['routes']) + self.assertDictEqual( + notify["on-complete"]["data"], notify_api["on-complete"]["data"] + ) + self.assertListEqual( + notify["on-complete"]["routes"], notify_api["on-complete"]["routes"] + ) + self.assertDictEqual( + notify["on-success"]["data"], notify_api["on-success"]["data"] + ) + self.assertListEqual( + notify["on-success"]["routes"], notify_api["on-success"]["routes"] + ) diff --git a/st2common/tests/unit/test_operators.py b/st2common/tests/unit/test_operators.py index 48f693af301..5917e4277c9 100644 --- a/st2common/tests/unit/test_operators.py +++ b/st2common/tests/unit/test_operators.py @@ -44,6 +44,7 @@ class ListOfDictsStrictEqualTest(unittest2.TestCase): We should test our comparison functions, even if they're only used in our other tests. """ + def test_empty_lists(self): self.assertTrue(list_of_dicts_strict_equal([], [])) @@ -54,65 +55,105 @@ def test_multiple_empty_dicts(self): self.assertTrue(list_of_dicts_strict_equal([{}, {}], [{}, {}])) def test_simple_dicts(self): - self.assertTrue(list_of_dicts_strict_equal([ - {'a': 1}, - ], [ - {'a': 1}, - ])) - - self.assertFalse(list_of_dicts_strict_equal([ - {'a': 1}, - ], [ - {'a': 2}, - ])) + self.assertTrue( + list_of_dicts_strict_equal( + [ + {"a": 1}, + ], + [ + {"a": 1}, + ], + ) + ) + + self.assertFalse( + list_of_dicts_strict_equal( + [ + {"a": 1}, + ], + [ + {"a": 2}, + ], + ) + ) def test_lists_of_different_lengths(self): - self.assertFalse(list_of_dicts_strict_equal([ - {'a': 1}, - ], [ - {'a': 1}, - {'b': 2}, - ])) - - self.assertFalse(list_of_dicts_strict_equal([ - {'a': 1}, - {'b': 2}, - ], [ - {'a': 1}, - ])) + self.assertFalse( + list_of_dicts_strict_equal( + [ + {"a": 1}, + ], + [ + {"a": 1}, + {"b": 2}, + ], + ) + ) + + self.assertFalse( + list_of_dicts_strict_equal( + [ + {"a": 1}, + {"b": 2}, + ], + [ + {"a": 1}, + ], + ) + ) def test_less_simple_dicts(self): - self.assertTrue(list_of_dicts_strict_equal([ - {'a': 1}, - {'b': 2}, - ], [ - {'a': 1}, - {'b': 2}, - ])) - - self.assertTrue(list_of_dicts_strict_equal([ - {'a': 1}, - {'a': 1}, - ], [ - {'a': 1}, - {'a': 1}, - ])) - - self.assertFalse(list_of_dicts_strict_equal([ - {'a': 1}, - {'a': 1}, - ], [ - {'a': 1}, - {'b': 2}, - ])) - - self.assertFalse(list_of_dicts_strict_equal([ - {'a': 1}, - {'b': 2}, - ], [ - {'a': 1}, - {'a': 1}, - ])) + self.assertTrue( + list_of_dicts_strict_equal( + [ + {"a": 1}, + {"b": 2}, + ], + [ + {"a": 1}, + {"b": 2}, + ], + ) + ) + + self.assertTrue( + list_of_dicts_strict_equal( + [ + {"a": 1}, + {"a": 1}, + ], + [ + {"a": 1}, + {"a": 1}, + ], + ) + ) + + self.assertFalse( + list_of_dicts_strict_equal( + [ + {"a": 1}, + {"a": 1}, + ], + [ + {"a": 1}, + {"b": 2}, + ], + ) + ) + + self.assertFalse( + list_of_dicts_strict_equal( + [ + {"a": 1}, + {"b": 2}, + ], + [ + {"a": 1}, + {"a": 1}, + ], + ) + ) class SearchOperatorTest(unittest2.TestCase): @@ -120,774 +161,850 @@ class SearchOperatorTest(unittest2.TestCase): # parser. As such, its tests are much more complex than other commands, so we # pull its tests out into their own test case. def test_search_with_weird_condition(self): - op = operators.get_operator('search') + op = operators.get_operator("search") with self.assertRaises(operators.UnrecognizedConditionError): - op([], [], 'weird', None) + op([], [], "weird", None) def test_search_any_true(self): - op = operators.get_operator('search') + op = operators.get_operator("search") called_function_args = [] def record_function_args(criterion_k, criterion_v, payload_lookup): - called_function_args.append({ - 'criterion_k': criterion_k, - 'criterion_v': criterion_v, - 'payload_lookup': { - 'field_name': payload_lookup.get_value('item.field_name')[0], - 'to_value': payload_lookup.get_value('item.to_value')[0], - }, - }) - return (len(called_function_args) < 3) + called_function_args.append( + { + "criterion_k": criterion_k, + "criterion_v": criterion_v, + "payload_lookup": { + "field_name": payload_lookup.get_value("item.field_name")[0], + "to_value": payload_lookup.get_value("item.to_value")[0], + }, + } + ) + return len(called_function_args) < 3 payload = [ { - 'field_name': "Status", - 'to_value': "Approved", - }, { - 'field_name': "Assigned to", - 'to_value': "Stanley", - } + "field_name": "Status", + "to_value": "Approved", + }, + { + "field_name": "Assigned to", + "to_value": "Stanley", + }, ] criteria_pattern = { - 'item.field_name': { - 'type': "equals", - 'pattern': "Status", + "item.field_name": { + "type": "equals", + "pattern": "Status", + }, + "item.to_value": { + "type": "equals", + "pattern": "Approved", }, - 'item.to_value': { - 'type': "equals", - 'pattern': "Approved", - } } - result = op(payload, criteria_pattern, 'any', record_function_args) + result = op(payload, criteria_pattern, "any", record_function_args) self.assertTrue(result) - self.assertTrue(list_of_dicts_strict_equal(called_function_args, [ - # Outer loop: payload -> {'field_name': "Status", 'to_value': "Approved"} - { - # Inner loop: criterion -> item.field_name: {'type': "equals", 'pattern': "Status"} - 'criterion_k': 'item.field_name', - 'criterion_v': { - 'type': "equals", - 'pattern': "Status", - }, - 'payload_lookup': { - 'field_name': "Status", - 'to_value': "Approved", - }, - }, { - # Inner loop: criterion -> item.to_value: {'type': "equals", 'pattern': "Approved"} - 'criterion_k': 'item.to_value', - 'criterion_v': { - 'type': "equals", - 'pattern': "Approved", - }, - 'payload_lookup': { - 'field_name': "Status", - 'to_value': "Approved", - }, - }, - # Outer loop: payload -> {'field_name': "Assigned to", 'to_value': "Stanley"} - { - # Inner loop: criterion -> item.field_name: {'type': "equals", 'pattern': "Status"} - 'criterion_k': 'item.field_name', - 'criterion_v': { - 'type': "equals", - 'pattern': "Status", - }, - 'payload_lookup': { - 'field_name': "Assigned to", - 'to_value': "Stanley", - }, - }, { - # Inner loop: criterion -> item.to_value: {'type': "equals", 'pattern': "Approved"} - 'criterion_k': 'item.to_value', - 'criterion_v': { - 'type': "equals", - 'pattern': "Approved", - }, - 'payload_lookup': { - 'field_name': "Assigned to", - 'to_value': "Stanley", - }, - } - ])) + self.assertTrue( + list_of_dicts_strict_equal( + called_function_args, + [ + # Outer loop: payload -> {'field_name': "Status", 'to_value': "Approved"} + { + # Inner loop: criterion -> item.field_name: {'type': "equals", 'pattern': "Status"} + "criterion_k": "item.field_name", + "criterion_v": { + "type": "equals", + "pattern": "Status", + }, + "payload_lookup": { + "field_name": "Status", + "to_value": "Approved", + }, + }, + { + # Inner loop: criterion -> item.to_value: {'type': "equals", 'pattern': "Approved"} + "criterion_k": "item.to_value", + "criterion_v": { + "type": "equals", + "pattern": "Approved", + }, + "payload_lookup": { + "field_name": "Status", + "to_value": "Approved", + }, + }, + # Outer loop: payload -> {'field_name': "Assigned to", 'to_value': "Stanley"} + { + # Inner loop: criterion -> item.field_name: {'type': "equals", 'pattern': "Status"} + "criterion_k": "item.field_name", + "criterion_v": { + "type": "equals", + "pattern": "Status", + }, + "payload_lookup": { + "field_name": "Assigned to", + "to_value": "Stanley", + }, + }, + { + # Inner loop: criterion -> item.to_value: {'type': "equals", 'pattern': "Approved"} + "criterion_k": "item.to_value", + "criterion_v": { + "type": "equals", + "pattern": "Approved", + }, + "payload_lookup": { + "field_name": "Assigned to", + "to_value": "Stanley", + }, + }, + ], + ) + ) def test_search_any_false(self): - op = operators.get_operator('search') + op = operators.get_operator("search") called_function_args = [] def record_function_args(criterion_k, criterion_v, payload_lookup): - called_function_args.append({ - 'criterion_k': criterion_k, - 'criterion_v': criterion_v, - 'payload_lookup': { - 'field_name': payload_lookup.get_value('item.field_name')[0], - 'to_value': payload_lookup.get_value('item.to_value')[0], - }, - }) + called_function_args.append( + { + "criterion_k": criterion_k, + "criterion_v": criterion_v, + "payload_lookup": { + "field_name": payload_lookup.get_value("item.field_name")[0], + "to_value": payload_lookup.get_value("item.to_value")[0], + }, + } + ) return (len(called_function_args) % 2) == 0 payload = [ { - 'field_name': "Status", - 'to_value': "Denied", - }, { - 'field_name': "Assigned to", - 'to_value': "Stanley", - } + "field_name": "Status", + "to_value": "Denied", + }, + { + "field_name": "Assigned to", + "to_value": "Stanley", + }, ] criteria_pattern = { - 'item.field_name': { - 'type': "equals", - 'pattern': "Status", + "item.field_name": { + "type": "equals", + "pattern": "Status", + }, + "item.to_value": { + "type": "equals", + "pattern": "Approved", }, - 'item.to_value': { - 'type': "equals", - 'pattern': "Approved", - } } - result = op(payload, criteria_pattern, 'any', record_function_args) + result = op(payload, criteria_pattern, "any", record_function_args) self.assertFalse(result) - self.assertEqual(called_function_args, [ - # Outer loop: payload -> {'field_name': "Status", 'to_value': "Denied"} - { - # Inner loop: criterion -> item.field_name: {'type': "equals", 'pattern': "Status"} - 'criterion_k': 'item.field_name', - 'criterion_v': { - 'type': "equals", - 'pattern': "Status", - }, - 'payload_lookup': { - 'field_name': "Status", - 'to_value': "Denied", - }, - }, { - # Inner loop: criterion -> item.to_value: {'type': "equals", 'pattern': "Approved"} - 'criterion_k': 'item.to_value', - 'criterion_v': { - 'type': "equals", - 'pattern': "Approved", - }, - 'payload_lookup': { - 'field_name': "Status", - 'to_value': "Denied", - }, - }, - # Outer loop: payload -> {'field_name': "Assigned to", 'to_value': "Stanley"} - { - # Inner loop: criterion -> item.to_value: {'type': "equals", 'pattern': "Approved"} - 'criterion_k': 'item.field_name', - 'criterion_v': { - 'type': "equals", - 'pattern': "Status", - }, - 'payload_lookup': { - 'field_name': "Assigned to", - 'to_value': "Stanley", - }, - }, { - # Inner loop: criterion -> item.to_value: {'type': "equals", 'pattern': "Approved"} - 'criterion_k': 'item.to_value', - 'criterion_v': { - 'type': "equals", - 'pattern': "Approved", - }, - 'payload_lookup': { - 'field_name': "Assigned to", - 'to_value': "Stanley", - }, - } - ]) + self.assertEqual( + called_function_args, + [ + # Outer loop: payload -> {'field_name': "Status", 'to_value': "Denied"} + { + # Inner loop: criterion -> item.field_name: {'type': "equals", 'pattern': "Status"} + "criterion_k": "item.field_name", + "criterion_v": { + "type": "equals", + "pattern": "Status", + }, + "payload_lookup": { + "field_name": "Status", + "to_value": "Denied", + }, + }, + { + # Inner loop: criterion -> item.to_value: {'type': "equals", 'pattern': "Approved"} + "criterion_k": "item.to_value", + "criterion_v": { + "type": "equals", + "pattern": "Approved", + }, + "payload_lookup": { + "field_name": "Status", + "to_value": "Denied", + }, + }, + # Outer loop: payload -> {'field_name': "Assigned to", 'to_value': "Stanley"} + { + # Inner loop: criterion -> item.to_value: {'type': "equals", 'pattern': "Approved"} + "criterion_k": "item.field_name", + "criterion_v": { + "type": "equals", + "pattern": "Status", + }, + "payload_lookup": { + "field_name": "Assigned to", + "to_value": "Stanley", + }, + }, + { + # Inner loop: criterion -> item.to_value: {'type': "equals", 'pattern': "Approved"} + "criterion_k": "item.to_value", + "criterion_v": { + "type": "equals", + "pattern": "Approved", + }, + "payload_lookup": { + "field_name": "Assigned to", + "to_value": "Stanley", + }, + }, + ], + ) def test_search_all_false(self): - op = operators.get_operator('search') + op = operators.get_operator("search") called_function_args = [] def record_function_args(criterion_k, criterion_v, payload_lookup): - called_function_args.append({ - 'criterion_k': criterion_k, - 'criterion_v': criterion_v, - 'payload_lookup': { - 'field_name': payload_lookup.get_value('item.field_name')[0], - 'to_value': payload_lookup.get_value('item.to_value')[0], - }, - }) + called_function_args.append( + { + "criterion_k": criterion_k, + "criterion_v": criterion_v, + "payload_lookup": { + "field_name": payload_lookup.get_value("item.field_name")[0], + "to_value": payload_lookup.get_value("item.to_value")[0], + }, + } + ) return (len(called_function_args) % 2) == 0 payload = [ { - 'field_name': "Status", - 'to_value': "Approved", - }, { - 'field_name': "Assigned to", - 'to_value': "Stanley", - } + "field_name": "Status", + "to_value": "Approved", + }, + { + "field_name": "Assigned to", + "to_value": "Stanley", + }, ] criteria_pattern = { - 'item.field_name': { - 'type': "equals", - 'pattern': "Status", + "item.field_name": { + "type": "equals", + "pattern": "Status", + }, + "item.to_value": { + "type": "equals", + "pattern": "Approved", }, - 'item.to_value': { - 'type': "equals", - 'pattern': "Approved", - } } - result = op(payload, criteria_pattern, 'all', record_function_args) + result = op(payload, criteria_pattern, "all", record_function_args) self.assertFalse(result) - self.assertEqual(called_function_args, [ - # Outer loop: payload -> {'field_name': "Status", 'to_value': "Approved"} - { - # Inner loop: item.field_name -> {'type': "equals", 'pattern': "Status"} - 'criterion_k': 'item.field_name', - 'criterion_v': { - 'type': "equals", - 'pattern': "Status", - }, - 'payload_lookup': { - 'field_name': "Status", - 'to_value': "Approved", - }, - }, { - # Inner loop: item.to_value -> {'type': "equals", 'pattern': "Approved"} - 'criterion_k': 'item.to_value', - 'criterion_v': { - 'type': "equals", - 'pattern': "Approved", - }, - 'payload_lookup': { - 'field_name': "Status", - 'to_value': "Approved", - }, - }, - # Outer loop: payload -> {'field_name': "Assigned to", 'to_value': "Stanley"} - { - # Inner loop: item.field_name -> {'type': "equals", 'pattern': "Status"} - 'criterion_k': 'item.field_name', - 'criterion_v': { - 'type': "equals", - 'pattern': "Status", - }, - 'payload_lookup': { - 'field_name': "Assigned to", - 'to_value': "Stanley", - }, - }, { - # Inner loop: item.to_value -> {'type': "equals", 'pattern': "Approved"} - 'criterion_k': 'item.to_value', - 'criterion_v': { - 'type': "equals", - 'pattern': "Approved", - }, - 'payload_lookup': { - 'field_name': "Assigned to", - 'to_value': "Stanley", - }, - } - ]) + self.assertEqual( + called_function_args, + [ + # Outer loop: payload -> {'field_name': "Status", 'to_value': "Approved"} + { + # Inner loop: item.field_name -> {'type': "equals", 'pattern': "Status"} + "criterion_k": "item.field_name", + "criterion_v": { + "type": "equals", + "pattern": "Status", + }, + "payload_lookup": { + "field_name": "Status", + "to_value": "Approved", + }, + }, + { + # Inner loop: item.to_value -> {'type': "equals", 'pattern': "Approved"} + "criterion_k": "item.to_value", + "criterion_v": { + "type": "equals", + "pattern": "Approved", + }, + "payload_lookup": { + "field_name": "Status", + "to_value": "Approved", + }, + }, + # Outer loop: payload -> {'field_name': "Assigned to", 'to_value': "Stanley"} + { + # Inner loop: item.field_name -> {'type': "equals", 'pattern': "Status"} + "criterion_k": "item.field_name", + "criterion_v": { + "type": "equals", + "pattern": "Status", + }, + "payload_lookup": { + "field_name": "Assigned to", + "to_value": "Stanley", + }, + }, + { + # Inner loop: item.to_value -> {'type': "equals", 'pattern': "Approved"} + "criterion_k": "item.to_value", + "criterion_v": { + "type": "equals", + "pattern": "Approved", + }, + "payload_lookup": { + "field_name": "Assigned to", + "to_value": "Stanley", + }, + }, + ], + ) def test_search_all_true(self): - op = operators.get_operator('search') + op = operators.get_operator("search") called_function_args = [] def record_function_args(criterion_k, criterion_v, payload_lookup): - called_function_args.append({ - 'criterion_k': criterion_k, - 'criterion_v': criterion_v, - 'payload_lookup': { - 'field_name': payload_lookup.get_value('item.field_name')[0], - 'to_value': payload_lookup.get_value('item.to_value')[0], - }, - }) + called_function_args.append( + { + "criterion_k": criterion_k, + "criterion_v": criterion_v, + "payload_lookup": { + "field_name": payload_lookup.get_value("item.field_name")[0], + "to_value": payload_lookup.get_value("item.to_value")[0], + }, + } + ) return True payload = [ { - 'field_name': "Status", - 'to_value': "Approved", - }, { - 'field_name': "Signed off by", - 'to_value': "Approved", - } + "field_name": "Status", + "to_value": "Approved", + }, + { + "field_name": "Signed off by", + "to_value": "Approved", + }, ] criteria_pattern = { - 'item.field_name': { - 'type': "startswith", - 'pattern': "S", + "item.field_name": { + "type": "startswith", + "pattern": "S", + }, + "item.to_value": { + "type": "equals", + "pattern": "Approved", }, - 'item.to_value': { - 'type': "equals", - 'pattern': "Approved", - } } - result = op(payload, criteria_pattern, 'all', record_function_args) + result = op(payload, criteria_pattern, "all", record_function_args) self.assertTrue(result) - self.assertEqual(called_function_args, [ - # Outer loop: payload -> {'field_name': "Status", 'to_value': "Approved"} - { - # Inner loop: item.field_name -> {'type': "startswith", 'pattern': "S"} - 'criterion_k': 'item.field_name', - 'criterion_v': { - 'type': "startswith", - 'pattern': "S", - }, - 'payload_lookup': { - 'field_name': "Status", - 'to_value': "Approved", - }, - }, { - # Inner loop: item.to_value -> {'type': "equals", 'pattern': "Approved"} - 'criterion_k': 'item.to_value', - 'criterion_v': { - 'type': "equals", - 'pattern': "Approved", - }, - 'payload_lookup': { - 'field_name': "Status", - 'to_value': "Approved", - }, - }, - # Outer loop: payload -> {'field_name': "Signed off by", 'to_value': "Approved"} - { - # Inner loop: item.field_name -> {'type': "startswith", 'pattern': "S"} - 'criterion_k': 'item.field_name', - 'criterion_v': { - 'type': "startswith", - 'pattern': "S", - }, - 'payload_lookup': { - 'field_name': "Signed off by", - 'to_value': "Approved", - }, - }, { - # Inner loop: item.to_value -> {'type': "equals", 'pattern': "Approved"} - 'criterion_k': 'item.to_value', - 'criterion_v': { - 'type': "equals", - 'pattern': "Approved", - }, - 'payload_lookup': { - 'field_name': "Signed off by", - 'to_value': "Approved", - }, - } - ]) + self.assertEqual( + called_function_args, + [ + # Outer loop: payload -> {'field_name': "Status", 'to_value': "Approved"} + { + # Inner loop: item.field_name -> {'type': "startswith", 'pattern': "S"} + "criterion_k": "item.field_name", + "criterion_v": { + "type": "startswith", + "pattern": "S", + }, + "payload_lookup": { + "field_name": "Status", + "to_value": "Approved", + }, + }, + { + # Inner loop: item.to_value -> {'type': "equals", 'pattern': "Approved"} + "criterion_k": "item.to_value", + "criterion_v": { + "type": "equals", + "pattern": "Approved", + }, + "payload_lookup": { + "field_name": "Status", + "to_value": "Approved", + }, + }, + # Outer loop: payload -> {'field_name': "Signed off by", 'to_value': "Approved"} + { + # Inner loop: item.field_name -> {'type': "startswith", 'pattern': "S"} + "criterion_k": "item.field_name", + "criterion_v": { + "type": "startswith", + "pattern": "S", + }, + "payload_lookup": { + "field_name": "Signed off by", + "to_value": "Approved", + }, + }, + { + # Inner loop: item.to_value -> {'type': "equals", 'pattern': "Approved"} + "criterion_k": "item.to_value", + "criterion_v": { + "type": "equals", + "pattern": "Approved", + }, + "payload_lookup": { + "field_name": "Signed off by", + "to_value": "Approved", + }, + }, + ], + ) class OperatorTest(unittest2.TestCase): def test_matchwildcard(self): - op = operators.get_operator('matchwildcard') - self.assertTrue(op('v1', 'v1'), 'Failed matchwildcard.') + op = operators.get_operator("matchwildcard") + self.assertTrue(op("v1", "v1"), "Failed matchwildcard.") - self.assertFalse(op('test foo test', 'foo'), 'Passed matchwildcard.') - self.assertTrue(op('test foo test', '*foo*'), 'Failed matchwildcard.') - self.assertTrue(op('bar', 'b*r'), 'Failed matchwildcard.') - self.assertTrue(op('bar', 'b?r'), 'Failed matchwildcard.') + self.assertFalse(op("test foo test", "foo"), "Passed matchwildcard.") + self.assertTrue(op("test foo test", "*foo*"), "Failed matchwildcard.") + self.assertTrue(op("bar", "b*r"), "Failed matchwildcard.") + self.assertTrue(op("bar", "b?r"), "Failed matchwildcard.") # Mixing bytes and strings / unicode should still work - self.assertTrue(op(b'bar', 'b?r'), 'Failed matchwildcard.') - self.assertTrue(op('bar', b'b?r'), 'Failed matchwildcard.') - self.assertTrue(op(b'bar', b'b?r'), 'Failed matchwildcard.') - self.assertTrue(op(u'bar', b'b?r'), 'Failed matchwildcard.') - self.assertTrue(op(u'bar', u'b?r'), 'Failed matchwildcard.') + self.assertTrue(op(b"bar", "b?r"), "Failed matchwildcard.") + self.assertTrue(op("bar", b"b?r"), "Failed matchwildcard.") + self.assertTrue(op(b"bar", b"b?r"), "Failed matchwildcard.") + self.assertTrue(op("bar", b"b?r"), "Failed matchwildcard.") + self.assertTrue(op("bar", "b?r"), "Failed matchwildcard.") - self.assertFalse(op('1', None), 'Passed matchwildcard with None as criteria_pattern.') + self.assertFalse( + op("1", None), "Passed matchwildcard with None as criteria_pattern." + ) def test_matchregex(self): - op = operators.get_operator('matchregex') - self.assertTrue(op('v1', 'v1$'), 'Failed matchregex.') + op = operators.get_operator("matchregex") + self.assertTrue(op("v1", "v1$"), "Failed matchregex.") # Multi line string, make sure re.DOTALL is used - string = '''ponies + string = """ponies moar foo bar yeah! - ''' - self.assertTrue(op(string, '.*bar.*'), 'Failed matchregex.') + """ + self.assertTrue(op(string, ".*bar.*"), "Failed matchregex.") - string = 'foo\r\nponies\nbar\nfooooo' - self.assertTrue(op(string, '.*ponies.*'), 'Failed matchregex.') + string = "foo\r\nponies\nbar\nfooooo" + self.assertTrue(op(string, ".*ponies.*"), "Failed matchregex.") # Mixing bytes and strings / unicode should still work - self.assertTrue(op(b'foo ponies bar', '.*ponies.*'), 'Failed matchregex.') - self.assertTrue(op('foo ponies bar', b'.*ponies.*'), 'Failed matchregex.') - self.assertTrue(op(b'foo ponies bar', b'.*ponies.*'), 'Failed matchregex.') - self.assertTrue(op(b'foo ponies bar', u'.*ponies.*'), 'Failed matchregex.') - self.assertTrue(op(u'foo ponies bar', u'.*ponies.*'), 'Failed matchregex.') + self.assertTrue(op(b"foo ponies bar", ".*ponies.*"), "Failed matchregex.") + self.assertTrue(op("foo ponies bar", b".*ponies.*"), "Failed matchregex.") + self.assertTrue(op(b"foo ponies bar", b".*ponies.*"), "Failed matchregex.") + self.assertTrue(op(b"foo ponies bar", ".*ponies.*"), "Failed matchregex.") + self.assertTrue(op("foo ponies bar", ".*ponies.*"), "Failed matchregex.") def test_iregex(self): - op = operators.get_operator('iregex') - self.assertTrue(op('V1', 'v1$'), 'Failed iregex.') + op = operators.get_operator("iregex") + self.assertTrue(op("V1", "v1$"), "Failed iregex.") - string = 'fooPONIESbarfooooo' - self.assertTrue(op(string, 'ponies'), 'Failed iregex.') + string = "fooPONIESbarfooooo" + self.assertTrue(op(string, "ponies"), "Failed iregex.") # Mixing bytes and strings / unicode should still work - self.assertTrue(op(b'fooPONIESbarfooooo', 'ponies'), 'Failed iregex.') - self.assertTrue(op('fooPONIESbarfooooo', b'ponies'), 'Failed iregex.') - self.assertTrue(op(b'fooPONIESbarfooooo', b'ponies'), 'Failed iregex.') - self.assertTrue(op(b'fooPONIESbarfooooo', u'ponies'), 'Failed iregex.') - self.assertTrue(op(u'fooPONIESbarfooooo', u'ponies'), 'Failed iregex.') + self.assertTrue(op(b"fooPONIESbarfooooo", "ponies"), "Failed iregex.") + self.assertTrue(op("fooPONIESbarfooooo", b"ponies"), "Failed iregex.") + self.assertTrue(op(b"fooPONIESbarfooooo", b"ponies"), "Failed iregex.") + self.assertTrue(op(b"fooPONIESbarfooooo", "ponies"), "Failed iregex.") + self.assertTrue(op("fooPONIESbarfooooo", "ponies"), "Failed iregex.") def test_iregex_fail(self): - op = operators.get_operator('iregex') - self.assertFalse(op('V1_foo', 'v1$'), 'Passed iregex.') - self.assertFalse(op('1', None), 'Passed iregex with None as criteria_pattern.') + op = operators.get_operator("iregex") + self.assertFalse(op("V1_foo", "v1$"), "Passed iregex.") + self.assertFalse(op("1", None), "Passed iregex with None as criteria_pattern.") def test_regex(self): - op = operators.get_operator('regex') - self.assertTrue(op('v1', 'v1$'), 'Failed regex.') + op = operators.get_operator("regex") + self.assertTrue(op("v1", "v1$"), "Failed regex.") - string = 'fooponiesbarfooooo' - self.assertTrue(op(string, 'ponies'), 'Failed regex.') + string = "fooponiesbarfooooo" + self.assertTrue(op(string, "ponies"), "Failed regex.") # Example with | modifier - string = 'apple ponies oranges' - self.assertTrue(op(string, '(ponies|unicorns)'), 'Failed regex.') + string = "apple ponies oranges" + self.assertTrue(op(string, "(ponies|unicorns)"), "Failed regex.") - string = 'apple unicorns oranges' - self.assertTrue(op(string, '(ponies|unicorns)'), 'Failed regex.') + string = "apple unicorns oranges" + self.assertTrue(op(string, "(ponies|unicorns)"), "Failed regex.") # Mixing bytes and strings / unicode should still work - self.assertTrue(op(b'apples unicorns oranges', '(ponies|unicorns)'), 'Failed regex.') - self.assertTrue(op('apples unicorns oranges', b'(ponies|unicorns)'), 'Failed regex.') - self.assertTrue(op(b'apples unicorns oranges', b'(ponies|unicorns)'), 'Failed regex.') - self.assertTrue(op(b'apples unicorns oranges', u'(ponies|unicorns)'), 'Failed regex.') - self.assertTrue(op(u'apples unicorns oranges', u'(ponies|unicorns)'), 'Failed regex.') - - string = 'apple unicorns oranges' - self.assertFalse(op(string, '(pikachu|snorlax|charmander)'), 'Passed regex.') + self.assertTrue( + op(b"apples unicorns oranges", "(ponies|unicorns)"), "Failed regex." + ) + self.assertTrue( + op("apples unicorns oranges", b"(ponies|unicorns)"), "Failed regex." + ) + self.assertTrue( + op(b"apples unicorns oranges", b"(ponies|unicorns)"), "Failed regex." + ) + self.assertTrue( + op(b"apples unicorns oranges", "(ponies|unicorns)"), "Failed regex." + ) + self.assertTrue( + op("apples unicorns oranges", "(ponies|unicorns)"), "Failed regex." + ) + + string = "apple unicorns oranges" + self.assertFalse(op(string, "(pikachu|snorlax|charmander)"), "Passed regex.") def test_regex_fail(self): - op = operators.get_operator('regex') - self.assertFalse(op('v1_foo', 'v1$'), 'Passed regex.') + op = operators.get_operator("regex") + self.assertFalse(op("v1_foo", "v1$"), "Passed regex.") - string = 'fooPONIESbarfooooo' - self.assertFalse(op(string, 'ponies'), 'Passed regex.') + string = "fooPONIESbarfooooo" + self.assertFalse(op(string, "ponies"), "Passed regex.") - self.assertFalse(op('1', None), 'Passed regex with None as criteria_pattern.') + self.assertFalse(op("1", None), "Passed regex with None as criteria_pattern.") def test_matchregex_case_variants(self): - op = operators.get_operator('MATCHREGEX') - self.assertTrue(op('v1', 'v1$'), 'Failed matchregex.') - op = operators.get_operator('MATCHregex') - self.assertTrue(op('v1', 'v1$'), 'Failed matchregex.') + op = operators.get_operator("MATCHREGEX") + self.assertTrue(op("v1", "v1$"), "Failed matchregex.") + op = operators.get_operator("MATCHregex") + self.assertTrue(op("v1", "v1$"), "Failed matchregex.") def test_matchregex_fail(self): - op = operators.get_operator('matchregex') - self.assertFalse(op('v1_foo', 'v1$'), 'Passed matchregex.') - self.assertFalse(op('1', None), 'Passed matchregex with None as criteria_pattern.') + op = operators.get_operator("matchregex") + self.assertFalse(op("v1_foo", "v1$"), "Passed matchregex.") + self.assertFalse( + op("1", None), "Passed matchregex with None as criteria_pattern." + ) def test_equals_numeric(self): - op = operators.get_operator('equals') - self.assertTrue(op(1, 1), 'Failed equals.') + op = operators.get_operator("equals") + self.assertTrue(op(1, 1), "Failed equals.") def test_equals_string(self): - op = operators.get_operator('equals') - self.assertTrue(op('1', '1'), 'Failed equals.') - self.assertTrue(op('', ''), 'Failed equals.') + op = operators.get_operator("equals") + self.assertTrue(op("1", "1"), "Failed equals.") + self.assertTrue(op("", ""), "Failed equals.") # Mixing bytes and strings / unicode should still work - self.assertTrue(op(b'1', '1'), 'Failed equals.') - self.assertTrue(op('1', b'1'), 'Failed equals.') - self.assertTrue(op(b'1', b'1'), 'Failed equals.') - self.assertTrue(op(b'1', u'1'), 'Failed equals.') - self.assertTrue(op(u'1', u'1'), 'Failed equals.') + self.assertTrue(op(b"1", "1"), "Failed equals.") + self.assertTrue(op("1", b"1"), "Failed equals.") + self.assertTrue(op(b"1", b"1"), "Failed equals.") + self.assertTrue(op(b"1", "1"), "Failed equals.") + self.assertTrue(op("1", "1"), "Failed equals.") def test_equals_fail(self): - op = operators.get_operator('equals') - self.assertFalse(op('1', '2'), 'Passed equals.') - self.assertFalse(op('1', None), 'Passed equals with None as criteria_pattern.') + op = operators.get_operator("equals") + self.assertFalse(op("1", "2"), "Passed equals.") + self.assertFalse(op("1", None), "Passed equals with None as criteria_pattern.") def test_nequals(self): - op = operators.get_operator('nequals') - self.assertTrue(op('foo', 'bar')) - self.assertTrue(op('foo', 'foo1')) - self.assertTrue(op('foo', 'FOO')) - self.assertTrue(op('True', True)) - self.assertTrue(op('None', None)) - - self.assertFalse(op('True', 'True')) + op = operators.get_operator("nequals") + self.assertTrue(op("foo", "bar")) + self.assertTrue(op("foo", "foo1")) + self.assertTrue(op("foo", "FOO")) + self.assertTrue(op("True", True)) + self.assertTrue(op("None", None)) + + self.assertFalse(op("True", "True")) self.assertFalse(op(None, None)) def test_iequals(self): - op = operators.get_operator('iequals') - self.assertTrue(op('ABC', 'ABC'), 'Failed iequals.') - self.assertTrue(op('ABC', 'abc'), 'Failed iequals.') - self.assertTrue(op('AbC', 'aBc'), 'Failed iequals.') + op = operators.get_operator("iequals") + self.assertTrue(op("ABC", "ABC"), "Failed iequals.") + self.assertTrue(op("ABC", "abc"), "Failed iequals.") + self.assertTrue(op("AbC", "aBc"), "Failed iequals.") # Mixing bytes and strings / unicode should still work - self.assertTrue(op(b'AbC', 'aBc'), 'Failed iequals.') - self.assertTrue(op('AbC', b'aBc'), 'Failed iequals.') - self.assertTrue(op(b'AbC', b'aBc'), 'Failed iequals.') - self.assertTrue(op(b'AbC', u'aBc'), 'Failed iequals.') - self.assertTrue(op(u'AbC', u'aBc'), 'Failed iequals.') + self.assertTrue(op(b"AbC", "aBc"), "Failed iequals.") + self.assertTrue(op("AbC", b"aBc"), "Failed iequals.") + self.assertTrue(op(b"AbC", b"aBc"), "Failed iequals.") + self.assertTrue(op(b"AbC", "aBc"), "Failed iequals.") + self.assertTrue(op("AbC", "aBc"), "Failed iequals.") def test_iequals_fail(self): - op = operators.get_operator('iequals') - self.assertFalse(op('ABC', 'BCA'), 'Passed iequals.') - self.assertFalse(op('1', None), 'Passed iequals with None as criteria_pattern.') + op = operators.get_operator("iequals") + self.assertFalse(op("ABC", "BCA"), "Passed iequals.") + self.assertFalse(op("1", None), "Passed iequals with None as criteria_pattern.") def test_contains(self): - op = operators.get_operator('contains') - self.assertTrue(op('hasystack needle haystack', 'needle')) - self.assertTrue(op('needle', 'needle')) - self.assertTrue(op('needlehaystack', 'needle')) - self.assertTrue(op('needle haystack', 'needle')) - self.assertTrue(op('haystackneedle', 'needle')) - self.assertTrue(op('haystack needle', 'needle')) + op = operators.get_operator("contains") + self.assertTrue(op("hasystack needle haystack", "needle")) + self.assertTrue(op("needle", "needle")) + self.assertTrue(op("needlehaystack", "needle")) + self.assertTrue(op("needle haystack", "needle")) + self.assertTrue(op("haystackneedle", "needle")) + self.assertTrue(op("haystack needle", "needle")) # Mixing bytes and strings / unicode should still work - self.assertTrue(op(b'haystack needle', 'needle')) - self.assertTrue(op('haystack needle', b'needle')) - self.assertTrue(op(b'haystack needle', b'needle')) - self.assertTrue(op(b'haystack needle', u'needle')) - self.assertTrue(op(u'haystack needle', b'needle')) + self.assertTrue(op(b"haystack needle", "needle")) + self.assertTrue(op("haystack needle", b"needle")) + self.assertTrue(op(b"haystack needle", b"needle")) + self.assertTrue(op(b"haystack needle", "needle")) + self.assertTrue(op("haystack needle", b"needle")) def test_contains_fail(self): - op = operators.get_operator('contains') - self.assertFalse(op('hasystack needl haystack', 'needle')) - self.assertFalse(op('needla', 'needle')) - self.assertFalse(op('1', None), 'Passed contains with None as criteria_pattern.') + op = operators.get_operator("contains") + self.assertFalse(op("hasystack needl haystack", "needle")) + self.assertFalse(op("needla", "needle")) + self.assertFalse( + op("1", None), "Passed contains with None as criteria_pattern." + ) def test_icontains(self): - op = operators.get_operator('icontains') - self.assertTrue(op('hasystack nEEdle haystack', 'needle')) - self.assertTrue(op('neeDle', 'NeedlE')) - self.assertTrue(op('needlehaystack', 'needle')) - self.assertTrue(op('NEEDLE haystack', 'NEEDLE')) - self.assertTrue(op('haystackNEEDLE', 'needle')) - self.assertTrue(op('haystack needle', 'NEEDLE')) + op = operators.get_operator("icontains") + self.assertTrue(op("hasystack nEEdle haystack", "needle")) + self.assertTrue(op("neeDle", "NeedlE")) + self.assertTrue(op("needlehaystack", "needle")) + self.assertTrue(op("NEEDLE haystack", "NEEDLE")) + self.assertTrue(op("haystackNEEDLE", "needle")) + self.assertTrue(op("haystack needle", "NEEDLE")) # Mixing bytes and strings / unicode should still work - self.assertTrue(op(b'haystack needle', 'NEEDLE')) - self.assertTrue(op('haystack needle', b'NEEDLE')) - self.assertTrue(op(b'haystack needle', b'NEEDLE')) - self.assertTrue(op(b'haystack needle', u'NEEDLE')) - self.assertTrue(op(u'haystack needle', b'NEEDLE')) + self.assertTrue(op(b"haystack needle", "NEEDLE")) + self.assertTrue(op("haystack needle", b"NEEDLE")) + self.assertTrue(op(b"haystack needle", b"NEEDLE")) + self.assertTrue(op(b"haystack needle", "NEEDLE")) + self.assertTrue(op("haystack needle", b"NEEDLE")) def test_icontains_fail(self): - op = operators.get_operator('icontains') - self.assertFalse(op('hasystack needl haystack', 'needle')) - self.assertFalse(op('needla', 'needle')) - self.assertFalse(op('1', None), 'Passed icontains with None as criteria_pattern.') + op = operators.get_operator("icontains") + self.assertFalse(op("hasystack needl haystack", "needle")) + self.assertFalse(op("needla", "needle")) + self.assertFalse( + op("1", None), "Passed icontains with None as criteria_pattern." + ) def test_ncontains(self): - op = operators.get_operator('ncontains') - self.assertTrue(op('hasystack needle haystack', 'foo')) - self.assertTrue(op('needle', 'foo')) - self.assertTrue(op('needlehaystack', 'needlex')) - self.assertTrue(op('needle haystack', 'needlex')) - self.assertTrue(op('haystackneedle', 'needlex')) - self.assertTrue(op('haystack needle', 'needlex')) + op = operators.get_operator("ncontains") + self.assertTrue(op("hasystack needle haystack", "foo")) + self.assertTrue(op("needle", "foo")) + self.assertTrue(op("needlehaystack", "needlex")) + self.assertTrue(op("needle haystack", "needlex")) + self.assertTrue(op("haystackneedle", "needlex")) + self.assertTrue(op("haystack needle", "needlex")) # Mixing bytes and strings / unicode should still work - self.assertTrue(op(b'haystack needle', 'needlex')) - self.assertTrue(op('haystack needle', b'needlex')) - self.assertTrue(op(b'haystack needle', b'needlex')) - self.assertTrue(op(b'haystack needle', u'needlex')) - self.assertTrue(op(u'haystack needle', b'needlex')) + self.assertTrue(op(b"haystack needle", "needlex")) + self.assertTrue(op("haystack needle", b"needlex")) + self.assertTrue(op(b"haystack needle", b"needlex")) + self.assertTrue(op(b"haystack needle", "needlex")) + self.assertTrue(op("haystack needle", b"needlex")) def test_ncontains_fail(self): - op = operators.get_operator('ncontains') - self.assertFalse(op('hasystack needle haystack', 'needle')) - self.assertFalse(op('needla', 'needla')) - self.assertFalse(op('1', None), 'Passed ncontains with None as criteria_pattern.') + op = operators.get_operator("ncontains") + self.assertFalse(op("hasystack needle haystack", "needle")) + self.assertFalse(op("needla", "needla")) + self.assertFalse( + op("1", None), "Passed ncontains with None as criteria_pattern." + ) def test_incontains(self): - op = operators.get_operator('incontains') - self.assertTrue(op('hasystack needle haystack', 'FOO')) - self.assertTrue(op('needle', 'FOO')) - self.assertTrue(op('needlehaystack', 'needlex')) - self.assertTrue(op('needle haystack', 'needlex')) - self.assertTrue(op('haystackneedle', 'needlex')) - self.assertTrue(op('haystack needle', 'needlex')) + op = operators.get_operator("incontains") + self.assertTrue(op("hasystack needle haystack", "FOO")) + self.assertTrue(op("needle", "FOO")) + self.assertTrue(op("needlehaystack", "needlex")) + self.assertTrue(op("needle haystack", "needlex")) + self.assertTrue(op("haystackneedle", "needlex")) + self.assertTrue(op("haystack needle", "needlex")) def test_incontains_fail(self): - op = operators.get_operator('incontains') - self.assertFalse(op('hasystack needle haystack', 'nEeDle')) - self.assertFalse(op('needlA', 'needlA')) - self.assertFalse(op('1', None), 'Passed incontains with None as criteria_pattern.') + op = operators.get_operator("incontains") + self.assertFalse(op("hasystack needle haystack", "nEeDle")) + self.assertFalse(op("needlA", "needlA")) + self.assertFalse( + op("1", None), "Passed incontains with None as criteria_pattern." + ) def test_startswith(self): - op = operators.get_operator('startswith') - self.assertTrue(op('hasystack needle haystack', 'hasystack')) - self.assertTrue(op('a hasystack needle haystack', 'a ')) + op = operators.get_operator("startswith") + self.assertTrue(op("hasystack needle haystack", "hasystack")) + self.assertTrue(op("a hasystack needle haystack", "a ")) # Mixing bytes and strings / unicode should still work - self.assertTrue(op(b'haystack needle', 'haystack')) - self.assertTrue(op('haystack needle', b'haystack')) - self.assertTrue(op(b'haystack needle', b'haystack')) - self.assertTrue(op(b'haystack needle', u'haystack')) - self.assertTrue(op(u'haystack needle', b'haystack')) + self.assertTrue(op(b"haystack needle", "haystack")) + self.assertTrue(op("haystack needle", b"haystack")) + self.assertTrue(op(b"haystack needle", b"haystack")) + self.assertTrue(op(b"haystack needle", "haystack")) + self.assertTrue(op("haystack needle", b"haystack")) def test_startswith_fail(self): - op = operators.get_operator('startswith') - self.assertFalse(op('hasystack needle haystack', 'needle')) - self.assertFalse(op('a hasystack needle haystack', 'haystack')) - self.assertFalse(op('1', None), 'Passed startswith with None as criteria_pattern.') + op = operators.get_operator("startswith") + self.assertFalse(op("hasystack needle haystack", "needle")) + self.assertFalse(op("a hasystack needle haystack", "haystack")) + self.assertFalse( + op("1", None), "Passed startswith with None as criteria_pattern." + ) def test_istartswith(self): - op = operators.get_operator('istartswith') - self.assertTrue(op('haystack needle haystack', 'HAYstack')) - self.assertTrue(op('HAYSTACK needle haystack', 'haystack')) + op = operators.get_operator("istartswith") + self.assertTrue(op("haystack needle haystack", "HAYstack")) + self.assertTrue(op("HAYSTACK needle haystack", "haystack")) # Mixing bytes and strings / unicode should still work - self.assertTrue(op(b'HAYSTACK needle haystack', 'haystack')) - self.assertTrue(op('HAYSTACK needle haystack', b'haystack')) - self.assertTrue(op(b'HAYSTACK needle haystack', b'haystack')) - self.assertTrue(op(b'HAYSTACK needle haystack', u'haystack')) - self.assertTrue(op(u'HAYSTACK needle haystack', b'haystack')) + self.assertTrue(op(b"HAYSTACK needle haystack", "haystack")) + self.assertTrue(op("HAYSTACK needle haystack", b"haystack")) + self.assertTrue(op(b"HAYSTACK needle haystack", b"haystack")) + self.assertTrue(op(b"HAYSTACK needle haystack", "haystack")) + self.assertTrue(op("HAYSTACK needle haystack", b"haystack")) def test_istartswith_fail(self): - op = operators.get_operator('istartswith') - self.assertFalse(op('hasystack needle haystack', 'NEEDLE')) - self.assertFalse(op('a hasystack needle haystack', 'haystack')) - self.assertFalse(op('1', None), 'Passed istartswith with None as criteria_pattern.') + op = operators.get_operator("istartswith") + self.assertFalse(op("hasystack needle haystack", "NEEDLE")) + self.assertFalse(op("a hasystack needle haystack", "haystack")) + self.assertFalse( + op("1", None), "Passed istartswith with None as criteria_pattern." + ) def test_endswith(self): - op = operators.get_operator('endswith') - self.assertTrue(op('hasystack needle haystackend', 'haystackend')) - self.assertTrue(op('a hasystack needle haystack b', 'b')) + op = operators.get_operator("endswith") + self.assertTrue(op("hasystack needle haystackend", "haystackend")) + self.assertTrue(op("a hasystack needle haystack b", "b")) # Mixing bytes and strings / unicode should still work - self.assertTrue(op(b'a hasystack needle haystack b', 'b')) - self.assertTrue(op('a hasystack needle haystack b', b'b')) - self.assertTrue(op(b'a hasystack needle haystack b', b'b')) - self.assertTrue(op(b'a hasystack needle haystack b', u'b')) - self.assertTrue(op(u'a hasystack needle haystack b', b'b')) + self.assertTrue(op(b"a hasystack needle haystack b", "b")) + self.assertTrue(op("a hasystack needle haystack b", b"b")) + self.assertTrue(op(b"a hasystack needle haystack b", b"b")) + self.assertTrue(op(b"a hasystack needle haystack b", "b")) + self.assertTrue(op("a hasystack needle haystack b", b"b")) def test_endswith_fail(self): - op = operators.get_operator('endswith') - self.assertFalse(op('hasystack needle haystackend', 'haystack')) - self.assertFalse(op('a hasystack needle haystack', 'a')) - self.assertFalse(op('1', None), 'Passed endswith with None as criteria_pattern.') + op = operators.get_operator("endswith") + self.assertFalse(op("hasystack needle haystackend", "haystack")) + self.assertFalse(op("a hasystack needle haystack", "a")) + self.assertFalse( + op("1", None), "Passed endswith with None as criteria_pattern." + ) def test_iendswith(self): - op = operators.get_operator('iendswith') - self.assertTrue(op('haystack needle haystackEND', 'HAYstackend')) - self.assertTrue(op('HAYSTACK needle haystackend', 'haystackEND')) + op = operators.get_operator("iendswith") + self.assertTrue(op("haystack needle haystackEND", "HAYstackend")) + self.assertTrue(op("HAYSTACK needle haystackend", "haystackEND")) def test_iendswith_fail(self): - op = operators.get_operator('iendswith') - self.assertFalse(op('hasystack needle haystack', 'NEEDLE')) - self.assertFalse(op('a hasystack needle haystack', 'a ')) - self.assertFalse(op('1', None), 'Passed iendswith with None as criteria_pattern.') + op = operators.get_operator("iendswith") + self.assertFalse(op("hasystack needle haystack", "NEEDLE")) + self.assertFalse(op("a hasystack needle haystack", "a ")) + self.assertFalse( + op("1", None), "Passed iendswith with None as criteria_pattern." + ) def test_lt(self): - op = operators.get_operator('lessthan') - self.assertTrue(op(1, 2), 'Failed lessthan.') + op = operators.get_operator("lessthan") + self.assertTrue(op(1, 2), "Failed lessthan.") def test_lt_char(self): - op = operators.get_operator('lessthan') - self.assertTrue(op('a', 'b'), 'Failed lessthan.') + op = operators.get_operator("lessthan") + self.assertTrue(op("a", "b"), "Failed lessthan.") def test_lt_fail(self): - op = operators.get_operator('lessthan') - self.assertFalse(op(1, 1), 'Passed lessthan.') - self.assertFalse(op('1', None), 'Passed lessthan with None as criteria_pattern.') + op = operators.get_operator("lessthan") + self.assertFalse(op(1, 1), "Passed lessthan.") + self.assertFalse( + op("1", None), "Passed lessthan with None as criteria_pattern." + ) def test_gt(self): - op = operators.get_operator('greaterthan') - self.assertTrue(op(2, 1), 'Failed greaterthan.') + op = operators.get_operator("greaterthan") + self.assertTrue(op(2, 1), "Failed greaterthan.") def test_gt_str(self): - op = operators.get_operator('lessthan') - self.assertTrue(op('aba', 'bcb'), 'Failed greaterthan.') + op = operators.get_operator("lessthan") + self.assertTrue(op("aba", "bcb"), "Failed greaterthan.") def test_gt_fail(self): - op = operators.get_operator('greaterthan') - self.assertFalse(op(2, 3), 'Passed greaterthan.') - self.assertFalse(op('1', None), 'Passed greaterthan with None as criteria_pattern.') + op = operators.get_operator("greaterthan") + self.assertFalse(op(2, 3), "Passed greaterthan.") + self.assertFalse( + op("1", None), "Passed greaterthan with None as criteria_pattern." + ) def test_timediff_lt(self): - op = operators.get_operator('timediff_lt') - self.assertTrue(op(date_utils.get_datetime_utc_now().isoformat(), 10), - 'Failed test_timediff_lt.') + op = operators.get_operator("timediff_lt") + self.assertTrue( + op(date_utils.get_datetime_utc_now().isoformat(), 10), + "Failed test_timediff_lt.", + ) def test_timediff_lt_fail(self): - op = operators.get_operator('timediff_lt') - self.assertFalse(op('2014-07-01T00:01:01.000000', 10), - 'Passed test_timediff_lt.') - self.assertFalse(op('2014-07-01T00:01:01.000000', None), - 'Passed test_timediff_lt with None as criteria_pattern.') + op = operators.get_operator("timediff_lt") + self.assertFalse( + op("2014-07-01T00:01:01.000000", 10), "Passed test_timediff_lt." + ) + self.assertFalse( + op("2014-07-01T00:01:01.000000", None), + "Passed test_timediff_lt with None as criteria_pattern.", + ) def test_timediff_gt(self): - op = operators.get_operator('timediff_gt') - self.assertTrue(op('2014-07-01T00:01:01.000000', 1), - 'Failed test_timediff_gt.') + op = operators.get_operator("timediff_gt") + self.assertTrue(op("2014-07-01T00:01:01.000000", 1), "Failed test_timediff_gt.") def test_timediff_gt_fail(self): - op = operators.get_operator('timediff_gt') - self.assertFalse(op(date_utils.get_datetime_utc_now().isoformat(), 10), - 'Passed test_timediff_gt.') - self.assertFalse(op('2014-07-01T00:01:01.000000', None), - 'Passed test_timediff_gt with None as criteria_pattern.') + op = operators.get_operator("timediff_gt") + self.assertFalse( + op(date_utils.get_datetime_utc_now().isoformat(), 10), + "Passed test_timediff_gt.", + ) + self.assertFalse( + op("2014-07-01T00:01:01.000000", None), + "Passed test_timediff_gt with None as criteria_pattern.", + ) def test_exists(self): - op = operators.get_operator('exists') - self.assertTrue(op(False, None), 'Should return True') - self.assertTrue(op(1, None), 'Should return True') - self.assertTrue(op('foo', None), 'Should return True') - self.assertFalse(op(None, None), 'Should return False') + op = operators.get_operator("exists") + self.assertTrue(op(False, None), "Should return True") + self.assertTrue(op(1, None), "Should return True") + self.assertTrue(op("foo", None), "Should return True") + self.assertFalse(op(None, None), "Should return False") def test_nexists(self): - op = operators.get_operator('nexists') - self.assertFalse(op(False, None), 'Should return False') - self.assertFalse(op(1, None), 'Should return False') - self.assertFalse(op('foo', None), 'Should return False') - self.assertTrue(op(None, None), 'Should return True') + op = operators.get_operator("nexists") + self.assertFalse(op(False, None), "Should return False") + self.assertFalse(op(1, None), "Should return False") + self.assertFalse(op("foo", None), "Should return False") + self.assertTrue(op(None, None), "Should return True") def test_inside(self): - op = operators.get_operator('inside') - self.assertFalse(op('a', None), 'Should return False') - self.assertFalse(op('a', 'bcd'), 'Should return False') - self.assertTrue(op('a', 'abc'), 'Should return True') + op = operators.get_operator("inside") + self.assertFalse(op("a", None), "Should return False") + self.assertFalse(op("a", "bcd"), "Should return False") + self.assertTrue(op("a", "abc"), "Should return True") # Mixing bytes and strings / unicode should still work - self.assertTrue(op(b'a', 'abc'), 'Should return True') - self.assertTrue(op('a', b'abc'), 'Should return True') - self.assertTrue(op(b'a', b'abc'), 'Should return True') + self.assertTrue(op(b"a", "abc"), "Should return True") + self.assertTrue(op("a", b"abc"), "Should return True") + self.assertTrue(op(b"a", b"abc"), "Should return True") def test_ninside(self): - op = operators.get_operator('ninside') - self.assertFalse(op('a', None), 'Should return False') - self.assertFalse(op('a', 'abc'), 'Should return False') - self.assertTrue(op('a', 'bcd'), 'Should return True') + op = operators.get_operator("ninside") + self.assertFalse(op("a", None), "Should return False") + self.assertFalse(op("a", "abc"), "Should return False") + self.assertTrue(op("a", "bcd"), "Should return True") class GetOperatorsTest(unittest2.TestCase): def test_get_operator(self): - self.assertTrue(operators.get_operator('equals')) - self.assertTrue(operators.get_operator('EQUALS')) + self.assertTrue(operators.get_operator("equals")) + self.assertTrue(operators.get_operator("EQUALS")) def test_get_operator_returns_same_operator_with_different_cases(self): - equals = operators.get_operator('equals') - EQUALS = operators.get_operator('EQUALS') - Equals = operators.get_operator('Equals') + equals = operators.get_operator("equals") + EQUALS = operators.get_operator("EQUALS") + Equals = operators.get_operator("Equals") self.assertEqual(equals, EQUALS) self.assertEqual(equals, Equals) def test_get_operator_with_nonexistent_operator(self): with self.assertRaises(Exception): - operators.get_operator('weird') + operators.get_operator("weird") def test_get_allowed_operators(self): # This test will need to change as operators are deprecated diff --git a/st2common/tests/unit/test_pack_action_alias_unit_testing_utils.py b/st2common/tests/unit/test_pack_action_alias_unit_testing_utils.py index 355e680c85f..630b14ed372 100644 --- a/st2common/tests/unit/test_pack_action_alias_unit_testing_utils.py +++ b/st2common/tests/unit/test_pack_action_alias_unit_testing_utils.py @@ -23,111 +23,117 @@ from st2common.exceptions.content import ParseException from st2common.models.db.actionalias import ActionAliasDB -__all__ = [ - 'PackActionAliasUnitTestUtils' -] +__all__ = ["PackActionAliasUnitTestUtils"] -PACK_PATH_1 = os.path.join(get_fixtures_base_path(), 'packs/pack_dir_name_doesnt_match_ref') +PACK_PATH_1 = os.path.join( + get_fixtures_base_path(), "packs/pack_dir_name_doesnt_match_ref" +) class PackActionAliasUnitTestUtils(BaseActionAliasTestCase): - action_alias_name = 'mock' + action_alias_name = "mock" mock_get_action_alias_db_by_name = True def test_assertExtractedParametersMatch_success(self): format_string = self.action_alias_db.formats[0] - command = 'show last 3 metrics for my.host' - expected_parameters = { - 'count': '3', - 'server': 'my.host' - } - self.assertExtractedParametersMatch(format_string=format_string, - command=command, - parameters=expected_parameters) + command = "show last 3 metrics for my.host" + expected_parameters = {"count": "3", "server": "my.host"} + self.assertExtractedParametersMatch( + format_string=format_string, command=command, parameters=expected_parameters + ) format_string = self.action_alias_db.formats[0] - command = 'show last 10 metrics for my.host.example' - expected_parameters = { - 'count': '10', - 'server': 'my.host.example' - } - self.assertExtractedParametersMatch(format_string=format_string, - command=command, - parameters=expected_parameters) + command = "show last 10 metrics for my.host.example" + expected_parameters = {"count": "10", "server": "my.host.example"} + self.assertExtractedParametersMatch( + format_string=format_string, command=command, parameters=expected_parameters + ) def test_assertExtractedParametersMatch_command_doesnt_match_format_string(self): format_string = self.action_alias_db.formats[0] - command = 'show last foo' + command = "show last foo" expected_parameters = {} - expected_msg = ('Command "show last foo" doesn\'t match format string ' - '"show last {{count}} metrics for {{server}}"') - - self.assertRaisesRegexp(ParseException, expected_msg, - self.assertExtractedParametersMatch, - format_string=format_string, - command=command, - parameters=expected_parameters) + expected_msg = ( + 'Command "show last foo" doesn\'t match format string ' + '"show last {{count}} metrics for {{server}}"' + ) + + self.assertRaisesRegexp( + ParseException, + expected_msg, + self.assertExtractedParametersMatch, + format_string=format_string, + command=command, + parameters=expected_parameters, + ) def test_assertCommandMatchesExactlyOneFormatString(self): # Matches single format string - format_strings = [ - 'foo bar {{bar}}', - 'foo bar {{baz}} baz' - ] - command = 'foo bar a test=1' - self.assertCommandMatchesExactlyOneFormatString(format_strings=format_strings, - command=command) + format_strings = ["foo bar {{bar}}", "foo bar {{baz}} baz"] + command = "foo bar a test=1" + self.assertCommandMatchesExactlyOneFormatString( + format_strings=format_strings, command=command + ) # Matches multiple format strings - format_strings = [ - 'foo bar {{bar}}', - 'foo bar {{baz}}' - ] - command = 'foo bar a test=1' - - expected_msg = ('Command "foo bar a test=1" matched multiple format ' - 'strings: foo bar {{bar}}, foo bar {{baz}}') - self.assertRaisesRegexp(AssertionError, expected_msg, - self.assertCommandMatchesExactlyOneFormatString, - format_strings=format_strings, - command=command) + format_strings = ["foo bar {{bar}}", "foo bar {{baz}}"] + command = "foo bar a test=1" + + expected_msg = ( + 'Command "foo bar a test=1" matched multiple format ' + "strings: foo bar {{bar}}, foo bar {{baz}}" + ) + self.assertRaisesRegexp( + AssertionError, + expected_msg, + self.assertCommandMatchesExactlyOneFormatString, + format_strings=format_strings, + command=command, + ) # Doesn't matches any format strings - format_strings = [ - 'foo bar {{bar}}', - 'foo bar {{baz}}' - ] - command = 'does not match foo' - - expected_msg = ('Command "does not match foo" didn\'t match any of the provided format ' - 'strings') - self.assertRaisesRegexp(AssertionError, expected_msg, - self.assertCommandMatchesExactlyOneFormatString, - format_strings=format_strings, - command=command) - - @mock.patch.object(BaseActionAliasTestCase, '_get_base_pack_path', - mock.Mock(return_value=PACK_PATH_1)) + format_strings = ["foo bar {{bar}}", "foo bar {{baz}}"] + command = "does not match foo" + + expected_msg = ( + 'Command "does not match foo" didn\'t match any of the provided format ' + "strings" + ) + self.assertRaisesRegexp( + AssertionError, + expected_msg, + self.assertCommandMatchesExactlyOneFormatString, + format_strings=format_strings, + command=command, + ) + + @mock.patch.object( + BaseActionAliasTestCase, + "_get_base_pack_path", + mock.Mock(return_value=PACK_PATH_1), + ) def test_base_class_works_when_pack_directory_name_doesnt_match_pack_name(self): # Verify that the alias can still be succesfuly loaded from disk even if the pack directory # name doesn't match "pack" resource attribute (aka pack ref) self.mock_get_action_alias_db_by_name = False - action_alias_db = self._get_action_alias_db_by_name(name='alias1') - self.assertEqual(action_alias_db.name, 'alias1') - self.assertEqual(action_alias_db.pack, 'pack_name_not_the_same_as_dir_name') + action_alias_db = self._get_action_alias_db_by_name(name="alias1") + self.assertEqual(action_alias_db.name, "alias1") + self.assertEqual(action_alias_db.pack, "pack_name_not_the_same_as_dir_name") # Note: We mock the original method to make testing of all the edge cases easier def _get_action_alias_db_by_name(self, name): if not self.mock_get_action_alias_db_by_name: - return super(PackActionAliasUnitTestUtils, self)._get_action_alias_db_by_name(name) + return super( + PackActionAliasUnitTestUtils, self + )._get_action_alias_db_by_name(name) values = { - 'name': self.action_alias_name, - 'pack': 'mock', - 'formats': [ - 'show last {{count}} metrics for {{server}}', - ] + "name": self.action_alias_name, + "pack": "mock", + "formats": [ + "show last {{count}} metrics for {{server}}", + ], } action_alias_db = ActionAliasDB(**values) return action_alias_db diff --git a/st2common/tests/unit/test_pack_management.py b/st2common/tests/unit/test_pack_management.py index abc04984892..b350c7d98fd 100644 --- a/st2common/tests/unit/test_pack_management.py +++ b/st2common/tests/unit/test_pack_management.py @@ -21,37 +21,35 @@ import unittest2 BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -PACK_ACTIONS_DIR = os.path.join(BASE_DIR, '../../../contrib/packs/actions') +PACK_ACTIONS_DIR = os.path.join(BASE_DIR, "../../../contrib/packs/actions") PACK_ACTIONS_DIR = os.path.abspath(PACK_ACTIONS_DIR) sys.path.insert(0, PACK_ACTIONS_DIR) from st2common.util.monkey_patch import use_select_poll_workaround + use_select_poll_workaround() from st2common.util.pack_management import eval_repo_url -__all__ = [ - 'InstallPackTestCase' -] +__all__ = ["InstallPackTestCase"] class InstallPackTestCase(unittest2.TestCase): - def test_eval_repo(self): - result = eval_repo_url('stackstorm/st2contrib') - self.assertEqual(result, 'https://github.com/stackstorm/st2contrib') + result = eval_repo_url("stackstorm/st2contrib") + self.assertEqual(result, "https://github.com/stackstorm/st2contrib") - result = eval_repo_url('git@github.com:StackStorm/st2contrib.git') - self.assertEqual(result, 'git@github.com:StackStorm/st2contrib.git') + result = eval_repo_url("git@github.com:StackStorm/st2contrib.git") + self.assertEqual(result, "git@github.com:StackStorm/st2contrib.git") - result = eval_repo_url('gitlab@gitlab.com:StackStorm/st2contrib.git') - self.assertEqual(result, 'gitlab@gitlab.com:StackStorm/st2contrib.git') + result = eval_repo_url("gitlab@gitlab.com:StackStorm/st2contrib.git") + self.assertEqual(result, "gitlab@gitlab.com:StackStorm/st2contrib.git") - repo_url = 'https://github.com/StackStorm/st2contrib.git' + repo_url = "https://github.com/StackStorm/st2contrib.git" result = eval_repo_url(repo_url) self.assertEqual(result, repo_url) - repo_url = 'https://git-wip-us.apache.org/repos/asf/libcloud.git' + repo_url = "https://git-wip-us.apache.org/repos/asf/libcloud.git" result = eval_repo_url(repo_url) self.assertEqual(result, repo_url) diff --git a/st2common/tests/unit/test_param_utils.py b/st2common/tests/unit/test_param_utils.py index 695d17f448a..c2e5810815f 100644 --- a/st2common/tests/unit/test_param_utils.py +++ b/st2common/tests/unit/test_param_utils.py @@ -36,30 +36,31 @@ from st2tests.fixturesloader import FixturesLoader -FIXTURES_PACK = 'generic' +FIXTURES_PACK = "generic" TEST_MODELS = { - 'actions': ['action_4_action_context_param.yaml', 'action_system_default.yaml'], - 'runners': ['testrunner1.yaml'] + "actions": ["action_4_action_context_param.yaml", "action_system_default.yaml"], + "runners": ["testrunner1.yaml"], } -FIXTURES = FixturesLoader().load_models(fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_MODELS) +FIXTURES = FixturesLoader().load_models( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS +) -@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) +@mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) class ParamsUtilsTest(DbTestCase): - action_db = FIXTURES['actions']['action_4_action_context_param.yaml'] - action_system_default_db = FIXTURES['actions']['action_system_default.yaml'] - runnertype_db = FIXTURES['runners']['testrunner1.yaml'] + action_db = FIXTURES["actions"]["action_4_action_context_param.yaml"] + action_system_default_db = FIXTURES["actions"]["action_system_default.yaml"] + runnertype_db = FIXTURES["runners"]["testrunner1.yaml"] def test_get_finalized_params(self): params = { - 'actionstr': 'foo', - 'some_key_that_aint_exist_in_action_or_runner': 'bar', - 'runnerint': 555, - 'runnerimmutable': 'failed_override', - 'actionimmutable': 'failed_override' + "actionstr": "foo", + "some_key_that_aint_exist_in_action_or_runner": "bar", + "runnerint": 555, + "runnerimmutable": "failed_override", + "actionimmutable": "failed_override", } liveaction_db = self._get_liveaction_model(params) @@ -67,289 +68,320 @@ def test_get_finalized_params(self): ParamsUtilsTest.runnertype_db.runner_parameters, ParamsUtilsTest.action_db.parameters, liveaction_db.parameters, - liveaction_db.context) + liveaction_db.context, + ) # Asserts for runner params. # Assert that default values for runner params are resolved. - self.assertEqual(runner_params.get('runnerstr'), 'defaultfoo') + self.assertEqual(runner_params.get("runnerstr"), "defaultfoo") # Assert that a runner param from action exec is picked up. - self.assertEqual(runner_params.get('runnerint'), 555) + self.assertEqual(runner_params.get("runnerint"), 555) # Assert that a runner param can be overridden by action param default. - self.assertEqual(runner_params.get('runnerdummy'), 'actiondummy') + self.assertEqual(runner_params.get("runnerdummy"), "actiondummy") # Assert that a runner param default can be overridden by 'falsey' action param default, # (timeout: 0 case). - self.assertEqual(runner_params.get('runnerdefaultint'), 0) + self.assertEqual(runner_params.get("runnerdefaultint"), 0) # Assert that an immutable param cannot be overridden by action param or execution param. - self.assertEqual(runner_params.get('runnerimmutable'), 'runnerimmutable') + self.assertEqual(runner_params.get("runnerimmutable"), "runnerimmutable") # Asserts for action params. - self.assertEqual(action_params.get('actionstr'), 'foo') + self.assertEqual(action_params.get("actionstr"), "foo") # Assert that a param that is provided in action exec that isn't in action or runner params # isn't in resolved params. - self.assertEqual(action_params.get('some_key_that_aint_exist_in_action_or_runner'), None) + self.assertEqual( + action_params.get("some_key_that_aint_exist_in_action_or_runner"), None + ) # Assert that an immutable param cannot be overridden by execution param. - self.assertEqual(action_params.get('actionimmutable'), 'actionimmutable') + self.assertEqual(action_params.get("actionimmutable"), "actionimmutable") # Assert that an action context param is set correctly. - self.assertEqual(action_params.get('action_api_user'), 'noob') + self.assertEqual(action_params.get("action_api_user"), "noob") # Assert that none of runner params are present in action_params. for k in action_params: - self.assertNotIn(k, runner_params, 'Param ' + k + ' is a runner param.') + self.assertNotIn(k, runner_params, "Param " + k + " is a runner param.") def test_get_finalized_params_system_values(self): - KeyValuePair.add_or_update(KeyValuePairDB(name='actionstr', value='foo')) - KeyValuePair.add_or_update(KeyValuePairDB(name='actionnumber', value='1.0')) - params = { - 'runnerint': 555 - } + KeyValuePair.add_or_update(KeyValuePairDB(name="actionstr", value="foo")) + KeyValuePair.add_or_update(KeyValuePairDB(name="actionnumber", value="1.0")) + params = {"runnerint": 555} liveaction_db = self._get_liveaction_model(params) runner_params, action_params = param_utils.get_finalized_params( ParamsUtilsTest.runnertype_db.runner_parameters, ParamsUtilsTest.action_system_default_db.parameters, liveaction_db.parameters, - liveaction_db.context) + liveaction_db.context, + ) # Asserts for runner params. # Assert that default values for runner params are resolved. - self.assertEqual(runner_params.get('runnerstr'), 'defaultfoo') + self.assertEqual(runner_params.get("runnerstr"), "defaultfoo") # Assert that a runner param from action exec is picked up. - self.assertEqual(runner_params.get('runnerint'), 555) + self.assertEqual(runner_params.get("runnerint"), 555) # Assert that an immutable param cannot be overridden by action param or execution param. - self.assertEqual(runner_params.get('runnerimmutable'), 'runnerimmutable') + self.assertEqual(runner_params.get("runnerimmutable"), "runnerimmutable") # Asserts for action params. - self.assertEqual(action_params.get('actionstr'), 'foo') - self.assertEqual(action_params.get('actionnumber'), 1.0) + self.assertEqual(action_params.get("actionstr"), "foo") + self.assertEqual(action_params.get("actionnumber"), 1.0) def test_get_finalized_params_action_immutable(self): params = { - 'actionstr': 'foo', - 'some_key_that_aint_exist_in_action_or_runner': 'bar', - 'runnerint': 555, - 'actionimmutable': 'failed_override' + "actionstr": "foo", + "some_key_that_aint_exist_in_action_or_runner": "bar", + "runnerint": 555, + "actionimmutable": "failed_override", } liveaction_db = self._get_liveaction_model(params) - action_context = {'api_user': None} + action_context = {"api_user": None} runner_params, action_params = param_utils.get_finalized_params( ParamsUtilsTest.runnertype_db.runner_parameters, ParamsUtilsTest.action_db.parameters, liveaction_db.parameters, - action_context) + action_context, + ) # Asserts for runner params. # Assert that default values for runner params are resolved. - self.assertEqual(runner_params.get('runnerstr'), 'defaultfoo') + self.assertEqual(runner_params.get("runnerstr"), "defaultfoo") # Assert that a runner param from action exec is picked up. - self.assertEqual(runner_params.get('runnerint'), 555) + self.assertEqual(runner_params.get("runnerint"), 555) # Assert that a runner param can be overridden by action param default. - self.assertEqual(runner_params.get('runnerdummy'), 'actiondummy') + self.assertEqual(runner_params.get("runnerdummy"), "actiondummy") # Asserts for action params. - self.assertEqual(action_params.get('actionstr'), 'foo') + self.assertEqual(action_params.get("actionstr"), "foo") # Assert that a param that is provided in action exec that isn't in action or runner params # isn't in resolved params. - self.assertEqual(action_params.get('some_key_that_aint_exist_in_action_or_runner'), None) + self.assertEqual( + action_params.get("some_key_that_aint_exist_in_action_or_runner"), None + ) def test_get_finalized_params_empty(self): params = {} runner_param_info = {} action_param_info = {} - action_context = {'user': None} + action_context = {"user": None} r_runner_params, r_action_params = param_utils.get_finalized_params( - runner_param_info, action_param_info, params, action_context) + runner_param_info, action_param_info, params, action_context + ) self.assertEqual(r_runner_params, params) self.assertEqual(r_action_params, params) def test_get_finalized_params_none(self): - params = { - 'r1': None, - 'a1': None - } - runner_param_info = {'r1': {}} - action_param_info = {'a1': {}} - action_context = {'api_user': None} + params = {"r1": None, "a1": None} + runner_param_info = {"r1": {}} + action_param_info = {"a1": {}} + action_context = {"api_user": None} r_runner_params, r_action_params = param_utils.get_finalized_params( - runner_param_info, action_param_info, params, action_context) - self.assertEqual(r_runner_params, {'r1': None}) - self.assertEqual(r_action_params, {'a1': None}) + runner_param_info, action_param_info, params, action_context + ) + self.assertEqual(r_runner_params, {"r1": None}) + self.assertEqual(r_action_params, {"a1": None}) def test_get_finalized_params_no_cast(self): params = { - 'r1': '{{r2}}', - 'r2': 1, - 'a1': True, - 'a2': '{{r1}} {{a1}}', - 'a3': '{{action_context.api_user}}' - } - runner_param_info = {'r1': {}, 'r2': {}} - action_param_info = {'a1': {}, 'a2': {}, 'a3': {}} - action_context = {'api_user': 'noob'} + "r1": "{{r2}}", + "r2": 1, + "a1": True, + "a2": "{{r1}} {{a1}}", + "a3": "{{action_context.api_user}}", + } + runner_param_info = {"r1": {}, "r2": {}} + action_param_info = {"a1": {}, "a2": {}, "a3": {}} + action_context = {"api_user": "noob"} r_runner_params, r_action_params = param_utils.get_finalized_params( - runner_param_info, action_param_info, params, action_context) - self.assertEqual(r_runner_params, {'r1': u'1', 'r2': 1}) - self.assertEqual(r_action_params, {'a1': True, 'a2': u'1 True', 'a3': 'noob'}) + runner_param_info, action_param_info, params, action_context + ) + self.assertEqual(r_runner_params, {"r1": "1", "r2": 1}) + self.assertEqual(r_action_params, {"a1": True, "a2": "1 True", "a3": "noob"}) def test_get_finalized_params_with_cast(self): # Note : In this test runner_params.r1 has a string value. However per runner_param_info the # type is an integer. The expected type is considered and cast is performed accordingly. params = { - 'r1': '{{r2}}', - 'r2': 1, - 'a1': True, - 'a2': '{{a1}}', - 'a3': '{{action_context.api_user}}' + "r1": "{{r2}}", + "r2": 1, + "a1": True, + "a2": "{{a1}}", + "a3": "{{action_context.api_user}}", } - runner_param_info = {'r1': {'type': 'integer'}, 'r2': {'type': 'integer'}} - action_param_info = {'a1': {'type': 'boolean'}, 'a2': {'type': 'boolean'}, 'a3': {}} - action_context = {'api_user': 'noob'} + runner_param_info = {"r1": {"type": "integer"}, "r2": {"type": "integer"}} + action_param_info = { + "a1": {"type": "boolean"}, + "a2": {"type": "boolean"}, + "a3": {}, + } + action_context = {"api_user": "noob"} r_runner_params, r_action_params = param_utils.get_finalized_params( - runner_param_info, action_param_info, params, action_context) - self.assertEqual(r_runner_params, {'r1': 1, 'r2': 1}) - self.assertEqual(r_action_params, {'a1': True, 'a2': True, 'a3': 'noob'}) + runner_param_info, action_param_info, params, action_context + ) + self.assertEqual(r_runner_params, {"r1": 1, "r2": 1}) + self.assertEqual(r_action_params, {"a1": True, "a2": True, "a3": "noob"}) def test_get_finalized_params_with_cast_overriden(self): params = { - 'r1': '{{r2}}', - 'r2': 1, - 'a1': '{{r1}}', - 'a2': '{{r1}}', - 'a3': '{{r1}}' + "r1": "{{r2}}", + "r2": 1, + "a1": "{{r1}}", + "a2": "{{r1}}", + "a3": "{{r1}}", } - runner_param_info = {'r1': {'type': 'integer'}, 'r2': {'type': 'integer'}} - action_param_info = {'a1': {'type': 'boolean'}, 'a2': {'type': 'string'}, - 'a3': {'type': 'integer'}, 'r1': {'type': 'string'}} - action_context = {'api_user': 'noob'} + runner_param_info = {"r1": {"type": "integer"}, "r2": {"type": "integer"}} + action_param_info = { + "a1": {"type": "boolean"}, + "a2": {"type": "string"}, + "a3": {"type": "integer"}, + "r1": {"type": "string"}, + } + action_context = {"api_user": "noob"} r_runner_params, r_action_params = param_utils.get_finalized_params( - runner_param_info, action_param_info, params, action_context) - self.assertEqual(r_runner_params, {'r1': 1, 'r2': 1}) - self.assertEqual(r_action_params, {'a1': 1, 'a2': u'1', 'a3': 1}) + runner_param_info, action_param_info, params, action_context + ) + self.assertEqual(r_runner_params, {"r1": 1, "r2": 1}) + self.assertEqual(r_action_params, {"a1": 1, "a2": "1", "a3": 1}) def test_get_finalized_params_cross_talk_no_cast(self): params = { - 'r1': '{{a1}}', - 'r2': 1, - 'a1': True, - 'a2': '{{r1}} {{a1}}', - 'a3': '{{action_context.api_user}}' - } - runner_param_info = {'r1': {}, 'r2': {}} - action_param_info = {'a1': {}, 'a2': {}, 'a3': {}} - action_context = {'api_user': 'noob'} + "r1": "{{a1}}", + "r2": 1, + "a1": True, + "a2": "{{r1}} {{a1}}", + "a3": "{{action_context.api_user}}", + } + runner_param_info = {"r1": {}, "r2": {}} + action_param_info = {"a1": {}, "a2": {}, "a3": {}} + action_context = {"api_user": "noob"} r_runner_params, r_action_params = param_utils.get_finalized_params( - runner_param_info, action_param_info, params, action_context) - self.assertEqual(r_runner_params, {'r1': u'True', 'r2': 1}) - self.assertEqual(r_action_params, {'a1': True, 'a2': u'True True', 'a3': 'noob'}) + runner_param_info, action_param_info, params, action_context + ) + self.assertEqual(r_runner_params, {"r1": "True", "r2": 1}) + self.assertEqual(r_action_params, {"a1": True, "a2": "True True", "a3": "noob"}) def test_get_finalized_params_cross_talk_with_cast(self): params = { - 'r1': '{{a1}}', - 'r2': 1, - 'r3': 1, - 'a1': True, - 'a2': '{{r1}},{{a1}},{{a3}},{{r3}}', - 'a3': '{{a1}}' + "r1": "{{a1}}", + "r2": 1, + "r3": 1, + "a1": True, + "a2": "{{r1}},{{a1}},{{a3}},{{r3}}", + "a3": "{{a1}}", } - runner_param_info = {'r1': {'type': 'boolean'}, 'r2': {'type': 'integer'}, 'r3': {}} - action_param_info = {'a1': {'type': 'boolean'}, 'a2': {'type': 'array'}, 'a3': {}} - action_context = {'user': None} + runner_param_info = { + "r1": {"type": "boolean"}, + "r2": {"type": "integer"}, + "r3": {}, + } + action_param_info = { + "a1": {"type": "boolean"}, + "a2": {"type": "array"}, + "a3": {}, + } + action_context = {"user": None} r_runner_params, r_action_params = param_utils.get_finalized_params( - runner_param_info, action_param_info, params, action_context) - self.assertEqual(r_runner_params, {'r1': True, 'r2': 1, 'r3': 1}) - self.assertEqual(r_action_params, {'a1': True, 'a2': (True, True, True, 1), 'a3': u'True'}) + runner_param_info, action_param_info, params, action_context + ) + self.assertEqual(r_runner_params, {"r1": True, "r2": 1, "r3": 1}) + self.assertEqual( + r_action_params, {"a1": True, "a2": (True, True, True, 1), "a3": "True"} + ) def test_get_finalized_params_order(self): - params = { - 'r1': 'p1', - 'r2': 'p2', - 'r3': 'p3', - 'a1': 'p4', - 'a2': 'p5' - } - runner_param_info = {'r1': {}, 'r2': {'default': 'r2'}, 'r3': {'default': 'r3'}} - action_param_info = {'a1': {}, 'a2': {'default': 'a2'}, 'r3': {'default': 'a3'}} - action_context = {'api_user': 'noob'} + params = {"r1": "p1", "r2": "p2", "r3": "p3", "a1": "p4", "a2": "p5"} + runner_param_info = {"r1": {}, "r2": {"default": "r2"}, "r3": {"default": "r3"}} + action_param_info = {"a1": {}, "a2": {"default": "a2"}, "r3": {"default": "a3"}} + action_context = {"api_user": "noob"} r_runner_params, r_action_params = param_utils.get_finalized_params( - runner_param_info, action_param_info, params, action_context) - self.assertEqual(r_runner_params, {'r1': u'p1', 'r2': u'p2', 'r3': u'p3'}) - self.assertEqual(r_action_params, {'a1': u'p4', 'a2': u'p5'}) + runner_param_info, action_param_info, params, action_context + ) + self.assertEqual(r_runner_params, {"r1": "p1", "r2": "p2", "r3": "p3"}) + self.assertEqual(r_action_params, {"a1": "p4", "a2": "p5"}) params = {} - runner_param_info = {'r1': {}, 'r2': {'default': 'r2'}, 'r3': {'default': 'r3'}} - action_param_info = {'a1': {}, 'a2': {'default': 'a2'}, 'r3': {'default': 'a3'}} - action_context = {'api_user': 'noob'} + runner_param_info = {"r1": {}, "r2": {"default": "r2"}, "r3": {"default": "r3"}} + action_param_info = {"a1": {}, "a2": {"default": "a2"}, "r3": {"default": "a3"}} + action_context = {"api_user": "noob"} r_runner_params, r_action_params = param_utils.get_finalized_params( - runner_param_info, action_param_info, params, action_context) - self.assertEqual(r_runner_params, {'r1': None, 'r2': u'r2', 'r3': u'a3'}) - self.assertEqual(r_action_params, {'a1': None, 'a2': u'a2'}) + runner_param_info, action_param_info, params, action_context + ) + self.assertEqual(r_runner_params, {"r1": None, "r2": "r2", "r3": "a3"}) + self.assertEqual(r_action_params, {"a1": None, "a2": "a2"}) params = {} - runner_param_info = {'r1': {}, 'r2': {'default': 'r2'}, 'r3': {}} - action_param_info = {'r1': {}, 'r2': {}, 'r3': {'default': 'a3'}} - action_context = {'api_user': 'noob'} + runner_param_info = {"r1": {}, "r2": {"default": "r2"}, "r3": {}} + action_param_info = {"r1": {}, "r2": {}, "r3": {"default": "a3"}} + action_context = {"api_user": "noob"} r_runner_params, r_action_params = param_utils.get_finalized_params( - runner_param_info, action_param_info, params, action_context) - self.assertEqual(r_runner_params, {'r1': None, 'r2': u'r2', 'r3': u'a3'}) + runner_param_info, action_param_info, params, action_context + ) + self.assertEqual(r_runner_params, {"r1": None, "r2": "r2", "r3": "a3"}) def test_get_finalized_params_non_existent_template_key_in_action_context(self): params = { - 'r1': 'foo', - 'r2': 2, - 'a1': 'i love tests', - 'a2': '{{action_context.lorem_ipsum}}' - } - runner_param_info = {'r1': {'type': 'string'}, 'r2': {'type': 'integer'}} - action_param_info = {'a1': {'type': 'string'}, 'a2': {'type': 'string'}} - action_context = {'api_user': 'noob', 'source_channel': 'reddit'} + "r1": "foo", + "r2": 2, + "a1": "i love tests", + "a2": "{{action_context.lorem_ipsum}}", + } + runner_param_info = {"r1": {"type": "string"}, "r2": {"type": "integer"}} + action_param_info = {"a1": {"type": "string"}, "a2": {"type": "string"}} + action_context = {"api_user": "noob", "source_channel": "reddit"} try: r_runner_params, r_action_params = param_utils.get_finalized_params( - runner_param_info, action_param_info, params, action_context) - self.fail('This should have thrown because we are trying to deref a key in ' + - 'action context that ain\'t exist.') + runner_param_info, action_param_info, params, action_context + ) + self.fail( + "This should have thrown because we are trying to deref a key in " + + "action context that ain't exist." + ) except ParamException as e: - error_msg = 'Failed to render parameter "a2": \'dict object\' ' + \ - 'has no attribute \'lorem_ipsum\'' + error_msg = ( + "Failed to render parameter \"a2\": 'dict object' " + + "has no attribute 'lorem_ipsum'" + ) self.assertIn(error_msg, six.text_type(e)) pass def test_unicode_value_casting(self): - rendered = {'a1': 'unicode1 ٩(̾●̮̮̃̾•̃̾)۶ unicode2'} - parameter_schemas = {'a1': {'type': 'string'}} + rendered = {"a1": "unicode1 ٩(̾●̮̮̃̾•̃̾)۶ unicode2"} + parameter_schemas = {"a1": {"type": "string"}} - result = param_utils._cast_params(rendered=rendered, - parameter_schemas=parameter_schemas) + result = param_utils._cast_params( + rendered=rendered, parameter_schemas=parameter_schemas + ) if six.PY3: - expected = { - 'a1': (u'unicode1 ٩(̾●̮̮̃̾•̃̾)۶ unicode2') - } + expected = {"a1": ("unicode1 ٩(̾●̮̮̃̾•̃̾)۶ unicode2")} else: expected = { - 'a1': (u'unicode1 \xd9\xa9(\xcc\xbe\xe2\x97\x8f\xcc\xae\xcc\xae\xcc' - u'\x83\xcc\xbe\xe2\x80\xa2\xcc\x83\xcc\xbe)\xdb\xb6 unicode2') + "a1": ( + "unicode1 \xd9\xa9(\xcc\xbe\xe2\x97\x8f\xcc\xae\xcc\xae\xcc" + "\x83\xcc\xbe\xe2\x80\xa2\xcc\x83\xcc\xbe)\xdb\xb6 unicode2" + ) } self.assertEqual(result, expected) def test_get_finalized_params_with_casting_unicode_values(self): - params = {'a1': 'unicode1 ٩(̾●̮̮̃̾•̃̾)۶ unicode2'} + params = {"a1": "unicode1 ٩(̾●̮̮̃̾•̃̾)۶ unicode2"} runner_param_info = {} - action_param_info = {'a1': {'type': 'string'}} + action_param_info = {"a1": {"type": "string"}} - action_context = {'user': None} + action_context = {"user": None} r_runner_params, r_action_params = param_utils.get_finalized_params( - runner_param_info, action_param_info, params, action_context) + runner_param_info, action_param_info, params, action_context + ) if six.PY3: - expected_action_params = { - 'a1': (u'unicode1 ٩(̾●̮̮̃̾•̃̾)۶ unicode2') - } + expected_action_params = {"a1": ("unicode1 ٩(̾●̮̮̃̾•̃̾)۶ unicode2")} else: expected_action_params = { - 'a1': (u'unicode1 \xd9\xa9(\xcc\xbe\xe2\x97\x8f\xcc\xae\xcc\xae\xcc' - u'\x83\xcc\xbe\xe2\x80\xa2\xcc\x83\xcc\xbe)\xdb\xb6 unicode2') + "a1": ( + "unicode1 \xd9\xa9(\xcc\xbe\xe2\x97\x8f\xcc\xae\xcc\xae\xcc" + "\x83\xcc\xbe\xe2\x80\xa2\xcc\x83\xcc\xbe)\xdb\xb6 unicode2" + ) } self.assertEqual(r_runner_params, {}) @@ -359,59 +391,53 @@ def test_get_finalized_params_with_dict(self): # Note : In this test runner_params.r1 has a string value. However per runner_param_info the # type is an integer. The expected type is considered and cast is performed accordingly. params = { - 'r1': '{{r2}}', - 'r2': {'r2.1': 1}, - 'a1': True, - 'a2': '{{a1}}', - 'a3': { - 'test': '{{a1}}', - 'test1': '{{a4}}', - 'test2': '{{a5}}', + "r1": "{{r2}}", + "r2": {"r2.1": 1}, + "a1": True, + "a2": "{{a1}}", + "a3": { + "test": "{{a1}}", + "test1": "{{a4}}", + "test2": "{{a5}}", }, - 'a4': 3, - 'a5': ['1', '{{a1}}'] + "a4": 3, + "a5": ["1", "{{a1}}"], } - runner_param_info = {'r1': {'type': 'object'}, 'r2': {'type': 'object'}} + runner_param_info = {"r1": {"type": "object"}, "r2": {"type": "object"}} action_param_info = { - 'a1': { - 'type': 'boolean', + "a1": { + "type": "boolean", }, - 'a2': { - 'type': 'boolean', + "a2": { + "type": "boolean", }, - 'a3': { - 'type': 'object', + "a3": { + "type": "object", }, - 'a4': { - 'type': 'integer', + "a4": { + "type": "integer", }, - 'a5': { - 'type': 'array', + "a5": { + "type": "array", }, } r_runner_params, r_action_params = param_utils.get_finalized_params( - runner_param_info, action_param_info, params, {'user': None}) - self.assertEqual( - r_runner_params, {'r1': {'r2.1': 1}, 'r2': {'r2.1': 1}}) + runner_param_info, action_param_info, params, {"user": None} + ) + self.assertEqual(r_runner_params, {"r1": {"r2.1": 1}, "r2": {"r2.1": 1}}) self.assertEqual( r_action_params, { - 'a1': True, - 'a2': True, - 'a3': { - 'test': True, - 'test1': 3, - 'test2': [ - '1', - True - ], + "a1": True, + "a2": True, + "a3": { + "test": True, + "test1": 3, + "test2": ["1", True], }, - 'a4': 3, - 'a5': [ - '1', - True - ], - } + "a4": 3, + "a5": ["1", True], + }, ) def test_get_finalized_params_with_list(self): @@ -419,183 +445,177 @@ def test_get_finalized_params_with_list(self): # type is an integer. The expected type is considered and cast is performed accordingly. self.maxDiff = None params = { - 'r1': '{{r2}}', - 'r2': ['1', '2'], - 'a1': True, - 'a2': 'Test', - 'a3': 'Test2', - 'a4': '{{a1}}', - 'a5': ['{{a2}}', '{{a3}}'], - 'a6': [ - ['{{r2}}', '{{a2}}'], - ['{{a3}}', '{{a1}}'], + "r1": "{{r2}}", + "r2": ["1", "2"], + "a1": True, + "a2": "Test", + "a3": "Test2", + "a4": "{{a1}}", + "a5": ["{{a2}}", "{{a3}}"], + "a6": [ + ["{{r2}}", "{{a2}}"], + ["{{a3}}", "{{a1}}"], [ - '{{a7}}', - 'This should be rendered as a string {{a1}}', - '{{a1}} This, too, should be rendered as a string {{a1}}', - ] + "{{a7}}", + "This should be rendered as a string {{a1}}", + "{{a1}} This, too, should be rendered as a string {{a1}}", + ], ], - 'a7': 5, + "a7": 5, } - runner_param_info = {'r1': {'type': 'array'}, 'r2': {'type': 'array'}} + runner_param_info = {"r1": {"type": "array"}, "r2": {"type": "array"}} action_param_info = { - 'a1': {'type': 'boolean'}, - 'a2': {'type': 'string'}, - 'a3': {'type': 'string'}, - 'a4': {'type': 'boolean'}, - 'a5': {'type': 'array'}, - 'a6': {'type': 'array'}, - 'a7': {'type': 'integer'}, + "a1": {"type": "boolean"}, + "a2": {"type": "string"}, + "a3": {"type": "string"}, + "a4": {"type": "boolean"}, + "a5": {"type": "array"}, + "a6": {"type": "array"}, + "a7": {"type": "integer"}, } r_runner_params, r_action_params = param_utils.get_finalized_params( - runner_param_info, action_param_info, params, {'user': None}) - self.assertEqual(r_runner_params, {'r1': ['1', '2'], 'r2': ['1', '2']}) + runner_param_info, action_param_info, params, {"user": None} + ) + self.assertEqual(r_runner_params, {"r1": ["1", "2"], "r2": ["1", "2"]}) self.assertEqual( r_action_params, { - 'a1': True, - 'a2': 'Test', - 'a3': 'Test2', - 'a4': True, - 'a5': ['Test', 'Test2'], - 'a6': [ - [['1', '2'], 'Test'], - ['Test2', True], + "a1": True, + "a2": "Test", + "a3": "Test2", + "a4": True, + "a5": ["Test", "Test2"], + "a6": [ + [["1", "2"], "Test"], + ["Test2", True], [ 5, - u'This should be rendered as a string True', - u'True This, too, should be rendered as a string True' - ] + "This should be rendered as a string True", + "True This, too, should be rendered as a string True", + ], ], - 'a7': 5, - } + "a7": 5, + }, ) def test_get_finalized_params_with_cyclic_dependency(self): - params = {'r1': '{{r2}}', 'r2': '{{r1}}'} - runner_param_info = {'r1': {}, 'r2': {}} + params = {"r1": "{{r2}}", "r2": "{{r1}}"} + runner_param_info = {"r1": {}, "r2": {}} action_param_info = {} test_pass = True try: - param_utils.get_finalized_params(runner_param_info, - action_param_info, - params, - {'user': None}) + param_utils.get_finalized_params( + runner_param_info, action_param_info, params, {"user": None} + ) test_pass = False except ParamException as e: - test_pass = six.text_type(e).find('Cyclic') == 0 + test_pass = six.text_type(e).find("Cyclic") == 0 self.assertTrue(test_pass) def test_get_finalized_params_with_missing_dependency(self): - params = {'r1': '{{r3}}', 'r2': '{{r3}}'} - runner_param_info = {'r1': {}, 'r2': {}} + params = {"r1": "{{r3}}", "r2": "{{r3}}"} + runner_param_info = {"r1": {}, "r2": {}} action_param_info = {} test_pass = True try: - param_utils.get_finalized_params(runner_param_info, - action_param_info, - params, - {'user': None}) + param_utils.get_finalized_params( + runner_param_info, action_param_info, params, {"user": None} + ) test_pass = False except ParamException as e: - test_pass = six.text_type(e).find('Dependency') == 0 + test_pass = six.text_type(e).find("Dependency") == 0 self.assertTrue(test_pass) params = {} - runner_param_info = {'r1': {'default': '{{r3}}'}, 'r2': {'default': '{{r3}}'}} + runner_param_info = {"r1": {"default": "{{r3}}"}, "r2": {"default": "{{r3}}"}} action_param_info = {} test_pass = True try: - param_utils.get_finalized_params(runner_param_info, - action_param_info, - params, - {'user': None}) + param_utils.get_finalized_params( + runner_param_info, action_param_info, params, {"user": None} + ) test_pass = False except ParamException as e: - test_pass = six.text_type(e).find('Dependency') == 0 + test_pass = six.text_type(e).find("Dependency") == 0 self.assertTrue(test_pass) def test_get_finalized_params_no_double_rendering(self): - params = { - 'r1': '{{ action_context.h1 }}{{ action_context.h2 }}' - } - runner_param_info = {'r1': {}} + params = {"r1": "{{ action_context.h1 }}{{ action_context.h2 }}"} + runner_param_info = {"r1": {}} action_param_info = {} - action_context = { - 'h1': '{', - 'h2': '{ missing }}', - 'user': None - } + action_context = {"h1": "{", "h2": "{ missing }}", "user": None} r_runner_params, r_action_params = param_utils.get_finalized_params( - runner_param_info, action_param_info, params, action_context) - self.assertEqual(r_runner_params, {'r1': '{{ missing }}'}) + runner_param_info, action_param_info, params, action_context + ) + self.assertEqual(r_runner_params, {"r1": "{{ missing }}"}) self.assertEqual(r_action_params, {}) def test_get_finalized_params_jinja_filters(self): - params = {'cmd': 'echo {{"1.6.0" | version_bump_minor}}'} - runner_param_info = {'r1': {}} - action_param_info = {'cmd': {}} - action_context = {'user': None} + params = {"cmd": 'echo {{"1.6.0" | version_bump_minor}}'} + runner_param_info = {"r1": {}} + action_param_info = {"cmd": {}} + action_context = {"user": None} r_runner_params, r_action_params = param_utils.get_finalized_params( - runner_param_info, action_param_info, params, action_context) + runner_param_info, action_param_info, params, action_context + ) - self.assertEqual(r_action_params['cmd'], "echo 1.7.0") + self.assertEqual(r_action_params["cmd"], "echo 1.7.0") def test_get_finalized_params_param_rendering_failure(self): - params = {'cmd': '{{a2.foo}}', 'a2': 'test'} - action_param_info = {'cmd': {}, 'a2': {}} + params = {"cmd": "{{a2.foo}}", "a2": "test"} + action_param_info = {"cmd": {}, "a2": {}} expected_msg = 'Failed to render parameter "cmd": .*' - self.assertRaisesRegexp(ParamException, - expected_msg, - param_utils.get_finalized_params, - runnertype_parameter_info={}, - action_parameter_info=action_param_info, - liveaction_parameters=params, - action_context={'user': None}) + self.assertRaisesRegexp( + ParamException, + expected_msg, + param_utils.get_finalized_params, + runnertype_parameter_info={}, + action_parameter_info=action_param_info, + liveaction_parameters=params, + action_context={"user": None}, + ) def test_get_finalized_param_object_contains_template_notation_in_the_value(self): - runner_param_info = {'r1': {}} + runner_param_info = {"r1": {}} action_param_info = { - 'params': { - 'type': 'object', - 'default': { - 'host': '{{host}}', - 'port': '{{port}}', - 'path': '/bar'} + "params": { + "type": "object", + "default": {"host": "{{host}}", "port": "{{port}}", "path": "/bar"}, } } - params = { - 'host': 'lolcathost', - 'port': 5555 - } - action_context = {'user': None} + params = {"host": "lolcathost", "port": 5555} + action_context = {"user": None} r_runner_params, r_action_params = param_utils.get_finalized_params( - runner_param_info, action_param_info, params, action_context) + runner_param_info, action_param_info, params, action_context + ) - expected_params = { - 'host': 'lolcathost', - 'port': 5555, - 'path': '/bar' - } - self.assertEqual(r_action_params['params'], expected_params) + expected_params = {"host": "lolcathost", "port": 5555, "path": "/bar"} + self.assertEqual(r_action_params["params"], expected_params) def test_cast_param_referenced_action_doesnt_exist(self): # Make sure the function throws if the action doesnt exist expected_msg = 'Action with ref "foo.doesntexist" doesn\'t exist' - self.assertRaisesRegexp(ValueError, expected_msg, action_param_utils.cast_params, - action_ref='foo.doesntexist', params={}) + self.assertRaisesRegexp( + ValueError, + expected_msg, + action_param_utils.cast_params, + action_ref="foo.doesntexist", + params={}, + ) def test_get_finalized_params_with_config(self): - with mock.patch('st2common.util.config_loader.ContentPackConfigLoader') as config_loader: + with mock.patch( + "st2common.util.config_loader.ContentPackConfigLoader" + ) as config_loader: config_loader().get_config.return_value = { - 'generic_config_param': 'So generic' + "generic_config_param": "So generic" } params = { - 'config_param': '{{config_context.generic_config_param}}', + "config_param": "{{config_context.generic_config_param}}", } liveaction_db = self._get_liveaction_model(params, True) @@ -603,369 +623,327 @@ def test_get_finalized_params_with_config(self): ParamsUtilsTest.runnertype_db.runner_parameters, ParamsUtilsTest.action_db.parameters, liveaction_db.parameters, - liveaction_db.context) - self.assertEqual( - action_params.get('config_param'), - 'So generic' + liveaction_db.context, ) + self.assertEqual(action_params.get("config_param"), "So generic") def test_get_config(self): - with mock.patch('st2common.util.config_loader.ContentPackConfigLoader') as config_loader: - mock_config_return = { - 'generic_config_param': 'So generic' - } + with mock.patch( + "st2common.util.config_loader.ContentPackConfigLoader" + ) as config_loader: + mock_config_return = {"generic_config_param": "So generic"} config_loader().get_config.return_value = mock_config_return self.assertEqual(get_config(None, None), {}) - self.assertEqual(get_config('pack', None), {}) - self.assertEqual(get_config(None, 'user'), {}) - self.assertEqual( - get_config('pack', 'user'), mock_config_return - ) + self.assertEqual(get_config("pack", None), {}) + self.assertEqual(get_config(None, "user"), {}) + self.assertEqual(get_config("pack", "user"), mock_config_return) - config_loader.assert_called_with(pack_name='pack', user='user') + config_loader.assert_called_with(pack_name="pack", user="user") config_loader().get_config.assert_called_once() def _get_liveaction_model(self, params, with_config_context=False): - status = 'initializing' + status = "initializing" start_timestamp = date_utils.get_datetime_utc_now() - action_ref = ResourceReference(name=ParamsUtilsTest.action_db.name, - pack=ParamsUtilsTest.action_db.pack).ref - liveaction_db = LiveActionDB(status=status, start_timestamp=start_timestamp, - action=action_ref, parameters=params) + action_ref = ResourceReference( + name=ParamsUtilsTest.action_db.name, pack=ParamsUtilsTest.action_db.pack + ).ref + liveaction_db = LiveActionDB( + status=status, + start_timestamp=start_timestamp, + action=action_ref, + parameters=params, + ) liveaction_db.context = { - 'api_user': 'noob', - 'source_channel': 'reddit', + "api_user": "noob", + "source_channel": "reddit", } if with_config_context: - liveaction_db.context.update( - { - 'pack': 'generic', - 'user': 'st2admin' - } - ) + liveaction_db.context.update({"pack": "generic", "user": "st2admin"}) return liveaction_db def test_get_value_from_datastore_through_render_live_params(self): # Register datastore value to be refered by this test-case register_kwargs = [ - {'name': 'test_key', 'value': 'foo'}, - {'name': 'user1:test_key', 'value': 'bar', 'scope': FULL_USER_SCOPE}, - {'name': '%s:test_key' % cfg.CONF.system_user.user, 'value': 'baz', - 'scope': FULL_USER_SCOPE}, + {"name": "test_key", "value": "foo"}, + {"name": "user1:test_key", "value": "bar", "scope": FULL_USER_SCOPE}, + { + "name": "%s:test_key" % cfg.CONF.system_user.user, + "value": "baz", + "scope": FULL_USER_SCOPE, + }, ] for kwargs in register_kwargs: KeyValuePair.add_or_update(KeyValuePairDB(**kwargs)) # Assert that datastore value can be got via the Jinja expression from individual scopes. - context = {'user': 'user1'} + context = {"user": "user1"} param = { - 'system_value': {'default': '{{ st2kv.system.test_key }}'}, - 'user_value': {'default': '{{ st2kv.user.test_key }}'}, + "system_value": {"default": "{{ st2kv.system.test_key }}"}, + "user_value": {"default": "{{ st2kv.user.test_key }}"}, } - live_params = param_utils.render_live_params(runner_parameters={}, - action_parameters=param, - params={}, - action_context=context) + live_params = param_utils.render_live_params( + runner_parameters={}, + action_parameters=param, + params={}, + action_context=context, + ) - self.assertEqual(live_params['system_value'], 'foo') - self.assertEqual(live_params['user_value'], 'bar') + self.assertEqual(live_params["system_value"], "foo") + self.assertEqual(live_params["user_value"], "bar") # Assert that datastore value in the user-scope that is registered by user1 # cannot be got by the operation of user2. - context = {'user': 'user2'} - param = {'user_value': {'default': '{{ st2kv.user.test_key }}'}} - live_params = param_utils.render_live_params(runner_parameters={}, - action_parameters=param, - params={}, - action_context=context) + context = {"user": "user2"} + param = {"user_value": {"default": "{{ st2kv.user.test_key }}"}} + live_params = param_utils.render_live_params( + runner_parameters={}, + action_parameters=param, + params={}, + action_context=context, + ) - self.assertEqual(live_params['user_value'], '') + self.assertEqual(live_params["user_value"], "") # Assert that system-user's scope is selected when user and api_user parameter specified context = {} - param = {'user_value': {'default': '{{ st2kv.user.test_key }}'}} - live_params = param_utils.render_live_params(runner_parameters={}, - action_parameters=param, - params={}, - action_context=context) + param = {"user_value": {"default": "{{ st2kv.user.test_key }}"}} + live_params = param_utils.render_live_params( + runner_parameters={}, + action_parameters=param, + params={}, + action_context=context, + ) - self.assertEqual(live_params['user_value'], 'baz') + self.assertEqual(live_params["user_value"], "baz") def test_get_live_params_with_additional_context(self): - runner_param_info = { - 'r1': { - 'default': 'some' - } - } - action_param_info = { - 'r2': { - 'default': '{{ r1 }}' - } - } - params = { - 'r3': 'lolcathost', - 'r1': '{{ additional.stuff }}' - } - action_context = {'user': None} - additional_contexts = { - 'additional': { - 'stuff': 'generic' - } - } + runner_param_info = {"r1": {"default": "some"}} + action_param_info = {"r2": {"default": "{{ r1 }}"}} + params = {"r3": "lolcathost", "r1": "{{ additional.stuff }}"} + action_context = {"user": None} + additional_contexts = {"additional": {"stuff": "generic"}} live_params = param_utils.render_live_params( - runner_param_info, action_param_info, params, action_context, additional_contexts) + runner_param_info, + action_param_info, + params, + action_context, + additional_contexts, + ) - expected_params = { - 'r1': 'generic', - 'r2': 'generic', - 'r3': 'lolcathost' - } + expected_params = {"r1": "generic", "r2": "generic", "r3": "lolcathost"} self.assertEqual(live_params, expected_params) def test_cyclic_dependency_friendly_error_message(self): runner_param_info = { - 'r1': { - 'default': 'some', - 'cyclic': 'cyclic value', - 'morecyclic': 'cyclic value' - } - } - action_param_info = { - 'r2': { - 'default': '{{ r1 }}' + "r1": { + "default": "some", + "cyclic": "cyclic value", + "morecyclic": "cyclic value", } } + action_param_info = {"r2": {"default": "{{ r1 }}"}} params = { - 'r3': 'lolcathost', - 'cyclic': '{{ cyclic }}', - 'morecyclic': '{{ morecyclic }}' + "r3": "lolcathost", + "cyclic": "{{ cyclic }}", + "morecyclic": "{{ morecyclic }}", } - action_context = {'user': None} + action_context = {"user": None} - expected_msg = 'Cyclic dependency found in the following variables: cyclic, morecyclic' - self.assertRaisesRegexp(ParamException, expected_msg, param_utils.render_live_params, - runner_param_info, action_param_info, params, action_context) + expected_msg = ( + "Cyclic dependency found in the following variables: cyclic, morecyclic" + ) + self.assertRaisesRegexp( + ParamException, + expected_msg, + param_utils.render_live_params, + runner_param_info, + action_param_info, + params, + action_context, + ) def test_unsatisfied_dependency_friendly_error_message(self): runner_param_info = { - 'r1': { - 'default': 'some', - } - } - action_param_info = { - 'r2': { - 'default': '{{ r1 }}' + "r1": { + "default": "some", } } + action_param_info = {"r2": {"default": "{{ r1 }}"}} params = { - 'r3': 'lolcathost', - 'r4': '{{ variable_not_defined }}', + "r3": "lolcathost", + "r4": "{{ variable_not_defined }}", } - action_context = {'user': None} + action_context = {"user": None} expected_msg = 'Dependency unsatisfied in variable "variable_not_defined"' - self.assertRaisesRegexp(ParamException, expected_msg, param_utils.render_live_params, - runner_param_info, action_param_info, params, action_context) + self.assertRaisesRegexp( + ParamException, + expected_msg, + param_utils.render_live_params, + runner_param_info, + action_param_info, + params, + action_context, + ) def test_add_default_templates_to_live_params(self): - """Test addition of template values in defaults to live params - """ + """Test addition of template values in defaults to live params""" # Ensure parameter is skipped if the parameter has immutable set to true in schema schemas = [ { - 'templateparam': { - 'default': '{{ 3 | int }}', - 'type': 'integer', - 'immutable': True + "templateparam": { + "default": "{{ 3 | int }}", + "type": "integer", + "immutable": True, } } ] - context = { - 'templateparam': '3' - } + context = {"templateparam": "3"} result = param_utils._cast_params_from({}, context, schemas) self.assertEqual(result, {}) # Test with no live params, and two parameters - one should make it through because # it was a template, and the other shouldn't because its default wasn't a template - schemas = [ - { - 'templateparam': { - 'default': '{{ 3 | int }}', - 'type': 'integer' - } - } - ] - context = { - 'templateparam': '3' - } + schemas = [{"templateparam": {"default": "{{ 3 | int }}", "type": "integer"}}] + context = {"templateparam": "3"} result = param_utils._cast_params_from({}, context, schemas) - self.assertEqual(result, {'templateparam': 3}) + self.assertEqual(result, {"templateparam": 3}) # Ensure parameter is skipped if the value in context is identical to default - schemas = [ - { - 'nottemplateparam': { - 'default': '4', - 'type': 'integer' - } - } - ] + schemas = [{"nottemplateparam": {"default": "4", "type": "integer"}}] context = { - 'nottemplateparam': '4', + "nottemplateparam": "4", } result = param_utils._cast_params_from({}, context, schemas) self.assertEqual(result, {}) # Ensure parameter is skipped if the parameter doesn't have a default - schemas = [ - { - 'nottemplateparam': { - 'type': 'integer' - } - } - ] + schemas = [{"nottemplateparam": {"type": "integer"}}] context = { - 'nottemplateparam': '4', + "nottemplateparam": "4", } result = param_utils._cast_params_from({}, context, schemas) self.assertEqual(result, {}) # Skip if the default value isn't a Jinja expression - schemas = [ - { - 'nottemplateparam': { - 'default': '5', - 'type': 'integer' - } - } - ] + schemas = [{"nottemplateparam": {"default": "5", "type": "integer"}}] context = { - 'nottemplateparam': '4', + "nottemplateparam": "4", } result = param_utils._cast_params_from({}, context, schemas) self.assertEqual(result, {}) # Ensure parameter is skipped if the parameter is being overridden - schemas = [ - { - 'templateparam': { - 'default': '{{ 3 | int }}', - 'type': 'integer' - } - } - ] + schemas = [{"templateparam": {"default": "{{ 3 | int }}", "type": "integer"}}] context = { - 'templateparam': '4', + "templateparam": "4", } - result = param_utils._cast_params_from({'templateparam': '4'}, context, schemas) - self.assertEqual(result, {'templateparam': 4}) + result = param_utils._cast_params_from({"templateparam": "4"}, context, schemas) + self.assertEqual(result, {"templateparam": 4}) def test_render_final_params_and_shell_script_action_command_strings(self): runner_parameters = {} action_db_parameters = { - 'project': { - 'type': 'string', - 'default': 'st2', - 'position': 0, + "project": { + "type": "string", + "default": "st2", + "position": 0, }, - 'version': { - 'type': 'string', - 'position': 1, - 'required': True + "version": {"type": "string", "position": 1, "required": True}, + "fork": { + "type": "string", + "position": 2, + "default": "StackStorm", }, - 'fork': { - 'type': 'string', - 'position': 2, - 'default': 'StackStorm', + "branch": { + "type": "string", + "position": 3, + "default": "master", }, - 'branch': { - 'type': 'string', - 'position': 3, - 'default': 'master', + "update_changelog": {"type": "boolean", "position": 4, "default": False}, + "local_repo": { + "type": "string", + "position": 5, }, - 'update_changelog': { - 'type': 'boolean', - 'position': 4, - 'default': False - }, - 'local_repo': { - 'type': 'string', - 'position': 5, - } } context = {} # 1. All default values used live_action_db_parameters = { - 'project': 'st2flow', - 'version': '3.0.0', - 'fork': 'StackStorm', - 'local_repo': '/tmp/repo' + "project": "st2flow", + "version": "3.0.0", + "fork": "StackStorm", + "local_repo": "/tmp/repo", } - runner_params, action_params = param_utils.render_final_params(runner_parameters, - action_db_parameters, - live_action_db_parameters, - context) + runner_params, action_params = param_utils.render_final_params( + runner_parameters, action_db_parameters, live_action_db_parameters, context + ) - self.assertDictEqual(action_params, { - 'project': 'st2flow', - 'version': '3.0.0', - 'fork': 'StackStorm', - 'branch': 'master', # default value used - 'update_changelog': False, # default value used - 'local_repo': '/tmp/repo' - }) + self.assertDictEqual( + action_params, + { + "project": "st2flow", + "version": "3.0.0", + "fork": "StackStorm", + "branch": "master", # default value used + "update_changelog": False, # default value used + "local_repo": "/tmp/repo", + }, + ) # 2. Some default values used live_action_db_parameters = { - 'project': 'st2web', - 'version': '3.1.0', - 'fork': 'StackStorm1', - 'update_changelog': True, - 'local_repo': '/tmp/repob' + "project": "st2web", + "version": "3.1.0", + "fork": "StackStorm1", + "update_changelog": True, + "local_repo": "/tmp/repob", } - runner_params, action_params = param_utils.render_final_params(runner_parameters, - action_db_parameters, - live_action_db_parameters, - context) + runner_params, action_params = param_utils.render_final_params( + runner_parameters, action_db_parameters, live_action_db_parameters, context + ) - self.assertDictEqual(action_params, { - 'project': 'st2web', - 'version': '3.1.0', - 'fork': 'StackStorm1', - 'branch': 'master', # default value used - 'update_changelog': True, # default value used - 'local_repo': '/tmp/repob' - }) + self.assertDictEqual( + action_params, + { + "project": "st2web", + "version": "3.1.0", + "fork": "StackStorm1", + "branch": "master", # default value used + "update_changelog": True, # default value used + "local_repo": "/tmp/repob", + }, + ) # 3. None is specified for a boolean parameter, should use a default live_action_db_parameters = { - 'project': 'st2rbac', - 'version': '3.2.0', - 'fork': 'StackStorm2', - 'update_changelog': None, - 'local_repo': '/tmp/repoc' + "project": "st2rbac", + "version": "3.2.0", + "fork": "StackStorm2", + "update_changelog": None, + "local_repo": "/tmp/repoc", } - runner_params, action_params = param_utils.render_final_params(runner_parameters, - action_db_parameters, - live_action_db_parameters, - context) - - self.assertDictEqual(action_params, { - 'project': 'st2rbac', - 'version': '3.2.0', - 'fork': 'StackStorm2', - 'branch': 'master', # default value used - 'update_changelog': False, # default value used - 'local_repo': '/tmp/repoc' - }) + runner_params, action_params = param_utils.render_final_params( + runner_parameters, action_db_parameters, live_action_db_parameters, context + ) + + self.assertDictEqual( + action_params, + { + "project": "st2rbac", + "version": "3.2.0", + "fork": "StackStorm2", + "branch": "master", # default value used + "update_changelog": False, # default value used + "local_repo": "/tmp/repoc", + }, + ) diff --git a/st2common/tests/unit/test_paramiko_command_action_model.py b/st2common/tests/unit/test_paramiko_command_action_model.py index 2ce7bbfed3a..0d023d4f8ac 100644 --- a/st2common/tests/unit/test_paramiko_command_action_model.py +++ b/st2common/tests/unit/test_paramiko_command_action_model.py @@ -18,76 +18,84 @@ from st2common.models.system.paramiko_command_action import ParamikoRemoteCommandAction -__all__ = [ - 'ParamikoRemoteCommandActionTestCase' -] +__all__ = ["ParamikoRemoteCommandActionTestCase"] class ParamikoRemoteCommandActionTestCase(unittest2.TestCase): - def test_get_command_string_no_env_vars(self): cmd_action = ParamikoRemoteCommandActionTestCase._get_test_command_action( - 'echo boo bah baz') - ex = 'cd /tmp && echo boo bah baz' + "echo boo bah baz" + ) + ex = "cd /tmp && echo boo bah baz" self.assertEqual(cmd_action.get_full_command_string(), ex) # With sudo cmd_action.sudo = True - ex = 'sudo -E -- bash -c \'cd /tmp && echo boo bah baz\'' + ex = "sudo -E -- bash -c 'cd /tmp && echo boo bah baz'" self.assertEqual(cmd_action.get_full_command_string(), ex) # Executing a path command requires user to provide an escaped input. # E.g. st2 run core.remote hosts=localhost cmd='"/tmp/space stuff.sh"' cmd_action = ParamikoRemoteCommandActionTestCase._get_test_command_action( - '"/t/space stuff.sh"') + '"/t/space stuff.sh"' + ) ex = 'cd /tmp && "/t/space stuff.sh"' self.assertEqual(cmd_action.get_full_command_string(), ex) # sudo_password provided cmd_action = ParamikoRemoteCommandActionTestCase._get_test_command_action( - 'echo boo bah baz') + "echo boo bah baz" + ) cmd_action.sudo = True - cmd_action.sudo_password = 'sudo pass' + cmd_action.sudo_password = "sudo pass" - ex = ('set +o history ; echo -e \'sudo pass\n\' | sudo -S -E -- ' - 'bash -c \'cd /tmp && echo boo bah baz\'') + ex = ( + "set +o history ; echo -e 'sudo pass\n' | sudo -S -E -- " + "bash -c 'cd /tmp && echo boo bah baz'" + ) self.assertEqual(cmd_action.get_full_command_string(), ex) def test_get_command_string_with_env_vars(self): cmd_action = ParamikoRemoteCommandActionTestCase._get_test_command_action( - 'echo boo bah baz') - cmd_action.env_vars = {'FOO': 'BAR', 'BAR': 'BEET CAFE'} - ex = 'export BAR=\'BEET CAFE\' ' + \ - 'FOO=BAR' + \ - ' && cd /tmp && echo boo bah baz' + "echo boo bah baz" + ) + cmd_action.env_vars = {"FOO": "BAR", "BAR": "BEET CAFE"} + ex = "export BAR='BEET CAFE' " + "FOO=BAR" + " && cd /tmp && echo boo bah baz" self.assertEqual(cmd_action.get_full_command_string(), ex) # With sudo cmd_action.sudo = True - ex = 'sudo -E -- bash -c ' + \ - '\'export FOO=BAR ' + \ - 'BAR=\'"\'"\'BEET CAFE\'"\'"\'' + \ - ' && cd /tmp && echo boo bah baz\'' - ex = 'sudo -E -- bash -c ' + \ - '\'export BAR=\'"\'"\'BEET CAFE\'"\'"\' ' + \ - 'FOO=BAR' + \ - ' && cd /tmp && echo boo bah baz\'' + ex = ( + "sudo -E -- bash -c " + + "'export FOO=BAR " + + "BAR='\"'\"'BEET CAFE'\"'\"'" + + " && cd /tmp && echo boo bah baz'" + ) + ex = ( + "sudo -E -- bash -c " + + "'export BAR='\"'\"'BEET CAFE'\"'\"' " + + "FOO=BAR" + + " && cd /tmp && echo boo bah baz'" + ) self.assertEqual(cmd_action.get_full_command_string(), ex) # with sudo_password cmd_action.sudo = True - cmd_action.sudo_password = 'sudo pass' - ex = 'set +o history ; echo -e \'sudo pass\n\' | sudo -S -E -- bash -c ' + \ - '\'export BAR=\'"\'"\'BEET CAFE\'"\'"\' ' + \ - 'FOO=BAR HISTFILE=/dev/null HISTSIZE=0' + \ - ' && cd /tmp && echo boo bah baz\'' + cmd_action.sudo_password = "sudo pass" + ex = ( + "set +o history ; echo -e 'sudo pass\n' | sudo -S -E -- bash -c " + + "'export BAR='\"'\"'BEET CAFE'\"'\"' " + + "FOO=BAR HISTFILE=/dev/null HISTSIZE=0" + + " && cd /tmp && echo boo bah baz'" + ) self.assertEqual(cmd_action.get_full_command_string(), ex) def test_get_command_string_no_user(self): cmd_action = ParamikoRemoteCommandActionTestCase._get_test_command_action( - 'echo boo bah baz') + "echo boo bah baz" + ) cmd_action.user = None - ex = 'cd /tmp && echo boo bah baz' + ex = "cd /tmp && echo boo bah baz" self.assertEqual(cmd_action.get_full_command_string(), ex) # Executing a path command requires user to provide an escaped input. @@ -99,25 +107,28 @@ def test_get_command_string_no_user(self): def test_get_command_string_no_user_env_vars(self): cmd_action = ParamikoRemoteCommandActionTestCase._get_test_command_action( - 'echo boo bah baz') + "echo boo bah baz" + ) cmd_action.user = None - cmd_action.env_vars = {'FOO': 'BAR'} - ex = 'export FOO=BAR && cd /tmp && echo boo bah baz' + cmd_action.env_vars = {"FOO": "BAR"} + ex = "export FOO=BAR && cd /tmp && echo boo bah baz" self.assertEqual(cmd_action.get_full_command_string(), ex) @staticmethod def _get_test_command_action(command): - cmd_action = ParamikoRemoteCommandAction('fixtures.remote_command', - '55ce39d532ed3543aecbe71d', - command=command, - env_vars={}, - on_behalf_user='svetlana', - user='estee', - password=None, - private_key='---PRIVATE-KEY---', - hosts='127.0.0.1', - parallel=True, - sudo=False, - timeout=None, - cwd='/tmp') + cmd_action = ParamikoRemoteCommandAction( + "fixtures.remote_command", + "55ce39d532ed3543aecbe71d", + command=command, + env_vars={}, + on_behalf_user="svetlana", + user="estee", + password=None, + private_key="---PRIVATE-KEY---", + hosts="127.0.0.1", + parallel=True, + sudo=False, + timeout=None, + cwd="/tmp", + ) return cmd_action diff --git a/st2common/tests/unit/test_paramiko_script_action_model.py b/st2common/tests/unit/test_paramiko_script_action_model.py index e05350e46d4..3efae1053f4 100644 --- a/st2common/tests/unit/test_paramiko_script_action_model.py +++ b/st2common/tests/unit/test_paramiko_script_action_model.py @@ -18,75 +18,81 @@ from st2common.models.system.paramiko_script_action import ParamikoRemoteScriptAction -__all__ = [ - 'ParamikoRemoteScriptActionTestCase' -] +__all__ = ["ParamikoRemoteScriptActionTestCase"] class ParamikoRemoteScriptActionTestCase(unittest2.TestCase): - def test_get_command_string_no_env_vars(self): script_action = ParamikoRemoteScriptActionTestCase._get_test_script_action() - ex = 'cd /tmp && /tmp/remote_script.sh song=\'b s\' \'taylor swift\'' + ex = "cd /tmp && /tmp/remote_script.sh song='b s' 'taylor swift'" self.assertEqual(script_action.get_full_command_string(), ex) # Test with sudo script_action.sudo = True - ex = 'sudo -E -- bash -c ' + \ - '\'cd /tmp && ' + \ - '/tmp/remote_script.sh song=\'"\'"\'b s\'"\'"\' \'"\'"\'taylor swift\'"\'"\'\'' + ex = ( + "sudo -E -- bash -c " + + "'cd /tmp && " + + "/tmp/remote_script.sh song='\"'\"'b s'\"'\"' '\"'\"'taylor swift'\"'\"''" + ) self.assertEqual(script_action.get_full_command_string(), ex) # with sudo password script_action.sudo = True - script_action.sudo_password = 'sudo pass' - ex = 'set +o history ; echo -e \'sudo pass\n\' | sudo -S -E -- bash -c ' + \ - '\'cd /tmp && ' + \ - '/tmp/remote_script.sh song=\'"\'"\'b s\'"\'"\' \'"\'"\'taylor swift\'"\'"\'\'' + script_action.sudo_password = "sudo pass" + ex = ( + "set +o history ; echo -e 'sudo pass\n' | sudo -S -E -- bash -c " + + "'cd /tmp && " + + "/tmp/remote_script.sh song='\"'\"'b s'\"'\"' '\"'\"'taylor swift'\"'\"''" + ) self.assertEqual(script_action.get_full_command_string(), ex) def test_get_command_string_with_env_vars(self): script_action = ParamikoRemoteScriptActionTestCase._get_test_script_action() script_action.env_vars = { - 'ST2_ACTION_EXECUTION_ID': '55ce39d532ed3543aecbe71d', - 'FOO': 'BAR BAZ BOOZ' + "ST2_ACTION_EXECUTION_ID": "55ce39d532ed3543aecbe71d", + "FOO": "BAR BAZ BOOZ", } - ex = 'export FOO=\'BAR BAZ BOOZ\' ' + \ - 'ST2_ACTION_EXECUTION_ID=55ce39d532ed3543aecbe71d && ' + \ - 'cd /tmp && /tmp/remote_script.sh song=\'b s\' \'taylor swift\'' + ex = ( + "export FOO='BAR BAZ BOOZ' " + + "ST2_ACTION_EXECUTION_ID=55ce39d532ed3543aecbe71d && " + + "cd /tmp && /tmp/remote_script.sh song='b s' 'taylor swift'" + ) self.assertEqual(script_action.get_full_command_string(), ex) # Test with sudo script_action.sudo = True - ex = 'sudo -E -- bash -c ' + \ - '\'export FOO=\'"\'"\'BAR BAZ BOOZ\'"\'"\' ' + \ - 'ST2_ACTION_EXECUTION_ID=55ce39d532ed3543aecbe71d && ' + \ - 'cd /tmp && ' + \ - '/tmp/remote_script.sh song=\'"\'"\'b s\'"\'"\' \'"\'"\'taylor swift\'"\'"\'\'' + ex = ( + "sudo -E -- bash -c " + + "'export FOO='\"'\"'BAR BAZ BOOZ'\"'\"' " + + "ST2_ACTION_EXECUTION_ID=55ce39d532ed3543aecbe71d && " + + "cd /tmp && " + + "/tmp/remote_script.sh song='\"'\"'b s'\"'\"' '\"'\"'taylor swift'\"'\"''" + ) self.assertEqual(script_action.get_full_command_string(), ex) # with sudo password script_action.sudo = True - script_action.sudo_password = 'sudo pass' + script_action.sudo_password = "sudo pass" - ex = 'set +o history ; echo -e \'sudo pass\n\' | sudo -S -E -- bash -c ' + \ - '\'export FOO=\'"\'"\'BAR BAZ BOOZ\'"\'"\' HISTFILE=/dev/null HISTSIZE=0 ' + \ - 'ST2_ACTION_EXECUTION_ID=55ce39d532ed3543aecbe71d && ' + \ - 'cd /tmp && ' + \ - '/tmp/remote_script.sh song=\'"\'"\'b s\'"\'"\' \'"\'"\'taylor swift\'"\'"\'\'' + ex = ( + "set +o history ; echo -e 'sudo pass\n' | sudo -S -E -- bash -c " + + "'export FOO='\"'\"'BAR BAZ BOOZ'\"'\"' HISTFILE=/dev/null HISTSIZE=0 " + + "ST2_ACTION_EXECUTION_ID=55ce39d532ed3543aecbe71d && " + + "cd /tmp && " + + "/tmp/remote_script.sh song='\"'\"'b s'\"'\"' '\"'\"'taylor swift'\"'\"''" + ) self.assertEqual(script_action.get_full_command_string(), ex) def test_get_command_string_no_script_args_no_env_args(self): script_action = ParamikoRemoteScriptActionTestCase._get_test_script_action() script_action.named_args = {} script_action.positional_args = [] - ex = 'cd /tmp && /tmp/remote_script.sh' + ex = "cd /tmp && /tmp/remote_script.sh" self.assertEqual(script_action.get_full_command_string(), ex) # Test with sudo script_action.sudo = True - ex = 'sudo -E -- bash -c ' + \ - '\'cd /tmp && /tmp/remote_script.sh\'' + ex = "sudo -E -- bash -c " + "'cd /tmp && /tmp/remote_script.sh'" self.assertEqual(script_action.get_full_command_string(), ex) def test_get_command_string_no_script_args_with_env_args(self): @@ -94,88 +100,100 @@ def test_get_command_string_no_script_args_with_env_args(self): script_action.named_args = {} script_action.positional_args = [] script_action.env_vars = { - 'ST2_ACTION_EXECUTION_ID': '55ce39d532ed3543aecbe71d', - 'FOO': 'BAR BAZ BOOZ' + "ST2_ACTION_EXECUTION_ID": "55ce39d532ed3543aecbe71d", + "FOO": "BAR BAZ BOOZ", } - ex = 'export FOO=\'BAR BAZ BOOZ\' ' + \ - 'ST2_ACTION_EXECUTION_ID=55ce39d532ed3543aecbe71d && ' + \ - 'cd /tmp && /tmp/remote_script.sh' + ex = ( + "export FOO='BAR BAZ BOOZ' " + + "ST2_ACTION_EXECUTION_ID=55ce39d532ed3543aecbe71d && " + + "cd /tmp && /tmp/remote_script.sh" + ) self.assertEqual(script_action.get_full_command_string(), ex) # Test with sudo script_action.sudo = True - ex = 'sudo -E -- bash -c ' + \ - '\'export FOO=\'"\'"\'BAR BAZ BOOZ\'"\'"\' ' + \ - 'ST2_ACTION_EXECUTION_ID=55ce39d532ed3543aecbe71d && ' + \ - 'cd /tmp && ' + \ - '/tmp/remote_script.sh\'' + ex = ( + "sudo -E -- bash -c " + + "'export FOO='\"'\"'BAR BAZ BOOZ'\"'\"' " + + "ST2_ACTION_EXECUTION_ID=55ce39d532ed3543aecbe71d && " + + "cd /tmp && " + + "/tmp/remote_script.sh'" + ) self.assertEqual(script_action.get_full_command_string(), ex) def test_script_path_shell_injection_safe(self): script_action = ParamikoRemoteScriptActionTestCase._get_test_script_action() - test_path = '/tmp/remote script.sh' + test_path = "/tmp/remote script.sh" script_action.remote_script = test_path script_action.named_args = {} script_action.positional_args = [] - ex = 'cd /tmp && \'/tmp/remote script.sh\'' + ex = "cd /tmp && '/tmp/remote script.sh'" self.assertEqual(script_action.get_full_command_string(), ex) # Test with sudo script_action.sudo = True - ex = 'sudo -E -- bash -c ' + \ - '\'cd /tmp && \'"\'"\'/tmp/remote script.sh\'"\'"\'\'' + ex = "sudo -E -- bash -c " + "'cd /tmp && '\"'\"'/tmp/remote script.sh'\"'\"''" self.assertEqual(script_action.get_full_command_string(), ex) # With sudo_password script_action.sudo = True - script_action.sudo_password = 'sudo pass' + script_action.sudo_password = "sudo pass" - ex = 'set +o history ; echo -e \'sudo pass\n\' | sudo -S -E -- bash -c ' + \ - '\'cd /tmp && \'"\'"\'/tmp/remote script.sh\'"\'"\'\'' + ex = ( + "set +o history ; echo -e 'sudo pass\n' | sudo -S -E -- bash -c " + + "'cd /tmp && '\"'\"'/tmp/remote script.sh'\"'\"''" + ) self.assertEqual(script_action.get_full_command_string(), ex) def test_script_path_shell_injection_safe_with_env_vars(self): script_action = ParamikoRemoteScriptActionTestCase._get_test_script_action() - test_path = '/tmp/remote script.sh' + test_path = "/tmp/remote script.sh" script_action.remote_script = test_path script_action.named_args = {} script_action.positional_args = [] - script_action.env_vars = {'FOO': 'BAR'} - ex = 'export FOO=BAR && cd /tmp && \'/tmp/remote script.sh\'' + script_action.env_vars = {"FOO": "BAR"} + ex = "export FOO=BAR && cd /tmp && '/tmp/remote script.sh'" self.assertEqual(script_action.get_full_command_string(), ex) # Test with sudo script_action.sudo = True - ex = 'sudo -E -- bash -c ' + \ - '\'export FOO=BAR && ' + \ - 'cd /tmp && \'"\'"\'/tmp/remote script.sh\'"\'"\'\'' + ex = ( + "sudo -E -- bash -c " + + "'export FOO=BAR && " + + "cd /tmp && '\"'\"'/tmp/remote script.sh'\"'\"''" + ) self.assertEqual(script_action.get_full_command_string(), ex) # With sudo_password script_action.sudo = True - script_action.sudo_password = 'sudo pass' + script_action.sudo_password = "sudo pass" - ex = 'set +o history ; echo -e \'sudo pass\n\' | sudo -S -E -- bash -c ' + \ - '\'export FOO=BAR HISTFILE=/dev/null HISTSIZE=0 && ' + \ - 'cd /tmp && \'"\'"\'/tmp/remote script.sh\'"\'"\'\'' + ex = ( + "set +o history ; echo -e 'sudo pass\n' | sudo -S -E -- bash -c " + + "'export FOO=BAR HISTFILE=/dev/null HISTSIZE=0 && " + + "cd /tmp && '\"'\"'/tmp/remote script.sh'\"'\"''" + ) self.assertEqual(script_action.get_full_command_string(), ex) @staticmethod def _get_test_script_action(): - local_script_path = '/opt/stackstorm/packs/fixtures/actions/remote_script.sh' - script_action = ParamikoRemoteScriptAction('fixtures.remote_script', - '55ce39d532ed3543aecbe71d', - local_script_path, - '/opt/stackstorm/packs/fixtures/actions/lib/', - named_args={'song': 'b s'}, - positional_args=['taylor swift'], - env_vars={}, - on_behalf_user='stanley', - user='vagrant', - private_key='/home/vagrant/.ssh/stanley_rsa', - remote_dir='/tmp', - hosts=['127.0.0.1'], - parallel=True, - sudo=False, - timeout=60, cwd='/tmp') + local_script_path = "/opt/stackstorm/packs/fixtures/actions/remote_script.sh" + script_action = ParamikoRemoteScriptAction( + "fixtures.remote_script", + "55ce39d532ed3543aecbe71d", + local_script_path, + "/opt/stackstorm/packs/fixtures/actions/lib/", + named_args={"song": "b s"}, + positional_args=["taylor swift"], + env_vars={}, + on_behalf_user="stanley", + user="vagrant", + private_key="/home/vagrant/.ssh/stanley_rsa", + remote_dir="/tmp", + hosts=["127.0.0.1"], + parallel=True, + sudo=False, + timeout=60, + cwd="/tmp", + ) return script_action diff --git a/st2common/tests/unit/test_persistence.py b/st2common/tests/unit/test_persistence.py index 14f25731ff9..6fce36c18d3 100644 --- a/st2common/tests/unit/test_persistence.py +++ b/st2common/tests/unit/test_persistence.py @@ -27,7 +27,6 @@ class TestPersistence(DbTestCase): - @classmethod def setUpClass(cls): super(TestPersistence, cls).setUpClass() @@ -38,7 +37,7 @@ def tearDown(self): super(TestPersistence, self).tearDown() def test_crud(self): - obj1 = FakeModelDB(name=uuid.uuid4().hex, context={'a': 1}) + obj1 = FakeModelDB(name=uuid.uuid4().hex, context={"a": 1}) obj1 = self.access.add_or_update(obj1) obj2 = self.access.get(name=obj1.name) self.assertIsNotNone(obj2) @@ -59,16 +58,16 @@ def test_crud(self): self.assertIsNone(obj2) def test_count(self): - obj1 = FakeModelDB(name=uuid.uuid4().hex, context={'user': 'system'}) + obj1 = FakeModelDB(name=uuid.uuid4().hex, context={"user": "system"}) obj1 = self.access.add_or_update(obj1) - obj2 = FakeModelDB(name=uuid.uuid4().hex, context={'user': 'stanley'}) + obj2 = FakeModelDB(name=uuid.uuid4().hex, context={"user": "stanley"}) obj2 = self.access.add_or_update(obj2) self.assertEqual(self.access.count(), 2) def test_get_all(self): - obj1 = FakeModelDB(name=uuid.uuid4().hex, context={'user': 'system'}) + obj1 = FakeModelDB(name=uuid.uuid4().hex, context={"user": "system"}) obj1 = self.access.add_or_update(obj1) - obj2 = FakeModelDB(name=uuid.uuid4().hex, context={'user': 'stanley'}) + obj2 = FakeModelDB(name=uuid.uuid4().hex, context={"user": "stanley"}) obj2 = self.access.add_or_update(obj2) objs = self.access.get_all() self.assertIsNotNone(objs) @@ -76,33 +75,35 @@ def test_get_all(self): self.assertListEqual(list(objs), [obj1, obj2]) def test_query_by_id(self): - obj1 = FakeModelDB(name=uuid.uuid4().hex, context={'user': 'system'}) + obj1 = FakeModelDB(name=uuid.uuid4().hex, context={"user": "system"}) obj1 = self.access.add_or_update(obj1) obj2 = self.access.get_by_id(str(obj1.id)) self.assertIsNotNone(obj2) self.assertEqual(obj1.id, obj2.id) self.assertEqual(obj1.name, obj2.name) self.assertDictEqual(obj1.context, obj2.context) - self.assertRaises(StackStormDBObjectNotFoundError, - self.access.get_by_id, str(bson.ObjectId())) + self.assertRaises( + StackStormDBObjectNotFoundError, self.access.get_by_id, str(bson.ObjectId()) + ) def test_query_by_name(self): - obj1 = FakeModelDB(name=uuid.uuid4().hex, context={'user': 'system'}) + obj1 = FakeModelDB(name=uuid.uuid4().hex, context={"user": "system"}) obj1 = self.access.add_or_update(obj1) obj2 = self.access.get_by_name(obj1.name) self.assertIsNotNone(obj2) self.assertEqual(obj1.id, obj2.id) self.assertEqual(obj1.name, obj2.name) self.assertDictEqual(obj1.context, obj2.context) - self.assertRaises(StackStormDBObjectNotFoundError, self.access.get_by_name, - uuid.uuid4().hex) + self.assertRaises( + StackStormDBObjectNotFoundError, self.access.get_by_name, uuid.uuid4().hex + ) def test_query_filter(self): - obj1 = FakeModelDB(name=uuid.uuid4().hex, context={'user': 'system'}) + obj1 = FakeModelDB(name=uuid.uuid4().hex, context={"user": "system"}) obj1 = self.access.add_or_update(obj1) - obj2 = FakeModelDB(name=uuid.uuid4().hex, context={'user': 'stanley'}) + obj2 = FakeModelDB(name=uuid.uuid4().hex, context={"user": "stanley"}) obj2 = self.access.add_or_update(obj2) - objs = self.access.query(context__user='system') + objs = self.access.query(context__user="system") self.assertIsNotNone(objs) self.assertGreater(len(objs), 0) self.assertEqual(obj1.id, objs[0].id) @@ -113,17 +114,17 @@ def test_null_filter(self): obj1 = FakeModelDB(name=uuid.uuid4().hex) obj1 = self.access.add_or_update(obj1) - objs = self.access.query(index='null') + objs = self.access.query(index="null") self.assertEqual(len(objs), 1) self.assertEqual(obj1.id, objs[0].id) self.assertEqual(obj1.name, objs[0].name) - self.assertIsNone(getattr(obj1, 'index', None)) + self.assertIsNone(getattr(obj1, "index", None)) objs = self.access.query(index=None) self.assertEqual(len(objs), 1) self.assertEqual(obj1.id, objs[0].id) self.assertEqual(obj1.name, objs[0].name) - self.assertIsNone(getattr(obj1, 'index', None)) + self.assertIsNone(getattr(obj1, "index", None)) def test_datetime_range(self): base = date_utils.add_utc_tz(datetime.datetime(2014, 12, 25, 0, 0, 0)) @@ -132,12 +133,12 @@ def test_datetime_range(self): obj = FakeModelDB(name=uuid.uuid4().hex, timestamp=timestamp) self.access.add_or_update(obj) - dt_range = '2014-12-25T00:00:10Z..2014-12-25T00:00:19Z' + dt_range = "2014-12-25T00:00:10Z..2014-12-25T00:00:19Z" objs = self.access.query(timestamp=dt_range) self.assertEqual(len(objs), 10) self.assertLess(objs[0].timestamp, objs[9].timestamp) - dt_range = '2014-12-25T00:00:19Z..2014-12-25T00:00:10Z' + dt_range = "2014-12-25T00:00:19Z..2014-12-25T00:00:10Z" objs = self.access.query(timestamp=dt_range) self.assertEqual(len(objs), 10) self.assertLess(objs[9].timestamp, objs[0].timestamp) @@ -146,52 +147,61 @@ def test_pagination(self): count = 100 page_size = 25 pages = int(count / page_size) - users = ['Peter', 'Susan', 'Edmund', 'Lucy'] + users = ["Peter", "Susan", "Edmund", "Lucy"] for user in users: - context = {'user': user} + context = {"user": user} for i in range(count): - self.access.add_or_update(FakeModelDB(name=uuid.uuid4().hex, - context=context, index=i)) + self.access.add_or_update( + FakeModelDB(name=uuid.uuid4().hex, context=context, index=i) + ) self.assertEqual(self.access.count(), len(users) * count) for user in users: for i in range(pages): offset = i * page_size - objs = self.access.query(context__user=user, order_by=['index'], - offset=offset, limit=page_size) + objs = self.access.query( + context__user=user, + order_by=["index"], + offset=offset, + limit=page_size, + ) self.assertEqual(len(objs), page_size) for j in range(page_size): - self.assertEqual(objs[j].context['user'], user) + self.assertEqual(objs[j].context["user"], user) self.assertEqual(objs[j].index, (i * page_size) + j) def test_sort_multiple(self): count = 60 base = date_utils.add_utc_tz(datetime.datetime(2014, 12, 25, 0, 0, 0)) for i in range(count): - category = 'type1' if i % 2 else 'type2' + category = "type1" if i % 2 else "type2" timestamp = base + datetime.timedelta(seconds=i) - obj = FakeModelDB(name=uuid.uuid4().hex, timestamp=timestamp, category=category) + obj = FakeModelDB( + name=uuid.uuid4().hex, timestamp=timestamp, category=category + ) self.access.add_or_update(obj) - objs = self.access.query(order_by=['category', 'timestamp']) + objs = self.access.query(order_by=["category", "timestamp"]) self.assertEqual(len(objs), count) for i in range(count): - category = 'type1' if i < count / 2 else 'type2' + category = "type1" if i < count / 2 else "type2" self.assertEqual(objs[i].category, category) self.assertLess(objs[0].timestamp, objs[(int(count / 2)) - 1].timestamp) - self.assertLess(objs[int(count / 2)].timestamp, objs[(int(count / 2)) - 1].timestamp) + self.assertLess( + objs[int(count / 2)].timestamp, objs[(int(count / 2)) - 1].timestamp + ) self.assertLess(objs[int(count / 2)].timestamp, objs[count - 1].timestamp) def test_escaped_field(self): - context = {'a.b.c': 'abc'} + context = {"a.b.c": "abc"} obj1 = FakeModelDB(name=uuid.uuid4().hex, context=context) obj2 = self.access.add_or_update(obj1) # Check that the original dict has not been altered. - self.assertIn('a.b.c', list(context.keys())) - self.assertNotIn('a\uff0eb\uff0ec', list(context.keys())) + self.assertIn("a.b.c", list(context.keys())) + self.assertNotIn("a\uff0eb\uff0ec", list(context.keys())) # Check to_python has run and context is not left escaped. self.assertDictEqual(obj2.context, context) @@ -206,26 +216,26 @@ def test_query_only_fields(self): count = 5 ts = date_utils.add_utc_tz(datetime.datetime(2014, 12, 25, 0, 0, 0)) for i in range(count): - category = 'type1' - obj = FakeModelDB(name='test-%s' % (i), timestamp=ts, category=category) + category = "type1" + obj = FakeModelDB(name="test-%s" % (i), timestamp=ts, category=category) self.access.add_or_update(obj) model_dbs = FakeModel.query() - self.assertEqual(model_dbs[0].name, 'test-0') + self.assertEqual(model_dbs[0].name, "test-0") self.assertEqual(model_dbs[0].timestamp, ts) - self.assertEqual(model_dbs[0].category, 'type1') + self.assertEqual(model_dbs[0].category, "type1") # only id - model_dbs = FakeModel.query(only_fields=['id']) + model_dbs = FakeModel.query(only_fields=["id"]) self.assertTrue(model_dbs[0].id) self.assertEqual(model_dbs[0].name, None) self.assertEqual(model_dbs[0].timestamp, None) self.assertEqual(model_dbs[0].category, None) # only name - note: id is always included - model_dbs = FakeModel.query(only_fields=['name']) + model_dbs = FakeModel.query(only_fields=["name"]) self.assertTrue(model_dbs[0].id) - self.assertEqual(model_dbs[0].name, 'test-0') + self.assertEqual(model_dbs[0].name, "test-0") self.assertEqual(model_dbs[0].timestamp, None) self.assertEqual(model_dbs[0].category, None) @@ -233,28 +243,28 @@ def test_query_exclude_fields(self): count = 5 ts = date_utils.add_utc_tz(datetime.datetime(2014, 12, 25, 0, 0, 0)) for i in range(count): - category = 'type1' - obj = FakeModelDB(name='test-2-%s' % (i), timestamp=ts, category=category) + category = "type1" + obj = FakeModelDB(name="test-2-%s" % (i), timestamp=ts, category=category) self.access.add_or_update(obj) model_dbs = FakeModel.query() - self.assertEqual(model_dbs[0].name, 'test-2-0') + self.assertEqual(model_dbs[0].name, "test-2-0") self.assertEqual(model_dbs[0].timestamp, ts) - self.assertEqual(model_dbs[0].category, 'type1') + self.assertEqual(model_dbs[0].category, "type1") - model_dbs = FakeModel.query(exclude_fields=['name']) + model_dbs = FakeModel.query(exclude_fields=["name"]) self.assertTrue(model_dbs[0].id) self.assertEqual(model_dbs[0].name, None) self.assertEqual(model_dbs[0].timestamp, ts) - self.assertEqual(model_dbs[0].category, 'type1') + self.assertEqual(model_dbs[0].category, "type1") - model_dbs = FakeModel.query(exclude_fields=['name', 'timestamp']) + model_dbs = FakeModel.query(exclude_fields=["name", "timestamp"]) self.assertTrue(model_dbs[0].id) self.assertEqual(model_dbs[0].name, None) self.assertEqual(model_dbs[0].timestamp, None) - self.assertEqual(model_dbs[0].category, 'type1') + self.assertEqual(model_dbs[0].category, "type1") - model_dbs = FakeModel.query(exclude_fields=['name', 'timestamp', 'category']) + model_dbs = FakeModel.query(exclude_fields=["name", "timestamp", "category"]) self.assertTrue(model_dbs[0].id) self.assertEqual(model_dbs[0].name, None) self.assertEqual(model_dbs[0].timestamp, None) diff --git a/st2common/tests/unit/test_persistence_change_revision.py b/st2common/tests/unit/test_persistence_change_revision.py index f9e31e1c732..c268fa86b5a 100644 --- a/st2common/tests/unit/test_persistence_change_revision.py +++ b/st2common/tests/unit/test_persistence_change_revision.py @@ -24,7 +24,6 @@ class TestChangeRevision(DbTestCase): - @classmethod def setUpClass(cls): super(TestChangeRevision, cls).setUpClass() @@ -35,7 +34,7 @@ def tearDown(self): super(TestChangeRevision, self).tearDown() def test_crud(self): - initial = ChangeRevFakeModelDB(name=uuid.uuid4().hex, context={'a': 1}) + initial = ChangeRevFakeModelDB(name=uuid.uuid4().hex, context={"a": 1}) # Test create created = self.access.add_or_update(initial) @@ -47,14 +46,14 @@ def test_crud(self): self.assertDictEqual(created.context, retrieved.context) # Test update - retrieved = self.access.update(retrieved, context={'a': 2}) + retrieved = self.access.update(retrieved, context={"a": 2}) updated = self.access.get_by_id(doc_id) self.assertNotEqual(created.rev, updated.rev) self.assertEqual(retrieved.rev, updated.rev) self.assertDictEqual(retrieved.context, updated.context) # Test add or update - retrieved.context = {'a': 1, 'b': 2} + retrieved.context = {"a": 1, "b": 2} retrieved = self.access.add_or_update(retrieved) updated = self.access.get_by_id(doc_id) self.assertNotEqual(created.rev, updated.rev) @@ -65,13 +64,11 @@ def test_crud(self): created.delete() self.assertRaises( - db_exc.StackStormDBObjectNotFoundError, - self.access.get_by_id, - doc_id + db_exc.StackStormDBObjectNotFoundError, self.access.get_by_id, doc_id ) def test_write_conflict(self): - initial = ChangeRevFakeModelDB(name=uuid.uuid4().hex, context={'a': 1}) + initial = ChangeRevFakeModelDB(name=uuid.uuid4().hex, context={"a": 1}) # Prep record created = self.access.add_or_update(initial) @@ -83,7 +80,7 @@ def test_write_conflict(self): retrieved2 = self.access.get_by_id(doc_id) # Test update on instance 1, expect success - retrieved1 = self.access.update(retrieved1, context={'a': 2}) + retrieved1 = self.access.update(retrieved1, context={"a": 2}) updated = self.access.get_by_id(doc_id) self.assertNotEqual(created.rev, updated.rev) self.assertEqual(retrieved1.rev, updated.rev) @@ -94,5 +91,5 @@ def test_write_conflict(self): db_exc.StackStormDBObjectWriteConflictError, self.access.update, retrieved2, - context={'a': 1, 'b': 2} + context={"a": 1, "b": 2}, ) diff --git a/st2common/tests/unit/test_plugin_loader.py b/st2common/tests/unit/test_plugin_loader.py index 4b78b6b4cc3..4641b66e9c3 100644 --- a/st2common/tests/unit/test_plugin_loader.py +++ b/st2common/tests/unit/test_plugin_loader.py @@ -24,8 +24,8 @@ import st2common.util.loader as plugin_loader -PLUGIN_FOLDER = 'loadableplugin' -SRC_RELATIVE = os.path.join('../resources', PLUGIN_FOLDER) +PLUGIN_FOLDER = "loadableplugin" +SRC_RELATIVE = os.path.join("../resources", PLUGIN_FOLDER) SRC_ROOT = os.path.join(os.path.abspath(os.path.dirname(__file__)), SRC_RELATIVE) @@ -51,64 +51,71 @@ def tearDown(self): sys.path = LoaderTest.sys_path def test_module_load_from_file(self): - plugin_path = os.path.join(SRC_ROOT, 'plugin/standaloneplugin.py') + plugin_path = os.path.join(SRC_ROOT, "plugin/standaloneplugin.py") plugin_classes = plugin_loader.register_plugin( - LoaderTest.DummyPlugin, plugin_path) + LoaderTest.DummyPlugin, plugin_path + ) # Even though there are two classes in that file, only one # matches the specs of DummyPlugin class. self.assertEqual(1, len(plugin_classes)) # Validate sys.path now contains the plugin directory. - self.assertIn(os.path.abspath(os.path.join(SRC_ROOT, 'plugin')), sys.path) + self.assertIn(os.path.abspath(os.path.join(SRC_ROOT, "plugin")), sys.path) # Validate the individual plugins for plugin_class in plugin_classes: try: plugin_instance = plugin_class() ret_val = plugin_instance.do_work() - self.assertIsNotNone(ret_val, 'Should be non-null.') + self.assertIsNotNone(ret_val, "Should be non-null.") except: pass def test_module_load_from_file_fail(self): try: - plugin_path = os.path.join(SRC_ROOT, 'plugin/sampleplugin.py') + plugin_path = os.path.join(SRC_ROOT, "plugin/sampleplugin.py") plugin_loader.register_plugin(LoaderTest.DummyPlugin, plugin_path) - self.assertTrue(False, 'Import error is expected.') + self.assertTrue(False, "Import error is expected.") except ImportError: self.assertTrue(True) def test_syspath_unchanged_load_multiple_plugins(self): - plugin_1_path = os.path.join(SRC_ROOT, 'plugin/sampleplugin.py') + plugin_1_path = os.path.join(SRC_ROOT, "plugin/sampleplugin.py") try: - plugin_loader.register_plugin( - LoaderTest.DummyPlugin, plugin_1_path) + plugin_loader.register_plugin(LoaderTest.DummyPlugin, plugin_1_path) except ImportError: pass old_sys_path = copy.copy(sys.path) - plugin_2_path = os.path.join(SRC_ROOT, 'plugin/sampleplugin2.py') + plugin_2_path = os.path.join(SRC_ROOT, "plugin/sampleplugin2.py") try: - plugin_loader.register_plugin( - LoaderTest.DummyPlugin, plugin_2_path) + plugin_loader.register_plugin(LoaderTest.DummyPlugin, plugin_2_path) except ImportError: pass - self.assertEqual(old_sys_path, sys.path, 'Should be equal.') + self.assertEqual(old_sys_path, sys.path, "Should be equal.") def test_register_plugin_class_class_doesnt_exist(self): - file_path = os.path.join(SRC_ROOT, 'plugin/sampleplugin3.py') + file_path = os.path.join(SRC_ROOT, "plugin/sampleplugin3.py") expected_msg = 'doesn\'t expose class named "SamplePluginNotExists"' - self.assertRaisesRegexp(Exception, expected_msg, - plugin_loader.register_plugin_class, - base_class=LoaderTest.DummyPlugin, - file_path=file_path, - class_name='SamplePluginNotExists') + self.assertRaisesRegexp( + Exception, + expected_msg, + plugin_loader.register_plugin_class, + base_class=LoaderTest.DummyPlugin, + file_path=file_path, + class_name="SamplePluginNotExists", + ) def test_register_plugin_class_abstract_method_not_implemented(self): - file_path = os.path.join(SRC_ROOT, 'plugin/sampleplugin3.py') - - expected_msg = 'doesn\'t implement required "do_work" method from the base class' - self.assertRaisesRegexp(plugin_loader.IncompatiblePluginException, expected_msg, - plugin_loader.register_plugin_class, - base_class=LoaderTest.DummyPlugin, - file_path=file_path, - class_name='SamplePlugin') + file_path = os.path.join(SRC_ROOT, "plugin/sampleplugin3.py") + + expected_msg = ( + 'doesn\'t implement required "do_work" method from the base class' + ) + self.assertRaisesRegexp( + plugin_loader.IncompatiblePluginException, + expected_msg, + plugin_loader.register_plugin_class, + base_class=LoaderTest.DummyPlugin, + file_path=file_path, + class_name="SamplePlugin", + ) diff --git a/st2common/tests/unit/test_policies.py b/st2common/tests/unit/test_policies.py index 5491e7482bc..f6dd9a47de6 100644 --- a/st2common/tests/unit/test_policies.py +++ b/st2common/tests/unit/test_policies.py @@ -19,55 +19,43 @@ from st2tests import DbTestCase from st2tests.fixturesloader import FixturesLoader -__all__ = [ - 'PolicyTestCase' -] +__all__ = ["PolicyTestCase"] -PACK = 'generic' +PACK = "generic" TEST_FIXTURES = { - 'runners': [ - 'testrunner1.yaml' - ], - 'actions': [ - 'action1.yaml' - ], - 'policytypes': [ - 'fake_policy_type_1.yaml', - 'fake_policy_type_2.yaml' - ], - 'policies': [ - 'policy_1.yaml', - 'policy_2.yaml' - ] + "runners": ["testrunner1.yaml"], + "actions": ["action1.yaml"], + "policytypes": ["fake_policy_type_1.yaml", "fake_policy_type_2.yaml"], + "policies": ["policy_1.yaml", "policy_2.yaml"], } class PolicyTestCase(DbTestCase): - @classmethod def setUpClass(cls): super(PolicyTestCase, cls).setUpClass() loader = FixturesLoader() - loader.save_fixtures_to_db(fixtures_pack=PACK, - fixtures_dict=TEST_FIXTURES) + loader.save_fixtures_to_db(fixtures_pack=PACK, fixtures_dict=TEST_FIXTURES) def test_get_by_ref(self): - policy_db = Policy.get_by_ref('wolfpack.action-1.concurrency') + policy_db = Policy.get_by_ref("wolfpack.action-1.concurrency") self.assertIsNotNone(policy_db) - self.assertEqual(policy_db.pack, 'wolfpack') - self.assertEqual(policy_db.name, 'action-1.concurrency') + self.assertEqual(policy_db.pack, "wolfpack") + self.assertEqual(policy_db.name, "action-1.concurrency") policy_type_db = PolicyType.get_by_ref(policy_db.policy_type) self.assertIsNotNone(policy_type_db) - self.assertEqual(policy_type_db.resource_type, 'action') - self.assertEqual(policy_type_db.name, 'concurrency') + self.assertEqual(policy_type_db.resource_type, "action") + self.assertEqual(policy_type_db.name, "concurrency") def test_get_driver(self): - policy_db = Policy.get_by_ref('wolfpack.action-1.concurrency') - policy = get_driver(policy_db.ref, policy_db.policy_type, **policy_db.parameters) + policy_db = Policy.get_by_ref("wolfpack.action-1.concurrency") + policy = get_driver( + policy_db.ref, policy_db.policy_type, **policy_db.parameters + ) self.assertIsInstance(policy, ResourcePolicyApplicator) self.assertEqual(policy._policy_ref, policy_db.ref) self.assertEqual(policy._policy_type, policy_db.policy_type) - self.assertTrue(hasattr(policy, 'threshold')) + self.assertTrue(hasattr(policy, "threshold")) self.assertEqual(policy.threshold, 3) diff --git a/st2common/tests/unit/test_policies_registrar.py b/st2common/tests/unit/test_policies_registrar.py index b46515e08af..85c1d344906 100644 --- a/st2common/tests/unit/test_policies_registrar.py +++ b/st2common/tests/unit/test_policies_registrar.py @@ -29,9 +29,7 @@ from st2tests.base import CleanDbTestCase from st2tests.fixturesloader import get_fixtures_packs_base_path -__all__ = [ - 'PoliciesRegistrarTestCase' -] +__all__ = ["PoliciesRegistrarTestCase"] class PoliciesRegistrarTestCase(CleanDbTestCase): @@ -44,13 +42,13 @@ def setUp(self): def test_register_policy_types(self): self.assertEqual(register_policy_types(st2tests), 2) - type1 = PolicyType.get_by_ref('action.concurrency') - self.assertEqual(type1.name, 'concurrency') - self.assertEqual(type1.resource_type, 'action') + type1 = PolicyType.get_by_ref("action.concurrency") + self.assertEqual(type1.name, "concurrency") + self.assertEqual(type1.resource_type, "action") - type2 = PolicyType.get_by_ref('action.mock_policy_error') - self.assertEqual(type2.name, 'mock_policy_error') - self.assertEqual(type2.resource_type, 'action') + type2 = PolicyType.get_by_ref("action.mock_policy_error") + self.assertEqual(type2.name, "mock_policy_error") + self.assertEqual(type2.resource_type, "action") def test_register_all_policies(self): policies_dbs = Policy.get_all() @@ -64,38 +62,29 @@ def test_register_all_policies(self): policies = { policies_db.name: { - 'pack': policies_db.pack, - 'type': policies_db.policy_type, - 'parameters': policies_db.parameters + "pack": policies_db.pack, + "type": policies_db.policy_type, + "parameters": policies_db.parameters, } for policies_db in policies_dbs } expected_policies = { - 'test_policy_1': { - 'pack': 'dummy_pack_1', - 'type': 'action.concurrency', - 'parameters': { - 'action': 'delay', - 'threshold': 3 - } + "test_policy_1": { + "pack": "dummy_pack_1", + "type": "action.concurrency", + "parameters": {"action": "delay", "threshold": 3}, }, - 'test_policy_3': { - 'pack': 'dummy_pack_1', - 'type': 'action.retry', - 'parameters': { - 'retry_on': 'timeout', - 'max_retry_count': 5 - } + "test_policy_3": { + "pack": "dummy_pack_1", + "type": "action.retry", + "parameters": {"retry_on": "timeout", "max_retry_count": 5}, + }, + "sequential.retry_on_failure": { + "pack": "orquesta_tests", + "type": "action.retry", + "parameters": {"retry_on": "failure", "max_retry_count": 1}, }, - 'sequential.retry_on_failure': { - 'pack': 'orquesta_tests', - 'type': 'action.retry', - 'parameters': { - 'retry_on': 'failure', - 'max_retry_count': 1 - } - } } self.assertEqual(len(expected_policies), count) @@ -103,39 +92,49 @@ def test_register_all_policies(self): self.assertDictEqual(expected_policies, policies) def test_register_policies_from_pack(self): - pack_dir = os.path.join(get_fixtures_packs_base_path(), 'dummy_pack_1') + pack_dir = os.path.join(get_fixtures_packs_base_path(), "dummy_pack_1") self.assertEqual(register_policies(pack_dir=pack_dir), 2) - p1 = Policy.get_by_ref('dummy_pack_1.test_policy_1') - self.assertEqual(p1.name, 'test_policy_1') - self.assertEqual(p1.pack, 'dummy_pack_1') - self.assertEqual(p1.resource_ref, 'dummy_pack_1.local') - self.assertEqual(p1.policy_type, 'action.concurrency') + p1 = Policy.get_by_ref("dummy_pack_1.test_policy_1") + self.assertEqual(p1.name, "test_policy_1") + self.assertEqual(p1.pack, "dummy_pack_1") + self.assertEqual(p1.resource_ref, "dummy_pack_1.local") + self.assertEqual(p1.policy_type, "action.concurrency") # Verify that a default value for parameter "action" which isn't provided in the file is set - self.assertEqual(p1.parameters['action'], 'delay') - self.assertEqual(p1.metadata_file, 'policies/policy_1.yaml') + self.assertEqual(p1.parameters["action"], "delay") + self.assertEqual(p1.metadata_file, "policies/policy_1.yaml") - p2 = Policy.get_by_ref('dummy_pack_1.test_policy_2') + p2 = Policy.get_by_ref("dummy_pack_1.test_policy_2") self.assertEqual(p2, None) def test_register_policy_invalid_policy_type_references(self): # Policy references an invalid (inexistent) policy type registrar = PolicyRegistrar() - policy_path = os.path.join(get_fixtures_packs_base_path(), - 'dummy_pack_1/policies/policy_2.yaml') + policy_path = os.path.join( + get_fixtures_packs_base_path(), "dummy_pack_1/policies/policy_2.yaml" + ) expected_msg = 'Referenced policy_type "action.mock_policy_error" doesnt exist' - self.assertRaisesRegexp(ValueError, expected_msg, registrar._register_policy, - pack='dummy_pack_1', policy=policy_path) + self.assertRaisesRegexp( + ValueError, + expected_msg, + registrar._register_policy, + pack="dummy_pack_1", + policy=policy_path, + ) def test_make_sure_policy_parameters_are_validated_during_register(self): # Policy where specified parameters fail schema validation registrar = PolicyRegistrar() - policy_path = os.path.join(get_fixtures_packs_base_path(), - 'dummy_pack_2/policies/policy_3.yaml') - - expected_msg = '100 is greater than the maximum of 5' - self.assertRaisesRegexp(jsonschema.ValidationError, expected_msg, - registrar._register_policy, - pack='dummy_pack_2', - policy=policy_path) + policy_path = os.path.join( + get_fixtures_packs_base_path(), "dummy_pack_2/policies/policy_3.yaml" + ) + + expected_msg = "100 is greater than the maximum of 5" + self.assertRaisesRegexp( + jsonschema.ValidationError, + expected_msg, + registrar._register_policy, + pack="dummy_pack_2", + policy=policy_path, + ) diff --git a/st2common/tests/unit/test_purge_executions.py b/st2common/tests/unit/test_purge_executions.py index 5362cc753e6..64ee4cfa677 100644 --- a/st2common/tests/unit/test_purge_executions.py +++ b/st2common/tests/unit/test_purge_executions.py @@ -34,18 +34,10 @@ LOG = logging.getLogger(__name__) -TEST_FIXTURES = { - 'executions': [ - 'execution1.yaml' - ], - 'liveactions': [ - 'liveaction4.yaml' - ] -} +TEST_FIXTURES = {"executions": ["execution1.yaml"], "liveactions": ["liveaction4.yaml"]} class TestPurgeExecutions(CleanDbTestCase): - @classmethod def setUpClass(cls): CleanDbTestCase.setUpClass() @@ -54,114 +46,128 @@ def setUpClass(cls): def setUp(self): super(TestPurgeExecutions, self).setUp() fixtures_loader = FixturesLoader() - self.models = fixtures_loader.load_models(fixtures_pack='generic', - fixtures_dict=TEST_FIXTURES) + self.models = fixtures_loader.load_models( + fixtures_pack="generic", fixtures_dict=TEST_FIXTURES + ) def test_no_timestamp_doesnt_delete_things(self): now = date_utils.get_datetime_utc_now() - exec_model = copy.deepcopy(self.models['executions']['execution1.yaml']) - exec_model['start_timestamp'] = now - timedelta(days=15) - exec_model['end_timestamp'] = now - timedelta(days=14) - exec_model['status'] = action_constants.LIVEACTION_STATUS_SUCCEEDED - exec_model['id'] = bson.ObjectId() + exec_model = copy.deepcopy(self.models["executions"]["execution1.yaml"]) + exec_model["start_timestamp"] = now - timedelta(days=15) + exec_model["end_timestamp"] = now - timedelta(days=14) + exec_model["status"] = action_constants.LIVEACTION_STATUS_SUCCEEDED + exec_model["id"] = bson.ObjectId() ActionExecution.add_or_update(exec_model) # Insert corresponding stdout and stderr db mock models - self._insert_mock_stdout_and_stderr_objects_for_execution(exec_model['id'], count=3) + self._insert_mock_stdout_and_stderr_objects_for_execution( + exec_model["id"], count=3 + ) execs = ActionExecution.get_all() self.assertEqual(len(execs), 1) - stdout_dbs = ActionExecutionOutput.query(output_type='stdout') + stdout_dbs = ActionExecutionOutput.query(output_type="stdout") self.assertEqual(len(stdout_dbs), 3) - stderr_dbs = ActionExecutionOutput.query(output_type='stderr') + stderr_dbs = ActionExecutionOutput.query(output_type="stderr") self.assertEqual(len(stderr_dbs), 3) - expected_msg = 'Specify a valid timestamp' - self.assertRaisesRegexp(ValueError, expected_msg, purge_executions, - logger=LOG, timestamp=None) + expected_msg = "Specify a valid timestamp" + self.assertRaisesRegexp( + ValueError, expected_msg, purge_executions, logger=LOG, timestamp=None + ) execs = ActionExecution.get_all() self.assertEqual(len(execs), 1) - stdout_dbs = ActionExecutionOutput.query(output_type='stdout') + stdout_dbs = ActionExecutionOutput.query(output_type="stdout") self.assertEqual(len(stdout_dbs), 3) - stderr_dbs = ActionExecutionOutput.query(output_type='stderr') + stderr_dbs = ActionExecutionOutput.query(output_type="stderr") self.assertEqual(len(stderr_dbs), 3) def test_purge_executions_with_action_ref(self): now = date_utils.get_datetime_utc_now() - exec_model = copy.deepcopy(self.models['executions']['execution1.yaml']) - exec_model['start_timestamp'] = now - timedelta(days=15) - exec_model['end_timestamp'] = now - timedelta(days=14) - exec_model['status'] = action_constants.LIVEACTION_STATUS_SUCCEEDED - exec_model['id'] = bson.ObjectId() + exec_model = copy.deepcopy(self.models["executions"]["execution1.yaml"]) + exec_model["start_timestamp"] = now - timedelta(days=15) + exec_model["end_timestamp"] = now - timedelta(days=14) + exec_model["status"] = action_constants.LIVEACTION_STATUS_SUCCEEDED + exec_model["id"] = bson.ObjectId() ActionExecution.add_or_update(exec_model) # Insert corresponding stdout and stderr db mock models - self._insert_mock_stdout_and_stderr_objects_for_execution(exec_model['id'], count=3) + self._insert_mock_stdout_and_stderr_objects_for_execution( + exec_model["id"], count=3 + ) execs = ActionExecution.get_all() self.assertEqual(len(execs), 1) - stdout_dbs = ActionExecutionOutput.query(output_type='stdout') + stdout_dbs = ActionExecutionOutput.query(output_type="stdout") self.assertEqual(len(stdout_dbs), 3) - stderr_dbs = ActionExecutionOutput.query(output_type='stderr') + stderr_dbs = ActionExecutionOutput.query(output_type="stderr") self.assertEqual(len(stderr_dbs), 3) # Invalid action reference, nothing should be deleted - purge_executions(logger=LOG, action_ref='core.localzzz', timestamp=now - timedelta(days=10)) + purge_executions( + logger=LOG, action_ref="core.localzzz", timestamp=now - timedelta(days=10) + ) execs = ActionExecution.get_all() self.assertEqual(len(execs), 1) - stdout_dbs = ActionExecutionOutput.query(output_type='stdout') + stdout_dbs = ActionExecutionOutput.query(output_type="stdout") self.assertEqual(len(stdout_dbs), 3) - stderr_dbs = ActionExecutionOutput.query(output_type='stderr') + stderr_dbs = ActionExecutionOutput.query(output_type="stderr") self.assertEqual(len(stderr_dbs), 3) - purge_executions(logger=LOG, action_ref='core.local', timestamp=now - timedelta(days=10)) + purge_executions( + logger=LOG, action_ref="core.local", timestamp=now - timedelta(days=10) + ) execs = ActionExecution.get_all() self.assertEqual(len(execs), 0) - stdout_dbs = ActionExecutionOutput.query(output_type='stdout') + stdout_dbs = ActionExecutionOutput.query(output_type="stdout") self.assertEqual(len(stdout_dbs), 0) - stderr_dbs = ActionExecutionOutput.query(output_type='stderr') + stderr_dbs = ActionExecutionOutput.query(output_type="stderr") self.assertEqual(len(stderr_dbs), 0) def test_purge_executions_with_timestamp(self): now = date_utils.get_datetime_utc_now() # Write one execution after cut-off threshold - exec_model = copy.deepcopy(self.models['executions']['execution1.yaml']) - exec_model['start_timestamp'] = now - timedelta(days=15) - exec_model['end_timestamp'] = now - timedelta(days=14) - exec_model['status'] = action_constants.LIVEACTION_STATUS_SUCCEEDED - exec_model['id'] = bson.ObjectId() + exec_model = copy.deepcopy(self.models["executions"]["execution1.yaml"]) + exec_model["start_timestamp"] = now - timedelta(days=15) + exec_model["end_timestamp"] = now - timedelta(days=14) + exec_model["status"] = action_constants.LIVEACTION_STATUS_SUCCEEDED + exec_model["id"] = bson.ObjectId() ActionExecution.add_or_update(exec_model) # Insert corresponding stdout and stderr db mock models - self._insert_mock_stdout_and_stderr_objects_for_execution(exec_model['id'], count=3) + self._insert_mock_stdout_and_stderr_objects_for_execution( + exec_model["id"], count=3 + ) # Write one execution before cut-off threshold - exec_model = copy.deepcopy(self.models['executions']['execution1.yaml']) - exec_model['start_timestamp'] = now - timedelta(days=22) - exec_model['end_timestamp'] = now - timedelta(days=21) - exec_model['status'] = action_constants.LIVEACTION_STATUS_SUCCEEDED - exec_model['id'] = bson.ObjectId() + exec_model = copy.deepcopy(self.models["executions"]["execution1.yaml"]) + exec_model["start_timestamp"] = now - timedelta(days=22) + exec_model["end_timestamp"] = now - timedelta(days=21) + exec_model["status"] = action_constants.LIVEACTION_STATUS_SUCCEEDED + exec_model["id"] = bson.ObjectId() ActionExecution.add_or_update(exec_model) # Insert corresponding stdout and stderr db mock models - self._insert_mock_stdout_and_stderr_objects_for_execution(exec_model['id'], count=3) + self._insert_mock_stdout_and_stderr_objects_for_execution( + exec_model["id"], count=3 + ) execs = ActionExecution.get_all() self.assertEqual(len(execs), 2) - stdout_dbs = ActionExecutionOutput.query(output_type='stdout') + stdout_dbs = ActionExecutionOutput.query(output_type="stdout") self.assertEqual(len(stdout_dbs), 6) - stderr_dbs = ActionExecutionOutput.query(output_type='stderr') + stderr_dbs = ActionExecutionOutput.query(output_type="stderr") self.assertEqual(len(stderr_dbs), 6) purge_executions(logger=LOG, timestamp=now - timedelta(days=20)) execs = ActionExecution.get_all() self.assertEqual(len(execs), 1) - stdout_dbs = ActionExecutionOutput.query(output_type='stdout') + stdout_dbs = ActionExecutionOutput.query(output_type="stdout") self.assertEqual(len(stdout_dbs), 3) - stderr_dbs = ActionExecutionOutput.query(output_type='stderr') + stderr_dbs = ActionExecutionOutput.query(output_type="stderr") self.assertEqual(len(stderr_dbs), 3) def test_liveaction_gets_deleted(self): @@ -169,19 +175,19 @@ def test_liveaction_gets_deleted(self): start_ts = now - timedelta(days=15) end_ts = now - timedelta(days=14) - liveaction_model = copy.deepcopy(self.models['liveactions']['liveaction4.yaml']) - liveaction_model['start_timestamp'] = start_ts - liveaction_model['end_timestamp'] = end_ts - liveaction_model['status'] = action_constants.LIVEACTION_STATUS_SUCCEEDED + liveaction_model = copy.deepcopy(self.models["liveactions"]["liveaction4.yaml"]) + liveaction_model["start_timestamp"] = start_ts + liveaction_model["end_timestamp"] = end_ts + liveaction_model["status"] = action_constants.LIVEACTION_STATUS_SUCCEEDED liveaction = LiveAction.add_or_update(liveaction_model) # Write one execution before cut-off threshold - exec_model = copy.deepcopy(self.models['executions']['execution1.yaml']) - exec_model['start_timestamp'] = start_ts - exec_model['end_timestamp'] = end_ts - exec_model['status'] = action_constants.LIVEACTION_STATUS_SUCCEEDED - exec_model['id'] = bson.ObjectId() - exec_model['liveaction']['id'] = str(liveaction.id) + exec_model = copy.deepcopy(self.models["executions"]["execution1.yaml"]) + exec_model["start_timestamp"] = start_ts + exec_model["end_timestamp"] = end_ts + exec_model["status"] = action_constants.LIVEACTION_STATUS_SUCCEEDED + exec_model["id"] = bson.ObjectId() + exec_model["liveaction"]["id"] = str(liveaction.id) ActionExecution.add_or_update(exec_model) liveactions = LiveAction.get_all() @@ -201,110 +207,143 @@ def test_purge_incomplete(self): start_ts = now - timedelta(days=15) # Write executions before cut-off threshold - exec_model = copy.deepcopy(self.models['executions']['execution1.yaml']) - exec_model['start_timestamp'] = start_ts - exec_model['status'] = action_constants.LIVEACTION_STATUS_SCHEDULED - exec_model['id'] = bson.ObjectId() + exec_model = copy.deepcopy(self.models["executions"]["execution1.yaml"]) + exec_model["start_timestamp"] = start_ts + exec_model["status"] = action_constants.LIVEACTION_STATUS_SCHEDULED + exec_model["id"] = bson.ObjectId() ActionExecution.add_or_update(exec_model) # Insert corresponding stdout and stderr db mock models - self._insert_mock_stdout_and_stderr_objects_for_execution(exec_model['id'], count=1) - - exec_model = copy.deepcopy(self.models['executions']['execution1.yaml']) - exec_model['start_timestamp'] = start_ts - exec_model['status'] = action_constants.LIVEACTION_STATUS_RUNNING - exec_model['id'] = bson.ObjectId() + self._insert_mock_stdout_and_stderr_objects_for_execution( + exec_model["id"], count=1 + ) + + exec_model = copy.deepcopy(self.models["executions"]["execution1.yaml"]) + exec_model["start_timestamp"] = start_ts + exec_model["status"] = action_constants.LIVEACTION_STATUS_RUNNING + exec_model["id"] = bson.ObjectId() ActionExecution.add_or_update(exec_model) # Insert corresponding stdout and stderr db mock models - self._insert_mock_stdout_and_stderr_objects_for_execution(exec_model['id'], count=1) - - exec_model = copy.deepcopy(self.models['executions']['execution1.yaml']) - exec_model['start_timestamp'] = start_ts - exec_model['status'] = action_constants.LIVEACTION_STATUS_DELAYED - exec_model['id'] = bson.ObjectId() + self._insert_mock_stdout_and_stderr_objects_for_execution( + exec_model["id"], count=1 + ) + + exec_model = copy.deepcopy(self.models["executions"]["execution1.yaml"]) + exec_model["start_timestamp"] = start_ts + exec_model["status"] = action_constants.LIVEACTION_STATUS_DELAYED + exec_model["id"] = bson.ObjectId() ActionExecution.add_or_update(exec_model) # Insert corresponding stdout and stderr db mock models - self._insert_mock_stdout_and_stderr_objects_for_execution(exec_model['id'], count=1) - - exec_model = copy.deepcopy(self.models['executions']['execution1.yaml']) - exec_model['start_timestamp'] = start_ts - exec_model['status'] = action_constants.LIVEACTION_STATUS_CANCELING - exec_model['id'] = bson.ObjectId() + self._insert_mock_stdout_and_stderr_objects_for_execution( + exec_model["id"], count=1 + ) + + exec_model = copy.deepcopy(self.models["executions"]["execution1.yaml"]) + exec_model["start_timestamp"] = start_ts + exec_model["status"] = action_constants.LIVEACTION_STATUS_CANCELING + exec_model["id"] = bson.ObjectId() ActionExecution.add_or_update(exec_model) # Insert corresponding stdout and stderr db mock models - self._insert_mock_stdout_and_stderr_objects_for_execution(exec_model['id'], count=1) - - exec_model = copy.deepcopy(self.models['executions']['execution1.yaml']) - exec_model['start_timestamp'] = start_ts - exec_model['status'] = action_constants.LIVEACTION_STATUS_REQUESTED - exec_model['id'] = bson.ObjectId() + self._insert_mock_stdout_and_stderr_objects_for_execution( + exec_model["id"], count=1 + ) + + exec_model = copy.deepcopy(self.models["executions"]["execution1.yaml"]) + exec_model["start_timestamp"] = start_ts + exec_model["status"] = action_constants.LIVEACTION_STATUS_REQUESTED + exec_model["id"] = bson.ObjectId() ActionExecution.add_or_update(exec_model) # Insert corresponding stdout and stderr db mock models - self._insert_mock_stdout_and_stderr_objects_for_execution(exec_model['id'], count=1) + self._insert_mock_stdout_and_stderr_objects_for_execution( + exec_model["id"], count=1 + ) self.assertEqual(len(ActionExecution.get_all()), 5) - stdout_dbs = ActionExecutionOutput.query(output_type='stdout') + stdout_dbs = ActionExecutionOutput.query(output_type="stdout") self.assertEqual(len(stdout_dbs), 5) - stderr_dbs = ActionExecutionOutput.query(output_type='stderr') + stderr_dbs = ActionExecutionOutput.query(output_type="stderr") self.assertEqual(len(stderr_dbs), 5) # Incompleted executions shouldnt be purged - purge_executions(logger=LOG, timestamp=now - timedelta(days=10), purge_incomplete=False) + purge_executions( + logger=LOG, timestamp=now - timedelta(days=10), purge_incomplete=False + ) self.assertEqual(len(ActionExecution.get_all()), 5) - stdout_dbs = ActionExecutionOutput.query(output_type='stdout') + stdout_dbs = ActionExecutionOutput.query(output_type="stdout") self.assertEqual(len(stdout_dbs), 5) - stderr_dbs = ActionExecutionOutput.query(output_type='stderr') + stderr_dbs = ActionExecutionOutput.query(output_type="stderr") self.assertEqual(len(stderr_dbs), 5) - purge_executions(logger=LOG, timestamp=now - timedelta(days=10), purge_incomplete=True) + purge_executions( + logger=LOG, timestamp=now - timedelta(days=10), purge_incomplete=True + ) self.assertEqual(len(ActionExecution.get_all()), 0) - stdout_dbs = ActionExecutionOutput.query(output_type='stdout') + stdout_dbs = ActionExecutionOutput.query(output_type="stdout") self.assertEqual(len(stdout_dbs), 0) - stderr_dbs = ActionExecutionOutput.query(output_type='stderr') + stderr_dbs = ActionExecutionOutput.query(output_type="stderr") self.assertEqual(len(stderr_dbs), 0) - @mock.patch('st2common.garbage_collection.executions.LiveAction') - @mock.patch('st2common.garbage_collection.executions.ActionExecution') - def test_purge_executions_whole_model_is_not_loaded_in_memory(self, mock_ActionExecution, - mock_LiveAction): + @mock.patch("st2common.garbage_collection.executions.LiveAction") + @mock.patch("st2common.garbage_collection.executions.ActionExecution") + def test_purge_executions_whole_model_is_not_loaded_in_memory( + self, mock_ActionExecution, mock_LiveAction + ): # Verify that whole execution objects are not loaded in memory and we just retrieve the # id field self.assertEqual(mock_ActionExecution.query.call_count, 0) self.assertEqual(mock_LiveAction.query.call_count, 0) now = date_utils.get_datetime_utc_now() - purge_executions(logger=LOG, timestamp=now - timedelta(days=10), purge_incomplete=True) + purge_executions( + logger=LOG, timestamp=now - timedelta(days=10), purge_incomplete=True + ) self.assertEqual(mock_ActionExecution.query.call_count, 2) self.assertEqual(mock_LiveAction.query.call_count, 1) - self.assertEqual(mock_ActionExecution.query.call_args_list[0][1]['only_fields'], ['id']) - self.assertTrue(mock_ActionExecution.query.call_args_list[0][1]['no_dereference']) - self.assertEqual(mock_ActionExecution.query.call_args_list[1][1]['only_fields'], ['id']) - self.assertTrue(mock_ActionExecution.query.call_args_list[1][1]['no_dereference']) - self.assertEqual(mock_LiveAction.query.call_args_list[0][1]['only_fields'], ['id']) - self.assertTrue(mock_LiveAction.query.call_args_list[0][1]['no_dereference']) - - def _insert_mock_stdout_and_stderr_objects_for_execution(self, execution_id, count=5): + self.assertEqual( + mock_ActionExecution.query.call_args_list[0][1]["only_fields"], ["id"] + ) + self.assertTrue( + mock_ActionExecution.query.call_args_list[0][1]["no_dereference"] + ) + self.assertEqual( + mock_ActionExecution.query.call_args_list[1][1]["only_fields"], ["id"] + ) + self.assertTrue( + mock_ActionExecution.query.call_args_list[1][1]["no_dereference"] + ) + self.assertEqual( + mock_LiveAction.query.call_args_list[0][1]["only_fields"], ["id"] + ) + self.assertTrue(mock_LiveAction.query.call_args_list[0][1]["no_dereference"]) + + def _insert_mock_stdout_and_stderr_objects_for_execution( + self, execution_id, count=5 + ): execution_id = str(execution_id) stdout_dbs, stderr_dbs = [], [] for i in range(0, count): - stdout_db = ActionExecutionOutputDB(execution_id=execution_id, - action_ref='dummy.pack', - runner_ref='dummy', - output_type='stdout', - data='stdout %s' % (i)) + stdout_db = ActionExecutionOutputDB( + execution_id=execution_id, + action_ref="dummy.pack", + runner_ref="dummy", + output_type="stdout", + data="stdout %s" % (i), + ) ActionExecutionOutput.add_or_update(stdout_db) - stderr_db = ActionExecutionOutputDB(execution_id=execution_id, - action_ref='dummy.pack', - runner_ref='dummy', - output_type='stderr', - data='stderr%s' % (i)) + stderr_db = ActionExecutionOutputDB( + execution_id=execution_id, + action_ref="dummy.pack", + runner_ref="dummy", + output_type="stderr", + data="stderr%s" % (i), + ) ActionExecutionOutput.add_or_update(stderr_db) return stdout_dbs, stderr_dbs diff --git a/st2common/tests/unit/test_purge_trigger_instances.py b/st2common/tests/unit/test_purge_trigger_instances.py index 2cc9f6ffed8..515c4040c31 100644 --- a/st2common/tests/unit/test_purge_trigger_instances.py +++ b/st2common/tests/unit/test_purge_trigger_instances.py @@ -28,7 +28,6 @@ class TestPurgeTriggerInstances(CleanDbTestCase): - @classmethod def setUpClass(cls): CleanDbTestCase.setUpClass() @@ -40,32 +39,42 @@ def setUp(self): def test_no_timestamp_doesnt_delete(self): now = date_utils.get_datetime_utc_now() - instance_db = TriggerInstanceDB(trigger='purge_tool.dummy.trigger', - payload={'hola': 'hi', 'kuraci': 'chicken'}, - occurrence_time=now - timedelta(days=20), - status=TRIGGER_INSTANCE_PROCESSED) + instance_db = TriggerInstanceDB( + trigger="purge_tool.dummy.trigger", + payload={"hola": "hi", "kuraci": "chicken"}, + occurrence_time=now - timedelta(days=20), + status=TRIGGER_INSTANCE_PROCESSED, + ) TriggerInstance.add_or_update(instance_db) self.assertEqual(len(TriggerInstance.get_all()), 1) - expected_msg = 'Specify a valid timestamp' - self.assertRaisesRegexp(ValueError, expected_msg, - purge_trigger_instances, - logger=LOG, timestamp=None) + expected_msg = "Specify a valid timestamp" + self.assertRaisesRegexp( + ValueError, + expected_msg, + purge_trigger_instances, + logger=LOG, + timestamp=None, + ) self.assertEqual(len(TriggerInstance.get_all()), 1) def test_purge(self): now = date_utils.get_datetime_utc_now() - instance_db = TriggerInstanceDB(trigger='purge_tool.dummy.trigger', - payload={'hola': 'hi', 'kuraci': 'chicken'}, - occurrence_time=now - timedelta(days=20), - status=TRIGGER_INSTANCE_PROCESSED) + instance_db = TriggerInstanceDB( + trigger="purge_tool.dummy.trigger", + payload={"hola": "hi", "kuraci": "chicken"}, + occurrence_time=now - timedelta(days=20), + status=TRIGGER_INSTANCE_PROCESSED, + ) TriggerInstance.add_or_update(instance_db) - instance_db = TriggerInstanceDB(trigger='purge_tool.dummy.trigger', - payload={'hola': 'hi', 'kuraci': 'chicken'}, - occurrence_time=now - timedelta(days=5), - status=TRIGGER_INSTANCE_PROCESSED) + instance_db = TriggerInstanceDB( + trigger="purge_tool.dummy.trigger", + payload={"hola": "hi", "kuraci": "chicken"}, + occurrence_time=now - timedelta(days=5), + status=TRIGGER_INSTANCE_PROCESSED, + ) TriggerInstance.add_or_update(instance_db) self.assertEqual(len(TriggerInstance.get_all()), 2) diff --git a/st2common/tests/unit/test_queue_consumer.py b/st2common/tests/unit/test_queue_consumer.py index 4f54c325aa4..463eb0def4b 100644 --- a/st2common/tests/unit/test_queue_consumer.py +++ b/st2common/tests/unit/test_queue_consumer.py @@ -23,8 +23,8 @@ from tests.unit.base import FakeModelDB -FAKE_XCHG = Exchange('st2.tests', type='topic') -FAKE_WORK_Q = Queue('st2.tests.unit', FAKE_XCHG) +FAKE_XCHG = Exchange("st2.tests", type="topic") +FAKE_WORK_Q = Queue("st2.tests.unit", FAKE_XCHG) class FakeMessageHandler(consumers.MessageHandler): @@ -39,15 +39,14 @@ def get_handler(): class QueueConsumerTest(DbTestCase): - - @mock.patch.object(FakeMessageHandler, 'process', mock.MagicMock()) + @mock.patch.object(FakeMessageHandler, "process", mock.MagicMock()) def test_process_message(self): payload = FakeModelDB() handler = get_handler() handler._queue_consumer._process_message(payload) FakeMessageHandler.process.assert_called_once_with(payload) - @mock.patch.object(FakeMessageHandler, 'process', mock.MagicMock()) + @mock.patch.object(FakeMessageHandler, "process", mock.MagicMock()) def test_process_message_wrong_payload_type(self): payload = 100 handler = get_handler() @@ -72,8 +71,7 @@ def get_staged_handler(): class StagedQueueConsumerTest(DbTestCase): - - @mock.patch.object(FakeStagedMessageHandler, 'pre_ack_process', mock.MagicMock()) + @mock.patch.object(FakeStagedMessageHandler, "pre_ack_process", mock.MagicMock()) def test_process_message_pre_ack(self): payload = FakeModelDB() handler = get_staged_handler() @@ -82,15 +80,16 @@ def test_process_message_pre_ack(self): FakeStagedMessageHandler.pre_ack_process.assert_called_once_with(payload) self.assertTrue(mock_message.ack.called) - @mock.patch.object(BufferedDispatcher, 'dispatch', mock.MagicMock()) - @mock.patch.object(FakeStagedMessageHandler, 'process', mock.MagicMock()) + @mock.patch.object(BufferedDispatcher, "dispatch", mock.MagicMock()) + @mock.patch.object(FakeStagedMessageHandler, "process", mock.MagicMock()) def test_process_message(self): payload = FakeModelDB() handler = get_staged_handler() mock_message = mock.MagicMock() handler._queue_consumer.process(payload, mock_message) BufferedDispatcher.dispatch.assert_called_once_with( - handler._queue_consumer._process_message, payload) + handler._queue_consumer._process_message, payload + ) handler._queue_consumer._process_message(payload) FakeStagedMessageHandler.process.assert_called_once_with(payload) self.assertTrue(mock_message.ack.called) @@ -104,13 +103,10 @@ def test_process_message_wrong_payload_type(self): class FakeVariableMessageHandler(consumers.VariableMessageHandler): - def __init__(self, connection, queues): super(FakeVariableMessageHandler, self).__init__(connection, queues) - self.message_types = { - FakeModelDB: self.handle_fake_model - } + self.message_types = {FakeModelDB: self.handle_fake_model} def process(self, message): handler_function = self.message_types.get(type(message)) @@ -125,15 +121,16 @@ def get_variable_messages_handler(): class VariableMessageQueueConsumerTest(DbTestCase): - - @mock.patch.object(FakeVariableMessageHandler, 'handle_fake_model', mock.MagicMock()) + @mock.patch.object( + FakeVariableMessageHandler, "handle_fake_model", mock.MagicMock() + ) def test_process_message(self): payload = FakeModelDB() handler = get_variable_messages_handler() handler._queue_consumer._process_message(payload) FakeVariableMessageHandler.handle_fake_model.assert_called_once_with(payload) - @mock.patch.object(FakeVariableMessageHandler, 'process', mock.MagicMock()) + @mock.patch.object(FakeVariableMessageHandler, "process", mock.MagicMock()) def test_process_message_wrong_payload_type(self): payload = 100 handler = get_variable_messages_handler() diff --git a/st2common/tests/unit/test_queue_utils.py b/st2common/tests/unit/test_queue_utils.py index 52ad7a60dc2..db77fc01c20 100644 --- a/st2common/tests/unit/test_queue_utils.py +++ b/st2common/tests/unit/test_queue_utils.py @@ -22,31 +22,42 @@ class TestQueueUtils(TestCase): - def test_get_queue_name(self): - self.assertRaises(ValueError, - queue_utils.get_queue_name, - queue_name_base=None, queue_name_suffix=None) - self.assertRaises(ValueError, - queue_utils.get_queue_name, - queue_name_base='', queue_name_suffix=None) - self.assertEqual(queue_utils.get_queue_name(queue_name_base='st2.test.watch', - queue_name_suffix=None), - 'st2.test.watch') - self.assertEqual(queue_utils.get_queue_name(queue_name_base='st2.test.watch', - queue_name_suffix=''), - 'st2.test.watch') + self.assertRaises( + ValueError, + queue_utils.get_queue_name, + queue_name_base=None, + queue_name_suffix=None, + ) + self.assertRaises( + ValueError, + queue_utils.get_queue_name, + queue_name_base="", + queue_name_suffix=None, + ) + self.assertEqual( + queue_utils.get_queue_name( + queue_name_base="st2.test.watch", queue_name_suffix=None + ), + "st2.test.watch", + ) + self.assertEqual( + queue_utils.get_queue_name( + queue_name_base="st2.test.watch", queue_name_suffix="" + ), + "st2.test.watch", + ) queue_name = queue_utils.get_queue_name( - queue_name_base='st2.test.watch', - queue_name_suffix='foo', - add_random_uuid_to_suffix=True + queue_name_base="st2.test.watch", + queue_name_suffix="foo", + add_random_uuid_to_suffix=True, ) - pattern = re.compile(r'st2.test.watch.foo-\w') + pattern = re.compile(r"st2.test.watch.foo-\w") self.assertTrue(re.match(pattern, queue_name)) queue_name = queue_utils.get_queue_name( - queue_name_base='st2.test.watch', - queue_name_suffix='foo', - add_random_uuid_to_suffix=False + queue_name_base="st2.test.watch", + queue_name_suffix="foo", + add_random_uuid_to_suffix=False, ) - self.assertEqual(queue_name, 'st2.test.watch.foo') + self.assertEqual(queue_name, "st2.test.watch.foo") diff --git a/st2common/tests/unit/test_rbac_types.py b/st2common/tests/unit/test_rbac_types.py index d9d0a1dae8e..03b5350cc97 100644 --- a/st2common/tests/unit/test_rbac_types.py +++ b/st2common/tests/unit/test_rbac_types.py @@ -22,158 +22,274 @@ class RBACPermissionTypeTestCase(TestCase): - def test_get_valid_permission_for_resource_type(self): - valid_action_permissions = PermissionType.get_valid_permissions_for_resource_type( - resource_type=ResourceType.ACTION) + valid_action_permissions = ( + PermissionType.get_valid_permissions_for_resource_type( + resource_type=ResourceType.ACTION + ) + ) for name in valid_action_permissions: - self.assertTrue(name.startswith(ResourceType.ACTION + '_')) + self.assertTrue(name.startswith(ResourceType.ACTION + "_")) valid_rule_permissions = PermissionType.get_valid_permissions_for_resource_type( - resource_type=ResourceType.RULE) + resource_type=ResourceType.RULE + ) for name in valid_rule_permissions: - self.assertTrue(name.startswith(ResourceType.RULE + '_')) + self.assertTrue(name.startswith(ResourceType.RULE + "_")) def test_get_resource_type(self): - self.assertEqual(PermissionType.get_resource_type(PermissionType.PACK_LIST), - SystemType.PACK) - self.assertEqual(PermissionType.get_resource_type(PermissionType.PACK_VIEW), - SystemType.PACK) - self.assertEqual(PermissionType.get_resource_type(PermissionType.PACK_CREATE), - SystemType.PACK) - self.assertEqual(PermissionType.get_resource_type(PermissionType.PACK_MODIFY), - SystemType.PACK) - self.assertEqual(PermissionType.get_resource_type(PermissionType.PACK_DELETE), - SystemType.PACK) - self.assertEqual(PermissionType.get_resource_type(PermissionType.PACK_ALL), - SystemType.PACK) - - self.assertEqual(PermissionType.get_resource_type(PermissionType.SENSOR_LIST), - SystemType.SENSOR_TYPE) - self.assertEqual(PermissionType.get_resource_type(PermissionType.SENSOR_VIEW), - SystemType.SENSOR_TYPE) - self.assertEqual(PermissionType.get_resource_type(PermissionType.SENSOR_MODIFY), - SystemType.SENSOR_TYPE) - self.assertEqual(PermissionType.get_resource_type(PermissionType.SENSOR_ALL), - SystemType.SENSOR_TYPE) - - self.assertEqual(PermissionType.get_resource_type(PermissionType.ACTION_LIST), - SystemType.ACTION) - self.assertEqual(PermissionType.get_resource_type(PermissionType.ACTION_VIEW), - SystemType.ACTION) - self.assertEqual(PermissionType.get_resource_type(PermissionType.ACTION_CREATE), - SystemType.ACTION) - self.assertEqual(PermissionType.get_resource_type(PermissionType.ACTION_MODIFY), - SystemType.ACTION) - self.assertEqual(PermissionType.get_resource_type(PermissionType.ACTION_DELETE), - SystemType.ACTION) - self.assertEqual(PermissionType.get_resource_type(PermissionType.ACTION_EXECUTE), - SystemType.ACTION) - self.assertEqual(PermissionType.get_resource_type(PermissionType.ACTION_ALL), - SystemType.ACTION) - - self.assertEqual(PermissionType.get_resource_type(PermissionType.EXECUTION_LIST), - SystemType.EXECUTION) - self.assertEqual(PermissionType.get_resource_type(PermissionType.EXECUTION_VIEW), - SystemType.EXECUTION) - self.assertEqual(PermissionType.get_resource_type(PermissionType.EXECUTION_RE_RUN), - SystemType.EXECUTION) - self.assertEqual(PermissionType.get_resource_type(PermissionType.EXECUTION_STOP), - SystemType.EXECUTION) - self.assertEqual(PermissionType.get_resource_type(PermissionType.EXECUTION_ALL), - SystemType.EXECUTION) - - self.assertEqual(PermissionType.get_resource_type(PermissionType.RULE_LIST), - SystemType.RULE) - self.assertEqual(PermissionType.get_resource_type(PermissionType.RULE_VIEW), - SystemType.RULE) - self.assertEqual(PermissionType.get_resource_type(PermissionType.RULE_CREATE), - SystemType.RULE) - self.assertEqual(PermissionType.get_resource_type(PermissionType.RULE_MODIFY), - SystemType.RULE) - self.assertEqual(PermissionType.get_resource_type(PermissionType.RULE_DELETE), - SystemType.RULE) - self.assertEqual(PermissionType.get_resource_type(PermissionType.RULE_ALL), - SystemType.RULE) - - self.assertEqual(PermissionType.get_resource_type(PermissionType.RULE_ENFORCEMENT_LIST), - SystemType.RULE_ENFORCEMENT) - self.assertEqual(PermissionType.get_resource_type(PermissionType.RULE_ENFORCEMENT_VIEW), - SystemType.RULE_ENFORCEMENT) - - self.assertEqual(PermissionType.get_resource_type(PermissionType.KEY_VALUE_VIEW), - SystemType.KEY_VALUE_PAIR) - self.assertEqual(PermissionType.get_resource_type(PermissionType.KEY_VALUE_SET), - SystemType.KEY_VALUE_PAIR) - self.assertEqual(PermissionType.get_resource_type(PermissionType.KEY_VALUE_DELETE), - SystemType.KEY_VALUE_PAIR) - - self.assertEqual(PermissionType.get_resource_type(PermissionType.WEBHOOK_CREATE), - SystemType.WEBHOOK) - self.assertEqual(PermissionType.get_resource_type(PermissionType.WEBHOOK_SEND), - SystemType.WEBHOOK) - self.assertEqual(PermissionType.get_resource_type(PermissionType.WEBHOOK_DELETE), - SystemType.WEBHOOK) - self.assertEqual(PermissionType.get_resource_type(PermissionType.WEBHOOK_ALL), - SystemType.WEBHOOK) - - self.assertEqual(PermissionType.get_resource_type(PermissionType.API_KEY_LIST), - SystemType.API_KEY) - self.assertEqual(PermissionType.get_resource_type(PermissionType.API_KEY_VIEW), - SystemType.API_KEY) - self.assertEqual(PermissionType.get_resource_type(PermissionType.API_KEY_CREATE), - SystemType.API_KEY) - self.assertEqual(PermissionType.get_resource_type(PermissionType.API_KEY_DELETE), - SystemType.API_KEY) - self.assertEqual(PermissionType.get_resource_type(PermissionType.API_KEY_ALL), - SystemType.API_KEY) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.PACK_LIST), SystemType.PACK + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.PACK_VIEW), SystemType.PACK + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.PACK_CREATE), + SystemType.PACK, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.PACK_MODIFY), + SystemType.PACK, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.PACK_DELETE), + SystemType.PACK, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.PACK_ALL), SystemType.PACK + ) + + self.assertEqual( + PermissionType.get_resource_type(PermissionType.SENSOR_LIST), + SystemType.SENSOR_TYPE, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.SENSOR_VIEW), + SystemType.SENSOR_TYPE, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.SENSOR_MODIFY), + SystemType.SENSOR_TYPE, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.SENSOR_ALL), + SystemType.SENSOR_TYPE, + ) + + self.assertEqual( + PermissionType.get_resource_type(PermissionType.ACTION_LIST), + SystemType.ACTION, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.ACTION_VIEW), + SystemType.ACTION, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.ACTION_CREATE), + SystemType.ACTION, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.ACTION_MODIFY), + SystemType.ACTION, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.ACTION_DELETE), + SystemType.ACTION, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.ACTION_EXECUTE), + SystemType.ACTION, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.ACTION_ALL), + SystemType.ACTION, + ) + + self.assertEqual( + PermissionType.get_resource_type(PermissionType.EXECUTION_LIST), + SystemType.EXECUTION, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.EXECUTION_VIEW), + SystemType.EXECUTION, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.EXECUTION_RE_RUN), + SystemType.EXECUTION, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.EXECUTION_STOP), + SystemType.EXECUTION, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.EXECUTION_ALL), + SystemType.EXECUTION, + ) + + self.assertEqual( + PermissionType.get_resource_type(PermissionType.RULE_LIST), SystemType.RULE + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.RULE_VIEW), SystemType.RULE + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.RULE_CREATE), + SystemType.RULE, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.RULE_MODIFY), + SystemType.RULE, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.RULE_DELETE), + SystemType.RULE, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.RULE_ALL), SystemType.RULE + ) + + self.assertEqual( + PermissionType.get_resource_type(PermissionType.RULE_ENFORCEMENT_LIST), + SystemType.RULE_ENFORCEMENT, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.RULE_ENFORCEMENT_VIEW), + SystemType.RULE_ENFORCEMENT, + ) + + self.assertEqual( + PermissionType.get_resource_type(PermissionType.KEY_VALUE_VIEW), + SystemType.KEY_VALUE_PAIR, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.KEY_VALUE_SET), + SystemType.KEY_VALUE_PAIR, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.KEY_VALUE_DELETE), + SystemType.KEY_VALUE_PAIR, + ) + + self.assertEqual( + PermissionType.get_resource_type(PermissionType.WEBHOOK_CREATE), + SystemType.WEBHOOK, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.WEBHOOK_SEND), + SystemType.WEBHOOK, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.WEBHOOK_DELETE), + SystemType.WEBHOOK, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.WEBHOOK_ALL), + SystemType.WEBHOOK, + ) + + self.assertEqual( + PermissionType.get_resource_type(PermissionType.API_KEY_LIST), + SystemType.API_KEY, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.API_KEY_VIEW), + SystemType.API_KEY, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.API_KEY_CREATE), + SystemType.API_KEY, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.API_KEY_DELETE), + SystemType.API_KEY, + ) + self.assertEqual( + PermissionType.get_resource_type(PermissionType.API_KEY_ALL), + SystemType.API_KEY, + ) def test_get_permission_type(self): - self.assertEqual(PermissionType.get_permission_type(resource_type=ResourceType.ACTION, - permission_name='view'), - PermissionType.ACTION_VIEW) - self.assertEqual(PermissionType.get_permission_type(resource_type=ResourceType.ACTION, - permission_name='all'), - PermissionType.ACTION_ALL) - self.assertEqual(PermissionType.get_permission_type(resource_type=ResourceType.ACTION, - permission_name='execute'), - PermissionType.ACTION_EXECUTE) - self.assertEqual(PermissionType.get_permission_type(resource_type=ResourceType.RULE, - permission_name='view'), - PermissionType.RULE_VIEW) - self.assertEqual(PermissionType.get_permission_type(resource_type=ResourceType.RULE, - permission_name='delete'), - PermissionType.RULE_DELETE) - self.assertEqual(PermissionType.get_permission_type(resource_type=ResourceType.SENSOR, - permission_name='view'), - PermissionType.SENSOR_VIEW) - self.assertEqual(PermissionType.get_permission_type(resource_type=ResourceType.SENSOR, - permission_name='all'), - PermissionType.SENSOR_ALL) - self.assertEqual(PermissionType.get_permission_type(resource_type=ResourceType.SENSOR, - permission_name='modify'), - PermissionType.SENSOR_MODIFY) - self.assertEqual( - PermissionType.get_permission_type(resource_type=ResourceType.RULE_ENFORCEMENT, - permission_name='view'), - PermissionType.RULE_ENFORCEMENT_VIEW) + self.assertEqual( + PermissionType.get_permission_type( + resource_type=ResourceType.ACTION, permission_name="view" + ), + PermissionType.ACTION_VIEW, + ) + self.assertEqual( + PermissionType.get_permission_type( + resource_type=ResourceType.ACTION, permission_name="all" + ), + PermissionType.ACTION_ALL, + ) + self.assertEqual( + PermissionType.get_permission_type( + resource_type=ResourceType.ACTION, permission_name="execute" + ), + PermissionType.ACTION_EXECUTE, + ) + self.assertEqual( + PermissionType.get_permission_type( + resource_type=ResourceType.RULE, permission_name="view" + ), + PermissionType.RULE_VIEW, + ) + self.assertEqual( + PermissionType.get_permission_type( + resource_type=ResourceType.RULE, permission_name="delete" + ), + PermissionType.RULE_DELETE, + ) + self.assertEqual( + PermissionType.get_permission_type( + resource_type=ResourceType.SENSOR, permission_name="view" + ), + PermissionType.SENSOR_VIEW, + ) + self.assertEqual( + PermissionType.get_permission_type( + resource_type=ResourceType.SENSOR, permission_name="all" + ), + PermissionType.SENSOR_ALL, + ) + self.assertEqual( + PermissionType.get_permission_type( + resource_type=ResourceType.SENSOR, permission_name="modify" + ), + PermissionType.SENSOR_MODIFY, + ) + self.assertEqual( + PermissionType.get_permission_type( + resource_type=ResourceType.RULE_ENFORCEMENT, permission_name="view" + ), + PermissionType.RULE_ENFORCEMENT_VIEW, + ) def test_get_permission_name(self): - self.assertEqual(PermissionType.get_permission_name(PermissionType.ACTION_LIST), - 'list') - self.assertEqual(PermissionType.get_permission_name(PermissionType.ACTION_CREATE), - 'create') - self.assertEqual(PermissionType.get_permission_name(PermissionType.ACTION_DELETE), - 'delete') - self.assertEqual(PermissionType.get_permission_name(PermissionType.ACTION_ALL), - 'all') - self.assertEqual(PermissionType.get_permission_name(PermissionType.PACK_ALL), - 'all') - self.assertEqual(PermissionType.get_permission_name(PermissionType.SENSOR_MODIFY), - 'modify') - self.assertEqual(PermissionType.get_permission_name(PermissionType.ACTION_EXECUTE), - 'execute') - self.assertEqual(PermissionType.get_permission_name(PermissionType.RULE_ENFORCEMENT_LIST), - 'list') + self.assertEqual( + PermissionType.get_permission_name(PermissionType.ACTION_LIST), "list" + ) + self.assertEqual( + PermissionType.get_permission_name(PermissionType.ACTION_CREATE), "create" + ) + self.assertEqual( + PermissionType.get_permission_name(PermissionType.ACTION_DELETE), "delete" + ) + self.assertEqual( + PermissionType.get_permission_name(PermissionType.ACTION_ALL), "all" + ) + self.assertEqual( + PermissionType.get_permission_name(PermissionType.PACK_ALL), "all" + ) + self.assertEqual( + PermissionType.get_permission_name(PermissionType.SENSOR_MODIFY), "modify" + ) + self.assertEqual( + PermissionType.get_permission_name(PermissionType.ACTION_EXECUTE), "execute" + ) + self.assertEqual( + PermissionType.get_permission_name(PermissionType.RULE_ENFORCEMENT_LIST), + "list", + ) diff --git a/st2common/tests/unit/test_reference.py b/st2common/tests/unit/test_reference.py index f39800c2dd1..ced486a867d 100644 --- a/st2common/tests/unit/test_reference.py +++ b/st2common/tests/unit/test_reference.py @@ -26,35 +26,34 @@ from st2tests import DbTestCase -@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) +@mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) class ReferenceTest(DbTestCase): __model = None __ref = None @classmethod - @mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) + @mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) def setUpClass(cls): super(ReferenceTest, cls).setUpClass() - trigger = TriggerDB(pack='dummy_pack_1', name='trigger-1') + trigger = TriggerDB(pack="dummy_pack_1", name="trigger-1") cls.__model = Trigger.add_or_update(trigger) - cls.__ref = {'id': str(cls.__model.id), - 'name': cls.__model.name} + cls.__ref = {"id": str(cls.__model.id), "name": cls.__model.name} @classmethod - @mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) + @mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) def tearDownClass(cls): Trigger.delete(cls.__model) super(ReferenceTest, cls).tearDownClass() def test_to_reference(self): ref = reference.get_ref_from_model(self.__model) - self.assertEqual(ref, self.__ref, 'Failed to generated equivalent ref.') + self.assertEqual(ref, self.__ref, "Failed to generated equivalent ref.") def test_to_reference_no_model(self): try: reference.get_ref_from_model(None) - self.assertTrue(False, 'Exception expected.') + self.assertTrue(False, "Exception expected.") except ValueError: self.assertTrue(True) @@ -63,37 +62,37 @@ def test_to_reference_no_model_id(self): model = copy.copy(self.__model) model.id = None reference.get_ref_from_model(model) - self.assertTrue(False, 'Exception expected.') + self.assertTrue(False, "Exception expected.") except db.StackStormDBObjectMalformedError: self.assertTrue(True) def test_to_model_with_id(self): model = reference.get_model_from_ref(Trigger, self.__ref) - self.assertEqual(model, self.__model, 'Failed to return correct model.') + self.assertEqual(model, self.__model, "Failed to return correct model.") def test_to_model_with_name(self): ref = copy.copy(self.__ref) - ref['id'] = None + ref["id"] = None model = reference.get_model_from_ref(Trigger, ref) - self.assertEqual(model, self.__model, 'Failed to return correct model.') + self.assertEqual(model, self.__model, "Failed to return correct model.") def test_to_model_no_name_no_id(self): try: reference.get_model_from_ref(Trigger, {}) - self.assertTrue(False, 'Exception expected.') + self.assertTrue(False, "Exception expected.") except db.StackStormDBObjectNotFoundError: self.assertTrue(True) def test_to_model_unknown_id(self): try: - reference.get_model_from_ref(Trigger, {'id': '1'}) - self.assertTrue(False, 'Exception expected.') + reference.get_model_from_ref(Trigger, {"id": "1"}) + self.assertTrue(False, "Exception expected.") except mongoengine.ValidationError: self.assertTrue(True) def test_to_model_unknown_name(self): try: - reference.get_model_from_ref(Trigger, {'name': 'unknown'}) - self.assertTrue(False, 'Exception expected.') + reference.get_model_from_ref(Trigger, {"name": "unknown"}) + self.assertTrue(False, "Exception expected.") except db.StackStormDBObjectNotFoundError: self.assertTrue(True) diff --git a/st2common/tests/unit/test_register_internal_trigger.py b/st2common/tests/unit/test_register_internal_trigger.py index dd4959611fc..3d33e325483 100644 --- a/st2common/tests/unit/test_register_internal_trigger.py +++ b/st2common/tests/unit/test_register_internal_trigger.py @@ -20,7 +20,6 @@ class TestRegisterInternalTriggers(DbTestCase): - def test_register_internal_trigger_types(self): registered_trigger_types_db = register_internal_trigger_types() for trigger_type_db in registered_trigger_types_db: @@ -31,4 +30,6 @@ def _validate_shadow_trigger(self, trigger_type_db): return trigger_type_ref = trigger_type_db.get_reference().ref triggers = Trigger.query(type=trigger_type_ref) - self.assertTrue(len(triggers) > 0, 'Shadow trigger not created for %s.' % trigger_type_ref) + self.assertTrue( + len(triggers) > 0, "Shadow trigger not created for %s." % trigger_type_ref + ) diff --git a/st2common/tests/unit/test_resource_reference.py b/st2common/tests/unit/test_resource_reference.py index 04dfdc93571..95533022ed7 100644 --- a/st2common/tests/unit/test_resource_reference.py +++ b/st2common/tests/unit/test_resource_reference.py @@ -22,45 +22,64 @@ class ResourceReferenceTestCase(unittest2.TestCase): def test_resource_reference_success(self): - value = 'pack1.name1' + value = "pack1.name1" ref = ResourceReference.from_string_reference(ref=value) - self.assertEqual(ref.pack, 'pack1') - self.assertEqual(ref.name, 'name1') + self.assertEqual(ref.pack, "pack1") + self.assertEqual(ref.name, "name1") self.assertEqual(ref.ref, value) - ref = ResourceReference(pack='pack1', name='name1') - self.assertEqual(ref.ref, 'pack1.name1') + ref = ResourceReference(pack="pack1", name="name1") + self.assertEqual(ref.ref, "pack1.name1") - ref = ResourceReference(pack='pack1', name='name1.name2') - self.assertEqual(ref.ref, 'pack1.name1.name2') + ref = ResourceReference(pack="pack1", name="name1.name2") + self.assertEqual(ref.ref, "pack1.name1.name2") def test_resource_reference_failure(self): - self.assertRaises(InvalidResourceReferenceError, - ResourceReference.from_string_reference, - ref='blah') + self.assertRaises( + InvalidResourceReferenceError, + ResourceReference.from_string_reference, + ref="blah", + ) - self.assertRaises(InvalidResourceReferenceError, - ResourceReference.from_string_reference, - ref=None) + self.assertRaises( + InvalidResourceReferenceError, + ResourceReference.from_string_reference, + ref=None, + ) def test_to_string_reference(self): - ref = ResourceReference.to_string_reference(pack='mapack', name='moname') - self.assertEqual(ref, 'mapack.moname') + ref = ResourceReference.to_string_reference(pack="mapack", name="moname") + self.assertEqual(ref, "mapack.moname") expected_msg = r'Pack name should not contain "\."' - self.assertRaisesRegexp(ValueError, expected_msg, ResourceReference.to_string_reference, - pack='pack.invalid', name='bar') + self.assertRaisesRegexp( + ValueError, + expected_msg, + ResourceReference.to_string_reference, + pack="pack.invalid", + name="bar", + ) - expected_msg = 'Both pack and name needed for building' - self.assertRaisesRegexp(ValueError, expected_msg, ResourceReference.to_string_reference, - pack='pack', name=None) + expected_msg = "Both pack and name needed for building" + self.assertRaisesRegexp( + ValueError, + expected_msg, + ResourceReference.to_string_reference, + pack="pack", + name=None, + ) - expected_msg = 'Both pack and name needed for building' - self.assertRaisesRegexp(ValueError, expected_msg, ResourceReference.to_string_reference, - pack=None, name='name') + expected_msg = "Both pack and name needed for building" + self.assertRaisesRegexp( + ValueError, + expected_msg, + ResourceReference.to_string_reference, + pack=None, + name="name", + ) def test_is_resource_reference(self): - self.assertTrue(ResourceReference.is_resource_reference('foo.bar')) - self.assertTrue(ResourceReference.is_resource_reference('foo.bar.ponies')) - self.assertFalse(ResourceReference.is_resource_reference('foo')) + self.assertTrue(ResourceReference.is_resource_reference("foo.bar")) + self.assertTrue(ResourceReference.is_resource_reference("foo.bar.ponies")) + self.assertFalse(ResourceReference.is_resource_reference("foo")) diff --git a/st2common/tests/unit/test_resource_registrar.py b/st2common/tests/unit/test_resource_registrar.py index 9850785f211..2a1c61ad6ab 100644 --- a/st2common/tests/unit/test_resource_registrar.py +++ b/st2common/tests/unit/test_resource_registrar.py @@ -30,23 +30,21 @@ from st2tests.fixturesloader import get_fixtures_base_path -__all__ = [ - 'ResourceRegistrarTestCase' -] - -PACK_PATH_1 = os.path.join(get_fixtures_base_path(), 'packs/dummy_pack_1') -PACK_PATH_6 = os.path.join(get_fixtures_base_path(), 'packs/dummy_pack_6') -PACK_PATH_7 = os.path.join(get_fixtures_base_path(), 'packs/dummy_pack_7') -PACK_PATH_8 = os.path.join(get_fixtures_base_path(), 'packs/dummy_pack_8') -PACK_PATH_9 = os.path.join(get_fixtures_base_path(), 'packs/dummy_pack_9') -PACK_PATH_10 = os.path.join(get_fixtures_base_path(), 'packs/dummy_pack_10') -PACK_PATH_12 = os.path.join(get_fixtures_base_path(), 'packs/dummy_pack_12') -PACK_PATH_13 = os.path.join(get_fixtures_base_path(), 'packs/dummy_pack_13') -PACK_PATH_14 = os.path.join(get_fixtures_base_path(), 'packs/dummy_pack_14') -PACK_PATH_17 = os.path.join(get_fixtures_base_path(), 'packs_invalid/dummy_pack_17') -PACK_PATH_18 = os.path.join(get_fixtures_base_path(), 'packs_invalid/dummy_pack_18') -PACK_PATH_20 = os.path.join(get_fixtures_base_path(), 'packs/dummy_pack_20') -PACK_PATH_21 = os.path.join(get_fixtures_base_path(), 'packs/dummy_pack_21') +__all__ = ["ResourceRegistrarTestCase"] + +PACK_PATH_1 = os.path.join(get_fixtures_base_path(), "packs/dummy_pack_1") +PACK_PATH_6 = os.path.join(get_fixtures_base_path(), "packs/dummy_pack_6") +PACK_PATH_7 = os.path.join(get_fixtures_base_path(), "packs/dummy_pack_7") +PACK_PATH_8 = os.path.join(get_fixtures_base_path(), "packs/dummy_pack_8") +PACK_PATH_9 = os.path.join(get_fixtures_base_path(), "packs/dummy_pack_9") +PACK_PATH_10 = os.path.join(get_fixtures_base_path(), "packs/dummy_pack_10") +PACK_PATH_12 = os.path.join(get_fixtures_base_path(), "packs/dummy_pack_12") +PACK_PATH_13 = os.path.join(get_fixtures_base_path(), "packs/dummy_pack_13") +PACK_PATH_14 = os.path.join(get_fixtures_base_path(), "packs/dummy_pack_14") +PACK_PATH_17 = os.path.join(get_fixtures_base_path(), "packs_invalid/dummy_pack_17") +PACK_PATH_18 = os.path.join(get_fixtures_base_path(), "packs_invalid/dummy_pack_18") +PACK_PATH_20 = os.path.join(get_fixtures_base_path(), "packs/dummy_pack_20") +PACK_PATH_21 = os.path.join(get_fixtures_base_path(), "packs/dummy_pack_21") class ResourceRegistrarTestCase(CleanDbTestCase): @@ -60,7 +58,7 @@ def test_register_packs(self): registrar = ResourceRegistrar(use_pack_cache=False) registrar._pack_loader.get_packs = mock.Mock() - registrar._pack_loader.get_packs.return_value = {'dummy_pack_1': PACK_PATH_1} + registrar._pack_loader.get_packs.return_value = {"dummy_pack_1": PACK_PATH_1} packs_base_paths = content_utils.get_packs_base_paths() registrar.register_packs(base_dirs=packs_base_paths) @@ -74,20 +72,20 @@ def test_register_packs(self): pack_db = pack_dbs[0] config_schema_db = config_schema_dbs[0] - self.assertEqual(pack_db.name, 'dummy_pack_1') + self.assertEqual(pack_db.name, "dummy_pack_1") self.assertEqual(len(pack_db.contributors), 2) - self.assertEqual(pack_db.contributors[0], 'John Doe1 ') - self.assertEqual(pack_db.contributors[1], 'John Doe2 ') - self.assertIn('api_key', config_schema_db.attributes) - self.assertIn('api_secret', config_schema_db.attributes) + self.assertEqual(pack_db.contributors[0], "John Doe1 ") + self.assertEqual(pack_db.contributors[1], "John Doe2 ") + self.assertIn("api_key", config_schema_db.attributes) + self.assertIn("api_secret", config_schema_db.attributes) # Verify pack_db.files is correct and doesn't contain excluded files (*.pyc, .git/*, etc.) # Note: We can't test that .git/* files are excluded since git doesn't allow you to add # .git directory to existing repo index :/ excluded_files = [ - '__init__.pyc', - 'actions/dummy1.pyc', - 'actions/dummy2.pyc', + "__init__.pyc", + "actions/dummy1.pyc", + "actions/dummy2.pyc", ] for excluded_file in excluded_files: @@ -100,14 +98,14 @@ def test_register_pack_arbitrary_properties_are_allowed(self): registrar = ResourceRegistrar(use_pack_cache=False) registrar._pack_loader.get_packs = mock.Mock() registrar._pack_loader.get_packs.return_value = { - 'dummy_pack_20': PACK_PATH_20, + "dummy_pack_20": PACK_PATH_20, } packs_base_paths = content_utils.get_packs_base_paths() registrar.register_packs(base_dirs=packs_base_paths) # Ref is provided - pack_db = Pack.get_by_name('dummy_pack_20') - self.assertEqual(pack_db.ref, 'dummy_pack_20_ref') + pack_db = Pack.get_by_name("dummy_pack_20") + self.assertEqual(pack_db.ref, "dummy_pack_20_ref") self.assertEqual(len(pack_db.contributors), 0) def test_register_pack_pack_ref(self): @@ -119,53 +117,74 @@ def test_register_pack_pack_ref(self): registrar = ResourceRegistrar(use_pack_cache=False) registrar._pack_loader.get_packs = mock.Mock() registrar._pack_loader.get_packs.return_value = { - 'dummy_pack_1': PACK_PATH_1, - 'dummy_pack_6': PACK_PATH_6 + "dummy_pack_1": PACK_PATH_1, + "dummy_pack_6": PACK_PATH_6, } packs_base_paths = content_utils.get_packs_base_paths() registrar.register_packs(base_dirs=packs_base_paths) # Ref is provided - pack_db = Pack.get_by_name('dummy_pack_6') - self.assertEqual(pack_db.ref, 'dummy_pack_6_ref') + pack_db = Pack.get_by_name("dummy_pack_6") + self.assertEqual(pack_db.ref, "dummy_pack_6_ref") self.assertEqual(len(pack_db.contributors), 0) # Ref is not provided, directory name should be used - pack_db = Pack.get_by_name('dummy_pack_1') - self.assertEqual(pack_db.ref, 'dummy_pack_1') + pack_db = Pack.get_by_name("dummy_pack_1") + self.assertEqual(pack_db.ref, "dummy_pack_1") # "ref" is not provided, but "name" is registrar._register_pack_db(pack_name=None, pack_dir=PACK_PATH_7) - pack_db = Pack.get_by_name('dummy_pack_7_name') - self.assertEqual(pack_db.ref, 'dummy_pack_7_name') + pack_db = Pack.get_by_name("dummy_pack_7_name") + self.assertEqual(pack_db.ref, "dummy_pack_7_name") # "ref" is not provided and "name" contains invalid characters - expected_msg = 'contains invalid characters' - self.assertRaisesRegexp(ValueError, expected_msg, registrar._register_pack_db, - pack_name=None, pack_dir=PACK_PATH_8) + expected_msg = "contains invalid characters" + self.assertRaisesRegexp( + ValueError, + expected_msg, + registrar._register_pack_db, + pack_name=None, + pack_dir=PACK_PATH_8, + ) def test_register_pack_invalid_ref_name_friendly_error_message(self): registrar = ResourceRegistrar(use_pack_cache=False) # Invalid ref - expected_msg = (r'Pack ref / name can only contain valid word characters .*?,' - ' dashes are not allowed.') - self.assertRaisesRegexp(ValidationError, expected_msg, registrar._register_pack_db, - pack_name=None, pack_dir=PACK_PATH_13) + expected_msg = ( + r"Pack ref / name can only contain valid word characters .*?," + " dashes are not allowed." + ) + self.assertRaisesRegexp( + ValidationError, + expected_msg, + registrar._register_pack_db, + pack_name=None, + pack_dir=PACK_PATH_13, + ) try: registrar._register_pack_db(pack_name=None, pack_dir=PACK_PATH_13) except ValidationError as e: - self.assertIn("'invalid-has-dash' does not match '^[a-z0-9_]+$'", six.text_type(e)) + self.assertIn( + "'invalid-has-dash' does not match '^[a-z0-9_]+$'", six.text_type(e) + ) else: - self.fail('Exception not thrown') + self.fail("Exception not thrown") # Pack ref not provided and name doesn't contain valid characters - expected_msg = (r'Pack name "dummy pack 14" contains invalid characters and "ref" ' - 'attribute is not available. You either need to add') - self.assertRaisesRegexp(ValueError, expected_msg, registrar._register_pack_db, - pack_name=None, pack_dir=PACK_PATH_14) + expected_msg = ( + r'Pack name "dummy pack 14" contains invalid characters and "ref" ' + "attribute is not available. You either need to add" + ) + self.assertRaisesRegexp( + ValueError, + expected_msg, + registrar._register_pack_db, + pack_name=None, + pack_dir=PACK_PATH_14, + ) def test_register_pack_pack_stackstorm_version_and_future_parameters(self): # Verify DB is empty @@ -174,53 +193,74 @@ def test_register_pack_pack_stackstorm_version_and_future_parameters(self): registrar = ResourceRegistrar(use_pack_cache=False) registrar._pack_loader.get_packs = mock.Mock() - registrar._pack_loader.get_packs.return_value = {'dummy_pack_9': PACK_PATH_9} + registrar._pack_loader.get_packs.return_value = {"dummy_pack_9": PACK_PATH_9} packs_base_paths = content_utils.get_packs_base_paths() registrar.register_packs(base_dirs=packs_base_paths) # Dependencies, stackstorm_version and future values - pack_db = Pack.get_by_name('dummy_pack_9_deps') - self.assertEqual(pack_db.dependencies, ['core=0.2.0']) - self.assertEqual(pack_db.stackstorm_version, '>=1.6.0, <2.2.0') - self.assertEqual(pack_db.system, {'centos': {'foo': '>= 1.0'}}) - self.assertEqual(pack_db.python_versions, ['2', '3']) + pack_db = Pack.get_by_name("dummy_pack_9_deps") + self.assertEqual(pack_db.dependencies, ["core=0.2.0"]) + self.assertEqual(pack_db.stackstorm_version, ">=1.6.0, <2.2.0") + self.assertEqual(pack_db.system, {"centos": {"foo": ">= 1.0"}}) + self.assertEqual(pack_db.python_versions, ["2", "3"]) # Note: We only store parameters which are defined in the schema, all other custom user # defined attributes are ignored - self.assertTrue(not hasattr(pack_db, 'future')) - self.assertTrue(not hasattr(pack_db, 'this')) + self.assertTrue(not hasattr(pack_db, "future")) + self.assertTrue(not hasattr(pack_db, "this")) # Wrong characters in the required st2 version expected_msg = "'wrongstackstormversion' does not match" - self.assertRaisesRegexp(ValidationError, expected_msg, registrar._register_pack_db, - pack_name=None, pack_dir=PACK_PATH_10) + self.assertRaisesRegexp( + ValidationError, + expected_msg, + registrar._register_pack_db, + pack_name=None, + pack_dir=PACK_PATH_10, + ) def test_register_pack_empty_and_invalid_config_schema(self): registrar = ResourceRegistrar(use_pack_cache=False, fail_on_failure=True) registrar._pack_loader.get_packs = mock.Mock() - registrar._pack_loader.get_packs.return_value = {'dummy_pack_17': PACK_PATH_17} + registrar._pack_loader.get_packs.return_value = {"dummy_pack_17": PACK_PATH_17} packs_base_paths = content_utils.get_packs_base_paths() - expected_msg = 'Config schema ".*?dummy_pack_17/config.schema.yaml" is empty and invalid.' - self.assertRaisesRegexp(ValueError, expected_msg, registrar.register_packs, - base_dirs=packs_base_paths) + expected_msg = ( + 'Config schema ".*?dummy_pack_17/config.schema.yaml" is empty and invalid.' + ) + self.assertRaisesRegexp( + ValueError, + expected_msg, + registrar.register_packs, + base_dirs=packs_base_paths, + ) def test_register_pack_invalid_config_schema_invalid_attribute(self): registrar = ResourceRegistrar(use_pack_cache=False, fail_on_failure=True) registrar._pack_loader.get_packs = mock.Mock() - registrar._pack_loader.get_packs.return_value = {'dummy_pack_18': PACK_PATH_18} + registrar._pack_loader.get_packs.return_value = {"dummy_pack_18": PACK_PATH_18} packs_base_paths = content_utils.get_packs_base_paths() - expected_msg = r'Additional properties are not allowed \(\'invalid\' was unexpected\)' - self.assertRaisesRegexp(ValueError, expected_msg, registrar.register_packs, - base_dirs=packs_base_paths) + expected_msg = ( + r"Additional properties are not allowed \(\'invalid\' was unexpected\)" + ) + self.assertRaisesRegexp( + ValueError, + expected_msg, + registrar.register_packs, + base_dirs=packs_base_paths, + ) def test_register_pack_invalid_python_versions_attribute(self): registrar = ResourceRegistrar(use_pack_cache=False, fail_on_failure=True) registrar._pack_loader.get_packs = mock.Mock() - registrar._pack_loader.get_packs.return_value = {'dummy_pack_21': PACK_PATH_21} + registrar._pack_loader.get_packs.return_value = {"dummy_pack_21": PACK_PATH_21} packs_base_paths = content_utils.get_packs_base_paths() expected_msg = r"'4' is not one of \['2', '3'\]" - self.assertRaisesRegexp(ValueError, expected_msg, registrar.register_packs, - base_dirs=packs_base_paths) + self.assertRaisesRegexp( + ValueError, + expected_msg, + registrar.register_packs, + base_dirs=packs_base_paths, + ) diff --git a/st2common/tests/unit/test_runners_base.py b/st2common/tests/unit/test_runners_base.py index 34ede41adfd..7490b40cd67 100644 --- a/st2common/tests/unit/test_runners_base.py +++ b/st2common/tests/unit/test_runners_base.py @@ -23,11 +23,12 @@ class RunnersLoaderUtilsTestCase(DbTestCase): def test_get_runner_success(self): - runner = get_runner('local-shell-cmd') + runner = get_runner("local-shell-cmd") self.assertTrue(runner) - self.assertEqual(runner.__class__.__name__, 'LocalShellCommandRunner') + self.assertEqual(runner.__class__.__name__, "LocalShellCommandRunner") def test_get_runner_failure_not_found(self): - expected_msg = 'Failed to find runner invalid-name-not-found.*' - self.assertRaisesRegexp(ActionRunnerCreateError, expected_msg, - get_runner, 'invalid-name-not-found') + expected_msg = "Failed to find runner invalid-name-not-found.*" + self.assertRaisesRegexp( + ActionRunnerCreateError, expected_msg, get_runner, "invalid-name-not-found" + ) diff --git a/st2common/tests/unit/test_runners_utils.py b/st2common/tests/unit/test_runners_utils.py index dc988482233..bc6acfcf7e7 100644 --- a/st2common/tests/unit/test_runners_utils.py +++ b/st2common/tests/unit/test_runners_utils.py @@ -24,16 +24,17 @@ from st2tests import config as tests_config + tests_config.parse_args() -FIXTURES_PACK = 'generic' +FIXTURES_PACK = "generic" TEST_FIXTURES = { - 'liveactions': ['liveaction1.yaml'], - 'actions': ['local.yaml'], - 'executions': ['execution1.yaml'], - 'runners': ['run-local.yaml'] + "liveactions": ["liveaction1.yaml"], + "actions": ["local.yaml"], + "executions": ["execution1.yaml"], + "runners": ["run-local.yaml"], } @@ -48,15 +49,16 @@ def setUp(self): loader = fixturesloader.FixturesLoader() self.models = loader.save_fixtures_to_db( - fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_FIXTURES + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES ) - self.liveaction_db = self.models['liveactions']['liveaction1.yaml'] + self.liveaction_db = self.models["liveactions"]["liveaction1.yaml"] exe_svc.create_execution_object(self.liveaction_db) self.action_db = action_db_utils.get_action_by_ref(self.liveaction_db.action) - @mock.patch.object(action_db_utils, 'get_action_by_ref', mock.MagicMock(return_value=None)) + @mock.patch.object( + action_db_utils, "get_action_by_ref", mock.MagicMock(return_value=None) + ) def test_invoke_post_run_action_provided(self): utils.invoke_post_run(self.liveaction_db, action_db=self.action_db) action_db_utils.get_action_by_ref.assert_not_called() @@ -64,8 +66,12 @@ def test_invoke_post_run_action_provided(self): def test_invoke_post_run_action_exists(self): utils.invoke_post_run(self.liveaction_db) - @mock.patch.object(action_db_utils, 'get_action_by_ref', mock.MagicMock(return_value=None)) - @mock.patch.object(action_db_utils, 'get_runnertype_by_name', mock.MagicMock(return_value=None)) + @mock.patch.object( + action_db_utils, "get_action_by_ref", mock.MagicMock(return_value=None) + ) + @mock.patch.object( + action_db_utils, "get_runnertype_by_name", mock.MagicMock(return_value=None) + ) def test_invoke_post_run_action_does_not_exist(self): utils.invoke_post_run(self.liveaction_db) action_db_utils.get_action_by_ref.assert_called_once() diff --git a/st2common/tests/unit/test_sensor_type_utils.py b/st2common/tests/unit/test_sensor_type_utils.py index 08269ebcf2c..657054c4536 100644 --- a/st2common/tests/unit/test_sensor_type_utils.py +++ b/st2common/tests/unit/test_sensor_type_utils.py @@ -22,59 +22,67 @@ class SensorTypeUtilsTestCase(unittest2.TestCase): - def test_to_sensor_db_model_no_trigger_types(self): sensor_meta = { - 'artifact_uri': 'file:///data/st2contrib/packs/jira/sensors/jira_sensor.py', - 'class_name': 'JIRASensor', - 'pack': 'jira' + "artifact_uri": "file:///data/st2contrib/packs/jira/sensors/jira_sensor.py", + "class_name": "JIRASensor", + "pack": "jira", } sensor_api = SensorTypeAPI(**sensor_meta) sensor_model = SensorTypeAPI.to_model(sensor_api) - self.assertEqual(sensor_model.name, sensor_meta['class_name']) - self.assertEqual(sensor_model.pack, sensor_meta['pack']) - self.assertEqual(sensor_model.artifact_uri, sensor_meta['artifact_uri']) + self.assertEqual(sensor_model.name, sensor_meta["class_name"]) + self.assertEqual(sensor_model.pack, sensor_meta["pack"]) + self.assertEqual(sensor_model.artifact_uri, sensor_meta["artifact_uri"]) self.assertListEqual(sensor_model.trigger_types, []) - @mock.patch.object(sensor_type_utils, 'create_trigger_types', mock.MagicMock( - return_value=['mock.trigger_ref'])) + @mock.patch.object( + sensor_type_utils, + "create_trigger_types", + mock.MagicMock(return_value=["mock.trigger_ref"]), + ) def test_to_sensor_db_model_with_trigger_types(self): sensor_meta = { - 'artifact_uri': 'file:///data/st2contrib/packs/jira/sensors/jira_sensor.py', - 'class_name': 'JIRASensor', - 'pack': 'jira', - 'trigger_types': [{'pack': 'jira', 'name': 'issue_created', 'parameters': {}}] + "artifact_uri": "file:///data/st2contrib/packs/jira/sensors/jira_sensor.py", + "class_name": "JIRASensor", + "pack": "jira", + "trigger_types": [ + {"pack": "jira", "name": "issue_created", "parameters": {}} + ], } sensor_api = SensorTypeAPI(**sensor_meta) sensor_model = SensorTypeAPI.to_model(sensor_api) - self.assertListEqual(sensor_model.trigger_types, ['mock.trigger_ref']) + self.assertListEqual(sensor_model.trigger_types, ["mock.trigger_ref"]) def test_get_sensor_entry_point(self): # System packs - file_path = 'file:///data/st/st2reactor/st2reactor/' + \ - 'contrib/sensors/st2_generic_webhook_sensor.py' - class_name = 'St2GenericWebhooksSensor' + file_path = ( + "file:///data/st/st2reactor/st2reactor/" + + "contrib/sensors/st2_generic_webhook_sensor.py" + ) + class_name = "St2GenericWebhooksSensor" - sensor = {'artifact_uri': file_path, 'class_name': class_name, 'pack': 'core'} + sensor = {"artifact_uri": file_path, "class_name": class_name, "pack": "core"} sensor_api = SensorTypeAPI(**sensor) entry_point = sensor_type_utils.get_sensor_entry_point(sensor_api) self.assertEqual(entry_point, class_name) # Non system packs - file_path = 'file:///data/st2contrib/packs/jira/sensors/jira_sensor.py' - class_name = 'JIRASensor' - sensor = {'artifact_uri': file_path, 'class_name': class_name, 'pack': 'jira'} + file_path = "file:///data/st2contrib/packs/jira/sensors/jira_sensor.py" + class_name = "JIRASensor" + sensor = {"artifact_uri": file_path, "class_name": class_name, "pack": "jira"} sensor_api = SensorTypeAPI(**sensor) entry_point = sensor_type_utils.get_sensor_entry_point(sensor_api) - self.assertEqual(entry_point, 'sensors.jira_sensor.JIRASensor') + self.assertEqual(entry_point, "sensors.jira_sensor.JIRASensor") - file_path = 'file:///data/st2contrib/packs/docker/sensors/docker_container_sensor.py' - class_name = 'DockerSensor' - sensor = {'artifact_uri': file_path, 'class_name': class_name, 'pack': 'docker'} + file_path = ( + "file:///data/st2contrib/packs/docker/sensors/docker_container_sensor.py" + ) + class_name = "DockerSensor" + sensor = {"artifact_uri": file_path, "class_name": class_name, "pack": "docker"} sensor_api = SensorTypeAPI(**sensor) entry_point = sensor_type_utils.get_sensor_entry_point(sensor_api) - self.assertEqual(entry_point, 'sensors.docker_container_sensor.DockerSensor') + self.assertEqual(entry_point, "sensors.docker_container_sensor.DockerSensor") diff --git a/st2common/tests/unit/test_sensor_watcher.py b/st2common/tests/unit/test_sensor_watcher.py index 65f61965df7..2379f815620 100644 --- a/st2common/tests/unit/test_sensor_watcher.py +++ b/st2common/tests/unit/test_sensor_watcher.py @@ -22,39 +22,44 @@ from st2common.models.db.sensor import SensorTypeDB from st2common.transport.publishers import PoolPublisher -MOCK_SENSOR_DB = SensorTypeDB(name='foo', pack='test') +MOCK_SENSOR_DB = SensorTypeDB(name="foo", pack="test") class SensorWatcherTests(unittest2.TestCase): - - @mock.patch.object(Message, 'ack', mock.MagicMock()) - @mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) + @mock.patch.object(Message, "ack", mock.MagicMock()) + @mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) def test_assert_handlers_called(self): handler_vars = { - 'create_handler_called': False, - 'update_handler_called': False, - 'delete_handler_called': False + "create_handler_called": False, + "update_handler_called": False, + "delete_handler_called": False, } def create_handler(sensor_db): - handler_vars['create_handler_called'] = True + handler_vars["create_handler_called"] = True def update_handler(sensor_db): - handler_vars['update_handler_called'] = True + handler_vars["update_handler_called"] = True def delete_handler(sensor_db): - handler_vars['delete_handler_called'] = True + handler_vars["delete_handler_called"] = True sensor_watcher = SensorWatcher(create_handler, update_handler, delete_handler) - message = Message(None, delivery_info={'routing_key': 'create'}) + message = Message(None, delivery_info={"routing_key": "create"}) sensor_watcher.process_task(MOCK_SENSOR_DB, message) - self.assertTrue(handler_vars['create_handler_called'], 'create handler should be called.') + self.assertTrue( + handler_vars["create_handler_called"], "create handler should be called." + ) - message = Message(None, delivery_info={'routing_key': 'update'}) + message = Message(None, delivery_info={"routing_key": "update"}) sensor_watcher.process_task(MOCK_SENSOR_DB, message) - self.assertTrue(handler_vars['update_handler_called'], 'update handler should be called.') + self.assertTrue( + handler_vars["update_handler_called"], "update handler should be called." + ) - message = Message(None, delivery_info={'routing_key': 'delete'}) + message = Message(None, delivery_info={"routing_key": "delete"}) sensor_watcher.process_task(MOCK_SENSOR_DB, message) - self.assertTrue(handler_vars['delete_handler_called'], 'delete handler should be called.') + self.assertTrue( + handler_vars["delete_handler_called"], "delete handler should be called." + ) diff --git a/st2common/tests/unit/test_service_setup.py b/st2common/tests/unit/test_service_setup.py index 4000f6ce811..b1358f295d5 100644 --- a/st2common/tests/unit/test_service_setup.py +++ b/st2common/tests/unit/test_service_setup.py @@ -31,9 +31,7 @@ from st2tests.base import CleanFilesTestCase from st2tests import config -__all__ = [ - 'ServiceSetupTestCase' -] +__all__ = ["ServiceSetupTestCase"] MOCK_LOGGING_CONFIG_INVALID_LOG_LEVEL = """ [loggers] @@ -61,11 +59,11 @@ datefmt= """.strip() -MOCK_DEFAULT_CONFIG_FILE_PATH = '/etc/st2/st2.conf-test-patched' +MOCK_DEFAULT_CONFIG_FILE_PATH = "/etc/st2/st2.conf-test-patched" def mock_get_logging_config_path(): - return '' + return "" class ServiceSetupTestCase(CleanFilesTestCase): @@ -78,19 +76,24 @@ def test_no_logging_config_found(self): else: expected_msg = "No section: .*" - self.assertRaisesRegexp(Exception, expected_msg, - service_setup.setup, service='api', - config=config, - setup_db=False, register_mq_exchanges=False, - register_signal_handlers=False, - register_internal_trigger_types=False, - run_migrations=False) + self.assertRaisesRegexp( + Exception, + expected_msg, + service_setup.setup, + service="api", + config=config, + setup_db=False, + register_mq_exchanges=False, + register_signal_handlers=False, + register_internal_trigger_types=False, + run_migrations=False, + ) def test_invalid_log_level_friendly_error_message(self): _, mock_logging_config_path = tempfile.mkstemp() self.to_delete_files.append(mock_logging_config_path) - with open(mock_logging_config_path, 'w') as fp: + with open(mock_logging_config_path, "w") as fp: fp.write(MOCK_LOGGING_CONFIG_INVALID_LOG_LEVEL) def mock_get_logging_config_path(): @@ -99,21 +102,28 @@ def mock_get_logging_config_path(): config.get_logging_config_path = mock_get_logging_config_path if six.PY3: - expected_msg = 'ValueError: Unknown level: \'invalid_log_level\'' + expected_msg = "ValueError: Unknown level: 'invalid_log_level'" exc_type = ValueError else: - expected_msg = 'Invalid log level selected. Log level names need to be all uppercase' + expected_msg = ( + "Invalid log level selected. Log level names need to be all uppercase" + ) exc_type = KeyError - self.assertRaisesRegexp(exc_type, expected_msg, - service_setup.setup, service='api', - config=config, - setup_db=False, register_mq_exchanges=False, - register_signal_handlers=False, - register_internal_trigger_types=False, - run_migrations=False) - - @mock.patch('kombu.Queue.declare') + self.assertRaisesRegexp( + exc_type, + expected_msg, + service_setup.setup, + service="api", + config=config, + setup_db=False, + register_mq_exchanges=False, + register_signal_handlers=False, + register_internal_trigger_types=False, + run_migrations=False, + ) + + @mock.patch("kombu.Queue.declare") def test_register_exchanges_predeclare_queues(self, mock_declare): # Verify that queues are correctly pre-declared self.assertEqual(mock_declare.call_count, 0) @@ -121,34 +131,50 @@ def test_register_exchanges_predeclare_queues(self, mock_declare): register_exchanges() self.assertEqual(mock_declare.call_count, len(QUEUES)) - @mock.patch('st2common.constants.system.DEFAULT_CONFIG_FILE_PATH', - MOCK_DEFAULT_CONFIG_FILE_PATH) - @mock.patch('st2common.config.DEFAULT_CONFIG_FILE_PATH', MOCK_DEFAULT_CONFIG_FILE_PATH) + @mock.patch( + "st2common.constants.system.DEFAULT_CONFIG_FILE_PATH", + MOCK_DEFAULT_CONFIG_FILE_PATH, + ) + @mock.patch( + "st2common.config.DEFAULT_CONFIG_FILE_PATH", MOCK_DEFAULT_CONFIG_FILE_PATH + ) def test_service_setup_default_st2_conf_config_is_used(self): st2common_config.get_logging_config_path = mock_get_logging_config_path cfg.CONF.reset() # 1. DEFAULT_CONFIG_FILE_PATH config path should be used by default (/etc/st2/st2.conf) - expected_msg = 'Failed to find some config files: %s' % (MOCK_DEFAULT_CONFIG_FILE_PATH) - self.assertRaisesRegexp(ConfigFilesNotFoundError, expected_msg, service_setup.setup, - service='api', - config=st2common_config, - config_args=['--debug'], - setup_db=False, register_mq_exchanges=False, - register_signal_handlers=False, - register_internal_trigger_types=False, - run_migrations=False) + expected_msg = "Failed to find some config files: %s" % ( + MOCK_DEFAULT_CONFIG_FILE_PATH + ) + self.assertRaisesRegexp( + ConfigFilesNotFoundError, + expected_msg, + service_setup.setup, + service="api", + config=st2common_config, + config_args=["--debug"], + setup_db=False, + register_mq_exchanges=False, + register_signal_handlers=False, + register_internal_trigger_types=False, + run_migrations=False, + ) cfg.CONF.reset() # 2. --config-file should still override default config file path option - config_file_path = '/etc/st2/config.override.test' - expected_msg = 'Failed to find some config files: %s' % (config_file_path) - self.assertRaisesRegexp(ConfigFilesNotFoundError, expected_msg, service_setup.setup, - service='api', - config=st2common_config, - config_args=['--config-file', config_file_path], - setup_db=False, register_mq_exchanges=False, - register_signal_handlers=False, - register_internal_trigger_types=False, - run_migrations=False) + config_file_path = "/etc/st2/config.override.test" + expected_msg = "Failed to find some config files: %s" % (config_file_path) + self.assertRaisesRegexp( + ConfigFilesNotFoundError, + expected_msg, + service_setup.setup, + service="api", + config=st2common_config, + config_args=["--config-file", config_file_path], + setup_db=False, + register_mq_exchanges=False, + register_signal_handlers=False, + register_internal_trigger_types=False, + run_migrations=False, + ) diff --git a/st2common/tests/unit/test_shell_action_system_model.py b/st2common/tests/unit/test_shell_action_system_model.py index 6fdc7d17165..76609ab9530 100644 --- a/st2common/tests/unit/test_shell_action_system_model.py +++ b/st2common/tests/unit/test_shell_action_system_model.py @@ -32,90 +32,87 @@ from local_runner.local_shell_script_runner import LocalShellScriptRunner CURRENT_DIR = os.path.dirname(os.path.realpath(__file__)) -FIXTURES_DIR = os.path.abspath(os.path.join(CURRENT_DIR, '../fixtures')) +FIXTURES_DIR = os.path.abspath(os.path.join(CURRENT_DIR, "../fixtures")) LOGGED_USER_USERNAME = pwd.getpwuid(os.getuid())[0] -__all__ = [ - 'ShellCommandActionTestCase', - 'ShellScriptActionTestCase' -] +__all__ = ["ShellCommandActionTestCase", "ShellScriptActionTestCase"] class ShellCommandActionTestCase(unittest2.TestCase): def setUp(self): self._base_kwargs = { - 'name': 'test action', - 'action_exec_id': '1', - 'command': 'ls -la', - 'env_vars': {}, - 'timeout': None + "name": "test action", + "action_exec_id": "1", + "command": "ls -la", + "env_vars": {}, + "timeout": None, } def test_user_argument(self): # User is the same as logged user, no sudo should be used kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = False - kwargs['user'] = LOGGED_USER_USERNAME + kwargs["sudo"] = False + kwargs["user"] = LOGGED_USER_USERNAME action = ShellCommandAction(**kwargs) command = action.get_full_command_string() - self.assertEqual(command, 'ls -la') + self.assertEqual(command, "ls -la") # User is different, sudo should be used kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = False - kwargs['user'] = 'mauser' + kwargs["sudo"] = False + kwargs["user"] = "mauser" action = ShellCommandAction(**kwargs) command = action.get_full_command_string() - self.assertEqual(command, 'sudo -E -H -u mauser -- bash -c \'ls -la\'') + self.assertEqual(command, "sudo -E -H -u mauser -- bash -c 'ls -la'") # sudo with password kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = False - kwargs['sudo_password'] = 'sudopass' - kwargs['user'] = 'mauser' + kwargs["sudo"] = False + kwargs["sudo_password"] = "sudopass" + kwargs["user"] = "mauser" action = ShellCommandAction(**kwargs) command = action.get_full_command_string() - expected_command = 'sudo -S -E -H -u mauser -- bash -c \'ls -la\'' + expected_command = "sudo -S -E -H -u mauser -- bash -c 'ls -la'" self.assertEqual(command, expected_command) # sudo is used, it doesn't matter what user is specified since the # command should run as root kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = True - kwargs['user'] = 'mauser' + kwargs["sudo"] = True + kwargs["user"] = "mauser" action = ShellCommandAction(**kwargs) command = action.get_full_command_string() - self.assertEqual(command, 'sudo -E -- bash -c \'ls -la\'') + self.assertEqual(command, "sudo -E -- bash -c 'ls -la'") # sudo with passwd kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = True - kwargs['user'] = 'mauser' - kwargs['sudo_password'] = 'sudopass' + kwargs["sudo"] = True + kwargs["user"] = "mauser" + kwargs["sudo_password"] = "sudopass" action = ShellCommandAction(**kwargs) command = action.get_full_command_string() - expected_command = 'sudo -S -E -- bash -c \'ls -la\'' + expected_command = "sudo -S -E -- bash -c 'ls -la'" self.assertEqual(command, expected_command) class ShellScriptActionTestCase(unittest2.TestCase): def setUp(self): self._base_kwargs = { - 'name': 'test action', - 'action_exec_id': '1', - 'script_local_path_abs': '/tmp/foo.sh', - 'named_args': {}, - 'positional_args': [], - 'env_vars': {}, - 'timeout': None + "name": "test action", + "action_exec_id": "1", + "script_local_path_abs": "/tmp/foo.sh", + "named_args": {}, + "positional_args": [], + "env_vars": {}, + "timeout": None, } def _get_fixture(self, name): - path = os.path.join(FIXTURES_DIR, 'local_runner', name) + path = os.path.join(FIXTURES_DIR, "local_runner", name) - with open(path, 'r') as fp: + with open(path, "r") as fp: content = fp.read().strip() return content @@ -123,371 +120,374 @@ def _get_fixture(self, name): def test_user_argument(self): # User is the same as logged user, no sudo should be used kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = False - kwargs['user'] = LOGGED_USER_USERNAME + kwargs["sudo"] = False + kwargs["user"] = LOGGED_USER_USERNAME action = ShellScriptAction(**kwargs) command = action.get_full_command_string() - self.assertEqual(command, '/tmp/foo.sh') + self.assertEqual(command, "/tmp/foo.sh") # User is different, sudo should be used kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = False - kwargs['user'] = 'mauser' + kwargs["sudo"] = False + kwargs["user"] = "mauser" action = ShellScriptAction(**kwargs) command = action.get_full_command_string() - self.assertEqual(command, 'sudo -E -H -u mauser -- bash -c /tmp/foo.sh') + self.assertEqual(command, "sudo -E -H -u mauser -- bash -c /tmp/foo.sh") # sudo with password kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = False - kwargs['user'] = 'mauser' - kwargs['sudo_password'] = 'sudopass' + kwargs["sudo"] = False + kwargs["user"] = "mauser" + kwargs["sudo_password"] = "sudopass" action = ShellScriptAction(**kwargs) command = action.get_full_command_string() - expected_command = 'sudo -S -E -H -u mauser -- bash -c /tmp/foo.sh' + expected_command = "sudo -S -E -H -u mauser -- bash -c /tmp/foo.sh" self.assertEqual(command, expected_command) # complex sudo password which needs escaping kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = False - kwargs['user'] = 'mauser' - kwargs['sudo_password'] = '$udo p\'as"sss' + kwargs["sudo"] = False + kwargs["user"] = "mauser" + kwargs["sudo_password"] = "$udo p'as\"sss" action = ShellScriptAction(**kwargs) command = action.get_full_command_string() - expected_command = ('sudo -S -E -H ' - '-u mauser -- bash -c /tmp/foo.sh') + expected_command = "sudo -S -E -H " "-u mauser -- bash -c /tmp/foo.sh" self.assertEqual(command, expected_command) command = action.get_sanitized_full_command_string() - expected_command = ('echo -e \'%s\n\' | sudo -S -E -H ' - '-u mauser -- bash -c /tmp/foo.sh' % (MASKED_ATTRIBUTE_VALUE)) + expected_command = ( + "echo -e '%s\n' | sudo -S -E -H " + "-u mauser -- bash -c /tmp/foo.sh" % (MASKED_ATTRIBUTE_VALUE) + ) self.assertEqual(command, expected_command) # sudo is used, it doesn't matter what user is specified since the # command should run as root kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = True - kwargs['user'] = 'mauser' - kwargs['sudo_password'] = 'sudopass' + kwargs["sudo"] = True + kwargs["user"] = "mauser" + kwargs["sudo_password"] = "sudopass" action = ShellScriptAction(**kwargs) command = action.get_full_command_string() - expected_command = 'sudo -S -E -- bash -c /tmp/foo.sh' + expected_command = "sudo -S -E -- bash -c /tmp/foo.sh" self.assertEqual(command, expected_command) def test_command_construction_with_parameters(self): # same user, named args, no positional args kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = False - kwargs['user'] = LOGGED_USER_USERNAME - kwargs['named_args'] = OrderedDict([ - ('key1', 'value1'), - ('key2', 'value2') - ]) - kwargs['positional_args'] = [] + kwargs["sudo"] = False + kwargs["user"] = LOGGED_USER_USERNAME + kwargs["named_args"] = OrderedDict([("key1", "value1"), ("key2", "value2")]) + kwargs["positional_args"] = [] action = ShellScriptAction(**kwargs) command = action.get_full_command_string() - self.assertEqual(command, '/tmp/foo.sh key1=value1 key2=value2') + self.assertEqual(command, "/tmp/foo.sh key1=value1 key2=value2") # same user, named args, no positional args, sudo password kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = True - kwargs['sudo_password'] = 'sudopass' - kwargs['user'] = LOGGED_USER_USERNAME - kwargs['named_args'] = OrderedDict([ - ('key1', 'value1'), - ('key2', 'value2') - ]) - kwargs['positional_args'] = [] + kwargs["sudo"] = True + kwargs["sudo_password"] = "sudopass" + kwargs["user"] = LOGGED_USER_USERNAME + kwargs["named_args"] = OrderedDict([("key1", "value1"), ("key2", "value2")]) + kwargs["positional_args"] = [] action = ShellScriptAction(**kwargs) command = action.get_full_command_string() - expected = ('sudo -S -E -- bash -c ' - '\'/tmp/foo.sh key1=value1 key2=value2\'') + expected = "sudo -S -E -- bash -c " "'/tmp/foo.sh key1=value1 key2=value2'" self.assertEqual(command, expected) # different user, named args, no positional args kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = False - kwargs['user'] = 'mauser' - kwargs['named_args'] = OrderedDict([ - ('key1', 'value1'), - ('key2', 'value2') - ]) - kwargs['positional_args'] = [] + kwargs["sudo"] = False + kwargs["user"] = "mauser" + kwargs["named_args"] = OrderedDict([("key1", "value1"), ("key2", "value2")]) + kwargs["positional_args"] = [] action = ShellScriptAction(**kwargs) command = action.get_full_command_string() - expected = 'sudo -E -H -u mauser -- bash -c \'/tmp/foo.sh key1=value1 key2=value2\'' + expected = ( + "sudo -E -H -u mauser -- bash -c '/tmp/foo.sh key1=value1 key2=value2'" + ) self.assertEqual(command, expected) # different user, named args, no positional args, sudo password kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = False - kwargs['sudo_password'] = 'sudopass' - kwargs['user'] = 'mauser' - kwargs['named_args'] = OrderedDict([ - ('key1', 'value1'), - ('key2', 'value2') - ]) - kwargs['positional_args'] = [] + kwargs["sudo"] = False + kwargs["sudo_password"] = "sudopass" + kwargs["user"] = "mauser" + kwargs["named_args"] = OrderedDict([("key1", "value1"), ("key2", "value2")]) + kwargs["positional_args"] = [] action = ShellScriptAction(**kwargs) command = action.get_full_command_string() - expected = ('sudo -S -E -H -u mauser -- bash -c ' - '\'/tmp/foo.sh key1=value1 key2=value2\'') + expected = ( + "sudo -S -E -H -u mauser -- bash -c " + "'/tmp/foo.sh key1=value1 key2=value2'" + ) self.assertEqual(command, expected) # same user, positional args, no named args kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = False - kwargs['user'] = LOGGED_USER_USERNAME - kwargs['named_args'] = {} - kwargs['positional_args'] = ['ein', 'zwei', 'drei', 'mamma mia', 'foo\nbar'] + kwargs["sudo"] = False + kwargs["user"] = LOGGED_USER_USERNAME + kwargs["named_args"] = {} + kwargs["positional_args"] = ["ein", "zwei", "drei", "mamma mia", "foo\nbar"] action = ShellScriptAction(**kwargs) command = action.get_full_command_string() - self.assertEqual(command, '/tmp/foo.sh ein zwei drei \'mamma mia\' \'foo\nbar\'') + self.assertEqual(command, "/tmp/foo.sh ein zwei drei 'mamma mia' 'foo\nbar'") # different user, named args, positional args kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = False - kwargs['user'] = 'mauser' - kwargs['named_args'] = {} - kwargs['positional_args'] = ['ein', 'zwei', 'drei', 'mamma mia'] + kwargs["sudo"] = False + kwargs["user"] = "mauser" + kwargs["named_args"] = {} + kwargs["positional_args"] = ["ein", "zwei", "drei", "mamma mia"] action = ShellScriptAction(**kwargs) command = action.get_full_command_string() - ex = ('sudo -E -H -u mauser -- ' - 'bash -c \'/tmp/foo.sh ein zwei drei \'"\'"\'mamma mia\'"\'"\'\'') + ex = ( + "sudo -E -H -u mauser -- " + "bash -c '/tmp/foo.sh ein zwei drei '\"'\"'mamma mia'\"'\"''" + ) self.assertEqual(command, ex) # same user, positional and named args kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = False - kwargs['user'] = LOGGED_USER_USERNAME - kwargs['named_args'] = OrderedDict([ - ('key1', 'value1'), - ('key2', 'value2'), - ('key3', 'value 3') - ]) - - kwargs['positional_args'] = ['ein', 'zwei', 'drei'] + kwargs["sudo"] = False + kwargs["user"] = LOGGED_USER_USERNAME + kwargs["named_args"] = OrderedDict( + [("key1", "value1"), ("key2", "value2"), ("key3", "value 3")] + ) + + kwargs["positional_args"] = ["ein", "zwei", "drei"] action = ShellScriptAction(**kwargs) command = action.get_full_command_string() - exp = '/tmp/foo.sh key1=value1 key2=value2 key3=\'value 3\' ein zwei drei' + exp = "/tmp/foo.sh key1=value1 key2=value2 key3='value 3' ein zwei drei" self.assertEqual(command, exp) # different user, positional and named args kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = False - kwargs['user'] = 'mauser' - kwargs['named_args'] = OrderedDict([ - ('key1', 'value1'), - ('key2', 'value2'), - ('key3', 'value 3') - ]) - kwargs['positional_args'] = ['ein', 'zwei', 'drei'] + kwargs["sudo"] = False + kwargs["user"] = "mauser" + kwargs["named_args"] = OrderedDict( + [("key1", "value1"), ("key2", "value2"), ("key3", "value 3")] + ) + kwargs["positional_args"] = ["ein", "zwei", "drei"] action = ShellScriptAction(**kwargs) command = action.get_full_command_string() - expected = ('sudo -E -H -u mauser -- bash -c \'/tmp/foo.sh key1=value1 key2=value2 ' - 'key3=\'"\'"\'value 3\'"\'"\' ein zwei drei\'') + expected = ( + "sudo -E -H -u mauser -- bash -c '/tmp/foo.sh key1=value1 key2=value2 " + "key3='\"'\"'value 3'\"'\"' ein zwei drei'" + ) self.assertEqual(command, expected) def test_named_parameter_escaping(self): # no sudo kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = False - kwargs['user'] = LOGGED_USER_USERNAME - kwargs['named_args'] = OrderedDict([ - ('key1', 'value foo bar'), - ('key2', 'value "bar" foo'), - ('key3', 'date ; whoami'), - ('key4', '"date ; whoami"'), - ]) + kwargs["sudo"] = False + kwargs["user"] = LOGGED_USER_USERNAME + kwargs["named_args"] = OrderedDict( + [ + ("key1", "value foo bar"), + ("key2", 'value "bar" foo'), + ("key3", "date ; whoami"), + ("key4", '"date ; whoami"'), + ] + ) action = ShellScriptAction(**kwargs) command = action.get_full_command_string() - expected = self._get_fixture('escaping_test_command_1.txt') + expected = self._get_fixture("escaping_test_command_1.txt") self.assertEqual(command, expected) # sudo kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = True - kwargs['user'] = LOGGED_USER_USERNAME - kwargs['named_args'] = OrderedDict([ - ('key1', 'value foo bar'), - ('key2', 'value "bar" foo'), - ('key3', 'date ; whoami'), - ('key4', '"date ; whoami"'), - ]) + kwargs["sudo"] = True + kwargs["user"] = LOGGED_USER_USERNAME + kwargs["named_args"] = OrderedDict( + [ + ("key1", "value foo bar"), + ("key2", 'value "bar" foo'), + ("key3", "date ; whoami"), + ("key4", '"date ; whoami"'), + ] + ) action = ShellScriptAction(**kwargs) command = action.get_full_command_string() - expected = self._get_fixture('escaping_test_command_2.txt') + expected = self._get_fixture("escaping_test_command_2.txt") self.assertEqual(command, expected) def test_various_ascii_parameters(self): kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = False - kwargs['user'] = LOGGED_USER_USERNAME - kwargs['named_args'] = {'foo1': 'bar1', 'foo2': 'bar2'} - kwargs['positional_args'] = [] + kwargs["sudo"] = False + kwargs["user"] = LOGGED_USER_USERNAME + kwargs["named_args"] = {"foo1": "bar1", "foo2": "bar2"} + kwargs["positional_args"] = [] action = ShellScriptAction(**kwargs) command = action.get_full_command_string() - self.assertEqual(command, u"/tmp/foo.sh foo1=bar1 foo2=bar2") + self.assertEqual(command, "/tmp/foo.sh foo1=bar1 foo2=bar2") def test_unicode_parameter_specifing(self): kwargs = copy.deepcopy(self._base_kwargs) - kwargs['sudo'] = False - kwargs['user'] = LOGGED_USER_USERNAME - kwargs['named_args'] = {u'foo': u'bar'} - kwargs['positional_args'] = [] + kwargs["sudo"] = False + kwargs["user"] = LOGGED_USER_USERNAME + kwargs["named_args"] = {"foo": "bar"} + kwargs["positional_args"] = [] action = ShellScriptAction(**kwargs) command = action.get_full_command_string() - self.assertEqual(command, u"/tmp/foo.sh 'foo'='bar'") + self.assertEqual(command, "/tmp/foo.sh 'foo'='bar'") def test_command_construction_correct_default_parameter_values_are_used(self): runner_parameters = {} action_db_parameters = { - 'project': { - 'type': 'string', - 'default': 'st2', - 'position': 0, - }, - 'version': { - 'type': 'string', - 'position': 1, - 'required': True + "project": { + "type": "string", + "default": "st2", + "position": 0, }, - 'fork': { - 'type': 'string', - 'position': 2, - 'default': 'StackStorm', + "version": {"type": "string", "position": 1, "required": True}, + "fork": { + "type": "string", + "position": 2, + "default": "StackStorm", }, - 'branch': { - 'type': 'string', - 'position': 3, - 'default': 'master', + "branch": { + "type": "string", + "position": 3, + "default": "master", }, - 'update_changelog': { - 'type': 'boolean', - 'position': 4, - 'default': False + "update_changelog": {"type": "boolean", "position": 4, "default": False}, + "local_repo": { + "type": "string", + "position": 5, }, - 'local_repo': { - 'type': 'string', - 'position': 5, - } } context = {} - action_db = ActionDB(pack='dummy', name='action') + action_db = ActionDB(pack="dummy", name="action") - runner = LocalShellScriptRunner('id') + runner = LocalShellScriptRunner("id") runner.runner_parameters = {} runner.action = action_db # 1. All default values used live_action_db_parameters = { - 'project': 'st2flow', - 'version': '3.0.0', - 'fork': 'StackStorm', - 'local_repo': '/tmp/repo' + "project": "st2flow", + "version": "3.0.0", + "fork": "StackStorm", + "local_repo": "/tmp/repo", } - runner_params, action_params = param_utils.render_final_params(runner_parameters, - action_db_parameters, - live_action_db_parameters, - context) - - self.assertDictEqual(action_params, { - 'project': 'st2flow', - 'version': '3.0.0', - 'fork': 'StackStorm', - 'branch': 'master', # default value used - 'update_changelog': False, # default value used - 'local_repo': '/tmp/repo' - }) + runner_params, action_params = param_utils.render_final_params( + runner_parameters, action_db_parameters, live_action_db_parameters, context + ) + + self.assertDictEqual( + action_params, + { + "project": "st2flow", + "version": "3.0.0", + "fork": "StackStorm", + "branch": "master", # default value used + "update_changelog": False, # default value used + "local_repo": "/tmp/repo", + }, + ) action_db.parameters = action_db_parameters positional_args, named_args = runner._get_script_args(action_params) named_args = runner._transform_named_args(named_args) - shell_script_action = ShellScriptAction(name='dummy', action_exec_id='dummy', - script_local_path_abs='/tmp/local.sh', - named_args=named_args, - positional_args=positional_args) + shell_script_action = ShellScriptAction( + name="dummy", + action_exec_id="dummy", + script_local_path_abs="/tmp/local.sh", + named_args=named_args, + positional_args=positional_args, + ) command_string = shell_script_action.get_full_command_string() - expected = '/tmp/local.sh st2flow 3.0.0 StackStorm master 0 /tmp/repo' + expected = "/tmp/local.sh st2flow 3.0.0 StackStorm master 0 /tmp/repo" self.assertEqual(command_string, expected) # 2. Some default values used live_action_db_parameters = { - 'project': 'st2web', - 'version': '3.1.0', - 'fork': 'StackStorm1', - 'update_changelog': True, - 'local_repo': '/tmp/repob' + "project": "st2web", + "version": "3.1.0", + "fork": "StackStorm1", + "update_changelog": True, + "local_repo": "/tmp/repob", } - runner_params, action_params = param_utils.render_final_params(runner_parameters, - action_db_parameters, - live_action_db_parameters, - context) - - self.assertDictEqual(action_params, { - 'project': 'st2web', - 'version': '3.1.0', - 'fork': 'StackStorm1', - 'branch': 'master', # default value used - 'update_changelog': True, # default value used - 'local_repo': '/tmp/repob' - }) + runner_params, action_params = param_utils.render_final_params( + runner_parameters, action_db_parameters, live_action_db_parameters, context + ) + + self.assertDictEqual( + action_params, + { + "project": "st2web", + "version": "3.1.0", + "fork": "StackStorm1", + "branch": "master", # default value used + "update_changelog": True, # default value used + "local_repo": "/tmp/repob", + }, + ) action_db.parameters = action_db_parameters positional_args, named_args = runner._get_script_args(action_params) named_args = runner._transform_named_args(named_args) - shell_script_action = ShellScriptAction(name='dummy', action_exec_id='dummy', - script_local_path_abs='/tmp/local.sh', - named_args=named_args, - positional_args=positional_args) + shell_script_action = ShellScriptAction( + name="dummy", + action_exec_id="dummy", + script_local_path_abs="/tmp/local.sh", + named_args=named_args, + positional_args=positional_args, + ) command_string = shell_script_action.get_full_command_string() - expected = '/tmp/local.sh st2web 3.1.0 StackStorm1 master 1 /tmp/repob' + expected = "/tmp/local.sh st2web 3.1.0 StackStorm1 master 1 /tmp/repob" self.assertEqual(command_string, expected) # 3. None is specified for a boolean parameter, should use a default live_action_db_parameters = { - 'project': 'st2rbac', - 'version': '3.2.0', - 'fork': 'StackStorm2', - 'update_changelog': None, - 'local_repo': '/tmp/repoc' + "project": "st2rbac", + "version": "3.2.0", + "fork": "StackStorm2", + "update_changelog": None, + "local_repo": "/tmp/repoc", } - runner_params, action_params = param_utils.render_final_params(runner_parameters, - action_db_parameters, - live_action_db_parameters, - context) - - self.assertDictEqual(action_params, { - 'project': 'st2rbac', - 'version': '3.2.0', - 'fork': 'StackStorm2', - 'branch': 'master', # default value used - 'update_changelog': False, # default value used - 'local_repo': '/tmp/repoc' - }) + runner_params, action_params = param_utils.render_final_params( + runner_parameters, action_db_parameters, live_action_db_parameters, context + ) + + self.assertDictEqual( + action_params, + { + "project": "st2rbac", + "version": "3.2.0", + "fork": "StackStorm2", + "branch": "master", # default value used + "update_changelog": False, # default value used + "local_repo": "/tmp/repoc", + }, + ) action_db.parameters = action_db_parameters positional_args, named_args = runner._get_script_args(action_params) named_args = runner._transform_named_args(named_args) - shell_script_action = ShellScriptAction(name='dummy', action_exec_id='dummy', - script_local_path_abs='/tmp/local.sh', - named_args=named_args, - positional_args=positional_args) + shell_script_action = ShellScriptAction( + name="dummy", + action_exec_id="dummy", + script_local_path_abs="/tmp/local.sh", + named_args=named_args, + positional_args=positional_args, + ) command_string = shell_script_action.get_full_command_string() - expected = '/tmp/local.sh st2rbac 3.2.0 StackStorm2 master 0 /tmp/repoc' + expected = "/tmp/local.sh st2rbac 3.2.0 StackStorm2 master 0 /tmp/repoc" self.assertEqual(command_string, expected) diff --git a/st2common/tests/unit/test_state_publisher.py b/st2common/tests/unit/test_state_publisher.py index 1fa87b8487f..99dbabda7f7 100644 --- a/st2common/tests/unit/test_state_publisher.py +++ b/st2common/tests/unit/test_state_publisher.py @@ -27,7 +27,7 @@ from st2tests import DbTestCase -FAKE_STATE_MGMT_XCHG = kombu.Exchange('st2.fake.state', type='topic') +FAKE_STATE_MGMT_XCHG = kombu.Exchange("st2.fake.state", type="topic") class FakeModelPublisher(publishers.StatePublisherMixin): @@ -57,7 +57,7 @@ def _get_publisher(cls): def publish_state(cls, model_object): publisher = cls._get_publisher() if publisher: - publisher.publish_state(model_object, getattr(model_object, 'state', None)) + publisher.publish_state(model_object, getattr(model_object, "state", None)) @classmethod def _get_by_object(cls, object): @@ -65,7 +65,6 @@ def _get_by_object(cls, object): class StatePublisherTest(DbTestCase): - @classmethod def setUpClass(cls): super(StatePublisherTest, cls).setUpClass() @@ -75,13 +74,13 @@ def tearDown(self): FakeModelDB.drop_collection() super(StatePublisherTest, self).tearDown() - @mock.patch.object(publishers.PoolPublisher, 'publish', mock.MagicMock()) + @mock.patch.object(publishers.PoolPublisher, "publish", mock.MagicMock()) def test_publish(self): - instance = FakeModelDB(state='faked') + instance = FakeModelDB(state="faked") self.access.publish_state(instance) - publishers.PoolPublisher.publish.assert_called_with(instance, - FAKE_STATE_MGMT_XCHG, - instance.state) + publishers.PoolPublisher.publish.assert_called_with( + instance, FAKE_STATE_MGMT_XCHG, instance.state + ) def test_publish_unset(self): instance = FakeModelDB() @@ -92,5 +91,5 @@ def test_publish_none(self): self.assertRaises(Exception, self.access.publish_state, instance) def test_publish_empty_str(self): - instance = FakeModelDB(state='') + instance = FakeModelDB(state="") self.assertRaises(Exception, self.access.publish_state, instance) diff --git a/st2common/tests/unit/test_stream_generator.py b/st2common/tests/unit/test_stream_generator.py index 9c44db46573..a184220b80d 100644 --- a/st2common/tests/unit/test_stream_generator.py +++ b/st2common/tests/unit/test_stream_generator.py @@ -20,7 +20,6 @@ class MockBody(object): - def __init__(self, id): self.id = id self.status = "succeeded" @@ -32,8 +31,7 @@ def __init__(self, id): EVENTS = [(INCLUDE, MockBody("notend")), (END_EVENT, MockBody(END_ID))] -class MockQueue(): - +class MockQueue: def __init__(self): self.items = EVENTS @@ -47,7 +45,6 @@ def put(self, event): class MockListener(listener.BaseListener): - def __init__(self, *args, **kwargs): super(MockListener, self).__init__(*args, **kwargs) @@ -56,19 +53,19 @@ def get_consumers(self, consumer, channel): class TestStream(unittest2.TestCase): - - @mock.patch('st2common.stream.listener.BaseListener._get_action_ref_for_body') - @mock.patch('eventlet.Queue') - def test_generator(self, mock_queue, - get_action_ref_for_body): + @mock.patch("st2common.stream.listener.BaseListener._get_action_ref_for_body") + @mock.patch("eventlet.Queue") + def test_generator(self, mock_queue, get_action_ref_for_body): get_action_ref_for_body.return_value = None mock_queue.return_value = MockQueue() mock_consumer = MockListener(connection=None) mock_consumer._stopped = False - app_iter = mock_consumer.generator(events=INCLUDE, + app_iter = mock_consumer.generator( + events=INCLUDE, end_event=END_EVENT, end_statuses=["succeeded"], - end_execution_id=END_ID) - events = EVENTS.append('') + end_execution_id=END_ID, + ) + events = EVENTS.append("") for index, val in enumerate(app_iter): self.assertEquals(val, events[index]) diff --git a/st2common/tests/unit/test_system_info.py b/st2common/tests/unit/test_system_info.py index e7ddb20bef9..c840a7aa8b3 100644 --- a/st2common/tests/unit/test_system_info.py +++ b/st2common/tests/unit/test_system_info.py @@ -23,8 +23,7 @@ class TestLogger(unittest.TestCase): - def test_process_info(self): process_info = system_info.get_process_info() - self.assertEqual(process_info['hostname'], socket.gethostname()) - self.assertEqual(process_info['pid'], os.getpid()) + self.assertEqual(process_info["hostname"], socket.gethostname()) + self.assertEqual(process_info["pid"], os.getpid()) diff --git a/st2common/tests/unit/test_tags.py b/st2common/tests/unit/test_tags.py index 3ffc59b50aa..6230cedea6f 100644 --- a/st2common/tests/unit/test_tags.py +++ b/st2common/tests/unit/test_tags.py @@ -28,53 +28,69 @@ class TaggedModel(stormbase.StormFoundationDB, stormbase.TagsMixin): class TestTags(DbTestCase): - def test_simple_count(self): instance = TaggedModel() - instance.tags = [stormbase.TagField(name='tag1', value='v1'), - stormbase.TagField(name='tag2', value='v2')] + instance.tags = [ + stormbase.TagField(name="tag1", value="v1"), + stormbase.TagField(name="tag2", value="v2"), + ] saved = instance.save() retrieved = TaggedModel.objects(id=instance.id).first() - self.assertEqual(len(saved.tags), len(retrieved.tags), 'Failed to retrieve tags.') + self.assertEqual( + len(saved.tags), len(retrieved.tags), "Failed to retrieve tags." + ) def test_simple_value(self): instance = TaggedModel() - instance.tags = [stormbase.TagField(name='tag1', value='v1')] + instance.tags = [stormbase.TagField(name="tag1", value="v1")] saved = instance.save() retrieved = TaggedModel.objects(id=instance.id).first() - self.assertEqual(len(saved.tags), len(retrieved.tags), 'Failed to retrieve tags.') + self.assertEqual( + len(saved.tags), len(retrieved.tags), "Failed to retrieve tags." + ) saved_tag = saved.tags[0] retrieved_tag = retrieved.tags[0] - self.assertEqual(saved_tag.name, retrieved_tag.name, 'Failed to retrieve tag.') - self.assertEqual(saved_tag.value, retrieved_tag.value, 'Failed to retrieve tag.') + self.assertEqual(saved_tag.name, retrieved_tag.name, "Failed to retrieve tag.") + self.assertEqual( + saved_tag.value, retrieved_tag.value, "Failed to retrieve tag." + ) def test_tag_max_size_restriction(self): instance = TaggedModel() - instance.tags = [stormbase.TagField(name=self._gen_random_string(), - value=self._gen_random_string())] + instance.tags = [ + stormbase.TagField( + name=self._gen_random_string(), value=self._gen_random_string() + ) + ] saved = instance.save() retrieved = TaggedModel.objects(id=instance.id).first() - self.assertEqual(len(saved.tags), len(retrieved.tags), 'Failed to retrieve tags.') + self.assertEqual( + len(saved.tags), len(retrieved.tags), "Failed to retrieve tags." + ) def test_name_exceeds_max_size(self): instance = TaggedModel() - instance.tags = [stormbase.TagField(name=self._gen_random_string(1025), - value='v1')] + instance.tags = [ + stormbase.TagField(name=self._gen_random_string(1025), value="v1") + ] try: instance.save() - self.assertTrue(False, 'Expected save to fail') + self.assertTrue(False, "Expected save to fail") except ValidationError: pass def test_value_exceeds_max_size(self): instance = TaggedModel() - instance.tags = [stormbase.TagField(name='n1', - value=self._gen_random_string(1025))] + instance.tags = [ + stormbase.TagField(name="n1", value=self._gen_random_string(1025)) + ] try: instance.save() - self.assertTrue(False, 'Expected save to fail') + self.assertTrue(False, "Expected save to fail") except ValidationError: pass - def _gen_random_string(self, size=1024, chars=string.ascii_lowercase + string.digits): - return ''.join([random.choice(chars) for _ in range(size)]) + def _gen_random_string( + self, size=1024, chars=string.ascii_lowercase + string.digits + ): + return "".join([random.choice(chars) for _ in range(size)]) diff --git a/st2common/tests/unit/test_time_jinja_filters.py b/st2common/tests/unit/test_time_jinja_filters.py index c61473cdfd7..5a343a5c293 100644 --- a/st2common/tests/unit/test_time_jinja_filters.py +++ b/st2common/tests/unit/test_time_jinja_filters.py @@ -20,14 +20,16 @@ class TestTimeJinjaFilters(TestCase): - def test_to_human_time_from_seconds(self): - self.assertEqual('0s', time.to_human_time_from_seconds(seconds=0)) - self.assertEqual('0.1\u03BCs', time.to_human_time_from_seconds(seconds=0.1)) - self.assertEqual('56s', time.to_human_time_from_seconds(seconds=56)) - self.assertEqual('56s', time.to_human_time_from_seconds(seconds=56.2)) - self.assertEqual('7m36s', time.to_human_time_from_seconds(seconds=456)) - self.assertEqual('1h16m0s', time.to_human_time_from_seconds(seconds=4560)) - self.assertEqual('1y12d16h36m37s', time.to_human_time_from_seconds(seconds=45678997)) - self.assertRaises(AssertionError, time.to_human_time_from_seconds, - seconds='stuff') + self.assertEqual("0s", time.to_human_time_from_seconds(seconds=0)) + self.assertEqual("0.1\u03BCs", time.to_human_time_from_seconds(seconds=0.1)) + self.assertEqual("56s", time.to_human_time_from_seconds(seconds=56)) + self.assertEqual("56s", time.to_human_time_from_seconds(seconds=56.2)) + self.assertEqual("7m36s", time.to_human_time_from_seconds(seconds=456)) + self.assertEqual("1h16m0s", time.to_human_time_from_seconds(seconds=4560)) + self.assertEqual( + "1y12d16h36m37s", time.to_human_time_from_seconds(seconds=45678997) + ) + self.assertRaises( + AssertionError, time.to_human_time_from_seconds, seconds="stuff" + ) diff --git a/st2common/tests/unit/test_transport.py b/st2common/tests/unit/test_transport.py index 9e4d4789b28..75e35ae2c99 100644 --- a/st2common/tests/unit/test_transport.py +++ b/st2common/tests/unit/test_transport.py @@ -19,9 +19,7 @@ from st2common.transport.utils import _get_ssl_kwargs -__all__ = [ - 'TransportUtilsTestCase' -] +__all__ = ["TransportUtilsTestCase"] class TransportUtilsTestCase(unittest2.TestCase): @@ -32,49 +30,39 @@ def test_get_ssl_kwargs(self): # 2. ssl kwarg provided ssl_kwargs = _get_ssl_kwargs(ssl=True) - self.assertEqual(ssl_kwargs, { - 'ssl': True - }) + self.assertEqual(ssl_kwargs, {"ssl": True}) # 3. ssl_keyfile provided - ssl_kwargs = _get_ssl_kwargs(ssl_keyfile='/tmp/keyfile') - self.assertEqual(ssl_kwargs, { - 'ssl': True, - 'keyfile': '/tmp/keyfile' - }) + ssl_kwargs = _get_ssl_kwargs(ssl_keyfile="/tmp/keyfile") + self.assertEqual(ssl_kwargs, {"ssl": True, "keyfile": "/tmp/keyfile"}) # 4. ssl_certfile provided - ssl_kwargs = _get_ssl_kwargs(ssl_certfile='/tmp/certfile') - self.assertEqual(ssl_kwargs, { - 'ssl': True, - 'certfile': '/tmp/certfile' - }) + ssl_kwargs = _get_ssl_kwargs(ssl_certfile="/tmp/certfile") + self.assertEqual(ssl_kwargs, {"ssl": True, "certfile": "/tmp/certfile"}) # 5. ssl_ca_certs provided - ssl_kwargs = _get_ssl_kwargs(ssl_ca_certs='/tmp/ca_certs') - self.assertEqual(ssl_kwargs, { - 'ssl': True, - 'ca_certs': '/tmp/ca_certs' - }) + ssl_kwargs = _get_ssl_kwargs(ssl_ca_certs="/tmp/ca_certs") + self.assertEqual(ssl_kwargs, {"ssl": True, "ca_certs": "/tmp/ca_certs"}) # 6. ssl_ca_certs and ssl_cert_reqs combinations - ssl_kwargs = _get_ssl_kwargs(ssl_ca_certs='/tmp/ca_certs', ssl_cert_reqs='none') - self.assertEqual(ssl_kwargs, { - 'ssl': True, - 'ca_certs': '/tmp/ca_certs', - 'cert_reqs': ssl.CERT_NONE - }) + ssl_kwargs = _get_ssl_kwargs(ssl_ca_certs="/tmp/ca_certs", ssl_cert_reqs="none") + self.assertEqual( + ssl_kwargs, + {"ssl": True, "ca_certs": "/tmp/ca_certs", "cert_reqs": ssl.CERT_NONE}, + ) - ssl_kwargs = _get_ssl_kwargs(ssl_ca_certs='/tmp/ca_certs', ssl_cert_reqs='optional') - self.assertEqual(ssl_kwargs, { - 'ssl': True, - 'ca_certs': '/tmp/ca_certs', - 'cert_reqs': ssl.CERT_OPTIONAL - }) + ssl_kwargs = _get_ssl_kwargs( + ssl_ca_certs="/tmp/ca_certs", ssl_cert_reqs="optional" + ) + self.assertEqual( + ssl_kwargs, + {"ssl": True, "ca_certs": "/tmp/ca_certs", "cert_reqs": ssl.CERT_OPTIONAL}, + ) - ssl_kwargs = _get_ssl_kwargs(ssl_ca_certs='/tmp/ca_certs', ssl_cert_reqs='required') - self.assertEqual(ssl_kwargs, { - 'ssl': True, - 'ca_certs': '/tmp/ca_certs', - 'cert_reqs': ssl.CERT_REQUIRED - }) + ssl_kwargs = _get_ssl_kwargs( + ssl_ca_certs="/tmp/ca_certs", ssl_cert_reqs="required" + ) + self.assertEqual( + ssl_kwargs, + {"ssl": True, "ca_certs": "/tmp/ca_certs", "cert_reqs": ssl.CERT_REQUIRED}, + ) diff --git a/st2common/tests/unit/test_trigger_services.py b/st2common/tests/unit/test_trigger_services.py index 6f66a5f55b4..b843526bc9d 100644 --- a/st2common/tests/unit/test_trigger_services.py +++ b/st2common/tests/unit/test_trigger_services.py @@ -18,124 +18,147 @@ from st2common.models.api.rule import RuleAPI from st2common.models.system.common import ResourceReference from st2common.models.db.trigger import TriggerDB -from st2common.persistence.trigger import (Trigger, TriggerType) +from st2common.persistence.trigger import Trigger, TriggerType import st2common.services.triggers as trigger_service from st2tests.base import CleanDbTestCase from st2tests.fixturesloader import FixturesLoader -MOCK_TRIGGER = TriggerDB(pack='dummy_pack_1', name='trigger-test.name', parameters={}, - type='dummy_pack_1.trigger-type-test.name') +MOCK_TRIGGER = TriggerDB( + pack="dummy_pack_1", + name="trigger-test.name", + parameters={}, + type="dummy_pack_1.trigger-type-test.name", +) class TriggerServiceTests(CleanDbTestCase): - def test_create_trigger_db_from_rule(self): - test_fixtures = { - 'rules': ['cron_timer_rule_1.yaml', 'cron_timer_rule_3.yaml'] - } + test_fixtures = {"rules": ["cron_timer_rule_1.yaml", "cron_timer_rule_3.yaml"]} loader = FixturesLoader() - fixtures = loader.load_fixtures(fixtures_pack='generic', fixtures_dict=test_fixtures) - rules = fixtures['rules'] + fixtures = loader.load_fixtures( + fixtures_pack="generic", fixtures_dict=test_fixtures + ) + rules = fixtures["rules"] trigger_db_ret_1 = trigger_service.create_trigger_db_from_rule( - RuleAPI(**rules['cron_timer_rule_1.yaml'])) + RuleAPI(**rules["cron_timer_rule_1.yaml"]) + ) self.assertIsNotNone(trigger_db_ret_1) trigger_db = Trigger.get_by_id(trigger_db_ret_1.id) - self.assertDictEqual(trigger_db.parameters, - rules['cron_timer_rule_1.yaml']['trigger']['parameters']) + self.assertDictEqual( + trigger_db.parameters, + rules["cron_timer_rule_1.yaml"]["trigger"]["parameters"], + ) trigger_db_ret_2 = trigger_service.create_trigger_db_from_rule( - RuleAPI(**rules['cron_timer_rule_3.yaml'])) + RuleAPI(**rules["cron_timer_rule_3.yaml"]) + ) self.assertIsNotNone(trigger_db_ret_2) self.assertTrue(trigger_db_ret_2.id != trigger_db_ret_1.id) def test_create_trigger_db_from_rule_duplicate(self): - test_fixtures = { - 'rules': ['cron_timer_rule_1.yaml', 'cron_timer_rule_2.yaml'] - } + test_fixtures = {"rules": ["cron_timer_rule_1.yaml", "cron_timer_rule_2.yaml"]} loader = FixturesLoader() - fixtures = loader.load_fixtures(fixtures_pack='generic', fixtures_dict=test_fixtures) - rules = fixtures['rules'] + fixtures = loader.load_fixtures( + fixtures_pack="generic", fixtures_dict=test_fixtures + ) + rules = fixtures["rules"] trigger_db_ret_1 = trigger_service.create_trigger_db_from_rule( - RuleAPI(**rules['cron_timer_rule_1.yaml'])) + RuleAPI(**rules["cron_timer_rule_1.yaml"]) + ) self.assertIsNotNone(trigger_db_ret_1) trigger_db_ret_2 = trigger_service.create_trigger_db_from_rule( - RuleAPI(**rules['cron_timer_rule_2.yaml'])) + RuleAPI(**rules["cron_timer_rule_2.yaml"]) + ) self.assertIsNotNone(trigger_db_ret_2) - self.assertEqual(trigger_db_ret_1, trigger_db_ret_2, 'Should reuse same trigger.') + self.assertEqual( + trigger_db_ret_1, trigger_db_ret_2, "Should reuse same trigger." + ) trigger_db = Trigger.get_by_id(trigger_db_ret_1.id) - self.assertDictEqual(trigger_db.parameters, - rules['cron_timer_rule_1.yaml']['trigger']['parameters']) + self.assertDictEqual( + trigger_db.parameters, + rules["cron_timer_rule_1.yaml"]["trigger"]["parameters"], + ) def test_create_or_update_trigger_db_simple_triggers(self): - test_fixtures = { - 'triggertypes': ['triggertype1.yaml'] - } + test_fixtures = {"triggertypes": ["triggertype1.yaml"]} loader = FixturesLoader() - fixtures = loader.save_fixtures_to_db(fixtures_pack='generic', fixtures_dict=test_fixtures) - triggertypes = fixtures['triggertypes'] + fixtures = loader.save_fixtures_to_db( + fixtures_pack="generic", fixtures_dict=test_fixtures + ) + triggertypes = fixtures["triggertypes"] trigger_type_ref = ResourceReference.to_string_reference( - name=triggertypes['triggertype1.yaml']['name'], - pack=triggertypes['triggertype1.yaml']['pack']) + name=triggertypes["triggertype1.yaml"]["name"], + pack=triggertypes["triggertype1.yaml"]["pack"], + ) trigger = { - 'name': triggertypes['triggertype1.yaml']['name'], - 'pack': triggertypes['triggertype1.yaml']['pack'], - 'type': trigger_type_ref + "name": triggertypes["triggertype1.yaml"]["name"], + "pack": triggertypes["triggertype1.yaml"]["pack"], + "type": trigger_type_ref, } trigger_service.create_or_update_trigger_db(trigger) triggers = Trigger.get_all() - self.assertTrue(len(triggers) == 1, 'Only one trigger should be created.') - self.assertTrue(triggers[0]['name'] == triggertypes['triggertype1.yaml']['name']) + self.assertTrue(len(triggers) == 1, "Only one trigger should be created.") + self.assertTrue( + triggers[0]["name"] == triggertypes["triggertype1.yaml"]["name"] + ) # Try adding duplicate trigger_service.create_or_update_trigger_db(trigger) triggers = Trigger.get_all() - self.assertTrue(len(triggers) == 1, 'Only one trigger should be present.') - self.assertTrue(triggers[0]['name'] == triggertypes['triggertype1.yaml']['name']) + self.assertTrue(len(triggers) == 1, "Only one trigger should be present.") + self.assertTrue( + triggers[0]["name"] == triggertypes["triggertype1.yaml"]["name"] + ) def test_exception_thrown_when_rule_creation_no_trigger_yes_triggertype(self): - test_fixtures = { - 'triggertypes': ['triggertype1.yaml'] - } + test_fixtures = {"triggertypes": ["triggertype1.yaml"]} loader = FixturesLoader() - fixtures = loader.save_fixtures_to_db(fixtures_pack='generic', fixtures_dict=test_fixtures) - triggertypes = fixtures['triggertypes'] + fixtures = loader.save_fixtures_to_db( + fixtures_pack="generic", fixtures_dict=test_fixtures + ) + triggertypes = fixtures["triggertypes"] trigger_type_ref = ResourceReference.to_string_reference( - name=triggertypes['triggertype1.yaml']['name'], - pack=triggertypes['triggertype1.yaml']['pack']) + name=triggertypes["triggertype1.yaml"]["name"], + pack=triggertypes["triggertype1.yaml"]["pack"], + ) rule = { - 'name': 'fancyrule', - 'trigger': { - 'type': trigger_type_ref - }, - 'criteria': { - - }, - 'action': { - 'ref': 'core.local', - 'parameters': { - 'cmd': 'date' - } - } + "name": "fancyrule", + "trigger": {"type": trigger_type_ref}, + "criteria": {}, + "action": {"ref": "core.local", "parameters": {"cmd": "date"}}, } rule_api = RuleAPI(**rule) - self.assertRaises(TriggerDoesNotExistException, - trigger_service.create_trigger_db_from_rule, rule_api) + self.assertRaises( + TriggerDoesNotExistException, + trigger_service.create_trigger_db_from_rule, + rule_api, + ) def test_get_trigger_db_given_type_and_params(self): # Add dummy triggers - trigger_1 = TriggerDB(pack='testpack', name='testtrigger1', type='testpack.testtrigger1') + trigger_1 = TriggerDB( + pack="testpack", name="testtrigger1", type="testpack.testtrigger1" + ) - trigger_2 = TriggerDB(pack='testpack', name='testtrigger2', type='testpack.testtrigger2') + trigger_2 = TriggerDB( + pack="testpack", name="testtrigger2", type="testpack.testtrigger2" + ) - trigger_3 = TriggerDB(pack='testpack', name='testtrigger3', type='testpack.testtrigger3') + trigger_3 = TriggerDB( + pack="testpack", name="testtrigger3", type="testpack.testtrigger3" + ) - trigger_4 = TriggerDB(pack='testpack', name='testtrigger4', type='testpack.testtrigger4', - parameters={'ponies': 'unicorn'}) + trigger_4 = TriggerDB( + pack="testpack", + name="testtrigger4", + type="testpack.testtrigger4", + parameters={"ponies": "unicorn"}, + ) Trigger.add_or_update(trigger_1) Trigger.add_or_update(trigger_2) @@ -143,64 +166,73 @@ def test_get_trigger_db_given_type_and_params(self): Trigger.add_or_update(trigger_4) # Trigger with no parameters, parameters={} in db - trigger_db = trigger_service.get_trigger_db_given_type_and_params(type=trigger_1.type, - parameters={}) + trigger_db = trigger_service.get_trigger_db_given_type_and_params( + type=trigger_1.type, parameters={} + ) self.assertEqual(trigger_db, trigger_1) - trigger_db = trigger_service.get_trigger_db_given_type_and_params(type=trigger_1.type, - parameters=None) + trigger_db = trigger_service.get_trigger_db_given_type_and_params( + type=trigger_1.type, parameters=None + ) self.assertEqual(trigger_db, trigger_1) - trigger_db = trigger_service.get_trigger_db_given_type_and_params(type=trigger_1.type, - parameters={'fo': 'bar'}) + trigger_db = trigger_service.get_trigger_db_given_type_and_params( + type=trigger_1.type, parameters={"fo": "bar"} + ) self.assertEqual(trigger_db, None) # Trigger with no parameters, no parameters attribute in the db - trigger_db = trigger_service.get_trigger_db_given_type_and_params(type=trigger_2.type, - parameters={}) + trigger_db = trigger_service.get_trigger_db_given_type_and_params( + type=trigger_2.type, parameters={} + ) self.assertEqual(trigger_db, trigger_2) - trigger_db = trigger_service.get_trigger_db_given_type_and_params(type=trigger_2.type, - parameters=None) + trigger_db = trigger_service.get_trigger_db_given_type_and_params( + type=trigger_2.type, parameters=None + ) self.assertEqual(trigger_db, trigger_2) - trigger_db = trigger_service.get_trigger_db_given_type_and_params(type=trigger_2.type, - parameters={'fo': 'bar'}) + trigger_db = trigger_service.get_trigger_db_given_type_and_params( + type=trigger_2.type, parameters={"fo": "bar"} + ) self.assertEqual(trigger_db, None) - trigger_db = trigger_service.get_trigger_db_given_type_and_params(type=trigger_3.type, - parameters={}) + trigger_db = trigger_service.get_trigger_db_given_type_and_params( + type=trigger_3.type, parameters={} + ) self.assertEqual(trigger_db, trigger_3) - trigger_db = trigger_service.get_trigger_db_given_type_and_params(type=trigger_3.type, - parameters=None) + trigger_db = trigger_service.get_trigger_db_given_type_and_params( + type=trigger_3.type, parameters=None + ) self.assertEqual(trigger_db, trigger_3) # Trigger with parameters trigger_db = trigger_service.get_trigger_db_given_type_and_params( - type=trigger_4.type, - parameters=trigger_4.parameters) + type=trigger_4.type, parameters=trigger_4.parameters + ) self.assertEqual(trigger_db, trigger_4) - trigger_db = trigger_service.get_trigger_db_given_type_and_params(type=trigger_4.type, - parameters=None) + trigger_db = trigger_service.get_trigger_db_given_type_and_params( + type=trigger_4.type, parameters=None + ) self.assertEqual(trigger_db, None) def test_add_trigger_type_no_params(self): # Trigger type with no params should create a trigger with same name as trigger type. trig_type = { - 'name': 'myawesometriggertype', - 'pack': 'dummy_pack_1', - 'description': 'Words cannot describe how awesome I am.', - 'parameters_schema': {}, - 'payload_schema': {} + "name": "myawesometriggertype", + "pack": "dummy_pack_1", + "description": "Words cannot describe how awesome I am.", + "parameters_schema": {}, + "payload_schema": {}, } trigtype_dbs = trigger_service.add_trigger_models(trigger_types=[trig_type]) trigger_type, trigger = trigtype_dbs[0] trigtype_db = TriggerType.get_by_id(trigger_type.id) - self.assertEqual(trigtype_db.pack, 'dummy_pack_1') - self.assertEqual(trigtype_db.name, trig_type.get('name')) + self.assertEqual(trigtype_db.pack, "dummy_pack_1") + self.assertEqual(trigtype_db.name, trig_type.get("name")) self.assertIsNotNone(trigger) self.assertEqual(trigger.name, trigtype_db.name) @@ -210,35 +242,34 @@ def test_add_trigger_type_no_params(self): self.assertTrue(len(triggers) == 1) def test_add_trigger_type_with_params(self): - MOCK_TRIGGER.type = 'system.test' + MOCK_TRIGGER.type = "system.test" # Trigger type with params should not create a trigger. PARAMETERS_SCHEMA = { "type": "object", - "properties": { - "url": {"type": "string"} - }, - "required": ['url'], - "additionalProperties": False + "properties": {"url": {"type": "string"}}, + "required": ["url"], + "additionalProperties": False, } trig_type = { - 'name': 'myawesometriggertype2', - 'pack': 'my_pack_1', - 'description': 'Words cannot describe how awesome I am.', - 'parameters_schema': PARAMETERS_SCHEMA, - 'payload_schema': {} + "name": "myawesometriggertype2", + "pack": "my_pack_1", + "description": "Words cannot describe how awesome I am.", + "parameters_schema": PARAMETERS_SCHEMA, + "payload_schema": {}, } trigtype_dbs = trigger_service.add_trigger_models(trigger_types=[trig_type]) trigger_type, trigger = trigtype_dbs[0] trigtype_db = TriggerType.get_by_id(trigger_type.id) - self.assertEqual(trigtype_db.pack, 'my_pack_1') - self.assertEqual(trigtype_db.name, trig_type.get('name')) + self.assertEqual(trigtype_db.pack, "my_pack_1") + self.assertEqual(trigtype_db.name, trig_type.get("name")) self.assertEqual(trigger, None) def test_add_trigger_type(self): """ This sensor has misconfigured trigger type. We shouldn't explode. """ + class FailTestSensor(object): started = False @@ -252,12 +283,12 @@ def stop(self): pass def get_trigger_types(self): - return [ - {'description': 'Ain\'t got no name'} - ] + return [{"description": "Ain't got no name"}] try: trigger_service.add_trigger_models(FailTestSensor().get_trigger_types()) - self.assertTrue(False, 'Trigger type doesn\'t have \'name\' field. Should have thrown.') + self.assertTrue( + False, "Trigger type doesn't have 'name' field. Should have thrown." + ) except Exception: self.assertTrue(True) diff --git a/st2common/tests/unit/test_triggers_registrar.py b/st2common/tests/unit/test_triggers_registrar.py index 53595d28673..5ceda4f851f 100644 --- a/st2common/tests/unit/test_triggers_registrar.py +++ b/st2common/tests/unit/test_triggers_registrar.py @@ -22,9 +22,7 @@ from st2tests.base import CleanDbTestCase from st2tests.fixturesloader import get_fixtures_packs_base_path -__all__ = [ - 'TriggersRegistrarTestCase' -] +__all__ = ["TriggersRegistrarTestCase"] class TriggersRegistrarTestCase(CleanDbTestCase): @@ -44,7 +42,7 @@ def test_register_all_triggers(self): def test_register_triggers_from_pack(self): base_path = get_fixtures_packs_base_path() - pack_dir = os.path.join(base_path, 'dummy_pack_1') + pack_dir = os.path.join(base_path, "dummy_pack_1") trigger_type_dbs = TriggerType.get_all() self.assertEqual(len(trigger_type_dbs), 0) @@ -58,12 +56,12 @@ def test_register_triggers_from_pack(self): self.assertEqual(len(trigger_type_dbs), 2) self.assertEqual(len(trigger_dbs), 2) - self.assertEqual(trigger_type_dbs[0].name, 'event_handler') - self.assertEqual(trigger_type_dbs[0].pack, 'dummy_pack_1') - self.assertEqual(trigger_dbs[0].name, 'event_handler') - self.assertEqual(trigger_dbs[0].pack, 'dummy_pack_1') - self.assertEqual(trigger_dbs[0].type, 'dummy_pack_1.event_handler') + self.assertEqual(trigger_type_dbs[0].name, "event_handler") + self.assertEqual(trigger_type_dbs[0].pack, "dummy_pack_1") + self.assertEqual(trigger_dbs[0].name, "event_handler") + self.assertEqual(trigger_dbs[0].pack, "dummy_pack_1") + self.assertEqual(trigger_dbs[0].type, "dummy_pack_1.event_handler") - self.assertEqual(trigger_type_dbs[1].name, 'head_sha_monitor') - self.assertEqual(trigger_type_dbs[1].pack, 'dummy_pack_1') - self.assertEqual(trigger_type_dbs[1].payload_schema['type'], 'object') + self.assertEqual(trigger_type_dbs[1].name, "head_sha_monitor") + self.assertEqual(trigger_type_dbs[1].pack, "dummy_pack_1") + self.assertEqual(trigger_type_dbs[1].payload_schema["type"], "object") diff --git a/st2common/tests/unit/test_unit_testing_mocks.py b/st2common/tests/unit/test_unit_testing_mocks.py index ce63dd78342..742ca85da12 100644 --- a/st2common/tests/unit/test_unit_testing_mocks.py +++ b/st2common/tests/unit/test_unit_testing_mocks.py @@ -23,9 +23,9 @@ from st2tests.mocks.action import MockActionService __all__ = [ - 'BaseSensorTestCaseTestCase', - 'MockSensorServiceTestCase', - 'MockActionServiceTestCase' + "BaseSensorTestCaseTestCase", + "MockSensorServiceTestCase", + "MockActionServiceTestCase", ] @@ -37,36 +37,38 @@ class BaseMockResourceServiceTestCase(object): class TestCase(unittest2.TestCase): def test_get_user_info(self): result = self.mock_service.get_user_info() - self.assertEqual(result['username'], 'admin') - self.assertEqual(result['rbac']['roles'], ['admin']) + self.assertEqual(result["username"], "admin") + self.assertEqual(result["rbac"]["roles"], ["admin"]) def test_list_set_get_delete_values(self): # list_values, set_value result = self.mock_service.list_values() self.assertSequenceEqual(result, []) - self.mock_service.set_value(name='t1.local', value='test1', local=True) - self.mock_service.set_value(name='t1.global', value='test1', local=False) + self.mock_service.set_value(name="t1.local", value="test1", local=True) + self.mock_service.set_value(name="t1.global", value="test1", local=False) result = self.mock_service.list_values(local=True) self.assertEqual(len(result), 1) - self.assertEqual(result[0].name, 'dummy.test:t1.local') + self.assertEqual(result[0].name, "dummy.test:t1.local") result = self.mock_service.list_values(local=False) - self.assertEqual(result[0].name, 'dummy.test:t1.local') - self.assertEqual(result[1].name, 't1.global') + self.assertEqual(result[0].name, "dummy.test:t1.local") + self.assertEqual(result[1].name, "t1.global") self.assertEqual(len(result), 2) # get_value - self.assertEqual(self.mock_service.get_value('inexistent'), None) - self.assertEqual(self.mock_service.get_value(name='t1.local', local=True), 'test1') + self.assertEqual(self.mock_service.get_value("inexistent"), None) + self.assertEqual( + self.mock_service.get_value(name="t1.local", local=True), "test1" + ) # delete_value self.assertEqual(len(self.mock_service.list_values(local=True)), 1) - self.assertEqual(self.mock_service.delete_value('inexistent'), False) + self.assertEqual(self.mock_service.delete_value("inexistent"), False) self.assertEqual(len(self.mock_service.list_values(local=True)), 1) - self.assertEqual(self.mock_service.delete_value('t1.local'), True) + self.assertEqual(self.mock_service.delete_value("t1.local"), True) self.assertEqual(len(self.mock_service.list_values(local=True)), 0) @@ -77,47 +79,50 @@ def test_dispatch_and_assertTriggerDispatched(self): sensor_service = self.sensor_service expected_msg = 'Trigger "nope" hasn\'t been dispatched' - self.assertRaisesRegexp(AssertionError, expected_msg, - self.assertTriggerDispatched, trigger='nope') + self.assertRaisesRegexp( + AssertionError, expected_msg, self.assertTriggerDispatched, trigger="nope" + ) - sensor_service.dispatch(trigger='test1', payload={'a': 'b'}) - result = self.assertTriggerDispatched(trigger='test1') + sensor_service.dispatch(trigger="test1", payload={"a": "b"}) + result = self.assertTriggerDispatched(trigger="test1") self.assertTrue(result) - result = self.assertTriggerDispatched(trigger='test1', payload={'a': 'b'}) + result = self.assertTriggerDispatched(trigger="test1", payload={"a": "b"}) self.assertTrue(result) expected_msg = 'Trigger "test1" hasn\'t been dispatched' - self.assertRaisesRegexp(AssertionError, expected_msg, - self.assertTriggerDispatched, - trigger='test1', - payload={'a': 'c'}) + self.assertRaisesRegexp( + AssertionError, + expected_msg, + self.assertTriggerDispatched, + trigger="test1", + payload={"a": "c"}, + ) class MockSensorServiceTestCase(BaseMockResourceServiceTestCase.TestCase): - def setUp(self): - mock_sensor_wrapper = MockSensorWrapper(pack='dummy', class_name='test') + mock_sensor_wrapper = MockSensorWrapper(pack="dummy", class_name="test") self.mock_service = MockSensorService(sensor_wrapper=mock_sensor_wrapper) def test_get_logger(self): sensor_service = self.mock_service - logger = sensor_service.get_logger('test') - logger.info('test info') - logger.debug('test debug') + logger = sensor_service.get_logger("test") + logger.info("test info") + logger.debug("test debug") self.assertEqual(len(logger.method_calls), 2) method_name, method_args, method_kwargs = tuple(logger.method_calls[0]) - self.assertEqual(method_name, 'info') - self.assertEqual(method_args, ('test info',)) + self.assertEqual(method_name, "info") + self.assertEqual(method_args, ("test info",)) self.assertEqual(method_kwargs, {}) method_name, method_args, method_kwargs = tuple(logger.method_calls[1]) - self.assertEqual(method_name, 'debug') - self.assertEqual(method_args, ('test debug',)) + self.assertEqual(method_name, "debug") + self.assertEqual(method_args, ("test debug",)) self.assertEqual(method_kwargs, {}) class MockActionServiceTestCase(BaseMockResourceServiceTestCase.TestCase): def setUp(self): - mock_action_wrapper = MockActionWrapper(pack='dummy', class_name='test') + mock_action_wrapper = MockActionWrapper(pack="dummy", class_name="test") self.mock_service = MockActionService(action_wrapper=mock_action_wrapper) diff --git a/st2common/tests/unit/test_util_actionalias_helpstrings.py b/st2common/tests/unit/test_util_actionalias_helpstrings.py index a7726dd1777..e543bd471a5 100644 --- a/st2common/tests/unit/test_util_actionalias_helpstrings.py +++ b/st2common/tests/unit/test_util_actionalias_helpstrings.py @@ -25,62 +25,101 @@ ALIASES = [ - MemoryActionAliasDB(name="kyle_reese", ref="terminator.1", - pack="the80s", enabled=True, - formats=["Come with me if you want to live"] + MemoryActionAliasDB( + name="kyle_reese", + ref="terminator.1", + pack="the80s", + enabled=True, + formats=["Come with me if you want to live"], ), - MemoryActionAliasDB(name="terminator", ref="terminator.2", - pack="the80s", enabled=True, - formats=["I need your {{item}}, your {{item2}}" - " and your {{vehicle}}"] + MemoryActionAliasDB( + name="terminator", + ref="terminator.2", + pack="the80s", + enabled=True, + formats=["I need your {{item}}, your {{item2}}" " and your {{vehicle}}"], ), - MemoryActionAliasDB(name="johnny_five_alive", ref="short_circuit.3", - pack="the80s", enabled=True, - formats=[{'display': 'Number 5 is {{status}}', - 'representation': ['Number 5 is {{status=alive}}']}, - 'Hey, laser lips, your mama was a snow blower.'] + MemoryActionAliasDB( + name="johnny_five_alive", + ref="short_circuit.3", + pack="the80s", + enabled=True, + formats=[ + { + "display": "Number 5 is {{status}}", + "representation": ["Number 5 is {{status=alive}}"], + }, + "Hey, laser lips, your mama was a snow blower.", + ], ), - MemoryActionAliasDB(name="i_feel_alive", ref="short_circuit.4", - pack="the80s", enabled=True, - formats=["How do I feel? I feel... {{status}}!"] + MemoryActionAliasDB( + name="i_feel_alive", + ref="short_circuit.4", + pack="the80s", + enabled=True, + formats=["How do I feel? I feel... {{status}}!"], ), - MemoryActionAliasDB(name='andy', ref='the_goonies.1', - pack="the80s", enabled=True, - formats=[{'display': 'Watch this.'}] + MemoryActionAliasDB( + name="andy", + ref="the_goonies.1", + pack="the80s", + enabled=True, + formats=[{"display": "Watch this."}], ), - MemoryActionAliasDB(name='andy', ref='the_goonies.5', - pack="the80s", enabled=True, - formats=[{'display': "He's just like his {{relation}}."}] + MemoryActionAliasDB( + name="andy", + ref="the_goonies.5", + pack="the80s", + enabled=True, + formats=[{"display": "He's just like his {{relation}}."}], ), - MemoryActionAliasDB(name='data', ref='the_goonies.6', - pack="the80s", enabled=True, - formats=[{'representation': "That's okay daddy. You can't hug a {{object}}."}] + MemoryActionAliasDB( + name="data", + ref="the_goonies.6", + pack="the80s", + enabled=True, + formats=[{"representation": "That's okay daddy. You can't hug a {{object}}."}], ), - MemoryActionAliasDB(name='mr_wang', ref='the_goonies.7', - pack="the80s", enabled=True, - formats=[{'representation': 'You are my greatest invention.'}] + MemoryActionAliasDB( + name="mr_wang", + ref="the_goonies.7", + pack="the80s", + enabled=True, + formats=[{"representation": "You are my greatest invention."}], ), - MemoryActionAliasDB(name="Ferris", ref="ferris_buellers_day_off.8", - pack="the80s", enabled=True, - formats=["Life moves pretty fast.", - "If you don't stop and look around once in a while, you could miss it."] + MemoryActionAliasDB( + name="Ferris", + ref="ferris_buellers_day_off.8", + pack="the80s", + enabled=True, + formats=[ + "Life moves pretty fast.", + "If you don't stop and look around once in a while, you could miss it.", + ], ), - MemoryActionAliasDB(name="economics.teacher", ref="ferris_buellers_day_off.10", - pack="the80s", enabled=False, - formats=["Bueller?... Bueller?... Bueller? "] + MemoryActionAliasDB( + name="economics.teacher", + ref="ferris_buellers_day_off.10", + pack="the80s", + enabled=False, + formats=["Bueller?... Bueller?... Bueller? "], + ), + MemoryActionAliasDB( + name="spengler", + ref="ghostbusters.10", + pack="the80s", + enabled=True, + formats=["{{choice}} cross the {{target}}"], ), - MemoryActionAliasDB(name="spengler", ref="ghostbusters.10", - pack="the80s", enabled=True, - formats=["{{choice}} cross the {{target}}"] - ) ] -@mock.patch.object(MemoryActionAliasDB, 'get_uid') +@mock.patch.object(MemoryActionAliasDB, "get_uid") class ActionAliasTestCase(unittest2.TestCase): - ''' + """ Test scenarios must consist of 80s movie quotes. - ''' + """ + def check_data_structure(self, result): tmp = list(result.keys()) tmp.sort() @@ -93,7 +132,9 @@ def test_filtering_no_arg(self, mock): result = generate_helpstring_result(ALIASES) self.check_data_structure(result) self.check_available_count(result, 10) - the80s = [line for line in result.get("helpstrings") if line['pack'] == "the80s"] + the80s = [ + line for line in result.get("helpstrings") if line["pack"] == "the80s" + ] self.assertEqual(len(the80s), 10) self.assertEqual(the80s[0].get("display"), "Come with me if you want to live") @@ -115,7 +156,9 @@ def test_filtering_match(self, mock): result = generate_helpstring_result(ALIASES, "you") self.check_data_structure(result) self.check_available_count(result, 4) - the80s = [line for line in result.get("helpstrings") if line['pack'] == "the80s"] + the80s = [ + line for line in result.get("helpstrings") if line["pack"] == "the80s" + ] self.assertEqual(len(the80s), 4) self.assertEqual(the80s[0].get("display"), "Come with me if you want to live") @@ -123,12 +166,16 @@ def test_pack_empty_string(self, mock): result = generate_helpstring_result(ALIASES, "", "") self.check_data_structure(result) self.check_available_count(result, 10) - the80s = [line for line in result.get("helpstrings") if line['pack'] == "the80s"] + the80s = [ + line for line in result.get("helpstrings") if line["pack"] == "the80s" + ] self.assertEqual(len(the80s), 10) self.assertEqual(the80s[0].get("display"), "Come with me if you want to live") def test_pack_no_match(self, mock): - result = generate_helpstring_result(ALIASES, "", "you_will_not_find_this_string") + result = generate_helpstring_result( + ALIASES, "", "you_will_not_find_this_string" + ) self.check_data_structure(result) self.check_available_count(result, 0) self.assertEqual(result.get("helpstrings"), []) @@ -137,7 +184,9 @@ def test_pack_match(self, mock): result = generate_helpstring_result(ALIASES, "", "the80s") self.check_data_structure(result) self.check_available_count(result, 10) - the80s = [line for line in result.get("helpstrings") if line['pack'] == "the80s"] + the80s = [ + line for line in result.get("helpstrings") if line["pack"] == "the80s" + ] self.assertEqual(len(the80s), 10) self.assertEqual(the80s[0].get("display"), "Come with me if you want to live") @@ -153,7 +202,9 @@ def test_limit_neg_out_of_bounds(self, mock): result = generate_helpstring_result(ALIASES, "", "the80s", -3) self.check_data_structure(result) self.check_available_count(result, 10) - the80s = [line for line in result.get("helpstrings") if line['pack'] == "the80s"] + the80s = [ + line for line in result.get("helpstrings") if line["pack"] == "the80s" + ] self.assertEqual(len(the80s), 10) self.assertEqual(the80s[0].get("display"), "Come with me if you want to live") @@ -161,7 +212,9 @@ def test_limit_pos_out_of_bounds(self, mock): result = generate_helpstring_result(ALIASES, "", "the80s", 30) self.check_data_structure(result) self.check_available_count(result, 10) - the80s = [line for line in result.get("helpstrings") if line['pack'] == "the80s"] + the80s = [ + line for line in result.get("helpstrings") if line["pack"] == "the80s" + ] self.assertEqual(len(the80s), 10) self.assertEqual(the80s[0].get("display"), "Come with me if you want to live") @@ -169,7 +222,9 @@ def test_limit_in_bounds(self, mock): result = generate_helpstring_result(ALIASES, "", "the80s", 3) self.check_data_structure(result) self.check_available_count(result, 10) - the80s = [line for line in result.get("helpstrings") if line['pack'] == "the80s"] + the80s = [ + line for line in result.get("helpstrings") if line["pack"] == "the80s" + ] self.assertEqual(len(the80s), 3) self.assertEqual(the80s[0].get("display"), "Come with me if you want to live") @@ -185,7 +240,9 @@ def test_offset_negative_out_of_bounds(self, mock): result = generate_helpstring_result(ALIASES, "", "the80s", 0, -1) self.check_data_structure(result) self.check_available_count(result, 10) - the80s = [line for line in result.get("helpstrings") if line['pack'] == "the80s"] + the80s = [ + line for line in result.get("helpstrings") if line["pack"] == "the80s" + ] self.assertEqual(len(the80s), 10) self.assertEqual(the80s[0].get("display"), "Come with me if you want to live") @@ -199,6 +256,8 @@ def test_offset_in_bounds(self, mock): result = generate_helpstring_result(ALIASES, "", "the80s", 0, 6) self.check_data_structure(result) self.check_available_count(result, 10) - the80s = [line for line in result.get("helpstrings") if line['pack'] == "the80s"] + the80s = [ + line for line in result.get("helpstrings") if line["pack"] == "the80s" + ] self.assertEqual(len(the80s), 4) self.assertEqual(the80s[0].get("display"), "He's just like his {{relation}}.") diff --git a/st2common/tests/unit/test_util_actionalias_matching.py b/st2common/tests/unit/test_util_actionalias_matching.py index c22ccab3e6a..082fa40b98a 100644 --- a/st2common/tests/unit/test_util_actionalias_matching.py +++ b/st2common/tests/unit/test_util_actionalias_matching.py @@ -24,89 +24,130 @@ MemoryActionAliasDB = ActionAliasDB -@mock.patch.object(MemoryActionAliasDB, 'get_uid') +@mock.patch.object(MemoryActionAliasDB, "get_uid") class ActionAliasTestCase(unittest2.TestCase): - ''' + """ Test scenarios must consist of 80s movie quotes. - ''' + """ + def test_list_format_strings_from_aliases(self, mock): ALIASES = [ - MemoryActionAliasDB(name="kyle_reese", ref="terminator.1", - formats=["Come with me if you want to live"]), - MemoryActionAliasDB(name="terminator", ref="terminator.2", - formats=["I need your {{item}}, your {{item2}}" - " and your {{vehicle}}"]) + MemoryActionAliasDB( + name="kyle_reese", + ref="terminator.1", + formats=["Come with me if you want to live"], + ), + MemoryActionAliasDB( + name="terminator", + ref="terminator.2", + formats=[ + "I need your {{item}}, your {{item2}}" " and your {{vehicle}}" + ], + ), ] result = matching.list_format_strings_from_aliases(ALIASES) self.assertEqual(len(result), 2) - self.assertEqual(result[0]['display'], "Come with me if you want to live") - self.assertEqual(result[1]['display'], - "I need your {{item}}, your {{item2}} and" - " your {{vehicle}}") + self.assertEqual(result[0]["display"], "Come with me if you want to live") + self.assertEqual( + result[1]["display"], + "I need your {{item}}, your {{item2}} and" " your {{vehicle}}", + ) def test_list_format_strings_from_aliases_with_display(self, mock): ALIASES = [ - MemoryActionAliasDB(name="johnny_five_alive", ref="short_circuit.1", formats=[ - {'display': 'Number 5 is {{status}}', - 'representation': ['Number 5 is {{status=alive}}']}, - 'Hey, laser lips, your mama was a snow blower.']), - MemoryActionAliasDB(name="i_feel_alive", ref="short_circuit.2", - formats=["How do I feel? I feel... {{status}}!"]) + MemoryActionAliasDB( + name="johnny_five_alive", + ref="short_circuit.1", + formats=[ + { + "display": "Number 5 is {{status}}", + "representation": ["Number 5 is {{status=alive}}"], + }, + "Hey, laser lips, your mama was a snow blower.", + ], + ), + MemoryActionAliasDB( + name="i_feel_alive", + ref="short_circuit.2", + formats=["How do I feel? I feel... {{status}}!"], + ), ] result = matching.list_format_strings_from_aliases(ALIASES) self.assertEqual(len(result), 3) - self.assertEqual(result[0]['display'], "Number 5 is {{status}}") - self.assertEqual(result[0]['representation'], "Number 5 is {{status=alive}}") - self.assertEqual(result[1]['display'], "Hey, laser lips, your mama was a snow blower.") - self.assertEqual(result[1]['representation'], - "Hey, laser lips, your mama was a snow blower.") - self.assertEqual(result[2]['display'], "How do I feel? I feel... {{status}}!") - self.assertEqual(result[2]['representation'], "How do I feel? I feel... {{status}}!") + self.assertEqual(result[0]["display"], "Number 5 is {{status}}") + self.assertEqual(result[0]["representation"], "Number 5 is {{status=alive}}") + self.assertEqual( + result[1]["display"], "Hey, laser lips, your mama was a snow blower." + ) + self.assertEqual( + result[1]["representation"], "Hey, laser lips, your mama was a snow blower." + ) + self.assertEqual(result[2]["display"], "How do I feel? I feel... {{status}}!") + self.assertEqual( + result[2]["representation"], "How do I feel? I feel... {{status}}!" + ) def test_list_format_strings_from_aliases_with_display_only(self, mock): ALIASES = [ - MemoryActionAliasDB(name='andy', - ref='the_goonies.1', formats=[{'display': 'Watch this.'}]), - MemoryActionAliasDB(name='andy', ref='the_goonies.2', - formats=[{'display': "He's just like his {{relation}}."}]) + MemoryActionAliasDB( + name="andy", ref="the_goonies.1", formats=[{"display": "Watch this."}] + ), + MemoryActionAliasDB( + name="andy", + ref="the_goonies.2", + formats=[{"display": "He's just like his {{relation}}."}], + ), ] result = matching.list_format_strings_from_aliases(ALIASES) self.assertEqual(len(result), 2) - self.assertEqual(result[0]['display'], 'Watch this.') - self.assertEqual(result[0]['representation'], '') - self.assertEqual(result[1]['display'], "He's just like his {{relation}}.") - self.assertEqual(result[1]['representation'], '') + self.assertEqual(result[0]["display"], "Watch this.") + self.assertEqual(result[0]["representation"], "") + self.assertEqual(result[1]["display"], "He's just like his {{relation}}.") + self.assertEqual(result[1]["representation"], "") def test_list_format_strings_from_aliases_with_representation_only(self, mock): ALIASES = [ - MemoryActionAliasDB(name='data', ref='the_goonies.1', formats=[ - {'representation': "That's okay daddy. You can't hug a {{object}}."}]), - MemoryActionAliasDB(name='mr_wang', ref='the_goonies.2', formats=[ - {'representation': 'You are my greatest invention.'}]) + MemoryActionAliasDB( + name="data", + ref="the_goonies.1", + formats=[ + {"representation": "That's okay daddy. You can't hug a {{object}}."} + ], + ), + MemoryActionAliasDB( + name="mr_wang", + ref="the_goonies.2", + formats=[{"representation": "You are my greatest invention."}], + ), ] result = matching.list_format_strings_from_aliases(ALIASES) self.assertEqual(len(result), 2) - self.assertEqual(result[0]['display'], None) - self.assertEqual(result[0]['representation'], - "That's okay daddy. You can't hug a {{object}}.") - self.assertEqual(result[1]['display'], None) - self.assertEqual(result[1]['representation'], 'You are my greatest invention.') + self.assertEqual(result[0]["display"], None) + self.assertEqual( + result[0]["representation"], + "That's okay daddy. You can't hug a {{object}}.", + ) + self.assertEqual(result[1]["display"], None) + self.assertEqual(result[1]["representation"], "You are my greatest invention.") def test_normalise_alias_format_string(self, mock): result = matching.normalise_alias_format_string( - 'Quite an experience to live in fear, isn\'t it?') + "Quite an experience to live in fear, isn't it?" + ) self.assertEqual([result[0]], result[1]) self.assertEqual(result[0], "Quite an experience to live in fear, isn't it?") def test_normalise_alias_format_string_error(self, mock): alias_list = ["Quite an experience to live in fear, isn't it?"] - expected_msg = ("alias_format '%s' is neither a dictionary or string type." - % repr(alias_list)) + expected_msg = ( + "alias_format '%s' is neither a dictionary or string type." + % repr(alias_list) + ) with self.assertRaises(TypeError) as cm: matching.normalise_alias_format_string(alias_list) @@ -115,13 +156,16 @@ def test_normalise_alias_format_string_error(self, mock): def test_matching(self, mock): ALIASES = [ - MemoryActionAliasDB(name="spengler", ref="ghostbusters.1", - formats=["{{choice}} cross the {{target}}"]), + MemoryActionAliasDB( + name="spengler", + ref="ghostbusters.1", + formats=["{{choice}} cross the {{target}}"], + ), ] COMMAND = "Don't cross the streams" match = matching.match_command_to_alias(COMMAND, ALIASES) self.assertEqual(len(match), 1) - self.assertEqual(match[0]['alias'].ref, "ghostbusters.1") - self.assertEqual(match[0]['representation'], "{{choice}} cross the {{target}}") + self.assertEqual(match[0]["alias"].ref, "ghostbusters.1") + self.assertEqual(match[0]["representation"], "{{choice}} cross the {{target}}") # we need some more complex scenarios in here. diff --git a/st2common/tests/unit/test_util_api.py b/st2common/tests/unit/test_util_api.py index bc0e385df18..2333939b13f 100644 --- a/st2common/tests/unit/test_util_api.py +++ b/st2common/tests/unit/test_util_api.py @@ -23,24 +23,25 @@ from st2common.util.api import get_full_public_api_url from st2tests.config import parse_args from six.moves import zip + parse_args() class APIUtilsTestCase(unittest2.TestCase): def test_get_base_public_api_url(self): values = [ - 'http://foo.bar.com', - 'http://foo.bar.com/', - 'http://foo.bar.com:8080', - 'http://foo.bar.com:8080/', - 'http://localhost:8080/', + "http://foo.bar.com", + "http://foo.bar.com/", + "http://foo.bar.com:8080", + "http://foo.bar.com:8080/", + "http://localhost:8080/", ] expected = [ - 'http://foo.bar.com', - 'http://foo.bar.com', - 'http://foo.bar.com:8080', - 'http://foo.bar.com:8080', - 'http://localhost:8080', + "http://foo.bar.com", + "http://foo.bar.com", + "http://foo.bar.com:8080", + "http://foo.bar.com:8080", + "http://localhost:8080", ] for mock_value, expected_result in zip(values, expected): @@ -50,18 +51,18 @@ def test_get_base_public_api_url(self): def test_get_full_public_api_url(self): values = [ - 'http://foo.bar.com', - 'http://foo.bar.com/', - 'http://foo.bar.com:8080', - 'http://foo.bar.com:8080/', - 'http://localhost:8080/', + "http://foo.bar.com", + "http://foo.bar.com/", + "http://foo.bar.com:8080", + "http://foo.bar.com:8080/", + "http://localhost:8080/", ] expected = [ - 'http://foo.bar.com/' + DEFAULT_API_VERSION, - 'http://foo.bar.com/' + DEFAULT_API_VERSION, - 'http://foo.bar.com:8080/' + DEFAULT_API_VERSION, - 'http://foo.bar.com:8080/' + DEFAULT_API_VERSION, - 'http://localhost:8080/' + DEFAULT_API_VERSION, + "http://foo.bar.com/" + DEFAULT_API_VERSION, + "http://foo.bar.com/" + DEFAULT_API_VERSION, + "http://foo.bar.com:8080/" + DEFAULT_API_VERSION, + "http://foo.bar.com:8080/" + DEFAULT_API_VERSION, + "http://localhost:8080/" + DEFAULT_API_VERSION, ] for mock_value, expected_result in zip(values, expected): diff --git a/st2common/tests/unit/test_util_compat.py b/st2common/tests/unit/test_util_compat.py index 0e1ac9efe7c..74face7ea64 100644 --- a/st2common/tests/unit/test_util_compat.py +++ b/st2common/tests/unit/test_util_compat.py @@ -19,18 +19,16 @@ from st2common.util.compat import to_ascii -__all__ = [ - 'CompatUtilsTestCase' -] +__all__ = ["CompatUtilsTestCase"] class CompatUtilsTestCase(unittest2.TestCase): def test_to_ascii(self): expected_values = [ - ('already ascii', 'already ascii'), - (u'foo', 'foo'), - ('٩(̾●̮̮̃̾•̃̾)۶', '()'), - ('\xd9\xa9', '') + ("already ascii", "already ascii"), + ("foo", "foo"), + ("٩(̾●̮̮̃̾•̃̾)۶", "()"), + ("\xd9\xa9", ""), ] for input_value, expected_value in expected_values: diff --git a/st2common/tests/unit/test_util_db.py b/st2common/tests/unit/test_util_db.py index dd230e6ae11..f94a2fe39a0 100644 --- a/st2common/tests/unit/test_util_db.py +++ b/st2common/tests/unit/test_util_db.py @@ -22,88 +22,73 @@ class DatabaseUtilTestCase(unittest2.TestCase): - def test_noop_mongodb_to_python_types(self): - data = [ - 123, - 999.99, - True, - [10, 20, 30], - {'a': 1, 'b': 2}, - None - ] + data = [123, 999.99, True, [10, 20, 30], {"a": 1, "b": 2}, None] for item in data: self.assertEqual(db_util.mongodb_to_python_types(item), item) def test_mongodb_basedict_to_dict(self): - data = {'a': 1, 'b': 2} + data = {"a": 1, "b": 2} - obj = mongoengine.base.datastructures.BaseDict(data, None, 'foobar') + obj = mongoengine.base.datastructures.BaseDict(data, None, "foobar") self.assertDictEqual(db_util.mongodb_to_python_types(obj), data) def test_mongodb_baselist_to_list(self): data = [2, 4, 6] - obj = mongoengine.base.datastructures.BaseList(data, None, 'foobar') + obj = mongoengine.base.datastructures.BaseList(data, None, "foobar") self.assertListEqual(db_util.mongodb_to_python_types(obj), data) def test_nested_mongdb_to_python_types(self): data = { - 'a': mongoengine.base.datastructures.BaseList([1, 2, 3], None, 'a'), - 'b': mongoengine.base.datastructures.BaseDict({'a': 1, 'b': 2}, None, 'b'), - 'c': { - 'd': mongoengine.base.datastructures.BaseList([4, 5, 6], None, 'd'), - 'e': mongoengine.base.datastructures.BaseDict({'c': 3, 'd': 4}, None, 'e') + "a": mongoengine.base.datastructures.BaseList([1, 2, 3], None, "a"), + "b": mongoengine.base.datastructures.BaseDict({"a": 1, "b": 2}, None, "b"), + "c": { + "d": mongoengine.base.datastructures.BaseList([4, 5, 6], None, "d"), + "e": mongoengine.base.datastructures.BaseDict( + {"c": 3, "d": 4}, None, "e" + ), }, - 'f': mongoengine.base.datastructures.BaseList( + "f": mongoengine.base.datastructures.BaseList( [ - mongoengine.base.datastructures.BaseDict({'e': 5}, None, 'f1'), - mongoengine.base.datastructures.BaseDict({'f': 6}, None, 'f2') + mongoengine.base.datastructures.BaseDict({"e": 5}, None, "f1"), + mongoengine.base.datastructures.BaseDict({"f": 6}, None, "f2"), ], None, - 'f' + "f", ), - 'g': mongoengine.base.datastructures.BaseDict( + "g": mongoengine.base.datastructures.BaseDict( { - 'h': mongoengine.base.datastructures.BaseList( + "h": mongoengine.base.datastructures.BaseList( [ - mongoengine.base.datastructures.BaseDict({'g': 7}, None, 'h1'), - mongoengine.base.datastructures.BaseDict({'h': 8}, None, 'h2') + mongoengine.base.datastructures.BaseDict( + {"g": 7}, None, "h1" + ), + mongoengine.base.datastructures.BaseDict( + {"h": 8}, None, "h2" + ), ], None, - 'h' + "h", + ), + "i": mongoengine.base.datastructures.BaseDict( + {"j": 9, "k": 10}, None, "i" ), - 'i': mongoengine.base.datastructures.BaseDict({'j': 9, 'k': 10}, None, 'i') }, None, - 'g' + "g", ), } expected = { - 'a': [1, 2, 3], - 'b': {'a': 1, 'b': 2}, - 'c': { - 'd': [4, 5, 6], - 'e': {'c': 3, 'd': 4} - }, - 'f': [ - {'e': 5}, - {'f': 6} - ], - 'g': { - 'h': [ - {'g': 7}, - {'h': 8} - ], - 'i': { - 'j': 9, - 'k': 10 - } - } + "a": [1, 2, 3], + "b": {"a": 1, "b": 2}, + "c": {"d": [4, 5, 6], "e": {"c": 3, "d": 4}}, + "f": [{"e": 5}, {"f": 6}], + "g": {"h": [{"g": 7}, {"h": 8}], "i": {"j": 9, "k": 10}}, } self.assertDictEqual(db_util.mongodb_to_python_types(data), expected) diff --git a/st2common/tests/unit/test_util_file_system.py b/st2common/tests/unit/test_util_file_system.py index ea46a0b943d..a1af0c957ac 100644 --- a/st2common/tests/unit/test_util_file_system.py +++ b/st2common/tests/unit/test_util_file_system.py @@ -22,30 +22,32 @@ from st2common.util.file_system import get_file_list CURRENT_DIR = os.path.dirname(__file__) -ST2TESTS_DIR = os.path.join(CURRENT_DIR, '../../../st2tests/st2tests') +ST2TESTS_DIR = os.path.join(CURRENT_DIR, "../../../st2tests/st2tests") class FileSystemUtilsTestCase(unittest2.TestCase): def test_get_file_list(self): # Standard exclude pattern - directory = os.path.join(ST2TESTS_DIR, 'policies') + directory = os.path.join(ST2TESTS_DIR, "policies") expected = [ - 'mock_exception.py', - 'concurrency.py', - '__init__.py', - 'meta/mock_exception.yaml', - 'meta/concurrency.yaml', - 'meta/__init__.py' + "mock_exception.py", + "concurrency.py", + "__init__.py", + "meta/mock_exception.yaml", + "meta/concurrency.yaml", + "meta/__init__.py", ] - result = get_file_list(directory=directory, exclude_patterns=['*.pyc']) + result = get_file_list(directory=directory, exclude_patterns=["*.pyc"]) self.assertItemsEqual(expected, result) # Custom exclude pattern expected = [ - 'mock_exception.py', - 'concurrency.py', - '__init__.py', - 'meta/__init__.py' + "mock_exception.py", + "concurrency.py", + "__init__.py", + "meta/__init__.py", ] - result = get_file_list(directory=directory, exclude_patterns=['*.pyc', '*.yaml']) + result = get_file_list( + directory=directory, exclude_patterns=["*.pyc", "*.yaml"] + ) self.assertItemsEqual(expected, result) diff --git a/st2common/tests/unit/test_util_http.py b/st2common/tests/unit/test_util_http.py index 2bfbc22f049..a97aa8c7f1f 100644 --- a/st2common/tests/unit/test_util_http.py +++ b/st2common/tests/unit/test_util_http.py @@ -19,24 +19,22 @@ from st2common.util.http import parse_content_type_header from six.moves import zip -__all__ = [ - 'HTTPUtilTestCase' -] +__all__ = ["HTTPUtilTestCase"] class HTTPUtilTestCase(unittest2.TestCase): def test_parse_content_type_header(self): values = [ - 'application/json', - 'foo/bar', - 'application/json; charset=utf-8', - 'application/json; charset=utf-8; foo=bar', + "application/json", + "foo/bar", + "application/json; charset=utf-8", + "application/json; charset=utf-8; foo=bar", ] expected_results = [ - ('application/json', {}), - ('foo/bar', {}), - ('application/json', {'charset': 'utf-8'}), - ('application/json', {'charset': 'utf-8', 'foo': 'bar'}) + ("application/json", {}), + ("foo/bar", {}), + ("application/json", {"charset": "utf-8"}), + ("application/json", {"charset": "utf-8", "foo": "bar"}), ] for value, expected_result in zip(values, expected_results): diff --git a/st2common/tests/unit/test_util_jinja.py b/st2common/tests/unit/test_util_jinja.py index 1b56adc0e94..127570f54b2 100644 --- a/st2common/tests/unit/test_util_jinja.py +++ b/st2common/tests/unit/test_util_jinja.py @@ -21,97 +21,95 @@ class JinjaUtilsRenderTestCase(unittest2.TestCase): - def test_render_values(self): actual = jinja_utils.render_values( - mapping={'k1': '{{a}}', 'k2': '{{b}}'}, - context={'a': 'v1', 'b': 'v2'}) - expected = {'k2': 'v2', 'k1': 'v1'} + mapping={"k1": "{{a}}", "k2": "{{b}}"}, context={"a": "v1", "b": "v2"} + ) + expected = {"k2": "v2", "k1": "v1"} self.assertEqual(actual, expected) def test_render_values_skip_missing(self): actual = jinja_utils.render_values( - mapping={'k1': '{{a}}', 'k2': '{{b}}', 'k3': '{{c}}'}, - context={'a': 'v1', 'b': 'v2'}, - allow_undefined=True) - expected = {'k2': 'v2', 'k1': 'v1', 'k3': ''} + mapping={"k1": "{{a}}", "k2": "{{b}}", "k3": "{{c}}"}, + context={"a": "v1", "b": "v2"}, + allow_undefined=True, + ) + expected = {"k2": "v2", "k1": "v1", "k3": ""} self.assertEqual(actual, expected) def test_render_values_ascii_and_unicode_values(self): - mapping = { - u'k_ascii': '{{a}}', - u'k_unicode': '{{b}}', - u'k_ascii_unicode': '{{c}}'} + mapping = {"k_ascii": "{{a}}", "k_unicode": "{{b}}", "k_ascii_unicode": "{{c}}"} context = { - 'a': u'some ascii value', - 'b': u'٩(̾●̮̮̃̾•̃̾)۶ ٩(̾●̮̮̃̾•̃̾)۶ ćšž', - 'c': u'some ascii some ٩(̾●̮̮̃̾•̃̾)۶ ٩(̾●̮̮̃̾•̃̾)۶ ' + "a": "some ascii value", + "b": "٩(̾●̮̮̃̾•̃̾)۶ ٩(̾●̮̮̃̾•̃̾)۶ ćšž", + "c": "some ascii some ٩(̾●̮̮̃̾•̃̾)۶ ٩(̾●̮̮̃̾•̃̾)۶ ", } expected = { - 'k_ascii': u'some ascii value', - 'k_unicode': u'٩(̾●̮̮̃̾•̃̾)۶ ٩(̾●̮̮̃̾•̃̾)۶ ćšž', - 'k_ascii_unicode': u'some ascii some ٩(̾●̮̮̃̾•̃̾)۶ ٩(̾●̮̮̃̾•̃̾)۶ ' + "k_ascii": "some ascii value", + "k_unicode": "٩(̾●̮̮̃̾•̃̾)۶ ٩(̾●̮̮̃̾•̃̾)۶ ćšž", + "k_ascii_unicode": "some ascii some ٩(̾●̮̮̃̾•̃̾)۶ ٩(̾●̮̮̃̾•̃̾)۶ ", } actual = jinja_utils.render_values( - mapping=mapping, - context=context, - allow_undefined=True) + mapping=mapping, context=context, allow_undefined=True + ) self.assertEqual(actual, expected) def test_convert_str_to_raw(self): - jinja_expr = '{{foobar}}' - expected_raw_block = '{% raw %}{{foobar}}{% endraw %}' - self.assertEqual(expected_raw_block, jinja_utils.convert_jinja_to_raw_block(jinja_expr)) + jinja_expr = "{{foobar}}" + expected_raw_block = "{% raw %}{{foobar}}{% endraw %}" + self.assertEqual( + expected_raw_block, jinja_utils.convert_jinja_to_raw_block(jinja_expr) + ) - jinja_block_expr = '{% for item in items %}foobar{% end for %}' - expected_raw_block = '{% raw %}{% for item in items %}foobar{% end for %}{% endraw %}' + jinja_block_expr = "{% for item in items %}foobar{% end for %}" + expected_raw_block = ( + "{% raw %}{% for item in items %}foobar{% end for %}{% endraw %}" + ) self.assertEqual( - expected_raw_block, - jinja_utils.convert_jinja_to_raw_block(jinja_block_expr) + expected_raw_block, jinja_utils.convert_jinja_to_raw_block(jinja_block_expr) ) def test_convert_list_to_raw(self): jinja_expr = [ - 'foobar', - '{{foo}}', - '{{bar}}', - '{% for item in items %}foobar{% end for %}', - {'foobar': '{{foobar}}'} + "foobar", + "{{foo}}", + "{{bar}}", + "{% for item in items %}foobar{% end for %}", + {"foobar": "{{foobar}}"}, ] expected_raw_block = [ - 'foobar', - '{% raw %}{{foo}}{% endraw %}', - '{% raw %}{{bar}}{% endraw %}', - '{% raw %}{% for item in items %}foobar{% end for %}{% endraw %}', - {'foobar': '{% raw %}{{foobar}}{% endraw %}'} + "foobar", + "{% raw %}{{foo}}{% endraw %}", + "{% raw %}{{bar}}{% endraw %}", + "{% raw %}{% for item in items %}foobar{% end for %}{% endraw %}", + {"foobar": "{% raw %}{{foobar}}{% endraw %}"}, ] - self.assertListEqual(expected_raw_block, jinja_utils.convert_jinja_to_raw_block(jinja_expr)) + self.assertListEqual( + expected_raw_block, jinja_utils.convert_jinja_to_raw_block(jinja_expr) + ) def test_convert_dict_to_raw(self): jinja_expr = { - 'var1': 'foobar', - 'var2': ['{{foo}}', '{{bar}}'], - 'var3': {'foobar': '{{foobar}}'}, - 'var4': {'foobar': '{% for item in items %}foobar{% end for %}'} + "var1": "foobar", + "var2": ["{{foo}}", "{{bar}}"], + "var3": {"foobar": "{{foobar}}"}, + "var4": {"foobar": "{% for item in items %}foobar{% end for %}"}, } expected_raw_block = { - 'var1': 'foobar', - 'var2': [ - '{% raw %}{{foo}}{% endraw %}', - '{% raw %}{{bar}}{% endraw %}' - ], - 'var3': { - 'foobar': '{% raw %}{{foobar}}{% endraw %}' + "var1": "foobar", + "var2": ["{% raw %}{{foo}}{% endraw %}", "{% raw %}{{bar}}{% endraw %}"], + "var3": {"foobar": "{% raw %}{{foobar}}{% endraw %}"}, + "var4": { + "foobar": "{% raw %}{% for item in items %}foobar{% end for %}{% endraw %}" }, - 'var4': { - 'foobar': '{% raw %}{% for item in items %}foobar{% end for %}{% endraw %}' - } } - self.assertDictEqual(expected_raw_block, jinja_utils.convert_jinja_to_raw_block(jinja_expr)) + self.assertDictEqual( + expected_raw_block, jinja_utils.convert_jinja_to_raw_block(jinja_expr) + ) diff --git a/st2common/tests/unit/test_util_keyvalue.py b/st2common/tests/unit/test_util_keyvalue.py index 07f061e60a9..5a8c15a3f76 100644 --- a/st2common/tests/unit/test_util_keyvalue.py +++ b/st2common/tests/unit/test_util_keyvalue.py @@ -18,14 +18,19 @@ import unittest2 from st2common.util import keyvalue as kv_utl -from st2common.constants.keyvalue import (FULL_SYSTEM_SCOPE, FULL_USER_SCOPE, USER_SCOPE, - ALL_SCOPE, DATASTORE_PARENT_SCOPE, - DATASTORE_SCOPE_SEPARATOR) +from st2common.constants.keyvalue import ( + FULL_SYSTEM_SCOPE, + FULL_USER_SCOPE, + USER_SCOPE, + ALL_SCOPE, + DATASTORE_PARENT_SCOPE, + DATASTORE_SCOPE_SEPARATOR, +) from st2common.exceptions.rbac import AccessDeniedError from st2common.models.db import auth as auth_db -USER = 'stanley' +USER = "stanley" class TestKeyValueUtil(unittest2.TestCase): @@ -38,48 +43,26 @@ def test_validate_scope(self): kv_utl._validate_scope(scope) def test_validate_scope_with_invalid_scope(self): - scope = 'INVALID_SCOPE' + scope = "INVALID_SCOPE" self.assertRaises(ValueError, kv_utl._validate_scope, scope) def test_validate_decrypt_query_parameter(self): test_params = [ - [ - False, - USER_SCOPE, - False, - {} - ], - [ - True, - USER_SCOPE, - False, - {} - ], - [ - True, - FULL_SYSTEM_SCOPE, - True, - {} - ], + [False, USER_SCOPE, False, {}], + [True, USER_SCOPE, False, {}], + [True, FULL_SYSTEM_SCOPE, True, {}], ] for params in test_params: kv_utl._validate_decrypt_query_parameter(*params) def test_validate_decrypt_query_parameter_access_denied(self): - test_params = [ - [ - True, - FULL_SYSTEM_SCOPE, - False, - {} - ] - ] + test_params = [[True, FULL_SYSTEM_SCOPE, False, {}]] for params in test_params: assert_params = [ AccessDeniedError, - kv_utl._validate_decrypt_query_parameter + kv_utl._validate_decrypt_query_parameter, ] assert_params.extend(params) @@ -88,81 +71,58 @@ def test_validate_decrypt_query_parameter_access_denied(self): def test_get_datastore_full_scope(self): self.assertEqual( kv_utl.get_datastore_full_scope(USER_SCOPE), - DATASTORE_SCOPE_SEPARATOR.join([DATASTORE_PARENT_SCOPE, USER_SCOPE]) + DATASTORE_SCOPE_SEPARATOR.join([DATASTORE_PARENT_SCOPE, USER_SCOPE]), ) def test_get_datastore_full_scope_all_scope(self): - self.assertEqual( - kv_utl.get_datastore_full_scope(ALL_SCOPE), - ALL_SCOPE - ) + self.assertEqual(kv_utl.get_datastore_full_scope(ALL_SCOPE), ALL_SCOPE) def test_get_datastore_full_scope_datastore_parent_scope(self): self.assertEqual( kv_utl.get_datastore_full_scope(DATASTORE_PARENT_SCOPE), - DATASTORE_PARENT_SCOPE + DATASTORE_PARENT_SCOPE, ) def test_derive_scope_and_key(self): - key = 'test' + key = "test" scope = USER_SCOPE result = kv_utl._derive_scope_and_key(key, scope) - self.assertEqual( - (FULL_USER_SCOPE, 'user:%s' % key), - result - ) + self.assertEqual((FULL_USER_SCOPE, "user:%s" % key), result) def test_derive_scope_and_key_without_scope(self): - key = 'test' + key = "test" scope = None result = kv_utl._derive_scope_and_key(key, scope) - self.assertEqual( - (FULL_USER_SCOPE, 'None:%s' % key), - result - ) + self.assertEqual((FULL_USER_SCOPE, "None:%s" % key), result) def test_derive_scope_and_key_system_key(self): - key = 'system.test' + key = "system.test" scope = None result = kv_utl._derive_scope_and_key(key, scope) - self.assertEqual( - (FULL_SYSTEM_SCOPE, key.split('.')[1]), - result - ) + self.assertEqual((FULL_SYSTEM_SCOPE, key.split(".")[1]), result) - @mock.patch('st2common.util.keyvalue.KeyValuePair') - @mock.patch('st2common.util.keyvalue.deserialize_key_value') + @mock.patch("st2common.util.keyvalue.KeyValuePair") + @mock.patch("st2common.util.keyvalue.deserialize_key_value") def test_get_key(self, deseralize_key_value, KeyValuePair): - key, value = ('Lindsay', 'Lohan') + key, value = ("Lindsay", "Lohan") decrypt = False KeyValuePair.get_by_scope_and_name().value = value deseralize_key_value.return_value = value - result = kv_utl.get_key(key=key, user_db=auth_db.UserDB(name=USER), decrypt=decrypt) + result = kv_utl.get_key( + key=key, user_db=auth_db.UserDB(name=USER), decrypt=decrypt + ) self.assertEqual(result, value) KeyValuePair.get_by_scope_and_name.assert_called_with( - FULL_USER_SCOPE, - 'stanley:%s' % key - ) - deseralize_key_value.assert_called_once_with( - value, - decrypt + FULL_USER_SCOPE, "stanley:%s" % key ) + deseralize_key_value.assert_called_once_with(value, decrypt) def test_get_key_invalid_input(self): - self.assertRaises( - TypeError, - kv_utl.get_key, - key=1 - ) - self.assertRaises( - TypeError, - kv_utl.get_key, - key='test', - decrypt='yep' - ) + self.assertRaises(TypeError, kv_utl.get_key, key=1) + self.assertRaises(TypeError, kv_utl.get_key, key="test", decrypt="yep") diff --git a/st2common/tests/unit/test_util_output_schema.py b/st2common/tests/unit/test_util_output_schema.py index d3ef387a266..af9570d4fa9 100644 --- a/st2common/tests/unit/test_util_output_schema.py +++ b/st2common/tests/unit/test_util_output_schema.py @@ -19,58 +19,46 @@ from st2common.constants.action import ( LIVEACTION_STATUS_SUCCEEDED, - LIVEACTION_STATUS_FAILED + LIVEACTION_STATUS_FAILED, ) ACTION_RESULT = { - 'output': { - 'output_1': 'Bobby', - 'output_2': 5, - 'deep_output': { - 'deep_item_1': 'Jindal', + "output": { + "output_1": "Bobby", + "output_2": 5, + "deep_output": { + "deep_item_1": "Jindal", }, } } RUNNER_SCHEMA = { - 'output': { - 'type': 'object' - }, - 'error': { - 'type': 'array' - }, + "output": {"type": "object"}, + "error": {"type": "array"}, } ACTION_SCHEMA = { - 'output_1': { - 'type': 'string' - }, - 'output_2': { - 'type': 'integer' - }, - 'deep_output': { - 'type': 'object', - 'parameters': { - 'deep_item_1': { - 'type': 'string', + "output_1": {"type": "string"}, + "output_2": {"type": "integer"}, + "deep_output": { + "type": "object", + "parameters": { + "deep_item_1": { + "type": "string", }, }, }, } RUNNER_SCHEMA_FAIL = { - 'not_a_key_you_have': { - 'type': 'string' - }, + "not_a_key_you_have": {"type": "string"}, } ACTION_SCHEMA_FAIL = { - 'not_a_key_you_have': { - 'type': 'string' - }, + "not_a_key_you_have": {"type": "string"}, } -OUTPUT_KEY = 'output' +OUTPUT_KEY = "output" class OutputSchemaTestCase(unittest2.TestCase): @@ -96,7 +84,7 @@ def test_invalid_runner_schema(self): ) expected_result = { - 'error': ( + "error": ( "Additional properties are not allowed ('output' was unexpected)" "\n\nFailed validating 'additionalProperties' in schema:\n {'addi" "tionalProperties': False,\n 'properties': {'not_a_key_you_have': " @@ -104,7 +92,7 @@ def test_invalid_runner_schema(self): "output': {'deep_output': {'deep_item_1': 'Jindal'},\n " "'output_1': 'Bobby',\n 'output_2': 5}}" ), - 'message': 'Error validating output. See error output for more details.' + "message": "Error validating output. See error output for more details.", } self.assertEqual(result, expected_result) @@ -120,12 +108,12 @@ def test_invalid_action_schema(self): ) expected_result = { - 'error': "Additional properties are not allowed", - 'message': u'Error validating output. See error output for more details.' + "error": "Additional properties are not allowed", + "message": "Error validating output. See error output for more details.", } # To avoid random failures (especially in python3) this assert cant be # exact since the parameters can be ordered differently per execution. - self.assertIn(expected_result['error'], result['error']) - self.assertEqual(result['message'], expected_result['message']) + self.assertIn(expected_result["error"], result["error"]) + self.assertEqual(result["message"], expected_result["message"]) self.assertEqual(status, LIVEACTION_STATUS_FAILED) diff --git a/st2common/tests/unit/test_util_pack.py b/st2common/tests/unit/test_util_pack.py index 20522b8b188..0b476b73362 100644 --- a/st2common/tests/unit/test_util_pack.py +++ b/st2common/tests/unit/test_util_pack.py @@ -22,59 +22,47 @@ class PackUtilsTestCase(unittest2.TestCase): - def test_get_pack_common_libs_path_for_pack_db(self): pack_model_args = { - 'name': 'Yolo CI', - 'ref': 'yolo_ci', - 'description': 'YOLO CI pack', - 'version': '0.1.0', - 'author': 'Volkswagen', - 'path': '/opt/stackstorm/packs/yolo_ci/' + "name": "Yolo CI", + "ref": "yolo_ci", + "description": "YOLO CI pack", + "version": "0.1.0", + "author": "Volkswagen", + "path": "/opt/stackstorm/packs/yolo_ci/", } pack_db = PackDB(**pack_model_args) lib_path = get_pack_common_libs_path_for_pack_db(pack_db) - self.assertEqual('/opt/stackstorm/packs/yolo_ci/lib', lib_path) + self.assertEqual("/opt/stackstorm/packs/yolo_ci/lib", lib_path) def test_get_pack_common_libs_path_for_pack_db_no_path_in_pack_db(self): pack_model_args = { - 'name': 'Yolo CI', - 'ref': 'yolo_ci', - 'description': 'YOLO CI pack', - 'version': '0.1.0', - 'author': 'Volkswagen' + "name": "Yolo CI", + "ref": "yolo_ci", + "description": "YOLO CI pack", + "version": "0.1.0", + "author": "Volkswagen", } pack_db = PackDB(**pack_model_args) lib_path = get_pack_common_libs_path_for_pack_db(pack_db) self.assertEqual(None, lib_path) def test_get_pack_warnings_python2_only(self): - pack_metadata = { - 'python_versions': ['2'], - 'name': 'Pack2' - } + pack_metadata = {"python_versions": ["2"], "name": "Pack2"} warning = get_pack_warnings(pack_metadata) self.assertTrue("DEPRECATION WARNING" in warning) def test_get_pack_warnings_python3_only(self): - pack_metadata = { - 'python_versions': ['3'], - 'name': 'Pack3' - } + pack_metadata = {"python_versions": ["3"], "name": "Pack3"} warning = get_pack_warnings(pack_metadata) self.assertEqual(None, warning) def test_get_pack_warnings_python2_and_3(self): - pack_metadata = { - 'python_versions': ['2', '3'], - 'name': 'Pack23' - } + pack_metadata = {"python_versions": ["2", "3"], "name": "Pack23"} warning = get_pack_warnings(pack_metadata) self.assertEqual(None, warning) def test_get_pack_warnings_no_python(self): - pack_metadata = { - 'name': 'PackNone' - } + pack_metadata = {"name": "PackNone"} warning = get_pack_warnings(pack_metadata) self.assertEqual(None, warning) diff --git a/st2common/tests/unit/test_util_payload.py b/st2common/tests/unit/test_util_payload.py index 207d4c17661..2621e3de91e 100644 --- a/st2common/tests/unit/test_util_payload.py +++ b/st2common/tests/unit/test_util_payload.py @@ -19,27 +19,31 @@ from st2common.util.payload import PayloadLookup -__all__ = [ - 'PayloadLookupTestCase' -] +__all__ = ["PayloadLookupTestCase"] class PayloadLookupTestCase(unittest2.TestCase): @classmethod def setUpClass(cls): - cls.payload = PayloadLookup({ - 'pikachu': "Has no ears", - 'charmander': "Plays with fire", - }) + cls.payload = PayloadLookup( + { + "pikachu": "Has no ears", + "charmander": "Plays with fire", + } + ) super(PayloadLookupTestCase, cls).setUpClass() def test_get_key(self): - self.assertEqual(self.payload.get_value('trigger.pikachu'), ["Has no ears"]) - self.assertEqual(self.payload.get_value('trigger.charmander'), ["Plays with fire"]) + self.assertEqual(self.payload.get_value("trigger.pikachu"), ["Has no ears"]) + self.assertEqual( + self.payload.get_value("trigger.charmander"), ["Plays with fire"] + ) def test_explicitly_get_multiple_keys(self): - self.assertEqual(self.payload.get_value('trigger.pikachu[*]'), ["Has no ears"]) - self.assertEqual(self.payload.get_value('trigger.charmander[*]'), ["Plays with fire"]) + self.assertEqual(self.payload.get_value("trigger.pikachu[*]"), ["Has no ears"]) + self.assertEqual( + self.payload.get_value("trigger.charmander[*]"), ["Plays with fire"] + ) def test_get_nonexistent_key(self): - self.assertIsNone(self.payload.get_value('trigger.squirtle')) + self.assertIsNone(self.payload.get_value("trigger.squirtle")) diff --git a/st2common/tests/unit/test_util_sandboxing.py b/st2common/tests/unit/test_util_sandboxing.py index 5f387e0067f..3926c9f74c6 100644 --- a/st2common/tests/unit/test_util_sandboxing.py +++ b/st2common/tests/unit/test_util_sandboxing.py @@ -32,9 +32,7 @@ import st2tests.config as tests_config -__all__ = [ - 'SandboxingUtilsTestCase' -] +__all__ = ["SandboxingUtilsTestCase"] class SandboxingUtilsTestCase(unittest.TestCase): @@ -69,8 +67,10 @@ def assertEndsWith(self, string, ending_substr, msg=None): def test_get_sandbox_python_binary_path(self): # Non-system content pack, should use pack specific virtualenv binary - result = get_sandbox_python_binary_path(pack='mapack') - expected = os.path.join(cfg.CONF.system.base_path, 'virtualenvs/mapack/bin/python') + result = get_sandbox_python_binary_path(pack="mapack") + expected = os.path.join( + cfg.CONF.system.base_path, "virtualenvs/mapack/bin/python" + ) self.assertEqual(result, expected) # System content pack, should use current process (system) python binary @@ -78,159 +78,190 @@ def test_get_sandbox_python_binary_path(self): self.assertEqual(result, sys.executable) def test_get_sandbox_path(self): - virtualenv_path = '/home/venv/test' + virtualenv_path = "/home/venv/test" # Mock the current PATH value - with mock.patch.dict(os.environ, {'PATH': '/home/path1:/home/path2:/home/path3:'}): + with mock.patch.dict( + os.environ, {"PATH": "/home/path1:/home/path2:/home/path3:"} + ): result = get_sandbox_path(virtualenv_path=virtualenv_path) - self.assertEqual(result, f'{virtualenv_path}/bin/:/home/path1:/home/path2:/home/path3') + self.assertEqual( + result, f"{virtualenv_path}/bin/:/home/path1:/home/path2:/home/path3" + ) - @mock.patch('st2common.util.sandboxing.get_python_lib') + @mock.patch("st2common.util.sandboxing.get_python_lib") def test_get_sandbox_python_path(self, mock_get_python_lib): # No inheritance - python_path = get_sandbox_python_path(inherit_from_parent=False, - inherit_parent_virtualenv=False) - self.assertEqual(python_path, ':') + python_path = get_sandbox_python_path( + inherit_from_parent=False, inherit_parent_virtualenv=False + ) + self.assertEqual(python_path, ":") # Inherit python path from current process # Mock the current process python path - with mock.patch.dict(os.environ, {'PYTHONPATH': ':/data/test1:/data/test2'}): - python_path = get_sandbox_python_path(inherit_from_parent=True, - inherit_parent_virtualenv=False) + with mock.patch.dict(os.environ, {"PYTHONPATH": ":/data/test1:/data/test2"}): + python_path = get_sandbox_python_path( + inherit_from_parent=True, inherit_parent_virtualenv=False + ) - self.assertEqual(python_path, ':/data/test1:/data/test2') + self.assertEqual(python_path, ":/data/test1:/data/test2") # Inherit from current process and from virtualenv (not running inside virtualenv) clear_virtualenv_prefix() - with mock.patch.dict(os.environ, {'PYTHONPATH': ':/data/test1:/data/test2'}): - python_path = get_sandbox_python_path(inherit_from_parent=True, - inherit_parent_virtualenv=False) + with mock.patch.dict(os.environ, {"PYTHONPATH": ":/data/test1:/data/test2"}): + python_path = get_sandbox_python_path( + inherit_from_parent=True, inherit_parent_virtualenv=False + ) - self.assertEqual(python_path, ':/data/test1:/data/test2') + self.assertEqual(python_path, ":/data/test1:/data/test2") # Inherit from current process and from virtualenv (running inside virtualenv) - sys.real_prefix = '/usr' - mock_get_python_lib.return_value = f'{sys.prefix}/virtualenvtest' + sys.real_prefix = "/usr" + mock_get_python_lib.return_value = f"{sys.prefix}/virtualenvtest" - with mock.patch.dict(os.environ, {'PYTHONPATH': ':/data/test1:/data/test2'}): - python_path = get_sandbox_python_path(inherit_from_parent=True, - inherit_parent_virtualenv=True) + with mock.patch.dict(os.environ, {"PYTHONPATH": ":/data/test1:/data/test2"}): + python_path = get_sandbox_python_path( + inherit_from_parent=True, inherit_parent_virtualenv=True + ) - self.assertEqual(python_path, f':/data/test1:/data/test2:{sys.prefix}/virtualenvtest') + self.assertEqual( + python_path, f":/data/test1:/data/test2:{sys.prefix}/virtualenvtest" + ) - @mock.patch('os.path.isdir', mock.Mock(return_value=True)) - @mock.patch('os.listdir', mock.Mock(return_value=['python3.6'])) - @mock.patch('st2common.util.sandboxing.get_python_lib') - def test_get_sandbox_python_path_for_python_action_no_inheritance(self, - mock_get_python_lib): + @mock.patch("os.path.isdir", mock.Mock(return_value=True)) + @mock.patch("os.listdir", mock.Mock(return_value=["python3.6"])) + @mock.patch("st2common.util.sandboxing.get_python_lib") + def test_get_sandbox_python_path_for_python_action_no_inheritance( + self, mock_get_python_lib + ): # No inheritance - python_path = get_sandbox_python_path_for_python_action(pack='dummy_pack', - inherit_from_parent=False, - inherit_parent_virtualenv=False) + python_path = get_sandbox_python_path_for_python_action( + pack="dummy_pack", + inherit_from_parent=False, + inherit_parent_virtualenv=False, + ) - actual_path = python_path.strip(':').split(':') + actual_path = python_path.strip(":").split(":") self.assertEqual(len(actual_path), 3) # First entry should be lib/python3 dir from venv - self.assertEndsWith(actual_path[0], 'virtualenvs/dummy_pack/lib/python3.6') + self.assertEndsWith(actual_path[0], "virtualenvs/dummy_pack/lib/python3.6") # Second entry should be python3 site-packages dir from venv - self.assertEndsWith(actual_path[1], 'virtualenvs/dummy_pack/lib/python3.6/site-packages') + self.assertEndsWith( + actual_path[1], "virtualenvs/dummy_pack/lib/python3.6/site-packages" + ) # Third entry should be actions/lib dir from pack root directory - self.assertEndsWith(actual_path[2], 'packs/dummy_pack/actions/lib') + self.assertEndsWith(actual_path[2], "packs/dummy_pack/actions/lib") - @mock.patch('os.path.isdir', mock.Mock(return_value=True)) - @mock.patch('os.listdir', mock.Mock(return_value=['python3.6'])) - @mock.patch('st2common.util.sandboxing.get_python_lib') - def test_get_sandbox_python_path_for_python_action_inherit_from_parent_process_only(self, - mock_get_python_lib): + @mock.patch("os.path.isdir", mock.Mock(return_value=True)) + @mock.patch("os.listdir", mock.Mock(return_value=["python3.6"])) + @mock.patch("st2common.util.sandboxing.get_python_lib") + def test_get_sandbox_python_path_for_python_action_inherit_from_parent_process_only( + self, mock_get_python_lib + ): # Inherit python path from current process # Mock the current process python path - with mock.patch.dict(os.environ, {'PYTHONPATH': ':/data/test1:/data/test2'}): - python_path = get_sandbox_python_path(inherit_from_parent=True, - inherit_parent_virtualenv=False) + with mock.patch.dict(os.environ, {"PYTHONPATH": ":/data/test1:/data/test2"}): + python_path = get_sandbox_python_path( + inherit_from_parent=True, inherit_parent_virtualenv=False + ) - self.assertEqual(python_path, ':/data/test1:/data/test2') + self.assertEqual(python_path, ":/data/test1:/data/test2") - python_path = get_sandbox_python_path_for_python_action(pack='dummy_pack', - inherit_from_parent=True, - inherit_parent_virtualenv=False) + python_path = get_sandbox_python_path_for_python_action( + pack="dummy_pack", + inherit_from_parent=True, + inherit_parent_virtualenv=False, + ) - actual_path = python_path.strip(':').split(':') + actual_path = python_path.strip(":").split(":") self.assertEqual(len(actual_path), 6) # First entry should be lib/python3 dir from venv - self.assertEndsWith(actual_path[0], 'virtualenvs/dummy_pack/lib/python3.6') + self.assertEndsWith(actual_path[0], "virtualenvs/dummy_pack/lib/python3.6") # Second entry should be python3 site-packages dir from venv - self.assertEndsWith(actual_path[1], 'virtualenvs/dummy_pack/lib/python3.6/site-packages') + self.assertEndsWith( + actual_path[1], "virtualenvs/dummy_pack/lib/python3.6/site-packages" + ) # Third entry should be actions/lib dir from pack root directory - self.assertEndsWith(actual_path[2], 'packs/dummy_pack/actions/lib') + self.assertEndsWith(actual_path[2], "packs/dummy_pack/actions/lib") # And the rest of the paths from get_sandbox_python_path - self.assertEqual(actual_path[3], '') - self.assertEqual(actual_path[4], '/data/test1') - self.assertEqual(actual_path[5], '/data/test2') + self.assertEqual(actual_path[3], "") + self.assertEqual(actual_path[4], "/data/test1") + self.assertEqual(actual_path[5], "/data/test2") - @mock.patch('os.path.isdir', mock.Mock(return_value=True)) - @mock.patch('os.listdir', mock.Mock(return_value=['python3.6'])) - @mock.patch('st2common.util.sandboxing.get_python_lib') - def test_get_sandbox_python_path_for_python_action_inherit_from_parent_process_and_venv(self, - mock_get_python_lib): + @mock.patch("os.path.isdir", mock.Mock(return_value=True)) + @mock.patch("os.listdir", mock.Mock(return_value=["python3.6"])) + @mock.patch("st2common.util.sandboxing.get_python_lib") + def test_get_sandbox_python_path_for_python_action_inherit_from_parent_process_and_venv( + self, mock_get_python_lib + ): # Inherit from current process and from virtualenv (not running inside virtualenv) clear_virtualenv_prefix() # Inherit python path from current process # Mock the current process python path - with mock.patch.dict(os.environ, {'PYTHONPATH': ':/data/test1:/data/test2'}): - python_path = get_sandbox_python_path(inherit_from_parent=True, - inherit_parent_virtualenv=False) + with mock.patch.dict(os.environ, {"PYTHONPATH": ":/data/test1:/data/test2"}): + python_path = get_sandbox_python_path( + inherit_from_parent=True, inherit_parent_virtualenv=False + ) - self.assertEqual(python_path, ':/data/test1:/data/test2') + self.assertEqual(python_path, ":/data/test1:/data/test2") - python_path = get_sandbox_python_path_for_python_action(pack='dummy_pack', - inherit_from_parent=True, - inherit_parent_virtualenv=True) + python_path = get_sandbox_python_path_for_python_action( + pack="dummy_pack", + inherit_from_parent=True, + inherit_parent_virtualenv=True, + ) - actual_path = python_path.strip(':').split(':') + actual_path = python_path.strip(":").split(":") self.assertEqual(len(actual_path), 6) # First entry should be lib/python3 dir from venv - self.assertEndsWith(actual_path[0], 'virtualenvs/dummy_pack/lib/python3.6') + self.assertEndsWith(actual_path[0], "virtualenvs/dummy_pack/lib/python3.6") # Second entry should be python3 site-packages dir from venv - self.assertEndsWith(actual_path[1], 'virtualenvs/dummy_pack/lib/python3.6/site-packages') + self.assertEndsWith( + actual_path[1], "virtualenvs/dummy_pack/lib/python3.6/site-packages" + ) # Third entry should be actions/lib dir from pack root directory - self.assertEndsWith(actual_path[2], 'packs/dummy_pack/actions/lib') + self.assertEndsWith(actual_path[2], "packs/dummy_pack/actions/lib") # And the rest of the paths from get_sandbox_python_path - self.assertEqual(actual_path[3], '') - self.assertEqual(actual_path[4], '/data/test1') - self.assertEqual(actual_path[5], '/data/test2') + self.assertEqual(actual_path[3], "") + self.assertEqual(actual_path[4], "/data/test1") + self.assertEqual(actual_path[5], "/data/test2") # Inherit from current process and from virtualenv (running inside virtualenv) - sys.real_prefix = '/usr' - mock_get_python_lib.return_value = f'{sys.prefix}/virtualenvtest' + sys.real_prefix = "/usr" + mock_get_python_lib.return_value = f"{sys.prefix}/virtualenvtest" # Inherit python path from current process # Mock the current process python path - with mock.patch.dict(os.environ, {'PYTHONPATH': ':/data/test1:/data/test2'}): - python_path = get_sandbox_python_path_for_python_action(pack='dummy_pack', - inherit_from_parent=True, - inherit_parent_virtualenv=True) - - actual_path = python_path.strip(':').split(':') + with mock.patch.dict(os.environ, {"PYTHONPATH": ":/data/test1:/data/test2"}): + python_path = get_sandbox_python_path_for_python_action( + pack="dummy_pack", + inherit_from_parent=True, + inherit_parent_virtualenv=True, + ) + + actual_path = python_path.strip(":").split(":") self.assertEqual(len(actual_path), 7) # First entry should be lib/python3 dir from venv - self.assertEndsWith(actual_path[0], 'virtualenvs/dummy_pack/lib/python3.6') + self.assertEndsWith(actual_path[0], "virtualenvs/dummy_pack/lib/python3.6") # Second entry should be python3 site-packages dir from venv - self.assertEndsWith(actual_path[1], 'virtualenvs/dummy_pack/lib/python3.6/site-packages') + self.assertEndsWith( + actual_path[1], "virtualenvs/dummy_pack/lib/python3.6/site-packages" + ) # Third entry should be actions/lib dir from pack root directory - self.assertEndsWith(actual_path[2], 'packs/dummy_pack/actions/lib') + self.assertEndsWith(actual_path[2], "packs/dummy_pack/actions/lib") # The paths from get_sandbox_python_path - self.assertEqual(actual_path[3], '') - self.assertEqual(actual_path[4], '/data/test1') - self.assertEqual(actual_path[5], '/data/test2') + self.assertEqual(actual_path[3], "") + self.assertEqual(actual_path[4], "/data/test1") + self.assertEqual(actual_path[5], "/data/test2") # And the parent virtualenv - self.assertEqual(actual_path[6], f'{sys.prefix}/virtualenvtest') + self.assertEqual(actual_path[6], f"{sys.prefix}/virtualenvtest") diff --git a/st2common/tests/unit/test_util_secrets.py b/st2common/tests/unit/test_util_secrets.py index f49f8f76a9e..8c77c34f49f 100644 --- a/st2common/tests/unit/test_util_secrets.py +++ b/st2common/tests/unit/test_util_secrets.py @@ -22,38 +22,30 @@ ################################################################################ TEST_FLAT_SCHEMA = { - 'arg_required_no_default': { - 'description': 'Foo', - 'required': True, - 'type': 'string', - 'secret': False + "arg_required_no_default": { + "description": "Foo", + "required": True, + "type": "string", + "secret": False, }, - 'arg_optional_no_type_secret': { - 'description': 'Bar', - 'secret': True - }, - 'arg_optional_type_array': { - 'description': 'Who''s the fairest?', - 'type': 'array' - }, - 'arg_optional_type_object': { - 'description': 'Who''s the fairest of them?', - 'type': 'object' + "arg_optional_no_type_secret": {"description": "Bar", "secret": True}, + "arg_optional_type_array": {"description": "Who" "s the fairest?", "type": "array"}, + "arg_optional_type_object": { + "description": "Who" "s the fairest of them?", + "type": "object", }, } -TEST_FLAT_SECRET_PARAMS = { - 'arg_optional_no_type_secret': None -} +TEST_FLAT_SECRET_PARAMS = {"arg_optional_no_type_secret": None} ################################################################################ TEST_NO_SECRETS_SCHEMA = { - 'arg_required_no_default': { - 'description': 'Foo', - 'required': True, - 'type': 'string', - 'secret': False + "arg_required_no_default": { + "description": "Foo", + "required": True, + "type": "string", + "secret": False, } } @@ -62,497 +54,397 @@ ################################################################################ TEST_NESTED_OBJECTS_SCHEMA = { - 'arg_string': { - 'description': 'Junk', - 'type': 'string', + "arg_string": { + "description": "Junk", + "type": "string", }, - 'arg_optional_object': { - 'description': 'Mirror', - 'type': 'object', - 'properties': { - 'arg_nested_object': { - 'description': 'Mirror mirror', - 'type': 'object', - 'properties': { - 'arg_double_nested_secret': { - 'description': 'Deep, deep down', - 'type': 'string', - 'secret': True + "arg_optional_object": { + "description": "Mirror", + "type": "object", + "properties": { + "arg_nested_object": { + "description": "Mirror mirror", + "type": "object", + "properties": { + "arg_double_nested_secret": { + "description": "Deep, deep down", + "type": "string", + "secret": True, } - } + }, }, - 'arg_nested_secret': { - 'description': 'Deep down', - 'type': 'string', - 'secret': True - } - } - } + "arg_nested_secret": { + "description": "Deep down", + "type": "string", + "secret": True, + }, + }, + }, } TEST_NESTED_OBJECTS_SECRET_PARAMS = { - 'arg_optional_object': { - 'arg_nested_secret': 'string', - 'arg_nested_object': { - 'arg_double_nested_secret': 'string', - } + "arg_optional_object": { + "arg_nested_secret": "string", + "arg_nested_object": { + "arg_double_nested_secret": "string", + }, } } ################################################################################ TEST_ARRAY_SCHEMA = { - 'arg_optional_array': { - 'description': 'Mirror', - 'type': 'array', - 'items': { - 'description': 'down', - 'type': 'string', - 'secret': True - } + "arg_optional_array": { + "description": "Mirror", + "type": "array", + "items": {"description": "down", "type": "string", "secret": True}, } } -TEST_ARRAY_SECRET_PARAMS = { - 'arg_optional_array': [ - 'string' - ] -} +TEST_ARRAY_SECRET_PARAMS = {"arg_optional_array": ["string"]} ################################################################################ TEST_ROOT_ARRAY_SCHEMA = { - 'description': 'Mirror', - 'type': 'array', - 'items': { - 'description': 'down', - 'type': 'object', - 'properties': { - 'secret_field_in_object': { - 'type': 'string', - 'secret': True - } - } - } + "description": "Mirror", + "type": "array", + "items": { + "description": "down", + "type": "object", + "properties": {"secret_field_in_object": {"type": "string", "secret": True}}, + }, } -TEST_ROOT_ARRAY_SECRET_PARAMS = [ - { - 'secret_field_in_object': 'string' - } -] +TEST_ROOT_ARRAY_SECRET_PARAMS = [{"secret_field_in_object": "string"}] ################################################################################ TEST_ROOT_OBJECT_SCHEMA = { - 'description': 'root', - 'type': 'object', - 'properties': { - 'arg_level_one': { - 'description': 'down', - 'type': 'object', - 'properties': { - 'secret_field_in_object': { - 'type': 'string', - 'secret': True - } - } + "description": "root", + "type": "object", + "properties": { + "arg_level_one": { + "description": "down", + "type": "object", + "properties": { + "secret_field_in_object": {"type": "string", "secret": True} + }, } - } + }, } -TEST_ROOT_OBJECT_SECRET_PARAMS = { - 'arg_level_one': - { - 'secret_field_in_object': 'string' - } -} +TEST_ROOT_OBJECT_SECRET_PARAMS = {"arg_level_one": {"secret_field_in_object": "string"}} ################################################################################ TEST_NESTED_ARRAYS_SCHEMA = { - 'arg_optional_array': { - 'description': 'Mirror', - 'type': 'array', - 'items': { - 'description': 'Deep down', - 'type': 'string', - 'secret': True - } + "arg_optional_array": { + "description": "Mirror", + "type": "array", + "items": {"description": "Deep down", "type": "string", "secret": True}, }, - 'arg_optional_double_array': { - 'description': 'Mirror', - 'type': 'array', - 'items': { - 'type': 'array', - 'items': { - 'description': 'Deep down', - 'type': 'string', - 'secret': True - } - } + "arg_optional_double_array": { + "description": "Mirror", + "type": "array", + "items": { + "type": "array", + "items": {"description": "Deep down", "type": "string", "secret": True}, + }, }, - 'arg_optional_tripple_array': { - 'description': 'Mirror', - 'type': 'array', - 'items': { - 'type': 'array', - 'items': { - 'type': 'array', - 'items': { - 'description': 'Deep down', - 'type': 'string', - 'secret': True - } - } - } + "arg_optional_tripple_array": { + "description": "Mirror", + "type": "array", + "items": { + "type": "array", + "items": { + "type": "array", + "items": {"description": "Deep down", "type": "string", "secret": True}, + }, + }, + }, + "arg_optional_quad_array": { + "description": "Mirror", + "type": "array", + "items": { + "type": "array", + "items": { + "type": "array", + "items": { + "type": "array", + "items": { + "description": "Deep down", + "type": "string", + "secret": True, + }, + }, + }, + }, }, - 'arg_optional_quad_array': { - 'description': 'Mirror', - 'type': 'array', - 'items': { - 'type': 'array', - 'items': { - 'type': 'array', - 'items': { - 'type': 'array', - 'items': { - 'description': 'Deep down', - 'type': 'string', - 'secret': True - } - } - } - } - } } TEST_NESTED_ARRAYS_SECRET_PARAMS = { - 'arg_optional_array': [ - 'string' - ], - 'arg_optional_double_array': [ - [ - 'string' - ] - ], - 'arg_optional_tripple_array': [ - [ - [ - 'string' - ] - ] - ], - 'arg_optional_quad_array': [ - [ - [ - [ - 'string' - ] - ] - ] - ] + "arg_optional_array": ["string"], + "arg_optional_double_array": [["string"]], + "arg_optional_tripple_array": [[["string"]]], + "arg_optional_quad_array": [[[["string"]]]], } ################################################################################ TEST_NESTED_OBJECT_WITH_ARRAY_SCHEMA = { - 'arg_optional_object_with_array': { - 'description': 'Mirror', - 'type': 'object', - 'properties': { - 'arg_nested_array': { - 'description': 'Mirror', - 'type': 'array', - 'items': { - 'description': 'Deep down', - 'type': 'string', - 'secret': True - } + "arg_optional_object_with_array": { + "description": "Mirror", + "type": "object", + "properties": { + "arg_nested_array": { + "description": "Mirror", + "type": "array", + "items": {"description": "Deep down", "type": "string", "secret": True}, } - } + }, } } TEST_NESTED_OBJECT_WITH_ARRAY_SECRET_PARAMS = { - 'arg_optional_object_with_array': { - 'arg_nested_array': [ - 'string' - ] - } + "arg_optional_object_with_array": {"arg_nested_array": ["string"]} } ################################################################################ TEST_NESTED_OBJECT_WITH_DOUBLE_ARRAY_SCHEMA = { - 'arg_optional_object_with_double_array': { - 'description': 'Mirror', - 'type': 'object', - 'properties': { - 'arg_double_nested_array': { - 'description': 'Mirror', - 'type': 'array', - 'items': { - 'description': 'Mirror', - 'type': 'array', - 'items': { - 'description': 'Deep down', - 'type': 'string', - 'secret': True - } - } + "arg_optional_object_with_double_array": { + "description": "Mirror", + "type": "object", + "properties": { + "arg_double_nested_array": { + "description": "Mirror", + "type": "array", + "items": { + "description": "Mirror", + "type": "array", + "items": { + "description": "Deep down", + "type": "string", + "secret": True, + }, + }, } - } + }, } } TEST_NESTED_OBJECT_WITH_DOUBLE_ARRAY_SECRET_PARAMS = { - 'arg_optional_object_with_double_array': { - 'arg_double_nested_array': [ - [ - 'string' - ] - ] - } + "arg_optional_object_with_double_array": {"arg_double_nested_array": [["string"]]} } ################################################################################ TEST_NESTED_ARRAY_WITH_OBJECT_SCHEMA = { - 'arg_optional_array_with_object': { - 'description': 'Mirror', - 'type': 'array', - 'items': { - 'description': 'Mirror', - 'type': 'object', - 'properties': { - 'arg_nested_secret': { - 'description': 'Deep down', - 'type': 'string', - 'secret': True + "arg_optional_array_with_object": { + "description": "Mirror", + "type": "array", + "items": { + "description": "Mirror", + "type": "object", + "properties": { + "arg_nested_secret": { + "description": "Deep down", + "type": "string", + "secret": True, } - } - } + }, + }, } } TEST_NESTED_ARRAY_WITH_OBJECT_SECRET_PARAMS = { - 'arg_optional_array_with_object': [ - { - 'arg_nested_secret': 'string' - } - ] + "arg_optional_array_with_object": [{"arg_nested_secret": "string"}] } ################################################################################ TEST_SECRET_ARRAY_SCHEMA = { - 'arg_secret_array': { - 'description': 'Mirror', - 'type': 'array', - 'secret': True, + "arg_secret_array": { + "description": "Mirror", + "type": "array", + "secret": True, } } -TEST_SECRET_ARRAY_SECRET_PARAMS = { - 'arg_secret_array': 'array' -} +TEST_SECRET_ARRAY_SECRET_PARAMS = {"arg_secret_array": "array"} ################################################################################ TEST_SECRET_OBJECT_SCHEMA = { - 'arg_secret_object': { - 'type': 'object', - 'secret': True, + "arg_secret_object": { + "type": "object", + "secret": True, } } -TEST_SECRET_OBJECT_SECRET_PARAMS = { - 'arg_secret_object': 'object' -} +TEST_SECRET_OBJECT_SECRET_PARAMS = {"arg_secret_object": "object"} ################################################################################ TEST_SECRET_ROOT_ARRAY_SCHEMA = { - 'description': 'secret array', - 'type': 'array', - 'secret': True, - 'items': { - 'description': 'down', - 'type': 'object', - 'properties': { - 'secret_field_in_object': { - 'type': 'string', - 'secret': True - } - } - } + "description": "secret array", + "type": "array", + "secret": True, + "items": { + "description": "down", + "type": "object", + "properties": {"secret_field_in_object": {"type": "string", "secret": True}}, + }, } -TEST_SECRET_ROOT_ARRAY_SECRET_PARAMS = 'array' +TEST_SECRET_ROOT_ARRAY_SECRET_PARAMS = "array" ################################################################################ TEST_SECRET_ROOT_OBJECT_SCHEMA = { - 'description': 'secret object', - 'type': 'object', - 'secret': True, - 'proeprteis': { - 'arg_level_one': { - 'description': 'down', - 'type': 'object', - 'properties': { - 'secret_field_in_object': { - 'type': 'string', - 'secret': True - } - } + "description": "secret object", + "type": "object", + "secret": True, + "proeprteis": { + "arg_level_one": { + "description": "down", + "type": "object", + "properties": { + "secret_field_in_object": {"type": "string", "secret": True} + }, } - } + }, } -TEST_SECRET_ROOT_OBJECT_SECRET_PARAMS = 'object' +TEST_SECRET_ROOT_OBJECT_SECRET_PARAMS = "object" ################################################################################ TEST_SECRET_NESTED_OBJECTS_SCHEMA = { - 'arg_object': { - 'description': 'Mirror', - 'type': 'object', - 'properties': { - 'arg_nested_object': { - 'description': 'Mirror mirror', - 'type': 'object', - 'secret': True, - 'properties': { - 'arg_double_nested_secret': { - 'description': 'Deep, deep down', - 'type': 'string', - 'secret': True + "arg_object": { + "description": "Mirror", + "type": "object", + "properties": { + "arg_nested_object": { + "description": "Mirror mirror", + "type": "object", + "secret": True, + "properties": { + "arg_double_nested_secret": { + "description": "Deep, deep down", + "type": "string", + "secret": True, } - } + }, }, - 'arg_nested_secret': { - 'description': 'Deep down', - 'type': 'string', - 'secret': True - } - } + "arg_nested_secret": { + "description": "Deep down", + "type": "string", + "secret": True, + }, + }, }, - 'arg_secret_object': { - 'description': 'Mirror', - 'type': 'object', - 'secret': True, - 'properties': { - 'arg_nested_object': { - 'description': 'Mirror mirror', - 'type': 'object', - 'secret': True, - 'properties': { - 'arg_double_nested_secret': { - 'description': 'Deep, deep down', - 'type': 'string', - 'secret': True + "arg_secret_object": { + "description": "Mirror", + "type": "object", + "secret": True, + "properties": { + "arg_nested_object": { + "description": "Mirror mirror", + "type": "object", + "secret": True, + "properties": { + "arg_double_nested_secret": { + "description": "Deep, deep down", + "type": "string", + "secret": True, } - } + }, }, - 'arg_nested_secret': { - 'description': 'Deep down', - 'type': 'string', - 'secret': True - } - } - } + "arg_nested_secret": { + "description": "Deep down", + "type": "string", + "secret": True, + }, + }, + }, } TEST_SECRET_NESTED_OBJECTS_SECRET_PARAMS = { - 'arg_object': { - 'arg_nested_secret': 'string', - 'arg_nested_object': 'object' - }, - 'arg_secret_object': 'object' + "arg_object": {"arg_nested_secret": "string", "arg_nested_object": "object"}, + "arg_secret_object": "object", } ################################################################################ TEST_SECRET_NESTED_ARRAYS_SCHEMA = { - 'arg_optional_array': { - 'description': 'Mirror', - 'type': 'array', - 'secret': True, - 'items': { - 'description': 'Deep down', - 'type': 'string' - } + "arg_optional_array": { + "description": "Mirror", + "type": "array", + "secret": True, + "items": {"description": "Deep down", "type": "string"}, }, - 'arg_optional_double_array': { - 'description': 'Mirror', - 'type': 'array', - 'secret': True, - 'items': { - 'type': 'array', - 'items': { - 'description': 'Deep down', - 'type': 'string', - } - } + "arg_optional_double_array": { + "description": "Mirror", + "type": "array", + "secret": True, + "items": { + "type": "array", + "items": { + "description": "Deep down", + "type": "string", + }, + }, }, - 'arg_optional_tripple_array': { - 'description': 'Mirror', - 'type': 'array', - 'items': { - 'type': 'array', - 'secret': True, - 'items': { - 'type': 'array', - 'items': { - 'description': 'Deep down', - 'type': 'string', - } - } - } + "arg_optional_tripple_array": { + "description": "Mirror", + "type": "array", + "items": { + "type": "array", + "secret": True, + "items": { + "type": "array", + "items": { + "description": "Deep down", + "type": "string", + }, + }, + }, + }, + "arg_optional_quad_array": { + "description": "Mirror", + "type": "array", + "items": { + "type": "array", + "items": { + "type": "array", + "secret": True, + "items": { + "type": "array", + "items": { + "description": "Deep down", + "type": "string", + }, + }, + }, + }, }, - 'arg_optional_quad_array': { - 'description': 'Mirror', - 'type': 'array', - 'items': { - 'type': 'array', - 'items': { - 'type': 'array', - 'secret': True, - 'items': { - 'type': 'array', - 'items': { - 'description': 'Deep down', - 'type': 'string', - } - } - } - } - } } TEST_SECRET_NESTED_ARRAYS_SECRET_PARAMS = { - 'arg_optional_array': 'array', - 'arg_optional_double_array': 'array', - 'arg_optional_tripple_array': [ - 'array' - ], - 'arg_optional_quad_array': [ - [ - 'array' - ] - ] + "arg_optional_array": "array", + "arg_optional_double_array": "array", + "arg_optional_tripple_array": ["array"], + "arg_optional_quad_array": [["array"]], } ################################################################################ class SecretUtilsTestCase(unittest2.TestCase): - def test_get_secret_parameters_flat(self): result = secrets.get_secret_parameters(TEST_FLAT_SCHEMA) self.assertEqual(TEST_FLAT_SECRET_PARAMS, result) @@ -586,7 +478,9 @@ def test_get_secret_parameters_nested_object_with_array(self): self.assertEqual(TEST_NESTED_OBJECT_WITH_ARRAY_SECRET_PARAMS, result) def test_get_secret_parameters_nested_object_with_double_array(self): - result = secrets.get_secret_parameters(TEST_NESTED_OBJECT_WITH_DOUBLE_ARRAY_SCHEMA) + result = secrets.get_secret_parameters( + TEST_NESTED_OBJECT_WITH_DOUBLE_ARRAY_SCHEMA + ) self.assertEqual(TEST_NESTED_OBJECT_WITH_DOUBLE_ARRAY_SECRET_PARAMS, result) def test_get_secret_parameters_nested_array_with_object(self): @@ -621,178 +515,128 @@ def test_get_secret_parameters_secret_nested_objects(self): def test_mask_secret_parameters_flat(self): parameters = { - 'arg_required_no_default': 'test', - 'arg_optional_no_type_secret': None + "arg_required_no_default": "test", + "arg_optional_no_type_secret": None, } - result = secrets.mask_secret_parameters(parameters, - TEST_FLAT_SECRET_PARAMS) + result = secrets.mask_secret_parameters(parameters, TEST_FLAT_SECRET_PARAMS) expected = { - 'arg_required_no_default': 'test', - 'arg_optional_no_type_secret': MASKED_ATTRIBUTE_VALUE + "arg_required_no_default": "test", + "arg_optional_no_type_secret": MASKED_ATTRIBUTE_VALUE, } self.assertEqual(expected, result) def test_mask_secret_parameters_no_secrets(self): - parameters = {'arg_required_no_default': 'junk'} - result = secrets.mask_secret_parameters(parameters, - TEST_NO_SECRETS_SECRET_PARAMS) - expected = { - 'arg_required_no_default': 'junk' - } + parameters = {"arg_required_no_default": "junk"} + result = secrets.mask_secret_parameters( + parameters, TEST_NO_SECRETS_SECRET_PARAMS + ) + expected = {"arg_required_no_default": "junk"} self.assertEqual(expected, result) def test_mask_secret_parameters_nested_objects(self): parameters = { - 'arg_optional_object': { - 'arg_nested_secret': 'nested Secret', - 'arg_nested_object': { - 'arg_double_nested_secret': 'double nested $ecret', - } + "arg_optional_object": { + "arg_nested_secret": "nested Secret", + "arg_nested_object": { + "arg_double_nested_secret": "double nested $ecret", + }, } } - result = secrets.mask_secret_parameters(parameters, - TEST_NESTED_OBJECTS_SECRET_PARAMS) + result = secrets.mask_secret_parameters( + parameters, TEST_NESTED_OBJECTS_SECRET_PARAMS + ) expected = { - 'arg_optional_object': { - 'arg_nested_secret': MASKED_ATTRIBUTE_VALUE, - 'arg_nested_object': { - 'arg_double_nested_secret': MASKED_ATTRIBUTE_VALUE, - } + "arg_optional_object": { + "arg_nested_secret": MASKED_ATTRIBUTE_VALUE, + "arg_nested_object": { + "arg_double_nested_secret": MASKED_ATTRIBUTE_VALUE, + }, } } self.assertEqual(expected, result) def test_mask_secret_parameters_array(self): parameters = { - 'arg_optional_array': [ - '$ecret $tring 1', - '$ecret $tring 2', - '$ecret $tring 3' + "arg_optional_array": [ + "$ecret $tring 1", + "$ecret $tring 2", + "$ecret $tring 3", ] } - result = secrets.mask_secret_parameters(parameters, - TEST_ARRAY_SECRET_PARAMS) + result = secrets.mask_secret_parameters(parameters, TEST_ARRAY_SECRET_PARAMS) expected = { - 'arg_optional_array': [ + "arg_optional_array": [ + MASKED_ATTRIBUTE_VALUE, MASKED_ATTRIBUTE_VALUE, MASKED_ATTRIBUTE_VALUE, - MASKED_ATTRIBUTE_VALUE ] } self.assertEqual(expected, result) def test_mask_secret_parameters_root_array(self): parameters = [ - { - 'secret_field_in_object': 'Secret $tr!ng' - }, - { - 'secret_field_in_object': 'Secret $tr!ng 2' - }, - { - 'secret_field_in_object': 'Secret $tr!ng 3' - }, - { - 'secret_field_in_object': 'Secret $tr!ng 4' - } + {"secret_field_in_object": "Secret $tr!ng"}, + {"secret_field_in_object": "Secret $tr!ng 2"}, + {"secret_field_in_object": "Secret $tr!ng 3"}, + {"secret_field_in_object": "Secret $tr!ng 4"}, ] - result = secrets.mask_secret_parameters(parameters, TEST_ROOT_ARRAY_SECRET_PARAMS) + result = secrets.mask_secret_parameters( + parameters, TEST_ROOT_ARRAY_SECRET_PARAMS + ) expected = [ - { - 'secret_field_in_object': MASKED_ATTRIBUTE_VALUE - }, - { - 'secret_field_in_object': MASKED_ATTRIBUTE_VALUE - }, - { - 'secret_field_in_object': MASKED_ATTRIBUTE_VALUE - }, - { - 'secret_field_in_object': MASKED_ATTRIBUTE_VALUE - } + {"secret_field_in_object": MASKED_ATTRIBUTE_VALUE}, + {"secret_field_in_object": MASKED_ATTRIBUTE_VALUE}, + {"secret_field_in_object": MASKED_ATTRIBUTE_VALUE}, + {"secret_field_in_object": MASKED_ATTRIBUTE_VALUE}, ] self.assertEqual(expected, result) def test_mask_secret_parameters_root_object(self): - parameters = { - 'arg_level_one': - { - 'secret_field_in_object': 'Secret $tr!ng' - } - } + parameters = {"arg_level_one": {"secret_field_in_object": "Secret $tr!ng"}} - result = secrets.mask_secret_parameters(parameters, TEST_ROOT_OBJECT_SECRET_PARAMS) - expected = { - 'arg_level_one': - { - 'secret_field_in_object': MASKED_ATTRIBUTE_VALUE - } - } + result = secrets.mask_secret_parameters( + parameters, TEST_ROOT_OBJECT_SECRET_PARAMS + ) + expected = {"arg_level_one": {"secret_field_in_object": MASKED_ATTRIBUTE_VALUE}} self.assertEqual(expected, result) def test_mask_secret_parameters_nested_arrays(self): parameters = { - 'arg_optional_array': [ - 'secret 1', - 'secret 2', - 'secret 3', + "arg_optional_array": [ + "secret 1", + "secret 2", + "secret 3", ], - 'arg_optional_double_array': [ + "arg_optional_double_array": [ [ - 'secret 4', - 'secret 5', - 'secret 6', + "secret 4", + "secret 5", + "secret 6", ], [ - 'secret 7', - 'secret 8', - 'secret 9', - ] - ], - 'arg_optional_tripple_array': [ - [ - [ - 'secret 10', - 'secret 11' - ], - [ - 'secret 12', - 'secret 13', - 'secret 14' - ] + "secret 7", + "secret 8", + "secret 9", ], - [ - [ - 'secret 15', - 'secret 16' - ] - ] ], - 'arg_optional_quad_array': [ - [ - [ - [ - 'secret 17', - 'secret 18' - ], - [ - 'secret 19' - ] - ] - ] - ] + "arg_optional_tripple_array": [ + [["secret 10", "secret 11"], ["secret 12", "secret 13", "secret 14"]], + [["secret 15", "secret 16"]], + ], + "arg_optional_quad_array": [[[["secret 17", "secret 18"], ["secret 19"]]]], } - result = secrets.mask_secret_parameters(parameters, - TEST_NESTED_ARRAYS_SECRET_PARAMS) + result = secrets.mask_secret_parameters( + parameters, TEST_NESTED_ARRAYS_SECRET_PARAMS + ) expected = { - 'arg_optional_array': [ + "arg_optional_array": [ MASKED_ATTRIBUTE_VALUE, MASKED_ATTRIBUTE_VALUE, MASKED_ATTRIBUTE_VALUE, ], - 'arg_optional_double_array': [ + "arg_optional_double_array": [ [ MASKED_ATTRIBUTE_VALUE, MASKED_ATTRIBUTE_VALUE, @@ -802,58 +646,46 @@ def test_mask_secret_parameters_nested_arrays(self): MASKED_ATTRIBUTE_VALUE, MASKED_ATTRIBUTE_VALUE, MASKED_ATTRIBUTE_VALUE, - ] + ], ], - 'arg_optional_tripple_array': [ + "arg_optional_tripple_array": [ [ + [MASKED_ATTRIBUTE_VALUE, MASKED_ATTRIBUTE_VALUE], [ MASKED_ATTRIBUTE_VALUE, - MASKED_ATTRIBUTE_VALUE - ], - [ MASKED_ATTRIBUTE_VALUE, MASKED_ATTRIBUTE_VALUE, - MASKED_ATTRIBUTE_VALUE - ] + ], ], - [ - [ - MASKED_ATTRIBUTE_VALUE, - MASKED_ATTRIBUTE_VALUE - ] - ] + [[MASKED_ATTRIBUTE_VALUE, MASKED_ATTRIBUTE_VALUE]], ], - 'arg_optional_quad_array': [ + "arg_optional_quad_array": [ [ [ - [ - MASKED_ATTRIBUTE_VALUE, - MASKED_ATTRIBUTE_VALUE - ], - [ - MASKED_ATTRIBUTE_VALUE - ] + [MASKED_ATTRIBUTE_VALUE, MASKED_ATTRIBUTE_VALUE], + [MASKED_ATTRIBUTE_VALUE], ] ] - ] + ], } self.assertEqual(expected, result) def test_mask_secret_parameters_nested_object_with_array(self): parameters = { - 'arg_optional_object_with_array': { - 'arg_nested_array': [ - 'secret array value 1', - 'secret array value 2', - 'secret array value 3', + "arg_optional_object_with_array": { + "arg_nested_array": [ + "secret array value 1", + "secret array value 2", + "secret array value 3", ] } } - result = secrets.mask_secret_parameters(parameters, - TEST_NESTED_OBJECT_WITH_ARRAY_SECRET_PARAMS) + result = secrets.mask_secret_parameters( + parameters, TEST_NESTED_OBJECT_WITH_ARRAY_SECRET_PARAMS + ) expected = { - 'arg_optional_object_with_array': { - 'arg_nested_array': [ + "arg_optional_object_with_array": { + "arg_nested_array": [ MASKED_ATTRIBUTE_VALUE, MASKED_ATTRIBUTE_VALUE, MASKED_ATTRIBUTE_VALUE, @@ -864,36 +696,33 @@ def test_mask_secret_parameters_nested_object_with_array(self): def test_mask_secret_parameters_nested_object_with_double_array(self): parameters = { - 'arg_optional_object_with_double_array': { - 'arg_double_nested_array': [ + "arg_optional_object_with_double_array": { + "arg_double_nested_array": [ + ["secret 1", "secret 2", "secret 3"], [ - 'secret 1', - 'secret 2', - 'secret 3' + "secret 4", + "secret 5", + "secret 6", ], [ - 'secret 4', - 'secret 5', - 'secret 6', + "secret 7", + "secret 8", + "secret 9", + "secret 10", ], - [ - 'secret 7', - 'secret 8', - 'secret 9', - 'secret 10', - ] ] } } - result = secrets.mask_secret_parameters(parameters, - TEST_NESTED_OBJECT_WITH_DOUBLE_ARRAY_SECRET_PARAMS) + result = secrets.mask_secret_parameters( + parameters, TEST_NESTED_OBJECT_WITH_DOUBLE_ARRAY_SECRET_PARAMS + ) expected = { - 'arg_optional_object_with_double_array': { - 'arg_double_nested_array': [ + "arg_optional_object_with_double_array": { + "arg_double_nested_array": [ [ MASKED_ATTRIBUTE_VALUE, MASKED_ATTRIBUTE_VALUE, - MASKED_ATTRIBUTE_VALUE + MASKED_ATTRIBUTE_VALUE, ], [ MASKED_ATTRIBUTE_VALUE, @@ -905,7 +734,7 @@ def test_mask_secret_parameters_nested_object_with_double_array(self): MASKED_ATTRIBUTE_VALUE, MASKED_ATTRIBUTE_VALUE, MASKED_ATTRIBUTE_VALUE, - ] + ], ] } } @@ -913,187 +742,132 @@ def test_mask_secret_parameters_nested_object_with_double_array(self): def test_mask_secret_parameters_nested_array_with_object(self): parameters = { - 'arg_optional_array_with_object': [ - { - 'arg_nested_secret': 'secret 1' - }, - { - 'arg_nested_secret': 'secret 2' - }, - { - 'arg_nested_secret': 'secret 3' - } + "arg_optional_array_with_object": [ + {"arg_nested_secret": "secret 1"}, + {"arg_nested_secret": "secret 2"}, + {"arg_nested_secret": "secret 3"}, ] } - result = secrets.mask_secret_parameters(parameters, - TEST_NESTED_ARRAY_WITH_OBJECT_SECRET_PARAMS) + result = secrets.mask_secret_parameters( + parameters, TEST_NESTED_ARRAY_WITH_OBJECT_SECRET_PARAMS + ) expected = { - 'arg_optional_array_with_object': [ - { - 'arg_nested_secret': MASKED_ATTRIBUTE_VALUE - }, - { - 'arg_nested_secret': MASKED_ATTRIBUTE_VALUE - }, - { - 'arg_nested_secret': MASKED_ATTRIBUTE_VALUE - } + "arg_optional_array_with_object": [ + {"arg_nested_secret": MASKED_ATTRIBUTE_VALUE}, + {"arg_nested_secret": MASKED_ATTRIBUTE_VALUE}, + {"arg_nested_secret": MASKED_ATTRIBUTE_VALUE}, ] } self.assertEqual(expected, result) def test_mask_secret_parameters_secret_array(self): - parameters = { - 'arg_secret_array': [ - "abc", - 123, - True - ] - } - result = secrets.mask_secret_parameters(parameters, - TEST_SECRET_ARRAY_SECRET_PARAMS) - expected = { - 'arg_secret_array': MASKED_ATTRIBUTE_VALUE - } + parameters = {"arg_secret_array": ["abc", 123, True]} + result = secrets.mask_secret_parameters( + parameters, TEST_SECRET_ARRAY_SECRET_PARAMS + ) + expected = {"arg_secret_array": MASKED_ATTRIBUTE_VALUE} self.assertEqual(expected, result) def test_mask_secret_parameters_secret_object(self): parameters = { - 'arg_secret_object': - { + "arg_secret_object": { "abc": 123, "key": "value", "bool": True, "array": ["x", "y", "z"], - "obj": - { - "x": "deep" - } + "obj": {"x": "deep"}, } } - result = secrets.mask_secret_parameters(parameters, - TEST_SECRET_OBJECT_SECRET_PARAMS) - expected = { - 'arg_secret_object': MASKED_ATTRIBUTE_VALUE - } + result = secrets.mask_secret_parameters( + parameters, TEST_SECRET_OBJECT_SECRET_PARAMS + ) + expected = {"arg_secret_object": MASKED_ATTRIBUTE_VALUE} self.assertEqual(expected, result) def test_mask_secret_parameters_secret_root_array(self): - parameters = [ - "abc", - 123, - True - ] - result = secrets.mask_secret_parameters(parameters, - TEST_SECRET_ROOT_ARRAY_SECRET_PARAMS) + parameters = ["abc", 123, True] + result = secrets.mask_secret_parameters( + parameters, TEST_SECRET_ROOT_ARRAY_SECRET_PARAMS + ) expected = MASKED_ATTRIBUTE_VALUE self.assertEqual(expected, result) def test_mask_secret_parameters_secret_root_object(self): - parameters = { - 'arg_level_one': - { - 'secret_field_in_object': 'Secret $tr!ng' - } - } - result = secrets.mask_secret_parameters(parameters, - TEST_SECRET_ROOT_OBJECT_SECRET_PARAMS) + parameters = {"arg_level_one": {"secret_field_in_object": "Secret $tr!ng"}} + result = secrets.mask_secret_parameters( + parameters, TEST_SECRET_ROOT_OBJECT_SECRET_PARAMS + ) expected = MASKED_ATTRIBUTE_VALUE self.assertEqual(expected, result) def test_mask_secret_parameters_secret_nested_arrays(self): parameters = { - 'arg_optional_array': [ - 'secret 1', - 'secret 2', - 'secret 3', + "arg_optional_array": [ + "secret 1", + "secret 2", + "secret 3", ], - 'arg_optional_double_array': [ + "arg_optional_double_array": [ [ - 'secret 4', - 'secret 5', - 'secret 6', + "secret 4", + "secret 5", + "secret 6", ], [ - 'secret 7', - 'secret 8', - 'secret 9', - ] - ], - 'arg_optional_tripple_array': [ - [ - [ - 'secret 10', - 'secret 11' - ], - [ - 'secret 12', - 'secret 13', - 'secret 14' - ] + "secret 7", + "secret 8", + "secret 9", ], - [ - [ - 'secret 15', - 'secret 16' - ] - ] ], - 'arg_optional_quad_array': [ - [ - [ - [ - 'secret 17', - 'secret 18' - ], - [ - 'secret 19' - ] - ] - ] - ] + "arg_optional_tripple_array": [ + [["secret 10", "secret 11"], ["secret 12", "secret 13", "secret 14"]], + [["secret 15", "secret 16"]], + ], + "arg_optional_quad_array": [[[["secret 17", "secret 18"], ["secret 19"]]]], } - result = secrets.mask_secret_parameters(parameters, - TEST_SECRET_NESTED_ARRAYS_SECRET_PARAMS) + result = secrets.mask_secret_parameters( + parameters, TEST_SECRET_NESTED_ARRAYS_SECRET_PARAMS + ) expected = { - 'arg_optional_array': MASKED_ATTRIBUTE_VALUE, - 'arg_optional_double_array': MASKED_ATTRIBUTE_VALUE, - 'arg_optional_tripple_array': [ + "arg_optional_array": MASKED_ATTRIBUTE_VALUE, + "arg_optional_double_array": MASKED_ATTRIBUTE_VALUE, + "arg_optional_tripple_array": [ MASKED_ATTRIBUTE_VALUE, MASKED_ATTRIBUTE_VALUE, ], - 'arg_optional_quad_array': [ + "arg_optional_quad_array": [ [ MASKED_ATTRIBUTE_VALUE, ] - ] + ], } self.assertEqual(expected, result) def test_mask_secret_parameters_secret_nested_objects(self): parameters = { - 'arg_object': { - 'arg_nested_secret': 'nested Secret', - 'arg_nested_object': { - 'arg_double_nested_secret': 'double nested $ecret', - } + "arg_object": { + "arg_nested_secret": "nested Secret", + "arg_nested_object": { + "arg_double_nested_secret": "double nested $ecret", + }, + }, + "arg_secret_object": { + "arg_nested_secret": "secret data", + "arg_nested_object": { + "arg_double_nested_secret": "double nested $ecret", + }, }, - 'arg_secret_object': { - 'arg_nested_secret': 'secret data', - 'arg_nested_object': { - 'arg_double_nested_secret': 'double nested $ecret', - } - } } - result = secrets.mask_secret_parameters(parameters, - TEST_SECRET_NESTED_OBJECTS_SECRET_PARAMS) + result = secrets.mask_secret_parameters( + parameters, TEST_SECRET_NESTED_OBJECTS_SECRET_PARAMS + ) expected = { - 'arg_object': { - 'arg_nested_secret': MASKED_ATTRIBUTE_VALUE, - 'arg_nested_object': MASKED_ATTRIBUTE_VALUE, + "arg_object": { + "arg_nested_secret": MASKED_ATTRIBUTE_VALUE, + "arg_nested_object": MASKED_ATTRIBUTE_VALUE, }, - 'arg_secret_object': MASKED_ATTRIBUTE_VALUE, + "arg_secret_object": MASKED_ATTRIBUTE_VALUE, } self.assertEqual(expected, result) diff --git a/st2common/tests/unit/test_util_shell.py b/st2common/tests/unit/test_util_shell.py index 86c37f2ad16..4a2a00e3433 100644 --- a/st2common/tests/unit/test_util_shell.py +++ b/st2common/tests/unit/test_util_shell.py @@ -23,38 +23,26 @@ class ShellUtilsTestCase(unittest2.TestCase): def test_quote_unix(self): - arguments = [ - 'foo', - 'foo bar', - 'foo1 bar1', - '"foo"', - '"foo" "bar"', - "'foo bar'" - ] + arguments = ["foo", "foo bar", "foo1 bar1", '"foo"', '"foo" "bar"', "'foo bar'"] expected_values = [ """ foo """, - """ 'foo bar' """, - """ 'foo1 bar1' """, - """ '"foo"' """, - """ '"foo" "bar"' """, - """ ''"'"'foo bar'"'"'' - """ + """, ] for argument, expected_value in zip(arguments, expected_values): @@ -63,38 +51,26 @@ def test_quote_unix(self): self.assertEqual(actual_value, expected_value.strip()) def test_quote_windows(self): - arguments = [ - 'foo', - 'foo bar', - 'foo1 bar1', - '"foo"', - '"foo" "bar"', - "'foo bar'" - ] + arguments = ["foo", "foo bar", "foo1 bar1", '"foo"', '"foo" "bar"', "'foo bar'"] expected_values = [ """ foo """, - """ "foo bar" """, - """ "foo1 bar1" """, - """ \\"foo\\" """, - """ "\\"foo\\" \\"bar\\"" """, - """ "'foo bar'" - """ + """, ] for argument, expected_value in zip(arguments, expected_values): diff --git a/st2common/tests/unit/test_util_templating.py b/st2common/tests/unit/test_util_templating.py index 1756590bc18..c6cd5398493 100644 --- a/st2common/tests/unit/test_util_templating.py +++ b/st2common/tests/unit/test_util_templating.py @@ -26,41 +26,45 @@ def setUp(self): super(TemplatingUtilsTestCase, self).setUp() # Insert mock DB objects - kvp_1_db = KeyValuePairDB(name='key1', value='valuea') + kvp_1_db = KeyValuePairDB(name="key1", value="valuea") kvp_1_db = KeyValuePair.add_or_update(kvp_1_db) - kvp_2_db = KeyValuePairDB(name='key2', value='valueb') + kvp_2_db = KeyValuePairDB(name="key2", value="valueb") kvp_2_db = KeyValuePair.add_or_update(kvp_2_db) - kvp_3_db = KeyValuePairDB(name='stanley:key1', value='valuestanley1', scope=FULL_USER_SCOPE) + kvp_3_db = KeyValuePairDB( + name="stanley:key1", value="valuestanley1", scope=FULL_USER_SCOPE + ) kvp_3_db = KeyValuePair.add_or_update(kvp_3_db) - kvp_4_db = KeyValuePairDB(name='joe:key1', value='valuejoe1', scope=FULL_USER_SCOPE) + kvp_4_db = KeyValuePairDB( + name="joe:key1", value="valuejoe1", scope=FULL_USER_SCOPE + ) kvp_4_db = KeyValuePair.add_or_update(kvp_4_db) def test_render_template_with_system_and_user_context(self): # 1. No reference to the user inside the template - template = '{{st2kv.system.key1}}' - user = 'stanley' + template = "{{st2kv.system.key1}}" + user = "stanley" result = render_template_with_system_and_user_context(value=template, user=user) - self.assertEqual(result, 'valuea') + self.assertEqual(result, "valuea") - template = '{{st2kv.system.key2}}' - user = 'stanley' + template = "{{st2kv.system.key2}}" + user = "stanley" result = render_template_with_system_and_user_context(value=template, user=user) - self.assertEqual(result, 'valueb') + self.assertEqual(result, "valueb") # 2. Reference to the user inside the template - template = '{{st2kv.user.key1}}' - user = 'stanley' + template = "{{st2kv.user.key1}}" + user = "stanley" result = render_template_with_system_and_user_context(value=template, user=user) - self.assertEqual(result, 'valuestanley1') + self.assertEqual(result, "valuestanley1") - template = '{{st2kv.user.key1}}' - user = 'joe' + template = "{{st2kv.user.key1}}" + user = "joe" result = render_template_with_system_and_user_context(value=template, user=user) - self.assertEqual(result, 'valuejoe1') + self.assertEqual(result, "valuejoe1") diff --git a/st2common/tests/unit/test_util_types.py b/st2common/tests/unit/test_util_types.py index 1213eb69d1e..8b7ef788640 100644 --- a/st2common/tests/unit/test_util_types.py +++ b/st2common/tests/unit/test_util_types.py @@ -17,9 +17,7 @@ from st2common.util.types import OrderedSet -__all__ = [ - 'OrderedTestTypeTestCase' -] +__all__ = ["OrderedTestTypeTestCase"] class OrderedTestTypeTestCase(unittest2.TestCase): diff --git a/st2common/tests/unit/test_util_url.py b/st2common/tests/unit/test_util_url.py index 551aed3e8c7..8b236195932 100644 --- a/st2common/tests/unit/test_util_url.py +++ b/st2common/tests/unit/test_util_url.py @@ -23,16 +23,16 @@ class URLUtilsTestCase(unittest2.TestCase): def test_get_url_without_trailing_slash(self): values = [ - 'http://localhost:1818/foo/bar/', - 'http://localhost:1818/foo/bar', - 'http://localhost:1818/', - 'http://localhost:1818', + "http://localhost:1818/foo/bar/", + "http://localhost:1818/foo/bar", + "http://localhost:1818/", + "http://localhost:1818", ] expected = [ - 'http://localhost:1818/foo/bar', - 'http://localhost:1818/foo/bar', - 'http://localhost:1818', - 'http://localhost:1818', + "http://localhost:1818/foo/bar", + "http://localhost:1818/foo/bar", + "http://localhost:1818", + "http://localhost:1818", ] for value, expected_result in zip(values, expected): diff --git a/st2common/tests/unit/test_versioning_utils.py b/st2common/tests/unit/test_versioning_utils.py index 73d118aa891..de7bbbfeafa 100644 --- a/st2common/tests/unit/test_versioning_utils.py +++ b/st2common/tests/unit/test_versioning_utils.py @@ -23,40 +23,40 @@ class VersioningUtilsTestCase(unittest2.TestCase): def test_complex_semver_match(self): # Positive test case - self.assertTrue(complex_semver_match('1.6.0', '>=1.6.0, <2.2.0')) - self.assertTrue(complex_semver_match('1.6.1', '>=1.6.0, <2.2.0')) - self.assertTrue(complex_semver_match('2.0.0', '>=1.6.0, <2.2.0')) - self.assertTrue(complex_semver_match('2.1.0', '>=1.6.0, <2.2.0')) - self.assertTrue(complex_semver_match('2.1.9', '>=1.6.0, <2.2.0')) + self.assertTrue(complex_semver_match("1.6.0", ">=1.6.0, <2.2.0")) + self.assertTrue(complex_semver_match("1.6.1", ">=1.6.0, <2.2.0")) + self.assertTrue(complex_semver_match("2.0.0", ">=1.6.0, <2.2.0")) + self.assertTrue(complex_semver_match("2.1.0", ">=1.6.0, <2.2.0")) + self.assertTrue(complex_semver_match("2.1.9", ">=1.6.0, <2.2.0")) - self.assertTrue(complex_semver_match('1.6.0', 'all')) - self.assertTrue(complex_semver_match('1.6.1', 'all')) - self.assertTrue(complex_semver_match('2.0.0', 'all')) - self.assertTrue(complex_semver_match('2.1.0', 'all')) + self.assertTrue(complex_semver_match("1.6.0", "all")) + self.assertTrue(complex_semver_match("1.6.1", "all")) + self.assertTrue(complex_semver_match("2.0.0", "all")) + self.assertTrue(complex_semver_match("2.1.0", "all")) - self.assertTrue(complex_semver_match('1.6.0', '>=1.6.0')) - self.assertTrue(complex_semver_match('1.6.1', '>=1.6.0')) - self.assertTrue(complex_semver_match('2.1.0', '>=1.6.0')) + self.assertTrue(complex_semver_match("1.6.0", ">=1.6.0")) + self.assertTrue(complex_semver_match("1.6.1", ">=1.6.0")) + self.assertTrue(complex_semver_match("2.1.0", ">=1.6.0")) # Negative test case - self.assertFalse(complex_semver_match('1.5.0', '>=1.6.0, <2.2.0')) - self.assertFalse(complex_semver_match('0.1.0', '>=1.6.0, <2.2.0')) - self.assertFalse(complex_semver_match('2.2.1', '>=1.6.0, <2.2.0')) - self.assertFalse(complex_semver_match('2.3.0', '>=1.6.0, <2.2.0')) - self.assertFalse(complex_semver_match('3.0.0', '>=1.6.0, <2.2.0')) + self.assertFalse(complex_semver_match("1.5.0", ">=1.6.0, <2.2.0")) + self.assertFalse(complex_semver_match("0.1.0", ">=1.6.0, <2.2.0")) + self.assertFalse(complex_semver_match("2.2.1", ">=1.6.0, <2.2.0")) + self.assertFalse(complex_semver_match("2.3.0", ">=1.6.0, <2.2.0")) + self.assertFalse(complex_semver_match("3.0.0", ">=1.6.0, <2.2.0")) - self.assertFalse(complex_semver_match('1.5.0', '>=1.6.0')) - self.assertFalse(complex_semver_match('0.1.0', '>=1.6.0')) - self.assertFalse(complex_semver_match('1.5.9', '>=1.6.0')) + self.assertFalse(complex_semver_match("1.5.0", ">=1.6.0")) + self.assertFalse(complex_semver_match("0.1.0", ">=1.6.0")) + self.assertFalse(complex_semver_match("1.5.9", ">=1.6.0")) def test_normalize_pack_version(self): # Already a valid semver version string - self.assertEqual(normalize_pack_version('0.2.0'), '0.2.0') - self.assertEqual(normalize_pack_version('0.2.1'), '0.2.1') - self.assertEqual(normalize_pack_version('1.2.1'), '1.2.1') + self.assertEqual(normalize_pack_version("0.2.0"), "0.2.0") + self.assertEqual(normalize_pack_version("0.2.1"), "0.2.1") + self.assertEqual(normalize_pack_version("1.2.1"), "1.2.1") # Not a valid semver version string - self.assertEqual(normalize_pack_version('0.2'), '0.2.0') - self.assertEqual(normalize_pack_version('0.3'), '0.3.0') - self.assertEqual(normalize_pack_version('1.3'), '1.3.0') - self.assertEqual(normalize_pack_version('2.0'), '2.0.0') + self.assertEqual(normalize_pack_version("0.2"), "0.2.0") + self.assertEqual(normalize_pack_version("0.3"), "0.3.0") + self.assertEqual(normalize_pack_version("1.3"), "1.3.0") + self.assertEqual(normalize_pack_version("2.0"), "2.0.0") diff --git a/st2common/tests/unit/test_virtualenvs.py b/st2common/tests/unit/test_virtualenvs.py index 90c0f4e9890..439801f67ab 100644 --- a/st2common/tests/unit/test_virtualenvs.py +++ b/st2common/tests/unit/test_virtualenvs.py @@ -30,30 +30,28 @@ from st2common.util.virtualenvs import setup_pack_virtualenv -__all__ = [ - 'VirtualenvUtilsTestCase' -] +__all__ = ["VirtualenvUtilsTestCase"] # Note: We set base requirements to an empty list to speed up the tests -@mock.patch('st2common.util.virtualenvs.BASE_PACK_REQUIREMENTS', []) +@mock.patch("st2common.util.virtualenvs.BASE_PACK_REQUIREMENTS", []) class VirtualenvUtilsTestCase(CleanFilesTestCase): def setUp(self): super(VirtualenvUtilsTestCase, self).setUp() config.parse_args() dir_path = tempfile.mkdtemp() - cfg.CONF.set_override(name='base_path', override=dir_path, group='system') + cfg.CONF.set_override(name="base_path", override=dir_path, group="system") self.base_path = dir_path - self.virtualenvs_path = os.path.join(self.base_path, 'virtualenvs/') + self.virtualenvs_path = os.path.join(self.base_path, "virtualenvs/") # Make sure dir is deleted on tearDown self.to_delete_directories.append(self.base_path) def test_setup_pack_virtualenv_doesnt_exist_yet(self): # Test a fresh virtualenv creation - pack_name = 'dummy_pack_1' + pack_name = "dummy_pack_1" pack_virtualenv_dir = os.path.join(self.virtualenvs_path, pack_name) # Verify virtualenv directory doesn't exist @@ -61,58 +59,81 @@ def test_setup_pack_virtualenv_doesnt_exist_yet(self): # Create virtualenv # Note: This pack has no requirements - setup_pack_virtualenv(pack_name=pack_name, update=False, - include_pip=False, include_setuptools=False, include_wheel=False) + setup_pack_virtualenv( + pack_name=pack_name, + update=False, + include_pip=False, + include_setuptools=False, + include_wheel=False, + ) # Verify that virtualenv has been created self.assertVirtualenvExists(pack_virtualenv_dir) def test_setup_pack_virtualenv_already_exists(self): # Test a scenario where virtualenv already exists - pack_name = 'dummy_pack_1' + pack_name = "dummy_pack_1" pack_virtualenv_dir = os.path.join(self.virtualenvs_path, pack_name) # Verify virtualenv directory doesn't exist self.assertFalse(os.path.exists(pack_virtualenv_dir)) # Create virtualenv - setup_pack_virtualenv(pack_name=pack_name, update=False, - include_pip=False, include_setuptools=False, include_wheel=False) + setup_pack_virtualenv( + pack_name=pack_name, + update=False, + include_pip=False, + include_setuptools=False, + include_wheel=False, + ) # Verify that virtualenv has been created self.assertVirtualenvExists(pack_virtualenv_dir) # Re-create virtualenv - setup_pack_virtualenv(pack_name=pack_name, update=False, - include_pip=False, include_setuptools=False, include_wheel=False) + setup_pack_virtualenv( + pack_name=pack_name, + update=False, + include_pip=False, + include_setuptools=False, + include_wheel=False, + ) # Verify virtrualenv is still there self.assertVirtualenvExists(pack_virtualenv_dir) def test_setup_virtualenv_update(self): # Test a virtualenv update with pack which has requirements.txt - pack_name = 'dummy_pack_2' + pack_name = "dummy_pack_2" pack_virtualenv_dir = os.path.join(self.virtualenvs_path, pack_name) # Verify virtualenv directory doesn't exist self.assertFalse(os.path.exists(pack_virtualenv_dir)) # Create virtualenv - setup_pack_virtualenv(pack_name=pack_name, update=False, - include_setuptools=False, include_wheel=False) + setup_pack_virtualenv( + pack_name=pack_name, + update=False, + include_setuptools=False, + include_wheel=False, + ) # Verify that virtualenv has been created self.assertVirtualenvExists(pack_virtualenv_dir) # Update it - setup_pack_virtualenv(pack_name=pack_name, update=True, - include_setuptools=False, include_wheel=False) + setup_pack_virtualenv( + pack_name=pack_name, + update=True, + include_setuptools=False, + include_wheel=False, + ) # Verify virtrualenv is still there self.assertVirtualenvExists(pack_virtualenv_dir) def test_setup_virtualenv_invalid_dependency_in_requirements_file(self): - pack_name = 'pack_invalid_requirements' + pack_name = "pack_invalid_requirements" pack_virtualenv_dir = os.path.join(self.virtualenvs_path, pack_name) # Verify virtualenv directory doesn't exist @@ -120,182 +141,240 @@ def test_setup_virtualenv_invalid_dependency_in_requirements_file(self): # Try to create virtualenv, assert that it fails try: - setup_pack_virtualenv(pack_name=pack_name, update=False, - include_setuptools=False, include_wheel=False) + setup_pack_virtualenv( + pack_name=pack_name, + update=False, + include_setuptools=False, + include_wheel=False, + ) except Exception as e: - self.assertIn('Failed to install requirements from', six.text_type(e)) - self.assertTrue('No matching distribution found for someinvalidname' in - six.text_type(e)) + self.assertIn("Failed to install requirements from", six.text_type(e)) + self.assertTrue( + "No matching distribution found for someinvalidname" in six.text_type(e) + ) else: - self.fail('Exception not thrown') - - @mock.patch.object(virtualenvs, 'run_command', mock.MagicMock(return_value=(0, '', ''))) - @mock.patch.object(virtualenvs, 'get_env_for_subprocess_command', - mock.MagicMock(return_value={})) + self.fail("Exception not thrown") + + @mock.patch.object( + virtualenvs, "run_command", mock.MagicMock(return_value=(0, "", "")) + ) + @mock.patch.object( + virtualenvs, "get_env_for_subprocess_command", mock.MagicMock(return_value={}) + ) def test_install_requirement_without_proxy(self): - pack_virtualenv_dir = '/opt/stackstorm/virtualenvs/dummy_pack_tests/' - requirement = 'six>=1.9.0' + pack_virtualenv_dir = "/opt/stackstorm/virtualenvs/dummy_pack_tests/" + requirement = "six>=1.9.0" install_requirement(pack_virtualenv_dir, requirement, proxy_config=None) expected_args = { - 'cmd': [ - '/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip', - 'install', 'six>=1.9.0' + "cmd": [ + "/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip", + "install", + "six>=1.9.0", ], - 'env': {} + "env": {}, } virtualenvs.run_command.assert_called_once_with(**expected_args) - @mock.patch.object(virtualenvs, 'run_command', mock.MagicMock(return_value=(0, '', ''))) - @mock.patch.object(virtualenvs, 'get_env_for_subprocess_command', - mock.MagicMock(return_value={})) + @mock.patch.object( + virtualenvs, "run_command", mock.MagicMock(return_value=(0, "", "")) + ) + @mock.patch.object( + virtualenvs, "get_env_for_subprocess_command", mock.MagicMock(return_value={}) + ) def test_install_requirement_with_http_proxy(self): - pack_virtualenv_dir = '/opt/stackstorm/virtualenvs/dummy_pack_tests/' - requirement = 'six>=1.9.0' - proxy_config = { - 'http_proxy': 'http://192.168.1.5:8080' - } + pack_virtualenv_dir = "/opt/stackstorm/virtualenvs/dummy_pack_tests/" + requirement = "six>=1.9.0" + proxy_config = {"http_proxy": "http://192.168.1.5:8080"} install_requirement(pack_virtualenv_dir, requirement, proxy_config=proxy_config) expected_args = { - 'cmd': [ - '/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip', - '--proxy', 'http://192.168.1.5:8080', - 'install', 'six>=1.9.0' + "cmd": [ + "/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip", + "--proxy", + "http://192.168.1.5:8080", + "install", + "six>=1.9.0", ], - 'env': {} + "env": {}, } virtualenvs.run_command.assert_called_once_with(**expected_args) - @mock.patch.object(virtualenvs, 'run_command', mock.MagicMock(return_value=(0, '', ''))) - @mock.patch.object(virtualenvs, 'get_env_for_subprocess_command', - mock.MagicMock(return_value={})) + @mock.patch.object( + virtualenvs, "run_command", mock.MagicMock(return_value=(0, "", "")) + ) + @mock.patch.object( + virtualenvs, "get_env_for_subprocess_command", mock.MagicMock(return_value={}) + ) def test_install_requirement_with_https_proxy(self): - pack_virtualenv_dir = '/opt/stackstorm/virtualenvs/dummy_pack_tests/' - requirement = 'six>=1.9.0' + pack_virtualenv_dir = "/opt/stackstorm/virtualenvs/dummy_pack_tests/" + requirement = "six>=1.9.0" proxy_config = { - 'https_proxy': 'https://192.168.1.5:8080', - 'proxy_ca_bundle_path': '/etc/ssl/certs/mitmproxy-ca.pem' + "https_proxy": "https://192.168.1.5:8080", + "proxy_ca_bundle_path": "/etc/ssl/certs/mitmproxy-ca.pem", } install_requirement(pack_virtualenv_dir, requirement, proxy_config=proxy_config) expected_args = { - 'cmd': [ - '/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip', - '--proxy', 'https://192.168.1.5:8080', - '--cert', '/etc/ssl/certs/mitmproxy-ca.pem', - 'install', 'six>=1.9.0' + "cmd": [ + "/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip", + "--proxy", + "https://192.168.1.5:8080", + "--cert", + "/etc/ssl/certs/mitmproxy-ca.pem", + "install", + "six>=1.9.0", ], - 'env': {} + "env": {}, } virtualenvs.run_command.assert_called_once_with(**expected_args) - @mock.patch.object(virtualenvs, 'run_command', mock.MagicMock(return_value=(0, '', ''))) - @mock.patch.object(virtualenvs, 'get_env_for_subprocess_command', - mock.MagicMock(return_value={})) + @mock.patch.object( + virtualenvs, "run_command", mock.MagicMock(return_value=(0, "", "")) + ) + @mock.patch.object( + virtualenvs, "get_env_for_subprocess_command", mock.MagicMock(return_value={}) + ) def test_install_requirement_with_https_proxy_no_cert(self): - pack_virtualenv_dir = '/opt/stackstorm/virtualenvs/dummy_pack_tests/' - requirement = 'six>=1.9.0' + pack_virtualenv_dir = "/opt/stackstorm/virtualenvs/dummy_pack_tests/" + requirement = "six>=1.9.0" proxy_config = { - 'https_proxy': 'https://192.168.1.5:8080', + "https_proxy": "https://192.168.1.5:8080", } install_requirement(pack_virtualenv_dir, requirement, proxy_config=proxy_config) expected_args = { - 'cmd': [ - '/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip', - '--proxy', 'https://192.168.1.5:8080', - 'install', 'six>=1.9.0' + "cmd": [ + "/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip", + "--proxy", + "https://192.168.1.5:8080", + "install", + "six>=1.9.0", ], - 'env': {} + "env": {}, } virtualenvs.run_command.assert_called_once_with(**expected_args) - @mock.patch.object(virtualenvs, 'run_command', mock.MagicMock(return_value=(0, '', ''))) - @mock.patch.object(virtualenvs, 'get_env_for_subprocess_command', - mock.MagicMock(return_value={})) + @mock.patch.object( + virtualenvs, "run_command", mock.MagicMock(return_value=(0, "", "")) + ) + @mock.patch.object( + virtualenvs, "get_env_for_subprocess_command", mock.MagicMock(return_value={}) + ) def test_install_requirements_without_proxy(self): - pack_virtualenv_dir = '/opt/stackstorm/virtualenvs/dummy_pack_tests/' - requirements_file_path = '/opt/stackstorm/packs/dummy_pack_tests/requirements.txt' - install_requirements(pack_virtualenv_dir, requirements_file_path, proxy_config=None) + pack_virtualenv_dir = "/opt/stackstorm/virtualenvs/dummy_pack_tests/" + requirements_file_path = ( + "/opt/stackstorm/packs/dummy_pack_tests/requirements.txt" + ) + install_requirements( + pack_virtualenv_dir, requirements_file_path, proxy_config=None + ) expected_args = { - 'cmd': [ - '/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip', - 'install', '-U', - '-r', requirements_file_path + "cmd": [ + "/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip", + "install", + "-U", + "-r", + requirements_file_path, ], - 'env': {} + "env": {}, } virtualenvs.run_command.assert_called_once_with(**expected_args) - @mock.patch.object(virtualenvs, 'run_command', mock.MagicMock(return_value=(0, '', ''))) - @mock.patch.object(virtualenvs, 'get_env_for_subprocess_command', - mock.MagicMock(return_value={})) + @mock.patch.object( + virtualenvs, "run_command", mock.MagicMock(return_value=(0, "", "")) + ) + @mock.patch.object( + virtualenvs, "get_env_for_subprocess_command", mock.MagicMock(return_value={}) + ) def test_install_requirements_with_http_proxy(self): - pack_virtualenv_dir = '/opt/stackstorm/virtualenvs/dummy_pack_tests/' - requirements_file_path = '/opt/stackstorm/packs/dummy_pack_tests/requirements.txt' - proxy_config = { - 'http_proxy': 'http://192.168.1.5:8080' - } - install_requirements(pack_virtualenv_dir, requirements_file_path, - proxy_config=proxy_config) + pack_virtualenv_dir = "/opt/stackstorm/virtualenvs/dummy_pack_tests/" + requirements_file_path = ( + "/opt/stackstorm/packs/dummy_pack_tests/requirements.txt" + ) + proxy_config = {"http_proxy": "http://192.168.1.5:8080"} + install_requirements( + pack_virtualenv_dir, requirements_file_path, proxy_config=proxy_config + ) expected_args = { - 'cmd': [ - '/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip', - '--proxy', 'http://192.168.1.5:8080', - 'install', '-U', - '-r', requirements_file_path + "cmd": [ + "/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip", + "--proxy", + "http://192.168.1.5:8080", + "install", + "-U", + "-r", + requirements_file_path, ], - 'env': {} + "env": {}, } virtualenvs.run_command.assert_called_once_with(**expected_args) - @mock.patch.object(virtualenvs, 'run_command', mock.MagicMock(return_value=(0, '', ''))) - @mock.patch.object(virtualenvs, 'get_env_for_subprocess_command', - mock.MagicMock(return_value={})) + @mock.patch.object( + virtualenvs, "run_command", mock.MagicMock(return_value=(0, "", "")) + ) + @mock.patch.object( + virtualenvs, "get_env_for_subprocess_command", mock.MagicMock(return_value={}) + ) def test_install_requirements_with_https_proxy(self): - pack_virtualenv_dir = '/opt/stackstorm/virtualenvs/dummy_pack_tests/' - requirements_file_path = '/opt/stackstorm/packs/dummy_pack_tests/requirements.txt' + pack_virtualenv_dir = "/opt/stackstorm/virtualenvs/dummy_pack_tests/" + requirements_file_path = ( + "/opt/stackstorm/packs/dummy_pack_tests/requirements.txt" + ) proxy_config = { - 'https_proxy': 'https://192.168.1.5:8080', - 'proxy_ca_bundle_path': '/etc/ssl/certs/mitmproxy-ca.pem' + "https_proxy": "https://192.168.1.5:8080", + "proxy_ca_bundle_path": "/etc/ssl/certs/mitmproxy-ca.pem", } - install_requirements(pack_virtualenv_dir, requirements_file_path, - proxy_config=proxy_config) + install_requirements( + pack_virtualenv_dir, requirements_file_path, proxy_config=proxy_config + ) expected_args = { - 'cmd': [ - '/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip', - '--proxy', 'https://192.168.1.5:8080', - '--cert', '/etc/ssl/certs/mitmproxy-ca.pem', - 'install', '-U', - '-r', requirements_file_path + "cmd": [ + "/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip", + "--proxy", + "https://192.168.1.5:8080", + "--cert", + "/etc/ssl/certs/mitmproxy-ca.pem", + "install", + "-U", + "-r", + requirements_file_path, ], - 'env': {} + "env": {}, } virtualenvs.run_command.assert_called_once_with(**expected_args) - @mock.patch.object(virtualenvs, 'run_command', mock.MagicMock(return_value=(0, '', ''))) - @mock.patch.object(virtualenvs, 'get_env_for_subprocess_command', - mock.MagicMock(return_value={})) + @mock.patch.object( + virtualenvs, "run_command", mock.MagicMock(return_value=(0, "", "")) + ) + @mock.patch.object( + virtualenvs, "get_env_for_subprocess_command", mock.MagicMock(return_value={}) + ) def test_install_requirements_with_https_proxy_no_cert(self): - pack_virtualenv_dir = '/opt/stackstorm/virtualenvs/dummy_pack_tests/' - requirements_file_path = '/opt/stackstorm/packs/dummy_pack_tests/requirements.txt' + pack_virtualenv_dir = "/opt/stackstorm/virtualenvs/dummy_pack_tests/" + requirements_file_path = ( + "/opt/stackstorm/packs/dummy_pack_tests/requirements.txt" + ) proxy_config = { - 'https_proxy': 'https://192.168.1.5:8080', + "https_proxy": "https://192.168.1.5:8080", } - install_requirements(pack_virtualenv_dir, requirements_file_path, - proxy_config=proxy_config) + install_requirements( + pack_virtualenv_dir, requirements_file_path, proxy_config=proxy_config + ) expected_args = { - 'cmd': [ - '/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip', - '--proxy', 'https://192.168.1.5:8080', - 'install', '-U', - '-r', requirements_file_path + "cmd": [ + "/opt/stackstorm/virtualenvs/dummy_pack_tests/bin/pip", + "--proxy", + "https://192.168.1.5:8080", + "install", + "-U", + "-r", + requirements_file_path, ], - 'env': {} + "env": {}, } virtualenvs.run_command.assert_called_once_with(**expected_args) def assertVirtualenvExists(self, virtualenv_dir): self.assertTrue(os.path.exists(virtualenv_dir)) self.assertTrue(os.path.isdir(virtualenv_dir)) - self.assertTrue(os.path.isdir(os.path.join(virtualenv_dir, 'bin/'))) + self.assertTrue(os.path.isdir(os.path.join(virtualenv_dir, "bin/"))) return True diff --git a/st2exporter/dist_utils.py b/st2exporter/dist_utils.py index a6f62c8cc2a..2f2043cf29f 100644 --- a/st2exporter/dist_utils.py +++ b/st2exporter/dist_utils.py @@ -43,17 +43,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" __all__ = [ - 'check_pip_is_installed', - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_is_installed", + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] @@ -64,15 +64,15 @@ def check_pip_is_installed(): try: import pip # NOQA except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) return True -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ @@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'): import pip if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) return True @@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path): reqs = [] def _get_link(line): - vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+'] + vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"] for vcs_prefix in vcs_prefixes: - if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)): - req_name = re.findall('.*#egg=(.+)([&|@]).*$', line) + if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)): + req_name = re.findall(".*#egg=(.+)([&|@]).*$", line) if not req_name: - req_name = re.findall('.*#egg=(.+?)$', line) + req_name = re.findall(".*#egg=(.+?)$", line) else: req_name = req_name[0] if not req_name: - raise ValueError('Line "%s" is missing "#egg="' % (line)) + raise ValueError( + 'Line "%s" is missing "#egg="' % (line) + ) - link = line.replace('-e ', '').strip() + link = line.replace("-e ", "").strip() return link, req_name[0] return None, None - with open(requirements_file_path, 'r') as fp: + with open(requirements_file_path, "r") as fp: for line in fp.readlines(): line = line.strip() - if line.startswith('#') or not line: + if line.startswith("#") or not line: continue link, req_name = _get_link(line=line) @@ -131,8 +135,8 @@ def _get_link(line): else: req_name = line - if ';' in req_name: - req_name = req_name.split(';')[0].strip() + if ";" in req_name: + req_name = req_name.split(";")[0].strip() reqs.append(req_name) @@ -146,7 +150,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -155,14 +159,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/st2exporter/setup.py b/st2exporter/setup.py index bfd01f7061f..afaae79cacc 100644 --- a/st2exporter/setup.py +++ b/st2exporter/setup.py @@ -22,9 +22,9 @@ from dist_utils import apply_vagrant_workaround from st2exporter import __version__ -ST2_COMPONENT = 'st2exporter' +ST2_COMPONENT = "st2exporter" BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt') +REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt") install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE) @@ -32,18 +32,18 @@ setup( name=ST2_COMPONENT, version=__version__, - description='{} StackStorm event-driven automation platform component'.format(ST2_COMPONENT), - author='StackStorm', - author_email='info@stackstorm.com', - license='Apache License (2.0)', - url='https://stackstorm.com/', + description="{} StackStorm event-driven automation platform component".format( + ST2_COMPONENT + ), + author="StackStorm", + author_email="info@stackstorm.com", + license="Apache License (2.0)", + url="https://stackstorm.com/", install_requires=install_reqs, dependency_links=dep_links, test_suite=ST2_COMPONENT, zip_safe=False, include_package_data=True, - packages=find_packages(exclude=['setuptools', 'tests']), - scripts=[ - 'bin/st2exporter' - ] + packages=find_packages(exclude=["setuptools", "tests"]), + scripts=["bin/st2exporter"], ) diff --git a/st2exporter/st2exporter/cmd/st2exporter_starter.py b/st2exporter/st2exporter/cmd/st2exporter_starter.py index c5ce157e246..2b86ef2707e 100644 --- a/st2exporter/st2exporter/cmd/st2exporter_starter.py +++ b/st2exporter/st2exporter/cmd/st2exporter_starter.py @@ -14,6 +14,7 @@ # limitations under the License. from st2common.util.monkey_patch import monkey_patch + monkey_patch() import os @@ -25,26 +26,29 @@ from st2exporter import config from st2exporter import worker -__all__ = [ - 'main' -] +__all__ = ["main"] LOG = logging.getLogger(__name__) def _setup(): - common_setup(service='exporter', config=config, setup_db=True, register_mq_exchanges=True, - register_signal_handlers=True) + common_setup( + service="exporter", + config=config, + setup_db=True, + register_mq_exchanges=True, + register_signal_handlers=True, + ) def _run_worker(): - LOG.info('(PID=%s) Exporter started.', os.getpid()) + LOG.info("(PID=%s) Exporter started.", os.getpid()) export_worker = worker.get_worker() try: export_worker.start(wait=True) except (KeyboardInterrupt, SystemExit): - LOG.info('(PID=%s) Exporter stopped.', os.getpid()) + LOG.info("(PID=%s) Exporter stopped.", os.getpid()) export_worker.shutdown() except: return 1 @@ -62,7 +66,7 @@ def main(): except SystemExit as exit_code: sys.exit(exit_code) except: - LOG.exception('(PID=%s) Exporter quit due to exception.', os.getpid()) + LOG.exception("(PID=%s) Exporter quit due to exception.", os.getpid()) return 1 finally: _teardown() diff --git a/st2exporter/st2exporter/config.py b/st2exporter/st2exporter/config.py index 456b09e3653..83f4f45d5df 100644 --- a/st2exporter/st2exporter/config.py +++ b/st2exporter/st2exporter/config.py @@ -31,8 +31,11 @@ def parse_args(args=None): - cfg.CONF(args=args, version=VERSION_STRING, - default_config_files=[DEFAULT_CONFIG_FILE_PATH]) + cfg.CONF( + args=args, + version=VERSION_STRING, + default_config_files=[DEFAULT_CONFIG_FILE_PATH], + ) def get_logging_config_path(): @@ -51,16 +54,20 @@ def _register_common_opts(): def _register_app_opts(): dump_opts = [ cfg.StrOpt( - 'dump_dir', default='/opt/stackstorm/exports/', - help='Directory to dump data to.') + "dump_dir", + default="/opt/stackstorm/exports/", + help="Directory to dump data to.", + ) ] - CONF.register_opts(dump_opts, group='exporter') + CONF.register_opts(dump_opts, group="exporter") logging_opts = [ cfg.StrOpt( - 'logging', default='/etc/st2/logging.exporter.conf', - help='location of the logging.exporter.conf file') + "logging", + default="/etc/st2/logging.exporter.conf", + help="location of the logging.exporter.conf file", + ) ] - CONF.register_opts(logging_opts, group='exporter') + CONF.register_opts(logging_opts, group="exporter") diff --git a/st2exporter/st2exporter/exporter/dumper.py b/st2exporter/st2exporter/exporter/dumper.py index 20595574207..12fbeb4f83b 100644 --- a/st2exporter/st2exporter/exporter/dumper.py +++ b/st2exporter/st2exporter/exporter/dumper.py @@ -26,40 +26,43 @@ from st2common.util import date as date_utils from st2common.util import isotime -__all__ = [ - 'Dumper' -] +__all__ = ["Dumper"] -ALLOWED_EXTENSIONS = ['json'] +ALLOWED_EXTENSIONS = ["json"] -CONVERTERS = { - 'json': JsonConverter -} +CONVERTERS = {"json": JsonConverter} LOG = logging.getLogger(__name__) class Dumper(object): - - def __init__(self, queue, export_dir, file_format='json', - file_prefix='st2-executions-', - batch_size=1000, sleep_interval=60, - max_files_per_sleep=5, - file_writer=None): + def __init__( + self, + queue, + export_dir, + file_format="json", + file_prefix="st2-executions-", + batch_size=1000, + sleep_interval=60, + max_files_per_sleep=5, + file_writer=None, + ): if not queue: - raise Exception('Need a queue to consume data from.') + raise Exception("Need a queue to consume data from.") if not export_dir: - raise Exception('Export dir needed to dump files to.') + raise Exception("Export dir needed to dump files to.") self._export_dir = export_dir if not os.path.exists(self._export_dir): - raise Exception('Dir path %s does not exist. Create one before using exporter.' % - self._export_dir) + raise Exception( + "Dir path %s does not exist. Create one before using exporter." + % self._export_dir + ) self._file_format = file_format.lower() if self._file_format not in ALLOWED_EXTENSIONS: - raise ValueError('Disallowed extension %s.' % file_format) + raise ValueError("Disallowed extension %s." % file_format) self._file_prefix = file_prefix self._batch_size = batch_size @@ -99,8 +102,8 @@ def _get_batch(self): else: executions_to_write.append(item) - LOG.debug('Returning %d items in batch.', len(executions_to_write)) - LOG.debug('Remaining items in queue: %d', self._queue.qsize()) + LOG.debug("Returning %d items in batch.", len(executions_to_write)) + LOG.debug("Remaining items in queue: %d", self._queue.qsize()) return executions_to_write def _flush(self): @@ -111,7 +114,7 @@ def _flush(self): try: self._write_to_disk() except: - LOG.error('Failed writing data to disk.') + LOG.error("Failed writing data to disk.") def _write_to_disk(self): count = 0 @@ -128,7 +131,7 @@ def _write_to_disk(self): self._update_marker(batch) count += 1 except: - LOG.exception('Writing batch to disk failed.') + LOG.exception("Writing batch to disk failed.") return count def _create_date_folder(self): @@ -139,7 +142,7 @@ def _create_date_folder(self): try: os.makedirs(folder_path) except: - LOG.exception('Unable to create sub-folder %s for export.', folder_name) + LOG.exception("Unable to create sub-folder %s for export.", folder_name) raise def _write_batch_to_disk(self, batch): @@ -147,42 +150,44 @@ def _write_batch_to_disk(self, batch): self._file_writer.write_text(doc_to_write, self._get_file_name()) def _get_file_name(self): - timestring = date_utils.get_datetime_utc_now().strftime('%Y-%m-%dT%H:%M:%S.%fZ') - file_name = self._file_prefix + timestring + '.' + self._file_format + timestring = date_utils.get_datetime_utc_now().strftime("%Y-%m-%dT%H:%M:%S.%fZ") + file_name = self._file_prefix + timestring + "." + self._file_format file_name = os.path.join(self._export_dir, self._get_date_folder(), file_name) return file_name def _get_date_folder(self): - return date_utils.get_datetime_utc_now().strftime('%Y-%m-%d') + return date_utils.get_datetime_utc_now().strftime("%Y-%m-%d") def _update_marker(self, batch): timestamps = [isotime.parse(item.end_timestamp) for item in batch] new_marker = max(timestamps) if self._persisted_marker and self._persisted_marker > new_marker: - LOG.warn('Older executions are being exported. Perhaps out of order messages.') + LOG.warn( + "Older executions are being exported. Perhaps out of order messages." + ) try: self._write_marker_to_db(new_marker) except: - LOG.exception('Failed persisting dumper marker to db.') + LOG.exception("Failed persisting dumper marker to db.") else: self._persisted_marker = new_marker return self._persisted_marker def _write_marker_to_db(self, new_marker): - LOG.info('Updating marker in db to: %s', new_marker) + LOG.info("Updating marker in db to: %s", new_marker) markers = DumperMarker.get_all() if len(markers) > 1: - LOG.exception('More than one dumper marker found. Using first found one.') + LOG.exception("More than one dumper marker found. Using first found one.") marker = isotime.format(new_marker, offset=False) updated_at = date_utils.get_datetime_utc_now() if markers: - marker_id = markers[0]['id'] + marker_id = markers[0]["id"] else: marker_id = None diff --git a/st2exporter/st2exporter/exporter/file_writer.py b/st2exporter/st2exporter/exporter/file_writer.py index ec7e4d876c3..49b5b4d63a6 100644 --- a/st2exporter/st2exporter/exporter/file_writer.py +++ b/st2exporter/st2exporter/exporter/file_writer.py @@ -18,15 +18,11 @@ import abc import six -__all__ = [ - 'FileWriter', - 'TextFileWriter' -] +__all__ = ["FileWriter", "TextFileWriter"] @six.add_metaclass(abc.ABCMeta) class FileWriter(object): - @abc.abstractmethod def write(self, data, file_path, replace=False): """ @@ -40,13 +36,13 @@ class TextFileWriter(FileWriter): def write_text(self, text_data, file_path, replace=False, compressed=False): if compressed: - return Exception('Compression not supported.') + return Exception("Compression not supported.") self.write(text_data, file_path, replace=replace) def write(self, data, file_path, replace=False): if os.path.exists(file_path) and not replace: - raise Exception('File %s already exists.' % file_path) + raise Exception("File %s already exists." % file_path) - with open(file_path, 'w') as f: + with open(file_path, "w") as f: f.write(data) diff --git a/st2exporter/st2exporter/exporter/json_converter.py b/st2exporter/st2exporter/exporter/json_converter.py index a288197d417..ba7e95c0a5c 100644 --- a/st2exporter/st2exporter/exporter/json_converter.py +++ b/st2exporter/st2exporter/exporter/json_converter.py @@ -15,15 +15,12 @@ from st2common.util.jsonify import json_encode -__all__ = [ - 'JsonConverter' -] +__all__ = ["JsonConverter"] class JsonConverter(object): - def convert(self, items_list): if not isinstance(items_list, list): - raise ValueError('Items to be converted should be a list.') + raise ValueError("Items to be converted should be a list.") json_doc = json_encode(items_list) return json_doc diff --git a/st2exporter/st2exporter/worker.py b/st2exporter/st2exporter/worker.py index 13273fd587a..a5557ee41fd 100644 --- a/st2exporter/st2exporter/worker.py +++ b/st2exporter/st2exporter/worker.py @@ -18,8 +18,11 @@ from oslo_config import cfg from st2common import log as logging -from st2common.constants.action import (LIVEACTION_STATUS_SUCCEEDED, LIVEACTION_STATUS_FAILED, - LIVEACTION_STATUS_CANCELED) +from st2common.constants.action import ( + LIVEACTION_STATUS_SUCCEEDED, + LIVEACTION_STATUS_FAILED, + LIVEACTION_STATUS_CANCELED, +) from st2common.models.api.execution import ActionExecutionAPI from st2common.models.db.execution import ActionExecutionDB from st2common.persistence.execution import ActionExecution @@ -30,13 +33,13 @@ from st2exporter.exporter.dumper import Dumper from st2common.transport.queues import EXPORTER_WORK_QUEUE -__all__ = [ - 'ExecutionsExporter', - 'get_worker' -] +__all__ = ["ExecutionsExporter", "get_worker"] -COMPLETION_STATUSES = [LIVEACTION_STATUS_SUCCEEDED, LIVEACTION_STATUS_FAILED, - LIVEACTION_STATUS_CANCELED] +COMPLETION_STATUSES = [ + LIVEACTION_STATUS_SUCCEEDED, + LIVEACTION_STATUS_FAILED, + LIVEACTION_STATUS_CANCELED, +] LOG = logging.getLogger(__name__) @@ -46,18 +49,21 @@ class ExecutionsExporter(consumers.MessageHandler): def __init__(self, connection, queues): super(ExecutionsExporter, self).__init__(connection, queues) self.pending_executions = queue.Queue() - self._dumper = Dumper(queue=self.pending_executions, - export_dir=cfg.CONF.exporter.dump_dir) + self._dumper = Dumper( + queue=self.pending_executions, export_dir=cfg.CONF.exporter.dump_dir + ) self._consumer_thread = None def start(self, wait=False): - LOG.info('Bootstrapping executions from db...') + LOG.info("Bootstrapping executions from db...") try: self._bootstrap() except: - LOG.exception('Unable to bootstrap executions from db. Aborting.') + LOG.exception("Unable to bootstrap executions from db. Aborting.") raise - self._consumer_thread = eventlet.spawn(super(ExecutionsExporter, self).start, wait=True) + self._consumer_thread = eventlet.spawn( + super(ExecutionsExporter, self).start, wait=True + ) self._dumper.start() if wait: self.wait() @@ -71,7 +77,7 @@ def shutdown(self): super(ExecutionsExporter, self).shutdown() def process(self, execution): - LOG.debug('Got execution from queue: %s', execution) + LOG.debug("Got execution from queue: %s", execution) if execution.status not in COMPLETION_STATUSES: return execution_api = ActionExecutionAPI.from_model(execution, mask_secrets=True) @@ -80,21 +86,23 @@ def process(self, execution): def _bootstrap(self): marker = self._get_export_marker_from_db() - LOG.info('Using marker %s...' % marker) + LOG.info("Using marker %s..." % marker) missed_executions = self._get_missed_executions_from_db(export_marker=marker) - LOG.info('Found %d executions not exported yet...', len(missed_executions)) + LOG.info("Found %d executions not exported yet...", len(missed_executions)) for missed_execution in missed_executions: if missed_execution.status not in COMPLETION_STATUSES: continue - execution_api = ActionExecutionAPI.from_model(missed_execution, mask_secrets=True) + execution_api = ActionExecutionAPI.from_model( + missed_execution, mask_secrets=True + ) try: - LOG.debug('Missed execution %s', execution_api) + LOG.debug("Missed execution %s", execution_api) self.pending_executions.put_nowait(execution_api) except: - LOG.exception('Failed adding execution to in-memory queue.') + LOG.exception("Failed adding execution to in-memory queue.") continue - LOG.info('Bootstrapped executions...') + LOG.info("Bootstrapped executions...") def _get_export_marker_from_db(self): try: @@ -114,8 +122,8 @@ def _get_missed_executions_from_db(self, export_marker=None): # XXX: Should adapt this query to get only executions with status # in COMPLETION_STATUSES. - filters = {'end_timestamp__gt': export_marker} - LOG.info('Querying for executions with filters: %s', filters) + filters = {"end_timestamp__gt": export_marker} + LOG.info("Querying for executions with filters: %s", filters) return ActionExecution.query(**filters) def _get_all_executions_from_db(self): diff --git a/st2exporter/tests/integration/test_dumper_integration.py b/st2exporter/tests/integration/test_dumper_integration.py index bdb87b12495..0de7b91ed02 100644 --- a/st2exporter/tests/integration/test_dumper_integration.py +++ b/st2exporter/tests/integration/test_dumper_integration.py @@ -28,21 +28,30 @@ from st2tests.base import DbTestCase from st2tests.fixturesloader import FixturesLoader -DESCENDANTS_PACK = 'descendants' +DESCENDANTS_PACK = "descendants" DESCENDANTS_FIXTURES = { - 'executions': ['root_execution.yaml', 'child1_level1.yaml', 'child2_level1.yaml', - 'child1_level2.yaml', 'child2_level2.yaml', 'child3_level2.yaml', - 'child1_level3.yaml', 'child2_level3.yaml', 'child3_level3.yaml'] + "executions": [ + "root_execution.yaml", + "child1_level1.yaml", + "child2_level1.yaml", + "child1_level2.yaml", + "child2_level2.yaml", + "child3_level2.yaml", + "child1_level3.yaml", + "child2_level3.yaml", + "child3_level3.yaml", + ] } class TestDumper(DbTestCase): fixtures_loader = FixturesLoader() - loaded_fixtures = fixtures_loader.load_fixtures(fixtures_pack=DESCENDANTS_PACK, - fixtures_dict=DESCENDANTS_FIXTURES) - loaded_executions = loaded_fixtures['executions'] + loaded_fixtures = fixtures_loader.load_fixtures( + fixtures_pack=DESCENDANTS_PACK, fixtures_dict=DESCENDANTS_FIXTURES + ) + loaded_executions = loaded_fixtures["executions"] execution_apis = [] for execution in loaded_executions.values(): execution_apis.append(ActionExecutionAPI(**execution)) @@ -54,31 +63,45 @@ def get_queue(self): executions_queue.put(execution) return executions_queue - @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True)) + @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True)) def test_write_marker_to_db(self): executions_queue = self.get_queue() - dumper = Dumper(queue=executions_queue, - export_dir='/tmp', batch_size=5, - max_files_per_sleep=1, - file_prefix='st2-stuff-', file_format='json') - timestamps = [isotime.parse(execution.end_timestamp) for execution in self.execution_apis] + dumper = Dumper( + queue=executions_queue, + export_dir="/tmp", + batch_size=5, + max_files_per_sleep=1, + file_prefix="st2-stuff-", + file_format="json", + ) + timestamps = [ + isotime.parse(execution.end_timestamp) for execution in self.execution_apis + ] max_timestamp = max(timestamps) marker_db = dumper._write_marker_to_db(max_timestamp) persisted_marker = marker_db.marker self.assertIsInstance(persisted_marker, six.string_types) self.assertEqual(isotime.parse(persisted_marker), max_timestamp) - @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True)) + @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True)) def test_write_marker_to_db_marker_exists(self): executions_queue = self.get_queue() - dumper = Dumper(queue=executions_queue, - export_dir='/tmp', batch_size=5, - max_files_per_sleep=1, - file_prefix='st2-stuff-', file_format='json') - timestamps = [isotime.parse(execution.end_timestamp) for execution in self.execution_apis] + dumper = Dumper( + queue=executions_queue, + export_dir="/tmp", + batch_size=5, + max_files_per_sleep=1, + file_prefix="st2-stuff-", + file_format="json", + ) + timestamps = [ + isotime.parse(execution.end_timestamp) for execution in self.execution_apis + ] max_timestamp = max(timestamps) first_marker_db = dumper._write_marker_to_db(max_timestamp) - second_marker_db = dumper._write_marker_to_db(max_timestamp + datetime.timedelta(hours=1)) + second_marker_db = dumper._write_marker_to_db( + max_timestamp + datetime.timedelta(hours=1) + ) markers = DumperMarker.get_all() self.assertEqual(len(markers), 1) final_marker_id = markers[0].id diff --git a/st2exporter/tests/integration/test_export_worker.py b/st2exporter/tests/integration/test_export_worker.py index 8b0caf7d863..9237aab0e83 100644 --- a/st2exporter/tests/integration/test_export_worker.py +++ b/st2exporter/tests/integration/test_export_worker.py @@ -27,75 +27,92 @@ from st2tests.base import DbTestCase from st2tests.fixturesloader import FixturesLoader import st2tests.config as tests_config + tests_config.parse_args() -DESCENDANTS_PACK = 'descendants' +DESCENDANTS_PACK = "descendants" DESCENDANTS_FIXTURES = { - 'executions': ['root_execution.yaml', 'child1_level1.yaml', 'child2_level1.yaml', - 'child1_level2.yaml', 'child2_level2.yaml', 'child3_level2.yaml', - 'child1_level3.yaml', 'child2_level3.yaml', 'child3_level3.yaml'] + "executions": [ + "root_execution.yaml", + "child1_level1.yaml", + "child2_level1.yaml", + "child1_level2.yaml", + "child2_level2.yaml", + "child3_level2.yaml", + "child1_level3.yaml", + "child2_level3.yaml", + "child3_level3.yaml", + ] } class TestExportWorker(DbTestCase): - @classmethod def setUpClass(cls): super(TestExportWorker, cls).setUpClass() fixtures_loader = FixturesLoader() - loaded_fixtures = fixtures_loader.save_fixtures_to_db(fixtures_pack=DESCENDANTS_PACK, - fixtures_dict=DESCENDANTS_FIXTURES) - TestExportWorker.saved_executions = loaded_fixtures['executions'] + loaded_fixtures = fixtures_loader.save_fixtures_to_db( + fixtures_pack=DESCENDANTS_PACK, fixtures_dict=DESCENDANTS_FIXTURES + ) + TestExportWorker.saved_executions = loaded_fixtures["executions"] - @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True)) + @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True)) def test_get_marker_from_db(self): marker_dt = date_utils.get_datetime_utc_now() - datetime.timedelta(minutes=5) - marker_db = DumperMarkerDB(marker=isotime.format(marker_dt, offset=False), - updated_at=date_utils.get_datetime_utc_now()) + marker_db = DumperMarkerDB( + marker=isotime.format(marker_dt, offset=False), + updated_at=date_utils.get_datetime_utc_now(), + ) DumperMarker.add_or_update(marker_db) exec_exporter = ExecutionsExporter(None, None) export_marker = exec_exporter._get_export_marker_from_db() self.assertEqual(export_marker, date_utils.add_utc_tz(marker_dt)) - @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True)) + @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True)) def test_get_missed_executions_from_db_no_marker(self): exec_exporter = ExecutionsExporter(None, None) all_execs = exec_exporter._get_missed_executions_from_db(export_marker=None) self.assertEqual(len(all_execs), len(self.saved_executions.values())) - @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True)) + @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True)) def test_get_missed_executions_from_db_with_marker(self): exec_exporter = ExecutionsExporter(None, None) all_execs = exec_exporter._get_missed_executions_from_db(export_marker=None) min_timestamp = min([item.end_timestamp for item in all_execs]) marker = min_timestamp + datetime.timedelta(seconds=1) - execs_greater_than_marker = [item for item in all_execs if item.end_timestamp > marker] + execs_greater_than_marker = [ + item for item in all_execs if item.end_timestamp > marker + ] all_execs = exec_exporter._get_missed_executions_from_db(export_marker=marker) self.assertTrue(len(all_execs) > 0) self.assertTrue(len(all_execs) == len(execs_greater_than_marker)) for item in all_execs: self.assertTrue(item.end_timestamp > marker) - @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True)) + @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True)) def test_bootstrap(self): exec_exporter = ExecutionsExporter(None, None) exec_exporter._bootstrap() - self.assertEqual(exec_exporter.pending_executions.qsize(), len(self.saved_executions)) + self.assertEqual( + exec_exporter.pending_executions.qsize(), len(self.saved_executions) + ) count = 0 while count < exec_exporter.pending_executions.qsize(): - self.assertIsInstance(exec_exporter.pending_executions.get(), ActionExecutionAPI) + self.assertIsInstance( + exec_exporter.pending_executions.get(), ActionExecutionAPI + ) count += 1 - @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True)) + @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True)) def test_process(self): some_execution = list(self.saved_executions.values())[5] exec_exporter = ExecutionsExporter(None, None) self.assertEqual(exec_exporter.pending_executions.qsize(), 0) exec_exporter.process(some_execution) self.assertEqual(exec_exporter.pending_executions.qsize(), 1) - some_execution.status = 'scheduled' + some_execution.status = "scheduled" exec_exporter.process(some_execution) self.assertEqual(exec_exporter.pending_executions.qsize(), 1) diff --git a/st2exporter/tests/unit/test_dumper.py b/st2exporter/tests/unit/test_dumper.py index 98e42e60f14..0ddec72e3b9 100644 --- a/st2exporter/tests/unit/test_dumper.py +++ b/st2exporter/tests/unit/test_dumper.py @@ -28,21 +28,30 @@ from st2tests.fixturesloader import FixturesLoader from st2common.util import date as date_utils -DESCENDANTS_PACK = 'descendants' +DESCENDANTS_PACK = "descendants" DESCENDANTS_FIXTURES = { - 'executions': ['root_execution.yaml', 'child1_level1.yaml', 'child2_level1.yaml', - 'child1_level2.yaml', 'child2_level2.yaml', 'child3_level2.yaml', - 'child1_level3.yaml', 'child2_level3.yaml', 'child3_level3.yaml'] + "executions": [ + "root_execution.yaml", + "child1_level1.yaml", + "child2_level1.yaml", + "child1_level2.yaml", + "child2_level2.yaml", + "child3_level2.yaml", + "child1_level3.yaml", + "child2_level3.yaml", + "child3_level3.yaml", + ] } class TestDumper(EventletTestCase): fixtures_loader = FixturesLoader() - loaded_fixtures = fixtures_loader.load_fixtures(fixtures_pack=DESCENDANTS_PACK, - fixtures_dict=DESCENDANTS_FIXTURES) - loaded_executions = loaded_fixtures['executions'] + loaded_fixtures = fixtures_loader.load_fixtures( + fixtures_pack=DESCENDANTS_PACK, fixtures_dict=DESCENDANTS_FIXTURES + ) + loaded_executions = loaded_fixtures["executions"] execution_apis = [] for execution in loaded_executions.values(): execution_apis.append(ActionExecutionAPI(**execution)) @@ -54,81 +63,101 @@ def get_queue(self): executions_queue.put(execution) return executions_queue - @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True)) + @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True)) def test_get_batch_batch_size_greater_than_actual(self): executions_queue = self.get_queue() qsize = executions_queue.qsize() self.assertTrue(qsize > 0) - dumper = Dumper(queue=executions_queue, batch_size=2 * qsize, - export_dir='/tmp') + dumper = Dumper(queue=executions_queue, batch_size=2 * qsize, export_dir="/tmp") batch = dumper._get_batch() self.assertEqual(len(batch), qsize) - @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True)) + @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True)) def test_get_batch_batch_size_lesser_than_actual(self): executions_queue = self.get_queue() qsize = executions_queue.qsize() self.assertTrue(qsize > 0) expected_batch_size = int(qsize / 2) - dumper = Dumper(queue=executions_queue, - batch_size=expected_batch_size, - export_dir='/tmp') + dumper = Dumper( + queue=executions_queue, batch_size=expected_batch_size, export_dir="/tmp" + ) batch = dumper._get_batch() self.assertEqual(len(batch), expected_batch_size) - @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True)) + @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True)) def test_get_file_name(self): - dumper = Dumper(queue=self.get_queue(), - export_dir='/tmp', - file_prefix='st2-stuff-', file_format='json') + dumper = Dumper( + queue=self.get_queue(), + export_dir="/tmp", + file_prefix="st2-stuff-", + file_format="json", + ) file_name = dumper._get_file_name() - export_date = date_utils.get_datetime_utc_now().strftime('%Y-%m-%d') - self.assertTrue(file_name.startswith('/tmp/' + export_date + '/st2-stuff-')) - self.assertTrue(file_name.endswith('json')) + export_date = date_utils.get_datetime_utc_now().strftime("%Y-%m-%d") + self.assertTrue(file_name.startswith("/tmp/" + export_date + "/st2-stuff-")) + self.assertTrue(file_name.endswith("json")) - @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True)) + @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True)) def test_write_to_disk_empty_queue(self): - dumper = Dumper(queue=queue.Queue(), - export_dir='/tmp', - file_prefix='st2-stuff-', file_format='json') + dumper = Dumper( + queue=queue.Queue(), + export_dir="/tmp", + file_prefix="st2-stuff-", + file_format="json", + ) # We just make sure this doesn't blow up. ret = dumper._write_to_disk() self.assertEqual(ret, 0) - @mock.patch.object(TextFileWriter, 'write_text', mock.MagicMock(return_value=True)) - @mock.patch.object(Dumper, '_update_marker', mock.MagicMock(return_value=None)) - @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True)) + @mock.patch.object(TextFileWriter, "write_text", mock.MagicMock(return_value=True)) + @mock.patch.object(Dumper, "_update_marker", mock.MagicMock(return_value=None)) + @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True)) def test_write_to_disk(self): executions_queue = self.get_queue() max_files_per_sleep = 5 - dumper = Dumper(queue=executions_queue, - export_dir='/tmp', batch_size=1, max_files_per_sleep=max_files_per_sleep, - file_prefix='st2-stuff-', file_format='json') + dumper = Dumper( + queue=executions_queue, + export_dir="/tmp", + batch_size=1, + max_files_per_sleep=max_files_per_sleep, + file_prefix="st2-stuff-", + file_format="json", + ) # We just make sure this doesn't blow up. ret = dumper._write_to_disk() self.assertEqual(ret, max_files_per_sleep) - @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True)) - @mock.patch.object(TextFileWriter, 'write_text', mock.MagicMock(return_value=True)) + @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True)) + @mock.patch.object(TextFileWriter, "write_text", mock.MagicMock(return_value=True)) def test_start_stop_dumper(self): executions_queue = self.get_queue() sleep_interval = 0.01 - dumper = Dumper(queue=executions_queue, sleep_interval=sleep_interval, - export_dir='/tmp', batch_size=1, max_files_per_sleep=5, - file_prefix='st2-stuff-', file_format='json') + dumper = Dumper( + queue=executions_queue, + sleep_interval=sleep_interval, + export_dir="/tmp", + batch_size=1, + max_files_per_sleep=5, + file_prefix="st2-stuff-", + file_format="json", + ) dumper.start() # Call stop after at least one batch was written to disk. eventlet.sleep(10 * sleep_interval) dumper.stop() - @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True)) - @mock.patch.object(Dumper, '_write_marker_to_db', mock.MagicMock(return_value=True)) + @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True)) + @mock.patch.object(Dumper, "_write_marker_to_db", mock.MagicMock(return_value=True)) def test_update_marker(self): executions_queue = self.get_queue() - dumper = Dumper(queue=executions_queue, - export_dir='/tmp', batch_size=5, - max_files_per_sleep=1, - file_prefix='st2-stuff-', file_format='json') + dumper = Dumper( + queue=executions_queue, + export_dir="/tmp", + batch_size=5, + max_files_per_sleep=1, + file_prefix="st2-stuff-", + file_format="json", + ) # Batch 1 batch = self.execution_apis[0:5] new_marker = dumper._update_marker(batch) @@ -145,15 +174,21 @@ def test_update_marker(self): self.assertEqual(new_marker, max_timestamp) dumper._write_marker_to_db.assert_called_with(new_marker) - @mock.patch.object(os.path, 'exists', mock.MagicMock(return_value=True)) - @mock.patch.object(Dumper, '_write_marker_to_db', mock.MagicMock(return_value=True)) + @mock.patch.object(os.path, "exists", mock.MagicMock(return_value=True)) + @mock.patch.object(Dumper, "_write_marker_to_db", mock.MagicMock(return_value=True)) def test_update_marker_out_of_order_batch(self): executions_queue = self.get_queue() - dumper = Dumper(queue=executions_queue, - export_dir='/tmp', batch_size=5, - max_files_per_sleep=1, - file_prefix='st2-stuff-', file_format='json') - timestamps = [isotime.parse(execution.end_timestamp) for execution in self.execution_apis] + dumper = Dumper( + queue=executions_queue, + export_dir="/tmp", + batch_size=5, + max_files_per_sleep=1, + file_prefix="st2-stuff-", + file_format="json", + ) + timestamps = [ + isotime.parse(execution.end_timestamp) for execution in self.execution_apis + ] max_timestamp = max(timestamps) # set dumper persisted timestamp to something less than min timestamp in the batch diff --git a/st2exporter/tests/unit/test_json_converter.py b/st2exporter/tests/unit/test_json_converter.py index ce2f484bca8..07f82a8bf05 100644 --- a/st2exporter/tests/unit/test_json_converter.py +++ b/st2exporter/tests/unit/test_json_converter.py @@ -20,34 +20,43 @@ from st2tests.fixturesloader import FixturesLoader from st2exporter.exporter.json_converter import JsonConverter -DESCENDANTS_PACK = 'descendants' +DESCENDANTS_PACK = "descendants" DESCENDANTS_FIXTURES = { - 'executions': ['root_execution.yaml', 'child1_level1.yaml', 'child2_level1.yaml', - 'child1_level2.yaml', 'child2_level2.yaml', 'child3_level2.yaml', - 'child1_level3.yaml', 'child2_level3.yaml', 'child3_level3.yaml'] + "executions": [ + "root_execution.yaml", + "child1_level1.yaml", + "child2_level1.yaml", + "child1_level2.yaml", + "child2_level2.yaml", + "child3_level2.yaml", + "child1_level3.yaml", + "child2_level3.yaml", + "child3_level3.yaml", + ] } class TestJsonConverter(unittest2.TestCase): fixtures_loader = FixturesLoader() - loaded_fixtures = fixtures_loader.load_fixtures(fixtures_pack=DESCENDANTS_PACK, - fixtures_dict=DESCENDANTS_FIXTURES) + loaded_fixtures = fixtures_loader.load_fixtures( + fixtures_pack=DESCENDANTS_PACK, fixtures_dict=DESCENDANTS_FIXTURES + ) def test_convert(self): - executions_list = list(self.loaded_fixtures['executions'].values()) + executions_list = list(self.loaded_fixtures["executions"].values()) converter = JsonConverter() converted_doc = converter.convert(executions_list) - self.assertTrue(type(converted_doc), 'string') + self.assertTrue(type(converted_doc), "string") reversed_doc = json.loads(converted_doc) self.assertListEqual(executions_list, reversed_doc) def test_convert_non_list(self): - executions_dict = self.loaded_fixtures['executions'] + executions_dict = self.loaded_fixtures["executions"] converter = JsonConverter() try: converter.convert(executions_dict) - self.fail('Should have thrown exception.') + self.fail("Should have thrown exception.") except ValueError: pass diff --git a/st2reactor/dist_utils.py b/st2reactor/dist_utils.py index a6f62c8cc2a..2f2043cf29f 100644 --- a/st2reactor/dist_utils.py +++ b/st2reactor/dist_utils.py @@ -43,17 +43,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" __all__ = [ - 'check_pip_is_installed', - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_is_installed", + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] @@ -64,15 +64,15 @@ def check_pip_is_installed(): try: import pip # NOQA except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) return True -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ @@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'): import pip if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) return True @@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path): reqs = [] def _get_link(line): - vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+'] + vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"] for vcs_prefix in vcs_prefixes: - if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)): - req_name = re.findall('.*#egg=(.+)([&|@]).*$', line) + if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)): + req_name = re.findall(".*#egg=(.+)([&|@]).*$", line) if not req_name: - req_name = re.findall('.*#egg=(.+?)$', line) + req_name = re.findall(".*#egg=(.+?)$", line) else: req_name = req_name[0] if not req_name: - raise ValueError('Line "%s" is missing "#egg="' % (line)) + raise ValueError( + 'Line "%s" is missing "#egg="' % (line) + ) - link = line.replace('-e ', '').strip() + link = line.replace("-e ", "").strip() return link, req_name[0] return None, None - with open(requirements_file_path, 'r') as fp: + with open(requirements_file_path, "r") as fp: for line in fp.readlines(): line = line.strip() - if line.startswith('#') or not line: + if line.startswith("#") or not line: continue link, req_name = _get_link(line=line) @@ -131,8 +135,8 @@ def _get_link(line): else: req_name = line - if ';' in req_name: - req_name = req_name.split(';')[0].strip() + if ";" in req_name: + req_name = req_name.split(";")[0].strip() reqs.append(req_name) @@ -146,7 +150,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -155,14 +159,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/st2reactor/setup.py b/st2reactor/setup.py index 0379240b8fd..adb3e7accc6 100644 --- a/st2reactor/setup.py +++ b/st2reactor/setup.py @@ -23,9 +23,9 @@ from dist_utils import apply_vagrant_workaround from st2reactor import __version__ -ST2_COMPONENT = 'st2reactor' +ST2_COMPONENT = "st2reactor" BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt') +REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt") install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE) @@ -33,23 +33,25 @@ setup( name=ST2_COMPONENT, version=__version__, - description='{} StackStorm event-driven automation platform component'.format(ST2_COMPONENT), - author='StackStorm', - author_email='info@stackstorm.com', - license='Apache License (2.0)', - url='https://stackstorm.com/', + description="{} StackStorm event-driven automation platform component".format( + ST2_COMPONENT + ), + author="StackStorm", + author_email="info@stackstorm.com", + license="Apache License (2.0)", + url="https://stackstorm.com/", install_requires=install_reqs, dependency_links=dep_links, test_suite=ST2_COMPONENT, zip_safe=False, include_package_data=True, - packages=find_packages(exclude=['setuptools', 'tests']), + packages=find_packages(exclude=["setuptools", "tests"]), scripts=[ - 'bin/st2-rule-tester', - 'bin/st2-trigger-refire', - 'bin/st2rulesengine', - 'bin/st2sensorcontainer', - 'bin/st2garbagecollector', - 'bin/st2timersengine', - ] + "bin/st2-rule-tester", + "bin/st2-trigger-refire", + "bin/st2rulesengine", + "bin/st2sensorcontainer", + "bin/st2garbagecollector", + "bin/st2timersengine", + ], ) diff --git a/st2reactor/st2reactor/__init__.py b/st2reactor/st2reactor/__init__.py index bbe290db9a7..e6d3f15e0bd 100644 --- a/st2reactor/st2reactor/__init__.py +++ b/st2reactor/st2reactor/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = "3.4dev" diff --git a/st2reactor/st2reactor/cmd/garbagecollector.py b/st2reactor/st2reactor/cmd/garbagecollector.py index ab3c64409b9..b4be4dfa8bd 100644 --- a/st2reactor/st2reactor/cmd/garbagecollector.py +++ b/st2reactor/st2reactor/cmd/garbagecollector.py @@ -16,6 +16,7 @@ from __future__ import absolute_import from st2common.util.monkey_patch import monkey_patch + monkey_patch() import os @@ -31,9 +32,7 @@ from st2reactor.garbage_collector import config from st2reactor.garbage_collector.base import GarbageCollectorService -__all__ = [ - 'main' -] +__all__ = ["main"] LOGGER_NAME = get_logger_name_for_module(sys.modules[__name__]) @@ -41,14 +40,17 @@ def _setup(): - capabilities = { - 'name': 'garbagecollector', - 'type': 'passive' - } - common_setup(service='garbagecollector', config=config, setup_db=True, - register_mq_exchanges=True, register_signal_handlers=True, - register_runners=False, service_registry=True, - capabilities=capabilities) + capabilities = {"name": "garbagecollector", "type": "passive"} + common_setup( + service="garbagecollector", + config=config, + setup_db=True, + register_mq_exchanges=True, + register_signal_handlers=True, + register_runners=False, + service_registry=True, + capabilities=capabilities, + ) def _teardown(): @@ -61,13 +63,14 @@ def main(): collection_interval = cfg.CONF.garbagecollector.collection_interval sleep_delay = cfg.CONF.garbagecollector.sleep_delay - garbage_collector = GarbageCollectorService(collection_interval=collection_interval, - sleep_delay=sleep_delay) + garbage_collector = GarbageCollectorService( + collection_interval=collection_interval, sleep_delay=sleep_delay + ) exit_code = garbage_collector.run() except SystemExit as exit_code: return exit_code except: - LOG.exception('(PID:%s) GarbageCollector quit due to exception.', os.getpid()) + LOG.exception("(PID:%s) GarbageCollector quit due to exception.", os.getpid()) return FAILURE_EXIT_CODE finally: _teardown() diff --git a/st2reactor/st2reactor/cmd/rule_tester.py b/st2reactor/st2reactor/cmd/rule_tester.py index 926a27a4ff9..b346168cb55 100644 --- a/st2reactor/st2reactor/cmd/rule_tester.py +++ b/st2reactor/st2reactor/cmd/rule_tester.py @@ -25,23 +25,27 @@ from st2common.script_setup import teardown as common_teardown from st2reactor.rules.tester import RuleTester -__all__ = [ - 'main' -] +__all__ = ["main"] LOG = logging.getLogger(__name__) def _register_cli_opts(): cli_opts = [ - cfg.StrOpt('rule', default=None, - help='Path to the file containing rule definition.'), - cfg.StrOpt('rule-ref', default=None, - help='Ref of the rule.'), - cfg.StrOpt('trigger-instance', default=None, - help='Path to the file containing trigger instance definition'), - cfg.StrOpt('trigger-instance-id', default=None, - help='Id of the Trigger Instance to use for validation.') + cfg.StrOpt( + "rule", default=None, help="Path to the file containing rule definition." + ), + cfg.StrOpt("rule-ref", default=None, help="Ref of the rule."), + cfg.StrOpt( + "trigger-instance", + default=None, + help="Path to the file containing trigger instance definition", + ), + cfg.StrOpt( + "trigger-instance-id", + default=None, + help="Id of the Trigger Instance to use for validation.", + ), ] do_register_cli_opts(cli_opts) @@ -51,17 +55,19 @@ def main(): common_setup(config=config, setup_db=True, register_mq_exchanges=False) try: - tester = RuleTester(rule_file_path=cfg.CONF.rule, - rule_ref=cfg.CONF.rule_ref, - trigger_instance_file_path=cfg.CONF.trigger_instance, - trigger_instance_id=cfg.CONF.trigger_instance_id) + tester = RuleTester( + rule_file_path=cfg.CONF.rule, + rule_ref=cfg.CONF.rule_ref, + trigger_instance_file_path=cfg.CONF.trigger_instance, + trigger_instance_id=cfg.CONF.trigger_instance_id, + ) matches = tester.evaluate() finally: common_teardown() if matches: - LOG.info('=== RULE MATCHES ===') + LOG.info("=== RULE MATCHES ===") sys.exit(0) else: - LOG.info('=== RULE DOES NOT MATCH ===') + LOG.info("=== RULE DOES NOT MATCH ===") sys.exit(1) diff --git a/st2reactor/st2reactor/cmd/rulesengine.py b/st2reactor/st2reactor/cmd/rulesengine.py index f372cc252ef..895fbe42d9b 100644 --- a/st2reactor/st2reactor/cmd/rulesengine.py +++ b/st2reactor/st2reactor/cmd/rulesengine.py @@ -16,6 +16,7 @@ from __future__ import absolute_import from st2common.util.monkey_patch import monkey_patch + monkey_patch() import os @@ -34,13 +35,18 @@ def _setup(): - capabilities = { - 'name': 'rulesengine', - 'type': 'passive' - } - common_setup(service='rulesengine', config=config, setup_db=True, register_mq_exchanges=True, - register_signal_handlers=True, register_internal_trigger_types=True, - register_runners=False, service_registry=True, capabilities=capabilities) + capabilities = {"name": "rulesengine", "type": "passive"} + common_setup( + service="rulesengine", + config=config, + setup_db=True, + register_mq_exchanges=True, + register_signal_handlers=True, + register_internal_trigger_types=True, + register_runners=False, + service_registry=True, + capabilities=capabilities, + ) def _teardown(): @@ -48,7 +54,7 @@ def _teardown(): def _run_worker(): - LOG.info('(PID=%s) RulesEngine started.', os.getpid()) + LOG.info("(PID=%s) RulesEngine started.", os.getpid()) rules_engine_worker = worker.get_worker() @@ -56,10 +62,10 @@ def _run_worker(): rules_engine_worker.start() return rules_engine_worker.wait() except (KeyboardInterrupt, SystemExit): - LOG.info('(PID=%s) RulesEngine stopped.', os.getpid()) + LOG.info("(PID=%s) RulesEngine stopped.", os.getpid()) rules_engine_worker.shutdown() except: - LOG.exception('(PID:%s) RulesEngine quit due to exception.', os.getpid()) + LOG.exception("(PID:%s) RulesEngine quit due to exception.", os.getpid()) return 1 return 0 @@ -72,7 +78,7 @@ def main(): except SystemExit as exit_code: sys.exit(exit_code) except: - LOG.exception('(PID=%s) RulesEngine quit due to exception.', os.getpid()) + LOG.exception("(PID=%s) RulesEngine quit due to exception.", os.getpid()) return 1 finally: _teardown() diff --git a/st2reactor/st2reactor/cmd/sensormanager.py b/st2reactor/st2reactor/cmd/sensormanager.py index df2be8e7ac7..f3d27afb5b8 100644 --- a/st2reactor/st2reactor/cmd/sensormanager.py +++ b/st2reactor/st2reactor/cmd/sensormanager.py @@ -16,6 +16,7 @@ from __future__ import absolute_import from st2common.util.monkey_patch import monkey_patch + monkey_patch() import os @@ -33,9 +34,7 @@ from st2reactor.container.manager import SensorContainerManager from st2reactor.container.partitioner_lookup import get_sensors_partitioner -__all__ = [ - 'main' -] +__all__ = ["main"] LOGGER_NAME = get_logger_name_for_module(sys.modules[__name__]) @@ -43,13 +42,17 @@ def _setup(): - capabilities = { - 'name': 'sensorcontainer', - 'type': 'passive' - } - common_setup(service='sensorcontainer', config=config, setup_db=True, - register_mq_exchanges=True, register_signal_handlers=True, - register_runners=False, service_registry=True, capabilities=capabilities) + capabilities = {"name": "sensorcontainer", "type": "passive"} + common_setup( + service="sensorcontainer", + config=config, + setup_db=True, + register_mq_exchanges=True, + register_signal_handlers=True, + register_runners=False, + service_registry=True, + capabilities=capabilities, + ) def _teardown(): @@ -60,16 +63,21 @@ def main(): try: _setup() - single_sensor_mode = (cfg.CONF.single_sensor_mode or - cfg.CONF.sensorcontainer.single_sensor_mode) + single_sensor_mode = ( + cfg.CONF.single_sensor_mode or cfg.CONF.sensorcontainer.single_sensor_mode + ) if single_sensor_mode and not cfg.CONF.sensor_ref: - raise ValueError('--sensor-ref argument must be provided when running in single ' - 'sensor mode') + raise ValueError( + "--sensor-ref argument must be provided when running in single " + "sensor mode" + ) sensors_partitioner = get_sensors_partitioner() - container_manager = SensorContainerManager(sensors_partitioner=sensors_partitioner, - single_sensor_mode=single_sensor_mode) + container_manager = SensorContainerManager( + sensors_partitioner=sensors_partitioner, + single_sensor_mode=single_sensor_mode, + ) return container_manager.run_sensors() except SystemExit as exit_code: return exit_code @@ -77,7 +85,7 @@ def main(): LOG.exception(e) return 1 except: - LOG.exception('(PID:%s) SensorContainer quit due to exception.', os.getpid()) + LOG.exception("(PID:%s) SensorContainer quit due to exception.", os.getpid()) return FAILURE_EXIT_CODE finally: _teardown() diff --git a/st2reactor/st2reactor/cmd/timersengine.py b/st2reactor/st2reactor/cmd/timersengine.py index 0b0cc4b5dd8..9b4edd52b5e 100644 --- a/st2reactor/st2reactor/cmd/timersengine.py +++ b/st2reactor/st2reactor/cmd/timersengine.py @@ -16,6 +16,7 @@ from __future__ import absolute_import from st2common.util.monkey_patch import monkey_patch + monkey_patch() import os @@ -38,12 +39,16 @@ def _setup(): - capabilities = { - 'name': 'timerengine', - 'type': 'passive' - } - common_setup(service='timer_engine', config=config, setup_db=True, register_mq_exchanges=True, - register_signal_handlers=True, service_registry=True, capabilities=capabilities) + capabilities = {"name": "timerengine", "type": "passive"} + common_setup( + service="timer_engine", + config=config, + setup_db=True, + register_mq_exchanges=True, + register_signal_handlers=True, + service_registry=True, + capabilities=capabilities, + ) def _teardown(): @@ -55,14 +60,16 @@ def _kickoff_timer(timer): def _run_worker(): - LOG.info('(PID=%s) TimerEngine started.', os.getpid()) + LOG.info("(PID=%s) TimerEngine started.", os.getpid()) timer = None try: timer_thread = None if cfg.CONF.timer.enable or cfg.CONF.timersengine.enable: - local_tz = cfg.CONF.timer.local_timezone or cfg.CONF.timersengine.local_timezone + local_tz = ( + cfg.CONF.timer.local_timezone or cfg.CONF.timersengine.local_timezone + ) timer = St2Timer(local_timezone=local_tz) timer_thread = concurrency.spawn(_kickoff_timer, timer) LOG.info(TIMER_ENABLED_LOG_LINE) @@ -70,9 +77,9 @@ def _run_worker(): else: LOG.info(TIMER_DISABLED_LOG_LINE) except (KeyboardInterrupt, SystemExit): - LOG.info('(PID=%s) TimerEngine stopped.', os.getpid()) + LOG.info("(PID=%s) TimerEngine stopped.", os.getpid()) except: - LOG.exception('(PID:%s) TimerEngine quit due to exception.', os.getpid()) + LOG.exception("(PID:%s) TimerEngine quit due to exception.", os.getpid()) return 1 finally: if timer: @@ -88,7 +95,7 @@ def main(): except SystemExit as exit_code: sys.exit(exit_code) except Exception: - LOG.exception('(PID=%s) TimerEngine quit due to exception.', os.getpid()) + LOG.exception("(PID=%s) TimerEngine quit due to exception.", os.getpid()) return 1 finally: _teardown() diff --git a/st2reactor/st2reactor/cmd/trigger_re_fire.py b/st2reactor/st2reactor/cmd/trigger_re_fire.py index 4f2c8f9ca17..8282a5decfe 100644 --- a/st2reactor/st2reactor/cmd/trigger_re_fire.py +++ b/st2reactor/st2reactor/cmd/trigger_re_fire.py @@ -27,24 +27,23 @@ from st2common.persistence.trigger import TriggerInstance from st2common.transport.reactor import TriggerDispatcher -__all__ = [ - 'main' -] +__all__ = ["main"] CONF = cfg.CONF def _parse_config(): cli_opts = [ - cfg.BoolOpt('verbose', - short='v', - default=False, - help='Print more verbose output'), - cfg.StrOpt('trigger-instance-id', - short='t', - required=True, - dest='trigger_instance_id', - help='Id of trigger instance'), + cfg.BoolOpt( + "verbose", short="v", default=False, help="Print more verbose output" + ), + cfg.StrOpt( + "trigger-instance-id", + short="t", + required=True, + dest="trigger_instance_id", + help="Id of trigger instance", + ), ] CONF.register_cli_opts(cli_opts) st2cfg.register_opts(ignore_errors=False) @@ -54,22 +53,17 @@ def _parse_config(): def _setup_logging(): logging_config = { - 'version': 1, - 'disable_existing_loggers': False, - 'formatters': { - 'default': { - 'format': '%(asctime)s %(levelname)s %(name)s %(message)s' - }, + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "default": {"format": "%(asctime)s %(levelname)s %(name)s %(message)s"}, }, - 'handlers': { - 'console': { - '()': std_logging.StreamHandler, - 'formatter': 'default' - } + "handlers": { + "console": {"()": std_logging.StreamHandler, "formatter": "default"} }, - 'root': { - 'handlers': ['console'], - 'level': 'DEBUG', + "root": { + "handlers": ["console"], + "level": "DEBUG", }, } std_logging.config.dictConfig(logging_config) @@ -82,8 +76,9 @@ def _setup_db(): def _refire_trigger_instance(trigger_instance_id, log_): trigger_instance = TriggerInstance.get_by_id(trigger_instance_id) trigger_dispatcher = TriggerDispatcher(log_) - trigger_dispatcher.dispatch(trigger=trigger_instance.trigger, - payload=trigger_instance.payload) + trigger_dispatcher.dispatch( + trigger=trigger_instance.trigger, payload=trigger_instance.payload + ) def main(): @@ -94,7 +89,8 @@ def main(): else: output = pprint.pprint _setup_db() - _refire_trigger_instance(trigger_instance_id=CONF.trigger_instance_id, - log_=logging.getLogger(__name__)) - output('Trigger re-fired') + _refire_trigger_instance( + trigger_instance_id=CONF.trigger_instance_id, log_=logging.getLogger(__name__) + ) + output("Trigger re-fired") db_teardown() diff --git a/st2reactor/st2reactor/container/hash_partitioner.py b/st2reactor/st2reactor/container/hash_partitioner.py index 9ed0cb78bea..b9e5e466581 100644 --- a/st2reactor/st2reactor/container/hash_partitioner.py +++ b/st2reactor/st2reactor/container/hash_partitioner.py @@ -17,25 +17,25 @@ import ctypes import hashlib -from st2reactor.container.partitioners import DefaultPartitioner, get_all_enabled_sensors +from st2reactor.container.partitioners import ( + DefaultPartitioner, + get_all_enabled_sensors, +) -__all__ = [ - 'HashPartitioner', - 'Range' -] +__all__ = ["HashPartitioner", "Range"] # The range expression serialized is of the form `RANGE_START..RANGE_END|RANGE_START..RANGE_END ...` -SUB_RANGE_SEPARATOR = '|' -RANGE_BOUNDARY_SEPARATOR = '..' +SUB_RANGE_SEPARATOR = "|" +RANGE_BOUNDARY_SEPARATOR = ".." class Range(object): - RANGE_MIN_ENUM = 'min' + RANGE_MIN_ENUM = "min" RANGE_MIN_VALUE = 0 - RANGE_MAX_ENUM = 'max' - RANGE_MAX_VALUE = 2**32 + RANGE_MAX_ENUM = "max" + RANGE_MAX_VALUE = 2 ** 32 def __init__(self, range_repr): self.range_start, self.range_end = self._get_range_boundaries(range_repr) @@ -44,15 +44,17 @@ def __contains__(self, item): return item >= self.range_start and item < self.range_end def _get_range_boundaries(self, range_repr): - range_repr = [value.strip() for value in range_repr.split(RANGE_BOUNDARY_SEPARATOR)] + range_repr = [ + value.strip() for value in range_repr.split(RANGE_BOUNDARY_SEPARATOR) + ] if len(range_repr) != 2: - raise ValueError('Unsupported sub-range format %s.' % range_repr) + raise ValueError("Unsupported sub-range format %s." % range_repr) range_start = self._get_valid_range_boundary(range_repr[0]) range_end = self._get_valid_range_boundary(range_repr[1]) if range_start > range_end: - raise ValueError('Misconfigured range [%d..%d]' % (range_start, range_end)) + raise ValueError("Misconfigured range [%d..%d]" % (range_start, range_end)) return (range_start, range_end) def _get_valid_range_boundary(self, boundary_value): @@ -73,7 +75,6 @@ def _get_valid_range_boundary(self, boundary_value): class HashPartitioner(DefaultPartitioner): - def __init__(self, sensor_node_name, hash_ranges): super(HashPartitioner, self).__init__(sensor_node_name=sensor_node_name) self._hash_ranges = self._create_hash_ranges(hash_ranges) @@ -112,7 +113,7 @@ def _hash_sensor_ref(self, sensor_ref): h = ctypes.c_uint(0) for d in reversed(str(md5_hash_int_repr)): d = ctypes.c_uint(int(d)) - higherorder = ctypes.c_uint(h.value & 0xf8000000) + higherorder = ctypes.c_uint(h.value & 0xF8000000) h = ctypes.c_uint(h.value << 5) h = ctypes.c_uint(h.value ^ (higherorder.value >> 27)) h = ctypes.c_uint(h.value ^ d.value) diff --git a/st2reactor/st2reactor/container/manager.py b/st2reactor/st2reactor/container/manager.py index 694d3ce337e..e9f251aebc1 100644 --- a/st2reactor/st2reactor/container/manager.py +++ b/st2reactor/st2reactor/container/manager.py @@ -27,16 +27,13 @@ LOG = logging.getLogger(__name__) -__all__ = [ - 'SensorContainerManager' -] +__all__ = ["SensorContainerManager"] class SensorContainerManager(object): - def __init__(self, sensors_partitioner, single_sensor_mode=False): if not sensors_partitioner: - raise ValueError('sensors_partitioner should be non-None.') + raise ValueError("sensors_partitioner should be non-None.") self._sensors_partitioner = sensors_partitioner self._single_sensor_mode = single_sensor_mode @@ -44,10 +41,12 @@ def __init__(self, sensors_partitioner, single_sensor_mode=False): self._sensor_container = None self._container_thread = None - self._sensors_watcher = SensorWatcher(create_handler=self._handle_create_sensor, - update_handler=self._handle_update_sensor, - delete_handler=self._handle_delete_sensor, - queue_suffix='sensor_container') + self._sensors_watcher = SensorWatcher( + create_handler=self._handle_create_sensor, + update_handler=self._handle_update_sensor, + delete_handler=self._handle_delete_sensor, + queue_suffix="sensor_container", + ) def run_sensors(self): """ @@ -55,15 +54,18 @@ def run_sensors(self): """ sensors = self._sensors_partitioner.get_sensors() if sensors: - LOG.info('Setting up container to run %d sensors.', len(sensors)) - LOG.info('\tSensors list - %s.', [self._get_sensor_ref(sensor) for sensor in sensors]) + LOG.info("Setting up container to run %d sensors.", len(sensors)) + LOG.info( + "\tSensors list - %s.", + [self._get_sensor_ref(sensor) for sensor in sensors], + ) sensors_to_run = [] for sensor in sensors: # TODO: Directly pass DB object to the ProcessContainer sensors_to_run.append(self._to_sensor_object(sensor)) - LOG.info('(PID:%s) SensorContainer started.', os.getpid()) + LOG.info("(PID:%s) SensorContainer started.", os.getpid()) self._setup_sigterm_handler() exit_code = self._spin_container_and_wait(sensors_to_run) @@ -74,22 +76,25 @@ def _spin_container_and_wait(self, sensors): try: self._sensor_container = ProcessSensorContainer( - sensors=sensors, - single_sensor_mode=self._single_sensor_mode) + sensors=sensors, single_sensor_mode=self._single_sensor_mode + ) self._container_thread = concurrency.spawn(self._sensor_container.run) - LOG.debug('Starting sensor CUD watcher...') + LOG.debug("Starting sensor CUD watcher...") self._sensors_watcher.start() exit_code = self._container_thread.wait() - LOG.error('Process container quit with exit_code %d.', exit_code) - LOG.error('(PID:%s) SensorContainer stopped.', os.getpid()) + LOG.error("Process container quit with exit_code %d.", exit_code) + LOG.error("(PID:%s) SensorContainer stopped.", os.getpid()) except (KeyboardInterrupt, SystemExit): self._sensor_container.shutdown() self._sensors_watcher.stop() - LOG.info('(PID:%s) SensorContainer stopped. Reason - %s', os.getpid(), - sys.exc_info()[0].__name__) + LOG.info( + "(PID:%s) SensorContainer stopped. Reason - %s", + os.getpid(), + sys.exc_info()[0].__name__, + ) concurrency.kill(self._container_thread) self._container_thread = None @@ -99,7 +104,6 @@ def _spin_container_and_wait(self, sensors): return exit_code def _setup_sigterm_handler(self): - def sigterm_handler(signum=None, frame=None): # This will cause SystemExit to be throw and we call sensor_container.shutdown() # there which cleans things up. @@ -110,16 +114,16 @@ def sigterm_handler(signum=None, frame=None): signal.signal(signal.SIGTERM, sigterm_handler) def _to_sensor_object(self, sensor_db): - file_path = sensor_db.artifact_uri.replace('file://', '') - class_name = sensor_db.entry_point.split('.')[-1] + file_path = sensor_db.artifact_uri.replace("file://", "") + class_name = sensor_db.entry_point.split(".")[-1] sensor_obj = { - 'pack': sensor_db.pack, - 'file_path': file_path, - 'class_name': class_name, - 'trigger_types': sensor_db.trigger_types, - 'poll_interval': sensor_db.poll_interval, - 'ref': self._get_sensor_ref(sensor_db) + "pack": sensor_db.pack, + "file_path": file_path, + "class_name": class_name, + "trigger_types": sensor_db.trigger_types, + "poll_interval": sensor_db.poll_interval, + "ref": self._get_sensor_ref(sensor_db), } return sensor_obj @@ -130,42 +134,50 @@ def _to_sensor_object(self, sensor_db): def _handle_create_sensor(self, sensor): if not self._sensors_partitioner.is_sensor_owner(sensor): - LOG.info('sensor %s is not supported. Ignoring create.', self._get_sensor_ref(sensor)) + LOG.info( + "sensor %s is not supported. Ignoring create.", + self._get_sensor_ref(sensor), + ) return if not sensor.enabled: - LOG.info('sensor %s is not enabled.', self._get_sensor_ref(sensor)) + LOG.info("sensor %s is not enabled.", self._get_sensor_ref(sensor)) return - LOG.info('Adding sensor %s.', self._get_sensor_ref(sensor)) + LOG.info("Adding sensor %s.", self._get_sensor_ref(sensor)) self._sensor_container.add_sensor(sensor=self._to_sensor_object(sensor)) def _handle_update_sensor(self, sensor): if not self._sensors_partitioner.is_sensor_owner(sensor): - LOG.info('sensor %s is not assigned to this partition. Ignoring update. ', - self._get_sensor_ref(sensor)) + LOG.info( + "sensor %s is not assigned to this partition. Ignoring update. ", + self._get_sensor_ref(sensor), + ) return sensor_ref = self._get_sensor_ref(sensor) sensor_obj = self._to_sensor_object(sensor) # Handle disabling sensor if not sensor.enabled: - LOG.info('Sensor %s disabled. Unloading sensor.', sensor_ref) + LOG.info("Sensor %s disabled. Unloading sensor.", sensor_ref) self._sensor_container.remove_sensor(sensor=sensor_obj) return - LOG.info('Sensor %s updated. Reloading sensor.', sensor_ref) + LOG.info("Sensor %s updated. Reloading sensor.", sensor_ref) try: self._sensor_container.remove_sensor(sensor=sensor_obj) except: - LOG.exception('Failed to reload sensor %s', sensor_ref) + LOG.exception("Failed to reload sensor %s", sensor_ref) else: self._sensor_container.add_sensor(sensor=sensor_obj) - LOG.info('Sensor %s reloaded.', sensor_ref) + LOG.info("Sensor %s reloaded.", sensor_ref) def _handle_delete_sensor(self, sensor): if not self._sensors_partitioner.is_sensor_owner(sensor): - LOG.info('sensor %s is not supported. Ignoring delete.', self._get_sensor_ref(sensor)) + LOG.info( + "sensor %s is not supported. Ignoring delete.", + self._get_sensor_ref(sensor), + ) return - LOG.info('Unloading sensor %s.', self._get_sensor_ref(sensor)) + LOG.info("Unloading sensor %s.", self._get_sensor_ref(sensor)) self._sensor_container.remove_sensor(sensor=self._to_sensor_object(sensor)) def _get_sensor_ref(self, sensor): diff --git a/st2reactor/st2reactor/container/partitioner_lookup.py b/st2reactor/st2reactor/container/partitioner_lookup.py index c4f43db6da3..1469b3c63c2 100644 --- a/st2reactor/st2reactor/container/partitioner_lookup.py +++ b/st2reactor/st2reactor/container/partitioner_lookup.py @@ -18,16 +18,22 @@ from oslo_config import cfg from st2common import log as logging -from st2common.constants.sensors import DEFAULT_PARTITION_LOADER, KVSTORE_PARTITION_LOADER, \ - FILE_PARTITION_LOADER, HASH_PARTITION_LOADER +from st2common.constants.sensors import ( + DEFAULT_PARTITION_LOADER, + KVSTORE_PARTITION_LOADER, + FILE_PARTITION_LOADER, + HASH_PARTITION_LOADER, +) from st2common.exceptions.sensors import SensorPartitionerNotSupportedException -from st2reactor.container.partitioners import DefaultPartitioner, KVStorePartitioner, \ - FileBasedPartitioner, SingleSensorPartitioner +from st2reactor.container.partitioners import ( + DefaultPartitioner, + KVStorePartitioner, + FileBasedPartitioner, + SingleSensorPartitioner, +) from st2reactor.container.hash_partitioner import HashPartitioner -__all__ = [ - 'get_sensors_partitioner' -] +__all__ = ["get_sensors_partitioner"] LOG = logging.getLogger(__name__) @@ -35,25 +41,28 @@ DEFAULT_PARTITION_LOADER: DefaultPartitioner, KVSTORE_PARTITION_LOADER: KVStorePartitioner, FILE_PARTITION_LOADER: FileBasedPartitioner, - HASH_PARTITION_LOADER: HashPartitioner + HASH_PARTITION_LOADER: HashPartitioner, } def get_sensors_partitioner(): if cfg.CONF.sensor_ref: - LOG.info('Running in single sensor mode, using a single sensor partitioner...') + LOG.info("Running in single sensor mode, using a single sensor partitioner...") return SingleSensorPartitioner(sensor_ref=cfg.CONF.sensor_ref) partition_provider_config = copy.copy(cfg.CONF.sensorcontainer.partition_provider) - partition_provider = partition_provider_config.pop('name') + partition_provider = partition_provider_config.pop("name") sensor_node_name = cfg.CONF.sensorcontainer.sensor_node_name provider = PROVIDERS.get(partition_provider.lower(), None) if not provider: - raise SensorPartitionerNotSupportedException('Partition provider %s not found.' % - (partition_provider)) + raise SensorPartitionerNotSupportedException( + "Partition provider %s not found." % (partition_provider) + ) - LOG.info('Using partitioner %s with sensornode %s.', partition_provider, sensor_node_name) + LOG.info( + "Using partitioner %s with sensornode %s.", partition_provider, sensor_node_name + ) # pass in extra config with no analysis return provider(sensor_node_name=sensor_node_name, **partition_provider_config) diff --git a/st2reactor/st2reactor/container/partitioners.py b/st2reactor/st2reactor/container/partitioners.py index 12a17f9081b..02a6d6137b7 100644 --- a/st2reactor/st2reactor/container/partitioners.py +++ b/st2reactor/st2reactor/container/partitioners.py @@ -18,18 +18,20 @@ import yaml from st2common import log as logging -from st2common.exceptions.sensors import SensorNotFoundException, \ - SensorPartitionMapMissingException +from st2common.exceptions.sensors import ( + SensorNotFoundException, + SensorPartitionMapMissingException, +) from st2common.persistence.keyvalue import KeyValuePair from st2common.persistence.sensor import SensorType __all__ = [ - 'get_all_enabled_sensors', - 'DefaultPartitioner', - 'KVStorePartitioner', - 'FileBasedPartitioner', - 'SingleSensorPartitioner' + "get_all_enabled_sensors", + "DefaultPartitioner", + "KVStorePartitioner", + "FileBasedPartitioner", + "SingleSensorPartitioner", ] LOG = logging.getLogger(__name__) @@ -38,12 +40,11 @@ def get_all_enabled_sensors(): # only query for enabled sensors. sensors = SensorType.query(enabled=True) - LOG.info('Found %d registered sensors in db scan.', len(sensors)) + LOG.info("Found %d registered sensors in db scan.", len(sensors)) return sensors class DefaultPartitioner(object): - def __init__(self, sensor_node_name): self.sensor_node_name = sensor_node_name @@ -78,7 +79,6 @@ def get_required_sensor_refs(self): class KVStorePartitioner(DefaultPartitioner): - def __init__(self, sensor_node_name): super(KVStorePartitioner, self).__init__(sensor_node_name=sensor_node_name) self._supported_sensor_refs = None @@ -90,46 +90,51 @@ def get_required_sensor_refs(self): partition_lookup_key = self._get_partition_lookup_key(self.sensor_node_name) kvp = KeyValuePair.get_by_name(partition_lookup_key) - sensor_refs_str = kvp.value if kvp.value else '' - self._supported_sensor_refs = set([ - sensor_ref.strip() for sensor_ref in sensor_refs_str.split(',')]) + sensor_refs_str = kvp.value if kvp.value else "" + self._supported_sensor_refs = set( + [sensor_ref.strip() for sensor_ref in sensor_refs_str.split(",")] + ) return list(self._supported_sensor_refs) def _get_partition_lookup_key(self, sensor_node_name): - return '{}.sensor_partition'.format(sensor_node_name) + return "{}.sensor_partition".format(sensor_node_name) class FileBasedPartitioner(DefaultPartitioner): - def __init__(self, sensor_node_name, partition_file): super(FileBasedPartitioner, self).__init__(sensor_node_name=sensor_node_name) self.partition_file = partition_file self._supported_sensor_refs = None def is_sensor_owner(self, sensor_db): - return sensor_db.get_reference().ref in self._supported_sensor_refs and sensor_db.enabled + return ( + sensor_db.get_reference().ref in self._supported_sensor_refs + and sensor_db.enabled + ) def get_required_sensor_refs(self): - with open(self.partition_file, 'r') as f: + with open(self.partition_file, "r") as f: partition_map = yaml.safe_load(f) sensor_refs = partition_map.get(self.sensor_node_name, None) if sensor_refs is None: - raise SensorPartitionMapMissingException('Sensor partition not found for %s in %s.' - % (self.sensor_node_name, - self.partition_file)) + raise SensorPartitionMapMissingException( + "Sensor partition not found for %s in %s." + % (self.sensor_node_name, self.partition_file) + ) self._supported_sensor_refs = set(sensor_refs) return list(self._supported_sensor_refs) class SingleSensorPartitioner(object): - def __init__(self, sensor_ref): self._sensor_ref = sensor_ref def get_sensors(self): sensor = SensorType.get_by_ref(self._sensor_ref) if not sensor: - raise SensorNotFoundException('Sensor %s not found in db.' % self._sensor_ref) + raise SensorNotFoundException( + "Sensor %s not found in db." % self._sensor_ref + ) return [sensor] def is_sensor_owner(self, sensor_db): diff --git a/st2reactor/st2reactor/container/process_container.py b/st2reactor/st2reactor/container/process_container.py index f8f1638d71b..890bcccbb9d 100644 --- a/st2reactor/st2reactor/container/process_container.py +++ b/st2reactor/st2reactor/container/process_container.py @@ -31,7 +31,7 @@ from st2common.constants.error_messages import PACK_VIRTUALENV_DOESNT_EXIST from st2common.constants.system import API_URL_ENV_VARIABLE_NAME from st2common.constants.system import AUTH_TOKEN_ENV_VARIABLE_NAME -from st2common.constants.triggers import (SENSOR_SPAWN_TRIGGER, SENSOR_EXIT_TRIGGER) +from st2common.constants.triggers import SENSOR_SPAWN_TRIGGER, SENSOR_EXIT_TRIGGER from st2common.constants.exit_codes import SUCCESS_EXIT_CODE from st2common.constants.exit_codes import FAILURE_EXIT_CODE from st2common.models.system.common import ResourceReference @@ -44,14 +44,12 @@ from st2common.util.sandboxing import get_sandbox_python_binary_path from st2common.util.sandboxing import get_sandbox_virtualenv_path -__all__ = [ - 'ProcessSensorContainer' -] +__all__ = ["ProcessSensorContainer"] -LOG = logging.getLogger('st2reactor.process_sensor_container') +LOG = logging.getLogger("st2reactor.process_sensor_container") BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -WRAPPER_SCRIPT_NAME = 'sensor_wrapper.py' +WRAPPER_SCRIPT_NAME = "sensor_wrapper.py" WRAPPER_SCRIPT_PATH = os.path.join(BASE_DIR, WRAPPER_SCRIPT_NAME) # How many times to try to subsequently respawn a sensor after a non-zero exit before giving up @@ -78,8 +76,15 @@ class ProcessSensorContainer(object): Sensor container which runs sensors in a separate process. """ - def __init__(self, sensors, poll_interval=5, single_sensor_mode=False, dispatcher=None, - wrapper_script_path=WRAPPER_SCRIPT_PATH, create_token=True): + def __init__( + self, + sensors, + poll_interval=5, + single_sensor_mode=False, + dispatcher=None, + wrapper_script_path=WRAPPER_SCRIPT_PATH, + create_token=True, + ): """ :param sensors: A list of sensor dicts. :type sensors: ``list`` of ``dict`` @@ -119,7 +124,9 @@ def __init__(self, sensors, poll_interval=5, single_sensor_mode=False, dispatche # Stores information needed for respawning dead sensors self._sensor_start_times = {} # maps sensor_id -> sensor start time - self._sensor_respawn_counts = defaultdict(int) # maps sensor_id -> number of respawns + self._sensor_respawn_counts = defaultdict( + int + ) # maps sensor_id -> number of respawns # A list of all the instance variables which hold internal state information about a # particular_sensor @@ -144,10 +151,10 @@ def run(self): sensor_ids = list(self._sensors.keys()) if len(sensor_ids) >= 1: - LOG.debug('%d active sensor(s)' % (len(sensor_ids))) + LOG.debug("%d active sensor(s)" % (len(sensor_ids))) self._poll_sensors_for_results(sensor_ids) else: - LOG.debug('No active sensors') + LOG.debug("No active sensors") concurrency.sleep(self._poll_interval) except success_exception_cls: @@ -157,12 +164,12 @@ def run(self): self._stopped = True return SUCCESS_EXIT_CODE except: - LOG.exception('Container failed to run sensors.') + LOG.exception("Container failed to run sensors.") self._stopped = True return FAILURE_EXIT_CODE self._stopped = True - LOG.error('Process container stopped.') + LOG.error("Process container stopped.") exit_code = self._exit_code or SUCCESS_EXIT_CODE return exit_code @@ -179,23 +186,29 @@ def _poll_sensors_for_results(self, sensor_ids): if status is not None: # Dead process detected - LOG.info('Process for sensor %s has exited with code %s', sensor_id, status) + LOG.info( + "Process for sensor %s has exited with code %s", sensor_id, status + ) sensor = self._sensors[sensor_id] self._delete_sensor(sensor_id) - self._dispatch_trigger_for_sensor_exit(sensor=sensor, - exit_code=status) + self._dispatch_trigger_for_sensor_exit(sensor=sensor, exit_code=status) # Try to respawn a dead process (maybe it was a simple failure which can be # resolved with a restart) - concurrency.spawn(self._respawn_sensor, sensor_id=sensor_id, sensor=sensor, - exit_code=status) + concurrency.spawn( + self._respawn_sensor, + sensor_id=sensor_id, + sensor=sensor, + exit_code=status, + ) else: sensor_start_time = self._sensor_start_times[sensor_id] sensor_respawn_count = self._sensor_respawn_counts[sensor_id] - successfully_started = ((now - sensor_start_time) >= - SENSOR_SUCCESSFUL_START_THRESHOLD) + successfully_started = ( + now - sensor_start_time + ) >= SENSOR_SUCCESSFUL_START_THRESHOLD if successfully_started and sensor_respawn_count >= 1: # Sensor has been successfully running more than threshold seconds, clear the @@ -209,7 +222,7 @@ def stopped(self): return self._stopped def shutdown(self, force=False): - LOG.info('Container shutting down. Invoking cleanup on sensors.') + LOG.info("Container shutting down. Invoking cleanup on sensors.") self._stopped = True if force: @@ -221,7 +234,7 @@ def shutdown(self, force=False): for sensor_id in sensor_ids: self._stop_sensor_process(sensor_id=sensor_id, exit_timeout=exit_timeout) - LOG.info('All sensors are shut down.') + LOG.info("All sensors are shut down.") self._sensors = {} self._processes = {} @@ -235,11 +248,11 @@ def add_sensor(self, sensor): sensor_id = self._get_sensor_id(sensor=sensor) if sensor_id in self._sensors: - LOG.warning('Sensor %s already exists and running.', sensor_id) + LOG.warning("Sensor %s already exists and running.", sensor_id) return False self._spawn_sensor_process(sensor=sensor) - LOG.debug('Sensor %s started.', sensor_id) + LOG.debug("Sensor %s started.", sensor_id) self._sensors[sensor_id] = sensor return True @@ -252,11 +265,11 @@ def remove_sensor(self, sensor): sensor_id = self._get_sensor_id(sensor=sensor) if sensor_id not in self._sensors: - LOG.warning('Sensor %s isn\'t running in this container.', sensor_id) + LOG.warning("Sensor %s isn't running in this container.", sensor_id) return False self._stop_sensor_process(sensor_id=sensor_id) - LOG.debug('Sensor %s stopped.', sensor_id) + LOG.debug("Sensor %s stopped.", sensor_id) return True def _run_all_sensors(self): @@ -264,7 +277,7 @@ def _run_all_sensors(self): for sensor_id in sensor_ids: sensor_obj = self._sensors[sensor_id] - LOG.info('Running sensor %s', sensor_id) + LOG.info("Running sensor %s", sensor_id) try: self._spawn_sensor_process(sensor=sensor_obj) @@ -275,7 +288,7 @@ def _run_all_sensors(self): del self._sensors[sensor_id] continue - LOG.info('Sensor %s started' % sensor_id) + LOG.info("Sensor %s started" % sensor_id) def _spawn_sensor_process(self, sensor): """ @@ -285,45 +298,53 @@ def _spawn_sensor_process(self, sensor): belonging to the sensor pack. """ sensor_id = self._get_sensor_id(sensor=sensor) - pack_ref = sensor['pack'] + pack_ref = sensor["pack"] virtualenv_path = get_sandbox_virtualenv_path(pack=pack_ref) python_path = get_sandbox_python_binary_path(pack=pack_ref) if virtualenv_path and not os.path.isdir(virtualenv_path): - format_values = {'pack': sensor['pack'], 'virtualenv_path': virtualenv_path} + format_values = {"pack": sensor["pack"], "virtualenv_path": virtualenv_path} msg = PACK_VIRTUALENV_DOESNT_EXIST % format_values raise Exception(msg) - args = self._get_args_for_wrapper_script(python_binary=python_path, sensor=sensor) + args = self._get_args_for_wrapper_script( + python_binary=python_path, sensor=sensor + ) if self._enable_common_pack_libs: - pack_common_libs_path = get_pack_common_libs_path_for_pack_ref(pack_ref=pack_ref) + pack_common_libs_path = get_pack_common_libs_path_for_pack_ref( + pack_ref=pack_ref + ) else: pack_common_libs_path = None env = os.environ.copy() - sandbox_python_path = get_sandbox_python_path(inherit_from_parent=True, - inherit_parent_virtualenv=True) + sandbox_python_path = get_sandbox_python_path( + inherit_from_parent=True, inherit_parent_virtualenv=True + ) if self._enable_common_pack_libs and pack_common_libs_path: - env['PYTHONPATH'] = pack_common_libs_path + ':' + sandbox_python_path + env["PYTHONPATH"] = pack_common_libs_path + ":" + sandbox_python_path else: - env['PYTHONPATH'] = sandbox_python_path + env["PYTHONPATH"] = sandbox_python_path if self._create_token: # Include full api URL and API token specific to that sensor - LOG.debug('Creating temporary auth token for sensor %s' % (sensor['class_name'])) + LOG.debug( + "Creating temporary auth token for sensor %s" % (sensor["class_name"]) + ) ttl = cfg.CONF.auth.service_token_ttl metadata = { - 'service': 'sensors_container', - 'sensor_path': sensor['file_path'], - 'sensor_class': sensor['class_name'] + "service": "sensors_container", + "sensor_path": sensor["file_path"], + "sensor_class": sensor["class_name"], } - temporary_token = create_token(username='sensors_container', ttl=ttl, metadata=metadata, - service=True) + temporary_token = create_token( + username="sensors_container", ttl=ttl, metadata=metadata, service=True + ) env[API_URL_ENV_VARIABLE_NAME] = get_full_public_api_url() env[AUTH_TOKEN_ENV_VARIABLE_NAME] = temporary_token.token @@ -332,18 +353,27 @@ def _spawn_sensor_process(self, sensor): # TODO 2: Store metadata (wrapper process id) with the token and delete # tokens for old, dead processes on startup - cmd = ' '.join(args) + cmd = " ".join(args) LOG.debug('Running sensor subprocess (cmd="%s")', cmd) # TODO: Intercept stdout and stderr for aggregated logging purposes try: - process = subprocess.Popen(args=args, stdin=None, stdout=None, - stderr=None, shell=False, env=env, - preexec_fn=on_parent_exit('SIGTERM')) + process = subprocess.Popen( + args=args, + stdin=None, + stdout=None, + stderr=None, + shell=False, + env=env, + preexec_fn=on_parent_exit("SIGTERM"), + ) except Exception as e: - cmd = ' '.join(args) - message = ('Failed to spawn process for sensor %s ("%s"): %s' % - (sensor_id, cmd, six.text_type(e))) + cmd = " ".join(args) + message = 'Failed to spawn process for sensor %s ("%s"): %s' % ( + sensor_id, + cmd, + six.text_type(e), + ) raise Exception(message) self._processes[sensor_id] = process @@ -397,32 +427,35 @@ def _respawn_sensor(self, sensor_id, sensor, exit_code): """ Method for respawning a sensor which died with a non-zero exit code. """ - extra = {'sensor_id': sensor_id, 'sensor': sensor} + extra = {"sensor_id": sensor_id, "sensor": sensor} if self._single_sensor_mode: # In single sensor mode we want to exit immediately on failure - LOG.info('Not respawning a sensor since running in single sensor mode', - extra=extra) + LOG.info( + "Not respawning a sensor since running in single sensor mode", + extra=extra, + ) self._stopped = True self._exit_code = exit_code return if self._stopped: - LOG.debug('Stopped, not respawning a dead sensor', extra=extra) + LOG.debug("Stopped, not respawning a dead sensor", extra=extra) return - should_respawn = self._should_respawn_sensor(sensor_id=sensor_id, sensor=sensor, - exit_code=exit_code) + should_respawn = self._should_respawn_sensor( + sensor_id=sensor_id, sensor=sensor, exit_code=exit_code + ) if not should_respawn: - LOG.debug('Not respawning a dead sensor', extra=extra) + LOG.debug("Not respawning a dead sensor", extra=extra) return - LOG.debug('Respawning dead sensor', extra=extra) + LOG.debug("Respawning dead sensor", extra=extra) self._sensor_respawn_counts[sensor_id] += 1 - sleep_delay = (SENSOR_RESPAWN_DELAY * self._sensor_respawn_counts[sensor_id]) + sleep_delay = SENSOR_RESPAWN_DELAY * self._sensor_respawn_counts[sensor_id] concurrency.sleep(sleep_delay) try: @@ -443,7 +476,7 @@ def _should_respawn_sensor(self, sensor_id, sensor, exit_code): respawn_count = self._sensor_respawn_counts[sensor_id] if respawn_count >= SENSOR_MAX_RESPAWN_COUNTS: - LOG.debug('Sensor has already been respawned max times, giving up') + LOG.debug("Sensor has already been respawned max times, giving up") return False return True @@ -460,23 +493,23 @@ def _get_args_for_wrapper_script(self, python_binary, sensor): :rtype: ``list`` """ - trigger_type_refs = sensor['trigger_types'] or [] - trigger_type_refs = ','.join(trigger_type_refs) + trigger_type_refs = sensor["trigger_types"] or [] + trigger_type_refs = ",".join(trigger_type_refs) parent_args = json.dumps(sys.argv[1:]) args = [ python_binary, self._wrapper_script_path, - '--pack=%s' % (sensor['pack']), - '--file-path=%s' % (sensor['file_path']), - '--class-name=%s' % (sensor['class_name']), - '--trigger-type-refs=%s' % (trigger_type_refs), - '--parent-args=%s' % (parent_args) + "--pack=%s" % (sensor["pack"]), + "--file-path=%s" % (sensor["file_path"]), + "--class-name=%s" % (sensor["class_name"]), + "--trigger-type-refs=%s" % (trigger_type_refs), + "--parent-args=%s" % (parent_args), ] - if sensor['poll_interval']: - args.append('--poll-interval=%s' % (sensor['poll_interval'])) + if sensor["poll_interval"]: + args.append("--poll-interval=%s" % (sensor["poll_interval"])) return args @@ -486,32 +519,28 @@ def _get_sensor_id(self, sensor): :type sensor: ``dict`` """ - sensor_id = sensor['ref'] + sensor_id = sensor["ref"] return sensor_id def _dispatch_trigger_for_sensor_spawn(self, sensor, process, cmd): trigger = ResourceReference.to_string_reference( - name=SENSOR_SPAWN_TRIGGER['name'], - pack=SENSOR_SPAWN_TRIGGER['pack']) + name=SENSOR_SPAWN_TRIGGER["name"], pack=SENSOR_SPAWN_TRIGGER["pack"] + ) now = int(time.time()) payload = { - 'id': sensor['class_name'], - 'timestamp': now, - 'pid': process.pid, - 'cmd': cmd + "id": sensor["class_name"], + "timestamp": now, + "pid": process.pid, + "cmd": cmd, } self._dispatcher.dispatch(trigger, payload=payload) def _dispatch_trigger_for_sensor_exit(self, sensor, exit_code): trigger = ResourceReference.to_string_reference( - name=SENSOR_EXIT_TRIGGER['name'], - pack=SENSOR_EXIT_TRIGGER['pack']) + name=SENSOR_EXIT_TRIGGER["name"], pack=SENSOR_EXIT_TRIGGER["pack"] + ) now = int(time.time()) - payload = { - 'id': sensor['class_name'], - 'timestamp': now, - 'exit_code': exit_code - } + payload = {"id": sensor["class_name"], "timestamp": now, "exit_code": exit_code} self._dispatcher.dispatch(trigger, payload=payload) def _delete_sensor(self, sensor_id): diff --git a/st2reactor/st2reactor/container/sensor_wrapper.py b/st2reactor/st2reactor/container/sensor_wrapper.py index 56a37707d21..c605b472911 100644 --- a/st2reactor/st2reactor/container/sensor_wrapper.py +++ b/st2reactor/st2reactor/container/sensor_wrapper.py @@ -25,6 +25,7 @@ # for details. from st2common.util.monkey_patch import monkey_patch + monkey_patch() import os @@ -51,10 +52,7 @@ from st2common.services.datastore import SensorDatastoreService from st2common.util.monkey_patch import use_select_poll_workaround -__all__ = [ - 'SensorWrapper', - 'SensorService' -] +__all__ = ["SensorWrapper", "SensorService"] use_select_poll_workaround(nose_only=False) @@ -69,12 +67,15 @@ def __init__(self, sensor_wrapper): self._sensor_wrapper = sensor_wrapper self._logger = self._sensor_wrapper._logger - self._trigger_dispatcher_service = TriggerDispatcherService(logger=sensor_wrapper._logger) + self._trigger_dispatcher_service = TriggerDispatcherService( + logger=sensor_wrapper._logger + ) self._datastore_service = SensorDatastoreService( logger=self._logger, pack_name=self._sensor_wrapper._pack, class_name=self._sensor_wrapper._class_name, - api_username='sensor_service') + api_username="sensor_service", + ) self._client = None @@ -86,7 +87,7 @@ def get_logger(self, name): """ Retrieve an instance of a logger to be used by the sensor class. """ - logger_name = '%s.%s' % (self._sensor_wrapper._logger.name, name) + logger_name = "%s.%s" % (self._sensor_wrapper._logger.name, name) logger = logging.getLogger(logger_name) logger.propagate = True @@ -105,9 +106,12 @@ def get_user_info(self): def dispatch(self, trigger, payload=None, trace_tag=None): # Provided by the parent BaseTriggerDispatcherService class - return self._trigger_dispatcher_service.dispatch(trigger=trigger, payload=payload, - trace_tag=trace_tag, - throw_on_validation_error=False) + return self._trigger_dispatcher_service.dispatch( + trigger=trigger, + payload=payload, + trace_tag=trace_tag, + throw_on_validation_error=False, + ) def dispatch_with_context(self, trigger, payload=None, trace_context=None): """ @@ -123,10 +127,12 @@ def dispatch_with_context(self, trigger, payload=None, trace_context=None): :type trace_context: ``st2common.api.models.api.trace.TraceContext`` """ # Provided by the parent BaseTriggerDispatcherService class - return self._trigger_dispatcher_service.dispatch_with_context(trigger=trigger, + return self._trigger_dispatcher_service.dispatch_with_context( + trigger=trigger, payload=payload, trace_context=trace_context, - throw_on_validation_error=False) + throw_on_validation_error=False, + ) ################################## # Methods for datastore management @@ -136,20 +142,31 @@ def list_values(self, local=True, prefix=None): return self.datastore_service.list_values(local=local, prefix=prefix) def get_value(self, name, local=True, scope=SYSTEM_SCOPE, decrypt=False): - return self.datastore_service.get_value(name=name, local=local, scope=scope, - decrypt=decrypt) + return self.datastore_service.get_value( + name=name, local=local, scope=scope, decrypt=decrypt + ) - def set_value(self, name, value, ttl=None, local=True, scope=SYSTEM_SCOPE, encrypt=False): - return self.datastore_service.set_value(name=name, value=value, ttl=ttl, local=local, - scope=scope, encrypt=encrypt) + def set_value( + self, name, value, ttl=None, local=True, scope=SYSTEM_SCOPE, encrypt=False + ): + return self.datastore_service.set_value( + name=name, value=value, ttl=ttl, local=local, scope=scope, encrypt=encrypt + ) def delete_value(self, name, local=True, scope=SYSTEM_SCOPE): return self.datastore_service.delete_value(name=name, local=local, scope=scope) class SensorWrapper(object): - def __init__(self, pack, file_path, class_name, trigger_types, - poll_interval=None, parent_args=None): + def __init__( + self, + pack, + file_path, + class_name, + trigger_types, + poll_interval=None, + parent_args=None, + ): """ :param pack: Name of the pack this sensor belongs to. :type pack: ``str`` @@ -185,32 +202,48 @@ def __init__(self, pack, file_path, class_name, trigger_types, pass # 2. Establish DB connection - username = cfg.CONF.database.username if hasattr(cfg.CONF.database, 'username') else None - password = cfg.CONF.database.password if hasattr(cfg.CONF.database, 'password') else None - db_setup_with_retry(cfg.CONF.database.db_name, cfg.CONF.database.host, - cfg.CONF.database.port, username=username, password=password, - ssl=cfg.CONF.database.ssl, ssl_keyfile=cfg.CONF.database.ssl_keyfile, - ssl_certfile=cfg.CONF.database.ssl_certfile, - ssl_cert_reqs=cfg.CONF.database.ssl_cert_reqs, - ssl_ca_certs=cfg.CONF.database.ssl_ca_certs, - authentication_mechanism=cfg.CONF.database.authentication_mechanism, - ssl_match_hostname=cfg.CONF.database.ssl_match_hostname) + username = ( + cfg.CONF.database.username + if hasattr(cfg.CONF.database, "username") + else None + ) + password = ( + cfg.CONF.database.password + if hasattr(cfg.CONF.database, "password") + else None + ) + db_setup_with_retry( + cfg.CONF.database.db_name, + cfg.CONF.database.host, + cfg.CONF.database.port, + username=username, + password=password, + ssl=cfg.CONF.database.ssl, + ssl_keyfile=cfg.CONF.database.ssl_keyfile, + ssl_certfile=cfg.CONF.database.ssl_certfile, + ssl_cert_reqs=cfg.CONF.database.ssl_cert_reqs, + ssl_ca_certs=cfg.CONF.database.ssl_ca_certs, + authentication_mechanism=cfg.CONF.database.authentication_mechanism, + ssl_match_hostname=cfg.CONF.database.ssl_match_hostname, + ) # 3. Instantiate the watcher - self._trigger_watcher = TriggerWatcher(create_handler=self._handle_create_trigger, - update_handler=self._handle_update_trigger, - delete_handler=self._handle_delete_trigger, - trigger_types=self._trigger_types, - queue_suffix='sensorwrapper_%s_%s' % - (self._pack, self._class_name), - exclusive=True) + self._trigger_watcher = TriggerWatcher( + create_handler=self._handle_create_trigger, + update_handler=self._handle_update_trigger, + delete_handler=self._handle_delete_trigger, + trigger_types=self._trigger_types, + queue_suffix="sensorwrapper_%s_%s" % (self._pack, self._class_name), + exclusive=True, + ) # 4. Set up logging - self._logger = logging.getLogger('SensorWrapper.%s.%s' % - (self._pack, self._class_name)) + self._logger = logging.getLogger( + "SensorWrapper.%s.%s" % (self._pack, self._class_name) + ) logging.setup(cfg.CONF.sensorcontainer.logging) - if '--debug' in parent_args: + if "--debug" in parent_args: set_log_level_for_all_loggers() else: # NOTE: statsd logger logs everything by default under INFO so we ignore those log @@ -223,16 +256,17 @@ def run(self): atexit.register(self.stop) self._trigger_watcher.start() - self._logger.info('Watcher started') + self._logger.info("Watcher started") - self._logger.info('Running sensor initialization code') + self._logger.info("Running sensor initialization code") self._sensor_instance.setup() if self._poll_interval: - message = ('Running sensor in active mode (poll interval=%ss)' % - (self._poll_interval)) + message = "Running sensor in active mode (poll interval=%ss)" % ( + self._poll_interval + ) else: - message = 'Running sensor in passive mode' + message = "Running sensor in passive mode" self._logger.info(message) @@ -240,18 +274,20 @@ def run(self): self._sensor_instance.run() except Exception as e: # Include traceback - msg = ('Sensor "%s" run method raised an exception: %s.' % - (self._class_name, six.text_type(e))) + msg = 'Sensor "%s" run method raised an exception: %s.' % ( + self._class_name, + six.text_type(e), + ) self._logger.warn(msg, exc_info=True) raise Exception(msg) def stop(self): # Stop watcher - self._logger.info('Stopping trigger watcher') + self._logger.info("Stopping trigger watcher") self._trigger_watcher.stop() # Run sensor cleanup code - self._logger.info('Invoking cleanup on sensor') + self._logger.info("Invoking cleanup on sensor") self._sensor_instance.cleanup() ############################################## @@ -259,16 +295,18 @@ def stop(self): ############################################## def _handle_create_trigger(self, trigger): - self._logger.debug('Calling sensor "add_trigger" method (trigger.type=%s)' % - (trigger.type)) + self._logger.debug( + 'Calling sensor "add_trigger" method (trigger.type=%s)' % (trigger.type) + ) self._trigger_names[str(trigger.id)] = trigger trigger = self._sanitize_trigger(trigger=trigger) self._sensor_instance.add_trigger(trigger=trigger) def _handle_update_trigger(self, trigger): - self._logger.debug('Calling sensor "update_trigger" method (trigger.type=%s)' % - (trigger.type)) + self._logger.debug( + 'Calling sensor "update_trigger" method (trigger.type=%s)' % (trigger.type) + ) self._trigger_names[str(trigger.id)] = trigger trigger = self._sanitize_trigger(trigger=trigger) @@ -279,8 +317,9 @@ def _handle_delete_trigger(self, trigger): if trigger_id not in self._trigger_names: return - self._logger.debug('Calling sensor "remove_trigger" method (trigger.type=%s)' % - (trigger.type)) + self._logger.debug( + 'Calling sensor "remove_trigger" method (trigger.type=%s)' % (trigger.type) + ) del self._trigger_names[trigger_id] trigger = self._sanitize_trigger(trigger=trigger) @@ -294,35 +333,45 @@ def _get_sensor_instance(self): module_name, _ = os.path.splitext(filename) try: - sensor_class = loader.register_plugin_class(base_class=Sensor, - file_path=self._file_path, - class_name=self._class_name) + sensor_class = loader.register_plugin_class( + base_class=Sensor, + file_path=self._file_path, + class_name=self._class_name, + ) except Exception as e: tb_msg = traceback.format_exc() - msg = ('Failed to load sensor class from file "%s" (sensor file most likely doesn\'t ' - 'exist or contains invalid syntax): %s' % (self._file_path, six.text_type(e))) - msg += '\n\n' + tb_msg + msg = ( + 'Failed to load sensor class from file "%s" (sensor file most likely doesn\'t ' + "exist or contains invalid syntax): %s" + % (self._file_path, six.text_type(e)) + ) + msg += "\n\n" + tb_msg exc_cls = type(e) raise exc_cls(msg) if not sensor_class: - raise ValueError('Sensor module is missing a class with name "%s"' % - (self._class_name)) + raise ValueError( + 'Sensor module is missing a class with name "%s"' % (self._class_name) + ) sensor_class_kwargs = {} - sensor_class_kwargs['sensor_service'] = SensorService(sensor_wrapper=self) + sensor_class_kwargs["sensor_service"] = SensorService(sensor_wrapper=self) sensor_config = self._get_sensor_config() - sensor_class_kwargs['config'] = sensor_config + sensor_class_kwargs["config"] = sensor_config if self._poll_interval and issubclass(sensor_class, PollingSensor): - sensor_class_kwargs['poll_interval'] = self._poll_interval + sensor_class_kwargs["poll_interval"] = self._poll_interval try: sensor_instance = sensor_class(**sensor_class_kwargs) except Exception: - self._logger.exception('Failed to instantiate "%s" sensor class' % (self._class_name)) - raise Exception('Failed to instantiate "%s" sensor class' % (self._class_name)) + self._logger.exception( + 'Failed to instantiate "%s" sensor class' % (self._class_name) + ) + raise Exception( + 'Failed to instantiate "%s" sensor class' % (self._class_name) + ) return sensor_instance @@ -342,31 +391,43 @@ def _sanitize_trigger(self, trigger): return sanitized -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Sensor runner wrapper') - parser.add_argument('--pack', required=True, - help='Name of the pack this sensor belongs to') - parser.add_argument('--file-path', required=True, - help='Path to the sensor module') - parser.add_argument('--class-name', required=True, - help='Name of the sensor class') - parser.add_argument('--trigger-type-refs', required=False, - help='Comma delimited string of trigger type references') - parser.add_argument('--poll-interval', type=int, default=None, required=False, - help='Sensor poll interval') - parser.add_argument('--parent-args', required=False, - help='Command line arguments passed to the parent process') +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Sensor runner wrapper") + parser.add_argument( + "--pack", required=True, help="Name of the pack this sensor belongs to" + ) + parser.add_argument("--file-path", required=True, help="Path to the sensor module") + parser.add_argument("--class-name", required=True, help="Name of the sensor class") + parser.add_argument( + "--trigger-type-refs", + required=False, + help="Comma delimited string of trigger type references", + ) + parser.add_argument( + "--poll-interval", + type=int, + default=None, + required=False, + help="Sensor poll interval", + ) + parser.add_argument( + "--parent-args", + required=False, + help="Command line arguments passed to the parent process", + ) args = parser.parse_args() trigger_types = args.trigger_type_refs - trigger_types = trigger_types.split(',') if trigger_types else [] + trigger_types = trigger_types.split(",") if trigger_types else [] parent_args = json.loads(args.parent_args) if args.parent_args else [] assert isinstance(parent_args, list) - obj = SensorWrapper(pack=args.pack, - file_path=args.file_path, - class_name=args.class_name, - trigger_types=trigger_types, - poll_interval=args.poll_interval, - parent_args=parent_args) + obj = SensorWrapper( + pack=args.pack, + file_path=args.file_path, + class_name=args.class_name, + trigger_types=trigger_types, + poll_interval=args.poll_interval, + parent_args=parent_args, + ) obj.run() diff --git a/st2reactor/st2reactor/container/utils.py b/st2reactor/st2reactor/container/utils.py index a156d209b05..6b059046274 100644 --- a/st2reactor/st2reactor/container/utils.py +++ b/st2reactor/st2reactor/container/utils.py @@ -22,10 +22,12 @@ from st2common.persistence.trigger import TriggerInstance from st2common.services.triggers import get_trigger_db_by_ref_or_dict -LOG = logging.getLogger('st2reactor.sensor.container_utils') +LOG = logging.getLogger("st2reactor.sensor.container_utils") -def create_trigger_instance(trigger, payload, occurrence_time, raise_on_no_trigger=False): +def create_trigger_instance( + trigger, payload, occurrence_time, raise_on_no_trigger=False +): """ This creates a trigger instance object given trigger and payload. Trigger can be just a string reference (pack.name) or a ``dict`` containing 'id' or @@ -40,9 +42,9 @@ def create_trigger_instance(trigger, payload, occurrence_time, raise_on_no_trigg trigger_db = get_trigger_db_by_ref_or_dict(trigger=trigger) if not trigger_db: - LOG.debug('No trigger in db for %s', trigger) + LOG.debug("No trigger in db for %s", trigger) if raise_on_no_trigger: - raise StackStormDBObjectNotFoundError('Trigger not found for %s' % trigger) + raise StackStormDBObjectNotFoundError("Trigger not found for %s" % trigger) return None trigger_ref = trigger_db.get_reference().ref diff --git a/st2reactor/st2reactor/garbage_collector/base.py b/st2reactor/st2reactor/garbage_collector/base.py index 32614586778..bb963e8e51c 100644 --- a/st2reactor/st2reactor/garbage_collector/base.py +++ b/st2reactor/st2reactor/garbage_collector/base.py @@ -42,16 +42,17 @@ from st2common.garbage_collection.inquiries import purge_inquiries from st2common.garbage_collection.trigger_instances import purge_trigger_instances -__all__ = [ - 'GarbageCollectorService' -] +__all__ = ["GarbageCollectorService"] LOG = logging.getLogger(__name__) class GarbageCollectorService(object): - def __init__(self, collection_interval=DEFAULT_COLLECTION_INTERVAL, - sleep_delay=DEFAULT_SLEEP_DELAY): + def __init__( + self, + collection_interval=DEFAULT_COLLECTION_INTERVAL, + sleep_delay=DEFAULT_SLEEP_DELAY, + ): """ :param collection_interval: How often to check database for old data and perform garbage collection. @@ -64,7 +65,9 @@ def __init__(self, collection_interval=DEFAULT_COLLECTION_INTERVAL, self._collection_interval = collection_interval self._action_executions_ttl = cfg.CONF.garbagecollector.action_executions_ttl - self._action_executions_output_ttl = cfg.CONF.garbagecollector.action_executions_output_ttl + self._action_executions_output_ttl = ( + cfg.CONF.garbagecollector.action_executions_output_ttl + ) self._trigger_instances_ttl = cfg.CONF.garbagecollector.trigger_instances_ttl self._purge_inquiries = cfg.CONF.garbagecollector.purge_inquiries self._workflow_execution_max_idle = cfg.CONF.workflow_engine.gc_max_idle_sec @@ -91,7 +94,7 @@ def run(self): self._running = False return SUCCESS_EXIT_CODE except Exception as e: - LOG.exception('Exception in the garbage collector: %s' % (six.text_type(e))) + LOG.exception("Exception in the garbage collector: %s" % (six.text_type(e))) self._running = False return FAILURE_EXIT_CODE @@ -101,7 +104,7 @@ def _register_signal_handlers(self): signal.signal(signal.SIGUSR2, self.handle_sigusr2) def handle_sigusr2(self, signal_number, stack_frame): - LOG.info('Forcing garbage collection...') + LOG.info("Forcing garbage collection...") self._perform_garbage_collection() def shutdown(self): @@ -111,61 +114,88 @@ def _main_loop(self): while self._running: self._perform_garbage_collection() - LOG.info('Sleeping for %s seconds before next garbage collection...' % - (self._collection_interval)) + LOG.info( + "Sleeping for %s seconds before next garbage collection..." + % (self._collection_interval) + ) concurrency.sleep(self._collection_interval) def _validate_ttl_values(self): """ Validate that a user has supplied reasonable TTL values. """ - if self._action_executions_ttl and self._action_executions_ttl < MINIMUM_TTL_DAYS: - raise ValueError('Minimum possible TTL for action_executions_ttl in days is %s' % - (MINIMUM_TTL_DAYS)) - - if self._trigger_instances_ttl and self._trigger_instances_ttl < MINIMUM_TTL_DAYS: - raise ValueError('Minimum possible TTL for trigger_instances_ttl in days is %s' % - (MINIMUM_TTL_DAYS)) - - if self._action_executions_output_ttl and \ - self._action_executions_output_ttl < MINIMUM_TTL_DAYS_EXECUTION_OUTPUT: - raise ValueError(('Minimum possible TTL for action_executions_output_ttl in days ' - 'is %s') % (MINIMUM_TTL_DAYS_EXECUTION_OUTPUT)) + if ( + self._action_executions_ttl + and self._action_executions_ttl < MINIMUM_TTL_DAYS + ): + raise ValueError( + "Minimum possible TTL for action_executions_ttl in days is %s" + % (MINIMUM_TTL_DAYS) + ) + + if ( + self._trigger_instances_ttl + and self._trigger_instances_ttl < MINIMUM_TTL_DAYS + ): + raise ValueError( + "Minimum possible TTL for trigger_instances_ttl in days is %s" + % (MINIMUM_TTL_DAYS) + ) + + if ( + self._action_executions_output_ttl + and self._action_executions_output_ttl < MINIMUM_TTL_DAYS_EXECUTION_OUTPUT + ): + raise ValueError( + ( + "Minimum possible TTL for action_executions_output_ttl in days " + "is %s" + ) + % (MINIMUM_TTL_DAYS_EXECUTION_OUTPUT) + ) def _perform_garbage_collection(self): - LOG.info('Performing garbage collection...') + LOG.info("Performing garbage collection...") proc_message = "Performing garbage collection for %s." skip_message = "Skipping garbage collection for %s since it's not configured." # Note: We sleep for a bit between garbage collection of each object type to prevent busy # waiting - obj_type = 'action executions' - if self._action_executions_ttl and self._action_executions_ttl >= MINIMUM_TTL_DAYS: + obj_type = "action executions" + if ( + self._action_executions_ttl + and self._action_executions_ttl >= MINIMUM_TTL_DAYS + ): LOG.info(proc_message, obj_type) self._purge_action_executions() concurrency.sleep(self._sleep_delay) else: LOG.debug(skip_message, obj_type) - obj_type = 'action executions output' - if self._action_executions_output_ttl and \ - self._action_executions_output_ttl >= MINIMUM_TTL_DAYS_EXECUTION_OUTPUT: + obj_type = "action executions output" + if ( + self._action_executions_output_ttl + and self._action_executions_output_ttl >= MINIMUM_TTL_DAYS_EXECUTION_OUTPUT + ): LOG.info(proc_message, obj_type) self._purge_action_executions_output() concurrency.sleep(self._sleep_delay) else: LOG.debug(skip_message, obj_type) - obj_type = 'trigger instances' - if self._trigger_instances_ttl and self._trigger_instances_ttl >= MINIMUM_TTL_DAYS: + obj_type = "trigger instances" + if ( + self._trigger_instances_ttl + and self._trigger_instances_ttl >= MINIMUM_TTL_DAYS + ): LOG.info(proc_message, obj_type) self._purge_trigger_instances() concurrency.sleep(self._sleep_delay) else: LOG.debug(skip_message, obj_type) - obj_type = 'inquiries' + obj_type = "inquiries" if self._purge_inquiries: LOG.info(proc_message, obj_type) self._timeout_inquiries() @@ -173,7 +203,7 @@ def _perform_garbage_collection(self): else: LOG.debug(skip_message, obj_type) - obj_type = 'orphaned workflow executions' + obj_type = "orphaned workflow executions" if self._workflow_execution_max_idle > 0: LOG.info(proc_message, obj_type) self._purge_orphaned_workflow_executions() @@ -187,41 +217,53 @@ def _purge_action_executions(self): the criteria defined in the config. """ utc_now = get_datetime_utc_now() - timestamp = (utc_now - datetime.timedelta(days=self._action_executions_ttl)) + timestamp = utc_now - datetime.timedelta(days=self._action_executions_ttl) # Another sanity check to make sure we don't delete new executions if timestamp > (utc_now - datetime.timedelta(days=MINIMUM_TTL_DAYS)): - raise ValueError('Calculated timestamp would violate the minimum TTL constraint') + raise ValueError( + "Calculated timestamp would violate the minimum TTL constraint" + ) timestamp_str = isotime.format(dt=timestamp) - LOG.info('Deleting action executions older than: %s' % (timestamp_str)) + LOG.info("Deleting action executions older than: %s" % (timestamp_str)) assert timestamp < utc_now try: purge_executions(logger=LOG, timestamp=timestamp) except Exception as e: - LOG.exception('Failed to delete executions: %s' % (six.text_type(e))) + LOG.exception("Failed to delete executions: %s" % (six.text_type(e))) return True def _purge_action_executions_output(self): utc_now = get_datetime_utc_now() - timestamp = (utc_now - datetime.timedelta(days=self._action_executions_output_ttl)) + timestamp = utc_now - datetime.timedelta( + days=self._action_executions_output_ttl + ) # Another sanity check to make sure we don't delete new objects - if timestamp > (utc_now - datetime.timedelta(days=MINIMUM_TTL_DAYS_EXECUTION_OUTPUT)): - raise ValueError('Calculated timestamp would violate the minimum TTL constraint') + if timestamp > ( + utc_now - datetime.timedelta(days=MINIMUM_TTL_DAYS_EXECUTION_OUTPUT) + ): + raise ValueError( + "Calculated timestamp would violate the minimum TTL constraint" + ) timestamp_str = isotime.format(dt=timestamp) - LOG.info('Deleting action executions output objects older than: %s' % (timestamp_str)) + LOG.info( + "Deleting action executions output objects older than: %s" % (timestamp_str) + ) assert timestamp < utc_now try: purge_execution_output_objects(logger=LOG, timestamp=timestamp) except Exception as e: - LOG.exception('Failed to delete execution output objects: %s' % (six.text_type(e))) + LOG.exception( + "Failed to delete execution output objects: %s" % (six.text_type(e)) + ) return True @@ -230,31 +272,32 @@ def _purge_trigger_instances(self): Purge trigger instances which match the criteria defined in the config. """ utc_now = get_datetime_utc_now() - timestamp = (utc_now - datetime.timedelta(days=self._trigger_instances_ttl)) + timestamp = utc_now - datetime.timedelta(days=self._trigger_instances_ttl) # Another sanity check to make sure we don't delete new executions if timestamp > (utc_now - datetime.timedelta(days=MINIMUM_TTL_DAYS)): - raise ValueError('Calculated timestamp would violate the minimum TTL constraint') + raise ValueError( + "Calculated timestamp would violate the minimum TTL constraint" + ) timestamp_str = isotime.format(dt=timestamp) - LOG.info('Deleting trigger instances older than: %s' % (timestamp_str)) + LOG.info("Deleting trigger instances older than: %s" % (timestamp_str)) assert timestamp < utc_now try: purge_trigger_instances(logger=LOG, timestamp=timestamp) except Exception as e: - LOG.exception('Failed to trigger instances: %s' % (six.text_type(e))) + LOG.exception("Failed to trigger instances: %s" % (six.text_type(e))) return True def _timeout_inquiries(self): - """Mark Inquiries as "timeout" that have exceeded their TTL - """ + """Mark Inquiries as "timeout" that have exceeded their TTL""" try: purge_inquiries(logger=LOG) except Exception as e: - LOG.exception('Failed to purge inquiries: %s' % (six.text_type(e))) + LOG.exception("Failed to purge inquiries: %s" % (six.text_type(e))) return True @@ -265,6 +308,8 @@ def _purge_orphaned_workflow_executions(self): try: purge_orphaned_workflow_executions(logger=LOG) except Exception as e: - LOG.exception('Failed to purge orphaned workflow executions: %s' % (six.text_type(e))) + LOG.exception( + "Failed to purge orphaned workflow executions: %s" % (six.text_type(e)) + ) return True diff --git a/st2reactor/st2reactor/garbage_collector/config.py b/st2reactor/st2reactor/garbage_collector/config.py index 19cf53362e9..9a0faf0dec0 100644 --- a/st2reactor/st2reactor/garbage_collector/config.py +++ b/st2reactor/st2reactor/garbage_collector/config.py @@ -29,8 +29,11 @@ def parse_args(args=None): - cfg.CONF(args=args, version=VERSION_STRING, - default_config_files=[DEFAULT_CONFIG_FILE_PATH]) + cfg.CONF( + args=args, + version=VERSION_STRING, + default_config_files=[DEFAULT_CONFIG_FILE_PATH], + ) def register_opts(): @@ -49,48 +52,62 @@ def _register_common_opts(): def _register_garbage_collector_opts(): logging_opts = [ cfg.StrOpt( - 'logging', default='/etc/st2/logging.garbagecollector.conf', - help='Location of the logging configuration file.') + "logging", + default="/etc/st2/logging.garbagecollector.conf", + help="Location of the logging configuration file.", + ) ] - CONF.register_opts(logging_opts, group='garbagecollector') + CONF.register_opts(logging_opts, group="garbagecollector") common_opts = [ cfg.IntOpt( - 'collection_interval', default=DEFAULT_COLLECTION_INTERVAL, - help='How often to check database for old data and perform garbage collection.'), + "collection_interval", + default=DEFAULT_COLLECTION_INTERVAL, + help="How often to check database for old data and perform garbage collection.", + ), cfg.FloatOpt( - 'sleep_delay', default=DEFAULT_SLEEP_DELAY, - help='How long to wait / sleep (in seconds) between ' - 'collection of different object types.') + "sleep_delay", + default=DEFAULT_SLEEP_DELAY, + help="How long to wait / sleep (in seconds) between " + "collection of different object types.", + ), ] - CONF.register_opts(common_opts, group='garbagecollector') + CONF.register_opts(common_opts, group="garbagecollector") ttl_opts = [ cfg.IntOpt( - 'action_executions_ttl', default=None, - help='Action executions and related objects (live actions, action output ' - 'objects) older than this value (days) will be automatically deleted.'), + "action_executions_ttl", + default=None, + help="Action executions and related objects (live actions, action output " + "objects) older than this value (days) will be automatically deleted.", + ), cfg.IntOpt( - 'action_executions_output_ttl', default=7, - help='Action execution output objects (ones generated by action output ' - 'streaming) older than this value (days) will be automatically deleted.'), + "action_executions_output_ttl", + default=7, + help="Action execution output objects (ones generated by action output " + "streaming) older than this value (days) will be automatically deleted.", + ), cfg.IntOpt( - 'trigger_instances_ttl', default=None, - help='Trigger instances older than this value (days) will be automatically deleted.') + "trigger_instances_ttl", + default=None, + help="Trigger instances older than this value (days) will be automatically deleted.", + ), ] - CONF.register_opts(ttl_opts, group='garbagecollector') + CONF.register_opts(ttl_opts, group="garbagecollector") inquiry_opts = [ cfg.BoolOpt( - 'purge_inquiries', default=False, - help='Set to True to perform garbage collection on Inquiries (based on ' - 'the TTL value per Inquiry)') + "purge_inquiries", + default=False, + help="Set to True to perform garbage collection on Inquiries (based on " + "the TTL value per Inquiry)", + ) ] - CONF.register_opts(inquiry_opts, group='garbagecollector') + CONF.register_opts(inquiry_opts, group="garbagecollector") register_opts() diff --git a/st2reactor/st2reactor/rules/config.py b/st2reactor/st2reactor/rules/config.py index 004c45b8708..637ef4e4573 100644 --- a/st2reactor/st2reactor/rules/config.py +++ b/st2reactor/st2reactor/rules/config.py @@ -27,8 +27,11 @@ def parse_args(args=None): - cfg.CONF(args=args, version=VERSION_STRING, - default_config_files=[DEFAULT_CONFIG_FILE_PATH]) + cfg.CONF( + args=args, + version=VERSION_STRING, + default_config_files=[DEFAULT_CONFIG_FILE_PATH], + ) def register_opts(): @@ -47,11 +50,13 @@ def _register_common_opts(): def _register_rules_engine_opts(): logging_opts = [ cfg.StrOpt( - 'logging', default='/etc/st2/logging.rulesengine.conf', - help='Location of the logging configuration file.') + "logging", + default="/etc/st2/logging.rulesengine.conf", + help="Location of the logging configuration file.", + ) ] - CONF.register_opts(logging_opts, group='rulesengine') + CONF.register_opts(logging_opts, group="rulesengine") register_opts() diff --git a/st2reactor/st2reactor/rules/enforcer.py b/st2reactor/st2reactor/rules/enforcer.py index 594f157482c..4d34b86ce2d 100644 --- a/st2reactor/st2reactor/rules/enforcer.py +++ b/st2reactor/st2reactor/rules/enforcer.py @@ -40,15 +40,15 @@ from st2common.exceptions import param as param_exc from st2common.exceptions import apivalidation as validation_exc -__all__ = [ - 'RuleEnforcer' -] +__all__ = ["RuleEnforcer"] -LOG = logging.getLogger('st2reactor.ruleenforcement.enforce') +LOG = logging.getLogger("st2reactor.ruleenforcement.enforce") -EXEC_KICKED_OFF_STATES = [action_constants.LIVEACTION_STATUS_SCHEDULED, - action_constants.LIVEACTION_STATUS_REQUESTED] +EXEC_KICKED_OFF_STATES = [ + action_constants.LIVEACTION_STATUS_SCHEDULED, + action_constants.LIVEACTION_STATUS_REQUESTED, +] class RuleEnforcer(object): @@ -58,95 +58,117 @@ def __init__(self, trigger_instance, rule): def get_action_execution_context(self, action_db, trace_context=None): context = { - 'trigger_instance': reference.get_ref_from_model(self.trigger_instance), - 'rule': reference.get_ref_from_model(self.rule), - 'user': get_system_username(), - 'pack': action_db.pack, + "trigger_instance": reference.get_ref_from_model(self.trigger_instance), + "rule": reference.get_ref_from_model(self.rule), + "user": get_system_username(), + "pack": action_db.pack, } if trace_context is not None: context[TRACE_CONTEXT] = trace_context # Additional non-action / global context - additional_context = { - TRIGGER_PAYLOAD_PREFIX: self.trigger_instance.payload - } + additional_context = {TRIGGER_PAYLOAD_PREFIX: self.trigger_instance.payload} return context, additional_context - def get_resolved_parameters(self, action_db, runnertype_db, params, context=None, - additional_contexts=None): + def get_resolved_parameters( + self, action_db, runnertype_db, params, context=None, additional_contexts=None + ): resolved_params = param_utils.render_live_params( runner_parameters=runnertype_db.runner_parameters, action_parameters=action_db.parameters, params=params, action_context=context, - additional_contexts=additional_contexts) + additional_contexts=additional_contexts, + ) return resolved_params def enforce(self): - rule_spec = {'ref': self.rule.ref, 'id': str(self.rule.id), 'uid': self.rule.uid} - enforcement_db = RuleEnforcementDB(trigger_instance_id=str(self.trigger_instance.id), - rule=rule_spec) - extra = { - 'trigger_instance_db': self.trigger_instance, - 'rule_db': self.rule + rule_spec = { + "ref": self.rule.ref, + "id": str(self.rule.id), + "uid": self.rule.uid, } + enforcement_db = RuleEnforcementDB( + trigger_instance_id=str(self.trigger_instance.id), rule=rule_spec + ) + extra = {"trigger_instance_db": self.trigger_instance, "rule_db": self.rule} execution_db = None try: execution_db = self._do_enforce() # pylint: disable=no-member enforcement_db.execution_id = str(execution_db.id) enforcement_db.status = RULE_ENFORCEMENT_STATUS_SUCCEEDED - extra['execution_db'] = execution_db + extra["execution_db"] = execution_db except Exception as e: # Record the failure reason in the RuleEnforcement. enforcement_db.status = RULE_ENFORCEMENT_STATUS_FAILED enforcement_db.failure_reason = six.text_type(e) - LOG.exception('Failed kicking off execution for rule %s.', self.rule, extra=extra) + LOG.exception( + "Failed kicking off execution for rule %s.", self.rule, extra=extra + ) finally: self._update_enforcement(enforcement_db) # pylint: disable=no-member if not execution_db or execution_db.status not in EXEC_KICKED_OFF_STATES: - LOG.audit('Rule enforcement failed. Execution of Action %s failed. ' - 'TriggerInstance: %s and Rule: %s', - self.rule.action.ref, self.trigger_instance, self.rule, - extra=extra) + LOG.audit( + "Rule enforcement failed. Execution of Action %s failed. " + "TriggerInstance: %s and Rule: %s", + self.rule.action.ref, + self.trigger_instance, + self.rule, + extra=extra, + ) else: - LOG.audit('Rule enforced. Execution %s, TriggerInstance %s and Rule %s.', - execution_db, self.trigger_instance, self.rule, extra=extra) + LOG.audit( + "Rule enforced. Execution %s, TriggerInstance %s and Rule %s.", + execution_db, + self.trigger_instance, + self.rule, + extra=extra, + ) return execution_db def _do_enforce(self): # TODO: Refactor this to avoid additional lookup in cast_params - action_ref = self.rule.action['ref'] + action_ref = self.rule.action["ref"] # Verify action referenced in the rule exists in the database action_db = action_utils.get_action_by_ref(action_ref) if not action_db: raise ValueError('Action "%s" doesn\'t exist' % (action_ref)) - runnertype_db = action_utils.get_runnertype_by_name(action_db.runner_type['name']) + runnertype_db = action_utils.get_runnertype_by_name( + action_db.runner_type["name"] + ) params = self.rule.action.parameters - LOG.info('Invoking action %s for trigger_instance %s with params %s.', - self.rule.action.ref, self.trigger_instance.id, - json.dumps(params)) + LOG.info( + "Invoking action %s for trigger_instance %s with params %s.", + self.rule.action.ref, + self.trigger_instance.id, + json.dumps(params), + ) # update trace before invoking the action. trace_context = self._update_trace() - LOG.debug('Updated trace %s with rule %s.', trace_context, self.rule.id) + LOG.debug("Updated trace %s with rule %s.", trace_context, self.rule.id) context, additional_contexts = self.get_action_execution_context( - action_db=action_db, - trace_context=trace_context) + action_db=action_db, trace_context=trace_context + ) - return self._invoke_action(action_db=action_db, runnertype_db=runnertype_db, params=params, - context=context, - additional_contexts=additional_contexts) + return self._invoke_action( + action_db=action_db, + runnertype_db=runnertype_db, + params=params, + context=context, + additional_contexts=additional_contexts, + ) def _update_trace(self): """ @@ -154,9 +176,13 @@ def _update_trace(self): """ trace_db = None try: - trace_db = trace_service.get_trace_db_by_trigger_instance(self.trigger_instance) + trace_db = trace_service.get_trace_db_by_trigger_instance( + self.trigger_instance + ) except: - LOG.exception('No Trace found for TriggerInstance %s.', self.trigger_instance.id) + LOG.exception( + "No Trace found for TriggerInstance %s.", self.trigger_instance.id + ) return None # This would signify some sort of coding error so assert. @@ -165,19 +191,23 @@ def _update_trace(self): trace_db = trace_service.add_or_update_given_trace_db( trace_db=trace_db, rules=[ - trace_service.get_trace_component_for_rule(self.rule, self.trigger_instance) - ]) + trace_service.get_trace_component_for_rule( + self.rule, self.trigger_instance + ) + ], + ) return vars(TraceContext(id_=str(trace_db.id), trace_tag=trace_db.trace_tag)) def _update_enforcement(self, enforcement_db): try: RuleEnforcement.add_or_update(enforcement_db) except: - extra = {'enforcement_db': enforcement_db} - LOG.exception('Failed writing enforcement model to db.', extra=extra) + extra = {"enforcement_db": enforcement_db} + LOG.exception("Failed writing enforcement model to db.", extra=extra) - def _invoke_action(self, action_db, runnertype_db, params, context=None, - additional_contexts=None): + def _invoke_action( + self, action_db, runnertype_db, params, context=None, additional_contexts=None + ): """ Schedule an action execution. @@ -189,9 +219,13 @@ def _invoke_action(self, action_db, runnertype_db, params, context=None, :rtype: :class:`LiveActionDB` on successful scheduling, None otherwise. """ action_ref = action_db.ref - runnertype_db = action_utils.get_runnertype_by_name(action_db.runner_type['name']) + runnertype_db = action_utils.get_runnertype_by_name( + action_db.runner_type["name"] + ) - liveaction_db = LiveActionDB(action=action_ref, context=context, parameters=params) + liveaction_db = LiveActionDB( + action=action_ref, context=context, parameters=params + ) try: liveaction_db.parameters = self.get_resolved_parameters( @@ -199,7 +233,8 @@ def _invoke_action(self, action_db, runnertype_db, params, context=None, action_db=action_db, params=liveaction_db.parameters, context=liveaction_db.context, - additional_contexts=additional_contexts) + additional_contexts=additional_contexts, + ) except param_exc.ParamException as e: # We still need to create a request, so liveaction_db is assigned an ID liveaction_db, execution_db = action_service.create_request(liveaction_db) @@ -209,8 +244,11 @@ def _invoke_action(self, action_db, runnertype_db, params, context=None, action_service.update_status( liveaction=liveaction_db, new_status=action_constants.LIVEACTION_STATUS_FAILED, - result={'error': six.text_type(e), - 'traceback': ''.join(traceback.format_tb(tb, 20))}) + result={ + "error": six.text_type(e), + "traceback": "".join(traceback.format_tb(tb, 20)), + }, + ) # Might be a good idea to return the actual ActionExecution rather than bubble up # the exception. diff --git a/st2reactor/st2reactor/rules/engine.py b/st2reactor/st2reactor/rules/engine.py index 1d50d01c9e3..453a0457da8 100644 --- a/st2reactor/st2reactor/rules/engine.py +++ b/st2reactor/st2reactor/rules/engine.py @@ -21,11 +21,9 @@ from st2reactor.rules.matcher import RulesMatcher from st2common.metrics.base import get_driver -LOG = logging.getLogger('st2reactor.rules.RulesEngine') +LOG = logging.getLogger("st2reactor.rules.RulesEngine") -__all__ = [ - 'RulesEngine' -] +__all__ = ["RulesEngine"] class RulesEngine(object): @@ -40,7 +38,10 @@ def handle_trigger_instance(self, trigger_instance): # Enforce the rules. self.enforce_rules(enforcers) else: - LOG.info('No matching rules found for trigger instance %s.', trigger_instance['id']) + LOG.info( + "No matching rules found for trigger instance %s.", + trigger_instance["id"], + ) def get_matching_rules_for_trigger(self, trigger_instance): trigger = trigger_instance.trigger @@ -48,23 +49,34 @@ def get_matching_rules_for_trigger(self, trigger_instance): trigger_db = get_trigger_db_by_ref(trigger_instance.trigger) if not trigger_db: - LOG.error('No matching trigger found in db for trigger instance %s.', trigger_instance) + LOG.error( + "No matching trigger found in db for trigger instance %s.", + trigger_instance, + ) return None rules = get_rules_given_trigger(trigger=trigger) - LOG.info('Found %d rules defined for trigger %s', len(rules), - trigger_db.get_reference().ref) + LOG.info( + "Found %d rules defined for trigger %s", + len(rules), + trigger_db.get_reference().ref, + ) if len(rules) < 1: return rules - matcher = RulesMatcher(trigger_instance=trigger_instance, - trigger=trigger_db, rules=rules) + matcher = RulesMatcher( + trigger_instance=trigger_instance, trigger=trigger_db, rules=rules + ) matching_rules = matcher.get_matching_rules() - LOG.info('Matched %s rule(s) for trigger_instance %s (trigger=%s)', len(matching_rules), - trigger_instance['id'], trigger_db.ref) + LOG.info( + "Matched %s rule(s) for trigger_instance %s (trigger=%s)", + len(matching_rules), + trigger_instance["id"], + trigger_db.ref, + ) return matching_rules def create_rule_enforcers(self, trigger_instance, matching_rules): @@ -78,8 +90,8 @@ def create_rule_enforcers(self, trigger_instance, matching_rules): enforcers = [] for matching_rule in matching_rules: - metrics_driver.inc_counter('rule.matched') - metrics_driver.inc_counter('rule.%s.matched' % (matching_rule.ref)) + metrics_driver.inc_counter("rule.matched") + metrics_driver.inc_counter("rule.%s.matched" % (matching_rule.ref)) enforcers.append(RuleEnforcer(trigger_instance, matching_rule)) return enforcers @@ -89,4 +101,4 @@ def enforce_rules(self, enforcers): try: enforcer.enforce() # Should this happen in an eventlet pool? except: - LOG.exception('Exception enforcing rule %s.', enforcer.rule) + LOG.exception("Exception enforcing rule %s.", enforcer.rule) diff --git a/st2reactor/st2reactor/rules/filter.py b/st2reactor/st2reactor/rules/filter.py index 1c675381983..700d072c31c 100644 --- a/st2reactor/st2reactor/rules/filter.py +++ b/st2reactor/st2reactor/rules/filter.py @@ -31,12 +31,10 @@ from st2common.util.payload import PayloadLookup from st2common.util.templating import render_template_with_system_context -__all__ = [ - 'RuleFilter' -] +__all__ = ["RuleFilter"] -LOG = logging.getLogger('st2reactor.ruleenforcement.filter') +LOG = logging.getLogger("st2reactor.ruleenforcement.filter") class RuleFilter(object): @@ -58,9 +56,9 @@ def __init__(self, trigger_instance, trigger, rule, extra_info=False): # Base context used with a logger self._base_logger_context = { - 'rule': self.rule, - 'trigger': self.trigger, - 'trigger_instance': self.trigger_instance + "rule": self.rule, + "trigger": self.trigger, + "trigger_instance": self.trigger_instance, } def filter(self): @@ -69,12 +67,18 @@ def filter(self): :rtype: ``bool`` """ - LOG.info('Validating rule %s for %s.', self.rule.ref, self.trigger['name'], - extra=self._base_logger_context) + LOG.info( + "Validating rule %s for %s.", + self.rule.ref, + self.trigger["name"], + extra=self._base_logger_context, + ) if not self.rule.enabled: if self.extra_info: - LOG.info('Validation failed for rule %s as it is disabled.', self.rule.ref) + LOG.info( + "Validation failed for rule %s as it is disabled.", self.rule.ref + ) return False criteria = self.rule.criteria @@ -85,52 +89,66 @@ def filter(self): payload_lookup = PayloadLookup(self.trigger_instance.payload) - LOG.debug('Trigger payload: %s', self.trigger_instance.payload, - extra=self._base_logger_context) + LOG.debug( + "Trigger payload: %s", + self.trigger_instance.payload, + extra=self._base_logger_context, + ) for (criterion_k, criterion_v) in six.iteritems(criteria): - is_rule_applicable, payload_value, criterion_pattern = self._check_criterion( - criterion_k, - criterion_v, - payload_lookup - ) + ( + is_rule_applicable, + payload_value, + criterion_pattern, + ) = self._check_criterion(criterion_k, criterion_v, payload_lookup) if not is_rule_applicable: if self.extra_info: - criteria_extra_info = '\n'.join([ - ' key: %s' % criterion_k, - ' pattern: %s' % criterion_pattern, - ' type: %s' % criterion_v['type'], - ' payload: %s' % payload_value - ]) - LOG.info('Validation for rule %s failed on criteria -\n%s', self.rule.ref, - criteria_extra_info, - extra=self._base_logger_context) + criteria_extra_info = "\n".join( + [ + " key: %s" % criterion_k, + " pattern: %s" % criterion_pattern, + " type: %s" % criterion_v["type"], + " payload: %s" % payload_value, + ] + ) + LOG.info( + "Validation for rule %s failed on criteria -\n%s", + self.rule.ref, + criteria_extra_info, + extra=self._base_logger_context, + ) break if not is_rule_applicable: - LOG.debug('Rule %s not applicable for %s.', self.rule.id, self.trigger['name'], - extra=self._base_logger_context) + LOG.debug( + "Rule %s not applicable for %s.", + self.rule.id, + self.trigger["name"], + extra=self._base_logger_context, + ) return is_rule_applicable def _check_criterion(self, criterion_k, criterion_v, payload_lookup): - if 'type' not in criterion_v: + if "type" not in criterion_v: # Comparison operator type not specified, can't perform a comparison return (False, None, None) - criteria_operator = criterion_v['type'] - criteria_condition = criterion_v.get('condition', None) - criteria_pattern = criterion_v.get('pattern', None) + criteria_operator = criterion_v["type"] + criteria_condition = criterion_v.get("condition", None) + criteria_pattern = criterion_v.get("pattern", None) # Render the pattern (it can contain a jinja expressions) try: criteria_pattern = self._render_criteria_pattern( criteria_pattern=criteria_pattern, - criteria_context=payload_lookup.context + criteria_context=payload_lookup.context, ) except Exception as e: - msg = ('Failed to render pattern value "%s" for key "%s"' % (criteria_pattern, - criterion_k)) + msg = 'Failed to render pattern value "%s" for key "%s"' % ( + criteria_pattern, + criterion_k, + ) LOG.exception(msg, extra=self._base_logger_context) self._create_rule_enforcement(failure_reason=msg, exc=e) @@ -144,7 +162,7 @@ def _check_criterion(self, criterion_k, criterion_v, payload_lookup): else: payload_value = None except Exception as e: - msg = ('Failed transforming criteria key %s' % criterion_k) + msg = "Failed transforming criteria key %s" % criterion_k LOG.exception(msg, extra=self._base_logger_context) self._create_rule_enforcement(failure_reason=msg, exc=e) @@ -154,13 +172,18 @@ def _check_criterion(self, criterion_k, criterion_v, payload_lookup): try: if criteria_operator == criteria_operators.SEARCH: - result = op_func(value=payload_value, criteria_pattern=criteria_pattern, - criteria_condition=criteria_condition, - check_function=self._bool_criterion) + result = op_func( + value=payload_value, + criteria_pattern=criteria_pattern, + criteria_condition=criteria_condition, + check_function=self._bool_criterion, + ) else: result = op_func(value=payload_value, criteria_pattern=criteria_pattern) except Exception as e: - msg = ('There might be a problem with the criteria in rule %s' % (self.rule.ref)) + msg = "There might be a problem with the criteria in rule %s" % ( + self.rule.ref + ) LOG.exception(msg, extra=self._base_logger_context) self._create_rule_enforcement(failure_reason=msg, exc=e) @@ -185,9 +208,9 @@ def _render_criteria_pattern(self, criteria_pattern, criteria_context): return criteria_pattern LOG.debug( - 'Rendering criteria pattern (%s) with context: %s', + "Rendering criteria pattern (%s) with context: %s", criteria_pattern, - criteria_context + criteria_context, ) to_complex = False @@ -197,30 +220,24 @@ def _render_criteria_pattern(self, criteria_pattern, criteria_context): if len(re.findall(MATCH_CRITERIA, criteria_pattern)) > 0: LOG.debug("Rendering Complex") complex_criteria_pattern = re.sub( - MATCH_CRITERIA, r'\1\2 | to_complex\3', - criteria_pattern + MATCH_CRITERIA, r"\1\2 | to_complex\3", criteria_pattern ) try: criteria_rendered = render_template_with_system_context( - value=complex_criteria_pattern, - context=criteria_context + value=complex_criteria_pattern, context=criteria_context ) criteria_rendered = json.loads(criteria_rendered) to_complex = True except ValueError as error: - LOG.debug('Criteria pattern not valid JSON: %s', error) + LOG.debug("Criteria pattern not valid JSON: %s", error) if not to_complex: criteria_rendered = render_template_with_system_context( - value=criteria_pattern, - context=criteria_context + value=criteria_pattern, context=criteria_context ) - LOG.debug( - 'Rendered criteria pattern: %s', - criteria_rendered - ) + LOG.debug("Rendered criteria pattern: %s", criteria_rendered) return criteria_rendered @@ -231,19 +248,32 @@ def _create_rule_enforcement(self, failure_reason, exc): Without that, only way for users to find out about those failes matches is by inspecting the logs. """ - failure_reason = ('Failed to match rule "%s" against trigger instance "%s": %s: %s' % - (self.rule.ref, str(self.trigger_instance.id), failure_reason, str(exc))) - rule_spec = {'ref': self.rule.ref, 'id': str(self.rule.id), 'uid': self.rule.uid} - enforcement_db = RuleEnforcementDB(trigger_instance_id=str(self.trigger_instance.id), - rule=rule_spec, - failure_reason=failure_reason, - status=RULE_ENFORCEMENT_STATUS_FAILED) + failure_reason = ( + 'Failed to match rule "%s" against trigger instance "%s": %s: %s' + % ( + self.rule.ref, + str(self.trigger_instance.id), + failure_reason, + str(exc), + ) + ) + rule_spec = { + "ref": self.rule.ref, + "id": str(self.rule.id), + "uid": self.rule.uid, + } + enforcement_db = RuleEnforcementDB( + trigger_instance_id=str(self.trigger_instance.id), + rule=rule_spec, + failure_reason=failure_reason, + status=RULE_ENFORCEMENT_STATUS_FAILED, + ) try: RuleEnforcement.add_or_update(enforcement_db) except: - extra = {'enforcement_db': enforcement_db} - LOG.exception('Failed writing enforcement model to db.', extra=extra) + extra = {"enforcement_db": enforcement_db} + LOG.exception("Failed writing enforcement model to db.", extra=extra) return enforcement_db @@ -253,6 +283,7 @@ class SecondPassRuleFilter(RuleFilter): Special filter that handles all second pass rules. For not these are only backstop rules i.e. those that can match when no other rule has matched. """ + def __init__(self, trigger_instance, trigger, rule, first_pass_matched): """ :param trigger_instance: TriggerInstance DB object. @@ -277,4 +308,4 @@ def filter(self): return super(SecondPassRuleFilter, self).filter() def _is_backstop_rule(self): - return self.rule.type['ref'] == RULE_TYPE_BACKSTOP + return self.rule.type["ref"] == RULE_TYPE_BACKSTOP diff --git a/st2reactor/st2reactor/rules/matcher.py b/st2reactor/st2reactor/rules/matcher.py index b2ed1989455..4b3a8a2483d 100644 --- a/st2reactor/st2reactor/rules/matcher.py +++ b/st2reactor/st2reactor/rules/matcher.py @@ -18,7 +18,7 @@ from st2common.constants.rules import RULE_TYPE_BACKSTOP from st2reactor.rules.filter import RuleFilter, SecondPassRuleFilter -LOG = logging.getLogger('st2reactor.rules.RulesMatcher') +LOG = logging.getLogger("st2reactor.rules.RulesMatcher") class RulesMatcher(object): @@ -31,25 +31,44 @@ def __init__(self, trigger_instance, trigger, rules, extra_info=False): def get_matching_rules(self): first_pass, second_pass = self._split_rules_into_passes() # first pass - rule_filters = [RuleFilter(trigger_instance=self.trigger_instance, - trigger=self.trigger, - rule=rule, - extra_info=self.extra_info) - for rule in first_pass] - matched_rules = [rule_filter.rule for rule_filter in rule_filters if rule_filter.filter()] - LOG.debug('[1st_pass] %d rule(s) found to enforce for %s.', len(matched_rules), - self.trigger['name']) + rule_filters = [ + RuleFilter( + trigger_instance=self.trigger_instance, + trigger=self.trigger, + rule=rule, + extra_info=self.extra_info, + ) + for rule in first_pass + ] + matched_rules = [ + rule_filter.rule for rule_filter in rule_filters if rule_filter.filter() + ] + LOG.debug( + "[1st_pass] %d rule(s) found to enforce for %s.", + len(matched_rules), + self.trigger["name"], + ) # second pass - rule_filters = [SecondPassRuleFilter(self.trigger_instance, self.trigger, rule, - matched_rules) - for rule in second_pass] - matched_in_second_pass = [rule_filter.rule for rule_filter in rule_filters - if rule_filter.filter()] - LOG.debug('[2nd_pass] %d rule(s) found to enforce for %s.', len(matched_in_second_pass), - self.trigger['name']) + rule_filters = [ + SecondPassRuleFilter( + self.trigger_instance, self.trigger, rule, matched_rules + ) + for rule in second_pass + ] + matched_in_second_pass = [ + rule_filter.rule for rule_filter in rule_filters if rule_filter.filter() + ] + LOG.debug( + "[2nd_pass] %d rule(s) found to enforce for %s.", + len(matched_in_second_pass), + self.trigger["name"], + ) matched_rules.extend(matched_in_second_pass) - LOG.info('%d rule(s) found to enforce for %s.', len(matched_rules), - self.trigger['name']) + LOG.info( + "%d rule(s) found to enforce for %s.", + len(matched_rules), + self.trigger["name"], + ) return matched_rules def _split_rules_into_passes(self): @@ -68,4 +87,4 @@ def _split_rules_into_passes(self): return first_pass, second_pass def _is_first_pass_rule(self, rule): - return rule.type['ref'] != RULE_TYPE_BACKSTOP + return rule.type["ref"] != RULE_TYPE_BACKSTOP diff --git a/st2reactor/st2reactor/rules/tester.py b/st2reactor/st2reactor/rules/tester.py index 790148d82de..da4e3572c5f 100644 --- a/st2reactor/st2reactor/rules/tester.py +++ b/st2reactor/st2reactor/rules/tester.py @@ -32,16 +32,19 @@ from st2reactor.rules.enforcer import RuleEnforcer from st2reactor.rules.matcher import RulesMatcher -__all__ = [ - 'RuleTester' -] +__all__ = ["RuleTester"] LOG = logging.getLogger(__name__) class RuleTester(object): - def __init__(self, rule_file_path=None, rule_ref=None, trigger_instance_file_path=None, - trigger_instance_id=None): + def __init__( + self, + rule_file_path=None, + rule_ref=None, + trigger_instance_file_path=None, + trigger_instance_id=None, + ): """ :param rule_file_path: Path to the file containing rule definition. :type rule_file_path: ``str`` @@ -69,13 +72,20 @@ def evaluate(self): # The trigger check needs to be performed here as that is not performed # by RulesMatcher. if rule_db.trigger != trigger_db.ref: - LOG.info('rule.trigger "%s" and trigger.ref "%s" do not match.', - rule_db.trigger, trigger_db.ref) + LOG.info( + 'rule.trigger "%s" and trigger.ref "%s" do not match.', + rule_db.trigger, + trigger_db.ref, + ) return False # Check if rule matches criteria. - matcher = RulesMatcher(trigger_instance=trigger_instance_db, trigger=trigger_db, - rules=[rule_db], extra_info=True) + matcher = RulesMatcher( + trigger_instance=trigger_instance_db, + trigger=trigger_db, + rules=[rule_db], + extra_info=True, + ) matching_rules = matcher.get_matching_rules() # Rule does not match so early exit. @@ -91,69 +101,86 @@ def evaluate(self): action_db.parameters = {} params = rule_db.action.parameters # pylint: disable=no-member - context, additional_contexts = enforcer.get_action_execution_context(action_db=action_db, - trace_context=None) + context, additional_contexts = enforcer.get_action_execution_context( + action_db=action_db, trace_context=None + ) # Note: We only return partially resolved parameters. # To be able to return all parameters we would need access to corresponding ActionDB, # RunnerTypeDB and ConfigDB object, but this would add a dependency on the database and the # tool is meant to be used standalone. try: - params = enforcer.get_resolved_parameters(action_db=action_db, - runnertype_db=runner_type_db, - params=params, - context=context, - additional_contexts=additional_contexts) - - LOG.info('Action parameters resolved to:') + params = enforcer.get_resolved_parameters( + action_db=action_db, + runnertype_db=runner_type_db, + params=params, + context=context, + additional_contexts=additional_contexts, + ) + + LOG.info("Action parameters resolved to:") for param in six.iteritems(params): - LOG.info('\t%s: %s', param[0], param[1]) + LOG.info("\t%s: %s", param[0], param[1]) return True except (UndefinedError, ValueError) as e: - LOG.error('Failed to resolve parameters\n\tOriginal error : %s', six.text_type(e)) + LOG.error( + "Failed to resolve parameters\n\tOriginal error : %s", six.text_type(e) + ) return False except: - LOG.exception('Failed to resolve parameters.') + LOG.exception("Failed to resolve parameters.") return False def _get_rule_db(self): if self._rule_file_path: return self._get_rule_db_from_file( - file_path=os.path.realpath(self._rule_file_path)) + file_path=os.path.realpath(self._rule_file_path) + ) elif self._rule_ref: return Rule.get_by_ref(self._rule_ref) - raise ValueError('One of _rule_file_path or _rule_ref should be specified.') + raise ValueError("One of _rule_file_path or _rule_ref should be specified.") def _get_trigger_instance_db(self): if self._trigger_instance_file_path: return self._get_trigger_instance_db_from_file( - file_path=os.path.realpath(self._trigger_instance_file_path)) + file_path=os.path.realpath(self._trigger_instance_file_path) + ) elif self._trigger_instance_id: trigger_instance_db = TriggerInstance.get_by_id(self._trigger_instance_id) trigger_db = Trigger.get_by_ref(trigger_instance_db.trigger) return trigger_instance_db, trigger_db - raise ValueError('One of _trigger_instance_file_path or' - '_trigger_instance_id should be specified.') + raise ValueError( + "One of _trigger_instance_file_path or" + "_trigger_instance_id should be specified." + ) def _get_rule_db_from_file(self, file_path): data = self._meta_loader.load(file_path=file_path) - pack = data.get('pack', 'unknown') - name = data.get('name', 'unknown') - trigger = data['trigger']['type'] - criteria = data.get('criteria', None) - action = data.get('action', {}) - - rule_db = RuleDB(pack=pack, name=name, trigger=trigger, criteria=criteria, action=action, - enabled=True) - rule_db.id = 'rule_tester_rule' + pack = data.get("pack", "unknown") + name = data.get("name", "unknown") + trigger = data["trigger"]["type"] + criteria = data.get("criteria", None) + action = data.get("action", {}) + + rule_db = RuleDB( + pack=pack, + name=name, + trigger=trigger, + criteria=criteria, + action=action, + enabled=True, + ) + rule_db.id = "rule_tester_rule" return rule_db def _get_trigger_instance_db_from_file(self, file_path): data = self._meta_loader.load(file_path=file_path) instance = TriggerInstanceDB(**data) - instance.id = 'rule_tester_instance' + instance.id = "rule_tester_instance" - trigger_ref = ResourceReference.from_string_reference(instance['trigger']) - trigger_db = TriggerDB(pack=trigger_ref.pack, name=trigger_ref.name, type=trigger_ref.ref) + trigger_ref = ResourceReference.from_string_reference(instance["trigger"]) + trigger_db = TriggerDB( + pack=trigger_ref.pack, name=trigger_ref.name, type=trigger_ref.ref + ) return instance, trigger_db diff --git a/st2reactor/st2reactor/rules/worker.py b/st2reactor/st2reactor/rules/worker.py index 7dbe4a59e1b..53e636a346b 100644 --- a/st2reactor/st2reactor/rules/worker.py +++ b/st2reactor/st2reactor/rules/worker.py @@ -41,12 +41,12 @@ def __init__(self, connection, queues): self.rules_engine = RulesEngine() def pre_ack_process(self, message): - ''' + """ TriggerInstance from message is create prior to acknowledging the message. This gets us a way to not acknowledge messages. - ''' - trigger = message['trigger'] - payload = message['payload'] + """ + trigger = message["trigger"] + payload = message["payload"] # Accomodate for not being able to create a TrigegrInstance if a TriggerDB # is not found. @@ -54,16 +54,19 @@ def pre_ack_process(self, message): trigger, payload or {}, date_utils.get_datetime_utc_now(), - raise_on_no_trigger=True) + raise_on_no_trigger=True, + ) return self._compose_pre_ack_process_response(trigger_instance, message) def process(self, pre_ack_response): - trigger_instance, message = self._decompose_pre_ack_process_response(pre_ack_response) + trigger_instance, message = self._decompose_pre_ack_process_response( + pre_ack_response + ) if not trigger_instance: - raise ValueError('No trigger_instance provided for processing.') + raise ValueError("No trigger_instance provided for processing.") - get_driver().inc_counter('trigger.%s.processed' % (trigger_instance.trigger)) + get_driver().inc_counter("trigger.%s.processed" % (trigger_instance.trigger)) try: # Use trace_context from the message and if not found create a new context @@ -71,34 +74,39 @@ def process(self, pre_ack_response): trace_context = message.get(TRACE_CONTEXT, None) if not trace_context: trace_context = { - TRACE_ID: 'trigger_instance-%s' % str(trigger_instance.id) + TRACE_ID: "trigger_instance-%s" % str(trigger_instance.id) } # add a trace or update an existing trace with trigger_instance trace_service.add_or_update_given_trace_context( trace_context=trace_context, trigger_instances=[ - trace_service.get_trace_component_for_trigger_instance(trigger_instance) - ] + trace_service.get_trace_component_for_trigger_instance( + trigger_instance + ) + ], ) container_utils.update_trigger_instance_status( - trigger_instance, trigger_constants.TRIGGER_INSTANCE_PROCESSING) + trigger_instance, trigger_constants.TRIGGER_INSTANCE_PROCESSING + ) - with CounterWithTimer(key='rule.processed'): - with Timer(key='trigger.%s.processed' % (trigger_instance.trigger)): + with CounterWithTimer(key="rule.processed"): + with Timer(key="trigger.%s.processed" % (trigger_instance.trigger)): self.rules_engine.handle_trigger_instance(trigger_instance) container_utils.update_trigger_instance_status( - trigger_instance, trigger_constants.TRIGGER_INSTANCE_PROCESSED) + trigger_instance, trigger_constants.TRIGGER_INSTANCE_PROCESSED + ) except: # TODO : Capture the reason for failure. container_utils.update_trigger_instance_status( - trigger_instance, trigger_constants.TRIGGER_INSTANCE_PROCESSING_FAILED) + trigger_instance, trigger_constants.TRIGGER_INSTANCE_PROCESSING_FAILED + ) # This could be a large message but at least in case of an exception # we get to see more context. # Beyond this point code cannot really handle the exception anyway so # eating up the exception. - LOG.exception('Failed to handle trigger_instance %s.', trigger_instance) + LOG.exception("Failed to handle trigger_instance %s.", trigger_instance) return @staticmethod @@ -106,14 +114,14 @@ def _compose_pre_ack_process_response(trigger_instance, message): """ Codify response of the pre_ack_process method. """ - return {'trigger_instance': trigger_instance, 'message': message} + return {"trigger_instance": trigger_instance, "message": message} @staticmethod def _decompose_pre_ack_process_response(response): """ Break-down response of pre_ack_process into constituents for simpler consumption. """ - return response.get('trigger_instance', None), response.get('message', None) + return response.get("trigger_instance", None), response.get("message", None) def get_worker(): diff --git a/st2reactor/st2reactor/sensor/base.py b/st2reactor/st2reactor/sensor/base.py index f7fce2460b5..a8309ba292b 100644 --- a/st2reactor/st2reactor/sensor/base.py +++ b/st2reactor/st2reactor/sensor/base.py @@ -21,10 +21,7 @@ from st2common.util import concurrency -__all__ = [ - 'Sensor', - 'PollingSensor' -] +__all__ = ["Sensor", "PollingSensor"] @six.add_metaclass(abc.ABCMeta) @@ -107,7 +104,9 @@ class PollingSensor(BaseSensor): """ def __init__(self, sensor_service, config=None, poll_interval=5): - super(PollingSensor, self).__init__(sensor_service=sensor_service, config=config) + super(PollingSensor, self).__init__( + sensor_service=sensor_service, config=config + ) self._poll_interval = poll_interval @abc.abstractmethod diff --git a/st2reactor/st2reactor/sensor/config.py b/st2reactor/st2reactor/sensor/config.py index 981ddd9b8f8..8126bdbc9f6 100644 --- a/st2reactor/st2reactor/sensor/config.py +++ b/st2reactor/st2reactor/sensor/config.py @@ -26,8 +26,11 @@ def parse_args(args=None): - cfg.CONF(args=args, version=VERSION_STRING, - default_config_files=[DEFAULT_CONFIG_FILE_PATH]) + cfg.CONF( + args=args, + version=VERSION_STRING, + default_config_files=[DEFAULT_CONFIG_FILE_PATH], + ) def register_opts(ignore_errors=False): @@ -46,48 +49,62 @@ def _register_common_opts(ignore_errors=False): def _register_sensor_container_opts(ignore_errors=False): logging_opts = [ cfg.StrOpt( - 'logging', default='/etc/st2/logging.sensorcontainer.conf', - help='location of the logging.conf file') + "logging", + default="/etc/st2/logging.sensorcontainer.conf", + help="location of the logging.conf file", + ) ] - st2cfg.do_register_opts(logging_opts, group='sensorcontainer', ignore_errors=ignore_errors) + st2cfg.do_register_opts( + logging_opts, group="sensorcontainer", ignore_errors=ignore_errors + ) # Partitioning options partition_opts = [ cfg.StrOpt( - 'sensor_node_name', default='sensornode1', - help='name of the sensor node.'), + "sensor_node_name", default="sensornode1", help="name of the sensor node." + ), cfg.Opt( - 'partition_provider', + "partition_provider", type=types.Dict(value_type=types.String()), - default={'name': DEFAULT_PARTITION_LOADER}, - help='Provider of sensor node partition config.') + default={"name": DEFAULT_PARTITION_LOADER}, + help="Provider of sensor node partition config.", + ), ] - st2cfg.do_register_opts(partition_opts, group='sensorcontainer', ignore_errors=ignore_errors) + st2cfg.do_register_opts( + partition_opts, group="sensorcontainer", ignore_errors=ignore_errors + ) # Other options other_opts = [ cfg.BoolOpt( - 'single_sensor_mode', default=False, - help='Run in a single sensor mode where parent process exits when a sensor crashes / ' - 'dies. This is useful in environments where partitioning, sensor process life ' - 'cycle and failover is handled by a 3rd party service such as kubernetes.') + "single_sensor_mode", + default=False, + help="Run in a single sensor mode where parent process exits when a sensor crashes / " + "dies. This is useful in environments where partitioning, sensor process life " + "cycle and failover is handled by a 3rd party service such as kubernetes.", + ) ] - st2cfg.do_register_opts(other_opts, group='sensorcontainer', ignore_errors=ignore_errors) + st2cfg.do_register_opts( + other_opts, group="sensorcontainer", ignore_errors=ignore_errors + ) # CLI options cli_opts = [ cfg.StrOpt( - 'sensor-ref', - help='Only run sensor with the provided reference. Value is of the form ' - '. (e.g. linux.FileWatchSensor).'), + "sensor-ref", + help="Only run sensor with the provided reference. Value is of the form " + ". (e.g. linux.FileWatchSensor).", + ), cfg.BoolOpt( - 'single-sensor-mode', default=False, - help='Run in a single sensor mode where parent process exits when a sensor crashes / ' - 'dies. This is useful in environments where partitioning, sensor process life ' - 'cycle and failover is handled by a 3rd party service such as kubernetes.') + "single-sensor-mode", + default=False, + help="Run in a single sensor mode where parent process exits when a sensor crashes / " + "dies. This is useful in environments where partitioning, sensor process life " + "cycle and failover is handled by a 3rd party service such as kubernetes.", + ), ] st2cfg.do_register_cli_opts(cli_opts, ignore_errors=ignore_errors) diff --git a/st2reactor/st2reactor/timer/base.py b/st2reactor/st2reactor/timer/base.py index ed99d90e776..723d362066a 100644 --- a/st2reactor/st2reactor/timer/base.py +++ b/st2reactor/st2reactor/timer/base.py @@ -41,17 +41,20 @@ class St2Timer(object): """ A timer interface that uses APScheduler 3.0. """ + def __init__(self, local_timezone=None): self._timezone = local_timezone self._scheduler = BlockingScheduler(timezone=self._timezone) self._jobs = {} self._trigger_types = list(TIMER_TRIGGER_TYPES.keys()) - self._trigger_watcher = TriggerWatcher(create_handler=self._handle_create_trigger, - update_handler=self._handle_update_trigger, - delete_handler=self._handle_delete_trigger, - trigger_types=self._trigger_types, - queue_suffix=self.__class__.__name__, - exclusive=True) + self._trigger_watcher = TriggerWatcher( + create_handler=self._handle_create_trigger, + update_handler=self._handle_update_trigger, + delete_handler=self._handle_delete_trigger, + trigger_types=self._trigger_types, + queue_suffix=self.__class__.__name__, + exclusive=True, + ) self._trigger_dispatcher = TriggerDispatcher(LOG) def start(self): @@ -70,89 +73,109 @@ def update_trigger(self, trigger): self.add_trigger(trigger) def remove_trigger(self, trigger): - trigger_id = trigger['id'] + trigger_id = trigger["id"] try: job_id = self._jobs[trigger_id] except KeyError: - LOG.info('Job not found: %s', trigger_id) + LOG.info("Job not found: %s", trigger_id) return self._scheduler.remove_job(job_id) del self._jobs[trigger_id] def _add_job_to_scheduler(self, trigger): - trigger_type_ref = trigger['type'] + trigger_type_ref = trigger["type"] trigger_type = TIMER_TRIGGER_TYPES[trigger_type_ref] try: - util_schema.validate(instance=trigger['parameters'], - schema=trigger_type['parameters_schema'], - cls=util_schema.CustomValidator, - use_default=True, - allow_default_none=True) + util_schema.validate( + instance=trigger["parameters"], + schema=trigger_type["parameters_schema"], + cls=util_schema.CustomValidator, + use_default=True, + allow_default_none=True, + ) except jsonschema.ValidationError as e: - LOG.error('Exception scheduling timer: %s, %s', - trigger['parameters'], e, exc_info=True) + LOG.error( + "Exception scheduling timer: %s, %s", + trigger["parameters"], + e, + exc_info=True, + ) raise # Or should we just return? - time_spec = trigger['parameters'] - time_zone = aps_utils.astimezone(trigger['parameters'].get('timezone')) + time_spec = trigger["parameters"] + time_zone = aps_utils.astimezone(trigger["parameters"].get("timezone")) time_type = None - if trigger_type['name'] == 'st2.IntervalTimer': - unit = time_spec.get('unit', None) - value = time_spec.get('delta', None) - time_type = IntervalTrigger(**{unit: value, 'timezone': time_zone}) - elif trigger_type['name'] == 'st2.DateTimer': + if trigger_type["name"] == "st2.IntervalTimer": + unit = time_spec.get("unit", None) + value = time_spec.get("delta", None) + time_type = IntervalTrigger(**{unit: value, "timezone": time_zone}) + elif trigger_type["name"] == "st2.DateTimer": # Raises an exception if date string isn't a valid one. - dat = date_parser.parse(time_spec.get('date', None)) + dat = date_parser.parse(time_spec.get("date", None)) time_type = DateTrigger(dat, timezone=time_zone) - elif trigger_type['name'] == 'st2.CronTimer': + elif trigger_type["name"] == "st2.CronTimer": cron = time_spec.copy() - cron['timezone'] = time_zone + cron["timezone"] = time_zone time_type = CronTrigger(**cron) utc_now = date_utils.get_datetime_utc_now() - if hasattr(time_type, 'run_date') and utc_now > time_type.run_date: - LOG.warning('Not scheduling expired timer: %s : %s', - trigger['parameters'], time_type.run_date) + if hasattr(time_type, "run_date") and utc_now > time_type.run_date: + LOG.warning( + "Not scheduling expired timer: %s : %s", + trigger["parameters"], + time_type.run_date, + ) else: self._add_job(trigger, time_type) return time_type def _add_job(self, trigger, time_type, replace=True): try: - job = self._scheduler.add_job(self._emit_trigger_instance, - trigger=time_type, - args=[trigger], - replace_existing=replace) - LOG.info('Job %s scheduled.', job.id) - self._jobs[trigger['id']] = job.id + job = self._scheduler.add_job( + self._emit_trigger_instance, + trigger=time_type, + args=[trigger], + replace_existing=replace, + ) + LOG.info("Job %s scheduled.", job.id) + self._jobs[trigger["id"]] = job.id except Exception as e: - LOG.error('Exception scheduling timer: %s, %s', - trigger['parameters'], e, exc_info=True) + LOG.error( + "Exception scheduling timer: %s, %s", + trigger["parameters"], + e, + exc_info=True, + ) def _emit_trigger_instance(self, trigger): utc_now = date_utils.get_datetime_utc_now() # debug logging is reasonable for this one. A high resolution timer will end up # trashing standard logs. - LOG.debug('Timer fired at: %s. Trigger: %s', str(utc_now), trigger) + LOG.debug("Timer fired at: %s. Trigger: %s", str(utc_now), trigger) payload = { - 'executed_at': str(utc_now), - 'schedule': trigger['parameters'].get('time') + "executed_at": str(utc_now), + "schedule": trigger["parameters"].get("time"), } - trace_context = TraceContext(trace_tag='%s-%s' % (self._get_trigger_type_name(trigger), - trigger.get('name', uuid.uuid4().hex))) + trace_context = TraceContext( + trace_tag="%s-%s" + % ( + self._get_trigger_type_name(trigger), + trigger.get("name", uuid.uuid4().hex), + ) + ) self._trigger_dispatcher.dispatch(trigger, payload, trace_context=trace_context) def _get_trigger_type_name(self, trigger): - trigger_type_ref = trigger['type'] + trigger_type_ref = trigger["type"] trigger_type = TIMER_TRIGGER_TYPES[trigger_type_ref] - return trigger_type['name'] + return trigger_type["name"] def _register_timer_trigger_types(self): return trigger_services.add_trigger_models(list(TIMER_TRIGGER_TYPES.values())) diff --git a/st2reactor/st2reactor/timer/config.py b/st2reactor/st2reactor/timer/config.py index db180f85dd0..bbc1020cb94 100644 --- a/st2reactor/st2reactor/timer/config.py +++ b/st2reactor/st2reactor/timer/config.py @@ -25,8 +25,11 @@ def parse_args(args=None): - cfg.CONF(args=args, version=VERSION_STRING, - default_config_files=[DEFAULT_CONFIG_FILE_PATH]) + cfg.CONF( + args=args, + version=VERSION_STRING, + default_config_files=[DEFAULT_CONFIG_FILE_PATH], + ) def register_opts(): diff --git a/st2reactor/tests/integration/test_garbage_collector.py b/st2reactor/tests/integration/test_garbage_collector.py index 1de1e9c529c..5b0f890ac32 100644 --- a/st2reactor/tests/integration/test_garbage_collector.py +++ b/st2reactor/tests/integration/test_garbage_collector.py @@ -37,33 +37,28 @@ from st2tests.fixturesloader import FixturesLoader from six.moves import range -__all__ = [ - 'GarbageCollectorServiceTestCase' -] +__all__ = ["GarbageCollectorServiceTestCase"] BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -ST2_CONFIG_PATH = os.path.join(BASE_DIR, '../../../conf/st2.tests.conf') +ST2_CONFIG_PATH = os.path.join(BASE_DIR, "../../../conf/st2.tests.conf") ST2_CONFIG_PATH = os.path.abspath(ST2_CONFIG_PATH) -INQUIRY_CONFIG_PATH = os.path.join(BASE_DIR, '../../../conf/st2.tests2.conf') +INQUIRY_CONFIG_PATH = os.path.join(BASE_DIR, "../../../conf/st2.tests2.conf") INQUIRY_CONFIG_PATH = os.path.abspath(INQUIRY_CONFIG_PATH) PYTHON_BINARY = sys.executable -BINARY = os.path.join(BASE_DIR, '../../../st2reactor/bin/st2garbagecollector') +BINARY = os.path.join(BASE_DIR, "../../../st2reactor/bin/st2garbagecollector") BINARY = os.path.abspath(BINARY) -CMD = [PYTHON_BINARY, BINARY, '--config-file', ST2_CONFIG_PATH] -CMD_INQUIRY = [PYTHON_BINARY, BINARY, '--config-file', INQUIRY_CONFIG_PATH] +CMD = [PYTHON_BINARY, BINARY, "--config-file", ST2_CONFIG_PATH] +CMD_INQUIRY = [PYTHON_BINARY, BINARY, "--config-file", INQUIRY_CONFIG_PATH] -TEST_FIXTURES = { - 'runners': ['inquirer.yaml'], - 'actions': ['ask.yaml'] -} +TEST_FIXTURES = {"runners": ["inquirer.yaml"], "actions": ["ask.yaml"]} -FIXTURES_PACK = 'generic' +FIXTURES_PACK = "generic" class GarbageCollectorServiceTestCase(IntegrationTestCase, CleanDbTestCase): @@ -75,7 +70,8 @@ def setUp(self): super(GarbageCollectorServiceTestCase, self).setUp() self.models = FixturesLoader().save_fixtures_to_db( - fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES) + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_FIXTURES + ) def test_garbage_collection(self): now = date_utils.get_datetime_utc_now() @@ -85,102 +81,125 @@ def test_garbage_collection(self): # config old_executions_count = 15 ttl_days = 30 # > 20 - timestamp = (now - datetime.timedelta(days=ttl_days)) + timestamp = now - datetime.timedelta(days=ttl_days) for index in range(0, old_executions_count): - action_execution_db = ActionExecutionDB(start_timestamp=timestamp, - end_timestamp=timestamp, - status=status, - action={'ref': 'core.local'}, - runner={'name': 'local-shell-cmd'}, - liveaction={'ref': 'foo'}) + action_execution_db = ActionExecutionDB( + start_timestamp=timestamp, + end_timestamp=timestamp, + status=status, + action={"ref": "core.local"}, + runner={"name": "local-shell-cmd"}, + liveaction={"ref": "foo"}, + ) ActionExecution.add_or_update(action_execution_db) - stdout_db = ActionExecutionOutputDB(execution_id=str(action_execution_db.id), - action_ref='core.local', - runner_ref='dummy', - timestamp=timestamp, - output_type='stdout', - data='stdout') + stdout_db = ActionExecutionOutputDB( + execution_id=str(action_execution_db.id), + action_ref="core.local", + runner_ref="dummy", + timestamp=timestamp, + output_type="stdout", + data="stdout", + ) ActionExecutionOutput.add_or_update(stdout_db) - stderr_db = ActionExecutionOutputDB(execution_id=str(action_execution_db.id), - action_ref='core.local', - runner_ref='dummy', - timestamp=timestamp, - output_type='stderr', - data='stderr') + stderr_db = ActionExecutionOutputDB( + execution_id=str(action_execution_db.id), + action_ref="core.local", + runner_ref="dummy", + timestamp=timestamp, + output_type="stderr", + data="stderr", + ) ActionExecutionOutput.add_or_update(stderr_db) # Insert come mock ActionExecutionDB objects with start_timestamp > TTL defined in the # config new_executions_count = 5 ttl_days = 2 # < 20 - timestamp = (now - datetime.timedelta(days=ttl_days)) + timestamp = now - datetime.timedelta(days=ttl_days) for index in range(0, new_executions_count): - action_execution_db = ActionExecutionDB(start_timestamp=timestamp, - end_timestamp=timestamp, - status=status, - action={'ref': 'core.local'}, - runner={'name': 'local-shell-cmd'}, - liveaction={'ref': 'foo'}) + action_execution_db = ActionExecutionDB( + start_timestamp=timestamp, + end_timestamp=timestamp, + status=status, + action={"ref": "core.local"}, + runner={"name": "local-shell-cmd"}, + liveaction={"ref": "foo"}, + ) ActionExecution.add_or_update(action_execution_db) - stdout_db = ActionExecutionOutputDB(execution_id=str(action_execution_db.id), - action_ref='core.local', - runner_ref='dummy', - timestamp=timestamp, - output_type='stdout', - data='stdout') + stdout_db = ActionExecutionOutputDB( + execution_id=str(action_execution_db.id), + action_ref="core.local", + runner_ref="dummy", + timestamp=timestamp, + output_type="stdout", + data="stdout", + ) ActionExecutionOutput.add_or_update(stdout_db) - stderr_db = ActionExecutionOutputDB(execution_id=str(action_execution_db.id), - action_ref='core.local', - runner_ref='dummy', - timestamp=timestamp, - output_type='stderr', - data='stderr') + stderr_db = ActionExecutionOutputDB( + execution_id=str(action_execution_db.id), + action_ref="core.local", + runner_ref="dummy", + timestamp=timestamp, + output_type="stderr", + data="stderr", + ) ActionExecutionOutput.add_or_update(stderr_db) # Insert some mock output objects where start_timestamp > action_executions_output_ttl new_output_count = 5 ttl_days = 15 # > 10 and < 20 - timestamp = (now - datetime.timedelta(days=ttl_days)) + timestamp = now - datetime.timedelta(days=ttl_days) for index in range(0, new_output_count): - action_execution_db = ActionExecutionDB(start_timestamp=timestamp, - end_timestamp=timestamp, - status=status, - action={'ref': 'core.local'}, - runner={'name': 'local-shell-cmd'}, - liveaction={'ref': 'foo'}) + action_execution_db = ActionExecutionDB( + start_timestamp=timestamp, + end_timestamp=timestamp, + status=status, + action={"ref": "core.local"}, + runner={"name": "local-shell-cmd"}, + liveaction={"ref": "foo"}, + ) ActionExecution.add_or_update(action_execution_db) - stdout_db = ActionExecutionOutputDB(execution_id=str(action_execution_db.id), - action_ref='core.local', - runner_ref='dummy', - timestamp=timestamp, - output_type='stdout', - data='stdout') + stdout_db = ActionExecutionOutputDB( + execution_id=str(action_execution_db.id), + action_ref="core.local", + runner_ref="dummy", + timestamp=timestamp, + output_type="stdout", + data="stdout", + ) ActionExecutionOutput.add_or_update(stdout_db) - stderr_db = ActionExecutionOutputDB(execution_id=str(action_execution_db.id), - action_ref='core.local', - runner_ref='dummy', - timestamp=timestamp, - output_type='stderr', - data='stderr') + stderr_db = ActionExecutionOutputDB( + execution_id=str(action_execution_db.id), + action_ref="core.local", + runner_ref="dummy", + timestamp=timestamp, + output_type="stderr", + data="stderr", + ) ActionExecutionOutput.add_or_update(stderr_db) execs = ActionExecution.get_all() - self.assertEqual(len(execs), - (old_executions_count + new_executions_count + new_output_count)) - - stdout_dbs = ActionExecutionOutput.query(output_type='stdout') - self.assertEqual(len(stdout_dbs), - (old_executions_count + new_executions_count + new_output_count)) - - stderr_dbs = ActionExecutionOutput.query(output_type='stderr') - self.assertEqual(len(stderr_dbs), - (old_executions_count + new_executions_count + new_output_count)) + self.assertEqual( + len(execs), (old_executions_count + new_executions_count + new_output_count) + ) + + stdout_dbs = ActionExecutionOutput.query(output_type="stdout") + self.assertEqual( + len(stdout_dbs), + (old_executions_count + new_executions_count + new_output_count), + ) + + stderr_dbs = ActionExecutionOutput.query(output_type="stderr") + self.assertEqual( + len(stderr_dbs), + (old_executions_count + new_executions_count + new_output_count), + ) # Start garbage collector process = self._start_garbage_collector() @@ -196,10 +215,10 @@ def test_garbage_collection(self): # Collection for output objects older than 10 days is also enabled, so those objects # should be deleted as well - stdout_dbs = ActionExecutionOutput.query(output_type='stdout') + stdout_dbs = ActionExecutionOutput.query(output_type="stdout") self.assertEqual(len(stdout_dbs), (new_executions_count)) - stderr_dbs = ActionExecutionOutput.query(output_type='stderr') + stderr_dbs = ActionExecutionOutput.query(output_type="stderr") self.assertEqual(len(stderr_dbs), (new_executions_count)) def test_inquiry_garbage_collection(self): @@ -207,28 +226,28 @@ def test_inquiry_garbage_collection(self): # Insert some mock Inquiries with start_timestamp > TTL old_inquiry_count = 15 - timestamp = (now - datetime.timedelta(minutes=3)) + timestamp = now - datetime.timedelta(minutes=3) for index in range(0, old_inquiry_count): self._create_inquiry(ttl=2, timestamp=timestamp) # Insert some mock Inquiries with TTL set to a "disabled" value disabled_inquiry_count = 3 - timestamp = (now - datetime.timedelta(minutes=3)) + timestamp = now - datetime.timedelta(minutes=3) for index in range(0, disabled_inquiry_count): self._create_inquiry(ttl=0, timestamp=timestamp) # Insert some mock Inquiries with start_timestamp < TTL new_inquiry_count = 5 - timestamp = (now - datetime.timedelta(minutes=3)) + timestamp = now - datetime.timedelta(minutes=3) for index in range(0, new_inquiry_count): self._create_inquiry(ttl=15, timestamp=timestamp) - filters = { - 'status': action_constants.LIVEACTION_STATUS_PENDING - } + filters = {"status": action_constants.LIVEACTION_STATUS_PENDING} inquiries = list(ActionExecution.query(**filters)) - self.assertEqual(len(inquiries), - (old_inquiry_count + new_inquiry_count + disabled_inquiry_count)) + self.assertEqual( + len(inquiries), + (old_inquiry_count + new_inquiry_count + disabled_inquiry_count), + ) # Start garbage collector process = self._start_garbage_collector() @@ -243,18 +262,25 @@ def test_inquiry_garbage_collection(self): self.assertEqual(len(inquiries), new_inquiry_count + disabled_inquiry_count) def _create_inquiry(self, ttl, timestamp): - action_db = self.models['actions']['ask.yaml'] + action_db = self.models["actions"]["ask.yaml"] liveaction_db = LiveActionDB() liveaction_db.status = action_constants.LIVEACTION_STATUS_PENDING liveaction_db.start_timestamp = timestamp - liveaction_db.action = ResourceReference(name=action_db.name, pack=action_db.pack).ref - liveaction_db.result = {'ttl': ttl} + liveaction_db.action = ResourceReference( + name=action_db.name, pack=action_db.pack + ).ref + liveaction_db.result = {"ttl": ttl} liveaction_db = LiveAction.add_or_update(liveaction_db) executions.create_execution_object(liveaction_db) def _start_garbage_collector(self): subprocess = concurrency.get_subprocess_module() - process = subprocess.Popen(CMD_INQUIRY, stdout=subprocess.PIPE, stderr=subprocess.PIPE, - shell=False, preexec_fn=os.setsid) + process = subprocess.Popen( + CMD_INQUIRY, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=False, + preexec_fn=os.setsid, + ) self.add_process(process=process) return process diff --git a/st2reactor/tests/integration/test_rules_engine.py b/st2reactor/tests/integration/test_rules_engine.py index 669a88797f8..05ebce5e9ea 100644 --- a/st2reactor/tests/integration/test_rules_engine.py +++ b/st2reactor/tests/integration/test_rules_engine.py @@ -26,18 +26,16 @@ from st2tests.base import IntegrationTestCase from st2tests.base import CleanDbTestCase -__all__ = [ - 'TimersEngineServiceEnableDisableTestCase' -] +__all__ = ["TimersEngineServiceEnableDisableTestCase"] BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -ST2_CONFIG_PATH = os.path.join(BASE_DIR, '../../../conf/st2.tests.conf') +ST2_CONFIG_PATH = os.path.join(BASE_DIR, "../../../conf/st2.tests.conf") ST2_CONFIG_PATH = os.path.abspath(ST2_CONFIG_PATH) PYTHON_BINARY = sys.executable -BINARY = os.path.join(BASE_DIR, '../../../st2reactor/bin/st2timersengine') +BINARY = os.path.join(BASE_DIR, "../../../st2reactor/bin/st2timersengine") BINARY = os.path.abspath(BINARY) -CMD = [PYTHON_BINARY, BINARY, '--config-file'] +CMD = [PYTHON_BINARY, BINARY, "--config-file"] class TimersEngineServiceEnableDisableTestCase(IntegrationTestCase, CleanDbTestCase): @@ -46,7 +44,7 @@ def setUp(self): config_text = open(ST2_CONFIG_PATH).read() self.cfg_fd, self.cfg_path = tempfile.mkstemp() - with open(self.cfg_path, 'w') as f: + with open(self.cfg_path, "w") as f: f.write(config_text) self.cmd = [] self.cmd.extend(CMD) @@ -65,7 +63,7 @@ def test_timer_enable_implicit(self): process = self._start_times_engine(cmd=self.cmd) lines = 0 while lines < 100: - line = process.stdout.readline().decode('utf-8') + line = process.stdout.readline().decode("utf-8") lines += 1 sys.stdout.write(line) @@ -78,12 +76,15 @@ def test_timer_enable_implicit(self): self.remove_process(process=process) if not seen_line: - raise AssertionError('Didn\'t see "%s" log line in timer output' % - (TIMER_ENABLED_LOG_LINE)) + raise AssertionError( + 'Didn\'t see "%s" log line in timer output' % (TIMER_ENABLED_LOG_LINE) + ) def test_timer_enable_explicit(self): - self._append_to_cfg_file(cfg_path=self.cfg_path, - content='\n[timersengine]\nenable = True\n[timer]\nenable = True') + self._append_to_cfg_file( + cfg_path=self.cfg_path, + content="\n[timersengine]\nenable = True\n[timer]\nenable = True", + ) process = None seen_line = False @@ -91,7 +92,7 @@ def test_timer_enable_explicit(self): process = self._start_times_engine(cmd=self.cmd) lines = 0 while lines < 100: - line = process.stdout.readline().decode('utf-8') + line = process.stdout.readline().decode("utf-8") lines += 1 sys.stdout.write(line) @@ -104,12 +105,15 @@ def test_timer_enable_explicit(self): self.remove_process(process=process) if not seen_line: - raise AssertionError('Didn\'t see "%s" log line in timer output' % - (TIMER_ENABLED_LOG_LINE)) + raise AssertionError( + 'Didn\'t see "%s" log line in timer output' % (TIMER_ENABLED_LOG_LINE) + ) def test_timer_disable_explicit(self): - self._append_to_cfg_file(cfg_path=self.cfg_path, - content='\n[timersengine]\nenable = False\n[timer]\nenable = False') + self._append_to_cfg_file( + cfg_path=self.cfg_path, + content="\n[timersengine]\nenable = False\n[timer]\nenable = False", + ) process = None seen_line = False @@ -117,7 +121,7 @@ def test_timer_disable_explicit(self): process = self._start_times_engine(cmd=self.cmd) lines = 0 while lines < 100: - line = process.stdout.readline().decode('utf-8') + line = process.stdout.readline().decode("utf-8") lines += 1 sys.stdout.write(line) @@ -130,18 +134,24 @@ def test_timer_disable_explicit(self): self.remove_process(process=process) if not seen_line: - raise AssertionError('Didn\'t see "%s" log line in timer output' % - (TIMER_DISABLED_LOG_LINE)) + raise AssertionError( + 'Didn\'t see "%s" log line in timer output' % (TIMER_DISABLED_LOG_LINE) + ) def _start_times_engine(self, cmd): subprocess = concurrency.get_subprocess_module() - process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, - shell=False, preexec_fn=os.setsid) + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=False, + preexec_fn=os.setsid, + ) self.add_process(process=process) return process def _append_to_cfg_file(self, cfg_path, content): - with open(cfg_path, 'a') as f: + with open(cfg_path, "a") as f: f.write(content) def _remove_tempfile(self, fd, path): diff --git a/st2reactor/tests/integration/test_sensor_container.py b/st2reactor/tests/integration/test_sensor_container.py index 7971e361063..41eb3307bc9 100644 --- a/st2reactor/tests/integration/test_sensor_container.py +++ b/st2reactor/tests/integration/test_sensor_container.py @@ -30,28 +30,26 @@ from st2common.bootstrap.sensorsregistrar import register_sensors from st2tests.base import IntegrationTestCase -__all__ = [ - 'SensorContainerTestCase' -] +__all__ = ["SensorContainerTestCase"] BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -ST2_CONFIG_PATH = os.path.join(BASE_DIR, '../../../conf/st2.tests.conf') +ST2_CONFIG_PATH = os.path.join(BASE_DIR, "../../../conf/st2.tests.conf") ST2_CONFIG_PATH = os.path.abspath(ST2_CONFIG_PATH) PYTHON_BINARY = sys.executable -BINARY = os.path.join(BASE_DIR, '../../../st2reactor/bin/st2sensorcontainer') +BINARY = os.path.join(BASE_DIR, "../../../st2reactor/bin/st2sensorcontainer") BINARY = os.path.abspath(BINARY) -PACKS_BASE_PATH = os.path.abspath(os.path.join(BASE_DIR, '../../../contrib')) +PACKS_BASE_PATH = os.path.abspath(os.path.join(BASE_DIR, "../../../contrib")) DEFAULT_CMD = [ PYTHON_BINARY, BINARY, - '--config-file', + "--config-file", ST2_CONFIG_PATH, - '--sensor-ref=examples.SamplePollingSensor' + "--sensor-ref=examples.SamplePollingSensor", ] @@ -69,11 +67,24 @@ def setUpClass(cls): st2tests.config.parse_args() - username = cfg.CONF.database.username if hasattr(cfg.CONF.database, 'username') else None - password = cfg.CONF.database.password if hasattr(cfg.CONF.database, 'password') else None + username = ( + cfg.CONF.database.username + if hasattr(cfg.CONF.database, "username") + else None + ) + password = ( + cfg.CONF.database.password + if hasattr(cfg.CONF.database, "password") + else None + ) cls.db_connection = db_setup( - cfg.CONF.database.db_name, cfg.CONF.database.host, cfg.CONF.database.port, - username=username, password=password, ensure_indexes=False) + cfg.CONF.database.db_name, + cfg.CONF.database.host, + cfg.CONF.database.port, + username=username, + password=password, + ensure_indexes=False, + ) # NOTE: We need to perform this patching because test fixtures are located outside of the # packs base paths directory. This will never happen outside the context of test fixtures. @@ -83,11 +94,17 @@ def setUpClass(cls): register_sensors(packs_base_paths=[PACKS_BASE_PATH], use_pack_cache=False) # Create virtualenv for examples pack - virtualenv_path = '/tmp/virtualenvs/examples' + virtualenv_path = "/tmp/virtualenvs/examples" - run_command(cmd=['rm', '-rf', virtualenv_path]) + run_command(cmd=["rm", "-rf", virtualenv_path]) - cmd = ['virtualenv', '--system-site-packages', '--python', PYTHON_BINARY, virtualenv_path] + cmd = [ + "virtualenv", + "--system-site-packages", + "--python", + PYTHON_BINARY, + virtualenv_path, + ] run_command(cmd=cmd) def test_child_processes_are_killed_on_sigint(self): @@ -169,7 +186,13 @@ def test_child_processes_are_killed_on_sigkill(self): def test_single_sensor_mode(self): # 1. --sensor-ref not provided - cmd = [PYTHON_BINARY, BINARY, '--config-file', ST2_CONFIG_PATH, '--single-sensor-mode'] + cmd = [ + PYTHON_BINARY, + BINARY, + "--config-file", + ST2_CONFIG_PATH, + "--single-sensor-mode", + ] process = self._start_sensor_container(cmd=cmd) pp = psutil.Process(process.pid) @@ -178,14 +201,24 @@ def test_single_sensor_mode(self): concurrency.sleep(4) stdout = process.stdout.read() - self.assertTrue((b'--sensor-ref argument must be provided when running in single sensor ' - b'mode') in stdout) + self.assertTrue( + ( + b"--sensor-ref argument must be provided when running in single sensor " + b"mode" + ) + in stdout + ) self.assertProcessExited(proc=pp) self.remove_process(process=process) # 2. sensor ref provided - cmd = [BINARY, '--config-file', ST2_CONFIG_PATH, '--single-sensor-mode', - '--sensor-ref=examples.SampleSensorExit'] + cmd = [ + BINARY, + "--config-file", + ST2_CONFIG_PATH, + "--single-sensor-mode", + "--sensor-ref=examples.SampleSensorExit", + ] process = self._start_sensor_container(cmd=cmd) pp = psutil.Process(process.pid) @@ -196,9 +229,11 @@ def test_single_sensor_mode(self): # Container should exit and not respawn a sensor in single sensor mode stdout = process.stdout.read() - self.assertTrue(b'Process for sensor examples.SampleSensorExit has exited with code 110') - self.assertTrue(b'Not respawning a sensor since running in single sensor mode') - self.assertTrue(b'Process container quit with exit_code 110.') + self.assertTrue( + b"Process for sensor examples.SampleSensorExit has exited with code 110" + ) + self.assertTrue(b"Not respawning a sensor since running in single sensor mode") + self.assertTrue(b"Process container quit with exit_code 110.") concurrency.sleep(2) self.assertProcessExited(proc=pp) @@ -207,7 +242,12 @@ def test_single_sensor_mode(self): def _start_sensor_container(self, cmd=DEFAULT_CMD): subprocess = concurrency.get_subprocess_module() - process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, - shell=False, preexec_fn=os.setsid) + process = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=False, + preexec_fn=os.setsid, + ) self.add_process(process=process) return process diff --git a/st2reactor/tests/integration/test_sensor_watcher.py b/st2reactor/tests/integration/test_sensor_watcher.py index 9727da92a68..6caee09c7fd 100644 --- a/st2reactor/tests/integration/test_sensor_watcher.py +++ b/st2reactor/tests/integration/test_sensor_watcher.py @@ -22,19 +22,15 @@ from st2common.services.sensor_watcher import SensorWatcher from st2tests.base import IntegrationTestCase -__all__ = [ - 'SensorWatcherTestCase' -] +__all__ = ["SensorWatcherTestCase"] class SensorWatcherTestCase(IntegrationTestCase): - @classmethod def setUpClass(cls): super(SensorWatcherTestCase, cls).setUpClass() def test_sensor_watch_queue_gets_deleted_on_stop(self): - def create_handler(sensor_db): pass @@ -44,25 +40,32 @@ def update_handler(sensor_db): def delete_handler(sensor_db): pass - sensor_watcher = SensorWatcher(create_handler, update_handler, delete_handler, - queue_suffix='covfefe') + sensor_watcher = SensorWatcher( + create_handler, update_handler, delete_handler, queue_suffix="covfefe" + ) sensor_watcher.start() - sw_queues = self._get_sensor_watcher_amqp_queues(queue_name='st2.sensor.watch.covfefe') + sw_queues = self._get_sensor_watcher_amqp_queues( + queue_name="st2.sensor.watch.covfefe" + ) start = monotonic() done = False while not done: concurrency.sleep(0.01) - sw_queues = self._get_sensor_watcher_amqp_queues(queue_name='st2.sensor.watch.covfefe') + sw_queues = self._get_sensor_watcher_amqp_queues( + queue_name="st2.sensor.watch.covfefe" + ) done = len(sw_queues) > 0 or ((monotonic() - start) < 5) sensor_watcher.stop() - sw_queues = self._get_sensor_watcher_amqp_queues(queue_name='st2.sensor.watch.covfefe') + sw_queues = self._get_sensor_watcher_amqp_queues( + queue_name="st2.sensor.watch.covfefe" + ) self.assertTrue(len(sw_queues) == 0) def _list_amqp_queues(self): - rabbit_client = Client('localhost:15672', 'guest', 'guest') - queues = [q['name'] for q in rabbit_client.get_queues()] + rabbit_client = Client("localhost:15672", "guest", "guest") + queues = [q["name"] for q in rabbit_client.get_queues()] return queues def _get_sensor_watcher_amqp_queues(self, queue_name): diff --git a/st2reactor/tests/unit/test_container_utils.py b/st2reactor/tests/unit/test_container_utils.py index d8c14bf1d5f..24d297ba7d2 100644 --- a/st2reactor/tests/unit/test_container_utils.py +++ b/st2reactor/tests/unit/test_container_utils.py @@ -23,20 +23,25 @@ from st2tests.base import CleanDbTestCase -@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) +@mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) class ContainerUtilsTest(CleanDbTestCase): def setUp(self): super(ContainerUtilsTest, self).setUp() # Insert mock TriggerDB - trigger_db = TriggerDB(name='name1', pack='pack1', type='type1', - parameters={'a': 1, 'b': '2', 'c': 'foo'}) + trigger_db = TriggerDB( + name="name1", + pack="pack1", + type="type1", + parameters={"a": 1, "b": "2", "c": "foo"}, + ) self.trigger_db = Trigger.add_or_update(trigger_db) def test_create_trigger_instance_invalid_trigger(self): - trigger_instance = 'dummy_pack.footrigger' - instance = create_trigger_instance(trigger=trigger_instance, payload={}, - occurrence_time=None) + trigger_instance = "dummy_pack.footrigger" + instance = create_trigger_instance( + trigger=trigger_instance, payload={}, occurrence_time=None + ) self.assertIsNone(instance) def test_create_trigger_instance_success(self): @@ -46,34 +51,40 @@ def test_create_trigger_instance_success(self): occurrence_time = None # TriggerDB look up by id - trigger = {'id': self.trigger_db.id} - trigger_instance_db = create_trigger_instance(trigger=trigger, payload=payload, - occurrence_time=occurrence_time) - self.assertEqual(trigger_instance_db.trigger, 'pack1.name1') + trigger = {"id": self.trigger_db.id} + trigger_instance_db = create_trigger_instance( + trigger=trigger, payload=payload, occurrence_time=occurrence_time + ) + self.assertEqual(trigger_instance_db.trigger, "pack1.name1") # Object doesn't exist (invalid id) - trigger = {'id': '5776aa2b0640fd2991b15987'} - trigger_instance_db = create_trigger_instance(trigger=trigger, payload=payload, - occurrence_time=occurrence_time) + trigger = {"id": "5776aa2b0640fd2991b15987"} + trigger_instance_db = create_trigger_instance( + trigger=trigger, payload=payload, occurrence_time=occurrence_time + ) self.assertEqual(trigger_instance_db, None) # TriggerDB look up by uid - trigger = {'uid': self.trigger_db.uid} - trigger_instance_db = create_trigger_instance(trigger=trigger, payload=payload, - occurrence_time=occurrence_time) - self.assertEqual(trigger_instance_db.trigger, 'pack1.name1') + trigger = {"uid": self.trigger_db.uid} + trigger_instance_db = create_trigger_instance( + trigger=trigger, payload=payload, occurrence_time=occurrence_time + ) + self.assertEqual(trigger_instance_db.trigger, "pack1.name1") - trigger = {'uid': 'invaliduid'} - trigger_instance_db = create_trigger_instance(trigger=trigger, payload=payload, - occurrence_time=occurrence_time) + trigger = {"uid": "invaliduid"} + trigger_instance_db = create_trigger_instance( + trigger=trigger, payload=payload, occurrence_time=occurrence_time + ) self.assertEqual(trigger_instance_db, None) # TriggerDB look up by type and parameters (last resort) - trigger = {'type': 'pack1.name1', 'parameters': self.trigger_db.parameters} - trigger_instance_db = create_trigger_instance(trigger=trigger, payload=payload, - occurrence_time=occurrence_time) + trigger = {"type": "pack1.name1", "parameters": self.trigger_db.parameters} + trigger_instance_db = create_trigger_instance( + trigger=trigger, payload=payload, occurrence_time=occurrence_time + ) - trigger = {'type': 'pack1.name1', 'parameters': {}} - trigger_instance_db = create_trigger_instance(trigger=trigger, payload=payload, - occurrence_time=occurrence_time) + trigger = {"type": "pack1.name1", "parameters": {}} + trigger_instance_db = create_trigger_instance( + trigger=trigger, payload=payload, occurrence_time=occurrence_time + ) self.assertEqual(trigger_instance_db, None) diff --git a/st2reactor/tests/unit/test_enforce.py b/st2reactor/tests/unit/test_enforce.py index 174216dbb44..4b282305bd2 100644 --- a/st2reactor/tests/unit/test_enforce.py +++ b/st2reactor/tests/unit/test_enforce.py @@ -38,62 +38,68 @@ from st2tests import DbTestCase from st2tests.fixturesloader import FixturesLoader -__all__ = [ - 'RuleEnforcerTestCase', - 'RuleEnforcerDataTransformationTestCase' -] +__all__ = ["RuleEnforcerTestCase", "RuleEnforcerDataTransformationTestCase"] -PACK = 'generic' +PACK = "generic" FIXTURES_1 = { - 'runners': ['testrunner1.yaml', 'testrunner2.yaml'], - 'actions': ['action1.yaml', 'a2.yaml', 'a2_default_value.yaml'], - 'triggertypes': ['triggertype1.yaml'], - 'triggers': ['trigger1.yaml'], - 'traces': ['trace_for_test_enforce.yaml', 'trace_for_test_enforce_2.yaml', - 'trace_for_test_enforce_3.yaml'] + "runners": ["testrunner1.yaml", "testrunner2.yaml"], + "actions": ["action1.yaml", "a2.yaml", "a2_default_value.yaml"], + "triggertypes": ["triggertype1.yaml"], + "triggers": ["trigger1.yaml"], + "traces": [ + "trace_for_test_enforce.yaml", + "trace_for_test_enforce_2.yaml", + "trace_for_test_enforce_3.yaml", + ], } FIXTURES_2 = { - 'rules': [ - 'rule1.yaml', - 'rule2.yaml', - 'rule_use_none_filter.yaml', - 'rule_none_no_use_none_filter.yaml', - 'rule_action_default_value.yaml', - 'rule_action_default_value_overridden.yaml', - 'rule_action_default_value_render_fail.yaml' + "rules": [ + "rule1.yaml", + "rule2.yaml", + "rule_use_none_filter.yaml", + "rule_none_no_use_none_filter.yaml", + "rule_action_default_value.yaml", + "rule_action_default_value_overridden.yaml", + "rule_action_default_value_render_fail.yaml", ] } MOCK_TRIGGER_INSTANCE = TriggerInstanceDB() -MOCK_TRIGGER_INSTANCE.id = 'triggerinstance-test' -MOCK_TRIGGER_INSTANCE.payload = {'t1_p': 't1_p_v'} +MOCK_TRIGGER_INSTANCE.id = "triggerinstance-test" +MOCK_TRIGGER_INSTANCE.payload = {"t1_p": "t1_p_v"} MOCK_TRIGGER_INSTANCE.occurrence_time = date_utils.get_datetime_utc_now() MOCK_TRIGGER_INSTANCE_2 = TriggerInstanceDB() -MOCK_TRIGGER_INSTANCE_2.id = 'triggerinstance-test2' -MOCK_TRIGGER_INSTANCE_2.payload = {'t1_p': None} +MOCK_TRIGGER_INSTANCE_2.id = "triggerinstance-test2" +MOCK_TRIGGER_INSTANCE_2.payload = {"t1_p": None} MOCK_TRIGGER_INSTANCE_2.occurrence_time = date_utils.get_datetime_utc_now() MOCK_TRIGGER_INSTANCE_3 = TriggerInstanceDB() -MOCK_TRIGGER_INSTANCE_3.id = 'triggerinstance-test3' -MOCK_TRIGGER_INSTANCE_3.payload = {'t1_p': None, 't2_p': 'value2'} +MOCK_TRIGGER_INSTANCE_3.id = "triggerinstance-test3" +MOCK_TRIGGER_INSTANCE_3.payload = {"t1_p": None, "t2_p": "value2"} MOCK_TRIGGER_INSTANCE_3.occurrence_time = date_utils.get_datetime_utc_now() -MOCK_TRIGGER_INSTANCE_PAYLOAD = {'k1': 'v1', 'k2': 'v2', 'k3': 3, 'k4': True, - 'k5': {'foo': 'bar'}, 'k6': [1, 3]} +MOCK_TRIGGER_INSTANCE_PAYLOAD = { + "k1": "v1", + "k2": "v2", + "k3": 3, + "k4": True, + "k5": {"foo": "bar"}, + "k6": [1, 3], +} MOCK_TRIGGER_INSTANCE_4 = TriggerInstanceDB() -MOCK_TRIGGER_INSTANCE_4.id = 'triggerinstance-test4' +MOCK_TRIGGER_INSTANCE_4.id = "triggerinstance-test4" MOCK_TRIGGER_INSTANCE_4.payload = MOCK_TRIGGER_INSTANCE_PAYLOAD MOCK_TRIGGER_INSTANCE_4.occurrence_time = date_utils.get_datetime_utc_now() MOCK_LIVEACTION = LiveActionDB() -MOCK_LIVEACTION.id = 'liveaction-test-1.id' -MOCK_LIVEACTION.status = 'requested' +MOCK_LIVEACTION.id = "liveaction-test-1.id" +MOCK_LIVEACTION.status = "requested" MOCK_EXECUTION = ActionExecutionDB() -MOCK_EXECUTION.id = 'exec-test-1.id' -MOCK_EXECUTION.status = 'requested' +MOCK_EXECUTION.id = "exec-test-1.id" +MOCK_EXECUTION.status = "requested" FAILURE_REASON = "fail!" @@ -111,11 +117,16 @@ def setUpClass(cls): # Create TriggerTypes before creation of Rule to avoid failure. Rule requires the # Trigger and therefore TriggerType to be created prior to rule creation. cls.models = FixturesLoader().save_fixtures_to_db( - fixtures_pack=PACK, fixtures_dict=FIXTURES_1) - cls.models.update(FixturesLoader().save_fixtures_to_db( - fixtures_pack=PACK, fixtures_dict=FIXTURES_2)) + fixtures_pack=PACK, fixtures_dict=FIXTURES_1 + ) + cls.models.update( + FixturesLoader().save_fixtures_to_db( + fixtures_pack=PACK, fixtures_dict=FIXTURES_2 + ) + ) MOCK_TRIGGER_INSTANCE.trigger = reference.get_ref_from_model( - cls.models['triggers']['trigger1.yaml']) + cls.models["triggers"]["trigger1.yaml"] + ) def setUp(self): super(BaseRuleEnforcerTestCase, self).setUp() @@ -124,335 +135,445 @@ def setUp(self): class RuleEnforcerTestCase(BaseRuleEnforcerTestCase): - - @mock.patch.object(action_service, 'request', mock.MagicMock( - return_value=(MOCK_LIVEACTION, MOCK_EXECUTION))) + @mock.patch.object( + action_service, + "request", + mock.MagicMock(return_value=(MOCK_LIVEACTION, MOCK_EXECUTION)), + ) def test_ruleenforcement_occurs(self): - enforcer = RuleEnforcer(MOCK_TRIGGER_INSTANCE, self.models['rules']['rule1.yaml']) + enforcer = RuleEnforcer( + MOCK_TRIGGER_INSTANCE, self.models["rules"]["rule1.yaml"] + ) execution_db = enforcer.enforce() self.assertIsNotNone(execution_db) - @mock.patch.object(action_service, 'request', mock.MagicMock( - return_value=(MOCK_LIVEACTION, MOCK_EXECUTION))) + @mock.patch.object( + action_service, + "request", + mock.MagicMock(return_value=(MOCK_LIVEACTION, MOCK_EXECUTION)), + ) def test_ruleenforcement_casts(self): - enforcer = RuleEnforcer(MOCK_TRIGGER_INSTANCE, self.models['rules']['rule2.yaml']) + enforcer = RuleEnforcer( + MOCK_TRIGGER_INSTANCE, self.models["rules"]["rule2.yaml"] + ) execution_db = enforcer.enforce() self.assertIsNotNone(execution_db) self.assertTrue(action_service.request.called) - self.assertIsInstance(action_service.request.call_args[0][0].parameters['objtype'], dict) - - @mock.patch.object(action_service, 'request', mock.MagicMock( - return_value=(MOCK_LIVEACTION, MOCK_EXECUTION))) - @mock.patch.object(RuleEnforcement, 'add_or_update', mock.MagicMock()) + self.assertIsInstance( + action_service.request.call_args[0][0].parameters["objtype"], dict + ) + + @mock.patch.object( + action_service, + "request", + mock.MagicMock(return_value=(MOCK_LIVEACTION, MOCK_EXECUTION)), + ) + @mock.patch.object(RuleEnforcement, "add_or_update", mock.MagicMock()) def test_ruleenforcement_create_on_success(self): - enforcer = RuleEnforcer(MOCK_TRIGGER_INSTANCE, self.models['rules']['rule2.yaml']) + enforcer = RuleEnforcer( + MOCK_TRIGGER_INSTANCE, self.models["rules"]["rule2.yaml"] + ) execution_db = enforcer.enforce() self.assertIsNotNone(execution_db) self.assertTrue(RuleEnforcement.add_or_update.called) - self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].rule.ref, - self.models['rules']['rule2.yaml'].ref) - self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].status, - RULE_ENFORCEMENT_STATUS_SUCCEEDED) - - @mock.patch.object(action_service, 'request', mock.MagicMock( - return_value=(MOCK_LIVEACTION, MOCK_EXECUTION))) - @mock.patch.object(RuleEnforcement, 'add_or_update', mock.MagicMock()) + self.assertEqual( + RuleEnforcement.add_or_update.call_args[0][0].rule.ref, + self.models["rules"]["rule2.yaml"].ref, + ) + self.assertEqual( + RuleEnforcement.add_or_update.call_args[0][0].status, + RULE_ENFORCEMENT_STATUS_SUCCEEDED, + ) + + @mock.patch.object( + action_service, + "request", + mock.MagicMock(return_value=(MOCK_LIVEACTION, MOCK_EXECUTION)), + ) + @mock.patch.object(RuleEnforcement, "add_or_update", mock.MagicMock()) def test_rule_enforcement_create_rule_none_param_casting(self): mock_trigger_instance = MOCK_TRIGGER_INSTANCE_2 # 1. Non None value, should be serialized as regular string - mock_trigger_instance.payload = {'t1_p': 'somevalue'} + mock_trigger_instance.payload = {"t1_p": "somevalue"} def mock_cast_string(x): - assert x == 'somevalue' + assert x == "somevalue" return casts._cast_string(x) - casts.CASTS['string'] = mock_cast_string - enforcer = RuleEnforcer(mock_trigger_instance, - self.models['rules']['rule_use_none_filter.yaml']) + casts.CASTS["string"] = mock_cast_string + + enforcer = RuleEnforcer( + mock_trigger_instance, self.models["rules"]["rule_use_none_filter.yaml"] + ) execution_db = enforcer.enforce() # Verify value has been serialized correctly call_args = action_service.request.call_args[0] live_action_db = call_args[0] - self.assertEqual(live_action_db.parameters['actionstr'], 'somevalue') + self.assertEqual(live_action_db.parameters["actionstr"], "somevalue") self.assertIsNotNone(execution_db) self.assertTrue(RuleEnforcement.add_or_update.called) - self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].rule.ref, - self.models['rules']['rule_use_none_filter.yaml'].ref) - self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].status, - RULE_ENFORCEMENT_STATUS_SUCCEEDED) + self.assertEqual( + RuleEnforcement.add_or_update.call_args[0][0].rule.ref, + self.models["rules"]["rule_use_none_filter.yaml"].ref, + ) + self.assertEqual( + RuleEnforcement.add_or_update.call_args[0][0].status, + RULE_ENFORCEMENT_STATUS_SUCCEEDED, + ) # 2. Verify that None type from trigger instance is correctly serialized to # None when using "use_none" Jinja filter when invoking an action - mock_trigger_instance.payload = {'t1_p': None} + mock_trigger_instance.payload = {"t1_p": None} def mock_cast_string(x): assert x == data.NONE_MAGIC_VALUE return casts._cast_string(x) - casts.CASTS['string'] = mock_cast_string - enforcer = RuleEnforcer(mock_trigger_instance, - self.models['rules']['rule_use_none_filter.yaml']) + casts.CASTS["string"] = mock_cast_string + + enforcer = RuleEnforcer( + mock_trigger_instance, self.models["rules"]["rule_use_none_filter.yaml"] + ) execution_db = enforcer.enforce() # Verify None has been correctly serialized to None call_args = action_service.request.call_args[0] live_action_db = call_args[0] - self.assertEqual(live_action_db.parameters['actionstr'], None) + self.assertEqual(live_action_db.parameters["actionstr"], None) self.assertIsNotNone(execution_db) self.assertTrue(RuleEnforcement.add_or_update.called) - self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].rule.ref, - self.models['rules']['rule_use_none_filter.yaml'].ref) - self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].status, - RULE_ENFORCEMENT_STATUS_SUCCEEDED) + self.assertEqual( + RuleEnforcement.add_or_update.call_args[0][0].rule.ref, + self.models["rules"]["rule_use_none_filter.yaml"].ref, + ) + self.assertEqual( + RuleEnforcement.add_or_update.call_args[0][0].status, + RULE_ENFORCEMENT_STATUS_SUCCEEDED, + ) - casts.CASTS['string'] = casts._cast_string + casts.CASTS["string"] = casts._cast_string # 3. Parameter value is a compound string one of which values is None, but "use_none" # filter is not used mock_trigger_instance = MOCK_TRIGGER_INSTANCE_3 - mock_trigger_instance.payload = {'t1_p': None, 't2_p': 'value2'} + mock_trigger_instance.payload = {"t1_p": None, "t2_p": "value2"} - enforcer = RuleEnforcer(mock_trigger_instance, - self.models['rules']['rule_none_no_use_none_filter.yaml']) + enforcer = RuleEnforcer( + mock_trigger_instance, + self.models["rules"]["rule_none_no_use_none_filter.yaml"], + ) execution_db = enforcer.enforce() # Verify None has been correctly serialized to None call_args = action_service.request.call_args[0] live_action_db = call_args[0] - self.assertEqual(live_action_db.parameters['actionstr'], 'None-value2') + self.assertEqual(live_action_db.parameters["actionstr"], "None-value2") self.assertIsNotNone(execution_db) self.assertTrue(RuleEnforcement.add_or_update.called) - self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].rule.ref, - self.models['rules']['rule_none_no_use_none_filter.yaml'].ref) - self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].status, - RULE_ENFORCEMENT_STATUS_SUCCEEDED) - - casts.CASTS['string'] = casts._cast_string - - @mock.patch.object(action_service, 'request', mock.MagicMock( - side_effect=ValueError(FAILURE_REASON))) - @mock.patch.object(RuleEnforcement, 'add_or_update', mock.MagicMock()) + self.assertEqual( + RuleEnforcement.add_or_update.call_args[0][0].rule.ref, + self.models["rules"]["rule_none_no_use_none_filter.yaml"].ref, + ) + self.assertEqual( + RuleEnforcement.add_or_update.call_args[0][0].status, + RULE_ENFORCEMENT_STATUS_SUCCEEDED, + ) + + casts.CASTS["string"] = casts._cast_string + + @mock.patch.object( + action_service, + "request", + mock.MagicMock(side_effect=ValueError(FAILURE_REASON)), + ) + @mock.patch.object(RuleEnforcement, "add_or_update", mock.MagicMock()) def test_ruleenforcement_create_on_fail(self): - enforcer = RuleEnforcer(MOCK_TRIGGER_INSTANCE, self.models['rules']['rule1.yaml']) + enforcer = RuleEnforcer( + MOCK_TRIGGER_INSTANCE, self.models["rules"]["rule1.yaml"] + ) execution_db = enforcer.enforce() self.assertIsNone(execution_db) self.assertTrue(RuleEnforcement.add_or_update.called) - self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].failure_reason, - FAILURE_REASON) - self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].status, - RULE_ENFORCEMENT_STATUS_FAILED) - - @mock.patch.object(action_service, 'request', mock.MagicMock( - return_value=(MOCK_LIVEACTION, MOCK_EXECUTION))) - @mock.patch.object(RuleEnforcement, 'add_or_update', mock.MagicMock()) - @mock.patch('st2common.util.param.get_config', - mock.Mock(return_value={'arrtype_value': ['one 1', 'two 2', 'three 3']})) + self.assertEqual( + RuleEnforcement.add_or_update.call_args[0][0].failure_reason, FAILURE_REASON + ) + self.assertEqual( + RuleEnforcement.add_or_update.call_args[0][0].status, + RULE_ENFORCEMENT_STATUS_FAILED, + ) + + @mock.patch.object( + action_service, + "request", + mock.MagicMock(return_value=(MOCK_LIVEACTION, MOCK_EXECUTION)), + ) + @mock.patch.object(RuleEnforcement, "add_or_update", mock.MagicMock()) + @mock.patch( + "st2common.util.param.get_config", + mock.Mock(return_value={"arrtype_value": ["one 1", "two 2", "three 3"]}), + ) def test_action_default_jinja_parameter_value_is_rendered(self): # Verify that a default action parameter which is a Jinja variable is correctly rendered - rule = self.models['rules']['rule_action_default_value.yaml'] + rule = self.models["rules"]["rule_action_default_value.yaml"] enforcer = RuleEnforcer(MOCK_TRIGGER_INSTANCE, rule) execution_db = enforcer.enforce() self.assertIsNotNone(execution_db) self.assertTrue(RuleEnforcement.add_or_update.called) - self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].rule.ref, rule.ref) - self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].status, - RULE_ENFORCEMENT_STATUS_SUCCEEDED) + self.assertEqual( + RuleEnforcement.add_or_update.call_args[0][0].rule.ref, rule.ref + ) + self.assertEqual( + RuleEnforcement.add_or_update.call_args[0][0].status, + RULE_ENFORCEMENT_STATUS_SUCCEEDED, + ) call_parameters = action_service.request.call_args[0][0].parameters - self.assertEqual(call_parameters['objtype'], {'t1_p': 't1_p_v'}) - self.assertEqual(call_parameters['strtype'], 't1_p_v') - self.assertEqual(call_parameters['arrtype'], ['one 1', 'two 2', 'three 3']) + self.assertEqual(call_parameters["objtype"], {"t1_p": "t1_p_v"}) + self.assertEqual(call_parameters["strtype"], "t1_p_v") + self.assertEqual(call_parameters["arrtype"], ["one 1", "two 2", "three 3"]) - @mock.patch.object(action_service, 'request', mock.MagicMock( - return_value=(MOCK_LIVEACTION, MOCK_EXECUTION))) - @mock.patch.object(RuleEnforcement, 'add_or_update', mock.MagicMock()) + @mock.patch.object( + action_service, + "request", + mock.MagicMock(return_value=(MOCK_LIVEACTION, MOCK_EXECUTION)), + ) + @mock.patch.object(RuleEnforcement, "add_or_update", mock.MagicMock()) def test_action_default_jinja_parameter_value_overridden_in_rule(self): # Verify that it works correctly if default parameter value is overridden in rule - rule = self.models['rules']['rule_action_default_value_overridden.yaml'] + rule = self.models["rules"]["rule_action_default_value_overridden.yaml"] enforcer = RuleEnforcer(MOCK_TRIGGER_INSTANCE, rule) execution_db = enforcer.enforce() self.assertIsNotNone(execution_db) self.assertTrue(RuleEnforcement.add_or_update.called) - self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].rule.ref, rule.ref) - self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].status, - RULE_ENFORCEMENT_STATUS_SUCCEEDED) + self.assertEqual( + RuleEnforcement.add_or_update.call_args[0][0].rule.ref, rule.ref + ) + self.assertEqual( + RuleEnforcement.add_or_update.call_args[0][0].status, + RULE_ENFORCEMENT_STATUS_SUCCEEDED, + ) call_parameters = action_service.request.call_args[0][0].parameters - self.assertEqual(call_parameters['objtype'], {'t1_p': 't1_p_v'}) - self.assertEqual(call_parameters['strtype'], 't1_p_v') - self.assertEqual(call_parameters['arrtype'], ['override 1', 'override 2']) - - @mock.patch.object(action_service, 'request', mock.MagicMock( - return_value=(MOCK_LIVEACTION, MOCK_EXECUTION))) - @mock.patch.object(action_service, 'create_request', mock.MagicMock( - return_value=(MOCK_LIVEACTION, MOCK_EXECUTION))) - @mock.patch.object(action_service, 'update_status', mock.MagicMock( - return_value=(MOCK_LIVEACTION, MOCK_EXECUTION))) - @mock.patch.object(RuleEnforcement, 'add_or_update', mock.MagicMock()) + self.assertEqual(call_parameters["objtype"], {"t1_p": "t1_p_v"}) + self.assertEqual(call_parameters["strtype"], "t1_p_v") + self.assertEqual(call_parameters["arrtype"], ["override 1", "override 2"]) + + @mock.patch.object( + action_service, + "request", + mock.MagicMock(return_value=(MOCK_LIVEACTION, MOCK_EXECUTION)), + ) + @mock.patch.object( + action_service, + "create_request", + mock.MagicMock(return_value=(MOCK_LIVEACTION, MOCK_EXECUTION)), + ) + @mock.patch.object( + action_service, + "update_status", + mock.MagicMock(return_value=(MOCK_LIVEACTION, MOCK_EXECUTION)), + ) + @mock.patch.object(RuleEnforcement, "add_or_update", mock.MagicMock()) def test_action_default_jinja_parameter_value_render_fail(self): # Action parameter render failure should result in a failed execution - rule = self.models['rules']['rule_action_default_value_render_fail.yaml'] + rule = self.models["rules"]["rule_action_default_value_render_fail.yaml"] enforcer = RuleEnforcer(MOCK_TRIGGER_INSTANCE, rule) execution_db = enforcer.enforce() self.assertIsNone(execution_db) self.assertTrue(RuleEnforcement.add_or_update.called) - self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].rule.ref, rule.ref) - self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].status, - RULE_ENFORCEMENT_STATUS_FAILED) + self.assertEqual( + RuleEnforcement.add_or_update.call_args[0][0].rule.ref, rule.ref + ) + self.assertEqual( + RuleEnforcement.add_or_update.call_args[0][0].status, + RULE_ENFORCEMENT_STATUS_FAILED, + ) self.assertFalse(action_service.request.called) self.assertTrue(action_service.create_request.called) - self.assertEqual(action_service.create_request.call_args[0][0].action, - 'wolfpack.a2_default_value') + self.assertEqual( + action_service.create_request.call_args[0][0].action, + "wolfpack.a2_default_value", + ) self.assertTrue(action_service.update_status.called) - self.assertEqual(action_service.update_status.call_args[1]['new_status'], - action_constants.LIVEACTION_STATUS_FAILED) + self.assertEqual( + action_service.update_status.call_args[1]["new_status"], + action_constants.LIVEACTION_STATUS_FAILED, + ) - expected_msg = ('Failed to render parameter "arrtype": \'dict object\' has no ' - 'attribute \'arrtype_value\'') + expected_msg = ( + "Failed to render parameter \"arrtype\": 'dict object' has no " + "attribute 'arrtype_value'" + ) - result = action_service.update_status.call_args[1]['result'] - self.assertEqual(result['error'], expected_msg) + result = action_service.update_status.call_args[1]["result"] + self.assertEqual(result["error"], expected_msg) - self.assertEqual(RuleEnforcement.add_or_update.call_args[0][0].failure_reason, - expected_msg) + self.assertEqual( + RuleEnforcement.add_or_update.call_args[0][0].failure_reason, expected_msg + ) class RuleEnforcerDataTransformationTestCase(BaseRuleEnforcerTestCase): - def test_payload_data_transform(self): - rule = self.models['rules']['rule_action_default_value_render_fail.yaml'] + rule = self.models["rules"]["rule_action_default_value_render_fail.yaml"] - params = {'ip1': '{{trigger.k1}}-static', - 'ip2': '{{trigger.k2}} static'} + params = {"ip1": "{{trigger.k1}}-static", "ip2": "{{trigger.k2}} static"} - expected_params = {'ip1': 'v1-static', 'ip2': 'v2 static'} + expected_params = {"ip1": "v1-static", "ip2": "v2 static"} - self.assertResolvedParamsMatchExpected(rule=rule, - trigger_instance=MOCK_TRIGGER_INSTANCE_4, - params=params, - expected_params=expected_params) + self.assertResolvedParamsMatchExpected( + rule=rule, + trigger_instance=MOCK_TRIGGER_INSTANCE_4, + params=params, + expected_params=expected_params, + ) def test_payload_transforms_int_type(self): - rule = self.models['rules']['rule_action_default_value_render_fail.yaml'] + rule = self.models["rules"]["rule_action_default_value_render_fail.yaml"] - params = {'int': 666} - expected_params = {'int': 666} + params = {"int": 666} + expected_params = {"int": 666} - self.assertResolvedParamsMatchExpected(rule=rule, - trigger_instance=MOCK_TRIGGER_INSTANCE_4, - params=params, - expected_params=expected_params) + self.assertResolvedParamsMatchExpected( + rule=rule, + trigger_instance=MOCK_TRIGGER_INSTANCE_4, + params=params, + expected_params=expected_params, + ) def test_payload_transforms_bool_type(self): - rule = self.models['rules']['rule_action_default_value_render_fail.yaml'] + rule = self.models["rules"]["rule_action_default_value_render_fail.yaml"] runner_type_db = mock.Mock() runner_type_db.runner_parameters = {} action_db = mock.Mock() action_db.parameters = {} - params = {'bool': True} - expected_params = {'bool': True} + params = {"bool": True} + expected_params = {"bool": True} - self.assertResolvedParamsMatchExpected(rule=rule, - trigger_instance=MOCK_TRIGGER_INSTANCE_4, - params=params, - expected_params=expected_params) + self.assertResolvedParamsMatchExpected( + rule=rule, + trigger_instance=MOCK_TRIGGER_INSTANCE_4, + params=params, + expected_params=expected_params, + ) def test_payload_transforms_complex_type(self): - rule = self.models['rules']['rule_action_default_value_render_fail.yaml'] + rule = self.models["rules"]["rule_action_default_value_render_fail.yaml"] runner_type_db = mock.Mock() runner_type_db.runner_parameters = {} action_db = mock.Mock() action_db.parameters = {} - params = {'complex_dict': {'bool': True, 'int': 666, 'str': '{{trigger.k1}}-string'}} - expected_params = {'complex_dict': {'bool': True, 'int': 666, 'str': 'v1-string'}} + params = { + "complex_dict": {"bool": True, "int": 666, "str": "{{trigger.k1}}-string"} + } + expected_params = { + "complex_dict": {"bool": True, "int": 666, "str": "v1-string"} + } - self.assertResolvedParamsMatchExpected(rule=rule, - trigger_instance=MOCK_TRIGGER_INSTANCE_4, - params=params, - expected_params=expected_params) + self.assertResolvedParamsMatchExpected( + rule=rule, + trigger_instance=MOCK_TRIGGER_INSTANCE_4, + params=params, + expected_params=expected_params, + ) - params = {'simple_list': [1, 2, 3]} - expected_params = {'simple_list': [1, 2, 3]} + params = {"simple_list": [1, 2, 3]} + expected_params = {"simple_list": [1, 2, 3]} - self.assertResolvedParamsMatchExpected(rule=rule, - trigger_instance=MOCK_TRIGGER_INSTANCE_4, - params=params, - expected_params=expected_params) + self.assertResolvedParamsMatchExpected( + rule=rule, + trigger_instance=MOCK_TRIGGER_INSTANCE_4, + params=params, + expected_params=expected_params, + ) def test_hypenated_payload_transform(self): - rule = self.models['rules']['rule_action_default_value_render_fail.yaml'] - payload = {'headers': {'hypenated-header': 'dont-care'}, 'k2': 'v2'} + rule = self.models["rules"]["rule_action_default_value_render_fail.yaml"] + payload = {"headers": {"hypenated-header": "dont-care"}, "k2": "v2"} MOCK_TRIGGER_INSTANCE_4.payload = payload - params = {'ip1': '{{trigger.headers[\'hypenated-header\']}}-static', - 'ip2': '{{trigger.k2}} static'} - expected_params = {'ip1': 'dont-care-static', 'ip2': 'v2 static'} - - self.assertResolvedParamsMatchExpected(rule=rule, - trigger_instance=MOCK_TRIGGER_INSTANCE_4, - params=params, - expected_params=expected_params) + params = { + "ip1": "{{trigger.headers['hypenated-header']}}-static", + "ip2": "{{trigger.k2}} static", + } + expected_params = {"ip1": "dont-care-static", "ip2": "v2 static"} + + self.assertResolvedParamsMatchExpected( + rule=rule, + trigger_instance=MOCK_TRIGGER_INSTANCE_4, + params=params, + expected_params=expected_params, + ) def test_system_transform(self): - rule = self.models['rules']['rule_action_default_value_render_fail.yaml'] + rule = self.models["rules"]["rule_action_default_value_render_fail.yaml"] runner_type_db = mock.Mock() runner_type_db.runner_parameters = {} action_db = mock.Mock() action_db.parameters = {} - k5 = KeyValuePair.add_or_update(KeyValuePairDB(name='k5', value='v5')) - k6 = KeyValuePair.add_or_update(KeyValuePairDB(name='k6', value='v6')) - k7 = KeyValuePair.add_or_update(KeyValuePairDB(name='k7', value='v7')) - k8 = KeyValuePair.add_or_update(KeyValuePairDB(name='k8', value='v8', - scope=FULL_SYSTEM_SCOPE)) + k5 = KeyValuePair.add_or_update(KeyValuePairDB(name="k5", value="v5")) + k6 = KeyValuePair.add_or_update(KeyValuePairDB(name="k6", value="v6")) + k7 = KeyValuePair.add_or_update(KeyValuePairDB(name="k7", value="v7")) + k8 = KeyValuePair.add_or_update( + KeyValuePairDB(name="k8", value="v8", scope=FULL_SYSTEM_SCOPE) + ) - params = {'ip5': '{{trigger.k2}}-static', - 'ip6': '{{st2kv.system.k6}}-static', - 'ip7': '{{st2kv.system.k7}}-static'} - expected_params = {'ip5': 'v2-static', - 'ip6': 'v6-static', - 'ip7': 'v7-static'} + params = { + "ip5": "{{trigger.k2}}-static", + "ip6": "{{st2kv.system.k6}}-static", + "ip7": "{{st2kv.system.k7}}-static", + } + expected_params = {"ip5": "v2-static", "ip6": "v6-static", "ip7": "v7-static"} try: - self.assertResolvedParamsMatchExpected(rule=rule, - trigger_instance=MOCK_TRIGGER_INSTANCE_4, - params=params, - expected_params=expected_params) + self.assertResolvedParamsMatchExpected( + rule=rule, + trigger_instance=MOCK_TRIGGER_INSTANCE_4, + params=params, + expected_params=expected_params, + ) finally: KeyValuePair.delete(k5) KeyValuePair.delete(k6) KeyValuePair.delete(k7) KeyValuePair.delete(k8) - def assertResolvedParamsMatchExpected(self, rule, trigger_instance, params, expected_params): + def assertResolvedParamsMatchExpected( + self, rule, trigger_instance, params, expected_params + ): runner_type_db = mock.Mock() runner_type_db.runner_parameters = {} action_db = mock.Mock() action_db.parameters = {} enforcer = RuleEnforcer(trigger_instance, rule) - context, additional_contexts = enforcer.get_action_execution_context(action_db=action_db) + context, additional_contexts = enforcer.get_action_execution_context( + action_db=action_db + ) - resolved_params = enforcer.get_resolved_parameters(action_db=action_db, + resolved_params = enforcer.get_resolved_parameters( + action_db=action_db, runnertype_db=runner_type_db, params=params, context=context, - additional_contexts=additional_contexts) + additional_contexts=additional_contexts, + ) self.assertEqual(resolved_params, expected_params) diff --git a/st2reactor/tests/unit/test_filter.py b/st2reactor/tests/unit/test_filter.py index 4df7ef23604..d1e42eaecef 100644 --- a/st2reactor/tests/unit/test_filter.py +++ b/st2reactor/tests/unit/test_filter.py @@ -27,57 +27,71 @@ from st2tests import DbTestCase -MOCK_TRIGGER = TriggerDB(pack='dummy_pack_1', name='trigger-test.name', type='system.test') +MOCK_TRIGGER = TriggerDB( + pack="dummy_pack_1", name="trigger-test.name", type="system.test" +) MOCK_TRIGGER_INSTANCE = TriggerInstanceDB( trigger=MOCK_TRIGGER.get_reference().ref, occurrence_time=date_utils.get_datetime_utc_now(), payload={ - 'p1': 'v1', - 'p2': 'preYYYpost', - 'bool': True, - 'int': 1, - 'float': 0.8, - 'list': ['v1', True, 1], - 'recursive_list': [ + "p1": "v1", + "p2": "preYYYpost", + "bool": True, + "int": 1, + "float": 0.8, + "list": ["v1", True, 1], + "recursive_list": [ { - 'field_name': "Status", - 'to_value': "Approved", - }, { - 'field_name': "Signed off by", - 'to_value': "Stanley", - } + "field_name": "Status", + "to_value": "Approved", + }, + { + "field_name": "Signed off by", + "to_value": "Stanley", + }, ], - } + }, ) -MOCK_ACTION = ActionDB(id=bson.ObjectId(), pack='wolfpack', name='action-test-1.name') +MOCK_ACTION = ActionDB(id=bson.ObjectId(), pack="wolfpack", name="action-test-1.name") -MOCK_RULE_1 = RuleDB(id=bson.ObjectId(), pack='wolfpack', name='some1', - trigger=reference.get_str_resource_ref_from_model(MOCK_TRIGGER), - criteria={}, action=ActionExecutionSpecDB(ref="somepack.someaction")) +MOCK_RULE_1 = RuleDB( + id=bson.ObjectId(), + pack="wolfpack", + name="some1", + trigger=reference.get_str_resource_ref_from_model(MOCK_TRIGGER), + criteria={}, + action=ActionExecutionSpecDB(ref="somepack.someaction"), +) -MOCK_RULE_2 = RuleDB(id=bson.ObjectId(), pack='wolfpack', name='some2', - trigger=reference.get_str_resource_ref_from_model(MOCK_TRIGGER), - criteria={}, action=ActionExecutionSpecDB(ref="somepack.someaction")) +MOCK_RULE_2 = RuleDB( + id=bson.ObjectId(), + pack="wolfpack", + name="some2", + trigger=reference.get_str_resource_ref_from_model(MOCK_TRIGGER), + criteria={}, + action=ActionExecutionSpecDB(ref="somepack.someaction"), +) -@mock.patch.object(reference, 'get_model_by_resource_ref', - mock.MagicMock(return_value=MOCK_TRIGGER)) +@mock.patch.object( + reference, "get_model_by_resource_ref", mock.MagicMock(return_value=MOCK_TRIGGER) +) class FilterTest(DbTestCase): def test_empty_criteria(self): rule = MOCK_RULE_1 rule.criteria = {} f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertTrue(f.filter(), 'equals check should have failed.') + self.assertTrue(f.filter(), "equals check should have failed.") def test_empty_payload(self): rule = MOCK_RULE_1 - rule.criteria = {'trigger.p1': {'type': 'equals', 'pattern': 'v1'}} + rule.criteria = {"trigger.p1": {"type": "equals", "pattern": "v1"}} trigger_instance = copy.deepcopy(MOCK_TRIGGER_INSTANCE) trigger_instance.payload = None f = RuleFilter(trigger_instance, MOCK_TRIGGER, rule) - self.assertFalse(f.filter(), 'equals check should have failed.') + self.assertFalse(f.filter(), "equals check should have failed.") def test_empty_criteria_and_empty_payload(self): rule = MOCK_RULE_1 @@ -85,234 +99,247 @@ def test_empty_criteria_and_empty_payload(self): trigger_instance = copy.deepcopy(MOCK_TRIGGER_INSTANCE) trigger_instance.payload = None f = RuleFilter(trigger_instance, MOCK_TRIGGER, rule) - self.assertTrue(f.filter(), 'equals check should have failed.') + self.assertTrue(f.filter(), "equals check should have failed.") def test_search_operator_pass_any_criteria(self): rule = MOCK_RULE_1 rule.criteria = { - 'trigger.recursive_list': { - 'type': 'search', - 'condition': 'any', - 'pattern': { - 'item.field_name': { - 'type': 'equals', - 'pattern': 'Status', + "trigger.recursive_list": { + "type": "search", + "condition": "any", + "pattern": { + "item.field_name": { + "type": "equals", + "pattern": "Status", }, - 'item.to_value': { - 'type': 'equals', - 'pattern': 'Approved' - } - } + "item.to_value": {"type": "equals", "pattern": "Approved"}, + }, } } f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertTrue(f.filter(), 'Failed evaluation') + self.assertTrue(f.filter(), "Failed evaluation") def test_search_operator_fail_any_criteria(self): rule = MOCK_RULE_1 rule.criteria = { - 'trigger.recursive_list': { - 'type': 'search', - 'condition': 'any', - 'pattern': { - 'item.field_name': { - 'type': 'equals', - 'pattern': 'Status', + "trigger.recursive_list": { + "type": "search", + "condition": "any", + "pattern": { + "item.field_name": { + "type": "equals", + "pattern": "Status", }, - 'item.to_value': { - 'type': 'equals', - 'pattern': 'Denied', - } - } + "item.to_value": { + "type": "equals", + "pattern": "Denied", + }, + }, } } f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertFalse(f.filter(), 'Passed evaluation') + self.assertFalse(f.filter(), "Passed evaluation") def test_search_operator_pass_all_criteria(self): rule = MOCK_RULE_1 rule.criteria = { - 'trigger.recursive_list': { - 'type': 'search', - 'condition': 'all', - 'pattern': { - 'item.field_name': { - 'type': 'startswith', - 'pattern': 'S', + "trigger.recursive_list": { + "type": "search", + "condition": "all", + "pattern": { + "item.field_name": { + "type": "startswith", + "pattern": "S", } - } + }, } } f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertTrue(f.filter(), 'Failed evaluation') + self.assertTrue(f.filter(), "Failed evaluation") def test_search_operator_fail_all_criteria(self): rule = MOCK_RULE_1 rule.criteria = { - 'trigger.recursive_list': { - 'type': 'search', - 'condition': 'all', - 'pattern': { - 'item.field_name': { - 'type': 'equals', - 'pattern': 'Status', + "trigger.recursive_list": { + "type": "search", + "condition": "all", + "pattern": { + "item.field_name": { + "type": "equals", + "pattern": "Status", }, - 'item.to_value': { - 'type': 'equals', - 'pattern': 'Denied', - } - } + "item.to_value": { + "type": "equals", + "pattern": "Denied", + }, + }, } } f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertFalse(f.filter(), 'Passed evaluation') + self.assertFalse(f.filter(), "Passed evaluation") def test_matchregex_operator_pass_criteria(self): rule = MOCK_RULE_1 - rule.criteria = {'trigger.p1': {'type': 'matchregex', 'pattern': 'v1$'}} + rule.criteria = {"trigger.p1": {"type": "matchregex", "pattern": "v1$"}} f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertTrue(f.filter(), 'Failed to pass evaluation.') + self.assertTrue(f.filter(), "Failed to pass evaluation.") def test_matchregex_operator_fail_criteria(self): rule = MOCK_RULE_1 - rule.criteria = {'trigger.p1': {'type': 'matchregex', 'pattern': 'v$'}} + rule.criteria = {"trigger.p1": {"type": "matchregex", "pattern": "v$"}} f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertFalse(f.filter(), 'regex check should have failed.') + self.assertFalse(f.filter(), "regex check should have failed.") def test_equals_operator_pass_criteria(self): rule = MOCK_RULE_1 - rule.criteria = {'trigger.p1': {'type': 'equals', 'pattern': 'v1'}} + rule.criteria = {"trigger.p1": {"type": "equals", "pattern": "v1"}} f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertTrue(f.filter(), 'equals check should have passed.') + self.assertTrue(f.filter(), "equals check should have passed.") rule = MOCK_RULE_1 - rule.criteria = {'trigger.p1': {'type': 'equals', 'pattern': '{{trigger.p1}}'}} + rule.criteria = {"trigger.p1": {"type": "equals", "pattern": "{{trigger.p1}}"}} f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertTrue(f.filter(), 'equals check should have passed.') + self.assertTrue(f.filter(), "equals check should have passed.") rule = MOCK_RULE_1 rule.criteria = { - 'trigger.p1': { - 'type': 'equals', - 'pattern': "{{'%s' % trigger.p1 if trigger.int}}" + "trigger.p1": { + "type": "equals", + "pattern": "{{'%s' % trigger.p1 if trigger.int}}", } } f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertTrue(f.filter(), 'equals check should have passed.') + self.assertTrue(f.filter(), "equals check should have passed.") # Test our filter works if proper JSON is returned from user pattern rule = MOCK_RULE_1 rule.criteria = { - 'trigger.list': { - 'type': 'equals', - 'pattern': """ + "trigger.list": { + "type": "equals", + "pattern": """ [ {% for item in trigger.list %} {{item}}{% if not loop.last %},{% endif %} {% endfor %} ] - """ + """, } } f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertTrue(f.filter(), 'equals check should have passed.') + self.assertTrue(f.filter(), "equals check should have passed.") def test_equals_operator_fail_criteria(self): rule = MOCK_RULE_1 - rule.criteria = {'trigger.p1': {'type': 'equals', 'pattern': 'v'}} + rule.criteria = {"trigger.p1": {"type": "equals", "pattern": "v"}} f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertFalse(f.filter(), 'equals check should have failed.') + self.assertFalse(f.filter(), "equals check should have failed.") rule = MOCK_RULE_1 - rule.criteria = {'trigger.p1': {'type': 'equals', 'pattern': '{{trigger.p2}}'}} + rule.criteria = {"trigger.p1": {"type": "equals", "pattern": "{{trigger.p2}}"}} f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertFalse(f.filter(), 'equals check should have failed.') + self.assertFalse(f.filter(), "equals check should have failed.") rule = MOCK_RULE_1 rule.criteria = { - 'trigger.list': { - 'type': 'equals', - 'pattern': """ + "trigger.list": { + "type": "equals", + "pattern": """ [ {% for item in trigger.list %} {{item}} {% endfor %} ] - """ + """, } } f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertFalse(f.filter(), 'equals check should have failed.') + self.assertFalse(f.filter(), "equals check should have failed.") def test_equals_bool_value(self): rule = MOCK_RULE_1 - rule.criteria = {'trigger.bool': {'type': 'equals', 'pattern': True}} + rule.criteria = {"trigger.bool": {"type": "equals", "pattern": True}} f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertTrue(f.filter(), 'equals check should have passed.') + self.assertTrue(f.filter(), "equals check should have passed.") rule = MOCK_RULE_1 - rule.criteria = {'trigger.bool': {'type': 'equals', 'pattern': '{{trigger.bool}}'}} + rule.criteria = { + "trigger.bool": {"type": "equals", "pattern": "{{trigger.bool}}"} + } f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertTrue(f.filter(), 'equals check should have passed.') + self.assertTrue(f.filter(), "equals check should have passed.") rule = MOCK_RULE_1 - rule.criteria = {'trigger.bool': {'type': 'equals', 'pattern': '{{ trigger.bool }}'}} + rule.criteria = { + "trigger.bool": {"type": "equals", "pattern": "{{ trigger.bool }}"} + } f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertTrue(f.filter(), 'equals check should have passed.') + self.assertTrue(f.filter(), "equals check should have passed.") def test_equals_int_value(self): rule = MOCK_RULE_1 - rule.criteria = {'trigger.int': {'type': 'equals', 'pattern': 1}} + rule.criteria = {"trigger.int": {"type": "equals", "pattern": 1}} f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertTrue(f.filter(), 'equals check should have passed.') + self.assertTrue(f.filter(), "equals check should have passed.") rule = MOCK_RULE_1 - rule.criteria = {'trigger.int': {'type': 'equals', 'pattern': '{{trigger.int}}'}} + rule.criteria = { + "trigger.int": {"type": "equals", "pattern": "{{trigger.int}}"} + } f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertTrue(f.filter(), 'equals check should have passed.') + self.assertTrue(f.filter(), "equals check should have passed.") def test_equals_float_value(self): rule = MOCK_RULE_1 - rule.criteria = {'trigger.float': {'type': 'equals', 'pattern': 0.8}} + rule.criteria = {"trigger.float": {"type": "equals", "pattern": 0.8}} f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertTrue(f.filter(), 'equals check should have passed.') + self.assertTrue(f.filter(), "equals check should have passed.") rule = MOCK_RULE_1 - rule.criteria = {'trigger.float': {'type': 'equals', 'pattern': '{{trigger.float}}'}} + rule.criteria = { + "trigger.float": {"type": "equals", "pattern": "{{trigger.float}}"} + } f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertTrue(f.filter(), 'equals check should have passed.') + self.assertTrue(f.filter(), "equals check should have passed.") def test_exists(self): rule = MOCK_RULE_1 - rule.criteria = {'trigger.float': {'type': 'exists'}} + rule.criteria = {"trigger.float": {"type": "exists"}} f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertTrue(f.filter(), '"float" key exists in trigger. Should return true.') - rule.criteria = {'trigger.floattt': {'type': 'exists'}} + self.assertTrue( + f.filter(), '"float" key exists in trigger. Should return true.' + ) + rule.criteria = {"trigger.floattt": {"type": "exists"}} f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertFalse(f.filter(), '"floattt" key ain\'t exist in trigger. Should return false.') + self.assertFalse( + f.filter(), '"floattt" key ain\'t exist in trigger. Should return false.' + ) def test_nexists(self): rule = MOCK_RULE_1 - rule.criteria = {'trigger.float': {'type': 'nexists'}} + rule.criteria = {"trigger.float": {"type": "nexists"}} f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertFalse(f.filter(), '"float" key exists in trigger. Should return false.') - rule.criteria = {'trigger.floattt': {'type': 'nexists'}} + self.assertFalse( + f.filter(), '"float" key exists in trigger. Should return false.' + ) + rule.criteria = {"trigger.floattt": {"type": "nexists"}} f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertTrue(f.filter(), '"floattt" key ain\'t exist in trigger. Should return true.') + self.assertTrue( + f.filter(), '"floattt" key ain\'t exist in trigger. Should return true.' + ) def test_gt_lt_falsy_pattern(self): # Make sure that the falsy value (number 0) is handled correctly rule = MOCK_RULE_1 - rule.criteria = {'trigger.int': {'type': 'gt', 'pattern': 0}} + rule.criteria = {"trigger.int": {"type": "gt", "pattern": 0}} f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertTrue(f.filter(), 'trigger value is gt than 0 but didn\'t match') + self.assertTrue(f.filter(), "trigger value is gt than 0 but didn't match") - rule.criteria = {'trigger.int': {'type': 'lt', 'pattern': 0}} + rule.criteria = {"trigger.int": {"type": "lt", "pattern": 0}} f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) - self.assertFalse(f.filter(), 'trigger value is gt than 0 but didn\'t fail') + self.assertFalse(f.filter(), "trigger value is gt than 0 but didn't fail") - @mock.patch('st2common.util.templating.KeyValueLookup') + @mock.patch("st2common.util.templating.KeyValueLookup") def test_criteria_pattern_references_a_datastore_item(self, mock_KeyValueLookup): class MockResultLookup(object): pass @@ -323,22 +350,24 @@ class MockSystemLookup(object): rule = MOCK_RULE_2 # Using a variable in pattern, referencing an inexistent datastore value - rule.criteria = {'trigger.p1': { - 'type': 'equals', - 'pattern': '{{ st2kv.system.inexistent_value }}'} + rule.criteria = { + "trigger.p1": { + "type": "equals", + "pattern": "{{ st2kv.system.inexistent_value }}", + } } f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) self.assertFalse(f.filter()) # Using a variable in pattern, referencing an existing value which doesn't match mock_result = MockSystemLookup() - mock_result.test_value_1 = 'non matching' + mock_result.test_value_1 = "non matching" mock_KeyValueLookup.return_value = mock_result rule.criteria = { - 'trigger.p1': { - 'type': 'equals', - 'pattern': '{{ st2kv.system.test_value_1 }}' + "trigger.p1": { + "type": "equals", + "pattern": "{{ st2kv.system.test_value_1 }}", } } f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) @@ -346,13 +375,13 @@ class MockSystemLookup(object): # Using a variable in pattern, referencing an existing value which does match mock_result = MockSystemLookup() - mock_result.test_value_2 = 'v1' + mock_result.test_value_2 = "v1" mock_KeyValueLookup.return_value = mock_result rule.criteria = { - 'trigger.p1': { - 'type': 'equals', - 'pattern': '{{ st2kv.system.test_value_2 }}' + "trigger.p1": { + "type": "equals", + "pattern": "{{ st2kv.system.test_value_2 }}", } } f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) @@ -360,13 +389,13 @@ class MockSystemLookup(object): # Using a variable in pattern, referencing an existing value which matches partially mock_result = MockSystemLookup() - mock_result.test_value_3 = 'YYY' + mock_result.test_value_3 = "YYY" mock_KeyValueLookup.return_value = mock_result rule.criteria = { - 'trigger.p2': { - 'type': 'equals', - 'pattern': '{{ st2kv.system.test_value_3 }}' + "trigger.p2": { + "type": "equals", + "pattern": "{{ st2kv.system.test_value_3 }}", } } f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) @@ -374,13 +403,13 @@ class MockSystemLookup(object): # Using a variable in pattern, referencing an existing value which matches partially mock_result = MockSystemLookup() - mock_result.test_value_3 = 'YYY' + mock_result.test_value_3 = "YYY" mock_KeyValueLookup.return_value = mock_result rule.criteria = { - 'trigger.p2': { - 'type': 'equals', - 'pattern': 'pre{{ st2kv.system.test_value_3 }}post' + "trigger.p2": { + "type": "equals", + "pattern": "pre{{ st2kv.system.test_value_3 }}post", } } f = RuleFilter(MOCK_TRIGGER_INSTANCE, MOCK_TRIGGER, rule) diff --git a/st2reactor/tests/unit/test_garbage_collector.py b/st2reactor/tests/unit/test_garbage_collector.py index 31442e8eb30..93de6b25d00 100644 --- a/st2reactor/tests/unit/test_garbage_collector.py +++ b/st2reactor/tests/unit/test_garbage_collector.py @@ -21,43 +21,48 @@ from oslo_config import cfg import st2tests.config as tests_config + tests_config.parse_args() from st2reactor.garbage_collector import base as garbage_collector class GarbageCollectorServiceTest(unittest.TestCase): - def tearDown(self): # Reset gc_max_idle_sec with a value of 1 to reenable for other tests. - cfg.CONF.set_override('gc_max_idle_sec', 1, group='workflow_engine') + cfg.CONF.set_override("gc_max_idle_sec", 1, group="workflow_engine") super(GarbageCollectorServiceTest, self).tearDown() @mock.patch.object( garbage_collector.GarbageCollectorService, - '_purge_action_executions', - mock.MagicMock(return_value=None)) + "_purge_action_executions", + mock.MagicMock(return_value=None), + ) @mock.patch.object( garbage_collector.GarbageCollectorService, - '_purge_action_executions_output', - mock.MagicMock(return_value=None)) + "_purge_action_executions_output", + mock.MagicMock(return_value=None), + ) @mock.patch.object( garbage_collector.GarbageCollectorService, - '_purge_trigger_instances', - mock.MagicMock(return_value=None)) + "_purge_trigger_instances", + mock.MagicMock(return_value=None), + ) @mock.patch.object( garbage_collector.GarbageCollectorService, - '_timeout_inquiries', - mock.MagicMock(return_value=None)) + "_timeout_inquiries", + mock.MagicMock(return_value=None), + ) @mock.patch.object( garbage_collector.GarbageCollectorService, - '_purge_orphaned_workflow_executions', - mock.MagicMock(return_value=None)) + "_purge_orphaned_workflow_executions", + mock.MagicMock(return_value=None), + ) def test_orphaned_workflow_executions_gc_enabled(self): # Mock the default value of gc_max_idle_sec with a value >= 1 to enable. The config # gc_max_idle_sec is assigned to _workflow_execution_max_idle which gc checks to see # whether to run the routine. - cfg.CONF.set_override('gc_max_idle_sec', 1, group='workflow_engine') + cfg.CONF.set_override("gc_max_idle_sec", 1, group="workflow_engine") # Run the garbage collection. gc = garbage_collector.GarbageCollectorService(sleep_delay=0) @@ -70,29 +75,34 @@ def test_orphaned_workflow_executions_gc_enabled(self): @mock.patch.object( garbage_collector.GarbageCollectorService, - '_purge_action_executions', - mock.MagicMock(return_value=None)) + "_purge_action_executions", + mock.MagicMock(return_value=None), + ) @mock.patch.object( garbage_collector.GarbageCollectorService, - '_purge_action_executions_output', - mock.MagicMock(return_value=None)) + "_purge_action_executions_output", + mock.MagicMock(return_value=None), + ) @mock.patch.object( garbage_collector.GarbageCollectorService, - '_purge_trigger_instances', - mock.MagicMock(return_value=None)) + "_purge_trigger_instances", + mock.MagicMock(return_value=None), + ) @mock.patch.object( garbage_collector.GarbageCollectorService, - '_timeout_inquiries', - mock.MagicMock(return_value=None)) + "_timeout_inquiries", + mock.MagicMock(return_value=None), + ) @mock.patch.object( garbage_collector.GarbageCollectorService, - '_purge_orphaned_workflow_executions', - mock.MagicMock(return_value=None)) + "_purge_orphaned_workflow_executions", + mock.MagicMock(return_value=None), + ) def test_orphaned_workflow_executions_gc_disabled(self): # Mock the default value of gc_max_idle_sec with a value of 0 to disable. The config # gc_max_idle_sec is assigned to _workflow_execution_max_idle which gc checks to see # whether to run the routine. - cfg.CONF.set_override('gc_max_idle_sec', 0, group='workflow_engine') + cfg.CONF.set_override("gc_max_idle_sec", 0, group="workflow_engine") # Run the garbage collection. gc = garbage_collector.GarbageCollectorService(sleep_delay=0) diff --git a/st2reactor/tests/unit/test_hash_partitioner.py b/st2reactor/tests/unit/test_hash_partitioner.py index 4412c07b97a..12e522a10c2 100644 --- a/st2reactor/tests/unit/test_hash_partitioner.py +++ b/st2reactor/tests/unit/test_hash_partitioner.py @@ -22,10 +22,8 @@ from st2tests import DbTestCase from st2tests.fixturesloader import FixturesLoader -PACK = 'generic' -FIXTURES_1 = { - 'sensors': ['sensor1.yaml', 'sensor2.yaml', 'sensor3.yaml'] -} +PACK = "generic" +FIXTURES_1 = {"sensors": ["sensor1.yaml", "sensor2.yaml", "sensor3.yaml"]} class HashPartitionerTest(DbTestCase): @@ -38,39 +36,42 @@ def setUpClass(cls): # Create TriggerTypes before creation of Rule to avoid failure. Rule requires the # Trigger and therefore TriggerType to be created prior to rule creation. cls.models = FixturesLoader().save_fixtures_to_db( - fixtures_pack=PACK, fixtures_dict=FIXTURES_1) + fixtures_pack=PACK, fixtures_dict=FIXTURES_1 + ) config.parse_args() def test_full_range_hash_partitioner(self): - partitioner = HashPartitioner('node1', 'MIN..MAX') + partitioner = HashPartitioner("node1", "MIN..MAX") sensors = partitioner.get_sensors() - self.assertEqual(len(sensors), 3, 'Expected all sensors') + self.assertEqual(len(sensors), 3, "Expected all sensors") def test_multi_range_hash_partitioner(self): range_third = int(Range.RANGE_MAX_VALUE / 3) range_two_third = range_third * 2 - hash_ranges = \ - 'MIN..{range_third}|{range_third}..{range_two_third}|{range_two_third}..MAX'.format( - range_third=range_third, range_two_third=range_two_third) - partitioner = HashPartitioner('node1', hash_ranges) + hash_ranges = "MIN..{range_third}|{range_third}..{range_two_third}|{range_two_third}..MAX".format( + range_third=range_third, range_two_third=range_two_third + ) + partitioner = HashPartitioner("node1", hash_ranges) sensors = partitioner.get_sensors() - self.assertEqual(len(sensors), 3, 'Expected all sensors') + self.assertEqual(len(sensors), 3, "Expected all sensors") def test_split_range_hash_partitioner(self): range_mid = int(Range.RANGE_MAX_VALUE / 2) - partitioner = HashPartitioner('node1', 'MIN..%s' % range_mid) + partitioner = HashPartitioner("node1", "MIN..%s" % range_mid) sensors1 = partitioner.get_sensors() - partitioner = HashPartitioner('node2', '%s..MAX' % range_mid) + partitioner = HashPartitioner("node2", "%s..MAX" % range_mid) sensors2 = partitioner.get_sensors() - self.assertEqual(len(sensors1) + len(sensors2), 3, 'Expected all sensors') + self.assertEqual(len(sensors1) + len(sensors2), 3, "Expected all sensors") def test_hash_effectiveness(self): range_third = int(Range.RANGE_MAX_VALUE / 3) - partitioner1 = HashPartitioner('node1', 'MIN..%s' % range_third) - partitioner2 = HashPartitioner('node2', '%s..%s' % (range_third, range_third + range_third)) - partitioner3 = HashPartitioner('node2', '%s..MAX' % (range_third + range_third)) + partitioner1 = HashPartitioner("node1", "MIN..%s" % range_third) + partitioner2 = HashPartitioner( + "node2", "%s..%s" % (range_third, range_third + range_third) + ) + partitioner3 = HashPartitioner("node2", "%s..MAX" % (range_third + range_third)) refs_count = 1000 @@ -89,15 +90,21 @@ def test_hash_effectiveness(self): if partitioner3._is_in_hash_range(ref): p3_count += 1 - self.assertEqual(p1_count + p2_count + p3_count, refs_count, - 'Sum should equal all sensors.') + self.assertEqual( + p1_count + p2_count + p3_count, refs_count, "Sum should equal all sensors." + ) # Test effectiveness by checking if the sd is within 20% of mean mean = refs_count / 3 - variance = float((p1_count - mean)**2 + (p1_count - mean)**2 + (p3_count - mean)**2) / 3 + variance = ( + float( + (p1_count - mean) ** 2 + (p1_count - mean) ** 2 + (p3_count - mean) ** 2 + ) + / 3 + ) sd = math.sqrt(variance) - self.assertTrue(sd / mean <= 0.2, 'Some values deviate too much from the mean.') + self.assertTrue(sd / mean <= 0.2, "Some values deviate too much from the mean.") def _generate_refs(self, count=10): random_word_count = int(math.sqrt(count)) + 1 @@ -105,7 +112,7 @@ def _generate_refs(self, count=10): x_index = 0 y_index = 0 while count > 0: - yield '%s.%s' % (words[x_index], words[y_index]) + yield "%s.%s" % (words[x_index], words[y_index]) if y_index < len(words) - 1: y_index += 1 else: diff --git a/st2reactor/tests/unit/test_partitioners.py b/st2reactor/tests/unit/test_partitioners.py index 00e7681cc97..8c4213ec5b2 100644 --- a/st2reactor/tests/unit/test_partitioners.py +++ b/st2reactor/tests/unit/test_partitioners.py @@ -16,8 +16,11 @@ from __future__ import absolute_import from oslo_config import cfg -from st2common.constants.sensors import KVSTORE_PARTITION_LOADER, FILE_PARTITION_LOADER, \ - HASH_PARTITION_LOADER +from st2common.constants.sensors import ( + KVSTORE_PARTITION_LOADER, + FILE_PARTITION_LOADER, + HASH_PARTITION_LOADER, +) from st2common.models.db.keyvalue import KeyValuePairDB from st2common.persistence.keyvalue import KeyValuePair from st2reactor.container.partitioner_lookup import get_sensors_partitioner @@ -26,10 +29,8 @@ from st2tests import DbTestCase from st2tests.fixturesloader import FixturesLoader -PACK = 'generic' -FIXTURES_1 = { - 'sensors': ['sensor1.yaml', 'sensor2.yaml', 'sensor3.yaml'] -} +PACK = "generic" +FIXTURES_1 = {"sensors": ["sensor1.yaml", "sensor2.yaml", "sensor3.yaml"]} class PartitionerTest(DbTestCase): @@ -42,76 +43,91 @@ def setUpClass(cls): # Create TriggerTypes before creation of Rule to avoid failure. Rule requires the # Trigger and therefore TriggerType to be created prior to rule creation. cls.models = FixturesLoader().save_fixtures_to_db( - fixtures_pack=PACK, fixtures_dict=FIXTURES_1) + fixtures_pack=PACK, fixtures_dict=FIXTURES_1 + ) config.parse_args() def test_default_partitioner(self): provider = get_sensors_partitioner() sensors = provider.get_sensors() - self.assertEqual(len(sensors), len(FIXTURES_1['sensors']), - 'Failed to provider all sensors') + self.assertEqual( + len(sensors), len(FIXTURES_1["sensors"]), "Failed to provider all sensors" + ) - sensor1 = self.models['sensors']['sensor1.yaml'] + sensor1 = self.models["sensors"]["sensor1.yaml"] self.assertTrue(provider.is_sensor_owner(sensor1)) def test_kvstore_partitioner(self): - cfg.CONF.set_override(name='partition_provider', - override={'name': KVSTORE_PARTITION_LOADER}, - group='sensorcontainer') - kvp = KeyValuePairDB(**{'name': 'sensornode1.sensor_partition', - 'value': 'generic.Sensor1, generic.Sensor2'}) + cfg.CONF.set_override( + name="partition_provider", + override={"name": KVSTORE_PARTITION_LOADER}, + group="sensorcontainer", + ) + kvp = KeyValuePairDB( + **{ + "name": "sensornode1.sensor_partition", + "value": "generic.Sensor1, generic.Sensor2", + } + ) KeyValuePair.add_or_update(kvp, publish=False, dispatch_trigger=False) provider = get_sensors_partitioner() sensors = provider.get_sensors() - self.assertEqual(len(sensors), len(kvp.value.split(','))) + self.assertEqual(len(sensors), len(kvp.value.split(","))) - sensor1 = self.models['sensors']['sensor1.yaml'] + sensor1 = self.models["sensors"]["sensor1.yaml"] self.assertTrue(provider.is_sensor_owner(sensor1)) - sensor3 = self.models['sensors']['sensor3.yaml'] + sensor3 = self.models["sensors"]["sensor3.yaml"] self.assertFalse(provider.is_sensor_owner(sensor3)) def test_file_partitioner(self): partition_file = FixturesLoader().get_fixture_file_path_abs( - fixtures_pack=PACK, fixtures_type='sensors', fixture_name='partition_file.yaml') - cfg.CONF.set_override(name='partition_provider', - override={'name': FILE_PARTITION_LOADER, - 'partition_file': partition_file}, - group='sensorcontainer') + fixtures_pack=PACK, + fixtures_type="sensors", + fixture_name="partition_file.yaml", + ) + cfg.CONF.set_override( + name="partition_provider", + override={"name": FILE_PARTITION_LOADER, "partition_file": partition_file}, + group="sensorcontainer", + ) provider = get_sensors_partitioner() sensors = provider.get_sensors() self.assertEqual(len(sensors), 2) - sensor1 = self.models['sensors']['sensor1.yaml'] + sensor1 = self.models["sensors"]["sensor1.yaml"] self.assertTrue(provider.is_sensor_owner(sensor1)) - sensor3 = self.models['sensors']['sensor3.yaml'] + sensor3 = self.models["sensors"]["sensor3.yaml"] self.assertFalse(provider.is_sensor_owner(sensor3)) def test_hash_partitioner(self): # no specific partitioner testing here for that see test_hash_partitioner.py # This test is to make sure the wiring and some basics work - cfg.CONF.set_override(name='partition_provider', - override={'name': HASH_PARTITION_LOADER, - 'hash_ranges': '%s..%s' % (Range.RANGE_MIN_ENUM, - Range.RANGE_MAX_ENUM)}, - group='sensorcontainer') + cfg.CONF.set_override( + name="partition_provider", + override={ + "name": HASH_PARTITION_LOADER, + "hash_ranges": "%s..%s" % (Range.RANGE_MIN_ENUM, Range.RANGE_MAX_ENUM), + }, + group="sensorcontainer", + ) provider = get_sensors_partitioner() sensors = provider.get_sensors() self.assertEqual(len(sensors), 3) - sensor1 = self.models['sensors']['sensor1.yaml'] + sensor1 = self.models["sensors"]["sensor1.yaml"] self.assertTrue(provider.is_sensor_owner(sensor1)) - sensor2 = self.models['sensors']['sensor2.yaml'] + sensor2 = self.models["sensors"]["sensor2.yaml"] self.assertTrue(provider.is_sensor_owner(sensor2)) - sensor3 = self.models['sensors']['sensor3.yaml'] + sensor3 = self.models["sensors"]["sensor3.yaml"] self.assertTrue(provider.is_sensor_owner(sensor3)) diff --git a/st2reactor/tests/unit/test_process_container.py b/st2reactor/tests/unit/test_process_container.py index 10ad700b8af..d1bcfdfe643 100644 --- a/st2reactor/tests/unit/test_process_container.py +++ b/st2reactor/tests/unit/test_process_container.py @@ -17,7 +17,7 @@ import os import time -from mock import (MagicMock, Mock, patch) +from mock import MagicMock, Mock, patch import unittest2 from st2reactor.container.process_container import ProcessSensorContainer @@ -26,14 +26,18 @@ from st2common.persistence.pack import Pack import st2tests.config as tests_config + tests_config.parse_args() -MOCK_PACK_DB = PackDB(ref='wolfpack', name='wolf pack', description='', - path='/opt/stackstorm/packs/wolfpack/') +MOCK_PACK_DB = PackDB( + ref="wolfpack", + name="wolf pack", + description="", + path="/opt/stackstorm/packs/wolfpack/", +) class ProcessContainerTests(unittest2.TestCase): - def test_no_sensors_dont_quit(self): process_container = ProcessSensorContainer(None, poll_interval=0.1) process_container_thread = concurrency.spawn(process_container.run) @@ -43,113 +47,133 @@ def test_no_sensors_dont_quit(self): process_container.shutdown() process_container_thread.kill() - @patch.object(ProcessSensorContainer, '_get_sensor_id', - MagicMock(return_value='wolfpack.StupidSensor')) - @patch.object(ProcessSensorContainer, '_dispatch_trigger_for_sensor_spawn', - MagicMock(return_value=None)) - @patch.object(Pack, 'get_by_ref', MagicMock(return_value=MOCK_PACK_DB)) - @patch.object(os.path, 'isdir', MagicMock(return_value=True)) - @patch('subprocess.Popen') - @patch('st2reactor.container.process_container.create_token') - def test_common_lib_path_in_pythonpath_env_var(self, mock_create_token, mock_subproc_popen): + @patch.object( + ProcessSensorContainer, + "_get_sensor_id", + MagicMock(return_value="wolfpack.StupidSensor"), + ) + @patch.object( + ProcessSensorContainer, + "_dispatch_trigger_for_sensor_spawn", + MagicMock(return_value=None), + ) + @patch.object(Pack, "get_by_ref", MagicMock(return_value=MOCK_PACK_DB)) + @patch.object(os.path, "isdir", MagicMock(return_value=True)) + @patch("subprocess.Popen") + @patch("st2reactor.container.process_container.create_token") + def test_common_lib_path_in_pythonpath_env_var( + self, mock_create_token, mock_subproc_popen + ): process_mock = Mock() - attrs = {'communicate.return_value': ('output', 'error')} + attrs = {"communicate.return_value": ("output", "error")} process_mock.configure_mock(**attrs) mock_subproc_popen.return_value = process_mock mock_create_token = Mock() - mock_create_token.return_value = 'WHOLETTHEDOGSOUT' + mock_create_token.return_value = "WHOLETTHEDOGSOUT" mock_dispatcher = Mock() - process_container = ProcessSensorContainer(None, poll_interval=0.1, - dispatcher=mock_dispatcher) + process_container = ProcessSensorContainer( + None, poll_interval=0.1, dispatcher=mock_dispatcher + ) sensor = { - 'class_name': 'wolfpack.StupidSensor', - 'ref': 'wolfpack.StupidSensor', - 'id': '567890', - 'trigger_types': ['some_trigga'], - 'pack': 'wolfpack', - 'file_path': '/opt/stackstorm/packs/wolfpack/sensors/stupid_sensor.py', - 'poll_interval': 5 + "class_name": "wolfpack.StupidSensor", + "ref": "wolfpack.StupidSensor", + "id": "567890", + "trigger_types": ["some_trigga"], + "pack": "wolfpack", + "file_path": "/opt/stackstorm/packs/wolfpack/sensors/stupid_sensor.py", + "poll_interval": 5, } process_container._enable_common_pack_libs = True - process_container._sensors = {'pack.StupidSensor': sensor} + process_container._sensors = {"pack.StupidSensor": sensor} process_container._spawn_sensor_process(sensor) _, call_kwargs = mock_subproc_popen.call_args - actual_env = call_kwargs['env'] - self.assertIn('PYTHONPATH', actual_env) - pack_common_lib_path = '/opt/stackstorm/packs/wolfpack/lib' - self.assertIn(pack_common_lib_path, actual_env['PYTHONPATH']) - - @patch.object(ProcessSensorContainer, '_get_sensor_id', - MagicMock(return_value='wolfpack.StupidSensor')) - @patch.object(ProcessSensorContainer, '_dispatch_trigger_for_sensor_spawn', - MagicMock(return_value=None)) - @patch.object(Pack, 'get_by_ref', MagicMock(return_value=MOCK_PACK_DB)) - @patch.object(os.path, 'isdir', MagicMock(return_value=True)) - @patch('subprocess.Popen') - @patch('st2reactor.container.process_container.create_token') - def test_common_lib_path_not_in_pythonpath_env_var(self, mock_create_token, mock_subproc_popen): + actual_env = call_kwargs["env"] + self.assertIn("PYTHONPATH", actual_env) + pack_common_lib_path = "/opt/stackstorm/packs/wolfpack/lib" + self.assertIn(pack_common_lib_path, actual_env["PYTHONPATH"]) + + @patch.object( + ProcessSensorContainer, + "_get_sensor_id", + MagicMock(return_value="wolfpack.StupidSensor"), + ) + @patch.object( + ProcessSensorContainer, + "_dispatch_trigger_for_sensor_spawn", + MagicMock(return_value=None), + ) + @patch.object(Pack, "get_by_ref", MagicMock(return_value=MOCK_PACK_DB)) + @patch.object(os.path, "isdir", MagicMock(return_value=True)) + @patch("subprocess.Popen") + @patch("st2reactor.container.process_container.create_token") + def test_common_lib_path_not_in_pythonpath_env_var( + self, mock_create_token, mock_subproc_popen + ): process_mock = Mock() - attrs = {'communicate.return_value': ('output', 'error')} + attrs = {"communicate.return_value": ("output", "error")} process_mock.configure_mock(**attrs) mock_subproc_popen.return_value = process_mock mock_create_token = Mock() - mock_create_token.return_value = 'WHOLETTHEDOGSOUT' + mock_create_token.return_value = "WHOLETTHEDOGSOUT" mock_dispatcher = Mock() - process_container = ProcessSensorContainer(None, poll_interval=0.1, - dispatcher=mock_dispatcher) + process_container = ProcessSensorContainer( + None, poll_interval=0.1, dispatcher=mock_dispatcher + ) sensor = { - 'class_name': 'wolfpack.StupidSensor', - 'ref': 'wolfpack.StupidSensor', - 'id': '567890', - 'trigger_types': ['some_trigga'], - 'pack': 'wolfpack', - 'file_path': '/opt/stackstorm/packs/wolfpack/sensors/stupid_sensor.py', - 'poll_interval': 5 + "class_name": "wolfpack.StupidSensor", + "ref": "wolfpack.StupidSensor", + "id": "567890", + "trigger_types": ["some_trigga"], + "pack": "wolfpack", + "file_path": "/opt/stackstorm/packs/wolfpack/sensors/stupid_sensor.py", + "poll_interval": 5, } process_container._enable_common_pack_libs = False - process_container._sensors = {'pack.StupidSensor': sensor} + process_container._sensors = {"pack.StupidSensor": sensor} process_container._spawn_sensor_process(sensor) _, call_kwargs = mock_subproc_popen.call_args - actual_env = call_kwargs['env'] - self.assertIn('PYTHONPATH', actual_env) - pack_common_lib_path = '/opt/stackstorm/packs/wolfpack/lib' - self.assertNotIn(pack_common_lib_path, actual_env['PYTHONPATH']) + actual_env = call_kwargs["env"] + self.assertIn("PYTHONPATH", actual_env) + pack_common_lib_path = "/opt/stackstorm/packs/wolfpack/lib" + self.assertNotIn(pack_common_lib_path, actual_env["PYTHONPATH"]) - @patch.object(time, 'time', MagicMock(return_value=1439441533)) + @patch.object(time, "time", MagicMock(return_value=1439441533)) def test_dispatch_triggers_on_spawn_exit(self): mock_dispatcher = Mock() - process_container = ProcessSensorContainer(None, poll_interval=0.1, - dispatcher=mock_dispatcher) - sensor = { - 'class_name': 'pack.StupidSensor' - } + process_container = ProcessSensorContainer( + None, poll_interval=0.1, dispatcher=mock_dispatcher + ) + sensor = {"class_name": "pack.StupidSensor"} process = Mock() - process_attrs = {'pid': 1234} + process_attrs = {"pid": 1234} process.configure_mock(**process_attrs) - cmd = 'sensor_wrapper.py --class-name pack.StupidSensor' + cmd = "sensor_wrapper.py --class-name pack.StupidSensor" process_container._dispatch_trigger_for_sensor_spawn(sensor, process, cmd) mock_dispatcher.dispatch.assert_called_with( - 'core.st2.sensor.process_spawn', + "core.st2.sensor.process_spawn", payload={ - 'timestamp': 1439441533, - 'cmd': 'sensor_wrapper.py --class-name pack.StupidSensor', - 'pid': 1234, - 'id': 'pack.StupidSensor'}) + "timestamp": 1439441533, + "cmd": "sensor_wrapper.py --class-name pack.StupidSensor", + "pid": 1234, + "id": "pack.StupidSensor", + }, + ) process_container._dispatch_trigger_for_sensor_exit(sensor, 1) mock_dispatcher.dispatch.assert_called_with( - 'core.st2.sensor.process_exit', + "core.st2.sensor.process_exit", payload={ - 'id': 'pack.StupidSensor', - 'timestamp': 1439441533, - 'exit_code': 1 - }) + "id": "pack.StupidSensor", + "timestamp": 1439441533, + "exit_code": 1, + }, + ) diff --git a/st2reactor/tests/unit/test_rule_engine.py b/st2reactor/tests/unit/test_rule_engine.py index 39b1627268a..2f70a2a9d79 100644 --- a/st2reactor/tests/unit/test_rule_engine.py +++ b/st2reactor/tests/unit/test_rule_engine.py @@ -18,9 +18,9 @@ from mongoengine import NotUniqueError from st2common.models.api.rule import RuleAPI -from st2common.models.db.trigger import (TriggerDB, TriggerTypeDB) +from st2common.models.db.trigger import TriggerDB, TriggerTypeDB from st2common.persistence.rule import Rule -from st2common.persistence.trigger import (TriggerType, Trigger) +from st2common.persistence.trigger import TriggerType, Trigger from st2common.util import date as date_utils import st2reactor.container.utils as container_utils from st2reactor.rules.enforcer import RuleEnforcer @@ -29,30 +29,29 @@ class RuleEngineTest(DbTestCase): - @classmethod def setUpClass(cls): super(RuleEngineTest, cls).setUpClass() RuleEngineTest._setup_test_models() - @mock.patch.object(RuleEnforcer, 'enforce', mock.MagicMock(return_value=True)) + @mock.patch.object(RuleEnforcer, "enforce", mock.MagicMock(return_value=True)) def test_handle_trigger_instances(self): trigger_instance_1 = container_utils.create_trigger_instance( - 'dummy_pack_1.st2.test.trigger1', - {'k1': 't1_p_v', 'k2': 'v2'}, - date_utils.get_datetime_utc_now() + "dummy_pack_1.st2.test.trigger1", + {"k1": "t1_p_v", "k2": "v2"}, + date_utils.get_datetime_utc_now(), ) trigger_instance_2 = container_utils.create_trigger_instance( - 'dummy_pack_1.st2.test.trigger1', - {'k1': 't1_p_v', 'k2': 'v2', 'k3': 'v3'}, - date_utils.get_datetime_utc_now() + "dummy_pack_1.st2.test.trigger1", + {"k1": "t1_p_v", "k2": "v2", "k3": "v3"}, + date_utils.get_datetime_utc_now(), ) trigger_instance_3 = container_utils.create_trigger_instance( - 'dummy_pack_1.st2.test.trigger2', - {'k1': 't1_p_v', 'k2': 'v2', 'k3': 'v3'}, - date_utils.get_datetime_utc_now() + "dummy_pack_1.st2.test.trigger2", + {"k1": "t1_p_v", "k2": "v2", "k3": "v3"}, + date_utils.get_datetime_utc_now(), ) instances = [trigger_instance_1, trigger_instance_2, trigger_instance_3] rules_engine = RulesEngine() @@ -60,32 +59,36 @@ def test_handle_trigger_instances(self): rules_engine.handle_trigger_instance(instance) def test_create_trigger_instance_for_trigger_with_params(self): - trigger = {'type': 'dummy_pack_1.st2.test.trigger4', 'parameters': {'url': 'sample'}} - payload = {'k1': 't1_p_v', 'k2': 'v2', 'k3': 'v3'} + trigger = { + "type": "dummy_pack_1.st2.test.trigger4", + "parameters": {"url": "sample"}, + } + payload = {"k1": "t1_p_v", "k2": "v2", "k3": "v3"} occurrence_time = date_utils.get_datetime_utc_now() - trigger_instance = container_utils.create_trigger_instance(trigger=trigger, - payload=payload, - occurrence_time=occurrence_time) + trigger_instance = container_utils.create_trigger_instance( + trigger=trigger, payload=payload, occurrence_time=occurrence_time + ) self.assertTrue(trigger_instance) - self.assertEqual(trigger_instance.trigger, trigger['type']) + self.assertEqual(trigger_instance.trigger, trigger["type"]) self.assertEqual(trigger_instance.payload, payload) def test_get_matching_rules_filters_disabled_rules(self): trigger_instance = container_utils.create_trigger_instance( - 'dummy_pack_1.st2.test.trigger1', - {'k1': 't1_p_v', 'k2': 'v2'}, date_utils.get_datetime_utc_now() + "dummy_pack_1.st2.test.trigger1", + {"k1": "t1_p_v", "k2": "v2"}, + date_utils.get_datetime_utc_now(), ) rules_engine = RulesEngine() matching_rules = rules_engine.get_matching_rules_for_trigger(trigger_instance) - expected_rules = ['st2.test.rule2'] + expected_rules = ["st2.test.rule2"] for rule in matching_rules: self.assertIn(rule.name, expected_rules) def test_handle_trigger_instance_no_rules(self): trigger_instance = container_utils.create_trigger_instance( - 'dummy_pack_1.st2.test.trigger3', - {'k1': 't1_p_v', 'k2': 'v2'}, - date_utils.get_datetime_utc_now() + "dummy_pack_1.st2.test.trigger3", + {"k1": "t1_p_v", "k2": "v2"}, + date_utils.get_datetime_utc_now(), ) rules_engine = RulesEngine() rules_engine.handle_trigger_instance(trigger_instance) # should not throw. @@ -96,14 +99,26 @@ def _setup_test_models(cls): RuleEngineTest._setup_sample_rules() @classmethod - def _setup_sample_triggers(self, names=['st2.test.trigger1', 'st2.test.trigger2', - 'st2.test.trigger3', 'st2.test.trigger4']): + def _setup_sample_triggers( + self, + names=[ + "st2.test.trigger1", + "st2.test.trigger2", + "st2.test.trigger3", + "st2.test.trigger4", + ], + ): trigger_dbs = [] for name in names: trigtype = None try: - trigtype = TriggerTypeDB(pack='dummy_pack_1', name=name, description='', - payload_schema={}, parameters_schema={}) + trigtype = TriggerTypeDB( + pack="dummy_pack_1", + name=name, + description="", + payload_schema={}, + parameters_schema={}, + ) try: trigtype = TriggerType.get_by_name(name) except: @@ -111,11 +126,15 @@ def _setup_sample_triggers(self, names=['st2.test.trigger1', 'st2.test.trigger2' except NotUniqueError: pass - created = TriggerDB(pack='dummy_pack_1', name=name, description='', - type=trigtype.get_reference().ref) + created = TriggerDB( + pack="dummy_pack_1", + name=name, + description="", + type=trigtype.get_reference().ref, + ) - if name in ['st2.test.trigger4']: - created.parameters = {'url': 'sample'} + if name in ["st2.test.trigger4"]: + created.parameters = {"url": "sample"} else: created.parameters = {} @@ -130,55 +149,40 @@ def _setup_sample_rules(self): # Rules for st2.test.trigger1 RULE_1 = { - 'enabled': True, - 'name': 'st2.test.rule1', - 'pack': 'sixpack', - 'trigger': { - 'type': 'dummy_pack_1.st2.test.trigger1' - }, - 'criteria': { - 'k1': { # Missing prefix 'trigger'. This rule won't match. - 'pattern': 't1_p_v', - 'type': 'equals' + "enabled": True, + "name": "st2.test.rule1", + "pack": "sixpack", + "trigger": {"type": "dummy_pack_1.st2.test.trigger1"}, + "criteria": { + "k1": { # Missing prefix 'trigger'. This rule won't match. + "pattern": "t1_p_v", + "type": "equals", } }, - 'action': { - 'ref': 'sixpack.st2.test.action', - 'parameters': { - 'ip2': '{{rule.k1}}', - 'ip1': '{{trigger.t1_p}}' - } + "action": { + "ref": "sixpack.st2.test.action", + "parameters": {"ip2": "{{rule.k1}}", "ip1": "{{trigger.t1_p}}"}, }, - 'id': '23', - 'description': '' + "id": "23", + "description": "", } rule_api = RuleAPI(**RULE_1) rule_db = RuleAPI.to_model(rule_api) rule_db = Rule.add_or_update(rule_db) rules.append(rule_db) - RULE_2 = { # Rule should match. - 'enabled': True, - 'name': 'st2.test.rule2', - 'pack': 'sixpack', - 'trigger': { - 'type': 'dummy_pack_1.st2.test.trigger1' - }, - 'criteria': { - 'trigger.k1': { - 'pattern': 't1_p_v', - 'type': 'equals' - } - }, - 'action': { - 'ref': 'sixpack.st2.test.action', - 'parameters': { - 'ip2': '{{rule.k1}}', - 'ip1': '{{trigger.t1_p}}' - } + RULE_2 = { # Rule should match. + "enabled": True, + "name": "st2.test.rule2", + "pack": "sixpack", + "trigger": {"type": "dummy_pack_1.st2.test.trigger1"}, + "criteria": {"trigger.k1": {"pattern": "t1_p_v", "type": "equals"}}, + "action": { + "ref": "sixpack.st2.test.action", + "parameters": {"ip2": "{{rule.k1}}", "ip1": "{{trigger.t1_p}}"}, }, - 'id': '23', - 'description': '' + "id": "23", + "description": "", } rule_api = RuleAPI(**RULE_2) rule_db = RuleAPI.to_model(rule_api) @@ -186,27 +190,17 @@ def _setup_sample_rules(self): rules.append(rule_db) RULE_3 = { - 'enabled': False, # Disabled rule shouldn't match. - 'name': 'st2.test.rule3', - 'pack': 'sixpack', - 'trigger': { - 'type': 'dummy_pack_1.st2.test.trigger1' - }, - 'criteria': { - 'trigger.k1': { - 'pattern': 't1_p_v', - 'type': 'equals' - } - }, - 'action': { - 'ref': 'sixpack.st2.test.action', - 'parameters': { - 'ip2': '{{rule.k1}}', - 'ip1': '{{trigger.t1_p}}' - } + "enabled": False, # Disabled rule shouldn't match. + "name": "st2.test.rule3", + "pack": "sixpack", + "trigger": {"type": "dummy_pack_1.st2.test.trigger1"}, + "criteria": {"trigger.k1": {"pattern": "t1_p_v", "type": "equals"}}, + "action": { + "ref": "sixpack.st2.test.action", + "parameters": {"ip2": "{{rule.k1}}", "ip1": "{{trigger.t1_p}}"}, }, - 'id': '23', - 'description': '' + "id": "23", + "description": "", } rule_api = RuleAPI(**RULE_3) rule_db = RuleAPI.to_model(rule_api) @@ -215,27 +209,17 @@ def _setup_sample_rules(self): # Rules for st2.test.trigger2 RULE_4 = { - 'enabled': True, - 'name': 'st2.test.rule4', - 'pack': 'sixpack', - 'trigger': { - 'type': 'dummy_pack_1.st2.test.trigger2' - }, - 'criteria': { - 'trigger.k1': { - 'pattern': 't1_p_v', - 'type': 'equals' - } - }, - 'action': { - 'ref': 'sixpack.st2.test.action', - 'parameters': { - 'ip2': '{{rule.k1}}', - 'ip1': '{{trigger.t1_p}}' - } + "enabled": True, + "name": "st2.test.rule4", + "pack": "sixpack", + "trigger": {"type": "dummy_pack_1.st2.test.trigger2"}, + "criteria": {"trigger.k1": {"pattern": "t1_p_v", "type": "equals"}}, + "action": { + "ref": "sixpack.st2.test.action", + "parameters": {"ip2": "{{rule.k1}}", "ip1": "{{trigger.t1_p}}"}, }, - 'id': '23', - 'description': '' + "id": "23", + "description": "", } rule_api = RuleAPI(**RULE_4) rule_db = RuleAPI.to_model(rule_api) diff --git a/st2reactor/tests/unit/test_rule_matcher.py b/st2reactor/tests/unit/test_rule_matcher.py index a5680fa0942..46cc0846623 100644 --- a/st2reactor/tests/unit/test_rule_matcher.py +++ b/st2reactor/tests/unit/test_rule_matcher.py @@ -19,9 +19,9 @@ import mock from st2common.models.api.rule import RuleAPI -from st2common.models.db.trigger import (TriggerDB, TriggerTypeDB) +from st2common.models.db.trigger import TriggerDB, TriggerTypeDB from st2common.persistence.rule import Rule -from st2common.persistence.trigger import (TriggerType, Trigger) +from st2common.persistence.trigger import TriggerType, Trigger from st2common.services.triggers import get_trigger_db_by_ref from st2common.util import date as date_utils import st2reactor.container.utils as container_utils @@ -33,106 +33,68 @@ from st2tests.base import CleanDbTestCase from st2tests.fixturesloader import FixturesLoader -__all__ = [ - 'RuleMatcherTestCase', - 'BackstopRuleMatcherTestCase' -] +__all__ = ["RuleMatcherTestCase", "BackstopRuleMatcherTestCase"] # Mock rules RULE_1 = { - 'enabled': True, - 'name': 'st2.test.rule1', - 'pack': 'yoyohoneysingh', - 'trigger': { - 'type': 'dummy_pack_1.st2.test.trigger1' - }, - 'criteria': { - 'k1': { # Missing prefix 'trigger'. This rule won't match. - 'pattern': 't1_p_v', - 'type': 'equals' + "enabled": True, + "name": "st2.test.rule1", + "pack": "yoyohoneysingh", + "trigger": {"type": "dummy_pack_1.st2.test.trigger1"}, + "criteria": { + "k1": { # Missing prefix 'trigger'. This rule won't match. + "pattern": "t1_p_v", + "type": "equals", } }, - 'action': { - 'ref': 'sixpack.st2.test.action', - 'parameters': { - 'ip2': '{{rule.k1}}', - 'ip1': '{{trigger.t1_p}}' - } + "action": { + "ref": "sixpack.st2.test.action", + "parameters": {"ip2": "{{rule.k1}}", "ip1": "{{trigger.t1_p}}"}, }, - 'id': '23', - 'description': '' + "id": "23", + "description": "", } -RULE_2 = { # Rule should match. - 'enabled': True, - 'name': 'st2.test.rule2', - 'pack': 'yoyohoneysingh', - 'trigger': { - 'type': 'dummy_pack_1.st2.test.trigger1' - }, - 'criteria': { - 'trigger.k1': { - 'pattern': 't1_p_v', - 'type': 'equals' - } +RULE_2 = { # Rule should match. + "enabled": True, + "name": "st2.test.rule2", + "pack": "yoyohoneysingh", + "trigger": {"type": "dummy_pack_1.st2.test.trigger1"}, + "criteria": {"trigger.k1": {"pattern": "t1_p_v", "type": "equals"}}, + "action": { + "ref": "sixpack.st2.test.action", + "parameters": {"ip2": "{{rule.k1}}", "ip1": "{{trigger.t1_p}}"}, }, - 'action': { - 'ref': 'sixpack.st2.test.action', - 'parameters': { - 'ip2': '{{rule.k1}}', - 'ip1': '{{trigger.t1_p}}' - } - }, - 'id': '23', - 'description': '' + "id": "23", + "description": "", } RULE_3 = { - 'enabled': False, # Disabled rule shouldn't match. - 'name': 'st2.test.rule3', - 'pack': 'yoyohoneysingh', - 'trigger': { - 'type': 'dummy_pack_1.st2.test.trigger1' + "enabled": False, # Disabled rule shouldn't match. + "name": "st2.test.rule3", + "pack": "yoyohoneysingh", + "trigger": {"type": "dummy_pack_1.st2.test.trigger1"}, + "criteria": {"trigger.k1": {"pattern": "t1_p_v", "type": "equals"}}, + "action": { + "ref": "sixpack.st2.test.action", + "parameters": {"ip2": "{{rule.k1}}", "ip1": "{{trigger.t1_p}}"}, }, - 'criteria': { - 'trigger.k1': { - 'pattern': 't1_p_v', - 'type': 'equals' - } - }, - 'action': { - 'ref': 'sixpack.st2.test.action', - 'parameters': { - 'ip2': '{{rule.k1}}', - 'ip1': '{{trigger.t1_p}}' - } - }, - 'id': '23', - 'description': '' + "id": "23", + "description": "", } -RULE_4 = { # Rule should match. - 'enabled': True, - 'name': 'st2.test.rule4', - 'pack': 'yoyohoneysingh', - 'trigger': { - 'type': 'dummy_pack_1.st2.test.trigger4' +RULE_4 = { # Rule should match. + "enabled": True, + "name": "st2.test.rule4", + "pack": "yoyohoneysingh", + "trigger": {"type": "dummy_pack_1.st2.test.trigger4"}, + "criteria": {"trigger.k1": {"pattern": "t2_p_v", "type": "equals"}}, + "action": { + "ref": "sixpack.st2.test.action", + "parameters": {"ip2": "{{rule.k1}}", "ip1": "{{trigger.t1_p}}"}, }, - 'criteria': { - 'trigger.k1': { - 'pattern': 't2_p_v', - 'type': 'equals' - } - }, - 'action': { - 'ref': 'sixpack.st2.test.action', - 'parameters': { - 'ip2': '{{rule.k1}}', - 'ip1': '{{trigger.t1_p}}' - } - }, - 'id': '23', - 'description': '' + "id": "23", + "description": "", } @@ -140,15 +102,15 @@ class RuleMatcherTestCase(CleanDbTestCase): rules = [] def test_get_matching_rules(self): - self._setup_sample_trigger('st2.test.trigger1') + self._setup_sample_trigger("st2.test.trigger1") rule_db_1 = self._setup_sample_rule(RULE_1) rule_db_2 = self._setup_sample_rule(RULE_2) rule_db_3 = self._setup_sample_rule(RULE_3) rules = [rule_db_1, rule_db_2, rule_db_3] trigger_instance = container_utils.create_trigger_instance( - 'dummy_pack_1.st2.test.trigger1', - {'k1': 't1_p_v', 'k2': 'v2'}, - date_utils.get_datetime_utc_now() + "dummy_pack_1.st2.test.trigger1", + {"k1": "t1_p_v", "k2": "v2"}, + date_utils.get_datetime_utc_now(), ) trigger = get_trigger_db_by_ref(trigger_instance.trigger) @@ -159,17 +121,22 @@ def test_get_matching_rules(self): def test_trigger_instance_payload_with_special_values(self): # Test a rule where TriggerInstance payload contains a dot (".") and $ - self._setup_sample_trigger('st2.test.trigger1') - self._setup_sample_trigger('st2.test.trigger2') + self._setup_sample_trigger("st2.test.trigger1") + self._setup_sample_trigger("st2.test.trigger2") rule_db_1 = self._setup_sample_rule(RULE_1) rule_db_2 = self._setup_sample_rule(RULE_2) rule_db_3 = self._setup_sample_rule(RULE_3) rules = [rule_db_1, rule_db_2, rule_db_3] trigger_instance = container_utils.create_trigger_instance( - 'dummy_pack_1.st2.test.trigger2', - {'k1': 't1_p_v', 'k2.k2': 'v2', 'k3.more.nested.deep': 'some.value', - 'k4.even.more.nested$': 'foo', 'yep$aaa': 'b'}, - date_utils.get_datetime_utc_now() + "dummy_pack_1.st2.test.trigger2", + { + "k1": "t1_p_v", + "k2.k2": "v2", + "k3.more.nested.deep": "some.value", + "k4.even.more.nested$": "foo", + "yep$aaa": "b", + }, + date_utils.get_datetime_utc_now(), ) trigger = get_trigger_db_by_ref(trigger_instance.trigger) @@ -178,20 +145,22 @@ def test_trigger_instance_payload_with_special_values(self): self.assertIsNotNone(matching_rules) self.assertEqual(len(matching_rules), 1) - @mock.patch('st2reactor.rules.matcher.RuleFilter._render_criteria_pattern', - mock.Mock(side_effect=Exception('exception in _render_criteria_pattern'))) + @mock.patch( + "st2reactor.rules.matcher.RuleFilter._render_criteria_pattern", + mock.Mock(side_effect=Exception("exception in _render_criteria_pattern")), + ) def test_rule_enforcement_is_created_on_exception_1(self): # 1. Exception in _render_criteria_pattern rule_enforcement_dbs = list(RuleEnforcement.get_all()) self.assertEqual(rule_enforcement_dbs, []) - self._setup_sample_trigger('st2.test.trigger4') + self._setup_sample_trigger("st2.test.trigger4") rule_4_db = self._setup_sample_rule(RULE_4) rules = [rule_4_db] trigger_instance = container_utils.create_trigger_instance( - 'dummy_pack_1.st2.test.trigger4', - {'k1': 't2_p_v', 'k2': 'v2'}, - date_utils.get_datetime_utc_now() + "dummy_pack_1.st2.test.trigger4", + {"k1": "t2_p_v", "k2": "v2"}, + date_utils.get_datetime_utc_now(), ) trigger = get_trigger_db_by_ref(trigger_instance.trigger) @@ -203,29 +172,35 @@ def test_rule_enforcement_is_created_on_exception_1(self): rule_enforcement_dbs = list(RuleEnforcement.get_all()) self.assertEqual(len(rule_enforcement_dbs), 1) - expected_failure = ('Failed to match rule "yoyohoneysingh.st2.test.rule4" against trigger ' - 'instance "%s": Failed to render pattern value "t2_p_v" for key ' - '"trigger.k1": exception in _render_criteria_pattern' % - (str(trigger_instance.id))) + expected_failure = ( + 'Failed to match rule "yoyohoneysingh.st2.test.rule4" against trigger ' + 'instance "%s": Failed to render pattern value "t2_p_v" for key ' + '"trigger.k1": exception in _render_criteria_pattern' + % (str(trigger_instance.id)) + ) self.assertEqual(rule_enforcement_dbs[0].failure_reason, expected_failure) - self.assertEqual(rule_enforcement_dbs[0].trigger_instance_id, str(trigger_instance.id)) - self.assertEqual(rule_enforcement_dbs[0].rule['id'], str(rule_4_db.id)) + self.assertEqual( + rule_enforcement_dbs[0].trigger_instance_id, str(trigger_instance.id) + ) + self.assertEqual(rule_enforcement_dbs[0].rule["id"], str(rule_4_db.id)) self.assertEqual(rule_enforcement_dbs[0].status, RULE_ENFORCEMENT_STATUS_FAILED) - @mock.patch('st2reactor.rules.filter.PayloadLookup.get_value', - mock.Mock(side_effect=Exception('exception in get_value'))) + @mock.patch( + "st2reactor.rules.filter.PayloadLookup.get_value", + mock.Mock(side_effect=Exception("exception in get_value")), + ) def test_rule_enforcement_is_created_on_exception_2(self): # 1. Exception in payload_lookup.get_value rule_enforcement_dbs = list(RuleEnforcement.get_all()) self.assertEqual(rule_enforcement_dbs, []) - self._setup_sample_trigger('st2.test.trigger4') + self._setup_sample_trigger("st2.test.trigger4") rule_4_db = self._setup_sample_rule(RULE_4) rules = [rule_4_db] trigger_instance = container_utils.create_trigger_instance( - 'dummy_pack_1.st2.test.trigger4', - {'k1': 't2_p_v', 'k2': 'v2'}, - date_utils.get_datetime_utc_now() + "dummy_pack_1.st2.test.trigger4", + {"k1": "t2_p_v", "k2": "v2"}, + date_utils.get_datetime_utc_now(), ) trigger = get_trigger_db_by_ref(trigger_instance.trigger) @@ -237,28 +212,34 @@ def test_rule_enforcement_is_created_on_exception_2(self): rule_enforcement_dbs = list(RuleEnforcement.get_all()) self.assertEqual(len(rule_enforcement_dbs), 1) - expected_failure = ('Failed to match rule "yoyohoneysingh.st2.test.rule4" against trigger ' - 'instance "%s": Failed transforming criteria key trigger.k1: ' - 'exception in get_value' % (str(trigger_instance.id))) + expected_failure = ( + 'Failed to match rule "yoyohoneysingh.st2.test.rule4" against trigger ' + 'instance "%s": Failed transforming criteria key trigger.k1: ' + "exception in get_value" % (str(trigger_instance.id)) + ) self.assertEqual(rule_enforcement_dbs[0].failure_reason, expected_failure) - self.assertEqual(rule_enforcement_dbs[0].trigger_instance_id, str(trigger_instance.id)) - self.assertEqual(rule_enforcement_dbs[0].rule['id'], str(rule_4_db.id)) + self.assertEqual( + rule_enforcement_dbs[0].trigger_instance_id, str(trigger_instance.id) + ) + self.assertEqual(rule_enforcement_dbs[0].rule["id"], str(rule_4_db.id)) self.assertEqual(rule_enforcement_dbs[0].status, RULE_ENFORCEMENT_STATUS_FAILED) - @mock.patch('st2common.operators.get_operator', - mock.Mock(return_value=mock.Mock(side_effect=Exception('exception in equals')))) + @mock.patch( + "st2common.operators.get_operator", + mock.Mock(return_value=mock.Mock(side_effect=Exception("exception in equals"))), + ) def test_rule_enforcement_is_created_on_exception_3(self): # 1. Exception in payload_lookup.get_value rule_enforcement_dbs = list(RuleEnforcement.get_all()) self.assertEqual(rule_enforcement_dbs, []) - self._setup_sample_trigger('st2.test.trigger4') + self._setup_sample_trigger("st2.test.trigger4") rule_4_db = self._setup_sample_rule(RULE_4) rules = [rule_4_db] trigger_instance = container_utils.create_trigger_instance( - 'dummy_pack_1.st2.test.trigger4', - {'k1': 't2_p_v', 'k2': 'v2'}, - date_utils.get_datetime_utc_now() + "dummy_pack_1.st2.test.trigger4", + {"k1": "t2_p_v", "k2": "v2"}, + date_utils.get_datetime_utc_now(), ) trigger = get_trigger_db_by_ref(trigger_instance.trigger) @@ -270,22 +251,31 @@ def test_rule_enforcement_is_created_on_exception_3(self): rule_enforcement_dbs = list(RuleEnforcement.get_all()) self.assertEqual(len(rule_enforcement_dbs), 1) - expected_failure = ('Failed to match rule "yoyohoneysingh.st2.test.rule4" against trigger ' - 'instance "%s": There might be a problem with the criteria in rule ' - 'yoyohoneysingh.st2.test.rule4: exception in equals' % - (str(trigger_instance.id))) + expected_failure = ( + 'Failed to match rule "yoyohoneysingh.st2.test.rule4" against trigger ' + 'instance "%s": There might be a problem with the criteria in rule ' + "yoyohoneysingh.st2.test.rule4: exception in equals" + % (str(trigger_instance.id)) + ) self.assertEqual(rule_enforcement_dbs[0].failure_reason, expected_failure) - self.assertEqual(rule_enforcement_dbs[0].trigger_instance_id, str(trigger_instance.id)) - self.assertEqual(rule_enforcement_dbs[0].rule['id'], str(rule_4_db.id)) + self.assertEqual( + rule_enforcement_dbs[0].trigger_instance_id, str(trigger_instance.id) + ) + self.assertEqual(rule_enforcement_dbs[0].rule["id"], str(rule_4_db.id)) self.assertEqual(rule_enforcement_dbs[0].status, RULE_ENFORCEMENT_STATUS_FAILED) def _setup_sample_trigger(self, name): - trigtype = TriggerTypeDB(name=name, pack='dummy_pack_1', payload_schema={}, - parameters_schema={}) + trigtype = TriggerTypeDB( + name=name, pack="dummy_pack_1", payload_schema={}, parameters_schema={} + ) TriggerType.add_or_update(trigtype) - created = TriggerDB(name=name, pack='dummy_pack_1', type=trigtype.get_reference().ref, - parameters={}) + created = TriggerDB( + name=name, + pack="dummy_pack_1", + type=trigtype.get_reference().ref, + parameters={}, + ) Trigger.add_or_update(created) def _setup_sample_rule(self, rule): @@ -295,14 +285,12 @@ def _setup_sample_rule(self, rule): return rule_db -PACK = 'backstop' +PACK = "backstop" FIXTURES_TRIGGERS = { - 'triggertypes': ['triggertype1.yaml'], - 'triggers': ['trigger1.yaml'] -} -FIXTURES_RULES = { - 'rules': ['backstop.yaml', 'success.yaml', 'fail.yaml'] + "triggertypes": ["triggertype1.yaml"], + "triggers": ["trigger1.yaml"], } +FIXTURES_RULES = {"rules": ["backstop.yaml", "success.yaml", "fail.yaml"]} class BackstopRuleMatcherTestCase(DbTestCase): @@ -315,33 +303,41 @@ def setUpClass(cls): # Create TriggerTypes before creation of Rule to avoid failure. Rule requires the # Trigger and therefore TriggerType to be created prior to rule creation. cls.models = fixturesloader.save_fixtures_to_db( - fixtures_pack=PACK, fixtures_dict=FIXTURES_TRIGGERS) - cls.models.update(fixturesloader.save_fixtures_to_db( - fixtures_pack=PACK, fixtures_dict=FIXTURES_RULES)) + fixtures_pack=PACK, fixtures_dict=FIXTURES_TRIGGERS + ) + cls.models.update( + fixturesloader.save_fixtures_to_db( + fixtures_pack=PACK, fixtures_dict=FIXTURES_RULES + ) + ) def test_backstop_ignore(self): trigger_instance = container_utils.create_trigger_instance( - self.models['triggers']['trigger1.yaml'].ref, - {'k1': 'v1'}, - date_utils.get_datetime_utc_now() + self.models["triggers"]["trigger1.yaml"].ref, + {"k1": "v1"}, + date_utils.get_datetime_utc_now(), ) - trigger = self.models['triggers']['trigger1.yaml'] - rules = [rule for rule in six.itervalues(self.models['rules'])] + trigger = self.models["triggers"]["trigger1.yaml"] + rules = [rule for rule in six.itervalues(self.models["rules"])] rules_matcher = RulesMatcher(trigger_instance, trigger, rules) matching_rules = rules_matcher.get_matching_rules() self.assertEqual(len(matching_rules), 1) - self.assertEqual(matching_rules[0].id, self.models['rules']['success.yaml'].id) + self.assertEqual(matching_rules[0].id, self.models["rules"]["success.yaml"].id) def test_backstop_apply(self): trigger_instance = container_utils.create_trigger_instance( - self.models['triggers']['trigger1.yaml'].ref, - {'k1': 'v1'}, - date_utils.get_datetime_utc_now() + self.models["triggers"]["trigger1.yaml"].ref, + {"k1": "v1"}, + date_utils.get_datetime_utc_now(), ) - trigger = self.models['triggers']['trigger1.yaml'] - success_rule = self.models['rules']['success.yaml'] - rules = [rule for rule in six.itervalues(self.models['rules']) if rule != success_rule] + trigger = self.models["triggers"]["trigger1.yaml"] + success_rule = self.models["rules"]["success.yaml"] + rules = [ + rule + for rule in six.itervalues(self.models["rules"]) + if rule != success_rule + ] rules_matcher = RulesMatcher(trigger_instance, trigger, rules) matching_rules = rules_matcher.get_matching_rules() self.assertEqual(len(matching_rules), 1) - self.assertEqual(matching_rules[0].id, self.models['rules']['backstop.yaml'].id) + self.assertEqual(matching_rules[0].id, self.models["rules"]["backstop.yaml"].id) diff --git a/st2reactor/tests/unit/test_sensor_and_rule_registration.py b/st2reactor/tests/unit/test_sensor_and_rule_registration.py index 3f54e97c738..50075690e9c 100644 --- a/st2reactor/tests/unit/test_sensor_and_rule_registration.py +++ b/st2reactor/tests/unit/test_sensor_and_rule_registration.py @@ -27,22 +27,20 @@ from st2common.bootstrap.sensorsregistrar import SensorsRegistrar from st2common.bootstrap.rulesregistrar import RulesRegistrar -__all__ = [ - 'SensorRegistrationTestCase', - 'RuleRegistrationTestCase' -] +__all__ = ["SensorRegistrationTestCase", "RuleRegistrationTestCase"] CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) -PACKS_DIR = os.path.abspath(os.path.join(CURRENT_DIR, '../fixtures/packs')) +PACKS_DIR = os.path.abspath(os.path.join(CURRENT_DIR, "../fixtures/packs")) # NOTE: We need to perform this patching because test fixtures are located outside of the packs # base paths directory. This will never happen outside the context of test fixtures. -@mock.patch('st2common.content.utils.get_pack_base_path', - mock.Mock(return_value=os.path.join(PACKS_DIR, 'pack_with_sensor'))) +@mock.patch( + "st2common.content.utils.get_pack_base_path", + mock.Mock(return_value=os.path.join(PACKS_DIR, "pack_with_sensor")), +) class SensorRegistrationTestCase(DbTestCase): - - @mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) + @mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) def test_register_sensors(self): # Verify DB is empty at the beginning self.assertEqual(len(SensorType.get_all()), 0) @@ -61,29 +59,33 @@ def test_register_sensors(self): self.assertEqual(len(trigger_type_dbs), 2) self.assertEqual(len(trigger_dbs), 2) - self.assertEqual(sensor_dbs[0].name, 'TestSensor') + self.assertEqual(sensor_dbs[0].name, "TestSensor") self.assertEqual(sensor_dbs[0].poll_interval, 10) self.assertTrue(sensor_dbs[0].enabled) - self.assertEqual(sensor_dbs[0].metadata_file, 'sensors/test_sensor_1.yaml') + self.assertEqual(sensor_dbs[0].metadata_file, "sensors/test_sensor_1.yaml") - self.assertEqual(sensor_dbs[1].name, 'TestSensorDisabled') + self.assertEqual(sensor_dbs[1].name, "TestSensorDisabled") self.assertEqual(sensor_dbs[1].poll_interval, 10) self.assertFalse(sensor_dbs[1].enabled) - self.assertEqual(sensor_dbs[1].metadata_file, 'sensors/test_sensor_2.yaml') + self.assertEqual(sensor_dbs[1].metadata_file, "sensors/test_sensor_2.yaml") - self.assertEqual(trigger_type_dbs[0].name, 'trigger_type_1') - self.assertEqual(trigger_type_dbs[0].pack, 'pack_with_sensor') + self.assertEqual(trigger_type_dbs[0].name, "trigger_type_1") + self.assertEqual(trigger_type_dbs[0].pack, "pack_with_sensor") self.assertEqual(len(trigger_type_dbs[0].tags), 0) - self.assertEqual(trigger_type_dbs[1].name, 'trigger_type_2') - self.assertEqual(trigger_type_dbs[1].pack, 'pack_with_sensor') + self.assertEqual(trigger_type_dbs[1].name, "trigger_type_2") + self.assertEqual(trigger_type_dbs[1].pack, "pack_with_sensor") self.assertEqual(len(trigger_type_dbs[1].tags), 2) - self.assertEqual(trigger_type_dbs[1].tags[0].name, 'tag1name') - self.assertEqual(trigger_type_dbs[1].tags[0].value, 'tag1 value') + self.assertEqual(trigger_type_dbs[1].tags[0].name, "tag1name") + self.assertEqual(trigger_type_dbs[1].tags[0].value, "tag1 value") # Triggered which are registered via sensors have metadata_file pointing to the sensor # definition file - self.assertEqual(trigger_type_dbs[0].metadata_file, 'sensors/test_sensor_1.yaml') - self.assertEqual(trigger_type_dbs[1].metadata_file, 'sensors/test_sensor_1.yaml') + self.assertEqual( + trigger_type_dbs[0].metadata_file, "sensors/test_sensor_1.yaml" + ) + self.assertEqual( + trigger_type_dbs[1].metadata_file, "sensors/test_sensor_1.yaml" + ) # Verify second call to registration doesn't create a duplicate objects registrar.register_from_packs(base_dirs=[PACKS_DIR]) @@ -96,13 +98,13 @@ def test_register_sensors(self): self.assertEqual(len(trigger_type_dbs), 2) self.assertEqual(len(trigger_dbs), 2) - self.assertEqual(sensor_dbs[0].name, 'TestSensor') + self.assertEqual(sensor_dbs[0].name, "TestSensor") self.assertEqual(sensor_dbs[0].poll_interval, 10) - self.assertEqual(trigger_type_dbs[0].name, 'trigger_type_1') - self.assertEqual(trigger_type_dbs[0].pack, 'pack_with_sensor') - self.assertEqual(trigger_type_dbs[1].name, 'trigger_type_2') - self.assertEqual(trigger_type_dbs[1].pack, 'pack_with_sensor') + self.assertEqual(trigger_type_dbs[0].name, "trigger_type_1") + self.assertEqual(trigger_type_dbs[0].pack, "pack_with_sensor") + self.assertEqual(trigger_type_dbs[1].name, "trigger_type_2") + self.assertEqual(trigger_type_dbs[1].pack, "pack_with_sensor") # Verify sensor and trigger data is updated on registration original_load = registrar._meta_loader.load @@ -110,9 +112,10 @@ def test_register_sensors(self): def mock_load(*args, **kwargs): # Update poll_interval and trigger_type_2 description data = original_load(*args, **kwargs) - data['poll_interval'] = 50 - data['trigger_types'][1]['description'] = 'test 2' + data["poll_interval"] = 50 + data["trigger_types"][1]["description"] = "test 2" return data + registrar._meta_loader.load = mock_load registrar.register_from_packs(base_dirs=[PACKS_DIR]) @@ -125,20 +128,22 @@ def mock_load(*args, **kwargs): self.assertEqual(len(trigger_type_dbs), 2) self.assertEqual(len(trigger_dbs), 2) - self.assertEqual(sensor_dbs[0].name, 'TestSensor') + self.assertEqual(sensor_dbs[0].name, "TestSensor") self.assertEqual(sensor_dbs[0].poll_interval, 50) - self.assertEqual(trigger_type_dbs[0].name, 'trigger_type_1') - self.assertEqual(trigger_type_dbs[0].pack, 'pack_with_sensor') - self.assertEqual(trigger_type_dbs[1].name, 'trigger_type_2') - self.assertEqual(trigger_type_dbs[1].pack, 'pack_with_sensor') - self.assertEqual(trigger_type_dbs[1].description, 'test 2') + self.assertEqual(trigger_type_dbs[0].name, "trigger_type_1") + self.assertEqual(trigger_type_dbs[0].pack, "pack_with_sensor") + self.assertEqual(trigger_type_dbs[1].name, "trigger_type_2") + self.assertEqual(trigger_type_dbs[1].pack, "pack_with_sensor") + self.assertEqual(trigger_type_dbs[1].description, "test 2") # NOTE: We need to perform this patching because test fixtures are located outside of the packs # base paths directory. This will never happen outside the context of test fixtures. -@mock.patch('st2common.content.utils.get_pack_base_path', - mock.Mock(return_value=os.path.join(PACKS_DIR, 'pack_with_rules'))) +@mock.patch( + "st2common.content.utils.get_pack_base_path", + mock.Mock(return_value=os.path.join(PACKS_DIR, "pack_with_rules")), +) class RuleRegistrationTestCase(DbTestCase): def test_register_rules(self): # Verify DB is empty at the beginning @@ -154,8 +159,8 @@ def test_register_rules(self): self.assertEqual(len(rule_dbs), 2) self.assertEqual(len(trigger_dbs), 1) - self.assertEqual(rule_dbs[0].name, 'sample.with_the_same_timer') - self.assertEqual(rule_dbs[1].name, 'sample.with_timer') + self.assertEqual(rule_dbs[0].name, "sample.with_the_same_timer") + self.assertEqual(rule_dbs[1].name, "sample.with_timer") self.assertIsNotNone(trigger_dbs[0].name) # Verify second register call updates existing models diff --git a/st2reactor/tests/unit/test_sensor_service.py b/st2reactor/tests/unit/test_sensor_service.py index 2064c25ee32..9d1e245e104 100644 --- a/st2reactor/tests/unit/test_sensor_service.py +++ b/st2reactor/tests/unit/test_sensor_service.py @@ -23,22 +23,20 @@ from st2common.constants.keyvalue import SYSTEM_SCOPE from st2common.constants.keyvalue import USER_SCOPE -__all__ = [ - 'SensorServiceTestCase' -] +__all__ = ["SensorServiceTestCase"] # This trigger has schema that uses all property types TEST_SCHEMA = { - 'type': 'object', - 'additionalProperties': False, - 'properties': { - 'age': {'type': 'integer'}, - 'name': {'type': 'string', 'required': True}, - 'address': {'type': 'string', 'default': '-'}, - 'career': {'type': 'array'}, - 'married': {'type': 'boolean'}, - 'awards': {'type': 'object'}, - 'income': {'anyOf': [{'type': 'integer'}, {'type': 'string'}]}, + "type": "object", + "additionalProperties": False, + "properties": { + "age": {"type": "integer"}, + "name": {"type": "string", "required": True}, + "address": {"type": "string", "default": "-"}, + "career": {"type": "array"}, + "married": {"type": "boolean"}, + "awards": {"type": "object"}, + "income": {"anyOf": [{"type": "integer"}, {"type": "string"}]}, }, } @@ -60,8 +58,9 @@ def side_effect(trigger, payload, trace_context): self.sensor_service = SensorService(mock.MagicMock()) self.sensor_service._trigger_dispatcher_service._dispatcher = mock.Mock() - self.sensor_service._trigger_dispatcher_service._dispatcher.dispatch = \ + self.sensor_service._trigger_dispatcher_service._dispatcher.dispatch = ( mock.MagicMock(side_effect=side_effect) + ) self._dispatched_count = 0 # Previously, cfg.CONF.system.validate_trigger_payload was set to False explicitly @@ -73,55 +72,65 @@ def tearDown(self): # Replace original configured value for payload validation cfg.CONF.system.validate_trigger_payload = self.validate_trigger_payload - @mock.patch('st2common.services.triggers.get_trigger_type_db', - mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA))) + @mock.patch( + "st2common.services.triggers.get_trigger_type_db", + mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA)), + ) def test_dispatch_success_valid_payload_validation_enabled(self): cfg.CONF.system.validate_trigger_payload = True # define a valid payload payload = { - 'name': 'John Doe', - 'age': 25, - 'career': ['foo, Inc.', 'bar, Inc.'], - 'married': True, - 'awards': {'2016': ['hoge prize', 'fuga prize']}, - 'income': 50000 + "name": "John Doe", + "age": 25, + "career": ["foo, Inc.", "bar, Inc."], + "married": True, + "awards": {"2016": ["hoge prize", "fuga prize"]}, + "income": 50000, } # dispatching a trigger - self.sensor_service.dispatch('trigger-name', payload) + self.sensor_service.dispatch("trigger-name", payload) # This assumed that the target tirgger dispatched self.assertEqual(self._dispatched_count, 1) - @mock.patch('st2common.services.triggers.get_trigger_type_db', - mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA))) - @mock.patch('st2common.services.triggers.get_trigger_db_by_ref', - mock.MagicMock(return_value=TriggerDBMock(type='trigger-type-ref'))) + @mock.patch( + "st2common.services.triggers.get_trigger_type_db", + mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA)), + ) + @mock.patch( + "st2common.services.triggers.get_trigger_db_by_ref", + mock.MagicMock(return_value=TriggerDBMock(type="trigger-type-ref")), + ) def test_dispatch_success_with_validation_enabled_trigger_reference(self): # Test a scenario where a Trigger ref and not TriggerType ref is provided cfg.CONF.system.validate_trigger_payload = True # define a valid payload payload = { - 'name': 'John Doe', - 'age': 25, - 'career': ['foo, Inc.', 'bar, Inc.'], - 'married': True, - 'awards': {'2016': ['hoge prize', 'fuga prize']}, - 'income': 50000 + "name": "John Doe", + "age": 25, + "career": ["foo, Inc.", "bar, Inc."], + "married": True, + "awards": {"2016": ["hoge prize", "fuga prize"]}, + "income": 50000, } self.assertEqual(self._dispatched_count, 0) # dispatching a trigger - self.sensor_service.dispatch('pack.86582f21-1fbc-44ea-88cb-0cd2b610e93b', payload) + self.sensor_service.dispatch( + "pack.86582f21-1fbc-44ea-88cb-0cd2b610e93b", payload + ) # This assumed that the target tirgger dispatched self.assertEqual(self._dispatched_count, 1) - @mock.patch('st2common.services.triggers.get_trigger_type_db', - mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA))) + @mock.patch( + "st2common.services.triggers.get_trigger_type_db", + mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA)), + ) def test_dispatch_success_with_validation_disabled_and_invalid_payload(self): """ Tests that an invalid payload still results in dispatch success with default config @@ -143,29 +152,31 @@ def test_dispatch_success_with_validation_disabled_and_invalid_payload(self): # define a invalid payload (the type of 'age' is incorrect) payload = { - 'name': 'John Doe', - 'age': '25', + "name": "John Doe", + "age": "25", } - self.sensor_service.dispatch('trigger-name', payload) + self.sensor_service.dispatch("trigger-name", payload) # The default config is to disable validation. So, we want to make sure # the dispatch actually went through. self.assertEqual(self._dispatched_count, 1) - @mock.patch('st2common.services.triggers.get_trigger_type_db', - mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA))) + @mock.patch( + "st2common.services.triggers.get_trigger_type_db", + mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA)), + ) def test_dispatch_failure_caused_by_incorrect_type(self): # define a invalid payload (the type of 'age' is incorrect) payload = { - 'name': 'John Doe', - 'age': '25', + "name": "John Doe", + "age": "25", } # set config to stop dispatching when the payload comply with target trigger_type cfg.CONF.system.validate_trigger_payload = True - self.sensor_service.dispatch('trigger-name', payload) + self.sensor_service.dispatch("trigger-name", payload) # This assumed that the target trigger isn't dispatched self.assertEqual(self._dispatched_count, 0) @@ -173,120 +184,130 @@ def test_dispatch_failure_caused_by_incorrect_type(self): # reset config to permit force dispatching cfg.CONF.system.validate_trigger_payload = False - self.sensor_service.dispatch('trigger-name', payload) + self.sensor_service.dispatch("trigger-name", payload) self.assertEqual(self._dispatched_count, 1) - @mock.patch('st2common.services.triggers.get_trigger_type_db', - mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA))) + @mock.patch( + "st2common.services.triggers.get_trigger_type_db", + mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA)), + ) def test_dispatch_failure_caused_by_lack_of_required_parameter(self): # define a invalid payload (lack of required property) payload = { - 'age': 25, + "age": 25, } cfg.CONF.system.validate_trigger_payload = True - self.sensor_service.dispatch('trigger-name', payload) + self.sensor_service.dispatch("trigger-name", payload) self.assertEqual(self._dispatched_count, 0) # reset config to permit force dispatching cfg.CONF.system.validate_trigger_payload = False - self.sensor_service.dispatch('trigger-name', payload) + self.sensor_service.dispatch("trigger-name", payload) self.assertEqual(self._dispatched_count, 1) - @mock.patch('st2common.services.triggers.get_trigger_type_db', - mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA))) + @mock.patch( + "st2common.services.triggers.get_trigger_type_db", + mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA)), + ) def test_dispatch_failure_caused_by_extra_parameter(self): # define a invalid payload ('hobby' is extra) payload = { - 'name': 'John Doe', - 'hobby': 'programming', + "name": "John Doe", + "hobby": "programming", } cfg.CONF.system.validate_trigger_payload = True - self.sensor_service.dispatch('trigger-name', payload) + self.sensor_service.dispatch("trigger-name", payload) self.assertEqual(self._dispatched_count, 0) - @mock.patch('st2common.services.triggers.get_trigger_type_db', - mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA))) + @mock.patch( + "st2common.services.triggers.get_trigger_type_db", + mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA)), + ) def test_dispatch_success_with_multiple_type_value(self): payload = { - 'name': 'John Doe', - 'income': 1234, + "name": "John Doe", + "income": 1234, } cfg.CONF.system.validate_trigger_payload = True - self.sensor_service.dispatch('trigger-name', payload) + self.sensor_service.dispatch("trigger-name", payload) # reset payload which can have different type - payload['income'] = 'secret' + payload["income"] = "secret" - self.sensor_service.dispatch('trigger-name', payload) + self.sensor_service.dispatch("trigger-name", payload) self.assertEqual(self._dispatched_count, 2) - @mock.patch('st2common.services.triggers.get_trigger_type_db', - mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA))) + @mock.patch( + "st2common.services.triggers.get_trigger_type_db", + mock.MagicMock(return_value=TriggerTypeDBMock(TEST_SCHEMA)), + ) def test_dispatch_success_with_null(self): payload = { - 'name': 'John Doe', - 'age': None, + "name": "John Doe", + "age": None, } cfg.CONF.system.validate_trigger_payload = True - self.sensor_service.dispatch('trigger-name', payload) + self.sensor_service.dispatch("trigger-name", payload) self.assertEqual(self._dispatched_count, 1) - @mock.patch('st2common.services.triggers.get_trigger_type_db', - mock.MagicMock(return_value=TriggerTypeDBMock())) + @mock.patch( + "st2common.services.triggers.get_trigger_type_db", + mock.MagicMock(return_value=TriggerTypeDBMock()), + ) def test_dispatch_success_without_payload_schema(self): # the case trigger has no property - self.sensor_service.dispatch('trigger-name', {}) + self.sensor_service.dispatch("trigger-name", {}) self.assertEqual(self._dispatched_count, 1) - @mock.patch('st2common.services.triggers.get_trigger_type_db', - mock.MagicMock(return_value=None)) + @mock.patch( + "st2common.services.triggers.get_trigger_type_db", + mock.MagicMock(return_value=None), + ) def test_dispatch_trigger_type_not_in_db_should_not_dispatch(self): cfg.CONF.system.validate_trigger_payload = True - self.sensor_service.dispatch('not-in-database-ref', {}) + self.sensor_service.dispatch("not-in-database-ref", {}) self.assertEqual(self._dispatched_count, 0) def test_datastore_methods(self): self.sensor_service._datastore_service = mock.Mock() # Verify methods take encrypt, decrypt and scope arguments - self.sensor_service.get_value(name='foo1', scope=SYSTEM_SCOPE, decrypt=True) + self.sensor_service.get_value(name="foo1", scope=SYSTEM_SCOPE, decrypt=True) call_kwargs = self.sensor_service.datastore_service.get_value.call_args[1] expected_kwargs = { - 'name': 'foo1', - 'local': True, - 'scope': SYSTEM_SCOPE, - 'decrypt': True + "name": "foo1", + "local": True, + "scope": SYSTEM_SCOPE, + "decrypt": True, } self.assertEqual(call_kwargs, expected_kwargs) - self.sensor_service.set_value(name='foo2', value='bar', scope=USER_SCOPE, encrypt=True) + self.sensor_service.set_value( + name="foo2", value="bar", scope=USER_SCOPE, encrypt=True + ) call_kwargs = self.sensor_service.datastore_service.set_value.call_args[1] expected_kwargs = { - 'name': 'foo2', - 'value': 'bar', - 'ttl': None, - 'local': True, - 'scope': USER_SCOPE, - 'encrypt': True + "name": "foo2", + "value": "bar", + "ttl": None, + "local": True, + "scope": USER_SCOPE, + "encrypt": True, } self.assertEqual(call_kwargs, expected_kwargs) - self.sensor_service.delete_value(name='foo3', scope=USER_SCOPE) + self.sensor_service.delete_value(name="foo3", scope=USER_SCOPE) call_kwargs = self.sensor_service.datastore_service.delete_value.call_args[1] - expected_kwargs = { - 'name': 'foo3', - 'local': True, - 'scope': USER_SCOPE - } + expected_kwargs = {"name": "foo3", "local": True, "scope": USER_SCOPE} self.assertEqual(call_kwargs, expected_kwargs) diff --git a/st2reactor/tests/unit/test_sensor_wrapper.py b/st2reactor/tests/unit/test_sensor_wrapper.py index 735e0e545bf..b2d637812d7 100644 --- a/st2reactor/tests/unit/test_sensor_wrapper.py +++ b/st2reactor/tests/unit/test_sensor_wrapper.py @@ -16,6 +16,7 @@ from __future__ import absolute_import from st2common.util.monkey_patch import monkey_patch + monkey_patch() import os @@ -33,11 +34,9 @@ from st2reactor.sensor.base import Sensor, PollingSensor CURRENT_DIR = os.path.abspath(os.path.dirname(__file__)) -RESOURCES_DIR = os.path.abspath(os.path.join(CURRENT_DIR, '../resources')) +RESOURCES_DIR = os.path.abspath(os.path.join(CURRENT_DIR, "../resources")) -__all__ = [ - 'SensorWrapperTestCase' -] +__all__ = ["SensorWrapperTestCase"] class SensorWrapperTestCase(unittest2.TestCase): @@ -47,27 +46,33 @@ def setUpClass(cls): tests_config.parse_args() def test_sensor_instance_has_sensor_service(self): - file_path = os.path.join(RESOURCES_DIR, 'test_sensor.py') - trigger_types = ['trigger1', 'trigger2'] - parent_args = ['--config-file', TESTS_CONFIG_PATH] - - wrapper = SensorWrapper(pack='core', file_path=file_path, - class_name='TestSensor', - trigger_types=trigger_types, - parent_args=parent_args) - self.assertIsNotNone(getattr(wrapper._sensor_instance, 'sensor_service', None)) - self.assertIsNotNone(getattr(wrapper._sensor_instance, 'config', None)) + file_path = os.path.join(RESOURCES_DIR, "test_sensor.py") + trigger_types = ["trigger1", "trigger2"] + parent_args = ["--config-file", TESTS_CONFIG_PATH] + + wrapper = SensorWrapper( + pack="core", + file_path=file_path, + class_name="TestSensor", + trigger_types=trigger_types, + parent_args=parent_args, + ) + self.assertIsNotNone(getattr(wrapper._sensor_instance, "sensor_service", None)) + self.assertIsNotNone(getattr(wrapper._sensor_instance, "config", None)) def test_trigger_cud_event_handlers(self): - trigger_id = '57861fcb0640fd1524e577c0' - file_path = os.path.join(RESOURCES_DIR, 'test_sensor.py') - trigger_types = ['trigger1', 'trigger2'] - parent_args = ['--config-file', TESTS_CONFIG_PATH] - - wrapper = SensorWrapper(pack='core', file_path=file_path, - class_name='TestSensor', - trigger_types=trigger_types, - parent_args=parent_args) + trigger_id = "57861fcb0640fd1524e577c0" + file_path = os.path.join(RESOURCES_DIR, "test_sensor.py") + trigger_types = ["trigger1", "trigger2"] + parent_args = ["--config-file", TESTS_CONFIG_PATH] + + wrapper = SensorWrapper( + pack="core", + file_path=file_path, + class_name="TestSensor", + trigger_types=trigger_types, + parent_args=parent_args, + ) self.assertEqual(wrapper._trigger_names, {}) @@ -78,7 +83,9 @@ def test_trigger_cud_event_handlers(self): # Call create handler with a trigger which refers to this sensor self.assertEqual(wrapper._sensor_instance.add_trigger.call_count, 0) - trigger = TriggerDB(id=trigger_id, name='test', pack='dummy', type=trigger_types[0]) + trigger = TriggerDB( + id=trigger_id, name="test", pack="dummy", type=trigger_types[0] + ) wrapper._handle_create_trigger(trigger=trigger) self.assertEqual(wrapper._trigger_names, {trigger_id: trigger}) self.assertEqual(wrapper._sensor_instance.add_trigger.call_count, 1) @@ -86,7 +93,9 @@ def test_trigger_cud_event_handlers(self): # Validate that update handler updates the trigger_names self.assertEqual(wrapper._sensor_instance.update_trigger.call_count, 0) - trigger = TriggerDB(id=trigger_id, name='test', pack='dummy', type=trigger_types[0]) + trigger = TriggerDB( + id=trigger_id, name="test", pack="dummy", type=trigger_types[0] + ) wrapper._handle_update_trigger(trigger=trigger) self.assertEqual(wrapper._trigger_names, {trigger_id: trigger}) self.assertEqual(wrapper._sensor_instance.update_trigger.call_count, 1) @@ -94,70 +103,97 @@ def test_trigger_cud_event_handlers(self): # Validate that delete handler deletes the trigger from trigger_names self.assertEqual(wrapper._sensor_instance.remove_trigger.call_count, 0) - trigger = TriggerDB(id=trigger_id, name='test', pack='dummy', type=trigger_types[0]) + trigger = TriggerDB( + id=trigger_id, name="test", pack="dummy", type=trigger_types[0] + ) wrapper._handle_delete_trigger(trigger=trigger) self.assertEqual(wrapper._trigger_names, {}) self.assertEqual(wrapper._sensor_instance.remove_trigger.call_count, 1) def test_sensor_creation_passive(self): - file_path = os.path.join(RESOURCES_DIR, 'test_sensor.py') - trigger_types = ['trigger1', 'trigger2'] - parent_args = ['--config-file', TESTS_CONFIG_PATH] - - wrapper = SensorWrapper(pack='core', file_path=file_path, - class_name='TestSensor', - trigger_types=trigger_types, - parent_args=parent_args) + file_path = os.path.join(RESOURCES_DIR, "test_sensor.py") + trigger_types = ["trigger1", "trigger2"] + parent_args = ["--config-file", TESTS_CONFIG_PATH] + + wrapper = SensorWrapper( + pack="core", + file_path=file_path, + class_name="TestSensor", + trigger_types=trigger_types, + parent_args=parent_args, + ) self.assertIsInstance(wrapper._sensor_instance, Sensor) self.assertIsNotNone(wrapper._sensor_instance) def test_sensor_creation_active(self): - file_path = os.path.join(RESOURCES_DIR, 'test_sensor.py') - trigger_types = ['trigger1', 'trigger2'] - parent_args = ['--config-file', TESTS_CONFIG_PATH] + file_path = os.path.join(RESOURCES_DIR, "test_sensor.py") + trigger_types = ["trigger1", "trigger2"] + parent_args = ["--config-file", TESTS_CONFIG_PATH] poll_interval = 10 - wrapper = SensorWrapper(pack='core', file_path=file_path, - class_name='TestPollingSensor', - trigger_types=trigger_types, - parent_args=parent_args, - poll_interval=poll_interval) + wrapper = SensorWrapper( + pack="core", + file_path=file_path, + class_name="TestPollingSensor", + trigger_types=trigger_types, + parent_args=parent_args, + poll_interval=poll_interval, + ) self.assertIsNotNone(wrapper._sensor_instance) self.assertIsInstance(wrapper._sensor_instance, PollingSensor) self.assertEqual(wrapper._sensor_instance._poll_interval, poll_interval) def test_sensor_init_fails_file_doesnt_exist(self): - file_path = os.path.join(RESOURCES_DIR, 'test_sensor_doesnt_exist.py') - trigger_types = ['trigger1', 'trigger2'] - parent_args = ['--config-file', TESTS_CONFIG_PATH] - - expected_msg = 'Failed to load sensor class from file.*? No such file or directory' - self.assertRaisesRegexp(IOError, expected_msg, SensorWrapper, - pack='core', file_path=file_path, - class_name='TestSensor', - trigger_types=trigger_types, - parent_args=parent_args) + file_path = os.path.join(RESOURCES_DIR, "test_sensor_doesnt_exist.py") + trigger_types = ["trigger1", "trigger2"] + parent_args = ["--config-file", TESTS_CONFIG_PATH] + + expected_msg = ( + "Failed to load sensor class from file.*? No such file or directory" + ) + self.assertRaisesRegexp( + IOError, + expected_msg, + SensorWrapper, + pack="core", + file_path=file_path, + class_name="TestSensor", + trigger_types=trigger_types, + parent_args=parent_args, + ) def test_sensor_init_fails_sensor_code_contains_typo(self): - file_path = os.path.join(RESOURCES_DIR, 'test_sensor_with_typo.py') - trigger_types = ['trigger1', 'trigger2'] - parent_args = ['--config-file', TESTS_CONFIG_PATH] - - expected_msg = 'Failed to load sensor class from file.*? \'typobar\' is not defined' - self.assertRaisesRegexp(NameError, expected_msg, SensorWrapper, - pack='core', file_path=file_path, - class_name='TestSensor', - trigger_types=trigger_types, - parent_args=parent_args) + file_path = os.path.join(RESOURCES_DIR, "test_sensor_with_typo.py") + trigger_types = ["trigger1", "trigger2"] + parent_args = ["--config-file", TESTS_CONFIG_PATH] + + expected_msg = ( + "Failed to load sensor class from file.*? 'typobar' is not defined" + ) + self.assertRaisesRegexp( + NameError, + expected_msg, + SensorWrapper, + pack="core", + file_path=file_path, + class_name="TestSensor", + trigger_types=trigger_types, + parent_args=parent_args, + ) # Verify error message also contains traceback try: - SensorWrapper(pack='core', file_path=file_path, class_name='TestSensor', - trigger_types=trigger_types, parent_args=parent_args) + SensorWrapper( + pack="core", + file_path=file_path, + class_name="TestSensor", + trigger_types=trigger_types, + parent_args=parent_args, + ) except NameError as e: - self.assertIn('Traceback (most recent call last)', six.text_type(e)) - self.assertIn('line 20, in ', six.text_type(e)) + self.assertIn("Traceback (most recent call last)", six.text_type(e)) + self.assertIn("line 20, in ", six.text_type(e)) else: - self.fail('NameError not thrown') + self.fail("NameError not thrown") def test_sensor_wrapper_poll_method_still_works(self): # Verify that sensor wrapper correctly applied select.poll() eventlet workaround so code @@ -167,5 +203,5 @@ def test_sensor_wrapper_poll_method_still_works(self): import select self.assertTrue(eventlet.patcher.is_monkey_patched(select)) - self.assertTrue(select != eventlet.patcher.original('select')) + self.assertTrue(select != eventlet.patcher.original("select")) self.assertTrue(select.poll()) diff --git a/st2reactor/tests/unit/test_tester.py b/st2reactor/tests/unit/test_tester.py index f1f1b01886f..60cd6919b87 100644 --- a/st2reactor/tests/unit/test_tester.py +++ b/st2reactor/tests/unit/test_tester.py @@ -25,65 +25,77 @@ BASE_PATH = os.path.dirname(os.path.abspath(__file__)) -FIXTURES_PACK = 'generic' +FIXTURES_PACK = "generic" TEST_MODELS_TRIGGERS = { - 'triggertypes': ['triggertype1.yaml', 'triggertype2.yaml'], - 'triggers': ['trigger1.yaml', 'trigger2.yaml'], - 'triggerinstances': ['trigger_instance_1.yaml', 'trigger_instance_2.yaml'] + "triggertypes": ["triggertype1.yaml", "triggertype2.yaml"], + "triggers": ["trigger1.yaml", "trigger2.yaml"], + "triggerinstances": ["trigger_instance_1.yaml", "trigger_instance_2.yaml"], } -TEST_MODELS_RULES = { - 'rules': ['rule1.yaml'] -} +TEST_MODELS_RULES = {"rules": ["rule1.yaml"]} -TEST_MODELS_ACTIONS = { - 'actions': ['action1.yaml'] -} +TEST_MODELS_ACTIONS = {"actions": ["action1.yaml"]} -@mock.patch.object(PoolPublisher, 'publish', mock.MagicMock()) +@mock.patch.object(PoolPublisher, "publish", mock.MagicMock()) class RuleTesterTestCase(CleanDbTestCase): def test_matching_trigger_from_file(self): - FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_MODELS_ACTIONS) - rule_file_path = os.path.join(BASE_PATH, '../fixtures/rule.yaml') - trigger_instance_file_path = os.path.join(BASE_PATH, '../fixtures/trigger_instance_1.yaml') - tester = RuleTester(rule_file_path=rule_file_path, - trigger_instance_file_path=trigger_instance_file_path) + FixturesLoader().save_fixtures_to_db( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS_ACTIONS + ) + rule_file_path = os.path.join(BASE_PATH, "../fixtures/rule.yaml") + trigger_instance_file_path = os.path.join( + BASE_PATH, "../fixtures/trigger_instance_1.yaml" + ) + tester = RuleTester( + rule_file_path=rule_file_path, + trigger_instance_file_path=trigger_instance_file_path, + ) matching = tester.evaluate() self.assertTrue(matching) def test_non_matching_trigger_from_file(self): - rule_file_path = os.path.join(BASE_PATH, '../fixtures/rule.yaml') - trigger_instance_file_path = os.path.join(BASE_PATH, '../fixtures/trigger_instance_2.yaml') - tester = RuleTester(rule_file_path=rule_file_path, - trigger_instance_file_path=trigger_instance_file_path) + rule_file_path = os.path.join(BASE_PATH, "../fixtures/rule.yaml") + trigger_instance_file_path = os.path.join( + BASE_PATH, "../fixtures/trigger_instance_2.yaml" + ) + tester = RuleTester( + rule_file_path=rule_file_path, + trigger_instance_file_path=trigger_instance_file_path, + ) matching = tester.evaluate() self.assertFalse(matching) def test_matching_trigger_from_db(self): - FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_MODELS_ACTIONS) - models = FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_MODELS_TRIGGERS) - trigger_instance_db = models['triggerinstances']['trigger_instance_2.yaml'] - models = FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_MODELS_RULES) - rule_db = models['rules']['rule1.yaml'] - tester = RuleTester(rule_ref=rule_db.ref, - trigger_instance_id=str(trigger_instance_db.id)) + FixturesLoader().save_fixtures_to_db( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS_ACTIONS + ) + models = FixturesLoader().save_fixtures_to_db( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS_TRIGGERS + ) + trigger_instance_db = models["triggerinstances"]["trigger_instance_2.yaml"] + models = FixturesLoader().save_fixtures_to_db( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS_RULES + ) + rule_db = models["rules"]["rule1.yaml"] + tester = RuleTester( + rule_ref=rule_db.ref, trigger_instance_id=str(trigger_instance_db.id) + ) matching = tester.evaluate() self.assertTrue(matching) def test_non_matching_trigger_from_db(self): - models = FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_MODELS_TRIGGERS) - trigger_instance_db = models['triggerinstances']['trigger_instance_1.yaml'] - models = FixturesLoader().save_fixtures_to_db(fixtures_pack=FIXTURES_PACK, - fixtures_dict=TEST_MODELS_RULES) - rule_db = models['rules']['rule1.yaml'] - tester = RuleTester(rule_ref=rule_db.ref, - trigger_instance_id=str(trigger_instance_db.id)) + models = FixturesLoader().save_fixtures_to_db( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS_TRIGGERS + ) + trigger_instance_db = models["triggerinstances"]["trigger_instance_1.yaml"] + models = FixturesLoader().save_fixtures_to_db( + fixtures_pack=FIXTURES_PACK, fixtures_dict=TEST_MODELS_RULES + ) + rule_db = models["rules"]["rule1.yaml"] + tester = RuleTester( + rule_ref=rule_db.ref, trigger_instance_id=str(trigger_instance_db.id) + ) matching = tester.evaluate() self.assertFalse(matching) diff --git a/st2reactor/tests/unit/test_timer.py b/st2reactor/tests/unit/test_timer.py index 861d74349e1..f4311d18d83 100644 --- a/st2reactor/tests/unit/test_timer.py +++ b/st2reactor/tests/unit/test_timer.py @@ -60,9 +60,14 @@ def test_existing_rules_are_loaded_on_start(self): # Add a dummy timer Trigger object type_ = list(TIMER_TRIGGER_TYPES.keys())[0] - parameters = {'unit': 'seconds', 'delta': 1000} - trigger_db = TriggerDB(id=bson.ObjectId(), name='test_trigger_1', pack='dummy', - type=type_, parameters=parameters) + parameters = {"unit": "seconds", "delta": 1000} + trigger_db = TriggerDB( + id=bson.ObjectId(), + name="test_trigger_1", + pack="dummy", + type=type_, + parameters=parameters, + ) trigger_db = Trigger.add_or_update(trigger_db) # Verify object has been added @@ -74,7 +79,7 @@ def test_existing_rules_are_loaded_on_start(self): # Verify handlers are called timer._handle_create_trigger.assert_called_with(trigger_db) - @mock.patch('st2common.transport.reactor.TriggerDispatcher.dispatch') + @mock.patch("st2common.transport.reactor.TriggerDispatcher.dispatch") def test_timer_trace_tag_creation(self, dispatch_mock): timer = St2Timer() timer._scheduler = mock.Mock() @@ -82,11 +87,14 @@ def test_timer_trace_tag_creation(self, dispatch_mock): # Add a dummy timer Trigger object type_ = list(TIMER_TRIGGER_TYPES.keys())[0] - parameters = {'unit': 'seconds', 'delta': 1} - trigger_db = TriggerDB(name='test_trigger_1', pack='dummy', type=type_, - parameters=parameters) + parameters = {"unit": "seconds", "delta": 1} + trigger_db = TriggerDB( + name="test_trigger_1", pack="dummy", type=type_, parameters=parameters + ) timer.add_trigger(trigger_db) timer._emit_trigger_instance(trigger=trigger_db.to_serializable_dict()) - self.assertEqual(dispatch_mock.call_args[1]['trace_context'].trace_tag, - '%s-%s' % (TIMER_TRIGGER_TYPES[type_]['name'], trigger_db.name)) + self.assertEqual( + dispatch_mock.call_args[1]["trace_context"].trace_tag, + "%s-%s" % (TIMER_TRIGGER_TYPES[type_]["name"], trigger_db.name), + ) diff --git a/st2stream/dist_utils.py b/st2stream/dist_utils.py index a6f62c8cc2a..2f2043cf29f 100644 --- a/st2stream/dist_utils.py +++ b/st2stream/dist_utils.py @@ -43,17 +43,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" __all__ = [ - 'check_pip_is_installed', - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_is_installed", + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] @@ -64,15 +64,15 @@ def check_pip_is_installed(): try: import pip # NOQA except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) return True -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ @@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'): import pip if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) return True @@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path): reqs = [] def _get_link(line): - vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+'] + vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"] for vcs_prefix in vcs_prefixes: - if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)): - req_name = re.findall('.*#egg=(.+)([&|@]).*$', line) + if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)): + req_name = re.findall(".*#egg=(.+)([&|@]).*$", line) if not req_name: - req_name = re.findall('.*#egg=(.+?)$', line) + req_name = re.findall(".*#egg=(.+?)$", line) else: req_name = req_name[0] if not req_name: - raise ValueError('Line "%s" is missing "#egg="' % (line)) + raise ValueError( + 'Line "%s" is missing "#egg="' % (line) + ) - link = line.replace('-e ', '').strip() + link = line.replace("-e ", "").strip() return link, req_name[0] return None, None - with open(requirements_file_path, 'r') as fp: + with open(requirements_file_path, "r") as fp: for line in fp.readlines(): line = line.strip() - if line.startswith('#') or not line: + if line.startswith("#") or not line: continue link, req_name = _get_link(line=line) @@ -131,8 +135,8 @@ def _get_link(line): else: req_name = line - if ';' in req_name: - req_name = req_name.split(';')[0].strip() + if ";" in req_name: + req_name = req_name.split(";")[0].strip() reqs.append(req_name) @@ -146,7 +150,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -155,14 +159,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/st2stream/setup.py b/st2stream/setup.py index af6b302f5d2..f34692affcb 100644 --- a/st2stream/setup.py +++ b/st2stream/setup.py @@ -22,9 +22,9 @@ from dist_utils import apply_vagrant_workaround from st2stream import __version__ -ST2_COMPONENT = 'st2stream' +ST2_COMPONENT = "st2stream" BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt') +REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt") install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE) @@ -32,18 +32,18 @@ setup( name=ST2_COMPONENT, version=__version__, - description='{} StackStorm event-driven automation platform component'.format(ST2_COMPONENT), - author='StackStorm', - author_email='info@stackstorm.com', - license='Apache License (2.0)', - url='https://stackstorm.com/', + description="{} StackStorm event-driven automation platform component".format( + ST2_COMPONENT + ), + author="StackStorm", + author_email="info@stackstorm.com", + license="Apache License (2.0)", + url="https://stackstorm.com/", install_requires=install_reqs, dependency_links=dep_links, test_suite=ST2_COMPONENT, zip_safe=False, include_package_data=True, - packages=find_packages(exclude=['setuptools', 'tests']), - scripts=[ - 'bin/st2stream' - ] + packages=find_packages(exclude=["setuptools", "tests"]), + scripts=["bin/st2stream"], ) diff --git a/st2stream/st2stream/__init__.py b/st2stream/st2stream/__init__.py index bbe290db9a7..e6d3f15e0bd 100644 --- a/st2stream/st2stream/__init__.py +++ b/st2stream/st2stream/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = "3.4dev" diff --git a/st2stream/st2stream/app.py b/st2stream/st2stream/app.py index 73d32eb4cfe..0932dfcc27d 100644 --- a/st2stream/st2stream/app.py +++ b/st2stream/st2stream/app.py @@ -43,9 +43,9 @@ def setup_app(config={}): - LOG.info('Creating st2stream: %s as OpenAPI app.', VERSION_STRING) + LOG.info("Creating st2stream: %s as OpenAPI app.", VERSION_STRING) - is_gunicorn = config.get('is_gunicorn', False) + is_gunicorn = config.get("is_gunicorn", False) if is_gunicorn: # Note: We need to perform monkey patching in the worker. If we do it in # the master process (gunicorn_config.py), it breaks tons of things @@ -54,30 +54,33 @@ def setup_app(config={}): st2stream_config.register_opts() capabilities = { - 'name': 'stream', - 'listen_host': cfg.CONF.stream.host, - 'listen_port': cfg.CONF.stream.port, - 'type': 'active' + "name": "stream", + "listen_host": cfg.CONF.stream.host, + "listen_port": cfg.CONF.stream.port, + "type": "active", } # This should be called in gunicorn case because we only want # workers to connect to db, rabbbitmq etc. In standalone HTTP # server case, this setup would have already occurred. - common_setup(service='stream', config=st2stream_config, setup_db=True, - register_mq_exchanges=True, - register_signal_handlers=True, - register_internal_trigger_types=False, - run_migrations=False, - service_registry=True, - capabilities=capabilities, - config_args=config.get('config_args', None)) + common_setup( + service="stream", + config=st2stream_config, + setup_db=True, + register_mq_exchanges=True, + register_signal_handlers=True, + register_internal_trigger_types=False, + run_migrations=False, + service_registry=True, + capabilities=capabilities, + config_args=config.get("config_args", None), + ) - router = Router(debug=cfg.CONF.stream.debug, auth=cfg.CONF.auth.enable, - is_gunicorn=is_gunicorn) + router = Router( + debug=cfg.CONF.stream.debug, auth=cfg.CONF.auth.enable, is_gunicorn=is_gunicorn + ) - spec = spec_loader.load_spec('st2common', 'openapi.yaml.j2') - transforms = { - '^/stream/v1/': ['/', '/v1/'] - } + spec = spec_loader.load_spec("st2common", "openapi.yaml.j2") + transforms = {"^/stream/v1/": ["/", "/v1/"]} router.add_spec(spec, transforms=transforms) app = router.as_wsgi @@ -87,8 +90,8 @@ def setup_app(config={}): app = ErrorHandlingMiddleware(app) app = CorsMiddleware(app) app = LoggingMiddleware(app, router) - app = ResponseInstrumentationMiddleware(app, router, service_name='stream') + app = ResponseInstrumentationMiddleware(app, router, service_name="stream") app = RequestIDMiddleware(app) - app = RequestInstrumentationMiddleware(app, router, service_name='stream') + app = RequestInstrumentationMiddleware(app, router, service_name="stream") return app diff --git a/st2stream/st2stream/cmd/__init__.py b/st2stream/st2stream/cmd/__init__.py index 4d6cd0332d3..85b1f07d71a 100644 --- a/st2stream/st2stream/cmd/__init__.py +++ b/st2stream/st2stream/cmd/__init__.py @@ -15,4 +15,4 @@ from st2stream.cmd import api -__all__ = ['api'] +__all__ = ["api"] diff --git a/st2stream/st2stream/cmd/api.py b/st2stream/st2stream/cmd/api.py index cc1eec7d17c..b4ce963ea58 100644 --- a/st2stream/st2stream/cmd/api.py +++ b/st2stream/st2stream/cmd/api.py @@ -14,6 +14,7 @@ # limitations under the License. from st2common.util.monkey_patch import monkey_patch + monkey_patch() import os @@ -30,20 +31,20 @@ from st2common.util.wsgi import shutdown_server_kill_pending_requests from st2stream.signal_handlers import register_stream_signal_handlers from st2stream import config + config.register_opts() from st2stream import app -__all__ = [ - 'main' -] +__all__ = ["main"] eventlet.monkey_patch( os=True, select=True, socket=True, - thread=False if '--use-debugger' in sys.argv else True, - time=True) + thread=False if "--use-debugger" in sys.argv else True, + time=True, +) LOG = logging.getLogger(__name__) @@ -53,29 +54,43 @@ def _setup(): capabilities = { - 'name': 'stream', - 'listen_host': cfg.CONF.stream.host, - 'listen_port': cfg.CONF.stream.port, - 'type': 'active' + "name": "stream", + "listen_host": cfg.CONF.stream.host, + "listen_port": cfg.CONF.stream.port, + "type": "active", } - common_setup(service='stream', config=config, setup_db=True, register_mq_exchanges=True, - register_signal_handlers=True, register_internal_trigger_types=False, - run_migrations=False, service_registry=True, capabilities=capabilities) + common_setup( + service="stream", + config=config, + setup_db=True, + register_mq_exchanges=True, + register_signal_handlers=True, + register_internal_trigger_types=False, + run_migrations=False, + service_registry=True, + capabilities=capabilities, + ) def _run_server(): host = cfg.CONF.stream.host port = cfg.CONF.stream.port - LOG.info('(PID=%s) ST2 Stream API is serving on http://%s:%s.', os.getpid(), host, port) + LOG.info( + "(PID=%s) ST2 Stream API is serving on http://%s:%s.", os.getpid(), host, port + ) max_pool_size = eventlet.wsgi.DEFAULT_MAX_SIMULTANEOUS_REQUESTS worker_pool = eventlet.GreenPool(max_pool_size) sock = eventlet.listen((host, port)) def queue_shutdown(signal_number, stack_frame): - eventlet.spawn_n(shutdown_server_kill_pending_requests, sock=sock, - worker_pool=worker_pool, wait_time=WSGI_SERVER_REQUEST_SHUTDOWN_TIME) + eventlet.spawn_n( + shutdown_server_kill_pending_requests, + sock=sock, + worker_pool=worker_pool, + wait_time=WSGI_SERVER_REQUEST_SHUTDOWN_TIME, + ) # We register a custom SIGINT handler which allows us to kill long running active requests. # Note: Eventually we will support draining (waiting for short-running requests), but we @@ -97,12 +112,12 @@ def main(): except SystemExit as exit_code: sys.exit(exit_code) except KeyboardInterrupt: - listener = get_listener_if_set(name='stream') + listener = get_listener_if_set(name="stream") if listener: listener.shutdown() except Exception: - LOG.exception('(PID=%s) ST2 Stream API quit due to exception.', os.getpid()) + LOG.exception("(PID=%s) ST2 Stream API quit due to exception.", os.getpid()) return 1 finally: _teardown() diff --git a/st2stream/st2stream/config.py b/st2stream/st2stream/config.py index fe068dc0b24..bc117b556aa 100644 --- a/st2stream/st2stream/config.py +++ b/st2stream/st2stream/config.py @@ -32,8 +32,11 @@ def parse_args(args=None): - cfg.CONF(args=args, version=VERSION_STRING, - default_config_files=[DEFAULT_CONFIG_FILE_PATH]) + cfg.CONF( + args=args, + version=VERSION_STRING, + default_config_files=[DEFAULT_CONFIG_FILE_PATH], + ) def register_opts(): @@ -54,17 +57,15 @@ def _register_app_opts(): # config since they are also used outside st2stream api_opts = [ cfg.StrOpt( - 'host', default='127.0.0.1', - help='StackStorm stream API server host'), - cfg.IntOpt( - 'port', default=9102, - help='StackStorm API stream, server port'), - cfg.BoolOpt( - 'debug', default=False, - help='Specify to enable debug mode.'), + "host", default="127.0.0.1", help="StackStorm stream API server host" + ), + cfg.IntOpt("port", default=9102, help="StackStorm API stream, server port"), + cfg.BoolOpt("debug", default=False, help="Specify to enable debug mode."), cfg.StrOpt( - 'logging', default='/etc/st2/logging.stream.conf', - help='location of the logging.conf file') + "logging", + default="/etc/st2/logging.stream.conf", + help="location of the logging.conf file", + ), ] - CONF.register_opts(api_opts, group='stream') + CONF.register_opts(api_opts, group="stream") diff --git a/st2stream/st2stream/controllers/v1/executions.py b/st2stream/st2stream/controllers/v1/executions.py index 379491e9785..70023b87454 100644 --- a/st2stream/st2stream/controllers/v1/executions.py +++ b/st2stream/st2stream/controllers/v1/executions.py @@ -30,47 +30,46 @@ from st2common.rbac.types import PermissionType from st2common.stream.listener import get_listener -__all__ = [ - 'ActionExecutionOutputStreamController' -] +__all__ = ["ActionExecutionOutputStreamController"] LOG = logging.getLogger(__name__) # Event which is returned when no more data will be produced on this stream endpoint before closing # the connection. -NO_MORE_DATA_EVENT = 'event: EOF\ndata: \'\'\n\n' +NO_MORE_DATA_EVENT = "event: EOF\ndata: ''\n\n" class ActionExecutionOutputStreamController(ResourceController): model = ActionExecutionAPI access = ActionExecution - supported_filters = { - 'output_type': 'output_type' - } + supported_filters = {"output_type": "output_type"} CLOSE_STREAM_LIVEACTION_STATES = action_constants.LIVEACTION_COMPLETED_STATES + [ action_constants.LIVEACTION_STATUS_PAUSING, - action_constants.LIVEACTION_STATUS_RESUMING + action_constants.LIVEACTION_STATUS_RESUMING, ] - def get_one(self, id, output_type='all', requester_user=None): + def get_one(self, id, output_type="all", requester_user=None): # Special case for id == "last" - if id == 'last': - execution_db = ActionExecution.query().order_by('-id').limit(1).first() + if id == "last": + execution_db = ActionExecution.query().order_by("-id").limit(1).first() if not execution_db: - raise ValueError('No executions found in the database') + raise ValueError("No executions found in the database") id = str(execution_db.id) - execution_db = self._get_one_by_id(id=id, requester_user=requester_user, - permission_type=PermissionType.EXECUTION_VIEW) + execution_db = self._get_one_by_id( + id=id, + requester_user=requester_user, + permission_type=PermissionType.EXECUTION_VIEW, + ) execution_id = str(execution_db.id) query_filters = {} - if output_type and output_type != 'all': - query_filters['output_type'] = output_type + if output_type and output_type != "all": + query_filters["output_type"] = output_type def format_output_object(output_db_or_api): if isinstance(output_db_or_api, ActionExecutionOutputDB): @@ -78,25 +77,27 @@ def format_output_object(output_db_or_api): elif isinstance(output_db_or_api, ActionExecutionOutputAPI): data = output_db_or_api else: - raise ValueError('Unsupported format: %s' % (type(output_db_or_api))) + raise ValueError("Unsupported format: %s" % (type(output_db_or_api))) - event = 'st2.execution.output__create' - result = 'event: %s\ndata: %s\n\n' % (event, json_encode(data, indent=None)) + event = "st2.execution.output__create" + result = "event: %s\ndata: %s\n\n" % (event, json_encode(data, indent=None)) return result def existing_output_iter(): # Consume and return all of the existing lines - output_dbs = ActionExecutionOutput.query(execution_id=execution_id, **query_filters) + output_dbs = ActionExecutionOutput.query( + execution_id=execution_id, **query_filters + ) # Note: We return all at once instead of yield line by line to avoid multiple socket # writes and to achieve better performance output = [format_output_object(output_db) for output_db in output_dbs] - output = ''.join(output) - yield six.binary_type(output.encode('utf-8')) + output = "".join(output) + yield six.binary_type(output.encode("utf-8")) def new_output_iter(): def noop_gen(): - yield six.binary_type(NO_MORE_DATA_EVENT.encode('utf-8')) + yield six.binary_type(NO_MORE_DATA_EVENT.encode("utf-8")) # Bail out if execution has already completed / been paused if execution_db.status in self.CLOSE_STREAM_LIVEACTION_STATES: @@ -104,7 +105,9 @@ def noop_gen(): # Wait for and return any new line which may come in execution_ids = [execution_id] - listener = get_listener(name='execution_output') # pylint: disable=no-member + listener = get_listener( + name="execution_output" + ) # pylint: disable=no-member gen = listener.generator(execution_ids=execution_ids) def format(gen): @@ -117,28 +120,37 @@ def format(gen): # Note: gunicorn wsgi handler expect bytes, not unicode # pylint: disable=no-member if isinstance(model_api, ActionExecutionOutputAPI): - if output_type and output_type != 'all' and \ - model_api.output_type != output_type: + if ( + output_type + and output_type != "all" + and model_api.output_type != output_type + ): continue - output = format_output_object(model_api).encode('utf-8') + output = format_output_object(model_api).encode("utf-8") yield six.binary_type(output) elif isinstance(model_api, ActionExecutionAPI): if model_api.status in self.CLOSE_STREAM_LIVEACTION_STATES: - yield six.binary_type(NO_MORE_DATA_EVENT.encode('utf-8')) + yield six.binary_type( + NO_MORE_DATA_EVENT.encode("utf-8") + ) break else: - LOG.debug('Unrecognized message type: %s' % (model_api)) + LOG.debug("Unrecognized message type: %s" % (model_api)) gen = format(gen) return gen def make_response(): app_iter = itertools.chain(existing_output_iter(), new_output_iter()) - res = Response(headerlist=[("X-Accel-Buffering", "no"), - ('Cache-Control', 'no-cache'), - ("Content-Type", "text/event-stream; charset=UTF-8")], - app_iter=app_iter) + res = Response( + headerlist=[ + ("X-Accel-Buffering", "no"), + ("Cache-Control", "no-cache"), + ("Content-Type", "text/event-stream; charset=UTF-8"), + ], + app_iter=app_iter, + ) return res res = make_response() diff --git a/st2stream/st2stream/controllers/v1/root.py b/st2stream/st2stream/controllers/v1/root.py index c9873127a63..2b9178f7853 100644 --- a/st2stream/st2stream/controllers/v1/root.py +++ b/st2stream/st2stream/controllers/v1/root.py @@ -15,9 +15,7 @@ from st2stream.controllers.v1.stream import StreamController -__all__ = [ - 'RootController' -] +__all__ = ["RootController"] class RootController(object): diff --git a/st2stream/st2stream/controllers/v1/stream.py b/st2stream/st2stream/controllers/v1/stream.py index 19c7d71b1d3..f6995c3300c 100644 --- a/st2stream/st2stream/controllers/v1/stream.py +++ b/st2stream/st2stream/controllers/v1/stream.py @@ -21,58 +21,70 @@ from st2common.util.jsonify import json_encode from st2common.stream.listener import get_listener -__all__ = [ - 'StreamController' -] +__all__ = ["StreamController"] LOG = logging.getLogger(__name__) DEFAULT_EVENTS_WHITELIST = [ - 'st2.announcement__*', - - 'st2.execution__create', - 'st2.execution__update', - 'st2.execution__delete', - - 'st2.liveaction__create', - 'st2.liveaction__update', - 'st2.liveaction__delete', + "st2.announcement__*", + "st2.execution__create", + "st2.execution__update", + "st2.execution__delete", + "st2.liveaction__create", + "st2.liveaction__update", + "st2.liveaction__delete", ] def format(gen): - message = '''event: %s\ndata: %s\n\n''' + message = """event: %s\ndata: %s\n\n""" for pack in gen: if not pack: # Note: gunicorn wsgi handler expect bytes, not unicode - yield six.binary_type(b'\n') + yield six.binary_type(b"\n") else: (event, body) = pack # Note: gunicorn wsgi handler expect bytes, not unicode - yield six.binary_type((message % (event, json_encode(body, - indent=None))).encode('utf-8')) + yield six.binary_type( + (message % (event, json_encode(body, indent=None))).encode("utf-8") + ) class StreamController(object): - def get_all(self, end_execution_id=None, end_event=None, - events=None, action_refs=None, execution_ids=None, requester_user=None): + def get_all( + self, + end_execution_id=None, + end_event=None, + events=None, + action_refs=None, + execution_ids=None, + requester_user=None, + ): events = events if events else DEFAULT_EVENTS_WHITELIST action_refs = action_refs if action_refs else None execution_ids = execution_ids if execution_ids else None def make_response(): - listener = get_listener(name='stream') - app_iter = format(listener.generator(events=events, - action_refs=action_refs, - end_event=end_event, - end_statuses=action_constants.LIVEACTION_COMPLETED_STATES, - end_execution_id=end_execution_id, - execution_ids=execution_ids)) - res = Response(headerlist=[("X-Accel-Buffering", "no"), - ('Cache-Control', 'no-cache'), - ("Content-Type", "text/event-stream; charset=UTF-8")], - app_iter=app_iter) + listener = get_listener(name="stream") + app_iter = format( + listener.generator( + events=events, + action_refs=action_refs, + end_event=end_event, + end_statuses=action_constants.LIVEACTION_COMPLETED_STATES, + end_execution_id=end_execution_id, + execution_ids=execution_ids, + ) + ) + res = Response( + headerlist=[ + ("X-Accel-Buffering", "no"), + ("Cache-Control", "no-cache"), + ("Content-Type", "text/event-stream; charset=UTF-8"), + ], + app_iter=app_iter, + ) return res stream = make_response() diff --git a/st2stream/st2stream/signal_handlers.py b/st2stream/st2stream/signal_handlers.py index 56bc06450a1..b292d8b67b3 100644 --- a/st2stream/st2stream/signal_handlers.py +++ b/st2stream/st2stream/signal_handlers.py @@ -15,9 +15,7 @@ import signal -__all__ = [ - 'register_stream_signal_handlers' -] +__all__ = ["register_stream_signal_handlers"] def register_stream_signal_handlers(handler_func): diff --git a/st2stream/st2stream/wsgi.py b/st2stream/st2stream/wsgi.py index c177572ba1b..14d847e2a1e 100644 --- a/st2stream/st2stream/wsgi.py +++ b/st2stream/st2stream/wsgi.py @@ -18,8 +18,11 @@ from st2stream import app config = { - 'is_gunicorn': True, - 'config_args': ['--config-file', os.environ.get('ST2_CONFIG_PATH', '/etc/st2/st2.conf')] + "is_gunicorn": True, + "config_args": [ + "--config-file", + os.environ.get("ST2_CONFIG_PATH", "/etc/st2/st2.conf"), + ], } application = app.setup_app(config) diff --git a/st2stream/tests/unit/controllers/v1/base.py b/st2stream/tests/unit/controllers/v1/base.py index 24a59a5cd02..4f6e2ca336f 100644 --- a/st2stream/tests/unit/controllers/v1/base.py +++ b/st2stream/tests/unit/controllers/v1/base.py @@ -16,9 +16,7 @@ from st2stream import app from st2tests.api import BaseFunctionalTest -__all__ = [ - 'FunctionalTest' -] +__all__ = ["FunctionalTest"] class FunctionalTest(BaseFunctionalTest): diff --git a/st2stream/tests/unit/controllers/v1/test_stream.py b/st2stream/tests/unit/controllers/v1/test_stream.py index 7ff7e62f3db..c67f3e27824 100644 --- a/st2stream/tests/unit/controllers/v1/test_stream.py +++ b/st2stream/tests/unit/controllers/v1/test_stream.py @@ -34,88 +34,72 @@ RUNNER_TYPE_1 = { - 'description': '', - 'enabled': True, - 'name': 'local-shell-cmd', - 'runner_module': 'local_runner', - 'runner_parameters': {} + "description": "", + "enabled": True, + "name": "local-shell-cmd", + "runner_module": "local_runner", + "runner_parameters": {}, } ACTION_1 = { - 'name': 'st2.dummy.action1', - 'description': 'test description', - 'enabled': True, - 'entry_point': '/tmp/test/action1.sh', - 'pack': 'sixpack', - 'runner_type': 'local-shell-cmd', - 'parameters': { - 'a': { - 'type': 'string', - 'default': 'abc' - }, - 'b': { - 'type': 'number', - 'default': 123 - }, - 'c': { - 'type': 'number', - 'default': 123, - 'immutable': True - }, - 'd': { - 'type': 'string', - 'secret': True - } - } + "name": "st2.dummy.action1", + "description": "test description", + "enabled": True, + "entry_point": "/tmp/test/action1.sh", + "pack": "sixpack", + "runner_type": "local-shell-cmd", + "parameters": { + "a": {"type": "string", "default": "abc"}, + "b": {"type": "number", "default": 123}, + "c": {"type": "number", "default": 123, "immutable": True}, + "d": {"type": "string", "secret": True}, + }, } LIVE_ACTION_1 = { - 'action': 'sixpack.st2.dummy.action1', - 'parameters': { - 'hosts': 'localhost', - 'cmd': 'uname -a', - 'd': SUPER_SECRET_PARAMETER - } + "action": "sixpack.st2.dummy.action1", + "parameters": { + "hosts": "localhost", + "cmd": "uname -a", + "d": SUPER_SECRET_PARAMETER, + }, } EXECUTION_1 = { - 'id': '598dbf0c0640fd54bffc688b', - 'action': { - 'ref': 'sixpack.st2.dummy.action1' + "id": "598dbf0c0640fd54bffc688b", + "action": {"ref": "sixpack.st2.dummy.action1"}, + "parameters": { + "hosts": "localhost", + "cmd": "uname -a", + "d": SUPER_SECRET_PARAMETER, }, - 'parameters': { - 'hosts': 'localhost', - 'cmd': 'uname -a', - 'd': SUPER_SECRET_PARAMETER - } } STDOUT_1 = { - 'execution_id': '598dbf0c0640fd54bffc688b', - 'action_ref': 'dummy.action1', - 'output_type': 'stdout' + "execution_id": "598dbf0c0640fd54bffc688b", + "action_ref": "dummy.action1", + "output_type": "stdout", } STDERR_1 = { - 'execution_id': '598dbf0c0640fd54bffc688b', - 'action_ref': 'dummy.action1', - 'output_type': 'stderr' + "execution_id": "598dbf0c0640fd54bffc688b", + "action_ref": "dummy.action1", + "output_type": "stderr", } class META(object): delivery_info = {} - def __init__(self, exchange='some', routing_key='thing'): - self.delivery_info['exchange'] = exchange - self.delivery_info['routing_key'] = routing_key + def __init__(self, exchange="some", routing_key="thing"): + self.delivery_info["exchange"] = exchange + self.delivery_info["routing_key"] = routing_key def ack(self): pass class TestStreamController(FunctionalTest): - @classmethod def setUpClass(cls): super(TestStreamController, cls).setUpClass() @@ -126,33 +110,35 @@ def setUpClass(cls): instance = ActionAPI(**ACTION_1) Action.add_or_update(ActionAPI.to_model(instance)) - @mock.patch.object(st2common.stream.listener, 'listen', mock.Mock()) - @mock.patch('st2stream.controllers.v1.stream.DEFAULT_EVENTS_WHITELIST', None) + @mock.patch.object(st2common.stream.listener, "listen", mock.Mock()) + @mock.patch("st2stream.controllers.v1.stream.DEFAULT_EVENTS_WHITELIST", None) def test_get_all(self): resp = stream.StreamController().get_all() - self.assertEqual(resp._status, '200 OK') - self.assertIn(('Content-Type', 'text/event-stream; charset=UTF-8'), resp._headerlist) + self.assertEqual(resp._status, "200 OK") + self.assertIn( + ("Content-Type", "text/event-stream; charset=UTF-8"), resp._headerlist + ) - listener = st2common.stream.listener.get_listener(name='stream') + listener = st2common.stream.listener.get_listener(name="stream") process = listener.processor(LiveActionAPI) message = None for message in resp._app_iter: - message = message.decode('utf-8') - if message != '\n': + message = message.decode("utf-8") + if message != "\n": break process(LiveActionDB(**LIVE_ACTION_1), META()) - self.assertIn('event: some__thing', message) + self.assertIn("event: some__thing", message) self.assertIn('data: {"', message) self.assertNotIn(SUPER_SECRET_PARAMETER, message) - @mock.patch.object(st2common.stream.listener, 'listen', mock.Mock()) + @mock.patch.object(st2common.stream.listener, "listen", mock.Mock()) def test_get_all_with_filters(self): - cfg.CONF.set_override(name='heartbeat', group='stream', override=0.1) + cfg.CONF.set_override(name="heartbeat", group="stream", override=0.1) - listener = st2common.stream.listener.get_listener(name='stream') + listener = st2common.stream.listener.get_listener(name="stream") process_execution = listener.processor(ActionExecutionAPI) process_liveaction = listener.processor(LiveActionAPI) process_output = listener.processor(ActionExecutionOutputAPI) @@ -164,50 +150,50 @@ def test_get_all_with_filters(self): output_api_stderr = ActionExecutionOutputDB(**STDERR_1) def dispatch_and_handle_mock_data(resp): - received_messages_data = '' + received_messages_data = "" for index, message in enumerate(resp._app_iter): if message.strip(): - received_messages_data += message.decode('utf-8') + received_messages_data += message.decode("utf-8") # Dispatch some mock events if index == 0: - meta = META('st2.execution', 'create') + meta = META("st2.execution", "create") process_execution(execution_api, meta) elif index == 1: - meta = META('st2.execution', 'update') + meta = META("st2.execution", "update") process_execution(execution_api, meta) elif index == 2: - meta = META('st2.execution', 'delete') + meta = META("st2.execution", "delete") process_execution(execution_api, meta) elif index == 3: - meta = META('st2.liveaction', 'create') + meta = META("st2.liveaction", "create") process_liveaction(liveaction_api, meta) elif index == 4: - meta = META('st2.liveaction', 'create') + meta = META("st2.liveaction", "create") process_liveaction(liveaction_api, meta) elif index == 5: - meta = META('st2.liveaction', 'delete') + meta = META("st2.liveaction", "delete") process_liveaction(liveaction_api, meta) elif index == 6: - meta = META('st2.liveaction', 'delete') + meta = META("st2.liveaction", "delete") process_liveaction(liveaction_api, meta) elif index == 7: - meta = META('st2.announcement', 'chatops') + meta = META("st2.announcement", "chatops") process_no_api_model({}, meta) elif index == 8: - meta = META('st2.execution.output', 'create') + meta = META("st2.execution.output", "create") process_output(output_api_stdout, meta) elif index == 9: - meta = META('st2.execution.output', 'create') + meta = META("st2.execution.output", "create") process_output(output_api_stderr, meta) elif index == 10: - meta = META('st2.announcement', 'errbot') + meta = META("st2.announcement", "errbot") process_no_api_model({}, meta) else: break - received_messages = received_messages_data.split('\n\n') + received_messages = received_messages_data.split("\n\n") received_messages = [message for message in received_messages if message] return received_messages @@ -217,10 +203,10 @@ def dispatch_and_handle_mock_data(resp): received_messages = dispatch_and_handle_mock_data(resp) self.assertEqual(len(received_messages), 9) - self.assertIn('st2.execution__create', received_messages[0]) - self.assertIn('st2.liveaction__delete', received_messages[5]) - self.assertIn('st2.announcement__chatops', received_messages[7]) - self.assertIn('st2.announcement__errbot', received_messages[8]) + self.assertIn("st2.execution__create", received_messages[0]) + self.assertIn("st2.liveaction__delete", received_messages[5]) + self.assertIn("st2.announcement__chatops", received_messages[7]) + self.assertIn("st2.announcement__errbot", received_messages[8]) # 1. ?events= filter # No filter provided - all messages should be received @@ -229,79 +215,79 @@ def dispatch_and_handle_mock_data(resp): received_messages = dispatch_and_handle_mock_data(resp) self.assertEqual(len(received_messages), 11) - self.assertIn('st2.execution__create', received_messages[0]) - self.assertIn('st2.announcement__chatops', received_messages[7]) - self.assertIn('st2.execution.output__create', received_messages[8]) - self.assertIn('st2.execution.output__create', received_messages[9]) - self.assertIn('st2.announcement__errbot', received_messages[10]) + self.assertIn("st2.execution__create", received_messages[0]) + self.assertIn("st2.announcement__chatops", received_messages[7]) + self.assertIn("st2.execution.output__create", received_messages[8]) + self.assertIn("st2.execution.output__create", received_messages[9]) + self.assertIn("st2.announcement__errbot", received_messages[10]) # Filter provided, only three messages should be received - events = ['st2.execution__create', 'st2.liveaction__delete'] + events = ["st2.execution__create", "st2.liveaction__delete"] resp = stream.StreamController().get_all(events=events) received_messages = dispatch_and_handle_mock_data(resp) self.assertEqual(len(received_messages), 3) - self.assertIn('st2.execution__create', received_messages[0]) - self.assertIn('st2.liveaction__delete', received_messages[1]) - self.assertIn('st2.liveaction__delete', received_messages[2]) + self.assertIn("st2.execution__create", received_messages[0]) + self.assertIn("st2.liveaction__delete", received_messages[1]) + self.assertIn("st2.liveaction__delete", received_messages[2]) # Filter provided, only three messages should be received - events = ['st2.liveaction__create', 'st2.liveaction__delete'] + events = ["st2.liveaction__create", "st2.liveaction__delete"] resp = stream.StreamController().get_all(events=events) received_messages = dispatch_and_handle_mock_data(resp) self.assertEqual(len(received_messages), 4) - self.assertIn('st2.liveaction__create', received_messages[0]) - self.assertIn('st2.liveaction__create', received_messages[1]) - self.assertIn('st2.liveaction__delete', received_messages[2]) - self.assertIn('st2.liveaction__delete', received_messages[3]) + self.assertIn("st2.liveaction__create", received_messages[0]) + self.assertIn("st2.liveaction__create", received_messages[1]) + self.assertIn("st2.liveaction__delete", received_messages[2]) + self.assertIn("st2.liveaction__delete", received_messages[3]) # Glob filter - events = ['st2.announcement__*'] + events = ["st2.announcement__*"] resp = stream.StreamController().get_all(events=events) received_messages = dispatch_and_handle_mock_data(resp) self.assertEqual(len(received_messages), 2) - self.assertIn('st2.announcement__chatops', received_messages[0]) - self.assertIn('st2.announcement__errbot', received_messages[1]) + self.assertIn("st2.announcement__chatops", received_messages[0]) + self.assertIn("st2.announcement__errbot", received_messages[1]) # Filter provided - events = ['st2.execution.output__create'] + events = ["st2.execution.output__create"] resp = stream.StreamController().get_all(events=events) received_messages = dispatch_and_handle_mock_data(resp) self.assertEqual(len(received_messages), 2) - self.assertIn('st2.execution.output__create', received_messages[0]) - self.assertIn('st2.execution.output__create', received_messages[1]) + self.assertIn("st2.execution.output__create", received_messages[0]) + self.assertIn("st2.execution.output__create", received_messages[1]) # Filter provided, invalid , no message should be received - events = ['invalid1', 'invalid2'] + events = ["invalid1", "invalid2"] resp = stream.StreamController().get_all(events=events) received_messages = dispatch_and_handle_mock_data(resp) self.assertEqual(len(received_messages), 0) # 2. ?action_refs= filter - action_refs = ['invalid1', 'invalid2'] + action_refs = ["invalid1", "invalid2"] resp = stream.StreamController().get_all(action_refs=action_refs) received_messages = dispatch_and_handle_mock_data(resp) self.assertEqual(len(received_messages), 0) - action_refs = ['dummy.action1'] + action_refs = ["dummy.action1"] resp = stream.StreamController().get_all(action_refs=action_refs) received_messages = dispatch_and_handle_mock_data(resp) self.assertEqual(len(received_messages), 2) # 3. ?execution_ids= filter - execution_ids = ['invalid1', 'invalid2'] + execution_ids = ["invalid1", "invalid2"] resp = stream.StreamController().get_all(execution_ids=execution_ids) received_messages = dispatch_and_handle_mock_data(resp) self.assertEqual(len(received_messages), 0) - execution_ids = [EXECUTION_1['id']] + execution_ids = [EXECUTION_1["id"]] resp = stream.StreamController().get_all(execution_ids=execution_ids) received_messages = dispatch_and_handle_mock_data(resp) diff --git a/st2stream/tests/unit/controllers/v1/test_stream_execution_output.py b/st2stream/tests/unit/controllers/v1/test_stream_execution_output.py index deb76b4e979..d14dd029e83 100644 --- a/st2stream/tests/unit/controllers/v1/test_stream_execution_output.py +++ b/st2stream/tests/unit/controllers/v1/test_stream_execution_output.py @@ -30,50 +30,54 @@ from .base import FunctionalTest -__all__ = [ - 'ActionExecutionOutputStreamControllerTestCase' -] +__all__ = ["ActionExecutionOutputStreamControllerTestCase"] class ActionExecutionOutputStreamControllerTestCase(FunctionalTest): def test_get_one_id_last_no_executions_in_the_database(self): ActionExecution.query().delete() - resp = self.app.get('/v1/executions/last/output', expect_errors=True) + resp = self.app.get("/v1/executions/last/output", expect_errors=True) self.assertEqual(resp.status_int, http_client.BAD_REQUEST) - self.assertEqual(resp.json['faultstring'], 'No executions found in the database') + self.assertEqual( + resp.json["faultstring"], "No executions found in the database" + ) def test_get_output_running_execution(self): # Retrieve lister instance to avoid race with listener connection not being established # early enough for tests to pass. # NOTE: This only affects tests where listeners are not pre-initialized. - listener = get_listener(name='execution_output') + listener = get_listener(name="execution_output") eventlet.sleep(1.0) # Test the execution output API endpoint for execution which is running (blocking) status = action_constants.LIVEACTION_STATUS_RUNNING timestamp = date_utils.get_datetime_utc_now() - action_execution_db = ActionExecutionDB(start_timestamp=timestamp, - end_timestamp=timestamp, - status=status, - action={'ref': 'core.local'}, - runner={'name': 'local-shell-cmd'}, - liveaction={'ref': 'foo'}) + action_execution_db = ActionExecutionDB( + start_timestamp=timestamp, + end_timestamp=timestamp, + status=status, + action={"ref": "core.local"}, + runner={"name": "local-shell-cmd"}, + liveaction={"ref": "foo"}, + ) action_execution_db = ActionExecution.add_or_update(action_execution_db) - output_params = dict(execution_id=str(action_execution_db.id), - action_ref='core.local', - runner_ref='dummy', - timestamp=timestamp, - output_type='stdout', - data='stdout before start\n') + output_params = dict( + execution_id=str(action_execution_db.id), + action_ref="core.local", + runner_ref="dummy", + timestamp=timestamp, + output_type="stdout", + data="stdout before start\n", + ) # Insert mock output object output_db = ActionExecutionOutputDB(**output_params) ActionExecutionOutput.add_or_update(output_db, publish=False) def insert_mock_data(): - output_params['data'] = 'stdout mid 1\n' + output_params["data"] = "stdout mid 1\n" output_db = ActionExecutionOutputDB(**output_params) ActionExecutionOutput.add_or_update(output_db) @@ -81,7 +85,7 @@ def insert_mock_data(): # spawn an eventlet which eventually finishes the action. def publish_action_finished(action_execution_db): # Insert mock output object - output_params['data'] = 'stdout pre finish 1\n' + output_params["data"] = "stdout pre finish 1\n" output_db = ActionExecutionOutputDB(**output_params) ActionExecutionOutput.add_or_update(output_db) @@ -96,28 +100,32 @@ def publish_action_finished(action_execution_db): # Retrieve data while execution is running - endpoint return new data once it's available # and block until the execution finishes - resp = self.app.get('/v1/executions/%s/output' % (str(action_execution_db.id)), - expect_errors=False) + resp = self.app.get( + "/v1/executions/%s/output" % (str(action_execution_db.id)), + expect_errors=False, + ) self.assertEqual(resp.status_int, 200) events = self._parse_response(resp.text) self.assertEqual(len(events), 4) - self.assertEqual(events[0][1]['data'], 'stdout before start\n') - self.assertEqual(events[1][1]['data'], 'stdout mid 1\n') - self.assertEqual(events[2][1]['data'], 'stdout pre finish 1\n') - self.assertEqual(events[3][0], 'EOF') + self.assertEqual(events[0][1]["data"], "stdout before start\n") + self.assertEqual(events[1][1]["data"], "stdout mid 1\n") + self.assertEqual(events[2][1]["data"], "stdout pre finish 1\n") + self.assertEqual(events[3][0], "EOF") # Once the execution is in completed state, existing output should be returned immediately - resp = self.app.get('/v1/executions/%s/output' % (str(action_execution_db.id)), - expect_errors=False) + resp = self.app.get( + "/v1/executions/%s/output" % (str(action_execution_db.id)), + expect_errors=False, + ) self.assertEqual(resp.status_int, 200) events = self._parse_response(resp.text) self.assertEqual(len(events), 4) - self.assertEqual(events[0][1]['data'], 'stdout before start\n') - self.assertEqual(events[1][1]['data'], 'stdout mid 1\n') - self.assertEqual(events[2][1]['data'], 'stdout pre finish 1\n') - self.assertEqual(events[3][0], 'EOF') + self.assertEqual(events[0][1]["data"], "stdout before start\n") + self.assertEqual(events[1][1]["data"], "stdout mid 1\n") + self.assertEqual(events[2][1]["data"], "stdout pre finish 1\n") + self.assertEqual(events[3][0], "EOF") listener.shutdown() @@ -127,49 +135,57 @@ def test_get_output_finished_execution(self): # Insert mock execution and output objects status = action_constants.LIVEACTION_STATUS_SUCCEEDED timestamp = date_utils.get_datetime_utc_now() - action_execution_db = ActionExecutionDB(start_timestamp=timestamp, - end_timestamp=timestamp, - status=status, - action={'ref': 'core.local'}, - runner={'name': 'local-shell-cmd'}, - liveaction={'ref': 'foo'}) + action_execution_db = ActionExecutionDB( + start_timestamp=timestamp, + end_timestamp=timestamp, + status=status, + action={"ref": "core.local"}, + runner={"name": "local-shell-cmd"}, + liveaction={"ref": "foo"}, + ) action_execution_db = ActionExecution.add_or_update(action_execution_db) for i in range(1, 6): - stdout_db = ActionExecutionOutputDB(execution_id=str(action_execution_db.id), - action_ref='core.local', - runner_ref='dummy', - timestamp=timestamp, - output_type='stdout', - data='stdout %s\n' % (i)) + stdout_db = ActionExecutionOutputDB( + execution_id=str(action_execution_db.id), + action_ref="core.local", + runner_ref="dummy", + timestamp=timestamp, + output_type="stdout", + data="stdout %s\n" % (i), + ) ActionExecutionOutput.add_or_update(stdout_db) for i in range(10, 15): - stderr_db = ActionExecutionOutputDB(execution_id=str(action_execution_db.id), - action_ref='core.local', - runner_ref='dummy', - timestamp=timestamp, - output_type='stderr', - data='stderr %s\n' % (i)) + stderr_db = ActionExecutionOutputDB( + execution_id=str(action_execution_db.id), + action_ref="core.local", + runner_ref="dummy", + timestamp=timestamp, + output_type="stderr", + data="stderr %s\n" % (i), + ) ActionExecutionOutput.add_or_update(stderr_db) - resp = self.app.get('/v1/executions/%s/output' % (str(action_execution_db.id)), - expect_errors=False) + resp = self.app.get( + "/v1/executions/%s/output" % (str(action_execution_db.id)), + expect_errors=False, + ) self.assertEqual(resp.status_int, 200) events = self._parse_response(resp.text) self.assertEqual(len(events), 11) - self.assertEqual(events[0][1]['data'], 'stdout 1\n') - self.assertEqual(events[9][1]['data'], 'stderr 14\n') - self.assertEqual(events[10][0], 'EOF') + self.assertEqual(events[0][1]["data"], "stdout 1\n") + self.assertEqual(events[9][1]["data"], "stderr 14\n") + self.assertEqual(events[10][0], "EOF") # Verify "last" short-hand id works - resp = self.app.get('/v1/executions/last/output', expect_errors=False) + resp = self.app.get("/v1/executions/last/output", expect_errors=False) self.assertEqual(resp.status_int, 200) events = self._parse_response(resp.text) self.assertEqual(len(events), 11) - self.assertEqual(events[10][0], 'EOF') + self.assertEqual(events[10][0], "EOF") def _parse_response(self, response): """ @@ -177,12 +193,12 @@ def _parse_response(self, response): """ events = [] - lines = response.strip().split('\n') + lines = response.strip().split("\n") for index, line in enumerate(lines): - if 'data:' in line: + if "data:" in line: e_line = lines[index - 1] - event_name = e_line[e_line.find('event: ') + len('event:'):].strip() - event_data = line[line.find('data: ') + len('data :'):].strip() + event_name = e_line[e_line.find("event: ") + len("event:") :].strip() + event_data = line[line.find("data: ") + len("data :") :].strip() event_data = json.loads(event_data) if len(event_data) > 2 else {} events.append((event_name, event_data)) diff --git a/st2tests/dist_utils.py b/st2tests/dist_utils.py index a6f62c8cc2a..2f2043cf29f 100644 --- a/st2tests/dist_utils.py +++ b/st2tests/dist_utils.py @@ -43,17 +43,17 @@ if PY3: text_type = str else: - text_type = unicode # noqa # pylint: disable=E0602 + text_type = unicode # noqa # pylint: disable=E0602 -GET_PIP = 'curl https://bootstrap.pypa.io/get-pip.py | python' +GET_PIP = "curl https://bootstrap.pypa.io/get-pip.py | python" __all__ = [ - 'check_pip_is_installed', - 'check_pip_version', - 'fetch_requirements', - 'apply_vagrant_workaround', - 'get_version_string', - 'parse_version_string' + "check_pip_is_installed", + "check_pip_version", + "fetch_requirements", + "apply_vagrant_workaround", + "get_version_string", + "parse_version_string", ] @@ -64,15 +64,15 @@ def check_pip_is_installed(): try: import pip # NOQA except ImportError as e: - print('Failed to import pip: %s' % (text_type(e))) - print('') - print('Download pip:\n%s' % (GET_PIP)) + print("Failed to import pip: %s" % (text_type(e))) + print("") + print("Download pip:\n%s" % (GET_PIP)) sys.exit(1) return True -def check_pip_version(min_version='6.0.0'): +def check_pip_version(min_version="6.0.0"): """ Ensure that a minimum supported version of pip is installed. """ @@ -81,10 +81,12 @@ def check_pip_version(min_version='6.0.0'): import pip if StrictVersion(pip.__version__) < StrictVersion(min_version): - print("Upgrade pip, your version '{0}' " - "is outdated. Minimum required version is '{1}':\n{2}".format(pip.__version__, - min_version, - GET_PIP)) + print( + "Upgrade pip, your version '{0}' " + "is outdated. Minimum required version is '{1}':\n{2}".format( + pip.__version__, min_version, GET_PIP + ) + ) sys.exit(1) return True @@ -98,30 +100,32 @@ def fetch_requirements(requirements_file_path): reqs = [] def _get_link(line): - vcs_prefixes = ['git+', 'svn+', 'hg+', 'bzr+'] + vcs_prefixes = ["git+", "svn+", "hg+", "bzr+"] for vcs_prefix in vcs_prefixes: - if line.startswith(vcs_prefix) or line.startswith('-e %s' % (vcs_prefix)): - req_name = re.findall('.*#egg=(.+)([&|@]).*$', line) + if line.startswith(vcs_prefix) or line.startswith("-e %s" % (vcs_prefix)): + req_name = re.findall(".*#egg=(.+)([&|@]).*$", line) if not req_name: - req_name = re.findall('.*#egg=(.+?)$', line) + req_name = re.findall(".*#egg=(.+?)$", line) else: req_name = req_name[0] if not req_name: - raise ValueError('Line "%s" is missing "#egg="' % (line)) + raise ValueError( + 'Line "%s" is missing "#egg="' % (line) + ) - link = line.replace('-e ', '').strip() + link = line.replace("-e ", "").strip() return link, req_name[0] return None, None - with open(requirements_file_path, 'r') as fp: + with open(requirements_file_path, "r") as fp: for line in fp.readlines(): line = line.strip() - if line.startswith('#') or not line: + if line.startswith("#") or not line: continue link, req_name = _get_link(line=line) @@ -131,8 +135,8 @@ def _get_link(line): else: req_name = line - if ';' in req_name: - req_name = req_name.split(';')[0].strip() + if ";" in req_name: + req_name = req_name.split(";")[0].strip() reqs.append(req_name) @@ -146,7 +150,7 @@ def apply_vagrant_workaround(): Note: Without this workaround, setup.py sdist will fail when running inside a shared directory (nfs / virtualbox shared folders). """ - if os.environ.get('USER', None) == 'vagrant': + if os.environ.get("USER", None) == "vagrant": del os.link @@ -155,14 +159,13 @@ def get_version_string(init_file): Read __version__ string for an init file. """ - with open(init_file, 'r') as fp: + with open(init_file, "r") as fp: content = fp.read() - version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", - content, re.M) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", content, re.M) if version_match: return version_match.group(1) - raise RuntimeError('Unable to find version string in %s.' % (init_file)) + raise RuntimeError("Unable to find version string in %s." % (init_file)) # alias for get_version_string diff --git a/st2tests/integration/orquesta/base.py b/st2tests/integration/orquesta/base.py index 52e2277e4cf..f5f13cce04b 100644 --- a/st2tests/integration/orquesta/base.py +++ b/st2tests/integration/orquesta/base.py @@ -30,7 +30,7 @@ LIVEACTION_LAUNCHED_STATUSES = [ action_constants.LIVEACTION_STATUS_REQUESTED, action_constants.LIVEACTION_STATUS_SCHEDULED, - action_constants.LIVEACTION_STATUS_RUNNING + action_constants.LIVEACTION_STATUS_RUNNING, ] DEFAULT_WAIT_FIXED = 500 @@ -42,10 +42,9 @@ def retry_on_exceptions(exc): class WorkflowControlTestCaseMixin(object): - def _create_temp_file(self): _, temp_file_path = tempfile.mkstemp() - os.chmod(temp_file_path, 0o755) # nosec + os.chmod(temp_file_path, 0o755) # nosec return temp_file_path def _delete_temp_file(self, temp_file_path): @@ -57,18 +56,23 @@ def _delete_temp_file(self, temp_file_path): class TestWorkflowExecution(unittest2.TestCase): - @classmethod def setUpClass(cls): - cls.st2client = st2.Client(base_url='http://127.0.0.1') + cls.st2client = st2.Client(base_url="http://127.0.0.1") - def _execute_workflow(self, action, parameters=None, execute_async=True, - expected_status=None, expected_result=None): + def _execute_workflow( + self, + action, + parameters=None, + execute_async=True, + expected_status=None, + expected_result=None, + ): ex = models.LiveAction(action=action, parameters=(parameters or {})) ex = self.st2client.executions.create(ex) self.assertIsNotNone(ex.id) - self.assertEqual(ex.action['ref'], action) + self.assertEqual(ex.action["ref"], action) self.assertIn(ex.status, LIVEACTION_LAUNCHED_STATUSES) if execute_async: @@ -88,14 +92,16 @@ def _execute_workflow(self, action, parameters=None, execute_async=True, @retrying.retry( retry_on_exception=retry_on_exceptions, - wait_fixed=DEFAULT_WAIT_FIXED, stop_max_delay=DEFAULT_STOP_MAX_DELAY) + wait_fixed=DEFAULT_WAIT_FIXED, + stop_max_delay=DEFAULT_STOP_MAX_DELAY, + ) def _wait_for_state(self, ex, states): if isinstance(states, six.string_types): states = [states] for state in states: if state not in action_constants.LIVEACTION_STATUSES: - raise ValueError('Status %s is not valid.' % state) + raise ValueError("Status %s is not valid." % state) try: ex = self.st2client.executions.get_by_id(ex.id) @@ -104,8 +110,7 @@ def _wait_for_state(self, ex, states): if ex.status in action_constants.LIVEACTION_COMPLETED_STATES: raise Exception( 'Execution is in completed state "%s" and ' - 'does not match expected state(s). %s' % - (ex.status, ex.result) + "does not match expected state(s). %s" % (ex.status, ex.result) ) else: raise @@ -117,13 +122,16 @@ def _get_children(self, ex): @retrying.retry( retry_on_exception=retry_on_exceptions, - wait_fixed=DEFAULT_WAIT_FIXED, stop_max_delay=DEFAULT_STOP_MAX_DELAY) + wait_fixed=DEFAULT_WAIT_FIXED, + stop_max_delay=DEFAULT_STOP_MAX_DELAY, + ) def _wait_for_task(self, ex, task, status=None, num_task_exs=1): ex = self.st2client.executions.get_by_id(ex.id) task_exs = [ - task_ex for task_ex in self._get_children(ex) - if task_ex.context.get('orquesta', {}).get('task_name', '') == task + task_ex + for task_ex in self._get_children(ex) + if task_ex.context.get("orquesta", {}).get("task_name", "") == task ] try: @@ -131,8 +139,9 @@ def _wait_for_task(self, ex, task, status=None, num_task_exs=1): except: if ex.status in action_constants.LIVEACTION_COMPLETED_STATES: raise Exception( - 'Execution is in completed state and does not match expected number of ' - 'tasks. Expected: %s Actual: %s' % (str(num_task_exs), str(len(task_exs))) + "Execution is in completed state and does not match expected number of " + "tasks. Expected: %s Actual: %s" + % (str(num_task_exs), str(len(task_exs))) ) else: raise @@ -143,7 +152,7 @@ def _wait_for_task(self, ex, task, status=None, num_task_exs=1): except: if ex.status in action_constants.LIVEACTION_COMPLETED_STATES: raise Exception( - 'Execution is in completed state and not all tasks ' + "Execution is in completed state and not all tasks " 'match expected status "%s".' % status ) else: @@ -153,17 +162,19 @@ def _wait_for_task(self, ex, task, status=None, num_task_exs=1): @retrying.retry( retry_on_exception=retry_on_exceptions, - wait_fixed=DEFAULT_WAIT_FIXED, stop_max_delay=DEFAULT_STOP_MAX_DELAY) + wait_fixed=DEFAULT_WAIT_FIXED, + stop_max_delay=DEFAULT_STOP_MAX_DELAY, + ) def _wait_for_completion(self, ex): ex = self._wait_for_state(ex, action_constants.LIVEACTION_COMPLETED_STATES) try: - self.assertTrue(hasattr(ex, 'result')) + self.assertTrue(hasattr(ex, "result")) except: if ex.status in action_constants.LIVEACTION_COMPLETED_STATES: raise Exception( - 'Execution is in completed state and does not ' - 'contain expected result.' + "Execution is in completed state and does not " + "contain expected result." ) else: raise diff --git a/st2tests/integration/orquesta/test_performance.py b/st2tests/integration/orquesta/test_performance.py index e68ecc7f5fc..899b3090f96 100644 --- a/st2tests/integration/orquesta/test_performance.py +++ b/st2tests/integration/orquesta/test_performance.py @@ -27,34 +27,35 @@ class WiringTest(base.TestWorkflowExecution): - def test_concurrent_load(self): load_count = 3 delay_poll = load_count * 5 - wf_name = 'examples.orquesta-mock-create-vm' - wf_input = {'vm_name': 'demo1', 'meta': {'demo1.itests.org': '10.3.41.99'}} + wf_name = "examples.orquesta-mock-create-vm" + wf_input = {"vm_name": "demo1", "meta": {"demo1.itests.org": "10.3.41.99"}} exs = [self._execute_workflow(wf_name, wf_input) for i in range(load_count)] eventlet.sleep(delay_poll) for ex in exs: e = self._wait_for_completion(ex) - self.assertEqual(e.status, ac_const.LIVEACTION_STATUS_SUCCEEDED, json.dumps(e.result)) - self.assertIn('output', e.result) - self.assertIn('vm_id', e.result['output']) + self.assertEqual( + e.status, ac_const.LIVEACTION_STATUS_SUCCEEDED, json.dumps(e.result) + ) + self.assertIn("output", e.result) + self.assertIn("vm_id", e.result["output"]) def test_with_items_load(self): - wf_name = 'examples.orquesta-with-items-concurrency' + wf_name = "examples.orquesta-with-items-concurrency" num_items = 10 concurrency = 10 members = [str(i).zfill(5) for i in range(0, num_items)] - wf_input = {'members': members, 'concurrency': concurrency} + wf_input = {"members": members, "concurrency": concurrency} - message = '%s, resistance is futile!' - expected_output = {'items': [message % i for i in members]} - expected_result = {'output': expected_output} + message = "%s, resistance is futile!" + expected_output = {"items": [message % i for i in members]} + expected_result = {"output": expected_output} ex = self._execute_workflow(wf_name, wf_input) ex = self._wait_for_completion(ex) diff --git a/st2tests/integration/orquesta/test_wiring.py b/st2tests/integration/orquesta/test_wiring.py index f542c0d779f..3e07d7b3fe4 100644 --- a/st2tests/integration/orquesta/test_wiring.py +++ b/st2tests/integration/orquesta/test_wiring.py @@ -23,13 +23,12 @@ class WiringTest(base.TestWorkflowExecution): - def test_sequential(self): - wf_name = 'examples.orquesta-sequential' - wf_input = {'name': 'Thanos'} + wf_name = "examples.orquesta-sequential" + wf_input = {"name": "Thanos"} - expected_output = {'greeting': 'Thanos, All your base are belong to us!'} - expected_result = {'output': expected_output} + expected_output = {"greeting": "Thanos, All your base are belong to us!"} + expected_result = {"output": expected_output} ex = self._execute_workflow(wf_name, wf_input) ex = self._wait_for_completion(ex) @@ -38,18 +37,18 @@ def test_sequential(self): self.assertDictEqual(ex.result, expected_result) def test_join(self): - wf_name = 'examples.orquesta-join' + wf_name = "examples.orquesta-join" expected_output = { - 'messages': [ - 'Fee fi fo fum', - 'I smell the blood of an English man', - 'Be alive, or be he dead', - 'I\'ll grind his bones to make my bread' + "messages": [ + "Fee fi fo fum", + "I smell the blood of an English man", + "Be alive, or be he dead", + "I'll grind his bones to make my bread", ] } - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} ex = self._execute_workflow(wf_name) ex = self._wait_for_completion(ex) @@ -58,10 +57,10 @@ def test_join(self): self.assertDictEqual(ex.result, expected_result) def test_cycle(self): - wf_name = 'examples.orquesta-rollback-retry' + wf_name = "examples.orquesta-rollback-retry" expected_output = None - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} ex = self._execute_workflow(wf_name) ex = self._wait_for_completion(ex) @@ -70,12 +69,12 @@ def test_cycle(self): self.assertDictEqual(ex.result, expected_result) def test_action_less(self): - wf_name = 'examples.orquesta-test-action-less-tasks' - wf_input = {'name': 'Thanos'} + wf_name = "examples.orquesta-test-action-less-tasks" + wf_input = {"name": "Thanos"} - message = 'Thanos, All your base are belong to us!' - expected_output = {'greeting': message.upper()} - expected_result = {'output': expected_output} + message = "Thanos, All your base are belong to us!" + expected_output = {"greeting": message.upper()} + expected_result = {"output": expected_output} ex = self._execute_workflow(wf_name, wf_input) ex = self._wait_for_completion(ex) @@ -84,73 +83,72 @@ def test_action_less(self): self.assertDictEqual(ex.result, expected_result) def test_st2_runtime_context(self): - wf_name = 'examples.orquesta-st2-ctx' + wf_name = "examples.orquesta-st2-ctx" ex = self._execute_workflow(wf_name) ex = self._wait_for_completion(ex) - expected_output = {'callback': 'http://127.0.0.1:9101/v1/executions/%s' % str(ex.id)} - expected_result = {'output': expected_output} + expected_output = { + "callback": "http://127.0.0.1:9101/v1/executions/%s" % str(ex.id) + } + expected_result = {"output": expected_output} self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) self.assertDictEqual(ex.result, expected_result) def test_subworkflow(self): - wf_name = 'examples.orquesta-subworkflow' + wf_name = "examples.orquesta-subworkflow" ex = self._execute_workflow(wf_name) ex = self._wait_for_completion(ex) self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) - self._wait_for_task(ex, 'start', ac_const.LIVEACTION_STATUS_SUCCEEDED) + self._wait_for_task(ex, "start", ac_const.LIVEACTION_STATUS_SUCCEEDED) - t2_ex = self._wait_for_task(ex, 'subworkflow', ac_const.LIVEACTION_STATUS_SUCCEEDED)[0] - self._wait_for_task(t2_ex, 'task1', ac_const.LIVEACTION_STATUS_SUCCEEDED) - self._wait_for_task(t2_ex, 'task2', ac_const.LIVEACTION_STATUS_SUCCEEDED) - self._wait_for_task(t2_ex, 'task3', ac_const.LIVEACTION_STATUS_SUCCEEDED) + t2_ex = self._wait_for_task( + ex, "subworkflow", ac_const.LIVEACTION_STATUS_SUCCEEDED + )[0] + self._wait_for_task(t2_ex, "task1", ac_const.LIVEACTION_STATUS_SUCCEEDED) + self._wait_for_task(t2_ex, "task2", ac_const.LIVEACTION_STATUS_SUCCEEDED) + self._wait_for_task(t2_ex, "task3", ac_const.LIVEACTION_STATUS_SUCCEEDED) - self._wait_for_task(ex, 'finish', ac_const.LIVEACTION_STATUS_SUCCEEDED) + self._wait_for_task(ex, "finish", ac_const.LIVEACTION_STATUS_SUCCEEDED) def test_output_on_error(self): - wf_name = 'examples.orquesta-output-on-error' + wf_name = "examples.orquesta-output-on-error" ex = self._execute_workflow(wf_name) ex = self._wait_for_completion(ex) - expected_output = { - 'progress': 25 - } + expected_output = {"progress": 25} expected_errors = [ { - 'type': 'error', - 'task_id': 'task2', - 'message': 'Execution failed. See result for details.', - 'result': { - 'failed': True, - 'return_code': 1, - 'stderr': '', - 'stdout': '', - 'succeeded': False - } + "type": "error", + "task_id": "task2", + "message": "Execution failed. See result for details.", + "result": { + "failed": True, + "return_code": 1, + "stderr": "", + "stdout": "", + "succeeded": False, + }, } ] - expected_result = { - 'errors': expected_errors, - 'output': expected_output - } + expected_result = {"errors": expected_errors, "output": expected_output} self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_FAILED) self.assertDictEqual(ex.result, expected_result) def test_config_context_renders(self): config_value = "Testing" - wf_name = 'examples.render_config_context' + wf_name = "examples.render_config_context" - expected_output = {'context_value': config_value} - expected_result = {'output': expected_output} + expected_output = {"context_value": config_value} + expected_result = {"output": expected_output} ex = self._execute_workflow(wf_name) ex = self._wait_for_completion(ex) @@ -159,21 +157,21 @@ def test_config_context_renders(self): self.assertDictEqual(ex.result, expected_result) def test_field_escaping(self): - wf_name = 'examples.orquesta-test-field-escaping' + wf_name = "examples.orquesta-test-field-escaping" ex = self._execute_workflow(wf_name) ex = self._wait_for_completion(ex) expected_output = { - 'wf.hostname.with.periods': { - 'hostname.domain.tld': 'vars.value.with.periods', - 'hostname2.domain.tld': { - 'stdout': 'vars.nested.value.with.periods', + "wf.hostname.with.periods": { + "hostname.domain.tld": "vars.value.with.periods", + "hostname2.domain.tld": { + "stdout": "vars.nested.value.with.periods", }, }, - 'wf.output.with.periods': 'vars.nested.value.with.periods', + "wf.output.with.periods": "vars.nested.value.with.periods", } - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) self.assertDictEqual(ex.result, expected_result) diff --git a/st2tests/integration/orquesta/test_wiring_cancel.py b/st2tests/integration/orquesta/test_wiring_cancel.py index ff9d0d378f3..0e4edaf9184 100644 --- a/st2tests/integration/orquesta/test_wiring_cancel.py +++ b/st2tests/integration/orquesta/test_wiring_cancel.py @@ -22,7 +22,9 @@ from st2common.constants import action as ac_const -class CancellationWiringTest(base.TestWorkflowExecution, base.WorkflowControlTestCaseMixin): +class CancellationWiringTest( + base.TestWorkflowExecution, base.WorkflowControlTestCaseMixin +): temp_file_path = None @@ -44,9 +46,9 @@ def test_cancellation(self): self.assertTrue(os.path.exists(path)) # Launch the workflow. The workflow will wait for the temp file to be deleted. - params = {'tempfile': path, 'message': 'foobar'} - ex = self._execute_workflow('examples.orquesta-test-cancel', params) - self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_RUNNING) + params = {"tempfile": path, "message": "foobar"} + ex = self._execute_workflow("examples.orquesta-test-cancel", params) + self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING) # Cancel the workflow before the temp file is created. The workflow will be paused # but task1 will still be running to allow for graceful exit. @@ -63,7 +65,7 @@ def test_cancellation(self): ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_CANCELED) # Task is completed successfully for graceful exit. - self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_SUCCEEDED) + self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_SUCCEEDED) # Get the updated execution with task result. ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_CANCELED) @@ -74,15 +76,15 @@ def test_task_cancellation(self): self.assertTrue(os.path.exists(path)) # Launch the workflow. The workflow will wait for the temp file to be deleted. - params = {'tempfile': path, 'message': 'foobar'} - ex = self._execute_workflow('examples.orquesta-test-cancel', params) - task_exs = self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_RUNNING) + params = {"tempfile": path, "message": "foobar"} + ex = self._execute_workflow("examples.orquesta-test-cancel", params) + task_exs = self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING) # Cancel the task execution. self.st2client.executions.delete(task_exs[0]) # Wait for the task and parent workflow to be canceled. - self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_CANCELED) + self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_CANCELED) # Get the updated execution with task result. ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_CANCELED) @@ -93,10 +95,10 @@ def test_cancellation_cascade_down_to_subworkflow(self): self.assertTrue(os.path.exists(path)) # Launch the workflow. The workflow will wait for the temp file to be deleted. - params = {'tempfile': path, 'message': 'foobar'} - action_ref = 'examples.orquesta-test-cancel-subworkflow' + params = {"tempfile": path, "message": "foobar"} + action_ref = "examples.orquesta-test-cancel-subworkflow" ex = self._execute_workflow(action_ref, params) - task_exs = self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_RUNNING) + task_exs = self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING) subwf_ex = task_exs[0] # Cancel the workflow before the temp file is deleted. The workflow will be canceled @@ -123,10 +125,10 @@ def test_cancellation_cascade_up_from_subworkflow(self): self.assertTrue(os.path.exists(path)) # Launch the workflow. The workflow will wait for the temp file to be deleted. - params = {'tempfile': path, 'message': 'foobar'} - action_ref = 'examples.orquesta-test-cancel-subworkflow' + params = {"tempfile": path, "message": "foobar"} + action_ref = "examples.orquesta-test-cancel-subworkflow" ex = self._execute_workflow(action_ref, params) - task_exs = self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_RUNNING) + task_exs = self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING) subwf_ex = task_exs[0] # Cancel the workflow before the temp file is deleted. The workflow will be canceled @@ -155,12 +157,12 @@ def test_cancellation_cascade_up_to_workflow_with_other_subworkflow(self): self.assertTrue(os.path.exists(path)) # Launch the workflow. The workflow will wait for the temp file to be deleted. - params = {'file1': path, 'file2': path} - action_ref = 'examples.orquesta-test-cancel-subworkflows' + params = {"file1": path, "file2": path} + action_ref = "examples.orquesta-test-cancel-subworkflows" ex = self._execute_workflow(action_ref, params) - task_exs = self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_RUNNING) + task_exs = self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING) subwf_ex_1 = task_exs[0] - task_exs = self._wait_for_task(ex, 'task2', ac_const.LIVEACTION_STATUS_RUNNING) + task_exs = self._wait_for_task(ex, "task2", ac_const.LIVEACTION_STATUS_RUNNING) subwf_ex_2 = task_exs[0] # Cancel the workflow before the temp file is deleted. The workflow will be canceled @@ -168,19 +170,27 @@ def test_cancellation_cascade_up_to_workflow_with_other_subworkflow(self): self.st2client.executions.delete(subwf_ex_1) # Assert subworkflow is canceling. - subwf_ex_1 = self._wait_for_state(subwf_ex_1, ac_const.LIVEACTION_STATUS_CANCELING) + subwf_ex_1 = self._wait_for_state( + subwf_ex_1, ac_const.LIVEACTION_STATUS_CANCELING + ) # Assert main workflow and the other subworkflow is canceling. ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_CANCELING) - subwf_ex_2 = self._wait_for_state(subwf_ex_2, ac_const.LIVEACTION_STATUS_CANCELING) + subwf_ex_2 = self._wait_for_state( + subwf_ex_2, ac_const.LIVEACTION_STATUS_CANCELING + ) # Delete the temporary file. os.remove(path) self.assertFalse(os.path.exists(path)) # Assert subworkflows are canceled. - subwf_ex_1 = self._wait_for_state(subwf_ex_1, ac_const.LIVEACTION_STATUS_CANCELED) - subwf_ex_2 = self._wait_for_state(subwf_ex_2, ac_const.LIVEACTION_STATUS_CANCELED) + subwf_ex_1 = self._wait_for_state( + subwf_ex_1, ac_const.LIVEACTION_STATUS_CANCELED + ) + subwf_ex_2 = self._wait_for_state( + subwf_ex_2, ac_const.LIVEACTION_STATUS_CANCELED + ) # Assert main workflow is canceled. ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_CANCELED) diff --git a/st2tests/integration/orquesta/test_wiring_data_flow.py b/st2tests/integration/orquesta/test_wiring_data_flow.py index a9569cf693d..ed5fbfa23a1 100644 --- a/st2tests/integration/orquesta/test_wiring_data_flow.py +++ b/st2tests/integration/orquesta/test_wiring_data_flow.py @@ -27,13 +27,12 @@ class WiringTest(base.TestWorkflowExecution): - def test_data_flow(self): - wf_name = 'examples.orquesta-data-flow' - wf_input = {'a1': 'fee fi fo fum'} + wf_name = "examples.orquesta-data-flow" + wf_input = {"a1": "fee fi fo fum"} - expected_output = {'a5': wf_input['a1'], 'b5': wf_input['a1']} - expected_result = {'output': expected_output} + expected_output = {"a5": wf_input["a1"], "b5": wf_input["a1"]} + expected_result = {"output": expected_output} ex = self._execute_workflow(wf_name, wf_input) ex = self._wait_for_completion(ex) @@ -42,15 +41,15 @@ def test_data_flow(self): self.assertDictEqual(ex.result, expected_result) def test_data_flow_unicode(self): - wf_name = 'examples.orquesta-data-flow' - wf_input = {'a1': '床前明月光 疑是地上霜 舉頭望明月 低頭思故鄉'} + wf_name = "examples.orquesta-data-flow" + wf_input = {"a1": "床前明月光 疑是地上霜 舉頭望明月 低頭思故鄉"} expected_output = { - 'a5': wf_input['a1'].decode('utf-8') if six.PY2 else wf_input['a1'], - 'b5': wf_input['a1'].decode('utf-8') if six.PY2 else wf_input['a1'] + "a5": wf_input["a1"].decode("utf-8") if six.PY2 else wf_input["a1"], + "b5": wf_input["a1"].decode("utf-8") if six.PY2 else wf_input["a1"], } - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} ex = self._execute_workflow(wf_name, wf_input) ex = self._wait_for_completion(ex) @@ -59,16 +58,15 @@ def test_data_flow_unicode(self): self.assertDictEqual(ex.result, expected_result) def test_data_flow_unicode_concat_with_ascii(self): - wf_name = 'examples.orquesta-sequential' - wf_input = {'name': '薩諾斯'} + wf_name = "examples.orquesta-sequential" + wf_input = {"name": "薩諾斯"} expected_output = { - 'greeting': '%s, All your base are belong to us!' % ( - wf_input['name'].decode('utf-8') if six.PY2 else wf_input['name'] - ) + "greeting": "%s, All your base are belong to us!" + % (wf_input["name"].decode("utf-8") if six.PY2 else wf_input["name"]) } - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} ex = self._execute_workflow(wf_name, wf_input) ex = self._wait_for_completion(ex) @@ -77,15 +75,17 @@ def test_data_flow_unicode_concat_with_ascii(self): self.assertDictEqual(ex.result, expected_result) def test_data_flow_big_data_size(self): - wf_name = 'examples.orquesta-data-flow' + wf_name = "examples.orquesta-data-flow" data_length = 100000 - data = ''.join(random.choice(string.ascii_lowercase) for _ in range(data_length)) + data = "".join( + random.choice(string.ascii_lowercase) for _ in range(data_length) + ) - wf_input = {'a1': data} + wf_input = {"a1": data} - expected_output = {'a5': wf_input['a1'], 'b5': wf_input['a1']} - expected_result = {'output': expected_output} + expected_output = {"a5": wf_input["a1"], "b5": wf_input["a1"]} + expected_result = {"output": expected_output} ex = self._execute_workflow(wf_name, wf_input) ex = self._wait_for_completion(ex) diff --git a/st2tests/integration/orquesta/test_wiring_delay.py b/st2tests/integration/orquesta/test_wiring_delay.py index f825475479f..32b923b923a 100644 --- a/st2tests/integration/orquesta/test_wiring_delay.py +++ b/st2tests/integration/orquesta/test_wiring_delay.py @@ -23,13 +23,12 @@ class TaskDelayWiringTest(base.TestWorkflowExecution): - def test_task_delay(self): - wf_name = 'examples.orquesta-delay' - wf_input = {'name': 'Thanos', 'delay': 1} + wf_name = "examples.orquesta-delay" + wf_input = {"name": "Thanos", "delay": 1} - expected_output = {'greeting': 'Thanos, All your base are belong to us!'} - expected_result = {'output': expected_output} + expected_output = {"greeting": "Thanos, All your base are belong to us!"} + expected_result = {"output": expected_output} ex = self._execute_workflow(wf_name, wf_input) ex = self._wait_for_completion(ex) @@ -38,12 +37,12 @@ def test_task_delay(self): self.assertDictEqual(ex.result, expected_result) def test_task_delay_workflow_cancellation(self): - wf_name = 'examples.orquesta-delay' - wf_input = {'name': 'Thanos', 'delay': 300} + wf_name = "examples.orquesta-delay" + wf_input = {"name": "Thanos", "delay": 300} # Launch workflow and task1 should be delayed. ex = self._execute_workflow(wf_name, wf_input) - self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_DELAYED) + self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_DELAYED) # Cancel the workflow before the temp file is created. The workflow will be paused # but task1 will still be running to allow for graceful exit. @@ -53,24 +52,24 @@ def test_task_delay_workflow_cancellation(self): ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_CANCELED) # Task execution should be canceled. - self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_CANCELED) + self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_CANCELED) # Get the updated execution with task result. ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_CANCELED) def test_task_delay_task_cancellation(self): - wf_name = 'examples.orquesta-delay' - wf_input = {'name': 'Thanos', 'delay': 300} + wf_name = "examples.orquesta-delay" + wf_input = {"name": "Thanos", "delay": 300} # Launch workflow and task1 should be delayed. ex = self._execute_workflow(wf_name, wf_input) - task_exs = self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_DELAYED) + task_exs = self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_DELAYED) # Cancel the task execution. self.st2client.executions.delete(task_exs[0]) # Wait for the task and parent workflow to be canceled. - self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_CANCELED) + self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_CANCELED) # Get the updated execution with task result. ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_CANCELED) diff --git a/st2tests/integration/orquesta/test_wiring_error_handling.py b/st2tests/integration/orquesta/test_wiring_error_handling.py index f3c9f87fdda..130a68c7c5c 100644 --- a/st2tests/integration/orquesta/test_wiring_error_handling.py +++ b/st2tests/integration/orquesta/test_wiring_error_handling.py @@ -22,236 +22,235 @@ class ErrorHandlingTest(base.TestWorkflowExecution): - def test_inspection_error(self): expected_errors = [ { - 'type': 'content', - 'message': 'The action "std.noop" is not registered in the database.', - 'schema_path': r'properties.tasks.patternProperties.^\w+$.properties.action', - 'spec_path': 'tasks.task3.action' + "type": "content", + "message": 'The action "std.noop" is not registered in the database.', + "schema_path": r"properties.tasks.patternProperties.^\w+$.properties.action", + "spec_path": "tasks.task3.action", }, { - 'type': 'context', - 'language': 'yaql', - 'expression': '<% ctx().foobar %>', - 'message': 'Variable "foobar" is referenced before assignment.', - 'schema_path': r'properties.tasks.patternProperties.^\w+$.properties.input', - 'spec_path': 'tasks.task1.input', + "type": "context", + "language": "yaql", + "expression": "<% ctx().foobar %>", + "message": 'Variable "foobar" is referenced before assignment.', + "schema_path": r"properties.tasks.patternProperties.^\w+$.properties.input", + "spec_path": "tasks.task1.input", }, { - 'type': 'expression', - 'language': 'yaql', - 'expression': '<% <% succeeded() %>', - 'message': ( - 'Parse error: unexpected \'<\' at ' - 'position 0 of expression \'<% succeeded()\'' + "type": "expression", + "language": "yaql", + "expression": "<% <% succeeded() %>", + "message": ( + "Parse error: unexpected '<' at " + "position 0 of expression '<% succeeded()'" ), - 'schema_path': ( - r'properties.tasks.patternProperties.^\w+$.' - 'properties.next.items.properties.when' + "schema_path": ( + r"properties.tasks.patternProperties.^\w+$." + "properties.next.items.properties.when" ), - 'spec_path': 'tasks.task2.next[0].when' + "spec_path": "tasks.task2.next[0].when", }, { - 'type': 'syntax', - 'message': ( - '[{\'cmd\': \'echo <% ctx().macro %>\'}] is ' - 'not valid under any of the given schemas' + "type": "syntax", + "message": ( + "[{'cmd': 'echo <% ctx().macro %>'}] is " + "not valid under any of the given schemas" ), - 'schema_path': r'properties.tasks.patternProperties.^\w+$.properties.input.oneOf', - 'spec_path': 'tasks.task2.input' - } + "schema_path": r"properties.tasks.patternProperties.^\w+$.properties.input.oneOf", + "spec_path": "tasks.task2.input", + }, ] - ex = self._execute_workflow('examples.orquesta-fail-inspection') + ex = self._execute_workflow("examples.orquesta-fail-inspection") ex = self._wait_for_completion(ex) self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_FAILED) - self.assertDictEqual(ex.result, {'errors': expected_errors, 'output': None}) + self.assertDictEqual(ex.result, {"errors": expected_errors, "output": None}) def test_input_error(self): expected_errors = [ { - 'type': 'error', - 'message': ( - 'YaqlEvaluationException: Unable to evaluate expression ' - '\'<% abs(8).value %>\'. NoFunctionRegisteredException: ' + "type": "error", + "message": ( + "YaqlEvaluationException: Unable to evaluate expression " + "'<% abs(8).value %>'. NoFunctionRegisteredException: " 'Unknown function "#property#value"' - ) + ), } ] - ex = self._execute_workflow('examples.orquesta-fail-input-rendering') + ex = self._execute_workflow("examples.orquesta-fail-input-rendering") ex = self._wait_for_completion(ex) self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_FAILED) - self.assertDictEqual(ex.result, {'errors': expected_errors, 'output': None}) + self.assertDictEqual(ex.result, {"errors": expected_errors, "output": None}) def test_vars_error(self): expected_errors = [ { - 'type': 'error', - 'message': ( - 'YaqlEvaluationException: Unable to evaluate expression ' - '\'<% abs(8).value %>\'. NoFunctionRegisteredException: ' + "type": "error", + "message": ( + "YaqlEvaluationException: Unable to evaluate expression " + "'<% abs(8).value %>'. NoFunctionRegisteredException: " 'Unknown function "#property#value"' - ) + ), } ] - ex = self._execute_workflow('examples.orquesta-fail-vars-rendering') + ex = self._execute_workflow("examples.orquesta-fail-vars-rendering") ex = self._wait_for_completion(ex) self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_FAILED) - self.assertDictEqual(ex.result, {'errors': expected_errors, 'output': None}) + self.assertDictEqual(ex.result, {"errors": expected_errors, "output": None}) def test_start_task_error(self): self.maxDiff = None expected_errors = [ { - 'type': 'error', - 'message': ( - 'YaqlEvaluationException: Unable to evaluate expression ' - '\'<% ctx().name.value %>\'. NoFunctionRegisteredException: ' + "type": "error", + "message": ( + "YaqlEvaluationException: Unable to evaluate expression " + "'<% ctx().name.value %>'. NoFunctionRegisteredException: " 'Unknown function "#property#value"' ), - 'task_id': 'task1', - 'route': 0 + "task_id": "task1", + "route": 0, }, { - 'type': 'error', - 'message': ( - 'YaqlEvaluationException: Unable to resolve key \'greeting\' ' - 'in expression \'<% ctx().greeting %>\' from context.' - ) - } + "type": "error", + "message": ( + "YaqlEvaluationException: Unable to resolve key 'greeting' " + "in expression '<% ctx().greeting %>' from context." + ), + }, ] - ex = self._execute_workflow('examples.orquesta-fail-start-task') + ex = self._execute_workflow("examples.orquesta-fail-start-task") ex = self._wait_for_completion(ex) self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_FAILED) - self.assertDictEqual(ex.result, {'errors': expected_errors, 'output': None}) + self.assertDictEqual(ex.result, {"errors": expected_errors, "output": None}) def test_task_transition_error(self): expected_errors = [ { - 'type': 'error', - 'message': ( - 'YaqlEvaluationException: Unable to resolve key \'value\' ' - 'in expression \'<% succeeded() and result().value %>\' from context.' + "type": "error", + "message": ( + "YaqlEvaluationException: Unable to resolve key 'value' " + "in expression '<% succeeded() and result().value %>' from context." ), - 'task_transition_id': 'task2__t0', - 'task_id': 'task1', - 'route': 0 + "task_transition_id": "task2__t0", + "task_id": "task1", + "route": 0, } ] - expected_output = { - 'greeting': None - } + expected_output = {"greeting": None} - ex = self._execute_workflow('examples.orquesta-fail-task-transition') + ex = self._execute_workflow("examples.orquesta-fail-task-transition") ex = self._wait_for_completion(ex) self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_FAILED) - self.assertDictEqual(ex.result, {'errors': expected_errors, 'output': expected_output}) + self.assertDictEqual( + ex.result, {"errors": expected_errors, "output": expected_output} + ) def test_task_publish_error(self): expected_errors = [ { - 'type': 'error', - 'message': ( - 'YaqlEvaluationException: Unable to resolve key \'value\' ' - 'in expression \'<% result().value %>\' from context.' + "type": "error", + "message": ( + "YaqlEvaluationException: Unable to resolve key 'value' " + "in expression '<% result().value %>' from context." ), - 'task_transition_id': 'task2__t0', - 'task_id': 'task1', - 'route': 0 + "task_transition_id": "task2__t0", + "task_id": "task1", + "route": 0, } ] - expected_output = { - 'greeting': None - } + expected_output = {"greeting": None} - ex = self._execute_workflow('examples.orquesta-fail-task-publish') + ex = self._execute_workflow("examples.orquesta-fail-task-publish") ex = self._wait_for_completion(ex) self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_FAILED) - self.assertDictEqual(ex.result, {'errors': expected_errors, 'output': expected_output}) + self.assertDictEqual( + ex.result, {"errors": expected_errors, "output": expected_output} + ) def test_output_error(self): expected_errors = [ { - 'type': 'error', - 'message': ( - 'YaqlEvaluationException: Unable to evaluate expression ' - '\'<% abs(8).value %>\'. NoFunctionRegisteredException: ' + "type": "error", + "message": ( + "YaqlEvaluationException: Unable to evaluate expression " + "'<% abs(8).value %>'. NoFunctionRegisteredException: " 'Unknown function "#property#value"' - ) + ), } ] - ex = self._execute_workflow('examples.orquesta-fail-output-rendering') + ex = self._execute_workflow("examples.orquesta-fail-output-rendering") ex = self._wait_for_completion(ex) self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_FAILED) - self.assertDictEqual(ex.result, {'errors': expected_errors, 'output': None}) + self.assertDictEqual(ex.result, {"errors": expected_errors, "output": None}) def test_task_content_errors(self): expected_errors = [ { - 'type': 'content', - 'message': 'The action reference "echo" is not formatted correctly.', - 'schema_path': r'properties.tasks.patternProperties.^\w+$.properties.action', - 'spec_path': 'tasks.task1.action' + "type": "content", + "message": 'The action reference "echo" is not formatted correctly.', + "schema_path": r"properties.tasks.patternProperties.^\w+$.properties.action", + "spec_path": "tasks.task1.action", }, { - 'type': 'content', - 'message': 'The action "core.echoz" is not registered in the database.', - 'schema_path': r'properties.tasks.patternProperties.^\w+$.properties.action', - 'spec_path': 'tasks.task2.action' + "type": "content", + "message": 'The action "core.echoz" is not registered in the database.', + "schema_path": r"properties.tasks.patternProperties.^\w+$.properties.action", + "spec_path": "tasks.task2.action", }, { - 'type': 'content', - 'message': 'Action "core.echo" is missing required input "message".', - 'schema_path': r'properties.tasks.patternProperties.^\w+$.properties.input', - 'spec_path': 'tasks.task3.input' + "type": "content", + "message": 'Action "core.echo" is missing required input "message".', + "schema_path": r"properties.tasks.patternProperties.^\w+$.properties.input", + "spec_path": "tasks.task3.input", }, { - 'type': 'content', - 'message': 'Action "core.echo" has unexpected input "messages".', - 'schema_path': ( - r'properties.tasks.patternProperties.^\w+$.properties.input.' - r'patternProperties.^\w+$' + "type": "content", + "message": 'Action "core.echo" has unexpected input "messages".', + "schema_path": ( + r"properties.tasks.patternProperties.^\w+$.properties.input." + r"patternProperties.^\w+$" ), - 'spec_path': 'tasks.task3.input.messages' - } + "spec_path": "tasks.task3.input.messages", + }, ] - ex = self._execute_workflow('examples.orquesta-fail-inspection-task-contents') + ex = self._execute_workflow("examples.orquesta-fail-inspection-task-contents") ex = self._wait_for_completion(ex) self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_FAILED) - self.assertDictEqual(ex.result, {'errors': expected_errors, 'output': None}) + self.assertDictEqual(ex.result, {"errors": expected_errors, "output": None}) def test_remediate_then_fail(self): expected_errors = [ { - 'task_id': 'task1', - 'type': 'error', - 'message': 'Execution failed. See result for details.', - 'result': { - 'failed': True, - 'return_code': 1, - 'stderr': '', - 'stdout': '', - 'succeeded': False - } + "task_id": "task1", + "type": "error", + "message": "Execution failed. See result for details.", + "result": { + "failed": True, + "return_code": 1, + "stderr": "", + "stdout": "", + "succeeded": False, + }, }, { - 'task_id': 'fail', - 'type': 'error', - 'message': 'Execution failed. See result for details.' - } + "task_id": "fail", + "type": "error", + "message": "Execution failed. See result for details.", + }, ] - ex = self._execute_workflow('examples.orquesta-remediate-then-fail') + ex = self._execute_workflow("examples.orquesta-remediate-then-fail") ex = self._wait_for_completion(ex) # Assert that the log task is executed. @@ -261,93 +260,95 @@ def test_remediate_then_fail(self): # tasks is reached (With some hard limit) before failing eventlet.sleep(2) - self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_FAILED) - self._wait_for_task(ex, 'log', ac_const.LIVEACTION_STATUS_SUCCEEDED) + self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_FAILED) + self._wait_for_task(ex, "log", ac_const.LIVEACTION_STATUS_SUCCEEDED) # Assert workflow status and result. self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_FAILED) - self.assertDictEqual(ex.result, {'errors': expected_errors, 'output': None}) + self.assertDictEqual(ex.result, {"errors": expected_errors, "output": None}) def test_fail_manually(self): expected_errors = [ { - 'task_id': 'task1', - 'type': 'error', - 'message': 'Execution failed. See result for details.', - 'result': { - 'failed': True, - 'return_code': 1, - 'stderr': '', - 'stdout': '', - 'succeeded': False - } + "task_id": "task1", + "type": "error", + "message": "Execution failed. See result for details.", + "result": { + "failed": True, + "return_code": 1, + "stderr": "", + "stdout": "", + "succeeded": False, + }, }, { - 'task_id': 'fail', - 'type': 'error', - 'message': 'Execution failed. See result for details.' - } + "task_id": "fail", + "type": "error", + "message": "Execution failed. See result for details.", + }, ] - expected_output = { - 'message': '$%#&@#$!!!' - } + expected_output = {"message": "$%#&@#$!!!"} - wf_input = {'cmd': 'exit 1'} - ex = self._execute_workflow('examples.orquesta-error-handling-fail-manually', wf_input) + wf_input = {"cmd": "exit 1"} + ex = self._execute_workflow( + "examples.orquesta-error-handling-fail-manually", wf_input + ) ex = self._wait_for_completion(ex) # Assert task status. - self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_FAILED) - self._wait_for_task(ex, 'task3', ac_const.LIVEACTION_STATUS_SUCCEEDED) + self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_FAILED) + self._wait_for_task(ex, "task3", ac_const.LIVEACTION_STATUS_SUCCEEDED) # Assert workflow status and result. self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_FAILED) - self.assertDictEqual(ex.result, {'errors': expected_errors, 'output': expected_output}) + self.assertDictEqual( + ex.result, {"errors": expected_errors, "output": expected_output} + ) def test_fail_continue(self): expected_errors = [ { - 'task_id': 'task1', - 'type': 'error', - 'message': 'Execution failed. See result for details.', - 'result': { - 'failed': True, - 'return_code': 1, - 'stderr': '', - 'stdout': '', - 'succeeded': False - } + "task_id": "task1", + "type": "error", + "message": "Execution failed. See result for details.", + "result": { + "failed": True, + "return_code": 1, + "stderr": "", + "stdout": "", + "succeeded": False, + }, } ] - expected_output = { - 'message': '$%#&@#$!!!' - } + expected_output = {"message": "$%#&@#$!!!"} - wf_input = {'cmd': 'exit 1'} - ex = self._execute_workflow('examples.orquesta-error-handling-continue', wf_input) + wf_input = {"cmd": "exit 1"} + ex = self._execute_workflow( + "examples.orquesta-error-handling-continue", wf_input + ) ex = self._wait_for_completion(ex) # Assert task status. - self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_FAILED) + self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_FAILED) # Assert workflow status and result. self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_FAILED) - self.assertDictEqual(ex.result, {'errors': expected_errors, 'output': expected_output}) + self.assertDictEqual( + ex.result, {"errors": expected_errors, "output": expected_output} + ) def test_fail_noop(self): - expected_output = { - 'message': '$%#&@#$!!!' - } + expected_output = {"message": "$%#&@#$!!!"} - wf_input = {'cmd': 'exit 1'} - ex = self._execute_workflow('examples.orquesta-error-handling-noop', wf_input) + wf_input = {"cmd": "exit 1"} + ex = self._execute_workflow("examples.orquesta-error-handling-noop", wf_input) ex = self._wait_for_completion(ex) # Assert task status. - self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_FAILED) + self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_FAILED) # Assert workflow status and result. self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) - self.assertDictEqual(ex.result, {'output': expected_output}) + self.assertDictEqual(ex.result, {"output": expected_output}) diff --git a/st2tests/integration/orquesta/test_wiring_functions.py b/st2tests/integration/orquesta/test_wiring_functions.py index 91da108d391..538bf9ddd7d 100644 --- a/st2tests/integration/orquesta/test_wiring_functions.py +++ b/st2tests/integration/orquesta/test_wiring_functions.py @@ -19,165 +19,174 @@ class FunctionsWiringTest(base.TestWorkflowExecution): - def test_data_functions_in_yaql(self): - wf_name = 'examples.orquesta-test-yaql-data-functions' + wf_name = "examples.orquesta-test-yaql-data-functions" expected_output = { - 'data_json_str_1': '{"foo": {"bar": "foobar"}}', - 'data_json_str_2': '{"foo": {"bar": "foobar"}}', - 'data_json_str_3': '{"foo": {"bar": "foobar"}}', - 'data_json_obj_1': {'foo': {'bar': 'foobar'}}, - 'data_json_obj_2': {'foo': {'bar': 'foobar'}}, - 'data_json_obj_3': {'foo': {'bar': 'foobar'}}, - 'data_json_obj_4': {'foo': {'bar': 'foobar'}}, - 'data_yaml_str_1': 'foo:\n bar: foobar\n', - 'data_yaml_str_2': 'foo:\n bar: foobar\n', - 'data_query_1': ['foobar'], - 'data_none_str': '%*****__%NONE%__*****%', - 'data_str': 'foobar' + "data_json_str_1": '{"foo": {"bar": "foobar"}}', + "data_json_str_2": '{"foo": {"bar": "foobar"}}', + "data_json_str_3": '{"foo": {"bar": "foobar"}}', + "data_json_obj_1": {"foo": {"bar": "foobar"}}, + "data_json_obj_2": {"foo": {"bar": "foobar"}}, + "data_json_obj_3": {"foo": {"bar": "foobar"}}, + "data_json_obj_4": {"foo": {"bar": "foobar"}}, + "data_yaml_str_1": "foo:\n bar: foobar\n", + "data_yaml_str_2": "foo:\n bar: foobar\n", + "data_query_1": ["foobar"], + "data_none_str": "%*****__%NONE%__*****%", + "data_str": "foobar", } - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} - self._execute_workflow(wf_name, execute_async=False, expected_result=expected_result) + self._execute_workflow( + wf_name, execute_async=False, expected_result=expected_result + ) def test_data_functions_in_jinja(self): - wf_name = 'examples.orquesta-test-jinja-data-functions' + wf_name = "examples.orquesta-test-jinja-data-functions" expected_output = { - 'data_json_str_1': '{"foo": {"bar": "foobar"}}', - 'data_json_str_2': '{"foo": {"bar": "foobar"}}', - 'data_json_str_3': '{"foo": {"bar": "foobar"}}', - 'data_json_obj_1': {'foo': {'bar': 'foobar'}}, - 'data_json_obj_2': {'foo': {'bar': 'foobar'}}, - 'data_json_obj_3': {'foo': {'bar': 'foobar'}}, - 'data_json_obj_4': {'foo': {'bar': 'foobar'}}, - 'data_yaml_str_1': 'foo:\n bar: foobar\n', - 'data_yaml_str_2': 'foo:\n bar: foobar\n', - 'data_query_1': ['foobar'], - 'data_pipe_str_1': '{"foo": {"bar": "foobar"}}', - 'data_none_str': '%*****__%NONE%__*****%', - 'data_str': 'foobar', - 'data_list_str': '- a: 1\n b: 2\n- x: 3\n y: 4\n' + "data_json_str_1": '{"foo": {"bar": "foobar"}}', + "data_json_str_2": '{"foo": {"bar": "foobar"}}', + "data_json_str_3": '{"foo": {"bar": "foobar"}}', + "data_json_obj_1": {"foo": {"bar": "foobar"}}, + "data_json_obj_2": {"foo": {"bar": "foobar"}}, + "data_json_obj_3": {"foo": {"bar": "foobar"}}, + "data_json_obj_4": {"foo": {"bar": "foobar"}}, + "data_yaml_str_1": "foo:\n bar: foobar\n", + "data_yaml_str_2": "foo:\n bar: foobar\n", + "data_query_1": ["foobar"], + "data_pipe_str_1": '{"foo": {"bar": "foobar"}}', + "data_none_str": "%*****__%NONE%__*****%", + "data_str": "foobar", + "data_list_str": "- a: 1\n b: 2\n- x: 3\n y: 4\n", } - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} - self._execute_workflow(wf_name, execute_async=False, expected_result=expected_result) + self._execute_workflow( + wf_name, execute_async=False, expected_result=expected_result + ) def test_path_functions_in_yaql(self): - wf_name = 'examples.orquesta-test-yaql-path-functions' + wf_name = "examples.orquesta-test-yaql-path-functions" - expected_output = { - 'basename': 'file.txt', - 'dirname': '/path/to/some' - } + expected_output = {"basename": "file.txt", "dirname": "/path/to/some"} - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} - self._execute_workflow(wf_name, execute_async=False, expected_result=expected_result) + self._execute_workflow( + wf_name, execute_async=False, expected_result=expected_result + ) def test_path_functions_in_jinja(self): - wf_name = 'examples.orquesta-test-jinja-path-functions' + wf_name = "examples.orquesta-test-jinja-path-functions" - expected_output = { - 'basename': 'file.txt', - 'dirname': '/path/to/some' - } + expected_output = {"basename": "file.txt", "dirname": "/path/to/some"} - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} - self._execute_workflow(wf_name, execute_async=False, expected_result=expected_result) + self._execute_workflow( + wf_name, execute_async=False, expected_result=expected_result + ) def test_regex_functions_in_yaql(self): - wf_name = 'examples.orquesta-test-yaql-regex-functions' + wf_name = "examples.orquesta-test-yaql-regex-functions" expected_output = { - 'match': True, - 'replace': 'wxyz', - 'search': True, - 'substring': '668 Infinite Dr' + "match": True, + "replace": "wxyz", + "search": True, + "substring": "668 Infinite Dr", } - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} - self._execute_workflow(wf_name, execute_async=False, expected_result=expected_result) + self._execute_workflow( + wf_name, execute_async=False, expected_result=expected_result + ) def test_regex_functions_in_jinja(self): - wf_name = 'examples.orquesta-test-jinja-regex-functions' + wf_name = "examples.orquesta-test-jinja-regex-functions" expected_output = { - 'match': True, - 'replace': 'wxyz', - 'search': True, - 'substring': '668 Infinite Dr' + "match": True, + "replace": "wxyz", + "search": True, + "substring": "668 Infinite Dr", } - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} - self._execute_workflow(wf_name, execute_async=False, expected_result=expected_result) + self._execute_workflow( + wf_name, execute_async=False, expected_result=expected_result + ) def test_time_functions_in_yaql(self): - wf_name = 'examples.orquesta-test-yaql-time-functions' + wf_name = "examples.orquesta-test-yaql-time-functions" - expected_output = { - 'time': '3h25m45s' - } + expected_output = {"time": "3h25m45s"} - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} - self._execute_workflow(wf_name, execute_async=False, expected_result=expected_result) + self._execute_workflow( + wf_name, execute_async=False, expected_result=expected_result + ) def test_time_functions_in_jinja(self): - wf_name = 'examples.orquesta-test-jinja-time-functions' + wf_name = "examples.orquesta-test-jinja-time-functions" - expected_output = { - 'time': '3h25m45s' - } + expected_output = {"time": "3h25m45s"} - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} - self._execute_workflow(wf_name, execute_async=False, expected_result=expected_result) + self._execute_workflow( + wf_name, execute_async=False, expected_result=expected_result + ) def test_version_functions_in_yaql(self): - wf_name = 'examples.orquesta-test-yaql-version-functions' + wf_name = "examples.orquesta-test-yaql-version-functions" expected_output = { - 'compare_equal': 0, - 'compare_more_than': -1, - 'compare_less_than': 1, - 'equal': True, - 'more_than': False, - 'less_than': False, - 'match': True, - 'bump_major': '1.0.0', - 'bump_minor': '0.11.0', - 'bump_patch': '0.10.1', - 'strip_patch': '0.10' + "compare_equal": 0, + "compare_more_than": -1, + "compare_less_than": 1, + "equal": True, + "more_than": False, + "less_than": False, + "match": True, + "bump_major": "1.0.0", + "bump_minor": "0.11.0", + "bump_patch": "0.10.1", + "strip_patch": "0.10", } - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} - self._execute_workflow(wf_name, execute_async=False, expected_result=expected_result) + self._execute_workflow( + wf_name, execute_async=False, expected_result=expected_result + ) def test_version_functions_in_jinja(self): - wf_name = 'examples.orquesta-test-jinja-version-functions' + wf_name = "examples.orquesta-test-jinja-version-functions" expected_output = { - 'compare_equal': 0, - 'compare_more_than': -1, - 'compare_less_than': 1, - 'equal': True, - 'more_than': False, - 'less_than': False, - 'match': True, - 'bump_major': '1.0.0', - 'bump_minor': '0.11.0', - 'bump_patch': '0.10.1', - 'strip_patch': '0.10' + "compare_equal": 0, + "compare_more_than": -1, + "compare_less_than": 1, + "equal": True, + "more_than": False, + "less_than": False, + "match": True, + "bump_major": "1.0.0", + "bump_minor": "0.11.0", + "bump_patch": "0.10.1", + "strip_patch": "0.10", } - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} - self._execute_workflow(wf_name, execute_async=False, expected_result=expected_result) + self._execute_workflow( + wf_name, execute_async=False, expected_result=expected_result + ) diff --git a/st2tests/integration/orquesta/test_wiring_functions_st2kv.py b/st2tests/integration/orquesta/test_wiring_functions_st2kv.py index d02b8594c4f..e4384c72cdb 100644 --- a/st2tests/integration/orquesta/test_wiring_functions_st2kv.py +++ b/st2tests/integration/orquesta/test_wiring_functions_st2kv.py @@ -21,90 +21,76 @@ class DatastoreFunctionTest(base.TestWorkflowExecution): @classmethod - def set_kvp(cls, name, value, scope='system', secret=False): + def set_kvp(cls, name, value, scope="system", secret=False): kvp = models.KeyValuePair( - id=name, - name=name, - value=value, - scope=scope, - secret=secret + id=name, name=name, value=value, scope=scope, secret=secret ) cls.st2client.keys.update(kvp) @classmethod - def del_kvp(cls, name, scope='system'): - kvp = models.KeyValuePair( - id=name, - name=name, - scope=scope - ) + def del_kvp(cls, name, scope="system"): + kvp = models.KeyValuePair(id=name, name=name, scope=scope) cls.st2client.keys.delete(kvp) def test_st2kv_system_scope(self): - key = 'lakshmi' - value = 'kanahansnasnasdlsajks' + key = "lakshmi" + value = "kanahansnasnasdlsajks" self.set_kvp(key, value) - wf_name = 'examples.orquesta-st2kv' - wf_input = {'key_name': 'system.%s' % key} + wf_name = "examples.orquesta-st2kv" + wf_input = {"key_name": "system.%s" % key} execution = self._execute_workflow(wf_name, wf_input) output = self._wait_for_completion(execution) self.assertEqual(output.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) - self.assertIn('output', output.result) - self.assertIn('value', output.result['output']) - self.assertEqual(value, output.result['output']['value']) + self.assertIn("output", output.result) + self.assertIn("value", output.result["output"]) + self.assertEqual(value, output.result["output"]["value"]) self.del_kvp(key) def test_st2kv_user_scope(self): - key = 'winson' - value = 'SoDiamondEng' + key = "winson" + value = "SoDiamondEng" - self.set_kvp(key, value, 'user') - wf_name = 'examples.orquesta-st2kv' - wf_input = {'key_name': key} + self.set_kvp(key, value, "user") + wf_name = "examples.orquesta-st2kv" + wf_input = {"key_name": key} execution = self._execute_workflow(wf_name, wf_input) output = self._wait_for_completion(execution) self.assertEqual(output.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) - self.assertIn('output', output.result) - self.assertIn('value', output.result['output']) - self.assertEqual(value, output.result['output']['value']) + self.assertIn("output", output.result) + self.assertIn("value", output.result["output"]) + self.assertEqual(value, output.result["output"]["value"]) # self.del_kvp(key) def test_st2kv_decrypt(self): - key = 'kami' - value = 'eggplant' + key = "kami" + value = "eggplant" self.set_kvp(key, value, secret=True) - wf_name = 'examples.orquesta-st2kv' - wf_input = { - 'key_name': 'system.%s' % key, - 'decrypt': True - } + wf_name = "examples.orquesta-st2kv" + wf_input = {"key_name": "system.%s" % key, "decrypt": True} execution = self._execute_workflow(wf_name, wf_input) output = self._wait_for_completion(execution) self.assertEqual(output.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) - self.assertIn('output', output.result) - self.assertIn('value', output.result['output']) - self.assertEqual(value, output.result['output']['value']) + self.assertIn("output", output.result) + self.assertIn("value", output.result["output"]) + self.assertEqual(value, output.result["output"]["value"]) self.del_kvp(key) def test_st2kv_nonexistent(self): - key = 'matt' + key = "matt" - wf_name = 'examples.orquesta-st2kv' - wf_input = { - 'key_name': 'system.%s' % key, - 'decrypt': True - } + wf_name = "examples.orquesta-st2kv" + wf_input = {"key_name": "system.%s" % key, "decrypt": True} execution = self._execute_workflow(wf_name, wf_input) @@ -112,69 +98,71 @@ def test_st2kv_nonexistent(self): self.assertEqual(output.status, ac_const.LIVEACTION_STATUS_FAILED) - expected_error = 'The key "%s" does not exist in the StackStorm datastore.' % key + expected_error = ( + 'The key "%s" does not exist in the StackStorm datastore.' % key + ) - self.assertIn(expected_error, output.result['errors'][0]['message']) + self.assertIn(expected_error, output.result["errors"][0]["message"]) def test_st2kv_default_value(self): - key = 'matt' + key = "matt" - wf_name = 'examples.orquesta-st2kv-default' - wf_input = { - 'key_name': 'system.%s' % key, - 'decrypt': True, - 'default': 'stone' - } + wf_name = "examples.orquesta-st2kv-default" + wf_input = {"key_name": "system.%s" % key, "decrypt": True, "default": "stone"} execution = self._execute_workflow(wf_name, wf_input) output = self._wait_for_completion(execution) self.assertEqual(output.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) - self.assertIn('output', output.result) - self.assertIn('value_from_yaql', output.result['output']) - self.assertEqual(wf_input['default'], output.result['output']['value_from_yaql']) - self.assertIn('value_from_jinja', output.result['output']) - self.assertEqual(wf_input['default'], output.result['output']['value_from_jinja']) + self.assertIn("output", output.result) + self.assertIn("value_from_yaql", output.result["output"]) + self.assertEqual( + wf_input["default"], output.result["output"]["value_from_yaql"] + ) + self.assertIn("value_from_jinja", output.result["output"]) + self.assertEqual( + wf_input["default"], output.result["output"]["value_from_jinja"] + ) def test_st2kv_default_value_with_empty_string(self): - key = 'matt' + key = "matt" - wf_name = 'examples.orquesta-st2kv-default' - wf_input = { - 'key_name': 'system.%s' % key, - 'decrypt': True, - 'default': '' - } + wf_name = "examples.orquesta-st2kv-default" + wf_input = {"key_name": "system.%s" % key, "decrypt": True, "default": ""} execution = self._execute_workflow(wf_name, wf_input) output = self._wait_for_completion(execution) self.assertEqual(output.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) - self.assertIn('output', output.result) - self.assertIn('value_from_yaql', output.result['output']) - self.assertEqual(wf_input['default'], output.result['output']['value_from_yaql']) - self.assertIn('value_from_jinja', output.result['output']) - self.assertEqual(wf_input['default'], output.result['output']['value_from_jinja']) + self.assertIn("output", output.result) + self.assertIn("value_from_yaql", output.result["output"]) + self.assertEqual( + wf_input["default"], output.result["output"]["value_from_yaql"] + ) + self.assertIn("value_from_jinja", output.result["output"]) + self.assertEqual( + wf_input["default"], output.result["output"]["value_from_jinja"] + ) def test_st2kv_default_value_with_null(self): - key = 'matt' + key = "matt" - wf_name = 'examples.orquesta-st2kv-default' - wf_input = { - 'key_name': 'system.%s' % key, - 'decrypt': True, - 'default': None - } + wf_name = "examples.orquesta-st2kv-default" + wf_input = {"key_name": "system.%s" % key, "decrypt": True, "default": None} execution = self._execute_workflow(wf_name, wf_input) output = self._wait_for_completion(execution) self.assertEqual(output.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) - self.assertIn('output', output.result) - self.assertIn('value_from_yaql', output.result['output']) - self.assertEqual(wf_input['default'], output.result['output']['value_from_yaql']) - self.assertIn('value_from_jinja', output.result['output']) - self.assertEqual(wf_input['default'], output.result['output']['value_from_jinja']) + self.assertIn("output", output.result) + self.assertIn("value_from_yaql", output.result["output"]) + self.assertEqual( + wf_input["default"], output.result["output"]["value_from_yaql"] + ) + self.assertIn("value_from_jinja", output.result["output"]) + self.assertEqual( + wf_input["default"], output.result["output"]["value_from_jinja"] + ) diff --git a/st2tests/integration/orquesta/test_wiring_functions_task.py b/st2tests/integration/orquesta/test_wiring_functions_task.py index 990b86752c8..35d002c885d 100644 --- a/st2tests/integration/orquesta/test_wiring_functions_task.py +++ b/st2tests/integration/orquesta/test_wiring_functions_task.py @@ -21,91 +21,94 @@ class FunctionsWiringTest(base.TestWorkflowExecution): - def test_task_functions_in_yaql(self): - wf_name = 'examples.orquesta-test-yaql-task-functions' + wf_name = "examples.orquesta-test-yaql-task-functions" expected_output = { - 'last_task4_result': 'False', - 'task9__1__parent': 'task8__1', - 'task9__2__parent': 'task8__2', - 'that_task_by_name': 'task1', - 'this_task_by_name': 'task1', - 'this_task_no_arg': 'task1' + "last_task4_result": "False", + "task9__1__parent": "task8__1", + "task9__2__parent": "task8__2", + "that_task_by_name": "task1", + "this_task_by_name": "task1", + "this_task_no_arg": "task1", } - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} - self._execute_workflow(wf_name, execute_async=False, expected_result=expected_result) + self._execute_workflow( + wf_name, execute_async=False, expected_result=expected_result + ) def test_task_functions_in_jinja(self): - wf_name = 'examples.orquesta-test-jinja-task-functions' + wf_name = "examples.orquesta-test-jinja-task-functions" expected_output = { - 'last_task4_result': 'False', - 'task9__1__parent': 'task8__1', - 'task9__2__parent': 'task8__2', - 'that_task_by_name': 'task1', - 'this_task_by_name': 'task1', - 'this_task_no_arg': 'task1' + "last_task4_result": "False", + "task9__1__parent": "task8__1", + "task9__2__parent": "task8__2", + "that_task_by_name": "task1", + "this_task_by_name": "task1", + "this_task_no_arg": "task1", } - expected_result = {'output': expected_output} + expected_result = {"output": expected_output} - self._execute_workflow(wf_name, execute_async=False, expected_result=expected_result) + self._execute_workflow( + wf_name, execute_async=False, expected_result=expected_result + ) def test_task_nonexistent_in_yaql(self): - wf_name = 'examples.orquesta-test-yaql-task-nonexistent' + wf_name = "examples.orquesta-test-yaql-task-nonexistent" expected_output = None expected_errors = [ { - 'type': 'error', - 'message': ( - 'YaqlEvaluationException: Unable to evaluate expression ' - '\'<% task("task0") %>\'. ExpressionEvaluationException: ' + "type": "error", + "message": ( + "YaqlEvaluationException: Unable to evaluate expression " + "'<% task(\"task0\") %>'. ExpressionEvaluationException: " 'Unable to find task execution for "task0".' ), - 'task_transition_id': 'continue__t0', - 'task_id': 'task1', - 'route': 0 + "task_transition_id": "continue__t0", + "task_id": "task1", + "route": 0, } ] - expected_result = {'output': expected_output, 'errors': expected_errors} + expected_result = {"output": expected_output, "errors": expected_errors} self._execute_workflow( wf_name, execute_async=False, expected_status=action_constants.LIVEACTION_STATUS_FAILED, - expected_result=expected_result + expected_result=expected_result, ) def test_task_nonexistent_in_jinja(self): - wf_name = 'examples.orquesta-test-jinja-task-nonexistent' + wf_name = "examples.orquesta-test-jinja-task-nonexistent" expected_output = None expected_errors = [ { - 'type': 'error', - 'message': ( - 'JinjaEvaluationException: Unable to evaluate expression ' - '\'{{ task("task0") }}\'. ExpressionEvaluationException: ' + "type": "error", + "message": ( + "JinjaEvaluationException: Unable to evaluate expression " + "'{{ task(\"task0\") }}'. ExpressionEvaluationException: " 'Unable to find task execution for "task0".' ), - 'task_transition_id': 'continue__t0', - 'task_id': 'task1', - 'route': 0 + "task_transition_id": "continue__t0", + "task_id": "task1", + "route": 0, } ] - expected_result = {'output': expected_output, 'errors': expected_errors} + expected_result = {"output": expected_output, "errors": expected_errors} self._execute_workflow( wf_name, execute_async=False, expected_status=action_constants.LIVEACTION_STATUS_FAILED, - expected_result=expected_result + expected_result=expected_result, ) diff --git a/st2tests/integration/orquesta/test_wiring_inquiry.py b/st2tests/integration/orquesta/test_wiring_inquiry.py index 71d0ed9e96c..688929c0415 100644 --- a/st2tests/integration/orquesta/test_wiring_inquiry.py +++ b/st2tests/integration/orquesta/test_wiring_inquiry.py @@ -23,75 +23,88 @@ class InquiryWiringTest(base.TestWorkflowExecution): - def test_basic_inquiry(self): # Launch the workflow. The workflow will paused at the pending task. - ex = self._execute_workflow('examples.orquesta-ask-basic') + ex = self._execute_workflow("examples.orquesta-ask-basic") ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_PAUSED) # Respond to the inquiry. - ac_exs = self._wait_for_task(ex, 'get_approval', ac_const.LIVEACTION_STATUS_PENDING) - self.st2client.inquiries.respond(ac_exs[0].id, {'approved': True}) + ac_exs = self._wait_for_task( + ex, "get_approval", ac_const.LIVEACTION_STATUS_PENDING + ) + self.st2client.inquiries.respond(ac_exs[0].id, {"approved": True}) # Wait for completion. ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) def test_consecutive_inquiries(self): # Launch the workflow. The workflow will paused at the pending task. - ex = self._execute_workflow('examples.orquesta-ask-consecutive') + ex = self._execute_workflow("examples.orquesta-ask-consecutive") ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_PAUSED) # Respond to the first inquiry. - t1_ac_exs = self._wait_for_task(ex, 'get_approval', ac_const.LIVEACTION_STATUS_PENDING) - self.st2client.inquiries.respond(t1_ac_exs[0].id, {'approved': True}) + t1_ac_exs = self._wait_for_task( + ex, "get_approval", ac_const.LIVEACTION_STATUS_PENDING + ) + self.st2client.inquiries.respond(t1_ac_exs[0].id, {"approved": True}) # Wait for the workflow to pause again. ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_PAUSED) # Respond to the second inquiry. - t2_ac_exs = self._wait_for_task(ex, 'get_confirmation', ac_const.LIVEACTION_STATUS_PENDING) - self.st2client.inquiries.respond(t2_ac_exs[0].id, {'approved': True}) + t2_ac_exs = self._wait_for_task( + ex, "get_confirmation", ac_const.LIVEACTION_STATUS_PENDING + ) + self.st2client.inquiries.respond(t2_ac_exs[0].id, {"approved": True}) # Wait for completion. ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) def test_parallel_inquiries(self): # Launch the workflow. The workflow will paused at the pending task. - ex = self._execute_workflow('examples.orquesta-ask-parallel') + ex = self._execute_workflow("examples.orquesta-ask-parallel") ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_PAUSED) # Respond to the first inquiry. - t1_ac_exs = self._wait_for_task(ex, 'ask_jack', ac_const.LIVEACTION_STATUS_PENDING) - self.st2client.inquiries.respond(t1_ac_exs[0].id, {'approved': True}) - t1_ac_exs = self._wait_for_task(ex, 'ask_jack', ac_const.LIVEACTION_STATUS_SUCCEEDED) + t1_ac_exs = self._wait_for_task( + ex, "ask_jack", ac_const.LIVEACTION_STATUS_PENDING + ) + self.st2client.inquiries.respond(t1_ac_exs[0].id, {"approved": True}) + t1_ac_exs = self._wait_for_task( + ex, "ask_jack", ac_const.LIVEACTION_STATUS_SUCCEEDED + ) # Allow some time for the first inquiry to get processed. eventlet.sleep(1) # Respond to the second inquiry. - t2_ac_exs = self._wait_for_task(ex, 'ask_jill', ac_const.LIVEACTION_STATUS_PENDING) - self.st2client.inquiries.respond(t2_ac_exs[0].id, {'approved': True}) - t2_ac_exs = self._wait_for_task(ex, 'ask_jill', ac_const.LIVEACTION_STATUS_SUCCEEDED) + t2_ac_exs = self._wait_for_task( + ex, "ask_jill", ac_const.LIVEACTION_STATUS_PENDING + ) + self.st2client.inquiries.respond(t2_ac_exs[0].id, {"approved": True}) + t2_ac_exs = self._wait_for_task( + ex, "ask_jill", ac_const.LIVEACTION_STATUS_SUCCEEDED + ) # Wait for completion. ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) def test_nested_inquiry(self): # Launch the workflow. The workflow will paused at the pending task. - ex = self._execute_workflow('examples.orquesta-ask-nested') + ex = self._execute_workflow("examples.orquesta-ask-nested") ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_PAUSED) # Get the action execution of the subworkflow - ac_exs = self._wait_for_task(ex, 'get_approval', ac_const.LIVEACTION_STATUS_PAUSED) + ac_exs = self._wait_for_task( + ex, "get_approval", ac_const.LIVEACTION_STATUS_PAUSED + ) # Respond to the inquiry in the subworkflow. t2_t2_ac_exs = self._wait_for_task( - ac_exs[0], - 'get_approval', - ac_const.LIVEACTION_STATUS_PENDING + ac_exs[0], "get_approval", ac_const.LIVEACTION_STATUS_PENDING ) - self.st2client.inquiries.respond(t2_t2_ac_exs[0].id, {'approved': True}) + self.st2client.inquiries.respond(t2_t2_ac_exs[0].id, {"approved": True}) # Wait for completion. ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) diff --git a/st2tests/integration/orquesta/test_wiring_pause_and_resume.py b/st2tests/integration/orquesta/test_wiring_pause_and_resume.py index 52eca1490f8..9779ee26b88 100644 --- a/st2tests/integration/orquesta/test_wiring_pause_and_resume.py +++ b/st2tests/integration/orquesta/test_wiring_pause_and_resume.py @@ -22,7 +22,9 @@ from st2common.constants import action as ac_const -class PauseResumeWiringTest(base.TestWorkflowExecution, base.WorkflowControlTestCaseMixin): +class PauseResumeWiringTest( + base.TestWorkflowExecution, base.WorkflowControlTestCaseMixin +): temp_file_path_x = None temp_file_path_y = None @@ -47,9 +49,9 @@ def test_pause_and_resume(self): self.assertTrue(os.path.exists(path)) # Launch the workflow. The workflow will wait for the temp file to be deleted. - params = {'tempfile': path} - ex = self._execute_workflow('examples.orquesta-test-pause', params) - self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_RUNNING) + params = {"tempfile": path} + ex = self._execute_workflow("examples.orquesta-test-pause", params) + self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING) # Cancel the workflow before the temp file is deleted. The workflow will be paused # but task1 will still be running to allow for graceful exit. @@ -77,10 +79,10 @@ def test_pause_and_resume_cascade_to_subworkflow(self): self.assertTrue(os.path.exists(path)) # Launch the workflow. The workflow will wait for the temp file to be deleted. - params = {'tempfile': path} - ex = self._execute_workflow('examples.orquesta-test-pause-subworkflow', params) + params = {"tempfile": path} + ex = self._execute_workflow("examples.orquesta-test-pause-subworkflow", params) ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_RUNNING) - tk_exs = self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_RUNNING) + tk_exs = self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING) # Pause the workflow before the temp file is deleted. The workflow will be paused # but task1 will still be running to allow for graceful exit. @@ -113,11 +115,11 @@ def test_pause_and_resume_cascade_to_subworkflows(self): self.assertTrue(os.path.exists(path2)) # Launch the workflow. The workflow will wait for the temp file to be deleted. - params = {'file1': path1, 'file2': path2} - ex = self._execute_workflow('examples.orquesta-test-pause-subworkflows', params) + params = {"file1": path1, "file2": path2} + ex = self._execute_workflow("examples.orquesta-test-pause-subworkflows", params) ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_RUNNING) - tk1_exs = self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_RUNNING) - tk2_exs = self._wait_for_task(ex, 'task2', ac_const.LIVEACTION_STATUS_RUNNING) + tk1_exs = self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING) + tk2_exs = self._wait_for_task(ex, "task2", ac_const.LIVEACTION_STATUS_RUNNING) # Pause the workflow before the temp files are deleted. The workflow will be paused # but task1 will still be running to allow for graceful exit. @@ -150,8 +152,12 @@ def test_pause_and_resume_cascade_to_subworkflows(self): ex = self.st2client.executions.resume(ex.id) # Wait for completion. - tk1_ac_ex = self._wait_for_state(tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) - tk2_ac_ex = self._wait_for_state(tk2_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) + tk1_ac_ex = self._wait_for_state( + tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED + ) + tk2_ac_ex = self._wait_for_state( + tk2_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED + ) ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) def test_pause_and_resume_cascade_from_subworkflow(self): @@ -160,10 +166,10 @@ def test_pause_and_resume_cascade_from_subworkflow(self): self.assertTrue(os.path.exists(path)) # Launch the workflow. The workflow will wait for the temp file to be deleted. - params = {'tempfile': path} - ex = self._execute_workflow('examples.orquesta-test-pause-subworkflow', params) + params = {"tempfile": path} + ex = self._execute_workflow("examples.orquesta-test-pause-subworkflow", params) ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_RUNNING) - tk_exs = self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_RUNNING) + tk_exs = self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING) # Pause the subworkflow before the temp file is deleted. The task will be # paused but workflow will still be running. @@ -188,7 +194,9 @@ def test_pause_and_resume_cascade_from_subworkflow(self): tk_ac_ex = self._wait_for_state(tk_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) - def test_pause_from_1_of_2_subworkflows_and_resume_subworkflow_when_workflow_paused(self): + def test_pause_from_1_of_2_subworkflows_and_resume_subworkflow_when_workflow_paused( + self, + ): # Temp files are created during test setup. Ensure the temp files exist. path1 = self.temp_file_path_x self.assertTrue(os.path.exists(path1)) @@ -196,11 +204,11 @@ def test_pause_from_1_of_2_subworkflows_and_resume_subworkflow_when_workflow_pau self.assertTrue(os.path.exists(path2)) # Launch the workflow. The workflow will wait for the temp file to be deleted. - params = {'file1': path1, 'file2': path2} - ex = self._execute_workflow('examples.orquesta-test-pause-subworkflows', params) + params = {"file1": path1, "file2": path2} + ex = self._execute_workflow("examples.orquesta-test-pause-subworkflows", params) ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_RUNNING) - tk1_exs = self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_RUNNING) - tk2_exs = self._wait_for_task(ex, 'task2', ac_const.LIVEACTION_STATUS_RUNNING) + tk1_exs = self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING) + tk2_exs = self._wait_for_task(ex, "task2", ac_const.LIVEACTION_STATUS_RUNNING) # Pause the subworkflow before the temp file is deleted. The task will be # paused but workflow and the other subworkflow will still be running. @@ -228,17 +236,25 @@ def test_pause_from_1_of_2_subworkflows_and_resume_subworkflow_when_workflow_pau # The workflow will now be paused because no other task is running. ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_PAUSED) tk1_ac_ex = self._wait_for_state(tk1_ac_ex, ac_const.LIVEACTION_STATUS_PAUSED) - tk2_ac_ex = self._wait_for_state(tk2_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) + tk2_ac_ex = self._wait_for_state( + tk2_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED + ) # Resume the subworkflow. tk1_ac_ex = self.st2client.executions.resume(tk1_ac_ex.id) # Wait for completion. - tk1_ac_ex = self._wait_for_state(tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) - tk2_ac_ex = self._wait_for_state(tk2_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) + tk1_ac_ex = self._wait_for_state( + tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED + ) + tk2_ac_ex = self._wait_for_state( + tk2_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED + ) ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) - def test_pause_from_1_of_2_subworkflows_and_resume_subworkflow_while_workflow_running(self): + def test_pause_from_1_of_2_subworkflows_and_resume_subworkflow_while_workflow_running( + self, + ): # Temp files are created during test setup. Ensure the temp files exist. path1 = self.temp_file_path_x self.assertTrue(os.path.exists(path1)) @@ -246,11 +262,11 @@ def test_pause_from_1_of_2_subworkflows_and_resume_subworkflow_while_workflow_ru self.assertTrue(os.path.exists(path2)) # Launch the workflow. The workflow will wait for the temp file to be deleted. - params = {'file1': path1, 'file2': path2} - ex = self._execute_workflow('examples.orquesta-test-pause-subworkflows', params) + params = {"file1": path1, "file2": path2} + ex = self._execute_workflow("examples.orquesta-test-pause-subworkflows", params) ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_RUNNING) - tk1_exs = self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_RUNNING) - tk2_exs = self._wait_for_task(ex, 'task2', ac_const.LIVEACTION_STATUS_RUNNING) + tk1_exs = self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING) + tk2_exs = self._wait_for_task(ex, "task2", ac_const.LIVEACTION_STATUS_RUNNING) # Pause the subworkflow before the temp file is deleted. The task will be # paused but workflow and the other subworkflow will still be running. @@ -276,7 +292,9 @@ def test_pause_from_1_of_2_subworkflows_and_resume_subworkflow_while_workflow_ru # The subworkflow will succeed while the other subworkflow is still running. ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_RUNNING) - tk1_ac_ex = self._wait_for_state(tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) + tk1_ac_ex = self._wait_for_state( + tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED + ) tk2_ac_ex = self._wait_for_state(tk2_ac_ex, ac_const.LIVEACTION_STATUS_RUNNING) # Delete the temporary file for the other subworkflow. @@ -284,8 +302,12 @@ def test_pause_from_1_of_2_subworkflows_and_resume_subworkflow_while_workflow_ru self.assertFalse(os.path.exists(path2)) # Wait for completion. - tk1_ac_ex = self._wait_for_state(tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) - tk2_ac_ex = self._wait_for_state(tk2_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) + tk1_ac_ex = self._wait_for_state( + tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED + ) + tk2_ac_ex = self._wait_for_state( + tk2_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED + ) ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) def test_pause_from_all_subworkflows_and_resume_from_subworkflows(self): @@ -296,11 +318,11 @@ def test_pause_from_all_subworkflows_and_resume_from_subworkflows(self): self.assertTrue(os.path.exists(path2)) # Launch the workflow. The workflow will wait for the temp file to be deleted. - params = {'file1': path1, 'file2': path2} - ex = self._execute_workflow('examples.orquesta-test-pause-subworkflows', params) + params = {"file1": path1, "file2": path2} + ex = self._execute_workflow("examples.orquesta-test-pause-subworkflows", params) ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_RUNNING) - tk1_exs = self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_RUNNING) - tk2_exs = self._wait_for_task(ex, 'task2', ac_const.LIVEACTION_STATUS_RUNNING) + tk1_exs = self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING) + tk2_exs = self._wait_for_task(ex, "task2", ac_const.LIVEACTION_STATUS_RUNNING) # Pause the subworkflow before the temp file is deleted. The task will be # paused but workflow and the other subworkflow will still be running. @@ -336,7 +358,9 @@ def test_pause_from_all_subworkflows_and_resume_from_subworkflows(self): tk1_ac_ex = self.st2client.executions.resume(tk1_ac_ex.id) # The subworkflow will succeed while the other subworkflow is still paused. - tk1_ac_ex = self._wait_for_state(tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) + tk1_ac_ex = self._wait_for_state( + tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED + ) tk2_ac_ex = self._wait_for_state(tk2_ac_ex, ac_const.LIVEACTION_STATUS_PAUSED) ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_PAUSED) @@ -344,8 +368,12 @@ def test_pause_from_all_subworkflows_and_resume_from_subworkflows(self): tk2_ac_ex = self.st2client.executions.resume(tk2_ac_ex.id) # Wait for completion. - tk1_ac_ex = self._wait_for_state(tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) - tk2_ac_ex = self._wait_for_state(tk2_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) + tk1_ac_ex = self._wait_for_state( + tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED + ) + tk2_ac_ex = self._wait_for_state( + tk2_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED + ) ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) def test_pause_from_all_subworkflows_and_resume_from_parent_workflow(self): @@ -356,11 +384,11 @@ def test_pause_from_all_subworkflows_and_resume_from_parent_workflow(self): self.assertTrue(os.path.exists(path2)) # Launch the workflow. The workflow will wait for the temp file to be deleted. - params = {'file1': path1, 'file2': path2} - ex = self._execute_workflow('examples.orquesta-test-pause-subworkflows', params) + params = {"file1": path1, "file2": path2} + ex = self._execute_workflow("examples.orquesta-test-pause-subworkflows", params) ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_RUNNING) - tk1_exs = self._wait_for_task(ex, 'task1', ac_const.LIVEACTION_STATUS_RUNNING) - tk2_exs = self._wait_for_task(ex, 'task2', ac_const.LIVEACTION_STATUS_RUNNING) + tk1_exs = self._wait_for_task(ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING) + tk2_exs = self._wait_for_task(ex, "task2", ac_const.LIVEACTION_STATUS_RUNNING) # Pause the subworkflow before the temp file is deleted. The task will be # paused but workflow and the other subworkflow will still be running. @@ -396,6 +424,10 @@ def test_pause_from_all_subworkflows_and_resume_from_parent_workflow(self): ex = self.st2client.executions.resume(ex.id) # Wait for completion. - tk1_ac_ex = self._wait_for_state(tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) - tk2_ac_ex = self._wait_for_state(tk2_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) + tk1_ac_ex = self._wait_for_state( + tk1_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED + ) + tk2_ac_ex = self._wait_for_state( + tk2_ac_ex, ac_const.LIVEACTION_STATUS_SUCCEEDED + ) ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) diff --git a/st2tests/integration/orquesta/test_wiring_rerun.py b/st2tests/integration/orquesta/test_wiring_rerun.py index b7a6de0efe1..2fafee76e40 100644 --- a/st2tests/integration/orquesta/test_wiring_rerun.py +++ b/st2tests/integration/orquesta/test_wiring_rerun.py @@ -43,106 +43,104 @@ def tearDown(self): def test_rerun_workflow(self): path = self.temp_dir_path - with open(path, 'w') as f: - f.write('1') + with open(path, "w") as f: + f.write("1") - params = {'tempfile': path} - ex = self._execute_workflow('examples.orquesta-test-rerun', params) + params = {"tempfile": path} + ex = self._execute_workflow("examples.orquesta-test-rerun", params) ex = self._wait_for_state(ex, action_constants.LIVEACTION_STATUS_FAILED) orig_st2_ex_id = ex.id - orig_wf_ex_id = ex.context['workflow_execution'] + orig_wf_ex_id = ex.context["workflow_execution"] - with open(path, 'w') as f: - f.write('0') + with open(path, "w") as f: + f.write("0") ex = self.st2client.executions.re_run(orig_st2_ex_id) self.assertNotEqual(ex.id, orig_st2_ex_id) ex = self._wait_for_state(ex, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertNotEqual(ex.context['workflow_execution'], orig_wf_ex_id) + self.assertNotEqual(ex.context["workflow_execution"], orig_wf_ex_id) def test_rerun_task(self): path = self.temp_dir_path - with open(path, 'w') as f: - f.write('1') + with open(path, "w") as f: + f.write("1") - params = {'tempfile': path} - ex = self._execute_workflow('examples.orquesta-test-rerun', params) + params = {"tempfile": path} + ex = self._execute_workflow("examples.orquesta-test-rerun", params) ex = self._wait_for_state(ex, action_constants.LIVEACTION_STATUS_FAILED) orig_st2_ex_id = ex.id - orig_wf_ex_id = ex.context['workflow_execution'] + orig_wf_ex_id = ex.context["workflow_execution"] - with open(path, 'w') as f: - f.write('0') + with open(path, "w") as f: + f.write("0") - ex = self.st2client.executions.re_run(orig_st2_ex_id, tasks=['task2']) + ex = self.st2client.executions.re_run(orig_st2_ex_id, tasks=["task2"]) self.assertNotEqual(ex.id, orig_st2_ex_id) ex = self._wait_for_state(ex, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(ex.context['workflow_execution'], orig_wf_ex_id) + self.assertEqual(ex.context["workflow_execution"], orig_wf_ex_id) def test_rerun_task_of_workflow_already_succeeded(self): path = self.temp_dir_path - with open(path, 'w') as f: - f.write('0') + with open(path, "w") as f: + f.write("0") - params = {'tempfile': path} - ex = self._execute_workflow('examples.orquesta-test-rerun', params) + params = {"tempfile": path} + ex = self._execute_workflow("examples.orquesta-test-rerun", params) ex = self._wait_for_state(ex, action_constants.LIVEACTION_STATUS_SUCCEEDED) orig_st2_ex_id = ex.id - orig_wf_ex_id = ex.context['workflow_execution'] + orig_wf_ex_id = ex.context["workflow_execution"] - ex = self.st2client.executions.re_run(orig_st2_ex_id, tasks=['task2']) + ex = self.st2client.executions.re_run(orig_st2_ex_id, tasks=["task2"]) self.assertNotEqual(ex.id, orig_st2_ex_id) ex = self._wait_for_state(ex, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(ex.context['workflow_execution'], orig_wf_ex_id) + self.assertEqual(ex.context["workflow_execution"], orig_wf_ex_id) def test_rerun_and_reset_with_items_task(self): path = self.temp_dir_path - with open(path, 'w') as f: - f.write('1') + with open(path, "w") as f: + f.write("1") - params = {'tempfile': path} - ex = self._execute_workflow('examples.orquesta-test-rerun-with-items', params) + params = {"tempfile": path} + ex = self._execute_workflow("examples.orquesta-test-rerun-with-items", params) ex = self._wait_for_state(ex, action_constants.LIVEACTION_STATUS_FAILED) orig_st2_ex_id = ex.id - orig_wf_ex_id = ex.context['workflow_execution'] + orig_wf_ex_id = ex.context["workflow_execution"] - with open(path, 'w') as f: - f.write('0') + with open(path, "w") as f: + f.write("0") - ex = self.st2client.executions.re_run(orig_st2_ex_id, tasks=['task1']) + ex = self.st2client.executions.re_run(orig_st2_ex_id, tasks=["task1"]) self.assertNotEqual(ex.id, orig_st2_ex_id) ex = self._wait_for_state(ex, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(ex.context['workflow_execution'], orig_wf_ex_id) + self.assertEqual(ex.context["workflow_execution"], orig_wf_ex_id) - children = self.st2client.executions.get_property(ex.id, 'children') + children = self.st2client.executions.get_property(ex.id, "children") self.assertEqual(len(children), 4) def test_rerun_and_resume_with_items_task(self): path = self.temp_dir_path - with open(path, 'w') as f: - f.write('1') + with open(path, "w") as f: + f.write("1") - params = {'tempfile': path} - ex = self._execute_workflow('examples.orquesta-test-rerun-with-items', params) + params = {"tempfile": path} + ex = self._execute_workflow("examples.orquesta-test-rerun-with-items", params) ex = self._wait_for_state(ex, action_constants.LIVEACTION_STATUS_FAILED) orig_st2_ex_id = ex.id - orig_wf_ex_id = ex.context['workflow_execution'] + orig_wf_ex_id = ex.context["workflow_execution"] - with open(path, 'w') as f: - f.write('0') + with open(path, "w") as f: + f.write("0") ex = self.st2client.executions.re_run( - orig_st2_ex_id, - tasks=['task1'], - no_reset=['task1'] + orig_st2_ex_id, tasks=["task1"], no_reset=["task1"] ) self.assertNotEqual(ex.id, orig_st2_ex_id) ex = self._wait_for_state(ex, action_constants.LIVEACTION_STATUS_SUCCEEDED) - self.assertEqual(ex.context['workflow_execution'], orig_wf_ex_id) + self.assertEqual(ex.context["workflow_execution"], orig_wf_ex_id) - children = self.st2client.executions.get_property(ex.id, 'children') + children = self.st2client.executions.get_property(ex.id, "children") self.assertEqual(len(children), 2) diff --git a/st2tests/integration/orquesta/test_wiring_task_retry.py b/st2tests/integration/orquesta/test_wiring_task_retry.py index c8d3bd18898..7bb7f3f2580 100644 --- a/st2tests/integration/orquesta/test_wiring_task_retry.py +++ b/st2tests/integration/orquesta/test_wiring_task_retry.py @@ -23,9 +23,8 @@ class TaskRetryWiringTest(base.TestWorkflowExecution): - def test_task_retry(self): - wf_name = 'examples.orquesta-task-retry' + wf_name = "examples.orquesta-task-retry" ex = self._execute_workflow(wf_name) ex = self._wait_for_completion(ex) @@ -34,14 +33,15 @@ def test_task_retry(self): # Assert there are retries for the task. task_exs = [ - task_ex for task_ex in self._get_children(ex) - if task_ex.context.get('orquesta', {}).get('task_name', '') == 'check' + task_ex + for task_ex in self._get_children(ex) + if task_ex.context.get("orquesta", {}).get("task_name", "") == "check" ] self.assertGreater(len(task_exs), 1) def test_task_retry_exhausted(self): - wf_name = 'examples.orquesta-task-retry-exhausted' + wf_name = "examples.orquesta-task-retry-exhausted" ex = self._execute_workflow(wf_name) ex = self._wait_for_completion(ex) @@ -51,16 +51,18 @@ def test_task_retry_exhausted(self): # Assert the task has exhausted the number of retries task_exs = [ - task_ex for task_ex in self._get_children(ex) - if task_ex.context.get('orquesta', {}).get('task_name', '') == 'check' + task_ex + for task_ex in self._get_children(ex) + if task_ex.context.get("orquesta", {}).get("task_name", "") == "check" ] - self.assertListEqual(['failed'] * 3, [task_ex.status for task_ex in task_exs]) + self.assertListEqual(["failed"] * 3, [task_ex.status for task_ex in task_exs]) # Assert the task following the retry task is not run. task_exs = [ - task_ex for task_ex in self._get_children(ex) - if task_ex.context.get('orquesta', {}).get('task_name', '') == 'delete' + task_ex + for task_ex in self._get_children(ex) + if task_ex.context.get("orquesta", {}).get("task_name", "") == "delete" ] self.assertEqual(len(task_exs), 0) diff --git a/st2tests/integration/orquesta/test_wiring_with_items.py b/st2tests/integration/orquesta/test_wiring_with_items.py index b80e04e7025..0bf83f1bf1a 100644 --- a/st2tests/integration/orquesta/test_wiring_with_items.py +++ b/st2tests/integration/orquesta/test_wiring_with_items.py @@ -40,14 +40,14 @@ def tearDown(self): super(WithItemsWiringTest, self).tearDown() def test_with_items(self): - wf_name = 'examples.orquesta-with-items' + wf_name = "examples.orquesta-with-items" - members = ['Lakshmi', 'Lindsay', 'Tomaz', 'Matt', 'Drew'] - wf_input = {'members': members} + members = ["Lakshmi", "Lindsay", "Tomaz", "Matt", "Drew"] + wf_input = {"members": members} - message = '%s, resistance is futile!' - expected_output = {'items': [message % i for i in members]} - expected_result = {'output': expected_output} + message = "%s, resistance is futile!" + expected_output = {"items": [message % i for i in members]} + expected_result = {"output": expected_output} ex = self._execute_workflow(wf_name, wf_input) ex = self._wait_for_completion(ex) @@ -56,17 +56,17 @@ def test_with_items(self): self.assertDictEqual(ex.result, expected_result) def test_with_items_failure(self): - wf_name = 'examples.orquesta-test-with-items-failure' + wf_name = "examples.orquesta-test-with-items-failure" ex = self._execute_workflow(wf_name) ex = self._wait_for_completion(ex) - self._wait_for_task(ex, 'task1', num_task_exs=10) + self._wait_for_task(ex, "task1", num_task_exs=10) self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_FAILED) def test_with_items_concurrency(self): - wf_name = 'examples.orquesta-test-with-items' + wf_name = "examples.orquesta-test-with-items" concurrency = 2 num_items = 5 @@ -74,22 +74,22 @@ def test_with_items_concurrency(self): for i in range(0, num_items): _, f = tempfile.mkstemp() - os.chmod(f, 0o755) # nosec + os.chmod(f, 0o755) # nosec self.tempfiles.append(f) - wf_input = {'tempfiles': self.tempfiles, 'concurrency': concurrency} + wf_input = {"tempfiles": self.tempfiles, "concurrency": concurrency} ex = self._execute_workflow(wf_name, wf_input) ex = self._wait_for_state(ex, [ac_const.LIVEACTION_STATUS_RUNNING]) - self._wait_for_task(ex, 'task1', num_task_exs=2) + self._wait_for_task(ex, "task1", num_task_exs=2) os.remove(self.tempfiles[0]) os.remove(self.tempfiles[1]) - self._wait_for_task(ex, 'task1', num_task_exs=4) + self._wait_for_task(ex, "task1", num_task_exs=4) os.remove(self.tempfiles[2]) os.remove(self.tempfiles[3]) - self._wait_for_task(ex, 'task1', num_task_exs=5) + self._wait_for_task(ex, "task1", num_task_exs=5) os.remove(self.tempfiles[4]) ex = self._wait_for_completion(ex) @@ -97,7 +97,7 @@ def test_with_items_concurrency(self): self.assertEqual(ex.status, ac_const.LIVEACTION_STATUS_SUCCEEDED) def test_with_items_cancellation(self): - wf_name = 'examples.orquesta-test-with-items' + wf_name = "examples.orquesta-test-with-items" concurrency = 2 num_items = 2 @@ -105,19 +105,16 @@ def test_with_items_cancellation(self): for i in range(0, num_items): _, f = tempfile.mkstemp() - os.chmod(f, 0o755) # nosec + os.chmod(f, 0o755) # nosec self.tempfiles.append(f) - wf_input = {'tempfiles': self.tempfiles, 'concurrency': concurrency} + wf_input = {"tempfiles": self.tempfiles, "concurrency": concurrency} ex = self._execute_workflow(wf_name, wf_input) ex = self._wait_for_state(ex, [ac_const.LIVEACTION_STATUS_RUNNING]) # Wait for action executions to run. self._wait_for_task( - ex, - 'task1', - ac_const.LIVEACTION_STATUS_RUNNING, - num_task_exs=concurrency + ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING, num_task_exs=concurrency ) # Cancel the workflow execution. @@ -133,17 +130,14 @@ def test_with_items_cancellation(self): # Task is completed successfully for graceful exit. self._wait_for_task( - ex, - 'task1', - ac_const.LIVEACTION_STATUS_SUCCEEDED, - num_task_exs=concurrency + ex, "task1", ac_const.LIVEACTION_STATUS_SUCCEEDED, num_task_exs=concurrency ) # Wait for the ex to be canceled. ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_CANCELED) def test_with_items_concurrency_cancellation(self): - wf_name = 'examples.orquesta-test-with-items' + wf_name = "examples.orquesta-test-with-items" concurrency = 2 num_items = 4 @@ -151,19 +145,16 @@ def test_with_items_concurrency_cancellation(self): for i in range(0, num_items): _, f = tempfile.mkstemp() - os.chmod(f, 0o755) # nosec + os.chmod(f, 0o755) # nosec self.tempfiles.append(f) - wf_input = {'tempfiles': self.tempfiles, 'concurrency': concurrency} + wf_input = {"tempfiles": self.tempfiles, "concurrency": concurrency} ex = self._execute_workflow(wf_name, wf_input) ex = self._wait_for_state(ex, [ac_const.LIVEACTION_STATUS_RUNNING]) # Wait for action executions to run. self._wait_for_task( - ex, - 'task1', - ac_const.LIVEACTION_STATUS_RUNNING, - num_task_exs=concurrency + ex, "task1", ac_const.LIVEACTION_STATUS_RUNNING, num_task_exs=concurrency ) # Cancel the workflow execution. @@ -180,27 +171,24 @@ def test_with_items_concurrency_cancellation(self): # Task is completed successfully for graceful exit. self._wait_for_task( - ex, - 'task1', - ac_const.LIVEACTION_STATUS_SUCCEEDED, - num_task_exs=concurrency + ex, "task1", ac_const.LIVEACTION_STATUS_SUCCEEDED, num_task_exs=concurrency ) # Wait for the ex to be canceled. ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_CANCELED) def test_with_items_pause_and_resume(self): - wf_name = 'examples.orquesta-test-with-items' + wf_name = "examples.orquesta-test-with-items" num_items = 2 self.tempfiles = [] for i in range(0, num_items): _, f = tempfile.mkstemp() - os.chmod(f, 0o755) # nosec + os.chmod(f, 0o755) # nosec self.tempfiles.append(f) - wf_input = {'tempfiles': self.tempfiles} + wf_input = {"tempfiles": self.tempfiles} ex = self._execute_workflow(wf_name, wf_input) ex = self._wait_for_state(ex, [ac_const.LIVEACTION_STATUS_RUNNING]) @@ -217,10 +205,7 @@ def test_with_items_pause_and_resume(self): # Wait for action executions for task to succeed. self._wait_for_task( - ex, - 'task1', - ac_const.LIVEACTION_STATUS_SUCCEEDED, - num_task_exs=num_items + ex, "task1", ac_const.LIVEACTION_STATUS_SUCCEEDED, num_task_exs=num_items ) # Wait for the workflow execution to pause. @@ -233,7 +218,7 @@ def test_with_items_pause_and_resume(self): ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) def test_with_items_concurrency_pause_and_resume(self): - wf_name = 'examples.orquesta-test-with-items' + wf_name = "examples.orquesta-test-with-items" concurrency = 2 num_items = 4 @@ -241,10 +226,10 @@ def test_with_items_concurrency_pause_and_resume(self): for i in range(0, num_items): _, f = tempfile.mkstemp() - os.chmod(f, 0o755) # nosec + os.chmod(f, 0o755) # nosec self.tempfiles.append(f) - wf_input = {'tempfiles': self.tempfiles, 'concurrency': concurrency} + wf_input = {"tempfiles": self.tempfiles, "concurrency": concurrency} ex = self._execute_workflow(wf_name, wf_input) ex = self._wait_for_state(ex, [ac_const.LIVEACTION_STATUS_RUNNING]) @@ -261,10 +246,7 @@ def test_with_items_concurrency_pause_and_resume(self): # Wait for action executions for task to succeed. self._wait_for_task( - ex, - 'task1', - ac_const.LIVEACTION_STATUS_SUCCEEDED, - num_task_exs=concurrency + ex, "task1", ac_const.LIVEACTION_STATUS_SUCCEEDED, num_task_exs=concurrency ) # Wait for the workflow execution to pause. @@ -280,17 +262,14 @@ def test_with_items_concurrency_pause_and_resume(self): # Wait for action executions for task to succeed. self._wait_for_task( - ex, - 'task1', - ac_const.LIVEACTION_STATUS_SUCCEEDED, - num_task_exs=num_items + ex, "task1", ac_const.LIVEACTION_STATUS_SUCCEEDED, num_task_exs=num_items ) # Wait for completion. ex = self._wait_for_state(ex, ac_const.LIVEACTION_STATUS_SUCCEEDED) def test_subworkflow_empty_with_items(self): - wf_name = 'examples.orquesta-test-subworkflow-empty-with-items' + wf_name = "examples.orquesta-test-subworkflow-empty-with-items" ex = self._execute_workflow(wf_name) ex = self._wait_for_completion(ex) diff --git a/st2tests/setup.py b/st2tests/setup.py index 3d5947be042..f5e17bb3a3a 100644 --- a/st2tests/setup.py +++ b/st2tests/setup.py @@ -23,10 +23,10 @@ from dist_utils import apply_vagrant_workaround from dist_utils import get_version_string -ST2_COMPONENT = 'st2tests' +ST2_COMPONENT = "st2tests" BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -REQUIREMENTS_FILE = os.path.join(BASE_DIR, 'requirements.txt') -INIT_FILE = os.path.join(BASE_DIR, 'st2tests/__init__.py') +REQUIREMENTS_FILE = os.path.join(BASE_DIR, "requirements.txt") +INIT_FILE = os.path.join(BASE_DIR, "st2tests/__init__.py") install_reqs, dep_links = fetch_requirements(REQUIREMENTS_FILE) @@ -39,15 +39,17 @@ setup( name=ST2_COMPONENT, version=get_version_string(INIT_FILE), - description='{} StackStorm event-driven automation platform component'.format(ST2_COMPONENT), - author='StackStorm', - author_email='info@stackstorm.com', - license='Apache License (2.0)', - url='https://stackstorm.com/', + description="{} StackStorm event-driven automation platform component".format( + ST2_COMPONENT + ), + author="StackStorm", + author_email="info@stackstorm.com", + license="Apache License (2.0)", + url="https://stackstorm.com/", install_requires=install_reqs, dependency_links=dep_links, test_suite=ST2_COMPONENT, zip_safe=False, include_package_data=True, - packages=find_packages(exclude=['setuptools', 'tests']) + packages=find_packages(exclude=["setuptools", "tests"]), ) diff --git a/st2tests/st2tests/__init__.py b/st2tests/st2tests/__init__.py index 594f0e2ae1e..d087d05d2d8 100644 --- a/st2tests/st2tests/__init__.py +++ b/st2tests/st2tests/__init__.py @@ -23,11 +23,11 @@ __all__ = [ - 'EventletTestCase', - 'DbTestCase', - 'ExecutionDbTestCase', - 'DbModelTestCase', - 'WorkflowTestCase' + "EventletTestCase", + "DbTestCase", + "ExecutionDbTestCase", + "DbModelTestCase", + "WorkflowTestCase", ] -__version__ = '3.4dev' +__version__ = "3.4dev" diff --git a/st2tests/st2tests/action_aliases.py b/st2tests/st2tests/action_aliases.py index 301fd9a20f0..88f02f96429 100644 --- a/st2tests/st2tests/action_aliases.py +++ b/st2tests/st2tests/action_aliases.py @@ -25,13 +25,13 @@ from st2common.util.pack import get_pack_ref_from_metadata from st2common.exceptions.content import ParseException from st2common.bootstrap.aliasesregistrar import AliasesRegistrar -from st2common.models.utils.action_alias_utils import extract_parameters_for_action_alias_db +from st2common.models.utils.action_alias_utils import ( + extract_parameters_for_action_alias_db, +) from st2common.models.utils.action_alias_utils import extract_parameters from st2tests.pack_resource import BasePackResourceTestCase -__all__ = [ - 'BaseActionAliasTestCase' -] +__all__ = ["BaseActionAliasTestCase"] class BaseActionAliasTestCase(BasePackResourceTestCase): @@ -48,7 +48,9 @@ def setUp(self): if not self.action_alias_name: raise ValueError('"action_alias_name" class attribute needs to be provided') - self.action_alias_db = self._get_action_alias_db_by_name(name=self.action_alias_name) + self.action_alias_db = self._get_action_alias_db_by_name( + name=self.action_alias_name + ) def assertCommandMatchesExactlyOneFormatString(self, format_strings, command): """ @@ -58,19 +60,22 @@ def assertCommandMatchesExactlyOneFormatString(self, format_strings, command): for format_string in format_strings: try: - extract_parameters(format_str=format_string, - param_stream=command) + extract_parameters(format_str=format_string, param_stream=command) except ParseException: continue matched_format_strings.append(format_string) if len(matched_format_strings) == 0: - msg = ('Command "%s" didn\'t match any of the provided format strings' % (command)) + msg = 'Command "%s" didn\'t match any of the provided format strings' % ( + command + ) raise AssertionError(msg) elif len(matched_format_strings) > 1: - msg = ('Command "%s" matched multiple format strings: %s' % - (command, ', '.join(matched_format_strings))) + msg = 'Command "%s" matched multiple format strings: %s' % ( + command, + ", ".join(matched_format_strings), + ) raise AssertionError(msg) def assertExtractedParametersMatch(self, format_string, command, parameters): @@ -83,11 +88,14 @@ def assertExtractedParametersMatch(self, format_string, command, parameters): extracted_params = extract_parameters_for_action_alias_db( action_alias_db=self.action_alias_db, format_str=format_string, - param_stream=command) + param_stream=command, + ) if extracted_params != parameters: - msg = ('Extracted parameters from command string "%s" against format string "%s"' - ' didn\'t match the provided parameters: ' % (command, format_string)) + msg = ( + 'Extracted parameters from command string "%s" against format string "%s"' + " didn't match the provided parameters: " % (command, format_string) + ) # Note: We intercept the exception so we can can include diff for the dictionaries try: @@ -117,13 +125,14 @@ def _get_action_alias_db_by_name(self, name): pack_loader = ContentPackLoader() registrar = AliasesRegistrar(use_pack_cache=False) - aliases_path = pack_loader.get_content_from_pack(pack_dir=base_pack_path, - content_type='aliases') + aliases_path = pack_loader.get_content_from_pack( + pack_dir=base_pack_path, content_type="aliases" + ) aliases = registrar._get_aliases_from_pack(aliases_dir=aliases_path) for alias_path in aliases: - action_alias_db = registrar._get_action_alias_db(pack=pack, - action_alias=alias_path, - ignore_metadata_file_error=True) + action_alias_db = registrar._get_action_alias_db( + pack=pack, action_alias=alias_path, ignore_metadata_file_error=True + ) if action_alias_db.name == name: return action_alias_db diff --git a/st2tests/st2tests/actions.py b/st2tests/st2tests/actions.py index f6026bc8bd5..9caec9bca92 100644 --- a/st2tests/st2tests/actions.py +++ b/st2tests/st2tests/actions.py @@ -19,9 +19,7 @@ from st2tests.mocks.action import MockActionService from st2tests.pack_resource import BasePackResourceTestCase -__all__ = [ - 'BaseActionTestCase' -] +__all__ = ["BaseActionTestCase"] class BaseActionTestCase(BasePackResourceTestCase): @@ -35,7 +33,7 @@ def setUp(self): super(BaseActionTestCase, self).setUp() class_name = self.action_cls.__name__ - action_wrapper = MockActionWrapper(pack='tests', class_name=class_name) + action_wrapper = MockActionWrapper(pack="tests", class_name=class_name) self.action_service = MockActionService(action_wrapper=action_wrapper) def get_action_instance(self, config=None): @@ -43,7 +41,9 @@ def get_action_instance(self, config=None): Retrieve instance of the action class. """ # pylint: disable=not-callable - instance = get_action_class_instance(action_cls=self.action_cls, - config=config, - action_service=self.action_service) + instance = get_action_class_instance( + action_cls=self.action_cls, + config=config, + action_service=self.action_service, + ) return instance diff --git a/st2tests/st2tests/api.py b/st2tests/st2tests/api.py index 3b48df737aa..7000ddd9a11 100644 --- a/st2tests/st2tests/api.py +++ b/st2tests/st2tests/api.py @@ -34,19 +34,19 @@ from st2tests import config as tests_config __all__ = [ - 'BaseFunctionalTest', - - 'FunctionalTest', - 'APIControllerWithIncludeAndExcludeFilterTestCase', - 'BaseInquiryControllerTestCase', - - 'FakeResponse', - 'TestApp' + "BaseFunctionalTest", + "FunctionalTest", + "APIControllerWithIncludeAndExcludeFilterTestCase", + "BaseInquiryControllerTestCase", + "FakeResponse", + "TestApp", ] -SUPER_SECRET_PARAMETER = 'SUPER_SECRET_PARAMETER_THAT_SHOULD_NEVER_APPEAR_IN_RESPONSES_OR_LOGS' -ANOTHER_SUPER_SECRET_PARAMETER = 'ANOTHER_SUPER_SECRET_PARAMETER_TO_TEST_OVERRIDING' +SUPER_SECRET_PARAMETER = ( + "SUPER_SECRET_PARAMETER_THAT_SHOULD_NEVER_APPEAR_IN_RESPONSES_OR_LOGS" +) +ANOTHER_SUPER_SECRET_PARAMETER = "ANOTHER_SUPER_SECRET_PARAMETER_TO_TEST_OVERRIDING" class ResponseValidationError(ValueError): @@ -61,32 +61,37 @@ class TestApp(webtest.TestApp): def do_request(self, req, **kwargs): self.cookiejar.clear() - if req.environ['REQUEST_METHOD'] != 'OPTIONS': + if req.environ["REQUEST_METHOD"] != "OPTIONS": # Making sure endpoint handles OPTIONS method properly - self.options(req.environ['PATH_INFO']) + self.options(req.environ["PATH_INFO"]) res = super(TestApp, self).do_request(req, **kwargs) - if res.headers.get('Warning', None): - raise ResponseValidationError('Endpoint produced invalid response. Make sure the ' - 'response matches OpenAPI scheme for the endpoint.') + if res.headers.get("Warning", None): + raise ResponseValidationError( + "Endpoint produced invalid response. Make sure the " + "response matches OpenAPI scheme for the endpoint." + ) - if not kwargs.get('expect_errors', None): + if not kwargs.get("expect_errors", None): try: body = res.body except AssertionError as e: - if 'Iterator read after closed' in six.text_type(e): - body = b'' + if "Iterator read after closed" in six.text_type(e): + body = b"" else: raise e - if six.b(SUPER_SECRET_PARAMETER) in body or \ - six.b(ANOTHER_SUPER_SECRET_PARAMETER) in body: - raise ResponseLeakError('Endpoint response contains secret parameter. ' - 'Find the leak.') + if ( + six.b(SUPER_SECRET_PARAMETER) in body + or six.b(ANOTHER_SUPER_SECRET_PARAMETER) in body + ): + raise ResponseLeakError( + "Endpoint response contains secret parameter. " "Find the leak." + ) - if 'Access-Control-Allow-Origin' not in res.headers: - raise ResponseValidationError('Response missing a required CORS header') + if "Access-Control-Allow-Origin" not in res.headers: + raise ResponseValidationError("Response missing a required CORS header") return res @@ -113,19 +118,19 @@ def tearDown(self): super(BaseFunctionalTest, self).tearDown() # Reset mock context for API requests - if getattr(self, 'request_context_mock', None): + if getattr(self, "request_context_mock", None): self.request_context_mock.stop() - if hasattr(Router, 'mock_context'): - del(Router.mock_context) + if hasattr(Router, "mock_context"): + del Router.mock_context @classmethod def _do_setUpClass(cls): tests_config.parse_args() - cfg.CONF.set_default('enable', cls.enable_auth, group='auth') + cfg.CONF.set_default("enable", cls.enable_auth, group="auth") - cfg.CONF.set_override(name='enable', override=False, group='rbac') + cfg.CONF.set_override(name="enable", override=False, group="rbac") # TODO(manas) : register action types here for now. RunnerType registration can be moved # to posting to /runnertypes but that implies implementing POST. @@ -142,11 +147,8 @@ def use_user(self, user_db): raise ValueError('"user_db" is mandatory') mock_context = { - 'user': user_db, - 'auth_info': { - 'method': 'authentication token', - 'location': 'header' - } + "user": user_db, + "auth_info": {"method": "authentication token", "location": "header"}, } self.request_context_mock = mock.PropertyMock(return_value=mock_context) Router.mock_context = self.request_context_mock @@ -184,40 +186,48 @@ class APIControllerWithIncludeAndExcludeFilterTestCase(object): # True if those tests are running with rbac enabled rbac_enabled = False - def test_get_all_exclude_attributes_and_include_attributes_are_mutually_exclusive(self): + def test_get_all_exclude_attributes_and_include_attributes_are_mutually_exclusive( + self, + ): if self.rbac_enabled: - self.use_user(self.users['admin']) + self.use_user(self.users["admin"]) - url = self.get_all_path + '?include_attributes=id&exclude_attributes=id' + url = self.get_all_path + "?include_attributes=id&exclude_attributes=id" resp = self.app.get(url, expect_errors=True) self.assertEqual(resp.status_int, 400) - expected_msg = ('exclude.*? and include.*? arguments are mutually exclusive. ' - 'You need to provide either one or another, but not both.') - self.assertRegexpMatches(resp.json['faultstring'], expected_msg) + expected_msg = ( + "exclude.*? and include.*? arguments are mutually exclusive. " + "You need to provide either one or another, but not both." + ) + self.assertRegexpMatches(resp.json["faultstring"], expected_msg) def test_get_all_invalid_exclude_and_include_parameter(self): if self.rbac_enabled: - self.use_user(self.users['admin']) + self.use_user(self.users["admin"]) # 1. Invalid exclude_attributes field - url = self.get_all_path + '?exclude_attributes=invalid_field' + url = self.get_all_path + "?exclude_attributes=invalid_field" resp = self.app.get(url, expect_errors=True) - expected_msg = ('Invalid or unsupported exclude attribute specified: .*invalid_field.*') + expected_msg = ( + "Invalid or unsupported exclude attribute specified: .*invalid_field.*" + ) self.assertEqual(resp.status_int, 400) - self.assertRegexpMatches(resp.json['faultstring'], expected_msg) + self.assertRegexpMatches(resp.json["faultstring"], expected_msg) # 2. Invalid include_attributes field - url = self.get_all_path + '?include_attributes=invalid_field' + url = self.get_all_path + "?include_attributes=invalid_field" resp = self.app.get(url, expect_errors=True) - expected_msg = ('Invalid or unsupported include attribute specified: .*invalid_field.*') + expected_msg = ( + "Invalid or unsupported include attribute specified: .*invalid_field.*" + ) self.assertEqual(resp.status_int, 400) - self.assertRegexpMatches(resp.json['faultstring'], expected_msg) + self.assertRegexpMatches(resp.json["faultstring"], expected_msg) def test_get_all_include_attributes_filter(self): if self.rbac_enabled: - self.use_user(self.users['admin']) + self.use_user(self.users["admin"]) mandatory_include_fields = self.controller_cls.mandatory_include_fields_response @@ -226,8 +236,10 @@ def test_get_all_include_attributes_filter(self): object_ids = self._insert_mock_models() # Valid include attribute - mandatory field which should always be included - resp = self.app.get('%s?include_attributes=%s' % (self.get_all_path, - mandatory_include_fields[0])) + resp = self.app.get( + "%s?include_attributes=%s" + % (self.get_all_path, mandatory_include_fields[0]) + ) self.assertEqual(resp.status_int, 200) self.assertTrue(len(resp.json) >= 1) @@ -245,7 +257,9 @@ def test_get_all_include_attributes_filter(self): include_field = self.include_attribute_field_name assert include_field not in mandatory_include_fields - resp = self.app.get('%s?include_attributes=%s' % (self.get_all_path, include_field)) + resp = self.app.get( + "%s?include_attributes=%s" % (self.get_all_path, include_field) + ) self.assertEqual(resp.status_int, 200) self.assertTrue(len(resp.json) >= 1) @@ -263,7 +277,7 @@ def test_get_all_include_attributes_filter(self): def test_get_all_exclude_attributes_filter(self): if self.rbac_enabled: - self.use_user(self.users['admin']) + self.use_user(self.users["admin"]) # Create any resources needed by those tests (if not already created inside setUp / # setUpClass) @@ -285,8 +299,9 @@ def test_get_all_exclude_attributes_filter(self): # 2. Verify attribute is excluded when filter is provided exclude_attribute = self.exclude_attribute_field_name - resp = self.app.get('%s?exclude_attributes=%s' % (self.get_all_path, - exclude_attribute)) + resp = self.app.get( + "%s?exclude_attributes=%s" % (self.get_all_path, exclude_attribute) + ) self.assertEqual(resp.status_int, 200) self.assertTrue(len(resp.json) >= 1) @@ -300,8 +315,8 @@ def test_get_all_exclude_attributes_filter(self): def assertResponseObjectContainsField(self, resp_item, field): # Handle "." and nested fields - if '.' in field: - split = field.split('.') + if "." in field: + split = field.split(".") for index, field_part in enumerate(split): self.assertIn(field_part, resp_item) @@ -336,7 +351,6 @@ def _do_delete(self, object_id): class FakeResponse(object): - def __init__(self, text, status_code, reason): self.text = text self.status_code = status_code @@ -354,24 +368,27 @@ class BaseActionExecutionControllerTestCase(object): @staticmethod def _get_actionexecution_id(resp): - return resp.json['id'] + return resp.json["id"] @staticmethod def _get_liveaction_id(resp): - return resp.json['liveaction']['id'] + return resp.json["liveaction"]["id"] def _do_get_one(self, actionexecution_id, *args, **kwargs): - return self.app.get('/v1/executions/%s' % actionexecution_id, *args, **kwargs) + return self.app.get("/v1/executions/%s" % actionexecution_id, *args, **kwargs) def _do_post(self, liveaction, *args, **kwargs): - return self.app.post_json('/v1/executions', liveaction, *args, **kwargs) + return self.app.post_json("/v1/executions", liveaction, *args, **kwargs) def _do_delete(self, actionexecution_id, expect_errors=False): - return self.app.delete('/v1/executions/%s' % actionexecution_id, - expect_errors=expect_errors) + return self.app.delete( + "/v1/executions/%s" % actionexecution_id, expect_errors=expect_errors + ) def _do_put(self, actionexecution_id, updates, *args, **kwargs): - return self.app.put_json('/v1/executions/%s' % actionexecution_id, updates, *args, **kwargs) + return self.app.put_json( + "/v1/executions/%s" % actionexecution_id, updates, *args, **kwargs + ) class BaseInquiryControllerTestCase(BaseFunctionalTest, CleanDbTestCase): @@ -380,6 +397,7 @@ class BaseInquiryControllerTestCase(BaseFunctionalTest, CleanDbTestCase): Inherits from CleanDbTestCase to preserve atomicity between tests """ + from st2api import app enable_auth = False @@ -387,26 +405,27 @@ class BaseInquiryControllerTestCase(BaseFunctionalTest, CleanDbTestCase): @staticmethod def _get_inquiry_id(resp): - return resp.json['id'] + return resp.json["id"] def _do_get_execution(self, actionexecution_id, *args, **kwargs): - return self.app.get('/v1/executions/%s' % actionexecution_id, *args, **kwargs) + return self.app.get("/v1/executions/%s" % actionexecution_id, *args, **kwargs) def _do_get_one(self, inquiry_id, *args, **kwargs): - return self.app.get('/v1/inquiries/%s' % inquiry_id, *args, **kwargs) + return self.app.get("/v1/inquiries/%s" % inquiry_id, *args, **kwargs) def _do_get_all(self, limit=50, *args, **kwargs): - return self.app.get('/v1/inquiries/?limit=%s' % limit, *args, **kwargs) + return self.app.get("/v1/inquiries/?limit=%s" % limit, *args, **kwargs) def _do_respond(self, inquiry_id, response, *args, **kwargs): - payload = { - "id": inquiry_id, - "response": response - } - return self.app.put_json('/v1/inquiries/%s' % inquiry_id, payload, *args, **kwargs) + payload = {"id": inquiry_id, "response": response} + return self.app.put_json( + "/v1/inquiries/%s" % inquiry_id, payload, *args, **kwargs + ) - def _do_create_inquiry(self, liveaction, result, status='pending', *args, **kwargs): - post_resp = self.app.post_json('/v1/executions', liveaction, *args, **kwargs) + def _do_create_inquiry(self, liveaction, result, status="pending", *args, **kwargs): + post_resp = self.app.post_json("/v1/executions", liveaction, *args, **kwargs) inquiry_id = self._get_inquiry_id(post_resp) - updates = {'status': status, 'result': result} - return self.app.put_json('/v1/executions/%s' % inquiry_id, updates, *args, **kwargs) + updates = {"status": status, "result": result} + return self.app.put_json( + "/v1/executions/%s" % inquiry_id, updates, *args, **kwargs + ) diff --git a/st2tests/st2tests/base.py b/st2tests/st2tests/base.py index 75a8f7ce025..4a4964763d5 100644 --- a/st2tests/st2tests/base.py +++ b/st2tests/st2tests/base.py @@ -19,6 +19,7 @@ # NOTE: We need to perform monkeypatch before importing ssl module otherwise tests will fail. # See https://github.com/StackStorm/st2/pull/4834 for details from st2common.util.monkey_patch import monkey_patch + monkey_patch() try: @@ -50,6 +51,7 @@ # parse_args when BaseDbTestCase runs class setup. If that is removed, unit tests # will failed due to conflict with duplicate DB keys. import st2tests.config as tests_config + tests_config.parse_args() from st2common.util.api import get_full_public_api_url @@ -95,26 +97,23 @@ __all__ = [ - 'EventletTestCase', - 'DbTestCase', - 'DbModelTestCase', - 'CleanDbTestCase', - 'CleanFilesTestCase', - 'IntegrationTestCase', - 'RunnerTestCase', - 'ExecutionDbTestCase', - 'WorkflowTestCase', - + "EventletTestCase", + "DbTestCase", + "DbModelTestCase", + "CleanDbTestCase", + "CleanFilesTestCase", + "IntegrationTestCase", + "RunnerTestCase", + "ExecutionDbTestCase", + "WorkflowTestCase", # Pack test classes - 'BaseSensorTestCase', - 'BaseActionTestCase', - 'BaseActionAliasTestCase', - - 'get_fixtures_path', - 'get_resources_path', - - 'blocking_eventlet_spawn', - 'make_mock_stream_readline' + "BaseSensorTestCase", + "BaseActionTestCase", + "BaseActionAliasTestCase", + "get_fixtures_path", + "get_resources_path", + "blocking_eventlet_spawn", + "make_mock_stream_readline", ] BASE_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -135,7 +134,7 @@ ALL_MODELS.extend(rule_enforcement_model.MODELS) BASE_DIR = os.path.dirname(os.path.abspath(__file__)) -TESTS_CONFIG_PATH = os.path.join(BASE_DIR, '../conf/st2.conf') +TESTS_CONFIG_PATH = os.path.join(BASE_DIR, "../conf/st2.conf") class RunnerTestCase(unittest2.TestCase): @@ -148,17 +147,15 @@ def assertCommonSt2EnvVarsAvailableInEnv(self, env): """ for var_name in COMMON_ACTION_ENV_VARIABLES: self.assertIn(var_name, env) - self.assertEqual(env['ST2_ACTION_API_URL'], get_full_public_api_url()) + self.assertEqual(env["ST2_ACTION_API_URL"], get_full_public_api_url()) self.assertIsNotNone(env[AUTH_TOKEN_ENV_VARIABLE_NAME]) def loader(self, path): - """ Load the runner config - """ + """Load the runner config""" return self.meta_loader.load(path) class BaseTestCase(TestCase): - @classmethod def _register_packs(self): """ @@ -173,7 +170,9 @@ def _register_pack_configs(self, validate_configs=False): """ Register all the packs inside the fixtures directory. """ - registrar = ConfigsRegistrar(use_pack_cache=False, validate_configs=validate_configs) + registrar = ConfigsRegistrar( + use_pack_cache=False, validate_configs=validate_configs + ) registrar.register_from_packs(base_dirs=get_packs_base_paths()) @@ -189,18 +188,14 @@ def setUpClass(cls): os=True, select=True, socket=True, - thread=False if '--use-debugger' in sys.argv else True, - time=True + thread=False if "--use-debugger" in sys.argv else True, + time=True, ) @classmethod def tearDownClass(cls): eventlet.monkey_patch( - os=False, - select=False, - socket=False, - thread=False, - time=False + os=False, select=False, socket=False, thread=False, time=False ) @@ -222,17 +217,29 @@ def setUpClass(cls): tests_config.parse_args() if cls.DISPLAY_LOG_MESSAGES: - config_path = os.path.join(BASE_DIR, '../conf/logging.conf') - logging.config.fileConfig(config_path, - disable_existing_loggers=False) + config_path = os.path.join(BASE_DIR, "../conf/logging.conf") + logging.config.fileConfig(config_path, disable_existing_loggers=False) @classmethod def _establish_connection_and_re_create_db(cls): - username = cfg.CONF.database.username if hasattr(cfg.CONF.database, 'username') else None - password = cfg.CONF.database.password if hasattr(cfg.CONF.database, 'password') else None + username = ( + cfg.CONF.database.username + if hasattr(cfg.CONF.database, "username") + else None + ) + password = ( + cfg.CONF.database.password + if hasattr(cfg.CONF.database, "password") + else None + ) cls.db_connection = db_setup( - cfg.CONF.database.db_name, cfg.CONF.database.host, cfg.CONF.database.port, - username=username, password=password, ensure_indexes=False) + cfg.CONF.database.db_name, + cfg.CONF.database.host, + cfg.CONF.database.port, + username=username, + password=password, + ensure_indexes=False, + ) cls._drop_collections() cls.db_connection.drop_database(cfg.CONF.database.db_name) @@ -242,12 +249,17 @@ def _establish_connection_and_re_create_db(cls): # NOTE: This is only needed in distributed scenarios (production deployments) where # multiple services can start up at the same time and race conditions are possible. if cls.ensure_indexes: - if len(cls.ensure_indexes_models) == 0 or len(cls.ensure_indexes_models) > 1: - msg = ('Ensuring indexes for all the models, this could significantly slow down ' - 'the tests') - print('#' * len(msg), file=sys.stderr) + if ( + len(cls.ensure_indexes_models) == 0 + or len(cls.ensure_indexes_models) > 1 + ): + msg = ( + "Ensuring indexes for all the models, this could significantly slow down " + "the tests" + ) + print("#" * len(msg), file=sys.stderr) print(msg, file=sys.stderr) - print('#' * len(msg), file=sys.stderr) + print("#" * len(msg), file=sys.stderr) db_ensure_indexes(cls.ensure_indexes_models) @@ -319,19 +331,19 @@ def run(self, result=None): class ExecutionDbTestCase(DbTestCase): - """" + """ " Base test class for tests which test various execution related code paths. This class offers some utility methods for waiting on execution status, etc. """ ensure_indexes = True - ensure_indexes_models = [ - ActionExecutionSchedulingQueueItemDB - ] + ensure_indexes_models = [ActionExecutionSchedulingQueueItemDB] - def _wait_on_status(self, liveaction_db, status, retries=300, delay=0.1, raise_exc=True): - assert isinstance(status, six.string_types), '%s is not of text type' % (status) + def _wait_on_status( + self, liveaction_db, status, retries=300, delay=0.1, raise_exc=True + ): + assert isinstance(status, six.string_types), "%s is not of text type" % (status) for _ in range(0, retries): eventlet.sleep(delay) @@ -344,8 +356,12 @@ def _wait_on_status(self, liveaction_db, status, retries=300, delay=0.1, raise_e return liveaction_db - def _wait_on_statuses(self, liveaction_db, statuses, retries=300, delay=0.1, raise_exc=True): - assert isinstance(statuses, (list, tuple)), '%s is not of list type' % (statuses) + def _wait_on_statuses( + self, liveaction_db, statuses, retries=300, delay=0.1, raise_exc=True + ): + assert isinstance(statuses, (list, tuple)), "%s is not of list type" % ( + statuses + ) for _ in range(0, retries): eventlet.sleep(delay) @@ -358,7 +374,9 @@ def _wait_on_statuses(self, liveaction_db, statuses, retries=300, delay=0.1, rai return liveaction_db - def _wait_on_ac_ex_status(self, execution_db, status, retries=300, delay=0.1, raise_exc=True): + def _wait_on_ac_ex_status( + self, execution_db, status, retries=300, delay=0.1, raise_exc=True + ): for _ in range(0, retries): eventlet.sleep(delay) execution_db = ex_db_access.ActionExecution.get_by_id(str(execution_db.id)) @@ -370,7 +388,9 @@ def _wait_on_ac_ex_status(self, execution_db, status, retries=300, delay=0.1, ra return execution_db - def _wait_on_call_count(self, mocked, expected_count, retries=100, delay=0.1, raise_exc=True): + def _wait_on_call_count( + self, mocked, expected_count, retries=100, delay=0.1, raise_exc=True + ): for _ in range(0, retries): eventlet.sleep(delay) if mocked.call_count == expected_count: @@ -395,12 +415,14 @@ def setUpClass(cls): def _assert_fields_equal(self, a, b, exclude=None): exclude = exclude or [] - fields = {k: v for k, v in six.iteritems(self.db_type._fields) if k not in exclude} + fields = { + k: v for k, v in six.iteritems(self.db_type._fields) if k not in exclude + } assert_funcs = { - 'mongoengine.fields.DictField': self.assertDictEqual, - 'mongoengine.fields.ListField': self.assertListEqual, - 'mongoengine.fields.SortedListField': self.assertListEqual + "mongoengine.fields.DictField": self.assertDictEqual, + "mongoengine.fields.ListField": self.assertListEqual, + "mongoengine.fields.SortedListField": self.assertListEqual, } for k, v in six.iteritems(fields): @@ -410,10 +432,7 @@ def _assert_fields_equal(self, a, b, exclude=None): def _assert_values_equal(self, a, values=None): values = values or {} - assert_funcs = { - 'dict': self.assertDictEqual, - 'list': self.assertListEqual - } + assert_funcs = {"dict": self.assertDictEqual, "list": self.assertListEqual} for k, v in six.iteritems(values): assert_func = assert_funcs.get(type(v).__name__, self.assertEqual) @@ -421,7 +440,7 @@ def _assert_values_equal(self, a, values=None): def _assert_crud(self, instance, defaults=None, updates=None): # Assert instance is not already in the database. - self.assertIsNone(getattr(instance, 'id', None)) + self.assertIsNone(getattr(instance, "id", None)) # Assert default values are assigned. self._assert_values_equal(instance, values=defaults) @@ -429,7 +448,7 @@ def _assert_crud(self, instance, defaults=None, updates=None): # Assert instance is created in the datbaase. saved = self.access_type.add_or_update(instance) self.assertIsNotNone(saved.id) - self._assert_fields_equal(instance, saved, exclude=['id']) + self._assert_fields_equal(instance, saved, exclude=["id"]) retrieved = self.access_type.get_by_id(saved.id) self._assert_fields_equal(saved, retrieved) @@ -443,22 +462,23 @@ def _assert_crud(self, instance, defaults=None, updates=None): # Assert instance is deleted from the database. retrieved = self.access_type.get_by_id(instance.id) retrieved.delete() - self.assertRaises(StackStormDBObjectNotFoundError, - self.access_type.get_by_id, instance.id) + self.assertRaises( + StackStormDBObjectNotFoundError, self.access_type.get_by_id, instance.id + ) def _assert_unique_key_constraint(self, instance): # Assert instance is not already in the database. - self.assertIsNone(getattr(instance, 'id', None)) + self.assertIsNone(getattr(instance, "id", None)) # Assert instance is created in the datbaase. saved = self.access_type.add_or_update(instance) self.assertIsNotNone(saved.id) # Assert exception is thrown if try to create same instance again. - delattr(instance, 'id') - self.assertRaises(StackStormDBObjectConflictError, - self.access_type.add_or_update, - instance) + delattr(instance, "id") + self.assertRaises( + StackStormDBObjectConflictError, self.access_type.add_or_update, instance + ) class CleanDbTestCase(BaseDbTestCase): @@ -486,6 +506,7 @@ class CleanFilesTestCase(TestCase): """ Base test class which deletes specified files and directories on setUp and `tearDown. """ + to_delete_files = [] to_delete_directories = [] @@ -555,8 +576,8 @@ def tearDown(self): stderr = None print('Process "%s"' % (process.pid)) - print('Stdout: %s' % (stdout)) - print('Stderr: %s' % (stderr)) + print("Stdout: %s" % (stdout)) + print("Stderr: %s" % (stderr)) def add_process(self, process): """ @@ -578,7 +599,7 @@ def assertProcessIsRunning(self, process): has succesfuly started and is running. """ if not process: - raise ValueError('process is None') + raise ValueError("process is None") return_code = process.poll() @@ -586,24 +607,27 @@ def assertProcessIsRunning(self, process): if process.stdout: stdout = process.stdout.read() else: - stdout = '' + stdout = "" if process.stderr: stderr = process.stderr.read() else: - stderr = '' + stderr = "" - msg = ('Process exited with code=%s.\nStdout:\n%s\n\nStderr:\n%s' % - (return_code, stdout, stderr)) + msg = "Process exited with code=%s.\nStdout:\n%s\n\nStderr:\n%s" % ( + return_code, + stdout, + stderr, + ) self.fail(msg) def assertProcessExited(self, proc): try: status = proc.status() except psutil.NoSuchProcess: - status = 'exited' + status = "exited" - if status not in ['exited', 'zombie']: + if status not in ["exited", "zombie"]: self.fail('Process with pid "%s" is still running' % (proc.pid)) @@ -613,49 +637,49 @@ class WorkflowTestCase(ExecutionDbTestCase): """ def get_wf_fixture_meta_data(self, fixture_pack_path, wf_meta_file_name): - wf_meta_file_path = fixture_pack_path + '/actions/' + wf_meta_file_name + wf_meta_file_path = fixture_pack_path + "/actions/" + wf_meta_file_name wf_meta_content = loader.load_meta_file(wf_meta_file_path) - wf_name = wf_meta_content['pack'] + '.' + wf_meta_content['name'] + wf_name = wf_meta_content["pack"] + "." + wf_meta_content["name"] return { - 'file_name': wf_meta_file_name, - 'file_path': wf_meta_file_path, - 'content': wf_meta_content, - 'name': wf_name + "file_name": wf_meta_file_name, + "file_path": wf_meta_file_path, + "content": wf_meta_content, + "name": wf_name, } def get_wf_def(self, test_pack_path, wf_meta): - rel_wf_def_path = wf_meta['content']['entry_point'] - abs_wf_def_path = os.path.join(test_pack_path, 'actions', rel_wf_def_path) + rel_wf_def_path = wf_meta["content"]["entry_point"] + abs_wf_def_path = os.path.join(test_pack_path, "actions", rel_wf_def_path) - with open(abs_wf_def_path, 'r') as def_file: + with open(abs_wf_def_path, "r") as def_file: return def_file.read() def mock_st2_context(self, ac_ex_db, context=None): st2_ctx = { - 'st2': { - 'api_url': api_util.get_full_public_api_url(), - 'action_execution_id': str(ac_ex_db.id), - 'user': 'stanley', - 'action': ac_ex_db.action['ref'], - 'runner': ac_ex_db.runner['name'] + "st2": { + "api_url": api_util.get_full_public_api_url(), + "action_execution_id": str(ac_ex_db.id), + "user": "stanley", + "action": ac_ex_db.action["ref"], + "runner": ac_ex_db.runner["name"], } } if context: - st2_ctx['parent'] = context + st2_ctx["parent"] = context return st2_ctx def prep_wf_ex(self, wf_ex_db): data = { - 'spec': wf_ex_db.spec, - 'graph': wf_ex_db.graph, - 'input': wf_ex_db.input, - 'context': wf_ex_db.context, - 'state': wf_ex_db.state, - 'output': wf_ex_db.output, - 'errors': wf_ex_db.errors + "spec": wf_ex_db.spec, + "graph": wf_ex_db.graph, + "input": wf_ex_db.input, + "context": wf_ex_db.context, + "state": wf_ex_db.state, + "output": wf_ex_db.output, + "errors": wf_ex_db.errors, } conductor = conducting.WorkflowConductor.deserialize(data) @@ -663,7 +687,7 @@ def prep_wf_ex(self, wf_ex_db): for task in conductor.get_next_tasks(): ac_ex_event = events.ActionExecutionEvent(wf_statuses.RUNNING) - conductor.update_task_state(task['id'], task['route'], ac_ex_event) + conductor.update_task_state(task["id"], task["route"], ac_ex_event) wf_ex_db.status = conductor.get_workflow_status() wf_ex_db.state = conductor.workflow_state.serialize() @@ -672,7 +696,9 @@ def prep_wf_ex(self, wf_ex_db): return wf_ex_db def get_task_ex(self, task_id, route): - task_ex_dbs = wf_db_access.TaskExecution.query(task_id=task_id, task_route=route) + task_ex_dbs = wf_db_access.TaskExecution.query( + task_id=task_id, task_route=route + ) self.assertGreater(len(task_ex_dbs), 0) return task_ex_dbs[0] @@ -686,21 +712,29 @@ def get_action_ex(self, task_ex_id): self.assertEqual(len(ac_ex_dbs), 1) return ac_ex_dbs[0] - def run_workflow_step(self, wf_ex_db, task_id, route, ctx=None, - expected_ac_ex_db_status=ac_const.LIVEACTION_STATUS_SUCCEEDED, - expected_tk_ex_db_status=wf_statuses.SUCCEEDED): - spec_module = specs_loader.get_spec_module(wf_ex_db.spec['catalog']) + def run_workflow_step( + self, + wf_ex_db, + task_id, + route, + ctx=None, + expected_ac_ex_db_status=ac_const.LIVEACTION_STATUS_SUCCEEDED, + expected_tk_ex_db_status=wf_statuses.SUCCEEDED, + ): + spec_module = specs_loader.get_spec_module(wf_ex_db.spec["catalog"]) wf_spec = spec_module.WorkflowSpec.deserialize(wf_ex_db.spec) - st2_ctx = {'execution_id': wf_ex_db.action_execution} + st2_ctx = {"execution_id": wf_ex_db.action_execution} task_spec = wf_spec.tasks.get_task(task_id) - task_actions = [{'action': task_spec.action, 'input': getattr(task_spec, 'input', {})}] + task_actions = [ + {"action": task_spec.action, "input": getattr(task_spec, "input", {})} + ] task_req = { - 'id': task_id, - 'route': route, - 'spec': task_spec, - 'ctx': ctx or {}, - 'actions': task_actions + "id": task_id, + "route": route, + "spec": task_spec, + "ctx": ctx or {}, + "actions": task_actions, } task_ex_db = wf_svc.request_task_execution(wf_ex_db, st2_ctx, task_req) @@ -712,10 +746,12 @@ def run_workflow_step(self, wf_ex_db, task_id, route, ctx=None, self.assertEqual(task_ex_db.status, expected_tk_ex_db_status) def sort_workflow_errors(self, errors): - return sorted(errors, key=lambda x: x.get('task_id', None)) + return sorted(errors, key=lambda x: x.get("task_id", None)) def assert_task_not_started(self, task_id, route): - task_ex_dbs = wf_db_access.TaskExecution.query(task_id=task_id, task_route=route) + task_ex_dbs = wf_db_access.TaskExecution.query( + task_id=task_id, task_route=route + ) self.assertEqual(len(task_ex_dbs), 0) def assert_task_running(self, task_id, route): @@ -734,7 +770,6 @@ def assert_workflow_completed(self, wf_ex_id, status=None): class FakeResponse(object): - def __init__(self, text, status_code, reason): self.text = text self.status_code = status_code @@ -748,11 +783,11 @@ def raise_for_status(self): def get_fixtures_path(): - return os.path.join(os.path.dirname(__file__), 'fixtures') + return os.path.join(os.path.dirname(__file__), "fixtures") def get_resources_path(): - return os.path.join(os.path.dirname(__file__), 'resources') + return os.path.join(os.path.dirname(__file__), "resources") def blocking_eventlet_spawn(func, *args, **kwargs): diff --git a/st2tests/st2tests/config.py b/st2tests/st2tests/config.py index 7fa4ad7b6ef..b1403578391 100644 --- a/st2tests/st2tests/config.py +++ b/st2tests/st2tests/config.py @@ -77,57 +77,66 @@ def _register_config_opts(): def _override_db_opts(): - CONF.set_override(name='db_name', override='st2-test', group='database') - CONF.set_override(name='host', override='127.0.0.1', group='database') + CONF.set_override(name="db_name", override="st2-test", group="database") + CONF.set_override(name="host", override="127.0.0.1", group="database") def _override_common_opts(): packs_base_path = get_fixtures_packs_base_path() - CONF.set_override(name='base_path', override=packs_base_path, group='system') - CONF.set_override(name='validate_output_schema', override=True, group='system') - CONF.set_override(name='system_packs_base_path', override=packs_base_path, group='content') - CONF.set_override(name='packs_base_paths', override=packs_base_path, group='content') - CONF.set_override(name='api_url', override='http://127.0.0.1', group='auth') - CONF.set_override(name='mask_secrets', override=True, group='log') - CONF.set_override(name='stream_output', override=False, group='actionrunner') + CONF.set_override(name="base_path", override=packs_base_path, group="system") + CONF.set_override(name="validate_output_schema", override=True, group="system") + CONF.set_override( + name="system_packs_base_path", override=packs_base_path, group="content" + ) + CONF.set_override( + name="packs_base_paths", override=packs_base_path, group="content" + ) + CONF.set_override(name="api_url", override="http://127.0.0.1", group="auth") + CONF.set_override(name="mask_secrets", override=True, group="log") + CONF.set_override(name="stream_output", override=False, group="actionrunner") def _override_api_opts(): - CONF.set_override(name='allow_origin', override=['http://127.0.0.1:3000', 'http://dev'], - group='api') + CONF.set_override( + name="allow_origin", + override=["http://127.0.0.1:3000", "http://dev"], + group="api", + ) def _override_keyvalue_opts(): current_file_path = os.path.dirname(__file__) - rel_st2_base_path = os.path.join(current_file_path, '../..') + rel_st2_base_path = os.path.join(current_file_path, "../..") abs_st2_base_path = os.path.abspath(rel_st2_base_path) - rel_enc_key_path = 'st2tests/conf/st2_kvstore_tests.crypto.key.json' + rel_enc_key_path = "st2tests/conf/st2_kvstore_tests.crypto.key.json" ovr_enc_key_path = os.path.join(abs_st2_base_path, rel_enc_key_path) - CONF.set_override(name='encryption_key_path', override=ovr_enc_key_path, group='keyvalue') + CONF.set_override( + name="encryption_key_path", override=ovr_enc_key_path, group="keyvalue" + ) def _override_scheduler_opts(): - CONF.set_override(name='sleep_interval', group='scheduler', override=0.01) + CONF.set_override(name="sleep_interval", group="scheduler", override=0.01) def _override_coordinator_opts(noop=False): - driver = None if noop else 'zake://' - CONF.set_override(name='url', override=driver, group='coordination') - CONF.set_override(name='lock_timeout', override=1, group='coordination') + driver = None if noop else "zake://" + CONF.set_override(name="url", override=driver, group="coordination") + CONF.set_override(name="lock_timeout", override=1, group="coordination") def _override_workflow_engine_opts(): - cfg.CONF.set_override('retry_stop_max_msec', 500, group='workflow_engine') - cfg.CONF.set_override('retry_wait_fixed_msec', 100, group='workflow_engine') - cfg.CONF.set_override('retry_max_jitter_msec', 100, group='workflow_engine') - cfg.CONF.set_override('gc_max_idle_sec', 1, group='workflow_engine') + cfg.CONF.set_override("retry_stop_max_msec", 500, group="workflow_engine") + cfg.CONF.set_override("retry_wait_fixed_msec", 100, group="workflow_engine") + cfg.CONF.set_override("retry_max_jitter_msec", 100, group="workflow_engine") + cfg.CONF.set_override("gc_max_idle_sec", 1, group="workflow_engine") def _register_common_opts(): try: common_config.register_opts(ignore_errors=True) except: - LOG.exception('Common config registration failed.') + LOG.exception("Common config registration failed.") def _register_api_opts(): @@ -135,225 +144,292 @@ def _register_api_opts(): # Brittle! pecan_opts = [ cfg.StrOpt( - 'root', default='st2api.controllers.root.RootController', - help='Pecan root controller'), - cfg.StrOpt('template_path', default='%(confdir)s/st2api/st2api/templates'), - cfg.ListOpt('modules', default=['st2api']), - cfg.BoolOpt('debug', default=True), - cfg.BoolOpt('auth_enable', default=True), - cfg.DictOpt('errors', default={404: '/error/404', '__force_dict__': True}) + "root", + default="st2api.controllers.root.RootController", + help="Pecan root controller", + ), + cfg.StrOpt("template_path", default="%(confdir)s/st2api/st2api/templates"), + cfg.ListOpt("modules", default=["st2api"]), + cfg.BoolOpt("debug", default=True), + cfg.BoolOpt("auth_enable", default=True), + cfg.DictOpt("errors", default={404: "/error/404", "__force_dict__": True}), ] - _register_opts(pecan_opts, group='api_pecan') + _register_opts(pecan_opts, group="api_pecan") api_opts = [ - cfg.BoolOpt('debug', default=True), + cfg.BoolOpt("debug", default=True), cfg.IntOpt( - 'max_page_size', default=100, - help='Maximum limit (page size) argument which can be specified by the user in a query ' - 'string. If a larger value is provided, it will default to this value.') + "max_page_size", + default=100, + help="Maximum limit (page size) argument which can be specified by the user in a query " + "string. If a larger value is provided, it will default to this value.", + ), ] - _register_opts(api_opts, group='api') + _register_opts(api_opts, group="api") messaging_opts = [ cfg.StrOpt( - 'url', default='amqp://guest:guest@127.0.0.1:5672//', - help='URL of the messaging server.'), + "url", + default="amqp://guest:guest@127.0.0.1:5672//", + help="URL of the messaging server.", + ), cfg.ListOpt( - 'cluster_urls', default=[], - help='URL of all the nodes in a messaging service cluster.'), + "cluster_urls", + default=[], + help="URL of all the nodes in a messaging service cluster.", + ), cfg.IntOpt( - 'connection_retries', default=10, - help='How many times should we retry connection before failing.'), + "connection_retries", + default=10, + help="How many times should we retry connection before failing.", + ), cfg.IntOpt( - 'connection_retry_wait', default=10000, - help='How long should we wait between connection retries.'), + "connection_retry_wait", + default=10000, + help="How long should we wait between connection retries.", + ), cfg.BoolOpt( - 'ssl', default=False, - help='Use SSL / TLS to connect to the messaging server. Same as ' - 'appending "?ssl=true" at the end of the connection URL string.'), + "ssl", + default=False, + help="Use SSL / TLS to connect to the messaging server. Same as " + 'appending "?ssl=true" at the end of the connection URL string.', + ), cfg.StrOpt( - 'ssl_keyfile', default=None, - help='Private keyfile used to identify the local connection against RabbitMQ.'), + "ssl_keyfile", + default=None, + help="Private keyfile used to identify the local connection against RabbitMQ.", + ), cfg.StrOpt( - 'ssl_certfile', default=None, - help='Certificate file used to identify the local connection (client).'), + "ssl_certfile", + default=None, + help="Certificate file used to identify the local connection (client).", + ), cfg.StrOpt( - 'ssl_cert_reqs', default=None, choices='none, optional, required', - help='Specifies whether a certificate is required from the other side of the ' - 'connection, and whether it will be validated if provided.'), + "ssl_cert_reqs", + default=None, + choices="none, optional, required", + help="Specifies whether a certificate is required from the other side of the " + "connection, and whether it will be validated if provided.", + ), cfg.StrOpt( - 'ssl_ca_certs', default=None, - help='ca_certs file contains a set of concatenated CA certificates, which are ' - 'used to validate certificates passed from RabbitMQ.'), + "ssl_ca_certs", + default=None, + help="ca_certs file contains a set of concatenated CA certificates, which are " + "used to validate certificates passed from RabbitMQ.", + ), cfg.StrOpt( - 'login_method', default=None, - help='Login method to use (AMQPLAIN, PLAIN, EXTERNAL, etc.).') + "login_method", + default=None, + help="Login method to use (AMQPLAIN, PLAIN, EXTERNAL, etc.).", + ), ] - _register_opts(messaging_opts, group='messaging') + _register_opts(messaging_opts, group="messaging") ssh_runner_opts = [ cfg.StrOpt( - 'remote_dir', default='/tmp', - help='Location of the script on the remote filesystem.'), + "remote_dir", + default="/tmp", + help="Location of the script on the remote filesystem.", + ), cfg.BoolOpt( - 'allow_partial_failure', default=False, - help='How partial success of actions run on multiple nodes should be treated.'), + "allow_partial_failure", + default=False, + help="How partial success of actions run on multiple nodes should be treated.", + ), cfg.BoolOpt( - 'use_ssh_config', default=False, - help='Use the .ssh/config file. Useful to override ports etc.') + "use_ssh_config", + default=False, + help="Use the .ssh/config file. Useful to override ports etc.", + ), ] - _register_opts(ssh_runner_opts, group='ssh_runner') + _register_opts(ssh_runner_opts, group="ssh_runner") def _register_stream_opts(): stream_opts = [ cfg.IntOpt( - 'heartbeat', default=25, - help='Send empty message every N seconds to keep connection open'), - cfg.BoolOpt( - 'debug', default=False, - help='Specify to enable debug mode.'), + "heartbeat", + default=25, + help="Send empty message every N seconds to keep connection open", + ), + cfg.BoolOpt("debug", default=False, help="Specify to enable debug mode."), ] - _register_opts(stream_opts, group='stream') + _register_opts(stream_opts, group="stream") def _register_auth_opts(): auth_opts = [ - cfg.StrOpt('host', default='127.0.0.1'), - cfg.IntOpt('port', default=9100), - cfg.BoolOpt('use_ssl', default=False), - cfg.StrOpt('mode', default='proxy'), - cfg.StrOpt('backend', default='flat_file'), - cfg.StrOpt('backend_kwargs', default=None), - cfg.StrOpt('logging', default='conf/logging.conf'), - cfg.IntOpt('token_ttl', default=86400, help='Access token ttl in seconds.'), - cfg.BoolOpt('sso', default=True), - cfg.StrOpt('sso_backend', default='noop'), - cfg.StrOpt('sso_backend_kwargs', default=None), - cfg.BoolOpt('debug', default=True) + cfg.StrOpt("host", default="127.0.0.1"), + cfg.IntOpt("port", default=9100), + cfg.BoolOpt("use_ssl", default=False), + cfg.StrOpt("mode", default="proxy"), + cfg.StrOpt("backend", default="flat_file"), + cfg.StrOpt("backend_kwargs", default=None), + cfg.StrOpt("logging", default="conf/logging.conf"), + cfg.IntOpt("token_ttl", default=86400, help="Access token ttl in seconds."), + cfg.BoolOpt("sso", default=True), + cfg.StrOpt("sso_backend", default="noop"), + cfg.StrOpt("sso_backend_kwargs", default=None), + cfg.BoolOpt("debug", default=True), ] - _register_opts(auth_opts, group='auth') + _register_opts(auth_opts, group="auth") def _register_action_sensor_opts(): action_sensor_opts = [ cfg.BoolOpt( - 'enable', default=True, - help='Whether to enable or disable the ability to post a trigger on action.'), + "enable", + default=True, + help="Whether to enable or disable the ability to post a trigger on action.", + ), cfg.StrOpt( - 'triggers_base_url', default='http://127.0.0.1:9101/v1/triggertypes/', - help='URL for action sensor to post TriggerType.'), + "triggers_base_url", + default="http://127.0.0.1:9101/v1/triggertypes/", + help="URL for action sensor to post TriggerType.", + ), cfg.IntOpt( - 'request_timeout', default=1, - help='Timeout value of all httprequests made by action sensor.'), + "request_timeout", + default=1, + help="Timeout value of all httprequests made by action sensor.", + ), cfg.IntOpt( - 'max_attempts', default=10, - help='No. of times to retry registration.'), + "max_attempts", default=10, help="No. of times to retry registration." + ), cfg.IntOpt( - 'retry_wait', default=1, - help='Amount of time to wait prior to retrying a request.') + "retry_wait", + default=1, + help="Amount of time to wait prior to retrying a request.", + ), ] - _register_opts(action_sensor_opts, group='action_sensor') + _register_opts(action_sensor_opts, group="action_sensor") def _register_ssh_runner_opts(): ssh_runner_opts = [ cfg.BoolOpt( - 'use_ssh_config', default=False, - help='Use the .ssh/config file. Useful to override ports etc.'), + "use_ssh_config", + default=False, + help="Use the .ssh/config file. Useful to override ports etc.", + ), cfg.StrOpt( - 'remote_dir', default='/tmp', - help='Location of the script on the remote filesystem.'), + "remote_dir", + default="/tmp", + help="Location of the script on the remote filesystem.", + ), cfg.BoolOpt( - 'allow_partial_failure', default=False, - help='How partial success of actions run on multiple nodes should be treated.'), + "allow_partial_failure", + default=False, + help="How partial success of actions run on multiple nodes should be treated.", + ), cfg.IntOpt( - 'max_parallel_actions', default=50, - help='Max number of parallel remote SSH actions that should be run. ' - 'Works only with Paramiko SSH runner.'), + "max_parallel_actions", + default=50, + help="Max number of parallel remote SSH actions that should be run. " + "Works only with Paramiko SSH runner.", + ), ] - _register_opts(ssh_runner_opts, group='ssh_runner') + _register_opts(ssh_runner_opts, group="ssh_runner") def _register_scheduler_opts(): scheduler_opts = [ cfg.FloatOpt( - 'execution_scheduling_timeout_threshold_min', default=1, - help='How long GC to search back in minutes for orphaned scheduled actions'), + "execution_scheduling_timeout_threshold_min", + default=1, + help="How long GC to search back in minutes for orphaned scheduled actions", + ), cfg.IntOpt( - 'pool_size', default=10, - help='The size of the pool used by the scheduler for scheduling executions.'), + "pool_size", + default=10, + help="The size of the pool used by the scheduler for scheduling executions.", + ), cfg.FloatOpt( - 'sleep_interval', default=0.01, - help='How long to sleep between each action scheduler main loop run interval (in ms).'), + "sleep_interval", + default=0.01, + help="How long to sleep between each action scheduler main loop run interval (in ms).", + ), cfg.FloatOpt( - 'gc_interval', default=5, - help='How often to look for zombie executions before rescheduling them (in ms).'), + "gc_interval", + default=5, + help="How often to look for zombie executions before rescheduling them (in ms).", + ), cfg.IntOpt( - 'retry_max_attempt', default=3, - help='The maximum number of attempts that the scheduler retries on error.'), + "retry_max_attempt", + default=3, + help="The maximum number of attempts that the scheduler retries on error.", + ), cfg.IntOpt( - 'retry_wait_msec', default=100, - help='The number of milliseconds to wait in between retries.') + "retry_wait_msec", + default=100, + help="The number of milliseconds to wait in between retries.", + ), ] - _register_opts(scheduler_opts, group='scheduler') + _register_opts(scheduler_opts, group="scheduler") def _register_exporter_opts(): exporter_opts = [ cfg.StrOpt( - 'dump_dir', default='/opt/stackstorm/exports/', - help='Directory to dump data to.') + "dump_dir", + default="/opt/stackstorm/exports/", + help="Directory to dump data to.", + ) ] - _register_opts(exporter_opts, group='exporter') + _register_opts(exporter_opts, group="exporter") def _register_sensor_container_opts(): partition_opts = [ cfg.StrOpt( - 'sensor_node_name', default='sensornode1', - help='name of the sensor node.'), + "sensor_node_name", default="sensornode1", help="name of the sensor node." + ), cfg.Opt( - 'partition_provider', + "partition_provider", type=types.Dict(value_type=types.String()), - default={'name': DEFAULT_PARTITION_LOADER}, - help='Provider of sensor node partition config.') + default={"name": DEFAULT_PARTITION_LOADER}, + help="Provider of sensor node partition config.", + ), ] - _register_opts(partition_opts, group='sensorcontainer') + _register_opts(partition_opts, group="sensorcontainer") # Other options other_opts = [ cfg.BoolOpt( - 'single_sensor_mode', default=False, - help='Run in a single sensor mode where parent process exits when a sensor crashes / ' - 'dies. This is useful in environments where partitioning, sensor process life ' - 'cycle and failover is handled by a 3rd party service such as kubernetes.') + "single_sensor_mode", + default=False, + help="Run in a single sensor mode where parent process exits when a sensor crashes / " + "dies. This is useful in environments where partitioning, sensor process life " + "cycle and failover is handled by a 3rd party service such as kubernetes.", + ) ] - _register_opts(other_opts, group='sensorcontainer') + _register_opts(other_opts, group="sensorcontainer") # CLI options cli_opts = [ cfg.StrOpt( - 'sensor-ref', - help='Only run sensor with the provided reference. Value is of the form ' - '. (e.g. linux.FileWatchSensor).'), + "sensor-ref", + help="Only run sensor with the provided reference. Value is of the form " + ". (e.g. linux.FileWatchSensor).", + ), cfg.BoolOpt( - 'single-sensor-mode', default=False, - help='Run in a single sensor mode where parent process exits when a sensor crashes / ' - 'dies. This is useful in environments where partitioning, sensor process life ' - 'cycle and failover is handled by a 3rd party service such as kubernetes.') + "single-sensor-mode", + default=False, + help="Run in a single sensor mode where parent process exits when a sensor crashes / " + "dies. This is useful in environments where partitioning, sensor process life " + "cycle and failover is handled by a 3rd party service such as kubernetes.", + ), ] _register_cli_opts(cli_opts) @@ -362,40 +438,52 @@ def _register_sensor_container_opts(): def _register_garbage_collector_opts(): common_opts = [ cfg.IntOpt( - 'collection_interval', default=DEFAULT_COLLECTION_INTERVAL, - help='How often to check database for old data and perform garbage collection.'), + "collection_interval", + default=DEFAULT_COLLECTION_INTERVAL, + help="How often to check database for old data and perform garbage collection.", + ), cfg.FloatOpt( - 'sleep_delay', default=DEFAULT_SLEEP_DELAY, - help='How long to wait / sleep (in seconds) between ' - 'collection of different object types.') + "sleep_delay", + default=DEFAULT_SLEEP_DELAY, + help="How long to wait / sleep (in seconds) between " + "collection of different object types.", + ), ] - _register_opts(common_opts, group='garbagecollector') + _register_opts(common_opts, group="garbagecollector") ttl_opts = [ cfg.IntOpt( - 'action_executions_ttl', default=None, - help='Action executions and related objects (live actions, action output ' - 'objects) older than this value (days) will be automatically deleted.'), + "action_executions_ttl", + default=None, + help="Action executions and related objects (live actions, action output " + "objects) older than this value (days) will be automatically deleted.", + ), cfg.IntOpt( - 'action_executions_output_ttl', default=7, - help='Action execution output objects (ones generated by action output ' - 'streaming) older than this value (days) will be automatically deleted.'), + "action_executions_output_ttl", + default=7, + help="Action execution output objects (ones generated by action output " + "streaming) older than this value (days) will be automatically deleted.", + ), cfg.IntOpt( - 'trigger_instances_ttl', default=None, - help='Trigger instances older than this value (days) will be automatically deleted.') + "trigger_instances_ttl", + default=None, + help="Trigger instances older than this value (days) will be automatically deleted.", + ), ] - _register_opts(ttl_opts, group='garbagecollector') + _register_opts(ttl_opts, group="garbagecollector") inquiry_opts = [ cfg.BoolOpt( - 'purge_inquiries', default=False, - help='Set to True to perform garbage collection on Inquiries (based on ' - 'the TTL value per Inquiry)') + "purge_inquiries", + default=False, + help="Set to True to perform garbage collection on Inquiries (based on " + "the TTL value per Inquiry)", + ) ] - _register_opts(inquiry_opts, group='garbagecollector') + _register_opts(inquiry_opts, group="garbagecollector") def _register_opts(opts, group=None): diff --git a/st2tests/st2tests/fixtures/history_views/__init__.py b/st2tests/st2tests/fixtures/history_views/__init__.py index dd42395788b..24567ead6e8 100644 --- a/st2tests/st2tests/fixtures/history_views/__init__.py +++ b/st2tests/st2tests/fixtures/history_views/__init__.py @@ -21,12 +21,12 @@ PATH = os.path.join(os.path.dirname(os.path.realpath(__file__))) -FILES = glob.glob('%s/*.yaml' % PATH) +FILES = glob.glob("%s/*.yaml" % PATH) ARTIFACTS = {} for f in FILES: f_name = os.path.split(f)[1] name = six.text_type(os.path.splitext(f_name)[0]) - with open(f, 'r') as fd: + with open(f, "r") as fd: ARTIFACTS[name] = yaml.safe_load(fd) diff --git a/st2tests/st2tests/fixtures/localrunner_pack/actions/text_gen.py b/st2tests/st2tests/fixtures/localrunner_pack/actions/text_gen.py index b5184b586c9..5b2cc19cc02 100755 --- a/st2tests/st2tests/fixtures/localrunner_pack/actions/text_gen.py +++ b/st2tests/st2tests/fixtures/localrunner_pack/actions/text_gen.py @@ -32,16 +32,16 @@ def print_random_chars(chars=1000, selection=ascii_letters + string.digits): s = [] for _ in range(chars - 1): s.append(random.choice(selection)) - s.append('@') - print(''.join(s)) + s.append("@") + print("".join(s)) def main(): parser = argparse.ArgumentParser() - parser.add_argument('--chars', type=int, metavar='N', default=10) + parser.add_argument("--chars", type=int, metavar="N", default=10) args = parser.parse_args() print_random_chars(args.chars) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/st2tests/st2tests/fixtures/packs/dummy_pack_7/actions/render_config_context.py b/st2tests/st2tests/fixtures/packs/dummy_pack_7/actions/render_config_context.py index 57f3f6eea31..40875e1182e 100644 --- a/st2tests/st2tests/fixtures/packs/dummy_pack_7/actions/render_config_context.py +++ b/st2tests/st2tests/fixtures/packs/dummy_pack_7/actions/render_config_context.py @@ -17,6 +17,5 @@ class PrintPythonVersionAction(Action): - def run(self, value1): return {"context_value": value1} diff --git a/st2tests/st2tests/fixtures/packs/dummy_pack_9/actions/invalid_syntax.py b/st2tests/st2tests/fixtures/packs/dummy_pack_9/actions/invalid_syntax.py index ef42e25d15e..acd28326276 100644 --- a/st2tests/st2tests/fixtures/packs/dummy_pack_9/actions/invalid_syntax.py +++ b/st2tests/st2tests/fixtures/packs/dummy_pack_9/actions/invalid_syntax.py @@ -14,8 +14,8 @@ # limitations under the License. from __future__ import absolute_import -from invalid import Invalid # noqa +from invalid import Invalid # noqa -class Foo(): +class Foo: pass diff --git a/st2tests/st2tests/fixtures/packs/executions/__init__.py b/st2tests/st2tests/fixtures/packs/executions/__init__.py index 3faa0a81adb..ef9bf26a3fc 100644 --- a/st2tests/st2tests/fixtures/packs/executions/__init__.py +++ b/st2tests/st2tests/fixtures/packs/executions/__init__.py @@ -22,17 +22,17 @@ PATH = os.path.dirname(os.path.realpath(__file__)) -FILES = glob.glob('%s/*.yaml' % PATH) +FILES = glob.glob("%s/*.yaml" % PATH) ARTIFACTS = {} for f in FILES: f_name = os.path.split(f)[1] name = six.text_type(os.path.splitext(f_name)[0]) - with open(f, 'r') as fd: + with open(f, "r") as fd: ARTIFACTS[name] = yaml.safe_load(fd) if isinstance(ARTIFACTS[name], dict): - ARTIFACTS[name][u'id'] = six.text_type(bson.ObjectId()) + ARTIFACTS[name]["id"] = six.text_type(bson.ObjectId()) elif isinstance(ARTIFACTS[name], list): for item in ARTIFACTS[name]: - item[u'id'] = six.text_type(bson.ObjectId()) + item["id"] = six.text_type(bson.ObjectId()) diff --git a/st2tests/st2tests/fixtures/packs/runners/test_async_runner/test_async_runner.py b/st2tests/st2tests/fixtures/packs/runners/test_async_runner/test_async_runner.py index 04092029038..31258fae4ec 100644 --- a/st2tests/st2tests/fixtures/packs/runners/test_async_runner/test_async_runner.py +++ b/st2tests/st2tests/fixtures/packs/runners/test_async_runner/test_async_runner.py @@ -14,15 +14,16 @@ # limitations under the License. from __future__ import absolute_import + try: import simplejson as json except: import json from st2common.runners.base import AsyncActionRunner -from st2common.constants.action import (LIVEACTION_STATUS_RUNNING) +from st2common.constants.action import LIVEACTION_STATUS_RUNNING -RAISE_PROPERTY = 'raise' +RAISE_PROPERTY = "raise" def get_runner(): @@ -31,7 +32,7 @@ def get_runner(): class AsyncTestRunner(AsyncActionRunner): def __init__(self): - super(AsyncTestRunner, self).__init__(runner_id='1') + super(AsyncTestRunner, self).__init__(runner_id="1") self.pre_run_called = False self.run_called = False self.post_run_called = False @@ -43,14 +44,11 @@ def run(self, action_params): self.run_called = True result = {} if self.runner_parameters.get(RAISE_PROPERTY, False): - raise Exception('Raise required.') + raise Exception("Raise required.") else: - result = { - 'ran': True, - 'action_params': action_params - } + result = {"ran": True, "action_params": action_params} - return (LIVEACTION_STATUS_RUNNING, json.dumps(result), {'id': 'foo'}) + return (LIVEACTION_STATUS_RUNNING, json.dumps(result), {"id": "foo"}) def post_run(self, status, result): self.post_run_called = True diff --git a/st2tests/st2tests/fixtures/packs/runners/test_polling_async_runner/test_polling_async_runner.py b/st2tests/st2tests/fixtures/packs/runners/test_polling_async_runner/test_polling_async_runner.py index 435f7eb9b68..c48bb9aa675 100644 --- a/st2tests/st2tests/fixtures/packs/runners/test_polling_async_runner/test_polling_async_runner.py +++ b/st2tests/st2tests/fixtures/packs/runners/test_polling_async_runner/test_polling_async_runner.py @@ -14,15 +14,16 @@ # limitations under the License. from __future__ import absolute_import + try: import simplejson as json except: import json from st2common.runners.base import PollingAsyncActionRunner -from st2common.constants.action import (LIVEACTION_STATUS_RUNNING) +from st2common.constants.action import LIVEACTION_STATUS_RUNNING -RAISE_PROPERTY = 'raise' +RAISE_PROPERTY = "raise" def get_runner(): @@ -31,7 +32,7 @@ def get_runner(): class PollingAsyncTestRunner(PollingAsyncActionRunner): def __init__(self): - super(PollingAsyncTestRunner, self).__init__(runner_id='1') + super(PollingAsyncTestRunner, self).__init__(runner_id="1") self.pre_run_called = False self.run_called = False self.post_run_called = False @@ -43,14 +44,11 @@ def run(self, action_params): self.run_called = True result = {} if self.runner_parameters.get(RAISE_PROPERTY, False): - raise Exception('Raise required.') + raise Exception("Raise required.") else: - result = { - 'ran': True, - 'action_params': action_params - } + result = {"ran": True, "action_params": action_params} - return (LIVEACTION_STATUS_RUNNING, json.dumps(result), {'id': 'foo'}) + return (LIVEACTION_STATUS_RUNNING, json.dumps(result), {"id": "foo"}) def post_run(self, status, result): self.post_run_called = True diff --git a/st2tests/st2tests/fixtures/packs/test_library_dependencies/actions/get_library_path.py b/st2tests/st2tests/fixtures/packs/test_library_dependencies/actions/get_library_path.py index fe50e37ae5b..5d18a77ccb7 100644 --- a/st2tests/st2tests/fixtures/packs/test_library_dependencies/actions/get_library_path.py +++ b/st2tests/st2tests/fixtures/packs/test_library_dependencies/actions/get_library_path.py @@ -15,9 +15,7 @@ from st2actions.runners.pythonrunner import Action -__all__ = [ - 'GetLibraryPathAction' -] +__all__ = ["GetLibraryPathAction"] class GetLibraryPathAction(Action): diff --git a/st2tests/st2tests/fixturesloader.py b/st2tests/st2tests/fixturesloader.py index df9f2cef7f7..dd1446153ea 100644 --- a/st2tests/st2tests/fixturesloader.py +++ b/st2tests/st2tests/fixturesloader.py @@ -21,16 +21,21 @@ from st2common.content.loader import MetaLoader -from st2common.models.api.action import (ActionAPI, LiveActionAPI, ActionExecutionStateAPI, - RunnerTypeAPI, ActionAliasAPI) +from st2common.models.api.action import ( + ActionAPI, + LiveActionAPI, + ActionExecutionStateAPI, + RunnerTypeAPI, + ActionAliasAPI, +) from st2common.models.api.auth import ApiKeyAPI, UserAPI -from st2common.models.api.execution import (ActionExecutionAPI) -from st2common.models.api.policy import (PolicyTypeAPI, PolicyAPI) -from st2common.models.api.rule import (RuleAPI) +from st2common.models.api.execution import ActionExecutionAPI +from st2common.models.api.policy import PolicyTypeAPI, PolicyAPI +from st2common.models.api.rule import RuleAPI from st2common.models.api.rule_enforcement import RuleEnforcementAPI from st2common.models.api.sensor import SensorTypeAPI from st2common.models.api.trace import TraceAPI -from st2common.models.api.trigger import (TriggerAPI, TriggerTypeAPI, TriggerInstanceAPI) +from st2common.models.api.trigger import TriggerAPI, TriggerTypeAPI, TriggerInstanceAPI from st2common.models.db.action import ActionDB from st2common.models.db.actionalias import ActionAliasDB @@ -38,13 +43,13 @@ from st2common.models.db.liveaction import LiveActionDB from st2common.models.db.executionstate import ActionExecutionStateDB from st2common.models.db.runner import RunnerTypeDB -from st2common.models.db.execution import (ActionExecutionDB) -from st2common.models.db.policy import (PolicyTypeDB, PolicyDB) +from st2common.models.db.execution import ActionExecutionDB +from st2common.models.db.policy import PolicyTypeDB, PolicyDB from st2common.models.db.rule import RuleDB from st2common.models.db.rule_enforcement import RuleEnforcementDB from st2common.models.db.sensor import SensorTypeDB from st2common.models.db.trace import TraceDB -from st2common.models.db.trigger import (TriggerDB, TriggerTypeDB, TriggerInstanceDB) +from st2common.models.db.trigger import TriggerDB, TriggerTypeDB, TriggerInstanceDB from st2common.persistence.action import Action from st2common.persistence.actionalias import ActionAlias from st2common.persistence.execution import ActionExecution @@ -52,107 +57,125 @@ from st2common.persistence.auth import ApiKey, User from st2common.persistence.liveaction import LiveAction from st2common.persistence.runner import RunnerType -from st2common.persistence.policy import (PolicyType, Policy) +from st2common.persistence.policy import PolicyType, Policy from st2common.persistence.rule import Rule from st2common.persistence.rule_enforcement import RuleEnforcement from st2common.persistence.sensor import SensorType from st2common.persistence.trace import Trace -from st2common.persistence.trigger import (Trigger, TriggerType, TriggerInstance) - - -ALLOWED_DB_FIXTURES = ['actions', 'actionstates', 'aliases', 'executions', 'liveactions', - 'policies', 'policytypes', 'rules', 'runners', 'sensors', - 'triggertypes', 'triggers', 'triggerinstances', 'traces', 'apikeys', - 'users', 'enforcements'] +from st2common.persistence.trigger import Trigger, TriggerType, TriggerInstance + + +ALLOWED_DB_FIXTURES = [ + "actions", + "actionstates", + "aliases", + "executions", + "liveactions", + "policies", + "policytypes", + "rules", + "runners", + "sensors", + "triggertypes", + "triggers", + "triggerinstances", + "traces", + "apikeys", + "users", + "enforcements", +] ALLOWED_FIXTURES = copy.copy(ALLOWED_DB_FIXTURES) -ALLOWED_FIXTURES.extend(['actionchains', 'workflows']) +ALLOWED_FIXTURES.extend(["actionchains", "workflows"]) FIXTURE_DB_MODEL = { - 'actions': ActionDB, - 'aliases': ActionAliasDB, - 'actionstates': ActionExecutionStateDB, - 'apikeys': ApiKeyDB, - 'enforcements': RuleEnforcementDB, - 'executions': ActionExecutionDB, - 'liveactions': LiveActionDB, - 'policies': PolicyDB, - 'policytypes': PolicyTypeDB, - 'rules': RuleDB, - 'runners': RunnerTypeDB, - 'sensors': SensorTypeDB, - 'traces': TraceDB, - 'triggertypes': TriggerTypeDB, - 'triggers': TriggerDB, - 'triggerinstances': TriggerInstanceDB, - 'users': UserDB + "actions": ActionDB, + "aliases": ActionAliasDB, + "actionstates": ActionExecutionStateDB, + "apikeys": ApiKeyDB, + "enforcements": RuleEnforcementDB, + "executions": ActionExecutionDB, + "liveactions": LiveActionDB, + "policies": PolicyDB, + "policytypes": PolicyTypeDB, + "rules": RuleDB, + "runners": RunnerTypeDB, + "sensors": SensorTypeDB, + "traces": TraceDB, + "triggertypes": TriggerTypeDB, + "triggers": TriggerDB, + "triggerinstances": TriggerInstanceDB, + "users": UserDB, } FIXTURE_API_MODEL = { - 'actions': ActionAPI, - 'aliases': ActionAliasAPI, - 'actionstates': ActionExecutionStateAPI, - 'apikeys': ApiKeyAPI, - 'enforcements': RuleEnforcementAPI, - 'executions': ActionExecutionAPI, - 'liveactions': LiveActionAPI, - 'policies': PolicyAPI, - 'policytypes': PolicyTypeAPI, - 'rules': RuleAPI, - 'runners': RunnerTypeAPI, - 'sensors': SensorTypeAPI, - 'traces': TraceAPI, - 'triggertypes': TriggerTypeAPI, - 'triggers': TriggerAPI, - 'triggerinstances': TriggerInstanceAPI, - 'users': UserAPI + "actions": ActionAPI, + "aliases": ActionAliasAPI, + "actionstates": ActionExecutionStateAPI, + "apikeys": ApiKeyAPI, + "enforcements": RuleEnforcementAPI, + "executions": ActionExecutionAPI, + "liveactions": LiveActionAPI, + "policies": PolicyAPI, + "policytypes": PolicyTypeAPI, + "rules": RuleAPI, + "runners": RunnerTypeAPI, + "sensors": SensorTypeAPI, + "traces": TraceAPI, + "triggertypes": TriggerTypeAPI, + "triggers": TriggerAPI, + "triggerinstances": TriggerInstanceAPI, + "users": UserAPI, } FIXTURE_PERSISTENCE_MODEL = { - 'actions': Action, - 'aliases': ActionAlias, - 'actionstates': ActionExecutionState, - 'apikeys': ApiKey, - 'enforcements': RuleEnforcement, - 'executions': ActionExecution, - 'liveactions': LiveAction, - 'policies': Policy, - 'policytypes': PolicyType, - 'rules': Rule, - 'runners': RunnerType, - 'sensors': SensorType, - 'traces': Trace, - 'triggertypes': TriggerType, - 'triggers': Trigger, - 'triggerinstances': TriggerInstance, - 'users': User + "actions": Action, + "aliases": ActionAlias, + "actionstates": ActionExecutionState, + "apikeys": ApiKey, + "enforcements": RuleEnforcement, + "executions": ActionExecution, + "liveactions": LiveAction, + "policies": Policy, + "policytypes": PolicyType, + "rules": Rule, + "runners": RunnerType, + "sensors": SensorType, + "traces": Trace, + "triggertypes": TriggerType, + "triggers": Trigger, + "triggerinstances": TriggerInstance, + "users": User, } GIT_SUBMODULES_NOT_CHECKED_OUT_ERROR = """ Git submodule "%s" is not checked out. Make sure to run "git submodule update --init --recursive" in the repository root directory to check out all the submodules. -""".replace('\n', '').strip() +""".replace( + "\n", "" +).strip() def get_fixtures_base_path(): - return os.path.join(os.path.dirname(__file__), 'fixtures') + return os.path.join(os.path.dirname(__file__), "fixtures") def get_fixtures_packs_base_path(): - return os.path.join(os.path.dirname(__file__), 'fixtures/packs') + return os.path.join(os.path.dirname(__file__), "fixtures/packs") def get_resources_base_path(): - return os.path.join(os.path.dirname(__file__), 'resources') + return os.path.join(os.path.dirname(__file__), "resources") class FixturesLoader(object): def __init__(self): self.meta_loader = MetaLoader() - def save_fixtures_to_db(self, fixtures_pack='generic', fixtures_dict=None, - use_object_ids=False): + def save_fixtures_to_db( + self, fixtures_pack="generic", fixtures_dict=None, use_object_ids=False + ): """ Loads fixtures specified in fixtures_dict into the database and returns DB models for the fixtures. @@ -193,17 +216,22 @@ def save_fixtures_to_db(self, fixtures_pack='generic', fixtures_dict=None, for fixture in fixtures: # Guard against copy and type and similar typos if fixture in loaded_fixtures: - msg = 'Fixture "%s" is specified twice, probably a typo.' % (fixture) + msg = 'Fixture "%s" is specified twice, probably a typo.' % ( + fixture + ) raise ValueError(msg) fixture_dict = self.meta_loader.load( - self._get_fixture_file_path_abs(fixtures_pack_path, fixture_type, fixture)) + self._get_fixture_file_path_abs( + fixtures_pack_path, fixture_type, fixture + ) + ) api_model = API_MODEL(**fixture_dict) db_model = API_MODEL.to_model(api_model) # Make sure we also set and use object id if that functionality is used - if use_object_ids and 'id' in fixture_dict: - db_model.id = fixture_dict['id'] + if use_object_ids and "id" in fixture_dict: + db_model.id = fixture_dict["id"] db_model = PERSISTENCE_MODEL.add_or_update(db_model) loaded_fixtures[fixture] = db_model @@ -212,7 +240,7 @@ def save_fixtures_to_db(self, fixtures_pack='generic', fixtures_dict=None, return db_models - def load_fixtures(self, fixtures_pack='generic', fixtures_dict=None): + def load_fixtures(self, fixtures_pack="generic", fixtures_dict=None): """ Loads fixtures specified in fixtures_dict. We simply want to load the meta into dict objects. @@ -241,13 +269,16 @@ def load_fixtures(self, fixtures_pack='generic', fixtures_dict=None): loaded_fixtures = {} for fixture in fixtures: fixture_dict = self.meta_loader.load( - self._get_fixture_file_path_abs(fixtures_pack_path, fixture_type, fixture)) + self._get_fixture_file_path_abs( + fixtures_pack_path, fixture_type, fixture + ) + ) loaded_fixtures[fixture] = fixture_dict all_fixtures[fixture_type] = loaded_fixtures return all_fixtures - def load_models(self, fixtures_pack='generic', fixtures_dict=None): + def load_models(self, fixtures_pack="generic", fixtures_dict=None): """ Loads fixtures specified in fixtures_dict as db models. This method must be used for fixtures that have associated DB models. We simply want to load the @@ -281,7 +312,10 @@ def load_models(self, fixtures_pack='generic', fixtures_dict=None): loaded_models = {} for fixture in fixtures: fixture_dict = self.meta_loader.load( - self._get_fixture_file_path_abs(fixtures_pack_path, fixture_type, fixture)) + self._get_fixture_file_path_abs( + fixtures_pack_path, fixture_type, fixture + ) + ) api_model = API_MODEL(**fixture_dict) db_model = API_MODEL.to_model(api_model) loaded_models[fixture] = db_model @@ -289,8 +323,9 @@ def load_models(self, fixtures_pack='generic', fixtures_dict=None): return all_fixtures - def delete_fixtures_from_db(self, fixtures_pack='generic', fixtures_dict=None, - raise_on_fail=False): + def delete_fixtures_from_db( + self, fixtures_pack="generic", fixtures_dict=None, raise_on_fail=False + ): """ Deletes fixtures specified in fixtures_dict from the database. @@ -320,7 +355,10 @@ def delete_fixtures_from_db(self, fixtures_pack='generic', fixtures_dict=None, PERSISTENCE_MODEL = FIXTURE_PERSISTENCE_MODEL.get(fixture_type, None) for fixture in fixtures: fixture_dict = self.meta_loader.load( - self._get_fixture_file_path_abs(fixtures_pack_path, fixture_type, fixture)) + self._get_fixture_file_path_abs( + fixtures_pack_path, fixture_type, fixture + ) + ) # Note that when we have a reference mechanism consistent for # every model, we can just do a get and delete the object. Until # then, this model conversions are necessary. @@ -362,28 +400,36 @@ def _validate_fixtures_pack(self, fixtures_pack): fixtures_pack_path = self._get_fixtures_pack_path(fixtures_pack) if not self._is_fixture_pack_exists(fixtures_pack_path): - raise Exception('Fixtures pack not found ' + - 'in fixtures path %s.' % get_fixtures_base_path()) + raise Exception( + "Fixtures pack not found " + + "in fixtures path %s." % get_fixtures_base_path() + ) return fixtures_pack_path def _validate_fixture_dict(self, fixtures_dict, allowed=ALLOWED_FIXTURES): fixture_types = list(fixtures_dict.keys()) for fixture_type in fixture_types: if fixture_type not in allowed: - raise Exception('Disallowed fixture type: %s. Valid fixture types are: %s' % ( - fixture_type, ", ".join(allowed))) + raise Exception( + "Disallowed fixture type: %s. Valid fixture types are: %s" + % (fixture_type, ", ".join(allowed)) + ) def _is_fixture_pack_exists(self, fixtures_pack_path): return os.path.exists(fixtures_pack_path) - def _get_fixture_file_path_abs(self, fixtures_pack_path, fixtures_type, fixture_name): + def _get_fixture_file_path_abs( + self, fixtures_pack_path, fixtures_type, fixture_name + ): return os.path.join(fixtures_pack_path, fixtures_type, fixture_name) def _get_fixtures_pack_path(self, fixtures_pack_name): return os.path.join(get_fixtures_base_path(), fixtures_pack_name) def get_fixture_file_path_abs(self, fixtures_pack, fixtures_type, fixture_name): - return os.path.join(get_fixtures_base_path(), fixtures_pack, fixtures_type, fixture_name) + return os.path.join( + get_fixtures_base_path(), fixtures_pack, fixtures_type, fixture_name + ) def assert_submodules_are_checked_out(): @@ -392,9 +438,9 @@ def assert_submodules_are_checked_out(): root of the directory and that the "st2tests/st2tests/fixtures/packs/test" git repo submodule used by the tests is checked out. """ - pack_path = os.path.join(get_fixtures_packs_base_path(), 'test_content_version/') + pack_path = os.path.join(get_fixtures_packs_base_path(), "test_content_version/") pack_path = os.path.abspath(pack_path) - submodule_git_dir_or_file_path = os.path.join(pack_path, '.git') + submodule_git_dir_or_file_path = os.path.join(pack_path, ".git") # NOTE: In newer versions of git, that .git is a file and not a directory if not os.path.exists(submodule_git_dir_or_file_path): diff --git a/st2tests/st2tests/http.py b/st2tests/st2tests/http.py index 4dce56f45aa..e14672d0010 100644 --- a/st2tests/st2tests/http.py +++ b/st2tests/st2tests/http.py @@ -18,7 +18,6 @@ class FakeResponse(object): - def __init__(self, text, status_code, reason): self.text = text self.status_code = status_code diff --git a/st2tests/st2tests/mocks/action.py b/st2tests/st2tests/mocks/action.py index f09d5a7f8d6..ec8f7842b8c 100644 --- a/st2tests/st2tests/mocks/action.py +++ b/st2tests/st2tests/mocks/action.py @@ -25,10 +25,7 @@ from python_runner.python_action_wrapper import ActionService from st2tests.mocks.datastore import MockDatastoreService -__all__ = [ - 'MockActionWrapper', - 'MockActionService' -] +__all__ = ["MockActionWrapper", "MockActionService"] class MockActionWrapper(object): @@ -49,9 +46,11 @@ def __init__(self, action_wrapper): # We use a Mock class so use can assert logger was called with particular arguments self._logger = Mock(spec=RootLogger) - self._datastore_service = MockDatastoreService(logger=self._logger, - pack_name=self._action_wrapper._pack, - class_name=self._action_wrapper._class_name) + self._datastore_service = MockDatastoreService( + logger=self._logger, + pack_name=self._action_wrapper._pack, + class_name=self._action_wrapper._class_name, + ) @property def datastore_service(self): diff --git a/st2tests/st2tests/mocks/auth.py b/st2tests/st2tests/mocks/auth.py index e0624aca42d..6f322959cce 100644 --- a/st2tests/st2tests/mocks/auth.py +++ b/st2tests/st2tests/mocks/auth.py @@ -18,24 +18,18 @@ from st2auth.backends.base import BaseAuthenticationBackend # auser:apassword in b64 -DUMMY_CREDS = 'YXVzZXI6YXBhc3N3b3Jk' +DUMMY_CREDS = "YXVzZXI6YXBhc3N3b3Jk" -__all__ = [ - 'DUMMY_CREDS', - - 'MockAuthBackend', - 'MockRequest', - - 'get_mock_backend' -] +__all__ = ["DUMMY_CREDS", "MockAuthBackend", "MockRequest", "get_mock_backend"] class MockAuthBackend(BaseAuthenticationBackend): groups = [] def authenticate(self, username, password): - return ((username == 'auser' and password == 'apassword') or - (username == 'username' and password == 'password:password')) + return (username == "auser" and password == "apassword") or ( + username == "username" and password == "password:password" + ) def get_user(self, username): return username @@ -44,7 +38,7 @@ def get_user_groups(self, username): return self.groups -class MockRequest(): +class MockRequest: def __init__(self, ttl): self.ttl = ttl diff --git a/st2tests/st2tests/mocks/datastore.py b/st2tests/st2tests/mocks/datastore.py index fe8156bf9ed..0282a18ffdf 100644 --- a/st2tests/st2tests/mocks/datastore.py +++ b/st2tests/st2tests/mocks/datastore.py @@ -22,9 +22,7 @@ from st2common.services.datastore import BaseDatastoreService from st2client.models.keyvalue import KeyValuePair -__all__ = [ - 'MockDatastoreService' -] +__all__ = ["MockDatastoreService"] class MockDatastoreService(BaseDatastoreService): @@ -35,7 +33,7 @@ class MockDatastoreService(BaseDatastoreService): def __init__(self, logger, pack_name, class_name, api_username=None): self._pack_name = pack_name self._class_name = class_name - self._username = api_username or 'admin' + self._username = api_username or "admin" # Holds mock KeyValuePair objects # Key is a KeyValuePair name and value is the KeyValuePair object @@ -53,18 +51,9 @@ def get_user_info(self): :rtype: ``dict`` """ result = { - 'username': self._username, - 'rbac': { - 'is_admin': True, - 'enabled': True, - 'roles': [ - 'admin' - ] - }, - 'authentication': { - 'method': 'authentication token', - 'location': 'header' - } + "username": self._username, + "rbac": {"is_admin": True, "enabled": True, "roles": ["admin"]}, + "authentication": {"method": "authentication token", "location": "header"}, } return result @@ -101,12 +90,16 @@ def get_value(self, name, local=True, scope=SYSTEM_SCOPE, decrypt=False): kvp = self._datastore_items[name] return kvp.value - def set_value(self, name, value, ttl=None, local=True, scope=SYSTEM_SCOPE, encrypt=False): + def set_value( + self, name, value, ttl=None, local=True, scope=SYSTEM_SCOPE, encrypt=False + ): """ Store a value in a dictionary which is local to this class. """ if ttl: - raise ValueError('MockDatastoreService.set_value doesn\'t support "ttl" argument') + raise ValueError( + 'MockDatastoreService.set_value doesn\'t support "ttl" argument' + ) name = self._get_full_key_name(name=name, local=local) diff --git a/st2tests/st2tests/mocks/execution.py b/st2tests/st2tests/mocks/execution.py index 1fdf8a42629..00e3c8ef111 100644 --- a/st2tests/st2tests/mocks/execution.py +++ b/st2tests/st2tests/mocks/execution.py @@ -21,13 +21,10 @@ from st2common.models.db.execution import ActionExecutionDB -__all__ = [ - 'MockExecutionPublisher' -] +__all__ = ["MockExecutionPublisher"] class MockExecutionPublisher(object): - @classmethod def publish_update(cls, payload): try: @@ -39,7 +36,6 @@ def publish_update(cls, payload): class MockExecutionPublisherNonBlocking(object): - @classmethod def publish_update(cls, payload): try: diff --git a/st2tests/st2tests/mocks/liveaction.py b/st2tests/st2tests/mocks/liveaction.py index 753224d9eab..2b329e6b252 100644 --- a/st2tests/st2tests/mocks/liveaction.py +++ b/st2tests/st2tests/mocks/liveaction.py @@ -26,14 +26,10 @@ from st2common.constants import action as action_constants from st2common.models.db.liveaction import LiveActionDB -__all__ = [ - 'MockLiveActionPublisher', - 'MockLiveActionPublisherNonBlocking' -] +__all__ = ["MockLiveActionPublisher", "MockLiveActionPublisherNonBlocking"] class MockLiveActionPublisher(object): - @classmethod def process(cls, payload): ex_req = scheduling.get_scheduler_entrypoint().process(payload) @@ -106,7 +102,6 @@ def wait_all(cls): class MockLiveActionPublisherSchedulingQueueOnly(object): - @classmethod def process(cls, payload): scheduling.get_scheduler_entrypoint().process(payload) diff --git a/st2tests/st2tests/mocks/runners/async_runner.py b/st2tests/st2tests/mocks/runners/async_runner.py index 04092029038..31258fae4ec 100644 --- a/st2tests/st2tests/mocks/runners/async_runner.py +++ b/st2tests/st2tests/mocks/runners/async_runner.py @@ -14,15 +14,16 @@ # limitations under the License. from __future__ import absolute_import + try: import simplejson as json except: import json from st2common.runners.base import AsyncActionRunner -from st2common.constants.action import (LIVEACTION_STATUS_RUNNING) +from st2common.constants.action import LIVEACTION_STATUS_RUNNING -RAISE_PROPERTY = 'raise' +RAISE_PROPERTY = "raise" def get_runner(): @@ -31,7 +32,7 @@ def get_runner(): class AsyncTestRunner(AsyncActionRunner): def __init__(self): - super(AsyncTestRunner, self).__init__(runner_id='1') + super(AsyncTestRunner, self).__init__(runner_id="1") self.pre_run_called = False self.run_called = False self.post_run_called = False @@ -43,14 +44,11 @@ def run(self, action_params): self.run_called = True result = {} if self.runner_parameters.get(RAISE_PROPERTY, False): - raise Exception('Raise required.') + raise Exception("Raise required.") else: - result = { - 'ran': True, - 'action_params': action_params - } + result = {"ran": True, "action_params": action_params} - return (LIVEACTION_STATUS_RUNNING, json.dumps(result), {'id': 'foo'}) + return (LIVEACTION_STATUS_RUNNING, json.dumps(result), {"id": "foo"}) def post_run(self, status, result): self.post_run_called = True diff --git a/st2tests/st2tests/mocks/runners/polling_async_runner.py b/st2tests/st2tests/mocks/runners/polling_async_runner.py index 435f7eb9b68..c48bb9aa675 100644 --- a/st2tests/st2tests/mocks/runners/polling_async_runner.py +++ b/st2tests/st2tests/mocks/runners/polling_async_runner.py @@ -14,15 +14,16 @@ # limitations under the License. from __future__ import absolute_import + try: import simplejson as json except: import json from st2common.runners.base import PollingAsyncActionRunner -from st2common.constants.action import (LIVEACTION_STATUS_RUNNING) +from st2common.constants.action import LIVEACTION_STATUS_RUNNING -RAISE_PROPERTY = 'raise' +RAISE_PROPERTY = "raise" def get_runner(): @@ -31,7 +32,7 @@ def get_runner(): class PollingAsyncTestRunner(PollingAsyncActionRunner): def __init__(self): - super(PollingAsyncTestRunner, self).__init__(runner_id='1') + super(PollingAsyncTestRunner, self).__init__(runner_id="1") self.pre_run_called = False self.run_called = False self.post_run_called = False @@ -43,14 +44,11 @@ def run(self, action_params): self.run_called = True result = {} if self.runner_parameters.get(RAISE_PROPERTY, False): - raise Exception('Raise required.') + raise Exception("Raise required.") else: - result = { - 'ran': True, - 'action_params': action_params - } + result = {"ran": True, "action_params": action_params} - return (LIVEACTION_STATUS_RUNNING, json.dumps(result), {'id': 'foo'}) + return (LIVEACTION_STATUS_RUNNING, json.dumps(result), {"id": "foo"}) def post_run(self, status, result): self.post_run_called = True diff --git a/st2tests/st2tests/mocks/runners/runner.py b/st2tests/st2tests/mocks/runners/runner.py index 40d07516c66..b89b75b712f 100644 --- a/st2tests/st2tests/mocks/runners/runner.py +++ b/st2tests/st2tests/mocks/runners/runner.py @@ -17,12 +17,9 @@ import json from st2common.runners.base import ActionRunner -from st2common.constants.action import (LIVEACTION_STATUS_SUCCEEDED) +from st2common.constants.action import LIVEACTION_STATUS_SUCCEEDED -__all__ = [ - 'get_runner', - 'MockActionRunner' -] +__all__ = ["get_runner", "MockActionRunner"] def get_runner(config=None): @@ -31,7 +28,7 @@ def get_runner(config=None): class MockActionRunner(ActionRunner): def __init__(self): - super(MockActionRunner, self).__init__(runner_id='1') + super(MockActionRunner, self).__init__(runner_id="1") self.pre_run_called = False self.run_called = False @@ -45,22 +42,15 @@ def run(self, action_params): self.run_called = True result = {} - if self.runner_parameters.get('raise', False): - raise Exception('Raise required.') + if self.runner_parameters.get("raise", False): + raise Exception("Raise required.") - default_result = { - 'ran': True, - 'action_params': action_params - } - default_context = { - 'third_party_system': { - 'ref_id': '1234' - } - } + default_result = {"ran": True, "action_params": action_params} + default_context = {"third_party_system": {"ref_id": "1234"}} - status = self.runner_parameters.get('mock_status', LIVEACTION_STATUS_SUCCEEDED) - result = self.runner_parameters.get('mock_result', default_result) - context = self.runner_parameters.get('mock_context', default_context) + status = self.runner_parameters.get("mock_status", LIVEACTION_STATUS_SUCCEEDED) + result = self.runner_parameters.get("mock_result", default_result) + context = self.runner_parameters.get("mock_context", default_context) return (status, json.dumps(result), context) diff --git a/st2tests/st2tests/mocks/sensor.py b/st2tests/st2tests/mocks/sensor.py index 1f06787b141..c65825786cd 100644 --- a/st2tests/st2tests/mocks/sensor.py +++ b/st2tests/st2tests/mocks/sensor.py @@ -27,10 +27,7 @@ from st2reactor.container.sensor_wrapper import SensorService from st2tests.mocks.datastore import MockDatastoreService -__all__ = [ - 'MockSensorWrapper', - 'MockSensorService' -] +__all__ = ["MockSensorWrapper", "MockSensorService"] class MockSensorWrapper(object): @@ -54,9 +51,11 @@ def __init__(self, sensor_wrapper): # Holds a list of triggers which were dispatched self.dispatched_triggers = [] - self._datastore_service = MockDatastoreService(logger=self._logger, - pack_name=self._sensor_wrapper._pack, - class_name=self._sensor_wrapper._class_name) + self._datastore_service = MockDatastoreService( + logger=self._logger, + pack_name=self._sensor_wrapper._pack, + class_name=self._sensor_wrapper._class_name, + ) @property def datastore_service(self): @@ -74,14 +73,11 @@ def get_logger(self, name): def dispatch(self, trigger, payload=None, trace_tag=None): trace_context = TraceContext(trace_tag=trace_tag) if trace_tag else None - return self.dispatch_with_context(trigger=trigger, payload=payload, - trace_context=trace_context) + return self.dispatch_with_context( + trigger=trigger, payload=payload, trace_context=trace_context + ) def dispatch_with_context(self, trigger, payload=None, trace_context=None): - item = { - 'trigger': trigger, - 'payload': payload, - 'trace_context': trace_context - } + item = {"trigger": trigger, "payload": payload, "trace_context": trace_context} self.dispatched_triggers.append(item) return item diff --git a/st2tests/st2tests/mocks/workflow.py b/st2tests/st2tests/mocks/workflow.py index ef50b66389c..051bf5cb835 100644 --- a/st2tests/st2tests/mocks/workflow.py +++ b/st2tests/st2tests/mocks/workflow.py @@ -23,13 +23,10 @@ from st2common.models.db import workflow as wf_ex_db -__all__ = [ - 'MockWorkflowExecutionPublisher' -] +__all__ = ["MockWorkflowExecutionPublisher"] class MockWorkflowExecutionPublisher(object): - @classmethod def publish_create(cls, payload): try: diff --git a/st2tests/st2tests/pack_resource.py b/st2tests/st2tests/pack_resource.py index 7d51d742197..51f58992185 100644 --- a/st2tests/st2tests/pack_resource.py +++ b/st2tests/st2tests/pack_resource.py @@ -19,9 +19,7 @@ from unittest2 import TestCase -__all__ = [ - 'BasePackResourceTestCase' -] +__all__ = ["BasePackResourceTestCase"] class BasePackResourceTestCase(TestCase): @@ -39,16 +37,16 @@ def get_fixture_content(self, fixture_path): :type fixture_path: ``str`` """ base_pack_path = self._get_base_pack_path() - fixtures_path = os.path.join(base_pack_path, 'tests/fixtures/') + fixtures_path = os.path.join(base_pack_path, "tests/fixtures/") fixture_path = os.path.join(fixtures_path, fixture_path) - with open(fixture_path, 'r') as fp: + with open(fixture_path, "r") as fp: content = fp.read() return content def _get_base_pack_path(self): test_file_path = inspect.getfile(self.__class__) - base_pack_path = os.path.join(os.path.dirname(test_file_path), '..') + base_pack_path = os.path.join(os.path.dirname(test_file_path), "..") base_pack_path = os.path.abspath(base_pack_path) return base_pack_path diff --git a/st2tests/st2tests/policies/concurrency.py b/st2tests/st2tests/policies/concurrency.py index e7494a134b4..ecd6ffb51a7 100644 --- a/st2tests/st2tests/policies/concurrency.py +++ b/st2tests/st2tests/policies/concurrency.py @@ -20,11 +20,12 @@ class FakeConcurrencyApplicator(BaseConcurrencyApplicator): - def __init__(self, policy_ref, policy_type, *args, **kwargs): - super(FakeConcurrencyApplicator, self).__init__(policy_ref=policy_ref, - policy_type=policy_type, - threshold=kwargs.get('threshold', 0)) + super(FakeConcurrencyApplicator, self).__init__( + policy_ref=policy_ref, + policy_type=policy_type, + threshold=kwargs.get("threshold", 0), + ) def get_threshold(self): return self.threshold @@ -35,7 +36,8 @@ def apply_before(self, target): target = action_utils.update_liveaction_status( status=action_constants.LIVEACTION_STATUS_CANCELED, liveaction_id=target.id, - publish=False) + publish=False, + ) return target diff --git a/st2tests/st2tests/policies/mock_exception.py b/st2tests/st2tests/policies/mock_exception.py index 298a8cb7bb1..673eccbb549 100644 --- a/st2tests/st2tests/policies/mock_exception.py +++ b/st2tests/st2tests/policies/mock_exception.py @@ -18,9 +18,8 @@ class RaiseExceptionApplicator(base.ResourcePolicyApplicator): - def apply_before(self, target): - raise Exception('For honor!!!!') + raise Exception("For honor!!!!") def apply_after(self, target): return target diff --git a/st2tests/st2tests/resources/packs/pythonactions/actions/echoer.py b/st2tests/st2tests/resources/packs/pythonactions/actions/echoer.py index 4994db82b35..6a124573e9c 100644 --- a/st2tests/st2tests/resources/packs/pythonactions/actions/echoer.py +++ b/st2tests/st2tests/resources/packs/pythonactions/actions/echoer.py @@ -18,4 +18,4 @@ class Echoer(Action): def run(self, action_input): - return {'action_input': action_input} + return {"action_input": action_input} diff --git a/st2tests/st2tests/resources/packs/pythonactions/actions/non_simple_type.py b/st2tests/st2tests/resources/packs/pythonactions/actions/non_simple_type.py index 926a56c73f5..2597f811ad1 100644 --- a/st2tests/st2tests/resources/packs/pythonactions/actions/non_simple_type.py +++ b/st2tests/st2tests/resources/packs/pythonactions/actions/non_simple_type.py @@ -18,14 +18,10 @@ class Test(object): - foo = 'bar' + foo = "bar" class NonSimpleTypeAction(Action): def run(self): - result = [ - {'a': '1'}, - {'c': 2, 'h': 3}, - {'e': Test()} - ] + result = [{"a": "1"}, {"c": 2, "h": 3}, {"e": Test()}] return result diff --git a/st2tests/st2tests/resources/packs/pythonactions/actions/pascal_row.py b/st2tests/st2tests/resources/packs/pythonactions/actions/pascal_row.py index 3034e0352a7..cacb89d0053 100644 --- a/st2tests/st2tests/resources/packs/pythonactions/actions/pascal_row.py +++ b/st2tests/st2tests/resources/packs/pythonactions/actions/pascal_row.py @@ -30,35 +30,39 @@ def run(self, **kwargs): except Exception: pass - self.logger.info('test info log message') - self.logger.debug('test debug log message') - self.logger.error('test error log message') + self.logger.info("test info log message") + self.logger.debug("test debug log message") + self.logger.error("test error log message") return PascalRowAction._compute_pascal_row(**kwargs) @staticmethod def _compute_pascal_row(row_index=0): - print('Pascal row action') + print("Pascal row action") - if row_index == 'a': - return False, 'This is suppose to fail don\'t worry!!' - elif row_index == 'b': + if row_index == "a": + return False, "This is suppose to fail don't worry!!" + elif row_index == "b": return None - elif row_index == 'complex_type': + elif row_index == "complex_type": result = PascalRowAction() return (False, result) - elif row_index == 'c': + elif row_index == "c": return False, None - elif row_index == 'd': - return 'succeeded', [1, 2, 3, 4] - elif row_index == 'e': + elif row_index == "d": + return "succeeded", [1, 2, 3, 4] + elif row_index == "e": return [1, 2] elif row_index == 5: - return [math.factorial(row_index) / - (math.factorial(i) * math.factorial(row_index - i)) - for i in range(row_index + 1)] - elif row_index == 'f': - raise ValueError('Duplicate traceback test') + return [ + math.factorial(row_index) + / (math.factorial(i) * math.factorial(row_index - i)) + for i in range(row_index + 1) + ] + elif row_index == "f": + raise ValueError("Duplicate traceback test") else: - return True, [math.factorial(row_index) / - (math.factorial(i) * math.factorial(row_index - i)) - for i in range(row_index + 1)] + return True, [ + math.factorial(row_index) + / (math.factorial(i) * math.factorial(row_index - i)) + for i in range(row_index + 1) + ] diff --git a/st2tests/st2tests/resources/packs/pythonactions/actions/print_config_item_doesnt_exist.py b/st2tests/st2tests/resources/packs/pythonactions/actions/print_config_item_doesnt_exist.py index f1f888f069a..0bf6856145d 100644 --- a/st2tests/st2tests/resources/packs/pythonactions/actions/print_config_item_doesnt_exist.py +++ b/st2tests/st2tests/resources/packs/pythonactions/actions/print_config_item_doesnt_exist.py @@ -22,5 +22,5 @@ class PrintConfigItemAction(Action): def run(self): print(self.config) # Verify .get() still works - print(self.config.get('item1', 'default_value')) - print(self.config['key']) + print(self.config.get("item1", "default_value")) + print(self.config["key"]) diff --git a/st2tests/st2tests/resources/packs/pythonactions/actions/print_to_stdout_and_stderr.py b/st2tests/st2tests/resources/packs/pythonactions/actions/print_to_stdout_and_stderr.py index 06c0d2f30a3..9838e5bfb62 100644 --- a/st2tests/st2tests/resources/packs/pythonactions/actions/print_to_stdout_and_stderr.py +++ b/st2tests/st2tests/resources/packs/pythonactions/actions/print_to_stdout_and_stderr.py @@ -24,7 +24,7 @@ class PrintToStdoutAndStderrAction(Action): def run(self, stdout_count=3, stderr_count=3): for index in range(0, stdout_count): - sys.stdout.write('stdout line %s\n' % (index)) + sys.stdout.write("stdout line %s\n" % (index)) for index in range(0, stderr_count): - sys.stderr.write('stderr line %s\n' % (index)) + sys.stderr.write("stderr line %s\n" % (index)) diff --git a/st2tests/st2tests/resources/packs/pythonactions/actions/python_paths.py b/st2tests/st2tests/resources/packs/pythonactions/actions/python_paths.py index 717549347b0..ffe7b69b3bc 100644 --- a/st2tests/st2tests/resources/packs/pythonactions/actions/python_paths.py +++ b/st2tests/st2tests/resources/packs/pythonactions/actions/python_paths.py @@ -22,5 +22,5 @@ class PythonPathsAction(Action): def run(self): - print('sys.path: %s' % (sys.path)) - print('PYTHONPATH: %s' % (os.environ.get('PYTHONPATH'))) + print("sys.path: %s" % (sys.path)) + print("PYTHONPATH: %s" % (os.environ.get("PYTHONPATH"))) diff --git a/st2tests/st2tests/resources/packs/pythonactions/actions/test.py b/st2tests/st2tests/resources/packs/pythonactions/actions/test.py index d95939990f0..eeed54fbb09 100644 --- a/st2tests/st2tests/resources/packs/pythonactions/actions/test.py +++ b/st2tests/st2tests/resources/packs/pythonactions/actions/test.py @@ -22,4 +22,4 @@ class TestAction(Action): def run(self): - return 'test action' + return "test action" diff --git a/st2tests/st2tests/sensors.py b/st2tests/st2tests/sensors.py index 0b6f31e6b6b..52c0451f45f 100644 --- a/st2tests/st2tests/sensors.py +++ b/st2tests/st2tests/sensors.py @@ -18,9 +18,7 @@ from st2tests.mocks.sensor import MockSensorService from st2tests.pack_resource import BasePackResourceTestCase -__all__ = [ - 'BaseSensorTestCase' -] +__all__ = ["BaseSensorTestCase"] class BaseSensorTestCase(BasePackResourceTestCase): @@ -37,22 +35,20 @@ def setUp(self): super(BaseSensorTestCase, self).setUp() class_name = self.sensor_cls.__name__ - sensor_wrapper = MockSensorWrapper(pack='tests', class_name=class_name) + sensor_wrapper = MockSensorWrapper(pack="tests", class_name=class_name) self.sensor_service = MockSensorService(sensor_wrapper=sensor_wrapper) def get_sensor_instance(self, config=None, poll_interval=None): """ Retrieve instance of the sensor class. """ - kwargs = { - 'sensor_service': self.sensor_service - } + kwargs = {"sensor_service": self.sensor_service} if config: - kwargs['config'] = config + kwargs["config"] = config if poll_interval is not None: - kwargs['poll_interval'] = poll_interval + kwargs["poll_interval"] = poll_interval instance = self.sensor_cls(**kwargs) # pylint: disable=not-callable return instance @@ -79,15 +75,15 @@ def assertTriggerDispatched(self, trigger, payload=None, trace_context=None): """ dispatched_triggers = self.get_dispatched_triggers() for item in dispatched_triggers: - trigger_matches = (item['trigger'] == trigger) + trigger_matches = item["trigger"] == trigger if payload: - payload_matches = (item['payload'] == payload) + payload_matches = item["payload"] == payload else: payload_matches = True if trace_context: - trace_context_matches = (item['trace_context'] == trace_context) + trace_context_matches = item["trace_context"] == trace_context else: trace_context_matches = True diff --git a/st2tests/testpacks/checks/actions/checks/check_loadavg.py b/st2tests/testpacks/checks/actions/checks/check_loadavg.py index 4a568348326..9439679df35 100755 --- a/st2tests/testpacks/checks/actions/checks/check_loadavg.py +++ b/st2tests/testpacks/checks/actions/checks/check_loadavg.py @@ -23,40 +23,40 @@ def print_load_avg(args): period = args[1] - loadavg_file = '/proc/loadavg' - cpuinfo_file = '/proc/cpuinfo' + loadavg_file = "/proc/loadavg" + cpuinfo_file = "/proc/cpuinfo" cpus = 0 try: - fh = open(loadavg_file, 'r') + fh = open(loadavg_file, "r") load = fh.readline().split()[0:3] fh.close() except: - sys.stderr.write('Error opening %s\n' % loadavg_file) + sys.stderr.write("Error opening %s\n" % loadavg_file) sys.exit(2) try: - fh = open(cpuinfo_file, 'r') + fh = open(cpuinfo_file, "r") for line in fh: - if 'processor' in line: + if "processor" in line: cpus += 1 fh.close() except: - sys.stderr.write('Error opeing %s\n' % cpuinfo_file) + sys.stderr.write("Error opeing %s\n" % cpuinfo_file) - one_min = '1 min load/core: %s' % str(float(load[0]) / cpus) - five_min = '5 min load/core: %s' % str(float(load[1]) / cpus) - fifteen_min = '15 min load/core: %s' % str(float(load[2]) / cpus) + one_min = "1 min load/core: %s" % str(float(load[0]) / cpus) + five_min = "5 min load/core: %s" % str(float(load[1]) / cpus) + fifteen_min = "15 min load/core: %s" % str(float(load[2]) / cpus) - if period == '1' or period == 'one': + if period == "1" or period == "one": print(one_min) - elif period == '5' or period == 'five': + elif period == "5" or period == "five": print(five_min) - elif period == '15' or period == 'fifteen': + elif period == "15" or period == "fifteen": print(fifteen_min) else: print(one_min + " " + five_min + " " + fifteen_min) -if __name__ == '__main__': +if __name__ == "__main__": print_load_avg(sys.argv) diff --git a/tools/config_gen.py b/tools/config_gen.py index e705161ea37..e0004d04e1b 100755 --- a/tools/config_gen.py +++ b/tools/config_gen.py @@ -24,57 +24,57 @@ from oslo_config import cfg -CONFIGS = ['st2actions.config', - 'st2actions.scheduler.config', - 'st2actions.notifier.config', - 'st2actions.workflows.config', - 'st2api.config', - 'st2stream.config', - 'st2auth.config', - 'st2common.config', - 'st2exporter.config', - 'st2reactor.rules.config', - 'st2reactor.sensor.config', - 'st2reactor.timer.config', - 'st2reactor.garbage_collector.config'] - -SKIP_GROUPS = ['api_pecan', 'rbac', 'results_tracker'] +CONFIGS = [ + "st2actions.config", + "st2actions.scheduler.config", + "st2actions.notifier.config", + "st2actions.workflows.config", + "st2api.config", + "st2stream.config", + "st2auth.config", + "st2common.config", + "st2exporter.config", + "st2reactor.rules.config", + "st2reactor.sensor.config", + "st2reactor.timer.config", + "st2reactor.garbage_collector.config", +] + +SKIP_GROUPS = ["api_pecan", "rbac", "results_tracker"] # We group auth options together to make it a bit more clear what applies where AUTH_OPTIONS = { - 'common': [ - 'enable', - 'mode', - 'logging', - 'api_url', - 'token_ttl', - 'service_token_ttl', - 'sso', - 'sso_backend', - 'sso_backend_kwargs', - 'debug' + "common": [ + "enable", + "mode", + "logging", + "api_url", + "token_ttl", + "service_token_ttl", + "sso", + "sso_backend", + "sso_backend_kwargs", + "debug", + ], + "standalone": [ + "host", + "port", + "use_ssl", + "cert", + "key", + "backend", + "backend_kwargs", ], - 'standalone': [ - 'host', - 'port', - 'use_ssl', - 'cert', - 'key', - 'backend', - 'backend_kwargs' - ] } # Some of the config values change depending on the environment where this script is ran so we # set them to static values to ensure consistent and stable output STATIC_OPTION_VALUES = { - 'actionrunner': { - 'virtualenv_binary': '/usr/bin/virtualenv', - 'python_binary': '/usr/bin/python', + "actionrunner": { + "virtualenv_binary": "/usr/bin/virtualenv", + "python_binary": "/usr/bin/python", }, - 'webui': { - 'webui_base_url': 'https://localhost' - } + "webui": {"webui_base_url": "https://localhost"}, } COMMON_AUTH_OPTIONS_COMMENT = """ @@ -112,22 +112,28 @@ def _clear_config(): def _read_group(opt_group): all_options = list(opt_group._opts.values()) - if opt_group.name == 'auth': + if opt_group.name == "auth": print(COMMON_AUTH_OPTIONS_COMMENT) - print('') - common_options = [option for option in all_options if option['opt'].name in - AUTH_OPTIONS['common']] + print("") + common_options = [ + option + for option in all_options + if option["opt"].name in AUTH_OPTIONS["common"] + ] _print_options(opt_group=opt_group, options=common_options) - print('') + print("") print(STANDALONE_AUTH_OPTIONS_COMMENT) - print('') - standalone_options = [option for option in all_options if option['opt'].name in - AUTH_OPTIONS['standalone']] + print("") + standalone_options = [ + option + for option in all_options + if option["opt"].name in AUTH_OPTIONS["standalone"] + ] _print_options(opt_group=opt_group, options=standalone_options) if len(common_options) + len(standalone_options) != len(all_options): - msg = ('Not all options are declared in AUTH_OPTIONS dict, please update it') + msg = "Not all options are declared in AUTH_OPTIONS dict, please update it" raise Exception(msg) else: options = all_options @@ -137,33 +143,35 @@ def _read_group(opt_group): def _read_groups(opt_groups): opt_groups = collections.OrderedDict(sorted(opt_groups.items())) for name, opt_group in six.iteritems(opt_groups): - print('[%s]' % name) + print("[%s]" % name) _read_group(opt_group) - print('') + print("") def _print_options(opt_group, options): - for opt in sorted(options, key=lambda x: x['opt'].name): - opt = opt['opt'] + for opt in sorted(options, key=lambda x: x["opt"].name): + opt = opt["opt"] # Special case for options which could change during this script run - static_option_value = STATIC_OPTION_VALUES.get(opt_group.name, {}).get(opt.name, None) + static_option_value = STATIC_OPTION_VALUES.get(opt_group.name, {}).get( + opt.name, None + ) if static_option_value: opt.default = static_option_value # Special handling for list options if isinstance(opt, cfg.ListOpt): if opt.default: - value = ','.join(opt.default) + value = ",".join(opt.default) else: - value = '' + value = "" - value += ' # comma separated list allowed here.' + value += " # comma separated list allowed here." else: value = opt.default - print('# %s' % opt.help) - print('%s = %s' % (opt.name, value)) + print("# %s" % opt.help) + print("%s = %s" % (opt.name, value)) def main(args): @@ -176,5 +184,5 @@ def main(args): _read_groups(opt_groups) -if __name__ == '__main__': +if __name__ == "__main__": main(sys.argv) diff --git a/tools/diff-db-disk.py b/tools/diff-db-disk.py index ec09a767094..a9e65d72ea9 100755 --- a/tools/diff-db-disk.py +++ b/tools/diff-db-disk.py @@ -47,20 +47,20 @@ from st2common.persistence.action import Action registrar = ResourceRegistrar() -registrar.ALLOWED_EXTENSIONS = ['.yaml', '.yml', '.json'] +registrar.ALLOWED_EXTENSIONS = [".yaml", ".yml", ".json"] meta_loader = MetaLoader() API_MODELS_ARTIFACT_TYPES = { - 'actions': ActionAPI, - 'sensors': SensorTypeAPI, - 'rules': RuleAPI + "actions": ActionAPI, + "sensors": SensorTypeAPI, + "rules": RuleAPI, } API_MODELS_PERSISTENT_MODELS = { Action: ActionAPI, SensorType: SensorTypeAPI, - Rule: RuleAPI + Rule: RuleAPI, } @@ -77,13 +77,15 @@ def _get_api_models_from_db(persistence_model, pack_dir=None): filters = {} if pack_dir: pack_name = os.path.basename(os.path.normpath(pack_dir)) - filters = {'pack': pack_name} + filters = {"pack": pack_name} models = persistence_model.query(**filters) models_dict = {} for model in models: - model_pack = getattr(model, 'pack', None) or DEFAULT_PACK_NAME - model_ref = ResourceReference.to_string_reference(name=model.name, pack=model_pack) - if getattr(model, 'id', None): + model_pack = getattr(model, "pack", None) or DEFAULT_PACK_NAME + model_ref = ResourceReference.to_string_reference( + name=model.name, pack=model_pack + ) + if getattr(model, "id", None): del model.id API_MODEL = API_MODELS_PERSISTENT_MODELS[persistence_model] models_dict[model_ref] = API_MODEL.from_model(model) @@ -107,15 +109,14 @@ def _get_api_models_from_disk(artifact_type, pack_dir=None): artifacts_paths = registrar.get_resources_from_pack(pack_path) for artifact_path in artifacts_paths: artifact = meta_loader.load(artifact_path) - if artifact_type == 'sensors': + if artifact_type == "sensors": sensors_dir = os.path.dirname(artifact_path) - sensor_file_path = os.path.join(sensors_dir, artifact['entry_point']) - artifact['artifact_uri'] = 'file://' + sensor_file_path - name = artifact.get('name', None) or artifact.get('class_name', None) - if not artifact.get('pack', None): - artifact['pack'] = pack_name - ref = ResourceReference.to_string_reference(name=name, - pack=pack_name) + sensor_file_path = os.path.join(sensors_dir, artifact["entry_point"]) + artifact["artifact_uri"] = "file://" + sensor_file_path + name = artifact.get("name", None) or artifact.get("class_name", None) + if not artifact.get("pack", None): + artifact["pack"] = pack_name + ref = ResourceReference.to_string_reference(name=name, pack=pack_name) API_MODEL = API_MODELS_ARTIFACT_TYPES[artifact_type] # Following conversions are required because we add some fields with # default values in db model. If we don't do these conversions, @@ -128,42 +129,49 @@ def _get_api_models_from_disk(artifact_type, pack_dir=None): return artifacts_dict -def _content_diff(artifact_type=None, artifact_in_disk=None, artifact_in_db=None, - verbose=False): +def _content_diff( + artifact_type=None, artifact_in_disk=None, artifact_in_db=None, verbose=False +): artifact_in_disk_str = json.dumps( - artifact_in_disk.__json__(), sort_keys=True, - indent=4, separators=(',', ': ') + artifact_in_disk.__json__(), sort_keys=True, indent=4, separators=(",", ": ") ) artifact_in_db_str = json.dumps( - artifact_in_db.__json__(), sort_keys=True, - indent=4, separators=(',', ': ') + artifact_in_db.__json__(), sort_keys=True, indent=4, separators=(",", ": ") + ) + diffs = difflib.context_diff( + artifact_in_db_str.splitlines(), + artifact_in_disk_str.splitlines(), + fromfile="DB contents", + tofile="Disk contents", ) - diffs = difflib.context_diff(artifact_in_db_str.splitlines(), - artifact_in_disk_str.splitlines(), - fromfile='DB contents', tofile='Disk contents') printed = False for diff in diffs: if not printed: - identifier = getattr(artifact_in_db, 'ref', getattr(artifact_in_db, 'name')) - print('%s %s in db differs from what is in disk.' % (artifact_type.upper(), - identifier)) + identifier = getattr(artifact_in_db, "ref", getattr(artifact_in_db, "name")) + print( + "%s %s in db differs from what is in disk." + % (artifact_type.upper(), identifier) + ) printed = True print(diff) if verbose: - print('\n\nOriginal contents:') - print('===================\n') - print('Artifact in db:\n\n%s\n\n' % artifact_in_db_str) - print('Artifact in disk:\n\n%s\n\n' % artifact_in_disk_str) + print("\n\nOriginal contents:") + print("===================\n") + print("Artifact in db:\n\n%s\n\n" % artifact_in_db_str) + print("Artifact in disk:\n\n%s\n\n" % artifact_in_disk_str) -def _diff(persistence_model, artifact_type, pack_dir=None, verbose=True, - content_diff=True): +def _diff( + persistence_model, artifact_type, pack_dir=None, verbose=True, content_diff=True +): artifacts_in_db_dict = _get_api_models_from_db(persistence_model, pack_dir=pack_dir) artifacts_in_disk_dict = _get_api_models_from_disk(artifact_type, pack_dir=pack_dir) # print(artifacts_in_disk_dict) - all_artifacts = set(list(artifacts_in_db_dict.keys()) + list(artifacts_in_disk_dict.keys())) + all_artifacts = set( + list(artifacts_in_db_dict.keys()) + list(artifacts_in_disk_dict.keys()) + ) for artifact in all_artifacts: artifact_in_db = artifacts_in_db_dict.get(artifact, None) @@ -172,76 +180,96 @@ def _diff(persistence_model, artifact_type, pack_dir=None, verbose=True, artifact_in_db_pretty_json = None if verbose: - print('******************************************************************************') - print('Checking if artifact %s is present in both disk and db.' % artifact) + print( + "******************************************************************************" + ) + print("Checking if artifact %s is present in both disk and db." % artifact) if not artifact_in_db: - print('##############################################################################') - print('%s %s in disk not available in db.' % (artifact_type.upper(), artifact)) + print( + "##############################################################################" + ) + print( + "%s %s in disk not available in db." % (artifact_type.upper(), artifact) + ) artifact_in_disk_pretty_json = json.dumps( - artifact_in_disk.__json__(), sort_keys=True, - indent=4, separators=(',', ': ') + artifact_in_disk.__json__(), + sort_keys=True, + indent=4, + separators=(",", ": "), ) if verbose: - print('File contents: \n') + print("File contents: \n") print(artifact_in_disk_pretty_json) continue if not artifact_in_disk: - print('##############################################################################') - print('%s %s in db not available in disk.' % (artifact_type.upper(), artifact)) + print( + "##############################################################################" + ) + print( + "%s %s in db not available in disk." % (artifact_type.upper(), artifact) + ) artifact_in_db_pretty_json = json.dumps( - artifact_in_db.__json__(), sort_keys=True, - indent=4, separators=(',', ': ') + artifact_in_db.__json__(), + sort_keys=True, + indent=4, + separators=(",", ": "), ) if verbose: - print('DB contents: \n') + print("DB contents: \n") print(artifact_in_db_pretty_json) continue if verbose: - print('Artifact %s exists in both disk and db.' % artifact) + print("Artifact %s exists in both disk and db." % artifact) if content_diff: if verbose: - print('Performing content diff for artifact %s.' % artifact) + print("Performing content diff for artifact %s." % artifact) - _content_diff(artifact_type=artifact_type, - artifact_in_disk=artifact_in_disk, - artifact_in_db=artifact_in_db, - verbose=verbose) + _content_diff( + artifact_type=artifact_type, + artifact_in_disk=artifact_in_disk, + artifact_in_db=artifact_in_db, + verbose=verbose, + ) def _diff_actions(pack_dir=None, verbose=False, content_diff=True): - _diff(Action, 'actions', pack_dir=pack_dir, - verbose=verbose, content_diff=content_diff) + _diff( + Action, "actions", pack_dir=pack_dir, verbose=verbose, content_diff=content_diff + ) def _diff_sensors(pack_dir=None, verbose=False, content_diff=True): - _diff(SensorType, 'sensors', pack_dir=pack_dir, - verbose=verbose, content_diff=content_diff) + _diff( + SensorType, + "sensors", + pack_dir=pack_dir, + verbose=verbose, + content_diff=content_diff, + ) def _diff_rules(pack_dir=None, verbose=True, content_diff=True): - _diff(Rule, 'rules', pack_dir=pack_dir, - verbose=verbose, content_diff=content_diff) + _diff(Rule, "rules", pack_dir=pack_dir, verbose=verbose, content_diff=content_diff) def main(): monkey_patch() cli_opts = [ - cfg.BoolOpt('sensors', default=False, - help='diff sensor alone.'), - cfg.BoolOpt('actions', default=False, - help='diff actions alone.'), - cfg.BoolOpt('rules', default=False, - help='diff rules alone.'), - cfg.BoolOpt('all', default=False, - help='diff sensors, actions and rules.'), - cfg.BoolOpt('verbose', default=False), - cfg.BoolOpt('simple', default=False, - help='In simple mode, tool only tells you if content is missing.' + - 'It doesn\'t show you content diff between disk and db.'), - cfg.StrOpt('pack-dir', default=None, help='Path to specific pack to diff.') + cfg.BoolOpt("sensors", default=False, help="diff sensor alone."), + cfg.BoolOpt("actions", default=False, help="diff actions alone."), + cfg.BoolOpt("rules", default=False, help="diff rules alone."), + cfg.BoolOpt("all", default=False, help="diff sensors, actions and rules."), + cfg.BoolOpt("verbose", default=False), + cfg.BoolOpt( + "simple", + default=False, + help="In simple mode, tool only tells you if content is missing." + + "It doesn't show you content diff between disk and db.", + ), + cfg.StrOpt("pack-dir", default=None, help="Path to specific pack to diff."), ] do_register_cli_opts(cli_opts) config.parse_args() @@ -254,23 +282,35 @@ def main(): content_diff = not cfg.CONF.simple if cfg.CONF.all: - _diff_sensors(pack_dir=pack_dir, verbose=cfg.CONF.verbose, content_diff=content_diff) - _diff_actions(pack_dir=pack_dir, verbose=cfg.CONF.verbose, content_diff=content_diff) - _diff_rules(pack_dir=pack_dir, verbose=cfg.CONF.verbose, content_diff=content_diff) + _diff_sensors( + pack_dir=pack_dir, verbose=cfg.CONF.verbose, content_diff=content_diff + ) + _diff_actions( + pack_dir=pack_dir, verbose=cfg.CONF.verbose, content_diff=content_diff + ) + _diff_rules( + pack_dir=pack_dir, verbose=cfg.CONF.verbose, content_diff=content_diff + ) return if cfg.CONF.sensors: - _diff_sensors(pack_dir=pack_dir, verbose=cfg.CONF.verbose, content_diff=content_diff) + _diff_sensors( + pack_dir=pack_dir, verbose=cfg.CONF.verbose, content_diff=content_diff + ) if cfg.CONF.actions: - _diff_actions(pack_dir=pack_dir, verbose=cfg.CONF.verbose, content_diff=content_diff) + _diff_actions( + pack_dir=pack_dir, verbose=cfg.CONF.verbose, content_diff=content_diff + ) if cfg.CONF.rules: - _diff_rules(pack_dir=pack_dir, verbose=cfg.CONF.verbose, content_diff=content_diff) + _diff_rules( + pack_dir=pack_dir, verbose=cfg.CONF.verbose, content_diff=content_diff + ) # Disconnect from db. db_teardown() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tools/direct_queue_publisher.py b/tools/direct_queue_publisher.py index bc012420859..0da7dd0b081 100755 --- a/tools/direct_queue_publisher.py +++ b/tools/direct_queue_publisher.py @@ -22,26 +22,27 @@ def main(queue, payload): - connection = pika.BlockingConnection(pika.ConnectionParameters( - host='localhost', - credentials=pika.credentials.PlainCredentials(username='guest', password='guest'))) + connection = pika.BlockingConnection( + pika.ConnectionParameters( + host="localhost", + credentials=pika.credentials.PlainCredentials( + username="guest", password="guest" + ), + ) + ) channel = connection.channel() channel.queue_declare(queue=queue, durable=True) - channel.basic_publish(exchange='', - routing_key=queue, - body=payload) + channel.basic_publish(exchange="", routing_key=queue, body=payload) print("Sent %s" % payload) connection.close() -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Direct queue publisher') - parser.add_argument('--queue', required=True, - help='Routing key to use') - parser.add_argument('--payload', required=True, - help='Message payload') +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Direct queue publisher") + parser.add_argument("--queue", required=True, help="Routing key to use") + parser.add_argument("--payload", required=True, help="Message payload") args = parser.parse_args() main(queue=args.queue, payload=args.payload) diff --git a/tools/enumerate-runners.py b/tools/enumerate-runners.py index 96104074114..9cae10cd18a 100755 --- a/tools/enumerate-runners.py +++ b/tools/enumerate-runners.py @@ -20,15 +20,18 @@ from st2common.runners import get_backend_driver from st2common import config + config.parse_args() runner_names = get_available_backends() -print('Available / installed action runners:') +print("Available / installed action runners:") for name in runner_names: runner_driver = get_backend_driver(name) runner_instance = runner_driver.get_runner() runner_metadata = runner_driver.get_metadata() - print('- %s (runner_module=%s,cls=%s)' % (name, runner_metadata['runner_module'], - runner_instance.__class__)) + print( + "- %s (runner_module=%s,cls=%s)" + % (name, runner_metadata["runner_module"], runner_instance.__class__) + ) diff --git a/tools/json2yaml.py b/tools/json2yaml.py index 29959949e84..5aecb3711e7 100755 --- a/tools/json2yaml.py +++ b/tools/json2yaml.py @@ -21,6 +21,7 @@ from __future__ import absolute_import import argparse import fnmatch + try: import simplejson as json except ImportError: @@ -33,7 +34,7 @@ PRINT = pprint.pprint -YAML_HEADER = '---' +YAML_HEADER = "---" def get_files_matching_pattern(dir_, pattern): @@ -47,47 +48,47 @@ def get_files_matching_pattern(dir_, pattern): def json_2_yaml_convert(filename): data = None try: - with open(filename, 'r') as json_file: + with open(filename, "r") as json_file: data = json.load(json_file) except: - PRINT('Failed on {}'.format(filename)) + PRINT("Failed on {}".format(filename)) traceback.print_exc() - return (filename, '') - new_filename = os.path.splitext(filename)[0] + '.yaml' - with open(new_filename, 'w') as yaml_file: - yaml_file.write(YAML_HEADER + '\n') + return (filename, "") + new_filename = os.path.splitext(filename)[0] + ".yaml" + with open(new_filename, "w") as yaml_file: + yaml_file.write(YAML_HEADER + "\n") yaml_file.write(yaml.safe_dump(data, default_flow_style=False)) return (filename, new_filename) def git_rm(filename): try: - subprocess.check_call(['git', 'rm', filename]) + subprocess.check_call(["git", "rm", filename]) except subprocess.CalledProcessError: - PRINT('Failed to git rm {}'.format(filename)) + PRINT("Failed to git rm {}".format(filename)) traceback.print_exc() return (False, filename) return (True, filename) def main(dir_, skip_convert): - files = get_files_matching_pattern(dir_, '*.json') + files = get_files_matching_pattern(dir_, "*.json") if skip_convert: PRINT(files) return results = [json_2_yaml_convert(filename) for filename in files] - PRINT('*** conversion done ***') - PRINT(['converted {} to {}'.format(result[0], result[1]) for result in results]) + PRINT("*** conversion done ***") + PRINT(["converted {} to {}".format(result[0], result[1]) for result in results]) results = [git_rm(filename) for filename, new_filename in results if new_filename] - PRINT('*** git rm done ***') + PRINT("*** git rm done ***") -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='json2yaml converter.') - parser.add_argument('--dir', '-d', required=True, - help='The dir to look for json.') - parser.add_argument('--skipconvert', '-s', action='store_true', - help='Skip conversion') +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="json2yaml converter.") + parser.add_argument("--dir", "-d", required=True, help="The dir to look for json.") + parser.add_argument( + "--skipconvert", "-s", action="store_true", help="Skip conversion" + ) args = parser.parse_args() main(dir_=args.dir, skip_convert=args.skipconvert) diff --git a/tools/list_group_members.py b/tools/list_group_members.py index e811eabd005..9cf575b62e3 100755 --- a/tools/list_group_members.py +++ b/tools/list_group_members.py @@ -31,24 +31,26 @@ def main(group_id=None): if not group_id: group_ids = list(coordinator.get_groups().get()) - group_ids = [item.decode('utf-8') for item in group_ids] + group_ids = [item.decode("utf-8") for item in group_ids] - print('Available groups (%s):' % (len(group_ids))) + print("Available groups (%s):" % (len(group_ids))) for group_id in group_ids: - print(' - %s' % (group_id)) - print('') + print(" - %s" % (group_id)) + print("") else: group_ids = [group_id] for group_id in group_ids: member_ids = list(coordinator.get_members(group_id).get()) - member_ids = [member_id.decode('utf-8') for member_id in member_ids] + member_ids = [member_id.decode("utf-8") for member_id in member_ids] print('Members in group "%s" (%s):' % (group_id, len(member_ids))) for member_id in member_ids: - capabilities = coordinator.get_member_capabilities(group_id, member_id).get() - print(' - %s (capabilities=%s)' % (member_id, str(capabilities))) + capabilities = coordinator.get_member_capabilities( + group_id, member_id + ).get() + print(" - %s (capabilities=%s)" % (member_id, str(capabilities))) def do_register_cli_opts(opts, ignore_errors=False): @@ -60,11 +62,13 @@ def do_register_cli_opts(opts, ignore_errors=False): raise -if __name__ == '__main__': +if __name__ == "__main__": cli_opts = [ - cfg.StrOpt('group-id', default=None, - help='If provided, only list members for that group.'), - + cfg.StrOpt( + "group-id", + default=None, + help="If provided, only list members for that group.", + ), ] do_register_cli_opts(cli_opts) config.parse_args() diff --git a/tools/log_watcher.py b/tools/log_watcher.py index cafcb4efec2..b16af95cc11 100755 --- a/tools/log_watcher.py +++ b/tools/log_watcher.py @@ -27,25 +27,9 @@ LOG_ALERT_PERCENT = 5 # default. -EVILS = [ - 'info', - 'debug', - 'warning', - 'exception', - 'error', - 'audit' -] - -LOG_VARS = [ - 'LOG', - 'Log', - 'log', - 'LOGGER', - 'Logger', - 'logger', - 'logging', - 'LOGGING' -] +EVILS = ["info", "debug", "warning", "exception", "error", "audit"] + +LOG_VARS = ["LOG", "Log", "log", "LOGGER", "Logger", "logger", "logging", "LOGGING"] FILE_LOG_COUNT = collections.defaultdict() FILE_LINE_COUNT = collections.defaultdict() @@ -55,25 +39,25 @@ def _parse_args(args): global LOG_ALERT_PERCENT params = {} if len(args) > 1: - params['alert_percent'] = args[1] + params["alert_percent"] = args[1] LOG_ALERT_PERCENT = int(args[1]) return params def _skip_file(filename): - if filename.startswith('.') or filename.startswith('_'): + if filename.startswith(".") or filename.startswith("_"): return True def _get_files(dir_path): if not os.path.exists(dir_path): - print('Directory %s doesn\'t exist.' % dir_path) + print("Directory %s doesn't exist." % dir_path) files = [] - exclude = set(['virtualenv', 'build', '.tox']) + exclude = set(["virtualenv", "build", ".tox"]) for root, dirnames, filenames in os.walk(dir_path): dirnames[:] = [d for d in dirnames if d not in exclude] - for filename in fnmatch.filter(filenames, '*.py'): + for filename in fnmatch.filter(filenames, "*.py"): if not _skip_file(filename): files.append(os.path.join(root, filename)) return files @@ -84,7 +68,7 @@ def _build_regex(): regex_strings = {} regexes = {} for level in EVILS: - regex_string = '|'.join([r'\.'.join([log, level]) for log in LOG_VARS]) + regex_string = "|".join([r"\.".join([log, level]) for log in LOG_VARS]) regex_strings[level] = regex_string # print('Level: %s, regex_string: %s' % (level, regex_strings[level])) regexes[level] = re.compile(regex_strings[level]) @@ -98,7 +82,7 @@ def _regex_match(line, regexes): def _build_str_matchers(): match_strings = {} for level in EVILS: - match_strings[level] = ['.'.join([log, level]) for log in LOG_VARS] + match_strings[level] = [".".join([log, level]) for log in LOG_VARS] return match_strings @@ -107,8 +91,10 @@ def _get_log_count_dict(): def _alert(fil, lines, logs, logs_level): - print('WARNING: Too many logs!!!: File: %s, total lines: %d, log lines: %d, percent: %f, ' - 'logs: %s' % (fil, lines, logs, float(logs) / lines * 100, logs_level)) + print( + "WARNING: Too many logs!!!: File: %s, total lines: %d, log lines: %d, percent: %f, " + "logs: %s" % (fil, lines, logs, float(logs) / lines * 100, logs_level) + ) def _match(line, match_strings): @@ -117,7 +103,7 @@ def _match(line, match_strings): if line.startswith(match_string): # print('Line: %s, match: %s' % (line, match_string)) return True, level, line - return False, 'UNKNOWN', line + return False, "UNKNOWN", line def _detect_log_lines(fil, matchers): @@ -148,23 +134,45 @@ def _post_process(file_dir): if total_log_count > 0: if float(total_log_count) / lines * 100 > LOG_ALERT_PERCENT: if file_dir in fil: - fil = fil[len(file_dir) + 1:] - alerts.append([fil, lines, total_log_count, float(total_log_count) / lines * 100, - log_lines_count_level['audit'], - log_lines_count_level['exception'], - log_lines_count_level['error'], - log_lines_count_level['warning'], - log_lines_count_level['info'], - log_lines_count_level['debug']]) + fil = fil[len(file_dir) + 1 :] + alerts.append( + [ + fil, + lines, + total_log_count, + float(total_log_count) / lines * 100, + log_lines_count_level["audit"], + log_lines_count_level["exception"], + log_lines_count_level["error"], + log_lines_count_level["warning"], + log_lines_count_level["info"], + log_lines_count_level["debug"], + ] + ) # sort by percent alerts.sort(key=lambda alert: alert[3], reverse=True) - print(tabulate(alerts, headers=['File', 'Lines', 'Logs', 'Percent', 'adt', 'exc', 'err', 'wrn', - 'inf', 'dbg'])) + print( + tabulate( + alerts, + headers=[ + "File", + "Lines", + "Logs", + "Percent", + "adt", + "exc", + "err", + "wrn", + "inf", + "dbg", + ], + ) + ) def main(args): params = _parse_args(args) - file_dir = params.get('dir', os.getcwd()) + file_dir = params.get("dir", os.getcwd()) files = _get_files(file_dir) matchers = _build_str_matchers() for f in files: @@ -172,5 +180,5 @@ def main(args): _post_process(file_dir) -if __name__ == '__main__': +if __name__ == "__main__": main(sys.argv) diff --git a/tools/migrate_messaging_setup.py b/tools/migrate_messaging_setup.py index 095af26e0d4..3fea8cab83f 100755 --- a/tools/migrate_messaging_setup.py +++ b/tools/migrate_messaging_setup.py @@ -36,11 +36,13 @@ class Migrate_0_13_x_to_1_1_0(object): # changes or changes in durability proeprties. OLD_QS = [ # Name changed in 1.1 - reactor.get_trigger_cud_queue('st2.trigger.watch.timers', routing_key='#'), + reactor.get_trigger_cud_queue("st2.trigger.watch.timers", routing_key="#"), # Split to multiple queues in 1.1 - reactor.get_trigger_cud_queue('st2.trigger.watch.sensorwrapper', routing_key='#'), + reactor.get_trigger_cud_queue( + "st2.trigger.watch.sensorwrapper", routing_key="#" + ), # Name changed in 1.1 - reactor.get_trigger_cud_queue('st2.trigger.watch.webhooks', routing_key='#') + reactor.get_trigger_cud_queue("st2.trigger.watch.webhooks", routing_key="#"), ] def migrate(self): @@ -53,7 +55,7 @@ def _cleanup_old_queues(self): try: bound_q.delete() except: - print('Failed to delete %s.' % q.name) + print("Failed to delete %s." % q.name) traceback.print_exc() @@ -62,10 +64,10 @@ def main(): migrator = Migrate_0_13_x_to_1_1_0() migrator.migrate() except: - print('Messaging setup migration failed.') + print("Messaging setup migration failed.") traceback.print_exc() -if __name__ == '__main__': +if __name__ == "__main__": config.parse_args(args={}) main() diff --git a/tools/migrate_rules_to_include_pack.py b/tools/migrate_rules_to_include_pack.py index 8afd3faa157..1acdd26383e 100755 --- a/tools/migrate_rules_to_include_pack.py +++ b/tools/migrate_rules_to_include_pack.py @@ -31,8 +31,11 @@ class Migration(object): - class RuleDB(stormbase.StormFoundationDB, stormbase.TagsMixin, - stormbase.ContentPackResourceMixin): + class RuleDB( + stormbase.StormFoundationDB, + stormbase.TagsMixin, + stormbase.ContentPackResourceMixin, + ): """Specifies the action to invoke on the occurrence of a Trigger. It also includes the transformation to perform to match the impedance between the payload of a TriggerInstance and input of a action. @@ -43,22 +46,23 @@ class RuleDB(stormbase.StormFoundationDB, stormbase.TagsMixin, status: enabled or disabled. If disabled occurrence of the trigger does not lead to execution of a action and vice-versa. """ + name = me.StringField(required=True) ref = me.StringField(required=True) description = me.StringField() pack = me.StringField( - required=False, - help_text='Name of the content pack.', - unique_with='name') + required=False, help_text="Name of the content pack.", unique_with="name" + ) trigger = me.StringField() criteria = stormbase.EscapedDictField() action = me.EmbeddedDocumentField(ActionExecutionSpecDB) - enabled = me.BooleanField(required=True, default=True, - help_text=u'Flag indicating whether the rule is enabled.') + enabled = me.BooleanField( + required=True, + default=True, + help_text="Flag indicating whether the rule is enabled.", + ) - meta = { - 'indexes': stormbase.TagsMixin.get_indexes() - } + meta = {"indexes": stormbase.TagsMixin.get_indexes()} # specialized access objects @@ -76,15 +80,17 @@ class RuleDB(stormbase.StormBaseDB, stormbase.TagsMixin): status: enabled or disabled. If disabled occurrence of the trigger does not lead to execution of a action and vice-versa. """ + trigger = me.StringField() criteria = stormbase.EscapedDictField() action = me.EmbeddedDocumentField(ActionExecutionSpecDB) - enabled = me.BooleanField(required=True, default=True, - help_text=u'Flag indicating whether the rule is enabled.') + enabled = me.BooleanField( + required=True, + default=True, + help_text="Flag indicating whether the rule is enabled.", + ) - meta = { - 'indexes': stormbase.TagsMixin.get_indexes() - } + meta = {"indexes": stormbase.TagsMixin.get_indexes()} rule_access_without_pack = MongoDBAccess(RuleDB) @@ -100,7 +106,7 @@ def _get_impl(cls): @classmethod def _get_by_object(cls, object): # For Rule name is unique. - name = getattr(object, 'name', '') + name = getattr(object, "name", "") return cls.get_by_name(name) @@ -126,13 +132,14 @@ def migrate_rules(): action=rule.action, enabled=rule.enabled, pack=DEFAULT_PACK_NAME, - ref=ResourceReference.to_string_reference(pack=DEFAULT_PACK_NAME, - name=rule.name) + ref=ResourceReference.to_string_reference( + pack=DEFAULT_PACK_NAME, name=rule.name + ), ) - print('Migrating rule: %s to rule: %s' % (rule.name, rule_with_pack.ref)) + print("Migrating rule: %s to rule: %s" % (rule.name, rule_with_pack.ref)) RuleWithPack.add_or_update(rule_with_pack) except Exception as e: - print('Migration failed. %s' % six.text_type(e)) + print("Migration failed. %s" % six.text_type(e)) def main(): @@ -148,5 +155,5 @@ def main(): db_teardown() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tools/migrate_triggers_to_include_ref_count.py b/tools/migrate_triggers_to_include_ref_count.py index 3e8f1b79f07..af98a00a07a 100755 --- a/tools/migrate_triggers_to_include_ref_count.py +++ b/tools/migrate_triggers_to_include_ref_count.py @@ -27,7 +27,6 @@ class TriggerMigrator(object): - def _get_trigger_with_parameters(self): """ All TriggerDB that has a parameter. @@ -38,7 +37,7 @@ def _get_rules_for_trigger(self, trigger_ref): """ All rules that reference the supplied trigger_ref. """ - return Rule.get_all(**{'trigger': trigger_ref}) + return Rule.get_all(**{"trigger": trigger_ref}) def _update_trigger_ref_count(self, trigger_db, ref_count): """ @@ -56,7 +55,7 @@ def migrate(self): trigger_ref = trigger_db.get_reference().ref rules = self._get_rules_for_trigger(trigger_ref=trigger_ref) ref_count = len(rules) - print('Updating Trigger %s to ref_count %s' % (trigger_ref, ref_count)) + print("Updating Trigger %s to ref_count %s" % (trigger_ref, ref_count)) self._update_trigger_ref_count(trigger_db=trigger_db, ref_count=ref_count) @@ -76,5 +75,5 @@ def main(): teartown() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tools/queue_consumer.py b/tools/queue_consumer.py index bf19cbf1d55..69de164fbd8 100755 --- a/tools/queue_consumer.py +++ b/tools/queue_consumer.py @@ -37,28 +37,31 @@ def __init__(self, connection, queue): self.queue = queue def get_consumers(self, Consumer, channel): - return [Consumer(queues=[self.queue], - accept=['pickle'], - callbacks=[self.process_task])] + return [ + Consumer( + queues=[self.queue], accept=["pickle"], callbacks=[self.process_task] + ) + ] def process_task(self, body, message): - print('===================================================') - print('Received message') - print('message.properties:') + print("===================================================") + print("Received message") + print("message.properties:") pprint(message.properties) - print('message.delivery_info:') + print("message.delivery_info:") pprint(message.delivery_info) - print('body:') + print("body:") pprint(body) - print('===================================================') + print("===================================================") message.ack() -def main(queue, exchange, routing_key='#'): - exchange = Exchange(exchange, type='topic') - queue = Queue(name=queue, exchange=exchange, routing_key=routing_key, - auto_delete=True) +def main(queue, exchange, routing_key="#"): + exchange = Exchange(exchange, type="topic") + queue = Queue( + name=queue, exchange=exchange, routing_key=routing_key, auto_delete=True + ) with transport_utils.get_connection() as connection: connection.connect() @@ -66,13 +69,11 @@ def main(queue, exchange, routing_key='#'): watcher.run() -if __name__ == '__main__': +if __name__ == "__main__": config.parse_args(args={}) - parser = argparse.ArgumentParser(description='Queue consumer') - parser.add_argument('--exchange', required=True, - help='Exchange to listen on') - parser.add_argument('--routing-key', default='#', - help='Routing key') + parser = argparse.ArgumentParser(description="Queue consumer") + parser.add_argument("--exchange", required=True, help="Exchange to listen on") + parser.add_argument("--routing-key", default="#", help="Routing key") args = parser.parse_args() queue_name = args.exchange + str(random.randint(1, 10000)) diff --git a/tools/queue_producer.py b/tools/queue_producer.py index c9360886768..01476a26b81 100755 --- a/tools/queue_producer.py +++ b/tools/queue_producer.py @@ -30,22 +30,20 @@ def main(exchange, routing_key, payload): - exchange = Exchange(exchange, type='topic') + exchange = Exchange(exchange, type="topic") publisher = PoolPublisher() publisher.publish(payload=payload, exchange=exchange, routing_key=routing_key) eventlet.sleep(0.5) -if __name__ == '__main__': +if __name__ == "__main__": config.parse_args(args={}) - parser = argparse.ArgumentParser(description='Queue producer') - parser.add_argument('--exchange', required=True, - help='Exchange to publish the message to') - parser.add_argument('--routing-key', required=True, - help='Routing key to use') - parser.add_argument('--payload', required=True, - help='Message payload') + parser = argparse.ArgumentParser(description="Queue producer") + parser.add_argument( + "--exchange", required=True, help="Exchange to publish the message to" + ) + parser.add_argument("--routing-key", required=True, help="Routing key to use") + parser.add_argument("--payload", required=True, help="Message payload") args = parser.parse_args() - main(exchange=args.exchange, routing_key=args.routing_key, - payload=args.payload) + main(exchange=args.exchange, routing_key=args.routing_key, payload=args.payload) diff --git a/tools/st2-analyze-links.py b/tools/st2-analyze-links.py index 4daeeafa44c..f66c158dea4 100644 --- a/tools/st2-analyze-links.py +++ b/tools/st2-analyze-links.py @@ -44,8 +44,10 @@ try: from graphviz import Digraph except ImportError: - msg = ('Missing "graphviz" dependency. You can install it using pip: \n' - 'pip install graphviz') + msg = ( + 'Missing "graphviz" dependency. You can install it using pip: \n' + "pip install graphviz" + ) raise ImportError(msg) @@ -59,18 +61,20 @@ def do_register_cli_opts(opts, ignore_errors=False): class RuleLink(object): - def __init__(self, source_action_ref, rule_ref, dest_action_ref): self._source_action_ref = source_action_ref self._rule_ref = rule_ref self._dest_action_ref = dest_action_ref def __str__(self): - return '(%s -> %s -> %s)' % (self._source_action_ref, self._rule_ref, self._dest_action_ref) + return "(%s -> %s -> %s)" % ( + self._source_action_ref, + self._rule_ref, + self._dest_action_ref, + ) class LinksAnalyzer(object): - def __init__(self): self._rule_link_by_action_ref = {} self._rules = {} @@ -81,25 +85,30 @@ def analyze(self, root_action_ref, link_tigger_ref): for rule in rules: source_action_ref = self._get_source_action_ref(rule) if not source_action_ref: - print('No source_action_ref for rule %s' % rule.ref) + print("No source_action_ref for rule %s" % rule.ref) continue rule_links = self._rules.get(source_action_ref, None) if rule_links is None: rule_links = [] self._rules[source_action_ref] = rule_links - rule_links.append(RuleLink(source_action_ref=source_action_ref, rule_ref=rule.ref, - dest_action_ref=rule.action.ref)) + rule_links.append( + RuleLink( + source_action_ref=source_action_ref, + rule_ref=rule.ref, + dest_action_ref=rule.action.ref, + ) + ) analyzed = self._do_analyze(action_ref=root_action_ref) for (depth, rule_link) in analyzed: - print('%s%s' % (' ' * depth, rule_link)) + print("%s%s" % (" " * depth, rule_link)) return analyzed def _get_source_action_ref(self, rule): criteria = rule.criteria - source_action_ref = criteria.get('trigger.action_name', None) + source_action_ref = criteria.get("trigger.action_name", None) if not source_action_ref: - source_action_ref = criteria.get('trigger.action_ref', None) - return source_action_ref['pattern'] if source_action_ref else None + source_action_ref = criteria.get("trigger.action_ref", None) + return source_action_ref["pattern"] if source_action_ref else None def _do_analyze(self, action_ref, rule_links=None, processed=None, depth=0): if processed is None: @@ -111,24 +120,32 @@ def _do_analyze(self, action_ref, rule_links=None, processed=None, depth=0): rule_links.append((depth, rule_link)) if rule_link._dest_action_ref in processed: continue - self._do_analyze(rule_link._dest_action_ref, rule_links=rule_links, - processed=processed, depth=depth + 1) + self._do_analyze( + rule_link._dest_action_ref, + rule_links=rule_links, + processed=processed, + depth=depth + 1, + ) return rule_links class Grapher(object): def generate_graph(self, rule_links, out_file): - graph_label = 'Rule based visualizer' + graph_label = "Rule based visualizer" graph_attr = { - 'rankdir': 'TD', - 'labelloc': 't', - 'fontsize': '15', - 'label': graph_label + "rankdir": "TD", + "labelloc": "t", + "fontsize": "15", + "label": graph_label, } node_attr = {} - dot = Digraph(comment='Rule based links visualization', - node_attr=node_attr, graph_attr=graph_attr, format='png') + dot = Digraph( + comment="Rule based links visualization", + node_attr=node_attr, + graph_attr=graph_attr, + format="png", + ) nodes = set() for _, rule_link in rule_links: @@ -139,10 +156,14 @@ def generate_graph(self, rule_links, out_file): if rule_link._dest_action_ref not in nodes: nodes.add(rule_link._dest_action_ref) dot.node(rule_link._dest_action_ref, rule_link._dest_action_ref) - dot.edge(rule_link._source_action_ref, rule_link._dest_action_ref, constraint='true', - label=rule_link._rule_ref) + dot.edge( + rule_link._source_action_ref, + rule_link._dest_action_ref, + constraint="true", + label=rule_link._rule_ref, + ) output_path = os.path.join(os.getcwd(), out_file) - dot.format = 'png' + dot.format = "png" dot.render(output_path) @@ -150,11 +171,13 @@ def main(): monkey_patch() cli_opts = [ - cfg.StrOpt('action_ref', default=None, - help='Root action to begin analysis.'), - cfg.StrOpt('link_trigger_ref', default='core.st2.generic.actiontrigger', - help='Root action to begin analysis.'), - cfg.StrOpt('out_file', default='pipeline') + cfg.StrOpt("action_ref", default=None, help="Root action to begin analysis."), + cfg.StrOpt( + "link_trigger_ref", + default="core.st2.generic.actiontrigger", + help="Root action to begin analysis.", + ), + cfg.StrOpt("out_file", default="pipeline"), ] do_register_cli_opts(cli_opts) config.parse_args() @@ -163,5 +186,5 @@ def main(): Grapher().generate_graph(rule_links, cfg.CONF.out_file) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tools/st2-inject-trigger-instances.py b/tools/st2-inject-trigger-instances.py index a20ba8bcbb1..79b0f18e254 100755 --- a/tools/st2-inject-trigger-instances.py +++ b/tools/st2-inject-trigger-instances.py @@ -49,7 +49,9 @@ def do_register_cli_opts(opts, ignore_errors=False): raise -def _inject_instances(trigger, rate_per_trigger, duration, payload=None, max_throughput=False): +def _inject_instances( + trigger, rate_per_trigger, duration, payload=None, max_throughput=False +): payload = payload or {} start = date_utils.get_datetime_utc_now() @@ -72,37 +74,54 @@ def _inject_instances(trigger, rate_per_trigger, duration, payload=None, max_thr actual_rate = int(count / elapsed) - print('%s: Emitted %d triggers in %d seconds (actual rate=%s triggers / second)' % - (trigger, count, elapsed, actual_rate)) + print( + "%s: Emitted %d triggers in %d seconds (actual rate=%s triggers / second)" + % (trigger, count, elapsed, actual_rate) + ) # NOTE: Due to the overhead of dispatcher.dispatch call, we allow for 10% of deviation from # requested rate before warning if rate_per_trigger and (actual_rate < (rate_per_trigger * 0.9)): - print('') - print('Warning, requested rate was %s triggers / second, but only achieved %s ' - 'triggers / second' % (rate_per_trigger, actual_rate)) - print('Too increase the throuput you will likely need to run multiple instances of ' - 'this script in parallel.') + print("") + print( + "Warning, requested rate was %s triggers / second, but only achieved %s " + "triggers / second" % (rate_per_trigger, actual_rate) + ) + print( + "Too increase the throuput you will likely need to run multiple instances of " + "this script in parallel." + ) def main(): monkey_patch() cli_opts = [ - cfg.IntOpt('rate', default=100, - help='Rate of trigger injection measured in instances in per sec.' + - ' Assumes a default exponential distribution in time so arrival is poisson.'), - cfg.ListOpt('triggers', required=False, - help='List of triggers for which instances should be fired.' + - ' Uniform distribution will be followed if there is more than one' + - 'trigger.'), - cfg.StrOpt('schema_file', default=None, - help='Path to schema file defining trigger and payload.'), - cfg.IntOpt('duration', default=60, - help='Duration of stress test in seconds.'), - cfg.BoolOpt('max-throughput', default=False, - help='If True, "rate" argument will be ignored and this script will try to ' - 'saturize the CPU and achieve max utilization.') + cfg.IntOpt( + "rate", + default=100, + help="Rate of trigger injection measured in instances in per sec." + + " Assumes a default exponential distribution in time so arrival is poisson.", + ), + cfg.ListOpt( + "triggers", + required=False, + help="List of triggers for which instances should be fired." + + " Uniform distribution will be followed if there is more than one" + + "trigger.", + ), + cfg.StrOpt( + "schema_file", + default=None, + help="Path to schema file defining trigger and payload.", + ), + cfg.IntOpt("duration", default=60, help="Duration of stress test in seconds."), + cfg.BoolOpt( + "max-throughput", + default=False, + help='If True, "rate" argument will be ignored and this script will try to ' + "saturize the CPU and achieve max utilization.", + ), ] do_register_cli_opts(cli_opts) config.parse_args() @@ -112,15 +131,20 @@ def main(): trigger_payload_schema = {} if not triggers: - if (cfg.CONF.schema_file is None or cfg.CONF.schema_file == '' or - not os.path.exists(cfg.CONF.schema_file)): - print('Either "triggers" need to be provided or a schema file containing' + - ' triggers should be provided.') + if ( + cfg.CONF.schema_file is None + or cfg.CONF.schema_file == "" + or not os.path.exists(cfg.CONF.schema_file) + ): + print( + 'Either "triggers" need to be provided or a schema file containing' + + " triggers should be provided." + ) return with open(cfg.CONF.schema_file) as fd: trigger_payload_schema = yaml.safe_load(fd) triggers = list(trigger_payload_schema.keys()) - print('Triggers=%s' % triggers) + print("Triggers=%s" % triggers) rate = cfg.CONF.rate rate_per_trigger = int(rate / len(triggers)) @@ -135,11 +159,17 @@ def main(): for trigger in triggers: payload = trigger_payload_schema.get(trigger, {}) - dispatcher_pool.spawn(_inject_instances, trigger, rate_per_trigger, duration, - payload=payload, max_throughput=max_throughput) + dispatcher_pool.spawn( + _inject_instances, + trigger, + rate_per_trigger, + duration, + payload=payload, + max_throughput=max_throughput, + ) eventlet.sleep(random.uniform(0, 1)) dispatcher_pool.waitall() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tools/visualize_action_chain.py b/tools/visualize_action_chain.py index 9981bd956ce..c6742c460d6 100755 --- a/tools/visualize_action_chain.py +++ b/tools/visualize_action_chain.py @@ -26,8 +26,10 @@ try: from graphviz import Digraph except ImportError: - msg = ('Missing "graphviz" dependency. You can install it using pip: \n' - 'pip install graphviz') + msg = ( + 'Missing "graphviz" dependency. You can install it using pip: \n' + "pip install graphviz" + ) raise ImportError(msg) from st2common.content.loader import MetaLoader @@ -41,25 +43,29 @@ def main(metadata_path, output_path, print_source=False): meta_loader = MetaLoader() data = meta_loader.load(metadata_path) - action_name = data['name'] - entry_point = data['entry_point'] + action_name = data["name"] + entry_point = data["entry_point"] workflow_metadata_path = os.path.join(metadata_dir, entry_point) chainspec = meta_loader.load(workflow_metadata_path) - chain_holder = ChainHolder(chainspec, 'workflow') + chain_holder = ChainHolder(chainspec, "workflow") - graph_label = '%s action-chain workflow visualization' % (action_name) + graph_label = "%s action-chain workflow visualization" % (action_name) graph_attr = { - 'rankdir': 'TD', - 'labelloc': 't', - 'fontsize': '15', - 'label': graph_label + "rankdir": "TD", + "labelloc": "t", + "fontsize": "15", + "label": graph_label, } node_attr = {} - dot = Digraph(comment='Action chain work-flow visualization', - node_attr=node_attr, graph_attr=graph_attr, format='png') + dot = Digraph( + comment="Action chain work-flow visualization", + node_attr=node_attr, + graph_attr=graph_attr, + format="png", + ) # dot.body.extend(['rankdir=TD', 'size="10,5"']) # Add all nodes @@ -74,23 +80,35 @@ def main(metadata_path, output_path, print_source=False): nodes = [node] while nodes: previous_node = nodes.pop() - success_node = chain_holder.get_next_node(curr_node_name=previous_node.name, - condition='on-success') - failure_node = chain_holder.get_next_node(curr_node_name=previous_node.name, - condition='on-failure') + success_node = chain_holder.get_next_node( + curr_node_name=previous_node.name, condition="on-success" + ) + failure_node = chain_holder.get_next_node( + curr_node_name=previous_node.name, condition="on-failure" + ) # Add success node (if any) if success_node: - dot.edge(previous_node.name, success_node.name, constraint='true', - color='green', label='on success') + dot.edge( + previous_node.name, + success_node.name, + constraint="true", + color="green", + label="on success", + ) if success_node.name not in processed_nodes: nodes.append(success_node) processed_nodes.add(success_node.name) # Add failure node (if any) if failure_node: - dot.edge(previous_node.name, failure_node.name, constraint='true', - color='red', label='on failure') + dot.edge( + previous_node.name, + failure_node.name, + constraint="true", + color="red", + label="on failure", + ) if failure_node.name not in processed_nodes: nodes.append(failure_node) processed_nodes.add(failure_node.name) @@ -103,21 +121,36 @@ def main(metadata_path, output_path, print_source=False): else: output_path = output_path or os.path.join(os.getcwd(), action_name) - dot.format = 'png' + dot.format = "png" dot.render(output_path) - print('Graph saved at %s' % (output_path + '.png')) - - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Action chain visualization') - parser.add_argument('--metadata-path', action='store', required=True, - help='Path to the workflow action metadata file') - parser.add_argument('--output-path', action='store', required=False, - help='Output directory for the generated image') - parser.add_argument('--print-source', action='store_true', default=False, - help='Print graphviz source code to the stdout') + print("Graph saved at %s" % (output_path + ".png")) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Action chain visualization") + parser.add_argument( + "--metadata-path", + action="store", + required=True, + help="Path to the workflow action metadata file", + ) + parser.add_argument( + "--output-path", + action="store", + required=False, + help="Output directory for the generated image", + ) + parser.add_argument( + "--print-source", + action="store_true", + default=False, + help="Print graphviz source code to the stdout", + ) args = parser.parse_args() - main(metadata_path=args.metadata_path, output_path=args.output_path, - print_source=args.print_source) + main( + metadata_path=args.metadata_path, + output_path=args.output_path, + print_source=args.print_source, + ) From 3277415eabfc7c481c9ff350a38a9d8be4857b3f Mon Sep 17 00:00:00 2001 From: Tomaz Muraus Date: Wed, 17 Feb 2021 22:44:24 +0100 Subject: [PATCH 03/25] Update black config, update .flake8 config so we ignore rules which conflict with black. --- lint-configs/python/.flake8 | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/lint-configs/python/.flake8 b/lint-configs/python/.flake8 index 4edeebe1621..271a9a21e66 100644 --- a/lint-configs/python/.flake8 +++ b/lint-configs/python/.flake8 @@ -5,7 +5,11 @@ enable-extensions = L101,L102 # We ignore some rules which conflict with black # E203 - whitespace before ':' - in direct conflict with black rule # W503 line break before binary operator - in direct conflict with black rule -ignore = E128,E402,E722,W504,E203,W503 +# E501 is line length limit +# https://black.readthedocs.io/en/stable/the_black_code_style.html#line-length +# We don't really need line length rule since black formatting takes care of +# that. +ignore = E128,E402,E722,W504,E501,E203,W503 exclude=*.egg/*,build,dist # Configuration for flake8-copyright extension From 4c526389b9beea8f34da64e0827dfc475ffc7156 Mon Sep 17 00:00:00 2001 From: Tomaz Muraus Date: Wed, 17 Feb 2021 23:42:55 +0100 Subject: [PATCH 04/25] Fix typo. --- Makefile | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/Makefile b/Makefile index abf89f450b4..67a4cc6b571 100644 --- a/Makefile +++ b/Makefile @@ -331,7 +331,7 @@ schemasgen: requirements .schemasgen black: requirements .black-check .PHONY: .black-check -.black: +.black-check: @echo @echo "================== black-check ====================" @echo @@ -349,8 +349,7 @@ black: requirements .black-check echo "==========================================================="; \ . $(VIRTUALENV_DIR)/bin/activate ; black --check --config pyproject.toml $$component/ || exit 1; \ done - # Python pack management actions - . $(VIRTUALENV_DIR)/bin/activate; black --check --config pyproject.toml contrib/* || exit 1; + . $(VIRTUALENV_DIR)/bin/activate; black --check --config pyproject.toml contrib/ || exit 1; . $(VIRTUALENV_DIR)/bin/activate; black --check --config pyproject.toml scripts/*.py || exit 1; . $(VIRTUALENV_DIR)/bin/activate; black --check --config pyproject.toml tools/*.py || exit 1; . $(VIRTUALENV_DIR)/bin/activate; black --check --config pyproject.toml pylint_plugins/*.py || exit 1; From 705d585a46c8568933b0b10ccdaf1ea2b409fdac Mon Sep 17 00:00:00 2001 From: Tomaz Muraus Date: Thu, 18 Feb 2021 00:03:17 +0100 Subject: [PATCH 05/25] Add pre-commit config which runs various lint tools on the modified / added files. --- .pre-commit-config.yaml | 42 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000000..76a3c8363fa --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,42 @@ +# pre-commit hook which runs all the various lint checks + black auto formatting on the added +# files. +# This hook relies on development virtual environment being present in virtualenv/. +default_language_version: + python: python3.6 + +exclude: '(build|dist)' +repos: + - repo: local + hooks: + - id: black + name: black + entry: ./virtualenv/bin/python -m black --config pyproject.toml + language: script + types: [file, python] + - repo: local + hooks: + - id: flake8 + name: flake8 + entry: ./virtualenv/bin/python -m flake8 --config lint-configs/python/.flake8 + language: script + types: [file, python] + - repo: local + hooks: + - id: pylint + name: pylint + entry: ./virtualenv/bin/python -m pylint -E --rcfile=./lint-configs/python/.pylintrc + language: script + types: [file, python] + - repo: local + hooks: + - id: bandit + name: bandit + entry: ./virtualenv/bin/python -m bandit -lll -x build,dist + language: script + types: [file, python] + # - repo: https://github.com/pre-commit/pre-commit-hooks + # rev: v2.5.0 + # hooks: + # - id: trailing-whitespace + # - id: check-yaml + # exclude: (^openapi|fixtures) From 0720c0a50843d930ef498e3512f625dc2888e758 Mon Sep 17 00:00:00 2001 From: Tomaz Muraus Date: Thu, 18 Feb 2021 00:19:32 +0100 Subject: [PATCH 06/25] Update pre-commit config to also run trialing whitespace and check yaml syntax check. --- .pre-commit-config.yaml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 76a3c8363fa..f4496a22781 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -34,9 +34,9 @@ repos: entry: ./virtualenv/bin/python -m bandit -lll -x build,dist language: script types: [file, python] - # - repo: https://github.com/pre-commit/pre-commit-hooks - # rev: v2.5.0 - # hooks: - # - id: trailing-whitespace - # - id: check-yaml - # exclude: (^openapi|fixtures) + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v2.5.0 + hooks: + - id: trailing-whitespace + - id: check-yaml + exclude: (openapi\.yaml) From e7945275e117898f346ea0cf0f078fe3ee28ca5f Mon Sep 17 00:00:00 2001 From: Tomaz Muraus Date: Thu, 18 Feb 2021 00:22:52 +0100 Subject: [PATCH 07/25] Also run trailing whitespace + yaml checks as part of Make targets and CI. --- Makefile | 15 +++++++++++++-- test-requirements.txt | 1 + 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 67a4cc6b571..8923cb7a5c7 100644 --- a/Makefile +++ b/Makefile @@ -382,6 +382,17 @@ black: requirements .black-format . $(VIRTUALENV_DIR)/bin/activate; black --config pyproject.toml tools/*.py || exit 1; . $(VIRTUALENV_DIR)/bin/activate; black --config pyproject.toml pylint_plugins/*.py || exit 1; +.PHONY: pre-commit-checks +black: requirements .pre-commit-checks + +# Ensure all files contain no trailing whitespace + that all YAML files are valid. +.PHONY: .pre-commit-checks +.pre-commit-checks: + @echo + @echo "================== pre-commit-checks ====================" + @echo + pre-commit run trailing-whitespace --all --show-diff-on-failure + pre-commit run check-yaml --all --show-diff-on-failure .PHONY: lint-api-spec lint-api-spec: requirements .lint-api-spec @@ -474,7 +485,7 @@ bandit: requirements .bandit lint: requirements .lint .PHONY: .lint -.lint: .generate-api-spec .flake8 .pylint .st2client-dependencies-check .st2common-circular-dependencies-check .rst-check .st2client-install-check +.lint: .generate-api-spec .black-check .pre-commit-checks .flake8 .pylint .st2client-dependencies-check .st2common-circular-dependencies-check .rst-check .st2client-install-check .PHONY: clean clean: .cleanpycs @@ -1035,7 +1046,7 @@ debs: ci: ci-checks ci-unit ci-integration ci-packs-tests .PHONY: ci-checks -ci-checks: .generated-files-check .black-check .pylint .flake8 check-requirements check-sdist-requirements .st2client-dependencies-check .st2common-circular-dependencies-check circle-lint-api-spec .rst-check .st2client-install-check check-python-packages +ci-checks: .generated-files-check .black-check .pre-commit-checks .pylint .flake8 check-requirements check-sdist-requirements .st2client-dependencies-check .st2common-circular-dependencies-check circle-lint-api-spec .rst-check .st2client-install-check check-python-packages .PHONY: .rst-check .rst-check: diff --git a/test-requirements.txt b/test-requirements.txt index b1909e45351..c004342bc8d 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -6,6 +6,7 @@ astroid==2.4.2 pylint==2.6.0 pylint-plugin-utils>=0.4 black==20.8b1 +pre-commit==2.1.0 bandit==1.5.1 ipython<6.0.0 isort>=4.2.5 From 969793f1fdbdd2c228e59ab112189166530d2680 Mon Sep 17 00:00:00 2001 From: StackStorm CodeFormat Date: Thu, 18 Feb 2021 00:23:20 +0100 Subject: [PATCH 08/25] Remove trailing whitespace from all the files. --- .circleci/config.yml | 2 +- CHANGELOG.rst | 2 +- OWNERS.md | 2 +- README.md | 12 ++++---- conf/st2.conf.sample | 4 +-- conf/st2.dev.conf | 2 +- contrib/core/CHANGES.md | 2 +- contrib/examples/actions/forloop_chain.yaml | 2 +- .../actions/forloop_push_github_repos.yaml | 2 +- .../actions/orquesta-mock-create-vm.yaml | 2 +- .../actions/workflows/orquesta-delay.yaml | 2 +- .../orquesta-error-handling-continue.yaml | 2 +- ...orquesta-error-handling-fail-manually.yaml | 2 +- .../orquesta-error-handling-noop.yaml | 2 +- .../workflows/orquesta-fail-manually.yaml | 2 +- .../actions/workflows/orquesta-join.yaml | 2 +- .../orquesta-remediate-then-fail.yaml | 2 +- .../workflows/orquesta-rollback-retry.yaml | 2 +- .../workflows/orquesta-sequential.yaml | 2 +- .../orquesta-with-items-concurrency.yaml | 2 +- .../workflows/orquesta-with-items.yaml | 2 +- .../tests/orquesta-fail-input-rendering.yaml | 2 +- ...rquesta-fail-inspection-task-contents.yaml | 2 +- .../tests/orquesta-fail-output-rendering.yaml | 2 +- .../tests/orquesta-fail-start-task.yaml | 2 +- .../tests/orquesta-fail-task-publish.yaml | 2 +- .../tests/orquesta-fail-task-transition.yaml | 2 +- .../tests/orquesta-fail-vars-rendering.yaml | 2 +- .../tests/orquesta-test-pause-resume.yaml | 2 +- .../workflows/tests/orquesta-test-rerun.yaml | 2 +- .../tests/orquesta-test-with-items.yaml | 2 +- contrib/linux/README.md | 2 +- contrib/linux/sensors/README.md | 4 +-- contrib/packs/actions/install.meta.yaml | 2 +- contrib/packs/actions/setup_virtualenv.yaml | 2 +- .../inquirer_runner/runner.yaml | 2 +- dev_docs/Troubleshooting_Guide.rst | 28 +++++++++---------- st2client/Makefile | 2 +- .../fixtures/execution_double_backslash.txt | 4 +-- .../tests/fixtures/execution_get_default.txt | 4 +-- .../fixtures/execution_get_has_schema.txt | 4 +-- .../fixtures/execution_unescape_newline.txt | 2 +- .../tests/fixtures/execution_unicode.txt | 4 +-- .../tests/fixtures/execution_unicode_py3.txt | 4 +-- st2common/bin/st2-run-pack-tests | 2 +- st2common/st2common/openapi.yaml | 6 ++-- st2reactor/Makefile | 2 +- .../fixtures/generic/runners/inquirer.yaml | 2 +- .../test_pause_resume_with_init_vars.yaml | 2 +- .../fixtures/packs/dummy_pack_20/pack.yaml | 2 +- .../workflows/jinja-version-functions.yaml | 2 +- ...low-default-value-from-action-context.yaml | 2 +- ...ow-source-channel-from-action-context.yaml | 2 +- .../workflows/yaql-version-functions.yaml | 2 +- .../checks/actions/check_loadavg.yaml | 4 +-- .../testpacks/errorcheck/actions/exit-code.sh | 2 +- tox.ini | 2 +- 57 files changed, 85 insertions(+), 85 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index bcbacbe3bd6..bde4a90784d 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1,5 +1,5 @@ # Setup in CircleCI account the following ENV variables: -# PACKAGECLOUD_ORGANIZATION (default: stackstorm) +# PACKAGECLOUD_ORGANIZATION (default: stackstorm) # PACKAGECLOUD_TOKEN version: 2 jobs: diff --git a/CHANGELOG.rst b/CHANGELOG.rst index d0460f4ea2a..a8b52fb6740 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -27,7 +27,7 @@ Changed * Improve the st2-self-check script to echo to stderr and exit if it isn't run with a ST2_AUTH_TOKEN or ST2_API_KEY environment variable. (improvement) #5068 -* Added timeout parameter for packs.install action to help with long running installs that exceed the +* Added timeout parameter for packs.install action to help with long running installs that exceed the default timeout of 600 sec which is defined by the python_script action runner (improvement) #5084 Contributed by @hnanchahal diff --git a/OWNERS.md b/OWNERS.md index dfb8fb87bc2..501e7b8f144 100644 --- a/OWNERS.md +++ b/OWNERS.md @@ -74,7 +74,7 @@ Thank you, Friends! * Johan Dahlberg ([@johandahlberg](https://github.com/johandahlberg)) - Using st2 for Bioinformatics/Science project, providing feedback & contributions in Ansible, Community, Workflows. [Case Study](https://stackstorm.com/case-study-scilifelab/). * Johan Hermansson ([@johanherman](https://github.com/johanherman)) - Using st2 for Bioinformatics/Science project, feedback & contributions in Ansible, Community, Workflows. [Case Study](https://stackstorm.com/case-study-scilifelab/). * Lakshmi Kannan ([@lakshmi-kannan](https://github.com/lakshmi-kannan)) - early Stormer. Initial Core platform architecture, scalability, reliability, Team Leadership during the project hard times. -* Lindsay Hill ([@LindsayHill](https://github.com/LindsayHill)) - ex StackStorm product manager that made a significant impact building an ecosystem we see today. +* Lindsay Hill ([@LindsayHill](https://github.com/LindsayHill)) - ex StackStorm product manager that made a significant impact building an ecosystem we see today. * Manas Kelshikar ([@manasdk](https://github.com/manasdk)) - ex Stormer. Developed (well) early core platform features. * Vineesh Jain ([@VineeshJain](https://github.com/VineeshJain)) - ex Stormer. Community, Tests, Core, QA. * Warren Van Winckel ([@warrenvw](https://github.com/warrenvw)) - ex Stormer. Docker, Kubernetes, Vagrant, Infrastructure. diff --git a/README.md b/README.md index 4d84895bbd6..b22e908d5cd 100644 --- a/README.md +++ b/README.md @@ -4,12 +4,12 @@ [![Build Status](https://github.com/StackStorm/st2/workflows/ci-checks/badge.svg?branch=master)](https://github.com/StackStorm/st2/actions?query=branch%3Amaster) [![Travis Integration Tests Status](https://travis-ci.org/StackStorm/st2.svg?branch=master)](https://travis-ci.org/StackStorm/st2) -[![Packages Build Status](https://circleci.com/gh/StackStorm/st2/tree/master.svg?style=shield)](https://circleci.com/gh/StackStorm/st2) -[![Codecov](https://codecov.io/github/StackStorm/st2/badge.svg?branch=master&service=github)](https://codecov.io/github/StackStorm/st2?branch=master) -[![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1833/badge)](https://bestpractices.coreinfrastructure.org/projects/1833) -![Python 3.6](https://img.shields.io/badge/python-3.6-blue) -[![Apache Licensed](https://img.shields.io/github/license/StackStorm/st2)](LICENSE) -[![Join our community Slack](https://img.shields.io/badge/slack-stackstorm-success.svg?logo=slack)](https://stackstorm.com/community-signup) +[![Packages Build Status](https://circleci.com/gh/StackStorm/st2/tree/master.svg?style=shield)](https://circleci.com/gh/StackStorm/st2) +[![Codecov](https://codecov.io/github/StackStorm/st2/badge.svg?branch=master&service=github)](https://codecov.io/github/StackStorm/st2?branch=master) +[![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1833/badge)](https://bestpractices.coreinfrastructure.org/projects/1833) +![Python 3.6](https://img.shields.io/badge/python-3.6-blue) +[![Apache Licensed](https://img.shields.io/github/license/StackStorm/st2)](LICENSE) +[![Join our community Slack](https://img.shields.io/badge/slack-stackstorm-success.svg?logo=slack)](https://stackstorm.com/community-signup) [![Forum](https://img.shields.io/discourse/https/forum.stackstorm.com/posts.svg)](https://forum.stackstorm.com/) --- diff --git a/conf/st2.conf.sample b/conf/st2.conf.sample index 758b743e75f..488939eb55e 100644 --- a/conf/st2.conf.sample +++ b/conf/st2.conf.sample @@ -2,7 +2,7 @@ # Note: This file is automatically generated using tools/config_gen.py - DO NOT UPDATE MANUALLY [action_sensor] -# List of execution statuses for which a trigger will be emitted. +# List of execution statuses for which a trigger will be emitted. emit_when = succeeded,failed,timeout,canceled,abandoned # comma separated list allowed here. # Whether to enable or disable the ability to post a trigger on action. enable = True @@ -170,7 +170,7 @@ trigger_instances_ttl = None # Allow encryption of values in key value stored qualified as "secret". enable_encryption = True # Location of the symmetric encryption key for encrypting values in kvstore. This key should be in JSON and should've been generated using st2-generate-symmetric-crypto-key tool. -encryption_key_path = +encryption_key_path = [log] # Exclusion list of loggers to omit. diff --git a/conf/st2.dev.conf b/conf/st2.dev.conf index 2357b082634..29078016d00 100644 --- a/conf/st2.dev.conf +++ b/conf/st2.dev.conf @@ -83,7 +83,7 @@ protocol = udp # - redis # - etcd3 # - etcd3gw -# Keep in mind that zake driver works in process so it won't work when testing cross process +# Keep in mind that zake driver works in process so it won't work when testing cross process # / cross server functionality #url = redis://localhost #url = kazoo://localhost diff --git a/contrib/core/CHANGES.md b/contrib/core/CHANGES.md index b9c04efa88d..c0b1692b039 100644 --- a/contrib/core/CHANGES.md +++ b/contrib/core/CHANGES.md @@ -1,5 +1,5 @@ # Changelog - + ## 0.3.1 * Add new ``core.uuid`` action for generating type 1 and type 4 UUIDs. diff --git a/contrib/examples/actions/forloop_chain.yaml b/contrib/examples/actions/forloop_chain.yaml index f226eae4202..86ead5303a6 100644 --- a/contrib/examples/actions/forloop_chain.yaml +++ b/contrib/examples/actions/forloop_chain.yaml @@ -6,7 +6,7 @@ entry_point: "chains/forloop_chain.yaml" enabled: true parameters: github_organization_url: - type: "string" + type: "string" description: "Organization url to parse data from" default: "https://github.com/StackStorm-Exchange" required: false diff --git a/contrib/examples/actions/forloop_push_github_repos.yaml b/contrib/examples/actions/forloop_push_github_repos.yaml index 3ff06eabc33..878772636aa 100644 --- a/contrib/examples/actions/forloop_push_github_repos.yaml +++ b/contrib/examples/actions/forloop_push_github_repos.yaml @@ -5,7 +5,7 @@ description: "Action to push the data to an external service" enabled: true entry_point: "pythonactions/forloop_push_github_repos.py" parameters: - data_to_push: + data_to_push: type: "object" description: "Dictonary of the data to be pushed" required: true diff --git a/contrib/examples/actions/orquesta-mock-create-vm.yaml b/contrib/examples/actions/orquesta-mock-create-vm.yaml index 85e774a7024..35c5ab26d81 100644 --- a/contrib/examples/actions/orquesta-mock-create-vm.yaml +++ b/contrib/examples/actions/orquesta-mock-create-vm.yaml @@ -15,7 +15,7 @@ parameters: required: true type: string ip: - default: "10.1.23.99" + default: "10.1.23.99" required: true type: string meta: diff --git a/contrib/examples/actions/workflows/orquesta-delay.yaml b/contrib/examples/actions/workflows/orquesta-delay.yaml index a0793f8bf6d..82a131712c4 100644 --- a/contrib/examples/actions/workflows/orquesta-delay.yaml +++ b/contrib/examples/actions/workflows/orquesta-delay.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic sequential workflow. input: diff --git a/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml b/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml index 5d9c6f22a0a..80047d2e5ed 100644 --- a/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml +++ b/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that demonstrates error handler with continue. input: diff --git a/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml b/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml index da9179b5edb..4e3dfa38c2c 100644 --- a/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml +++ b/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that demonstrates error handler with remediation and explicit fail. input: diff --git a/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml b/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml index 61b14a3c11d..e949dc37420 100644 --- a/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml +++ b/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that demonstrates error handler with noop to ignore error. input: diff --git a/contrib/examples/actions/workflows/orquesta-fail-manually.yaml b/contrib/examples/actions/workflows/orquesta-fail-manually.yaml index 936db68ff3e..b86d8ef25bc 100644 --- a/contrib/examples/actions/workflows/orquesta-fail-manually.yaml +++ b/contrib/examples/actions/workflows/orquesta-fail-manually.yaml @@ -11,7 +11,7 @@ tasks: - when: <% failed() %> publish: - task_name: <% task().task_name %> - - task_exit_code: <% task().result.stdout %> + - task_exit_code: <% task().result.stdout %> do: - log - fail diff --git a/contrib/examples/actions/workflows/orquesta-join.yaml b/contrib/examples/actions/workflows/orquesta-join.yaml index eaf09fed66a..a247423948d 100644 --- a/contrib/examples/actions/workflows/orquesta-join.yaml +++ b/contrib/examples/actions/workflows/orquesta-join.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that demonstrate branching and join. vars: diff --git a/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml b/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml index 936db68ff3e..b86d8ef25bc 100644 --- a/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml +++ b/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml @@ -11,7 +11,7 @@ tasks: - when: <% failed() %> publish: - task_name: <% task().task_name %> - - task_exit_code: <% task().result.stdout %> + - task_exit_code: <% task().result.stdout %> do: - log - fail diff --git a/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml b/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml index 0d80b0dbcb5..a1f203fb095 100644 --- a/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml +++ b/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: > A sample workflow that demonstrates how to handle rollback and retry on error. In this example, the workflow will loop until the file /tmp/done exists. A parallel task will wait for some time diff --git a/contrib/examples/actions/workflows/orquesta-sequential.yaml b/contrib/examples/actions/workflows/orquesta-sequential.yaml index 3a03409d36d..404681a3698 100644 --- a/contrib/examples/actions/workflows/orquesta-sequential.yaml +++ b/contrib/examples/actions/workflows/orquesta-sequential.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic sequential workflow. input: diff --git a/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml b/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml index e20b9078988..6bcbb82c583 100644 --- a/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml +++ b/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A workflow demonstrating with items and concurrent processing. input: diff --git a/contrib/examples/actions/workflows/orquesta-with-items.yaml b/contrib/examples/actions/workflows/orquesta-with-items.yaml index 6a2cc4af494..5833e270510 100644 --- a/contrib/examples/actions/workflows/orquesta-with-items.yaml +++ b/contrib/examples/actions/workflows/orquesta-with-items.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A workflow demonstrating with items. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml index ce935f62f7d..907a18e8bfe 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error while evaluating input. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml index c0322d025e8..a8be5311807 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic sequential workflow with inspection error(s). input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml index dd1e5164411..003ab8b69db 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error while evaluating output. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml index a0deab1d8f3..0c23ee6a82e 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error in the rendering of the starting task. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml index 0887d4a7beb..149fb93b97b 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that fails on publish during task transition. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml index 8fd2a94d8a6..4d4d9e5f392 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error while evaluating task transition. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml index 403728100ab..4ddd9867557 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error while evaluating vars. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml b/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml index 7123727cc31..285bf972d7a 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml @@ -19,4 +19,4 @@ tasks: task2: action: core.local input: - cmd: 'echo "<% $.var1 %>"' + cmd: 'echo "<% $.var1 %>"' diff --git a/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml b/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml index 11eb22a721a..3a4b20cee02 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml @@ -4,7 +4,7 @@ description: A sample workflow used to test the rerun feature. input: - tempfile - + tasks: task1: action: core.noop diff --git a/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml b/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml index 8af6899b595..6e24c0ec411 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A workflow for testing with items and concurrency. input: diff --git a/contrib/linux/README.md b/contrib/linux/README.md index 33d872cf868..e2b9f09d44c 100644 --- a/contrib/linux/README.md +++ b/contrib/linux/README.md @@ -55,4 +55,4 @@ Example trigger payload: ## Troubleshooting -* On CentOS7/RHEL7, dig is not installed by default. Run ``sudo yum install bind-utils`` to install. \ No newline at end of file +* On CentOS7/RHEL7, dig is not installed by default. Run ``sudo yum install bind-utils`` to install. \ No newline at end of file diff --git a/contrib/linux/sensors/README.md b/contrib/linux/sensors/README.md index 7924e91e17a..084fcad6a6e 100644 --- a/contrib/linux/sensors/README.md +++ b/contrib/linux/sensors/README.md @@ -1,6 +1,6 @@ ## NOTICE -File watch sensor has been updated to use trigger with parameters supplied via a rule approach. Tailing a file path supplied via a config file is now deprecated. +File watch sensor has been updated to use trigger with parameters supplied via a rule approach. Tailing a file path supplied via a config file is now deprecated. An example rule to supply a file path is as follows: @@ -25,5 +25,5 @@ action: ``` -Trigger ``linux.file_watch.line`` still emits the same payload as it used to. +Trigger ``linux.file_watch.line`` still emits the same payload as it used to. Just the way to provide the file_path to tail has changed. diff --git a/contrib/packs/actions/install.meta.yaml b/contrib/packs/actions/install.meta.yaml index 1b8d0d572a1..191accd1c30 100644 --- a/contrib/packs/actions/install.meta.yaml +++ b/contrib/packs/actions/install.meta.yaml @@ -35,6 +35,6 @@ timeout: default: 600 required: false - description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout + description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout type: integer diff --git a/contrib/packs/actions/setup_virtualenv.yaml b/contrib/packs/actions/setup_virtualenv.yaml index 18d1b3df157..47091705f3e 100644 --- a/contrib/packs/actions/setup_virtualenv.yaml +++ b/contrib/packs/actions/setup_virtualenv.yaml @@ -27,5 +27,5 @@ timeout: default: 600 required: false - description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout + description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout type: integer diff --git a/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml b/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml index 26711df8500..60d79a5b740 100644 --- a/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml +++ b/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml @@ -23,7 +23,7 @@ roles: default: [] required: false - description: A list of roles that are permitted to respond to the action (if nothing provided, all are permitted) - REQUIRES RBAC FEATURES + description: A list of roles that are permitted to respond to the action (if nothing provided, all are permitted) - REQUIRES RBAC FEATURES type: array users: default: [] diff --git a/dev_docs/Troubleshooting_Guide.rst b/dev_docs/Troubleshooting_Guide.rst index 4e1c1f22d21..f61cedcba4e 100644 --- a/dev_docs/Troubleshooting_Guide.rst +++ b/dev_docs/Troubleshooting_Guide.rst @@ -28,7 +28,7 @@ Troubleshooting Guide $ sudo netstat -tupln | grep 910 tcp 0 0 0.0.0.0:9100 0.0.0.0:* LISTEN 32420/python tcp 0 0 0.0.0.0:9102 0.0.0.0:* LISTEN 32403/python - + As we can see from above output port ``9101`` is not even up. To verify this let us try another command: .. code:: bash @@ -36,10 +36,10 @@ As we can see from above output port ``9101`` is not even up. To verify this let $ ps auxww | grep st2 | grep 910 vagrant 32420 0.2 1.5 79228 31364 pts/10 Ss+ 18:27 0:00 /home/vagrant/git/st2/virtualenv/bin/python ./virtualenv/bin/gunicorn st2auth.wsgi:application -k eventlet -b 0.0.0.0:9100 --workers 1 - vagrant@ether git/st2 (master %) » ps auxww | grep st2 | grep 32403 + vagrant@ether git/st2 (master %) » ps auxww | grep st2 | grep 32403 vagrant 32403 0.2 1.5 79228 31364 pts/3 Ss+ 18:27 0:00 /home/vagrant/git/st2/virtualenv/bin/python ./virtualenv/bin/gunicorn st2stream.wsgi:application -k eventlet -b 0.0.0.0:9102 --workers 1 - + - This suggests that the API process crashed, we can verify that by running ``screen -ls``.:: .. code:: bash @@ -51,19 +51,19 @@ As we can see from above output port ``9101`` is not even up. To verify this let 15767.st2-sensorcontainer (04/26/2016 06:39:10 PM) (Detached) 15762.st2-stream (04/26/2016 06:39:10 PM) (Detached) 3 Sockets in /var/run/screen/S-vagrant. - -- Now let us check the logs for any errors: + +- Now let us check the logs for any errors: .. code:: bash tail logs/st2api.log - 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger updated. Trigger.id=570e9704909a5030cf758e6d - (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, 'name': u'st2.sensor.process_exit', - 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', 'type': u'core.st2.sensor.process_exit', + 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger updated. Trigger.id=570e9704909a5030cf758e6d + (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, 'name': u'st2.sensor.process_exit', + 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', 'type': u'core.st2.sensor.process_exit', 'id': '570e9704909a5030cf758e6d', 'pack': u'core'}) - 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger created for parameter-less TriggerType. - Trigger.id=570e9704909a5030cf758e6d (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, - 'name': u'st2.sensor.process_exit', 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', + 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger created for parameter-less TriggerType. + Trigger.id=570e9704909a5030cf758e6d (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, + 'name': u'st2.sensor.process_exit', 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', 'type': u'core.st2.sensor.process_exit', 'id': '570e9704909a5030cf758e6d', 'pack': u'core'}) 2016-04-26 18:27:15,605 140317722756912 DEBUG base [-] Conflict while trying to save in DB. Traceback (most recent call last): @@ -94,7 +94,7 @@ As we can see from above output port ``9101`` is not even up. To verify this let NotUniqueError: Could not save document (E11000 duplicate key error index: st2.role_d_b.$name_1 dup key: { : "system_admin" }) 2016-04-26 18:27:15,676 140317722756912 INFO driver [-] Generating grammar tables from /usr/lib/python2.7/lib2to3/Grammar.txt 2016-04-26 18:27:15,693 140317722756912 INFO driver [-] Generating grammar tables from /usr/lib/python2.7/lib2to3/PatternGrammar.txt - + - To figure out whats wrong let us dig down further. Activate the virtualenv in st2 and run following command : .. code:: bash @@ -108,7 +108,7 @@ The above mentioned command will give out logs, we may find some error in the en File "/home/vagrant/git/st2/st2common/st2common/models/api/keyvalue.py", line 19, in from keyczar.keys import AesKey ImportError: No module named keyczar.keys - + So the problem is : module keyczar is missing. This module can be downloaded using following command: *Solution:* @@ -116,7 +116,7 @@ So the problem is : module keyczar is missing. This module can be downloaded usi .. code:: bash (virtualenv) $ pip install python-keyczar - + This should fix the issue. Now deactivate the virtual env and run ``tools/launchdev.sh restart`` diff --git a/st2client/Makefile b/st2client/Makefile index 9d6cf70a660..e17db7e4f65 100644 --- a/st2client/Makefile +++ b/st2client/Makefile @@ -9,7 +9,7 @@ RELEASE=1 COMPONENTS := st2client .PHONY: rpm -rpm: +rpm: $(PY3) setup.py bdist_rpm --python=$(PY3) mkdir -p $(RPM_ROOT)/RPMS/noarch cp dist/$(COMPONENTS)*noarch.rpm $(RPM_ROOT)/RPMS/noarch/$(COMPONENTS)-$(VER)-$(RELEASE).noarch.rpm diff --git a/st2client/tests/fixtures/execution_double_backslash.txt b/st2client/tests/fixtures/execution_double_backslash.txt index 5437c7add4b..efb21dc7e34 100644 --- a/st2client/tests/fixtures/execution_double_backslash.txt +++ b/st2client/tests/fixtures/execution_double_backslash.txt @@ -1,12 +1,12 @@ id: 547e19561e2e2417d3dde333 action.ref: core.local context.user: stanley -parameters: +parameters: cmd: echo 'C:\Users\ADMINI~1\AppData\Local\Temp\jking_vmware20_test' status: succeeded (1s elapsed) start_timestamp: Tue, 02 Dec 2014 19:56:06 UTC end_timestamp: Tue, 02 Dec 2014 19:56:07 UTC -result: +result: failed: false return_code: 0 stderr: '' diff --git a/st2client/tests/fixtures/execution_get_default.txt b/st2client/tests/fixtures/execution_get_default.txt index c29c2c221d5..4dea32224a4 100644 --- a/st2client/tests/fixtures/execution_get_default.txt +++ b/st2client/tests/fixtures/execution_get_default.txt @@ -1,12 +1,12 @@ id: 547e19561e2e2417d3dde398 action.ref: core.ping context.user: stanley -parameters: +parameters: cmd: 127.0.0.1 3 status: succeeded (1s elapsed) start_timestamp: Tue, 02 Dec 2014 19:56:06 UTC end_timestamp: Tue, 02 Dec 2014 19:56:07 UTC -result: +result: localhost: failed: false return_code: 0 diff --git a/st2client/tests/fixtures/execution_get_has_schema.txt b/st2client/tests/fixtures/execution_get_has_schema.txt index c29c2c221d5..4dea32224a4 100644 --- a/st2client/tests/fixtures/execution_get_has_schema.txt +++ b/st2client/tests/fixtures/execution_get_has_schema.txt @@ -1,12 +1,12 @@ id: 547e19561e2e2417d3dde398 action.ref: core.ping context.user: stanley -parameters: +parameters: cmd: 127.0.0.1 3 status: succeeded (1s elapsed) start_timestamp: Tue, 02 Dec 2014 19:56:06 UTC end_timestamp: Tue, 02 Dec 2014 19:56:07 UTC -result: +result: localhost: failed: false return_code: 0 diff --git a/st2client/tests/fixtures/execution_unescape_newline.txt b/st2client/tests/fixtures/execution_unescape_newline.txt index 4abac251a5a..a0b0624e554 100644 --- a/st2client/tests/fixtures/execution_unescape_newline.txt +++ b/st2client/tests/fixtures/execution_unescape_newline.txt @@ -5,7 +5,7 @@ parameters: None status: succeeded (1s elapsed) start_timestamp: Tue, 02 Dec 2014 19:56:06 UTC end_timestamp: Tue, 02 Dec 2014 19:56:07 UTC -result: +result: localhost: failed: false return_code: 0 diff --git a/st2client/tests/fixtures/execution_unicode.txt b/st2client/tests/fixtures/execution_unicode.txt index 7b7491d3b71..54a9ccc2540 100644 --- a/st2client/tests/fixtures/execution_unicode.txt +++ b/st2client/tests/fixtures/execution_unicode.txt @@ -1,12 +1,12 @@ id: 547e19561e2e2417d3dde321 action.ref: core.local context.user: stanley -parameters: +parameters: cmd: "echo '‡'" status: succeeded (1s elapsed) start_timestamp: Tue, 02 Dec 2014 19:56:06 UTC end_timestamp: Tue, 02 Dec 2014 19:56:07 UTC -result: +result: failed: false return_code: 0 stderr: '' diff --git a/st2client/tests/fixtures/execution_unicode_py3.txt b/st2client/tests/fixtures/execution_unicode_py3.txt index 0db50aa746e..0e69f4eff41 100644 --- a/st2client/tests/fixtures/execution_unicode_py3.txt +++ b/st2client/tests/fixtures/execution_unicode_py3.txt @@ -1,12 +1,12 @@ id: 547e19561e2e2417d3dde321 action.ref: core.local context.user: stanley -parameters: +parameters: cmd: "echo '\u2021'" status: succeeded (1s elapsed) start_timestamp: Tue, 02 Dec 2014 19:56:06 UTC end_timestamp: Tue, 02 Dec 2014 19:56:07 UTC -result: +result: failed: false return_code: 0 stderr: '' diff --git a/st2common/bin/st2-run-pack-tests b/st2common/bin/st2-run-pack-tests index bed28267602..9f7c2306ab0 100755 --- a/st2common/bin/st2-run-pack-tests +++ b/st2common/bin/st2-run-pack-tests @@ -322,7 +322,7 @@ if [ "${ENABLE_COVERAGE}" = true ]; then # Base options to enable test coverage reporting # --with-coverage : enables coverage reporting - # --cover-erase : removes old coverage reports before starting + # --cover-erase : removes old coverage reports before starting NOSE_OPTS+=(--with-coverage --cover-erase) # Now, by default nosetests reports test coverage for every module found diff --git a/st2common/st2common/openapi.yaml b/st2common/st2common/openapi.yaml index f1a116c3d6d..ecab78f5ad4 100644 --- a/st2common/st2common/openapi.yaml +++ b/st2common/st2common/openapi.yaml @@ -8,7 +8,7 @@ info: version: "1.0.0" title: StackStorm API description: | - + ## Welcome Welcome to the StackStorm API Reference documentation! You can use the StackStorm API to integrate StackStorm with 3rd-party systems and custom applications. Example integrations include writing your own self-service user portal, or integrating with other orquestation systems. @@ -197,7 +197,7 @@ info: Join our [Slack Community](https://stackstorm.com/community-signup) to get help from the engineering team and fellow users. You can also create issues against the main [StackStorm GitHub repo](https://github.com/StackStorm/st2/issues), or the [st2apidocs repo](https://github.com/StackStorm/st2apidocs) for documentation-specific issues. We also recommend reviewing the main [StackStorm documentation](https://docs.stackstorm.com/). - + paths: /api/v1/: @@ -1477,7 +1477,7 @@ paths: /api/v1/keys: get: operationId: st2api.controllers.v1.keyvalue:key_value_pair_controller.get_all - x-permissions: + x-permissions: description: Returns a list of all key value pairs. parameters: - name: prefix diff --git a/st2reactor/Makefile b/st2reactor/Makefile index cd3eb75a3ee..232abed4dd5 100644 --- a/st2reactor/Makefile +++ b/st2reactor/Makefile @@ -7,7 +7,7 @@ VER=0.4.0 COMPONENTS := st2reactor .PHONY: rpm -rpm: +rpm: pushd ~ && rpmdev-setuptree && popd tar --transform=s~^~$(COMPONENTS)-$(VER)/~ -czf $(RPM_SOURCES_DIR)/$(COMPONENTS).tar.gz bin conf $(COMPONENTS) cp packaging/rpm/$(COMPONENTS).spec $(RPM_SPECS_DIR)/ diff --git a/st2tests/st2tests/fixtures/generic/runners/inquirer.yaml b/st2tests/st2tests/fixtures/generic/runners/inquirer.yaml index 421262c52bf..f49903a7e92 100644 --- a/st2tests/st2tests/fixtures/generic/runners/inquirer.yaml +++ b/st2tests/st2tests/fixtures/generic/runners/inquirer.yaml @@ -24,7 +24,7 @@ runner_parameters: roles: default: [] required: false - description: A list of roles that are permitted to respond to the action (if nothing provided, all are permitted) - REQUIRES RBAC FEATURES + description: A list of roles that are permitted to respond to the action (if nothing provided, all are permitted) - REQUIRES RBAC FEATURES type: array users: default: [] diff --git a/st2tests/st2tests/fixtures/packs/action_chain_tests/actions/chains/test_pause_resume_with_init_vars.yaml b/st2tests/st2tests/fixtures/packs/action_chain_tests/actions/chains/test_pause_resume_with_init_vars.yaml index dda36fb3025..72114dd49e1 100644 --- a/st2tests/st2tests/fixtures/packs/action_chain_tests/actions/chains/test_pause_resume_with_init_vars.yaml +++ b/st2tests/st2tests/fixtures/packs/action_chain_tests/actions/chains/test_pause_resume_with_init_vars.yaml @@ -8,7 +8,7 @@ chain: cmd: "while [ -e '{{tempfile}}' ]; do sleep 0.1; done" timeout: 180 publish: - var1: "{{var1|upper}}" + var1: "{{var1|upper}}" on-success: task2 - name: task2 diff --git a/st2tests/st2tests/fixtures/packs/dummy_pack_20/pack.yaml b/st2tests/st2tests/fixtures/packs/dummy_pack_20/pack.yaml index c99661eb7c7..a843fb64d4c 100644 --- a/st2tests/st2tests/fixtures/packs/dummy_pack_20/pack.yaml +++ b/st2tests/st2tests/fixtures/packs/dummy_pack_20/pack.yaml @@ -9,6 +9,6 @@ attribute1: value1 attribute2: value2 attribute3: value3 attribute: 4 -some: +some: - "feature" - value diff --git a/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/jinja-version-functions.yaml b/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/jinja-version-functions.yaml index 0d77ef8aeff..579bc515bb3 100644 --- a/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/jinja-version-functions.yaml +++ b/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/jinja-version-functions.yaml @@ -14,7 +14,7 @@ output: - more_than: '{{ version_more_than("0.9.0", "0.10.0") }}' - less_than: '{{ version_less_than("0.10.0", "0.9.0") }}' - match: '{{ version_match("0.10.1", ">0.10.0") }}' - - bump_major: '{{ version_bump_major("0.10.0") }}' + - bump_major: '{{ version_bump_major("0.10.0") }}' - bump_minor: '{{ version_bump_minor("0.10.0") }}' - bump_patch: '{{ version_bump_patch("0.10.0") }}' - strip_patch: '{{ version_strip_patch("0.10.0") }}' diff --git a/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/subworkflow-default-value-from-action-context.yaml b/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/subworkflow-default-value-from-action-context.yaml index 90e3ac78ade..3301ab423a4 100644 --- a/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/subworkflow-default-value-from-action-context.yaml +++ b/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/subworkflow-default-value-from-action-context.yaml @@ -3,7 +3,7 @@ version: 1.0 description: A sample workflow that calls another subworkflow. output: - - msg: <% task(task1).result.output.msg %> + - msg: <% task(task1).result.output.msg %> tasks: task1: diff --git a/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/subworkflow-source-channel-from-action-context.yaml b/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/subworkflow-source-channel-from-action-context.yaml index eedc5b8c3e3..7a6bd62fa54 100644 --- a/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/subworkflow-source-channel-from-action-context.yaml +++ b/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/subworkflow-source-channel-from-action-context.yaml @@ -3,7 +3,7 @@ version: 1.0 description: A sample workflow that calls another subworkflow. output: - - msg: <% task(task1).result.output.msg %> + - msg: <% task(task1).result.output.msg %> tasks: task1: diff --git a/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/yaql-version-functions.yaml b/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/yaql-version-functions.yaml index cce350c46c6..7bda9cc83b6 100644 --- a/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/yaql-version-functions.yaml +++ b/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/yaql-version-functions.yaml @@ -14,7 +14,7 @@ output: - more_than: '<% version_more_than("0.9.0", "0.10.0") %>' - less_than: '<% version_less_than("0.10.0", "0.9.0") %>' - match: '<% version_match("0.10.1", ">0.10.0") %>' - - bump_major: '<% version_bump_major("0.10.0") %>' + - bump_major: '<% version_bump_major("0.10.0") %>' - bump_minor: '<% version_bump_minor("0.10.0") %>' - bump_patch: '<% version_bump_patch("0.10.0") %>' - strip_patch: '<% version_strip_patch("0.10.0") %>' diff --git a/st2tests/testpacks/checks/actions/check_loadavg.yaml b/st2tests/testpacks/checks/actions/check_loadavg.yaml index ac38037d6c8..06abc652278 100644 --- a/st2tests/testpacks/checks/actions/check_loadavg.yaml +++ b/st2tests/testpacks/checks/actions/check_loadavg.yaml @@ -4,8 +4,8 @@ description: "Check CPU Load Average on a Host" enabled: true entry_point: "checks/check_loadavg.py" - parameters: - period: + parameters: + period: type: "string" description: "Time period for load avg: 5,10,15 minutes, or 'all'" default: "all" diff --git a/st2tests/testpacks/errorcheck/actions/exit-code.sh b/st2tests/testpacks/errorcheck/actions/exit-code.sh index 5320dc2f363..2e6eadf6a2c 100755 --- a/st2tests/testpacks/errorcheck/actions/exit-code.sh +++ b/st2tests/testpacks/errorcheck/actions/exit-code.sh @@ -6,4 +6,4 @@ if [ -n "$1" ] exit_code=$1 fi -exit $exit_code +exit $exit_code diff --git a/tox.ini b/tox.ini index 451ceee8e1e..de40b858789 100644 --- a/tox.ini +++ b/tox.ini @@ -71,7 +71,7 @@ commands = [testenv:py36-integration] basepython = python3.6 -setenv = PYTHONPATH = {toxinidir}/external:{toxinidir}/st2common:{toxinidir}/st2auth:{toxinidir}/st2api:{toxinidir}/st2actions:{toxinidir}/st2exporter:{toxinidir}/st2reactor:{toxinidir}/st2tests:{toxinidir}/contrib/runners/action_chain_runner:{toxinidir}/contrib/runners/local_runner:{toxinidir}/contrib/runners/python_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/noop_runner:{toxinidir}/contrib/runners/announcement_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/orquesta_runner:{toxinidir}/contrib/runners/inquirer_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/winrm_runner +setenv = PYTHONPATH = {toxinidir}/external:{toxinidir}/st2common:{toxinidir}/st2auth:{toxinidir}/st2api:{toxinidir}/st2actions:{toxinidir}/st2exporter:{toxinidir}/st2reactor:{toxinidir}/st2tests:{toxinidir}/contrib/runners/action_chain_runner:{toxinidir}/contrib/runners/local_runner:{toxinidir}/contrib/runners/python_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/noop_runner:{toxinidir}/contrib/runners/announcement_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/orquesta_runner:{toxinidir}/contrib/runners/inquirer_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/winrm_runner VIRTUALENV_DIR = {envdir} passenv = NOSE_WITH_TIMER TRAVIS ST2_CI install_command = pip install -U --force-reinstall {opts} {packages} From e4cdc0584deb6034d65163b586cfd59ecd4c1f47 Mon Sep 17 00:00:00 2001 From: Tomaz Muraus Date: Thu, 18 Feb 2021 00:25:58 +0100 Subject: [PATCH 09/25] Add .git-blame-ignore-rev file. --- .git-blame-ignore-rev | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 .git-blame-ignore-rev diff --git a/.git-blame-ignore-rev b/.git-blame-ignore-rev new file mode 100644 index 00000000000..2e9f4011b20 --- /dev/null +++ b/.git-blame-ignore-rev @@ -0,0 +1,5 @@ +# Code was auto formatted using black +8496bb2407b969f0937431992172b98b545f6756 + +# Files were auto formatted to remove trailing whitespace +969793f1fdbdd2c228e59ab112189166530d2680 From 6051539231e4a8d1c547695f072882ed8cf384d3 Mon Sep 17 00:00:00 2001 From: Tomaz Muraus Date: Thu, 18 Feb 2021 10:52:46 +0100 Subject: [PATCH 10/25] Revert "Remove trailing whitespace from all the files." This reverts commit 969793f1fdbdd2c228e59ab112189166530d2680. --- .circleci/config.yml | 2 +- CHANGELOG.rst | 2 +- OWNERS.md | 2 +- README.md | 12 ++++---- conf/st2.conf.sample | 4 +-- conf/st2.dev.conf | 2 +- contrib/core/CHANGES.md | 2 +- contrib/examples/actions/forloop_chain.yaml | 2 +- .../actions/forloop_push_github_repos.yaml | 2 +- .../actions/orquesta-mock-create-vm.yaml | 2 +- .../actions/workflows/orquesta-delay.yaml | 2 +- .../orquesta-error-handling-continue.yaml | 2 +- ...orquesta-error-handling-fail-manually.yaml | 2 +- .../orquesta-error-handling-noop.yaml | 2 +- .../workflows/orquesta-fail-manually.yaml | 2 +- .../actions/workflows/orquesta-join.yaml | 2 +- .../orquesta-remediate-then-fail.yaml | 2 +- .../workflows/orquesta-rollback-retry.yaml | 2 +- .../workflows/orquesta-sequential.yaml | 2 +- .../orquesta-with-items-concurrency.yaml | 2 +- .../workflows/orquesta-with-items.yaml | 2 +- .../tests/orquesta-fail-input-rendering.yaml | 2 +- ...rquesta-fail-inspection-task-contents.yaml | 2 +- .../tests/orquesta-fail-output-rendering.yaml | 2 +- .../tests/orquesta-fail-start-task.yaml | 2 +- .../tests/orquesta-fail-task-publish.yaml | 2 +- .../tests/orquesta-fail-task-transition.yaml | 2 +- .../tests/orquesta-fail-vars-rendering.yaml | 2 +- .../tests/orquesta-test-pause-resume.yaml | 2 +- .../workflows/tests/orquesta-test-rerun.yaml | 2 +- .../tests/orquesta-test-with-items.yaml | 2 +- contrib/linux/README.md | 2 +- contrib/linux/sensors/README.md | 4 +-- contrib/packs/actions/install.meta.yaml | 2 +- contrib/packs/actions/setup_virtualenv.yaml | 2 +- .../inquirer_runner/runner.yaml | 2 +- dev_docs/Troubleshooting_Guide.rst | 28 +++++++++---------- st2client/Makefile | 2 +- .../fixtures/execution_double_backslash.txt | 4 +-- .../tests/fixtures/execution_get_default.txt | 4 +-- .../fixtures/execution_get_has_schema.txt | 4 +-- .../fixtures/execution_unescape_newline.txt | 2 +- .../tests/fixtures/execution_unicode.txt | 4 +-- .../tests/fixtures/execution_unicode_py3.txt | 4 +-- st2common/bin/st2-run-pack-tests | 2 +- st2common/st2common/openapi.yaml | 6 ++-- st2reactor/Makefile | 2 +- .../fixtures/generic/runners/inquirer.yaml | 2 +- .../test_pause_resume_with_init_vars.yaml | 2 +- .../fixtures/packs/dummy_pack_20/pack.yaml | 2 +- .../workflows/jinja-version-functions.yaml | 2 +- ...low-default-value-from-action-context.yaml | 2 +- ...ow-source-channel-from-action-context.yaml | 2 +- .../workflows/yaql-version-functions.yaml | 2 +- .../checks/actions/check_loadavg.yaml | 4 +-- .../testpacks/errorcheck/actions/exit-code.sh | 2 +- tox.ini | 2 +- 57 files changed, 85 insertions(+), 85 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index bde4a90784d..bcbacbe3bd6 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1,5 +1,5 @@ # Setup in CircleCI account the following ENV variables: -# PACKAGECLOUD_ORGANIZATION (default: stackstorm) +# PACKAGECLOUD_ORGANIZATION (default: stackstorm) # PACKAGECLOUD_TOKEN version: 2 jobs: diff --git a/CHANGELOG.rst b/CHANGELOG.rst index a8b52fb6740..d0460f4ea2a 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -27,7 +27,7 @@ Changed * Improve the st2-self-check script to echo to stderr and exit if it isn't run with a ST2_AUTH_TOKEN or ST2_API_KEY environment variable. (improvement) #5068 -* Added timeout parameter for packs.install action to help with long running installs that exceed the +* Added timeout parameter for packs.install action to help with long running installs that exceed the default timeout of 600 sec which is defined by the python_script action runner (improvement) #5084 Contributed by @hnanchahal diff --git a/OWNERS.md b/OWNERS.md index 501e7b8f144..dfb8fb87bc2 100644 --- a/OWNERS.md +++ b/OWNERS.md @@ -74,7 +74,7 @@ Thank you, Friends! * Johan Dahlberg ([@johandahlberg](https://github.com/johandahlberg)) - Using st2 for Bioinformatics/Science project, providing feedback & contributions in Ansible, Community, Workflows. [Case Study](https://stackstorm.com/case-study-scilifelab/). * Johan Hermansson ([@johanherman](https://github.com/johanherman)) - Using st2 for Bioinformatics/Science project, feedback & contributions in Ansible, Community, Workflows. [Case Study](https://stackstorm.com/case-study-scilifelab/). * Lakshmi Kannan ([@lakshmi-kannan](https://github.com/lakshmi-kannan)) - early Stormer. Initial Core platform architecture, scalability, reliability, Team Leadership during the project hard times. -* Lindsay Hill ([@LindsayHill](https://github.com/LindsayHill)) - ex StackStorm product manager that made a significant impact building an ecosystem we see today. +* Lindsay Hill ([@LindsayHill](https://github.com/LindsayHill)) - ex StackStorm product manager that made a significant impact building an ecosystem we see today. * Manas Kelshikar ([@manasdk](https://github.com/manasdk)) - ex Stormer. Developed (well) early core platform features. * Vineesh Jain ([@VineeshJain](https://github.com/VineeshJain)) - ex Stormer. Community, Tests, Core, QA. * Warren Van Winckel ([@warrenvw](https://github.com/warrenvw)) - ex Stormer. Docker, Kubernetes, Vagrant, Infrastructure. diff --git a/README.md b/README.md index b22e908d5cd..4d84895bbd6 100644 --- a/README.md +++ b/README.md @@ -4,12 +4,12 @@ [![Build Status](https://github.com/StackStorm/st2/workflows/ci-checks/badge.svg?branch=master)](https://github.com/StackStorm/st2/actions?query=branch%3Amaster) [![Travis Integration Tests Status](https://travis-ci.org/StackStorm/st2.svg?branch=master)](https://travis-ci.org/StackStorm/st2) -[![Packages Build Status](https://circleci.com/gh/StackStorm/st2/tree/master.svg?style=shield)](https://circleci.com/gh/StackStorm/st2) -[![Codecov](https://codecov.io/github/StackStorm/st2/badge.svg?branch=master&service=github)](https://codecov.io/github/StackStorm/st2?branch=master) -[![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1833/badge)](https://bestpractices.coreinfrastructure.org/projects/1833) -![Python 3.6](https://img.shields.io/badge/python-3.6-blue) -[![Apache Licensed](https://img.shields.io/github/license/StackStorm/st2)](LICENSE) -[![Join our community Slack](https://img.shields.io/badge/slack-stackstorm-success.svg?logo=slack)](https://stackstorm.com/community-signup) +[![Packages Build Status](https://circleci.com/gh/StackStorm/st2/tree/master.svg?style=shield)](https://circleci.com/gh/StackStorm/st2) +[![Codecov](https://codecov.io/github/StackStorm/st2/badge.svg?branch=master&service=github)](https://codecov.io/github/StackStorm/st2?branch=master) +[![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1833/badge)](https://bestpractices.coreinfrastructure.org/projects/1833) +![Python 3.6](https://img.shields.io/badge/python-3.6-blue) +[![Apache Licensed](https://img.shields.io/github/license/StackStorm/st2)](LICENSE) +[![Join our community Slack](https://img.shields.io/badge/slack-stackstorm-success.svg?logo=slack)](https://stackstorm.com/community-signup) [![Forum](https://img.shields.io/discourse/https/forum.stackstorm.com/posts.svg)](https://forum.stackstorm.com/) --- diff --git a/conf/st2.conf.sample b/conf/st2.conf.sample index 488939eb55e..758b743e75f 100644 --- a/conf/st2.conf.sample +++ b/conf/st2.conf.sample @@ -2,7 +2,7 @@ # Note: This file is automatically generated using tools/config_gen.py - DO NOT UPDATE MANUALLY [action_sensor] -# List of execution statuses for which a trigger will be emitted. +# List of execution statuses for which a trigger will be emitted. emit_when = succeeded,failed,timeout,canceled,abandoned # comma separated list allowed here. # Whether to enable or disable the ability to post a trigger on action. enable = True @@ -170,7 +170,7 @@ trigger_instances_ttl = None # Allow encryption of values in key value stored qualified as "secret". enable_encryption = True # Location of the symmetric encryption key for encrypting values in kvstore. This key should be in JSON and should've been generated using st2-generate-symmetric-crypto-key tool. -encryption_key_path = +encryption_key_path = [log] # Exclusion list of loggers to omit. diff --git a/conf/st2.dev.conf b/conf/st2.dev.conf index 29078016d00..2357b082634 100644 --- a/conf/st2.dev.conf +++ b/conf/st2.dev.conf @@ -83,7 +83,7 @@ protocol = udp # - redis # - etcd3 # - etcd3gw -# Keep in mind that zake driver works in process so it won't work when testing cross process +# Keep in mind that zake driver works in process so it won't work when testing cross process # / cross server functionality #url = redis://localhost #url = kazoo://localhost diff --git a/contrib/core/CHANGES.md b/contrib/core/CHANGES.md index c0b1692b039..b9c04efa88d 100644 --- a/contrib/core/CHANGES.md +++ b/contrib/core/CHANGES.md @@ -1,5 +1,5 @@ # Changelog - + ## 0.3.1 * Add new ``core.uuid`` action for generating type 1 and type 4 UUIDs. diff --git a/contrib/examples/actions/forloop_chain.yaml b/contrib/examples/actions/forloop_chain.yaml index 86ead5303a6..f226eae4202 100644 --- a/contrib/examples/actions/forloop_chain.yaml +++ b/contrib/examples/actions/forloop_chain.yaml @@ -6,7 +6,7 @@ entry_point: "chains/forloop_chain.yaml" enabled: true parameters: github_organization_url: - type: "string" + type: "string" description: "Organization url to parse data from" default: "https://github.com/StackStorm-Exchange" required: false diff --git a/contrib/examples/actions/forloop_push_github_repos.yaml b/contrib/examples/actions/forloop_push_github_repos.yaml index 878772636aa..3ff06eabc33 100644 --- a/contrib/examples/actions/forloop_push_github_repos.yaml +++ b/contrib/examples/actions/forloop_push_github_repos.yaml @@ -5,7 +5,7 @@ description: "Action to push the data to an external service" enabled: true entry_point: "pythonactions/forloop_push_github_repos.py" parameters: - data_to_push: + data_to_push: type: "object" description: "Dictonary of the data to be pushed" required: true diff --git a/contrib/examples/actions/orquesta-mock-create-vm.yaml b/contrib/examples/actions/orquesta-mock-create-vm.yaml index 35c5ab26d81..85e774a7024 100644 --- a/contrib/examples/actions/orquesta-mock-create-vm.yaml +++ b/contrib/examples/actions/orquesta-mock-create-vm.yaml @@ -15,7 +15,7 @@ parameters: required: true type: string ip: - default: "10.1.23.99" + default: "10.1.23.99" required: true type: string meta: diff --git a/contrib/examples/actions/workflows/orquesta-delay.yaml b/contrib/examples/actions/workflows/orquesta-delay.yaml index 82a131712c4..a0793f8bf6d 100644 --- a/contrib/examples/actions/workflows/orquesta-delay.yaml +++ b/contrib/examples/actions/workflows/orquesta-delay.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic sequential workflow. input: diff --git a/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml b/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml index 80047d2e5ed..5d9c6f22a0a 100644 --- a/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml +++ b/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that demonstrates error handler with continue. input: diff --git a/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml b/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml index 4e3dfa38c2c..da9179b5edb 100644 --- a/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml +++ b/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that demonstrates error handler with remediation and explicit fail. input: diff --git a/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml b/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml index e949dc37420..61b14a3c11d 100644 --- a/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml +++ b/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that demonstrates error handler with noop to ignore error. input: diff --git a/contrib/examples/actions/workflows/orquesta-fail-manually.yaml b/contrib/examples/actions/workflows/orquesta-fail-manually.yaml index b86d8ef25bc..936db68ff3e 100644 --- a/contrib/examples/actions/workflows/orquesta-fail-manually.yaml +++ b/contrib/examples/actions/workflows/orquesta-fail-manually.yaml @@ -11,7 +11,7 @@ tasks: - when: <% failed() %> publish: - task_name: <% task().task_name %> - - task_exit_code: <% task().result.stdout %> + - task_exit_code: <% task().result.stdout %> do: - log - fail diff --git a/contrib/examples/actions/workflows/orquesta-join.yaml b/contrib/examples/actions/workflows/orquesta-join.yaml index a247423948d..eaf09fed66a 100644 --- a/contrib/examples/actions/workflows/orquesta-join.yaml +++ b/contrib/examples/actions/workflows/orquesta-join.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that demonstrate branching and join. vars: diff --git a/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml b/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml index b86d8ef25bc..936db68ff3e 100644 --- a/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml +++ b/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml @@ -11,7 +11,7 @@ tasks: - when: <% failed() %> publish: - task_name: <% task().task_name %> - - task_exit_code: <% task().result.stdout %> + - task_exit_code: <% task().result.stdout %> do: - log - fail diff --git a/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml b/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml index a1f203fb095..0d80b0dbcb5 100644 --- a/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml +++ b/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: > A sample workflow that demonstrates how to handle rollback and retry on error. In this example, the workflow will loop until the file /tmp/done exists. A parallel task will wait for some time diff --git a/contrib/examples/actions/workflows/orquesta-sequential.yaml b/contrib/examples/actions/workflows/orquesta-sequential.yaml index 404681a3698..3a03409d36d 100644 --- a/contrib/examples/actions/workflows/orquesta-sequential.yaml +++ b/contrib/examples/actions/workflows/orquesta-sequential.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic sequential workflow. input: diff --git a/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml b/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml index 6bcbb82c583..e20b9078988 100644 --- a/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml +++ b/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A workflow demonstrating with items and concurrent processing. input: diff --git a/contrib/examples/actions/workflows/orquesta-with-items.yaml b/contrib/examples/actions/workflows/orquesta-with-items.yaml index 5833e270510..6a2cc4af494 100644 --- a/contrib/examples/actions/workflows/orquesta-with-items.yaml +++ b/contrib/examples/actions/workflows/orquesta-with-items.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A workflow demonstrating with items. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml index 907a18e8bfe..ce935f62f7d 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error while evaluating input. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml index a8be5311807..c0322d025e8 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic sequential workflow with inspection error(s). input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml index 003ab8b69db..dd1e5164411 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error while evaluating output. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml index 0c23ee6a82e..a0deab1d8f3 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error in the rendering of the starting task. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml index 149fb93b97b..0887d4a7beb 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that fails on publish during task transition. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml index 4d4d9e5f392..8fd2a94d8a6 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error while evaluating task transition. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml index 4ddd9867557..403728100ab 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error while evaluating vars. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml b/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml index 285bf972d7a..7123727cc31 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml @@ -19,4 +19,4 @@ tasks: task2: action: core.local input: - cmd: 'echo "<% $.var1 %>"' + cmd: 'echo "<% $.var1 %>"' diff --git a/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml b/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml index 3a4b20cee02..11eb22a721a 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml @@ -4,7 +4,7 @@ description: A sample workflow used to test the rerun feature. input: - tempfile - + tasks: task1: action: core.noop diff --git a/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml b/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml index 6e24c0ec411..8af6899b595 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A workflow for testing with items and concurrency. input: diff --git a/contrib/linux/README.md b/contrib/linux/README.md index e2b9f09d44c..33d872cf868 100644 --- a/contrib/linux/README.md +++ b/contrib/linux/README.md @@ -55,4 +55,4 @@ Example trigger payload: ## Troubleshooting -* On CentOS7/RHEL7, dig is not installed by default. Run ``sudo yum install bind-utils`` to install. \ No newline at end of file +* On CentOS7/RHEL7, dig is not installed by default. Run ``sudo yum install bind-utils`` to install. \ No newline at end of file diff --git a/contrib/linux/sensors/README.md b/contrib/linux/sensors/README.md index 084fcad6a6e..7924e91e17a 100644 --- a/contrib/linux/sensors/README.md +++ b/contrib/linux/sensors/README.md @@ -1,6 +1,6 @@ ## NOTICE -File watch sensor has been updated to use trigger with parameters supplied via a rule approach. Tailing a file path supplied via a config file is now deprecated. +File watch sensor has been updated to use trigger with parameters supplied via a rule approach. Tailing a file path supplied via a config file is now deprecated. An example rule to supply a file path is as follows: @@ -25,5 +25,5 @@ action: ``` -Trigger ``linux.file_watch.line`` still emits the same payload as it used to. +Trigger ``linux.file_watch.line`` still emits the same payload as it used to. Just the way to provide the file_path to tail has changed. diff --git a/contrib/packs/actions/install.meta.yaml b/contrib/packs/actions/install.meta.yaml index 191accd1c30..1b8d0d572a1 100644 --- a/contrib/packs/actions/install.meta.yaml +++ b/contrib/packs/actions/install.meta.yaml @@ -35,6 +35,6 @@ timeout: default: 600 required: false - description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout + description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout type: integer diff --git a/contrib/packs/actions/setup_virtualenv.yaml b/contrib/packs/actions/setup_virtualenv.yaml index 47091705f3e..18d1b3df157 100644 --- a/contrib/packs/actions/setup_virtualenv.yaml +++ b/contrib/packs/actions/setup_virtualenv.yaml @@ -27,5 +27,5 @@ timeout: default: 600 required: false - description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout + description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout type: integer diff --git a/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml b/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml index 60d79a5b740..26711df8500 100644 --- a/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml +++ b/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml @@ -23,7 +23,7 @@ roles: default: [] required: false - description: A list of roles that are permitted to respond to the action (if nothing provided, all are permitted) - REQUIRES RBAC FEATURES + description: A list of roles that are permitted to respond to the action (if nothing provided, all are permitted) - REQUIRES RBAC FEATURES type: array users: default: [] diff --git a/dev_docs/Troubleshooting_Guide.rst b/dev_docs/Troubleshooting_Guide.rst index f61cedcba4e..4e1c1f22d21 100644 --- a/dev_docs/Troubleshooting_Guide.rst +++ b/dev_docs/Troubleshooting_Guide.rst @@ -28,7 +28,7 @@ Troubleshooting Guide $ sudo netstat -tupln | grep 910 tcp 0 0 0.0.0.0:9100 0.0.0.0:* LISTEN 32420/python tcp 0 0 0.0.0.0:9102 0.0.0.0:* LISTEN 32403/python - + As we can see from above output port ``9101`` is not even up. To verify this let us try another command: .. code:: bash @@ -36,10 +36,10 @@ As we can see from above output port ``9101`` is not even up. To verify this let $ ps auxww | grep st2 | grep 910 vagrant 32420 0.2 1.5 79228 31364 pts/10 Ss+ 18:27 0:00 /home/vagrant/git/st2/virtualenv/bin/python ./virtualenv/bin/gunicorn st2auth.wsgi:application -k eventlet -b 0.0.0.0:9100 --workers 1 - vagrant@ether git/st2 (master %) » ps auxww | grep st2 | grep 32403 + vagrant@ether git/st2 (master %) » ps auxww | grep st2 | grep 32403 vagrant 32403 0.2 1.5 79228 31364 pts/3 Ss+ 18:27 0:00 /home/vagrant/git/st2/virtualenv/bin/python ./virtualenv/bin/gunicorn st2stream.wsgi:application -k eventlet -b 0.0.0.0:9102 --workers 1 - + - This suggests that the API process crashed, we can verify that by running ``screen -ls``.:: .. code:: bash @@ -51,19 +51,19 @@ As we can see from above output port ``9101`` is not even up. To verify this let 15767.st2-sensorcontainer (04/26/2016 06:39:10 PM) (Detached) 15762.st2-stream (04/26/2016 06:39:10 PM) (Detached) 3 Sockets in /var/run/screen/S-vagrant. - -- Now let us check the logs for any errors: + +- Now let us check the logs for any errors: .. code:: bash tail logs/st2api.log - 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger updated. Trigger.id=570e9704909a5030cf758e6d - (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, 'name': u'st2.sensor.process_exit', - 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', 'type': u'core.st2.sensor.process_exit', + 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger updated. Trigger.id=570e9704909a5030cf758e6d + (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, 'name': u'st2.sensor.process_exit', + 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', 'type': u'core.st2.sensor.process_exit', 'id': '570e9704909a5030cf758e6d', 'pack': u'core'}) - 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger created for parameter-less TriggerType. - Trigger.id=570e9704909a5030cf758e6d (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, - 'name': u'st2.sensor.process_exit', 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', + 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger created for parameter-less TriggerType. + Trigger.id=570e9704909a5030cf758e6d (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, + 'name': u'st2.sensor.process_exit', 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', 'type': u'core.st2.sensor.process_exit', 'id': '570e9704909a5030cf758e6d', 'pack': u'core'}) 2016-04-26 18:27:15,605 140317722756912 DEBUG base [-] Conflict while trying to save in DB. Traceback (most recent call last): @@ -94,7 +94,7 @@ As we can see from above output port ``9101`` is not even up. To verify this let NotUniqueError: Could not save document (E11000 duplicate key error index: st2.role_d_b.$name_1 dup key: { : "system_admin" }) 2016-04-26 18:27:15,676 140317722756912 INFO driver [-] Generating grammar tables from /usr/lib/python2.7/lib2to3/Grammar.txt 2016-04-26 18:27:15,693 140317722756912 INFO driver [-] Generating grammar tables from /usr/lib/python2.7/lib2to3/PatternGrammar.txt - + - To figure out whats wrong let us dig down further. Activate the virtualenv in st2 and run following command : .. code:: bash @@ -108,7 +108,7 @@ The above mentioned command will give out logs, we may find some error in the en File "/home/vagrant/git/st2/st2common/st2common/models/api/keyvalue.py", line 19, in from keyczar.keys import AesKey ImportError: No module named keyczar.keys - + So the problem is : module keyczar is missing. This module can be downloaded using following command: *Solution:* @@ -116,7 +116,7 @@ So the problem is : module keyczar is missing. This module can be downloaded usi .. code:: bash (virtualenv) $ pip install python-keyczar - + This should fix the issue. Now deactivate the virtual env and run ``tools/launchdev.sh restart`` diff --git a/st2client/Makefile b/st2client/Makefile index e17db7e4f65..9d6cf70a660 100644 --- a/st2client/Makefile +++ b/st2client/Makefile @@ -9,7 +9,7 @@ RELEASE=1 COMPONENTS := st2client .PHONY: rpm -rpm: +rpm: $(PY3) setup.py bdist_rpm --python=$(PY3) mkdir -p $(RPM_ROOT)/RPMS/noarch cp dist/$(COMPONENTS)*noarch.rpm $(RPM_ROOT)/RPMS/noarch/$(COMPONENTS)-$(VER)-$(RELEASE).noarch.rpm diff --git a/st2client/tests/fixtures/execution_double_backslash.txt b/st2client/tests/fixtures/execution_double_backslash.txt index efb21dc7e34..5437c7add4b 100644 --- a/st2client/tests/fixtures/execution_double_backslash.txt +++ b/st2client/tests/fixtures/execution_double_backslash.txt @@ -1,12 +1,12 @@ id: 547e19561e2e2417d3dde333 action.ref: core.local context.user: stanley -parameters: +parameters: cmd: echo 'C:\Users\ADMINI~1\AppData\Local\Temp\jking_vmware20_test' status: succeeded (1s elapsed) start_timestamp: Tue, 02 Dec 2014 19:56:06 UTC end_timestamp: Tue, 02 Dec 2014 19:56:07 UTC -result: +result: failed: false return_code: 0 stderr: '' diff --git a/st2client/tests/fixtures/execution_get_default.txt b/st2client/tests/fixtures/execution_get_default.txt index 4dea32224a4..c29c2c221d5 100644 --- a/st2client/tests/fixtures/execution_get_default.txt +++ b/st2client/tests/fixtures/execution_get_default.txt @@ -1,12 +1,12 @@ id: 547e19561e2e2417d3dde398 action.ref: core.ping context.user: stanley -parameters: +parameters: cmd: 127.0.0.1 3 status: succeeded (1s elapsed) start_timestamp: Tue, 02 Dec 2014 19:56:06 UTC end_timestamp: Tue, 02 Dec 2014 19:56:07 UTC -result: +result: localhost: failed: false return_code: 0 diff --git a/st2client/tests/fixtures/execution_get_has_schema.txt b/st2client/tests/fixtures/execution_get_has_schema.txt index 4dea32224a4..c29c2c221d5 100644 --- a/st2client/tests/fixtures/execution_get_has_schema.txt +++ b/st2client/tests/fixtures/execution_get_has_schema.txt @@ -1,12 +1,12 @@ id: 547e19561e2e2417d3dde398 action.ref: core.ping context.user: stanley -parameters: +parameters: cmd: 127.0.0.1 3 status: succeeded (1s elapsed) start_timestamp: Tue, 02 Dec 2014 19:56:06 UTC end_timestamp: Tue, 02 Dec 2014 19:56:07 UTC -result: +result: localhost: failed: false return_code: 0 diff --git a/st2client/tests/fixtures/execution_unescape_newline.txt b/st2client/tests/fixtures/execution_unescape_newline.txt index a0b0624e554..4abac251a5a 100644 --- a/st2client/tests/fixtures/execution_unescape_newline.txt +++ b/st2client/tests/fixtures/execution_unescape_newline.txt @@ -5,7 +5,7 @@ parameters: None status: succeeded (1s elapsed) start_timestamp: Tue, 02 Dec 2014 19:56:06 UTC end_timestamp: Tue, 02 Dec 2014 19:56:07 UTC -result: +result: localhost: failed: false return_code: 0 diff --git a/st2client/tests/fixtures/execution_unicode.txt b/st2client/tests/fixtures/execution_unicode.txt index 54a9ccc2540..7b7491d3b71 100644 --- a/st2client/tests/fixtures/execution_unicode.txt +++ b/st2client/tests/fixtures/execution_unicode.txt @@ -1,12 +1,12 @@ id: 547e19561e2e2417d3dde321 action.ref: core.local context.user: stanley -parameters: +parameters: cmd: "echo '‡'" status: succeeded (1s elapsed) start_timestamp: Tue, 02 Dec 2014 19:56:06 UTC end_timestamp: Tue, 02 Dec 2014 19:56:07 UTC -result: +result: failed: false return_code: 0 stderr: '' diff --git a/st2client/tests/fixtures/execution_unicode_py3.txt b/st2client/tests/fixtures/execution_unicode_py3.txt index 0e69f4eff41..0db50aa746e 100644 --- a/st2client/tests/fixtures/execution_unicode_py3.txt +++ b/st2client/tests/fixtures/execution_unicode_py3.txt @@ -1,12 +1,12 @@ id: 547e19561e2e2417d3dde321 action.ref: core.local context.user: stanley -parameters: +parameters: cmd: "echo '\u2021'" status: succeeded (1s elapsed) start_timestamp: Tue, 02 Dec 2014 19:56:06 UTC end_timestamp: Tue, 02 Dec 2014 19:56:07 UTC -result: +result: failed: false return_code: 0 stderr: '' diff --git a/st2common/bin/st2-run-pack-tests b/st2common/bin/st2-run-pack-tests index 9f7c2306ab0..bed28267602 100755 --- a/st2common/bin/st2-run-pack-tests +++ b/st2common/bin/st2-run-pack-tests @@ -322,7 +322,7 @@ if [ "${ENABLE_COVERAGE}" = true ]; then # Base options to enable test coverage reporting # --with-coverage : enables coverage reporting - # --cover-erase : removes old coverage reports before starting + # --cover-erase : removes old coverage reports before starting NOSE_OPTS+=(--with-coverage --cover-erase) # Now, by default nosetests reports test coverage for every module found diff --git a/st2common/st2common/openapi.yaml b/st2common/st2common/openapi.yaml index ecab78f5ad4..f1a116c3d6d 100644 --- a/st2common/st2common/openapi.yaml +++ b/st2common/st2common/openapi.yaml @@ -8,7 +8,7 @@ info: version: "1.0.0" title: StackStorm API description: | - + ## Welcome Welcome to the StackStorm API Reference documentation! You can use the StackStorm API to integrate StackStorm with 3rd-party systems and custom applications. Example integrations include writing your own self-service user portal, or integrating with other orquestation systems. @@ -197,7 +197,7 @@ info: Join our [Slack Community](https://stackstorm.com/community-signup) to get help from the engineering team and fellow users. You can also create issues against the main [StackStorm GitHub repo](https://github.com/StackStorm/st2/issues), or the [st2apidocs repo](https://github.com/StackStorm/st2apidocs) for documentation-specific issues. We also recommend reviewing the main [StackStorm documentation](https://docs.stackstorm.com/). - + paths: /api/v1/: @@ -1477,7 +1477,7 @@ paths: /api/v1/keys: get: operationId: st2api.controllers.v1.keyvalue:key_value_pair_controller.get_all - x-permissions: + x-permissions: description: Returns a list of all key value pairs. parameters: - name: prefix diff --git a/st2reactor/Makefile b/st2reactor/Makefile index 232abed4dd5..cd3eb75a3ee 100644 --- a/st2reactor/Makefile +++ b/st2reactor/Makefile @@ -7,7 +7,7 @@ VER=0.4.0 COMPONENTS := st2reactor .PHONY: rpm -rpm: +rpm: pushd ~ && rpmdev-setuptree && popd tar --transform=s~^~$(COMPONENTS)-$(VER)/~ -czf $(RPM_SOURCES_DIR)/$(COMPONENTS).tar.gz bin conf $(COMPONENTS) cp packaging/rpm/$(COMPONENTS).spec $(RPM_SPECS_DIR)/ diff --git a/st2tests/st2tests/fixtures/generic/runners/inquirer.yaml b/st2tests/st2tests/fixtures/generic/runners/inquirer.yaml index f49903a7e92..421262c52bf 100644 --- a/st2tests/st2tests/fixtures/generic/runners/inquirer.yaml +++ b/st2tests/st2tests/fixtures/generic/runners/inquirer.yaml @@ -24,7 +24,7 @@ runner_parameters: roles: default: [] required: false - description: A list of roles that are permitted to respond to the action (if nothing provided, all are permitted) - REQUIRES RBAC FEATURES + description: A list of roles that are permitted to respond to the action (if nothing provided, all are permitted) - REQUIRES RBAC FEATURES type: array users: default: [] diff --git a/st2tests/st2tests/fixtures/packs/action_chain_tests/actions/chains/test_pause_resume_with_init_vars.yaml b/st2tests/st2tests/fixtures/packs/action_chain_tests/actions/chains/test_pause_resume_with_init_vars.yaml index 72114dd49e1..dda36fb3025 100644 --- a/st2tests/st2tests/fixtures/packs/action_chain_tests/actions/chains/test_pause_resume_with_init_vars.yaml +++ b/st2tests/st2tests/fixtures/packs/action_chain_tests/actions/chains/test_pause_resume_with_init_vars.yaml @@ -8,7 +8,7 @@ chain: cmd: "while [ -e '{{tempfile}}' ]; do sleep 0.1; done" timeout: 180 publish: - var1: "{{var1|upper}}" + var1: "{{var1|upper}}" on-success: task2 - name: task2 diff --git a/st2tests/st2tests/fixtures/packs/dummy_pack_20/pack.yaml b/st2tests/st2tests/fixtures/packs/dummy_pack_20/pack.yaml index a843fb64d4c..c99661eb7c7 100644 --- a/st2tests/st2tests/fixtures/packs/dummy_pack_20/pack.yaml +++ b/st2tests/st2tests/fixtures/packs/dummy_pack_20/pack.yaml @@ -9,6 +9,6 @@ attribute1: value1 attribute2: value2 attribute3: value3 attribute: 4 -some: +some: - "feature" - value diff --git a/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/jinja-version-functions.yaml b/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/jinja-version-functions.yaml index 579bc515bb3..0d77ef8aeff 100644 --- a/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/jinja-version-functions.yaml +++ b/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/jinja-version-functions.yaml @@ -14,7 +14,7 @@ output: - more_than: '{{ version_more_than("0.9.0", "0.10.0") }}' - less_than: '{{ version_less_than("0.10.0", "0.9.0") }}' - match: '{{ version_match("0.10.1", ">0.10.0") }}' - - bump_major: '{{ version_bump_major("0.10.0") }}' + - bump_major: '{{ version_bump_major("0.10.0") }}' - bump_minor: '{{ version_bump_minor("0.10.0") }}' - bump_patch: '{{ version_bump_patch("0.10.0") }}' - strip_patch: '{{ version_strip_patch("0.10.0") }}' diff --git a/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/subworkflow-default-value-from-action-context.yaml b/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/subworkflow-default-value-from-action-context.yaml index 3301ab423a4..90e3ac78ade 100644 --- a/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/subworkflow-default-value-from-action-context.yaml +++ b/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/subworkflow-default-value-from-action-context.yaml @@ -3,7 +3,7 @@ version: 1.0 description: A sample workflow that calls another subworkflow. output: - - msg: <% task(task1).result.output.msg %> + - msg: <% task(task1).result.output.msg %> tasks: task1: diff --git a/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/subworkflow-source-channel-from-action-context.yaml b/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/subworkflow-source-channel-from-action-context.yaml index 7a6bd62fa54..eedc5b8c3e3 100644 --- a/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/subworkflow-source-channel-from-action-context.yaml +++ b/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/subworkflow-source-channel-from-action-context.yaml @@ -3,7 +3,7 @@ version: 1.0 description: A sample workflow that calls another subworkflow. output: - - msg: <% task(task1).result.output.msg %> + - msg: <% task(task1).result.output.msg %> tasks: task1: diff --git a/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/yaql-version-functions.yaml b/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/yaql-version-functions.yaml index 7bda9cc83b6..cce350c46c6 100644 --- a/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/yaql-version-functions.yaml +++ b/st2tests/st2tests/fixtures/packs/orquesta_tests/actions/workflows/yaql-version-functions.yaml @@ -14,7 +14,7 @@ output: - more_than: '<% version_more_than("0.9.0", "0.10.0") %>' - less_than: '<% version_less_than("0.10.0", "0.9.0") %>' - match: '<% version_match("0.10.1", ">0.10.0") %>' - - bump_major: '<% version_bump_major("0.10.0") %>' + - bump_major: '<% version_bump_major("0.10.0") %>' - bump_minor: '<% version_bump_minor("0.10.0") %>' - bump_patch: '<% version_bump_patch("0.10.0") %>' - strip_patch: '<% version_strip_patch("0.10.0") %>' diff --git a/st2tests/testpacks/checks/actions/check_loadavg.yaml b/st2tests/testpacks/checks/actions/check_loadavg.yaml index 06abc652278..ac38037d6c8 100644 --- a/st2tests/testpacks/checks/actions/check_loadavg.yaml +++ b/st2tests/testpacks/checks/actions/check_loadavg.yaml @@ -4,8 +4,8 @@ description: "Check CPU Load Average on a Host" enabled: true entry_point: "checks/check_loadavg.py" - parameters: - period: + parameters: + period: type: "string" description: "Time period for load avg: 5,10,15 minutes, or 'all'" default: "all" diff --git a/st2tests/testpacks/errorcheck/actions/exit-code.sh b/st2tests/testpacks/errorcheck/actions/exit-code.sh index 2e6eadf6a2c..5320dc2f363 100755 --- a/st2tests/testpacks/errorcheck/actions/exit-code.sh +++ b/st2tests/testpacks/errorcheck/actions/exit-code.sh @@ -6,4 +6,4 @@ if [ -n "$1" ] exit_code=$1 fi -exit $exit_code +exit $exit_code diff --git a/tox.ini b/tox.ini index de40b858789..451ceee8e1e 100644 --- a/tox.ini +++ b/tox.ini @@ -71,7 +71,7 @@ commands = [testenv:py36-integration] basepython = python3.6 -setenv = PYTHONPATH = {toxinidir}/external:{toxinidir}/st2common:{toxinidir}/st2auth:{toxinidir}/st2api:{toxinidir}/st2actions:{toxinidir}/st2exporter:{toxinidir}/st2reactor:{toxinidir}/st2tests:{toxinidir}/contrib/runners/action_chain_runner:{toxinidir}/contrib/runners/local_runner:{toxinidir}/contrib/runners/python_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/noop_runner:{toxinidir}/contrib/runners/announcement_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/orquesta_runner:{toxinidir}/contrib/runners/inquirer_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/winrm_runner +setenv = PYTHONPATH = {toxinidir}/external:{toxinidir}/st2common:{toxinidir}/st2auth:{toxinidir}/st2api:{toxinidir}/st2actions:{toxinidir}/st2exporter:{toxinidir}/st2reactor:{toxinidir}/st2tests:{toxinidir}/contrib/runners/action_chain_runner:{toxinidir}/contrib/runners/local_runner:{toxinidir}/contrib/runners/python_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/noop_runner:{toxinidir}/contrib/runners/announcement_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/orquesta_runner:{toxinidir}/contrib/runners/inquirer_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/winrm_runner VIRTUALENV_DIR = {envdir} passenv = NOSE_WITH_TIMER TRAVIS ST2_CI install_command = pip install -U --force-reinstall {opts} {packages} From 33d9efd711ddd3626ef611e867c81a0a5084739c Mon Sep 17 00:00:00 2001 From: Tomaz Muraus Date: Thu, 18 Feb 2021 11:00:37 +0100 Subject: [PATCH 11/25] Exclude test fixture files from trailing whitespace hook since it breaks some tests. --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f4496a22781..9ccc28e3226 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -38,5 +38,5 @@ repos: rev: v2.5.0 hooks: - id: trailing-whitespace + exclude: (^st2common/st2common/openapi.yaml|^st2client/tests/fixtures|^st2tests/st2tests/fixtures) - id: check-yaml - exclude: (openapi\.yaml) From 223a7ade496cbe0bb3f26b529d6d9d1c0f69a96c Mon Sep 17 00:00:00 2001 From: StackStorm CodeFormat Date: Thu, 18 Feb 2021 11:04:22 +0100 Subject: [PATCH 12/25] Remove trailing whitespace. --- .circleci/config.yml | 2 +- CHANGELOG.rst | 2 +- OWNERS.md | 2 +- README.md | 12 ++++---- conf/st2.conf.sample | 4 +-- conf/st2.dev.conf | 2 +- contrib/core/CHANGES.md | 2 +- contrib/examples/actions/forloop_chain.yaml | 2 +- .../actions/forloop_push_github_repos.yaml | 2 +- .../actions/orquesta-mock-create-vm.yaml | 2 +- .../actions/workflows/orquesta-delay.yaml | 2 +- .../orquesta-error-handling-continue.yaml | 2 +- ...orquesta-error-handling-fail-manually.yaml | 2 +- .../orquesta-error-handling-noop.yaml | 2 +- .../workflows/orquesta-fail-manually.yaml | 2 +- .../actions/workflows/orquesta-join.yaml | 2 +- .../orquesta-remediate-then-fail.yaml | 2 +- .../workflows/orquesta-rollback-retry.yaml | 2 +- .../workflows/orquesta-sequential.yaml | 2 +- .../orquesta-with-items-concurrency.yaml | 2 +- .../workflows/orquesta-with-items.yaml | 2 +- .../tests/orquesta-fail-input-rendering.yaml | 2 +- ...rquesta-fail-inspection-task-contents.yaml | 2 +- .../tests/orquesta-fail-output-rendering.yaml | 2 +- .../tests/orquesta-fail-start-task.yaml | 2 +- .../tests/orquesta-fail-task-publish.yaml | 2 +- .../tests/orquesta-fail-task-transition.yaml | 2 +- .../tests/orquesta-fail-vars-rendering.yaml | 2 +- .../tests/orquesta-test-pause-resume.yaml | 2 +- .../workflows/tests/orquesta-test-rerun.yaml | 2 +- .../tests/orquesta-test-with-items.yaml | 2 +- contrib/linux/README.md | 2 +- contrib/linux/sensors/README.md | 4 +-- contrib/packs/actions/install.meta.yaml | 2 +- contrib/packs/actions/setup_virtualenv.yaml | 2 +- .../inquirer_runner/runner.yaml | 2 +- dev_docs/Troubleshooting_Guide.rst | 28 +++++++++---------- st2client/Makefile | 2 +- st2common/bin/st2-run-pack-tests | 2 +- st2reactor/Makefile | 2 +- .../checks/actions/check_loadavg.yaml | 4 +-- .../testpacks/errorcheck/actions/exit-code.sh | 2 +- tox.ini | 2 +- 43 files changed, 64 insertions(+), 64 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index bcbacbe3bd6..bde4a90784d 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1,5 +1,5 @@ # Setup in CircleCI account the following ENV variables: -# PACKAGECLOUD_ORGANIZATION (default: stackstorm) +# PACKAGECLOUD_ORGANIZATION (default: stackstorm) # PACKAGECLOUD_TOKEN version: 2 jobs: diff --git a/CHANGELOG.rst b/CHANGELOG.rst index d0460f4ea2a..a8b52fb6740 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -27,7 +27,7 @@ Changed * Improve the st2-self-check script to echo to stderr and exit if it isn't run with a ST2_AUTH_TOKEN or ST2_API_KEY environment variable. (improvement) #5068 -* Added timeout parameter for packs.install action to help with long running installs that exceed the +* Added timeout parameter for packs.install action to help with long running installs that exceed the default timeout of 600 sec which is defined by the python_script action runner (improvement) #5084 Contributed by @hnanchahal diff --git a/OWNERS.md b/OWNERS.md index dfb8fb87bc2..501e7b8f144 100644 --- a/OWNERS.md +++ b/OWNERS.md @@ -74,7 +74,7 @@ Thank you, Friends! * Johan Dahlberg ([@johandahlberg](https://github.com/johandahlberg)) - Using st2 for Bioinformatics/Science project, providing feedback & contributions in Ansible, Community, Workflows. [Case Study](https://stackstorm.com/case-study-scilifelab/). * Johan Hermansson ([@johanherman](https://github.com/johanherman)) - Using st2 for Bioinformatics/Science project, feedback & contributions in Ansible, Community, Workflows. [Case Study](https://stackstorm.com/case-study-scilifelab/). * Lakshmi Kannan ([@lakshmi-kannan](https://github.com/lakshmi-kannan)) - early Stormer. Initial Core platform architecture, scalability, reliability, Team Leadership during the project hard times. -* Lindsay Hill ([@LindsayHill](https://github.com/LindsayHill)) - ex StackStorm product manager that made a significant impact building an ecosystem we see today. +* Lindsay Hill ([@LindsayHill](https://github.com/LindsayHill)) - ex StackStorm product manager that made a significant impact building an ecosystem we see today. * Manas Kelshikar ([@manasdk](https://github.com/manasdk)) - ex Stormer. Developed (well) early core platform features. * Vineesh Jain ([@VineeshJain](https://github.com/VineeshJain)) - ex Stormer. Community, Tests, Core, QA. * Warren Van Winckel ([@warrenvw](https://github.com/warrenvw)) - ex Stormer. Docker, Kubernetes, Vagrant, Infrastructure. diff --git a/README.md b/README.md index 4d84895bbd6..b22e908d5cd 100644 --- a/README.md +++ b/README.md @@ -4,12 +4,12 @@ [![Build Status](https://github.com/StackStorm/st2/workflows/ci-checks/badge.svg?branch=master)](https://github.com/StackStorm/st2/actions?query=branch%3Amaster) [![Travis Integration Tests Status](https://travis-ci.org/StackStorm/st2.svg?branch=master)](https://travis-ci.org/StackStorm/st2) -[![Packages Build Status](https://circleci.com/gh/StackStorm/st2/tree/master.svg?style=shield)](https://circleci.com/gh/StackStorm/st2) -[![Codecov](https://codecov.io/github/StackStorm/st2/badge.svg?branch=master&service=github)](https://codecov.io/github/StackStorm/st2?branch=master) -[![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1833/badge)](https://bestpractices.coreinfrastructure.org/projects/1833) -![Python 3.6](https://img.shields.io/badge/python-3.6-blue) -[![Apache Licensed](https://img.shields.io/github/license/StackStorm/st2)](LICENSE) -[![Join our community Slack](https://img.shields.io/badge/slack-stackstorm-success.svg?logo=slack)](https://stackstorm.com/community-signup) +[![Packages Build Status](https://circleci.com/gh/StackStorm/st2/tree/master.svg?style=shield)](https://circleci.com/gh/StackStorm/st2) +[![Codecov](https://codecov.io/github/StackStorm/st2/badge.svg?branch=master&service=github)](https://codecov.io/github/StackStorm/st2?branch=master) +[![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1833/badge)](https://bestpractices.coreinfrastructure.org/projects/1833) +![Python 3.6](https://img.shields.io/badge/python-3.6-blue) +[![Apache Licensed](https://img.shields.io/github/license/StackStorm/st2)](LICENSE) +[![Join our community Slack](https://img.shields.io/badge/slack-stackstorm-success.svg?logo=slack)](https://stackstorm.com/community-signup) [![Forum](https://img.shields.io/discourse/https/forum.stackstorm.com/posts.svg)](https://forum.stackstorm.com/) --- diff --git a/conf/st2.conf.sample b/conf/st2.conf.sample index 758b743e75f..488939eb55e 100644 --- a/conf/st2.conf.sample +++ b/conf/st2.conf.sample @@ -2,7 +2,7 @@ # Note: This file is automatically generated using tools/config_gen.py - DO NOT UPDATE MANUALLY [action_sensor] -# List of execution statuses for which a trigger will be emitted. +# List of execution statuses for which a trigger will be emitted. emit_when = succeeded,failed,timeout,canceled,abandoned # comma separated list allowed here. # Whether to enable or disable the ability to post a trigger on action. enable = True @@ -170,7 +170,7 @@ trigger_instances_ttl = None # Allow encryption of values in key value stored qualified as "secret". enable_encryption = True # Location of the symmetric encryption key for encrypting values in kvstore. This key should be in JSON and should've been generated using st2-generate-symmetric-crypto-key tool. -encryption_key_path = +encryption_key_path = [log] # Exclusion list of loggers to omit. diff --git a/conf/st2.dev.conf b/conf/st2.dev.conf index 2357b082634..29078016d00 100644 --- a/conf/st2.dev.conf +++ b/conf/st2.dev.conf @@ -83,7 +83,7 @@ protocol = udp # - redis # - etcd3 # - etcd3gw -# Keep in mind that zake driver works in process so it won't work when testing cross process +# Keep in mind that zake driver works in process so it won't work when testing cross process # / cross server functionality #url = redis://localhost #url = kazoo://localhost diff --git a/contrib/core/CHANGES.md b/contrib/core/CHANGES.md index b9c04efa88d..c0b1692b039 100644 --- a/contrib/core/CHANGES.md +++ b/contrib/core/CHANGES.md @@ -1,5 +1,5 @@ # Changelog - + ## 0.3.1 * Add new ``core.uuid`` action for generating type 1 and type 4 UUIDs. diff --git a/contrib/examples/actions/forloop_chain.yaml b/contrib/examples/actions/forloop_chain.yaml index f226eae4202..86ead5303a6 100644 --- a/contrib/examples/actions/forloop_chain.yaml +++ b/contrib/examples/actions/forloop_chain.yaml @@ -6,7 +6,7 @@ entry_point: "chains/forloop_chain.yaml" enabled: true parameters: github_organization_url: - type: "string" + type: "string" description: "Organization url to parse data from" default: "https://github.com/StackStorm-Exchange" required: false diff --git a/contrib/examples/actions/forloop_push_github_repos.yaml b/contrib/examples/actions/forloop_push_github_repos.yaml index 3ff06eabc33..878772636aa 100644 --- a/contrib/examples/actions/forloop_push_github_repos.yaml +++ b/contrib/examples/actions/forloop_push_github_repos.yaml @@ -5,7 +5,7 @@ description: "Action to push the data to an external service" enabled: true entry_point: "pythonactions/forloop_push_github_repos.py" parameters: - data_to_push: + data_to_push: type: "object" description: "Dictonary of the data to be pushed" required: true diff --git a/contrib/examples/actions/orquesta-mock-create-vm.yaml b/contrib/examples/actions/orquesta-mock-create-vm.yaml index 85e774a7024..35c5ab26d81 100644 --- a/contrib/examples/actions/orquesta-mock-create-vm.yaml +++ b/contrib/examples/actions/orquesta-mock-create-vm.yaml @@ -15,7 +15,7 @@ parameters: required: true type: string ip: - default: "10.1.23.99" + default: "10.1.23.99" required: true type: string meta: diff --git a/contrib/examples/actions/workflows/orquesta-delay.yaml b/contrib/examples/actions/workflows/orquesta-delay.yaml index a0793f8bf6d..82a131712c4 100644 --- a/contrib/examples/actions/workflows/orquesta-delay.yaml +++ b/contrib/examples/actions/workflows/orquesta-delay.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic sequential workflow. input: diff --git a/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml b/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml index 5d9c6f22a0a..80047d2e5ed 100644 --- a/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml +++ b/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that demonstrates error handler with continue. input: diff --git a/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml b/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml index da9179b5edb..4e3dfa38c2c 100644 --- a/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml +++ b/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that demonstrates error handler with remediation and explicit fail. input: diff --git a/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml b/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml index 61b14a3c11d..e949dc37420 100644 --- a/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml +++ b/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that demonstrates error handler with noop to ignore error. input: diff --git a/contrib/examples/actions/workflows/orquesta-fail-manually.yaml b/contrib/examples/actions/workflows/orquesta-fail-manually.yaml index 936db68ff3e..b86d8ef25bc 100644 --- a/contrib/examples/actions/workflows/orquesta-fail-manually.yaml +++ b/contrib/examples/actions/workflows/orquesta-fail-manually.yaml @@ -11,7 +11,7 @@ tasks: - when: <% failed() %> publish: - task_name: <% task().task_name %> - - task_exit_code: <% task().result.stdout %> + - task_exit_code: <% task().result.stdout %> do: - log - fail diff --git a/contrib/examples/actions/workflows/orquesta-join.yaml b/contrib/examples/actions/workflows/orquesta-join.yaml index eaf09fed66a..a247423948d 100644 --- a/contrib/examples/actions/workflows/orquesta-join.yaml +++ b/contrib/examples/actions/workflows/orquesta-join.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that demonstrate branching and join. vars: diff --git a/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml b/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml index 936db68ff3e..b86d8ef25bc 100644 --- a/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml +++ b/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml @@ -11,7 +11,7 @@ tasks: - when: <% failed() %> publish: - task_name: <% task().task_name %> - - task_exit_code: <% task().result.stdout %> + - task_exit_code: <% task().result.stdout %> do: - log - fail diff --git a/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml b/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml index 0d80b0dbcb5..a1f203fb095 100644 --- a/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml +++ b/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: > A sample workflow that demonstrates how to handle rollback and retry on error. In this example, the workflow will loop until the file /tmp/done exists. A parallel task will wait for some time diff --git a/contrib/examples/actions/workflows/orquesta-sequential.yaml b/contrib/examples/actions/workflows/orquesta-sequential.yaml index 3a03409d36d..404681a3698 100644 --- a/contrib/examples/actions/workflows/orquesta-sequential.yaml +++ b/contrib/examples/actions/workflows/orquesta-sequential.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic sequential workflow. input: diff --git a/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml b/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml index e20b9078988..6bcbb82c583 100644 --- a/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml +++ b/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A workflow demonstrating with items and concurrent processing. input: diff --git a/contrib/examples/actions/workflows/orquesta-with-items.yaml b/contrib/examples/actions/workflows/orquesta-with-items.yaml index 6a2cc4af494..5833e270510 100644 --- a/contrib/examples/actions/workflows/orquesta-with-items.yaml +++ b/contrib/examples/actions/workflows/orquesta-with-items.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A workflow demonstrating with items. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml index ce935f62f7d..907a18e8bfe 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error while evaluating input. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml index c0322d025e8..a8be5311807 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic sequential workflow with inspection error(s). input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml index dd1e5164411..003ab8b69db 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error while evaluating output. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml index a0deab1d8f3..0c23ee6a82e 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error in the rendering of the starting task. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml index 0887d4a7beb..149fb93b97b 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that fails on publish during task transition. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml index 8fd2a94d8a6..4d4d9e5f392 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error while evaluating task transition. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml index 403728100ab..4ddd9867557 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error while evaluating vars. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml b/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml index 7123727cc31..285bf972d7a 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml @@ -19,4 +19,4 @@ tasks: task2: action: core.local input: - cmd: 'echo "<% $.var1 %>"' + cmd: 'echo "<% $.var1 %>"' diff --git a/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml b/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml index 11eb22a721a..3a4b20cee02 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml @@ -4,7 +4,7 @@ description: A sample workflow used to test the rerun feature. input: - tempfile - + tasks: task1: action: core.noop diff --git a/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml b/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml index 8af6899b595..6e24c0ec411 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A workflow for testing with items and concurrency. input: diff --git a/contrib/linux/README.md b/contrib/linux/README.md index 33d872cf868..e2b9f09d44c 100644 --- a/contrib/linux/README.md +++ b/contrib/linux/README.md @@ -55,4 +55,4 @@ Example trigger payload: ## Troubleshooting -* On CentOS7/RHEL7, dig is not installed by default. Run ``sudo yum install bind-utils`` to install. \ No newline at end of file +* On CentOS7/RHEL7, dig is not installed by default. Run ``sudo yum install bind-utils`` to install. \ No newline at end of file diff --git a/contrib/linux/sensors/README.md b/contrib/linux/sensors/README.md index 7924e91e17a..084fcad6a6e 100644 --- a/contrib/linux/sensors/README.md +++ b/contrib/linux/sensors/README.md @@ -1,6 +1,6 @@ ## NOTICE -File watch sensor has been updated to use trigger with parameters supplied via a rule approach. Tailing a file path supplied via a config file is now deprecated. +File watch sensor has been updated to use trigger with parameters supplied via a rule approach. Tailing a file path supplied via a config file is now deprecated. An example rule to supply a file path is as follows: @@ -25,5 +25,5 @@ action: ``` -Trigger ``linux.file_watch.line`` still emits the same payload as it used to. +Trigger ``linux.file_watch.line`` still emits the same payload as it used to. Just the way to provide the file_path to tail has changed. diff --git a/contrib/packs/actions/install.meta.yaml b/contrib/packs/actions/install.meta.yaml index 1b8d0d572a1..191accd1c30 100644 --- a/contrib/packs/actions/install.meta.yaml +++ b/contrib/packs/actions/install.meta.yaml @@ -35,6 +35,6 @@ timeout: default: 600 required: false - description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout + description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout type: integer diff --git a/contrib/packs/actions/setup_virtualenv.yaml b/contrib/packs/actions/setup_virtualenv.yaml index 18d1b3df157..47091705f3e 100644 --- a/contrib/packs/actions/setup_virtualenv.yaml +++ b/contrib/packs/actions/setup_virtualenv.yaml @@ -27,5 +27,5 @@ timeout: default: 600 required: false - description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout + description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout type: integer diff --git a/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml b/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml index 26711df8500..60d79a5b740 100644 --- a/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml +++ b/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml @@ -23,7 +23,7 @@ roles: default: [] required: false - description: A list of roles that are permitted to respond to the action (if nothing provided, all are permitted) - REQUIRES RBAC FEATURES + description: A list of roles that are permitted to respond to the action (if nothing provided, all are permitted) - REQUIRES RBAC FEATURES type: array users: default: [] diff --git a/dev_docs/Troubleshooting_Guide.rst b/dev_docs/Troubleshooting_Guide.rst index 4e1c1f22d21..f61cedcba4e 100644 --- a/dev_docs/Troubleshooting_Guide.rst +++ b/dev_docs/Troubleshooting_Guide.rst @@ -28,7 +28,7 @@ Troubleshooting Guide $ sudo netstat -tupln | grep 910 tcp 0 0 0.0.0.0:9100 0.0.0.0:* LISTEN 32420/python tcp 0 0 0.0.0.0:9102 0.0.0.0:* LISTEN 32403/python - + As we can see from above output port ``9101`` is not even up. To verify this let us try another command: .. code:: bash @@ -36,10 +36,10 @@ As we can see from above output port ``9101`` is not even up. To verify this let $ ps auxww | grep st2 | grep 910 vagrant 32420 0.2 1.5 79228 31364 pts/10 Ss+ 18:27 0:00 /home/vagrant/git/st2/virtualenv/bin/python ./virtualenv/bin/gunicorn st2auth.wsgi:application -k eventlet -b 0.0.0.0:9100 --workers 1 - vagrant@ether git/st2 (master %) » ps auxww | grep st2 | grep 32403 + vagrant@ether git/st2 (master %) » ps auxww | grep st2 | grep 32403 vagrant 32403 0.2 1.5 79228 31364 pts/3 Ss+ 18:27 0:00 /home/vagrant/git/st2/virtualenv/bin/python ./virtualenv/bin/gunicorn st2stream.wsgi:application -k eventlet -b 0.0.0.0:9102 --workers 1 - + - This suggests that the API process crashed, we can verify that by running ``screen -ls``.:: .. code:: bash @@ -51,19 +51,19 @@ As we can see from above output port ``9101`` is not even up. To verify this let 15767.st2-sensorcontainer (04/26/2016 06:39:10 PM) (Detached) 15762.st2-stream (04/26/2016 06:39:10 PM) (Detached) 3 Sockets in /var/run/screen/S-vagrant. - -- Now let us check the logs for any errors: + +- Now let us check the logs for any errors: .. code:: bash tail logs/st2api.log - 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger updated. Trigger.id=570e9704909a5030cf758e6d - (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, 'name': u'st2.sensor.process_exit', - 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', 'type': u'core.st2.sensor.process_exit', + 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger updated. Trigger.id=570e9704909a5030cf758e6d + (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, 'name': u'st2.sensor.process_exit', + 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', 'type': u'core.st2.sensor.process_exit', 'id': '570e9704909a5030cf758e6d', 'pack': u'core'}) - 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger created for parameter-less TriggerType. - Trigger.id=570e9704909a5030cf758e6d (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, - 'name': u'st2.sensor.process_exit', 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', + 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger created for parameter-less TriggerType. + Trigger.id=570e9704909a5030cf758e6d (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, + 'name': u'st2.sensor.process_exit', 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', 'type': u'core.st2.sensor.process_exit', 'id': '570e9704909a5030cf758e6d', 'pack': u'core'}) 2016-04-26 18:27:15,605 140317722756912 DEBUG base [-] Conflict while trying to save in DB. Traceback (most recent call last): @@ -94,7 +94,7 @@ As we can see from above output port ``9101`` is not even up. To verify this let NotUniqueError: Could not save document (E11000 duplicate key error index: st2.role_d_b.$name_1 dup key: { : "system_admin" }) 2016-04-26 18:27:15,676 140317722756912 INFO driver [-] Generating grammar tables from /usr/lib/python2.7/lib2to3/Grammar.txt 2016-04-26 18:27:15,693 140317722756912 INFO driver [-] Generating grammar tables from /usr/lib/python2.7/lib2to3/PatternGrammar.txt - + - To figure out whats wrong let us dig down further. Activate the virtualenv in st2 and run following command : .. code:: bash @@ -108,7 +108,7 @@ The above mentioned command will give out logs, we may find some error in the en File "/home/vagrant/git/st2/st2common/st2common/models/api/keyvalue.py", line 19, in from keyczar.keys import AesKey ImportError: No module named keyczar.keys - + So the problem is : module keyczar is missing. This module can be downloaded using following command: *Solution:* @@ -116,7 +116,7 @@ So the problem is : module keyczar is missing. This module can be downloaded usi .. code:: bash (virtualenv) $ pip install python-keyczar - + This should fix the issue. Now deactivate the virtual env and run ``tools/launchdev.sh restart`` diff --git a/st2client/Makefile b/st2client/Makefile index 9d6cf70a660..e17db7e4f65 100644 --- a/st2client/Makefile +++ b/st2client/Makefile @@ -9,7 +9,7 @@ RELEASE=1 COMPONENTS := st2client .PHONY: rpm -rpm: +rpm: $(PY3) setup.py bdist_rpm --python=$(PY3) mkdir -p $(RPM_ROOT)/RPMS/noarch cp dist/$(COMPONENTS)*noarch.rpm $(RPM_ROOT)/RPMS/noarch/$(COMPONENTS)-$(VER)-$(RELEASE).noarch.rpm diff --git a/st2common/bin/st2-run-pack-tests b/st2common/bin/st2-run-pack-tests index bed28267602..9f7c2306ab0 100755 --- a/st2common/bin/st2-run-pack-tests +++ b/st2common/bin/st2-run-pack-tests @@ -322,7 +322,7 @@ if [ "${ENABLE_COVERAGE}" = true ]; then # Base options to enable test coverage reporting # --with-coverage : enables coverage reporting - # --cover-erase : removes old coverage reports before starting + # --cover-erase : removes old coverage reports before starting NOSE_OPTS+=(--with-coverage --cover-erase) # Now, by default nosetests reports test coverage for every module found diff --git a/st2reactor/Makefile b/st2reactor/Makefile index cd3eb75a3ee..232abed4dd5 100644 --- a/st2reactor/Makefile +++ b/st2reactor/Makefile @@ -7,7 +7,7 @@ VER=0.4.0 COMPONENTS := st2reactor .PHONY: rpm -rpm: +rpm: pushd ~ && rpmdev-setuptree && popd tar --transform=s~^~$(COMPONENTS)-$(VER)/~ -czf $(RPM_SOURCES_DIR)/$(COMPONENTS).tar.gz bin conf $(COMPONENTS) cp packaging/rpm/$(COMPONENTS).spec $(RPM_SPECS_DIR)/ diff --git a/st2tests/testpacks/checks/actions/check_loadavg.yaml b/st2tests/testpacks/checks/actions/check_loadavg.yaml index ac38037d6c8..06abc652278 100644 --- a/st2tests/testpacks/checks/actions/check_loadavg.yaml +++ b/st2tests/testpacks/checks/actions/check_loadavg.yaml @@ -4,8 +4,8 @@ description: "Check CPU Load Average on a Host" enabled: true entry_point: "checks/check_loadavg.py" - parameters: - period: + parameters: + period: type: "string" description: "Time period for load avg: 5,10,15 minutes, or 'all'" default: "all" diff --git a/st2tests/testpacks/errorcheck/actions/exit-code.sh b/st2tests/testpacks/errorcheck/actions/exit-code.sh index 5320dc2f363..2e6eadf6a2c 100755 --- a/st2tests/testpacks/errorcheck/actions/exit-code.sh +++ b/st2tests/testpacks/errorcheck/actions/exit-code.sh @@ -6,4 +6,4 @@ if [ -n "$1" ] exit_code=$1 fi -exit $exit_code +exit $exit_code diff --git a/tox.ini b/tox.ini index 451ceee8e1e..de40b858789 100644 --- a/tox.ini +++ b/tox.ini @@ -71,7 +71,7 @@ commands = [testenv:py36-integration] basepython = python3.6 -setenv = PYTHONPATH = {toxinidir}/external:{toxinidir}/st2common:{toxinidir}/st2auth:{toxinidir}/st2api:{toxinidir}/st2actions:{toxinidir}/st2exporter:{toxinidir}/st2reactor:{toxinidir}/st2tests:{toxinidir}/contrib/runners/action_chain_runner:{toxinidir}/contrib/runners/local_runner:{toxinidir}/contrib/runners/python_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/noop_runner:{toxinidir}/contrib/runners/announcement_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/orquesta_runner:{toxinidir}/contrib/runners/inquirer_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/winrm_runner +setenv = PYTHONPATH = {toxinidir}/external:{toxinidir}/st2common:{toxinidir}/st2auth:{toxinidir}/st2api:{toxinidir}/st2actions:{toxinidir}/st2exporter:{toxinidir}/st2reactor:{toxinidir}/st2tests:{toxinidir}/contrib/runners/action_chain_runner:{toxinidir}/contrib/runners/local_runner:{toxinidir}/contrib/runners/python_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/noop_runner:{toxinidir}/contrib/runners/announcement_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/orquesta_runner:{toxinidir}/contrib/runners/inquirer_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/winrm_runner VIRTUALENV_DIR = {envdir} passenv = NOSE_WITH_TIMER TRAVIS ST2_CI install_command = pip install -U --force-reinstall {opts} {packages} From eba4abe51ccb624f96a2fde8b9804d69fecda995 Mon Sep 17 00:00:00 2001 From: Tomaz Muraus Date: Thu, 18 Feb 2021 12:46:53 +0100 Subject: [PATCH 13/25] Also don't re-format config files. --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9ccc28e3226..c539e0b8548 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -38,5 +38,5 @@ repos: rev: v2.5.0 hooks: - id: trailing-whitespace - exclude: (^st2common/st2common/openapi.yaml|^st2client/tests/fixtures|^st2tests/st2tests/fixtures) + exclude: (^conf/|^st2common/st2common/openapi.yaml|^st2client/tests/fixtures|^st2tests/st2tests/fixtures) - id: check-yaml From 514bd279cc68d7dbaee33df57d7d94dc2183ee5b Mon Sep 17 00:00:00 2001 From: Tomaz Muraus Date: Thu, 18 Feb 2021 12:47:10 +0100 Subject: [PATCH 14/25] Revert "Remove trailing whitespace." This reverts commit 223a7ade496cbe0bb3f26b529d6d9d1c0f69a96c. --- .circleci/config.yml | 2 +- CHANGELOG.rst | 2 +- OWNERS.md | 2 +- README.md | 12 ++++---- conf/st2.conf.sample | 4 +-- conf/st2.dev.conf | 2 +- contrib/core/CHANGES.md | 2 +- contrib/examples/actions/forloop_chain.yaml | 2 +- .../actions/forloop_push_github_repos.yaml | 2 +- .../actions/orquesta-mock-create-vm.yaml | 2 +- .../actions/workflows/orquesta-delay.yaml | 2 +- .../orquesta-error-handling-continue.yaml | 2 +- ...orquesta-error-handling-fail-manually.yaml | 2 +- .../orquesta-error-handling-noop.yaml | 2 +- .../workflows/orquesta-fail-manually.yaml | 2 +- .../actions/workflows/orquesta-join.yaml | 2 +- .../orquesta-remediate-then-fail.yaml | 2 +- .../workflows/orquesta-rollback-retry.yaml | 2 +- .../workflows/orquesta-sequential.yaml | 2 +- .../orquesta-with-items-concurrency.yaml | 2 +- .../workflows/orquesta-with-items.yaml | 2 +- .../tests/orquesta-fail-input-rendering.yaml | 2 +- ...rquesta-fail-inspection-task-contents.yaml | 2 +- .../tests/orquesta-fail-output-rendering.yaml | 2 +- .../tests/orquesta-fail-start-task.yaml | 2 +- .../tests/orquesta-fail-task-publish.yaml | 2 +- .../tests/orquesta-fail-task-transition.yaml | 2 +- .../tests/orquesta-fail-vars-rendering.yaml | 2 +- .../tests/orquesta-test-pause-resume.yaml | 2 +- .../workflows/tests/orquesta-test-rerun.yaml | 2 +- .../tests/orquesta-test-with-items.yaml | 2 +- contrib/linux/README.md | 2 +- contrib/linux/sensors/README.md | 4 +-- contrib/packs/actions/install.meta.yaml | 2 +- contrib/packs/actions/setup_virtualenv.yaml | 2 +- .../inquirer_runner/runner.yaml | 2 +- dev_docs/Troubleshooting_Guide.rst | 28 +++++++++---------- st2client/Makefile | 2 +- st2common/bin/st2-run-pack-tests | 2 +- st2reactor/Makefile | 2 +- .../checks/actions/check_loadavg.yaml | 4 +-- .../testpacks/errorcheck/actions/exit-code.sh | 2 +- tox.ini | 2 +- 43 files changed, 64 insertions(+), 64 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index bde4a90784d..bcbacbe3bd6 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1,5 +1,5 @@ # Setup in CircleCI account the following ENV variables: -# PACKAGECLOUD_ORGANIZATION (default: stackstorm) +# PACKAGECLOUD_ORGANIZATION (default: stackstorm) # PACKAGECLOUD_TOKEN version: 2 jobs: diff --git a/CHANGELOG.rst b/CHANGELOG.rst index a8b52fb6740..d0460f4ea2a 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -27,7 +27,7 @@ Changed * Improve the st2-self-check script to echo to stderr and exit if it isn't run with a ST2_AUTH_TOKEN or ST2_API_KEY environment variable. (improvement) #5068 -* Added timeout parameter for packs.install action to help with long running installs that exceed the +* Added timeout parameter for packs.install action to help with long running installs that exceed the default timeout of 600 sec which is defined by the python_script action runner (improvement) #5084 Contributed by @hnanchahal diff --git a/OWNERS.md b/OWNERS.md index 501e7b8f144..dfb8fb87bc2 100644 --- a/OWNERS.md +++ b/OWNERS.md @@ -74,7 +74,7 @@ Thank you, Friends! * Johan Dahlberg ([@johandahlberg](https://github.com/johandahlberg)) - Using st2 for Bioinformatics/Science project, providing feedback & contributions in Ansible, Community, Workflows. [Case Study](https://stackstorm.com/case-study-scilifelab/). * Johan Hermansson ([@johanherman](https://github.com/johanherman)) - Using st2 for Bioinformatics/Science project, feedback & contributions in Ansible, Community, Workflows. [Case Study](https://stackstorm.com/case-study-scilifelab/). * Lakshmi Kannan ([@lakshmi-kannan](https://github.com/lakshmi-kannan)) - early Stormer. Initial Core platform architecture, scalability, reliability, Team Leadership during the project hard times. -* Lindsay Hill ([@LindsayHill](https://github.com/LindsayHill)) - ex StackStorm product manager that made a significant impact building an ecosystem we see today. +* Lindsay Hill ([@LindsayHill](https://github.com/LindsayHill)) - ex StackStorm product manager that made a significant impact building an ecosystem we see today. * Manas Kelshikar ([@manasdk](https://github.com/manasdk)) - ex Stormer. Developed (well) early core platform features. * Vineesh Jain ([@VineeshJain](https://github.com/VineeshJain)) - ex Stormer. Community, Tests, Core, QA. * Warren Van Winckel ([@warrenvw](https://github.com/warrenvw)) - ex Stormer. Docker, Kubernetes, Vagrant, Infrastructure. diff --git a/README.md b/README.md index b22e908d5cd..4d84895bbd6 100644 --- a/README.md +++ b/README.md @@ -4,12 +4,12 @@ [![Build Status](https://github.com/StackStorm/st2/workflows/ci-checks/badge.svg?branch=master)](https://github.com/StackStorm/st2/actions?query=branch%3Amaster) [![Travis Integration Tests Status](https://travis-ci.org/StackStorm/st2.svg?branch=master)](https://travis-ci.org/StackStorm/st2) -[![Packages Build Status](https://circleci.com/gh/StackStorm/st2/tree/master.svg?style=shield)](https://circleci.com/gh/StackStorm/st2) -[![Codecov](https://codecov.io/github/StackStorm/st2/badge.svg?branch=master&service=github)](https://codecov.io/github/StackStorm/st2?branch=master) -[![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1833/badge)](https://bestpractices.coreinfrastructure.org/projects/1833) -![Python 3.6](https://img.shields.io/badge/python-3.6-blue) -[![Apache Licensed](https://img.shields.io/github/license/StackStorm/st2)](LICENSE) -[![Join our community Slack](https://img.shields.io/badge/slack-stackstorm-success.svg?logo=slack)](https://stackstorm.com/community-signup) +[![Packages Build Status](https://circleci.com/gh/StackStorm/st2/tree/master.svg?style=shield)](https://circleci.com/gh/StackStorm/st2) +[![Codecov](https://codecov.io/github/StackStorm/st2/badge.svg?branch=master&service=github)](https://codecov.io/github/StackStorm/st2?branch=master) +[![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1833/badge)](https://bestpractices.coreinfrastructure.org/projects/1833) +![Python 3.6](https://img.shields.io/badge/python-3.6-blue) +[![Apache Licensed](https://img.shields.io/github/license/StackStorm/st2)](LICENSE) +[![Join our community Slack](https://img.shields.io/badge/slack-stackstorm-success.svg?logo=slack)](https://stackstorm.com/community-signup) [![Forum](https://img.shields.io/discourse/https/forum.stackstorm.com/posts.svg)](https://forum.stackstorm.com/) --- diff --git a/conf/st2.conf.sample b/conf/st2.conf.sample index 488939eb55e..758b743e75f 100644 --- a/conf/st2.conf.sample +++ b/conf/st2.conf.sample @@ -2,7 +2,7 @@ # Note: This file is automatically generated using tools/config_gen.py - DO NOT UPDATE MANUALLY [action_sensor] -# List of execution statuses for which a trigger will be emitted. +# List of execution statuses for which a trigger will be emitted. emit_when = succeeded,failed,timeout,canceled,abandoned # comma separated list allowed here. # Whether to enable or disable the ability to post a trigger on action. enable = True @@ -170,7 +170,7 @@ trigger_instances_ttl = None # Allow encryption of values in key value stored qualified as "secret". enable_encryption = True # Location of the symmetric encryption key for encrypting values in kvstore. This key should be in JSON and should've been generated using st2-generate-symmetric-crypto-key tool. -encryption_key_path = +encryption_key_path = [log] # Exclusion list of loggers to omit. diff --git a/conf/st2.dev.conf b/conf/st2.dev.conf index 29078016d00..2357b082634 100644 --- a/conf/st2.dev.conf +++ b/conf/st2.dev.conf @@ -83,7 +83,7 @@ protocol = udp # - redis # - etcd3 # - etcd3gw -# Keep in mind that zake driver works in process so it won't work when testing cross process +# Keep in mind that zake driver works in process so it won't work when testing cross process # / cross server functionality #url = redis://localhost #url = kazoo://localhost diff --git a/contrib/core/CHANGES.md b/contrib/core/CHANGES.md index c0b1692b039..b9c04efa88d 100644 --- a/contrib/core/CHANGES.md +++ b/contrib/core/CHANGES.md @@ -1,5 +1,5 @@ # Changelog - + ## 0.3.1 * Add new ``core.uuid`` action for generating type 1 and type 4 UUIDs. diff --git a/contrib/examples/actions/forloop_chain.yaml b/contrib/examples/actions/forloop_chain.yaml index 86ead5303a6..f226eae4202 100644 --- a/contrib/examples/actions/forloop_chain.yaml +++ b/contrib/examples/actions/forloop_chain.yaml @@ -6,7 +6,7 @@ entry_point: "chains/forloop_chain.yaml" enabled: true parameters: github_organization_url: - type: "string" + type: "string" description: "Organization url to parse data from" default: "https://github.com/StackStorm-Exchange" required: false diff --git a/contrib/examples/actions/forloop_push_github_repos.yaml b/contrib/examples/actions/forloop_push_github_repos.yaml index 878772636aa..3ff06eabc33 100644 --- a/contrib/examples/actions/forloop_push_github_repos.yaml +++ b/contrib/examples/actions/forloop_push_github_repos.yaml @@ -5,7 +5,7 @@ description: "Action to push the data to an external service" enabled: true entry_point: "pythonactions/forloop_push_github_repos.py" parameters: - data_to_push: + data_to_push: type: "object" description: "Dictonary of the data to be pushed" required: true diff --git a/contrib/examples/actions/orquesta-mock-create-vm.yaml b/contrib/examples/actions/orquesta-mock-create-vm.yaml index 35c5ab26d81..85e774a7024 100644 --- a/contrib/examples/actions/orquesta-mock-create-vm.yaml +++ b/contrib/examples/actions/orquesta-mock-create-vm.yaml @@ -15,7 +15,7 @@ parameters: required: true type: string ip: - default: "10.1.23.99" + default: "10.1.23.99" required: true type: string meta: diff --git a/contrib/examples/actions/workflows/orquesta-delay.yaml b/contrib/examples/actions/workflows/orquesta-delay.yaml index 82a131712c4..a0793f8bf6d 100644 --- a/contrib/examples/actions/workflows/orquesta-delay.yaml +++ b/contrib/examples/actions/workflows/orquesta-delay.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic sequential workflow. input: diff --git a/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml b/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml index 80047d2e5ed..5d9c6f22a0a 100644 --- a/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml +++ b/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that demonstrates error handler with continue. input: diff --git a/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml b/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml index 4e3dfa38c2c..da9179b5edb 100644 --- a/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml +++ b/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that demonstrates error handler with remediation and explicit fail. input: diff --git a/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml b/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml index e949dc37420..61b14a3c11d 100644 --- a/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml +++ b/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that demonstrates error handler with noop to ignore error. input: diff --git a/contrib/examples/actions/workflows/orquesta-fail-manually.yaml b/contrib/examples/actions/workflows/orquesta-fail-manually.yaml index b86d8ef25bc..936db68ff3e 100644 --- a/contrib/examples/actions/workflows/orquesta-fail-manually.yaml +++ b/contrib/examples/actions/workflows/orquesta-fail-manually.yaml @@ -11,7 +11,7 @@ tasks: - when: <% failed() %> publish: - task_name: <% task().task_name %> - - task_exit_code: <% task().result.stdout %> + - task_exit_code: <% task().result.stdout %> do: - log - fail diff --git a/contrib/examples/actions/workflows/orquesta-join.yaml b/contrib/examples/actions/workflows/orquesta-join.yaml index a247423948d..eaf09fed66a 100644 --- a/contrib/examples/actions/workflows/orquesta-join.yaml +++ b/contrib/examples/actions/workflows/orquesta-join.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that demonstrate branching and join. vars: diff --git a/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml b/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml index b86d8ef25bc..936db68ff3e 100644 --- a/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml +++ b/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml @@ -11,7 +11,7 @@ tasks: - when: <% failed() %> publish: - task_name: <% task().task_name %> - - task_exit_code: <% task().result.stdout %> + - task_exit_code: <% task().result.stdout %> do: - log - fail diff --git a/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml b/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml index a1f203fb095..0d80b0dbcb5 100644 --- a/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml +++ b/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: > A sample workflow that demonstrates how to handle rollback and retry on error. In this example, the workflow will loop until the file /tmp/done exists. A parallel task will wait for some time diff --git a/contrib/examples/actions/workflows/orquesta-sequential.yaml b/contrib/examples/actions/workflows/orquesta-sequential.yaml index 404681a3698..3a03409d36d 100644 --- a/contrib/examples/actions/workflows/orquesta-sequential.yaml +++ b/contrib/examples/actions/workflows/orquesta-sequential.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic sequential workflow. input: diff --git a/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml b/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml index 6bcbb82c583..e20b9078988 100644 --- a/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml +++ b/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A workflow demonstrating with items and concurrent processing. input: diff --git a/contrib/examples/actions/workflows/orquesta-with-items.yaml b/contrib/examples/actions/workflows/orquesta-with-items.yaml index 5833e270510..6a2cc4af494 100644 --- a/contrib/examples/actions/workflows/orquesta-with-items.yaml +++ b/contrib/examples/actions/workflows/orquesta-with-items.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A workflow demonstrating with items. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml index 907a18e8bfe..ce935f62f7d 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error while evaluating input. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml index a8be5311807..c0322d025e8 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic sequential workflow with inspection error(s). input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml index 003ab8b69db..dd1e5164411 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error while evaluating output. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml index 0c23ee6a82e..a0deab1d8f3 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error in the rendering of the starting task. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml index 149fb93b97b..0887d4a7beb 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that fails on publish during task transition. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml index 4d4d9e5f392..8fd2a94d8a6 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error while evaluating task transition. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml index 4ddd9867557..403728100ab 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error while evaluating vars. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml b/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml index 285bf972d7a..7123727cc31 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml @@ -19,4 +19,4 @@ tasks: task2: action: core.local input: - cmd: 'echo "<% $.var1 %>"' + cmd: 'echo "<% $.var1 %>"' diff --git a/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml b/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml index 3a4b20cee02..11eb22a721a 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml @@ -4,7 +4,7 @@ description: A sample workflow used to test the rerun feature. input: - tempfile - + tasks: task1: action: core.noop diff --git a/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml b/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml index 6e24c0ec411..8af6899b595 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A workflow for testing with items and concurrency. input: diff --git a/contrib/linux/README.md b/contrib/linux/README.md index e2b9f09d44c..33d872cf868 100644 --- a/contrib/linux/README.md +++ b/contrib/linux/README.md @@ -55,4 +55,4 @@ Example trigger payload: ## Troubleshooting -* On CentOS7/RHEL7, dig is not installed by default. Run ``sudo yum install bind-utils`` to install. \ No newline at end of file +* On CentOS7/RHEL7, dig is not installed by default. Run ``sudo yum install bind-utils`` to install. \ No newline at end of file diff --git a/contrib/linux/sensors/README.md b/contrib/linux/sensors/README.md index 084fcad6a6e..7924e91e17a 100644 --- a/contrib/linux/sensors/README.md +++ b/contrib/linux/sensors/README.md @@ -1,6 +1,6 @@ ## NOTICE -File watch sensor has been updated to use trigger with parameters supplied via a rule approach. Tailing a file path supplied via a config file is now deprecated. +File watch sensor has been updated to use trigger with parameters supplied via a rule approach. Tailing a file path supplied via a config file is now deprecated. An example rule to supply a file path is as follows: @@ -25,5 +25,5 @@ action: ``` -Trigger ``linux.file_watch.line`` still emits the same payload as it used to. +Trigger ``linux.file_watch.line`` still emits the same payload as it used to. Just the way to provide the file_path to tail has changed. diff --git a/contrib/packs/actions/install.meta.yaml b/contrib/packs/actions/install.meta.yaml index 191accd1c30..1b8d0d572a1 100644 --- a/contrib/packs/actions/install.meta.yaml +++ b/contrib/packs/actions/install.meta.yaml @@ -35,6 +35,6 @@ timeout: default: 600 required: false - description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout + description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout type: integer diff --git a/contrib/packs/actions/setup_virtualenv.yaml b/contrib/packs/actions/setup_virtualenv.yaml index 47091705f3e..18d1b3df157 100644 --- a/contrib/packs/actions/setup_virtualenv.yaml +++ b/contrib/packs/actions/setup_virtualenv.yaml @@ -27,5 +27,5 @@ timeout: default: 600 required: false - description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout + description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout type: integer diff --git a/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml b/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml index 60d79a5b740..26711df8500 100644 --- a/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml +++ b/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml @@ -23,7 +23,7 @@ roles: default: [] required: false - description: A list of roles that are permitted to respond to the action (if nothing provided, all are permitted) - REQUIRES RBAC FEATURES + description: A list of roles that are permitted to respond to the action (if nothing provided, all are permitted) - REQUIRES RBAC FEATURES type: array users: default: [] diff --git a/dev_docs/Troubleshooting_Guide.rst b/dev_docs/Troubleshooting_Guide.rst index f61cedcba4e..4e1c1f22d21 100644 --- a/dev_docs/Troubleshooting_Guide.rst +++ b/dev_docs/Troubleshooting_Guide.rst @@ -28,7 +28,7 @@ Troubleshooting Guide $ sudo netstat -tupln | grep 910 tcp 0 0 0.0.0.0:9100 0.0.0.0:* LISTEN 32420/python tcp 0 0 0.0.0.0:9102 0.0.0.0:* LISTEN 32403/python - + As we can see from above output port ``9101`` is not even up. To verify this let us try another command: .. code:: bash @@ -36,10 +36,10 @@ As we can see from above output port ``9101`` is not even up. To verify this let $ ps auxww | grep st2 | grep 910 vagrant 32420 0.2 1.5 79228 31364 pts/10 Ss+ 18:27 0:00 /home/vagrant/git/st2/virtualenv/bin/python ./virtualenv/bin/gunicorn st2auth.wsgi:application -k eventlet -b 0.0.0.0:9100 --workers 1 - vagrant@ether git/st2 (master %) » ps auxww | grep st2 | grep 32403 + vagrant@ether git/st2 (master %) » ps auxww | grep st2 | grep 32403 vagrant 32403 0.2 1.5 79228 31364 pts/3 Ss+ 18:27 0:00 /home/vagrant/git/st2/virtualenv/bin/python ./virtualenv/bin/gunicorn st2stream.wsgi:application -k eventlet -b 0.0.0.0:9102 --workers 1 - + - This suggests that the API process crashed, we can verify that by running ``screen -ls``.:: .. code:: bash @@ -51,19 +51,19 @@ As we can see from above output port ``9101`` is not even up. To verify this let 15767.st2-sensorcontainer (04/26/2016 06:39:10 PM) (Detached) 15762.st2-stream (04/26/2016 06:39:10 PM) (Detached) 3 Sockets in /var/run/screen/S-vagrant. - -- Now let us check the logs for any errors: + +- Now let us check the logs for any errors: .. code:: bash tail logs/st2api.log - 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger updated. Trigger.id=570e9704909a5030cf758e6d - (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, 'name': u'st2.sensor.process_exit', - 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', 'type': u'core.st2.sensor.process_exit', + 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger updated. Trigger.id=570e9704909a5030cf758e6d + (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, 'name': u'st2.sensor.process_exit', + 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', 'type': u'core.st2.sensor.process_exit', 'id': '570e9704909a5030cf758e6d', 'pack': u'core'}) - 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger created for parameter-less TriggerType. - Trigger.id=570e9704909a5030cf758e6d (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, - 'name': u'st2.sensor.process_exit', 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', + 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger created for parameter-less TriggerType. + Trigger.id=570e9704909a5030cf758e6d (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, + 'name': u'st2.sensor.process_exit', 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', 'type': u'core.st2.sensor.process_exit', 'id': '570e9704909a5030cf758e6d', 'pack': u'core'}) 2016-04-26 18:27:15,605 140317722756912 DEBUG base [-] Conflict while trying to save in DB. Traceback (most recent call last): @@ -94,7 +94,7 @@ As we can see from above output port ``9101`` is not even up. To verify this let NotUniqueError: Could not save document (E11000 duplicate key error index: st2.role_d_b.$name_1 dup key: { : "system_admin" }) 2016-04-26 18:27:15,676 140317722756912 INFO driver [-] Generating grammar tables from /usr/lib/python2.7/lib2to3/Grammar.txt 2016-04-26 18:27:15,693 140317722756912 INFO driver [-] Generating grammar tables from /usr/lib/python2.7/lib2to3/PatternGrammar.txt - + - To figure out whats wrong let us dig down further. Activate the virtualenv in st2 and run following command : .. code:: bash @@ -108,7 +108,7 @@ The above mentioned command will give out logs, we may find some error in the en File "/home/vagrant/git/st2/st2common/st2common/models/api/keyvalue.py", line 19, in from keyczar.keys import AesKey ImportError: No module named keyczar.keys - + So the problem is : module keyczar is missing. This module can be downloaded using following command: *Solution:* @@ -116,7 +116,7 @@ So the problem is : module keyczar is missing. This module can be downloaded usi .. code:: bash (virtualenv) $ pip install python-keyczar - + This should fix the issue. Now deactivate the virtual env and run ``tools/launchdev.sh restart`` diff --git a/st2client/Makefile b/st2client/Makefile index e17db7e4f65..9d6cf70a660 100644 --- a/st2client/Makefile +++ b/st2client/Makefile @@ -9,7 +9,7 @@ RELEASE=1 COMPONENTS := st2client .PHONY: rpm -rpm: +rpm: $(PY3) setup.py bdist_rpm --python=$(PY3) mkdir -p $(RPM_ROOT)/RPMS/noarch cp dist/$(COMPONENTS)*noarch.rpm $(RPM_ROOT)/RPMS/noarch/$(COMPONENTS)-$(VER)-$(RELEASE).noarch.rpm diff --git a/st2common/bin/st2-run-pack-tests b/st2common/bin/st2-run-pack-tests index 9f7c2306ab0..bed28267602 100755 --- a/st2common/bin/st2-run-pack-tests +++ b/st2common/bin/st2-run-pack-tests @@ -322,7 +322,7 @@ if [ "${ENABLE_COVERAGE}" = true ]; then # Base options to enable test coverage reporting # --with-coverage : enables coverage reporting - # --cover-erase : removes old coverage reports before starting + # --cover-erase : removes old coverage reports before starting NOSE_OPTS+=(--with-coverage --cover-erase) # Now, by default nosetests reports test coverage for every module found diff --git a/st2reactor/Makefile b/st2reactor/Makefile index 232abed4dd5..cd3eb75a3ee 100644 --- a/st2reactor/Makefile +++ b/st2reactor/Makefile @@ -7,7 +7,7 @@ VER=0.4.0 COMPONENTS := st2reactor .PHONY: rpm -rpm: +rpm: pushd ~ && rpmdev-setuptree && popd tar --transform=s~^~$(COMPONENTS)-$(VER)/~ -czf $(RPM_SOURCES_DIR)/$(COMPONENTS).tar.gz bin conf $(COMPONENTS) cp packaging/rpm/$(COMPONENTS).spec $(RPM_SPECS_DIR)/ diff --git a/st2tests/testpacks/checks/actions/check_loadavg.yaml b/st2tests/testpacks/checks/actions/check_loadavg.yaml index 06abc652278..ac38037d6c8 100644 --- a/st2tests/testpacks/checks/actions/check_loadavg.yaml +++ b/st2tests/testpacks/checks/actions/check_loadavg.yaml @@ -4,8 +4,8 @@ description: "Check CPU Load Average on a Host" enabled: true entry_point: "checks/check_loadavg.py" - parameters: - period: + parameters: + period: type: "string" description: "Time period for load avg: 5,10,15 minutes, or 'all'" default: "all" diff --git a/st2tests/testpacks/errorcheck/actions/exit-code.sh b/st2tests/testpacks/errorcheck/actions/exit-code.sh index 2e6eadf6a2c..5320dc2f363 100755 --- a/st2tests/testpacks/errorcheck/actions/exit-code.sh +++ b/st2tests/testpacks/errorcheck/actions/exit-code.sh @@ -6,4 +6,4 @@ if [ -n "$1" ] exit_code=$1 fi -exit $exit_code +exit $exit_code diff --git a/tox.ini b/tox.ini index de40b858789..451ceee8e1e 100644 --- a/tox.ini +++ b/tox.ini @@ -71,7 +71,7 @@ commands = [testenv:py36-integration] basepython = python3.6 -setenv = PYTHONPATH = {toxinidir}/external:{toxinidir}/st2common:{toxinidir}/st2auth:{toxinidir}/st2api:{toxinidir}/st2actions:{toxinidir}/st2exporter:{toxinidir}/st2reactor:{toxinidir}/st2tests:{toxinidir}/contrib/runners/action_chain_runner:{toxinidir}/contrib/runners/local_runner:{toxinidir}/contrib/runners/python_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/noop_runner:{toxinidir}/contrib/runners/announcement_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/orquesta_runner:{toxinidir}/contrib/runners/inquirer_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/winrm_runner +setenv = PYTHONPATH = {toxinidir}/external:{toxinidir}/st2common:{toxinidir}/st2auth:{toxinidir}/st2api:{toxinidir}/st2actions:{toxinidir}/st2exporter:{toxinidir}/st2reactor:{toxinidir}/st2tests:{toxinidir}/contrib/runners/action_chain_runner:{toxinidir}/contrib/runners/local_runner:{toxinidir}/contrib/runners/python_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/noop_runner:{toxinidir}/contrib/runners/announcement_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/orquesta_runner:{toxinidir}/contrib/runners/inquirer_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/winrm_runner VIRTUALENV_DIR = {envdir} passenv = NOSE_WITH_TIMER TRAVIS ST2_CI install_command = pip install -U --force-reinstall {opts} {packages} From 56101b8481a330a05e7fe668d762ca9ba1c386ac Mon Sep 17 00:00:00 2001 From: Tomaz Muraus Date: Thu, 18 Feb 2021 13:00:06 +0100 Subject: [PATCH 15/25] Make sure sample config doesn't contain trailing whitespace. --- conf/st2.conf.sample | 4 ++-- tools/config_gen.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/conf/st2.conf.sample b/conf/st2.conf.sample index 758b743e75f..488939eb55e 100644 --- a/conf/st2.conf.sample +++ b/conf/st2.conf.sample @@ -2,7 +2,7 @@ # Note: This file is automatically generated using tools/config_gen.py - DO NOT UPDATE MANUALLY [action_sensor] -# List of execution statuses for which a trigger will be emitted. +# List of execution statuses for which a trigger will be emitted. emit_when = succeeded,failed,timeout,canceled,abandoned # comma separated list allowed here. # Whether to enable or disable the ability to post a trigger on action. enable = True @@ -170,7 +170,7 @@ trigger_instances_ttl = None # Allow encryption of values in key value stored qualified as "secret". enable_encryption = True # Location of the symmetric encryption key for encrypting values in kvstore. This key should be in JSON and should've been generated using st2-generate-symmetric-crypto-key tool. -encryption_key_path = +encryption_key_path = [log] # Exclusion list of loggers to omit. diff --git a/tools/config_gen.py b/tools/config_gen.py index e0004d04e1b..309bdf608fd 100755 --- a/tools/config_gen.py +++ b/tools/config_gen.py @@ -170,8 +170,8 @@ def _print_options(opt_group, options): else: value = opt.default - print("# %s" % opt.help) - print("%s = %s" % (opt.name, value)) + print(("# %s" % opt.help).strip()) + print(("%s = %s" % (opt.name, value)).strip()) def main(args): From 100fbdb45d24d5829906f1e5e1a9fc1b398a7bf2 Mon Sep 17 00:00:00 2001 From: StackStorm CodeFormat Date: Thu, 18 Feb 2021 13:00:18 +0100 Subject: [PATCH 16/25] Remove trailing whitespace. --- .circleci/config.yml | 2 +- CHANGELOG.rst | 2 +- OWNERS.md | 2 +- README.md | 12 ++++---- contrib/core/CHANGES.md | 2 +- contrib/examples/actions/forloop_chain.yaml | 2 +- .../actions/forloop_push_github_repos.yaml | 2 +- .../actions/orquesta-mock-create-vm.yaml | 2 +- .../actions/workflows/orquesta-delay.yaml | 2 +- .../orquesta-error-handling-continue.yaml | 2 +- ...orquesta-error-handling-fail-manually.yaml | 2 +- .../orquesta-error-handling-noop.yaml | 2 +- .../workflows/orquesta-fail-manually.yaml | 2 +- .../actions/workflows/orquesta-join.yaml | 2 +- .../orquesta-remediate-then-fail.yaml | 2 +- .../workflows/orquesta-rollback-retry.yaml | 2 +- .../workflows/orquesta-sequential.yaml | 2 +- .../orquesta-with-items-concurrency.yaml | 2 +- .../workflows/orquesta-with-items.yaml | 2 +- .../tests/orquesta-fail-input-rendering.yaml | 2 +- ...rquesta-fail-inspection-task-contents.yaml | 2 +- .../tests/orquesta-fail-output-rendering.yaml | 2 +- .../tests/orquesta-fail-start-task.yaml | 2 +- .../tests/orquesta-fail-task-publish.yaml | 2 +- .../tests/orquesta-fail-task-transition.yaml | 2 +- .../tests/orquesta-fail-vars-rendering.yaml | 2 +- .../tests/orquesta-test-pause-resume.yaml | 2 +- .../workflows/tests/orquesta-test-rerun.yaml | 2 +- .../tests/orquesta-test-with-items.yaml | 2 +- contrib/linux/README.md | 2 +- contrib/linux/sensors/README.md | 4 +-- contrib/packs/actions/install.meta.yaml | 2 +- contrib/packs/actions/setup_virtualenv.yaml | 2 +- .../inquirer_runner/runner.yaml | 2 +- dev_docs/Troubleshooting_Guide.rst | 28 +++++++++---------- st2client/Makefile | 2 +- st2common/bin/st2-run-pack-tests | 2 +- st2reactor/Makefile | 2 +- .../checks/actions/check_loadavg.yaml | 4 +-- .../testpacks/errorcheck/actions/exit-code.sh | 2 +- tox.ini | 2 +- 41 files changed, 61 insertions(+), 61 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index bcbacbe3bd6..bde4a90784d 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1,5 +1,5 @@ # Setup in CircleCI account the following ENV variables: -# PACKAGECLOUD_ORGANIZATION (default: stackstorm) +# PACKAGECLOUD_ORGANIZATION (default: stackstorm) # PACKAGECLOUD_TOKEN version: 2 jobs: diff --git a/CHANGELOG.rst b/CHANGELOG.rst index d0460f4ea2a..a8b52fb6740 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -27,7 +27,7 @@ Changed * Improve the st2-self-check script to echo to stderr and exit if it isn't run with a ST2_AUTH_TOKEN or ST2_API_KEY environment variable. (improvement) #5068 -* Added timeout parameter for packs.install action to help with long running installs that exceed the +* Added timeout parameter for packs.install action to help with long running installs that exceed the default timeout of 600 sec which is defined by the python_script action runner (improvement) #5084 Contributed by @hnanchahal diff --git a/OWNERS.md b/OWNERS.md index dfb8fb87bc2..501e7b8f144 100644 --- a/OWNERS.md +++ b/OWNERS.md @@ -74,7 +74,7 @@ Thank you, Friends! * Johan Dahlberg ([@johandahlberg](https://github.com/johandahlberg)) - Using st2 for Bioinformatics/Science project, providing feedback & contributions in Ansible, Community, Workflows. [Case Study](https://stackstorm.com/case-study-scilifelab/). * Johan Hermansson ([@johanherman](https://github.com/johanherman)) - Using st2 for Bioinformatics/Science project, feedback & contributions in Ansible, Community, Workflows. [Case Study](https://stackstorm.com/case-study-scilifelab/). * Lakshmi Kannan ([@lakshmi-kannan](https://github.com/lakshmi-kannan)) - early Stormer. Initial Core platform architecture, scalability, reliability, Team Leadership during the project hard times. -* Lindsay Hill ([@LindsayHill](https://github.com/LindsayHill)) - ex StackStorm product manager that made a significant impact building an ecosystem we see today. +* Lindsay Hill ([@LindsayHill](https://github.com/LindsayHill)) - ex StackStorm product manager that made a significant impact building an ecosystem we see today. * Manas Kelshikar ([@manasdk](https://github.com/manasdk)) - ex Stormer. Developed (well) early core platform features. * Vineesh Jain ([@VineeshJain](https://github.com/VineeshJain)) - ex Stormer. Community, Tests, Core, QA. * Warren Van Winckel ([@warrenvw](https://github.com/warrenvw)) - ex Stormer. Docker, Kubernetes, Vagrant, Infrastructure. diff --git a/README.md b/README.md index 4d84895bbd6..b22e908d5cd 100644 --- a/README.md +++ b/README.md @@ -4,12 +4,12 @@ [![Build Status](https://github.com/StackStorm/st2/workflows/ci-checks/badge.svg?branch=master)](https://github.com/StackStorm/st2/actions?query=branch%3Amaster) [![Travis Integration Tests Status](https://travis-ci.org/StackStorm/st2.svg?branch=master)](https://travis-ci.org/StackStorm/st2) -[![Packages Build Status](https://circleci.com/gh/StackStorm/st2/tree/master.svg?style=shield)](https://circleci.com/gh/StackStorm/st2) -[![Codecov](https://codecov.io/github/StackStorm/st2/badge.svg?branch=master&service=github)](https://codecov.io/github/StackStorm/st2?branch=master) -[![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1833/badge)](https://bestpractices.coreinfrastructure.org/projects/1833) -![Python 3.6](https://img.shields.io/badge/python-3.6-blue) -[![Apache Licensed](https://img.shields.io/github/license/StackStorm/st2)](LICENSE) -[![Join our community Slack](https://img.shields.io/badge/slack-stackstorm-success.svg?logo=slack)](https://stackstorm.com/community-signup) +[![Packages Build Status](https://circleci.com/gh/StackStorm/st2/tree/master.svg?style=shield)](https://circleci.com/gh/StackStorm/st2) +[![Codecov](https://codecov.io/github/StackStorm/st2/badge.svg?branch=master&service=github)](https://codecov.io/github/StackStorm/st2?branch=master) +[![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1833/badge)](https://bestpractices.coreinfrastructure.org/projects/1833) +![Python 3.6](https://img.shields.io/badge/python-3.6-blue) +[![Apache Licensed](https://img.shields.io/github/license/StackStorm/st2)](LICENSE) +[![Join our community Slack](https://img.shields.io/badge/slack-stackstorm-success.svg?logo=slack)](https://stackstorm.com/community-signup) [![Forum](https://img.shields.io/discourse/https/forum.stackstorm.com/posts.svg)](https://forum.stackstorm.com/) --- diff --git a/contrib/core/CHANGES.md b/contrib/core/CHANGES.md index b9c04efa88d..c0b1692b039 100644 --- a/contrib/core/CHANGES.md +++ b/contrib/core/CHANGES.md @@ -1,5 +1,5 @@ # Changelog - + ## 0.3.1 * Add new ``core.uuid`` action for generating type 1 and type 4 UUIDs. diff --git a/contrib/examples/actions/forloop_chain.yaml b/contrib/examples/actions/forloop_chain.yaml index f226eae4202..86ead5303a6 100644 --- a/contrib/examples/actions/forloop_chain.yaml +++ b/contrib/examples/actions/forloop_chain.yaml @@ -6,7 +6,7 @@ entry_point: "chains/forloop_chain.yaml" enabled: true parameters: github_organization_url: - type: "string" + type: "string" description: "Organization url to parse data from" default: "https://github.com/StackStorm-Exchange" required: false diff --git a/contrib/examples/actions/forloop_push_github_repos.yaml b/contrib/examples/actions/forloop_push_github_repos.yaml index 3ff06eabc33..878772636aa 100644 --- a/contrib/examples/actions/forloop_push_github_repos.yaml +++ b/contrib/examples/actions/forloop_push_github_repos.yaml @@ -5,7 +5,7 @@ description: "Action to push the data to an external service" enabled: true entry_point: "pythonactions/forloop_push_github_repos.py" parameters: - data_to_push: + data_to_push: type: "object" description: "Dictonary of the data to be pushed" required: true diff --git a/contrib/examples/actions/orquesta-mock-create-vm.yaml b/contrib/examples/actions/orquesta-mock-create-vm.yaml index 85e774a7024..35c5ab26d81 100644 --- a/contrib/examples/actions/orquesta-mock-create-vm.yaml +++ b/contrib/examples/actions/orquesta-mock-create-vm.yaml @@ -15,7 +15,7 @@ parameters: required: true type: string ip: - default: "10.1.23.99" + default: "10.1.23.99" required: true type: string meta: diff --git a/contrib/examples/actions/workflows/orquesta-delay.yaml b/contrib/examples/actions/workflows/orquesta-delay.yaml index a0793f8bf6d..82a131712c4 100644 --- a/contrib/examples/actions/workflows/orquesta-delay.yaml +++ b/contrib/examples/actions/workflows/orquesta-delay.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic sequential workflow. input: diff --git a/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml b/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml index 5d9c6f22a0a..80047d2e5ed 100644 --- a/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml +++ b/contrib/examples/actions/workflows/orquesta-error-handling-continue.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that demonstrates error handler with continue. input: diff --git a/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml b/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml index da9179b5edb..4e3dfa38c2c 100644 --- a/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml +++ b/contrib/examples/actions/workflows/orquesta-error-handling-fail-manually.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that demonstrates error handler with remediation and explicit fail. input: diff --git a/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml b/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml index 61b14a3c11d..e949dc37420 100644 --- a/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml +++ b/contrib/examples/actions/workflows/orquesta-error-handling-noop.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that demonstrates error handler with noop to ignore error. input: diff --git a/contrib/examples/actions/workflows/orquesta-fail-manually.yaml b/contrib/examples/actions/workflows/orquesta-fail-manually.yaml index 936db68ff3e..b86d8ef25bc 100644 --- a/contrib/examples/actions/workflows/orquesta-fail-manually.yaml +++ b/contrib/examples/actions/workflows/orquesta-fail-manually.yaml @@ -11,7 +11,7 @@ tasks: - when: <% failed() %> publish: - task_name: <% task().task_name %> - - task_exit_code: <% task().result.stdout %> + - task_exit_code: <% task().result.stdout %> do: - log - fail diff --git a/contrib/examples/actions/workflows/orquesta-join.yaml b/contrib/examples/actions/workflows/orquesta-join.yaml index eaf09fed66a..a247423948d 100644 --- a/contrib/examples/actions/workflows/orquesta-join.yaml +++ b/contrib/examples/actions/workflows/orquesta-join.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that demonstrate branching and join. vars: diff --git a/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml b/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml index 936db68ff3e..b86d8ef25bc 100644 --- a/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml +++ b/contrib/examples/actions/workflows/orquesta-remediate-then-fail.yaml @@ -11,7 +11,7 @@ tasks: - when: <% failed() %> publish: - task_name: <% task().task_name %> - - task_exit_code: <% task().result.stdout %> + - task_exit_code: <% task().result.stdout %> do: - log - fail diff --git a/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml b/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml index 0d80b0dbcb5..a1f203fb095 100644 --- a/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml +++ b/contrib/examples/actions/workflows/orquesta-rollback-retry.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: > A sample workflow that demonstrates how to handle rollback and retry on error. In this example, the workflow will loop until the file /tmp/done exists. A parallel task will wait for some time diff --git a/contrib/examples/actions/workflows/orquesta-sequential.yaml b/contrib/examples/actions/workflows/orquesta-sequential.yaml index 3a03409d36d..404681a3698 100644 --- a/contrib/examples/actions/workflows/orquesta-sequential.yaml +++ b/contrib/examples/actions/workflows/orquesta-sequential.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic sequential workflow. input: diff --git a/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml b/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml index e20b9078988..6bcbb82c583 100644 --- a/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml +++ b/contrib/examples/actions/workflows/orquesta-with-items-concurrency.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A workflow demonstrating with items and concurrent processing. input: diff --git a/contrib/examples/actions/workflows/orquesta-with-items.yaml b/contrib/examples/actions/workflows/orquesta-with-items.yaml index 6a2cc4af494..5833e270510 100644 --- a/contrib/examples/actions/workflows/orquesta-with-items.yaml +++ b/contrib/examples/actions/workflows/orquesta-with-items.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A workflow demonstrating with items. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml index ce935f62f7d..907a18e8bfe 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-input-rendering.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error while evaluating input. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml index c0322d025e8..a8be5311807 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-inspection-task-contents.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic sequential workflow with inspection error(s). input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml index dd1e5164411..003ab8b69db 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-output-rendering.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error while evaluating output. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml index a0deab1d8f3..0c23ee6a82e 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-start-task.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error in the rendering of the starting task. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml index 0887d4a7beb..149fb93b97b 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-task-publish.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow that fails on publish during task transition. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml index 8fd2a94d8a6..4d4d9e5f392 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-task-transition.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error while evaluating task transition. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml b/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml index 403728100ab..4ddd9867557 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-fail-vars-rendering.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A basic workflow with error while evaluating vars. input: diff --git a/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml b/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml index 7123727cc31..285bf972d7a 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-test-pause-resume.yaml @@ -19,4 +19,4 @@ tasks: task2: action: core.local input: - cmd: 'echo "<% $.var1 %>"' + cmd: 'echo "<% $.var1 %>"' diff --git a/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml b/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml index 11eb22a721a..3a4b20cee02 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-test-rerun.yaml @@ -4,7 +4,7 @@ description: A sample workflow used to test the rerun feature. input: - tempfile - + tasks: task1: action: core.noop diff --git a/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml b/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml index 8af6899b595..6e24c0ec411 100644 --- a/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml +++ b/contrib/examples/actions/workflows/tests/orquesta-test-with-items.yaml @@ -1,5 +1,5 @@ version: 1.0 - + description: A workflow for testing with items and concurrency. input: diff --git a/contrib/linux/README.md b/contrib/linux/README.md index 33d872cf868..e2b9f09d44c 100644 --- a/contrib/linux/README.md +++ b/contrib/linux/README.md @@ -55,4 +55,4 @@ Example trigger payload: ## Troubleshooting -* On CentOS7/RHEL7, dig is not installed by default. Run ``sudo yum install bind-utils`` to install. \ No newline at end of file +* On CentOS7/RHEL7, dig is not installed by default. Run ``sudo yum install bind-utils`` to install. \ No newline at end of file diff --git a/contrib/linux/sensors/README.md b/contrib/linux/sensors/README.md index 7924e91e17a..084fcad6a6e 100644 --- a/contrib/linux/sensors/README.md +++ b/contrib/linux/sensors/README.md @@ -1,6 +1,6 @@ ## NOTICE -File watch sensor has been updated to use trigger with parameters supplied via a rule approach. Tailing a file path supplied via a config file is now deprecated. +File watch sensor has been updated to use trigger with parameters supplied via a rule approach. Tailing a file path supplied via a config file is now deprecated. An example rule to supply a file path is as follows: @@ -25,5 +25,5 @@ action: ``` -Trigger ``linux.file_watch.line`` still emits the same payload as it used to. +Trigger ``linux.file_watch.line`` still emits the same payload as it used to. Just the way to provide the file_path to tail has changed. diff --git a/contrib/packs/actions/install.meta.yaml b/contrib/packs/actions/install.meta.yaml index 1b8d0d572a1..191accd1c30 100644 --- a/contrib/packs/actions/install.meta.yaml +++ b/contrib/packs/actions/install.meta.yaml @@ -35,6 +35,6 @@ timeout: default: 600 required: false - description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout + description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout type: integer diff --git a/contrib/packs/actions/setup_virtualenv.yaml b/contrib/packs/actions/setup_virtualenv.yaml index 18d1b3df157..47091705f3e 100644 --- a/contrib/packs/actions/setup_virtualenv.yaml +++ b/contrib/packs/actions/setup_virtualenv.yaml @@ -27,5 +27,5 @@ timeout: default: 600 required: false - description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout + description: Action timeout in seconds. Action will get killed if it doesn't finish in timeout type: integer diff --git a/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml b/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml index 26711df8500..60d79a5b740 100644 --- a/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml +++ b/contrib/runners/inquirer_runner/inquirer_runner/runner.yaml @@ -23,7 +23,7 @@ roles: default: [] required: false - description: A list of roles that are permitted to respond to the action (if nothing provided, all are permitted) - REQUIRES RBAC FEATURES + description: A list of roles that are permitted to respond to the action (if nothing provided, all are permitted) - REQUIRES RBAC FEATURES type: array users: default: [] diff --git a/dev_docs/Troubleshooting_Guide.rst b/dev_docs/Troubleshooting_Guide.rst index 4e1c1f22d21..f61cedcba4e 100644 --- a/dev_docs/Troubleshooting_Guide.rst +++ b/dev_docs/Troubleshooting_Guide.rst @@ -28,7 +28,7 @@ Troubleshooting Guide $ sudo netstat -tupln | grep 910 tcp 0 0 0.0.0.0:9100 0.0.0.0:* LISTEN 32420/python tcp 0 0 0.0.0.0:9102 0.0.0.0:* LISTEN 32403/python - + As we can see from above output port ``9101`` is not even up. To verify this let us try another command: .. code:: bash @@ -36,10 +36,10 @@ As we can see from above output port ``9101`` is not even up. To verify this let $ ps auxww | grep st2 | grep 910 vagrant 32420 0.2 1.5 79228 31364 pts/10 Ss+ 18:27 0:00 /home/vagrant/git/st2/virtualenv/bin/python ./virtualenv/bin/gunicorn st2auth.wsgi:application -k eventlet -b 0.0.0.0:9100 --workers 1 - vagrant@ether git/st2 (master %) » ps auxww | grep st2 | grep 32403 + vagrant@ether git/st2 (master %) » ps auxww | grep st2 | grep 32403 vagrant 32403 0.2 1.5 79228 31364 pts/3 Ss+ 18:27 0:00 /home/vagrant/git/st2/virtualenv/bin/python ./virtualenv/bin/gunicorn st2stream.wsgi:application -k eventlet -b 0.0.0.0:9102 --workers 1 - + - This suggests that the API process crashed, we can verify that by running ``screen -ls``.:: .. code:: bash @@ -51,19 +51,19 @@ As we can see from above output port ``9101`` is not even up. To verify this let 15767.st2-sensorcontainer (04/26/2016 06:39:10 PM) (Detached) 15762.st2-stream (04/26/2016 06:39:10 PM) (Detached) 3 Sockets in /var/run/screen/S-vagrant. - -- Now let us check the logs for any errors: + +- Now let us check the logs for any errors: .. code:: bash tail logs/st2api.log - 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger updated. Trigger.id=570e9704909a5030cf758e6d - (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, 'name': u'st2.sensor.process_exit', - 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', 'type': u'core.st2.sensor.process_exit', + 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger updated. Trigger.id=570e9704909a5030cf758e6d + (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, 'name': u'st2.sensor.process_exit', + 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', 'type': u'core.st2.sensor.process_exit', 'id': '570e9704909a5030cf758e6d', 'pack': u'core'}) - 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger created for parameter-less TriggerType. - Trigger.id=570e9704909a5030cf758e6d (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, - 'name': u'st2.sensor.process_exit', 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', + 2016-04-26 18:27:15,603 140317722756912 AUDIT triggers [-] Trigger created for parameter-less TriggerType. + Trigger.id=570e9704909a5030cf758e6d (trigger_db={'description': None, 'parameters': {}, 'ref_count': 0, + 'name': u'st2.sensor.process_exit', 'uid': u'trigger:core:st2.sensor.process_exit:5f02f0889301fd7be1ac972c11bf3e7d', 'type': u'core.st2.sensor.process_exit', 'id': '570e9704909a5030cf758e6d', 'pack': u'core'}) 2016-04-26 18:27:15,605 140317722756912 DEBUG base [-] Conflict while trying to save in DB. Traceback (most recent call last): @@ -94,7 +94,7 @@ As we can see from above output port ``9101`` is not even up. To verify this let NotUniqueError: Could not save document (E11000 duplicate key error index: st2.role_d_b.$name_1 dup key: { : "system_admin" }) 2016-04-26 18:27:15,676 140317722756912 INFO driver [-] Generating grammar tables from /usr/lib/python2.7/lib2to3/Grammar.txt 2016-04-26 18:27:15,693 140317722756912 INFO driver [-] Generating grammar tables from /usr/lib/python2.7/lib2to3/PatternGrammar.txt - + - To figure out whats wrong let us dig down further. Activate the virtualenv in st2 and run following command : .. code:: bash @@ -108,7 +108,7 @@ The above mentioned command will give out logs, we may find some error in the en File "/home/vagrant/git/st2/st2common/st2common/models/api/keyvalue.py", line 19, in from keyczar.keys import AesKey ImportError: No module named keyczar.keys - + So the problem is : module keyczar is missing. This module can be downloaded using following command: *Solution:* @@ -116,7 +116,7 @@ So the problem is : module keyczar is missing. This module can be downloaded usi .. code:: bash (virtualenv) $ pip install python-keyczar - + This should fix the issue. Now deactivate the virtual env and run ``tools/launchdev.sh restart`` diff --git a/st2client/Makefile b/st2client/Makefile index 9d6cf70a660..e17db7e4f65 100644 --- a/st2client/Makefile +++ b/st2client/Makefile @@ -9,7 +9,7 @@ RELEASE=1 COMPONENTS := st2client .PHONY: rpm -rpm: +rpm: $(PY3) setup.py bdist_rpm --python=$(PY3) mkdir -p $(RPM_ROOT)/RPMS/noarch cp dist/$(COMPONENTS)*noarch.rpm $(RPM_ROOT)/RPMS/noarch/$(COMPONENTS)-$(VER)-$(RELEASE).noarch.rpm diff --git a/st2common/bin/st2-run-pack-tests b/st2common/bin/st2-run-pack-tests index bed28267602..9f7c2306ab0 100755 --- a/st2common/bin/st2-run-pack-tests +++ b/st2common/bin/st2-run-pack-tests @@ -322,7 +322,7 @@ if [ "${ENABLE_COVERAGE}" = true ]; then # Base options to enable test coverage reporting # --with-coverage : enables coverage reporting - # --cover-erase : removes old coverage reports before starting + # --cover-erase : removes old coverage reports before starting NOSE_OPTS+=(--with-coverage --cover-erase) # Now, by default nosetests reports test coverage for every module found diff --git a/st2reactor/Makefile b/st2reactor/Makefile index cd3eb75a3ee..232abed4dd5 100644 --- a/st2reactor/Makefile +++ b/st2reactor/Makefile @@ -7,7 +7,7 @@ VER=0.4.0 COMPONENTS := st2reactor .PHONY: rpm -rpm: +rpm: pushd ~ && rpmdev-setuptree && popd tar --transform=s~^~$(COMPONENTS)-$(VER)/~ -czf $(RPM_SOURCES_DIR)/$(COMPONENTS).tar.gz bin conf $(COMPONENTS) cp packaging/rpm/$(COMPONENTS).spec $(RPM_SPECS_DIR)/ diff --git a/st2tests/testpacks/checks/actions/check_loadavg.yaml b/st2tests/testpacks/checks/actions/check_loadavg.yaml index ac38037d6c8..06abc652278 100644 --- a/st2tests/testpacks/checks/actions/check_loadavg.yaml +++ b/st2tests/testpacks/checks/actions/check_loadavg.yaml @@ -4,8 +4,8 @@ description: "Check CPU Load Average on a Host" enabled: true entry_point: "checks/check_loadavg.py" - parameters: - period: + parameters: + period: type: "string" description: "Time period for load avg: 5,10,15 minutes, or 'all'" default: "all" diff --git a/st2tests/testpacks/errorcheck/actions/exit-code.sh b/st2tests/testpacks/errorcheck/actions/exit-code.sh index 5320dc2f363..2e6eadf6a2c 100755 --- a/st2tests/testpacks/errorcheck/actions/exit-code.sh +++ b/st2tests/testpacks/errorcheck/actions/exit-code.sh @@ -6,4 +6,4 @@ if [ -n "$1" ] exit_code=$1 fi -exit $exit_code +exit $exit_code diff --git a/tox.ini b/tox.ini index 451ceee8e1e..de40b858789 100644 --- a/tox.ini +++ b/tox.ini @@ -71,7 +71,7 @@ commands = [testenv:py36-integration] basepython = python3.6 -setenv = PYTHONPATH = {toxinidir}/external:{toxinidir}/st2common:{toxinidir}/st2auth:{toxinidir}/st2api:{toxinidir}/st2actions:{toxinidir}/st2exporter:{toxinidir}/st2reactor:{toxinidir}/st2tests:{toxinidir}/contrib/runners/action_chain_runner:{toxinidir}/contrib/runners/local_runner:{toxinidir}/contrib/runners/python_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/noop_runner:{toxinidir}/contrib/runners/announcement_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/orquesta_runner:{toxinidir}/contrib/runners/inquirer_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/winrm_runner +setenv = PYTHONPATH = {toxinidir}/external:{toxinidir}/st2common:{toxinidir}/st2auth:{toxinidir}/st2api:{toxinidir}/st2actions:{toxinidir}/st2exporter:{toxinidir}/st2reactor:{toxinidir}/st2tests:{toxinidir}/contrib/runners/action_chain_runner:{toxinidir}/contrib/runners/local_runner:{toxinidir}/contrib/runners/python_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/noop_runner:{toxinidir}/contrib/runners/announcement_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/remote_runner:{toxinidir}/contrib/runners/orquesta_runner:{toxinidir}/contrib/runners/inquirer_runner:{toxinidir}/contrib/runners/http_runner:{toxinidir}/contrib/runners/winrm_runner VIRTUALENV_DIR = {envdir} passenv = NOSE_WITH_TIMER TRAVIS ST2_CI install_command = pip install -U --force-reinstall {opts} {packages} From 549bcd00750a2ac31181279c3fd30b3947b7f30b Mon Sep 17 00:00:00 2001 From: Tomaz Muraus Date: Thu, 18 Feb 2021 16:26:07 +0100 Subject: [PATCH 17/25] Update black config so we don't try to reformat submodule. --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 4d034829943..1889c6a5da7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,7 @@ exclude = ''' | \.git | \.virtualenv | __pycache__ + | test_content_version )/ ) ''' From 00157676b47373142fec620d124718ff44671534 Mon Sep 17 00:00:00 2001 From: Tomaz Muraus Date: Thu, 18 Feb 2021 17:48:57 +0100 Subject: [PATCH 18/25] Update Makefile. --- Makefile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 8923cb7a5c7..91439ae15bc 100644 --- a/Makefile +++ b/Makefile @@ -391,8 +391,8 @@ black: requirements .pre-commit-checks @echo @echo "================== pre-commit-checks ====================" @echo - pre-commit run trailing-whitespace --all --show-diff-on-failure - pre-commit run check-yaml --all --show-diff-on-failure + . $(VIRTUALENV_DIR)/bin/activate; pre-commit run trailing-whitespace --all --show-diff-on-failure + . $(VIRTUALENV_DIR)/bin/activate; pre-commit run check-yaml --all --show-diff-on-failure .PHONY: lint-api-spec lint-api-spec: requirements .lint-api-spec From efa46112f5091eb33eb18e26e91c1aaadb116533 Mon Sep 17 00:00:00 2001 From: Tomaz Muraus Date: Thu, 18 Feb 2021 18:39:52 +0100 Subject: [PATCH 19/25] Fix lint. --- st2common/st2common/models/api/rbac.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/st2common/st2common/models/api/rbac.py b/st2common/st2common/models/api/rbac.py index bd269ce3d68..ffaff754098 100644 --- a/st2common/st2common/models/api/rbac.py +++ b/st2common/st2common/models/api/rbac.py @@ -228,10 +228,9 @@ def validate(self, validate_role_exists=False): if validate_role_exists: # Validate that the referenced roles exist in the db rbac_service = get_rbac_backend().get_service_class() - rbac_service.validate_roles_exists( - role_names=self.roles - ) # pylint: disable=no-member - + # pylint: disable=no-member + rbac_service.validate_roles_exists(role_names=self.roles) + # pylint: enable=no-member return cleaned From 27a06f6ae8d4cb31c888cd1a15038f9bea87b5f0 Mon Sep 17 00:00:00 2001 From: stanley Date: Thu, 4 Mar 2021 08:05:20 +0000 Subject: [PATCH 20/25] Update version to 3.5dev --- .../runners/action_chain_runner/action_chain_runner/__init__.py | 2 +- .../runners/announcement_runner/announcement_runner/__init__.py | 2 +- contrib/runners/http_runner/http_runner/__init__.py | 2 +- contrib/runners/inquirer_runner/inquirer_runner/__init__.py | 2 +- contrib/runners/local_runner/local_runner/__init__.py | 2 +- contrib/runners/noop_runner/noop_runner/__init__.py | 2 +- contrib/runners/orquesta_runner/orquesta_runner/__init__.py | 2 +- contrib/runners/python_runner/python_runner/__init__.py | 2 +- contrib/runners/remote_runner/remote_runner/__init__.py | 2 +- contrib/runners/winrm_runner/winrm_runner/__init__.py | 2 +- st2actions/st2actions/__init__.py | 2 +- st2api/st2api/__init__.py | 2 +- st2auth/st2auth/__init__.py | 2 +- st2client/st2client/__init__.py | 2 +- st2common/st2common/__init__.py | 2 +- st2reactor/st2reactor/__init__.py | 2 +- st2stream/st2stream/__init__.py | 2 +- st2tests/st2tests/__init__.py | 2 +- 18 files changed, 18 insertions(+), 18 deletions(-) diff --git a/contrib/runners/action_chain_runner/action_chain_runner/__init__.py b/contrib/runners/action_chain_runner/action_chain_runner/__init__.py index bbe290db9a7..ae0bd695f57 100644 --- a/contrib/runners/action_chain_runner/action_chain_runner/__init__.py +++ b/contrib/runners/action_chain_runner/action_chain_runner/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = '3.5dev' diff --git a/contrib/runners/announcement_runner/announcement_runner/__init__.py b/contrib/runners/announcement_runner/announcement_runner/__init__.py index bbe290db9a7..ae0bd695f57 100644 --- a/contrib/runners/announcement_runner/announcement_runner/__init__.py +++ b/contrib/runners/announcement_runner/announcement_runner/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = '3.5dev' diff --git a/contrib/runners/http_runner/http_runner/__init__.py b/contrib/runners/http_runner/http_runner/__init__.py index bbe290db9a7..ae0bd695f57 100644 --- a/contrib/runners/http_runner/http_runner/__init__.py +++ b/contrib/runners/http_runner/http_runner/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = '3.5dev' diff --git a/contrib/runners/inquirer_runner/inquirer_runner/__init__.py b/contrib/runners/inquirer_runner/inquirer_runner/__init__.py index bbe290db9a7..ae0bd695f57 100644 --- a/contrib/runners/inquirer_runner/inquirer_runner/__init__.py +++ b/contrib/runners/inquirer_runner/inquirer_runner/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = '3.5dev' diff --git a/contrib/runners/local_runner/local_runner/__init__.py b/contrib/runners/local_runner/local_runner/__init__.py index bbe290db9a7..ae0bd695f57 100644 --- a/contrib/runners/local_runner/local_runner/__init__.py +++ b/contrib/runners/local_runner/local_runner/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = '3.5dev' diff --git a/contrib/runners/noop_runner/noop_runner/__init__.py b/contrib/runners/noop_runner/noop_runner/__init__.py index bbe290db9a7..ae0bd695f57 100644 --- a/contrib/runners/noop_runner/noop_runner/__init__.py +++ b/contrib/runners/noop_runner/noop_runner/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = '3.5dev' diff --git a/contrib/runners/orquesta_runner/orquesta_runner/__init__.py b/contrib/runners/orquesta_runner/orquesta_runner/__init__.py index bbe290db9a7..ae0bd695f57 100644 --- a/contrib/runners/orquesta_runner/orquesta_runner/__init__.py +++ b/contrib/runners/orquesta_runner/orquesta_runner/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = '3.5dev' diff --git a/contrib/runners/python_runner/python_runner/__init__.py b/contrib/runners/python_runner/python_runner/__init__.py index bbe290db9a7..ae0bd695f57 100644 --- a/contrib/runners/python_runner/python_runner/__init__.py +++ b/contrib/runners/python_runner/python_runner/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = '3.5dev' diff --git a/contrib/runners/remote_runner/remote_runner/__init__.py b/contrib/runners/remote_runner/remote_runner/__init__.py index bbe290db9a7..ae0bd695f57 100644 --- a/contrib/runners/remote_runner/remote_runner/__init__.py +++ b/contrib/runners/remote_runner/remote_runner/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = '3.5dev' diff --git a/contrib/runners/winrm_runner/winrm_runner/__init__.py b/contrib/runners/winrm_runner/winrm_runner/__init__.py index bbe290db9a7..ae0bd695f57 100644 --- a/contrib/runners/winrm_runner/winrm_runner/__init__.py +++ b/contrib/runners/winrm_runner/winrm_runner/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = '3.5dev' diff --git a/st2actions/st2actions/__init__.py b/st2actions/st2actions/__init__.py index bbe290db9a7..ae0bd695f57 100644 --- a/st2actions/st2actions/__init__.py +++ b/st2actions/st2actions/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = '3.5dev' diff --git a/st2api/st2api/__init__.py b/st2api/st2api/__init__.py index bbe290db9a7..ae0bd695f57 100644 --- a/st2api/st2api/__init__.py +++ b/st2api/st2api/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = '3.5dev' diff --git a/st2auth/st2auth/__init__.py b/st2auth/st2auth/__init__.py index bbe290db9a7..ae0bd695f57 100644 --- a/st2auth/st2auth/__init__.py +++ b/st2auth/st2auth/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = '3.5dev' diff --git a/st2client/st2client/__init__.py b/st2client/st2client/__init__.py index bbe290db9a7..ae0bd695f57 100644 --- a/st2client/st2client/__init__.py +++ b/st2client/st2client/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = '3.5dev' diff --git a/st2common/st2common/__init__.py b/st2common/st2common/__init__.py index bbe290db9a7..ae0bd695f57 100644 --- a/st2common/st2common/__init__.py +++ b/st2common/st2common/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = '3.5dev' diff --git a/st2reactor/st2reactor/__init__.py b/st2reactor/st2reactor/__init__.py index bbe290db9a7..ae0bd695f57 100644 --- a/st2reactor/st2reactor/__init__.py +++ b/st2reactor/st2reactor/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = '3.5dev' diff --git a/st2stream/st2stream/__init__.py b/st2stream/st2stream/__init__.py index bbe290db9a7..ae0bd695f57 100644 --- a/st2stream/st2stream/__init__.py +++ b/st2stream/st2stream/__init__.py @@ -13,4 +13,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = '3.4dev' +__version__ = '3.5dev' diff --git a/st2tests/st2tests/__init__.py b/st2tests/st2tests/__init__.py index 594f0e2ae1e..3813790cca8 100644 --- a/st2tests/st2tests/__init__.py +++ b/st2tests/st2tests/__init__.py @@ -30,4 +30,4 @@ 'WorkflowTestCase' ] -__version__ = '3.4dev' +__version__ = '3.5dev' From 083b103649814425a782d0e9ee090e9087d5d4ed Mon Sep 17 00:00:00 2001 From: Tomaz Muraus Date: Sat, 6 Mar 2021 17:40:45 +0100 Subject: [PATCH 21/25] Fix typo. --- st2common/st2common/util/virtualenvs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/st2common/st2common/util/virtualenvs.py b/st2common/st2common/util/virtualenvs.py index b3635d15a53..62cfc99b52c 100644 --- a/st2common/st2common/util/virtualenvs.py +++ b/st2common/st2common/util/virtualenvs.py @@ -236,7 +236,7 @@ def remove_virtualenv(virtualenv_path, logger=None): logger.debug('Removing virtualenv in "%s"' % virtualenv_path) try: shutil.rmtree(virtualenv_path) - logger.debug("Virtualenv successfull removed.") + logger.debug("Virtualenv successfully removed.") except Exception as e: logger.error( 'Error while removing virtualenv at "%s": "%s"' % (virtualenv_path, e) From 370aa874c825f616a92123b75808a10a05f1f1bc Mon Sep 17 00:00:00 2001 From: Tomaz Muraus Date: Sat, 6 Mar 2021 17:42:27 +0100 Subject: [PATCH 22/25] Add changelog entry. --- CHANGELOG.rst | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 891719ccbb8..0fa0a562b80 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -4,6 +4,13 @@ Changelog in development -------------- +Changed +~~~~~~~ + +* All the code has been refactored using black and black style is automatically enforced and + required for all the new code. (#5156) + + Contributed by @Kami. 3.4.0 - March 02, 2021 ---------------------- @@ -22,7 +29,8 @@ Added * Added st2-auth-ldap pip requirements for LDAP auth integartion. (new feature) #5082 Contributed by @hnanchahal -* Added --register-recreate-virtualenvs flag to st2ctl reload to recreate virtualenvs from scratch. (part of upgrade instructions) [#5167] +* Added --register-recreate-virtualenvs flag to st2ctl reload to recreate virtualenvs from scratch. + (part of upgrade instructions) [#5167] Contributed by @winem and @blag Changed From 5d07a5c6b456737c5032e0dca459afa5b45f30af Mon Sep 17 00:00:00 2001 From: Tomaz Muraus Date: Sat, 6 Mar 2021 17:54:56 +0100 Subject: [PATCH 23/25] Fix typo. --- st2common/tests/integration/test_register_content_script.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/st2common/tests/integration/test_register_content_script.py b/st2common/tests/integration/test_register_content_script.py index cbc79831128..0a0dc8e7f17 100644 --- a/st2common/tests/integration/test_register_content_script.py +++ b/st2common/tests/integration/test_register_content_script.py @@ -183,5 +183,5 @@ def test_register_recreate_virtualenvs(self): self.assertIn('Setting up virtualenv for pack "dummy_pack_1"', stderr) self.assertIn("Setup virtualenv for 1 pack(s)", stderr) - self.assertIn("Virtualenv successfull removed.", stderr) + self.assertIn("Virtualenv successfully removed.", stderr) self.assertEqual(exit_code, 0) From 09deab7d2d76ffbfec98dae25453b84f3691116b Mon Sep 17 00:00:00 2001 From: Jacob Floyd Date: Wed, 17 Feb 2021 14:13:53 -0600 Subject: [PATCH 24/25] [BUGFIX] Use pip 20.0.2 to build virtualenvs We have already been using pip==20.0.2 for testing since st2 v3.2.0. virtualenv embeds the wheel for pip, so we need to adjust the pinned version of virtualenv to update which version of pip is used to create virtualenvs (eg for packs). This also clarifies the comment about why we're pinning virtualenv. --- fixed-requirements.txt | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/fixed-requirements.txt b/fixed-requirements.txt index a98a90dc345..7e1e599eeed 100644 --- a/fixed-requirements.txt +++ b/fixed-requirements.txt @@ -46,8 +46,9 @@ six==1.13.0 sseclient-py==1.7 stevedore==1.30.1 tooz==2.8.0 -# Note: We use latest version of virtualenv which uses pip 19 -virtualenv==16.6.0 +# Note: virtualenv embeds the pip wheel, so pinning virtualenv also pins pip +# virtualenv==20.0.18 has pip==20.0.2 +virtualenv==20.0.18 webob==1.8.5 zake==0.2.2 # test requirements below From dcc2693e4b12fe30f2fc2f1db0c62030c6843a2a Mon Sep 17 00:00:00 2001 From: Jacob Floyd Date: Wed, 24 Feb 2021 08:48:27 -0600 Subject: [PATCH 25/25] limit pip-tools version in testing-requirements pip-tools 5.4 needs pip>=20.1, but we use 20.0.2. So, we need to lower prevent installing a pip-tools 5.4+ to avoid accidentally updating pip. --- test-requirements.txt | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test-requirements.txt b/test-requirements.txt index c004342bc8d..1125e4813fd 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -30,4 +30,6 @@ webtest==2.0.25 rstcheck>=3.3.1,<3.4 tox==3.14.1 pyrabbit -pip-tools # For pip-compile, to check for version conflicts +# pip-tools provides pip-compile: to check for version conflicts +# pip-tools 5.4 needs pip>=20.1, but we use 20.0.2 +pip-tools<5.4