Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/save identity provider id #99

Merged
merged 11 commits into from
Jun 25, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,6 @@ interface UserRepository {
fun createUser(user: User): User
fun deleteAll()
fun findAllById(ids: Collection<Id>): List<User>
fun findByEmail(email: String): User?
fun update(user: User): User
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,41 @@ import tw.waterballsa.gaas.application.eventbus.EventBus
import tw.waterballsa.gaas.application.repositories.UserRepository
import tw.waterballsa.gaas.domain.User
import tw.waterballsa.gaas.events.UserCreatedEvent
import java.util.UUID.randomUUID
import javax.inject.Named

@Named
class CreateUserUseCase(
private val userRepository: UserRepository,
private val eventBus: EventBus,
) {
fun execute(request: Request) = when {
userRepository.existsUserByEmail(request.email) -> {}
else -> {
val user = userRepository.createUser(request.toUser())
val event = user.toUserCreatedEvent()
eventBus.broadcast(event)
fun execute(request: Request) {
var user = userRepository.findByEmail(request.email)
kuoche1712003 marked this conversation as resolved.
Show resolved Hide resolved

when {
user == null -> {
user = userRepository.createUser(request.toUser())
kuoche1712003 marked this conversation as resolved.
Show resolved Hide resolved
val event = user.toUserCreatedEvent()
eventBus.broadcast(event)
Wally5077 marked this conversation as resolved.
Show resolved Hide resolved
}

!user.hasIdentity(request.identityProviderId) -> {
user = user.addIdentity(request.identityProviderId)
userRepository.update(user)
Wally5077 marked this conversation as resolved.
Show resolved Hide resolved
}
}
}

class Request(val email: String) {
fun toUser(): User = User(email = email)
}
class Request(
val email: String,
val identityProviderId: String,
)

private fun Request.toUser(): User = User(
email = email,
nickname = "user_${randomUUID()}",
identities = listOf(identityProviderId)
)
}

fun User.toUserCreatedEvent(): UserCreatedEvent = UserCreatedEvent(id!!, email, nickname)
13 changes: 10 additions & 3 deletions domain/src/main/kotlin/tw/waterballsa/gaas/domain/User.kt
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
package tw.waterballsa.gaas.domain
import tw.waterballsa.gaas.domain.Room.Player

class User(
val id: Id? = null,
val email: String,
var nickname: String = "",
val email: String = "",
val nickname: String = "",
val identities: List<String> = emptyList(),
Wally5077 marked this conversation as resolved.
Show resolved Hide resolved
) {
@JvmInline
value class Id(val value: String)

fun hasIdentity(identityProviderId: String): Boolean {
return identities.contains(identityProviderId)
}

fun addIdentity(identityProviderId: String): User =
User(id, email, nickname, identities + identityProviderId)
Wally5077 marked this conversation as resolved.
Show resolved Hide resolved
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ class CustomSuccessHandler(
) {
authentication as OAuth2AuthenticationToken

val email = authentication.principal.let { it as OidcUser }.email
createUserUseCase.execute(CreateUserUseCase.Request(email))
val oidcUser = authentication.principal.let { it as OidcUser }
createUserUseCase.execute(
CreateUserUseCase.Request(oidcUser.email, oidcUser.subject)
)

val accessTokenValue = authorizedClientService.loadAuthorizedClient<OAuth2AuthorizedClient>(
authentication.authorizedClientRegistrationId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import org.springframework.http.ResponseEntity
import org.springframework.security.core.annotation.AuthenticationPrincipal
import org.springframework.security.oauth2.core.oidc.OidcIdToken
import org.springframework.security.oauth2.core.oidc.user.OidcUser
import org.springframework.security.oauth2.jwt.Jwt
import org.springframework.web.bind.annotation.GetMapping
import org.springframework.web.bind.annotation.RestController
import org.springframework.web.util.UriComponentsBuilder
Expand All @@ -23,9 +24,9 @@ class OAuth2Controller(
private lateinit var frontendUrl: String

@GetMapping
fun home(@AuthenticationPrincipal principal: OidcUser?): String {
fun home(@AuthenticationPrincipal principal: Jwt): String {
createUserUseCase.execute(principal.toRequest())
return principal?.idToken?.tokenValue ?: "index"
return principal.tokenValue ?: "index"
}

@GetMapping("/login-successfully")
Expand All @@ -43,6 +44,8 @@ class OAuth2Controller(
}
}

private fun OidcUser?.toRequest(): CreateUserUseCase.Request = CreateUserUseCase.Request(
this?.userInfo?.email ?: throw PlatformException("User email is null")
)
fun Jwt.toRequest(): CreateUserUseCase.Request =
CreateUserUseCase.Request(
email = claims["email"] as String? ?: throw PlatformException("JWT email should exist."),
identityProviderId = subject ?: throw PlatformException("JWT subject should exist.")
)
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,10 @@ class SpringUserRepository(

override fun findAllById(ids: Collection<User.Id>): List<User> =
userDAO.findAllById(ids.map(User.Id::value)).map(UserData::toDomain)

override fun findByEmail(email: String): User? =
userDAO.findByEmail(email)?.toDomain()

override fun update(user: User): User =
userDAO.save(user.toData()).toDomain()
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ import tw.waterballsa.gaas.spring.repositories.data.UserData
@Repository
interface UserDAO : MongoRepository<UserData, String> {
fun existsByEmail(email: String): Boolean
fun findByEmail(email: String): UserData?
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,24 @@ import tw.waterballsa.gaas.domain.User
class UserData(
@Id
var id: String? = null,
private var email: String? = null,
var nickname: String? = null
val email: String? = null,
val nickname: String? = null,
Wally5077 marked this conversation as resolved.
Show resolved Hide resolved
val identities: List<String>? = null,
Wally5077 marked this conversation as resolved.
Show resolved Hide resolved
) {

fun toDomain(): User =
User(
User.Id(id!!),
email!!,
nickname!!
nickname!!,
identities!!
)
}

fun User.toData(): UserData =
UserData(
id = id?.value,
email = email,
nickname = nickname
nickname = nickname,
identities = identities
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@ import org.springframework.beans.factory.annotation.Autowired
import org.springframework.boot.test.autoconfigure.web.servlet.AutoConfigureMockMvc
import org.springframework.boot.test.context.SpringBootTest
import org.springframework.http.MediaType.APPLICATION_JSON
import org.springframework.security.oauth2.jwt.Jwt
import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.jwt
import org.springframework.test.web.servlet.MockMvc
import org.springframework.test.web.servlet.ResultActions
import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder
import tw.waterballsa.gaas.domain.User

@SpringBootTest
@AutoConfigureMockMvc
Expand All @@ -19,6 +22,13 @@ abstract class AbstractSpringBootTest {
@Autowired
private lateinit var objectMapper: ObjectMapper

protected val mockUser: User = User(
User.Id("1"),
"[email protected]",
"user-437b200d-da9c-449e-b147-114b4822b5aa",
listOf("google-oauth2|102527320242660434908")
)

protected fun <T> ResultActions.getBody(type: Class<T>): T =
andReturn().response.contentAsString.let { objectMapper.readValue(it, type) }

Expand All @@ -29,4 +39,14 @@ abstract class AbstractSpringBootTest {

protected fun MockHttpServletRequestBuilder.withJson(request: Any): MockHttpServletRequestBuilder =
contentType(APPLICATION_JSON).content(request.toJson())

protected fun MockHttpServletRequestBuilder.withJwt(jwt: Jwt): MockHttpServletRequestBuilder =
with(jwt().jwt(jwt))

protected fun MockHttpServletRequestBuilder.withIdentityProviderId(identityProviderId: String): MockHttpServletRequestBuilder =
with(jwt().jwt {
it.subject(identityProviderId)
it.claim("email", mockUser.email)
})
Wally5077 marked this conversation as resolved.
Show resolved Hide resolved

}
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@ import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.springframework.beans.factory.annotation.Autowired
import org.springframework.security.oauth2.core.oidc.OidcIdToken
import org.springframework.security.oauth2.core.oidc.OidcUserInfo
import org.springframework.security.oauth2.core.oidc.user.DefaultOidcUser
import org.springframework.security.oauth2.core.oidc.user.OidcUser
import org.springframework.security.test.web.servlet.request.SecurityMockMvcRequestPostProcessors.oidcLogin
import org.springframework.security.oauth2.jwt.Jwt
import org.springframework.test.web.servlet.ResultActions
import org.springframework.test.web.servlet.request.MockMvcRequestBuilders.get
import org.springframework.test.web.servlet.result.MockMvcResultMatchers.status
Expand All @@ -20,63 +16,90 @@ class OAuth2ControllerTest @Autowired constructor(
val userRepository: UserRepository,
) : AbstractSpringBootTest() {

private final val googleIdentityProviderId = "google-oauth2|102527320242660434908"
private final val discordIdentityProviderId = "discord|102527320242660434908"

@BeforeEach
fun cleanUp() {
userRepository.deleteAll()
}

@Test
fun givenInvalidUserInfo_whenUserLogin_thenShouldLoginFailed() {
givenInvalidUserInfo()
.whenLogin()
fun whenUserLoginWithInvalidJwt_thenShouldLoginFailed() {
whenUserLoginWithInvalidJwt()
.thenShouldLoginFailed()
}

@Test
fun givenNewUser_whenUserLogin_thenCreateUser() {
givenNewUserInfo().assertUserNotExists()
.whenLogin()
.thenShouldLoginSuccessfully()
fun givenUserHasLoggedInViaGoogle_whenUserLoginWithGoogleOAuth2Jwt_thenLoginSuccessfully() {
givenUserHasLoggedInViaGoogle()
whenUserLogin(googleIdentityProviderId)
.thenLoginSuccessfully()
}

@Test
fun givenOldUser_whenUserLogin_thenShouldLoginSuccessfully() {
givenOldUserInfo().assertUserExists()
.whenLogin()
.thenShouldLoginSuccessfully()
fun givenUserHasLoggedInViaGoogle_whenUserLoginWithDiscordOAuth2Jwt_thenUserHaveNewIdentity() {
givenUserHasLoggedInViaGoogle()
whenUserLogin(discordIdentityProviderId)
.thenUserHaveNewIdentity()
}

private fun givenInvalidUserInfo(): OidcUser = givenUserInfo(null)

private fun givenNewUserInfo(): OidcUser = givenUserInfo(OidcUserInfo(mapOf("email" to "[email protected]")))
@Test
fun whenUserLoginWithNewIdentity_thenCreateNewUser() {
Wally5077 marked this conversation as resolved.
Show resolved Hide resolved
whenUserLogin(googleIdentityProviderId)
.thenCreateNewUser()
}

private fun givenOldUserInfo(): OidcUser {
val userInfo = givenUserInfo(OidcUserInfo(mapOf("email" to "[email protected]")))
userRepository.createUser(User(email = userInfo.email))
return userInfo
private fun givenUserHasLoggedInViaGoogle(): User =
userRepository.createUser(mockUser)

private fun whenUserLoginWithInvalidJwt(): ResultActions {
val invalidJwt = Jwt(
"invalid_token",
null,
null,
mapOf("alg" to "none"),
mapOf("no_email" to "none")
)
return mockMvc.perform(get("/").withJwt(invalidJwt))
}

private fun givenUserInfo(oidcUserInfo: OidcUserInfo?): OidcUser = DefaultOidcUser(
listOf(),
OidcIdToken("token", null, null, mapOf("sub" to "my_sub")),
oidcUserInfo
)
private fun whenUserLogin(identityProviderId: String): ResultActions =
mockMvc.perform(get("/").withIdentityProviderId(identityProviderId))

private fun OidcUser.assertUserExists(): OidcUser = this.also {
assertThat(userRepository.existsUserByEmail(userInfo.email)).isTrue()
private fun ResultActions.thenShouldLoginFailed() {
this.andExpect(status().isBadRequest)
Wally5077 marked this conversation as resolved.
Show resolved Hide resolved
}

private fun OidcUser.assertUserNotExists(): OidcUser = this.also {
assertThat(userRepository.existsUserByEmail(userInfo.email)).isFalse()
private fun ResultActions.thenLoginSuccessfully() {
this.andExpect(status().isOk)
}

private fun OidcUser.whenLogin(): ResultActions =
mockMvc.perform(get("/").with(oidcLogin().oidcUser(this)))
private fun ResultActions.thenUserHaveNewIdentity() {
Wally5077 marked this conversation as resolved.
Show resolved Hide resolved
thenLoginSuccessfully()
userRepository.findByEmail(mockUser.email)!!
.thenSaveIdentityProviderId(googleIdentityProviderId)
.thenSaveIdentityProviderId(discordIdentityProviderId)
}

private fun ResultActions.thenShouldLoginSuccessfully(): ResultActions =
andExpect(status().isOk)
private fun ResultActions.thenCreateNewUser() {
thenLoginSuccessfully()
userRepository.findByEmail(mockUser.email)!!
.thenCreateNickname()
.thenSaveIdentityProviderId(googleIdentityProviderId)
}

private fun User.thenSaveIdentityProviderId(identityProviderId: String): User {
Wally5077 marked this conversation as resolved.
Show resolved Hide resolved
assertThat(this).isNotNull
assertThat(identities).isNotEmpty
assertThat(identities).contains(identityProviderId)
return this
}
Wally5077 marked this conversation as resolved.
Show resolved Hide resolved

private fun ResultActions.thenShouldLoginFailed(): ResultActions =
andExpect(status().isBadRequest)
private fun User.thenCreateNickname(): User {
Wally5077 marked this conversation as resolved.
Show resolved Hide resolved
assertThat(this).isNotNull
assertThat(nickname).startsWith("user_")
return this
}

}