diff --git a/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java b/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java index 36db5ac317..2c54e87ce2 100644 --- a/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java +++ b/plugin/src/main/java/org/opensearch/ml/utils/RestActionUtils.java @@ -8,6 +8,9 @@ import static org.opensearch.ml.common.MLModel.MODEL_CONTENT_FIELD; import static org.opensearch.ml.common.MLModel.OLD_MODEL_CONTENT_FIELD; +import java.security.AccessController; +import java.security.PrivilegedActionException; +import java.security.PrivilegedExceptionAction; import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; @@ -30,7 +33,6 @@ import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.Nullable; -import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.commons.ConfigConstants; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; @@ -44,6 +46,8 @@ import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.search.internal.InternalSearchResponse; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.common.annotations.VisibleForTesting; import lombok.extern.log4j.Log4j2; @@ -71,9 +75,12 @@ public class RestActionUtils { public static final String PARAMETER_TOOL_NAME = "tool_name"; public static final String OPENDISTRO_SECURITY_CONFIG_PREFIX = "_opendistro_security_"; - public static final String OPENDISTRO_SECURITY_SSL_PRINCIPAL = OPENDISTRO_SECURITY_CONFIG_PREFIX + "ssl_principal"; + + public static final String OPENDISTRO_SECURITY_USER = OPENDISTRO_SECURITY_CONFIG_PREFIX + "user"; static final Set adminDn = new HashSet<>(); + static final Set adminUsernames = new HashSet(); + static final ObjectMapper objectMapper = new ObjectMapper(); public static String getAlgorithm(RestRequest request) { String algorithm = request.param(PARAMETER_ALGORITHM); @@ -212,7 +219,7 @@ public static Optional getStringParam(RestRequest request, String paramN */ public static User getUserContext(Client client) { String userStr = client.threadPool().getThreadContext().getTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT); - logger.debug("Filtering result by " + userStr); + logger.debug("Current user is " + userStr); return User.parse(userStr); } @@ -226,13 +233,25 @@ public static boolean isSuperAdminUser(ClusterService clusterService, Client cli logger.debug("{} is registered as an admin dn", dn); adminDn.add(new LdapName(dn)); } catch (final InvalidNameException e) { - logger.error("Unable to parse admin dn {}", dn, e); + logger.debug("Unable to parse admin dn {}", dn, e); + adminUsernames.add(dn); } } - ThreadContext threadContext = client.threadPool().getThreadContext(); - final String sslPrincipal = threadContext.getTransient(OPENDISTRO_SECURITY_SSL_PRINCIPAL); - return isAdminDN(sslPrincipal); + Object userObject = client.threadPool().getThreadContext().getTransient(OPENDISTRO_SECURITY_USER); + if (userObject == null) + return false; + try { + return AccessController.doPrivileged((PrivilegedExceptionAction) () -> { + String userContext = objectMapper.writeValueAsString(userObject); + final JsonNode node = objectMapper.readTree(userContext); + final String userName = node.get("name").asText(); + + return isAdminDN(userName); + }); + } catch (PrivilegedActionException e) { + throw new RuntimeException(e); + } } private static boolean isAdminDN(String dn) { @@ -241,7 +260,7 @@ private static boolean isAdminDN(String dn) { try { return isAdminDN(new LdapName(dn)); } catch (InvalidNameException e) { - return false; + return adminUsernames.contains(dn); } } diff --git a/plugin/src/test/java/org/opensearch/ml/utils/RestActionUtilsTests.java b/plugin/src/test/java/org/opensearch/ml/utils/RestActionUtilsTests.java index 22947b6407..bf2714c618 100644 --- a/plugin/src/test/java/org/opensearch/ml/utils/RestActionUtilsTests.java +++ b/plugin/src/test/java/org/opensearch/ml/utils/RestActionUtilsTests.java @@ -297,16 +297,20 @@ public void testIsSuperAdminUser() { ThreadContext threadContext = new ThreadContext(Settings.EMPTY); when(clusterService.getSettings()) - .thenReturn(Settings.builder().putList(RestActionUtils.SECURITY_AUTHCZ_ADMIN_DN, "cn=admin").build()); + .thenReturn( + Settings.builder().putList(RestActionUtils.SECURITY_AUTHCZ_ADMIN_DN, "CN=kirk,OU=client,O=client,L=test, C=de").build() + ); when(client.threadPool()).thenReturn(mock(ThreadPool.class)); when(client.threadPool().getThreadContext()).thenReturn(threadContext); - threadContext.putTransient(RestActionUtils.OPENDISTRO_SECURITY_SSL_PRINCIPAL, "cn=admin"); + threadContext.putTransient(RestActionUtils.OPENDISTRO_SECURITY_USER, Map.of("name", "CN=kirk,OU=client,O=client,L=test,C=de")); boolean isAdmin = RestActionUtils.isSuperAdminUser(clusterService, client); Assert.assertTrue(isAdmin); } + // Need to add a test case to cover non Ldap user + @Test public void testIsSuperAdminUser_NotAdmin() { ClusterService clusterService = mock(ClusterService.class); @@ -317,7 +321,7 @@ public void testIsSuperAdminUser_NotAdmin() { .thenReturn(Settings.builder().putList(RestActionUtils.SECURITY_AUTHCZ_ADMIN_DN, "cn=admin").build()); when(client.threadPool()).thenReturn(mock(ThreadPool.class)); when(client.threadPool().getThreadContext()).thenReturn(threadContext); - threadContext.putTransient(RestActionUtils.OPENDISTRO_SECURITY_SSL_PRINCIPAL, "cn=notadmin"); + threadContext.putTransient(RestActionUtils.OPENDISTRO_SECURITY_USER, Map.of("name", "nonAdmin")); boolean isAdmin = RestActionUtils.isSuperAdminUser(clusterService, client); Assert.assertFalse(isAdmin);