From cacd4afbbba4fdd412de8118e00dce17477c78d2 Mon Sep 17 00:00:00 2001 From: adel-signal Date: Mon, 21 Oct 2024 11:59:29 -0700 Subject: [PATCH] Add /v2/calling/relays This supports returning IceServers from multiple providers at once --- .../textsecuregcm/WhisperServerService.java | 2 + .../controllers/CallRoutingController.java | 5 + .../controllers/CallRoutingControllerV2.java | 110 +++++++++ .../CallRoutingControllerV2Test.java | 219 ++++++++++++++++++ 4 files changed, 336 insertions(+) create mode 100644 service/src/main/java/org/whispersystems/textsecuregcm/controllers/CallRoutingControllerV2.java create mode 100644 service/src/test/java/org/whispersystems/textsecuregcm/controllers/CallRoutingControllerV2Test.java diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java index d9d4555ac..09f2842ac 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/WhisperServerService.java @@ -113,6 +113,7 @@ import org.whispersystems.textsecuregcm.controllers.AttachmentControllerV4; import org.whispersystems.textsecuregcm.controllers.CallLinkController; import org.whispersystems.textsecuregcm.controllers.CallRoutingController; +import org.whispersystems.textsecuregcm.controllers.CallRoutingControllerV2; import org.whispersystems.textsecuregcm.controllers.CertificateController; import org.whispersystems.textsecuregcm.controllers.ChallengeController; import org.whispersystems.textsecuregcm.controllers.DeviceController; @@ -1099,6 +1100,7 @@ protected void configureServer(final ServerBuilder serverBuilder) { experimentEnrollmentManager), new ArchiveController(backupAuthManager, backupManager), new CallRoutingController(rateLimiters, callRouter, turnTokenGenerator, experimentEnrollmentManager, cloudflareTurnCredentialsManager), + new CallRoutingControllerV2(rateLimiters, callRouter, turnTokenGenerator, experimentEnrollmentManager, cloudflareTurnCredentialsManager), new CallLinkController(rateLimiters, callingGenericZkSecretParams), new CertificateController(new CertificateGenerator(config.getDeliveryCertificate().certificate().value(), config.getDeliveryCertificate().ecPrivateKey(), config.getDeliveryCertificate().expiresDays()), diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/CallRoutingController.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/CallRoutingController.java index f158149a8..fb1548101 100644 --- a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/CallRoutingController.java +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/CallRoutingController.java @@ -1,3 +1,8 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + package org.whispersystems.textsecuregcm.controllers; import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; diff --git a/service/src/main/java/org/whispersystems/textsecuregcm/controllers/CallRoutingControllerV2.java b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/CallRoutingControllerV2.java new file mode 100644 index 000000000..de674b554 --- /dev/null +++ b/service/src/main/java/org/whispersystems/textsecuregcm/controllers/CallRoutingControllerV2.java @@ -0,0 +1,110 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.controllers; + +import static org.whispersystems.textsecuregcm.metrics.MetricsUtil.name; + +import io.dropwizard.auth.Auth; +import io.micrometer.core.instrument.Counter; +import io.micrometer.core.instrument.Metrics; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.responses.ApiResponse; +import java.io.IOException; +import java.net.InetAddress; +import java.net.UnknownHostException; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.UUID; +import javax.ws.rs.GET; +import javax.ws.rs.Path; +import javax.ws.rs.Produces; +import javax.ws.rs.container.ContainerRequestContext; +import javax.ws.rs.core.Context; +import javax.ws.rs.core.MediaType; +import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; +import org.whispersystems.textsecuregcm.auth.CloudflareTurnCredentialsManager; +import org.whispersystems.textsecuregcm.auth.TurnToken; +import org.whispersystems.textsecuregcm.auth.TurnTokenGenerator; +import org.whispersystems.textsecuregcm.calls.routing.TurnCallRouter; +import org.whispersystems.textsecuregcm.calls.routing.TurnServerOptions; +import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; +import org.whispersystems.textsecuregcm.filters.RemoteAddressFilter; +import org.whispersystems.textsecuregcm.limits.RateLimiters; +import org.whispersystems.websocket.auth.ReadOnly; + +@io.swagger.v3.oas.annotations.tags.Tag(name = "Calling") +@Path("/v2/calling") +public class CallRoutingControllerV2 { + + private static final int TURN_INSTANCE_LIMIT = 2; + private static final Counter INVALID_IP_COUNTER = Metrics.counter(name(CallRoutingControllerV2.class, "invalidIP")); + private final RateLimiters rateLimiters; + private final TurnCallRouter turnCallRouter; + private final TurnTokenGenerator tokenGenerator; + private final ExperimentEnrollmentManager experimentEnrollmentManager; + private final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager; + + public CallRoutingControllerV2( + final RateLimiters rateLimiters, + final TurnCallRouter turnCallRouter, + final TurnTokenGenerator tokenGenerator, + final ExperimentEnrollmentManager experimentEnrollmentManager, + final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager + ) { + this.rateLimiters = rateLimiters; + this.turnCallRouter = turnCallRouter; + this.tokenGenerator = tokenGenerator; + this.experimentEnrollmentManager = experimentEnrollmentManager; + this.cloudflareTurnCredentialsManager = cloudflareTurnCredentialsManager; + } + + @GET + @Path("/relays") + @Produces(MediaType.APPLICATION_JSON) + @Operation( + summary = "Get 1:1 calling relay options for the client", + description = """ + Get 1:1 relay addresses in IpV4, Ipv6, and URL formats. + """ + ) + @ApiResponse(responseCode = "200", description = "`JSON` with call endpoints.", useReturnTypeSchema = true) + @ApiResponse(responseCode = "400", description = "Invalid get call endpoint request.") + @ApiResponse(responseCode = "401", description = "Account authentication check failed.") + @ApiResponse(responseCode = "422", description = "Invalid request format.") + @ApiResponse(responseCode = "429", description = "Rate limited.") + public GetCallingRelaysResponse getCallingRelays( + final @ReadOnly @Auth AuthenticatedDevice auth, + @Context ContainerRequestContext requestContext + ) throws RateLimitExceededException, IOException { + UUID aci = auth.getAccount().getUuid(); + rateLimiters.getCallEndpointLimiter().validate(aci); + + List tokens = new ArrayList<>(); + if (experimentEnrollmentManager.isEnrolled(auth.getAccount().getNumber(), aci, "cloudflareTurn")) { + tokens.add(cloudflareTurnCredentialsManager.retrieveFromCloudflare()); + } + + Optional address = Optional.empty(); + try { + final String remoteAddress = (String) requestContext.getProperty( + RemoteAddressFilter.REMOTE_ADDRESS_ATTRIBUTE_NAME); + address = Optional.of(InetAddress.getByName(remoteAddress)); + } catch (UnknownHostException e) { + INVALID_IP_COUNTER.increment(); + } + + TurnServerOptions options = turnCallRouter.getRoutingFor(aci, address, TURN_INSTANCE_LIMIT); + tokens.add(tokenGenerator.generateWithTurnServerOptions(options)); + + return new GetCallingRelaysResponse(tokens); + } + + public record GetCallingRelaysResponse( + List relays + ) { + } +} diff --git a/service/src/test/java/org/whispersystems/textsecuregcm/controllers/CallRoutingControllerV2Test.java b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/CallRoutingControllerV2Test.java new file mode 100644 index 000000000..55f5d93ea --- /dev/null +++ b/service/src/test/java/org/whispersystems/textsecuregcm/controllers/CallRoutingControllerV2Test.java @@ -0,0 +1,219 @@ +/* + * Copyright 2024 Signal Messenger, LLC + * SPDX-License-Identifier: AGPL-3.0-only + */ + +package org.whispersystems.textsecuregcm.controllers; + +import io.dropwizard.auth.AuthValueFactoryProvider; +import io.dropwizard.testing.junit5.DropwizardExtensionsSupport; +import io.dropwizard.testing.junit5.ResourceExtension; +import org.glassfish.jersey.test.grizzly.GrizzlyWebTestContainerFactory; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.whispersystems.textsecuregcm.auth.AuthenticatedDevice; +import org.whispersystems.textsecuregcm.auth.CloudflareTurnCredentialsManager; +import org.whispersystems.textsecuregcm.auth.TurnToken; +import org.whispersystems.textsecuregcm.auth.TurnTokenGenerator; +import org.whispersystems.textsecuregcm.calls.routing.TurnCallRouter; +import org.whispersystems.textsecuregcm.calls.routing.TurnServerOptions; +import org.whispersystems.textsecuregcm.configuration.dynamic.DynamicConfiguration; +import org.whispersystems.textsecuregcm.experiment.ExperimentEnrollmentManager; +import org.whispersystems.textsecuregcm.limits.RateLimiter; +import org.whispersystems.textsecuregcm.limits.RateLimiters; +import org.whispersystems.textsecuregcm.mappers.RateLimitExceededExceptionMapper; +import org.whispersystems.textsecuregcm.storage.DynamicConfigurationManager; +import org.whispersystems.textsecuregcm.tests.util.AuthHelper; +import org.whispersystems.textsecuregcm.util.SystemMapper; +import org.whispersystems.textsecuregcm.util.TestRemoteAddressFilterProvider; +import javax.ws.rs.core.Response; +import java.io.IOException; +import java.net.InetAddress; +import java.net.UnknownHostException; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.Optional; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.*; + +@ExtendWith(DropwizardExtensionsSupport.class) +class CallRoutingControllerV2Test { + + private static final String GET_CALL_RELAYS_PATH = "v2/calling/relays"; + private static final String REMOTE_ADDRESS = "123.123.123.1"; + private static final TurnServerOptions TURN_SERVER_OPTIONS = new TurnServerOptions( + "example.domain.org", + List.of("stun:12.34.56.78"), + List.of("stun:example.domain.org") + ); + private static final TurnToken CLOUDFLARE_TURN_TOKEN = new TurnToken( + "ABC", + "XYZ", + List.of("turn:cloudflare.example.com:3478?transport=udp"), + null, + "cf.example.com"); + + private static final RateLimiters rateLimiters = mock(RateLimiters.class); + private static final RateLimiter getCallEndpointLimiter = mock(RateLimiter.class); + private static final DynamicConfigurationManager dynamicConfigurationManager = mock( + DynamicConfigurationManager.class); + private static final ExperimentEnrollmentManager experimentEnrollmentManager = mock( + ExperimentEnrollmentManager.class); + private static final TurnTokenGenerator turnTokenGenerator = new TurnTokenGenerator(dynamicConfigurationManager, + "bloop".getBytes(StandardCharsets.UTF_8)); + private static final CloudflareTurnCredentialsManager cloudflareTurnCredentialsManager = mock( + CloudflareTurnCredentialsManager.class); + private static final TurnCallRouter turnCallRouter = mock(TurnCallRouter.class); + + private static final ResourceExtension resources = ResourceExtension.builder() + .addProvider(AuthHelper.getAuthFilter()) + .addProvider(new AuthValueFactoryProvider.Binder<>(AuthenticatedDevice.class)) + .addProvider(new RateLimitExceededExceptionMapper()) + .addProvider(new TestRemoteAddressFilterProvider(REMOTE_ADDRESS)) + .setMapper(SystemMapper.jsonMapper()) + .setTestContainerFactory(new GrizzlyWebTestContainerFactory()) + .addResource(new CallRoutingControllerV2(rateLimiters, turnCallRouter, turnTokenGenerator, + experimentEnrollmentManager, cloudflareTurnCredentialsManager)) + .build(); + + @BeforeEach + void setup() { + when(rateLimiters.getCallEndpointLimiter()).thenReturn(getCallEndpointLimiter); + } + + @AfterEach + void tearDown() { + reset(experimentEnrollmentManager, dynamicConfigurationManager, rateLimiters, getCallEndpointLimiter, + turnCallRouter); + } + + void initializeMocksWith(Optional signalTurn, Optional cloudflare) { + signalTurn.ifPresent(options -> { + try { + when(turnCallRouter.getRoutingFor( + eq(AuthHelper.VALID_UUID), + eq(Optional.of(InetAddress.getByName(REMOTE_ADDRESS))), + anyInt()) + ).thenReturn(options); + } catch (UnknownHostException ignored) { + } + }); + cloudflare.ifPresent(token -> { + when(experimentEnrollmentManager.isEnrolled(AuthHelper.VALID_NUMBER, AuthHelper.VALID_UUID, "cloudflareTurn")) + .thenReturn(true); + try { + when(cloudflareTurnCredentialsManager.retrieveFromCloudflare()).thenReturn(token); + } catch (IOException ignored) { + } + }); + } + + @Test + void testGetRelaysSignalRoutingOnly() { + TurnServerOptions options = TURN_SERVER_OPTIONS; + initializeMocksWith(Optional.of(options), Optional.empty()); + + try (Response rawResponse = resources.getJerseyTest() + .target(GET_CALL_RELAYS_PATH) + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .get()) { + + assertThat(rawResponse.getStatus()).isEqualTo(200); + + CallRoutingControllerV2.GetCallingRelaysResponse response = rawResponse.readEntity( + CallRoutingControllerV2.GetCallingRelaysResponse.class); + List relays = response.relays(); + assertThat(relays).hasSize(1); + assertThat(relays.getFirst().username()).isNotEmpty(); + assertThat(relays.getFirst().password()).isNotEmpty(); + assertThat(relays.getFirst().hostname()).isEqualTo(options.hostname()); + assertThat(relays.getFirst().urlsWithIps()).isEqualTo(options.urlsWithIps()); + assertThat(relays.getFirst().urls()).isEqualTo(options.urlsWithHostname()); + } + } + + @Test + void testGetRelaysBothRouting() { + TurnServerOptions options = TURN_SERVER_OPTIONS; + initializeMocksWith(Optional.of(options), Optional.of(CLOUDFLARE_TURN_TOKEN)); + + try (Response rawResponse = resources.getJerseyTest() + .target(GET_CALL_RELAYS_PATH) + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .get()) { + + assertThat(rawResponse.getStatus()).isEqualTo(200); + + CallRoutingControllerV2.GetCallingRelaysResponse response = rawResponse.readEntity( + CallRoutingControllerV2.GetCallingRelaysResponse.class); + + List relays = response.relays(); + assertThat(relays).hasSize(2); + + assertThat(relays.getFirst()).isEqualTo(CLOUDFLARE_TURN_TOKEN); + + TurnToken token = relays.get(1); + assertThat(token.username()).isNotEmpty(); + assertThat(token.password()).isNotEmpty(); + assertThat(token.hostname()).isEqualTo(options.hostname()); + assertThat(token.urlsWithIps()).isEqualTo(options.urlsWithIps()); + assertThat(token.urls()).isEqualTo(options.urlsWithHostname()); + } + } + + @Test + void testGetRelaysInvalidIpSuccess() throws UnknownHostException { + TurnServerOptions options = new TurnServerOptions( + "example.domain.org", + List.of(), + List.of("stun:example.domain.org") + ); + + when(turnCallRouter.getRoutingFor( + eq(AuthHelper.VALID_UUID), + eq(Optional.of(InetAddress.getByName(REMOTE_ADDRESS))), + anyInt()) + ).thenReturn(options); + try (Response rawResponse = resources.getJerseyTest() + .target(GET_CALL_RELAYS_PATH) + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .get()) { + + assertThat(rawResponse.getStatus()).isEqualTo(200); + CallRoutingControllerV2.GetCallingRelaysResponse response = rawResponse.readEntity( + CallRoutingControllerV2.GetCallingRelaysResponse.class + ); + + assertThat(response.relays()).hasSize(1); + TurnToken token = response.relays().getFirst(); + assertThat(token.username()).isNotEmpty(); + assertThat(token.password()).isNotEmpty(); + assertThat(token.hostname()).isEqualTo(options.hostname()); + assertThat(token.urlsWithIps()).isEqualTo(options.urlsWithIps()); + assertThat(token.urls()).isEqualTo(options.urlsWithHostname()); + } + } + + @Test + void testGetRelaysRateLimited() throws RateLimitExceededException { + doThrow(new RateLimitExceededException(null)) + .when(getCallEndpointLimiter).validate(AuthHelper.VALID_UUID); + + try (final Response response = resources.getJerseyTest() + .target(GET_CALL_RELAYS_PATH) + .request() + .header("Authorization", AuthHelper.getAuthHeader(AuthHelper.VALID_UUID, AuthHelper.VALID_PASSWORD)) + .get()) { + + assertThat(response.getStatus()).isEqualTo(429); + } + } +}