diff --git a/sirbot/plugins/slack/endpoints.py b/sirbot/plugins/slack/endpoints.py index e366faa..609b944 100644 --- a/sirbot/plugins/slack/endpoints.py +++ b/sirbot/plugins/slack/endpoints.py @@ -4,9 +4,10 @@ import aiohttp.web from aiohttp.web import Response from slack.events import Event +from slack.sansio import validate_request_signature from slack.actions import Action from slack.commands import Command -from slack.exceptions import FailedVerification +from slack.exceptions import InvalidTimestamp, FailedVerification, InvalidSlackSignature LOG = logging.getLogger(__name__) @@ -16,15 +17,24 @@ async def incoming_event(request): payload = await request.json() LOG.log(5, "Incoming event payload: %s", payload) - if "challenge" in payload: - if payload["token"] == slack.verify: + if payload.get("type") == "url_verification": + if slack.signing_secret: + try: + validate_request_signature( + await request.read(), request.headers, slack.signing_secret + ) + return Response(body=payload["challenge"]) + except (InvalidSlackSignature, InvalidTimestamp): + return Response(status=500) + elif payload["token"] == slack.verify: return Response(body=payload["challenge"]) else: return Response(status=500) try: - event = Event.from_http(payload, verification_token=slack.verify) - except FailedVerification: + verification_token = await _validate_request(request, slack) + event = Event.from_http(payload, verification_token=verification_token) + except (FailedVerification, InvalidSlackSignature, InvalidTimestamp): return Response(status=401) if event["type"] == "message": @@ -81,8 +91,9 @@ async def incoming_command(request): payload = await request.post() try: - command = Command(payload, verification_token=slack.verify) - except FailedVerification: + verification_token = await _validate_request(request, slack) + command = Command(payload, verification_token=verification_token) + except (FailedVerification, InvalidSlackSignature, InvalidTimestamp): return Response(status=401) LOG.debug("Incoming command: %s", command) @@ -99,8 +110,9 @@ async def incoming_action(request): LOG.log(5, "Incoming action payload: %s", payload) try: - action = Action.from_http(payload, verification_token=slack.verify) - except FailedVerification: + verification_token = await _validate_request(request, slack) + action = Action.from_http(payload, verification_token=verification_token) + except (FailedVerification, InvalidSlackSignature, InvalidTimestamp): return Response(status=401) LOG.debug("Incoming action: %s", action) @@ -143,3 +155,13 @@ async def _wait_and_check_result(futures): return results[0] return Response(status=200) + + +async def _validate_request(request, slack): + if slack.signing_secret: + validate_request_signature( + await request.read(), request.headers, slack.signing_secret + ) + return None + else: + return slack.verify diff --git a/sirbot/plugins/slack/plugin.py b/sirbot/plugins/slack/plugin.py index c93437b..861af9b 100644 --- a/sirbot/plugins/slack/plugin.py +++ b/sirbot/plugins/slack/plugin.py @@ -24,10 +24,12 @@ class SlackPlugin: Args: token: slack authentication token (env var: `SLACK_TOKEN`). - verify: slack verification token (env var: `SLACK_VERIFY`). bot_id: bot id (env var: `SLACK_BOT_ID`). bot_user_id: user id of the bot (env var: `SLACK_BOT_USER_ID`). admins: list of slack admins user id (env var: `SLACK_ADMINS`). + verify: slack verification token (env var: `SLACK_VERIFY`). + signing_secret: slack signing secret key (env var: `SLACK_SIGNING_SECRET`). + (disables verification token if provided). **Variables**: * **api**: Slack client. Instance of :class:`slack.io.aiohttp.SlackAPI`. @@ -36,12 +38,24 @@ class SlackPlugin: __name__ = "slack" def __init__( - self, *, token=None, verify=None, bot_id=None, bot_user_id=None, admins=None + self, + *, + token=None, + bot_id=None, + bot_user_id=None, + admins=None, + verify=None, + signing_secret=None ): self.api = None self.token = token or os.environ["SLACK_TOKEN"] self.admins = admins or os.environ.get("SLACK_ADMINS", []) - self.verify = verify or os.environ["SLACK_VERIFY"] + if signing_secret or "SLACK_SIGNING_SECRET" in os.environ: + self.signing_secret = signing_secret or os.environ["SLACK_SIGNING_SECRET"] + self.verify = None + else: + self.verify = verify or os.environ["SLACK_VERIFY"] + self.signing_secret = None self.bot_id = bot_id or os.environ.get("SLACK_BOT_ID") self.bot_user_id = bot_user_id or os.environ.get("SLACK_BOT_USER_ID") self.handlers_option = {} diff --git a/tests/test_plugin_apscheduler.py b/tests/test_plugin_apscheduler.py index aea8785..3857d25 100644 --- a/tests/test_plugin_apscheduler.py +++ b/tests/test_plugin_apscheduler.py @@ -11,6 +11,6 @@ async def bot(): class TestPluginAPscheduler: - async def test_start(self, bot, test_server): - await test_server(bot) + async def test_start(self, bot, aiohttp_server): + await aiohttp_server(bot) assert isinstance(bot["plugins"]["scheduler"], APSchedulerPlugin) diff --git a/tests/test_plugin_github.py b/tests/test_plugin_github.py index 9d552b5..42288e5 100644 --- a/tests/test_plugin_github.py +++ b/tests/test_plugin_github.py @@ -24,28 +24,28 @@ async def event(request): class TestPluginGithub: - async def test_start(self, bot, test_server): - await test_server(bot) + async def test_start(self, bot, aiohttp_server): + await aiohttp_server(bot) assert isinstance(bot["plugins"]["github"], GithubPlugin) - async def test_incoming_event(self, bot, test_client, event): - client = await test_client(bot) + async def test_incoming_event(self, bot, aiohttp_client, event): + client = await aiohttp_client(bot) r = await client.post("/github", json=event[0], headers=event[1]) assert r.status == 200 - async def test_incoming_event_401(self, bot, test_client, event): + async def test_incoming_event_401(self, bot, aiohttp_client, event): bot["plugins"]["github"].verify = "wrongsupersecrettoken" - client = await test_client(bot) + client = await aiohttp_client(bot) r = await client.post("/github", json=event[0], headers=event[1]) assert r.status == 401 - async def test_incoming_event_handler_error(self, bot, test_client, event): + async def test_incoming_event_handler_error(self, bot, aiohttp_client, event): async def handler(event, app): raise RuntimeError() bot["plugins"]["github"].router.add( handler, event[1]["X-GitHub-Event"], action=event[0]["action"] ) - client = await test_client(bot) + client = await aiohttp_client(bot) r = await client.post("/github", json=event[0], headers=event[1]) assert r.status == 500 diff --git a/tests/test_plugin_postgres.py b/tests/test_plugin_postgres.py index 154368e..bc28412 100644 --- a/tests/test_plugin_postgres.py +++ b/tests/test_plugin_postgres.py @@ -25,17 +25,17 @@ async def _teardown(self, bot): await pg_con.execute("""DROP SCHEMA IF EXISTS sirbot_test CASCADE""") await pg_con.execute("""DROP TABLE IF EXISTS metadata""") - async def test_start(self, bot, test_server): + async def test_start(self, bot, aiohttp_server): try: - await test_server(bot) + await aiohttp_server(bot) assert isinstance(bot["plugins"]["pg"], PgPlugin) finally: await self._teardown(bot) - async def test_no_migration(self, bot, test_server): + async def test_no_migration(self, bot, aiohttp_server): try: bot["plugins"]["pg"].sql_migration_directory = None - await test_server(bot) + await aiohttp_server(bot) with pytest.raises(asyncpg.exceptions.UndefinedTableError): async with bot["plugins"]["pg"].connection() as pg_con: @@ -43,10 +43,10 @@ async def test_no_migration(self, bot, test_server): finally: await self._teardown(bot) - async def test_initial_migration(self, bot, test_server): + async def test_initial_migration(self, bot, aiohttp_server): try: bot["plugins"]["pg"].version = "0.0.1" - await test_server(bot) + await aiohttp_server(bot) async with bot["plugins"]["pg"].connection() as pg_con: version = await pg_con.fetchval("""SELECT db_version FROM metadata""") count = await pg_con.fetchval( @@ -58,9 +58,9 @@ async def test_initial_migration(self, bot, test_server): finally: await self._teardown(bot) - async def test_migration_to_0_0_2(self, bot, test_server): + async def test_migration_to_0_0_2(self, bot, aiohttp_server): try: - await test_server(bot) + await aiohttp_server(bot) async with bot["plugins"]["pg"].connection() as pg_con: version = await pg_con.fetchval("""SELECT db_version FROM metadata""") @@ -82,10 +82,10 @@ async def test_migration_to_0_0_2(self, bot, test_server): finally: await self._teardown(bot) - async def test_no_migration_needed(self, bot, test_server): + async def test_no_migration_needed(self, bot, aiohttp_server): try: bot["plugins"]["pg"].version = "0.1.9" - await test_server(bot) + await aiohttp_server(bot) async with bot["plugins"]["pg"].connection() as pg_con: count_start = await pg_con.fetchval( @@ -104,12 +104,12 @@ async def test_no_migration_needed(self, bot, test_server): finally: await self._teardown(bot) - async def test_failed_migration(self, bot, test_server): + async def test_failed_migration(self, bot, aiohttp_server): try: bot["plugins"]["pg"].version = "0.2.0" with pytest.raises(asyncpg.exceptions.UndefinedColumnError): - await test_server(bot) + await aiohttp_server(bot) with pytest.raises(asyncpg.exceptions.UndefinedTableError): async with bot["plugins"]["pg"].connection() as pg_con: diff --git a/tests/test_plugin_readthedocs.py b/tests/test_plugin_readthedocs.py index f7ae578..93b9e0c 100644 --- a/tests/test_plugin_readthedocs.py +++ b/tests/test_plugin_readthedocs.py @@ -12,8 +12,8 @@ async def bot(): class TestPluginReadTheDocs: - async def test_start(self, bot, test_server): - await test_server(bot) + async def test_start(self, bot, aiohttp_server): + await aiohttp_server(bot) assert isinstance(bot["plugins"]["readthedocs"], RTDPlugin) async def test_register_project(self, bot): @@ -134,7 +134,7 @@ def handler_bis(): assert h[1] is handler_bis assert "test" in bot["plugins"]["readthedocs"]._projects - async def test_incoming(self, bot, test_client): + async def test_incoming(self, bot, aiohttp_client): async def handler(payload, app): assert payload == { "build": { @@ -147,7 +147,7 @@ async def handler(payload, app): } assert app is bot - client = await test_client(bot) + client = await aiohttp_client(bot) bot["plugins"]["readthedocs"].register_handler("sir-bot-a-lot", handler=handler) r = await client.post( @@ -164,11 +164,11 @@ async def handler(payload, app): ) assert r.status == 200 - async def test_incoming_handler_error(self, bot, test_client): + async def test_incoming_handler_error(self, bot, aiohttp_client): async def handler(payload, app): raise RuntimeError() - client = await test_client(bot) + client = await aiohttp_client(bot) bot["plugins"]["readthedocs"].register_handler("sir-bot-a-lot", handler=handler) r = await client.post( @@ -185,8 +185,8 @@ async def handler(payload, app): ) assert r.status == 500 - async def test_incoming_no_project(self, bot, test_client): - client = await test_client(bot) + async def test_incoming_no_project(self, bot, aiohttp_client): + client = await aiohttp_client(bot) r = await client.post( "/readthedocs", json={ @@ -201,8 +201,8 @@ async def test_incoming_no_project(self, bot, test_client): ) assert r.status == 400 - async def test_incoming_project_no_handler(self, bot, test_client): - client = await test_client(bot) + async def test_incoming_project_no_handler(self, bot, aiohttp_client): + client = await aiohttp_client(bot) bot["plugins"]["readthedocs"].register_project( "sir-bot-a-lot", build_url="https://example.com", jeton="aaaaaa" ) @@ -220,8 +220,8 @@ async def test_incoming_project_no_handler(self, bot, test_client): ) assert r.status == 200 - async def test_incoming_bad_json(self, bot, test_client): - client = await test_client(bot) + async def test_incoming_bad_json(self, bot, aiohttp_client): + client = await aiohttp_client(bot) r = await client.post("/readthedocs", json={"a": "b"}) assert r.status == 400 diff --git a/tests/test_plugin_slack.py b/tests/test_plugin_slack.py index efd0036..744a50b 100644 --- a/tests/test_plugin_slack.py +++ b/tests/test_plugin_slack.py @@ -1,6 +1,13 @@ import re +import hmac +import json +import time import asyncio +import hashlib +import urllib.parse +from typing import Dict, Tuple, Union, Optional from unittest import mock +from collections import MutableMapping import slack import pytest @@ -25,6 +32,49 @@ async def bot(): return b +@pytest.fixture +async def bot_signing(): + b = SirBot() + b.load_plugin( + SlackPlugin( + token="foo", + signing_secret="sharedsigningkey", + bot_user_id="baz", + bot_id="boo", + admins=["aaa", "bbb"], + ) + ) + return b + + +def _sign_body( + json_data: Optional[Dict] = None, + post_data: Optional[Dict] = None, + signing_secret: str = "sharedsigningkey", + timestamp: Optional[int] = None, +) -> Tuple[Dict[str, str], bytes]: + if json_data: + headers = {"content-type": "application/json"} + body = json.dumps(json_data).encode("utf-8") + elif post_data: + headers = {"content-type": "application/x-www-form-urlencoded"} + body = urllib.parse.urlencode(post_data).encode("utf-8") + else: + raise ValueError("Unknown type of data to sign") + if timestamp is None: + timestamp = int(time.time()) + headers["X-Slack-Request-Timestamp"] = str(timestamp) + headers["X-Slack-Signature"] = ( + "v0=" + + hmac.new( + signing_secret.encode("utf-8"), + f"""v0:{timestamp}:{body}""".encode("utf-8"), + digestmod=hashlib.sha256, + ).hexdigest() + ) + return headers, body + + @pytest.fixture def find_bot_id_query(): async def query(*args, **kwargs): @@ -64,7 +114,7 @@ async def query(*args, **kwargs): "is_restricted": False, "is_ultra_restricted": False, "is_bot": True, - "updated": 1502138686, + "updated": 1_502_138_686, "is_app_user": False, "has_2fa": False, }, @@ -74,8 +124,8 @@ async def query(*args, **kwargs): class TestPluginSlack: - async def test_start(self, bot, test_server): - await test_server(bot) + async def test_start(self, bot, aiohttp_server): + await aiohttp_server(bot) assert isinstance(bot["plugins"]["slack"], SlackPlugin) async def test_start_no_bot_user_id(self, caplog): @@ -193,55 +243,81 @@ def handler2(): bot["plugins"]["slack"].routers["action"]._routes["hello"]["*"][1][0] ) - async def test_find_bot_id(self, bot, test_server, find_bot_id_query): - await test_server(bot) + async def test_find_bot_id(self, bot, aiohttp_server, find_bot_id_query): + await aiohttp_server(bot) bot["plugins"]["slack"].api.query = find_bot_id_query await bot["plugins"]["slack"].find_bot_id(bot) assert bot["plugins"]["slack"].bot_id == "B00000000" - async def test_start_find_bot_id(self, test_server, find_bot_id_query): + async def test_start_find_bot_id(self, aiohttp_server, find_bot_id_query): bot = SirBot() bot.load_plugin(SlackPlugin(token="foo", verify="bar", bot_user_id="baz")) bot["plugins"]["slack"].api.query = find_bot_id_query - await test_server(bot) + await aiohttp_server(bot) assert bot["plugins"]["slack"].bot_id == "B00000000" class TestPluginSlackEndpoints: - async def test_incoming_event(self, bot, test_client, slack_event): - client = await test_client(bot) + async def test_incoming_event(self, bot, aiohttp_client, slack_event): + client = await aiohttp_client(bot) r = await client.post("/slack/events", json=slack_event) assert r.status == 200 - async def test_incoming_command(self, bot, test_client, slack_command): - client = await test_client(bot) + async def test_incoming_event_signed( + self, bot_signing, aiohttp_client, slack_event + ): + client = await aiohttp_client(bot_signing) + headers, body = _sign_body(json_data=slack_event) + r = await client.post("/slack/events", headers=headers, data=body) + assert r.status == 200 + + async def test_incoming_command(self, bot, aiohttp_client, slack_command): + client = await aiohttp_client(bot) r = await client.post("/slack/commands", data=slack_command) assert r.status == 200 - async def test_incoming_action(self, bot, test_client, slack_action): - client = await test_client(bot) + async def test_incoming_command_signed( + self, bot_signing, aiohttp_client, slack_command + ): + client = await aiohttp_client(bot_signing) + headers, body = _sign_body(post_data=slack_command) + r = await client.post("/slack/commands", headers=headers, data=body) + assert r.status == 200 + + async def test_incoming_action(self, bot, aiohttp_client, slack_action): + client = await aiohttp_client(bot) r = await client.post("/slack/actions", data=slack_action) assert r.status == 200 - async def test_incoming_event_wrong_token(self, bot, test_client, slack_event): + async def test_incoming_action_signed( + self, bot_signing, aiohttp_client, slack_action + ): + client = await aiohttp_client(bot_signing) + headers, body = _sign_body(post_data=slack_action) + r = await client.post("/slack/actions", headers=headers, data=body) + assert r.status == 200 + + async def test_incoming_event_wrong_token(self, bot, aiohttp_client, slack_event): bot["plugins"]["slack"].verify = "bar" - client = await test_client(bot) + client = await aiohttp_client(bot) r = await client.post("/slack/events", json=slack_event) assert r.status == 401 - async def test_incoming_command_wrong_token(self, bot, test_client, slack_command): + async def test_incoming_command_wrong_token( + self, bot, aiohttp_client, slack_command + ): bot["plugins"]["slack"].verify = "bar" - client = await test_client(bot) + client = await aiohttp_client(bot) r = await client.post("/slack/commands", data=slack_command) assert r.status == 401 - async def test_incoming_action_wrong_token(self, bot, test_client, slack_action): + async def test_incoming_action_wrong_token(self, bot, aiohttp_client, slack_action): bot["plugins"]["slack"].verify = "bar" - client = await test_client(bot) + client = await aiohttp_client(bot) r = await client.post("/slack/actions", data=slack_action) assert r.status == 401 - async def test_incoming_event_error(self, bot, test_client, slack_event): + async def test_incoming_event_error(self, bot, aiohttp_client, slack_event): async def handler(*args, **kwargs): raise RuntimeError() @@ -252,11 +328,11 @@ async def handler(*args, **kwargs): return_value=[(handler, {"wait": True, "mention": False, "admin": False})] ) - client = await test_client(bot) + client = await aiohttp_client(bot) r = await client.post("/slack/events", json=slack_event) assert r.status == 500 - async def test_incoming_message_error(self, bot, test_client, slack_message): + async def test_incoming_message_error(self, bot, aiohttp_client, slack_message): async def handler(*args, **kwargs): raise RuntimeError() @@ -264,11 +340,11 @@ async def handler(*args, **kwargs): return_value=[(handler, {"wait": True, "mention": False, "admin": False})] ) - client = await test_client(bot) + client = await aiohttp_client(bot) r = await client.post("/slack/events", json=slack_message) assert r.status == 500 - async def test_incoming_command_error(self, bot, test_client, slack_command): + async def test_incoming_command_error(self, bot, aiohttp_client, slack_command): async def handler(*args, **kwargs): raise RuntimeError() @@ -276,11 +352,11 @@ async def handler(*args, **kwargs): return_value=[(handler, {"wait": True})] ) - client = await test_client(bot) + client = await aiohttp_client(bot) r = await client.post("/slack/commands", data=slack_command) assert r.status == 500 - async def test_incoming_action_error(self, bot, test_client, slack_action): + async def test_incoming_action_error(self, bot, aiohttp_client, slack_action): async def handler(*args, **kwargs): raise RuntimeError() @@ -288,11 +364,11 @@ async def handler(*args, **kwargs): return_value=[(handler, {"wait": True})] ) - client = await test_client(bot) + client = await aiohttp_client(bot) r = await client.post("/slack/actions", data=slack_action) assert r.status == 500 - async def test_incoming_event_handler_arg(self, bot, test_client, slack_event): + async def test_incoming_event_handler_arg(self, bot, aiohttp_client, slack_event): async def handler(event, app): assert app is bot assert isinstance(event, slack.events.Event) @@ -301,11 +377,13 @@ async def handler(event, app): return_value=[(handler, {"wait": True})] ) - client = await test_client(bot) + client = await aiohttp_client(bot) r = await client.post("/slack/events", json=slack_event) assert r.status == 200 - async def test_incoming_message_handler_arg(self, bot, test_client, slack_message): + async def test_incoming_message_handler_arg( + self, bot, aiohttp_client, slack_message + ): async def handler(event, app): assert app is bot assert isinstance(event, slack.events.Message) @@ -314,11 +392,13 @@ async def handler(event, app): return_value=[(handler, {"wait": True, "mention": False, "admin": False})] ) - client = await test_client(bot) + client = await aiohttp_client(bot) r = await client.post("/slack/events", json=slack_message) assert r.status == 200 - async def test_incoming_command_handler_arg(self, bot, test_client, slack_command): + async def test_incoming_command_handler_arg( + self, bot, aiohttp_client, slack_command + ): async def handler(command, app): assert app is bot assert isinstance(command, slack.commands.Command) @@ -327,11 +407,11 @@ async def handler(command, app): return_value=[(handler, {"wait": True})] ) - client = await test_client(bot) + client = await aiohttp_client(bot) r = await client.post("/slack/commands", data=slack_command) assert r.status == 200 - async def test_incoming_action_handler_arg(self, bot, test_client, slack_action): + async def test_incoming_action_handler_arg(self, bot, aiohttp_client, slack_action): async def handler(action, app): assert app is bot assert isinstance(action, slack.actions.Action) @@ -340,106 +420,146 @@ async def handler(action, app): return_value=[(handler, {"wait": True})] ) - client = await test_client(bot) + client = await aiohttp_client(bot) r = await client.post("/slack/actions", data=slack_action) assert r.status == 200 - async def test_event_challenge(self, bot, test_client): + async def test_event_challenge(self, bot, aiohttp_client): - client = await test_client(bot) + client = await aiohttp_client(bot) r = await client.post( "/slack/events", - json={"token": "supersecuretoken", "challenge": "abcdefghij"}, + json={ + "token": "supersecuretoken", + "challenge": "abcdefghij", + "type": "url_verification", + }, + ) + data = await r.text() + assert r.status == 200 + assert data == "abcdefghij" + assert r.status == 200 + + async def test_event_challenge_signed(self, bot_signing, aiohttp_client): + + client = await aiohttp_client(bot_signing) + headers, body = _sign_body( + json_data={ + "token": "na", + "challenge": "abcdefghij", + "type": "url_verification", + } ) + r = await client.post("/slack/events", data=body, headers=headers) data = await r.text() assert r.status == 200 assert data == "abcdefghij" - async def test_event_challenge_wrong_token(self, bot, test_client): + async def test_event_challenge_wrong_token(self, bot, aiohttp_client): - client = await test_client(bot) + client = await aiohttp_client(bot) r = await client.post( "/slack/events", - json={"token": "wrongsupersecuretoken", "challenge": "abcdefghij"}, + json={ + "token": "wrongsupersecuretoken", + "challenge": "abcdefghij", + "type": "url_verification", + }, ) assert r.status == 500 + async def test_event_challenge_signed_wrong(self, bot_signing, aiohttp_client): + + client = await aiohttp_client(bot_signing) + headers, body = _sign_body( + json_data={ + "token": "na", + "challenge": "abcdefghij", + "type": "url_verification", + }, + signing_secret="notsharedsigningkey", + ) + r = await client.post("/slack/events", data=body, headers=headers) + assert r.status == 500 + @pytest.mark.parametrize("slack_message", ("bot",), indirect=True) - async def test_message_from_bot(self, bot, test_client, slack_message): + async def test_message_from_bot(self, bot, aiohttp_client, slack_message): bot["plugins"]["slack"].bot_id = "B0AAA0A00" bot["plugins"]["slack"].routers["message"].dispatch = mock.MagicMock() - client = await test_client(bot) + client = await aiohttp_client(bot) r = await client.post("/slack/events", json=slack_message) assert r.status == 200 assert bot["plugins"]["slack"].routers["message"].dispatch.call_count == 0 @pytest.mark.parametrize("slack_message", ("bot",), indirect=True) - async def test_message_from_other_bot(self, bot, test_client, slack_message): + async def test_message_from_other_bot(self, bot, aiohttp_client, slack_message): bot["plugins"]["slack"].bot_id = "B0AAA0A01" bot["plugins"]["slack"].routers["message"].dispatch = mock.MagicMock() - client = await test_client(bot) + client = await aiohttp_client(bot) r = await client.post("/slack/events", json=slack_message) assert r.status == 200 assert bot["plugins"]["slack"].routers["message"].dispatch.call_count == 1 @pytest.mark.parametrize("slack_message", ("simple",), indirect=True) - async def test_admin_message_ok(self, bot, test_client, slack_message): + async def test_admin_message_ok(self, bot, aiohttp_client, slack_message): handler = asynctest.CoroutineMock() bot["plugins"]["slack"].admins = ["U000AA000"] bot["plugins"]["slack"].on_message("hello", handler, admin=True) - client = await test_client(bot) + client = await aiohttp_client(bot) r = await client.post("/slack/events", json=slack_message) assert r.status == 200 assert handler.call_count == 1 @pytest.mark.parametrize("slack_message", ("simple",), indirect=True) - async def test_admin_message_skip(self, bot, test_client, slack_message): + async def test_admin_message_skip(self, bot, aiohttp_client, slack_message): handler = asynctest.CoroutineMock() bot["plugins"]["slack"].admins = ["U000AA001"] bot["plugins"]["slack"].on_message("hello", handler, admin=True) - client = await test_client(bot) + client = await aiohttp_client(bot) r = await client.post("/slack/events", json=slack_message) assert r.status == 200 assert handler.call_count == 0 @pytest.mark.parametrize("slack_message", ("mention",), indirect=True) - async def test_message_mention_ok(self, bot, test_client, slack_message): + async def test_message_mention_ok(self, bot, aiohttp_client, slack_message): handler = asynctest.CoroutineMock() bot["plugins"]["slack"].bot_user_id = "U0AAA0A00" bot["plugins"]["slack"].on_message("hello world", handler, mention=True) - client = await test_client(bot) + client = await aiohttp_client(bot) r = await client.post("/slack/events", json=slack_message) assert r.status == 200 assert handler.call_count == 1 @pytest.mark.parametrize("slack_message", ("mention",), indirect=True) - async def test_message_mention_skip(self, bot, test_client, slack_message): + async def test_message_mention_skip(self, bot, aiohttp_client, slack_message): handler = asynctest.CoroutineMock() bot["plugins"]["slack"].bot_user_id = "U0AAA0A01" bot["plugins"]["slack"].on_message("hello world", handler, mention=True) - client = await test_client(bot) + client = await aiohttp_client(bot) r = await client.post("/slack/events", json=slack_message) assert r.status == 200 assert handler.call_count == 0 @pytest.mark.parametrize("slack_message", ("mention",), indirect=True) - async def test_message_mention_strip_bot(self, bot, test_client, slack_message): + async def test_message_mention_strip_bot(self, bot, aiohttp_client, slack_message): async def handler(message, app): assert message["text"] == "hello world" bot["plugins"]["slack"].bot_user_id = "U0AAA0A00" bot["plugins"]["slack"].on_message("hello world", handler, mention=True) - client = await test_client(bot) + client = await aiohttp_client(bot) r = await client.post("/slack/events", json=slack_message) assert r.status == 200 - async def test_event_handler_return_response(self, bot, test_client, slack_event): + async def test_event_handler_return_response( + self, bot, aiohttp_client, slack_event + ): async def handler(message, app): return json_response(data={"ok": True}, status=200) @@ -450,12 +570,14 @@ async def handler(message, app): return_value=[(handler, {"wait": True, "mention": False, "admin": False})] ) - client = await test_client(bot) + client = await aiohttp_client(bot) r = await client.post("/slack/events", json=slack_event) assert r.status == 200 assert (await r.json()) == {"ok": True} - async def test_action_handler_return_response(self, bot, test_client, slack_action): + async def test_action_handler_return_response( + self, bot, aiohttp_client, slack_action + ): async def handler(message, app): print("AAAAA") return json_response(data={"ok": True}, status=200) @@ -463,13 +585,13 @@ async def handler(message, app): bot["plugins"]["slack"].routers["action"].dispatch = mock.MagicMock( return_value=[(handler, {"wait": True})] ) - client = await test_client(bot) + client = await aiohttp_client(bot) r = await client.post("/slack/actions", data=slack_action) assert r.status == 200 assert (await r.json()) == {"ok": True} async def test_command_handler_return_response( - self, bot, test_client, slack_command + self, bot, aiohttp_client, slack_command ): async def handler(message, app): return json_response(data={"ok": True}, status=200) @@ -477,12 +599,12 @@ async def handler(message, app): bot["plugins"]["slack"].routers["command"].dispatch = mock.MagicMock( return_value=[(handler, {"wait": True})] ) - client = await test_client(bot) + client = await aiohttp_client(bot) r = await client.post("/slack/commands", data=slack_command) assert r.status == 200 assert (await r.json()) == {"ok": True} - async def test_handler_multiple_response(self, bot, test_client, slack_event): + async def test_handler_multiple_response(self, bot, aiohttp_client, slack_event): async def handler(message, app): return json_response(data={"ok": True}, status=200) @@ -499,12 +621,12 @@ async def handler2(message, app): ] ) - client = await test_client(bot) + client = await aiohttp_client(bot) r = await client.post("/slack/events", json=slack_event) assert r.status == 200 assert (await r.text()) == "" - async def test_handler_no_wait(self, bot, test_client, slack_event): + async def test_handler_no_wait(self, bot, aiohttp_client, slack_event): global sentinel sentinel = False @@ -522,7 +644,7 @@ async def handler(message, app): assert not sentinel - client = await test_client(bot) + client = await aiohttp_client(bot) r = await client.post("/slack/events", json=slack_event) assert r.status == 200 diff --git a/tests/test_sirbot.py b/tests/test_sirbot.py index 4b47943..97cfadc 100644 --- a/tests/test_sirbot.py +++ b/tests/test_sirbot.py @@ -4,11 +4,11 @@ @pytest.mark.asyncio class TestSirBot: - async def test_bot(self, test_server): + async def test_bot(self, aiohttp_server): bot = SirBot() - await test_server(bot) + await aiohttp_server(bot) - async def test_load_plugin(self, test_server): + async def test_load_plugin(self, aiohttp_server): class MyPlugin: __name__ = "myplugin" @@ -22,7 +22,7 @@ def load(self, test_bot): bot.load_plugin(MyPlugin()) assert "myplugin" in bot.plugins assert isinstance(bot["plugins"]["myplugin"], MyPlugin) - await test_server(bot) + await aiohttp_server(bot) async def test_load_plugin_no_name(self): class MyPlugin: @@ -38,14 +38,14 @@ def load(self, test_bot): class TestEndpoints: - async def test_list_plugin_empty(self, test_client): + async def test_list_plugin_empty(self, aiohttp_client): bot = SirBot() - client = await test_client(bot) + client = await aiohttp_client(bot) rep = await client.get("/sirbot/plugins") data = await rep.json() assert data == {"plugins": []} - async def test_list_plugin(self, test_client): + async def test_list_plugin(self, aiohttp_client): class MyPlugin: __name__ = "myplugin" @@ -57,7 +57,7 @@ def load(self, test_bot): bot = SirBot() bot.load_plugin(MyPlugin()) - client = await test_client(bot) + client = await aiohttp_client(bot) rep = await client.get("/sirbot/plugins") data = await rep.json() assert data == {"plugins": ["myplugin"]}