Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Decouple synapse.api.auth_blocking.AuthBlocking from synapse.api.auth.Auth. #13021

Merged
merged 5 commits into from
Jun 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/13021.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Decouple `synapse.api.auth_blocking.AuthBlocking` from `synapse.api.auth.Auth`.
14 changes: 0 additions & 14 deletions synapse/api/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from twisted.web.server import Request

from synapse import event_auth
from synapse.api.auth_blocking import AuthBlocking
from synapse.api.constants import EventTypes, HistoryVisibility, Membership
from synapse.api.errors import (
AuthError,
Expand Down Expand Up @@ -67,8 +66,6 @@ def __init__(self, hs: "HomeServer"):
10000, "token_cache"
)

self._auth_blocking = AuthBlocking(self.hs)

self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips
self._track_puppeted_user_ips = hs.config.api.track_puppeted_user_ips
self._macaroon_secret_key = hs.config.key.macaroon_secret_key
Expand Down Expand Up @@ -711,14 +708,3 @@ async def check_user_in_room_or_world_readable(
"User %s not in room %s, and room previews are disabled"
% (user_id, room_id),
)

async def check_auth_blocking(
self,
user_id: Optional[str] = None,
threepid: Optional[dict] = None,
user_type: Optional[str] = None,
requester: Optional[Requester] = None,
) -> None:
await self._auth_blocking.check_auth_blocking(
user_id=user_id, threepid=threepid, user_type=user_type, requester=requester
)
5 changes: 3 additions & 2 deletions synapse/handlers/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ class AuthHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.auth = hs.get_auth()
self.auth_blocking = hs.get_auth_blocking()
self.clock = hs.get_clock()
self.checkers: Dict[str, UserInteractiveAuthChecker] = {}
for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
Expand Down Expand Up @@ -985,7 +986,7 @@ async def create_access_token_for_user_id(
not is_appservice_ghost
or self.hs.config.appservice.track_appservice_user_ips
):
await self.auth.check_auth_blocking(user_id)
await self.auth_blocking.check_auth_blocking(user_id)

access_token = self.generate_access_token(target_user_id_obj)
await self.store.add_access_token_to_user(
Expand Down Expand Up @@ -1439,7 +1440,7 @@ async def validate_short_term_login_token(
except Exception:
raise AuthError(403, "Invalid login token", errcode=Codes.FORBIDDEN)

await self.auth.check_auth_blocking(res.user_id)
await self.auth_blocking.check_auth_blocking(res.user_id)
return res

async def delete_access_token(self, access_token: str) -> None:
Expand Down
4 changes: 2 additions & 2 deletions synapse/handlers/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ async def _expire_event(self, event_id: str) -> None:
class EventCreationHandler:
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.auth_blocking = hs.get_auth_blocking()
self._event_auth_handler = hs.get_event_auth_handler()
self.store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
Expand Down Expand Up @@ -605,7 +605,7 @@ async def create_event(
Returns:
Tuple of created event, Context
"""
await self.auth.check_auth_blocking(requester=requester)
await self.auth_blocking.check_auth_blocking(requester=requester)

if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
room_version_id = event_dict["content"]["room_version"]
Expand Down
3 changes: 2 additions & 1 deletion synapse/handlers/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.hs = hs
self.auth = hs.get_auth()
self.auth_blocking = hs.get_auth_blocking()
self._auth_handler = hs.get_auth_handler()
self.profile_handler = hs.get_profile_handler()
self.user_directory_handler = hs.get_user_directory_handler()
Expand Down Expand Up @@ -276,7 +277,7 @@ async def register_user(

# do not check_auth_blocking if the call is coming through the Admin API
if not by_admin:
await self.auth.check_auth_blocking(threepid=threepid)
await self.auth_blocking.check_auth_blocking(threepid=threepid)

if localpart is not None:
await self.check_username(localpart, guest_access_token=guest_access_token)
Expand Down
3 changes: 2 additions & 1 deletion synapse/handlers/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self.auth = hs.get_auth()
self.auth_blocking = hs.get_auth_blocking()
self.clock = hs.get_clock()
self.hs = hs
self.spam_checker = hs.get_spam_checker()
Expand Down Expand Up @@ -707,7 +708,7 @@ async def create_room(
"""
user_id = requester.user.to_string()

await self.auth.check_auth_blocking(requester=requester)
await self.auth_blocking.check_auth_blocking(requester=requester)

if (
self._server_notices_mxid is not None
Expand Down
4 changes: 2 additions & 2 deletions synapse/handlers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def __init__(self, hs: "HomeServer"):
self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock()
self.state = hs.get_state_handler()
self.auth = hs.get_auth()
self.auth_blocking = hs.get_auth_blocking()
self._storage_controllers = hs.get_storage_controllers()
self._state_storage_controller = self._storage_controllers.state

Expand Down Expand Up @@ -280,7 +280,7 @@ async def wait_for_sync_for_user(
# not been exceeded (if not part of the group by this point, almost certain
# auth_blocking will occur)
user_id = sync_config.user.to_string()
await self.auth.check_auth_blocking(requester=requester)
await self.auth_blocking.check_auth_blocking(requester=requester)

res = await self.response_cache.wrap(
sync_config.request_key,
Expand Down
5 changes: 5 additions & 0 deletions synapse/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from twisted.web.resource import Resource

from synapse.api.auth import Auth
from synapse.api.auth_blocking import AuthBlocking
from synapse.api.filtering import Filtering
from synapse.api.ratelimiting import Ratelimiter, RequestRatelimiter
from synapse.appservice.api import ApplicationServiceApi
Expand Down Expand Up @@ -379,6 +380,10 @@ def get_notifier(self) -> Notifier:
def get_auth(self) -> Auth:
return Auth(self)

@cache_in_self
def get_auth_blocking(self) -> AuthBlocking:
return AuthBlocking(self)

@cache_in_self
def get_http_client_context_factory(self) -> IPolicyForHTTPS:
if self.config.tls.use_insecure_ssl_client_just_for_testing_do_not_use:
Expand Down
4 changes: 2 additions & 2 deletions synapse/server_notices/resource_limits_server_notices.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, hs: "HomeServer"):
self._server_notices_manager = hs.get_server_notices_manager()
self._store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self._auth = hs.get_auth()
self._auth_blocking = hs.get_auth_blocking()
self._config = hs.config
self._resouce_limited = False
self._account_data_handler = hs.get_account_data_handler()
Expand Down Expand Up @@ -91,7 +91,7 @@ async def maybe_send_server_notice_to_user(self, user_id: str) -> None:
# Normally should always pass in user_id to check_auth_blocking
# if you have it, but in this case are checking what would happen
# to other users if they were to arrive.
await self._auth.check_auth_blocking()
await self._auth_blocking.check_auth_blocking()
except ResourceLimitError as e:
limit_msg = e.msg
limit_type = e.limit_type
Expand Down
42 changes: 27 additions & 15 deletions tests/api/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from twisted.test.proto_helpers import MemoryReactor

from synapse.api.auth import Auth
from synapse.api.auth_blocking import AuthBlocking
from synapse.api.constants import UserTypes
from synapse.api.errors import (
AuthError,
Expand Down Expand Up @@ -49,7 +50,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):

# AuthBlocking reads from the hs' config on initialization. We need to
# modify its config instead of the hs'
self.auth_blocking = self.auth._auth_blocking
self.auth_blocking = AuthBlocking(hs)

self.test_user = "@foo:bar"
self.test_token = b"_test_token_"
Expand Down Expand Up @@ -362,36 +363,41 @@ def test_blocking_mau(self):
small_number_of_users = 1

# Ensure no error thrown
self.get_success(self.auth.check_auth_blocking())
self.get_success(self.auth_blocking.check_auth_blocking())

self.auth_blocking._limit_usage_by_mau = True

self.store.get_monthly_active_count = simple_async_mock(lots_of_users)

e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
e = self.get_failure(
self.auth_blocking.check_auth_blocking(), ResourceLimitError
)
self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact)
self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEqual(e.value.code, 403)

# Ensure does not throw an error
self.store.get_monthly_active_count = simple_async_mock(small_number_of_users)
self.get_success(self.auth.check_auth_blocking())
self.get_success(self.auth_blocking.check_auth_blocking())

def test_blocking_mau__depending_on_user_type(self):
self.auth_blocking._max_mau_value = 50
self.auth_blocking._limit_usage_by_mau = True

self.store.get_monthly_active_count = simple_async_mock(100)
# Support users allowed
self.get_success(self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT))
self.get_success(
self.auth_blocking.check_auth_blocking(user_type=UserTypes.SUPPORT)
)
self.store.get_monthly_active_count = simple_async_mock(100)
# Bots not allowed
self.get_failure(
self.auth.check_auth_blocking(user_type=UserTypes.BOT), ResourceLimitError
self.auth_blocking.check_auth_blocking(user_type=UserTypes.BOT),
ResourceLimitError,
)
self.store.get_monthly_active_count = simple_async_mock(100)
# Real users not allowed
self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError)

def test_blocking_mau__appservice_requester_allowed_when_not_tracking_ips(self):
self.auth_blocking._max_mau_value = 50
Expand Down Expand Up @@ -419,7 +425,7 @@ def test_blocking_mau__appservice_requester_allowed_when_not_tracking_ips(self):
app_service=appservice,
authenticated_entity="@appservice:server",
)
self.get_success(self.auth.check_auth_blocking(requester=requester))
self.get_success(self.auth_blocking.check_auth_blocking(requester=requester))

def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips(self):
self.auth_blocking._max_mau_value = 50
Expand Down Expand Up @@ -448,7 +454,8 @@ def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips(self):
authenticated_entity="@appservice:server",
)
self.get_failure(
self.auth.check_auth_blocking(requester=requester), ResourceLimitError
self.auth_blocking.check_auth_blocking(requester=requester),
ResourceLimitError,
)

def test_reserved_threepid(self):
Expand All @@ -459,18 +466,21 @@ def test_reserved_threepid(self):
unknown_threepid = {"medium": "email", "address": "[email protected]"}
self.auth_blocking._mau_limits_reserved_threepids = [threepid]

self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError)

self.get_failure(
self.auth.check_auth_blocking(threepid=unknown_threepid), ResourceLimitError
self.auth_blocking.check_auth_blocking(threepid=unknown_threepid),
ResourceLimitError,
)

self.get_success(self.auth.check_auth_blocking(threepid=threepid))
self.get_success(self.auth_blocking.check_auth_blocking(threepid=threepid))

def test_hs_disabled(self):
self.auth_blocking._hs_disabled = True
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
e = self.get_failure(
self.auth_blocking.check_auth_blocking(), ResourceLimitError
)
self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact)
self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEqual(e.value.code, 403)
Expand All @@ -485,7 +495,9 @@ def test_hs_disabled_no_server_notices_user(self):

self.auth_blocking._hs_disabled = True
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
e = self.get_failure(
self.auth_blocking.check_auth_blocking(), ResourceLimitError
)
self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact)
self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEqual(e.value.code, 403)
Expand All @@ -495,4 +507,4 @@ def test_server_notices_mxid_special_cased(self):
user = "@user:server"
self.auth_blocking._server_notices_mxid = user
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
self.get_success(self.auth.check_auth_blocking(user))
self.get_success(self.auth_blocking.check_auth_blocking(user))
2 changes: 1 addition & 1 deletion tests/handlers/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# MAU tests
# AuthBlocking reads from the hs' config on initialization. We need to
# modify its config instead of the hs'
self.auth_blocking = hs.get_auth()._auth_blocking
self.auth_blocking = hs.get_auth_blocking()
self.auth_blocking._max_mau_value = 50

self.small_number_of_users = 1
Expand Down
2 changes: 1 addition & 1 deletion tests/handlers/test_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ async def get_or_create_user(
"""
if localpart is None:
raise SynapseError(400, "Request must include user id")
await self.hs.get_auth().check_auth_blocking()
await self.hs.get_auth_blocking().check_auth_blocking()
need_register = True

try:
Expand Down
2 changes: 1 addition & 1 deletion tests/handlers/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def prepare(self, reactor, clock, hs: HomeServer):

# AuthBlocking reads from the hs' config on initialization. We need to
# modify its config instead of the hs'
self.auth_blocking = self.hs.get_auth()._auth_blocking
self.auth_blocking = self.hs.get_auth_blocking()

def test_wait_for_sync_for_user_auth_blocking(self):
user_id1 = "@user1:test"
Expand Down
22 changes: 14 additions & 8 deletions tests/server_notices/test_resource_limits_server_notices.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ def test_maybe_send_server_notice_to_user_flag_off(self):
def test_maybe_send_server_notice_to_user_remove_blocked_notice(self):
"""Test when user has blocked notice, but should have it removed"""

self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None))
self._rlsn._auth_blocking.check_auth_blocking = Mock(
return_value=make_awaitable(None)
)
mock_event = Mock(
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
Expand All @@ -112,7 +114,7 @@ def test_maybe_send_server_notice_to_user_remove_blocked_notice_noop(self):
"""
Test when user has blocked notice, but notice ought to be there (NOOP)
"""
self._rlsn._auth.check_auth_blocking = Mock(
self._rlsn._auth_blocking.check_auth_blocking = Mock(
return_value=make_awaitable(None),
side_effect=ResourceLimitError(403, "foo"),
)
Expand All @@ -132,7 +134,7 @@ def test_maybe_send_server_notice_to_user_add_blocked_notice(self):
"""
Test when user does not have blocked notice, but should have one
"""
self._rlsn._auth.check_auth_blocking = Mock(
self._rlsn._auth_blocking.check_auth_blocking = Mock(
return_value=make_awaitable(None),
side_effect=ResourceLimitError(403, "foo"),
)
Expand All @@ -145,7 +147,9 @@ def test_maybe_send_server_notice_to_user_add_blocked_notice_noop(self):
"""
Test when user does not have blocked notice, nor should they (NOOP)
"""
self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None))
self._rlsn._auth_blocking.check_auth_blocking = Mock(
return_value=make_awaitable(None)
)

self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))

Expand All @@ -156,7 +160,9 @@ def test_maybe_send_server_notice_to_user_not_in_mau_cohort(self):
Test when user is not part of the MAU cohort - this should not ever
happen - but ...
"""
self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None))
self._rlsn._auth_blocking.check_auth_blocking = Mock(
return_value=make_awaitable(None)
)
self._rlsn._store.user_last_seen_monthly_active = Mock(
return_value=make_awaitable(None)
)
Expand All @@ -170,7 +176,7 @@ def test_maybe_send_server_notice_when_alerting_suppressed_room_unblocked(self):
Test that when server is over MAU limit and alerting is suppressed, then
an alert message is not sent into the room
"""
self._rlsn._auth.check_auth_blocking = Mock(
self._rlsn._auth_blocking.check_auth_blocking = Mock(
return_value=make_awaitable(None),
side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
Expand All @@ -185,7 +191,7 @@ def test_check_hs_disabled_unaffected_by_mau_alert_suppression(self):
"""
Test that when a server is disabled, that MAU limit alerting is ignored.
"""
self._rlsn._auth.check_auth_blocking = Mock(
self._rlsn._auth_blocking.check_auth_blocking = Mock(
return_value=make_awaitable(None),
side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED
Expand All @@ -202,7 +208,7 @@ def test_maybe_send_server_notice_when_alerting_suppressed_room_blocked(self):
When the room is already in a blocked state, test that when alerting
is suppressed that the room is returned to an unblocked state.
"""
self._rlsn._auth.check_auth_blocking = Mock(
self._rlsn._auth_blocking.check_auth_blocking = Mock(
return_value=make_awaitable(None),
side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
Expand Down