Skip to content

Commit

Permalink
Add support for unbinding third-party IDs when they are removed from …
Browse files Browse the repository at this point in the history
…the user account (#3)
  • Loading branch information
anoadragon453 authored Mar 15, 2023
1 parent c9fed2d commit 9158b32
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 27 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
/.tox
_trial_temp
__pycache__
/dist
/dist
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
A module that leverages Sydent's internal bind APIs to automatically record 3PIDs
association on a Sydent instance once it's been verified on the local Synapse homeserver.

This module works with Synapse v1.78.0 and above.
This module works with Synapse v1.79.0 and above.

## Installation

Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ synapse_bind_sydent = py.typed
[options.extras_require]
dev =
# for tests
matrix-synapse
matrix-synapse >= 1.79.0 # first version to support `on_{add,remove}_user_third_party_identifier` module callbacks
tox
twisted
aiounittest
Expand Down
52 changes: 38 additions & 14 deletions synapse_bind_sydent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,15 @@ def __init__(self, config: SydentBinderConfig, api: ModuleApi) -> None:
self._sydent_bind_url = (
f"{config.sydent_base_url}/_matrix/identity/internal/bind"
)
self._sydent_unbind_url = (
f"{config.sydent_base_url}/_matrix/identity/internal/unbind"
)

self._sydent_host = urlparse(config.sydent_base_url).netloc

self._api.register_third_party_rules_callbacks(
on_threepid_bind=self.on_threepid_bind,
on_add_user_third_party_identifier=self.on_add_user_third_party_identifier,
on_remove_user_third_party_identifier=self.on_remove_user_third_party_identifier,
)

@staticmethod
Expand All @@ -59,16 +63,20 @@ def parse_config(config: Dict[str, Any]) -> SydentBinderConfig:

return SydentBinderConfig(**config)

async def on_threepid_bind(self, user_id: str, medium: str, address: str) -> None:
"""Binds the 3PID to Sydent once it's been associated locally."""
# Get the list of 3PIDs for this user.
async def on_add_user_third_party_identifier(
self, user_id: str, medium: str, address: str
) -> None:
"""
Binds a 3PID on the configured Sydent instance when it is locally associated with a user's account.
"""
# Build the body of the internal bind API request.
body = {
"medium": medium,
"address": address,
"mxid": user_id,
}

# Bind the threepid
# Bind the third-party ID against the configured Sydent using the internal bind API
try:
await self._http_client.post_json_get_json(self._sydent_bind_url, body)
except Exception as e:
Expand All @@ -80,12 +88,28 @@ async def on_threepid_bind(self, user_id: str, medium: str, address: str) -> Non
self._sydent_bind_url,
e,
)
return

# Store the association, so we can use this to unbind later.
await self._api.store_remote_3pid_association(
user_id,
medium,
address,
self._sydent_host,
)

async def on_remove_user_third_party_identifier(
self, user_id: str, medium: str, address: str
) -> None:
"""
Unbinds a 3PID from the configured Sydent instance when it is locally removed from a user's account.
"""
body = {
"medium": medium,
"address": address,
"mxid": user_id,
}

# Unbind the third-party ID from the configured Sydent using the internal unbind API
try:
await self._http_client.post_json_get_json(self._sydent_unbind_url, body)
except Exception as e:
# If there was an error, the IS is likely unreachable, so don't try again.
logger.exception(
"Failed to bind %s 3PID %s to identity server at %s: %s",
medium,
address,
self._sydent_bind_url,
e,
)
37 changes: 27 additions & 10 deletions tests/test_binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,33 @@ async def test_new_assoc(self) -> None:

module = create_module(http_mock=http_client)

await module.on_threepid_bind(user_id, medium, address)
await module.on_add_user_third_party_identifier(user_id, medium, address)

self.assertEqual(http_client.post_json_get_json.call_count, 1)
args = http_client.post_json_get_json.call_args[0]
self.assertEqual(
args[1], {"address": address, "medium": medium, "mxid": user_id}
)
(path, body) = http_client.post_json_get_json.call_args[0]
self.assertTrue(path.endswith("/_matrix/identity/internal/bind"))
self.assertEqual(body, {"address": address, "medium": medium, "mxid": user_id})

store_remote_3pid_association: Mock = module._api.store_remote_3pid_association # type: ignore[assignment]
self.assertEqual(store_remote_3pid_association.call_count, 1)
args = store_remote_3pid_association.call_args[0]
self.assertEqual(args, (user_id, medium, address, "test"))
async def test_remove_assoc(self) -> None:
"""Tests that the right function calls are made when the newly registered user has
a single 3PID associated.
"""
http_client = Mock()
http_client.post_json_get_json = Mock(return_value=make_awaitable(None))

address = "[email protected]"
medium = "email"
user_id = "@jdoe:example.com"

module = create_module(http_mock=http_client)

await module.on_remove_user_third_party_identifier(user_id, medium, address)

self.assertEqual(http_client.post_json_get_json.call_count, 1)

(path, body) = http_client.post_json_get_json.call_args[0]
self.assertTrue(path.endswith("/_matrix/identity/internal/unbind"))
self.assertEqual(body, {"address": address, "medium": medium, "mxid": user_id})

async def test_network_error(self) -> None:
"""Tests that the process is aborted right away if an error was raised when trying
Expand All @@ -59,7 +74,9 @@ async def post_json_get_json(*args: Any, **kwargs: Any) -> None:

module = create_module(http_mock=http_client)

await module.on_threepid_bind("@jdoe:matrix.org", "email", "[email protected]")
await module.on_add_user_third_party_identifier(
"@jdoe:matrix.org", "email", "[email protected]"
)

self.assertEqual(http_client.post_json_get_json.call_count, 1)

Expand Down

0 comments on commit 9158b32

Please sign in to comment.