Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Security/Extension] Role encryption/decryption #2620

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.security.authtoken.jwt;

import java.util.Arrays;
import java.util.Base64;

import javax.crypto.Cipher;
import javax.crypto.SecretKey;
import javax.crypto.spec.SecretKeySpec;

public class EncryptionDecryptionUtil {

public static String encrypt(final String secret, final String data) {

byte[] decodedKey = Base64.getDecoder().decode(secret);

try {
Cipher cipher = Cipher.getInstance("AES");
// rebuild key using SecretKeySpec
SecretKey originalKey = new SecretKeySpec(Arrays.copyOf(decodedKey, 16), "AES");
cipher.init(Cipher.ENCRYPT_MODE, originalKey);
byte[] cipherText = cipher.doFinal(data.getBytes("UTF-8"));
return Base64.getEncoder().encodeToString(cipherText);
} catch (Exception e) {
throw new RuntimeException(
"Error occured while encrypting data", e);
}
}

public static String decrypt(final String secret, final String encryptedString) {

byte[] decodedKey = Base64.getDecoder().decode(secret);

try {
Cipher cipher = Cipher.getInstance("AES");
// rebuild key using SecretKeySpec
SecretKey originalKey = new SecretKeySpec(Arrays.copyOf(decodedKey, 16), "AES");
cipher.init(Cipher.DECRYPT_MODE, originalKey);
byte[] cipherText = cipher.doFinal(Base64.getDecoder().decode(encryptedString));
return new String(cipherText);
} catch (Exception e) {
throw new RuntimeException("Error occured while decrypting data", e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import java.time.Instant;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.LongSupplier;
Expand Down Expand Up @@ -43,6 +44,7 @@ public class JwtVendor {

private static JsonMapObjectReaderWriter jsonMapReaderWriter = new JsonMapObjectReaderWriter();

private String claimsEncryptionKey;
private JsonWebKey signingKey;
private JoseJwtProducer jwtProducer;
private final LongSupplier timeProvider;
Expand All @@ -59,6 +61,11 @@ public JwtVendor(Settings settings) {
throw new RuntimeException(e);
}
this.jwtProducer = jwtProducer;
if (settings.get("encryption_key") == null) {
throw new RuntimeException("encryption_key cannot be null");
} else {
this.claimsEncryptionKey = settings.get("encryption_key");
}
timeProvider = System::currentTimeMillis;
}

Expand All @@ -71,6 +78,11 @@ public JwtVendor(Settings settings, final LongSupplier timeProvider) {
throw new RuntimeException(e);
}
this.jwtProducer = jwtProducer;
if (settings.get("encryption_key") == null) {
throw new RuntimeException("encryption_key cannot be null");
} else {
this.claimsEncryptionKey = settings.get("encryption_key");
}
this.timeProvider = timeProvider;
}

Expand Down Expand Up @@ -126,7 +138,7 @@ public Set<String> mapRoles(final User user, final TransportAddress caller) {
return this.configModel.mapSecurityRoles(user, caller);
}

public String createJwt(String issuer, String subject, String audience, Integer expirySeconds) throws Exception {
public String createJwt(String issuer, String subject, String audience, Integer expirySeconds, List<String> roles) throws Exception {
long timeMillis = timeProvider.getAsLong();
Instant now = Instant.ofEpochMilli(timeProvider.getAsLong());

Expand Down Expand Up @@ -154,7 +166,12 @@ public String createJwt(String issuer, String subject, String audience, Integer
throw new Exception("The expiration time should be a positive integer");
}

// TODO: Should call preparelaims() if we need roles in claim;
if (roles != null) {
String listOfRoles = String.join(",", roles);
jwtClaims.setProperty("roles", EncryptionDecryptionUtil.encrypt(claimsEncryptionKey, listOfRoles));
} else {
throw new Exception("Roles cannot be null");
}

String encodedJwt = jwtProducer.processJwt(jwt);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@

package org.opensearch.security.authtoken.jwt;

import java.util.List;
import java.util.function.LongSupplier;

import org.apache.commons.lang3.RandomStringUtils;
import org.apache.cxf.rs.security.jose.jwk.JsonWebKey;
import org.apache.cxf.rs.security.jose.jws.JwsJwtCompactConsumer;
import org.apache.cxf.rs.security.jose.jwt.JwtToken;
Expand Down Expand Up @@ -42,17 +44,20 @@ public void testCreateJwkFromSettingsWithoutSigningKey() throws Exception{
}

RyanL1997 marked this conversation as resolved.
Show resolved Hide resolved
@Test
public void testCreateJwt() throws Exception {
public void testCreateJwtWithRoles() throws Exception {
String issuer = "cluster_0";
String subject = "admin";
String audience = "extension_0";
List<String> roles = List.of("IT", "HR");
String expectedRoles = "IT,HR";
Integer expirySeconds = 300;
LongSupplier currentTime = () -> (int)100;
Settings settings = Settings.builder().put("signing_key", "abc123").build();
String claimsEncryptionKey = RandomStringUtils.randomAlphanumeric(16);
Settings settings = Settings.builder().put("signing_key", "abc123").put("encryption_key", claimsEncryptionKey).build();
Long expectedExp = currentTime.getAsLong() + (expirySeconds * 1000);

JwtVendor jwtVendor = new JwtVendor(settings, currentTime);
String encodedJwt = jwtVendor.createJwt(issuer, subject, audience, expirySeconds);
String encodedJwt = jwtVendor.createJwt(issuer, subject, audience, expirySeconds, roles);

JwsJwtCompactConsumer jwtConsumer = new JwsJwtCompactConsumer(encodedJwt);
JwtToken jwt = jwtConsumer.getJwtToken();
Expand All @@ -63,18 +68,52 @@ public void testCreateJwt() throws Exception {
Assert.assertNotNull(jwt.getClaim("iat"));
Assert.assertNotNull(jwt.getClaim("exp"));
Assert.assertEquals(expectedExp, jwt.getClaim("exp"));
Assert.assertNotEquals(expectedRoles, jwt.getClaim("roles"));
Assert.assertEquals(expectedRoles, EncryptionDecryptionUtil.decrypt(claimsEncryptionKey, jwt.getClaim("roles").toString()));
RyanL1997 marked this conversation as resolved.
Show resolved Hide resolved
}

@Test (expected = Exception.class)
public void testCreateJwtWithBadExpiry() throws Exception {
String issuer = "cluster_0";
String subject = "admin";
String audience = "extension_0";
List <String> roles = List.of("admin");
Integer expirySeconds = -300;
String claimsEncryptionKey = RandomStringUtils.randomAlphanumeric(16);

Settings settings = Settings.builder().put("signing_key", "abc123").put("encryption_key", claimsEncryptionKey).build();
JwtVendor jwtVendor = new JwtVendor(settings);

jwtVendor.createJwt(issuer, subject, audience, expirySeconds, roles);
}

@Test (expected = Exception.class)
public void testCreateJwtWithBadEncryptionKey() throws Exception {
String issuer = "cluster_0";
String subject = "admin";
String audience = "extension_0";
List <String> roles = List.of("admin");
Integer expirySeconds = 300;

Settings settings = Settings.builder().put("signing_key", "abc123").build();
JwtVendor jwtVendor = new JwtVendor(settings);

jwtVendor.createJwt(issuer, subject, audience, expirySeconds);
jwtVendor.createJwt(issuer, subject, audience, expirySeconds, roles);
}

@Test (expected = Exception.class)
public void testCreateJwtWithBadRoles() throws Exception {
String issuer = "cluster_0";
String subject = "admin";
String audience = "extension_0";
List <String> roles = null;
Integer expirySecond = 300;
String claimsEncryptionKey = RandomStringUtils.randomAlphanumeric(16);

Settings settings = Settings.builder().put("signing_key", "abc123").put("encryption_key", claimsEncryptionKey).build();

JwtVendor jwtVendor = new JwtVendor(settings);

jwtVendor.createJwt(issuer, subject, audience, expirySecond, roles);
}
}