Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Pass a Site into make_request #8757

Merged
merged 7 commits into from
Nov 16, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/8757.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Refactor test utilities for injecting HTTP requests.
13 changes: 7 additions & 6 deletions tests/app/test_frontend_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from synapse.app.generic_worker import GenericWorkerServer

from tests.server import make_request, render
from tests.unittest import HomeserverTestCase


Expand Down Expand Up @@ -55,10 +56,10 @@ def test_listen_http_with_presence_enabled(self):
# Grab the resource from the site that was told to listen
self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1]
self.resource = site.resource.children[b"_matrix"].children[b"client"]
resource = site.resource.children[b"_matrix"].children[b"client"]

request, channel = self.make_request("PUT", "presence/a/status")
self.render(request)
request, channel = make_request(self.reactor, site, "PUT", "presence/a/status")
render(request, resource, self.reactor)

# 400 + unrecognised, because nothing is registered
self.assertEqual(channel.code, 400)
Expand All @@ -77,10 +78,10 @@ def test_listen_http_with_presence_disabled(self):
# Grab the resource from the site that was told to listen
self.assertEqual(len(self.reactor.tcpServers), 1)
site = self.reactor.tcpServers[0][1]
self.resource = site.resource.children[b"_matrix"].children[b"client"]
resource = site.resource.children[b"_matrix"].children[b"client"]

request, channel = self.make_request("PUT", "presence/a/status")
self.render(request)
request, channel = make_request(self.reactor, site, "PUT", "presence/a/status")
render(request, resource, self.reactor)

# 401, because the stub servlet still checks authentication
self.assertEqual(channel.code, 401)
Expand Down
17 changes: 9 additions & 8 deletions tests/app/test_openid_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from synapse.app.homeserver import SynapseHomeServer
from synapse.config.server import parse_listener_def

from tests.server import make_request, render
from tests.unittest import HomeserverTestCase


Expand Down Expand Up @@ -66,16 +67,16 @@ def test_openid_listener(self, names, expectation):
# Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1]
try:
self.resource = site.resource.children[b"_matrix"].children[b"federation"]
resource = site.resource.children[b"_matrix"].children[b"federation"]
except KeyError:
if expectation == "no_resource":
return
raise

request, channel = self.make_request(
"GET", "/_matrix/federation/v1/openid/userinfo"
request, channel = make_request(
self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo"
)
self.render(request)
render(request, resource, self.reactor)

self.assertEqual(channel.code, 401)

Expand Down Expand Up @@ -115,15 +116,15 @@ def test_openid_listener(self, names, expectation):
# Grab the resource from the site that was told to listen
site = self.reactor.tcpServers[0][1]
try:
self.resource = site.resource.children[b"_matrix"].children[b"federation"]
resource = site.resource.children[b"_matrix"].children[b"federation"]
except KeyError:
if expectation == "no_resource":
return
raise

request, channel = self.make_request(
"GET", "/_matrix/federation/v1/openid/userinfo"
request, channel = make_request(
self.reactor, site, "GET", "/_matrix/federation/v1/openid/userinfo"
)
self.render(request)
render(request, resource, self.reactor)

self.assertEqual(channel.code, 401)
13 changes: 7 additions & 6 deletions tests/http/test_additional_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from synapse.http.additional_resource import AdditionalResource
from synapse.http.server import respond_with_json

from tests.server import FakeSite, make_request, render
from tests.unittest import HomeserverTestCase


Expand All @@ -43,20 +44,20 @@ class AdditionalResourceTests(HomeserverTestCase):

def test_async(self):
handler = _AsyncTestCustomEndpoint({}, None).handle_request
self.resource = AdditionalResource(self.hs, handler)
resource = AdditionalResource(self.hs, handler)

request, channel = self.make_request("GET", "/")
self.render(request)
request, channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
render(request, resource, self.reactor)

self.assertEqual(request.code, 200)
self.assertEqual(channel.json_body, {"some_key": "some_value_async"})

def test_sync(self):
handler = _SyncTestCustomEndpoint({}, None).handle_request
self.resource = AdditionalResource(self.hs, handler)
resource = AdditionalResource(self.hs, handler)

request, channel = self.make_request("GET", "/")
self.render(request)
request, channel = make_request(self.reactor, FakeSite(resource), "GET", "/")
render(request, resource, self.reactor)

self.assertEqual(request.code, 200)
self.assertEqual(channel.json_body, {"some_key": "some_value_sync"})
29 changes: 22 additions & 7 deletions tests/replication/test_client_reader_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker
from tests.server import FakeChannel
from tests.server import FakeChannel, make_request

logger = logging.getLogger(__name__)

Expand All @@ -46,8 +46,11 @@ def test_register_single_worker(self):
"""Test that registration works when using a single client reader worker.
"""
worker_hs = self.make_worker_hs("synapse.app.client_reader")
site = self._hs_to_site[worker_hs]

request_1, channel_1 = self.make_request(
request_1, channel_1 = make_request(
self.reactor,
site,
"POST",
"register",
{"username": "user", "type": "m.login.password", "password": "bar"},
Expand All @@ -59,8 +62,12 @@ def test_register_single_worker(self):
session = channel_1.json_body["session"]

# also complete the dummy auth
request_2, channel_2 = self.make_request(
"POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
request_2, channel_2 = make_request(
self.reactor,
site,
"POST",
"register",
{"auth": {"session": session, "type": "m.login.dummy"}},
) # type: SynapseRequest, FakeChannel
self.render_on_worker(worker_hs, request_2)
self.assertEqual(request_2.code, 200)
Expand All @@ -74,7 +81,10 @@ def test_register_multi_worker(self):
worker_hs_1 = self.make_worker_hs("synapse.app.client_reader")
worker_hs_2 = self.make_worker_hs("synapse.app.client_reader")

request_1, channel_1 = self.make_request(
site_1 = self._hs_to_site[worker_hs_1]
request_1, channel_1 = make_request(
self.reactor,
site_1,
"POST",
"register",
{"username": "user", "type": "m.login.password", "password": "bar"},
Expand All @@ -86,8 +96,13 @@ def test_register_multi_worker(self):
session = channel_1.json_body["session"]

# also complete the dummy auth
request_2, channel_2 = self.make_request(
"POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
site_2 = self._hs_to_site[worker_hs_2]
request_2, channel_2 = make_request(
self.reactor,
site_2,
"POST",
"register",
{"auth": {"session": session, "type": "m.login.dummy"}},
) # type: SynapseRequest, FakeChannel
self.render_on_worker(worker_hs_2, request_2)
self.assertEqual(request_2.code, 200)
Expand Down
10 changes: 6 additions & 4 deletions tests/replication/test_multi_media_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import FakeChannel, FakeTransport
from tests.server import FakeChannel, FakeSite, FakeTransport, make_request

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -67,14 +67,16 @@ def _get_media_req(
The channel for the *client* request and the *outbound* request for
the media which the caller should respond to.
"""

request, channel = self.make_request(
resource = hs.get_media_repository_resource().children[b"download"]
request, channel = make_request(
self.reactor,
FakeSite(resource),
"GET",
"/{}/{}".format(target, media_id),
shorthand=False,
access_token=self.access_token,
)
request.render(hs.get_media_repository_resource().children[b"download"])
request.render(resource)
self.pump()

clients = self.reactor.tcpClients
Expand Down
42 changes: 32 additions & 10 deletions tests/replication/test_sharded_event_persister.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from synapse.rest.client.v2_alpha import sync

from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import make_request
from tests.utils import USE_POSTGRES_FOR_TESTS

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -148,6 +149,7 @@ def test_vector_clock_token(self):
sync_hs = self.make_worker_hs(
"synapse.app.generic_worker", {"worker_name": "sync"},
)
sync_hs_site = self._hs_to_site[sync_hs]

# Specially selected room IDs that get persisted on different workers.
room_id1 = "!foo:test"
Expand Down Expand Up @@ -178,7 +180,9 @@ def test_vector_clock_token(self):
)

# Do an initial sync so that we're up to date.
request, channel = self.make_request("GET", "/sync", access_token=access_token)
request, channel = make_request(
self.reactor, sync_hs_site, "GET", "/sync", access_token=access_token
)
self.render_on_worker(sync_hs, request)
next_batch = channel.json_body["next_batch"]

Expand All @@ -203,8 +207,12 @@ def test_vector_clock_token(self):

# Check that syncing still gets the new event, despite the gap in the
# stream IDs.
request, channel = self.make_request(
"GET", "/sync?since={}".format(next_batch), access_token=access_token
request, channel = make_request(
self.reactor,
sync_hs_site,
"GET",
"/sync?since={}".format(next_batch),
access_token=access_token,
)
self.render_on_worker(sync_hs, request)

Expand All @@ -230,7 +238,9 @@ def test_vector_clock_token(self):
response = self.helper.send(room_id2, body="Hi!", tok=self.other_access_token)
first_event_in_room2 = response["event_id"]

request, channel = self.make_request(
request, channel = make_request(
self.reactor,
sync_hs_site,
"GET",
"/sync?since={}".format(vector_clock_token),
access_token=access_token,
Expand All @@ -254,8 +264,12 @@ def test_vector_clock_token(self):
self.helper.send(room_id1, body="Hi again!", tok=self.other_access_token)
self.helper.send(room_id2, body="Hi again!", tok=self.other_access_token)

request, channel = self.make_request(
"GET", "/sync?since={}".format(next_batch), access_token=access_token
request, channel = make_request(
self.reactor,
sync_hs_site,
"GET",
"/sync?since={}".format(next_batch),
access_token=access_token,
)
self.render_on_worker(sync_hs, request)

Expand All @@ -269,7 +283,9 @@ def test_vector_clock_token(self):
# Paginating back in the first room should not produce any results, as
# no events have happened in it. This tests that we are correctly
# filtering results based on the vector clock portion.
request, channel = self.make_request(
request, channel = make_request(
self.reactor,
sync_hs_site,
"GET",
"/rooms/{}/messages?from={}&to={}&dir=b".format(
room_id1, prev_batch1, vector_clock_token
Expand All @@ -281,7 +297,9 @@ def test_vector_clock_token(self):

# Paginating back on the second room should produce the first event
# again. This tests that pagination isn't completely broken.
request, channel = self.make_request(
request, channel = make_request(
self.reactor,
sync_hs_site,
"GET",
"/rooms/{}/messages?from={}&to={}&dir=b".format(
room_id2, prev_batch2, vector_clock_token
Expand All @@ -295,7 +313,9 @@ def test_vector_clock_token(self):
)

# Paginating forwards should give the same results
request, channel = self.make_request(
request, channel = make_request(
self.reactor,
sync_hs_site,
"GET",
"/rooms/{}/messages?from={}&to={}&dir=f".format(
room_id1, vector_clock_token, prev_batch1
Expand All @@ -305,7 +325,9 @@ def test_vector_clock_token(self):
self.render_on_worker(sync_hs, request)
self.assertListEqual([], channel.json_body["chunk"])

request, channel = self.make_request(
request, channel = make_request(
self.reactor,
sync_hs_site,
"GET",
"/rooms/{}/messages?from={}&to={}&dir=f".format(
room_id2, vector_clock_token, prev_batch2,
Expand Down
18 changes: 14 additions & 4 deletions tests/rest/admin/test_admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from synapse.rest.client.v2_alpha import groups

from tests import unittest
from tests.server import FakeSite, make_request


class VersionTestCase(unittest.HomeserverTestCase):
Expand Down Expand Up @@ -222,8 +223,13 @@ def write_to(r):

def _ensure_quarantined(self, admin_user_tok, server_and_media_id):
"""Ensure a piece of media is quarantined when trying to access it."""
request, channel = self.make_request(
"GET", server_and_media_id, shorthand=False, access_token=admin_user_tok,
request, channel = make_request(
self.reactor,
FakeSite(self.download_resource),
"GET",
server_and_media_id,
shorthand=False,
access_token=admin_user_tok,
)
request.render(self.download_resource)
self.pump(1.0)
Expand Down Expand Up @@ -287,7 +293,9 @@ def test_quarantine_media_by_id(self):
server_name, media_id = server_name_and_media_id.split("/")

# Attempt to access the media
request, channel = self.make_request(
request, channel = make_request(
self.reactor,
FakeSite(self.download_resource),
"GET",
server_name_and_media_id,
shorthand=False,
Expand Down Expand Up @@ -462,7 +470,9 @@ def test_cannot_quarantine_safe_media(self):
self._ensure_quarantined(admin_user_tok, server_and_media_id_1)

# Attempt to access each piece of media
request, channel = self.make_request(
request, channel = make_request(
self.reactor,
FakeSite(self.download_resource),
"GET",
server_and_media_id_2,
shorthand=False,
Expand Down
Loading