Skip to content

Commit

Permalink
Add bearer authentication to feature/identity
Browse files Browse the repository at this point in the history
Signed-off-by: Stephen Crawford <[email protected]>
  • Loading branch information
stephen-crawford committed Dec 21, 2022
1 parent ec7f9b6 commit ac66566
Show file tree
Hide file tree
Showing 9 changed files with 352 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@ public interface Subject {
* throws SubjectNotFound
* throws SubjectDisabled
*/
void login(final AuthenticationToken token);
void login(final AuthenticationToken token) throws RuntimeException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import java.util.Objects;

import org.apache.shiro.SecurityUtils;
import org.apache.shiro.session.Session;
import org.opensearch.authn.AuthenticationTokenHandler;
import org.opensearch.authn.tokens.AuthenticationToken;
import org.opensearch.authn.Subject;
Expand Down Expand Up @@ -63,10 +64,29 @@ public String toString() {
/**
* Logs the user in via authenticating the user against current Shiro realm
*/
public void login(AuthenticationToken authenticationToken) {
public void login(AuthenticationToken authenticationToken) throws RuntimeException {
org.apache.shiro.authc.AuthenticationToken authToken = AuthenticationTokenHandler.extractShiroAuthToken(authenticationToken);
// Login via shiro realm.
SecurityUtils.getSecurityManager().authenticate(authToken);
// shiroSubject.login(authToken);
ensureUserIsLoggedOut();
shiroSubject.login(authToken);
}

// Logout the user fully before continuing.
private void ensureUserIsLoggedOut() {
try {
// Get the user if one is logged in.
org.apache.shiro.subject.Subject currentUser = SecurityUtils.getSubject();
if (currentUser == null) return;

// Log the user out and kill their session if possible.
currentUser.logout();
Session session = currentUser.getSession(false);
if (session == null) return;

session.stop();
} catch (Exception e) {
// Ignore all errors, as we're trying to silently
// log the user out.
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ static JsonWebKey getDefaultJsonWebKey() {
return jwk;
}

static JsonWebKey getESJsonWebKey() {
JsonWebKey jwk = new JsonWebKey();
jwk.setKeyType(KeyType.OCTET);
jwk.setAlgorithm("ES256");
jwk.setPublicKeyUse(PublicKeyUse.SIGN);
String b64SigningKey = Base64.getEncoder().encodeToString("exchangeKey".getBytes(StandardCharsets.UTF_8));
jwk.setProperty("k", b64SigningKey);
return jwk;
}

public static String createJwt(Map<String, String> claims) {
JoseJwtProducer jwtProducer = new JoseJwtProducer();
jwtProducer.setSignatureProvider(JwsUtils.getSignatureProvider(getDefaultJsonWebKey()));
Expand Down Expand Up @@ -79,4 +89,106 @@ public static String createJwt(Map<String, String> claims) {

return encodedJwt;
}

public static String createEarlyJwt(Map<String, String> claims) {
JoseJwtProducer jwtProducer = new JoseJwtProducer();
jwtProducer.setSignatureProvider(JwsUtils.getSignatureProvider(getDefaultJsonWebKey()));
JwtClaims jwtClaims = new JwtClaims();
JwtToken jwt = new JwtToken(jwtClaims);

jwtClaims.setNotBefore(System.currentTimeMillis() / 1000 + 60 * 60 * 24 * 365); // Not valid until a year from creation
long expiryTime = System.currentTimeMillis() / 1000 + (60 * 60);
jwtClaims.setExpiryTime(expiryTime);

if (claims.containsKey("sub")) {
jwtClaims.setProperty("sub", claims.get("sub"));
} else {
jwtClaims.setProperty("sub", "example_subject");
}

String encodedJwt = jwtProducer.processJwt(jwt);

if (logger.isDebugEnabled()) {
logger.debug(
"Created JWT: "
+ encodedJwt
+ "\n"
+ jsonMapReaderWriter.toJson(jwt.getJwsHeaders())
+ "\n"
+ JwtUtils.claimsToJson(jwt.getClaims())
);
}

return encodedJwt;
}

public static String createExpiredJwt(Map<String, String> claims) {
JoseJwtProducer jwtProducer = new JoseJwtProducer();
jwtProducer.setSignatureProvider(JwsUtils.getSignatureProvider(getDefaultJsonWebKey()));
JwtClaims jwtClaims = new JwtClaims();
JwtToken jwt = new JwtToken(jwtClaims);

long expiryTime = System.currentTimeMillis() / 1000 - 1; // This means the token expired a second before it was made so should never
// be valid
jwtClaims.setExpiryTime(expiryTime);

if (claims.containsKey("sub")) {
jwtClaims.setProperty("sub", claims.get("sub"));
} else {
jwtClaims.setProperty("sub", "example_subject");
}

String encodedJwt = jwtProducer.processJwt(jwt);

if (logger.isDebugEnabled()) {
logger.debug(
"Created JWT: "
+ encodedJwt
+ "\n"
+ jsonMapReaderWriter.toJson(jwt.getJwsHeaders())
+ "\n"
+ JwtUtils.claimsToJson(jwt.getClaims())
);
}

return encodedJwt;
}

public static String createInvalidJwt(Map<String, String> claims) {
JoseJwtProducer jwtProducer = new JoseJwtProducer();
jwtProducer.setSignatureProvider(JwsUtils.getSignatureProvider(getESJsonWebKey()));
JwtClaims jwtClaims = new JwtClaims();
JwtToken jwt = new JwtToken(jwtClaims);

jwtClaims.setNotBefore(System.currentTimeMillis() / 1000);
long expiryTime = System.currentTimeMillis() / 1000 + (60 * 60);
jwtClaims.setExpiryTime(expiryTime);

if (claims.containsKey("sub")) {
jwtClaims.setProperty("sub", claims.get("sub"));
} else {
jwtClaims.setProperty("sub", "example_subject");
}

if (claims.containsKey("iat")) {
jwtClaims.setProperty("iat", claims.get("iat"));
} else {
jwtClaims.setProperty("iat", Instant.now().toString());
}

String encodedJwt = jwtProducer.processJwt(jwt);

if (logger.isDebugEnabled()) {
logger.debug(
"Created JWT: "
+ encodedJwt
+ "\n"
+ jsonMapReaderWriter.toJson(jwt.getJwsHeaders())
+ "\n"
+ JwtUtils.claimsToJson(jwt.getClaims())
);
}

return encodedJwt;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.authn.tokens;

public class BearerAuthToken extends HttpHeaderToken {

private String headerValue;

public BearerAuthToken(String headerValue) {
this.headerValue = headerValue;
}

@Override
public String getHeaderValue() {
return headerValue;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

public class AuthenticationTokenHandlerTests extends OpenSearchTestCase {

public void testShouldExtractBasicAuthTokenSuccessfully() {
public void testShouldExtractBasicAuthTokenSuccessfully() throws RuntimeException {

// The auth header that is part of the request
String authHeader = "Basic YWRtaW46YWRtaW4="; // admin:admin
Expand All @@ -31,7 +31,7 @@ public void testShouldExtractBasicAuthTokenSuccessfully() {
MatcherAssert.assertThat(new String(usernamePasswordToken.getPassword()), equalTo("admin"));
}

public void testShouldExtractBasicAuthTokenSuccessfully_twoSemiColonPassword() {
public void testShouldExtractBasicAuthTokenSuccessfully_twoSemiColonPassword() throws RuntimeException {

// The auth header that is part of the request
String authHeader = "Basic dGVzdDp0ZTpzdA=="; // test:te:st
Expand All @@ -45,7 +45,7 @@ public void testShouldExtractBasicAuthTokenSuccessfully_twoSemiColonPassword() {
MatcherAssert.assertThat(new String(usernamePasswordToken.getPassword()), equalTo("te:st"));
}

public void testShouldReturnNullWhenExtractingInvalidToken() {
public void testShouldReturnNullWhenExtractingInvalidToken() throws RuntimeException {
String authHeader = "Basic Nah";

AuthenticationToken authToken = new BasicAuthToken(authHeader);
Expand All @@ -55,10 +55,11 @@ public void testShouldReturnNullWhenExtractingInvalidToken() {
MatcherAssert.assertThat(usernamePasswordToken, nullValue());
}

public void testShouldReturnNullWhenExtractingNullToken() {
public void testShouldReturnNullWhenExtractingNullToken() throws RuntimeException {

org.apache.shiro.authc.AuthenticationToken shiroAuthToken = AuthenticationTokenHandler.extractShiroAuthToken(null);

MatcherAssert.assertThat(shiroAuthToken, nullValue());
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ public void testCreateJwtWithClaims() {
try {
JwtToken token = JwtVerifier.getVerifiedJwtToken(encodedToken);
assertTrue(token.getClaims().getClaim("sub").equals("testSubject"));
} catch (BadCredentialsException e) {
fail("Unexpected BadCredentialsException thrown");
} catch (RuntimeException e) {
fail("Unexpected RuntimeException thrown");
}
}
}
5 changes: 5 additions & 0 deletions sandbox/modules/identity/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ dependencies {

testImplementation project(path: ':modules:transport-netty4') // for http
testImplementation project(path: ':plugins:transport-nio') // for http


implementation('org.apache.cxf:cxf-rt-rs-security-jose:3.4.5') {
exclude(group: 'jakarta.activation', module: 'jakarta.activation-api')
}
}

//task integTest(type: RestIntegTestTask) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.opensearch.authn.jwt.JwtVendor;
import org.opensearch.authn.tokens.AuthenticationToken;
import org.opensearch.authn.tokens.BasicAuthToken;
import org.opensearch.authn.tokens.BearerAuthToken;
import org.opensearch.authn.tokens.HttpHeaderToken;
import org.opensearch.client.node.NodeClient;
import org.opensearch.common.settings.Settings;
Expand All @@ -32,7 +33,6 @@
import java.time.Instant;
import java.util.Collections;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;

Expand Down Expand Up @@ -80,19 +80,16 @@ private boolean checkAndAuthenticateRequest(RestRequest request, RestChannel cha
jwtClaims.put("sub", "subject");
jwtClaims.put("iat", Instant.now().toString());
String encodedJwt = JwtVendor.createJwt(jwtClaims);
String requestInfo = String.format(
Locale.ROOT,
"(nodeName=%s, requestId=%s, path=%s, jwtClaims=%s checkAndAuthenticateRequest)",
client.getLocalNodeId(),
request.getRequestId(),
request.getRequestId(),
jwtClaims
);
if (log.isDebugEnabled()) {
log.debug(requestInfo);
String logMsg = String.format(Locale.ROOT, "Created internal access token %s", encodedJwt);
log.debug("{} {}", requestInfo, logMsg);
}
String prefix = "(nodeName="
+ client.getLocalNodeId()
+ ", requestId="
+ request.getRequestId()
+ ", path="
+ request.path()
+ ", jwtClaims="
+ jwtClaims
+ " checkAndAuthenticateRequest)";
log.info(prefix + " Created internal access token " + encodedJwt);
threadContext.putHeader(ThreadContextConstants.OPENSEARCH_AUTHENTICATION_TOKEN_HEADER, encodedJwt);
}
return true;
Expand Down Expand Up @@ -126,6 +123,8 @@ private boolean authenticate(RestRequest request, RestChannel channel) throws IO
} catch (final AuthenticationException ae) {
log.info("Authentication finally failed: {}", ae.getMessage());
return false;
} catch (RuntimeException e) {
throw new RuntimeException(e);
}
}

Expand Down Expand Up @@ -155,6 +154,7 @@ private boolean authenticate(RestRequest request, RestChannel channel) throws IO
*/
static AuthenticationToken tokenType(String authHeader) {
if (authHeader.contains("Basic")) return new BasicAuthToken(authHeader);
if (authHeader.contains("Bearer")) return new BearerAuthToken(authHeader);
// support other type of header tokens
return null;
}
Expand Down
Loading

0 comments on commit ac66566

Please sign in to comment.