diff --git a/DependencyInjection/Security/OAuth2Factory.php b/DependencyInjection/Security/OAuth2Factory.php index dc4404bb..30c2fa77 100644 --- a/DependencyInjection/Security/OAuth2Factory.php +++ b/DependencyInjection/Security/OAuth2Factory.php @@ -23,11 +23,13 @@ public function create(ContainerBuilder $container, $id, $config, $userProvider, $providerId = 'security.authentication.provider.oauth2.' . $id; $container ->setDefinition($providerId, new ChildDefinition(OAuth2Provider::class)) - ->replaceArgument('$userProvider', new Reference($userProvider)); + ->replaceArgument('$userProvider', new Reference($userProvider)) + ->replaceArgument('$providerKey', $id); $listenerId = 'security.authentication.listener.oauth2.' . $id; $container - ->setDefinition($listenerId, new ChildDefinition(OAuth2Listener::class)); + ->setDefinition($listenerId, new ChildDefinition(OAuth2Listener::class)) + ->replaceArgument('$providerKey', $id); return [$providerId, $listenerId, OAuth2EntryPoint::class]; } diff --git a/Security/Authentication/Provider/OAuth2Provider.php b/Security/Authentication/Provider/OAuth2Provider.php index f4a86652..3528b2e7 100644 --- a/Security/Authentication/Provider/OAuth2Provider.php +++ b/Security/Authentication/Provider/OAuth2Provider.php @@ -32,11 +32,21 @@ final class OAuth2Provider implements AuthenticationProviderInterface */ private $oauth2TokenFactory; - public function __construct(UserProviderInterface $userProvider, ResourceServer $resourceServer, OAuth2TokenFactory $oauth2TokenFactory) - { + /** + * @var string + */ + private $providerKey; + + public function __construct( + UserProviderInterface $userProvider, + ResourceServer $resourceServer, + OAuth2TokenFactory $oauth2TokenFactory, + string $providerKey + ) { $this->userProvider = $userProvider; $this->resourceServer = $resourceServer; $this->oauth2TokenFactory = $oauth2TokenFactory; + $this->providerKey = $providerKey; } /** @@ -60,7 +70,7 @@ public function authenticate(TokenInterface $token) $request->getAttribute('oauth_user_id') ); - $token = $this->oauth2TokenFactory->createOAuth2Token($request, $user); + $token = $this->oauth2TokenFactory->createOAuth2Token($request, $user, $this->providerKey); $token->setAuthenticated(true); return $token; @@ -71,7 +81,7 @@ public function authenticate(TokenInterface $token) */ public function supports(TokenInterface $token) { - return $token instanceof OAuth2Token; + return $token instanceof OAuth2Token && $this->providerKey === $token->getProviderKey(); } private function getAuthenticatedUser(string $userIdentifier): ?UserInterface diff --git a/Security/Authentication/Token/OAuth2Token.php b/Security/Authentication/Token/OAuth2Token.php index 028232c9..9af825c8 100644 --- a/Security/Authentication/Token/OAuth2Token.php +++ b/Security/Authentication/Token/OAuth2Token.php @@ -10,8 +10,17 @@ final class OAuth2Token extends AbstractToken { - public function __construct(ServerRequestInterface $serverRequest, ?UserInterface $user, string $rolePrefix) - { + /** + * @var string + */ + private $providerKey; + + public function __construct( + ServerRequestInterface $serverRequest, + ?UserInterface $user, + string $rolePrefix, + string $providerKey + ) { $this->setAttribute('server_request', $serverRequest); $this->setAttribute('role_prefix', $rolePrefix); @@ -25,6 +34,8 @@ public function __construct(ServerRequestInterface $serverRequest, ?UserInterfac } parent::__construct(array_unique($roles)); + + $this->providerKey = $providerKey; } /** @@ -35,6 +46,61 @@ public function getCredentials() return $this->getAttribute('server_request')->getAttribute('oauth_access_token_id'); } + public function getProviderKey(): string + { + return $this->providerKey; + } + + public function __serialize(): array + { + if (method_exists(parent::class, '__serialize')) { + // this code path should be the only code path after dropping support for Symfony 3.4 + return [$this->providerKey, parent::__serialize()]; + } + + return [$this->providerKey, $this->getUser(), $this->isAuthenticated(), $this->getRoles(), $this->getAttributes()]; + } + + public function __unserialize(array $data): void + { + if (method_exists(parent::class, '__unserialize')) { + // this code path should be the only code path after dropping support for Symfony 3.4 + [$this->providerKey, $parentData] = $data; + parent::__unserialize($parentData); + + return; + } + + [$this->providerKey] = $data; + + unset($data[0]); + + parent::unserialize(array_values($data)); + } + + /** + * This entire function can be removed when dropping support for Symfony 3.4 + */ + public function serialize() + { + $serialized = [$this->providerKey, parent::serialize(true)]; + + if (method_exists(parent::class, 'doSerialize')) { + return $this->doSerialize($serialized, \func_num_args() ? func_get_arg(0) : null); + } + + return serialize($serialized); + } + + /** + * This entire function can be removed when dropping support for Symfony 3.4 + */ + public function unserialize($serialized) + { + [$this->providerKey, $parentStr] = \is_array($serialized) ? $serialized : unserialize($serialized); + parent::unserialize($parentStr); + } + private function buildRolesFromScopes(): array { $prefix = $this->getAttribute('role_prefix'); diff --git a/Security/Authentication/Token/OAuth2TokenFactory.php b/Security/Authentication/Token/OAuth2TokenFactory.php index 3144f172..61419706 100644 --- a/Security/Authentication/Token/OAuth2TokenFactory.php +++ b/Security/Authentication/Token/OAuth2TokenFactory.php @@ -19,8 +19,8 @@ public function __construct(string $rolePrefix) $this->rolePrefix = $rolePrefix; } - public function createOAuth2Token(ServerRequestInterface $serverRequest, ?UserInterface $user): OAuth2Token + public function createOAuth2Token(ServerRequestInterface $serverRequest, ?UserInterface $user, string $providerKey): OAuth2Token { - return new OAuth2Token($serverRequest, $user, $this->rolePrefix); + return new OAuth2Token($serverRequest, $user, $this->rolePrefix, $providerKey); } } diff --git a/Security/Firewall/OAuth2Listener.php b/Security/Firewall/OAuth2Listener.php index 7691a96f..c481ffbc 100644 --- a/Security/Firewall/OAuth2Listener.php +++ b/Security/Firewall/OAuth2Listener.php @@ -38,16 +38,23 @@ final class OAuth2Listener implements ListenerInterface */ private $oauth2TokenFactory; + /** + * @var string + */ + private $providerKey; + public function __construct( TokenStorageInterface $tokenStorage, AuthenticationManagerInterface $authenticationManager, HttpMessageFactoryInterface $httpMessageFactory, - OAuth2TokenFactory $oauth2TokenFactory + OAuth2TokenFactory $oauth2TokenFactory, + string $providerKey ) { $this->tokenStorage = $tokenStorage; $this->authenticationManager = $authenticationManager; $this->httpMessageFactory = $httpMessageFactory; $this->oauth2TokenFactory = $oauth2TokenFactory; + $this->providerKey = $providerKey; } /** @@ -68,7 +75,7 @@ public function __invoke(GetResponseEvent $event) try { /** @var OAuth2Token $authenticatedToken */ - $authenticatedToken = $this->authenticationManager->authenticate($this->oauth2TokenFactory->createOAuth2Token($request, null)); + $authenticatedToken = $this->authenticationManager->authenticate($this->oauth2TokenFactory->createOAuth2Token($request, null, $this->providerKey)); } catch (AuthenticationException $e) { throw Oauth2AuthenticationFailedException::create($e->getMessage()); } diff --git a/Tests/Unit/OAuth2ProviderTest.php b/Tests/Unit/OAuth2ProviderTest.php new file mode 100644 index 00000000..f1227f2d --- /dev/null +++ b/Tests/Unit/OAuth2ProviderTest.php @@ -0,0 +1,49 @@ +createMock(UserProviderInterface::class), + $this->createMock(ResourceServer::class), + $tokenFactory, + $providerKey + ); + + $this->assertTrue($provider->supports($this->createToken($tokenFactory, $providerKey))); + $this->assertFalse($provider->supports($this->createToken($tokenFactory, $providerKey . 'bar'))); + } + + private function createToken(OAuth2TokenFactory $tokenFactory, string $providerKey): OAuth2Token + { + $scopes = [FixtureFactory::FIXTURE_SCOPE_FIRST]; + $serverRequest = $this->createMock(ServerRequestInterface::class); + $serverRequest->expects($this->once()) + ->method('getAttribute') + ->with('oauth_scopes', []) + ->willReturn($scopes); + + $user = new User(); + + return $tokenFactory->createOAuth2Token($serverRequest, $user, $providerKey); + } +} diff --git a/Tests/Unit/OAuth2TokenFactoryTest.php b/Tests/Unit/OAuth2TokenFactoryTest.php new file mode 100644 index 00000000..f2d11370 --- /dev/null +++ b/Tests/Unit/OAuth2TokenFactoryTest.php @@ -0,0 +1,43 @@ +createMock(ServerRequestInterface::class); + $serverRequest->expects($this->once()) + ->method('getAttribute') + ->with('oauth_scopes', []) + ->willReturn($scopes); + + $user = new User(); + $providerKey = 'main'; + + $token = $factory->createOAuth2Token($serverRequest, $user, $providerKey); + + $this->assertInstanceOf(OAuth2Token::class, $token); + + $roles = $token->getRoles(); + $this->assertCount(1, $roles); + $this->assertSame($rolePrefix . strtoupper($scopes[0]), $roles[0]->getRole()); + + $this->assertFalse($token->isAuthenticated()); + $this->assertSame($user, $token->getUser()); + $this->assertSame($providerKey, $token->getProviderKey()); + } +} diff --git a/Tests/Unit/OAuth2TokenTest.php b/Tests/Unit/OAuth2TokenTest.php new file mode 100644 index 00000000..f4ad6162 --- /dev/null +++ b/Tests/Unit/OAuth2TokenTest.php @@ -0,0 +1,46 @@ +createMock(ServerRequestInterface::class); + $serverRequest->expects($this->once()) + ->method('getAttribute') + ->with('oauth_scopes', []) + ->willReturn($scopes); + + $user = new User(); + $rolePrefix = 'ROLE_OAUTH2_'; + $providerKey = 'main'; + $token = new OAuth2Token($serverRequest, $user, $rolePrefix, $providerKey); + + /** @var OAuth2Token $unserializedToken */ + $unserializedToken = unserialize(serialize($token)); + + $this->assertSame($providerKey, $unserializedToken->getProviderKey()); + + $roles = $unserializedToken->getRoles(); + $this->assertCount(1, $roles); + $expectedRole = $rolePrefix . strtoupper($scopes[0]); + $this->assertSame($expectedRole, $roles[0]->getRole()); + + $this->assertSame($user->getUsername(), $unserializedToken->getUser()->getUsername()); + $this->assertFalse($unserializedToken->isAuthenticated()); + + if (method_exists($token, 'getRoleNames')) { + $this->assertSame([$expectedRole], $token->getRoleNames()); + } + } +}