From b5abc84cdf60d6de17be1269b4893c147481f08a Mon Sep 17 00:00:00 2001 From: Houston Putman Date: Wed, 13 Sep 2023 14:02:37 -0400 Subject: [PATCH] SOLR-16951: Cache client pkiAuth headers for a second --- solr/CHANGES.txt | 1 + .../security/PKIAuthenticationPlugin.java | 69 +++++++++++++++---- 2 files changed, 57 insertions(+), 13 deletions(-) diff --git a/solr/CHANGES.txt b/solr/CHANGES.txt index 40aead39789..29c0ed4e7af 100644 --- a/solr/CHANGES.txt +++ b/solr/CHANGES.txt @@ -135,6 +135,7 @@ Optimizations * SOLR-16265: reduce memory usage of ContentWriter based requests in Http2SolrClient (Alex Deparvu, Kevin Risden, David Smiley) +* SOLR-16951: Cache client PKIAuth headers, regenerating every second. This speeds up request sending for the Http2SolrClient. (Tomás Fernández Löbbe, Houston Putman) Bug Fixes --------------------- diff --git a/solr/core/src/java/org/apache/solr/security/PKIAuthenticationPlugin.java b/solr/core/src/java/org/apache/solr/security/PKIAuthenticationPlugin.java index cfd4436dac7..1c09d61e93b 100644 --- a/solr/core/src/java/org/apache/solr/security/PKIAuthenticationPlugin.java +++ b/solr/core/src/java/org/apache/solr/security/PKIAuthenticationPlugin.java @@ -26,6 +26,7 @@ import java.security.Principal; import java.security.PublicKey; import java.security.SignatureException; +import java.time.Duration; import java.time.Instant; import java.util.Base64; import java.util.List; @@ -33,6 +34,7 @@ import java.util.Optional; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; import javax.servlet.FilterChain; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -92,7 +94,7 @@ public static void withServerIdentity(final boolean enabled) { private final Map keyCache = new ConcurrentHashMap<>(); private final PublicKeyHandler publicKeyHandler; private final CoreContainer cores; - private static final int MAX_VALIDITY = Integer.getInteger("pkiauth.ttl", 5000); + private static final int MAX_VALIDITY = Integer.getInteger("pkiauth.ttl", 10000); private final String myNodeName; private final HttpHeaderClientInterceptor interceptor = new HttpHeaderClientInterceptor(); private boolean interceptorRegistered = false; @@ -403,12 +405,12 @@ public void onBegin(Request request) { final Optional preFetchedUser = getUserFromJettyRequest(request); if ("v1".equals(System.getProperty(SEND_VERSION))) { preFetchedUser - .map(PKIAuthenticationPlugin.this::generateToken) - .ifPresent(token -> request.header(HEADER, token)); + .map(PKIAuthenticationPlugin.this::getToken) + .ifPresent(token -> request.headers(mutable -> mutable.add(HEADER, token))); } else { preFetchedUser - .map(PKIAuthenticationPlugin.this::generateTokenV2) - .ifPresent(token -> request.header(HEADER_V2, token)); + .map(PKIAuthenticationPlugin.this::getTokenV2) + .ifPresent(token -> request.headers(mutable -> mutable.add(HEADER_V2, token))); } } @@ -485,34 +487,75 @@ private Optional getUser() { } } + private static class CachedToken { + Instant generatedAt; + String token; + + private CachedToken(Instant generatedAt, String token) { + this.generatedAt = generatedAt; + this.token = token; + } + } + + private volatile ConcurrentHashMap> cachedV1Tokens = + new ConcurrentHashMap<>(); + private volatile ConcurrentHashMap> cachedV2Tokens = + new ConcurrentHashMap<>(); + + private static final Duration cacheExpiryTime = Duration.ofSeconds(1); + + private String getToken(String usr) { + AtomicReference tokenRef = + cachedV1Tokens.computeIfAbsent(usr, u -> new AtomicReference<>(generateToken(u))); + if (tokenRef.get().generatedAt.isBefore(Instant.now().minus(cacheExpiryTime))) { + synchronized (tokenRef) { + if (tokenRef.get().generatedAt.isBefore(Instant.now().minus(cacheExpiryTime))) { + tokenRef.set(generateToken(usr)); + } + } + } + return tokenRef.get().token; + } + + private synchronized String getTokenV2(String usr) { + AtomicReference tokenRef = + cachedV2Tokens.computeIfAbsent(usr, u -> new AtomicReference<>(generateTokenV2(u))); + if (tokenRef.get().generatedAt.isBefore(Instant.now().minus(cacheExpiryTime))) { + synchronized (tokenRef) { + if (tokenRef.get().generatedAt.isBefore(Instant.now().minus(cacheExpiryTime))) { + tokenRef.set(generateTokenV2(usr)); + } + } + } + return tokenRef.get().token; + } + @SuppressForbidden(reason = "Needs currentTimeMillis to set current time in header") - private String generateToken(String usr) { + private CachedToken generateToken(String usr) { assert usr != null; String s = usr + " " + System.currentTimeMillis(); byte[] payload = s.getBytes(UTF_8); byte[] payloadCipher = publicKeyHandler.getKeyPair().encrypt(ByteBuffer.wrap(payload)); String base64Cipher = Base64.getEncoder().encodeToString(payloadCipher); log.trace("generateToken: usr={} token={}", usr, base64Cipher); - return myNodeName + " " + base64Cipher; + return new CachedToken(Instant.now(), myNodeName + " " + base64Cipher); } - private String generateTokenV2(String user) { + private CachedToken generateTokenV2(String user) { assert user != null; String s = myNodeName + " " + user + " " + Instant.now().toEpochMilli(); byte[] payload = s.getBytes(UTF_8); byte[] signature = publicKeyHandler.getKeyPair().signSha256(payload); String base64Signature = Base64.getEncoder().encodeToString(signature); - return s + " " + base64Signature; + return new CachedToken(Instant.now(), s + " " + base64Signature); } void setHeader(HttpRequest httpRequest) { if ("v1".equals(System.getProperty(SEND_VERSION))) { - getUser().map(this::generateToken).ifPresent(token -> httpRequest.setHeader(HEADER, token)); + getUser().map(this::getToken).ifPresent(token -> httpRequest.setHeader(HEADER, token)); } else { - getUser() - .map(this::generateTokenV2) - .ifPresent(token -> httpRequest.setHeader(HEADER_V2, token)); + getUser().map(this::getTokenV2).ifPresent(token -> httpRequest.setHeader(HEADER_V2, token)); } }