Skip to content

Commit

Permalink
Update settings loading and parsing code
Browse files Browse the repository at this point in the history
  • Loading branch information
rtyley committed Aug 8, 2024
1 parent 59c7132 commit d565dd6
Show file tree
Hide file tree
Showing 16 changed files with 252 additions and 199 deletions.
Original file line number Diff line number Diff line change
@@ -1,59 +1,36 @@
package com.gu.pandomainauth

import java.util.concurrent.atomic.AtomicReference
import java.util.concurrent.{Executors, ScheduledExecutorService, TimeUnit}

import com.amazonaws.services.s3.AmazonS3
import com.gu.pandomainauth.model.PanDomainAuthSettings
import com.gu.pandomainauth.service.CryptoConf
import org.slf4j.LoggerFactory

import java.util.concurrent.{Executors, ScheduledExecutorService}
import scala.language.postfixOps

/**
* PanDomainAuthSettingsRefresher will periodically refresh the pan domain settings and expose them via the "settings" method
*
* @param domain the domain you are authenticating against
* @param system the identifier for your app, typically the same as the subdomain your app runs on
* @param bucketName the bucket where the settings are stored
* @param settingsFileKey the name of the file that contains the private settings for the given domain
* @param s3Client the AWS S3 client that will be used to download the settings from the bucket
* @param scheduler optional scheduler that will be used to run the code that updates the bucket
*/
class PanDomainAuthSettingsRefresher(
val domain: String,
val system: String,
val bucketName: String,
settingsFileKey: String,
val s3Client: AmazonS3,
val s3BucketLoader: S3BucketLoader,
scheduler: ScheduledExecutorService = Executors.newScheduledThreadPool(1)
) {
private val logger = LoggerFactory.getLogger(this.getClass)

// This is deliberately designed to throw an exception during construction if we cannot immediately read the settings
private val authSettings: AtomicReference[PanDomainAuthSettings] = new AtomicReference[PanDomainAuthSettings](loadSettings() match {
case Right(settings) => PanDomainAuthSettings(settings)
case Left(err) => throw Settings.errorToThrowable(err)
})

scheduler.scheduleAtFixedRate(() => refresh(), 1, 1, TimeUnit.MINUTES)

def settings: PanDomainAuthSettings = authSettings.get()

private def loadSettings(): Either[SettingsFailure, Map[String, String]] = {
Settings.fetchSettings(settingsFileKey, bucketName, s3Client).flatMap(Settings.extractSettings)
}

private def refresh(): Unit = {
loadSettings() match {
case Right(settings) =>
logger.debug(s"Updated pan-domain settings for $domain")
authSettings.set(PanDomainAuthSettings(settings))
private val settingsRefresher = new Settings.Refresher[PanDomainAuthSettings](
new Settings.Loader(s3BucketLoader, settingsFileKey),
PanDomainAuthSettings.apply,
scheduler
)
settingsRefresher.start(1)

case Left(err) =>
logger.error(s"Failed to update pan-domain settings for $domain")
Settings.logError(err, logger)
}
}
def settings: PanDomainAuthSettings = settingsRefresher.get()
}


Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
package com.gu.pandomainauth.model

import com.gu.pandomainauth.service.Crypto

import java.security.KeyPair
import com.gu.pandomainauth.SettingsFailure.SettingsResult
import com.gu.pandomainauth.service.{CryptoConf, KeyPair}

case class PanDomainAuthSettings(
signingKeyPair: KeyPair,
Expand Down Expand Up @@ -32,7 +31,7 @@ case class Google2FAGroupSettings(
object PanDomainAuthSettings{
private val legacyCookieNameSetting = "assymCookieName"

def apply(settingMap: Map[String, String]): PanDomainAuthSettings = {
def apply(settingMap: Map[String, String]): SettingsResult[PanDomainAuthSettings] = {
val cookieSettings = CookieSettings(
cookieName = settingMap.getOrElse(legacyCookieNameSetting, settingMap("cookieName"))
)
Expand All @@ -49,12 +48,12 @@ object PanDomainAuthSettings{
serviceAccountCert <- settingMap.get("googleServiceAccountCert");
adminUser <- settingMap.get("google2faUser");
group <- settingMap.get("multifactorGroupId")
) yield {
Google2FAGroupSettings(serviceAccountId, serviceAccountCert, adminUser, group)
}
) yield Google2FAGroupSettings(serviceAccountId, serviceAccountCert, adminUser, group)

PanDomainAuthSettings(
Crypto.keyPairFrom(settingMap),
for {
activeKeyPair <- CryptoConf.SettingsReader(settingMap).activeKeyPair
} yield PanDomainAuthSettings(
activeKeyPair,
cookieSettings,
oAuthSettings,
google2faSettings
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
package com.gu.pandomainauth.service

import com.amazonaws.services.s3.AmazonS3
import com.google.api.client.googleapis.javanet.GoogleNetHttpTransport
import com.google.api.client.googleapis.json.GoogleJsonResponseException
import com.google.api.client.json.gson.GsonFactory
import com.google.api.client.util.SecurityUtils
import com.google.api.services.directory.Directory
import com.google.api.services.directory.model.Groups
import com.google.api.services.directory.DirectoryScopes
import com.google.api.services.directory.{Directory, DirectoryScopes}
import com.google.auth.http.HttpCredentialsAdapter
import com.google.auth.oauth2.ServiceAccountCredentials

import scala.jdk.CollectionConverters._
import com.gu.pandomainauth.S3BucketLoader
import com.gu.pandomainauth.model.{AuthenticatedUser, Google2FAGroupSettings}
import org.slf4j.LoggerFactory

class GroupChecker(config: Google2FAGroupSettings, bucketName: String, s3Client: AmazonS3, appName: String) {
import scala.jdk.CollectionConverters._

class GroupChecker(config: Google2FAGroupSettings, s3BucketLoader: S3BucketLoader, appName: String) {
private val logger = LoggerFactory.getLogger(this.getClass)

private val transport = GoogleNetHttpTransport.newTrustedTransport()
Expand All @@ -36,14 +35,13 @@ class GroupChecker(config: Google2FAGroupSettings, bucketName: String, s3Client:
.build

private def loadServiceAccountPrivateKey = {
val certInputStream = s3Client.getObject(bucketName, config.serviceAccountCert).getObjectContent
val serviceAccountPrivateKey = SecurityUtils.loadPrivateKeyFromKeyStore(
SecurityUtils.getPkcs12KeyStore,
certInputStream,
s3BucketLoader.inputStreamFetching(config.serviceAccountCert),
"notasecret", "privatekey", "notasecret"
)

try { certInputStream.close() } catch { case _ : Throwable => }
try { s3BucketLoader.inputStreamFetching(config.serviceAccountCert).close() } catch { case _ : Throwable => }

serviceAccountPrivateKey
}
Expand Down Expand Up @@ -72,11 +70,11 @@ class GroupChecker(config: Google2FAGroupSettings, bucketName: String, s3Client:

private def hasMoreGroups(groupsResponse: Groups): Boolean = {
val token = groupsResponse.getNextPageToken
token != null && token.length > 0
token != null && token.nonEmpty
}
}

class GoogleGroupChecker(config: Google2FAGroupSettings, bucketName: String, s3Client: AmazonS3, appName: String) extends GroupChecker(config, bucketName, s3Client, appName) {
class GoogleGroupChecker(config: Google2FAGroupSettings, s3BucketLoader: S3BucketLoader, appName: String) extends GroupChecker(config, s3BucketLoader, appName) {

def checkGroups(authenticatedUser: AuthenticatedUser, groupIds: List[String]): Either[String, Boolean] = {
val query = directory.groups().list().setUserKey(authenticatedUser.user.email)
Expand All @@ -86,10 +84,9 @@ class GoogleGroupChecker(config: Google2FAGroupSettings, bucketName: String, s3C

}

class Google2FAGroupChecker(config: Google2FAGroupSettings, bucketName: String, s3Client: AmazonS3, appName: String) extends GroupChecker(config, bucketName, s3Client, appName) {
class Google2FAGroupChecker(config: Google2FAGroupSettings, s3BucketLoader: S3BucketLoader, appName: String) extends GroupChecker(config, s3BucketLoader, appName) {

def checkMultifactor(authenticatedUser: AuthenticatedUser): Boolean = {
def checkMultifactor(authenticatedUser: AuthenticatedUser): Boolean =
hasGroup(authenticatedUser.user.email, config.multifactorGroupId)
}

}
12 changes: 5 additions & 7 deletions pan-domain-auth-example/app/VerifyExample.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain
import com.amazonaws.regions.Regions
import com.amazonaws.services.s3.AmazonS3ClientBuilder
import com.gu.pandomainauth.S3BucketLoader.forAwsSdkV1
import com.gu.pandomainauth.model.{Authenticated, AuthenticatedUser, GracePeriod}
import com.gu.pandomainauth.{PanDomain, PublicSettings}
import com.gu.pandomainauth.{PanDomain, PublicSettings, Settings}

object VerifyExample {
// Change this to point to the S3 bucket and key for the settings file
Expand All @@ -14,16 +15,13 @@ object VerifyExample {
val credentials = DefaultAWSCredentialsProviderChain.getInstance()
val s3Client = AmazonS3ClientBuilder.standard().withRegion(region).withCredentials(credentials).build()

val publicSettings = new PublicSettings(settingsFileKey, bucketName, s3Client)
val loader = new Settings.Loader(forAwsSdkV1(s3Client, bucketName), settingsFileKey)
val publicSettings = new PublicSettings(loader)

// Call the start method when your application starts up to ensure the settings are kept up to date
publicSettings.start()

// You can integrate with your own scheduler by calling refresh() which will synchronously update the settings
publicSettings.refresh()

// `publicKey` will return None if a value has not been successfully obtained
val publicKey = publicSettings.publicKey.get
val publicKey = publicSettings.publicKey

// The name of this particular application
val system = "test"
Expand Down
4 changes: 2 additions & 2 deletions pan-domain-auth-example/app/di.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import com.amazonaws.auth.{AWSCredentialsProviderChain, DefaultAWSCredentialsPro
import com.amazonaws.regions.Regions
import com.amazonaws.services.s3.AmazonS3ClientBuilder
import com.gu.pandomainauth.PanDomainAuthSettingsRefresher
import com.gu.pandomainauth.S3BucketLoader.forAwsSdkV1
import controllers.AdminController
import play.api.ApplicationLoader.Context
import play.api.libs.ws.ahc.AhcWSComponents
Expand Down Expand Up @@ -37,9 +38,8 @@ class AppComponents(context: Context) extends BuiltInComponentsFromContext(conte
val panDomainSettings = new PanDomainAuthSettingsRefresher(
domain = "local.dev-gutools.co.uk",
system = "example",
bucketName = bucketName,
settingsFileKey = "local.dev-gutools.co.uk.settings",
s3Client = s3Client
s3BucketLoader = forAwsSdkV1(s3Client, bucketName)
)

val controller = new AdminController(controllerComponents, configuration, wsClient, panDomainSettings)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ trait AuthActions {
val applicationName: String = s"pan-domain-authentication-$system"

val multifactorChecker: Option[Google2FAGroupChecker] = settings.google2FAGroupSettings.map {
new Google2FAGroupChecker(_, panDomainSettings.bucketName, panDomainSettings.s3Client, applicationName)
new Google2FAGroupChecker(_, panDomainSettings.s3BucketLoader, applicationName)
}

/**
Expand Down Expand Up @@ -198,7 +198,7 @@ trait AuthActions {
}

def readAuthenticatedUser(request: RequestHeader): Option[AuthenticatedUser] = readCookie(request) map { cookie =>
CookieUtils.parseCookieData(cookie.cookie.value, settings.signingKeyPair.getPublic)
CookieUtils.parseCookieData(cookie.cookie.value, settings.signingKeyPair.publicKey)
}

def readCookie(request: RequestHeader): Option[PandomainCookie] = {
Expand All @@ -211,7 +211,7 @@ trait AuthActions {
def generateCookie(authedUser: AuthenticatedUser): Cookie =
Cookie(
name = settings.cookieSettings.cookieName,
value = CookieUtils.generateCookieData(authedUser, settings.signingKeyPair.getPrivate),
value = CookieUtils.generateCookieData(authedUser, settings.signingKeyPair.privateKey),
domain = Some(domain),
secure = true,
httpOnly = true
Expand All @@ -237,7 +237,7 @@ trait AuthActions {
*/
def extractAuth(request: RequestHeader): AuthenticationStatus = {
readCookie(request).map { cookie =>
PanDomain.authStatus(cookie.cookie.value, settings.signingKeyPair.getPublic, validateUser, apiGracePeriod, system, cacheValidation, cookie.forceExpiry)
PanDomain.authStatus(cookie.cookie.value, settings.signingKeyPair.publicKey, validateUser, apiGracePeriod, system, cacheValidation, cookie.forceExpiry)
} getOrElse NotAuthenticated
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,51 +1,30 @@
package com.gu.pandomainauth

import java.util.concurrent.atomic.AtomicReference
import java.util.concurrent.{Executors, ScheduledExecutorService, TimeUnit}
import com.amazonaws.services.s3.AmazonS3
import com.gu.pandomainauth.PublicSettings.validateAndParseKeyText
import com.gu.pandomainauth.service.Crypto
import org.slf4j.LoggerFactory
import com.gu.pandomainauth.SettingsFailure.SettingsResult
import com.gu.pandomainauth.service.CryptoConf

import java.security.PublicKey
import java.util.regex.Pattern
import scala.concurrent.ExecutionContext
import java.util.concurrent.{Executors, ScheduledExecutorService}
import scala.concurrent.duration._

/**
* Class that contains the static public settings and includes mechanism for fetching the public key. Once you have an
* instance, call the `start()` method to load the public data.
*
* @param settingsFileKey the settings file for the domain in the S3 bucket (eg local.dev.gutools.co.uk.public.settings)
* @param bucketName the name of the S3 bucket (eg pan-domain-auth-settings)
* @param s3Client the AWS S3 client that will be used to download the settings from the bucket
* @param scheduler optional scheduler that will be used to run the code that updates the bucket
*
* @param scheduler optional scheduler that will be used to run the code that updates the bucket
*/
class PublicSettings(settingsFileKey: String, bucketName: String, s3Client: AmazonS3,
class PublicSettings(loader: Settings.Loader,
scheduler: ScheduledExecutorService = Executors.newScheduledThreadPool(1)) {

private val agent = new AtomicReference[Option[PublicKey]](None)
private val settingsRefresher = new Settings.Refresher[PublicKey](
loader,
CryptoConf.SettingsReader(_).activePublicKey,
scheduler
)

private val logger = LoggerFactory.getLogger(this.getClass)
implicit private val executionContext: ExecutionContext = ExecutionContext.fromExecutor(scheduler)
def start(interval: FiniteDuration = 60.seconds): Unit = settingsRefresher.start(interval.toMinutes.toInt)

def start(interval: FiniteDuration = 60.seconds): Unit = {
scheduler.scheduleAtFixedRate(() => refresh(), 0, interval.toMillis, TimeUnit.MILLISECONDS)
}

def refresh(): Unit = {
PublicSettings.getPublicKey(settingsFileKey, bucketName, s3Client) match {
case Right(publicKey) =>
agent.set(Some(publicKey))
logger.debug("Successfully updated pan-domain public settings")

case Left(err) =>
logger.error("Failed to update pan-domain public settings")
Settings.logError(err, logger)
}
}

def publicKey: Option[PublicKey] = agent.get()
def publicKey: PublicKey = settingsRefresher.get()
}

/**
Expand All @@ -59,18 +38,7 @@ object PublicSettings {
* Fetches the public key from the public S3 bucket
*
* @param domain the domain to fetch the public key for
* @param client implicit dispatch.Http to use for fetching the key
* @param ec implicit execution context to use for fetching the key
*/
def getPublicKey(settingsFileKey: String, bucketName: String, s3Client: AmazonS3): Either[SettingsFailure, PublicKey] = {
fetchSettings(settingsFileKey, bucketName, s3Client) flatMap extractSettings flatMap extractPublicKey
}

private[pandomainauth] def extractPublicKey(settings: Map[String, String]): Either[SettingsFailure, PublicKey] =
settings.get("publicKey").toRight(PublicKeyNotFoundFailure).flatMap(validateAndParseKeyText)

private val KeyPattern: Pattern = "[a-zA-Z0-9+/\n]+={0,3}".r.pattern

private[pandomainauth] def validateAndParseKeyText(pubKeyText: String): Either[SettingsFailure, PublicKey] =
Either.cond(KeyPattern.matcher(pubKeyText).matches, Crypto.publicKeyFor(pubKeyText), PublicKeyFormatFailure)
def getPublicKey(loader: Loader): SettingsResult[PublicKey] =
loader.loadAndParseSettingsMap().flatMap(CryptoConf.SettingsReader(_).activePublicKey)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package com.gu.pandomainauth

import com.amazonaws.services.s3.AmazonS3

import java.io.InputStream

/**
* This trait provides a way to download a file from an S3 bucket, in a way that's agnostic of which
* AWS SDK (v1 or v2) is being used. An instance of S3BucketLoader is *specific* to a particular S3 bucket.
*/
trait S3BucketLoader {
/**
* @param key the key of the file in the S3 bucket, not including the bucket name or a starting "/"
*/
def inputStreamFetching(key: String): InputStream
}

object S3BucketLoader {
/**
* A convenience method to create an S3BucketLoader using AWS SDK v1, the version used by most of our existing code.
* However, codebases that want to use AWS SDK v2 are able to provide their own implementation of S3BucketLoader.
*/
def forAwsSdkV1(s3Client: AmazonS3, bucketName: String): S3BucketLoader =
s3Client.getObject(bucketName, _).getObjectContent
}
Loading

0 comments on commit d565dd6

Please sign in to comment.