Skip to content

Commit

Permalink
Feature/save identity provider id (#99)
Browse files Browse the repository at this point in the history
* feat: save identity provider id when first login

* feat: save multiple identity provider ids for same email

* fix: oauth2 client get OidcUser

* feat: generate random nickname

* refactor: coding style
  • Loading branch information
lohas1107 authored Jun 25, 2023
1 parent 52ccd33 commit 33a4006
Show file tree
Hide file tree
Showing 10 changed files with 149 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,6 @@ interface UserRepository {
fun createUser(user: User): User
fun deleteAll()
fun findAllById(ids: Collection<Id>): List<User>
fun findByEmail(email: String): User?
fun update(user: User): User
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,46 @@ import tw.waterballsa.gaas.application.eventbus.EventBus
import tw.waterballsa.gaas.application.repositories.UserRepository
import tw.waterballsa.gaas.domain.User
import tw.waterballsa.gaas.events.UserCreatedEvent
import java.util.UUID.randomUUID
import javax.inject.Named

@Named
class CreateUserUseCase(
private val userRepository: UserRepository,
private val eventBus: EventBus,
) {
fun execute(request: Request) = when {
userRepository.existsUserByEmail(request.email) -> {}
else -> {
val user = userRepository.createUser(request.toUser())
val event = user.toUserCreatedEvent()
eventBus.broadcast(event)
fun execute(request: Request) {
val user = userRepository.findByEmail(request.email)

with (request) {
when {
user == null -> createUser()
!user.hasIdentity(identityProviderId) -> user.addUserIdentity(identityProviderId)
}
}
}

class Request(val email: String) {
fun toUser(): User = User(email = email)
private fun Request.createUser() {
val user = userRepository.createUser(toUser())
val event = user.toUserCreatedEvent()
eventBus.broadcast(event)
}

private fun User.addUserIdentity(identityProviderId: String) {
addIdentity(identityProviderId)
userRepository.update(this)
}

class Request(
val email: String,
val identityProviderId: String,
)

private fun Request.toUser(): User = User(
email = email,
nickname = "user_${randomUUID()}",
identities = mutableListOf(identityProviderId)
)
}

fun User.toUserCreatedEvent(): UserCreatedEvent = UserCreatedEvent(id!!, email, nickname)
14 changes: 11 additions & 3 deletions domain/src/main/kotlin/tw/waterballsa/gaas/domain/User.kt
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
package tw.waterballsa.gaas.domain
import tw.waterballsa.gaas.domain.Room.Player

class User(
val id: Id? = null,
val email: String,
var nickname: String = "",
val email: String = "",
val nickname: String = "",
val identities: MutableList<String> = mutableListOf()
) {
@JvmInline
value class Id(val value: String)

fun hasIdentity(identityProviderId: String): Boolean {
return identities.contains(identityProviderId)
}

fun addIdentity(identityProviderId: String) {
identities.add(identityProviderId)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ class CustomSuccessHandler(
) {
authentication as OAuth2AuthenticationToken

val email = authentication.principal.let { it as OidcUser }.email
createUserUseCase.execute(CreateUserUseCase.Request(email))
val oidcUser = authentication.principal.let { it as OidcUser }
createUserUseCase.execute(
CreateUserUseCase.Request(oidcUser.email, oidcUser.subject)
)

val accessTokenValue = authorizedClientService.loadAuthorizedClient<OAuth2AuthorizedClient>(
authentication.authorizedClientRegistrationId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import org.springframework.http.ResponseEntity
import org.springframework.security.core.annotation.AuthenticationPrincipal
import org.springframework.security.oauth2.core.oidc.OidcIdToken
import org.springframework.security.oauth2.core.oidc.user.OidcUser
import org.springframework.security.oauth2.jwt.Jwt
import org.springframework.web.bind.annotation.GetMapping
import org.springframework.web.bind.annotation.RestController
import org.springframework.web.util.UriComponentsBuilder
Expand All @@ -23,9 +24,9 @@ class OAuth2Controller(
private lateinit var frontendUrl: String

@GetMapping
fun home(@AuthenticationPrincipal principal: OidcUser?): String {
fun home(@AuthenticationPrincipal principal: Jwt): String {
createUserUseCase.execute(principal.toRequest())
return principal?.idToken?.tokenValue ?: "index"
return principal.tokenValue ?: "index"
}

@GetMapping("/login-successfully")
Expand All @@ -43,6 +44,8 @@ class OAuth2Controller(
}
}

private fun OidcUser?.toRequest(): CreateUserUseCase.Request = CreateUserUseCase.Request(
this?.userInfo?.email ?: throw PlatformException("User email is null")
)
fun Jwt.toRequest(): CreateUserUseCase.Request =
CreateUserUseCase.Request(
email = claims["email"] as String? ?: throw PlatformException("JWT email should exist."),
identityProviderId = subject ?: throw PlatformException("JWT subject should exist.")
)
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,10 @@ class SpringUserRepository(

override fun findAllById(ids: Collection<User.Id>): List<User> =
userDAO.findAllById(ids.map(User.Id::value)).map(UserData::toDomain)

override fun findByEmail(email: String): User? =
userDAO.findByEmail(email)?.toDomain()

override fun update(user: User): User =
userDAO.save(user.toData()).toDomain()
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ import tw.waterballsa.gaas.spring.repositories.data.UserData
@Repository
interface UserDAO : MongoRepository<UserData, String> {
fun existsByEmail(email: String): Boolean
fun findByEmail(email: String): UserData?
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,24 @@ import tw.waterballsa.gaas.domain.User
class UserData(
@Id
var id: String? = null,
private var email: String? = null,
var nickname: String? = null
val email: String = "",
val nickname: String = "",
val identities: List<String> = emptyList(),
) {

fun toDomain(): User =
User(
User.Id(id!!),
email!!,
nickname!!
email,
nickname,
identities.toMutableList()
)
}

fun User.toData(): UserData =
UserData(
id = id?.value,
email = email,
nickname = nickname
nickname = nickname,
identities = identities
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@ import org.springframework.beans.factory.annotation.Autowired
import org.springframework.boot.test.autoconfigure.web.servlet.AutoConfigureMockMvc
import org.springframework.boot.test.context.SpringBootTest
import org.springframework.http.MediaType.APPLICATION_JSON
import org.springframework.security.oauth2.jwt.Jwt
import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.jwt
import org.springframework.test.web.servlet.MockMvc
import org.springframework.test.web.servlet.ResultActions
import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder
import tw.waterballsa.gaas.domain.User

@SpringBootTest
@AutoConfigureMockMvc
Expand All @@ -19,6 +22,20 @@ abstract class AbstractSpringBootTest {
@Autowired
private lateinit var objectMapper: ObjectMapper

protected final val mockUser: User = User(
User.Id("1"),
"[email protected]",
"user-437b200d-da9c-449e-b147-114b4822b5aa",
mutableListOf("google-oauth2|102527320242660434908")
)

protected final fun String.toJwt(): Jwt =
Jwt.withTokenValue("mock-token")
.header("alg", "none")
.subject(this)
.claim("email", mockUser.email)
.build()

protected fun <T> ResultActions.getBody(type: Class<T>): T =
andReturn().response.contentAsString.let { objectMapper.readValue(it, type) }

Expand All @@ -29,4 +46,8 @@ abstract class AbstractSpringBootTest {

protected fun MockHttpServletRequestBuilder.withJson(request: Any): MockHttpServletRequestBuilder =
contentType(APPLICATION_JSON).content(request.toJson())

protected fun MockHttpServletRequestBuilder.withJwt(jwt: Jwt): MockHttpServletRequestBuilder =
with(jwt().jwt(jwt))

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@ import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.security.oauth2.core.oidc.OidcIdToken
import org.springframework.security.oauth2.core.oidc.OidcUserInfo
import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser
import org.springframework.security.oauth2.core.oidc.user.OidcUser
import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.oidcLogin
import org.springframework.security.oauth2.jwt.Jwt
import org.springframework.test.web.servlet.ResultActions
import org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get
import org.springframework.test.web.servlet.result.MockMvcResultMatchers.status
Expand All @@ -20,63 +16,88 @@ class OAuth2ControllerTest @Autowired constructor(
val userRepository: UserRepository,
) : AbstractSpringBootTest() {

private final val googleIdentityProviderId = "google-oauth2|102527320242660434908"
private final val discordIdentityProviderId = "discord|102527320242660434908"

private final val googleOAuth2Jwt = googleIdentityProviderId.toJwt()
private final val discordOAuth2Jwt = discordIdentityProviderId.toJwt()

private final val invalidJwt = Jwt(
"invalid_token",
null,
null,
mapOf("alg" to "none"),
mapOf("no_email" to "none")
)

@BeforeEach
fun cleanUp() {
userRepository.deleteAll()
}

@Test
fun givenInvalidUserInfo_whenUserLogin_thenShouldLoginFailed() {
givenInvalidUserInfo()
.whenLogin()
fun whenUserLoginWithInvalidJwt_thenShouldLoginFailed() {
whenUserLogin(invalidJwt)
.thenShouldLoginFailed()
}

@Test
fun givenNewUser_whenUserLogin_thenCreateUser() {
givenNewUserInfo().assertUserNotExists()
.whenLogin()
.thenShouldLoginSuccessfully()
fun givenUserHasLoggedInViaGoogle_whenUserLoginWithGoogleOAuth2Jwt_thenLoginSuccessfully() {
givenUserHasLoggedInViaGoogle()
whenUserLogin(googleOAuth2Jwt)
.thenLoginSuccessfully()
}

@Test
fun givenOldUser_whenUserLogin_thenShouldLoginSuccessfully() {
givenOldUserInfo().assertUserExists()
.whenLogin()
.thenShouldLoginSuccessfully()
fun givenUserHasLoggedInViaGoogle_whenUserLoginWithDiscordOAuth2Jwt_thenUserHaveNewIdentity() {
givenUserHasLoggedInViaGoogle()
whenUserLogin(discordOAuth2Jwt)
.thenUserHaveNewIdentity(googleIdentityProviderId, discordIdentityProviderId)
}

private fun givenInvalidUserInfo(): OidcUser = givenUserInfo(null)
@Test
fun whenUserLoginAtTheFirstTime_thenCreateNewUser() {
whenUserLogin(googleOAuth2Jwt)
.thenCreateNewUser()
}

private fun givenNewUserInfo(): OidcUser = givenUserInfo(OidcUserInfo(mapOf("email" to "[email protected]")))
private fun givenUserHasLoggedInViaGoogle(): User =
userRepository.createUser(mockUser)

private fun givenOldUserInfo(): OidcUser {
val userInfo = givenUserInfo(OidcUserInfo(mapOf("email" to "[email protected]")))
userRepository.createUser(User(email = userInfo.email))
return userInfo
}
private fun whenUserLogin(jwt: Jwt): ResultActions =
mockMvc.perform(get("/").withJwt(jwt))

private fun givenUserInfo(oidcUserInfo: OidcUserInfo?): OidcUser = DefaultOidcUser(
listOf(),
OidcIdToken("token", null, null, mapOf("sub" to "my_sub")),
oidcUserInfo
)
private fun ResultActions.thenShouldLoginFailed() {
andExpect(status().isBadRequest)
}

private fun OidcUser.assertUserExists(): OidcUser = this.also {
assertThat(userRepository.existsUserByEmail(userInfo.email)).isTrue()
private fun ResultActions.thenLoginSuccessfully() {
andExpect(status().isOk)
}

private fun OidcUser.assertUserNotExists(): OidcUser = this.also {
assertThat(userRepository.existsUserByEmail(userInfo.email)).isFalse()
private fun ResultActions.thenUserHaveNewIdentity(vararg identityProviderIds: String) {
thenLoginSuccessfully()
userRepository.findByEmail(mockUser.email)
?.thenWouldHaveIdentityProviderIds(*identityProviderIds)
}

private fun OidcUser.whenLogin(): ResultActions =
mockMvc.perform(get("/").with(oidcLogin().oidcUser(this)))
private fun ResultActions.thenCreateNewUser() {
thenLoginSuccessfully()
userRepository.findByEmail(mockUser.email)
.thenNickNameShouldBeRandomName()
.thenWouldHaveIdentityProviderIds(googleIdentityProviderId)
}

private fun ResultActions.thenShouldLoginSuccessfully(): ResultActions =
andExpect(status().isOk)
private fun User?.thenWouldHaveIdentityProviderIds(vararg identityProviderIds: String): User {
assertThat(this).isNotNull
assertThat(this!!.identities).containsAll(identityProviderIds.toList())
return this
}

private fun ResultActions.thenShouldLoginFailed(): ResultActions =
andExpect(status().isBadRequest)
private fun User?.thenNickNameShouldBeRandomName(): User {
assertThat(this).isNotNull
assertThat(this!!.nickname).startsWith("user_")
return this
}

}

0 comments on commit 33a4006

Please sign in to comment.