diff --git a/application/src/main/kotlin/tw/waterballsa/gaas/application/repositories/UserRepository.kt b/application/src/main/kotlin/tw/waterballsa/gaas/application/repositories/UserRepository.kt index 3b32cfbe..cebc0d80 100644 --- a/application/src/main/kotlin/tw/waterballsa/gaas/application/repositories/UserRepository.kt +++ b/application/src/main/kotlin/tw/waterballsa/gaas/application/repositories/UserRepository.kt @@ -9,4 +9,6 @@ interface UserRepository { fun createUser(user: User): User fun deleteAll() fun findAllById(ids: Collection): List + fun findByEmail(email: String): User? + fun update(user: User): User } diff --git a/application/src/main/kotlin/tw/waterballsa/gaas/application/usecases/CreateUserUseCase.kt b/application/src/main/kotlin/tw/waterballsa/gaas/application/usecases/CreateUserUseCase.kt index c4e79f6d..dfac7952 100644 --- a/application/src/main/kotlin/tw/waterballsa/gaas/application/usecases/CreateUserUseCase.kt +++ b/application/src/main/kotlin/tw/waterballsa/gaas/application/usecases/CreateUserUseCase.kt @@ -4,6 +4,7 @@ import tw.waterballsa.gaas.application.eventbus.EventBus import tw.waterballsa.gaas.application.repositories.UserRepository import tw.waterballsa.gaas.domain.User import tw.waterballsa.gaas.events.UserCreatedEvent +import java.util.UUID.randomUUID import javax.inject.Named @Named @@ -11,18 +12,38 @@ class CreateUserUseCase( private val userRepository: UserRepository, private val eventBus: EventBus, ) { - fun execute(request: Request) = when { - userRepository.existsUserByEmail(request.email) -> {} - else -> { - val user = userRepository.createUser(request.toUser()) - val event = user.toUserCreatedEvent() - eventBus.broadcast(event) + fun execute(request: Request) { + val user = userRepository.findByEmail(request.email) + + with (request) { + when { + user == null -> createUser() + !user.hasIdentity(identityProviderId) -> user.addUserIdentity(identityProviderId) + } } } - class Request(val email: String) { - fun toUser(): User = User(email = email) + private fun Request.createUser() { + val user = userRepository.createUser(toUser()) + val event = user.toUserCreatedEvent() + eventBus.broadcast(event) + } + + private fun User.addUserIdentity(identityProviderId: String) { + addIdentity(identityProviderId) + userRepository.update(this) } + + class Request( + val email: String, + val identityProviderId: String, + ) + + private fun Request.toUser(): User = User( + email = email, + nickname = "user_${randomUUID()}", + identities = mutableListOf(identityProviderId) + ) } fun User.toUserCreatedEvent(): UserCreatedEvent = UserCreatedEvent(id!!, email, nickname) diff --git a/domain/src/main/kotlin/tw/waterballsa/gaas/domain/User.kt b/domain/src/main/kotlin/tw/waterballsa/gaas/domain/User.kt index c74de0b0..b8f02636 100644 --- a/domain/src/main/kotlin/tw/waterballsa/gaas/domain/User.kt +++ b/domain/src/main/kotlin/tw/waterballsa/gaas/domain/User.kt @@ -1,11 +1,19 @@ package tw.waterballsa.gaas.domain -import tw.waterballsa.gaas.domain.Room.Player class User( val id: Id? = null, - val email: String, - var nickname: String = "", + val email: String = "", + val nickname: String = "", + val identities: MutableList = mutableListOf() ) { @JvmInline value class Id(val value: String) + + fun hasIdentity(identityProviderId: String): Boolean { + return identities.contains(identityProviderId) + } + + fun addIdentity(identityProviderId: String) { + identities.add(identityProviderId) + } } diff --git a/spring/src/main/kotlin/tw/waterballsa/gaas/spring/configs/securities/CustomSuccessHandler.kt b/spring/src/main/kotlin/tw/waterballsa/gaas/spring/configs/securities/CustomSuccessHandler.kt index e6be1aaf..c07d5187 100644 --- a/spring/src/main/kotlin/tw/waterballsa/gaas/spring/configs/securities/CustomSuccessHandler.kt +++ b/spring/src/main/kotlin/tw/waterballsa/gaas/spring/configs/securities/CustomSuccessHandler.kt @@ -25,8 +25,10 @@ class CustomSuccessHandler( ) { authentication as OAuth2AuthenticationToken - val email = authentication.principal.let { it as OidcUser }.email - createUserUseCase.execute(CreateUserUseCase.Request(email)) + val oidcUser = authentication.principal.let { it as OidcUser } + createUserUseCase.execute( + CreateUserUseCase.Request(oidcUser.email, oidcUser.subject) + ) val accessTokenValue = authorizedClientService.loadAuthorizedClient( authentication.authorizedClientRegistrationId, diff --git a/spring/src/main/kotlin/tw/waterballsa/gaas/spring/controllers/OAuth2Controller.kt b/spring/src/main/kotlin/tw/waterballsa/gaas/spring/controllers/OAuth2Controller.kt index 61cb4307..6bb02183 100644 --- a/spring/src/main/kotlin/tw/waterballsa/gaas/spring/controllers/OAuth2Controller.kt +++ b/spring/src/main/kotlin/tw/waterballsa/gaas/spring/controllers/OAuth2Controller.kt @@ -7,6 +7,7 @@ import org.springframework.http.ResponseEntity import org.springframework.security.core.annotation.AuthenticationPrincipal import org.springframework.security.oauth2.core.oidc.OidcIdToken import org.springframework.security.oauth2.core.oidc.user.OidcUser +import org.springframework.security.oauth2.jwt.Jwt import org.springframework.web.bind.annotation.GetMapping import org.springframework.web.bind.annotation.RestController import org.springframework.web.util.UriComponentsBuilder @@ -23,9 +24,9 @@ class OAuth2Controller( private lateinit var frontendUrl: String @GetMapping - fun home(@AuthenticationPrincipal principal: OidcUser?): String { + fun home(@AuthenticationPrincipal principal: Jwt): String { createUserUseCase.execute(principal.toRequest()) - return principal?.idToken?.tokenValue ?: "index" + return principal.tokenValue ?: "index" } @GetMapping("/login-successfully") @@ -43,6 +44,8 @@ class OAuth2Controller( } } -private fun OidcUser?.toRequest(): CreateUserUseCase.Request = CreateUserUseCase.Request( - this?.userInfo?.email ?: throw PlatformException("User email is null") -) +fun Jwt.toRequest(): CreateUserUseCase.Request = + CreateUserUseCase.Request( + email = claims["email"] as String? ?: throw PlatformException("JWT email should exist."), + identityProviderId = subject ?: throw PlatformException("JWT subject should exist.") + ) diff --git a/spring/src/main/kotlin/tw/waterballsa/gaas/spring/repositories/SpringUserRepository.kt b/spring/src/main/kotlin/tw/waterballsa/gaas/spring/repositories/SpringUserRepository.kt index 235098e3..d24c50ca 100644 --- a/spring/src/main/kotlin/tw/waterballsa/gaas/spring/repositories/SpringUserRepository.kt +++ b/spring/src/main/kotlin/tw/waterballsa/gaas/spring/repositories/SpringUserRepository.kt @@ -25,4 +25,10 @@ class SpringUserRepository( override fun findAllById(ids: Collection): List = userDAO.findAllById(ids.map(User.Id::value)).map(UserData::toDomain) + + override fun findByEmail(email: String): User? = + userDAO.findByEmail(email)?.toDomain() + + override fun update(user: User): User = + userDAO.save(user.toData()).toDomain() } diff --git a/spring/src/main/kotlin/tw/waterballsa/gaas/spring/repositories/dao/UserDAO.kt b/spring/src/main/kotlin/tw/waterballsa/gaas/spring/repositories/dao/UserDAO.kt index 21ba0392..05165457 100644 --- a/spring/src/main/kotlin/tw/waterballsa/gaas/spring/repositories/dao/UserDAO.kt +++ b/spring/src/main/kotlin/tw/waterballsa/gaas/spring/repositories/dao/UserDAO.kt @@ -7,4 +7,5 @@ import tw.waterballsa.gaas.spring.repositories.data.UserData @Repository interface UserDAO : MongoRepository { fun existsByEmail(email: String): Boolean + fun findByEmail(email: String): UserData? } diff --git a/spring/src/main/kotlin/tw/waterballsa/gaas/spring/repositories/data/UserData.kt b/spring/src/main/kotlin/tw/waterballsa/gaas/spring/repositories/data/UserData.kt index f1fb411a..b40dfc56 100644 --- a/spring/src/main/kotlin/tw/waterballsa/gaas/spring/repositories/data/UserData.kt +++ b/spring/src/main/kotlin/tw/waterballsa/gaas/spring/repositories/data/UserData.kt @@ -8,15 +8,17 @@ import tw.waterballsa.gaas.domain.User class UserData( @Id var id: String? = null, - private var email: String? = null, - var nickname: String? = null + val email: String = "", + val nickname: String = "", + val identities: List = emptyList(), ) { fun toDomain(): User = User( User.Id(id!!), - email!!, - nickname!! + email, + nickname, + identities.toMutableList() ) } @@ -24,5 +26,6 @@ fun User.toData(): UserData = UserData( id = id?.value, email = email, - nickname = nickname + nickname = nickname, + identities = identities ) diff --git a/spring/src/test/kotlin/tw/waterballsa/gaas/spring/it/AbstractSpringBootTest.kt b/spring/src/test/kotlin/tw/waterballsa/gaas/spring/it/AbstractSpringBootTest.kt index 47b18f51..d2982417 100644 --- a/spring/src/test/kotlin/tw/waterballsa/gaas/spring/it/AbstractSpringBootTest.kt +++ b/spring/src/test/kotlin/tw/waterballsa/gaas/spring/it/AbstractSpringBootTest.kt @@ -6,9 +6,12 @@ import org.springframework.beans.factory.annotation.Autowired import org.springframework.boot.test.autoconfigure.web.servlet.AutoConfigureMockMvc import org.springframework.boot.test.context.SpringBootTest import org.springframework.http.MediaType.APPLICATION_JSON +import org.springframework.security.oauth2.jwt.Jwt +import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.jwt import org.springframework.test.web.servlet.MockMvc import org.springframework.test.web.servlet.ResultActions import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder +import tw.waterballsa.gaas.domain.User @SpringBootTest @AutoConfigureMockMvc @@ -19,6 +22,20 @@ abstract class AbstractSpringBootTest { @Autowired private lateinit var objectMapper: ObjectMapper + protected final val mockUser: User = User( + User.Id("1"), + "user@example.com", + "user-437b200d-da9c-449e-b147-114b4822b5aa", + mutableListOf("google-oauth2|102527320242660434908") + ) + + protected final fun String.toJwt(): Jwt = + Jwt.withTokenValue("mock-token") + .header("alg", "none") + .subject(this) + .claim("email", mockUser.email) + .build() + protected fun ResultActions.getBody(type: Class): T = andReturn().response.contentAsString.let { objectMapper.readValue(it, type) } @@ -29,4 +46,8 @@ abstract class AbstractSpringBootTest { protected fun MockHttpServletRequestBuilder.withJson(request: Any): MockHttpServletRequestBuilder = contentType(APPLICATION_JSON).content(request.toJson()) + + protected fun MockHttpServletRequestBuilder.withJwt(jwt: Jwt): MockHttpServletRequestBuilder = + with(jwt().jwt(jwt)) + } diff --git a/spring/src/test/kotlin/tw/waterballsa/gaas/spring/it/controllers/OAuth2ControllerTest.kt b/spring/src/test/kotlin/tw/waterballsa/gaas/spring/it/controllers/OAuth2ControllerTest.kt index e7acfaa2..0525f459 100644 --- a/spring/src/test/kotlin/tw/waterballsa/gaas/spring/it/controllers/OAuth2ControllerTest.kt +++ b/spring/src/test/kotlin/tw/waterballsa/gaas/spring/it/controllers/OAuth2ControllerTest.kt @@ -4,11 +4,7 @@ import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test import org.springframework.beans.factory.annotation.Autowired -import org.springframework.security.oauth2.core.oidc.OidcIdToken -import org.springframework.security.oauth2.core.oidc.OidcUserInfo -import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser -import org.springframework.security.oauth2.core.oidc.user.OidcUser -import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.oidcLogin +import org.springframework.security.oauth2.jwt.Jwt import org.springframework.test.web.servlet.ResultActions import org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get import org.springframework.test.web.servlet.result.MockMvcResultMatchers.status @@ -20,63 +16,88 @@ class OAuth2ControllerTest @Autowired constructor( val userRepository: UserRepository, ) : AbstractSpringBootTest() { + private final val googleIdentityProviderId = "google-oauth2|102527320242660434908" + private final val discordIdentityProviderId = "discord|102527320242660434908" + + private final val googleOAuth2Jwt = googleIdentityProviderId.toJwt() + private final val discordOAuth2Jwt = discordIdentityProviderId.toJwt() + + private final val invalidJwt = Jwt( + "invalid_token", + null, + null, + mapOf("alg" to "none"), + mapOf("no_email" to "none") + ) + @BeforeEach fun cleanUp() { userRepository.deleteAll() } @Test - fun givenInvalidUserInfo_whenUserLogin_thenShouldLoginFailed() { - givenInvalidUserInfo() - .whenLogin() + fun whenUserLoginWithInvalidJwt_thenShouldLoginFailed() { + whenUserLogin(invalidJwt) .thenShouldLoginFailed() } @Test - fun givenNewUser_whenUserLogin_thenCreateUser() { - givenNewUserInfo().assertUserNotExists() - .whenLogin() - .thenShouldLoginSuccessfully() + fun givenUserHasLoggedInViaGoogle_whenUserLoginWithGoogleOAuth2Jwt_thenLoginSuccessfully() { + givenUserHasLoggedInViaGoogle() + whenUserLogin(googleOAuth2Jwt) + .thenLoginSuccessfully() } @Test - fun givenOldUser_whenUserLogin_thenShouldLoginSuccessfully() { - givenOldUserInfo().assertUserExists() - .whenLogin() - .thenShouldLoginSuccessfully() + fun givenUserHasLoggedInViaGoogle_whenUserLoginWithDiscordOAuth2Jwt_thenUserHaveNewIdentity() { + givenUserHasLoggedInViaGoogle() + whenUserLogin(discordOAuth2Jwt) + .thenUserHaveNewIdentity(googleIdentityProviderId, discordIdentityProviderId) } - private fun givenInvalidUserInfo(): OidcUser = givenUserInfo(null) + @Test + fun whenUserLoginAtTheFirstTime_thenCreateNewUser() { + whenUserLogin(googleOAuth2Jwt) + .thenCreateNewUser() + } - private fun givenNewUserInfo(): OidcUser = givenUserInfo(OidcUserInfo(mapOf("email" to "user@example.com"))) + private fun givenUserHasLoggedInViaGoogle(): User = + userRepository.createUser(mockUser) - private fun givenOldUserInfo(): OidcUser { - val userInfo = givenUserInfo(OidcUserInfo(mapOf("email" to "other@example.com"))) - userRepository.createUser(User(email = userInfo.email)) - return userInfo - } + private fun whenUserLogin(jwt: Jwt): ResultActions = + mockMvc.perform(get("/").withJwt(jwt)) - private fun givenUserInfo(oidcUserInfo: OidcUserInfo?): OidcUser = DefaultOidcUser( - listOf(), - OidcIdToken("token", null, null, mapOf("sub" to "my_sub")), - oidcUserInfo - ) + private fun ResultActions.thenShouldLoginFailed() { + andExpect(status().isBadRequest) + } - private fun OidcUser.assertUserExists(): OidcUser = this.also { - assertThat(userRepository.existsUserByEmail(userInfo.email)).isTrue() + private fun ResultActions.thenLoginSuccessfully() { + andExpect(status().isOk) } - private fun OidcUser.assertUserNotExists(): OidcUser = this.also { - assertThat(userRepository.existsUserByEmail(userInfo.email)).isFalse() + private fun ResultActions.thenUserHaveNewIdentity(vararg identityProviderIds: String) { + thenLoginSuccessfully() + userRepository.findByEmail(mockUser.email) + ?.thenWouldHaveIdentityProviderIds(*identityProviderIds) } - private fun OidcUser.whenLogin(): ResultActions = - mockMvc.perform(get("/").with(oidcLogin().oidcUser(this))) + private fun ResultActions.thenCreateNewUser() { + thenLoginSuccessfully() + userRepository.findByEmail(mockUser.email) + .thenNickNameShouldBeRandomName() + .thenWouldHaveIdentityProviderIds(googleIdentityProviderId) + } - private fun ResultActions.thenShouldLoginSuccessfully(): ResultActions = - andExpect(status().isOk) + private fun User?.thenWouldHaveIdentityProviderIds(vararg identityProviderIds: String): User { + assertThat(this).isNotNull + assertThat(this!!.identities).containsAll(identityProviderIds.toList()) + return this + } - private fun ResultActions.thenShouldLoginFailed(): ResultActions = - andExpect(status().isBadRequest) + private fun User?.thenNickNameShouldBeRandomName(): User { + assertThat(this).isNotNull + assertThat(this!!.nickname).startsWith("user_") + return this + } }