Skip to content

Commit

Permalink
fix: use user identity to find user (#134)
Browse files Browse the repository at this point in the history
  • Loading branch information
kuoche1712003 authored Jul 25, 2023
1 parent 50591b8 commit 065f026
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ class GetUserUseCase(
) {
fun execute(request: Request, presenter: Presenter) {
with(request) {
val user = userRepository.findByEmail(email)
?: throw notFound(User::class).identifyBy("email", email)
val user = userRepository.findByIdentity(userIdentity)
?: throw notFound(User::class).identifyBy("userIdentity", userIdentity)
presenter.present(user)
}
}

class Request(val email: String)
class Request(val userIdentity: String)

interface Presenter {
fun present(user: User)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class UpdateUserUseCase(
fun execute(request: Request, presenter: Presenter) {
with(request) {
validateNicknameDuplicated(nickname)
val user = findUserByEmail(email)
val user = findUserByIdentity(userIdentity)
user.changeNickname(nickname)
val updatedUser = userRepository.update(user)

Expand All @@ -32,11 +32,11 @@ class UpdateUserUseCase(
}
}

private fun findUserByEmail(email: String) =
userRepository.findByEmail(email)
?: throw notFound(User::class).identifyBy("email", email)
private fun findUserByIdentity(userIdentity: String) =
userRepository.findByIdentity(userIdentity)
?: throw notFound(User::class).identifyBy("userIdentity", userIdentity)

data class Request(val email: String, val nickname: String)
data class Request(val userIdentity: String, val nickname: String)
}

private fun User.toUserUpdatedEvent(): UserUpdatedEvent =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ import org.springframework.http.HttpStatus.*
import org.springframework.http.ResponseEntity
import org.springframework.security.core.annotation.AuthenticationPrincipal
import org.springframework.security.oauth2.client.*
import org.springframework.security.oauth2.jwt.Jwt
import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient
import org.springframework.security.oauth2.core.oidc.user.OidcUser
import org.springframework.web.bind.annotation.*
import org.springframework.web.context.request.NativeWebRequest
import tw.waterballsa.gaas.application.usecases.CreateUserUseCase
Expand All @@ -20,9 +21,12 @@ class OAuth2Controller(
) {

@GetMapping
fun home(@AuthenticationPrincipal principal: Jwt): String {
fun home(
@AuthenticationPrincipal principal: OidcUser,
@RegisteredOAuth2AuthorizedClient client: OAuth2AuthorizedClient
): String {
createUserUseCase.execute(principal.toRequest())
return principal.tokenValue ?: "index"
return client.accessToken.tokenValue ?: "index"
}

@PostMapping("/authenticate")
Expand All @@ -47,13 +51,9 @@ data class AuthenticateToken(
val token: String
)

val Jwt.email: String
get() = claims["email"]?.toString()
?: throw PlatformException("JWT email should exist.")

val Jwt.identityProviderId: String
val OidcUser.identityProviderId: String
get() = subject
?: throw PlatformException("JWT subject should exist.")
?: throw PlatformException("subject should exist.")

private fun Jwt.toRequest(): CreateUserUseCase.Request =
CreateUserUseCase.Request(email, identityProviderId)
private fun OidcUser.toRequest(): CreateUserUseCase.Request =
CreateUserUseCase.Request(email ?: throw PlatformException("email should exist."), identityProviderId)
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ class UserController(
@AuthenticationPrincipal principal: Jwt,
@RequestBody updateUserRequest: UpdateUserRequest,
): UpdateUserViewModel {
val request = updateUserRequest.toRequest(principal.email)
val request = updateUserRequest.toRequest(principal.subject)
val presenter = UpdateUserPresenter()
updateUserUseCase.execute(request, presenter)
return presenter.viewModel
}
}

private fun Jwt.toRequest(): GetUserUseCase.Request =
GetUserUseCase.Request(email)
GetUserUseCase.Request(subject)

data class UpdateUserRequest(val nickname: String) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,14 @@ abstract class AbstractSpringBootTest {
mutableListOf("google-oauth2|102527320242660434908")
)

protected final fun String.toJwt(): Jwt = generateJwt(this, mockUser.email)
protected final fun String.toJwt(): Jwt = generateJwt(this)

protected final fun User.toJwt(): Jwt = generateJwt(identities.first(), email)
protected final fun User.toJwt(): Jwt = generateJwt(identities.first())

private fun generateJwt(id: String, email: String): Jwt =
private fun generateJwt(id: String): Jwt =
Jwt.withTokenValue("mock-token")
.header("alg", "none")
.subject(id)
.claim("email", email)
.build()

protected fun <T> ResultActions.getBody(type: Class<T>): T =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.security.oauth2.jwt.Jwt
import org.springframework.security.oauth2.core.oidc.OidcIdToken
import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser
import org.springframework.security.oauth2.core.oidc.user.OidcUser
import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.oidcLogin
import org.springframework.test.web.servlet.ResultActions
import org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get
import org.springframework.test.web.servlet.result.MockMvcResultMatchers.status
Expand All @@ -19,15 +22,15 @@ class OAuth2ControllerTest @Autowired constructor(
private final val googleIdentityProviderId = "google-oauth2|102527320242660434908"
private final val discordIdentityProviderId = "discord|102527320242660434908"

private final val googleOAuth2Jwt = googleIdentityProviderId.toJwt()
private final val discordOAuth2Jwt = discordIdentityProviderId.toJwt()
private final val googleOAuth2OidcUser = googleIdentityProviderId.toOidcUser()
private final val discordOAuth2OidcUser = discordIdentityProviderId.toOidcUser()

private final val invalidJwt = Jwt(
"invalid_token",
null,
null,
mapOf("alg" to "none"),
mapOf("no_email" to "none")
private final val invalidOidcUser = DefaultOidcUser(
emptyList(),
OidcIdToken.withTokenValue("oidc-token-value")
.subject("invalid-user")
.claim("no_email", "none")
.build(),
)

@BeforeEach
Expand All @@ -37,35 +40,35 @@ class OAuth2ControllerTest @Autowired constructor(

@Test
fun whenUserLoginWithInvalidJwt_thenShouldLoginFailed() {
whenUserLogin(invalidJwt)
whenUserLogin(invalidOidcUser)
.thenShouldLoginFailed()
}

@Test
fun givenUserHasLoggedInViaGoogle_whenUserLoginWithGoogleOAuth2Jwt_thenLoginSuccessfully() {
givenUserHasLoggedInViaGoogle()
whenUserLogin(googleOAuth2Jwt)
whenUserLogin(googleOAuth2OidcUser)
.thenLoginSuccessfully()
}

@Test
fun givenUserHasLoggedInViaGoogle_whenUserLoginWithDiscordOAuth2Jwt_thenUserHaveNewIdentity() {
givenUserHasLoggedInViaGoogle()
whenUserLogin(discordOAuth2Jwt)
whenUserLogin(discordOAuth2OidcUser)
.thenUserHaveNewIdentity(googleIdentityProviderId, discordIdentityProviderId)
}

@Test
fun whenUserLoginAtTheFirstTime_thenCreateNewUser() {
whenUserLogin(googleOAuth2Jwt)
whenUserLogin(googleOAuth2OidcUser)
.thenCreateNewUser()
}

private fun givenUserHasLoggedInViaGoogle(): User =
userRepository.createUser(mockUser)

private fun whenUserLogin(jwt: Jwt): ResultActions =
mockMvc.perform(get("/").withJwt(jwt))
private fun whenUserLogin(oidcUser: OidcUser): ResultActions =
mockMvc.perform(get("/").with(oidcLogin().oidcUser(oidcUser)))

private fun ResultActions.thenShouldLoginFailed() {
andExpect(status().isBadRequest)
Expand Down Expand Up @@ -100,4 +103,12 @@ class OAuth2ControllerTest @Autowired constructor(
return this
}

private fun String.toOidcUser(): OidcUser =
DefaultOidcUser(
emptyList(),
OidcIdToken.withTokenValue("oidc-token-value")
.subject(this)
.claim("email", mockUser.email)
.build(),
)
}

0 comments on commit 065f026

Please sign in to comment.