Skip to content

Commit

Permalink
feat: save multiple identity provider ids for same email
Browse files Browse the repository at this point in the history
  • Loading branch information
lohas1107 committed Jun 23, 2023
1 parent 9915515 commit df6f0c2
Show file tree
Hide file tree
Showing 10 changed files with 121 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ import tw.waterballsa.gaas.domain.User.Id

interface UserRepository {
fun findById(id: Id): User?
fun existsByIdentitiesIn(identityProviderId: String): Boolean
fun existsUserByEmail(email: String): Boolean
fun createUser(user: User): User
fun deleteAll()
fun findAllById(ids: Collection<Id>): List<User>
fun findByEmail(email: String): User?
fun update(user: User): User
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,27 @@ class CreateUserUseCase(
private val userRepository: UserRepository,
private val eventBus: EventBus,
) {
fun execute(request: Request) = when {
userRepository.existsByIdentitiesIn(request.identityProviderId) -> {}
else -> {
val user = userRepository.createUser(request.toUser())
val event = user.toUserCreatedEvent()
eventBus.broadcast(event)
fun execute(request: Request) {
var user = userRepository.findByEmail(request.email)

when {
user == null -> {
user = userRepository.createUser(request.toUser())
val event = user.toUserCreatedEvent()
eventBus.broadcast(event)
}
user.doesNotHaveIdentity(request.identityProviderId) -> {
user = user.addIdentity(request.identityProviderId)
userRepository.update(user)
}
}
}

class Request(val identityProviderId: String) {
fun toUser(): User = User(identities = listOf(identityProviderId))
class Request(
val email: String,
val identityProviderId: String
) {
fun toUser(): User = User(email = email, identities = listOf(identityProviderId))
}
}

Expand Down
11 changes: 11 additions & 0 deletions domain/src/main/kotlin/tw/waterballsa/gaas/domain/User.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,15 @@ class User(
) {
@JvmInline
value class Id(val value: String)

fun hasIdentity(identityProviderId: String): Boolean {
return identities.contains(identityProviderId)
}

fun doesNotHaveIdentity(identityProviderId: String): Boolean {
return !hasIdentity(identityProviderId)
}

fun addIdentity(identityProviderId: String): User =
User(id, email, nickname, identities + identityProviderId)
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ import org.springframework.security.core.Authentication
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient
import org.springframework.security.oauth2.client.OAuth2AuthorizedClientService
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken
import org.springframework.security.oauth2.core.oidc.user.OidcUser
import org.springframework.security.oauth2.jwt.Jwt
import org.springframework.security.web.authentication.AuthenticationSuccessHandler
import tw.waterballsa.gaas.application.usecases.CreateUserUseCase
import tw.waterballsa.gaas.spring.controllers.toRequest
import javax.servlet.http.HttpServletRequest
import javax.servlet.http.HttpServletResponse

Expand All @@ -25,8 +26,8 @@ class CustomSuccessHandler(
) {
authentication as OAuth2AuthenticationToken

val email = authentication.principal.let { it as OidcUser }.email
createUserUseCase.execute(CreateUserUseCase.Request(email))
val jwt = authentication.principal.let { it as Jwt }
createUserUseCase.execute(jwt.toRequest())

val accessTokenValue = authorizedClientService.loadAuthorizedClient<OAuth2AuthorizedClient>(
authentication.authorizedClientRegistrationId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ class OAuth2Controller(
}
}

private fun Jwt.toRequest(): CreateUserUseCase.Request =
fun Jwt.toRequest(): CreateUserUseCase.Request =
CreateUserUseCase.Request(
this.subject ?: throw PlatformException("JWT subject is null")
email = claims["email"] as String? ?: throw PlatformException("JWT email should exist."),
identityProviderId = subject ?: throw PlatformException("JWT subject should exist.")
)
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ class SpringUserRepository(
override fun findById(id: User.Id): User? =
userDAO.findById(id.value).mapOrNull(UserData::toDomain)

override fun existsByIdentitiesIn(identityProviderId: String): Boolean =
userDAO.existsByIdentitiesIn(mutableListOf(mutableListOf(identityProviderId)))

override fun existsUserByEmail(email: String): Boolean = userDAO.existsByEmail(email)

override fun createUser(user: User): User = userDAO.save(user.toData()).toDomain()
Expand All @@ -28,4 +25,10 @@ class SpringUserRepository(

override fun findAllById(ids: Collection<User.Id>): List<User> =
userDAO.findAllById(ids.map(User.Id::value)).map(UserData::toDomain)

override fun findByEmail(email: String): User? =
userDAO.findByEmail(email)?.toDomain()

override fun update(user: User): User =
userDAO.save(user.toData()).toDomain()
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ import tw.waterballsa.gaas.spring.repositories.data.UserData
@Repository
interface UserDAO : MongoRepository<UserData, String> {
fun existsByEmail(email: String): Boolean
fun existsByIdentitiesIn(identities: MutableCollection<MutableList<String>>): Boolean
fun findByEmail(email: String): UserData?
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,17 @@ import tw.waterballsa.gaas.domain.User
class UserData(
@Id
var id: String? = null,
private var email: String? = null,
var email: String? = null,
var nickname: String? = null,
var identities: List<String> = emptyList()
var identities: List<String>? = null
) {

fun toDomain(): User =
User(
User.Id(id!!),
email!!,
nickname!!,
identities
identities!!
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ import org.springframework.beans.factory.annotation.Autowired
import org.springframework.boot.test.autoconfigure.web.servlet.AutoConfigureMockMvc
import org.springframework.boot.test.context.SpringBootTest
import org.springframework.http.MediaType.APPLICATION_JSON
import org.springframework.security.oauth2.jwt.Jwt
import org.springframework.test.web.servlet.MockMvc
import org.springframework.test.web.servlet.ResultActions
import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder
import tw.waterballsa.gaas.domain.User

@SpringBootTest
@AutoConfigureMockMvc
Expand All @@ -19,6 +21,8 @@ abstract class AbstractSpringBootTest {
@Autowired
private lateinit var objectMapper: ObjectMapper

protected val mockUser: User = mockDefaultUser()

protected fun <T> ResultActions.getBody(type: Class<T>): T =
andReturn().response.contentAsString.let { objectMapper.readValue(it, type) }

Expand All @@ -29,4 +33,18 @@ abstract class AbstractSpringBootTest {

protected fun MockHttpServletRequestBuilder.withJson(request: Any): MockHttpServletRequestBuilder =
contentType(APPLICATION_JSON).content(request.toJson())

protected fun mockDefaultUser(): User =
User(User.Id("1"),
"[email protected]",
"user-437b200d-da9c-449e-b147-114b4822b5aa",
listOf("google-oauth2|102527320242660434908")
)

protected fun mockJwt(subject: String, email: String): Jwt =
Jwt.withTokenValue("mock-token")
.header("alg", "none")
.subject(subject)
.claim("email", email)
.build()
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.security.oauth2.jwt.Jwt
import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.jwt
import org.springframework.test.web.servlet.ResultActions
import org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get
import org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath
import org.springframework.test.web.servlet.result.MockMvcResultMatchers.status
import tw.waterballsa.gaas.application.repositories.UserRepository
import tw.waterballsa.gaas.domain.User
Expand All @@ -17,63 +17,92 @@ class OAuth2ControllerTest @Autowired constructor(
val userRepository: UserRepository,
) : AbstractSpringBootTest() {

val email = "[email protected]"
val googleIdentityProviderId = "google-oauth2|102527320242660434908"
val discordIdentityProviderId = "discord|102527320242660434908"

@BeforeEach
fun cleanUp() {
userRepository.deleteAll()
}

@Test
fun givenInvalidJwtSubject_whenUserLogin_thenShouldLoginFailed() {
givenInvalidJwtSubject()
?.whenUserLogin()
?.thenShouldLoginFailed()
givenInvalidJwt()
.whenUserLogin()
.thenShouldLoginFailed()
}

@Test
fun givenNewUserWithJwtSubject_whenUserLogin_thenCreateUser() {
givenNewUserWithJwtSubject()
fun givenOldEmail_andOldIdentityProviderId_whenUserLogin_thenLoginSuccessfully() {
givenGoogleOAuth2Jwt()
.whenUserLogin()
.thenCreateUser()
.thenLoginSuccessfully()
}

@Test
fun givenOldUser_whenUserLogin_thenShouldLoginSuccessfully() {
givenOldUserWithJwtSubject()
fun givenOldEmail_andNewIdentityProviderId_whenUserLogin_thenSaveNewIdentityProviderId() {
givenDiscordOAuth2Jwt()
.whenUserLogin()
.thenLoginSuccessfully()
.thenSaveNewIdentityProviderId()
}

@Test
fun givenNewEmail_andNewIdentityProviderId_whenUserLogin_thenCreateNewUser() {
givenJwt(googleIdentityProviderId, email)
.whenUserLogin()
.thenCreateNewUser()
}

private fun givenInvalidJwtSubject(): String? = null
private fun givenInvalidJwt(): Jwt =
Jwt("invalid_token",
null,
null,
mapOf("alg" to "none"),
mapOf("no_email" to "none"))

private fun givenNewUserWithJwtSubject(): String {
val subject = "google-oauth2|102527320242660434908"
assertThat(userRepository.existsByIdentitiesIn(subject)).isFalse()
return subject
private fun givenGoogleOAuth2Jwt(): Jwt {
userRepository.createUser(mockUser)
return givenJwt(googleIdentityProviderId, mockUser.email)
}

private fun givenOldUserWithJwtSubject(): String {
val subject = givenNewUserWithJwtSubject()
userRepository.createUser(User(identities = listOf(subject)))
assertThat(userRepository.existsByIdentitiesIn(subject)).isTrue()
return subject
private fun givenDiscordOAuth2Jwt(): Jwt {
userRepository.createUser(mockUser)
return givenJwt(discordIdentityProviderId, mockUser.email)
}

private fun String.whenUserLogin(): ResultActions =
mockMvc.perform(get("/").with(jwt().jwt { it.subject(this) }))
private fun givenJwt(identityProviderId: String, email: String): Jwt =
mockJwt(identityProviderId, email)

private fun Jwt.whenUserLogin(): ResultActions =
mockMvc.perform(get("/").with(jwt().jwt(this)))

private fun ResultActions.thenShouldLoginFailed() {
this.andExpect(status().isBadRequest)
.andExpect(jsonPath("$").value("JWT subject is null"))
}

private fun ResultActions.thenLoginSuccessfully() {
this.andExpect(status().isOk)
}

private fun ResultActions.thenCreateUser() {
private fun ResultActions.thenSaveNewIdentityProviderId() {
thenLoginSuccessfully()
val subject = "google-oauth2|102527320242660434908"
assertThat(userRepository.existsByIdentitiesIn(subject)).isTrue()
userRepository.findByEmail(email)!!
.thenSaveIdentityProviderId(googleIdentityProviderId)
.thenSaveIdentityProviderId(discordIdentityProviderId)
}

private fun ResultActions.thenCreateNewUser() {
thenLoginSuccessfully()
userRepository.findByEmail(email)!!
.thenSaveIdentityProviderId(googleIdentityProviderId)
}

private fun User.thenSaveIdentityProviderId(identityProviderId: String): User {
assertThat(this).isNotNull
assertThat(identities).isNotEmpty
assertThat(identities).contains(identityProviderId)
return this
}

}

0 comments on commit df6f0c2

Please sign in to comment.