From 32bcb13ac4158f860e8a6613bb75217c1934f0e8 Mon Sep 17 00:00:00 2001 From: Ievgen Degtiarenko Date: Fri, 12 Apr 2024 10:50:11 +0200 Subject: [PATCH 01/20] Introduce an easy way to get node id by its name (#107392) Our test utility returns the node name when starting a new node. A lot of APIs (such as routing table or node shutdown) require a node id. This change introduces a simple way to retrieve the node id based on its name. --- .../http/PrevalidateNodeRemovalRestIT.java | 2 +- .../cluster/PrevalidateNodeRemovalIT.java | 6 ++-- .../cluster/PrevalidateShardPathIT.java | 4 +-- .../discovery/ClusterDisruptionIT.java | 2 +- .../store/IndicesStoreIntegrationIT.java | 4 +-- .../nodesinfo/SimpleNodesInfoIT.java | 35 +++++++++---------- .../persistent/PersistentTasksExecutorIT.java | 2 +- .../AbstractIndexRecoveryIntegTestCase.java | 5 ++- .../elasticsearch/test/ESIntegTestCase.java | 4 +++ .../ComponentVersionsNodesInfoIT.java | 8 ++--- .../xpack/enrich/EnrichMultiNodeIT.java | 3 +- .../monitoring/MultiNodesStatsTests.java | 4 +-- .../local/LocalExporterIntegTests.java | 4 +-- ...movalWithSearchableSnapshotIntegTests.java | 2 +- .../shutdown/DesiredBalanceShutdownIT.java | 3 +- .../xpack/shutdown/NodeShutdownPluginsIT.java | 20 ++--------- .../shutdown/NodeShutdownReadinessIT.java | 14 -------- .../xpack/shutdown/NodeShutdownShardsIT.java | 13 ------- .../xpack/shutdown/NodeShutdownTasksIT.java | 19 ++-------- 19 files changed, 45 insertions(+), 109 deletions(-) diff --git a/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/PrevalidateNodeRemovalRestIT.java b/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/PrevalidateNodeRemovalRestIT.java index 17d0f04b9e2cf..ae1764310a34d 100644 --- a/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/PrevalidateNodeRemovalRestIT.java +++ b/qa/smoke-test-http/src/javaRestTest/java/org/elasticsearch/http/PrevalidateNodeRemovalRestIT.java @@ -29,7 +29,7 @@ public class PrevalidateNodeRemovalRestIT extends HttpSmokeTestCase { public void testRestStatusCode() throws IOException { String node1Name = internalCluster().getRandomNodeName(); - String node1Id = internalCluster().clusterService(node1Name).localNode().getId(); + String node1Id = getNodeId(node1Name); ensureGreen(); RestClient client = getRestClient(); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/cluster/PrevalidateNodeRemovalIT.java b/server/src/internalClusterTest/java/org/elasticsearch/cluster/PrevalidateNodeRemovalIT.java index f53e559bfda5d..38921840a2c64 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/cluster/PrevalidateNodeRemovalIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/cluster/PrevalidateNodeRemovalIT.java @@ -58,7 +58,7 @@ public void testNodeRemovalFromNonRedCluster() throws Exception { PrevalidateNodeRemovalRequest.Builder req = PrevalidateNodeRemovalRequest.builder(); switch (randomIntBetween(0, 2)) { case 0 -> req.setNames(nodeName); - case 1 -> req.setIds(internalCluster().clusterService(nodeName).localNode().getId()); + case 1 -> req.setIds(getNodeId(nodeName)); case 2 -> req.setExternalIds(internalCluster().clusterService(nodeName).localNode().getExternalId()); default -> throw new IllegalStateException("Unexpected value"); } @@ -156,7 +156,7 @@ public void testNodeRemovalFromRedClusterWithLocalShardCopy() throws Exception { // Prevalidate removal of node1 PrevalidateNodeRemovalRequest req = PrevalidateNodeRemovalRequest.builder().setNames(node1).build(); PrevalidateNodeRemovalResponse resp = client().execute(PrevalidateNodeRemovalAction.INSTANCE, req).get(); - String node1Id = internalCluster().clusterService(node1).localNode().getId(); + String node1Id = getNodeId(node1); assertFalse(resp.getPrevalidation().isSafe()); assertThat(resp.getPrevalidation().message(), equalTo("removal of the following nodes might not be safe: [" + node1Id + "]")); assertThat(resp.getPrevalidation().nodes().size(), equalTo(1)); @@ -187,7 +187,7 @@ public void testNodeRemovalFromRedClusterWithTimeout() throws Exception { .timeout(TimeValue.timeValueSeconds(1)); PrevalidateNodeRemovalResponse resp = client().execute(PrevalidateNodeRemovalAction.INSTANCE, req).get(); assertFalse("prevalidation result should return false", resp.getPrevalidation().isSafe()); - String node2Id = internalCluster().clusterService(node2).localNode().getId(); + String node2Id = getNodeId(node2); assertThat( resp.getPrevalidation().message(), equalTo("cannot prevalidate removal of nodes with the following IDs: [" + node2Id + "]") diff --git a/server/src/internalClusterTest/java/org/elasticsearch/cluster/PrevalidateShardPathIT.java b/server/src/internalClusterTest/java/org/elasticsearch/cluster/PrevalidateShardPathIT.java index 560a525ec526c..77bcaf1e1970c 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/cluster/PrevalidateShardPathIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/cluster/PrevalidateShardPathIT.java @@ -53,8 +53,8 @@ public void testCheckShards() throws Exception { .stream() .map(ShardRouting::shardId) .collect(Collectors.toSet()); - String node1Id = internalCluster().clusterService(node1).localNode().getId(); - String node2Id = internalCluster().clusterService(node2).localNode().getId(); + String node1Id = getNodeId(node1); + String node2Id = getNodeId(node2); Set shardIdsToCheck = new HashSet<>(shardIds); boolean includeUnknownShardId = randomBoolean(); if (includeUnknownShardId) { diff --git a/server/src/internalClusterTest/java/org/elasticsearch/discovery/ClusterDisruptionIT.java b/server/src/internalClusterTest/java/org/elasticsearch/discovery/ClusterDisruptionIT.java index a0efb81c18668..c661894840261 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/discovery/ClusterDisruptionIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/discovery/ClusterDisruptionIT.java @@ -326,7 +326,7 @@ public void testSendingShardFailure() throws Exception { String nonMasterNode = randomFrom(nonMasterNodes); assertAcked(prepareCreate("test").setSettings(indexSettings(3, 2))); ensureGreen(); - String nonMasterNodeId = internalCluster().clusterService(nonMasterNode).localNode().getId(); + String nonMasterNodeId = getNodeId(nonMasterNode); // fail a random shard ShardRouting failedShard = randomFrom( diff --git a/server/src/internalClusterTest/java/org/elasticsearch/indices/store/IndicesStoreIntegrationIT.java b/server/src/internalClusterTest/java/org/elasticsearch/indices/store/IndicesStoreIntegrationIT.java index ca749eeaef545..5805eab831230 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/indices/store/IndicesStoreIntegrationIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/indices/store/IndicesStoreIntegrationIT.java @@ -386,8 +386,8 @@ public void testShardActiveElseWhere() throws Exception { final String masterNode = internalCluster().getMasterName(); final String nonMasterNode = nodes.get(0).equals(masterNode) ? nodes.get(1) : nodes.get(0); - final String masterId = internalCluster().clusterService(masterNode).localNode().getId(); - final String nonMasterId = internalCluster().clusterService(nonMasterNode).localNode().getId(); + final String masterId = getNodeId(masterNode); + final String nonMasterId = getNodeId(nonMasterNode); final int numShards = scaledRandomIntBetween(2, 10); assertAcked(prepareCreate("test").setSettings(indexSettings(numShards, 0))); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/nodesinfo/SimpleNodesInfoIT.java b/server/src/internalClusterTest/java/org/elasticsearch/nodesinfo/SimpleNodesInfoIT.java index cafc0e9426eea..a5700c319aa59 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/nodesinfo/SimpleNodesInfoIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/nodesinfo/SimpleNodesInfoIT.java @@ -11,7 +11,6 @@ import org.elasticsearch.action.admin.cluster.health.ClusterHealthResponse; import org.elasticsearch.action.admin.cluster.node.info.NodesInfoRequest; import org.elasticsearch.action.admin.cluster.node.info.NodesInfoResponse; -import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.monitor.os.OsInfo; @@ -29,16 +28,16 @@ @ClusterScope(scope = Scope.TEST, numDataNodes = 0) public class SimpleNodesInfoIT extends ESIntegTestCase { - public void testNodesInfos() throws Exception { - List nodesIds = internalCluster().startNodes(2); - final String node_1 = nodesIds.get(0); - final String node_2 = nodesIds.get(1); + public void testNodesInfos() { + List nodesNames = internalCluster().startNodes(2); + final String node_1 = nodesNames.get(0); + final String node_2 = nodesNames.get(1); ClusterHealthResponse clusterHealth = clusterAdmin().prepareHealth().setWaitForGreenStatus().setWaitForNodes("2").get(); logger.info("--> done cluster_health, status {}", clusterHealth.getStatus()); - String server1NodeId = internalCluster().getInstance(ClusterService.class, node_1).state().nodes().getLocalNodeId(); - String server2NodeId = internalCluster().getInstance(ClusterService.class, node_2).state().nodes().getLocalNodeId(); + String server1NodeId = getNodeId(node_1); + String server2NodeId = getNodeId(node_2); logger.info("--> started nodes: {} and {}", server1NodeId, server2NodeId); NodesInfoResponse response = clusterAdmin().prepareNodesInfo().get(); @@ -68,16 +67,16 @@ public void testNodesInfos() throws Exception { assertThat(response.getNodesMap().get(server2NodeId), notNullValue()); } - public void testNodesInfosTotalIndexingBuffer() throws Exception { - List nodesIds = internalCluster().startNodes(2); - final String node_1 = nodesIds.get(0); - final String node_2 = nodesIds.get(1); + public void testNodesInfosTotalIndexingBuffer() { + List nodesNames = internalCluster().startNodes(2); + final String node_1 = nodesNames.get(0); + final String node_2 = nodesNames.get(1); ClusterHealthResponse clusterHealth = clusterAdmin().prepareHealth().setWaitForGreenStatus().setWaitForNodes("2").get(); logger.info("--> done cluster_health, status {}", clusterHealth.getStatus()); - String server1NodeId = internalCluster().getInstance(ClusterService.class, node_1).state().nodes().getLocalNodeId(); - String server2NodeId = internalCluster().getInstance(ClusterService.class, node_2).state().nodes().getLocalNodeId(); + String server1NodeId = getNodeId(node_1); + String server2NodeId = getNodeId(node_2); logger.info("--> started nodes: {} and {}", server1NodeId, server2NodeId); NodesInfoResponse response = clusterAdmin().prepareNodesInfo().get(); @@ -103,19 +102,19 @@ public void testNodesInfosTotalIndexingBuffer() throws Exception { } public void testAllocatedProcessors() throws Exception { - List nodesIds = internalCluster().startNodes( + List nodeNames = internalCluster().startNodes( Settings.builder().put(EsExecutors.NODE_PROCESSORS_SETTING.getKey(), 2.9).build(), Settings.builder().put(EsExecutors.NODE_PROCESSORS_SETTING.getKey(), 5.9).build() ); - final String node_1 = nodesIds.get(0); - final String node_2 = nodesIds.get(1); + final String node_1 = nodeNames.get(0); + final String node_2 = nodeNames.get(1); ClusterHealthResponse clusterHealth = clusterAdmin().prepareHealth().setWaitForGreenStatus().setWaitForNodes("2").get(); logger.info("--> done cluster_health, status {}", clusterHealth.getStatus()); - String server1NodeId = internalCluster().getInstance(ClusterService.class, node_1).state().nodes().getLocalNodeId(); - String server2NodeId = internalCluster().getInstance(ClusterService.class, node_2).state().nodes().getLocalNodeId(); + String server1NodeId = getNodeId(node_1); + String server2NodeId = getNodeId(node_2); logger.info("--> started nodes: {} and {}", server1NodeId, server2NodeId); NodesInfoResponse response = clusterAdmin().prepareNodesInfo().get(); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/persistent/PersistentTasksExecutorIT.java b/server/src/internalClusterTest/java/org/elasticsearch/persistent/PersistentTasksExecutorIT.java index 813c06d9f02f3..d71718f3f3a6b 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/persistent/PersistentTasksExecutorIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/persistent/PersistentTasksExecutorIT.java @@ -145,7 +145,7 @@ public void testPersistentActionWithNoAvailableNode() throws Exception { Settings nodeSettings = Settings.builder().put(nodeSettings(0, Settings.EMPTY)).put("node.attr.test_attr", "test").build(); String newNode = internalCluster().startNode(nodeSettings); - String newNodeId = internalCluster().clusterService(newNode).localNode().getId(); + String newNodeId = getNodeId(newNode); waitForTaskToStart(); TaskInfo taskInfo = clusterAdmin().prepareListTasks().setActions(TestPersistentTasksExecutor.NAME + "[c]").get().getTasks().get(0); diff --git a/test/framework/src/main/java/org/elasticsearch/indices/recovery/AbstractIndexRecoveryIntegTestCase.java b/test/framework/src/main/java/org/elasticsearch/indices/recovery/AbstractIndexRecoveryIntegTestCase.java index a5ace3e357f90..97f17858e753d 100644 --- a/test/framework/src/main/java/org/elasticsearch/indices/recovery/AbstractIndexRecoveryIntegTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/indices/recovery/AbstractIndexRecoveryIntegTestCase.java @@ -17,7 +17,6 @@ import org.elasticsearch.cluster.NodeConnectionsService; import org.elasticsearch.cluster.action.shard.ShardStateAction; import org.elasticsearch.cluster.metadata.IndexMetadata; -import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.breaker.CircuitBreakingException; @@ -135,7 +134,7 @@ protected void checkTransientErrorsDuringRecoveryAreRetried(String recoveryActio ensureSearchable(indexName); ClusterStateResponse stateResponse = clusterAdmin().prepareState().get(); - final String blueNodeId = internalCluster().getInstance(ClusterService.class, blueNodeName).localNode().getId(); + final String blueNodeId = getNodeId(blueNodeName); assertFalse(stateResponse.getState().getRoutingNodes().node(blueNodeId).isEmpty()); @@ -231,7 +230,7 @@ public void checkDisconnectsWhileRecovering(String recoveryActionToBlock) throws ensureSearchable(indexName); ClusterStateResponse stateResponse = clusterAdmin().prepareState().get(); - final String blueNodeId = internalCluster().getInstance(ClusterService.class, blueNodeName).localNode().getId(); + final String blueNodeId = getNodeId(blueNodeName); assertFalse(stateResponse.getState().getRoutingNodes().node(blueNodeId).isEmpty()); diff --git a/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java index 11d4754eaa596..1056c766e17ca 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/ESIntegTestCase.java @@ -1093,6 +1093,10 @@ public static void awaitClusterState(Logger logger, String viaNode, Predicate nodesIds = internalCluster().startNodes(1); - final String node_1 = nodesIds.get(0); + final String node_1 = internalCluster().startNode(); ClusterHealthResponse clusterHealth = clusterAdmin().prepareHealth().setWaitForGreenStatus().setWaitForNodes("1").get(); logger.info("--> done cluster_health, status {}", clusterHealth.getStatus()); - String server1NodeId = internalCluster().getInstance(ClusterService.class, node_1).state().nodes().getLocalNodeId(); + String server1NodeId = getNodeId(node_1); logger.info("--> started nodes: {}", server1NodeId); NodesInfoResponse response = clusterAdmin().prepareNodesInfo().get(); diff --git a/x-pack/plugin/enrich/src/internalClusterTest/java/org/elasticsearch/xpack/enrich/EnrichMultiNodeIT.java b/x-pack/plugin/enrich/src/internalClusterTest/java/org/elasticsearch/xpack/enrich/EnrichMultiNodeIT.java index b81a5e6b902b3..26e38252a4572 100644 --- a/x-pack/plugin/enrich/src/internalClusterTest/java/org/elasticsearch/xpack/enrich/EnrichMultiNodeIT.java +++ b/x-pack/plugin/enrich/src/internalClusterTest/java/org/elasticsearch/xpack/enrich/EnrichMultiNodeIT.java @@ -20,7 +20,6 @@ import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.ingest.PutPipelineRequest; import org.elasticsearch.action.search.SearchRequest; -import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.Maps; @@ -268,7 +267,7 @@ private static void enrich(Map> keys, String coordinatingNo EnrichStatsAction.Response statsResponse = client().execute(EnrichStatsAction.INSTANCE, new EnrichStatsAction.Request()) .actionGet(); assertThat(statsResponse.getCoordinatorStats().size(), equalTo(internalCluster().size())); - String nodeId = internalCluster().getInstance(ClusterService.class, coordinatingNode).localNode().getId(); + String nodeId = getNodeId(coordinatingNode); CoordinatorStats stats = statsResponse.getCoordinatorStats().stream().filter(s -> s.getNodeId().equals(nodeId)).findAny().get(); assertThat(stats.getNodeId(), equalTo(nodeId)); assertThat(stats.getRemoteRequestsTotal(), greaterThanOrEqualTo(1L)); diff --git a/x-pack/plugin/monitoring/src/test/java/org/elasticsearch/xpack/monitoring/MultiNodesStatsTests.java b/x-pack/plugin/monitoring/src/test/java/org/elasticsearch/xpack/monitoring/MultiNodesStatsTests.java index c8aae302e357b..3c085b9bb2820 100644 --- a/x-pack/plugin/monitoring/src/test/java/org/elasticsearch/xpack/monitoring/MultiNodesStatsTests.java +++ b/x-pack/plugin/monitoring/src/test/java/org/elasticsearch/xpack/monitoring/MultiNodesStatsTests.java @@ -87,9 +87,7 @@ public void testMultipleNodes() throws Exception { assertThat(((StringTerms) aggregation).getBuckets().size(), equalTo(nbNodes)); for (String nodeName : internalCluster().getNodeNames()) { - StringTerms.Bucket bucket = ((StringTerms) aggregation).getBucketByKey( - internalCluster().clusterService(nodeName).localNode().getId() - ); + StringTerms.Bucket bucket = ((StringTerms) aggregation).getBucketByKey(getNodeId(nodeName)); // At least 1 doc must exist per node, but it can be more than 1 // because the first node may have already collected many node stats documents // whereas the last node just started to collect node stats. diff --git a/x-pack/plugin/monitoring/src/test/java/org/elasticsearch/xpack/monitoring/exporter/local/LocalExporterIntegTests.java b/x-pack/plugin/monitoring/src/test/java/org/elasticsearch/xpack/monitoring/exporter/local/LocalExporterIntegTests.java index ef4f22f852b37..69ac9d4ddd876 100644 --- a/x-pack/plugin/monitoring/src/test/java/org/elasticsearch/xpack/monitoring/exporter/local/LocalExporterIntegTests.java +++ b/x-pack/plugin/monitoring/src/test/java/org/elasticsearch/xpack/monitoring/exporter/local/LocalExporterIntegTests.java @@ -173,7 +173,7 @@ public void testExport() throws Exception { aggregation.getBuckets().size() ); for (String nodeName : internalCluster().getNodeNames()) { - String nodeId = internalCluster().clusterService(nodeName).localNode().getId(); + String nodeId = getNodeId(nodeName); Terms.Bucket bucket = aggregation.getBucketByKey(nodeId); assertTrue("No bucket found for node id [" + nodeId + "]", bucket != null); assertTrue(bucket.getDocCount() >= 1L); @@ -208,7 +208,7 @@ public void testExport() throws Exception { response -> { Terms aggregation = response.getAggregations().get("agg_nodes_ids"); for (String nodeName : internalCluster().getNodeNames()) { - String nodeId = internalCluster().clusterService(nodeName).localNode().getId(); + String nodeId = getNodeId(nodeName); Terms.Bucket bucket = aggregation.getBucketByKey(nodeId); assertTrue("No bucket found for node id [" + nodeId + "]", bucket != null); assertTrue(bucket.getDocCount() >= 1L); diff --git a/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/PrevalidateNodeRemovalWithSearchableSnapshotIntegTests.java b/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/PrevalidateNodeRemovalWithSearchableSnapshotIntegTests.java index a651c4b30fcb1..37e2427ae6891 100644 --- a/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/PrevalidateNodeRemovalWithSearchableSnapshotIntegTests.java +++ b/x-pack/plugin/searchable-snapshots/src/internalClusterTest/java/org/elasticsearch/xpack/searchablesnapshots/PrevalidateNodeRemovalWithSearchableSnapshotIntegTests.java @@ -63,7 +63,7 @@ public void testNodeRemovalFromClusterWihRedSearchableSnapshotIndex() throws Exc PrevalidateNodeRemovalRequest.Builder req = PrevalidateNodeRemovalRequest.builder(); switch (randomIntBetween(0, 2)) { case 0 -> req.setNames(node2); - case 1 -> req.setIds(internalCluster().clusterService(node2).localNode().getId()); + case 1 -> req.setIds(getNodeId(node2)); case 2 -> req.setExternalIds(internalCluster().clusterService(node2).localNode().getExternalId()); default -> throw new IllegalStateException("Unexpected value"); } diff --git a/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/DesiredBalanceShutdownIT.java b/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/DesiredBalanceShutdownIT.java index ceedda30626c6..ce1704639527d 100644 --- a/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/DesiredBalanceShutdownIT.java +++ b/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/DesiredBalanceShutdownIT.java @@ -11,7 +11,6 @@ import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.SingleNodeShutdownMetadata; -import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.plugins.Plugin; @@ -36,7 +35,7 @@ protected Collection> nodePlugins() { public void testDesiredBalanceWithShutdown() throws Exception { final var oldNodeName = internalCluster().startNode(); - final var oldNodeId = internalCluster().getInstance(ClusterService.class, oldNodeName).localNode().getId(); + final var oldNodeId = getNodeId(oldNodeName); createIndex( INDEX, diff --git a/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownPluginsIT.java b/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownPluginsIT.java index 3a1280307b739..c87fa08e8c972 100644 --- a/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownPluginsIT.java +++ b/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownPluginsIT.java @@ -9,10 +9,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -import org.elasticsearch.action.admin.cluster.node.info.NodeInfo; -import org.elasticsearch.action.admin.cluster.node.info.NodesInfoResponse; import org.elasticsearch.cluster.metadata.SingleNodeShutdownMetadata; -import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.plugins.Plugin; import org.elasticsearch.plugins.ShutdownAwarePlugin; import org.elasticsearch.test.ESIntegTestCase; @@ -44,21 +41,8 @@ public void testShutdownAwarePlugin() throws Exception { final String shutdownNode; final String remainNode; - NodesInfoResponse nodes = clusterAdmin().prepareNodesInfo().clear().get(); - final String node1Id = nodes.getNodes() - .stream() - .map(NodeInfo::getNode) - .filter(node -> node.getName().equals(node1)) - .map(DiscoveryNode::getId) - .findFirst() - .orElseThrow(); - final String node2Id = nodes.getNodes() - .stream() - .map(NodeInfo::getNode) - .filter(node -> node.getName().equals(node2)) - .map(DiscoveryNode::getId) - .findFirst() - .orElseThrow(); + final String node1Id = getNodeId(node1); + final String node2Id = getNodeId(node2); if (randomBoolean()) { shutdownNode = node1Id; diff --git a/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownReadinessIT.java b/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownReadinessIT.java index af0713665731c..6dfbb8360e763 100644 --- a/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownReadinessIT.java +++ b/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownReadinessIT.java @@ -7,10 +7,7 @@ package org.elasticsearch.xpack.shutdown; -import org.elasticsearch.action.admin.cluster.node.info.NodeInfo; -import org.elasticsearch.action.admin.cluster.node.info.NodesInfoResponse; import org.elasticsearch.cluster.metadata.SingleNodeShutdownMetadata; -import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.plugins.Plugin; @@ -93,17 +90,6 @@ private void deleteNodeShutdown(String nodeId) { assertAcked(client().execute(DeleteShutdownNodeAction.INSTANCE, new DeleteShutdownNodeAction.Request(nodeId))); } - private String getNodeId(String nodeName) { - NodesInfoResponse nodes = clusterAdmin().prepareNodesInfo().clear().get(); - return nodes.getNodes() - .stream() - .map(NodeInfo::getNode) - .filter(node -> node.getName().equals(nodeName)) - .map(DiscoveryNode::getId) - .findFirst() - .orElseThrow(); - } - private void assertNoShuttingDownNodes(String nodeId) throws ExecutionException, InterruptedException { var response = client().execute(GetShutdownStatusAction.INSTANCE, new GetShutdownStatusAction.Request(nodeId)).get(); assertThat(response.getShutdownStatuses(), empty()); diff --git a/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownShardsIT.java b/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownShardsIT.java index fad05e6f213d5..fda2a5755be55 100644 --- a/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownShardsIT.java +++ b/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownShardsIT.java @@ -8,13 +8,11 @@ package org.elasticsearch.xpack.shutdown; import org.elasticsearch.action.admin.cluster.allocation.ClusterAllocationExplainResponse; -import org.elasticsearch.action.admin.cluster.node.info.NodeInfo; import org.elasticsearch.action.admin.cluster.node.info.NodesInfoResponse; import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.metadata.IndexMetadata; import org.elasticsearch.cluster.metadata.SingleNodeShutdownMetadata; -import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.routing.RoutingNodesHelper; import org.elasticsearch.cluster.routing.ShardRouting; import org.elasticsearch.cluster.routing.ShardRoutingState; @@ -456,17 +454,6 @@ private String findIdOfNodeWithPrimaryShard(String indexName) { ); } - private String getNodeId(String nodeName) { - NodesInfoResponse nodes = clusterAdmin().prepareNodesInfo().clear().get(); - return nodes.getNodes() - .stream() - .map(NodeInfo::getNode) - .filter(node -> node.getName().equals(nodeName)) - .map(DiscoveryNode::getId) - .findFirst() - .orElseThrow(); - } - private void putNodeShutdown(String nodeId, SingleNodeShutdownMetadata.Type type, String nodeReplacementName) throws Exception { assertAcked( client().execute( diff --git a/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownTasksIT.java b/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownTasksIT.java index 7c32311237c57..dc4e6b9c53fda 100644 --- a/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownTasksIT.java +++ b/x-pack/plugin/shutdown/src/internalClusterTest/java/org/elasticsearch/xpack/shutdown/NodeShutdownTasksIT.java @@ -12,8 +12,6 @@ import org.elasticsearch.ResourceAlreadyExistsException; import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.admin.cluster.node.info.NodeInfo; -import org.elasticsearch.action.admin.cluster.node.info.NodesInfoResponse; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterChangedEvent; import org.elasticsearch.cluster.ClusterState; @@ -78,21 +76,8 @@ public void testTasksAreNotAssignedToShuttingDownNode() throws Exception { final String shutdownNode; final String candidateNode; - NodesInfoResponse nodes = clusterAdmin().prepareNodesInfo().clear().get(); - final String node1Id = nodes.getNodes() - .stream() - .map(NodeInfo::getNode) - .filter(node -> node.getName().equals(node1)) - .map(DiscoveryNode::getId) - .findFirst() - .orElseThrow(); - final String node2Id = nodes.getNodes() - .stream() - .map(NodeInfo::getNode) - .filter(node -> node.getName().equals(node2)) - .map(DiscoveryNode::getId) - .findFirst() - .orElseThrow(); + final String node1Id = getNodeId(node1); + final String node2Id = getNodeId(node2); if (randomBoolean()) { shutdownNode = node1Id; From b68cae424437fa487ce7b92c1ce16dde995caa47 Mon Sep 17 00:00:00 2001 From: Albert Zaharovits Date: Fri, 12 Apr 2024 12:00:50 +0300 Subject: [PATCH 02/20] Unregistered setting to disable the native role mapping store (#107345) This PR introduces a new unregistered setting that can be used (by other plugins that register the setting) to disable the index-based native role mappings store. --- .../DisableNativeRoleMappingsStoreTests.java | 157 +++++++++++++++++ .../mapper/NativeRoleMappingStore.java | 161 +++++++++++------- .../NativeRoleMappingBaseRestHandler.java | 47 +++++ .../RestDeleteRoleMappingAction.java | 3 +- .../RestGetRoleMappingsAction.java | 3 +- .../rolemapping/RestPutRoleMappingAction.java | 3 +- 6 files changed, 307 insertions(+), 67 deletions(-) create mode 100644 x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authz/store/DisableNativeRoleMappingsStoreTests.java create mode 100644 x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/action/rolemapping/NativeRoleMappingBaseRestHandler.java diff --git a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authz/store/DisableNativeRoleMappingsStoreTests.java b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authz/store/DisableNativeRoleMappingsStoreTests.java new file mode 100644 index 0000000000000..4f56d783e117c --- /dev/null +++ b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authz/store/DisableNativeRoleMappingsStoreTests.java @@ -0,0 +1,157 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.security.authz.store; + +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.client.Request; +import org.elasticsearch.client.RequestOptions; +import org.elasticsearch.client.ResponseException; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.common.settings.Setting; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.test.SecurityIntegTestCase; +import org.elasticsearch.test.SecuritySettingsSource; +import org.elasticsearch.test.SecuritySettingsSourceField; +import org.elasticsearch.xpack.core.security.action.rolemapping.DeleteRoleMappingRequest; +import org.elasticsearch.xpack.core.security.action.rolemapping.PutRoleMappingRequest; +import org.elasticsearch.xpack.core.security.authc.RealmConfig; +import org.elasticsearch.xpack.core.security.authc.support.UserRoleMapper; +import org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken; +import org.elasticsearch.xpack.core.security.authc.support.mapper.ExpressionRoleMapping; +import org.elasticsearch.xpack.security.authc.support.mapper.NativeRoleMappingStore; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ExecutionException; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.emptyIterable; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class DisableNativeRoleMappingsStoreTests extends SecurityIntegTestCase { + + @Override + protected Collection> nodePlugins() { + List> plugins = new ArrayList<>(super.nodePlugins()); + plugins.add(PrivateCustomPlugin.class); + return plugins; + } + + @Override + protected boolean addMockHttpTransport() { + return false; // need real http + } + + public void testPutRoleMappingDisallowed() { + // transport action + NativeRoleMappingStore nativeRoleMappingStore = internalCluster().getInstance(NativeRoleMappingStore.class); + PlainActionFuture future = new PlainActionFuture<>(); + nativeRoleMappingStore.putRoleMapping(new PutRoleMappingRequest(), future); + ExecutionException e = expectThrows(ExecutionException.class, future::get); + assertThat(e.getMessage(), containsString("Native role mapping management is disabled")); + // rest request + Request request = new Request("POST", "_security/role_mapping/" + randomAlphaOfLength(8)); + RequestOptions.Builder options = request.getOptions().toBuilder(); + options.addHeader( + "Authorization", + UsernamePasswordToken.basicAuthHeaderValue( + SecuritySettingsSource.TEST_USER_NAME, + new SecureString(SecuritySettingsSourceField.TEST_PASSWORD.toCharArray()) + ) + ); + request.setOptions(options); + ResponseException e2 = expectThrows(ResponseException.class, () -> getRestClient().performRequest(request)); + assertThat(e2.getMessage(), containsString("Native role mapping management is not enabled in this Elasticsearch instance")); + assertThat(e2.getResponse().getStatusLine().getStatusCode(), is(410)); // gone + } + + public void testDeleteRoleMappingDisallowed() { + // transport action + NativeRoleMappingStore nativeRoleMappingStore = internalCluster().getInstance(NativeRoleMappingStore.class); + PlainActionFuture future = new PlainActionFuture<>(); + nativeRoleMappingStore.deleteRoleMapping(new DeleteRoleMappingRequest(), future); + ExecutionException e = expectThrows(ExecutionException.class, future::get); + assertThat(e.getMessage(), containsString("Native role mapping management is disabled")); + // rest request + Request request = new Request("DELETE", "_security/role_mapping/" + randomAlphaOfLength(8)); + RequestOptions.Builder options = request.getOptions().toBuilder(); + options.addHeader( + "Authorization", + UsernamePasswordToken.basicAuthHeaderValue( + SecuritySettingsSource.TEST_USER_NAME, + new SecureString(SecuritySettingsSourceField.TEST_PASSWORD.toCharArray()) + ) + ); + request.setOptions(options); + ResponseException e2 = expectThrows(ResponseException.class, () -> getRestClient().performRequest(request)); + assertThat(e2.getMessage(), containsString("Native role mapping management is not enabled in this Elasticsearch instance")); + assertThat(e2.getResponse().getStatusLine().getStatusCode(), is(410)); // gone + } + + public void testGetRoleMappingDisallowed() throws Exception { + // transport action + NativeRoleMappingStore nativeRoleMappingStore = internalCluster().getInstance(NativeRoleMappingStore.class); + PlainActionFuture> future = new PlainActionFuture<>(); + nativeRoleMappingStore.getRoleMappings(randomFrom(Set.of(randomAlphaOfLength(8)), null), future); + assertThat(future.get(), emptyIterable()); + // rest request + Request request = new Request("GET", "_security/role_mapping/" + randomAlphaOfLength(8)); + RequestOptions.Builder options = request.getOptions().toBuilder(); + options.addHeader( + "Authorization", + UsernamePasswordToken.basicAuthHeaderValue( + SecuritySettingsSource.TEST_USER_NAME, + new SecureString(SecuritySettingsSourceField.TEST_PASSWORD.toCharArray()) + ) + ); + request.setOptions(options); + ResponseException e2 = expectThrows(ResponseException.class, () -> getRestClient().performRequest(request)); + assertThat(e2.getMessage(), containsString("Native role mapping management is not enabled in this Elasticsearch instance")); + assertThat(e2.getResponse().getStatusLine().getStatusCode(), is(410)); // gone + } + + public void testResolveRoleMappings() throws Exception { + NativeRoleMappingStore nativeRoleMappingStore = internalCluster().getInstance(NativeRoleMappingStore.class); + UserRoleMapper.UserData userData = new UserRoleMapper.UserData( + randomAlphaOfLength(4), + null, + randomFrom(Set.of(randomAlphaOfLength(4)), Set.of()), + Map.of(), + mock(RealmConfig.class) + ); + PlainActionFuture> future = new PlainActionFuture<>(); + nativeRoleMappingStore.resolveRoles(userData, future); + assertThat(future.get(), emptyIterable()); + } + + public static class PrivateCustomPlugin extends Plugin { + + public static final Setting NATIVE_ROLE_MAPPINGS_SETTING = Setting.boolSetting( + "xpack.security.authc.native_role_mappings.enabled", + true, + Setting.Property.NodeScope + ); + + public PrivateCustomPlugin() {} + + @Override + public Settings additionalSettings() { + return Settings.builder().put(NATIVE_ROLE_MAPPINGS_SETTING.getKey(), false).build(); + } + + @Override + public List> getSettings() { + return List.of(NATIVE_ROLE_MAPPINGS_SETTING); + } + } +} diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/support/mapper/NativeRoleMappingStore.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/support/mapper/NativeRoleMappingStore.java index 4abf2e53d0264..926626f2eaf10 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/support/mapper/NativeRoleMappingStore.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/support/mapper/NativeRoleMappingStore.java @@ -48,7 +48,6 @@ import java.io.IOException; import java.util.Arrays; import java.util.Collection; -import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -85,6 +84,17 @@ */ public class NativeRoleMappingStore implements UserRoleMapper { + /** + * This setting is never registered by the security plugin - in order to disable the native role APIs + * another plugin must register it as a boolean setting and cause it to be set to `false`. + * + * If this setting is set to false then + *
    + *
  • the Rest APIs for native role mappings management are disabled.
  • + *
  • The native role mappings store will not map any roles to any user.
  • + *
+ */ + public static final String NATIVE_ROLE_MAPPINGS_ENABLED = "xpack.security.authc.native_role_mappings.enabled"; private static final Logger logger = LogManager.getLogger(NativeRoleMappingStore.class); static final String DOC_TYPE_FIELD = "doc_type"; static final String DOC_TYPE_ROLE_MAPPING = "role-mapping"; @@ -105,6 +115,7 @@ public class NativeRoleMappingStore implements UserRoleMapper { private final List realmsToRefresh = new CopyOnWriteArrayList<>(); private final boolean lastLoadCacheEnabled; private final AtomicReference> lastLoadRef = new AtomicReference<>(null); + private final boolean enabled; public NativeRoleMappingStore(Settings settings, Client client, SecurityIndexManager securityIndex, ScriptService scriptService) { this.settings = settings; @@ -112,16 +123,7 @@ public NativeRoleMappingStore(Settings settings, Client client, SecurityIndexMan this.securityIndex = securityIndex; this.scriptService = scriptService; this.lastLoadCacheEnabled = LAST_LOAD_CACHE_ENABLED_SETTING.get(settings); - } - - private static String getNameFromId(String id) { - assert id.startsWith(ID_PREFIX); - return id.substring(ID_PREFIX.length()); - } - - // package-private for testing - static String getIdForName(String name) { - return ID_PREFIX + name; + this.enabled = settings.getAsBoolean(NATIVE_ROLE_MAPPINGS_ENABLED, true); } /** @@ -129,6 +131,10 @@ static String getIdForName(String name) { * package private for unit testing */ protected void loadMappings(ActionListener> listener) { + if (enabled == false) { + listener.onResponse(List.of()); + return; + } if (securityIndex.isIndexUpToDate() == false) { listener.onFailure( new IllegalStateException( @@ -164,32 +170,21 @@ protected void loadMappings(ActionListener> listener () -> format("failed to load role mappings from index [%s] skipping all mappings.", SECURITY_MAIN_ALIAS), ex ); - listener.onResponse(Collections.emptyList()); + listener.onResponse(List.of()); })), doc -> buildMapping(getNameFromId(doc.getId()), doc.getSourceRef()) ); } } - protected static ExpressionRoleMapping buildMapping(String id, BytesReference source) { - try ( - XContentParser parser = XContentHelper.createParserNotCompressed( - LoggingDeprecationHandler.XCONTENT_PARSER_CONFIG, - source, - XContentType.JSON - ) - ) { - return ExpressionRoleMapping.parse(id, parser); - } catch (Exception e) { - logger.warn(() -> "Role mapping [" + id + "] cannot be parsed and will be skipped", e); - return null; - } - } - /** * Stores (create or update) a single mapping in the index */ public void putRoleMapping(PutRoleMappingRequest request, ActionListener listener) { + if (enabled == false) { + listener.onFailure(new IllegalStateException("Native role mapping management is disabled")); + return; + } // Validate all templates before storing the role mapping for (TemplateRoleName templateRoleName : request.getRoleTemplates()) { templateRoleName.validate(scriptService); @@ -201,6 +196,10 @@ public void putRoleMapping(PutRoleMappingRequest request, ActionListener listener) { + if (enabled == false) { + listener.onFailure(new IllegalStateException("Native role mapping management is disabled")); + return; + } modifyMapping(request.getName(), this::innerDeleteMapping, request, listener); } @@ -229,6 +228,10 @@ private void modifyMapping( } private void innerPutMapping(PutRoleMappingRequest request, ActionListener listener) { + if (enabled == false) { + listener.onFailure(new IllegalStateException("Native role mapping management is disabled")); + return; + } final ExpressionRoleMapping mapping = request.getMapping(); securityIndex.prepareIndexIfNeededThenExecute(listener::onFailure, () -> { final XContentBuilder xContentBuilder; @@ -266,6 +269,10 @@ public void onFailure(Exception e) { } private void innerDeleteMapping(DeleteRoleMappingRequest request, ActionListener listener) { + if (enabled == false) { + listener.onFailure(new IllegalStateException("Native role mapping management is disabled")); + return; + } final SecurityIndexManager frozenSecurityIndex = securityIndex.defensiveCopy(); if (frozenSecurityIndex.indexExists() == false) { listener.onResponse(false); @@ -307,7 +314,9 @@ public void onFailure(Exception e) { * Otherwise it retrieves the specified mappings by name. */ public void getRoleMappings(Set names, ActionListener> listener) { - if (names == null || names.isEmpty()) { + if (enabled == false) { + listener.onResponse(List.of()); + } else if (names == null || names.isEmpty()) { getMappings(listener); } else { getMappings(listener.safeMap(mappings -> mappings.stream().filter(m -> names.contains(m.getName())).toList())); @@ -315,10 +324,14 @@ public void getRoleMappings(Set names, ActionListener> listener) { + if (enabled == false) { + listener.onResponse(List.of()); + return; + } final SecurityIndexManager frozenSecurityIndex = securityIndex.defensiveCopy(); if (frozenSecurityIndex.indexExists() == false) { logger.debug("The security index does not exist - no role mappings can be loaded"); - listener.onResponse(Collections.emptyList()); + listener.onResponse(List.of()); return; } final List lastLoad = lastLoadRef.get(); @@ -329,7 +342,7 @@ private void getMappings(ActionListener> listener) { listener.onResponse(lastLoad); } else { logger.debug("The security index exists but is closed - no role mappings can be loaded"); - listener.onResponse(Collections.emptyList()); + listener.onResponse(List.of()); } } else if (frozenSecurityIndex.isAvailable(SEARCH_SHARDS) == false) { final ElasticsearchException unavailableReason = frozenSecurityIndex.getUnavailableReason(SEARCH_SHARDS); @@ -365,20 +378,15 @@ List getLastLoad() { * */ public void usageStats(ActionListener> listener) { - if (securityIndex.indexIsClosed() || securityIndex.isAvailable(SEARCH_SHARDS) == false) { - reportStats(listener, Collections.emptyList()); + if (enabled == false) { + reportStats(listener, List.of()); + } else if (securityIndex.indexIsClosed() || securityIndex.isAvailable(SEARCH_SHARDS) == false) { + reportStats(listener, List.of()); } else { getMappings(ActionListener.wrap(mappings -> reportStats(listener, mappings), listener::onFailure)); } } - private static void reportStats(ActionListener> listener, List mappings) { - Map usageStats = new HashMap<>(); - usageStats.put("size", mappings.size()); - usageStats.put("enabled", mappings.stream().filter(ExpressionRoleMapping::isEnabled).count()); - listener.onResponse(usageStats); - } - public void onSecurityIndexStateChange(SecurityIndexManager.State previousState, SecurityIndexManager.State currentState) { if (isMoveFromRedToNonRed(previousState, currentState) || isIndexDeleted(previousState, currentState) @@ -388,28 +396,6 @@ public void onSecurityIndexStateChange(SecurityIndexManager.State previousState, } } - private void refreshRealms(ActionListener listener, Result result) { - if (realmsToRefresh.isEmpty()) { - listener.onResponse(result); - return; - } - - final String[] realmNames = this.realmsToRefresh.toArray(Strings.EMPTY_ARRAY); - executeAsyncWithOrigin( - client, - SECURITY_ORIGIN, - ClearRealmCacheAction.INSTANCE, - new ClearRealmCacheRequest().realms(realmNames), - ActionListener.wrap(response -> { - logger.debug(() -> format("Cleared cached in realms [%s] due to role mapping change", Arrays.toString(realmNames))); - listener.onResponse(result); - }, ex -> { - logger.warn(() -> "Failed to clear cache for realms [" + Arrays.toString(realmNames) + "]", ex); - listener.onFailure(ex); - }) - ); - } - @Override public void resolveRoles(UserData user, ActionListener> listener) { getRoleMappings(null, ActionListener.wrap(mappings -> { @@ -438,4 +424,57 @@ public void resolveRoles(UserData user, ActionListener> listener) { public void refreshRealmOnChange(CachingRealm realm) { realmsToRefresh.add(realm.name()); } + + private void refreshRealms(ActionListener listener, Result result) { + if (enabled == false || realmsToRefresh.isEmpty()) { + listener.onResponse(result); + return; + } + final String[] realmNames = this.realmsToRefresh.toArray(Strings.EMPTY_ARRAY); + executeAsyncWithOrigin( + client, + SECURITY_ORIGIN, + ClearRealmCacheAction.INSTANCE, + new ClearRealmCacheRequest().realms(realmNames), + ActionListener.wrap(response -> { + logger.debug(() -> format("Cleared cached in realms [%s] due to role mapping change", Arrays.toString(realmNames))); + listener.onResponse(result); + }, ex -> { + logger.warn(() -> "Failed to clear cache for realms [" + Arrays.toString(realmNames) + "]", ex); + listener.onFailure(ex); + }) + ); + } + + protected static ExpressionRoleMapping buildMapping(String id, BytesReference source) { + try ( + XContentParser parser = XContentHelper.createParserNotCompressed( + LoggingDeprecationHandler.XCONTENT_PARSER_CONFIG, + source, + XContentType.JSON + ) + ) { + return ExpressionRoleMapping.parse(id, parser); + } catch (Exception e) { + logger.warn(() -> "Role mapping [" + id + "] cannot be parsed and will be skipped", e); + return null; + } + } + + // package-private for testing + static String getIdForName(String name) { + return ID_PREFIX + name; + } + + private static void reportStats(ActionListener> listener, List mappings) { + Map usageStats = new HashMap<>(); + usageStats.put("size", mappings.size()); + usageStats.put("enabled", mappings.stream().filter(ExpressionRoleMapping::isEnabled).count()); + listener.onResponse(usageStats); + } + + private static String getNameFromId(String id) { + assert id.startsWith(ID_PREFIX); + return id.substring(ID_PREFIX.length()); + } } diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/action/rolemapping/NativeRoleMappingBaseRestHandler.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/action/rolemapping/NativeRoleMappingBaseRestHandler.java new file mode 100644 index 0000000000000..e0d692814988b --- /dev/null +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/action/rolemapping/NativeRoleMappingBaseRestHandler.java @@ -0,0 +1,47 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.security.rest.action.rolemapping; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.license.XPackLicenseState; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.security.authc.support.mapper.NativeRoleMappingStore; +import org.elasticsearch.xpack.security.rest.action.SecurityBaseRestHandler; + +abstract class NativeRoleMappingBaseRestHandler extends SecurityBaseRestHandler { + + private static final Logger logger = LogManager.getLogger(NativeRoleMappingBaseRestHandler.class); + + NativeRoleMappingBaseRestHandler(Settings settings, XPackLicenseState licenseState) { + super(settings, licenseState); + } + + @Override + protected Exception innerCheckFeatureAvailable(RestRequest request) { + Boolean nativeRoleMappingsEnabled = settings.getAsBoolean(NativeRoleMappingStore.NATIVE_ROLE_MAPPINGS_ENABLED, true); + if (nativeRoleMappingsEnabled == false) { + logger.debug( + "Attempt to call [{} {}] but [{}] is [{}]", + request.method(), + request.rawPath(), + NativeRoleMappingStore.NATIVE_ROLE_MAPPINGS_ENABLED, + settings.get(NativeRoleMappingStore.NATIVE_ROLE_MAPPINGS_ENABLED) + ); + return new ElasticsearchStatusException( + "Native role mapping management is not enabled in this Elasticsearch instance", + RestStatus.GONE + ); + } else { + return null; + } + } +} diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/action/rolemapping/RestDeleteRoleMappingAction.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/action/rolemapping/RestDeleteRoleMappingAction.java index ee1952e359dd2..5964228009c4b 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/action/rolemapping/RestDeleteRoleMappingAction.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/action/rolemapping/RestDeleteRoleMappingAction.java @@ -19,7 +19,6 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.security.action.rolemapping.DeleteRoleMappingRequestBuilder; import org.elasticsearch.xpack.core.security.action.rolemapping.DeleteRoleMappingResponse; -import org.elasticsearch.xpack.security.rest.action.SecurityBaseRestHandler; import java.io.IOException; import java.util.List; @@ -30,7 +29,7 @@ * Rest endpoint to delete a role-mapping from the {@link org.elasticsearch.xpack.security.authc.support.mapper.NativeRoleMappingStore} */ @ServerlessScope(Scope.INTERNAL) -public class RestDeleteRoleMappingAction extends SecurityBaseRestHandler { +public class RestDeleteRoleMappingAction extends NativeRoleMappingBaseRestHandler { public RestDeleteRoleMappingAction(Settings settings, XPackLicenseState licenseState) { super(settings, licenseState); diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/action/rolemapping/RestGetRoleMappingsAction.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/action/rolemapping/RestGetRoleMappingsAction.java index 36b3f05668d0a..7a3378d843bca 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/action/rolemapping/RestGetRoleMappingsAction.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/action/rolemapping/RestGetRoleMappingsAction.java @@ -20,7 +20,6 @@ import org.elasticsearch.xpack.core.security.action.rolemapping.GetRoleMappingsRequestBuilder; import org.elasticsearch.xpack.core.security.action.rolemapping.GetRoleMappingsResponse; import org.elasticsearch.xpack.core.security.authc.support.mapper.ExpressionRoleMapping; -import org.elasticsearch.xpack.security.rest.action.SecurityBaseRestHandler; import java.io.IOException; import java.util.List; @@ -31,7 +30,7 @@ * Rest endpoint to retrieve a role-mapping from the org.elasticsearch.xpack.security.authc.support.mapper.NativeRoleMappingStore */ @ServerlessScope(Scope.INTERNAL) -public class RestGetRoleMappingsAction extends SecurityBaseRestHandler { +public class RestGetRoleMappingsAction extends NativeRoleMappingBaseRestHandler { public RestGetRoleMappingsAction(Settings settings, XPackLicenseState licenseState) { super(settings, licenseState); diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/action/rolemapping/RestPutRoleMappingAction.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/action/rolemapping/RestPutRoleMappingAction.java index bb6b07c1c3c95..e7e24037543fa 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/action/rolemapping/RestPutRoleMappingAction.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/action/rolemapping/RestPutRoleMappingAction.java @@ -19,7 +19,6 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.security.action.rolemapping.PutRoleMappingRequestBuilder; import org.elasticsearch.xpack.core.security.action.rolemapping.PutRoleMappingResponse; -import org.elasticsearch.xpack.security.rest.action.SecurityBaseRestHandler; import java.io.IOException; import java.util.List; @@ -33,7 +32,7 @@ * @see org.elasticsearch.xpack.security.authc.support.mapper.NativeRoleMappingStore */ @ServerlessScope(Scope.INTERNAL) -public class RestPutRoleMappingAction extends SecurityBaseRestHandler { +public class RestPutRoleMappingAction extends NativeRoleMappingBaseRestHandler { public RestPutRoleMappingAction(Settings settings, XPackLicenseState licenseState) { super(settings, licenseState); From c20ba4df840954d4e2c60f494c66b7c4100b06ea Mon Sep 17 00:00:00 2001 From: David Turner Date: Fri, 12 Apr 2024 12:19:05 +0100 Subject: [PATCH 03/20] Verify target node during primary relocation (#107407) If a primary relocation is cancelled and then restarted on a different node then today we may fail with a somewhat-opaque message saying that the `relocation target is no longer part of the replication group`. This commit adds a preliminary check on the target node ID so that we can fail in this situation with a message that should be easier for users to understand. --- .../elasticsearch/index/shard/IndexShard.java | 28 ++++++++-- .../index/shard/IndexShardTests.java | 54 +++++++++++++------ 2 files changed, 64 insertions(+), 18 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java b/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java index 046483a6b074f..ccc0ccd30d578 100644 --- a/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java +++ b/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java @@ -757,6 +757,16 @@ public IndexShardState markAsRecovering(String reason, RecoveryState recoverySta private final AtomicBoolean primaryReplicaResyncInProgress = new AtomicBoolean(); + // temporary compatibility shim while adding targetNodeId parameter to dependencies + @Deprecated(forRemoval = true) + public void relocated( + final String targetAllocationId, + final BiConsumer> consumer, + final ActionListener listener + ) throws IllegalIndexShardStateException, IllegalStateException { + relocated(null, targetAllocationId, consumer, listener); + } + /** * Completes the relocation. Operations are blocked and current operations are drained before changing state to relocated. The provided * {@link BiConsumer} is executed after all operations are successfully blocked. @@ -768,6 +778,7 @@ public IndexShardState markAsRecovering(String reason, RecoveryState recoverySta * @throws IllegalStateException if the relocation target is no longer part of the replication group */ public void relocated( + final String targetNodeId, final String targetAllocationId, final BiConsumer> consumer, final ActionListener listener @@ -788,7 +799,7 @@ public void onResponse(Releasable releasable) { * context via a network operation. Doing this under the mutex can implicitly block the cluster state update thread * on network operations. */ - verifyRelocatingState(); + verifyRelocatingState(targetNodeId); final ReplicationTracker.PrimaryContext primaryContext = replicationTracker.startRelocationHandoff( targetAllocationId ); @@ -803,7 +814,7 @@ public void onResponse(Void unused) { try { // make changes to primaryMode and relocated flag only under mutex synchronized (mutex) { - verifyRelocatingState(); + verifyRelocatingState(targetNodeId); replicationTracker.completeRelocationHandoff(); } wrappedInnerListener.onResponse(null); @@ -857,7 +868,8 @@ public void onFailure(Exception e) { } } - private void verifyRelocatingState() { + // TODO only nullable temporarily, remove once deprecated relocated() override is removed, see ES-6725 + private void verifyRelocatingState(@Nullable String targetNodeId) { if (state != IndexShardState.STARTED) { throw new IndexShardNotStartedException(shardId, state); } @@ -871,6 +883,16 @@ private void verifyRelocatingState() { throw new IllegalIndexShardStateException(shardId, IndexShardState.STARTED, ": shard is no longer relocating " + shardRouting); } + if (targetNodeId != null) { + if (targetNodeId.equals(shardRouting.relocatingNodeId()) == false) { + throw new IllegalIndexShardStateException( + shardId, + IndexShardState.STARTED, + ": shard is no longer relocating to node [" + targetNodeId + "]: " + shardRouting + ); + } + } + if (primaryReplicaResyncInProgress.get()) { throw new IllegalIndexShardStateException( shardId, diff --git a/server/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java b/server/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java index c2706a7a3cf22..73d6fb1a184a2 100644 --- a/server/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java +++ b/server/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java @@ -1916,10 +1916,15 @@ public void testDelayedOperationsBeforeAndAfterRelocated() throws Exception { Thread recoveryThread = new Thread(() -> { try { startRecovery.await(); - shard.relocated(routing.getTargetRelocatingShard().allocationId().getId(), (primaryContext, listener) -> { - relocationStarted.countDown(); - listener.onResponse(null); - }, ActionListener.noop()); + shard.relocated( + routing.relocatingNodeId(), + routing.getTargetRelocatingShard().allocationId().getId(), + (primaryContext, listener) -> { + relocationStarted.countDown(); + listener.onResponse(null); + }, + ActionListener.noop() + ); } catch (InterruptedException e) { throw new RuntimeException(e); } @@ -2123,29 +2128,48 @@ protected void doRun() throws Exception { closeShards(shard); } - public void testRelocateMissingTarget() throws Exception { + public void testRelocateMismatchedTarget() throws Exception { final IndexShard shard = newStartedShard(true); final ShardRouting original = shard.routingEntry(); - final ShardRouting toNode1 = ShardRoutingHelper.relocate(original, "node_1"); - IndexShardTestCase.updateRoutingEntry(shard, toNode1); + + final ShardRouting wrongTargetNodeShardRouting = ShardRoutingHelper.relocate(original, "node_1"); + IndexShardTestCase.updateRoutingEntry(shard, wrongTargetNodeShardRouting); + IndexShardTestCase.updateRoutingEntry(shard, original); + + final ShardRouting wrongTargetAllocationIdShardRouting = ShardRoutingHelper.relocate(original, "node_2"); + IndexShardTestCase.updateRoutingEntry(shard, wrongTargetAllocationIdShardRouting); IndexShardTestCase.updateRoutingEntry(shard, original); - final ShardRouting toNode2 = ShardRoutingHelper.relocate(original, "node_2"); - IndexShardTestCase.updateRoutingEntry(shard, toNode2); + + final ShardRouting correctShardRouting = ShardRoutingHelper.relocate(original, "node_2"); + IndexShardTestCase.updateRoutingEntry(shard, correctShardRouting); + final AtomicBoolean relocated = new AtomicBoolean(); - final IllegalStateException error = expectThrows( + + final IllegalIndexShardStateException wrongNodeException = expectThrows( + IllegalIndexShardStateException.class, + () -> blockingCallRelocated(shard, wrongTargetNodeShardRouting, (ctx, listener) -> relocated.set(true)) + ); + assertThat( + wrongNodeException.getMessage(), + equalTo("CurrentState[STARTED] : shard is no longer relocating to node [node_1]: " + correctShardRouting) + ); + assertFalse(relocated.get()); + + final IllegalStateException wrongTargetIdException = expectThrows( IllegalStateException.class, - () -> blockingCallRelocated(shard, toNode1, (ctx, listener) -> relocated.set(true)) + () -> blockingCallRelocated(shard, wrongTargetAllocationIdShardRouting, (ctx, listener) -> relocated.set(true)) ); assertThat( - error.getMessage(), + wrongTargetIdException.getMessage(), equalTo( "relocation target [" - + toNode1.getTargetRelocatingShard().allocationId().getId() + + wrongTargetAllocationIdShardRouting.getTargetRelocatingShard().allocationId().getId() + "] is no longer part of the replication group" ) ); assertFalse(relocated.get()); - blockingCallRelocated(shard, toNode2, (ctx, listener) -> { + + blockingCallRelocated(shard, correctShardRouting, (ctx, listener) -> { relocated.set(true); listener.onResponse(null); }); @@ -4937,7 +4961,7 @@ private static void blockingCallRelocated( BiConsumer> consumer ) { PlainActionFuture.get( - f -> indexShard.relocated(routing.getTargetRelocatingShard().allocationId().getId(), consumer, f) + f -> indexShard.relocated(routing.relocatingNodeId(), routing.getTargetRelocatingShard().allocationId().getId(), consumer, f) ); } } From 7c3dac181aa13e509ffdfa5d3996b2a0355de8cf Mon Sep 17 00:00:00 2001 From: Fang Xing <155562079+fang-xing-esql@users.noreply.github.com> Date: Fri, 12 Apr 2024 08:44:17 -0400 Subject: [PATCH 04/20] support binary comparison in implicit casting (#107388) --- .../xpack/esql/analysis/Analyzer.java | 65 ++----------------- .../xpack/esql/analysis/AnalyzerTests.java | 30 ++++----- .../xpack/esql/analysis/VerifierTests.java | 2 +- 3 files changed, 17 insertions(+), 80 deletions(-) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java index 13e088b81c95f..02969ed56798f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java @@ -52,6 +52,7 @@ import org.elasticsearch.xpack.ql.expression.function.FunctionRegistry; import org.elasticsearch.xpack.ql.expression.function.UnresolvedFunction; import org.elasticsearch.xpack.ql.expression.function.scalar.ScalarFunction; +import org.elasticsearch.xpack.ql.expression.predicate.BinaryOperator; import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.BinaryComparison; import org.elasticsearch.xpack.ql.index.EsIndex; import org.elasticsearch.xpack.ql.plan.TableIdentifier; @@ -62,7 +63,6 @@ import org.elasticsearch.xpack.ql.plan.logical.Project; import org.elasticsearch.xpack.ql.rule.ParameterizedRule; import org.elasticsearch.xpack.ql.rule.ParameterizedRuleExecutor; -import org.elasticsearch.xpack.ql.rule.Rule; import org.elasticsearch.xpack.ql.rule.RuleExecutor; import org.elasticsearch.xpack.ql.session.Configuration; import org.elasticsearch.xpack.ql.tree.Source; @@ -94,7 +94,6 @@ import static org.elasticsearch.common.logging.LoggerMessageFormat.format; import static org.elasticsearch.xpack.core.enrich.EnrichPolicy.GEO_MATCH_TYPE; import static org.elasticsearch.xpack.esql.stats.FeatureMetric.LIMIT; -import static org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter.dateTimeToLong; import static org.elasticsearch.xpack.esql.type.EsqlDataTypes.GEO_POINT; import static org.elasticsearch.xpack.esql.type.EsqlDataTypes.GEO_SHAPE; import static org.elasticsearch.xpack.ql.type.DataTypes.DATETIME; @@ -124,7 +123,7 @@ public class Analyzer extends ParameterizedRuleExecutor("Finish Analysis", Limiter.ONCE, new AddImplicitLimit(), new PromoteStringsInDateComparisons()); + var finish = new Batch<>("Finish Analysis", Limiter.ONCE, new AddImplicitLimit()); rules = List.of(resolution, finish); } @@ -778,58 +777,6 @@ public LogicalPlan apply(LogicalPlan logicalPlan, AnalyzerContext context) { } } - private static class PromoteStringsInDateComparisons extends Rule { - - @Override - public LogicalPlan apply(LogicalPlan plan) { - return plan.transformExpressionsUp(BinaryComparison.class, PromoteStringsInDateComparisons::promote); - } - - private static Expression promote(BinaryComparison cmp) { - if (cmp.resolved() == false) { - return cmp; - } - var left = cmp.left(); - var right = cmp.right(); - boolean modified = false; - if (left.dataType() == DATETIME) { - if (right.dataType() == KEYWORD && right.foldable() && ((right instanceof EsqlScalarFunction) == false)) { - right = stringToDate(right); - modified = true; - } - } else { - if (right.dataType() == DATETIME) { - if (left.dataType() == KEYWORD && left.foldable() && ((left instanceof EsqlScalarFunction) == false)) { - left = stringToDate(left); - modified = true; - } - } - } - return modified ? cmp.replaceChildren(List.of(left, right)) : cmp; - } - - private static Expression stringToDate(Expression stringExpression) { - var str = stringExpression.fold().toString(); - - Long millis = null; - // TODO: better control over this string format - do we want this to be flexible or always redirect folks to use date parsing - try { - millis = str == null ? null : dateTimeToLong(str); - } catch (Exception ex) { // in case of exception, millis will be null which will trigger an error - } - - var source = stringExpression.source(); - Expression result; - if (millis == null) { - var errorMessage = format(null, "Invalid date [{}]", str); - result = new UnresolvedAttribute(source, source.text(), null, errorMessage); - } else { - result = new Literal(source, millis, DATETIME); - } - return result; - } - } - private BitSet gatherPreAnalysisMetrics(LogicalPlan plan, BitSet b) { // count only the explicit "limit" the user added, otherwise all queries will have a "limit" and telemetry won't reflect reality if (plan.collectFirstChildren(Limit.class::isInstance).isEmpty() == false) { @@ -852,11 +799,9 @@ private static Expression cast(ScalarFunction f, EsqlFunctionRegistry registry) if (f instanceof EsqlScalarFunction esf) { return processScalarFunction(esf, registry); } - - if (f instanceof EsqlArithmeticOperation eao) { - return processArithmeticOperation(eao); + if (f instanceof EsqlArithmeticOperation || f instanceof BinaryComparison) { + return processBinaryOperator((BinaryOperator) f); } - return f; } @@ -888,7 +833,7 @@ private static Expression processScalarFunction(EsqlScalarFunction f, EsqlFuncti return childrenChanged ? f.replaceChildren(newChildren) : f; } - private static Expression processArithmeticOperation(EsqlArithmeticOperation o) { + private static Expression processBinaryOperator(BinaryOperator o) { Expression left = o.left(); Expression right = o.right(); if (left.resolved() == false || right.resolved() == false) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java index f4ecf38915a29..7a85ca1628048 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java @@ -1003,13 +1003,7 @@ public void testCompareIntToString() { from test | where emp_no COMPARISON "foo" """.replace("COMPARISON", comparison))); - assertThat( - e.getMessage(), - containsString( - "first argument of [emp_no COMPARISON \"foo\"] is [numeric] so second argument must also be [numeric] but was [keyword]" - .replace("COMPARISON", comparison) - ) - ); + assertThat(e.getMessage(), containsString("Cannot convert string [foo] to [INTEGER]".replace("COMPARISON", comparison))); } } @@ -1019,13 +1013,7 @@ public void testCompareStringToInt() { from test | where "foo" COMPARISON emp_no """.replace("COMPARISON", comparison))); - assertThat( - e.getMessage(), - containsString( - "first argument of [\"foo\" COMPARISON emp_no] is [keyword] so second argument must also be [keyword] but was [integer]" - .replace("COMPARISON", comparison) - ) - ); + assertThat(e.getMessage(), containsString("Cannot convert string [foo] to [INTEGER]".replace("COMPARISON", comparison))); } } @@ -1051,11 +1039,15 @@ public void testCompareStringToDate() { public void testCompareDateToStringFails() { for (String comparison : COMPARISONS) { - verifyUnsupported(""" - from test - | where date COMPARISON "not-a-date" - | keep date - """.replace("COMPARISON", comparison), "Invalid date [not-a-date]", "mapping-multi-field-variation.json"); + verifyUnsupported( + """ + from test + | where date COMPARISON "not-a-date" + | keep date + """.replace("COMPARISON", comparison), + "Cannot convert string [not-a-date] to [DATETIME]", + "mapping-multi-field-variation.json" + ); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java index e558dbe615642..8275f76d9a55c 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java @@ -317,7 +317,7 @@ public void testSumOnDate() { public void testWrongInputParam() { assertEquals( - "1:19: first argument of [emp_no == ?] is [numeric] so second argument must also be [numeric] but was [keyword]", + "1:29: Cannot convert string [foo] to [INTEGER], error [Cannot parse number [foo]]", error("from test | where emp_no == ?", "foo") ); From 6a300509cdbb03de32a55a0f5c52bb35ddb7576f Mon Sep 17 00:00:00 2001 From: Artem Prigoda Date: Fri, 12 Apr 2024 15:04:08 +0200 Subject: [PATCH 05/20] Add metric for calculating index flush time excluding waiting on locks (#107196) Add a new `total_time_excluding_waiting_on_lock metric` to the index flush stats that measures the flushing time excluding waiting on the flush lock. This metrics provides a more granular view on flush performance and without the overhead of flush throttling. Resolves ES-7201 --- docs/changelog/107196.yaml | 5 +++ docs/reference/cluster/nodes-stats.asciidoc | 1 + .../org/elasticsearch/TransportVersions.java | 1 + .../elasticsearch/index/engine/Engine.java | 7 ++++ .../index/engine/InternalEngine.java | 11 ++++++ .../elasticsearch/index/flush/FlushStats.java | 37 ++++++++++++++++--- .../elasticsearch/index/shard/IndexShard.java | 7 +++- .../monitor/metrics/NodeMetrics.java | 24 ++++++++++++ .../cluster/node/stats/NodeStatsTests.java | 2 +- .../index/shard/IndexShardTests.java | 36 ++++++++++++++++++ 10 files changed, 124 insertions(+), 7 deletions(-) create mode 100644 docs/changelog/107196.yaml diff --git a/docs/changelog/107196.yaml b/docs/changelog/107196.yaml new file mode 100644 index 0000000000000..9892ccf71856f --- /dev/null +++ b/docs/changelog/107196.yaml @@ -0,0 +1,5 @@ +pr: 107196 +summary: Add metric for calculating index flush time excluding waiting on locks +area: Engine +type: enhancement +issues: [] diff --git a/docs/reference/cluster/nodes-stats.asciidoc b/docs/reference/cluster/nodes-stats.asciidoc index c008b074acccd..07328ba98bcec 100644 --- a/docs/reference/cluster/nodes-stats.asciidoc +++ b/docs/reference/cluster/nodes-stats.asciidoc @@ -626,6 +626,7 @@ Total time spent performing flush operations. (integer) Total time in milliseconds spent performing flush operations. + ======= `warmer`:: diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 50008832712b4..5fc62afb0c27d 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -171,6 +171,7 @@ static TransportVersion def(int id) { public static final TransportVersion MODIFY_DATA_STREAM_FAILURE_STORES = def(8_630_00_0); public static final TransportVersion ML_INFERENCE_RERANK_NEW_RESPONSE_FORMAT = def(8_631_00_0); public static final TransportVersion HIGHLIGHTERS_TAGS_ON_FIELD_LEVEL = def(8_632_00_0); + public static final TransportVersion TRACK_FLUSH_TIME_EXCLUDING_WAITING_ON_LOCKS = def(8_633_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/index/engine/Engine.java b/server/src/main/java/org/elasticsearch/index/engine/Engine.java index a910e496ce1b5..8ee536ec72248 100644 --- a/server/src/main/java/org/elasticsearch/index/engine/Engine.java +++ b/server/src/main/java/org/elasticsearch/index/engine/Engine.java @@ -395,6 +395,13 @@ boolean throttleLockIsHeldByCurrentThread() { // to be used in assertions and te */ public abstract void trimOperationsFromTranslog(long belowTerm, long aboveSeqNo) throws EngineException; + /** + * Returns the total time flushes have been executed excluding waiting on locks. + */ + public long getTotalFlushTimeExcludingWaitingOnLockInMillis() { + return 0; + } + /** A Lock implementation that always allows the lock to be acquired */ protected static final class NoOpLock implements Lock { diff --git a/server/src/main/java/org/elasticsearch/index/engine/InternalEngine.java b/server/src/main/java/org/elasticsearch/index/engine/InternalEngine.java index 0b2e83532d030..d4371d71a4324 100644 --- a/server/src/main/java/org/elasticsearch/index/engine/InternalEngine.java +++ b/server/src/main/java/org/elasticsearch/index/engine/InternalEngine.java @@ -52,6 +52,7 @@ import org.elasticsearch.common.lucene.uid.VersionsAndSeqNoResolver; import org.elasticsearch.common.lucene.uid.VersionsAndSeqNoResolver.DocIdAndSeqNo; import org.elasticsearch.common.metrics.CounterMetric; +import org.elasticsearch.common.metrics.MeanMetric; import org.elasticsearch.common.unit.ByteSizeUnit; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.Maps; @@ -107,6 +108,7 @@ import java.util.Optional; import java.util.Set; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; @@ -177,6 +179,8 @@ public class InternalEngine extends Engine { private final CounterMetric numDocDeletes = new CounterMetric(); private final CounterMetric numDocAppends = new CounterMetric(); private final CounterMetric numDocUpdates = new CounterMetric(); + private final MeanMetric totalFlushTimeExcludingWaitingOnLock = new MeanMetric(); + private final NumericDocValuesField softDeletesField = Lucene.newSoftDeletesField(); private final SoftDeletesPolicy softDeletesPolicy; private final LastRefreshedCheckpointListener lastRefreshedCheckpointListener; @@ -2195,6 +2199,7 @@ protected void flushHoldingLock(boolean force, boolean waitIfOngoing, ActionList logger.trace("acquired flush lock immediately"); } + final long startTime = System.nanoTime(); try { // Only flush if (1) Lucene has uncommitted docs, or (2) forced by caller, or (3) the // newly created commit points to a different translog generation (can free translog), @@ -2246,6 +2251,7 @@ protected void flushHoldingLock(boolean force, boolean waitIfOngoing, ActionList listener.onFailure(e); return; } finally { + totalFlushTimeExcludingWaitingOnLock.inc(System.nanoTime() - startTime); flushLock.unlock(); logger.trace("released flush lock"); } @@ -3066,6 +3072,11 @@ long getNumDocUpdates() { return numDocUpdates.count(); } + @Override + public long getTotalFlushTimeExcludingWaitingOnLockInMillis() { + return TimeUnit.NANOSECONDS.toMillis(totalFlushTimeExcludingWaitingOnLock.sum()); + } + @Override public int countChanges(String source, long fromSeqNo, long toSeqNo) throws IOException { ensureOpen(); diff --git a/server/src/main/java/org/elasticsearch/index/flush/FlushStats.java b/server/src/main/java/org/elasticsearch/index/flush/FlushStats.java index 7114b7b0e5c4f..e514a6d2adac0 100644 --- a/server/src/main/java/org/elasticsearch/index/flush/FlushStats.java +++ b/server/src/main/java/org/elasticsearch/index/flush/FlushStats.java @@ -8,6 +8,7 @@ package org.elasticsearch.index.flush; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; @@ -23,6 +24,7 @@ public class FlushStats implements Writeable, ToXContentFragment { private long total; private long periodic; private long totalTimeInMillis; + private long totalTimeExcludingWaitingOnLockInMillis; public FlushStats() { @@ -32,18 +34,22 @@ public FlushStats(StreamInput in) throws IOException { total = in.readVLong(); totalTimeInMillis = in.readVLong(); periodic = in.readVLong(); + totalTimeExcludingWaitingOnLockInMillis = in.getTransportVersion() + .onOrAfter(TransportVersions.TRACK_FLUSH_TIME_EXCLUDING_WAITING_ON_LOCKS) ? in.readVLong() : 0L; } - public FlushStats(long total, long periodic, long totalTimeInMillis) { + public FlushStats(long total, long periodic, long totalTimeInMillis, long totalTimeExcludingWaitingOnLockInMillis) { this.total = total; this.periodic = periodic; this.totalTimeInMillis = totalTimeInMillis; + this.totalTimeExcludingWaitingOnLockInMillis = totalTimeExcludingWaitingOnLockInMillis; } - public void add(long total, long periodic, long totalTimeInMillis) { + public void add(long total, long periodic, long totalTimeInMillis, long totalTimeWithoutWaitingInMillis) { this.total += total; this.periodic += periodic; this.totalTimeInMillis += totalTimeInMillis; + this.totalTimeExcludingWaitingOnLockInMillis += totalTimeWithoutWaitingInMillis; } public void add(FlushStats flushStats) { @@ -57,6 +63,7 @@ public void addTotals(FlushStats flushStats) { this.total += flushStats.total; this.periodic += flushStats.periodic; this.totalTimeInMillis += flushStats.totalTimeInMillis; + this.totalTimeExcludingWaitingOnLockInMillis += flushStats.totalTimeExcludingWaitingOnLockInMillis; } /** @@ -81,18 +88,30 @@ public long getTotalTimeInMillis() { } /** - * The total time merges have been executed. + * The total time flushes have been executed. */ public TimeValue getTotalTime() { return new TimeValue(totalTimeInMillis); } + /** + * The total time flushes have been executed excluding waiting time on locks (in milliseconds). + */ + public long getTotalTimeExcludingWaitingOnLockMillis() { + return totalTimeExcludingWaitingOnLockInMillis; + } + @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(Fields.FLUSH); builder.field(Fields.TOTAL, total); builder.field(Fields.PERIODIC, periodic); builder.humanReadableField(Fields.TOTAL_TIME_IN_MILLIS, Fields.TOTAL_TIME, getTotalTime()); + builder.humanReadableField( + Fields.TOTAL_TIME_EXCLUDING_WAITING_ON_LOCK_IN_MILLIS, + Fields.TOTAL_TIME_EXCLUDING_WAITING, + new TimeValue(getTotalTimeExcludingWaitingOnLockMillis()) + ); builder.endObject(); return builder; } @@ -103,6 +122,8 @@ static final class Fields { static final String PERIODIC = "periodic"; static final String TOTAL_TIME = "total_time"; static final String TOTAL_TIME_IN_MILLIS = "total_time_in_millis"; + static final String TOTAL_TIME_EXCLUDING_WAITING = "total_time_excluding_waiting"; + static final String TOTAL_TIME_EXCLUDING_WAITING_ON_LOCK_IN_MILLIS = "total_time_excluding_waiting_on_lock_in_millis"; } @Override @@ -110,6 +131,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeVLong(total); out.writeVLong(totalTimeInMillis); out.writeVLong(periodic); + if (out.getTransportVersion().onOrAfter(TransportVersions.TRACK_FLUSH_TIME_EXCLUDING_WAITING_ON_LOCKS)) { + out.writeVLong(totalTimeExcludingWaitingOnLockInMillis); + } } @Override @@ -117,11 +141,14 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; FlushStats that = (FlushStats) o; - return total == that.total && totalTimeInMillis == that.totalTimeInMillis && periodic == that.periodic; + return total == that.total + && totalTimeInMillis == that.totalTimeInMillis + && periodic == that.periodic + && totalTimeExcludingWaitingOnLockInMillis == that.totalTimeExcludingWaitingOnLockInMillis; } @Override public int hashCode() { - return Objects.hash(total, totalTimeInMillis, periodic); + return Objects.hash(total, totalTimeInMillis, periodic, totalTimeExcludingWaitingOnLockInMillis); } } diff --git a/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java b/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java index ccc0ccd30d578..a52b289493cd6 100644 --- a/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java +++ b/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java @@ -1329,7 +1329,12 @@ public RefreshStats refreshStats() { } public FlushStats flushStats() { - return new FlushStats(flushMetric.count(), periodicFlushMetric.count(), TimeUnit.NANOSECONDS.toMillis(flushMetric.sum())); + return new FlushStats( + flushMetric.count(), + periodicFlushMetric.count(), + TimeUnit.NANOSECONDS.toMillis(flushMetric.sum()), + getEngineOrNull() != null ? getEngineOrNull().getTotalFlushTimeExcludingWaitingOnLockInMillis() : 0L + ); } public DocsStats docStats() { diff --git a/server/src/main/java/org/elasticsearch/monitor/metrics/NodeMetrics.java b/server/src/main/java/org/elasticsearch/monitor/metrics/NodeMetrics.java index e689898b05da6..68cbcdb5657f9 100644 --- a/server/src/main/java/org/elasticsearch/monitor/metrics/NodeMetrics.java +++ b/server/src/main/java/org/elasticsearch/monitor/metrics/NodeMetrics.java @@ -651,6 +651,29 @@ private void registerAsyncMetrics(MeterRegistry registry) { ) ); + metrics.add( + registry.registerLongAsyncCounter( + "es.flush.total.time", + "The total time flushes have been executed excluding waiting time on locks", + "milliseconds", + () -> new LongWithAttributes( + stats.getOrRefresh() != null ? stats.getOrRefresh().getIndices().getFlush().getTotalTimeInMillis() : 0L + ) + ) + ); + + metrics.add( + registry.registerLongAsyncCounter( + "es.flush.total_excluding_lock_waiting.time", + "The total time flushes have been executed excluding waiting time on locks", + "milliseconds", + () -> new LongWithAttributes( + stats.getOrRefresh() != null + ? stats.getOrRefresh().getIndices().getFlush().getTotalTimeExcludingWaitingOnLockMillis() + : 0L + ) + ) + ); } /** @@ -680,6 +703,7 @@ private long bytesUsedByGCGen(Optional optionalMem, String name) { private NodeStats getNodeStats() { CommonStatsFlags flags = new CommonStatsFlags( CommonStatsFlags.Flag.Indexing, + CommonStatsFlags.Flag.Flush, CommonStatsFlags.Flag.Get, CommonStatsFlags.Flag.Search, CommonStatsFlags.Flag.Merge, diff --git a/server/src/test/java/org/elasticsearch/action/admin/cluster/node/stats/NodeStatsTests.java b/server/src/test/java/org/elasticsearch/action/admin/cluster/node/stats/NodeStatsTests.java index b91ea304c5da6..e502904004fef 100644 --- a/server/src/test/java/org/elasticsearch/action/admin/cluster/node/stats/NodeStatsTests.java +++ b/server/src/test/java/org/elasticsearch/action/admin/cluster/node/stats/NodeStatsTests.java @@ -628,7 +628,7 @@ private static CommonStats createShardLevelCommonStats() { indicesCommonStats.getMerge().add(mergeStats); indicesCommonStats.getRefresh().add(new RefreshStats(++iota, ++iota, ++iota, ++iota, ++iota)); - indicesCommonStats.getFlush().add(new FlushStats(++iota, ++iota, ++iota)); + indicesCommonStats.getFlush().add(new FlushStats(++iota, ++iota, ++iota, ++iota)); indicesCommonStats.getWarmer().add(new WarmerStats(++iota, ++iota, ++iota)); indicesCommonStats.getCompletion().add(new CompletionStats(++iota, null)); indicesCommonStats.getTranslog().add(new TranslogStats(++iota, ++iota, ++iota, ++iota, ++iota)); diff --git a/server/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java b/server/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java index 73d6fb1a184a2..df4bde959d6ca 100644 --- a/server/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java +++ b/server/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java @@ -89,6 +89,7 @@ import org.elasticsearch.index.fielddata.IndexFieldData; import org.elasticsearch.index.fielddata.IndexFieldDataCache; import org.elasticsearch.index.fielddata.IndexFieldDataService; +import org.elasticsearch.index.flush.FlushStats; import org.elasticsearch.index.mapper.DocumentParsingException; import org.elasticsearch.index.mapper.IdFieldMapper; import org.elasticsearch.index.mapper.LuceneDocument; @@ -155,6 +156,7 @@ import java.util.concurrent.CyclicBarrier; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; +import java.util.concurrent.Executors; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; @@ -192,6 +194,7 @@ import static org.hamcrest.Matchers.in; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.lessThan; import static org.hamcrest.Matchers.lessThanOrEqualTo; import static org.hamcrest.Matchers.matchesRegex; import static org.hamcrest.Matchers.not; @@ -3998,6 +4001,39 @@ public void testFlushOnIdle() throws Exception { closeShards(shard); } + public void testFlushTimeExcludingWaiting() throws Exception { + IndexShard shard = newStartedShard(); + for (int i = 0; i < randomIntBetween(4, 10); i++) { + indexDoc(shard, "_doc", Integer.toString(i)); + } + + int numFlushes = randomIntBetween(2, 5); + var flushesLatch = new CountDownLatch(numFlushes); + var executor = Executors.newFixedThreadPool(numFlushes); + for (int i = 0; i < numFlushes; i++) { + executor.submit(() -> { + shard.flush(new FlushRequest().waitIfOngoing(true).force(true)); + flushesLatch.countDown(); + }); + } + safeAwait(flushesLatch); + + FlushStats flushStats = shard.flushStats(); + assertThat( + "Flush time excluding waiting should be captured", + flushStats.getTotalTimeExcludingWaitingOnLockMillis(), + greaterThan(0L) + ); + assertThat( + "Flush time excluding waiting should less than flush time with waiting", + flushStats.getTotalTimeExcludingWaitingOnLockMillis(), + lessThan(flushStats.getTotalTime().millis()) + ); + + closeShards(shard); + executor.shutdown(); + } + @TestLogging(reason = "testing traces of concurrent flushes", value = "org.elasticsearch.index.engine.Engine:TRACE") public void testFlushOnIdleConcurrentFlushDoesNotWait() throws Exception { final MockLogAppender mockLogAppender = new MockLogAppender(); From 898df4865cea5eecb1d793138afa1c5c8e9b2b7b Mon Sep 17 00:00:00 2001 From: Martijn van Groningen Date: Fri, 12 Apr 2024 15:17:36 +0200 Subject: [PATCH 06/20] Lossy source mapping validation shouldn't fail if source mode is specified. (#107412) Removed the assert that didn't allow this and added lossy source mapping tests with source mode. --- .../index/mapper/SourceFieldMapper.java | 13 +++++++------ .../index/mapper/SourceFieldMapperTests.java | 13 +++++++++++++ 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/SourceFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/SourceFieldMapper.java index 4a6eaa5b26c39..233faf462400b 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/SourceFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/SourceFieldMapper.java @@ -185,12 +185,13 @@ public SourceFieldMapper build() { if (mode.get() == Mode.DISABLED) { disallowed.add("mode=disabled"); } - assert disallowed.isEmpty() == false; - throw new MapperParsingException( - disallowed.size() == 1 - ? "Parameter [" + disallowed.get(0) + "] is not allowed in source" - : "Parameters [" + String.join(",", disallowed) + "] are not allowed in source" - ); + if (disallowed.isEmpty() == false) { + throw new MapperParsingException( + disallowed.size() == 1 + ? "Parameter [" + disallowed.get(0) + "] is not allowed in source" + : "Parameters [" + String.join(",", disallowed) + "] are not allowed in source" + ); + } } SourceFieldMapper sourceFieldMapper = new SourceFieldMapper( mode.get(), diff --git a/server/src/test/java/org/elasticsearch/index/mapper/SourceFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/SourceFieldMapperTests.java index 47b8bb3be36b7..a5264512d8086 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/SourceFieldMapperTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/SourceFieldMapperTests.java @@ -24,6 +24,7 @@ import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.notNullValue; public class SourceFieldMapperTests extends MetadataMapperTestCase { @@ -242,6 +243,18 @@ public void testSyntheticSourceInTimeSeries() throws IOException { public void testSupportsNonDefaultParameterValues() throws IOException { Settings settings = Settings.builder().put(SourceFieldMapper.LOSSY_PARAMETERS_ALLOWED_SETTING_NAME, false).build(); + { + var sourceFieldMapper = createMapperService(settings, topMapping(b -> b.startObject("_source").endObject())).documentMapper() + .sourceMapper(); + assertThat(sourceFieldMapper, notNullValue()); + } + { + var sourceFieldMapper = createMapperService( + settings, + topMapping(b -> b.startObject("_source").field("mode", randomBoolean() ? "synthetic" : "stored").endObject()) + ).documentMapper().sourceMapper(); + assertThat(sourceFieldMapper, notNullValue()); + } Exception e = expectThrows( MapperParsingException.class, () -> createMapperService(settings, topMapping(b -> b.startObject("_source").field("enabled", false).endObject())) From 469f4e32fb92ce6680b24accda77a1e8f867cfb5 Mon Sep 17 00:00:00 2001 From: Luigi Dell'Aquila Date: Fri, 12 Apr 2024 16:45:10 +0200 Subject: [PATCH 07/20] ESQL: make one more test deterministic (#107423) Minimal double precision errors on multi-node execution: ``` row 0 column 0:0: expected "176.82" but was "176.82000000000002" ``` --- .../esql/qa/testFixtures/src/main/resources/stats.csv-spec | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec index 867ff127c90e8..749c44d1f6ece 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec @@ -235,7 +235,7 @@ h:double ; sumOfScaledFloat -from employees | stats h = sum(height.scaled_float); +from employees | stats h = sum(height.scaled_float) | eval h = round(h, 10); h:double 176.82 From 96b513a7dea6286df8fc171bf68b2f864a3add05 Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Fri, 12 Apr 2024 08:01:23 -0700 Subject: [PATCH 08/20] Bulk loading enrich fields in ESQL (#106796) Today, the enrich lookup processes input terms one by one: querying one term, then loading enrich fields for matching documents of that term immediately. However, this approach can add significant overhead, such as the driver run loop, creating/releasing many pages, and especially excessive number of I/O seeks during loading _source, fields. This PR accumulates matching documents up to 256 before loading enrich fields. The 256 limit is chosen to avoid a significant sorting cost and long waits for cancellation. --- docs/changelog/106796.yaml | 5 + .../esql/enrich/EnrichLookupService.java | 1 + .../enrich/EnrichQuerySourceOperator.java | 135 +++++++++++------- .../EnrichQuerySourceOperatorTests.java | 79 ++++------ 4 files changed, 116 insertions(+), 104 deletions(-) create mode 100644 docs/changelog/106796.yaml diff --git a/docs/changelog/106796.yaml b/docs/changelog/106796.yaml new file mode 100644 index 0000000000000..83eb99dba1603 --- /dev/null +++ b/docs/changelog/106796.yaml @@ -0,0 +1,5 @@ +pr: 106796 +summary: Bulk loading enrich fields in ESQL +area: ES|QL +type: enhancement +issues: [] diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichLookupService.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichLookupService.java index e5d4e58d9d61b..366fb4ff55ba6 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichLookupService.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichLookupService.java @@ -270,6 +270,7 @@ private void doLookup( }; var queryOperator = new EnrichQuerySourceOperator( driverContext.blockFactory(), + EnrichQuerySourceOperator.DEFAULT_MAX_PAGE_SIZE, queryList, searchExecutionContext.getIndexReader() ); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichQuerySourceOperator.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichQuerySourceOperator.java index b0582e211fdba..6937f1a8c7772 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichQuerySourceOperator.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichQuerySourceOperator.java @@ -15,7 +15,6 @@ import org.apache.lucene.search.Query; import org.apache.lucene.search.Scorable; import org.apache.lucene.search.ScoreMode; -import org.apache.lucene.search.Weight; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.DocVector; import org.elasticsearch.compute.data.IntBlock; @@ -36,14 +35,17 @@ final class EnrichQuerySourceOperator extends SourceOperator { private final BlockFactory blockFactory; private final QueryList queryList; - private int queryPosition; - private Weight weight = null; + private int queryPosition = -1; private final IndexReader indexReader; - private int leafIndex = 0; private final IndexSearcher searcher; + private final int maxPageSize; - EnrichQuerySourceOperator(BlockFactory blockFactory, QueryList queryList, IndexReader indexReader) { + // using smaller pages enables quick cancellation and reduces sorting costs + static final int DEFAULT_MAX_PAGE_SIZE = 256; + + EnrichQuerySourceOperator(BlockFactory blockFactory, int maxPageSize, QueryList queryList, IndexReader indexReader) { this.blockFactory = blockFactory; + this.maxPageSize = maxPageSize; this.queryList = queryList; this.indexReader = indexReader; this.searcher = new IndexSearcher(indexReader); @@ -59,62 +61,96 @@ public boolean isFinished() { @Override public Page getOutput() { - if (leafIndex == indexReader.leaves().size()) { - queryPosition++; - leafIndex = 0; - weight = null; - } - if (isFinished()) { - return null; - } - if (weight == null) { - Query query = queryList.getQuery(queryPosition); - if (query != null) { - try { - query = searcher.rewrite(new ConstantScoreQuery(query)); - weight = searcher.createWeight(query, ScoreMode.COMPLETE_NO_SCORES, 1.0f); - } catch (IOException e) { - throw new UncheckedIOException(e); - } + int estimatedSize = Math.min(maxPageSize, queryList.getPositionCount() - queryPosition); + IntVector.Builder positionsBuilder = null; + IntVector.Builder docsBuilder = null; + IntVector.Builder segmentsBuilder = null; + try { + positionsBuilder = blockFactory.newIntVectorBuilder(estimatedSize); + docsBuilder = blockFactory.newIntVectorBuilder(estimatedSize); + if (indexReader.leaves().size() > 1) { + segmentsBuilder = blockFactory.newIntVectorBuilder(estimatedSize); } + int totalMatches = 0; + do { + Query query = nextQuery(); + if (query == null) { + assert isFinished(); + break; + } + query = searcher.rewrite(new ConstantScoreQuery(query)); + final var weight = searcher.createWeight(query, ScoreMode.COMPLETE_NO_SCORES, 1.0f); + if (weight == null) { + continue; + } + for (LeafReaderContext leaf : indexReader.leaves()) { + var scorer = weight.bulkScorer(leaf); + if (scorer == null) { + continue; + } + final DocCollector collector = new DocCollector(docsBuilder); + scorer.score(collector, leaf.reader().getLiveDocs()); + int matches = collector.matches; + + if (segmentsBuilder != null) { + for (int i = 0; i < matches; i++) { + segmentsBuilder.appendInt(leaf.ord); + } + } + for (int i = 0; i < matches; i++) { + positionsBuilder.appendInt(queryPosition); + } + totalMatches += matches; + } + } while (totalMatches < maxPageSize); + + return buildPage(totalMatches, positionsBuilder, segmentsBuilder, docsBuilder); + } catch (IOException e) { + throw new UncheckedIOException(e); + } finally { + Releasables.close(docsBuilder, segmentsBuilder, positionsBuilder); } + } + + Page buildPage(int positions, IntVector.Builder positionsBuilder, IntVector.Builder segmentsBuilder, IntVector.Builder docsBuilder) { + IntVector positionsVector = null; + IntVector shardsVector = null; + IntVector segmentsVector = null; + IntVector docsVector = null; + Page page = null; try { - return queryOneLeaf(weight, leafIndex++); - } catch (IOException ex) { - throw new UncheckedIOException(ex); + positionsVector = positionsBuilder.build(); + shardsVector = blockFactory.newConstantIntVector(0, positions); + if (segmentsBuilder == null) { + segmentsVector = blockFactory.newConstantIntVector(0, positions); + } else { + segmentsVector = segmentsBuilder.build(); + } + docsVector = docsBuilder.build(); + page = new Page(new DocVector(shardsVector, segmentsVector, docsVector, null).asBlock(), positionsVector.asBlock()); + } finally { + if (page == null) { + Releasables.close(positionsBuilder, segmentsVector, docsBuilder, positionsVector, shardsVector, docsVector); + } } + return page; } - private Page queryOneLeaf(Weight weight, int leafIndex) throws IOException { - if (weight == null) { - return null; - } - LeafReaderContext leafReaderContext = indexReader.leaves().get(leafIndex); - var scorer = weight.bulkScorer(leafReaderContext); - if (scorer == null) { - return null; - } - IntVector docs = null, segments = null, shards = null, positions = null; - boolean success = false; - try (IntVector.Builder docsBuilder = blockFactory.newIntVectorBuilder(1)) { - scorer.score(new DocCollector(docsBuilder), leafReaderContext.reader().getLiveDocs()); - docs = docsBuilder.build(); - final int positionCount = docs.getPositionCount(); - segments = blockFactory.newConstantIntVector(leafIndex, positionCount); - shards = blockFactory.newConstantIntVector(0, positionCount); - positions = blockFactory.newConstantIntVector(queryPosition, positionCount); - Page page = new Page(new DocVector(shards, segments, docs, true).asBlock(), positions.asBlock()); - success = true; - return page; - } finally { - if (success == false) { - Releasables.close(docs, shards, segments, positions); + private Query nextQuery() { + ++queryPosition; + while (isFinished() == false) { + Query query = queryList.getQuery(queryPosition); + if (query != null) { + return query; } + ++queryPosition; } + return null; } private static class DocCollector implements LeafCollector { final IntVector.Builder docIds; + int matches = 0; DocCollector(IntVector.Builder docIds) { this.docIds = docIds; @@ -127,6 +163,7 @@ public void setScorer(Scorable scorer) { @Override public void collect(int doc) { + ++matches; docIds.appendInt(doc); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/enrich/EnrichQuerySourceOperatorTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/enrich/EnrichQuerySourceOperatorTests.java index 7f8e1f7113e22..eef29f0681fbd 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/enrich/EnrichQuerySourceOperatorTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/enrich/EnrichQuerySourceOperatorTests.java @@ -48,6 +48,7 @@ import static org.elasticsearch.xpack.ql.type.DataTypes.KEYWORD; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.lessThanOrEqualTo; import static org.mockito.Mockito.mock; public class EnrichQuerySourceOperatorTests extends ESTestCase { @@ -120,60 +121,26 @@ public void testQueries() throws Exception { // 3 -> [] -> [] // 4 -> [a1] -> [3] // 5 -> [] -> [] - EnrichQuerySourceOperator queryOperator = new EnrichQuerySourceOperator(blockFactory, queryList, reader); - { - Page p0 = queryOperator.getOutput(); - assertNotNull(p0); - assertThat(p0.getPositionCount(), equalTo(2)); - IntVector docs = getDocVector(p0, 0); - assertThat(docs.getInt(0), equalTo(1)); - assertThat(docs.getInt(1), equalTo(4)); - Block positions = p0.getBlock(1); - assertThat(BlockUtils.toJavaObject(positions, 0), equalTo(0)); - assertThat(BlockUtils.toJavaObject(positions, 1), equalTo(0)); - p0.releaseBlocks(); - } - { - Page p1 = queryOperator.getOutput(); - assertNotNull(p1); - assertThat(p1.getPositionCount(), equalTo(3)); - IntVector docs = getDocVector(p1, 0); - assertThat(docs.getInt(0), equalTo(0)); - assertThat(docs.getInt(1), equalTo(1)); - assertThat(docs.getInt(2), equalTo(2)); - Block positions = p1.getBlock(1); - assertThat(BlockUtils.toJavaObject(positions, 0), equalTo(1)); - assertThat(BlockUtils.toJavaObject(positions, 1), equalTo(1)); - assertThat(BlockUtils.toJavaObject(positions, 2), equalTo(1)); - p1.releaseBlocks(); - } - { - Page p2 = queryOperator.getOutput(); - assertNull(p2); - } - { - Page p3 = queryOperator.getOutput(); - assertNull(p3); - } - { - Page p4 = queryOperator.getOutput(); - assertNotNull(p4); - assertThat(p4.getPositionCount(), equalTo(1)); - IntVector docs = getDocVector(p4, 0); - assertThat(docs.getInt(0), equalTo(3)); - Block positions = p4.getBlock(1); - assertThat(BlockUtils.toJavaObject(positions, 0), equalTo(4)); - p4.releaseBlocks(); - } - { - Page p5 = queryOperator.getOutput(); - assertNull(p5); - } - { - assertFalse(queryOperator.isFinished()); - Page p6 = queryOperator.getOutput(); - assertNull(p6); - } + EnrichQuerySourceOperator queryOperator = new EnrichQuerySourceOperator(blockFactory, 128, queryList, reader); + Page p0 = queryOperator.getOutput(); + assertNotNull(p0); + assertThat(p0.getPositionCount(), equalTo(6)); + IntVector docs = getDocVector(p0, 0); + assertThat(docs.getInt(0), equalTo(1)); + assertThat(docs.getInt(1), equalTo(4)); + assertThat(docs.getInt(2), equalTo(0)); + assertThat(docs.getInt(3), equalTo(1)); + assertThat(docs.getInt(4), equalTo(2)); + assertThat(docs.getInt(5), equalTo(3)); + + Block positions = p0.getBlock(1); + assertThat(BlockUtils.toJavaObject(positions, 0), equalTo(0)); + assertThat(BlockUtils.toJavaObject(positions, 1), equalTo(0)); + assertThat(BlockUtils.toJavaObject(positions, 2), equalTo(1)); + assertThat(BlockUtils.toJavaObject(positions, 3), equalTo(1)); + assertThat(BlockUtils.toJavaObject(positions, 4), equalTo(1)); + assertThat(BlockUtils.toJavaObject(positions, 5), equalTo(4)); + p0.releaseBlocks(); assertTrue(queryOperator.isFinished()); IOUtils.close(reader, dir, inputTerms); } @@ -220,13 +187,15 @@ public void testRandomMatchQueries() throws Exception { } MappedFieldType uidField = new KeywordFieldMapper.KeywordFieldType("uid"); var queryList = QueryList.termQueryList(uidField, mock(SearchExecutionContext.class), inputTerms, KEYWORD); - EnrichQuerySourceOperator queryOperator = new EnrichQuerySourceOperator(blockFactory, queryList, reader); + int maxPageSize = between(1, 256); + EnrichQuerySourceOperator queryOperator = new EnrichQuerySourceOperator(blockFactory, maxPageSize, queryList, reader); Map> actualPositions = new HashMap<>(); while (queryOperator.isFinished() == false) { Page page = queryOperator.getOutput(); if (page != null) { IntVector docs = getDocVector(page, 0); IntBlock positions = page.getBlock(1); + assertThat(positions.getPositionCount(), lessThanOrEqualTo(maxPageSize)); for (int i = 0; i < page.getPositionCount(); i++) { int doc = docs.getInt(i); int position = positions.getInt(i); From b85b9dcf9c50621b4de7c17f99e3208aace98514 Mon Sep 17 00:00:00 2001 From: Jedr Blaszyk Date: Fri, 12 Apr 2024 17:03:16 +0200 Subject: [PATCH 09/20] [Connector API] Simplify updating draft filtering rules (#107364) --- .../332_connector_update_filtering.yml | 485 +++++------------- .../connector/ConnectorFiltering.java | 117 +++-- .../connector/ConnectorIndexService.java | 77 ++- ...ansportUpdateConnectorFilteringAction.java | 28 +- .../UpdateConnectorFilteringAction.java | 80 ++- .../filtering/FilteringAdvancedSnippet.java | 34 +- .../connector/filtering/FilteringRule.java | 30 +- .../connector/filtering/FilteringRules.java | 6 +- .../filtering/FilteringValidationInfo.java | 7 + .../connector/ConnectorIndexServiceTests.java | 79 ++- .../connector/ConnectorTestUtils.java | 1 - ...eringActionRequestBWCSerializingTests.java | 4 +- 12 files changed, 503 insertions(+), 445 deletions(-) diff --git a/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/332_connector_update_filtering.yml b/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/332_connector_update_filtering.yml index ac102db163767..5734fdfe67ce8 100644 --- a/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/332_connector_update_filtering.yml +++ b/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/332_connector_update_filtering.yml @@ -13,91 +13,27 @@ setup: is_native: false service_type: super-connector --- -"Update Connector Filtering with advanced snippet value array": +"Update Connector Filtering - Update draft": - do: connector.update_filtering: connector_id: test-connector body: - filtering: - - active: - advanced_snippet: - created_at: "2023-05-25T12:30:00.000Z" - updated_at: "2023-05-25T12:30:00.000Z" - value: - - tables: - - some_table - query: 'SELECT id, st_geohash(coordinates) FROM my_db.some_table;' - rules: - - created_at: "2023-05-25T12:30:00.000Z" - field: _ - id: RULE-ACTIVE-0 - order: 0 - policy: include - rule: regex - updated_at: "2023-05-25T12:30:00.000Z" - value: ".*" - validation: - errors: [] - state: valid - domain: DEFAULT - draft: - advanced_snippet: - created_at: "2023-05-25T12:30:00.000Z" - updated_at: "2023-05-25T12:30:00.000Z" - value: - - tables: - - some_table - query: 'SELECT id, st_geohash(coordinates) FROM my_db.some_table;' - - tables: - - another_table - query: 'SELECT id, st_geohash(coordinates) FROM my_db.another_table;' - rules: - - created_at: "2023-05-25T12:30:00.000Z" - field: _ - id: RULE-DRAFT-0 - order: 0 - policy: include - rule: regex - updated_at: "2023-05-25T12:30:00.000Z" - value: ".*" - validation: - errors: [] - state: valid - - active: - advanced_snippet: - created_at: "2021-05-25T12:30:00.000Z" - updated_at: "2021-05-25T12:30:00.000Z" - value: {} - rules: - - created_at: "2021-05-25T12:30:00.000Z" - field: _ - id: RULE-ACTIVE-1 - order: 0 - policy: include - rule: regex - updated_at: "2021-05-25T12:30:00.000Z" - value: ".*" - validation: - errors: [] - state: valid - domain: TEST - draft: - advanced_snippet: - created_at: "2021-05-25T12:30:00.000Z" - updated_at: "2021-05-25T12:30:00.000Z" - value: {} - rules: - - created_at: "2021-05-25T12:30:00.000Z" - field: _ - id: RULE-DRAFT-1 - order: 0 - policy: exclude - rule: regex - updated_at: "2021-05-25T12:30:00.000Z" - value: ".*" - validation: - errors: [] - state: valid + advanced_snippet: + created_at: "2023-05-25T12:30:00.000Z" + updated_at: "2023-05-25T12:30:00.000Z" + value: + - tables: + - some_table + query: 'SELECT id, st_geohash(coordinates) FROM my_db.some_table;' + rules: + - created_at: "2023-05-25T12:30:00.000Z" + field: _ + id: DEFAULT + order: 0 + policy: include + rule: regex + updated_at: "2023-05-25T12:30:00.000Z" + value: ".*" - match: { result: updated } @@ -105,19 +41,88 @@ setup: connector.get: connector_id: test-connector + + - match: { filtering.0.draft.advanced_snippet.created_at: "2023-05-25T12:30:00.000Z" } + - match: { filtering.0.draft.advanced_snippet.value.0.tables.0.: "some_table" } + - match: { filtering.0.draft.rules.0.id: DEFAULT } + - match: { filtering.0.draft.validation.errors: [] } + - match: { filtering.0.draft.validation.state: edited } + + # Default domain and active should be unchanged - match: { filtering.0.domain: DEFAULT } - - match: { filtering.0.active.advanced_snippet.created_at: "2023-05-25T12:30:00.000Z" } - - match: { filtering.0.active.advanced_snippet.value.0.tables.0.: "some_table" } - - match: { filtering.0.active.rules.0.id: "RULE-ACTIVE-0" } - - match: { filtering.0.draft.rules.0.id: "RULE-DRAFT-0" } + - match: { filtering.0.active.advanced_snippet.value: {} } + - match: { filtering.0.active.advanced_snippet.value: {} } + - match: { filtering.0.active.rules.0.field: _ } + - match: { filtering.0.active.rules.0.id: DEFAULT } + - match: { filtering.0.active.rules.0.rule: regex } + + +--- +"Update Connector Filtering - Update draft rules only": + - do: + connector.update_filtering: + connector_id: test-connector + body: + rules: + - created_at: "2023-05-25T12:30:00.000Z" + field: my_field + id: MY-RULE-1 + order: 0 + policy: exclude + rule: regex + updated_at: "2023-05-25T12:30:00.000Z" + value: "tax-.*" + - created_at: "2023-05-25T12:30:00.000Z" + field: _ + id: DEFAULT + order: 1 + policy: include + rule: regex + updated_at: "2023-05-25T12:30:00.000Z" + value: ".*" + + - match: { result: updated } - - match: { filtering.1.domain: TEST } - - match: { filtering.1.active.advanced_snippet.created_at: "2021-05-25T12:30:00.000Z" } - - match: { filtering.1.active.rules.0.id: "RULE-ACTIVE-1" } - - match: { filtering.1.draft.rules.0.id: "RULE-DRAFT-1" } + - do: + connector.get: + connector_id: test-connector + + - match: { filtering.0.draft.rules.0.id: MY-RULE-1 } + - match: { filtering.0.draft.rules.1.id: DEFAULT } + + # Default domain and active should be unchanged + - match: { filtering.0.domain: DEFAULT } + - match: { filtering.0.active.advanced_snippet.value: {} } + - match: { filtering.0.active.advanced_snippet.value: {} } + - match: { filtering.0.active.rules.0.field: _ } + - match: { filtering.0.active.rules.0.id: DEFAULT } + - match: { filtering.0.active.rules.0.rule: regex } --- -"Update Connector Filtering with advanced snippet value object": +"Update Connector Filtering - Update draft advanced snippet only": + - do: + connector.update_filtering: + connector_id: test-connector + body: + advanced_snippet: + created_at: "2023-05-25T12:30:00.000Z" + updated_at: "2023-05-25T12:30:00.000Z" + value: + - tables: + - some_table + query: 'SELECT id, st_geohash(coordinates) FROM my_db.some_table;' + + - match: { result: updated } + + - do: + connector.get: + connector_id: test-connector + + - match: { filtering.0.draft.advanced_snippet.created_at: "2023-05-25T12:30:00.000Z" } + - match: { filtering.0.draft.advanced_snippet.value.0.tables.0.: "some_table" } + +--- +"Update Connector Filtering - Update full filtering object": - do: connector.update_filtering: connector_id: test-connector @@ -132,7 +137,7 @@ setup: rules: - created_at: "2023-05-25T12:30:00.000Z" field: _ - id: RULE-ACTIVE-0 + id: DEFAULT order: 0 policy: include rule: regex @@ -141,7 +146,6 @@ setup: validation: errors: [] state: valid - domain: DEFAULT draft: advanced_snippet: created_at: "2023-05-25T12:30:00.000Z" @@ -150,7 +154,7 @@ setup: rules: - created_at: "2023-05-25T12:30:00.000Z" field: _ - id: RULE-DRAFT-0 + id: DEFAULT order: 0 policy: include rule: regex @@ -159,41 +163,7 @@ setup: validation: errors: [] state: valid - - active: - advanced_snippet: - created_at: "2021-05-25T12:30:00.000Z" - updated_at: "2021-05-25T12:30:00.000Z" - value: {} - rules: - - created_at: "2021-05-25T12:30:00.000Z" - field: _ - id: RULE-ACTIVE-1 - order: 0 - policy: include - rule: regex - updated_at: "2021-05-25T12:30:00.000Z" - value: ".*" - validation: - errors: [] - state: valid - domain: TEST - draft: - advanced_snippet: - created_at: "2021-05-25T12:30:00.000Z" - updated_at: "2021-05-25T12:30:00.000Z" - value: {} - rules: - - created_at: "2021-05-25T12:30:00.000Z" - field: _ - id: RULE-DRAFT-1 - order: 0 - policy: exclude - rule: regex - updated_at: "2021-05-25T12:30:00.000Z" - value: ".*" - validation: - errors: [] - state: valid + - match: { result: updated } @@ -204,13 +174,9 @@ setup: - match: { filtering.0.domain: DEFAULT } - match: { filtering.0.active.advanced_snippet.created_at: "2023-05-25T12:30:00.000Z" } - match: { filtering.0.active.advanced_snippet.value.some_filtering_key: "some_filtering_value" } - - match: { filtering.0.active.rules.0.id: "RULE-ACTIVE-0" } - - match: { filtering.0.draft.rules.0.id: "RULE-DRAFT-0" } + - match: { filtering.0.active.rules.0.id: "DEFAULT" } + - match: { filtering.0.draft.rules.0.id: "DEFAULT" } - - match: { filtering.1.domain: TEST } - - match: { filtering.1.active.advanced_snippet.created_at: "2021-05-25T12:30:00.000Z" } - - match: { filtering.1.active.rules.0.id: "RULE-ACTIVE-1" } - - match: { filtering.1.draft.rules.0.id: "RULE-DRAFT-1" } --- "Update Connector Filtering with value literal - Wrong advanced snippet value": @@ -219,77 +185,34 @@ setup: connector.update_filtering: connector_id: test-connector body: - filtering: - - active: - advanced_snippet: - created_at: "2023-05-25T12:30:00.000Z" - updated_at: "2023-05-25T12:30:00.000Z" - value: "string literal" - rules: - - created_at: "2023-05-25T12:30:00.000Z" - field: _ - id: RULE-ACTIVE-0 - order: 0 - policy: include - rule: regex - updated_at: "2023-05-25T12:30:00.000Z" - value: ".*" - validation: - errors: [] - state: valid - domain: DEFAULT - draft: - advanced_snippet: - created_at: "2023-05-25T12:30:00.000Z" - updated_at: "2023-05-25T12:30:00.000Z" - value: {} - rules: - - created_at: "2023-05-25T12:30:00.000Z" - field: _ - id: RULE-DRAFT-0 - order: 0 - policy: include - rule: regex - updated_at: "2023-05-25T12:30:00.000Z" - value: ".*" - validation: - errors: [] - state: valid - - active: - advanced_snippet: - created_at: "2021-05-25T12:30:00.000Z" - updated_at: "2021-05-25T12:30:00.000Z" - value: {} - rules: - - created_at: "2021-05-25T12:30:00.000Z" - field: _ - id: RULE-ACTIVE-1 - order: 0 - policy: include - rule: regex - updated_at: "2021-05-25T12:30:00.000Z" - value: ".*" - validation: - errors: [] - state: valid - domain: TEST - draft: - advanced_snippet: - created_at: "2021-05-25T12:30:00.000Z" - updated_at: "2021-05-25T12:30:00.000Z" - value: {} - rules: - - created_at: "2021-05-25T12:30:00.000Z" - field: _ - id: RULE-DRAFT-1 - order: 0 - policy: exclude - rule: regex - updated_at: "2021-05-25T12:30:00.000Z" - value: ".*" - validation: - errors: [] - state: valid + advanced_snippet: + value: "string literal" + +--- +"Update Connector Filtering with value literal - Empty rules": + - do: + catch: "bad_request" + connector.update_filtering: + connector_id: test-connector + body: + rules: [ ] + +--- +"Update Connector Filtering with value literal - Default rule not present": + - do: + catch: "bad_request" + connector.update_filtering: + connector_id: test-connector + body: + rules: + - created_at: "2023-05-25T12:30:00.000Z" + field: my_field + id: MY_RULE + order: 0 + policy: exclude + rule: regex + updated_at: "2023-05-25T12:30:00.000Z" + value: "hello-not-default-rule.*" --- "Update Connector Filtering - Connector doesn't exist": @@ -298,77 +221,8 @@ setup: connector.update_filtering: connector_id: test-non-existent-connector body: - filtering: - - active: - advanced_snippet: - created_at: "2023-05-25T12:30:00.000Z" - updated_at: "2023-05-25T12:30:00.000Z" - value: {} - rules: - - created_at: "2023-05-25T12:30:00.000Z" - field: _ - id: RULE-ACTIVE-0 - order: 0 - policy: include - rule: regex - updated_at: "2023-05-25T12:30:00.000Z" - value: ".*" - validation: - errors: [] - state: valid - domain: DEFAULT - draft: - advanced_snippet: - created_at: "2023-05-25T12:30:00.000Z" - updated_at: "2023-05-25T12:30:00.000Z" - value: {} - rules: - - created_at: "2023-05-25T12:30:00.000Z" - field: _ - id: RULE-DRAFT-0 - order: 0 - policy: include - rule: regex - updated_at: "2023-05-25T12:30:00.000Z" - value: ".*" - validation: - errors: [] - state: valid - - active: - advanced_snippet: - created_at: "2021-05-25T12:30:00.000Z" - updated_at: "2021-05-25T12:30:00.000Z" - value: {} - rules: - - created_at: "2021-05-25T12:30:00.000Z" - field: _ - id: RULE-ACTIVE-1 - order: 0 - policy: include - rule: regex - updated_at: "2021-05-25T12:30:00.000Z" - value: ".*" - validation: - errors: [] - state: valid - domain: TEST - draft: - advanced_snippet: - created_at: "2021-05-25T12:30:00.000Z" - updated_at: "2021-05-25T12:30:00.000Z" - value: {} - rules: - - created_at: "2021-05-25T12:30:00.000Z" - field: _ - id: RULE-DRAFT-1 - order: 0 - policy: exclude - rule: regex - updated_at: "2021-05-25T12:30:00.000Z" - value: ".*" - validation: - errors: [] - state: valid + advanced_snippet: + value: {} --- "Update Connector Filtering - Required fields are missing": @@ -376,9 +230,7 @@ setup: catch: "bad_request" connector.update_filtering: connector_id: test-connector - body: - filtering: - - domain: some_domain + body: {} - match: status: 400 @@ -390,74 +242,7 @@ setup: connector.update_filtering: connector_id: test-connector body: - filtering: - - active: - advanced_snippet: - created_at: "this-is-not-a-datetime-!!!!" - updated_at: "2023-05-25T12:30:00.000Z" - value: {} - rules: - - created_at: "2023-05-25T12:30:00.000Z" - field: _ - id: RULE-ACTIVE-0 - order: 0 - policy: include - rule: regex - updated_at: "2023-05-25T12:30:00.000Z" - value: ".*" - validation: - errors: [] - state: valid - domain: DEFAULT - draft: - advanced_snippet: - created_at: "2023-05-25T12:30:00.000Z" - updated_at: "2023-05-25T12:30:00.000Z" - value: {} - rules: - - created_at: "2023-05-25T12:30:00.000Z" - field: _ - id: RULE-DRAFT-0 - order: 0 - policy: include - rule: regex - updated_at: "2023-05-25T12:30:00.000Z" - value: ".*" - validation: - errors: [] - state: valid - - active: - advanced_snippet: - created_at: "2021-05-25T12:30:00.000Z" - updated_at: "2021-05-25T12:30:00.000Z" - value: {} - rules: - - created_at: "2021-05-25T12:30:00.000Z" - field: _ - id: RULE-ACTIVE-1 - order: 0 - policy: include - rule: regex - updated_at: "2021-05-25T12:30:00.000Z" - value: ".*" - validation: - errors: [] - state: valid - domain: TEST - draft: - advanced_snippet: - created_at: "2021-05-25T12:30:00.000Z" - updated_at: "2021-05-25T12:30:00.000Z" - value: {} - rules: - - created_at: "2021-05-25T12:30:00.000Z" - field: _ - id: RULE-DRAFT-1 - order: 0 - policy: exclude - rule: regex - updated_at: "2021-05-25T12:30:00.000Z" - value: ".*" - validation: - errors: [] - state: valid + rules: [ ] + advanced_snippet: + updated_at: "wrong datetime" + value: { } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/ConnectorFiltering.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/ConnectorFiltering.java index 62a8a68cea5ca..4d357f459cb2f 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/ConnectorFiltering.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/ConnectorFiltering.java @@ -43,26 +43,23 @@ */ public class ConnectorFiltering implements Writeable, ToXContentObject { - private final FilteringRules active; - private final String domain; - private final FilteringRules draft; + private FilteringRules active; + private final String domain = "DEFAULT"; // Connectors always use DEFAULT domain, users should not modify it via API + private FilteringRules draft; /** * Constructs a new ConnectorFiltering instance. * * @param active The active filtering rules. - * @param domain The domain associated with the filtering. * @param draft The draft filtering rules. */ - public ConnectorFiltering(FilteringRules active, String domain, FilteringRules draft) { + public ConnectorFiltering(FilteringRules active, FilteringRules draft) { this.active = active; - this.domain = domain; this.draft = draft; } public ConnectorFiltering(StreamInput in) throws IOException { this.active = new FilteringRules(in); - this.domain = in.readString(); this.draft = new FilteringRules(in); } @@ -78,22 +75,27 @@ public FilteringRules getDraft() { return draft; } + public ConnectorFiltering setActive(FilteringRules active) { + this.active = active; + return this; + } + + public ConnectorFiltering setDraft(FilteringRules draft) { + this.draft = draft; + return this; + } + private static final ParseField ACTIVE_FIELD = new ParseField("active"); - private static final ParseField DOMAIN_FIELD = new ParseField("domain"); private static final ParseField DRAFT_FIELD = new ParseField("draft"); private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( "connector_filtering", true, - args -> new ConnectorFiltering.Builder().setActive((FilteringRules) args[0]) - .setDomain((String) args[1]) - .setDraft((FilteringRules) args[2]) - .build() + args -> new ConnectorFiltering.Builder().setActive((FilteringRules) args[0]).setDraft((FilteringRules) args[1]).build() ); static { PARSER.declareObject(constructorArg(), (p, c) -> FilteringRules.fromXContent(p), ACTIVE_FIELD); - PARSER.declareString(constructorArg(), DOMAIN_FIELD); PARSER.declareObject(constructorArg(), (p, c) -> FilteringRules.fromXContent(p), DRAFT_FIELD); } @@ -102,7 +104,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); { builder.field(ACTIVE_FIELD.getPreferredName(), active); - builder.field(DOMAIN_FIELD.getPreferredName(), domain); + builder.field("domain", domain); // We still want to write the DEFAULT domain to the index builder.field(DRAFT_FIELD.getPreferredName(), draft); } builder.endObject(); @@ -124,7 +126,6 @@ public static ConnectorFiltering fromXContentBytes(BytesReference source, XConte @Override public void writeTo(StreamOutput out) throws IOException { active.writeTo(out); - out.writeString(domain); draft.writeTo(out); } @@ -141,10 +142,41 @@ public int hashCode() { return Objects.hash(active, domain, draft); } + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser, Void> CONNECTOR_FILTERING_PARSER = + new ConstructingObjectParser<>( + "connector_filtering_parser", + true, + args -> (List) args[0] + + ); + + static { + CONNECTOR_FILTERING_PARSER.declareObjectArray( + constructorArg(), + (p, c) -> ConnectorFiltering.fromXContent(p), + Connector.FILTERING_FIELD + ); + } + + /** + * Deserializes the {@link ConnectorFiltering} property from a {@link Connector} byte representation. + * + * @param source Byte representation of the {@link Connector}. + * @param xContentType {@link XContentType} of the content (e.g., JSON). + * @return List of {@link ConnectorFiltering} objects. + */ + public static List fromXContentBytesConnectorFiltering(BytesReference source, XContentType xContentType) { + try (XContentParser parser = XContentHelper.createParser(XContentParserConfiguration.EMPTY, source, xContentType)) { + return CONNECTOR_FILTERING_PARSER.parse(parser, null); + } catch (IOException e) { + throw new ElasticsearchParseException("Failed to parse a connector filtering.", e); + } + } + public static class Builder { private FilteringRules active; - private String domain; private FilteringRules draft; public Builder setActive(FilteringRules active) { @@ -152,21 +184,33 @@ public Builder setActive(FilteringRules active) { return this; } - public Builder setDomain(String domain) { - this.domain = domain; - return this; - } - public Builder setDraft(FilteringRules draft) { this.draft = draft; return this; } public ConnectorFiltering build() { - return new ConnectorFiltering(active, domain, draft); + return new ConnectorFiltering(active, draft); } } + public static boolean isDefaultRulePresentInFilteringRules(List rules) { + FilteringRule defaultRule = getDefaultFilteringRule(null); + return rules.stream().anyMatch(rule -> rule.equalsExceptForTimestampsAndOrder(defaultRule)); + } + + public static FilteringRule getDefaultFilteringRule(Instant timestamp) { + return new FilteringRule.Builder().setCreatedAt(timestamp) + .setField("_") + .setId("DEFAULT") + .setOrder(0) + .setPolicy(FilteringPolicy.INCLUDE) + .setRule(FilteringRuleCondition.REGEX) + .setUpdatedAt(timestamp) + .setValue(".*") + .build(); + } + public static ConnectorFiltering getDefaultConnectorFilteringConfig() { Instant currentTimestamp = Instant.now(); @@ -178,19 +222,7 @@ public static ConnectorFiltering getDefaultConnectorFilteringConfig() { .setAdvancedSnippetValue(Collections.emptyMap()) .build() ) - .setRules( - List.of( - new FilteringRule.Builder().setCreatedAt(currentTimestamp) - .setField("_") - .setId("DEFAULT") - .setOrder(0) - .setPolicy(FilteringPolicy.INCLUDE) - .setRule(FilteringRuleCondition.REGEX) - .setUpdatedAt(currentTimestamp) - .setValue(".*") - .build() - ) - ) + .setRules(List.of(getDefaultFilteringRule(currentTimestamp))) .setFilteringValidationInfo( new FilteringValidationInfo.Builder().setValidationErrors(Collections.emptyList()) .setValidationState(FilteringValidationState.VALID) @@ -198,7 +230,6 @@ public static ConnectorFiltering getDefaultConnectorFilteringConfig() { ) .build() ) - .setDomain("DEFAULT") .setDraft( new FilteringRules.Builder().setAdvancedSnippet( new FilteringAdvancedSnippet.Builder().setAdvancedSnippetCreatedAt(currentTimestamp) @@ -206,19 +237,7 @@ public static ConnectorFiltering getDefaultConnectorFilteringConfig() { .setAdvancedSnippetValue(Collections.emptyMap()) .build() ) - .setRules( - List.of( - new FilteringRule.Builder().setCreatedAt(currentTimestamp) - .setField("_") - .setId("DEFAULT") - .setOrder(0) - .setPolicy(FilteringPolicy.INCLUDE) - .setRule(FilteringRuleCondition.REGEX) - .setUpdatedAt(currentTimestamp) - .setValue(".*") - .build() - ) - ) + .setRules(List.of(getDefaultFilteringRule(currentTimestamp))) .setFilteringValidationInfo( new FilteringValidationInfo.Builder().setValidationErrors(Collections.emptyList()) .setValidationState(FilteringValidationState.VALID) diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/ConnectorIndexService.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/ConnectorIndexService.java index bceeece6ec17b..20b9a8ec74027 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/ConnectorIndexService.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/ConnectorIndexService.java @@ -40,12 +40,12 @@ import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.application.connector.action.PostConnectorAction; import org.elasticsearch.xpack.application.connector.action.PutConnectorAction; import org.elasticsearch.xpack.application.connector.action.UpdateConnectorApiKeyIdAction; import org.elasticsearch.xpack.application.connector.action.UpdateConnectorConfigurationAction; import org.elasticsearch.xpack.application.connector.action.UpdateConnectorErrorAction; -import org.elasticsearch.xpack.application.connector.action.UpdateConnectorFilteringAction; import org.elasticsearch.xpack.application.connector.action.UpdateConnectorIndexNameAction; import org.elasticsearch.xpack.application.connector.action.UpdateConnectorLastSyncStatsAction; import org.elasticsearch.xpack.application.connector.action.UpdateConnectorNameAction; @@ -54,6 +54,10 @@ import org.elasticsearch.xpack.application.connector.action.UpdateConnectorSchedulingAction; import org.elasticsearch.xpack.application.connector.action.UpdateConnectorServiceTypeAction; import org.elasticsearch.xpack.application.connector.action.UpdateConnectorStatusAction; +import org.elasticsearch.xpack.application.connector.filtering.FilteringAdvancedSnippet; +import org.elasticsearch.xpack.application.connector.filtering.FilteringRule; +import org.elasticsearch.xpack.application.connector.filtering.FilteringRules; +import org.elasticsearch.xpack.application.connector.filtering.FilteringValidationInfo; import org.elasticsearch.xpack.application.connector.syncjob.ConnectorSyncJob; import org.elasticsearch.xpack.application.connector.syncjob.ConnectorSyncJobIndexService; @@ -70,6 +74,7 @@ import java.util.stream.Collectors; import static org.elasticsearch.xcontent.XContentFactory.jsonBuilder; +import static org.elasticsearch.xpack.application.connector.ConnectorFiltering.fromXContentBytesConnectorFiltering; /** * A service that manages persistent {@link Connector} configurations. @@ -555,19 +560,19 @@ public void updateConnectorNameOrDescription(UpdateConnectorNameAction.Request r } /** - * Updates the {@link ConnectorFiltering} property of a {@link Connector}. + * Sets the {@link ConnectorFiltering} property of a {@link Connector}. * - * @param request Request for updating connector filtering property. - * @param listener Listener to respond to a successful response or an error. + * @param connectorId The ID of the {@link Connector} to update. + * @param filtering The list of {@link ConnectorFiltering} . + * @param listener Listener to respond to a successful response or an error. */ - public void updateConnectorFiltering(UpdateConnectorFilteringAction.Request request, ActionListener listener) { + public void updateConnectorFiltering(String connectorId, List filtering, ActionListener listener) { try { - String connectorId = request.getConnectorId(); final UpdateRequest updateRequest = new UpdateRequest(CONNECTOR_INDEX_NAME, connectorId).doc( new IndexRequest(CONNECTOR_INDEX_NAME).opType(DocWriteRequest.OpType.INDEX) .id(connectorId) .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) - .source(Map.of(Connector.FILTERING_FIELD.getPreferredName(), request.getFiltering())) + .source(Map.of(Connector.FILTERING_FIELD.getPreferredName(), filtering)) ); client.update(updateRequest, new DelegatingIndexNotFoundActionListener<>(connectorId, listener, (l, updateResponse) -> { if (updateResponse.getResult() == UpdateResponse.Result.NOT_FOUND) { @@ -581,6 +586,64 @@ public void updateConnectorFiltering(UpdateConnectorFilteringAction.Request requ } } + /** + * Updates the draft filtering in a given {@link Connector}. + * + * @param connectorId The ID of the {@link Connector} to be updated. + * @param advancedSnippet An instance of {@link FilteringAdvancedSnippet}. + * @param rules A list of instances of {@link FilteringRule} to be applied. + * @param listener Listener to respond to a successful response or an error. + */ + public void updateConnectorFilteringDraft( + String connectorId, + FilteringAdvancedSnippet advancedSnippet, + List rules, + ActionListener listener + ) { + try { + getConnector(connectorId, listener.delegateFailure((l, connector) -> { + List connectorFilteringList = fromXContentBytesConnectorFiltering( + connector.getSourceRef(), + XContentType.JSON + ); + // Connectors represent their filtering configuration as a singleton list + ConnectorFiltering connectorFilteringSingleton = connectorFilteringList.get(0); + + // If advanced snippet or rules are not defined, keep the current draft state + FilteringAdvancedSnippet newDraftAdvancedSnippet = advancedSnippet == null + ? connectorFilteringSingleton.getDraft().getAdvancedSnippet() + : advancedSnippet; + + List newDraftRules = rules == null ? connectorFilteringSingleton.getDraft().getRules() : rules; + + ConnectorFiltering connectorFilteringWithUpdatedDraft = connectorFilteringSingleton.setDraft( + new FilteringRules.Builder().setRules(newDraftRules) + .setAdvancedSnippet(newDraftAdvancedSnippet) + .setFilteringValidationInfo(FilteringValidationInfo.getInitialDraftValidationInfo()) + .build() + ); + + final UpdateRequest updateRequest = new UpdateRequest(CONNECTOR_INDEX_NAME, connectorId).doc( + new IndexRequest(CONNECTOR_INDEX_NAME).opType(DocWriteRequest.OpType.INDEX) + .id(connectorId) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .source(Map.of(Connector.FILTERING_FIELD.getPreferredName(), List.of(connectorFilteringWithUpdatedDraft))) + ); + + client.update(updateRequest, new DelegatingIndexNotFoundActionListener<>(connectorId, listener, (ll, updateResponse) -> { + if (updateResponse.getResult() == UpdateResponse.Result.NOT_FOUND) { + ll.onFailure(new ResourceNotFoundException(connectorNotFoundErrorMsg(connectorId))); + return; + } + ll.onResponse(updateResponse); + })); + })); + + } catch (Exception e) { + listener.onFailure(e); + } + } + /** * Updates the lastSeen property of a {@link Connector}. * diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/TransportUpdateConnectorFilteringAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/TransportUpdateConnectorFilteringAction.java index 658a8075121af..ac3b3212c02da 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/TransportUpdateConnectorFilteringAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/TransportUpdateConnectorFilteringAction.java @@ -16,7 +16,12 @@ import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.application.connector.ConnectorFiltering; import org.elasticsearch.xpack.application.connector.ConnectorIndexService; +import org.elasticsearch.xpack.application.connector.filtering.FilteringAdvancedSnippet; +import org.elasticsearch.xpack.application.connector.filtering.FilteringRule; + +import java.util.List; public class TransportUpdateConnectorFilteringAction extends HandledTransportAction< UpdateConnectorFilteringAction.Request, @@ -47,6 +52,27 @@ protected void doExecute( UpdateConnectorFilteringAction.Request request, ActionListener listener ) { - connectorIndexService.updateConnectorFiltering(request, listener.map(r -> new ConnectorUpdateActionResponse(r.getResult()))); + String connectorId = request.getConnectorId(); + List filtering = request.getFiltering(); + FilteringAdvancedSnippet advancedSnippet = request.getAdvancedSnippet(); + List rules = request.getRules(); + // If [filtering] is not present in request body, it means that user's intention is to + // update draft's rules or advanced snippet + if (request.getFiltering() == null) { + connectorIndexService.updateConnectorFilteringDraft( + connectorId, + advancedSnippet, + rules, + listener.map(r -> new ConnectorUpdateActionResponse(r.getResult())) + ); + } + // Otherwise override the whole filtering object (discouraged in docs) + else { + connectorIndexService.updateConnectorFiltering( + connectorId, + filtering, + listener.map(r -> new ConnectorUpdateActionResponse(r.getResult())) + ); + } } } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorFilteringAction.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorFilteringAction.java index 9d55c12e4b7a1..566a01b855b99 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorFilteringAction.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorFilteringAction.java @@ -15,6 +15,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; @@ -23,13 +24,17 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.application.connector.Connector; import org.elasticsearch.xpack.application.connector.ConnectorFiltering; +import org.elasticsearch.xpack.application.connector.filtering.FilteringAdvancedSnippet; +import org.elasticsearch.xpack.application.connector.filtering.FilteringRule; +import org.elasticsearch.xpack.application.connector.filtering.FilteringRules; import java.io.IOException; import java.util.List; import java.util.Objects; import static org.elasticsearch.action.ValidateActions.addValidationError; -import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; +import static org.elasticsearch.xpack.application.connector.ConnectorFiltering.isDefaultRulePresentInFilteringRules; public class UpdateConnectorFilteringAction { @@ -41,17 +46,31 @@ private UpdateConnectorFilteringAction() {/* no instances */} public static class Request extends ConnectorActionRequest implements ToXContentObject { private final String connectorId; + @Nullable private final List filtering; + @Nullable + private final FilteringAdvancedSnippet advancedSnippet; + @Nullable + private final List rules; - public Request(String connectorId, List filtering) { + public Request( + String connectorId, + List filtering, + FilteringAdvancedSnippet advancedSnippet, + List rules + ) { this.connectorId = connectorId; this.filtering = filtering; + this.advancedSnippet = advancedSnippet; + this.rules = rules; } public Request(StreamInput in) throws IOException { super(in); this.connectorId = in.readString(); this.filtering = in.readOptionalCollectionAsList(ConnectorFiltering::new); + this.advancedSnippet = new FilteringAdvancedSnippet(in); + this.rules = in.readCollectionAsList(FilteringRule::new); } public String getConnectorId() { @@ -62,6 +81,14 @@ public List getFiltering() { return filtering; } + public FilteringAdvancedSnippet getAdvancedSnippet() { + return advancedSnippet; + } + + public List getRules() { + return rules; + } + @Override public ActionRequestValidationException validate() { ActionRequestValidationException validationException = null; @@ -70,8 +97,29 @@ public ActionRequestValidationException validate() { validationException = addValidationError("[connector_id] cannot be [null] or [\"\"].", validationException); } + // If [filtering] is not present in the request payload it means that the user should define [rules] and/or [advanced_snippet] if (filtering == null) { - validationException = addValidationError("[filtering] cannot be [null].", validationException); + if (rules == null && advancedSnippet == null) { + validationException = addValidationError("[advanced_snippet] and [rules] cannot be both [null].", validationException); + } else if (rules != null) { + if (rules.isEmpty()) { + validationException = addValidationError("[rules] cannot be an empty list.", validationException); + } else if (isDefaultRulePresentInFilteringRules(rules) == false) { + validationException = addValidationError( + "[rules] need to include the default filtering rule.", + validationException + ); + } + } + } + // If [filtering] is present we don't expect [rules] and [advances_snippet] in the request body + else { + if (rules != null || advancedSnippet != null) { + validationException = addValidationError( + "If [filtering] is specified, [rules] and [advanced_snippet] should not be present in the request body.", + validationException + ); + } } return validationException; @@ -82,11 +130,22 @@ public ActionRequestValidationException validate() { new ConstructingObjectParser<>( "connector_update_filtering_request", false, - ((args, connectorId) -> new UpdateConnectorFilteringAction.Request(connectorId, (List) args[0])) + ((args, connectorId) -> new UpdateConnectorFilteringAction.Request( + connectorId, + (List) args[0], + (FilteringAdvancedSnippet) args[1], + (List) args[2] + )) ); static { - PARSER.declareObjectArray(constructorArg(), (p, c) -> ConnectorFiltering.fromXContent(p), Connector.FILTERING_FIELD); + PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> ConnectorFiltering.fromXContent(p), Connector.FILTERING_FIELD); + PARSER.declareObject( + optionalConstructorArg(), + (p, c) -> FilteringAdvancedSnippet.fromXContent(p), + FilteringRules.ADVANCED_SNIPPET_FIELD + ); + PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> FilteringRule.fromXContent(p), FilteringRules.RULES_FIELD); } public static UpdateConnectorFilteringAction.Request fromXContentBytes( @@ -110,6 +169,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); { builder.field(Connector.FILTERING_FIELD.getPreferredName(), filtering); + builder.field(FilteringRules.ADVANCED_SNIPPET_FIELD.getPreferredName(), advancedSnippet); + builder.xContentList(FilteringRules.RULES_FIELD.getPreferredName(), rules); } builder.endObject(); return builder; @@ -120,6 +181,8 @@ public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeString(connectorId); out.writeOptionalCollection(filtering); + advancedSnippet.writeTo(out); + out.writeCollection(rules); } @Override @@ -127,12 +190,15 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; Request request = (Request) o; - return Objects.equals(connectorId, request.connectorId) && Objects.equals(filtering, request.filtering); + return Objects.equals(connectorId, request.connectorId) + && Objects.equals(filtering, request.filtering) + && Objects.equals(advancedSnippet, request.advancedSnippet) + && Objects.equals(rules, request.rules); } @Override public int hashCode() { - return Objects.hash(connectorId, filtering); + return Objects.hash(connectorId, filtering, advancedSnippet, rules); } } } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/filtering/FilteringAdvancedSnippet.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/filtering/FilteringAdvancedSnippet.java index 384fbc7bb5340..62da1dab08358 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/filtering/FilteringAdvancedSnippet.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/filtering/FilteringAdvancedSnippet.java @@ -10,6 +10,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ObjectParser; import org.elasticsearch.xcontent.ParseField; @@ -24,6 +25,7 @@ import java.util.Objects; import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; /** * Represents an advanced snippet used in filtering processes, providing detailed criteria or rules. @@ -31,8 +33,9 @@ * actual snippet content represented as a map. */ public class FilteringAdvancedSnippet implements Writeable, ToXContentObject { - + @Nullable private final Instant advancedSnippetCreatedAt; + @Nullable private final Instant advancedSnippetUpdatedAt; private final Object advancedSnippetValue; @@ -48,8 +51,8 @@ private FilteringAdvancedSnippet(Instant advancedSnippetCreatedAt, Instant advan } public FilteringAdvancedSnippet(StreamInput in) throws IOException { - this.advancedSnippetCreatedAt = in.readInstant(); - this.advancedSnippetUpdatedAt = in.readInstant(); + this.advancedSnippetCreatedAt = in.readOptionalInstant(); + this.advancedSnippetUpdatedAt = in.readOptionalInstant(); this.advancedSnippetValue = in.readGenericValue(); } @@ -57,6 +60,18 @@ public FilteringAdvancedSnippet(StreamInput in) throws IOException { private static final ParseField UPDATED_AT_FIELD = new ParseField("updated_at"); private static final ParseField VALUE_FIELD = new ParseField("value"); + public Instant getAdvancedSnippetCreatedAt() { + return advancedSnippetCreatedAt; + } + + public Instant getAdvancedSnippetUpdatedAt() { + return advancedSnippetUpdatedAt; + } + + public Object getAdvancedSnippetValue() { + return advancedSnippetValue; + } + @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( "connector_filtering_advanced_snippet", @@ -69,13 +84,13 @@ public FilteringAdvancedSnippet(StreamInput in) throws IOException { static { PARSER.declareField( - constructorArg(), + optionalConstructorArg(), (p, c) -> ConnectorUtils.parseInstant(p, CREATED_AT_FIELD.getPreferredName()), CREATED_AT_FIELD, ObjectParser.ValueType.STRING ); PARSER.declareField( - constructorArg(), + optionalConstructorArg(), (p, c) -> ConnectorUtils.parseInstant(p, UPDATED_AT_FIELD.getPreferredName()), UPDATED_AT_FIELD, ObjectParser.ValueType.STRING @@ -108,8 +123,8 @@ public static FilteringAdvancedSnippet fromXContent(XContentParser parser) throw @Override public void writeTo(StreamOutput out) throws IOException { - out.writeInstant(advancedSnippetCreatedAt); - out.writeInstant(advancedSnippetUpdatedAt); + out.writeOptionalInstant(advancedSnippetCreatedAt); + out.writeOptionalInstant(advancedSnippetUpdatedAt); out.writeGenericValue(advancedSnippetValue); } @@ -133,14 +148,15 @@ public static class Builder { private Instant advancedSnippetCreatedAt; private Instant advancedSnippetUpdatedAt; private Object advancedSnippetValue; + private final Instant currentTimestamp = Instant.now(); public Builder setAdvancedSnippetCreatedAt(Instant advancedSnippetCreatedAt) { - this.advancedSnippetCreatedAt = advancedSnippetCreatedAt; + this.advancedSnippetCreatedAt = Objects.requireNonNullElse(advancedSnippetCreatedAt, currentTimestamp); return this; } public Builder setAdvancedSnippetUpdatedAt(Instant advancedSnippetUpdatedAt) { - this.advancedSnippetUpdatedAt = advancedSnippetUpdatedAt; + this.advancedSnippetUpdatedAt = Objects.requireNonNullElse(advancedSnippetUpdatedAt, currentTimestamp); return this; } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/filtering/FilteringRule.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/filtering/FilteringRule.java index 02571078f4e21..3829eb7442522 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/filtering/FilteringRule.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/filtering/FilteringRule.java @@ -23,6 +23,7 @@ import java.util.Objects; import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; /** * Represents a single rule used for filtering in a data processing or querying context. @@ -75,13 +76,13 @@ public FilteringRule( } public FilteringRule(StreamInput in) throws IOException { - this.createdAt = in.readInstant(); + this.createdAt = in.readOptionalInstant(); this.field = in.readString(); this.id = in.readString(); this.order = in.readInt(); this.policy = in.readEnum(FilteringPolicy.class); this.rule = in.readEnum(FilteringRuleCondition.class); - this.updatedAt = in.readInstant(); + this.updatedAt = in.readOptionalInstant(); this.value = in.readString(); } @@ -110,7 +111,7 @@ public FilteringRule(StreamInput in) throws IOException { static { PARSER.declareField( - constructorArg(), + optionalConstructorArg(), (p, c) -> ConnectorUtils.parseInstant(p, CREATED_AT_FIELD.getPreferredName()), CREATED_AT_FIELD, ObjectParser.ValueType.STRING @@ -131,7 +132,7 @@ public FilteringRule(StreamInput in) throws IOException { ObjectParser.ValueType.STRING ); PARSER.declareField( - constructorArg(), + optionalConstructorArg(), (p, c) -> ConnectorUtils.parseInstant(p, UPDATED_AT_FIELD.getPreferredName()), UPDATED_AT_FIELD, ObjectParser.ValueType.STRING @@ -160,13 +161,13 @@ public static FilteringRule fromXContent(XContentParser parser) throws IOExcepti @Override public void writeTo(StreamOutput out) throws IOException { - out.writeInstant(createdAt); + out.writeOptionalInstant(createdAt); out.writeString(field); out.writeString(id); out.writeInt(order); out.writeEnum(policy); out.writeEnum(rule); - out.writeInstant(updatedAt); + out.writeOptionalInstant(updatedAt); out.writeString(value); } @@ -185,6 +186,18 @@ public boolean equals(Object o) { && Objects.equals(value, that.value); } + /** + * Compares this {@code FilteringRule} to another rule for equality, ignoring differences + * in created_at, updated_at timestamps and order. + */ + public boolean equalsExceptForTimestampsAndOrder(FilteringRule that) { + return Objects.equals(field, that.field) + && Objects.equals(id, that.id) + && policy == that.policy + && rule == that.rule + && Objects.equals(value, that.value); + } + @Override public int hashCode() { return Objects.hash(createdAt, field, id, order, policy, rule, updatedAt, value); @@ -200,9 +213,10 @@ public static class Builder { private FilteringRuleCondition rule; private Instant updatedAt; private String value; + private final Instant currentTimestamp = Instant.now(); public Builder setCreatedAt(Instant createdAt) { - this.createdAt = createdAt; + this.createdAt = Objects.requireNonNullElse(createdAt, currentTimestamp); return this; } @@ -232,7 +246,7 @@ public Builder setRule(FilteringRuleCondition rule) { } public Builder setUpdatedAt(Instant updatedAt) { - this.updatedAt = updatedAt; + this.updatedAt = Objects.requireNonNullElse(updatedAt, currentTimestamp); return this; } diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/filtering/FilteringRules.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/filtering/FilteringRules.java index fb4e25131449d..35d18d23450b1 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/filtering/FilteringRules.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/filtering/FilteringRules.java @@ -69,9 +69,9 @@ public FilteringValidationInfo getFilteringValidationInfo() { return filteringValidationInfo; } - private static final ParseField ADVANCED_SNIPPET_FIELD = new ParseField("advanced_snippet"); - private static final ParseField RULES_FIELD = new ParseField("rules"); - private static final ParseField VALIDATION_FIELD = new ParseField("validation"); + public static final ParseField ADVANCED_SNIPPET_FIELD = new ParseField("advanced_snippet"); + public static final ParseField RULES_FIELD = new ParseField("rules"); + public static final ParseField VALIDATION_FIELD = new ParseField("validation"); @SuppressWarnings("unchecked") private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/filtering/FilteringValidationInfo.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/filtering/FilteringValidationInfo.java index c0cd80d867592..cd197bf0538e4 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/filtering/FilteringValidationInfo.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/filtering/FilteringValidationInfo.java @@ -18,6 +18,7 @@ import org.elasticsearch.xcontent.XContentParser; import java.io.IOException; +import java.util.Collections; import java.util.List; import java.util.Objects; @@ -105,6 +106,12 @@ public int hashCode() { return Objects.hash(validationErrors, validationState); } + public static FilteringValidationInfo getInitialDraftValidationInfo() { + return new FilteringValidationInfo.Builder().setValidationErrors(Collections.emptyList()) + .setValidationState(FilteringValidationState.EDITED) + .build(); + } + public static class Builder { private List validationErrors; diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/ConnectorIndexServiceTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/ConnectorIndexServiceTests.java index f483887c4d81b..ea510086fcf8c 100644 --- a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/ConnectorIndexServiceTests.java +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/ConnectorIndexServiceTests.java @@ -28,7 +28,6 @@ import org.elasticsearch.xpack.application.connector.action.UpdateConnectorApiKeyIdAction; import org.elasticsearch.xpack.application.connector.action.UpdateConnectorConfigurationAction; import org.elasticsearch.xpack.application.connector.action.UpdateConnectorErrorAction; -import org.elasticsearch.xpack.application.connector.action.UpdateConnectorFilteringAction; import org.elasticsearch.xpack.application.connector.action.UpdateConnectorIndexNameAction; import org.elasticsearch.xpack.application.connector.action.UpdateConnectorLastSeenAction; import org.elasticsearch.xpack.application.connector.action.UpdateConnectorLastSyncStatsAction; @@ -38,6 +37,9 @@ import org.elasticsearch.xpack.application.connector.action.UpdateConnectorSchedulingAction; import org.elasticsearch.xpack.application.connector.action.UpdateConnectorServiceTypeAction; import org.elasticsearch.xpack.application.connector.action.UpdateConnectorStatusAction; +import org.elasticsearch.xpack.application.connector.filtering.FilteringAdvancedSnippet; +import org.elasticsearch.xpack.application.connector.filtering.FilteringRule; +import org.elasticsearch.xpack.application.connector.filtering.FilteringValidationInfo; import org.junit.Before; import java.util.ArrayList; @@ -248,17 +250,46 @@ public void testUpdateConnectorFiltering() throws Exception { .mapToObj((i) -> ConnectorTestUtils.getRandomConnectorFiltering()) .collect(Collectors.toList()); - UpdateConnectorFilteringAction.Request updateFilteringRequest = new UpdateConnectorFilteringAction.Request( - connectorId, - filteringList - ); - - DocWriteResponse updateResponse = awaitUpdateConnectorFiltering(updateFilteringRequest); + DocWriteResponse updateResponse = awaitUpdateConnectorFiltering(connectorId, filteringList); assertThat(updateResponse.status(), equalTo(RestStatus.OK)); Connector indexedConnector = awaitGetConnector(connectorId); assertThat(filteringList, equalTo(indexedConnector.getFiltering())); } + public void testUpdateConnectorFiltering_updateDraft() throws Exception { + Connector connector = ConnectorTestUtils.getRandomConnector(); + String connectorId = randomUUID(); + + DocWriteResponse resp = buildRequestAndAwaitPutConnector(connectorId, connector); + assertThat(resp.status(), anyOf(equalTo(RestStatus.CREATED), equalTo(RestStatus.OK))); + + FilteringAdvancedSnippet advancedSnippet = ConnectorTestUtils.getRandomConnectorFiltering().getDraft().getAdvancedSnippet(); + List rules = ConnectorTestUtils.getRandomConnectorFiltering().getDraft().getRules(); + + DocWriteResponse updateResponse = awaitUpdateConnectorFilteringDraft(connectorId, advancedSnippet, rules); + assertThat(updateResponse.status(), equalTo(RestStatus.OK)); + Connector indexedConnector = awaitGetConnector(connectorId); + + // Assert that draft got updated + assertThat(advancedSnippet, equalTo(indexedConnector.getFiltering().get(0).getDraft().getAdvancedSnippet())); + assertThat(rules, equalTo(indexedConnector.getFiltering().get(0).getDraft().getRules())); + // Assert that draft is marked as EDITED + assertThat( + FilteringValidationInfo.getInitialDraftValidationInfo(), + equalTo(indexedConnector.getFiltering().get(0).getDraft().getFilteringValidationInfo()) + ); + // Assert that default active rules are unchanged, avoid comparing timestamps + assertThat( + ConnectorFiltering.getDefaultConnectorFilteringConfig().getActive().getAdvancedSnippet().getAdvancedSnippetValue(), + equalTo(indexedConnector.getFiltering().get(0).getActive().getAdvancedSnippet().getAdvancedSnippetValue()) + ); + // Assert that domain is unchanged + assertThat( + ConnectorFiltering.getDefaultConnectorFilteringConfig().getDomain(), + equalTo(indexedConnector.getFiltering().get(0).getDomain()) + ); + } + public void testUpdateConnectorLastSeen() throws Exception { Connector connector = ConnectorTestUtils.getRandomConnector(); String connectorId = randomUUID(); @@ -717,11 +748,11 @@ public void onFailure(Exception e) { return resp.get(); } - private UpdateResponse awaitUpdateConnectorFiltering(UpdateConnectorFilteringAction.Request updateFiltering) throws Exception { + private UpdateResponse awaitUpdateConnectorFiltering(String connectorId, List filtering) throws Exception { CountDownLatch latch = new CountDownLatch(1); final AtomicReference resp = new AtomicReference<>(null); final AtomicReference exc = new AtomicReference<>(null); - connectorIndexService.updateConnectorFiltering(updateFiltering, new ActionListener<>() { + connectorIndexService.updateConnectorFiltering(connectorId, filtering, new ActionListener<>() { @Override public void onResponse(UpdateResponse indexResponse) { @@ -744,6 +775,36 @@ public void onFailure(Exception e) { return resp.get(); } + private UpdateResponse awaitUpdateConnectorFilteringDraft( + String connectorId, + FilteringAdvancedSnippet advancedSnippet, + List rules + ) throws Exception { + CountDownLatch latch = new CountDownLatch(1); + final AtomicReference resp = new AtomicReference<>(null); + final AtomicReference exc = new AtomicReference<>(null); + connectorIndexService.updateConnectorFilteringDraft(connectorId, advancedSnippet, rules, new ActionListener<>() { + @Override + public void onResponse(UpdateResponse indexResponse) { + resp.set(indexResponse); + latch.countDown(); + } + + @Override + public void onFailure(Exception e) { + exc.set(e); + latch.countDown(); + } + }); + + assertTrue("Timeout waiting for update filtering request", latch.await(REQUEST_TIMEOUT_SECONDS, TimeUnit.SECONDS)); + if (exc.get() != null) { + throw exc.get(); + } + assertNotNull("Received null response from update filtering request", resp.get()); + return resp.get(); + } + private UpdateResponse awaitUpdateConnectorIndexName(UpdateConnectorIndexNameAction.Request updateIndexNameRequest) throws Exception { CountDownLatch latch = new CountDownLatch(1); final AtomicReference resp = new AtomicReference<>(null); diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/ConnectorTestUtils.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/ConnectorTestUtils.java index 35a910b5641a9..876a1092a1d5b 100644 --- a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/ConnectorTestUtils.java +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/ConnectorTestUtils.java @@ -210,7 +210,6 @@ public static ConnectorFiltering getRandomConnectorFiltering() { ) .build() ) - .setDomain(randomAlphaOfLength(10)) .setDraft( new FilteringRules.Builder().setAdvancedSnippet( new FilteringAdvancedSnippet.Builder().setAdvancedSnippetCreatedAt(currentTimestamp) diff --git a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorFilteringActionRequestBWCSerializingTests.java b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorFilteringActionRequestBWCSerializingTests.java index 1d433d58be6ad..6874f4b2a1b36 100644 --- a/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorFilteringActionRequestBWCSerializingTests.java +++ b/x-pack/plugin/ent-search/src/test/java/org/elasticsearch/xpack/application/connector/action/UpdateConnectorFilteringActionRequestBWCSerializingTests.java @@ -31,7 +31,9 @@ protected UpdateConnectorFilteringAction.Request createTestInstance() { this.connectorId = randomUUID(); return new UpdateConnectorFilteringAction.Request( connectorId, - List.of(ConnectorTestUtils.getRandomConnectorFiltering(), ConnectorTestUtils.getRandomConnectorFiltering()) + List.of(ConnectorTestUtils.getRandomConnectorFiltering(), ConnectorTestUtils.getRandomConnectorFiltering()), + ConnectorTestUtils.getRandomConnectorFiltering().getActive().getAdvancedSnippet(), + ConnectorTestUtils.getRandomConnectorFiltering().getActive().getRules() ); } From 0660f7fcda96c70256255d57898ecbe979815218 Mon Sep 17 00:00:00 2001 From: Jake Landis Date: Fri, 12 Apr 2024 10:05:07 -0500 Subject: [PATCH 10/20] ES|QL with RCS 2.0 security fix (#107079) This commit provides the ES security changes to support internal ES|QL actions when running ES|QL queries across clusters that use RCS 2.0 (API keys) as the security model. The tests have been updated to illustrate the primary workflow working. --- .../privilege/ClusterPrivilegeResolver.java | 4 +- .../authz/privilege/IndexPrivilege.java | 6 +- .../operator/exchange/ExchangeService.java | 17 + .../RemoteClusterSecurityEsqlIT.java | 505 +++++++++++++++--- .../SecurityServerTransportInterceptor.java | 6 +- 5 files changed, 456 insertions(+), 82 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/privilege/ClusterPrivilegeResolver.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/privilege/ClusterPrivilegeResolver.java index 47e4a6913897b..3774efcdd2ad2 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/privilege/ClusterPrivilegeResolver.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/privilege/ClusterPrivilegeResolver.java @@ -180,7 +180,9 @@ public class ClusterPrivilegeResolver { RemoteClusterNodesAction.TYPE.name(), XPackInfoAction.NAME, // esql enrich - "cluster:monitor/xpack/enrich/esql/resolve_policy" + "cluster:monitor/xpack/enrich/esql/resolve_policy", + "cluster:internal:data/read/esql/open_exchange", + "cluster:internal:data/read/esql/exchange" ); private static final Set CROSS_CLUSTER_REPLICATION_PATTERN = Set.of( RemoteClusterService.REMOTE_CLUSTER_HANDSHAKE_ACTION_NAME, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/privilege/IndexPrivilege.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/privilege/IndexPrivilege.java index 066924b21c99f..674706eb9af49 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/privilege/IndexPrivilege.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/privilege/IndexPrivilege.java @@ -86,10 +86,8 @@ public final class IndexPrivilege extends Privilege { ClusterSearchShardsAction.NAME, TransportSearchShardsAction.TYPE.name(), TransportResolveClusterAction.NAME, - // cross clusters query for ESQL - "internal:data/read/esql/open_exchange", - "internal:data/read/esql/exchange", - "indices:data/read/esql/cluster" + "indices:data/read/esql", + "indices:data/read/esql/compute" ); private static final Automaton CREATE_AUTOMATON = patterns( "indices:data/write/index*", diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeService.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeService.java index a8afce1a3b223..da014ada387d6 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeService.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeService.java @@ -50,8 +50,10 @@ public final class ExchangeService extends AbstractLifecycleComponent { // TODO: Make this a child action of the data node transport to ensure that exchanges // are accessed only by the user initialized the session. public static final String EXCHANGE_ACTION_NAME = "internal:data/read/esql/exchange"; + public static final String EXCHANGE_ACTION_NAME_FOR_CCS = "cluster:internal:data/read/esql/exchange"; private static final String OPEN_EXCHANGE_ACTION_NAME = "internal:data/read/esql/open_exchange"; + private static final String OPEN_EXCHANGE_ACTION_NAME_FOR_CCS = "cluster:internal:data/read/esql/open_exchange"; /** * The time interval for an exchange sink handler to be considered inactive and subsequently @@ -85,6 +87,21 @@ public void registerTransportHandler(TransportService transportService) { OpenExchangeRequest::new, new OpenExchangeRequestHandler() ); + + // This allows the system user access this action when executed over CCS and the API key based security model is in use + transportService.registerRequestHandler( + EXCHANGE_ACTION_NAME_FOR_CCS, + this.executor, + ExchangeRequest::new, + new ExchangeTransportAction() + ); + transportService.registerRequestHandler( + OPEN_EXCHANGE_ACTION_NAME_FOR_CCS, + this.executor, + OpenExchangeRequest::new, + new OpenExchangeRequestHandler() + ); + } /** diff --git a/x-pack/plugin/security/qa/multi-cluster/src/javaRestTest/java/org/elasticsearch/xpack/remotecluster/RemoteClusterSecurityEsqlIT.java b/x-pack/plugin/security/qa/multi-cluster/src/javaRestTest/java/org/elasticsearch/xpack/remotecluster/RemoteClusterSecurityEsqlIT.java index e181a3542d446..4c7e96c26b7d6 100644 --- a/x-pack/plugin/security/qa/multi-cluster/src/javaRestTest/java/org/elasticsearch/xpack/remotecluster/RemoteClusterSecurityEsqlIT.java +++ b/x-pack/plugin/security/qa/multi-cluster/src/javaRestTest/java/org/elasticsearch/xpack/remotecluster/RemoteClusterSecurityEsqlIT.java @@ -36,7 +36,10 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; +import java.util.stream.Stream; +import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -83,7 +86,11 @@ public class RemoteClusterSecurityEsqlIT extends AbstractRemoteClusterSecurityTe { "search": [ { - "names": ["index*", "not_found_index", "employees"] + "names": ["index*", "not_found_index", "employees", "employees2"] + }, + { + "names": ["employees3"], + "query": {"term" : {"department" : "engineering"}} } ] }"""); @@ -188,40 +195,8 @@ public void populateData() throws Exception { performRequestWithAdminUser(client, new Request("DELETE", "/countries")); }; // Fulfilling cluster - { - setupEnrich.accept(fulfillingClusterClient); - Request createIndex = new Request("PUT", "employees"); - createIndex.setJsonEntity(""" - { - "mappings": { - "properties": { - "emp_id": { "type": "keyword" }, - "department": {"type": "keyword" } - } - } - } - """); - assertOK(performRequestAgainstFulfillingCluster(createIndex)); - final Request bulkRequest = new Request("POST", "/_bulk?refresh=true"); - bulkRequest.setJsonEntity(Strings.format(""" - { "index": { "_index": "employees" } } - { "emp_id": "1", "department" : "engineering" } - { "index": { "_index": "employees" } } - { "emp_id": "3", "department" : "sales" } - { "index": { "_index": "employees" } } - { "emp_id": "5", "department" : "marketing" } - { "index": { "_index": "employees" } } - { "emp_id": "7", "department" : "engineering" } - { "index": { "_index": "employees" } } - { "emp_id": "9", "department" : "sales" } - """)); - assertOK(performRequestAgainstFulfillingCluster(bulkRequest)); - } - // Querying cluster - // Index some documents, to use them in a mixed-cluster search - setupEnrich.accept(client()); - Request createIndex = new Request("PUT", "employees"); - createIndex.setJsonEntity(""" + setupEnrich.accept(fulfillingClusterClient); + String employeesMapping = """ { "mappings": { "properties": { @@ -230,9 +205,57 @@ public void populateData() throws Exception { } } } - """); + """; + Request createIndex = new Request("PUT", "employees"); + createIndex.setJsonEntity(employeesMapping); + assertOK(performRequestAgainstFulfillingCluster(createIndex)); + Request createIndex2 = new Request("PUT", "employees2"); + createIndex2.setJsonEntity(employeesMapping); + assertOK(performRequestAgainstFulfillingCluster(createIndex2)); + Request createIndex3 = new Request("PUT", "employees3"); + createIndex3.setJsonEntity(employeesMapping); + assertOK(performRequestAgainstFulfillingCluster(createIndex3)); + Request bulkRequest = new Request("POST", "/_bulk?refresh=true"); + bulkRequest.setJsonEntity(Strings.format(""" + { "index": { "_index": "employees" } } + { "emp_id": "1", "department" : "engineering" } + { "index": { "_index": "employees" } } + { "emp_id": "3", "department" : "sales" } + { "index": { "_index": "employees" } } + { "emp_id": "5", "department" : "marketing" } + { "index": { "_index": "employees" } } + { "emp_id": "7", "department" : "engineering" } + { "index": { "_index": "employees" } } + { "emp_id": "9", "department" : "sales" } + { "index": { "_index": "employees2" } } + { "emp_id": "11", "department" : "engineering" } + { "index": { "_index": "employees2" } } + { "emp_id": "13", "department" : "sales" } + { "index": { "_index": "employees3" } } + { "emp_id": "21", "department" : "engineering" } + { "index": { "_index": "employees3" } } + { "emp_id": "23", "department" : "sales" } + { "index": { "_index": "employees3" } } + { "emp_id": "25", "department" : "engineering" } + { "index": { "_index": "employees3" } } + { "emp_id": "27", "department" : "sales" } + """)); + assertOK(performRequestAgainstFulfillingCluster(bulkRequest)); + + // Querying cluster + // Index some documents, to use them in a mixed-cluster search + setupEnrich.accept(client()); + + createIndex = new Request("PUT", "employees"); + createIndex.setJsonEntity(employeesMapping); assertOK(adminClient().performRequest(createIndex)); - final Request bulkRequest = new Request("POST", "/_bulk?refresh=true"); + createIndex2 = new Request("PUT", "employees2"); + createIndex2.setJsonEntity(employeesMapping); + assertOK(adminClient().performRequest(createIndex2)); + createIndex3 = new Request("PUT", "employees3"); + createIndex3.setJsonEntity(employeesMapping); + assertOK(adminClient().performRequest(createIndex3)); + bulkRequest = new Request("POST", "/_bulk?refresh=true"); bulkRequest.setJsonEntity(Strings.format(""" { "index": { "_index": "employees" } } { "emp_id": "2", "department" : "management" } @@ -242,6 +265,14 @@ public void populateData() throws Exception { { "emp_id": "6", "department" : "marketing"} { "index": { "_index": "employees"} } { "emp_id": "8", "department" : "support"} + { "index": { "_index": "employees2"} } + { "emp_id": "10", "department" : "management"} + { "index": { "_index": "employees2"} } + { "emp_id": "12", "department" : "engineering"} + { "index": { "_index": "employees3"} } + { "emp_id": "20", "department" : "management"} + { "index": { "_index": "employees3"} } + { "emp_id": "22", "department" : "engineering"} """)); assertOK(client().performRequest(bulkRequest)); @@ -259,7 +290,7 @@ public void populateData() throws Exception { "remote_indices": [ { "names": ["employees"], - "privileges": ["read", "read_cross_cluster"], + "privileges": ["read"], "clusters": ["my_remote_cluster"] } ] @@ -278,56 +309,303 @@ public void populateData() throws Exception { public void wipeData() throws Exception { CheckedConsumer wipe = client -> { performRequestWithAdminUser(client, new Request("DELETE", "/employees")); + performRequestWithAdminUser(client, new Request("DELETE", "/employees2")); + performRequestWithAdminUser(client, new Request("DELETE", "/employees3")); performRequestWithAdminUser(client, new Request("DELETE", "/_enrich/policy/countries")); }; wipe.accept(fulfillingClusterClient); wipe.accept(client()); } - @AwaitsFix(bugUrl = "cross-clusters query doesn't work with RCS 2.0") + @SuppressWarnings("unchecked") public void testCrossClusterQuery() throws Exception { configureRemoteCluster(); populateData(); + + // query remote cluster only + Response response = performRequestWithRemoteSearchUser(esqlRequest(""" + FROM my_remote_cluster:employees + | SORT emp_id ASC + | LIMIT 2 + | KEEP emp_id, department""")); + assertOK(response); + assertRemoteOnlyResults(response); + + // query remote and local cluster + response = performRequestWithRemoteSearchUser(esqlRequest(""" + FROM my_remote_cluster:employees,employees + | SORT emp_id ASC + | LIMIT 10""")); + assertOK(response); + assertRemoteAndLocalResults(response); + + // query remote cluster only - but also include employees2 which the user does not have access to + response = performRequestWithRemoteSearchUser(esqlRequest(""" + FROM my_remote_cluster:employees,my_remote_cluster:employees2 + | SORT emp_id ASC + | LIMIT 2 + | KEEP emp_id, department""")); + assertOK(response); + assertRemoteOnlyResults(response); // same as above since the user only has access to employees + + // query remote and local cluster - but also include employees2 which the user does not have access to + response = performRequestWithRemoteSearchUser(esqlRequest(""" + FROM my_remote_cluster:employees,my_remote_cluster:employees2,employees,employees2 + | SORT emp_id ASC + | LIMIT 10""")); + assertOK(response); + assertRemoteAndLocalResults(response); // same as above since the user only has access to employees + + // update role to include both employees and employees2 for the remote cluster + final var putRoleRequest = new Request("PUT", "/_security/role/" + REMOTE_SEARCH_ROLE); + putRoleRequest.setJsonEntity(""" + { + "indices": [{"names": [""], "privileges": ["read_cross_cluster"]}], + "remote_indices": [ + { + "names": ["employees*"], + "privileges": ["read"], + "clusters": ["my_remote_cluster"] + } + ] + }"""); + response = adminClient().performRequest(putRoleRequest); + assertOK(response); + + // query remote cluster only - but also include employees2 which the user now access + response = performRequestWithRemoteSearchUser(esqlRequest(""" + FROM my_remote_cluster:employees,my_remote_cluster:employees2 + | SORT emp_id ASC + | LIMIT 2 + | KEEP emp_id, department""")); + assertOK(response); + assertRemoteOnlyAgainst2IndexResults(response); + } + + @SuppressWarnings("unchecked") + public void testCrossClusterQueryWithRemoteDLSAndFLS() throws Exception { + configureRemoteCluster(); + populateData(); + + // ensure user has access to the employees3 index + final var putRoleRequest = new Request("PUT", "/_security/role/" + REMOTE_SEARCH_ROLE); + putRoleRequest.setJsonEntity(""" + { + "indices": [{"names": [""], "privileges": ["read_cross_cluster"]}], + "remote_indices": [ + { + "names": ["employees*"], + "privileges": ["read"], + "clusters": ["my_remote_cluster"] + + } + ] + }"""); + Response response = adminClient().performRequest(putRoleRequest); + assertOK(response); + + response = performRequestWithRemoteSearchUser(esqlRequest(""" + FROM my_remote_cluster:employees3 + | SORT emp_id ASC + | LIMIT 10 + | KEEP emp_id, department""")); + assertOK(response); + + Map responseAsMap = entityAsMap(response); + List columns = (List) responseAsMap.get("columns"); + List values = (List) responseAsMap.get("values"); + assertEquals(2, columns.size()); + assertEquals(2, values.size()); + List flatList = values.stream() + .flatMap(innerList -> innerList instanceof List ? ((List) innerList).stream() : Stream.empty()) + .collect(Collectors.toList()); + // the APIKey has DLS set to : "query": {"term" : {"department" : "engineering"}} + assertThat(flatList, containsInAnyOrder("21", "25", "engineering", "engineering")); + + // add DLS to the remote indices in the role to restrict access to only emp_id = 21 + putRoleRequest.setJsonEntity(""" + { + "indices": [{"names": [""], "privileges": ["read_cross_cluster"]}], + "remote_indices": [ + { + "names": ["employees*"], + "privileges": ["read"], + "clusters": ["my_remote_cluster"], + "query": {"term" : {"emp_id" : "21"}} + + } + ] + }"""); + response = adminClient().performRequest(putRoleRequest); + assertOK(response); + + response = performRequestWithRemoteSearchUser(esqlRequest(""" + FROM my_remote_cluster:employees3 + | SORT emp_id ASC + | LIMIT 2 + | KEEP emp_id, department""")); + assertOK(response); + + responseAsMap = entityAsMap(response); + columns = (List) responseAsMap.get("columns"); + values = (List) responseAsMap.get("values"); + assertEquals(2, columns.size()); + assertEquals(1, values.size()); + flatList = values.stream() + .flatMap(innerList -> innerList instanceof List ? ((List) innerList).stream() : Stream.empty()) + .collect(Collectors.toList()); + // the APIKey has DLS set to : "query": {"term" : {"department" : "engineering"}} + // AND this role has DLS set to: "query": {"term" : {"emp_id" : "21"}} + assertThat(flatList, containsInAnyOrder("21", "engineering")); + + // add FLS to the remote indices in the role to restrict access to only access department + putRoleRequest.setJsonEntity(""" + { + "indices": [{"names": [""], "privileges": ["read_cross_cluster"]}], + "remote_indices": [ + { + "names": ["employees*"], + "privileges": ["read"], + "clusters": ["my_remote_cluster"], + "query": {"term" : {"emp_id" : "21"}}, + "field_security": {"grant": [ "department" ]} + } + ] + }"""); + response = adminClient().performRequest(putRoleRequest); + assertOK(response); + + response = performRequestWithRemoteSearchUser(esqlRequest(""" + FROM my_remote_cluster:employees3 + | LIMIT 2 + """)); + assertOK(response); + responseAsMap = entityAsMap(response); + columns = (List) responseAsMap.get("columns"); + values = (List) responseAsMap.get("values"); + assertEquals(1, columns.size()); + assertEquals(1, values.size()); + flatList = values.stream() + .flatMap(innerList -> innerList instanceof List ? ((List) innerList).stream() : Stream.empty()) + .collect(Collectors.toList()); + // the APIKey has DLS set to : "query": {"term" : {"department" : "engineering"}} + // AND this role has DLS set to: "query": {"term" : {"emp_id" : "21"}} + // AND this role has FLS set to: "field_security": {"grant": [ "department" ]} + assertThat(flatList, containsInAnyOrder("engineering")); + } + + @SuppressWarnings("unchecked") + @AwaitsFix(bugUrl = "this trips ThreadPool.assertCurrentThreadPool(ThreadPool.Names.SEARCH_COORDINATION)") + // comment out those assertions in EsqlIndexResolver and TransportFieldCapabilitiesAction to see this test pass + public void testCrossClusterQueryAgainstInvalidRemote() throws Exception { + configureRemoteCluster(); + populateData(); + + // avoids getting 404 errors + updateClusterSettings( + randomBoolean() + ? Settings.builder().put("cluster.remote.invalid_remote.seeds", fulfillingCluster.getRemoteClusterServerEndpoint(0)).build() + : Settings.builder() + .put("cluster.remote.invalid_remote.mode", "proxy") + .put("cluster.remote.invalid_remote.proxy_address", fulfillingCluster.getRemoteClusterServerEndpoint(0)) + .build() + ); + + // invalid remote with local index should return local results + var q = "FROM invalid_remote:employees,employees | SORT emp_id DESC | LIMIT 10"; + Response response = performRequestWithRemoteSearchUser(esqlRequest(q)); + assertOK(response); + assertLocalOnlyResults(response); + + // only calling an invalid remote should error + ResponseException error = expectThrows(ResponseException.class, () -> { + var q2 = "FROM invalid_remote:employees | SORT emp_id DESC | LIMIT 10"; + performRequestWithRemoteSearchUser(esqlRequest(q2)); + }); + assertThat(error.getResponse().getStatusLine().getStatusCode(), equalTo(401)); + assertThat(error.getMessage(), containsString("unable to find apikey")); + } + + @SuppressWarnings("unchecked") + public void testCrossClusterQueryWithOnlyRemotePrivs() throws Exception { + configureRemoteCluster(); + populateData(); + // Query cluster - { + var putRoleRequest = new Request("PUT", "/_security/role/" + REMOTE_SEARCH_ROLE); + putRoleRequest.setJsonEntity(""" { - Response response = performRequestWithRemoteSearchUser(esqlRequest(""" - FROM my_remote_cluster:employees - | SORT emp_id ASC - | LIMIT 2 - | KEEP emp_id, department""")); - assertOK(response); - Map values = entityAsMap(response); - } + "indices": [{"names": [""], "privileges": ["read_cross_cluster"]}], + "remote_indices": [ + { + "names": ["employees"], + "privileges": ["read"], + "clusters": ["my_remote_cluster"] + } + ] + }"""); + assertOK(adminClient().performRequest(putRoleRequest)); + + // query appropriate privs + Response response = performRequestWithRemoteSearchUser(esqlRequest(""" + FROM my_remote_cluster:employees + | SORT emp_id ASC + | LIMIT 2 + | KEEP emp_id, department""")); + assertOK(response); + assertRemoteOnlyResults(response); + + // without the remote index priv + putRoleRequest.setJsonEntity(""" { - Response response = performRequestWithRemoteSearchUser(esqlRequest(""" - FROM my_remote_cluster:employees,employees - | SORT emp_id ASC - | LIMIT 10""")); - assertOK(response); + "indices": [{"names": [""], "privileges": ["read_cross_cluster"]}], + "remote_indices": [ + { + "names": ["idontexist"], + "privileges": ["read"], + "clusters": ["my_remote_cluster"] + } + ] + }"""); + assertOK(adminClient().performRequest(putRoleRequest)); - } - // Check that authentication fails if we use a non-existent API key - updateClusterSettings( - randomBoolean() - ? Settings.builder() - .put("cluster.remote.invalid_remote.seeds", fulfillingCluster.getRemoteClusterServerEndpoint(0)) - .build() - : Settings.builder() - .put("cluster.remote.invalid_remote.mode", "proxy") - .put("cluster.remote.invalid_remote.proxy_address", fulfillingCluster.getRemoteClusterServerEndpoint(0)) - .build() - ); - for (String indices : List.of("my_remote_cluster:employees,employees", "my_remote_cluster:employees")) { - ResponseException error = expectThrows(ResponseException.class, () -> { - var q = "FROM " + indices + "| SORT emp_id DESC | LIMIT 10"; - performRequestWithLocalSearchUser(esqlRequest(q)); - }); - assertThat(error.getResponse().getStatusLine().getStatusCode(), equalTo(403)); - assertThat(error.getResponse().getStatusLine().getStatusCode(), equalTo(401)); - assertThat(error.getMessage(), containsString("unable to find apikey")); - } - } + ResponseException error = expectThrows(ResponseException.class, () -> performRequestWithRemoteSearchUser(esqlRequest(""" + FROM my_remote_cluster:employees + | SORT emp_id ASC + | LIMIT 2 + | KEEP emp_id, department"""))); + assertThat(error.getResponse().getStatusLine().getStatusCode(), equalTo(400)); + assertThat(error.getMessage(), containsString("Unknown index [my_remote_cluster:employees]")); + + // no local privs at all will fail + final var putRoleNoLocalPrivs = new Request("PUT", "/_security/role/" + REMOTE_SEARCH_ROLE); + putRoleNoLocalPrivs.setJsonEntity(""" + { + "indices": [], + "remote_indices": [ + { + "names": ["employees"], + "privileges": ["read"], + "clusters": ["my_remote_cluster"] + } + ] + }"""); + assertOK(adminClient().performRequest(putRoleNoLocalPrivs)); + + error = expectThrows(ResponseException.class, () -> { performRequestWithRemoteSearchUser(esqlRequest(""" + FROM my_remote_cluster:employees + | SORT emp_id ASC + | LIMIT 2 + | KEEP emp_id, department""")); }); + + assertThat(error.getResponse().getStatusLine().getStatusCode(), equalTo(403)); + assertThat( + error.getMessage(), + containsString( + "action [indices:data/read/esql] is unauthorized for user [remote_search_user] with effective roles [remote_search], " + + "this action is granted by the index privileges [read,read_cross_cluster,all]" + ) + ); } @AwaitsFix(bugUrl = "cross-clusters enrich doesn't work with RCS 2.0") @@ -360,7 +638,7 @@ public void testCrossClusterEnrich() throws Exception { "remote_indices": [ { "names": ["employees"], - "privileges": ["read", "read_cross_cluster"], + "privileges": ["read"], "clusters": ["my_remote_cluster"] } ] @@ -434,4 +712,79 @@ private Response performRequestWithLocalSearchUser(final Request request) throws ); return client().performRequest(request); } + + @SuppressWarnings("unchecked") + private void assertRemoteOnlyResults(Response response) throws IOException { + Map responseAsMap = entityAsMap(response); + List columns = (List) responseAsMap.get("columns"); + List values = (List) responseAsMap.get("values"); + assertEquals(2, columns.size()); + assertEquals(2, values.size()); + List flatList = values.stream() + .flatMap(innerList -> innerList instanceof List ? ((List) innerList).stream() : Stream.empty()) + .collect(Collectors.toList()); + assertThat(flatList, containsInAnyOrder("1", "3", "engineering", "sales")); + } + + @SuppressWarnings("unchecked") + private void assertRemoteOnlyAgainst2IndexResults(Response response) throws IOException { + Map responseAsMap = entityAsMap(response); + List columns = (List) responseAsMap.get("columns"); + List values = (List) responseAsMap.get("values"); + assertEquals(2, columns.size()); + assertEquals(2, values.size()); + List flatList = values.stream() + .flatMap(innerList -> innerList instanceof List ? ((List) innerList).stream() : Stream.empty()) + .collect(Collectors.toList()); + assertThat(flatList, containsInAnyOrder("1", "11", "engineering", "engineering")); + } + + @SuppressWarnings("unchecked") + private void assertLocalOnlyResults(Response response) throws IOException { + Map responseAsMap = entityAsMap(response); + List columns = (List) responseAsMap.get("columns"); + List values = (List) responseAsMap.get("values"); + assertEquals(2, columns.size()); + assertEquals(4, values.size()); + List flatList = values.stream() + .flatMap(innerList -> innerList instanceof List ? ((List) innerList).stream() : Stream.empty()) + .collect(Collectors.toList()); + // local results + assertThat(flatList, containsInAnyOrder("2", "4", "6", "8", "support", "management", "engineering", "marketing")); + } + + @SuppressWarnings("unchecked") + private void assertRemoteAndLocalResults(Response response) throws IOException { + Map responseAsMap = entityAsMap(response); + List columns = (List) responseAsMap.get("columns"); + List values = (List) responseAsMap.get("values"); + assertEquals(2, columns.size()); + assertEquals(9, values.size()); + List flatList = values.stream() + .flatMap(innerList -> innerList instanceof List ? ((List) innerList).stream() : Stream.empty()) + .collect(Collectors.toList()); + assertThat( + flatList, + containsInAnyOrder( + "1", + "2", + "3", + "4", + "5", + "6", + "7", + "8", + "9", + "engineering", + "engineering", + "engineering", + "management", + "sales", + "sales", + "marketing", + "marketing", + "support" + ) + ); + } } diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptor.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptor.java index ca08f63a09bb0..462b41a519460 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptor.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptor.java @@ -77,7 +77,11 @@ public class SecurityServerTransportInterceptor implements TransportInterceptor "internal:admin/ccr/restore/session/clear", "indices:internal/admin/ccr/restore/session/clear", "internal:admin/ccr/restore/file_chunk/get", - "indices:internal/admin/ccr/restore/file_chunk/get" + "indices:internal/admin/ccr/restore/file_chunk/get", + "internal:data/read/esql/open_exchange", + "cluster:internal:data/read/esql/open_exchange", + "internal:data/read/esql/exchange", + "cluster:internal:data/read/esql/exchange" ); private final AuthenticationService authcService; From 05a849832988d03ed669f0bfe118017540a2e007 Mon Sep 17 00:00:00 2001 From: Mark Tozzi Date: Fri, 12 Apr 2024 11:07:24 -0400 Subject: [PATCH 11/20] [ESQL] Divide by zero should return null (#107076) This PR wires up our full suite of testing for div and mod. Doing so revealed that there were still cases where we returned NaNs and infinite values (mostly when dividing by zero), so I also fixed that to return null with a warning. The warning is not correct in all cases - it is possible to get an infinite without dividing by zero (e.g. a very large double divided by a number very close to zero, the result of which would exceed the maximum finite double value), but the error message will still report "/ by zero". I think this is fine for now, and we can always add finer grained warnings if we feel the need. --- .../xpack/esql/qa/rest/RestEsqlTestCase.java | 23 +- .../src/main/resources/math.csv-spec | 15 -- .../arithmetic/DivDoublesEvaluator.java | 21 +- .../arithmetic/ModDoublesEvaluator.java | 21 +- .../predicate/operator/arithmetic/Div.java | 18 +- .../predicate/operator/arithmetic/Mod.java | 18 +- .../function/AbstractFunctionTestCase.java | 1 + .../expression/function/TestCaseSupplier.java | 63 ++--- .../scalar/convert/ToIntegerTests.java | 8 +- .../function/scalar/convert/ToLongTests.java | 8 +- .../scalar/convert/ToUnsignedLongTests.java | 8 +- .../operator/arithmetic/AddTests.java | 22 +- .../operator/arithmetic/DivTests.java | 237 +++++++++--------- .../operator/arithmetic/ModTests.java | 237 +++++++++--------- .../LocalPhysicalPlanOptimizerTests.java | 15 +- 15 files changed, 383 insertions(+), 332 deletions(-) diff --git a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/RestEsqlTestCase.java b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/RestEsqlTestCase.java index 6883f71c9ee14..86d48aca3baed 100644 --- a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/RestEsqlTestCase.java +++ b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/RestEsqlTestCase.java @@ -94,8 +94,6 @@ public abstract class RestEsqlTestCase extends ESRestTestCase { // larger than any (unsigned) long private static final String HUMONGOUS_DOUBLE = "1E300"; - private static final String INFINITY = "1.0/0.0"; - private static final String NAN = "0.0/0.0"; public static boolean shouldLog() { return false; @@ -431,22 +429,19 @@ public void testOutOfRangeComparisons() throws IOException { String equalPlusMinus = randomFrom(" == ", " == -"); // TODO: once we do not support infinity and NaN anymore, remove INFINITY/NAN cases. // https://github.com/elastic/elasticsearch/issues/98698#issuecomment-1847423390 - String humongousPositiveLiteral = randomFrom(HUMONGOUS_DOUBLE, INFINITY); - String nanOrNull = randomFrom(NAN, "to_double(null)"); List trueForSingleValuesPredicates = List.of( - lessOrLessEqual + humongousPositiveLiteral, - largerOrLargerEqual + " -" + humongousPositiveLiteral, - inEqualPlusMinus + humongousPositiveLiteral, - inEqualPlusMinus + NAN + lessOrLessEqual + HUMONGOUS_DOUBLE, + largerOrLargerEqual + " -" + HUMONGOUS_DOUBLE, + inEqualPlusMinus + HUMONGOUS_DOUBLE ); List alwaysFalsePredicates = List.of( - lessOrLessEqual + " -" + humongousPositiveLiteral, - largerOrLargerEqual + humongousPositiveLiteral, - equalPlusMinus + humongousPositiveLiteral, - lessOrLessEqual + nanOrNull, - largerOrLargerEqual + nanOrNull, - equalPlusMinus + nanOrNull, + lessOrLessEqual + " -" + HUMONGOUS_DOUBLE, + largerOrLargerEqual + HUMONGOUS_DOUBLE, + equalPlusMinus + HUMONGOUS_DOUBLE, + lessOrLessEqual + "to_double(null)", + largerOrLargerEqual + "to_double(null)", + equalPlusMinus + "to_double(null)", inEqualPlusMinus + "to_double(null)" ); diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/math.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/math.csv-spec index 905eac30a3012..399e1b5dc791b 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/math.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/math.csv-spec @@ -1213,21 +1213,6 @@ a:double // end::floor-result[] ; -ceilFloorOfInfinite -row i = 1.0/0.0 | eval c = ceil(i), f = floor(i); - -i:double | c:double | f:double -Infinity | Infinity | Infinity -; - -ceilFloorOfNegativeInfinite -row i = -1.0/0.0 | eval c = ceil(i), f = floor(i); - -i:double | c:double | f:double --Infinity | -Infinity | -Infinity -; - - ceilFloorOfInteger row i = 1 | eval c = ceil(i), f = floor(i); diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DivDoublesEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DivDoublesEvaluator.java index bb9f55f2b5b85..88bf948749ffc 100644 --- a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DivDoublesEvaluator.java +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DivDoublesEvaluator.java @@ -4,6 +4,7 @@ // 2.0. package org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic; +import java.lang.ArithmeticException; import java.lang.IllegalArgumentException; import java.lang.Override; import java.lang.String; @@ -50,7 +51,7 @@ public Block eval(Page page) { if (rhsVector == null) { return eval(page.getPositionCount(), lhsBlock, rhsBlock); } - return eval(page.getPositionCount(), lhsVector, rhsVector).asBlock(); + return eval(page.getPositionCount(), lhsVector, rhsVector); } } } @@ -80,16 +81,26 @@ public DoubleBlock eval(int positionCount, DoubleBlock lhsBlock, DoubleBlock rhs result.appendNull(); continue position; } - result.appendDouble(Div.processDoubles(lhsBlock.getDouble(lhsBlock.getFirstValueIndex(p)), rhsBlock.getDouble(rhsBlock.getFirstValueIndex(p)))); + try { + result.appendDouble(Div.processDoubles(lhsBlock.getDouble(lhsBlock.getFirstValueIndex(p)), rhsBlock.getDouble(rhsBlock.getFirstValueIndex(p)))); + } catch (ArithmeticException e) { + warnings.registerException(e); + result.appendNull(); + } } return result.build(); } } - public DoubleVector eval(int positionCount, DoubleVector lhsVector, DoubleVector rhsVector) { - try(DoubleVector.Builder result = driverContext.blockFactory().newDoubleVectorBuilder(positionCount)) { + public DoubleBlock eval(int positionCount, DoubleVector lhsVector, DoubleVector rhsVector) { + try(DoubleBlock.Builder result = driverContext.blockFactory().newDoubleBlockBuilder(positionCount)) { position: for (int p = 0; p < positionCount; p++) { - result.appendDouble(Div.processDoubles(lhsVector.getDouble(p), rhsVector.getDouble(p))); + try { + result.appendDouble(Div.processDoubles(lhsVector.getDouble(p), rhsVector.getDouble(p))); + } catch (ArithmeticException e) { + warnings.registerException(e); + result.appendNull(); + } } return result.build(); } diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/ModDoublesEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/ModDoublesEvaluator.java index 8d441ffe10a48..3afcac77973fb 100644 --- a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/ModDoublesEvaluator.java +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/ModDoublesEvaluator.java @@ -4,6 +4,7 @@ // 2.0. package org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic; +import java.lang.ArithmeticException; import java.lang.IllegalArgumentException; import java.lang.Override; import java.lang.String; @@ -50,7 +51,7 @@ public Block eval(Page page) { if (rhsVector == null) { return eval(page.getPositionCount(), lhsBlock, rhsBlock); } - return eval(page.getPositionCount(), lhsVector, rhsVector).asBlock(); + return eval(page.getPositionCount(), lhsVector, rhsVector); } } } @@ -80,16 +81,26 @@ public DoubleBlock eval(int positionCount, DoubleBlock lhsBlock, DoubleBlock rhs result.appendNull(); continue position; } - result.appendDouble(Mod.processDoubles(lhsBlock.getDouble(lhsBlock.getFirstValueIndex(p)), rhsBlock.getDouble(rhsBlock.getFirstValueIndex(p)))); + try { + result.appendDouble(Mod.processDoubles(lhsBlock.getDouble(lhsBlock.getFirstValueIndex(p)), rhsBlock.getDouble(rhsBlock.getFirstValueIndex(p)))); + } catch (ArithmeticException e) { + warnings.registerException(e); + result.appendNull(); + } } return result.build(); } } - public DoubleVector eval(int positionCount, DoubleVector lhsVector, DoubleVector rhsVector) { - try(DoubleVector.Builder result = driverContext.blockFactory().newDoubleVectorBuilder(positionCount)) { + public DoubleBlock eval(int positionCount, DoubleVector lhsVector, DoubleVector rhsVector) { + try(DoubleBlock.Builder result = driverContext.blockFactory().newDoubleBlockBuilder(positionCount)) { position: for (int p = 0; p < positionCount; p++) { - result.appendDouble(Mod.processDoubles(lhsVector.getDouble(p), rhsVector.getDouble(p))); + try { + result.appendDouble(Mod.processDoubles(lhsVector.getDouble(p), rhsVector.getDouble(p))); + } catch (ArithmeticException e) { + warnings.registerException(e); + result.appendNull(); + } } return result.build(); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/Div.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/Div.java index 170e3de6e4209..73863d308f6e4 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/Div.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/Div.java @@ -13,6 +13,7 @@ import org.elasticsearch.xpack.ql.tree.NodeInfo; import org.elasticsearch.xpack.ql.tree.Source; import org.elasticsearch.xpack.ql.type.DataType; +import org.elasticsearch.xpack.ql.util.NumericUtils; import static org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.EsqlArithmeticOperation.OperationSymbol.DIV; import static org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter.longToUnsignedLong; @@ -63,21 +64,34 @@ public ArithmeticOperationFactory binaryComparisonInverse() { @Evaluator(extraName = "Ints", warnExceptions = { ArithmeticException.class }) static int processInts(int lhs, int rhs) { + if (rhs == 0) { + throw new ArithmeticException("/ by zero"); + } return lhs / rhs; } @Evaluator(extraName = "Longs", warnExceptions = { ArithmeticException.class }) static long processLongs(long lhs, long rhs) { + if (rhs == 0L) { + throw new ArithmeticException("/ by zero"); + } return lhs / rhs; } @Evaluator(extraName = "UnsignedLongs", warnExceptions = { ArithmeticException.class }) static long processUnsignedLongs(long lhs, long rhs) { + if (rhs == NumericUtils.ZERO_AS_UNSIGNED_LONG) { + throw new ArithmeticException("/ by zero"); + } return longToUnsignedLong(Long.divideUnsigned(longToUnsignedLong(lhs, true), longToUnsignedLong(rhs, true)), true); } - @Evaluator(extraName = "Doubles") + @Evaluator(extraName = "Doubles", warnExceptions = { ArithmeticException.class }) static double processDoubles(double lhs, double rhs) { - return lhs / rhs; + double value = lhs / rhs; + if (Double.isNaN(value) || Double.isInfinite(value)) { + throw new ArithmeticException("/ by zero"); + } + return value; } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/Mod.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/Mod.java index bc1ad8fcb5f94..df3b8f27c4880 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/Mod.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/Mod.java @@ -11,6 +11,7 @@ import org.elasticsearch.xpack.ql.expression.Expression; import org.elasticsearch.xpack.ql.tree.NodeInfo; import org.elasticsearch.xpack.ql.tree.Source; +import org.elasticsearch.xpack.ql.util.NumericUtils; import static org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.EsqlArithmeticOperation.OperationSymbol.MOD; import static org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter.longToUnsignedLong; @@ -42,21 +43,34 @@ protected Mod replaceChildren(Expression left, Expression right) { @Evaluator(extraName = "Ints", warnExceptions = { ArithmeticException.class }) static int processInts(int lhs, int rhs) { + if (rhs == 0) { + throw new ArithmeticException("/ by zero"); + } return lhs % rhs; } @Evaluator(extraName = "Longs", warnExceptions = { ArithmeticException.class }) static long processLongs(long lhs, long rhs) { + if (rhs == 0L) { + throw new ArithmeticException("/ by zero"); + } return lhs % rhs; } @Evaluator(extraName = "UnsignedLongs", warnExceptions = { ArithmeticException.class }) static long processUnsignedLongs(long lhs, long rhs) { + if (rhs == NumericUtils.ZERO_AS_UNSIGNED_LONG) { + throw new ArithmeticException("/ by zero"); + } return longToUnsignedLong(Long.remainderUnsigned(longToUnsignedLong(lhs, true), longToUnsignedLong(rhs, true)), true); } - @Evaluator(extraName = "Doubles") + @Evaluator(extraName = "Doubles", warnExceptions = { ArithmeticException.class }) static double processDoubles(double lhs, double rhs) { - return lhs % rhs; + double value = lhs % rhs; + if (Double.isNaN(value) || Double.isInfinite(value)) { + throw new ArithmeticException("/ by zero"); + } + return value; } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java index b2d00a98dfa6c..0b6c64679dc1f 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java @@ -857,6 +857,7 @@ protected static String typeErrorMessage(boolean includeOrdinal, List forBinaryCastingToDouble( return suppliers; } - private static void casesCrossProduct( + public static void casesCrossProduct( BinaryOperator expected, List lhsSuppliers, List rhsSuppliers, @@ -251,10 +251,10 @@ private static TestCaseSupplier testCaseSupplier( public static List castToDoubleSuppliersFromRange(Double Min, Double Max) { List suppliers = new ArrayList<>(); - suppliers.addAll(intCases(Min.intValue(), Max.intValue())); - suppliers.addAll(longCases(Min.longValue(), Max.longValue())); - suppliers.addAll(ulongCases(BigInteger.valueOf((long) Math.ceil(Min)), BigInteger.valueOf((long) Math.floor(Max)))); - suppliers.addAll(doubleCases(Min, Max)); + suppliers.addAll(intCases(Min.intValue(), Max.intValue(), true)); + suppliers.addAll(longCases(Min.longValue(), Max.longValue(), true)); + suppliers.addAll(ulongCases(BigInteger.valueOf((long) Math.ceil(Min)), BigInteger.valueOf((long) Math.floor(Max)), true)); + suppliers.addAll(doubleCases(Min, Max, true)); return suppliers; } @@ -279,7 +279,7 @@ public NumericTypeTestConfig get(DataType type) { } } - private static DataType widen(DataType lhs, DataType rhs) { + public static DataType widen(DataType lhs, DataType rhs) { if (lhs == rhs) { return lhs; } @@ -292,21 +292,22 @@ private static DataType widen(DataType lhs, DataType rhs) { throw new IllegalArgumentException("Invalid numeric widening lhs: [" + lhs + "] rhs: [" + rhs + "]"); } - private static List getSuppliersForNumericType(DataType type, Number min, Number max) { + public static List getSuppliersForNumericType(DataType type, Number min, Number max, boolean includeZero) { if (type == DataTypes.INTEGER) { - return intCases(NumericUtils.saturatingIntValue(min), NumericUtils.saturatingIntValue(max)); + return intCases(NumericUtils.saturatingIntValue(min), NumericUtils.saturatingIntValue(max), includeZero); } if (type == DataTypes.LONG) { - return longCases(min.longValue(), max.longValue()); + return longCases(min.longValue(), max.longValue(), includeZero); } if (type == DataTypes.UNSIGNED_LONG) { return ulongCases( min instanceof BigInteger ? (BigInteger) min : BigInteger.valueOf(Math.max(min.longValue(), 0L)), - max instanceof BigInteger ? (BigInteger) max : BigInteger.valueOf(Math.max(max.longValue(), 0L)) + max instanceof BigInteger ? (BigInteger) max : BigInteger.valueOf(Math.max(max.longValue(), 0L)), + includeZero ); } if (type == DataTypes.DOUBLE) { - return doubleCases(min.doubleValue(), max.doubleValue()); + return doubleCases(min.doubleValue(), max.doubleValue(), includeZero); } throw new IllegalArgumentException("bogus numeric type [" + type + "]"); } @@ -315,7 +316,8 @@ public static List forBinaryWithWidening( NumericTypeTestConfigs typeStuff, String lhsName, String rhsName, - List warnings + List warnings, + boolean allowRhsZero ) { List suppliers = new ArrayList<>(); List numericTypes = List.of(DataTypes.INTEGER, DataTypes.LONG, DataTypes.DOUBLE); @@ -336,13 +338,13 @@ public static List forBinaryWithWidening( + "]"; casesCrossProduct( (l, r) -> expectedTypeStuff.expected().apply((Number) l, (Number) r), - getSuppliersForNumericType(lhsType, expectedTypeStuff.min(), expectedTypeStuff.max()), - getSuppliersForNumericType(rhsType, expectedTypeStuff.min(), expectedTypeStuff.max()), + getSuppliersForNumericType(lhsType, expectedTypeStuff.min(), expectedTypeStuff.max(), true), + getSuppliersForNumericType(rhsType, expectedTypeStuff.min(), expectedTypeStuff.max(), allowRhsZero), evaluatorToString, warnings, suppliers, expected, - true + false ); } } @@ -358,7 +360,8 @@ public static List forBinaryNotCasting( DataType expectedType, List lhsSuppliers, List rhsSuppliers, - List warnings + List warnings, + boolean symmetric ) { List suppliers = new ArrayList<>(); casesCrossProduct( @@ -369,7 +372,7 @@ public static List forBinaryNotCasting( warnings, suppliers, expectedType, - true + symmetric ); return suppliers; } @@ -389,7 +392,7 @@ public static void forUnaryInt( unaryNumeric( suppliers, expectedEvaluatorToString, - intCases(lowerBound, upperBound), + intCases(lowerBound, upperBound, true), expectedType, n -> expectedValue.apply(n.intValue()), n -> expectedWarnings.apply(n.intValue()) @@ -423,7 +426,7 @@ public static void forUnaryLong( unaryNumeric( suppliers, expectedEvaluatorToString, - longCases(lowerBound, upperBound), + longCases(lowerBound, upperBound, true), expectedType, n -> expectedValue.apply(n.longValue()), expectedWarnings @@ -457,7 +460,7 @@ public static void forUnaryUnsignedLong( unaryNumeric( suppliers, expectedEvaluatorToString, - ulongCases(lowerBound, upperBound), + ulongCases(lowerBound, upperBound, true), expectedType, n -> expectedValue.apply((BigInteger) n), n -> expectedWarnings.apply((BigInteger) n) @@ -503,7 +506,7 @@ public static void forUnaryDouble( unaryNumeric( suppliers, expectedEvaluatorToString, - doubleCases(lowerBound, upperBound), + doubleCases(lowerBound, upperBound, true), expectedType, n -> expectedValue.apply(n.doubleValue()), n -> expectedWarnings.apply(n.doubleValue()) @@ -729,9 +732,9 @@ public static void unary( unary(suppliers, expectedEvaluatorToString, valueSuppliers, expectedOutputType, expected, unused -> warnings); } - public static List intCases(int min, int max) { + public static List intCases(int min, int max, boolean includeZero) { List cases = new ArrayList<>(); - if (0 <= max && 0 >= min) { + if (0 <= max && 0 >= min && includeZero) { cases.add(new TypedDataSupplier("<0 int>", () -> 0, DataTypes.INTEGER)); } @@ -753,9 +756,9 @@ public static List intCases(int min, int max) { return cases; } - public static List longCases(long min, long max) { + public static List longCases(long min, long max, boolean includeZero) { List cases = new ArrayList<>(); - if (0L <= max && 0L >= min) { + if (0L <= max && 0L >= min && includeZero) { cases.add(new TypedDataSupplier("<0 long>", () -> 0L, DataTypes.LONG)); } @@ -778,11 +781,11 @@ public static List longCases(long min, long max) { return cases; } - public static List ulongCases(BigInteger min, BigInteger max) { + public static List ulongCases(BigInteger min, BigInteger max, boolean includeZero) { List cases = new ArrayList<>(); // Zero - if (BigInteger.ZERO.compareTo(max) <= 0 && BigInteger.ZERO.compareTo(min) >= 0) { + if (BigInteger.ZERO.compareTo(max) <= 0 && BigInteger.ZERO.compareTo(min) >= 0 && includeZero) { cases.add(new TypedDataSupplier("<0 unsigned long>", () -> BigInteger.ZERO, DataTypes.UNSIGNED_LONG)); } @@ -818,11 +821,11 @@ public static List ulongCases(BigInteger min, BigInteger max) return cases; } - public static List doubleCases(double min, double max) { + public static List doubleCases(double min, double max, boolean includeZero) { List cases = new ArrayList<>(); // Zeros - if (0d <= max && 0d >= min) { + if (0d <= max && 0d >= min && includeZero) { cases.add(new TypedDataSupplier("<0 double>", () -> 0.0d, DataTypes.DOUBLE)); cases.add(new TypedDataSupplier("<-0 double>", () -> -0.0d, DataTypes.DOUBLE)); } @@ -1046,7 +1049,7 @@ public static List versionCases(String prefix) { ); } - private static String getCastEvaluator(String original, DataType current, DataType target) { + public static String getCastEvaluator(String original, DataType current, DataType target) { if (current == target) { return original; } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToIntegerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToIntegerTests.java index 3a6cb86b7a3c6..e6f6cb7e978f7 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToIntegerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToIntegerTests.java @@ -178,7 +178,7 @@ public static Iterable parameters() { TestCaseSupplier.unary( suppliers, evaluatorName.apply("String"), - TestCaseSupplier.intCases(Integer.MIN_VALUE, Integer.MAX_VALUE) + TestCaseSupplier.intCases(Integer.MIN_VALUE, Integer.MAX_VALUE, true) .stream() .map( tds -> new TestCaseSupplier.TypedDataSupplier( @@ -196,7 +196,7 @@ public static Iterable parameters() { TestCaseSupplier.unary( suppliers, evaluatorName.apply("String"), - TestCaseSupplier.doubleCases(Integer.MIN_VALUE, Integer.MAX_VALUE) + TestCaseSupplier.doubleCases(Integer.MIN_VALUE, Integer.MAX_VALUE, true) .stream() .map( tds -> new TestCaseSupplier.TypedDataSupplier( @@ -214,7 +214,7 @@ public static Iterable parameters() { TestCaseSupplier.unary( suppliers, evaluatorName.apply("String"), - TestCaseSupplier.doubleCases(Double.NEGATIVE_INFINITY, Integer.MIN_VALUE - 1d) + TestCaseSupplier.doubleCases(Double.NEGATIVE_INFINITY, Integer.MIN_VALUE - 1d, true) .stream() .map( tds -> new TestCaseSupplier.TypedDataSupplier( @@ -237,7 +237,7 @@ public static Iterable parameters() { TestCaseSupplier.unary( suppliers, evaluatorName.apply("String"), - TestCaseSupplier.doubleCases(Integer.MAX_VALUE + 1d, Double.POSITIVE_INFINITY) + TestCaseSupplier.doubleCases(Integer.MAX_VALUE + 1d, Double.POSITIVE_INFINITY, true) .stream() .map( tds -> new TestCaseSupplier.TypedDataSupplier( diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToLongTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToLongTests.java index 031ce6193bcc4..1879b7ce97ea8 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToLongTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToLongTests.java @@ -129,7 +129,7 @@ public static Iterable parameters() { TestCaseSupplier.unary( suppliers, evaluatorName.apply("String"), - TestCaseSupplier.longCases(Long.MIN_VALUE, Long.MAX_VALUE) + TestCaseSupplier.longCases(Long.MIN_VALUE, Long.MAX_VALUE, true) .stream() .map( tds -> new TestCaseSupplier.TypedDataSupplier( @@ -147,7 +147,7 @@ public static Iterable parameters() { TestCaseSupplier.unary( suppliers, evaluatorName.apply("String"), - TestCaseSupplier.doubleCases(Long.MIN_VALUE, Long.MAX_VALUE) + TestCaseSupplier.doubleCases(Long.MIN_VALUE, Long.MAX_VALUE, true) .stream() .map( tds -> new TestCaseSupplier.TypedDataSupplier( @@ -165,7 +165,7 @@ public static Iterable parameters() { TestCaseSupplier.unary( suppliers, evaluatorName.apply("String"), - TestCaseSupplier.doubleCases(Double.NEGATIVE_INFINITY, Long.MIN_VALUE - 1d) + TestCaseSupplier.doubleCases(Double.NEGATIVE_INFINITY, Long.MIN_VALUE - 1d, true) .stream() .map( tds -> new TestCaseSupplier.TypedDataSupplier( @@ -188,7 +188,7 @@ public static Iterable parameters() { TestCaseSupplier.unary( suppliers, evaluatorName.apply("String"), - TestCaseSupplier.doubleCases(Long.MAX_VALUE + 1d, Double.POSITIVE_INFINITY) + TestCaseSupplier.doubleCases(Long.MAX_VALUE + 1d, Double.POSITIVE_INFINITY, true) .stream() .map( tds -> new TestCaseSupplier.TypedDataSupplier( diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToUnsignedLongTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToUnsignedLongTests.java index 8d5ee002a8f78..3cb9c813fd0b5 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToUnsignedLongTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/convert/ToUnsignedLongTests.java @@ -165,7 +165,7 @@ public static Iterable parameters() { TestCaseSupplier.unary( suppliers, evaluatorName.apply("String"), - TestCaseSupplier.ulongCases(BigInteger.ZERO, UNSIGNED_LONG_MAX) + TestCaseSupplier.ulongCases(BigInteger.ZERO, UNSIGNED_LONG_MAX, true) .stream() .map( tds -> new TestCaseSupplier.TypedDataSupplier( @@ -183,7 +183,7 @@ public static Iterable parameters() { TestCaseSupplier.unary( suppliers, evaluatorName.apply("String"), - TestCaseSupplier.doubleCases(0, UNSIGNED_LONG_MAX_AS_DOUBLE) + TestCaseSupplier.doubleCases(0, UNSIGNED_LONG_MAX_AS_DOUBLE, true) .stream() .map( tds -> new TestCaseSupplier.TypedDataSupplier( @@ -201,7 +201,7 @@ public static Iterable parameters() { TestCaseSupplier.unary( suppliers, evaluatorName.apply("String"), - TestCaseSupplier.doubleCases(Double.NEGATIVE_INFINITY, -1d) + TestCaseSupplier.doubleCases(Double.NEGATIVE_INFINITY, -1d, true) .stream() .map( tds -> new TestCaseSupplier.TypedDataSupplier( @@ -224,7 +224,7 @@ public static Iterable parameters() { TestCaseSupplier.unary( suppliers, evaluatorName.apply("String"), - TestCaseSupplier.doubleCases(UNSIGNED_LONG_MAX_AS_DOUBLE + 10e5, Double.POSITIVE_INFINITY) + TestCaseSupplier.doubleCases(UNSIGNED_LONG_MAX_AS_DOUBLE + 10e5, Double.POSITIVE_INFINITY, true) .stream() .map( tds -> new TestCaseSupplier.TypedDataSupplier( diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/AddTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/AddTests.java index 6a74dd13c1e3a..143f7e5aaba9f 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/AddTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/AddTests.java @@ -65,7 +65,8 @@ public static Iterable parameters() { ), "lhs", "rhs", - List.of() + List.of(), + true ) ); @@ -79,9 +80,10 @@ public static Iterable parameters() { "rhs", (l, r) -> (((BigInteger) l).add((BigInteger) r)), DataTypes.UNSIGNED_LONG, - TestCaseSupplier.ulongCases(BigInteger.ZERO, BigInteger.valueOf(Long.MAX_VALUE)), - TestCaseSupplier.ulongCases(BigInteger.ZERO, BigInteger.valueOf(Long.MAX_VALUE)), - List.of() + TestCaseSupplier.ulongCases(BigInteger.ZERO, BigInteger.valueOf(Long.MAX_VALUE), true), + TestCaseSupplier.ulongCases(BigInteger.ZERO, BigInteger.valueOf(Long.MAX_VALUE), true), + List.of(), + true ) ); @@ -96,7 +98,8 @@ public static Iterable parameters() { EsqlDataTypes.DATE_PERIOD, TestCaseSupplier.datePeriodCases(), TestCaseSupplier.datePeriodCases(), - List.of() + List.of(), + true ) ); suppliers.addAll( @@ -108,7 +111,8 @@ public static Iterable parameters() { EsqlDataTypes.TIME_DURATION, TestCaseSupplier.timeDurationCases(), TestCaseSupplier.timeDurationCases(), - List.of() + List.of(), + true ) ); @@ -134,7 +138,8 @@ public static Iterable parameters() { DataTypes.DATETIME, TestCaseSupplier.dateCases(), TestCaseSupplier.datePeriodCases(), - List.of() + List.of(), + true ) ); suppliers.addAll( @@ -159,7 +164,8 @@ public static Iterable parameters() { DataTypes.DATETIME, TestCaseSupplier.dateCases(), TestCaseSupplier.timeDurationCases(), - List.of() + List.of(), + true ) ); suppliers.addAll(TestCaseSupplier.dateCases().stream().mapMulti((tds, consumer) -> { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DivTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DivTests.java index 4aa8786f2cd69..1f5d57394ff4d 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DivTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/DivTests.java @@ -10,7 +10,7 @@ import com.carrotsearch.randomizedtesting.annotations.Name; import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; -import org.elasticsearch.compute.data.Block; +import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; import org.elasticsearch.xpack.ql.expression.Expression; import org.elasticsearch.xpack.ql.tree.Source; @@ -18,140 +18,149 @@ import org.elasticsearch.xpack.ql.type.DataTypes; import java.math.BigInteger; +import java.util.ArrayList; import java.util.List; +import java.util.Set; +import java.util.function.BiFunction; import java.util.function.Supplier; -import static org.elasticsearch.compute.data.BlockUtils.toJavaObject; -import static org.elasticsearch.xpack.ql.util.NumericUtils.ZERO_AS_UNSIGNED_LONG; -import static org.elasticsearch.xpack.ql.util.NumericUtils.asLongUnsigned; -import static org.elasticsearch.xpack.ql.util.NumericUtils.unsignedLongAsBigInteger; -import static org.hamcrest.Matchers.equalTo; - -public class DivTests extends AbstractArithmeticTestCase { +public class DivTests extends AbstractFunctionTestCase { public DivTests(@Name("TestCase") Supplier testCaseSupplier) { this.testCase = testCaseSupplier.get(); } @ParametersFactory public static Iterable parameters() { - return parameterSuppliersFromTypedData(List.of(new TestCaseSupplier("Int / Int", () -> { - int lhs = randomInt(); - int rhs; - do { - rhs = randomInt(); - } while (rhs == 0); - return new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData(lhs, DataTypes.INTEGER, "lhs"), - new TestCaseSupplier.TypedData(rhs, DataTypes.INTEGER, "rhs") - ), - "DivIntsEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1]]", - DataTypes.INTEGER, - equalTo(lhs / rhs) - ); - }), new TestCaseSupplier("Long / Long", () -> { - long lhs = randomLong(); - long rhs; - do { - rhs = randomLong(); - } while (rhs == 0); - return new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData(lhs, DataTypes.LONG, "lhs"), - new TestCaseSupplier.TypedData(rhs, DataTypes.LONG, "rhs") + List suppliers = new ArrayList<>(); + suppliers.addAll( + TestCaseSupplier.forBinaryWithWidening( + new TestCaseSupplier.NumericTypeTestConfigs( + new TestCaseSupplier.NumericTypeTestConfig( + (Integer.MIN_VALUE >> 1) - 1, + (Integer.MAX_VALUE >> 1) - 1, + (l, r) -> l.intValue() / r.intValue(), + "DivIntsEvaluator" + ), + new TestCaseSupplier.NumericTypeTestConfig( + (Long.MIN_VALUE >> 1) - 1, + (Long.MAX_VALUE >> 1) - 1, + (l, r) -> l.longValue() / r.longValue(), + "DivLongsEvaluator" + ), + new TestCaseSupplier.NumericTypeTestConfig( + Double.NEGATIVE_INFINITY, + Double.POSITIVE_INFINITY, + (l, r) -> l.doubleValue() / r.doubleValue(), + "DivDoublesEvaluator" + ) ), - "DivLongsEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1]]", - DataTypes.LONG, - equalTo(lhs / rhs) - ); - }), new TestCaseSupplier("Double / Double", () -> { - double lhs = randomDouble(); - double rhs; - do { - rhs = randomDouble(); - } while (rhs == 0); - return new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData(lhs, DataTypes.DOUBLE, "lhs"), - new TestCaseSupplier.TypedData(rhs, DataTypes.DOUBLE, "rhs") - ), - "DivDoublesEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1]]", - DataTypes.DOUBLE, - equalTo(lhs / rhs) - ); - })/*, new TestCaseSupplier("ULong / ULong", () -> { - // Ensure we don't have an overflow - long lhs = randomLong(); - long rhs; - do { - rhs = randomLong(); - } while (rhs == 0); - BigInteger lhsBI = unsignedLongAsBigInteger(lhs); - BigInteger rhsBI = unsignedLongAsBigInteger(rhs); - return new TestCase( - Source.EMPTY, - List.of(new TypedData(lhs, DataTypes.UNSIGNED_LONG, "lhs"), new TypedData(rhs, DataTypes.UNSIGNED_LONG, "rhs")), - "DivUnsignedLongsEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1]]", - equalTo(asLongUnsigned(lhsBI.divide(rhsBI).longValue())) - ); - }) - */ - )); - } + "lhs", + "rhs", + List.of(), + false + ) + ); + suppliers.addAll( + TestCaseSupplier.forBinaryNotCasting( + "DivUnsignedLongsEvaluator", + "lhs", + "rhs", + (l, r) -> (((BigInteger) l).divide((BigInteger) r)), + DataTypes.UNSIGNED_LONG, + TestCaseSupplier.ulongCases(BigInteger.ZERO, BigInteger.valueOf(Long.MAX_VALUE), true), + TestCaseSupplier.ulongCases(BigInteger.ONE, BigInteger.valueOf(Long.MAX_VALUE), true), + List.of(), + false + ) + ); - // run dedicated test to avoid the JVM optimized ArithmeticException that lacks a message - public void testDivisionByZero() { - DataType testCaseType = testCase.getData().get(0).type(); - List data = switch (testCaseType.typeName()) { - case "INTEGER" -> List.of(randomInt(), 0); - case "LONG" -> List.of(randomLong(), 0L); - case "UNSIGNED_LONG" -> List.of(randomLong(), ZERO_AS_UNSIGNED_LONG); - default -> null; - }; - if (data != null) { - var op = build(Source.EMPTY, field("lhs", testCaseType), field("rhs", testCaseType)); - try (Block block = evaluator(op).get(driverContext()).eval(row(data))) { - assertCriticalWarnings( - "Line -1:-1: evaluation of [] failed, treating result as null. Only first 20 failures recorded.", - "Line -1:-1: java.lang.ArithmeticException: / by zero" + suppliers = errorsForCasesWithoutExamples(anyNullIsNull(true, suppliers), DivTests::divErrorMessageString); + + // Divide by zero cases - all of these should warn and return null + TestCaseSupplier.NumericTypeTestConfigs typeStuff = new TestCaseSupplier.NumericTypeTestConfigs( + new TestCaseSupplier.NumericTypeTestConfig( + (Integer.MIN_VALUE >> 1) - 1, + (Integer.MAX_VALUE >> 1) - 1, + (l, r) -> null, + "DivIntsEvaluator" + ), + new TestCaseSupplier.NumericTypeTestConfig( + (Long.MIN_VALUE >> 1) - 1, + (Long.MAX_VALUE >> 1) - 1, + (l, r) -> null, + "DivLongsEvaluator" + ), + new TestCaseSupplier.NumericTypeTestConfig( + Double.NEGATIVE_INFINITY, + Double.POSITIVE_INFINITY, + (l, r) -> null, + "DivDoublesEvaluator" + ) + ); + List numericTypes = List.of(DataTypes.INTEGER, DataTypes.LONG, DataTypes.DOUBLE); + + for (DataType lhsType : numericTypes) { + for (DataType rhsType : numericTypes) { + DataType expected = TestCaseSupplier.widen(lhsType, rhsType); + TestCaseSupplier.NumericTypeTestConfig expectedTypeStuff = typeStuff.get(expected); + BiFunction evaluatorToString = (lhs, rhs) -> expectedTypeStuff.evaluatorName() + + "[" + + "lhs" + + "=" + + TestCaseSupplier.getCastEvaluator("Attribute[channel=0]", lhs, expected) + + ", " + + "rhs" + + "=" + + TestCaseSupplier.getCastEvaluator("Attribute[channel=1]", rhs, expected) + + "]"; + TestCaseSupplier.casesCrossProduct( + (l1, r1) -> expectedTypeStuff.expected().apply((Number) l1, (Number) r1), + TestCaseSupplier.getSuppliersForNumericType(lhsType, expectedTypeStuff.min(), expectedTypeStuff.max(), true), + TestCaseSupplier.getSuppliersForNumericType(rhsType, 0, 0, true), + evaluatorToString, + List.of( + "Line -1:-1: evaluation of [] failed, treating result as null. Only first 20 failures recorded.", + "Line -1:-1: java.lang.ArithmeticException: / by zero" + ), + suppliers, + expected, + false ); - assertNull(toJavaObject(block, 0)); } } - } - @Override - protected boolean rhsOk(Object o) { - if (o instanceof Number n) { - return n.doubleValue() != 0; - } - return true; - } - - @Override - protected Div build(Source source, Expression lhs, Expression rhs) { - return new Div(source, lhs, rhs); - } + suppliers.addAll( + TestCaseSupplier.forBinaryNotCasting( + "DivUnsignedLongsEvaluator", + "lhs", + "rhs", + (l, r) -> null, + DataTypes.UNSIGNED_LONG, + TestCaseSupplier.ulongCases(BigInteger.ZERO, BigInteger.valueOf(Long.MAX_VALUE), true), + TestCaseSupplier.ulongCases(BigInteger.ZERO, BigInteger.ZERO, true), + List.of( + "Line -1:-1: evaluation of [] failed, treating result as null. Only first 20 failures recorded.", + "Line -1:-1: java.lang.ArithmeticException: / by zero" + ), + false + ) + ); - @Override - protected double expectedValue(double lhs, double rhs) { - return lhs / rhs; + return parameterSuppliersFromTypedData(suppliers); } - @Override - protected int expectedValue(int lhs, int rhs) { - return lhs / rhs; - } + private static String divErrorMessageString(boolean includeOrdinal, List> validPerPosition, List types) { + try { + return typeErrorMessage(includeOrdinal, validPerPosition, types); + } catch (IllegalStateException e) { + // This means all the positional args were okay, so the expected error is from the combination + return "[/] has arguments with incompatible types [" + types.get(0).typeName() + "] and [" + types.get(1).typeName() + "]"; - @Override - protected long expectedValue(long lhs, long rhs) { - return lhs / rhs; + } } @Override - protected long expectedUnsignedLongValue(long lhs, long rhs) { - BigInteger lhsBI = unsignedLongAsBigInteger(lhs); - BigInteger rhsBI = unsignedLongAsBigInteger(rhs); - return asLongUnsigned(lhsBI.divide(rhsBI).longValue()); + protected Expression build(Source source, List args) { + return new Div(source, args.get(0), args.get(1)); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/ModTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/ModTests.java index 5beaf0b782af7..03fbbf6a21ebe 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/ModTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/arithmetic/ModTests.java @@ -10,7 +10,7 @@ import com.carrotsearch.randomizedtesting.annotations.Name; import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; -import org.elasticsearch.compute.data.Block; +import org.elasticsearch.xpack.esql.expression.function.AbstractFunctionTestCase; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; import org.elasticsearch.xpack.ql.expression.Expression; import org.elasticsearch.xpack.ql.tree.Source; @@ -18,140 +18,149 @@ import org.elasticsearch.xpack.ql.type.DataTypes; import java.math.BigInteger; +import java.util.ArrayList; import java.util.List; +import java.util.Set; +import java.util.function.BiFunction; import java.util.function.Supplier; -import static org.elasticsearch.compute.data.BlockUtils.toJavaObject; -import static org.elasticsearch.xpack.ql.util.NumericUtils.ZERO_AS_UNSIGNED_LONG; -import static org.elasticsearch.xpack.ql.util.NumericUtils.asLongUnsigned; -import static org.elasticsearch.xpack.ql.util.NumericUtils.unsignedLongAsBigInteger; -import static org.hamcrest.Matchers.equalTo; - -public class ModTests extends AbstractArithmeticTestCase { +public class ModTests extends AbstractFunctionTestCase { public ModTests(@Name("TestCase") Supplier testCaseSupplier) { this.testCase = testCaseSupplier.get(); } @ParametersFactory public static Iterable parameters() { - return parameterSuppliersFromTypedData(List.of(new TestCaseSupplier("Int % Int", () -> { - int lhs = randomInt(); - int rhs; - do { - rhs = randomInt(); - } while (rhs == 0); - return new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData(lhs, DataTypes.INTEGER, "lhs"), - new TestCaseSupplier.TypedData(rhs, DataTypes.INTEGER, "rhs") - ), - "ModIntsEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1]]", - DataTypes.INTEGER, - equalTo(lhs % rhs) - ); - }), new TestCaseSupplier("Long % Long", () -> { - long lhs = randomLong(); - long rhs; - do { - rhs = randomLong(); - } while (rhs == 0); - return new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData(lhs, DataTypes.LONG, "lhs"), - new TestCaseSupplier.TypedData(rhs, DataTypes.LONG, "rhs") + List suppliers = new ArrayList<>(); + suppliers.addAll( + TestCaseSupplier.forBinaryWithWidening( + new TestCaseSupplier.NumericTypeTestConfigs( + new TestCaseSupplier.NumericTypeTestConfig( + (Integer.MIN_VALUE >> 1) - 1, + (Integer.MAX_VALUE >> 1) - 1, + (l, r) -> l.intValue() % r.intValue(), + "ModIntsEvaluator" + ), + new TestCaseSupplier.NumericTypeTestConfig( + (Long.MIN_VALUE >> 1) - 1, + (Long.MAX_VALUE >> 1) - 1, + (l, r) -> l.longValue() % r.longValue(), + "ModLongsEvaluator" + ), + new TestCaseSupplier.NumericTypeTestConfig( + Double.NEGATIVE_INFINITY, + Double.POSITIVE_INFINITY, + (l, r) -> l.doubleValue() % r.doubleValue(), + "ModDoublesEvaluator" + ) ), - "ModLongsEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1]]", - DataTypes.LONG, - equalTo(lhs % rhs) - ); - }), new TestCaseSupplier("Double % Double", () -> { - double lhs = randomDouble(); - double rhs; - do { - rhs = randomDouble(); - } while (rhs == 0); - return new TestCaseSupplier.TestCase( - List.of( - new TestCaseSupplier.TypedData(lhs, DataTypes.DOUBLE, "lhs"), - new TestCaseSupplier.TypedData(rhs, DataTypes.DOUBLE, "rhs") - ), - "ModDoublesEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1]]", - DataTypes.DOUBLE, - equalTo(lhs % rhs) - ); - })/*, new TestCaseSupplier("ULong % ULong", () -> { - // Ensure we don't have an overflow - long lhs = randomLong(); - long rhs; - do { - rhs = randomLong(); - } while (rhs == 0); - BigInteger lhsBI = unsignedLongAsBigInteger(lhs); - BigInteger rhsBI = unsignedLongAsBigInteger(rhs); - return new TestCase( - Source.EMPTY, - List.of(new TypedData(lhs, DataTypes.UNSIGNED_LONG, "lhs"), new TypedData(rhs, DataTypes.UNSIGNED_LONG, "rhs")), - "ModUnsignedLongsEvaluator[lhs=Attribute[channel=0], rhs=Attribute[channel=1]]", - equalTo(asLongUnsigned(lhsBI.mod(rhsBI).longValue())) - ); - }) - */ - )); - } + "lhs", + "rhs", + List.of(), + false + ) + ); + suppliers.addAll( + TestCaseSupplier.forBinaryNotCasting( + "ModUnsignedLongsEvaluator", + "lhs", + "rhs", + (l, r) -> (((BigInteger) l).mod((BigInteger) r)), + DataTypes.UNSIGNED_LONG, + TestCaseSupplier.ulongCases(BigInteger.ZERO, BigInteger.valueOf(Long.MAX_VALUE), true), + TestCaseSupplier.ulongCases(BigInteger.ONE, BigInteger.valueOf(Long.MAX_VALUE), true), + List.of(), + false + ) + ); - // run dedicated test to avoid the JVM optimized ArithmeticException that lacks a message - public void testDivisionByZero() { - DataType testCaseType = testCase.getData().get(0).type(); - List data = switch (testCaseType.typeName()) { - case "INTEGER" -> List.of(randomInt(), 0); - case "LONG" -> List.of(randomLong(), 0L); - case "UNSIGNED_LONG" -> List.of(randomLong(), ZERO_AS_UNSIGNED_LONG); - default -> null; - }; - if (data != null) { - var op = build(Source.EMPTY, field("lhs", testCaseType), field("rhs", testCaseType)); - try (Block block = evaluator(op).get(driverContext()).eval(row(data))) { - assertCriticalWarnings( - "Line -1:-1: evaluation of [] failed, treating result as null. Only first 20 failures recorded.", - "Line -1:-1: java.lang.ArithmeticException: / by zero" + suppliers = errorsForCasesWithoutExamples(anyNullIsNull(true, suppliers), ModTests::modErrorMessageString); + + // Divide by zero cases - all of these should warn and return null + TestCaseSupplier.NumericTypeTestConfigs typeStuff = new TestCaseSupplier.NumericTypeTestConfigs( + new TestCaseSupplier.NumericTypeTestConfig( + (Integer.MIN_VALUE >> 1) - 1, + (Integer.MAX_VALUE >> 1) - 1, + (l, r) -> null, + "ModIntsEvaluator" + ), + new TestCaseSupplier.NumericTypeTestConfig( + (Long.MIN_VALUE >> 1) - 1, + (Long.MAX_VALUE >> 1) - 1, + (l, r) -> null, + "ModLongsEvaluator" + ), + new TestCaseSupplier.NumericTypeTestConfig( + Double.NEGATIVE_INFINITY, + Double.POSITIVE_INFINITY, + (l, r) -> null, + "ModDoublesEvaluator" + ) + ); + List numericTypes = List.of(DataTypes.INTEGER, DataTypes.LONG, DataTypes.DOUBLE); + + for (DataType lhsType : numericTypes) { + for (DataType rhsType : numericTypes) { + DataType expected = TestCaseSupplier.widen(lhsType, rhsType); + TestCaseSupplier.NumericTypeTestConfig expectedTypeStuff = typeStuff.get(expected); + BiFunction evaluatorToString = (lhs, rhs) -> expectedTypeStuff.evaluatorName() + + "[" + + "lhs" + + "=" + + TestCaseSupplier.getCastEvaluator("Attribute[channel=0]", lhs, expected) + + ", " + + "rhs" + + "=" + + TestCaseSupplier.getCastEvaluator("Attribute[channel=1]", rhs, expected) + + "]"; + TestCaseSupplier.casesCrossProduct( + (l1, r1) -> expectedTypeStuff.expected().apply((Number) l1, (Number) r1), + TestCaseSupplier.getSuppliersForNumericType(lhsType, expectedTypeStuff.min(), expectedTypeStuff.max(), true), + TestCaseSupplier.getSuppliersForNumericType(rhsType, 0, 0, true), + evaluatorToString, + List.of( + "Line -1:-1: evaluation of [] failed, treating result as null. Only first 20 failures recorded.", + "Line -1:-1: java.lang.ArithmeticException: / by zero" + ), + suppliers, + expected, + false ); - assertNull(toJavaObject(block, 0)); } } - } - @Override - protected boolean rhsOk(Object o) { - if (o instanceof Number n) { - return n.doubleValue() != 0; - } - return true; - } - - @Override - protected Mod build(Source source, Expression lhs, Expression rhs) { - return new Mod(source, lhs, rhs); - } + suppliers.addAll( + TestCaseSupplier.forBinaryNotCasting( + "ModUnsignedLongsEvaluator", + "lhs", + "rhs", + (l, r) -> null, + DataTypes.UNSIGNED_LONG, + TestCaseSupplier.ulongCases(BigInteger.ZERO, BigInteger.valueOf(Long.MAX_VALUE), true), + TestCaseSupplier.ulongCases(BigInteger.ZERO, BigInteger.ZERO, true), + List.of( + "Line -1:-1: evaluation of [] failed, treating result as null. Only first 20 failures recorded.", + "Line -1:-1: java.lang.ArithmeticException: / by zero" + ), + false + ) + ); - @Override - protected double expectedValue(double lhs, double rhs) { - return lhs % rhs; + return parameterSuppliersFromTypedData(suppliers); } - @Override - protected int expectedValue(int lhs, int rhs) { - return lhs % rhs; - } + private static String modErrorMessageString(boolean includeOrdinal, List> validPerPosition, List types) { + try { + return typeErrorMessage(includeOrdinal, validPerPosition, types); + } catch (IllegalStateException e) { + // This means all the positional args were okay, so the expected error is from the combination + return "[%] has arguments with incompatible types [" + types.get(0).typeName() + "] and [" + types.get(1).typeName() + "]"; - @Override - protected long expectedValue(long lhs, long rhs) { - return lhs % rhs; + } } @Override - protected long expectedUnsignedLongValue(long lhs, long rhs) { - BigInteger lhsBI = unsignedLongAsBigInteger(lhs); - BigInteger rhsBI = unsignedLongAsBigInteger(rhs); - return asLongUnsigned(lhsBI.mod(rhsBI).longValue()); + protected Expression build(Source source, List args) { + return new Mod(source, args.get(0), args.get(1)); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java index 80deb0ea83d86..f6aeb89faff0e 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizerTests.java @@ -657,11 +657,9 @@ public void testOutOfRangeFilterPushdown() { new OutOfRangeTestCase("byte", smallerThanInteger, largerThanInteger), new OutOfRangeTestCase("short", smallerThanInteger, largerThanInteger), new OutOfRangeTestCase("integer", smallerThanInteger, largerThanInteger), - new OutOfRangeTestCase("long", smallerThanLong, largerThanLong), + new OutOfRangeTestCase("long", smallerThanLong, largerThanLong) // TODO: add unsigned_long https://github.com/elastic/elasticsearch/issues/102935 // TODO: add half_float, float https://github.com/elastic/elasticsearch/issues/100130 - new OutOfRangeTestCase("double", "-1.0/0.0", "1.0/0.0"), - new OutOfRangeTestCase("scaled_float", "-1.0/0.0", "1.0/0.0") ); final String LT = "<"; @@ -678,8 +676,7 @@ public void testOutOfRangeFilterPushdown() { GT + testCase.tooLow, GTE + testCase.tooLow, NEQ + testCase.tooHigh, - NEQ + testCase.tooLow, - NEQ + "0.0/0.0" + NEQ + testCase.tooLow ); List alwaysFalsePredicates = List.of( LT + testCase.tooLow, @@ -687,12 +684,7 @@ public void testOutOfRangeFilterPushdown() { GT + testCase.tooHigh, GTE + testCase.tooHigh, EQ + testCase.tooHigh, - EQ + testCase.tooLow, - LT + "0.0/0.0", - LTE + "0.0/0.0", - GT + "0.0/0.0", - GTE + "0.0/0.0", - EQ + "0.0/0.0" + EQ + testCase.tooLow ); for (String truePredicate : trueForSingleValuesPredicates) { @@ -700,6 +692,7 @@ public void testOutOfRangeFilterPushdown() { var query = "from test | where " + comparison; Source expectedSource = new Source(1, 18, comparison); + logger.info("Query: " + query); EsQueryExec actualQueryExec = doTestOutOfRangeFilterPushdown(query, allTypeMappingAnalyzer); assertThat(actualQueryExec.query(), is(instanceOf(SingleValueQuery.Builder.class))); From 5ebb14abbe514183e733101ccb697836d1a79334 Mon Sep 17 00:00:00 2001 From: Rene Groeschke Date: Fri, 12 Apr 2024 17:25:13 +0200 Subject: [PATCH 12/20] Fix encoding issue on windows for doc snippet handling unit testing (#107427) Fixes #107386 --- .../gradle/internal/doc/DocSnippetTaskSpec.groovy | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/build-tools-internal/src/test/groovy/org/elasticsearch/gradle/internal/doc/DocSnippetTaskSpec.groovy b/build-tools-internal/src/test/groovy/org/elasticsearch/gradle/internal/doc/DocSnippetTaskSpec.groovy index 85ce3c1804474..96888357d8433 100644 --- a/build-tools-internal/src/test/groovy/org/elasticsearch/gradle/internal/doc/DocSnippetTaskSpec.groovy +++ b/build-tools-internal/src/test/groovy/org/elasticsearch/gradle/internal/doc/DocSnippetTaskSpec.groovy @@ -535,7 +535,7 @@ GET /_analyze ] } ], - "text": "My license plate is ٢٥٠١٥" + "text": "My license plate is empty" } ---- """ @@ -557,7 +557,7 @@ GET /_analyze ] } ], - "text": "My license plate is ٢٥٠١٥" + "text": "My license plate is empty" }""" } From beed68d1b98331f7ea4c7a6c89c0d65a49d9c8ef Mon Sep 17 00:00:00 2001 From: Niels Bauman <33722607+nielsbauman@users.noreply.github.com> Date: Fri, 12 Apr 2024 18:06:03 +0200 Subject: [PATCH 13/20] Rename failure store field and methods (#107399) This name better reflects that the field is a boolean. --- .../action/GetDataStreamsTransportAction.java | 2 +- .../DeleteDataStreamTransportActionTests.java | 2 +- .../action/GetDataStreamsResponseTests.java | 8 ++-- .../admin/indices/get/GetIndexRequest.java | 2 +- .../rollover/MetadataRolloverService.java | 2 +- .../action/bulk/BulkOperation.java | 6 +-- .../action/bulk/BulkRequestModifier.java | 2 +- .../action/bulk/TransportBulkAction.java | 4 +- .../datastreams/GetDataStreamAction.java | 6 +-- .../action/index/IndexRequest.java | 2 +- .../action/support/IndicesOptions.java | 16 +++---- .../metadata/ComposableIndexTemplate.java | 6 +-- .../cluster/metadata/DataStream.java | 48 ++++++++++--------- .../cluster/metadata/DataStreamAction.java | 6 +-- .../metadata/IndexNameExpressionResolver.java | 12 +++-- .../cluster/metadata/Metadata.java | 6 +-- .../MetadataCreateDataStreamService.java | 2 +- .../indices/RestRolloverIndexAction.java | 2 +- .../MetadataRolloverServiceTests.java | 2 +- .../action/bulk/BulkOperationTests.java | 14 +++--- .../action/bulk/TransportBulkActionTests.java | 4 +- .../cluster/metadata/DataStreamTests.java | 2 +- .../metadata/DataStreamTestHelper.java | 10 ++-- 23 files changed, 87 insertions(+), 79 deletions(-) diff --git a/modules/data-streams/src/main/java/org/elasticsearch/datastreams/action/GetDataStreamsTransportAction.java b/modules/data-streams/src/main/java/org/elasticsearch/datastreams/action/GetDataStreamsTransportAction.java index 41e62508cafbb..0fc00ad9ebe59 100644 --- a/modules/data-streams/src/main/java/org/elasticsearch/datastreams/action/GetDataStreamsTransportAction.java +++ b/modules/data-streams/src/main/java/org/elasticsearch/datastreams/action/GetDataStreamsTransportAction.java @@ -139,7 +139,7 @@ static GetDataStreamAction.Response innerOperation( Map backingIndicesSettingsValues = new HashMap<>(); Metadata metadata = state.getMetadata(); collectIndexSettingsValues(dataStream, backingIndicesSettingsValues, metadata, dataStream.getIndices()); - if (DataStream.isFailureStoreEnabled() && dataStream.getFailureIndices().isEmpty() == false) { + if (DataStream.isFailureStoreFeatureFlagEnabled() && dataStream.getFailureIndices().isEmpty() == false) { collectIndexSettingsValues(dataStream, backingIndicesSettingsValues, metadata, dataStream.getFailureIndices()); } diff --git a/modules/data-streams/src/test/java/org/elasticsearch/datastreams/action/DeleteDataStreamTransportActionTests.java b/modules/data-streams/src/test/java/org/elasticsearch/datastreams/action/DeleteDataStreamTransportActionTests.java index a5c3b348b1f1b..d394db9523cce 100644 --- a/modules/data-streams/src/test/java/org/elasticsearch/datastreams/action/DeleteDataStreamTransportActionTests.java +++ b/modules/data-streams/src/test/java/org/elasticsearch/datastreams/action/DeleteDataStreamTransportActionTests.java @@ -57,7 +57,7 @@ public void testDeleteDataStream() { } public void testDeleteDataStreamWithFailureStore() { - Assume.assumeTrue(DataStream.isFailureStoreEnabled()); + Assume.assumeTrue(DataStream.isFailureStoreFeatureFlagEnabled()); final String dataStreamName = "my-data-stream"; final List otherIndices = randomSubsetOf(List.of("foo", "bar", "baz")); diff --git a/modules/data-streams/src/test/java/org/elasticsearch/datastreams/action/GetDataStreamsResponseTests.java b/modules/data-streams/src/test/java/org/elasticsearch/datastreams/action/GetDataStreamsResponseTests.java index 2f268ea705a59..ec6e624794a03 100644 --- a/modules/data-streams/src/test/java/org/elasticsearch/datastreams/action/GetDataStreamsResponseTests.java +++ b/modules/data-streams/src/test/java/org/elasticsearch/datastreams/action/GetDataStreamsResponseTests.java @@ -81,7 +81,7 @@ public void testResponseIlmAndDataStreamLifecycleRepresentation() throws Excepti .setAllowCustomRouting(true) .setIndexMode(IndexMode.STANDARD) .setLifecycle(new DataStreamLifecycle()) - .setFailureStore(true) + .setFailureStoreEnabled(true) .setFailureIndices(failureStores) .build(); @@ -158,7 +158,7 @@ public void testResponseIlmAndDataStreamLifecycleRepresentation() throws Excepti is(ManagedBy.LIFECYCLE.displayValue) ); - if (DataStream.isFailureStoreEnabled()) { + if (DataStream.isFailureStoreFeatureFlagEnabled()) { List failureStoresRepresentation = (List) dataStreamMap.get( DataStream.FAILURE_INDICES_FIELD.getPreferredName() ); @@ -184,7 +184,7 @@ public void testResponseIlmAndDataStreamLifecycleRepresentation() throws Excepti .setAllowCustomRouting(true) .setIndexMode(IndexMode.STANDARD) .setLifecycle(new DataStreamLifecycle(null, null, false)) - .setFailureStore(true) + .setFailureStoreEnabled(true) .setFailureIndices(failureStores) .build(); @@ -250,7 +250,7 @@ public void testResponseIlmAndDataStreamLifecycleRepresentation() throws Excepti is(ManagedBy.UNMANAGED.displayValue) ); - if (DataStream.isFailureStoreEnabled()) { + if (DataStream.isFailureStoreFeatureFlagEnabled()) { List failureStoresRepresentation = (List) dataStreamMap.get( DataStream.FAILURE_INDICES_FIELD.getPreferredName() ); diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/get/GetIndexRequest.java b/server/src/main/java/org/elasticsearch/action/admin/indices/get/GetIndexRequest.java index a550350c20f6b..0b94e89fcc64d 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/get/GetIndexRequest.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/get/GetIndexRequest.java @@ -95,7 +95,7 @@ public static Feature[] fromRequest(RestRequest request) { public GetIndexRequest() { super( - DataStream.isFailureStoreEnabled() + DataStream.isFailureStoreFeatureFlagEnabled() ? IndicesOptions.builder(IndicesOptions.strictExpandOpen()) .failureStoreOptions( IndicesOptions.FailureStoreOptions.builder().includeRegularIndices(true).includeFailureIndices(true) diff --git a/server/src/main/java/org/elasticsearch/action/admin/indices/rollover/MetadataRolloverService.java b/server/src/main/java/org/elasticsearch/action/admin/indices/rollover/MetadataRolloverService.java index 75852098170c6..cef0b3797b1d4 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/indices/rollover/MetadataRolloverService.java +++ b/server/src/main/java/org/elasticsearch/action/admin/indices/rollover/MetadataRolloverService.java @@ -640,7 +640,7 @@ static void validate( ); } var dataStream = (DataStream) indexAbstraction; - if (isFailureStoreRollover && dataStream.isFailureStore() == false) { + if (isFailureStoreRollover && dataStream.isFailureStoreEnabled() == false) { throw new IllegalArgumentException( "unable to roll over failure store because [" + indexAbstraction.getName() + "] does not have the failure store enabled" ); diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java index 412e4f3c875e8..ea4d278227849 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java @@ -306,7 +306,7 @@ private void executeBulkRequestsByShard( } private void redirectFailuresOrCompleteBulkOperation() { - if (DataStream.isFailureStoreEnabled() && failureStoreRedirects.isEmpty() == false) { + if (DataStream.isFailureStoreFeatureFlagEnabled() && failureStoreRedirects.isEmpty() == false) { doRedirectFailures(); } else { completeBulkOperation(); @@ -412,7 +412,7 @@ private void completeShardOperation() { */ private static String getRedirectTarget(DocWriteRequest docWriteRequest, Metadata metadata) { // Feature flag guard - if (DataStream.isFailureStoreEnabled() == false) { + if (DataStream.isFailureStoreFeatureFlagEnabled() == false) { return null; } // Do not resolve a failure store for documents that were already headed to one @@ -431,7 +431,7 @@ private static String getRedirectTarget(DocWriteRequest docWriteRequest, Meta Index concreteIndex = ia.getWriteIndex(); IndexAbstraction writeIndexAbstraction = metadata.getIndicesLookup().get(concreteIndex.getName()); DataStream parentDataStream = writeIndexAbstraction.getParentDataStream(); - if (parentDataStream != null && parentDataStream.isFailureStore()) { + if (parentDataStream != null && parentDataStream.isFailureStoreEnabled()) { // Keep the data stream name around to resolve the redirect to failure store if the shard level request fails. return parentDataStream.getName(); } diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkRequestModifier.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkRequestModifier.java index 2112ad48bec62..d0a75bdf109c5 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkRequestModifier.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkRequestModifier.java @@ -215,7 +215,7 @@ synchronized void markItemAsDropped(int slot) { * @param e the failure encountered. */ public void markItemForFailureStore(int slot, String targetIndexName, Exception e) { - if (DataStream.isFailureStoreEnabled() == false) { + if (DataStream.isFailureStoreFeatureFlagEnabled() == false) { // Assert false for development, but if we somehow find ourselves here, default to failure logic. assert false : "Attempting to route a failed write request type to a failure store but the failure store is not enabled! " diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java index 3494701cf5b7a..13c4009cbc3e2 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java @@ -738,7 +738,7 @@ public boolean isForceExecution() { * or if it matches a template that has a data stream failure store enabled. */ static boolean shouldStoreFailure(String indexName, Metadata metadata, long epochMillis) { - return DataStream.isFailureStoreEnabled() + return DataStream.isFailureStoreFeatureFlagEnabled() && resolveFailureStoreFromMetadata(indexName, metadata, epochMillis).or( () -> resolveFailureStoreFromTemplate(indexName, metadata) ).orElse(false); @@ -774,7 +774,7 @@ private static Optional resolveFailureStoreFromMetadata(String indexNam DataStream targetDataStream = writeAbstraction.getParentDataStream(); // We will store the failure if the write target belongs to a data stream with a failure store. - return Optional.of(targetDataStream != null && targetDataStream.isFailureStore()); + return Optional.of(targetDataStream != null && targetDataStream.isFailureStoreEnabled()); } /** diff --git a/server/src/main/java/org/elasticsearch/action/datastreams/GetDataStreamAction.java b/server/src/main/java/org/elasticsearch/action/datastreams/GetDataStreamAction.java index 36f2ff4fffa96..1a2103d665b38 100644 --- a/server/src/main/java/org/elasticsearch/action/datastreams/GetDataStreamAction.java +++ b/server/src/main/java/org/elasticsearch/action/datastreams/GetDataStreamAction.java @@ -319,7 +319,7 @@ public XContentBuilder toXContent( builder.endArray(); } builder.field(DataStream.GENERATION_FIELD.getPreferredName(), dataStream.getGeneration()); - if (DataStream.isFailureStoreEnabled()) { + if (DataStream.isFailureStoreFeatureFlagEnabled()) { builder.field(DataStream.FAILURE_INDICES_FIELD.getPreferredName()); builder.startArray(); for (Index failureStore : dataStream.getFailureIndices()) { @@ -358,8 +358,8 @@ public XContentBuilder toXContent( builder.field(ALLOW_CUSTOM_ROUTING.getPreferredName(), dataStream.isAllowCustomRouting()); builder.field(REPLICATED.getPreferredName(), dataStream.isReplicated()); builder.field(ROLLOVER_ON_WRITE.getPreferredName(), dataStream.rolloverOnWrite()); - if (DataStream.isFailureStoreEnabled()) { - builder.field(DataStream.FAILURE_STORE_FIELD.getPreferredName(), dataStream.isFailureStore()); + if (DataStream.isFailureStoreFeatureFlagEnabled()) { + builder.field(DataStream.FAILURE_STORE_FIELD.getPreferredName(), dataStream.isFailureStoreEnabled()); } if (dataStream.getAutoShardingEvent() != null) { DataStreamAutoShardingEvent autoShardingEvent = dataStream.getAutoShardingEvent(); diff --git a/server/src/main/java/org/elasticsearch/action/index/IndexRequest.java b/server/src/main/java/org/elasticsearch/action/index/IndexRequest.java index d142db2d5a1ab..9d0eeb20dacef 100644 --- a/server/src/main/java/org/elasticsearch/action/index/IndexRequest.java +++ b/server/src/main/java/org/elasticsearch/action/index/IndexRequest.java @@ -858,7 +858,7 @@ public IndexRequest setRequireDataStream(boolean requireDataStream) { @Override public Index getConcreteWriteIndex(IndexAbstraction ia, Metadata metadata) { - if (DataStream.isFailureStoreEnabled() && writeToFailureStore) { + if (DataStream.isFailureStoreFeatureFlagEnabled() && writeToFailureStore) { if (ia.isDataStreamRelated() == false) { throw new ElasticsearchException( "Attempting to write a document to a failure store but the targeted index is not a data stream" diff --git a/server/src/main/java/org/elasticsearch/action/support/IndicesOptions.java b/server/src/main/java/org/elasticsearch/action/support/IndicesOptions.java index e46a7bd5f0ec2..1070a5d0bddd0 100644 --- a/server/src/main/java/org/elasticsearch/action/support/IndicesOptions.java +++ b/server/src/main/java/org/elasticsearch/action/support/IndicesOptions.java @@ -1109,7 +1109,7 @@ public static IndicesOptions fromRequest(RestRequest request, IndicesOptions def request.param(ConcreteTargetOptions.IGNORE_UNAVAILABLE), request.param(WildcardOptions.ALLOW_NO_INDICES), request.param(GatekeeperOptions.IGNORE_THROTTLED), - DataStream.isFailureStoreEnabled() + DataStream.isFailureStoreFeatureFlagEnabled() ? request.param(FailureStoreOptions.FAILURE_STORE) : FailureStoreOptions.INCLUDE_ONLY_REGULAR_INDICES, defaultSettings @@ -1117,7 +1117,7 @@ public static IndicesOptions fromRequest(RestRequest request, IndicesOptions def } public static IndicesOptions fromMap(Map map, IndicesOptions defaultSettings) { - if (DataStream.isFailureStoreEnabled()) { + if (DataStream.isFailureStoreFeatureFlagEnabled()) { return fromParameters( map.containsKey(WildcardOptions.EXPAND_WILDCARDS) ? map.get(WildcardOptions.EXPAND_WILDCARDS) : map.get("expandWildcards"), map.containsKey(ConcreteTargetOptions.IGNORE_UNAVAILABLE) @@ -1155,8 +1155,8 @@ public static boolean isIndicesOptions(String name) { || "ignoreThrottled".equals(name) || WildcardOptions.ALLOW_NO_INDICES.equals(name) || "allowNoIndices".equals(name) - || (DataStream.isFailureStoreEnabled() && FailureStoreOptions.FAILURE_STORE.equals(name)) - || (DataStream.isFailureStoreEnabled() && "failureStore".equals(name)); + || (DataStream.isFailureStoreFeatureFlagEnabled() && FailureStoreOptions.FAILURE_STORE.equals(name)) + || (DataStream.isFailureStoreFeatureFlagEnabled() && "failureStore".equals(name)); } public static IndicesOptions fromParameters( @@ -1187,7 +1187,7 @@ public static IndicesOptions fromParameters( WildcardOptions wildcards = WildcardOptions.parseParameters(wildcardsString, allowNoIndicesString, defaultSettings.wildcardOptions); GatekeeperOptions gatekeeperOptions = GatekeeperOptions.parseParameter(ignoreThrottled, defaultSettings.gatekeeperOptions); - FailureStoreOptions failureStoreOptions = DataStream.isFailureStoreEnabled() + FailureStoreOptions failureStoreOptions = DataStream.isFailureStoreFeatureFlagEnabled() ? FailureStoreOptions.parseParameters(failureStoreString, defaultSettings.failureStoreOptions) : FailureStoreOptions.DEFAULT; @@ -1205,7 +1205,7 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par concreteTargetOptions.toXContent(builder, params); wildcardOptions.toXContent(builder, params); gatekeeperOptions.toXContent(builder, params); - if (DataStream.isFailureStoreEnabled()) { + if (DataStream.isFailureStoreFeatureFlagEnabled()) { failureStoreOptions.toXContent(builder, params); } return builder; @@ -1276,7 +1276,7 @@ public static IndicesOptions fromXContent(XContentParser parser, @Nullable Indic allowNoIndices = parser.booleanValue(); } else if (IGNORE_THROTTLED_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { generalOptions.ignoreThrottled(parser.booleanValue()); - } else if (DataStream.isFailureStoreEnabled() + } else if (DataStream.isFailureStoreFeatureFlagEnabled() && FAILURE_STORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { failureStoreOptions = FailureStoreOptions.parseParameters(parser.text(), failureStoreOptions); } else { @@ -1423,7 +1423,7 @@ public String toString() { + ignoreAliases() + ", ignore_throttled=" + ignoreThrottled() - + (DataStream.isFailureStoreEnabled() + + (DataStream.isFailureStoreFeatureFlagEnabled() ? ", include_regular_indices=" + includeRegularIndices() + ", include_failure_indices=" diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/ComposableIndexTemplate.java b/server/src/main/java/org/elasticsearch/cluster/metadata/ComposableIndexTemplate.java index 8e8e6fff4cc6a..e6e48bfbd46b3 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/ComposableIndexTemplate.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/ComposableIndexTemplate.java @@ -376,14 +376,14 @@ public static class DataStreamTemplate implements Writeable, ToXContentObject { args -> new DataStreamTemplate( args[0] != null && (boolean) args[0], args[1] != null && (boolean) args[1], - DataStream.isFailureStoreEnabled() && args[2] != null && (boolean) args[2] + DataStream.isFailureStoreFeatureFlagEnabled() && args[2] != null && (boolean) args[2] ) ); static { PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), HIDDEN); PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), ALLOW_CUSTOM_ROUTING); - if (DataStream.isFailureStoreEnabled()) { + if (DataStream.isFailureStoreFeatureFlagEnabled()) { PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), FAILURE_STORE); } } @@ -478,7 +478,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(); builder.field("hidden", hidden); builder.field(ALLOW_CUSTOM_ROUTING.getPreferredName(), allowCustomRouting); - if (DataStream.isFailureStoreEnabled()) { + if (DataStream.isFailureStoreFeatureFlagEnabled()) { builder.field(FAILURE_STORE.getPreferredName(), failureStore); } builder.endObject(); diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/DataStream.java b/server/src/main/java/org/elasticsearch/cluster/metadata/DataStream.java index c9d37a9efebc8..33dab20a81494 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/DataStream.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/DataStream.java @@ -73,7 +73,7 @@ public final class DataStream implements SimpleDiffable, ToXContentO public static final TransportVersion ADDED_FAILURE_STORE_TRANSPORT_VERSION = TransportVersions.V_8_12_0; public static final TransportVersion ADDED_AUTO_SHARDING_EVENT_VERSION = TransportVersions.DATA_STREAM_AUTO_SHARDING_EVENT; - public static boolean isFailureStoreEnabled() { + public static boolean isFailureStoreFeatureFlagEnabled() { return FAILURE_STORE_FEATURE_FLAG.isEnabled(); } @@ -115,7 +115,7 @@ public static boolean isFailureStoreEnabled() { @Nullable private final DataStreamLifecycle lifecycle; private final boolean rolloverOnWrite; - private final boolean failureStore; + private final boolean failureStoreEnabled; private final List failureIndices; private volatile Set failureStoreLookup; @Nullable @@ -132,7 +132,7 @@ public DataStream( boolean allowCustomRouting, IndexMode indexMode, DataStreamLifecycle lifecycle, - boolean failureStore, + boolean failureStoreEnabled, List failureIndices, boolean rolloverOnWrite, @Nullable DataStreamAutoShardingEvent autoShardingEvent @@ -149,7 +149,7 @@ public DataStream( allowCustomRouting, indexMode, lifecycle, - failureStore, + failureStoreEnabled, failureIndices, rolloverOnWrite, autoShardingEvent @@ -169,7 +169,7 @@ public DataStream( boolean allowCustomRouting, IndexMode indexMode, DataStreamLifecycle lifecycle, - boolean failureStore, + boolean failureStoreEnabled, List failureIndices, boolean rolloverOnWrite, @Nullable DataStreamAutoShardingEvent autoShardingEvent @@ -187,7 +187,7 @@ public DataStream( this.allowCustomRouting = allowCustomRouting; this.indexMode = indexMode; this.lifecycle = lifecycle; - this.failureStore = failureStore; + this.failureStoreEnabled = failureStoreEnabled; this.failureIndices = failureIndices; assert assertConsistent(this.indices); assert replicated == false || rolloverOnWrite == false : "replicated data streams cannot be marked for lazy rollover"; @@ -243,7 +243,7 @@ public Index getWriteIndex() { */ @Nullable public Index getFailureStoreWriteIndex() { - return isFailureStore() == false || failureIndices.isEmpty() ? null : failureIndices.get(failureIndices.size() - 1); + return isFailureStoreEnabled() == false || failureIndices.isEmpty() ? null : failureIndices.get(failureIndices.size() - 1); } /** @@ -389,8 +389,8 @@ public boolean isAllowCustomRouting() { * * @return Whether this data stream should store ingestion failures. */ - public boolean isFailureStore() { - return failureStore; + public boolean isFailureStoreEnabled() { + return failureStoreEnabled; } @Nullable @@ -1017,7 +1017,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalWriteable(lifecycle); } if (out.getTransportVersion().onOrAfter(DataStream.ADDED_FAILURE_STORE_TRANSPORT_VERSION)) { - out.writeBoolean(failureStore); + out.writeBoolean(failureStoreEnabled); out.writeCollection(failureIndices); } if (out.getTransportVersion().onOrAfter(TransportVersions.LAZY_ROLLOVER_ADDED)) { @@ -1048,8 +1048,10 @@ public void writeTo(StreamOutput out) throws IOException { private static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("data_stream", args -> { // Fields behind a feature flag need to be parsed last otherwise the parser will fail when the feature flag is disabled. // Until the feature flag is removed we keep them separately to be mindful of this. - boolean failureStoreEnabled = DataStream.isFailureStoreEnabled() && args[12] != null && (boolean) args[12]; - List failureStoreIndices = DataStream.isFailureStoreEnabled() && args[13] != null ? (List) args[13] : List.of(); + boolean failureStoreEnabled = DataStream.isFailureStoreFeatureFlagEnabled() && args[12] != null && (boolean) args[12]; + List failureStoreIndices = DataStream.isFailureStoreFeatureFlagEnabled() && args[13] != null + ? (List) args[13] + : List.of(); return new DataStream( (String) args[0], (List) args[1], @@ -1094,7 +1096,7 @@ public void writeTo(StreamOutput out) throws IOException { AUTO_SHARDING_FIELD ); // The fields behind the feature flag should always be last. - if (DataStream.isFailureStoreEnabled()) { + if (DataStream.isFailureStoreFeatureFlagEnabled()) { PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), FAILURE_STORE_FIELD); PARSER.declareObjectArray( ConstructingObjectParser.optionalConstructorArg(), @@ -1130,7 +1132,7 @@ public XContentBuilder toXContent( .endObject(); builder.xContentList(INDICES_FIELD.getPreferredName(), indices); builder.field(GENERATION_FIELD.getPreferredName(), generation); - if (DataStream.isFailureStoreEnabled() && failureIndices.isEmpty() == false) { + if (DataStream.isFailureStoreFeatureFlagEnabled() && failureIndices.isEmpty() == false) { builder.xContentList(FAILURE_INDICES_FIELD.getPreferredName(), failureIndices); } if (metadata != null) { @@ -1140,8 +1142,8 @@ public XContentBuilder toXContent( builder.field(REPLICATED_FIELD.getPreferredName(), replicated); builder.field(SYSTEM_FIELD.getPreferredName(), system); builder.field(ALLOW_CUSTOM_ROUTING.getPreferredName(), allowCustomRouting); - if (DataStream.isFailureStoreEnabled()) { - builder.field(FAILURE_STORE_FIELD.getPreferredName(), failureStore); + if (DataStream.isFailureStoreFeatureFlagEnabled()) { + builder.field(FAILURE_STORE_FIELD.getPreferredName(), failureStoreEnabled); } if (indexMode != null) { builder.field(INDEX_MODE.getPreferredName(), indexMode); @@ -1175,7 +1177,7 @@ public boolean equals(Object o) { && allowCustomRouting == that.allowCustomRouting && indexMode == that.indexMode && Objects.equals(lifecycle, that.lifecycle) - && failureStore == that.failureStore + && failureStoreEnabled == that.failureStoreEnabled && failureIndices.equals(that.failureIndices) && rolloverOnWrite == that.rolloverOnWrite && Objects.equals(autoShardingEvent, that.autoShardingEvent); @@ -1194,7 +1196,7 @@ public int hashCode() { allowCustomRouting, indexMode, lifecycle, - failureStore, + failureStoreEnabled, failureIndices, rolloverOnWrite, autoShardingEvent @@ -1361,7 +1363,7 @@ public static class Builder { @Nullable private DataStreamLifecycle lifecycle = null; private boolean rolloverOnWrite = false; - private boolean failureStore = false; + private boolean failureStoreEnabled = false; private List failureIndices = List.of(); @Nullable private DataStreamAutoShardingEvent autoShardingEvent = null; @@ -1385,7 +1387,7 @@ public Builder(DataStream dataStream) { indexMode = dataStream.indexMode; lifecycle = dataStream.lifecycle; rolloverOnWrite = dataStream.rolloverOnWrite; - failureStore = dataStream.failureStore; + failureStoreEnabled = dataStream.failureStoreEnabled; failureIndices = dataStream.failureIndices; autoShardingEvent = dataStream.autoShardingEvent; } @@ -1451,8 +1453,8 @@ public Builder setRolloverOnWrite(boolean rolloverOnWrite) { return this; } - public Builder setFailureStore(boolean failureStore) { - this.failureStore = failureStore; + public Builder setFailureStoreEnabled(boolean failureStoreEnabled) { + this.failureStoreEnabled = failureStoreEnabled; return this; } @@ -1479,7 +1481,7 @@ public DataStream build() { allowCustomRouting, indexMode, lifecycle, - failureStore, + failureStoreEnabled, failureIndices, rolloverOnWrite, autoShardingEvent diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/DataStreamAction.java b/server/src/main/java/org/elasticsearch/cluster/metadata/DataStreamAction.java index 0148315c322be..f260b48cd7b7a 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/DataStreamAction.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/DataStreamAction.java @@ -142,7 +142,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.startObject(type.fieldName); builder.field(DATA_STREAM.getPreferredName(), dataStream); builder.field(INDEX.getPreferredName(), index); - if (DataStream.isFailureStoreEnabled() && failureStore) { + if (DataStream.isFailureStoreFeatureFlagEnabled() && failureStore) { builder.field(FAILURE_STORE.getPreferredName(), failureStore); } builder.endObject(); @@ -180,7 +180,7 @@ public static DataStreamAction fromXContent(XContentParser parser) throws IOExce ObjectParser.ValueType.STRING ); ADD_BACKING_INDEX_PARSER.declareField(DataStreamAction::setIndex, XContentParser::text, INDEX, ObjectParser.ValueType.STRING); - if (DataStream.isFailureStoreEnabled()) { + if (DataStream.isFailureStoreFeatureFlagEnabled()) { ADD_BACKING_INDEX_PARSER.declareField( DataStreamAction::setFailureStore, XContentParser::booleanValue, @@ -195,7 +195,7 @@ public static DataStreamAction fromXContent(XContentParser parser) throws IOExce ObjectParser.ValueType.STRING ); REMOVE_BACKING_INDEX_PARSER.declareField(DataStreamAction::setIndex, XContentParser::text, INDEX, ObjectParser.ValueType.STRING); - if (DataStream.isFailureStoreEnabled()) { + if (DataStream.isFailureStoreFeatureFlagEnabled()) { REMOVE_BACKING_INDEX_PARSER.declareField( DataStreamAction::setFailureStore, XContentParser::booleanValue, diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexNameExpressionResolver.java b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexNameExpressionResolver.java index b88292d4ed79b..effc89d8e535a 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/IndexNameExpressionResolver.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/IndexNameExpressionResolver.java @@ -387,7 +387,7 @@ Index[] concreteIndices(Context context, String... indexExpressions) { resolveIndicesForDataStream(context, (DataStream) indexAbstraction, concreteIndicesResult); } else if (indexAbstraction.getType() == Type.ALIAS && indexAbstraction.isDataStreamRelated() - && DataStream.isFailureStoreEnabled() + && DataStream.isFailureStoreFeatureFlagEnabled() && context.getOptions().includeFailureIndices()) { // Collect the data streams involved Set aliasDataStreams = new HashSet<>(); @@ -453,11 +453,13 @@ private static void resolveWriteIndexForDataStreams(Context context, DataStream } private static boolean shouldIncludeRegularIndices(IndicesOptions indicesOptions) { - return DataStream.isFailureStoreEnabled() == false || indicesOptions.includeRegularIndices(); + return DataStream.isFailureStoreFeatureFlagEnabled() == false || indicesOptions.includeRegularIndices(); } private static boolean shouldIncludeFailureIndices(IndicesOptions indicesOptions, DataStream dataStream) { - return DataStream.isFailureStoreEnabled() && indicesOptions.includeFailureIndices() && dataStream.isFailureStore(); + return DataStream.isFailureStoreFeatureFlagEnabled() + && indicesOptions.includeFailureIndices() + && dataStream.isFailureStoreEnabled(); } private static boolean resolvesToMoreThanOneIndex(IndexAbstraction indexAbstraction, Context context) { @@ -566,11 +568,11 @@ private static boolean shouldTrackConcreteIndex(Context context, IndicesOptions // Exclude this one as it's a net-new system index, and we explicitly don't want those. return false; } - if (DataStream.isFailureStoreEnabled()) { + if (DataStream.isFailureStoreFeatureFlagEnabled()) { IndexAbstraction indexAbstraction = context.getState().metadata().getIndicesLookup().get(index.getName()); if (context.options.allowFailureIndices() == false) { DataStream parentDataStream = indexAbstraction.getParentDataStream(); - if (parentDataStream != null && parentDataStream.isFailureStore()) { + if (parentDataStream != null && parentDataStream.isFailureStoreEnabled()) { if (parentDataStream.isFailureStoreIndex(index.getName())) { if (options.ignoreUnavailable()) { return false; diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/Metadata.java b/server/src/main/java/org/elasticsearch/cluster/metadata/Metadata.java index f424861c5b7ff..fec209960597b 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/Metadata.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/Metadata.java @@ -2598,8 +2598,8 @@ private static void collectIndices( private static boolean assertContainsIndexIfDataStream(DataStream parent, IndexMetadata indexMetadata) { assert parent == null || parent.getIndices().stream().anyMatch(index -> indexMetadata.getIndex().getName().equals(index.getName())) - || (DataStream.isFailureStoreEnabled() - && parent.isFailureStore() + || (DataStream.isFailureStoreFeatureFlagEnabled() + && parent.isFailureStoreEnabled() && parent.getFailureIndices().stream().anyMatch(index -> indexMetadata.getIndex().getName().equals(index.getName()))) : "Expected data stream [" + parent.getName() + "] to contain index " + indexMetadata.getIndex(); return true; @@ -2622,7 +2622,7 @@ private static void collectDataStreams( for (Index i : dataStream.getIndices()) { indexToDataStreamLookup.put(i.getName(), dataStream); } - if (DataStream.isFailureStoreEnabled() && dataStream.isFailureStore()) { + if (DataStream.isFailureStoreFeatureFlagEnabled() && dataStream.isFailureStoreEnabled()) { for (Index i : dataStream.getFailureIndices()) { indexToDataStreamLookup.put(i.getName(), dataStream); } diff --git a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateDataStreamService.java b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateDataStreamService.java index 3c3ff0d130f0a..2d1d38ac926d6 100644 --- a/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateDataStreamService.java +++ b/server/src/main/java/org/elasticsearch/cluster/metadata/MetadataCreateDataStreamService.java @@ -418,7 +418,7 @@ public static ClusterState createFailureStoreIndex( String failureStoreIndexName, @Nullable BiConsumer metadataTransformer ) throws Exception { - if (DataStream.isFailureStoreEnabled() == false) { + if (DataStream.isFailureStoreFeatureFlagEnabled() == false) { return currentState; } diff --git a/server/src/main/java/org/elasticsearch/rest/action/admin/indices/RestRolloverIndexAction.java b/server/src/main/java/org/elasticsearch/rest/action/admin/indices/RestRolloverIndexAction.java index 98895a49fae6e..1718d9af7e5c8 100644 --- a/server/src/main/java/org/elasticsearch/rest/action/admin/indices/RestRolloverIndexAction.java +++ b/server/src/main/java/org/elasticsearch/rest/action/admin/indices/RestRolloverIndexAction.java @@ -53,7 +53,7 @@ public RestChannelConsumer prepareRequest(final RestRequest request, final NodeC rolloverIndexRequest.lazy(request.paramAsBoolean("lazy", false)); rolloverIndexRequest.timeout(request.paramAsTime("timeout", rolloverIndexRequest.timeout())); rolloverIndexRequest.masterNodeTimeout(request.paramAsTime("master_timeout", rolloverIndexRequest.masterNodeTimeout())); - if (DataStream.isFailureStoreEnabled()) { + if (DataStream.isFailureStoreFeatureFlagEnabled()) { boolean failureStore = request.paramAsBoolean("target_failure_store", false); if (failureStore) { rolloverIndexRequest.setIndicesOptions( diff --git a/server/src/test/java/org/elasticsearch/action/admin/indices/rollover/MetadataRolloverServiceTests.java b/server/src/test/java/org/elasticsearch/action/admin/indices/rollover/MetadataRolloverServiceTests.java index d386eb40aea43..0bf92df006894 100644 --- a/server/src/test/java/org/elasticsearch/action/admin/indices/rollover/MetadataRolloverServiceTests.java +++ b/server/src/test/java/org/elasticsearch/action/admin/indices/rollover/MetadataRolloverServiceTests.java @@ -744,7 +744,7 @@ public void testValidation() throws Exception { // ensure no replicate data stream .promoteDataStream(); rolloverTarget = dataStream.getName(); - if (dataStream.isFailureStore() && randomBoolean()) { + if (dataStream.isFailureStoreEnabled() && randomBoolean()) { failureStoreOptions = new FailureStoreOptions(false, true); sourceIndexName = dataStream.getFailureStoreWriteIndex().getName(); defaultRolloverIndexName = DataStream.getDefaultFailureStoreName( diff --git a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java index 23395556761f1..b662f439a0e6f 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/BulkOperationTests.java @@ -363,7 +363,7 @@ public void testBulkToDataStreamFailingEntireShard() throws Exception { * A bulk operation to a data stream with a failure store enabled should redirect any shard level failures to the failure store. */ public void testFailingEntireShardRedirectsToFailureStore() throws Exception { - Assume.assumeTrue(DataStream.isFailureStoreEnabled()); + Assume.assumeTrue(DataStream.isFailureStoreFeatureFlagEnabled()); // Requests that go to two separate shards BulkRequest bulkRequest = new BulkRequest(); @@ -393,7 +393,7 @@ public void testFailingEntireShardRedirectsToFailureStore() throws Exception { * failure store. */ public void testFailingDocumentRedirectsToFailureStore() throws Exception { - Assume.assumeTrue(DataStream.isFailureStoreEnabled()); + Assume.assumeTrue(DataStream.isFailureStoreFeatureFlagEnabled()); // Requests that go to two separate shards BulkRequest bulkRequest = new BulkRequest(); @@ -423,7 +423,7 @@ public void testFailingDocumentRedirectsToFailureStore() throws Exception { * a shard-level failure while writing to the failure store indices. */ public void testFailureStoreShardFailureRejectsDocument() throws Exception { - Assume.assumeTrue(DataStream.isFailureStoreEnabled()); + Assume.assumeTrue(DataStream.isFailureStoreFeatureFlagEnabled()); // Requests that go to two separate shards BulkRequest bulkRequest = new BulkRequest(); @@ -467,7 +467,7 @@ public void testFailureStoreShardFailureRejectsDocument() throws Exception { * instead will simply report its original failure in the response, with the conversion failure present as a suppressed exception. */ public void testFailedDocumentCanNotBeConvertedFails() throws Exception { - Assume.assumeTrue(DataStream.isFailureStoreEnabled()); + Assume.assumeTrue(DataStream.isFailureStoreFeatureFlagEnabled()); // Requests that go to two separate shards BulkRequest bulkRequest = new BulkRequest(); @@ -505,7 +505,7 @@ public void testFailedDocumentCanNotBeConvertedFails() throws Exception { * non-retryable block when the redirected documents would be sent to the shard-level action. */ public void testBlockedClusterRejectsFailureStoreDocument() throws Exception { - Assume.assumeTrue(DataStream.isFailureStoreEnabled()); + Assume.assumeTrue(DataStream.isFailureStoreFeatureFlagEnabled()); // Requests that go to two separate shards BulkRequest bulkRequest = new BulkRequest(); @@ -560,7 +560,7 @@ public void testBlockedClusterRejectsFailureStoreDocument() throws Exception { * retryable block to clear when the redirected documents would be sent to the shard-level action. */ public void testOperationTimeoutRejectsFailureStoreDocument() throws Exception { - Assume.assumeTrue(DataStream.isFailureStoreEnabled()); + Assume.assumeTrue(DataStream.isFailureStoreFeatureFlagEnabled()); // Requests that go to two separate shards BulkRequest bulkRequest = new BulkRequest(); @@ -623,7 +623,7 @@ public void testOperationTimeoutRejectsFailureStoreDocument() throws Exception { * for a retryable block to clear when the redirected documents would be sent to the shard-level action. */ public void testNodeClosureRejectsFailureStoreDocument() { - Assume.assumeTrue(DataStream.isFailureStoreEnabled()); + Assume.assumeTrue(DataStream.isFailureStoreFeatureFlagEnabled()); // Requests that go to two separate shards BulkRequest bulkRequest = new BulkRequest(); diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java index 960397033f602..c27263f43eff1 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java @@ -366,7 +366,7 @@ public void testRejectionAfterCreateIndexIsPropagated() { } public void testResolveFailureStoreFromMetadata() throws Exception { - assumeThat(DataStream.isFailureStoreEnabled(), is(true)); + assumeThat(DataStream.isFailureStoreFeatureFlagEnabled(), is(true)); String dataStreamWithFailureStore = "test-data-stream-failure-enabled"; String dataStreamWithoutFailureStore = "test-data-stream-failure-disabled"; @@ -425,7 +425,7 @@ public void testResolveFailureStoreFromMetadata() throws Exception { } public void testResolveFailureStoreFromTemplate() throws Exception { - assumeThat(DataStream.isFailureStoreEnabled(), is(true)); + assumeThat(DataStream.isFailureStoreFeatureFlagEnabled(), is(true)); String dsTemplateWithFailureStore = "test-data-stream-failure-enabled"; String dsTemplateWithoutFailureStore = "test-data-stream-failure-disabled"; diff --git a/server/src/test/java/org/elasticsearch/cluster/metadata/DataStreamTests.java b/server/src/test/java/org/elasticsearch/cluster/metadata/DataStreamTests.java index 6ce9caee48d31..14c38a13f3730 100644 --- a/server/src/test/java/org/elasticsearch/cluster/metadata/DataStreamTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/metadata/DataStreamTests.java @@ -93,7 +93,7 @@ protected DataStream mutateInstance(DataStream instance) { var allowsCustomRouting = instance.isAllowCustomRouting(); var indexMode = instance.getIndexMode(); var lifecycle = instance.getLifecycle(); - var failureStore = instance.isFailureStore(); + var failureStore = instance.isFailureStoreEnabled(); var failureIndices = instance.getFailureIndices(); var rolloverOnWrite = instance.rolloverOnWrite(); var autoShardingEvent = instance.getAutoShardingEvent(); diff --git a/test/framework/src/main/java/org/elasticsearch/cluster/metadata/DataStreamTestHelper.java b/test/framework/src/main/java/org/elasticsearch/cluster/metadata/DataStreamTestHelper.java index 91e057bb58e71..6c038470b158d 100644 --- a/test/framework/src/main/java/org/elasticsearch/cluster/metadata/DataStreamTestHelper.java +++ b/test/framework/src/main/java/org/elasticsearch/cluster/metadata/DataStreamTestHelper.java @@ -153,7 +153,7 @@ public static DataStream newInstance( .setMetadata(metadata) .setReplicated(replicated) .setLifecycle(lifecycle) - .setFailureStore(failureStores.isEmpty() == false) + .setFailureStoreEnabled(failureStores.isEmpty() == false) .setFailureIndices(failureStores) .build(); } @@ -460,7 +460,11 @@ public static void getClusterStateWithDataStreams( ComposableIndexTemplate.builder() .indexPatterns(List.of("*")) .dataStreamTemplate( - new ComposableIndexTemplate.DataStreamTemplate(false, false, DataStream.isFailureStoreEnabled() && storeFailures) + new ComposableIndexTemplate.DataStreamTemplate( + false, + false, + DataStream.isFailureStoreFeatureFlagEnabled() && storeFailures + ) ) .build() ); @@ -476,7 +480,7 @@ public static void getClusterStateWithDataStreams( allIndices.addAll(backingIndices); List failureStores = new ArrayList<>(); - if (DataStream.isFailureStoreEnabled() && storeFailures) { + if (DataStream.isFailureStoreFeatureFlagEnabled() && storeFailures) { for (int failureStoreNumber = 1; failureStoreNumber <= dsTuple.v2(); failureStoreNumber++) { failureStores.add( createIndexMetadata( From 65bee01cfe36a289d0b4534c230788c6b10eb29e Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Fri, 12 Apr 2024 12:21:03 -0400 Subject: [PATCH 14/20] Loosen error for nearest score in VectorIT#testQuantizedVectorSearch (#107382) Co-authored-by: Elastic Machine --- .../java/org/elasticsearch/upgrades/VectorSearchIT.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/upgrades/VectorSearchIT.java b/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/upgrades/VectorSearchIT.java index 34b2f5d723949..e78e0978b1d80 100644 --- a/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/upgrades/VectorSearchIT.java +++ b/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/upgrades/VectorSearchIT.java @@ -285,7 +285,6 @@ public void testByteVectorSearch() throws Exception { assertThat((double) hits.get(0).get("_score"), closeTo(0.028571429, 0.0001)); } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/107332") public void testQuantizedVectorSearch() throws Exception { assumeTrue( "Quantized vector search is not supported on this version", @@ -357,7 +356,7 @@ public void testQuantizedVectorSearch() throws Exception { assertThat(extractValue(response, "hits.total.value"), equalTo(2)); hits = extractValue(response, "hits.hits"); assertThat(hits.get(0).get("_id"), equalTo("0")); - assertThat((double) hits.get(0).get("_score"), closeTo(0.9934857, 0.0001)); + assertThat((double) hits.get(0).get("_score"), closeTo(0.9934857, 0.005)); } private void indexVectors(String indexName) throws Exception { From 6f17d03e10a2aedd5435a520d510ee097b05fb56 Mon Sep 17 00:00:00 2001 From: Luigi Dell'Aquila Date: Fri, 12 Apr 2024 18:42:40 +0200 Subject: [PATCH 15/20] ESQL: Reduce size of HeapAttackIT.testSortByManyLongsSuccess (#107400) Reducing the number of fields to avoid failures on CI environments where we have much less heap (<300M for the circuit breaker). --- .../org/elasticsearch/xpack/esql/heap_attack/Clusters.java | 3 ++- .../org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/Clusters.java b/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/Clusters.java index fbc191a12d8b0..72e08c340ea0c 100644 --- a/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/Clusters.java +++ b/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/Clusters.java @@ -19,7 +19,8 @@ static ElasticsearchCluster buildCluster() { .nodes(2) .module("test-esql-heap-attack") .setting("xpack.security.enabled", "false") - .setting("xpack.license.self_generated.type", "trial"); + .setting("xpack.license.self_generated.type", "trial") + .jvmArg("-Xmx512m"); String javaVersion = JvmInfo.jvmInfo().version(); if (javaVersion.equals("20") || javaVersion.equals("21")) { // see https://github.com/elastic/elasticsearch/issues/99592 diff --git a/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java b/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java index 2f3826f8423b8..4f43817b7b92c 100644 --- a/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java +++ b/test/external-modules/esql-heap-attack/src/javaRestTest/java/org/elasticsearch/xpack/esql/heap_attack/HeapAttackIT.java @@ -77,7 +77,7 @@ public void skipOnAborted() { */ public void testSortByManyLongsSuccess() throws IOException { initManyLongs(); - Response response = sortByManyLongs(2000); + Response response = sortByManyLongs(500); Map map = responseAsMap(response); ListMatcher columns = matchesList().item(matchesMap().entry("name", "a").entry("type", "long")) .item(matchesMap().entry("name", "b").entry("type", "long")); From 2bc4966f76f76540f45a1f543a57b7c622f469ca Mon Sep 17 00:00:00 2001 From: Oleksandr Kolomiiets Date: Fri, 12 Apr 2024 10:17:43 -0700 Subject: [PATCH 16/20] Allow decoding IP ranges from doc values (#107376) --- .../main/java/org/elasticsearch/index/mapper/RangeType.java | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/mapper/RangeType.java b/server/src/main/java/org/elasticsearch/index/mapper/RangeType.java index f8100e794dbd9..f339269d93636 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/RangeType.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/RangeType.java @@ -107,9 +107,8 @@ public BytesRef encodeRanges(Set ranges) throws IOExcept } @Override - public List decodeRanges(BytesRef bytes) { - // TODO: Implement this. - throw new UnsupportedOperationException(); + public List decodeRanges(BytesRef bytes) throws IOException { + return BinaryRangeUtil.decodeIPRanges(bytes); } @Override From 965ebab631e7af42d6bbba583306a6a50a2878cf Mon Sep 17 00:00:00 2001 From: Mayya Sharipova Date: Fri, 12 Apr 2024 13:44:50 -0400 Subject: [PATCH 17/20] Percolator named queries: rewrite for matched info (#107432) PR #103084 introduced an ability to return matched_queries during percolate process for all percolator queries containing `_name` field. But there was a bug with complex queries, as they were not rewritten before obraining their Weight function. This fixes the bug by ensuring all queries are first rewritten. Closes #107176 --- docs/changelog/107432.yaml | 6 ++ .../PercolatorMatchedSlotSubFetchPhase.java | 3 +- .../test/30_matched_complex_queries.yml | 86 +++++++++++++++++++ 3 files changed, 94 insertions(+), 1 deletion(-) create mode 100644 docs/changelog/107432.yaml create mode 100644 modules/percolator/src/yamlRestTest/resources/rest-api-spec/test/30_matched_complex_queries.yml diff --git a/docs/changelog/107432.yaml b/docs/changelog/107432.yaml new file mode 100644 index 0000000000000..c492644c5baf2 --- /dev/null +++ b/docs/changelog/107432.yaml @@ -0,0 +1,6 @@ +pr: 107432 +summary: "Percolator named queries: rewrite for matched info" +area: Percolator +type: bug +issues: + - 107176 diff --git a/modules/percolator/src/main/java/org/elasticsearch/percolator/PercolatorMatchedSlotSubFetchPhase.java b/modules/percolator/src/main/java/org/elasticsearch/percolator/PercolatorMatchedSlotSubFetchPhase.java index 83703dcf10971..fe4bfc7741c87 100644 --- a/modules/percolator/src/main/java/org/elasticsearch/percolator/PercolatorMatchedSlotSubFetchPhase.java +++ b/modules/percolator/src/main/java/org/elasticsearch/percolator/PercolatorMatchedSlotSubFetchPhase.java @@ -85,8 +85,9 @@ public void process(HitContext hitContext) throws IOException { // This is not a document with a percolator field. continue; } - query = pc.filterNestedDocs(query, fetchContext.getSearchExecutionContext().indexVersionCreated()); IndexSearcher percolatorIndexSearcher = pc.percolateQuery.getPercolatorIndexSearcher(); + query = pc.filterNestedDocs(query, fetchContext.getSearchExecutionContext().indexVersionCreated()); + query = percolatorIndexSearcher.rewrite(query); int memoryIndexMaxDoc = percolatorIndexSearcher.getIndexReader().maxDoc(); TopDocs topDocs = percolatorIndexSearcher.search(query, memoryIndexMaxDoc, new Sort(SortField.FIELD_DOC)); if (topDocs.totalHits.value == 0) { diff --git a/modules/percolator/src/yamlRestTest/resources/rest-api-spec/test/30_matched_complex_queries.yml b/modules/percolator/src/yamlRestTest/resources/rest-api-spec/test/30_matched_complex_queries.yml new file mode 100644 index 0000000000000..eb0c020a8199e --- /dev/null +++ b/modules/percolator/src/yamlRestTest/resources/rest-api-spec/test/30_matched_complex_queries.yml @@ -0,0 +1,86 @@ +setup: + - requires: + cluster_features: ["gte_v8.14.0"] + reason: "Displaying matched complex named queries within percolator queries was fixed in 8.14" + - do: + indices.create: + index: houses + body: + mappings: + dynamic: strict + properties: + my_query: + type: percolator + description: + type: text + num_of_bedrooms: + type: integer + type: + type: keyword + price: + type: integer + + - do: + index: + refresh: true + index: houses + id: query_cheap_houses_with_swimming_pool + body: + my_query: + { + "bool": { + "should": [ + { "range": { "price": { "lte": 399999, "_name": "cheap_query" } } }, + { "wildcard": { "description": { "value": "swim*", "_name": "swimming_pool_query" } } } + ] + } + } + + - do: + index: + refresh: true + index: houses + id: query_big_houses_with_fireplace + body: + my_query: + { + "bool": { + "should": [ + { "range": { "num_of_bedrooms": { "gte": 3, "_name": "big_house_query" } } }, + { "query_string": { "query": "fire*", "fields" : ["description"], "_name": "fireplace_query" } } + ] + } + } + +--- +"Matched named queries within percolator queries: percolate existing document": + - do: + index: + refresh: true + index: houses + id: house1 + body: + description: "house with a beautiful fireplace and swimming pool" + num_of_bedrooms: 3 + type: detached + price: 1000000 + + - do: + search: + index: houses + body: + query: + percolate: + field: my_query + index: houses + id: house1 + + - match: { hits.total.value: 2 } + + - match: { hits.hits.0._id: query_big_houses_with_fireplace } + - match: { hits.hits.0.fields._percolator_document_slot: [ 0 ] } + - match: { hits.hits.0.fields._percolator_document_slot_0_matched_queries: [ "big_house_query", "fireplace_query" ] } + + - match: { hits.hits.1._id: query_cheap_houses_with_swimming_pool } + - match: { hits.hits.1.fields._percolator_document_slot: [ 0 ] } + - match: { hits.hits.1.fields._percolator_document_slot_0_matched_queries: [ "swimming_pool_query" ] } From aedc07da4e17ed23697ed7f881b26a8042680a11 Mon Sep 17 00:00:00 2001 From: "Mark J. Hoy" Date: Fri, 12 Apr 2024 15:58:03 -0400 Subject: [PATCH 18/20] Add Azure OpenAI Embeddings Inference Service (#107178) * initial start to Azure OpenAI Embeddings * some cleanups; adding more tests; breaking * cleanups; all test so far passing; * cleanups; checkstyle; finish tests * checkstyle cleanups; spotless apply * remove String.format usage * smoke tested and working; some cleanups * cleanup unneeded comments * cleanup wayward comment * finalize tests; set model as URI holder * fixups after rebase; notably add timeout param * cleanups; remove AzureResponse in favour of OpenAI * ensure dimensions_set_by_user cannot be in request * move AzureOpenAiSecretSettings to azureopenai pkg * fix lint * add similarity for service settings; cleanups; * allow request similarity;correct secret validation --- .../org/elasticsearch/TransportVersions.java | 1 + .../InferenceNamedWriteablesProvider.java | 27 + .../xpack/inference/InferencePlugin.java | 2 + .../azureopenai/AzureOpenAiActionCreator.java | 35 + .../azureopenai/AzureOpenAiActionVisitor.java | 17 + .../AzureOpenAiEmbeddingsAction.java | 53 + .../azureopenai/AzureOpenAiAccount.java | 40 + .../AzureOpenAiResponseHandler.java | 52 + ...nAiEmbeddingsExecutableRequestCreator.java | 63 + .../OpenAiChatCompletionResponseHandler.java | 2 +- .../openai/OpenAiResponseHandler.java | 2 +- .../AzureOpenAiEmbeddingsRequest.java | 110 ++ .../AzureOpenAiEmbeddingsRequestEntity.java | 49 + .../azureopenai/AzureOpenAiRequest.java | 12 + .../request/azureopenai/AzureOpenAiUtils.java | 20 + .../inference/services/ServiceUtils.java | 19 + .../azureopenai/AzureOpenAiModel.java | 49 + .../AzureOpenAiSecretSettings.java | 101 ++ .../azureopenai/AzureOpenAiService.java | 296 +++++ .../azureopenai/AzureOpenAiServiceFields.java | 16 + .../AzureOpenAiEmbeddingsModel.java | 116 ++ ...reOpenAiEmbeddingsRequestTaskSettings.java | 54 + .../AzureOpenAiEmbeddingsServiceSettings.java | 282 ++++ .../AzureOpenAiEmbeddingsTaskSettings.java | 114 ++ .../AzureOpenAiActionCreatorTests.java | 454 +++++++ .../AzureOpenAiEmbeddingsActionTests.java | 219 +++ .../AzureOpenAiResponseHandlerTests.java | 88 ++ ...ureOpenAiEmbeddingsRequestEntityTests.java | 77 ++ .../AzureOpenAiEmbeddingsRequestTests.java | 118 ++ .../AzureOpenAiSecretSettingsTests.java | 160 +++ .../azureopenai/AzureOpenAiServiceTests.java | 1180 +++++++++++++++++ .../AzureOpenAiEmbeddingsModelTests.java | 121 ++ ...nAiEmbeddingsRequestTaskSettingsTests.java | 56 + ...eOpenAiEmbeddingsServiceSettingsTests.java | 389 ++++++ ...zureOpenAiEmbeddingsTaskSettingsTests.java | 107 ++ 35 files changed, 4499 insertions(+), 2 deletions(-) create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreator.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionVisitor.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiEmbeddingsAction.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/azureopenai/AzureOpenAiAccount.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/azureopenai/AzureOpenAiResponseHandler.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsExecutableRequestCreator.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiEmbeddingsRequest.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiEmbeddingsRequestEntity.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiRequest.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiUtils.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiModel.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiSecretSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceFields.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsModel.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsRequestTaskSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettings.java create mode 100644 x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsTaskSettings.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiEmbeddingsActionTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/azureopenai/AzureOpenAiResponseHandlerTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiEmbeddingsRequestEntityTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiEmbeddingsRequestTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiSecretSettingsTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsModelTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsRequestTaskSettingsTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettingsTests.java create mode 100644 x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsTaskSettingsTests.java diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 5fc62afb0c27d..978ad1ce31e28 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -172,6 +172,7 @@ static TransportVersion def(int id) { public static final TransportVersion ML_INFERENCE_RERANK_NEW_RESPONSE_FORMAT = def(8_631_00_0); public static final TransportVersion HIGHLIGHTERS_TAGS_ON_FIELD_LEVEL = def(8_632_00_0); public static final TransportVersion TRACK_FLUSH_TIME_EXCLUDING_WAITING_ON_LOCKS = def(8_633_00_0); + public static final TransportVersion ML_INFERENCE_AZURE_OPENAI_EMBEDDINGS = def(8_634_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index cfb21ad5a1d94..21bd73c3821c4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -24,6 +24,9 @@ import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingByteResults; import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; +import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings; +import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsServiceSettings; +import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsTaskSettings; import org.elasticsearch.xpack.inference.services.cohere.CohereServiceSettings; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettings; @@ -202,6 +205,30 @@ public static List getNamedWriteables() { new NamedWriteableRegistry.Entry(TaskSettings.class, CohereRerankTaskSettings.NAME, CohereRerankTaskSettings::new) ); + // Azure OpenAI + namedWriteables.add( + new NamedWriteableRegistry.Entry( + AzureOpenAiSecretSettings.class, + AzureOpenAiSecretSettings.NAME, + AzureOpenAiSecretSettings::new + ) + ); + + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + AzureOpenAiEmbeddingsServiceSettings.NAME, + AzureOpenAiEmbeddingsServiceSettings::new + ) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry( + TaskSettings.class, + AzureOpenAiEmbeddingsTaskSettings.NAME, + AzureOpenAiEmbeddingsTaskSettings::new + ) + ); + return namedWriteables; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 3b2c0b3c4ac3e..f41f9a97cec18 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -56,6 +56,7 @@ import org.elasticsearch.xpack.inference.rest.RestInferenceAction; import org.elasticsearch.xpack.inference.rest.RestPutInferenceModelAction; import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService; import org.elasticsearch.xpack.inference.services.cohere.CohereService; import org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalService; import org.elasticsearch.xpack.inference.services.elser.ElserInternalService; @@ -176,6 +177,7 @@ public List getInferenceServiceFactories() { context -> new HuggingFaceService(httpFactory.get(), serviceComponents.get()), context -> new OpenAiService(httpFactory.get(), serviceComponents.get()), context -> new CohereService(httpFactory.get(), serviceComponents.get()), + context -> new AzureOpenAiService(httpFactory.get(), serviceComponents.get()), ElasticsearchInternalService::new ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreator.java new file mode 100644 index 0000000000000..39eaaceae08bc --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreator.java @@ -0,0 +1,35 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.azureopenai; + +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel; + +import java.util.Map; +import java.util.Objects; + +/** + * Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the openai model type. + */ +public class AzureOpenAiActionCreator implements AzureOpenAiActionVisitor { + private final Sender sender; + private final ServiceComponents serviceComponents; + + public AzureOpenAiActionCreator(Sender sender, ServiceComponents serviceComponents) { + this.sender = Objects.requireNonNull(sender); + this.serviceComponents = Objects.requireNonNull(serviceComponents); + } + + @Override + public ExecutableAction create(AzureOpenAiEmbeddingsModel model, Map taskSettings) { + var overriddenModel = AzureOpenAiEmbeddingsModel.of(model, taskSettings); + return new AzureOpenAiEmbeddingsAction(sender, overriddenModel, serviceComponents); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionVisitor.java new file mode 100644 index 0000000000000..49d1ce61b12dd --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionVisitor.java @@ -0,0 +1,17 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.azureopenai; + +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel; + +import java.util.Map; + +public interface AzureOpenAiActionVisitor { + ExecutableAction create(AzureOpenAiEmbeddingsModel model, Map taskSettings); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiEmbeddingsAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiEmbeddingsAction.java new file mode 100644 index 0000000000000..a682ad2bb23d5 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiEmbeddingsAction.java @@ -0,0 +1,53 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.azureopenai; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.http.sender.AzureOpenAiEmbeddingsExecutableRequestCreator; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel; + +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException; + +public class AzureOpenAiEmbeddingsAction implements ExecutableAction { + + private final String errorMessage; + private final AzureOpenAiEmbeddingsExecutableRequestCreator requestCreator; + private final Sender sender; + + public AzureOpenAiEmbeddingsAction(Sender sender, AzureOpenAiEmbeddingsModel model, ServiceComponents serviceComponents) { + Objects.requireNonNull(serviceComponents); + Objects.requireNonNull(model); + this.sender = Objects.requireNonNull(sender); + requestCreator = new AzureOpenAiEmbeddingsExecutableRequestCreator(model, serviceComponents.truncator()); + errorMessage = constructFailedToSendRequestMessage(model.getUri(), "Azure OpenAI embeddings"); + } + + @Override + public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener listener) { + try { + ActionListener wrappedListener = wrapFailuresInElasticsearchException(errorMessage, listener); + + sender.send(requestCreator, inferenceInputs, timeout, wrappedListener); + } catch (ElasticsearchException e) { + listener.onFailure(e); + } catch (Exception e) { + listener.onFailure(createInternalServerError(e, errorMessage)); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/azureopenai/AzureOpenAiAccount.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/azureopenai/AzureOpenAiAccount.java new file mode 100644 index 0000000000000..db1f91cc751ee --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/azureopenai/AzureOpenAiAccount.java @@ -0,0 +1,40 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.azureopenai; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel; + +import java.util.Objects; + +public record AzureOpenAiAccount( + String resourceName, + String deploymentId, + String apiVersion, + @Nullable SecureString apiKey, + @Nullable SecureString entraId +) { + + public AzureOpenAiAccount { + Objects.requireNonNull(resourceName); + Objects.requireNonNull(deploymentId); + Objects.requireNonNull(apiVersion); + Objects.requireNonNullElse(apiKey, entraId); + } + + public static AzureOpenAiAccount fromModel(AzureOpenAiEmbeddingsModel model) { + return new AzureOpenAiAccount( + model.getServiceSettings().resourceName(), + model.getServiceSettings().deploymentId(), + model.getServiceSettings().apiVersion(), + model.getSecretSettings().apiKey(), + model.getSecretSettings().entraId() + ); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/azureopenai/AzureOpenAiResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/azureopenai/AzureOpenAiResponseHandler.java new file mode 100644 index 0000000000000..2f72088327468 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/azureopenai/AzureOpenAiResponseHandler.java @@ -0,0 +1,52 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.azureopenai; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.openai.OpenAiResponseHandler; +import org.elasticsearch.xpack.inference.external.request.Request; + +import static org.elasticsearch.xpack.inference.external.http.retry.ResponseHandlerUtils.getFirstHeaderOrUnknown; + +public class AzureOpenAiResponseHandler extends OpenAiResponseHandler { + + /** + * These headers for Azure OpenAi are mostly the same as the OpenAi ones with the major exception + * that there is no information returned about the request limit or the tokens limit + * + * Microsoft does not seem to have any published information in their docs about this, but more + * information can be found in the following Medium article and accompanying code: + * - https://pablo-81685.medium.com/azure-openais-api-headers-unpacked-6dbe881e732a + * - https://github.com/pablosalvador10/gbbai-azure-ai-aoai + */ + static final String REMAINING_REQUESTS = "x-ratelimit-remaining-requests"; + // The remaining number of tokens that are permitted before exhausting the rate limit. + static final String REMAINING_TOKENS = "x-ratelimit-remaining-tokens"; + + public AzureOpenAiResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction); + } + + @Override + protected RetryException buildExceptionHandling429(Request request, HttpResult result) { + return new RetryException(true, buildError(buildRateLimitErrorMessage(result), request, result)); + } + + static String buildRateLimitErrorMessage(HttpResult result) { + var response = result.response(); + var remainingTokens = getFirstHeaderOrUnknown(response, REMAINING_TOKENS); + var remainingRequests = getFirstHeaderOrUnknown(response, REMAINING_REQUESTS); + var usageMessage = Strings.format("Remaining tokens [%s]. Remaining requests [%s].", remainingTokens, remainingRequests); + + return RATE_LIMIT + ". " + usageMessage; + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsExecutableRequestCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsExecutableRequestCreator.java new file mode 100644 index 0000000000000..b3f53d5f3f236 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsExecutableRequestCreator.java @@ -0,0 +1,63 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.http.client.protocol.HttpClientContext; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.external.azureopenai.AzureOpenAiAccount; +import org.elasticsearch.xpack.inference.external.azureopenai.AzureOpenAiResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiEmbeddingsRequest; +import org.elasticsearch.xpack.inference.external.response.openai.OpenAiEmbeddingsResponseEntity; +import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel; + +import java.util.List; +import java.util.Objects; +import java.util.function.Supplier; + +import static org.elasticsearch.xpack.inference.common.Truncator.truncate; + +public class AzureOpenAiEmbeddingsExecutableRequestCreator implements ExecutableRequestCreator { + + private static final Logger logger = LogManager.getLogger(AzureOpenAiEmbeddingsExecutableRequestCreator.class); + + private static final ResponseHandler HANDLER = createEmbeddingsHandler(); + + private static ResponseHandler createEmbeddingsHandler() { + return new AzureOpenAiResponseHandler("azure openai text embedding", OpenAiEmbeddingsResponseEntity::fromResponse); + } + + private final Truncator truncator; + private final AzureOpenAiEmbeddingsModel model; + private final AzureOpenAiAccount account; + + public AzureOpenAiEmbeddingsExecutableRequestCreator(AzureOpenAiEmbeddingsModel model, Truncator truncator) { + this.model = Objects.requireNonNull(model); + this.account = AzureOpenAiAccount.fromModel(model); + this.truncator = Objects.requireNonNull(truncator); + } + + @Override + public Runnable create( + String query, + List input, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + HttpClientContext context, + ActionListener listener + ) { + var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens()); + AzureOpenAiEmbeddingsRequest request = new AzureOpenAiEmbeddingsRequest(truncator, account, truncatedInput, model); + return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiChatCompletionResponseHandler.java index 5924356e610a3..7ca7cf0422fd9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiChatCompletionResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiChatCompletionResponseHandler.java @@ -18,7 +18,7 @@ public OpenAiChatCompletionResponseHandler(String requestType, ResponseParser pa } @Override - RetryException buildExceptionHandling429(Request request, HttpResult result) { + protected RetryException buildExceptionHandling429(Request request, HttpResult result) { // We don't retry, if the chat completion input is too large return new RetryException(false, buildError(RATE_LIMIT, request, result)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiResponseHandler.java index db7ca8d6bdc63..c23b94351c187 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiResponseHandler.java @@ -83,7 +83,7 @@ void checkForFailureStatusCode(Request request, HttpResult result) throws RetryE } } - RetryException buildExceptionHandling429(Request request, HttpResult result) { + protected RetryException buildExceptionHandling429(Request request, HttpResult result) { return new RetryException(true, buildError(buildRateLimitErrorMessage(result), request, result)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiEmbeddingsRequest.java new file mode 100644 index 0000000000000..c943d5f54b4ff --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiEmbeddingsRequest.java @@ -0,0 +1,110 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.azureopenai; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.apache.http.message.BasicHeader; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.external.azureopenai.AzureOpenAiAccount; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; +import static org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils.API_KEY_HEADER; +import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings.API_KEY; +import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings.ENTRA_ID; + +public class AzureOpenAiEmbeddingsRequest implements AzureOpenAiRequest { + private static final String MISSING_AUTHENTICATION_ERROR_MESSAGE = + "The request does not have any authentication methods set. One of [%s] or [%s] is required."; + + private final Truncator truncator; + private final AzureOpenAiAccount account; + private final Truncator.TruncationResult truncationResult; + private final URI uri; + private final AzureOpenAiEmbeddingsModel model; + + public AzureOpenAiEmbeddingsRequest( + Truncator truncator, + AzureOpenAiAccount account, + Truncator.TruncationResult input, + AzureOpenAiEmbeddingsModel model + ) { + this.truncator = Objects.requireNonNull(truncator); + this.account = Objects.requireNonNull(account); + this.truncationResult = Objects.requireNonNull(input); + this.model = Objects.requireNonNull(model); + this.uri = model.getUri(); + } + + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(uri); + + String requestEntity = Strings.toString( + new AzureOpenAiEmbeddingsRequestEntity( + truncationResult.input(), + model.getTaskSettings().user(), + model.getServiceSettings().dimensions(), + model.getServiceSettings().dimensionsSetByUser() + ) + ); + + ByteArrayEntity byteEntity = new ByteArrayEntity(requestEntity.getBytes(StandardCharsets.UTF_8)); + httpPost.setEntity(byteEntity); + + httpPost.setHeader(new BasicHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType())); + + var entraId = model.getSecretSettings().entraId(); + var apiKey = model.getSecretSettings().apiKey(); + + if (entraId != null && entraId.isEmpty() == false) { + httpPost.setHeader(createAuthBearerHeader(entraId)); + } else if (apiKey != null && apiKey.isEmpty() == false) { + httpPost.setHeader(new BasicHeader(API_KEY_HEADER, apiKey.toString())); + } else { + // should never happen due to the checks on the secret settings, but just in case + ValidationException validationException = new ValidationException(); + validationException.addValidationError(Strings.format(MISSING_AUTHENTICATION_ERROR_MESSAGE, API_KEY, ENTRA_ID)); + throw validationException; + } + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public URI getURI() { + return this.uri; + } + + @Override + public String getInferenceEntityId() { + return model.getInferenceEntityId(); + } + + @Override + public Request truncate() { + var truncatedInput = truncator.truncate(truncationResult.input()); + + return new AzureOpenAiEmbeddingsRequest(truncator, account, truncatedInput, model); + } + + @Override + public boolean[] getTruncationInfo() { + return truncationResult.truncated().clone(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiEmbeddingsRequestEntity.java new file mode 100644 index 0000000000000..2a9a93e99d4e4 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiEmbeddingsRequestEntity.java @@ -0,0 +1,49 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.azureopenai; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public record AzureOpenAiEmbeddingsRequestEntity( + List input, + @Nullable String user, + @Nullable Integer dimensions, + boolean dimensionsSetByUser +) implements ToXContentObject { + + private static final String INPUT_FIELD = "input"; + private static final String USER_FIELD = "user"; + private static final String DIMENSIONS_FIELD = "dimensions"; + + public AzureOpenAiEmbeddingsRequestEntity { + Objects.requireNonNull(input); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(INPUT_FIELD, input); + + if (user != null) { + builder.field(USER_FIELD, user); + } + + if (dimensionsSetByUser && dimensions != null) { + builder.field(DIMENSIONS_FIELD, dimensions); + } + + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiRequest.java new file mode 100644 index 0000000000000..edb7c70b3903e --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiRequest.java @@ -0,0 +1,12 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.azureopenai; + +import org.elasticsearch.xpack.inference.external.request.Request; + +public interface AzureOpenAiRequest extends Request {} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiUtils.java new file mode 100644 index 0000000000000..16a02a4c06c1c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiUtils.java @@ -0,0 +1,20 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.azureopenai; + +public class AzureOpenAiUtils { + + public static final String HOST_SUFFIX = "openai.azure.com"; + public static final String OPENAI_PATH = "openai"; + public static final String DEPLOYMENTS_PATH = "deployments"; + public static final String EMBEDDINGS_PATH = "embeddings"; + public static final String API_VERSION_PARAMETER = "api-version"; + public static final String API_KEY_HEADER = "api-key"; + + private AzureOpenAiUtils() {} +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index 72808b6de8132..1631755149578 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -139,6 +139,10 @@ public static String invalidValue(String settingName, String scope, String inval ); } + public static String invalidSettingError(String settingName, String scope) { + return Strings.format("[%s] does not allow the setting [%s]", scope, settingName); + } + // TODO improve URI validation logic public static URI convertToUri(@Nullable String url, String settingName, String settingScope, ValidationException validationException) { try { @@ -186,6 +190,21 @@ public static SecureString extractRequiredSecureString( return new SecureString(Objects.requireNonNull(requiredField).toCharArray()); } + public static SecureString extractOptionalSecureString( + Map map, + String settingName, + String scope, + ValidationException validationException + ) { + String optionalField = extractOptionalString(map, settingName, scope, validationException); + + if (validationException.validationErrors().isEmpty() == false || optionalField == null) { + return null; + } + + return new SecureString(optionalField.toCharArray()); + } + public static SimilarityMeasure extractSimilarity(Map map, String scope, ValidationException validationException) { return extractOptionalEnum( map, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiModel.java new file mode 100644 index 0000000000000..66070cab0e517 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiModel.java @@ -0,0 +1,49 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureopenai; + +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.azureopenai.AzureOpenAiActionVisitor; + +import java.net.URI; +import java.util.Map; + +public abstract class AzureOpenAiModel extends Model { + + protected URI uri; + + public AzureOpenAiModel(ModelConfigurations configurations, ModelSecrets secrets) { + super(configurations, secrets); + } + + protected AzureOpenAiModel(AzureOpenAiModel model, TaskSettings taskSettings) { + super(model, taskSettings); + this.uri = model.getUri(); + } + + protected AzureOpenAiModel(AzureOpenAiModel model, ServiceSettings serviceSettings) { + super(model, serviceSettings); + this.uri = model.getUri(); + } + + public abstract ExecutableAction accept(AzureOpenAiActionVisitor creator, Map taskSettings); + + public URI getUri() { + return uri; + } + + // Needed for testing + public void setUri(URI newUri) { + this.uri = newUri; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiSecretSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiSecretSettings.java new file mode 100644 index 0000000000000..f871fe6c080a1 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiSecretSettings.java @@ -0,0 +1,101 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureopenai; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SecretSettings; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalSecureString; + +public record AzureOpenAiSecretSettings(@Nullable SecureString apiKey, @Nullable SecureString entraId) implements SecretSettings { + + public static final String NAME = "azure_openai_secret_settings"; + public static final String API_KEY = "api_key"; + public static final String ENTRA_ID = "entra_id"; + + public static AzureOpenAiSecretSettings fromMap(@Nullable Map map) { + if (map == null) { + return null; + } + + ValidationException validationException = new ValidationException(); + SecureString secureApiToken = extractOptionalSecureString(map, API_KEY, ModelSecrets.SECRET_SETTINGS, validationException); + SecureString secureEntraId = extractOptionalSecureString(map, ENTRA_ID, ModelSecrets.SECRET_SETTINGS, validationException); + + if (secureApiToken == null && secureEntraId == null) { + validationException.addValidationError( + format("[secret_settings] must have either the [%s] or the [%s] key set", API_KEY, ENTRA_ID) + ); + } + + if (secureApiToken != null && secureEntraId != null) { + validationException.addValidationError( + format("[secret_settings] must have only one of the [%s] or the [%s] key set", API_KEY, ENTRA_ID) + ); + } + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new AzureOpenAiSecretSettings(secureApiToken, secureEntraId); + } + + public AzureOpenAiSecretSettings { + Objects.requireNonNullElse(apiKey, entraId); + } + + public AzureOpenAiSecretSettings(StreamInput in) throws IOException { + this(in.readOptionalSecureString(), in.readOptionalSecureString()); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + if (apiKey != null) { + builder.field(API_KEY, apiKey.toString()); + } + + if (entraId != null) { + builder.field(ENTRA_ID, entraId.toString()); + } + + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_AZURE_OPENAI_EMBEDDINGS; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalSecureString(apiKey); + out.writeOptionalSecureString(entraId); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java new file mode 100644 index 0000000000000..f20c262053d10 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -0,0 +1,296 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureopenai; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.Strings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; +import org.elasticsearch.xpack.inference.external.action.azureopenai.AzureOpenAiActionCreator; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.SenderService; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsServiceSettings; + +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; + +public class AzureOpenAiService extends SenderService { + public static final String NAME = "azureopenai"; + + public AzureOpenAiService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { + super(factory, serviceComponents); + } + + @Override + public String name() { + return NAME; + } + + @Override + public void parseRequestConfig( + String inferenceEntityId, + TaskType taskType, + Map config, + Set platformArchitectures, + ActionListener parsedModelListener + ) { + try { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + AzureOpenAiModel model = createModel( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + serviceSettingsMap, + TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), + ConfigurationParseContext.REQUEST + ); + + throwIfNotEmptyMap(config, NAME); + throwIfNotEmptyMap(serviceSettingsMap, NAME); + throwIfNotEmptyMap(taskSettingsMap, NAME); + + parsedModelListener.onResponse(model); + } catch (Exception e) { + parsedModelListener.onFailure(e); + } + } + + private static AzureOpenAiModel createModelFromPersistent( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + @Nullable Map secretSettings, + String failureMessage + ) { + return createModel( + inferenceEntityId, + taskType, + serviceSettings, + taskSettings, + secretSettings, + failureMessage, + ConfigurationParseContext.PERSISTENT + ); + } + + private static AzureOpenAiModel createModel( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + @Nullable Map secretSettings, + String failureMessage, + ConfigurationParseContext context + ) { + if (taskType == TaskType.TEXT_EMBEDDING) { + return new AzureOpenAiEmbeddingsModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + taskSettings, + secretSettings, + context + ); + } + + throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + } + + @Override + public AzureOpenAiModel parsePersistedConfigWithSecrets( + String inferenceEntityId, + TaskType taskType, + Map config, + Map secrets + ) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); + Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); + + return createModelFromPersistent( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + secretSettingsMap, + parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + ); + } + + @Override + public AzureOpenAiModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map config) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + return createModelFromPersistent( + inferenceEntityId, + taskType, + serviceSettingsMap, + taskSettingsMap, + null, + parsePersistedConfigErrorMsg(inferenceEntityId, NAME) + ); + } + + @Override + protected void doInfer( + Model model, + List input, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener listener + ) { + if (model instanceof AzureOpenAiModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + AzureOpenAiModel azureOpenAiModel = (AzureOpenAiModel) model; + var actionCreator = new AzureOpenAiActionCreator(getSender(), getServiceComponents()); + + var action = azureOpenAiModel.accept(actionCreator, taskSettings); + action.execute(new DocumentsOnlyInput(input), timeout, listener); + } + + @Override + protected void doInfer( + Model model, + String query, + List input, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener listener + ) { + throw new UnsupportedOperationException("Azure OpenAI service does not support inference with query input"); + } + + @Override + protected void doChunkedInfer( + Model model, + String query, + List input, + Map taskSettings, + InputType inputType, + ChunkingOptions chunkingOptions, + TimeValue timeout, + ActionListener> listener + ) { + ActionListener inferListener = listener.delegateFailureAndWrap( + (delegate, response) -> delegate.onResponse(translateToChunkedResults(input, response)) + ); + + doInfer(model, input, taskSettings, inputType, timeout, inferListener); + } + + private static List translateToChunkedResults( + List inputs, + InferenceServiceResults inferenceResults + ) { + if (inferenceResults instanceof TextEmbeddingResults textEmbeddingResults) { + return ChunkedTextEmbeddingResults.of(inputs, textEmbeddingResults); + } else if (inferenceResults instanceof ErrorInferenceResults error) { + return List.of(new ErrorChunkedInferenceResults(error.getException())); + } else { + throw createInvalidChunkedResultException(inferenceResults.getWriteableName()); + } + } + + /** + * For text embedding models get the embedding size and + * update the service settings. + * + * @param model The new model + * @param listener The listener + */ + @Override + public void checkModelConfig(Model model, ActionListener listener) { + if (model instanceof AzureOpenAiEmbeddingsModel embeddingsModel) { + ServiceUtils.getEmbeddingSize( + model, + this, + listener.delegateFailureAndWrap((l, size) -> l.onResponse(updateModelWithEmbeddingDetails(embeddingsModel, size))) + ); + } else { + listener.onResponse(model); + } + } + + private AzureOpenAiEmbeddingsModel updateModelWithEmbeddingDetails(AzureOpenAiEmbeddingsModel model, int embeddingSize) { + if (model.getServiceSettings().dimensionsSetByUser() + && model.getServiceSettings().dimensions() != null + && model.getServiceSettings().dimensions() != embeddingSize) { + throw new ElasticsearchStatusException( + Strings.format( + "The retrieved embeddings size [%s] does not match the size specified in the settings [%s]. " + + "Please recreate the [%s] configuration with the correct dimensions", + embeddingSize, + model.getServiceSettings().dimensions(), + model.getConfigurations().getInferenceEntityId() + ), + RestStatus.BAD_REQUEST + ); + } + + var similarityFromModel = model.getServiceSettings().similarity(); + var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel; + + AzureOpenAiEmbeddingsServiceSettings serviceSettings = new AzureOpenAiEmbeddingsServiceSettings( + model.getServiceSettings().resourceName(), + model.getServiceSettings().deploymentId(), + model.getServiceSettings().apiVersion(), + embeddingSize, + model.getServiceSettings().dimensionsSetByUser(), + model.getServiceSettings().maxInputTokens(), + similarityToUse + ); + + return new AzureOpenAiEmbeddingsModel(model, serviceSettings); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_AZURE_OPENAI_EMBEDDINGS; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceFields.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceFields.java new file mode 100644 index 0000000000000..a3786ff27224b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceFields.java @@ -0,0 +1,16 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureopenai; + +public class AzureOpenAiServiceFields { + + public static final String RESOURCE_NAME = "resource_name"; + public static final String DEPLOYMENT_ID = "deployment_id"; + public static final String API_VERSION = "api_version"; + public static final String USER = "user"; +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsModel.java new file mode 100644 index 0000000000000..4c3272013f0e2 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsModel.java @@ -0,0 +1,116 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureopenai.embeddings; + +import org.apache.http.client.utils.URIBuilder; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.azureopenai.AzureOpenAiActionVisitor; +import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiModel; +import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Map; + +import static org.elasticsearch.core.Strings.format; + +public class AzureOpenAiEmbeddingsModel extends AzureOpenAiModel { + + public static AzureOpenAiEmbeddingsModel of(AzureOpenAiEmbeddingsModel model, Map taskSettings) { + if (taskSettings == null || taskSettings.isEmpty()) { + return model; + } + + var requestTaskSettings = AzureOpenAiEmbeddingsRequestTaskSettings.fromMap(taskSettings); + return new AzureOpenAiEmbeddingsModel(model, AzureOpenAiEmbeddingsTaskSettings.of(model.getTaskSettings(), requestTaskSettings)); + } + + public AzureOpenAiEmbeddingsModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettings, + Map taskSettings, + @Nullable Map secrets, + ConfigurationParseContext context + ) { + this( + inferenceEntityId, + taskType, + service, + AzureOpenAiEmbeddingsServiceSettings.fromMap(serviceSettings, context), + AzureOpenAiEmbeddingsTaskSettings.fromMap(taskSettings), + AzureOpenAiSecretSettings.fromMap(secrets) + ); + } + + // Should only be used directly for testing + AzureOpenAiEmbeddingsModel( + String inferenceEntityId, + TaskType taskType, + String service, + AzureOpenAiEmbeddingsServiceSettings serviceSettings, + AzureOpenAiEmbeddingsTaskSettings taskSettings, + @Nullable AzureOpenAiSecretSettings secrets + ) { + super(new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), new ModelSecrets(secrets)); + try { + this.uri = getEmbeddingsUri(serviceSettings.resourceName(), serviceSettings.deploymentId(), serviceSettings.apiVersion()); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + private AzureOpenAiEmbeddingsModel(AzureOpenAiEmbeddingsModel originalModel, AzureOpenAiEmbeddingsTaskSettings taskSettings) { + super(originalModel, taskSettings); + } + + public AzureOpenAiEmbeddingsModel(AzureOpenAiEmbeddingsModel originalModel, AzureOpenAiEmbeddingsServiceSettings serviceSettings) { + super(originalModel, serviceSettings); + } + + @Override + public AzureOpenAiEmbeddingsServiceSettings getServiceSettings() { + return (AzureOpenAiEmbeddingsServiceSettings) super.getServiceSettings(); + } + + @Override + public AzureOpenAiEmbeddingsTaskSettings getTaskSettings() { + return (AzureOpenAiEmbeddingsTaskSettings) super.getTaskSettings(); + } + + @Override + public AzureOpenAiSecretSettings getSecretSettings() { + return (AzureOpenAiSecretSettings) super.getSecretSettings(); + } + + @Override + public ExecutableAction accept(AzureOpenAiActionVisitor creator, Map taskSettings) { + return creator.create(this, taskSettings); + } + + public static URI getEmbeddingsUri(String resourceName, String deploymentId, String apiVersion) throws URISyntaxException { + String hostname = format("%s.%s", resourceName, AzureOpenAiUtils.HOST_SUFFIX); + return new URIBuilder().setScheme("https") + .setHost(hostname) + .setPathSegments( + AzureOpenAiUtils.OPENAI_PATH, + AzureOpenAiUtils.DEPLOYMENTS_PATH, + deploymentId, + AzureOpenAiUtils.EMBEDDINGS_PATH + ) + .addParameter(AzureOpenAiUtils.API_VERSION_PARAMETER, apiVersion) + .build(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsRequestTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsRequestTaskSettings.java new file mode 100644 index 0000000000000..dc7012203a9c8 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsRequestTaskSettings.java @@ -0,0 +1,54 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureopenai.embeddings; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; + +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; +import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.USER; + +/** + * This class handles extracting Azure OpenAI task settings from a request. The difference between this class and + * {@link AzureOpenAiEmbeddingsTaskSettings} is that this class considers all fields as optional. It will not throw an error if a field + * is missing. This allows overriding persistent task settings. + * @param user a unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse + */ +public record AzureOpenAiEmbeddingsRequestTaskSettings(@Nullable String user) { + private static final Logger logger = LogManager.getLogger(AzureOpenAiEmbeddingsRequestTaskSettings.class); + + public static final AzureOpenAiEmbeddingsRequestTaskSettings EMPTY_SETTINGS = new AzureOpenAiEmbeddingsRequestTaskSettings(null); + + /** + * Extracts the task settings from a map. All settings are considered optional and the absence of a setting + * does not throw an error. + * + * @param map the settings received from a request + * @return a {@link AzureOpenAiEmbeddingsRequestTaskSettings} + */ + public static AzureOpenAiEmbeddingsRequestTaskSettings fromMap(Map map) { + if (map.isEmpty()) { + return AzureOpenAiEmbeddingsRequestTaskSettings.EMPTY_SETTINGS; + } + + ValidationException validationException = new ValidationException(); + + String user = extractOptionalString(map, USER, ModelConfigurations.TASK_SETTINGS, validationException); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new AzureOpenAiEmbeddingsRequestTaskSettings(user); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettings.java new file mode 100644 index 0000000000000..c3d9e3eb69a5d --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettings.java @@ -0,0 +1,282 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureopenai.embeddings; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceUtils; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType; +import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.API_VERSION; +import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.DEPLOYMENT_ID; +import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.RESOURCE_NAME; + +/** + * Defines the service settings for interacting with OpenAI's text embedding models. + */ +public class AzureOpenAiEmbeddingsServiceSettings implements ServiceSettings { + + public static final String NAME = "azure_openai_embeddings_service_settings"; + + static final String DIMENSIONS_SET_BY_USER = "dimensions_set_by_user"; + + public static AzureOpenAiEmbeddingsServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + + var settings = fromMap(map, validationException, context); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new AzureOpenAiEmbeddingsServiceSettings(settings); + } + + private static CommonFields fromMap( + Map map, + ValidationException validationException, + ConfigurationParseContext context + ) { + String resourceName = extractRequiredString(map, RESOURCE_NAME, ModelConfigurations.SERVICE_SETTINGS, validationException); + String deploymentId = extractRequiredString(map, DEPLOYMENT_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + String apiVersion = extractRequiredString(map, API_VERSION, ModelConfigurations.SERVICE_SETTINGS, validationException); + Integer dims = removeAsType(map, DIMENSIONS, Integer.class); + Integer maxTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); + SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); + + Boolean dimensionsSetByUser = extractOptionalBoolean( + map, + DIMENSIONS_SET_BY_USER, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + + switch (context) { + case REQUEST -> { + if (dimensionsSetByUser != null) { + validationException.addValidationError( + ServiceUtils.invalidSettingError(DIMENSIONS_SET_BY_USER, ModelConfigurations.SERVICE_SETTINGS) + ); + } + dimensionsSetByUser = dims != null; + } + case PERSISTENT -> { + if (dimensionsSetByUser == null) { + validationException.addValidationError( + ServiceUtils.missingSettingErrorMsg(DIMENSIONS_SET_BY_USER, ModelConfigurations.SERVICE_SETTINGS) + ); + } + } + } + + return new CommonFields( + resourceName, + deploymentId, + apiVersion, + dims, + Boolean.TRUE.equals(dimensionsSetByUser), + maxTokens, + similarity + ); + } + + private record CommonFields( + String resourceName, + String deploymentId, + String apiVersion, + @Nullable Integer dimensions, + Boolean dimensionsSetByUser, + @Nullable Integer maxInputTokens, + @Nullable SimilarityMeasure similarity + ) {} + + private final String resourceName; + private final String deploymentId; + private final String apiVersion; + private final Integer dimensions; + private final Boolean dimensionsSetByUser; + private final Integer maxInputTokens; + private final SimilarityMeasure similarity; + + public AzureOpenAiEmbeddingsServiceSettings( + String resourceName, + String deploymentId, + String apiVersion, + @Nullable Integer dimensions, + Boolean dimensionsSetByUser, + @Nullable Integer maxInputTokens, + @Nullable SimilarityMeasure similarity + ) { + this.resourceName = resourceName; + this.deploymentId = deploymentId; + this.apiVersion = apiVersion; + this.dimensions = dimensions; + this.dimensionsSetByUser = Objects.requireNonNull(dimensionsSetByUser); + this.maxInputTokens = maxInputTokens; + this.similarity = similarity; + } + + public AzureOpenAiEmbeddingsServiceSettings(StreamInput in) throws IOException { + resourceName = in.readString(); + deploymentId = in.readString(); + apiVersion = in.readString(); + dimensions = in.readOptionalVInt(); + dimensionsSetByUser = in.readBoolean(); + maxInputTokens = in.readOptionalVInt(); + similarity = in.readOptionalEnum(SimilarityMeasure.class); + } + + private AzureOpenAiEmbeddingsServiceSettings(CommonFields fields) { + this( + fields.resourceName, + fields.deploymentId, + fields.apiVersion, + fields.dimensions, + fields.dimensionsSetByUser, + fields.maxInputTokens, + fields.similarity + ); + } + + public String resourceName() { + return resourceName; + } + + public String deploymentId() { + return deploymentId; + } + + public String apiVersion() { + return apiVersion; + } + + @Override + public Integer dimensions() { + return dimensions; + } + + public Boolean dimensionsSetByUser() { + return dimensionsSetByUser; + } + + public Integer maxInputTokens() { + return maxInputTokens; + } + + @Override + public SimilarityMeasure similarity() { + return similarity; + } + + @Override + public DenseVectorFieldMapper.ElementType elementType() { + return DenseVectorFieldMapper.ElementType.FLOAT; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + toXContentFragmentOfExposedFields(builder, params); + + builder.field(DIMENSIONS_SET_BY_USER, dimensionsSetByUser); + + builder.endObject(); + return builder; + } + + private void toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + builder.field(RESOURCE_NAME, resourceName); + builder.field(DEPLOYMENT_ID, deploymentId); + builder.field(API_VERSION, apiVersion); + + if (dimensions != null) { + builder.field(DIMENSIONS, dimensions); + } + if (maxInputTokens != null) { + builder.field(MAX_INPUT_TOKENS, maxInputTokens); + } + if (similarity != null) { + builder.field(SIMILARITY, similarity); + } + } + + @Override + public ToXContentObject getFilteredXContentObject() { + return (builder, params) -> { + builder.startObject(); + + toXContentFragmentOfExposedFields(builder, params); + + builder.endObject(); + return builder; + }; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_AZURE_OPENAI_EMBEDDINGS; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(resourceName); + out.writeString(deploymentId); + out.writeString(apiVersion); + out.writeOptionalVInt(dimensions); + out.writeBoolean(dimensionsSetByUser); + out.writeOptionalVInt(maxInputTokens); + out.writeOptionalEnum(SimilarityMeasure.translateSimilarity(similarity, out.getTransportVersion())); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AzureOpenAiEmbeddingsServiceSettings that = (AzureOpenAiEmbeddingsServiceSettings) o; + + return Objects.equals(resourceName, that.resourceName) + && Objects.equals(deploymentId, that.deploymentId) + && Objects.equals(apiVersion, that.apiVersion) + && Objects.equals(dimensions, that.dimensions) + && Objects.equals(dimensionsSetByUser, that.dimensionsSetByUser) + && Objects.equals(maxInputTokens, that.maxInputTokens) + && Objects.equals(similarity, that.similarity); + } + + @Override + public int hashCode() { + return Objects.hash(resourceName, deploymentId, apiVersion, dimensions, dimensionsSetByUser, maxInputTokens, similarity); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsTaskSettings.java new file mode 100644 index 0000000000000..49329a55a18ef --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsTaskSettings.java @@ -0,0 +1,114 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureopenai.embeddings; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; +import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.USER; + +/** + * Defines the task settings for the openai service. + * + * User is an optional unique identifier representing the end-user, which can help OpenAI to monitor and detect abuse + * see the openai docs for more details + */ +public class AzureOpenAiEmbeddingsTaskSettings implements TaskSettings { + + public static final String NAME = "azure_openai_embeddings_task_settings"; + + public static AzureOpenAiEmbeddingsTaskSettings fromMap(Map map) { + ValidationException validationException = new ValidationException(); + + String user = extractOptionalString(map, USER, ModelConfigurations.TASK_SETTINGS, validationException); + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new AzureOpenAiEmbeddingsTaskSettings(user); + } + + /** + * Creates a new {@link AzureOpenAiEmbeddingsTaskSettings} object by overriding the values in originalSettings with the ones + * passed in via requestSettings if the fields are not null. + * @param originalSettings the original {@link AzureOpenAiEmbeddingsTaskSettings} from the inference entity configuration from storage + * @param requestSettings the {@link AzureOpenAiEmbeddingsTaskSettings} from the request + * @return a new {@link AzureOpenAiEmbeddingsTaskSettings} + */ + public static AzureOpenAiEmbeddingsTaskSettings of( + AzureOpenAiEmbeddingsTaskSettings originalSettings, + AzureOpenAiEmbeddingsRequestTaskSettings requestSettings + ) { + var userToUse = requestSettings.user() == null ? originalSettings.user : requestSettings.user(); + return new AzureOpenAiEmbeddingsTaskSettings(userToUse); + } + + private final String user; + + public AzureOpenAiEmbeddingsTaskSettings(@Nullable String user) { + this.user = user; + } + + public AzureOpenAiEmbeddingsTaskSettings(StreamInput in) throws IOException { + this.user = in.readOptionalString(); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (user != null) { + builder.field(USER, user); + } + builder.endObject(); + return builder; + } + + public String user() { + return user; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_AZURE_OPENAI_EMBEDDINGS; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(user); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AzureOpenAiEmbeddingsTaskSettings that = (AzureOpenAiEmbeddingsTaskSettings) o; + return Objects.equals(user, that.user); + } + + @Override + public int hashCode() { + return Objects.hash(user); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java new file mode 100644 index 0000000000000..4bdba67beec17 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java @@ -0,0 +1,454 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.azureopenai; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockRequest; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.external.http.retry.RetrySettingsTests.buildSettingsWithRetryFields; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModelTests.createModel; +import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsRequestTaskSettingsTests.getRequestTaskSettingsMap; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class AzureOpenAiActionCreatorTests extends ESTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testCreate_AzureOpenAiEmbeddingsModel() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = senderFactory.createSender("test_service")) { + sender.start(); + + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = createModel("resource", "deployment", "apiversion", "orig_user", "apikey", null, "id"); + model.setUri(new URI(getUrl(webServer))); + var actionCreator = new AzureOpenAiActionCreator(sender, createWithEmptySettings(threadPool)); + var overriddenTaskSettings = getRequestTaskSettingsMap("overridden_user"); + var action = (AzureOpenAiEmbeddingsAction) actionCreator.create(model, overriddenTaskSettings); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F))))); + assertThat(webServer.requests(), hasSize(1)); + validateRequestWithApiKey(webServer.requests().get(0), "apikey"); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + validateRequestMapWithUser(requestMap, List.of("abc"), "overridden_user"); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + public void testCreate_AzureOpenAiEmbeddingsModel_WithoutUser() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = senderFactory.createSender("test_service")) { + sender.start(); + + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = createModel("resource", "deployment", "apiversion", null, "apikey", null, "id"); + model.setUri(new URI(getUrl(webServer))); + var actionCreator = new AzureOpenAiActionCreator(sender, createWithEmptySettings(threadPool)); + var overriddenTaskSettings = getRequestTaskSettingsMap(null); + var action = (AzureOpenAiEmbeddingsAction) actionCreator.create(model, overriddenTaskSettings); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F))))); + assertThat(webServer.requests(), hasSize(1)); + validateRequestWithApiKey(webServer.requests().get(0), "apikey"); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + validateRequestMapWithUser(requestMap, List.of("abc"), null); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + public void testCreate_AzureOpenAiEmbeddingsModel_FailsFromInvalidResponseFormat() throws IOException { + // timeout as zero for no retries + var settings = buildSettingsWithRetryFields( + TimeValue.timeValueMillis(1), + TimeValue.timeValueMinutes(1), + TimeValue.timeValueSeconds(0) + ); + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager, settings); + + try (var sender = senderFactory.createSender("test_service")) { + sender.start(); + + String responseJson = """ + { + "object": "list", + "data_does_not_exist": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = createModel("resource", "deployment", "apiversion", null, "apikey", null, "id"); + model.setUri(new URI(getUrl(webServer))); + var actionCreator = new AzureOpenAiActionCreator(sender, createWithEmptySettings(threadPool)); + var overriddenTaskSettings = getRequestTaskSettingsMap("overridden_user"); + var action = (AzureOpenAiEmbeddingsAction) actionCreator.create(model, overriddenTaskSettings); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + thrownException.getMessage(), + is(format("Failed to send Azure OpenAI embeddings request to [%s]", getUrl(webServer))) + ); + assertThat(thrownException.getCause().getMessage(), is("Failed to find required field [data] in OpenAI embeddings response")); + + assertThat(webServer.requests(), hasSize(1)); + validateRequestWithApiKey(webServer.requests().get(0), "apikey"); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + validateRequestMapWithUser(requestMap, List.of("abc"), "overridden_user"); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From413StatusCode() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = senderFactory.createSender("test_service")) { + sender.start(); + + // note - there is no complete documentation on Azure's error messages + // but this error and response has been verified manually via CURL + var contentTooLargeErrorMessage = + "This model's maximum context length is 8192 tokens, however you requested 13531 tokens (13531 in your prompt;" + + "0 for the completion). Please reduce your prompt; or completion length."; + + String responseJsonContentTooLarge = Strings.format(""" + { + "error": { + "message": "%s", + "type": "invalid_request_error", + "param": null, + "code": null + } + } + """, contentTooLargeErrorMessage); + + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(413).setBody(responseJsonContentTooLarge)); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = createModel("resource", "deployment", "apiversion", null, "apikey", null, "id"); + model.setUri(new URI(getUrl(webServer))); + var actionCreator = new AzureOpenAiActionCreator(sender, createWithEmptySettings(threadPool)); + var overriddenTaskSettings = getRequestTaskSettingsMap("overridden_user"); + var action = (AzureOpenAiEmbeddingsAction) actionCreator.create(model, overriddenTaskSettings); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abcd")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F))))); + assertThat(webServer.requests(), hasSize(2)); + { + validateRequestWithApiKey(webServer.requests().get(0), "apikey"); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + validateRequestMapWithUser(requestMap, List.of("abcd"), "overridden_user"); + } + { + validateRequestWithApiKey(webServer.requests().get(1), "apikey"); + + var requestMap = entityAsMap(webServer.requests().get(1).getBody()); + validateRequestMapWithUser(requestMap, List.of("ab"), "overridden_user"); + } + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + public void testExecute_ReturnsSuccessfulResponse_AfterTruncating_From400StatusCode() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = senderFactory.createSender("test_service")) { + sender.start(); + + // note - there is no complete documentation on Azure's error messages + // but this error and response has been verified manually via CURL + var contentTooLargeErrorMessage = + "This model's maximum context length is 8192 tokens, however you requested 13531 tokens (13531 in your prompt;" + + "0 for the completion). Please reduce your prompt; or completion length."; + + String responseJsonContentTooLarge = Strings.format(""" + { + "error": { + "message": "%s", + "type": "invalid_request_error", + "param": null, + "code": null + } + } + """, contentTooLargeErrorMessage); + + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(400).setBody(responseJsonContentTooLarge)); + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = createModel("resource", "deployment", "apiversion", null, "apikey", null, "id"); + model.setUri(new URI(getUrl(webServer))); + var actionCreator = new AzureOpenAiActionCreator(sender, createWithEmptySettings(threadPool)); + var overriddenTaskSettings = getRequestTaskSettingsMap("overridden_user"); + var action = (AzureOpenAiEmbeddingsAction) actionCreator.create(model, overriddenTaskSettings); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abcd")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F))))); + assertThat(webServer.requests(), hasSize(2)); + { + validateRequestWithApiKey(webServer.requests().get(0), "apikey"); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + validateRequestMapWithUser(requestMap, List.of("abcd"), "overridden_user"); + } + { + validateRequestWithApiKey(webServer.requests().get(1), "apikey"); + + var requestMap = entityAsMap(webServer.requests().get(1).getBody()); + validateRequestMapWithUser(requestMap, List.of("ab"), "overridden_user"); + } + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + public void testExecute_TruncatesInputBeforeSending() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var sender = senderFactory.createSender("test_service")) { + sender.start(); + + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + // truncated to 1 token = 3 characters + var model = createModel("resource", "deployment", "apiversion", null, false, 1, null, null, "apikey", null, "id"); + model.setUri(new URI(getUrl(webServer))); + var actionCreator = new AzureOpenAiActionCreator(sender, createWithEmptySettings(threadPool)); + var overriddenTaskSettings = getRequestTaskSettingsMap("overridden_user"); + var action = (AzureOpenAiEmbeddingsAction) actionCreator.create(model, overriddenTaskSettings); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("super long input")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F))))); + assertThat(webServer.requests(), hasSize(1)); + validateRequestWithApiKey(webServer.requests().get(0), "apikey"); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + validateRequestMapWithUser(requestMap, List.of("sup"), "overridden_user"); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + private void validateRequestMapWithUser(Map requestMap, List input, @Nullable String user) { + var expectedSize = user == null ? 1 : 2; + + assertThat(requestMap.size(), is(expectedSize)); + assertThat(requestMap.get("input"), is(input)); + + if (user != null) { + assertThat(requestMap.get("user"), is(user)); + } + } + + private void validateRequestWithApiKey(MockRequest request, String apiKey) { + assertNull(request.getUri().getQuery()); + assertThat(request.getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + assertThat(request.getHeader(AzureOpenAiUtils.API_KEY_HEADER), equalTo(apiKey)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiEmbeddingsActionTests.java new file mode 100644 index 0000000000000..e8eac1a13b180 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiEmbeddingsActionTests.java @@ -0,0 +1,219 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.azureopenai; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.ServiceComponentsTests; +import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.List; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModelTests.createModel; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; + +public class AzureOpenAiEmbeddingsActionTests extends ESTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testExecute_ReturnsSuccessfulResponse() throws IOException { + var senderFactory = new HttpRequestSender.Factory( + ServiceComponentsTests.createWithEmptySettings(threadPool), + clientManager, + mockClusterServiceEmpty() + ); + + try (var sender = senderFactory.createSender("test_service")) { + sender.start(); + + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var action = createAction("resource", "deployment", "apiVersion", "user", "apikey", sender, "id"); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectation(List.of(List.of(0.0123F, -0.0123F))))); + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + assertThat(webServer.requests().get(0).getHeader(AzureOpenAiUtils.API_KEY_HEADER), equalTo("apikey")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), is(2)); + assertThat(requestMap.get("input"), is(List.of("abc"))); + assertThat(requestMap.get("user"), is("user")); + } + } + + public void testExecute_ThrowsElasticsearchException() { + var sender = mock(Sender.class); + doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any(), any(), any()); + + var action = createAction("resource", "deployment", "apiVersion", "user", "apikey", sender, "id"); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat(thrownException.getMessage(), is("failed")); + } + + public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled() { + var sender = mock(Sender.class); + + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onFailure(new IllegalStateException("failed")); + + return Void.TYPE; + }).when(sender).send(any(), any(), any(), any()); + + var action = createAction("resource", "deployment", "apiVersion", "user", "apikey", sender, "id"); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat(thrownException.getMessage(), is(format("Failed to send Azure OpenAI embeddings request to [%s]", getUrl(webServer)))); + } + + public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled_WhenUrlIsNull() { + var sender = mock(Sender.class); + + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) invocation.getArguments()[1]; + listener.onFailure(new IllegalStateException("failed")); + + return Void.TYPE; + }).when(sender).send(any(), any(), any(), any()); + + var action = createAction("resource", "deployment", "apiVersion", "user", "apikey", sender, "id"); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat(thrownException.getMessage(), is(format("Failed to send Azure OpenAI embeddings request to [%s]", getUrl(webServer)))); + } + + public void testExecute_ThrowsException() { + var sender = mock(Sender.class); + doThrow(new IllegalArgumentException("failed")).when(sender).send(any(), any(), any(), any()); + + var action = createAction("resource", "deployment", "apiVersion", "user", "apikey", sender, "id"); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat(thrownException.getMessage(), is(format("Failed to send Azure OpenAI embeddings request to [%s]", getUrl(webServer)))); + } + + private AzureOpenAiEmbeddingsAction createAction( + String resourceName, + String deploymentId, + String apiVersion, + @Nullable String user, + String apiKey, + Sender sender, + String inferenceEntityId + ) { + AzureOpenAiEmbeddingsModel model = null; + try { + model = createModel(resourceName, deploymentId, apiVersion, user, apiKey, null, inferenceEntityId); + model.setUri(new URI(getUrl(webServer))); + var action = new AzureOpenAiEmbeddingsAction(sender, model, createWithEmptySettings(threadPool)); + return action; + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/azureopenai/AzureOpenAiResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/azureopenai/AzureOpenAiResponseHandlerTests.java new file mode 100644 index 0000000000000..b18d9d76651d5 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/azureopenai/AzureOpenAiResponseHandlerTests.java @@ -0,0 +1,88 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.azureopenai; + +import org.apache.http.HttpResponse; +import org.apache.http.StatusLine; +import org.apache.http.message.BasicHeader; +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.http.HttpResult; + +import java.nio.charset.StandardCharsets; + +import static org.hamcrest.Matchers.containsString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class AzureOpenAiResponseHandlerTests extends ESTestCase { + + public void testBuildRateLimitErrorMessage() { + int statusCode = 429; + var statusLine = mock(StatusLine.class); + when(statusLine.getStatusCode()).thenReturn(statusCode); + var response = mock(HttpResponse.class); + when(response.getStatusLine()).thenReturn(statusLine); + var httpResult = new HttpResult(response, new byte[] {}); + + { + when(response.getFirstHeader(AzureOpenAiResponseHandler.REMAINING_REQUESTS)).thenReturn( + new BasicHeader(AzureOpenAiResponseHandler.REMAINING_REQUESTS, "2999") + ); + when(response.getFirstHeader(AzureOpenAiResponseHandler.REMAINING_TOKENS)).thenReturn( + new BasicHeader(AzureOpenAiResponseHandler.REMAINING_TOKENS, "99800") + ); + + var error = AzureOpenAiResponseHandler.buildRateLimitErrorMessage(httpResult); + assertThat(error, containsString("Remaining tokens [99800]. Remaining requests [2999]")); + } + + { + when(response.getFirstHeader(AzureOpenAiResponseHandler.REMAINING_TOKENS)).thenReturn(null); + var error = AzureOpenAiResponseHandler.buildRateLimitErrorMessage(httpResult); + assertThat(error, containsString("Remaining tokens [unknown]. Remaining requests [2999]")); + } + + { + when(response.getFirstHeader(AzureOpenAiResponseHandler.REMAINING_REQUESTS)).thenReturn( + new BasicHeader(AzureOpenAiResponseHandler.REMAINING_REQUESTS, "2999") + ); + when(response.getFirstHeader(AzureOpenAiResponseHandler.REMAINING_TOKENS)).thenReturn(null); + var error = AzureOpenAiResponseHandler.buildRateLimitErrorMessage(httpResult); + assertThat(error, containsString("Remaining tokens [unknown]. Remaining requests [2999]")); + } + } + + private static HttpResult createContentTooLargeResult(int statusCode) { + return createResult( + statusCode, + "This model's maximum context length is 8192 tokens, however you requested 13531 tokens (13531 in your prompt;" + + "0 for the completion). Please reduce your prompt; or completion length." + ); + } + + private static HttpResult createResult(int statusCode, String message) { + var statusLine = mock(StatusLine.class); + when(statusLine.getStatusCode()).thenReturn(statusCode); + var httpResponse = mock(HttpResponse.class); + when(httpResponse.getStatusLine()).thenReturn(statusLine); + + String responseJson = Strings.format(""" + { + "error": { + "message": "%s", + "type": "content_too_large", + "param": null, + "code": null + } + } + """, message); + + return new HttpResult(httpResponse, responseJson.getBytes(StandardCharsets.UTF_8)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiEmbeddingsRequestEntityTests.java new file mode 100644 index 0000000000000..14283ed53eed9 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiEmbeddingsRequestEntityTests.java @@ -0,0 +1,77 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.azureopenai; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; + +import java.io.IOException; +import java.util.List; + +import static org.hamcrest.CoreMatchers.is; + +public class AzureOpenAiEmbeddingsRequestEntityTests extends ESTestCase { + + public void testXContent_WritesUserWhenDefined() throws IOException { + var entity = new AzureOpenAiEmbeddingsRequestEntity(List.of("abc"), "testuser", null, false); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(""" + {"input":["abc"],"user":"testuser"}""")); + } + + public void testXContent_DoesNotWriteUserWhenItIsNull() throws IOException { + var entity = new AzureOpenAiEmbeddingsRequestEntity(List.of("abc"), null, null, false); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(""" + {"input":["abc"]}""")); + } + + public void testXContent_DoesNotWriteDimensionsWhenNotSetByUser() throws IOException { + var entity = new AzureOpenAiEmbeddingsRequestEntity(List.of("abc"), null, 100, false); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(""" + {"input":["abc"]}""")); + } + + public void testXContent_DoesNotWriteDimensionsWhenNull_EvenIfSetByUserIsTrue() throws IOException { + var entity = new AzureOpenAiEmbeddingsRequestEntity(List.of("abc"), null, null, true); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(""" + {"input":["abc"]}""")); + } + + public void testXContent_WritesDimensionsWhenNonNull_AndSetByUserIsTrue() throws IOException { + var entity = new AzureOpenAiEmbeddingsRequestEntity(List.of("abc"), null, 100, true); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(""" + {"input":["abc"],"dimensions":100}""")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiEmbeddingsRequestTests.java new file mode 100644 index 0000000000000..8e7c831a9820f --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/azureopenai/AzureOpenAiEmbeddingsRequestTests.java @@ -0,0 +1,118 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.azureopenai; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.common.TruncatorTests; +import org.elasticsearch.xpack.inference.external.azureopenai.AzureOpenAiAccount; +import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModelTests; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.util.List; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils.API_KEY_HEADER; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class AzureOpenAiEmbeddingsRequestTests extends ESTestCase { + public void testCreateRequest_WithApiKeyDefined() throws IOException, URISyntaxException { + var request = createRequest("resource", "deployment", "apiVersion", "apikey", null, "abc", "user"); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + var expectedUri = AzureOpenAiEmbeddingsModel.getEmbeddingsUri("resource", "deployment", "apiVersion").toString(); + assertThat(httpPost.getURI().toString(), is(expectedUri)); + + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(httpPost.getLastHeader(API_KEY_HEADER).getValue(), is("apikey")); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(2)); + assertThat(requestMap.get("input"), is(List.of("abc"))); + assertThat(requestMap.get("user"), is("user")); + } + + public void testCreateRequest_WithEntraIdDefined() throws IOException, URISyntaxException { + var request = createRequest("resource", "deployment", "apiVersion", null, "entraId", "abc", "user"); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + var expectedUri = AzureOpenAiEmbeddingsModel.getEmbeddingsUri("resource", "deployment", "apiVersion").toString(); + assertThat(httpPost.getURI().toString(), is(expectedUri)); + + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer entraId")); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(2)); + assertThat(requestMap.get("input"), is(List.of("abc"))); + assertThat(requestMap.get("user"), is("user")); + } + + public void testTruncate_ReducesInputTextSizeByHalf() throws IOException { + var request = createRequest("resource", "deployment", "apiVersion", "apikey", null, "abcd", null); + var truncatedRequest = request.truncate(); + + var httpRequest = truncatedRequest.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(1)); + assertThat(requestMap.get("input"), is(List.of("ab"))); + } + + public void testIsTruncated_ReturnsTrue() { + var request = createRequest("resource", "deployment", "apiVersion", "apikey", null, "abcd", null); + assertFalse(request.getTruncationInfo()[0]); + + var truncatedRequest = request.truncate(); + assertTrue(truncatedRequest.getTruncationInfo()[0]); + } + + public static AzureOpenAiEmbeddingsRequest createRequest( + String resourceName, + String deploymentId, + String apiVersion, + @Nullable String apiKey, + @Nullable String entraId, + String input, + @Nullable String user + ) { + var embeddingsModel = AzureOpenAiEmbeddingsModelTests.createModel( + resourceName, + deploymentId, + apiVersion, + user, + apiKey, + entraId, + "id" + ); + var account = AzureOpenAiAccount.fromModel(embeddingsModel); + + return new AzureOpenAiEmbeddingsRequest( + TruncatorTests.createTruncator(), + account, + new Truncator.TruncationResult(List.of(input), new boolean[] { false }), + embeddingsModel + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiSecretSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiSecretSettingsTests.java new file mode 100644 index 0000000000000..97fa6efc962bb --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiSecretSettingsTests.java @@ -0,0 +1,160 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureopenai; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.hamcrest.CoreMatchers; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings.API_KEY; +import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings.ENTRA_ID; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class AzureOpenAiSecretSettingsTests extends AbstractWireSerializingTestCase { + + public static AzureOpenAiSecretSettings createRandom() { + return new AzureOpenAiSecretSettings( + new SecureString(randomAlphaOfLength(15).toCharArray()), + new SecureString(randomAlphaOfLength(15).toCharArray()) + ); + } + + public void testFromMap_ApiKey_Only() { + var serviceSettings = AzureOpenAiSecretSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiSecretSettings.API_KEY, "abc"))); + assertThat(new AzureOpenAiSecretSettings(new SecureString("abc".toCharArray()), null), is(serviceSettings)); + } + + public void testFromMap_EntraId_Only() { + var serviceSettings = AzureOpenAiSecretSettings.fromMap(new HashMap<>(Map.of(ENTRA_ID, "xyz"))); + assertThat(new AzureOpenAiSecretSettings(null, new SecureString("xyz".toCharArray())), is(serviceSettings)); + } + + public void testFromMap_ReturnsNull_WhenMapIsNull() { + assertNull(AzureOpenAiSecretSettings.fromMap(null)); + } + + public void testFromMap_MissingApiKeyAndEntraId_ThrowsError() { + var thrownException = expectThrows(ValidationException.class, () -> AzureOpenAiSecretSettings.fromMap(new HashMap<>())); + + assertThat( + thrownException.getMessage(), + containsString( + Strings.format( + "[secret_settings] must have either the [%s] or the [%s] key set", + AzureOpenAiSecretSettings.API_KEY, + ENTRA_ID + ) + ) + ); + } + + public void testFromMap_HasBothApiKeyAndEntraId_ThrowsError() { + var mapValues = getAzureOpenAiSecretSettingsMap("apikey", "entraid"); + var thrownException = expectThrows(ValidationException.class, () -> AzureOpenAiSecretSettings.fromMap(mapValues)); + + assertThat( + thrownException.getMessage(), + containsString( + Strings.format( + "[secret_settings] must have only one of the [%s] or the [%s] key set", + AzureOpenAiSecretSettings.API_KEY, + ENTRA_ID + ) + ) + ); + } + + public void testFromMap_EmptyApiKey_ThrowsError() { + var thrownException = expectThrows( + ValidationException.class, + () -> AzureOpenAiSecretSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiSecretSettings.API_KEY, ""))) + ); + + assertThat( + thrownException.getMessage(), + containsString( + Strings.format( + "[secret_settings] Invalid value empty string. [%s] must be a non-empty string", + AzureOpenAiSecretSettings.API_KEY + ) + ) + ); + } + + public void testFromMap_EmptyEntraId_ThrowsError() { + var thrownException = expectThrows( + ValidationException.class, + () -> AzureOpenAiSecretSettings.fromMap(new HashMap<>(Map.of(ENTRA_ID, ""))) + ); + + assertThat( + thrownException.getMessage(), + containsString(Strings.format("[secret_settings] Invalid value empty string. [%s] must be a non-empty string", ENTRA_ID)) + ); + } + + // test toXContent + public void testToXContext_WritesApiKeyOnlyWhenEntraIdIsNull() throws IOException { + var testSettings = new AzureOpenAiSecretSettings(new SecureString("apikey"), null); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + testSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + var expectedResult = Strings.format("{\"%s\":\"apikey\"}", API_KEY); + assertThat(xContentResult, CoreMatchers.is(expectedResult)); + } + + public void testToXContext_WritesEntraIdOnlyWhenApiKeyIsNull() throws IOException { + var testSettings = new AzureOpenAiSecretSettings(null, new SecureString("entraid")); + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + testSettings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + var expectedResult = Strings.format("{\"%s\":\"entraid\"}", ENTRA_ID); + assertThat(xContentResult, CoreMatchers.is(expectedResult)); + } + + @Override + protected Writeable.Reader instanceReader() { + return AzureOpenAiSecretSettings::new; + } + + @Override + protected AzureOpenAiSecretSettings createTestInstance() { + return createRandom(); + } + + @Override + protected AzureOpenAiSecretSettings mutateInstance(AzureOpenAiSecretSettings instance) throws IOException { + return createRandom(); + } + + public static Map getAzureOpenAiSecretSettingsMap(@Nullable String apiKey, @Nullable String entraId) { + var map = new HashMap(); + if (apiKey != null) { + map.put(AzureOpenAiSecretSettings.API_KEY, apiKey); + } + if (entraId != null) { + map.put(ENTRA_ID, entraId); + } + return map; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java new file mode 100644 index 0000000000000..4e65d987a26ad --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -0,0 +1,1180 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + * + * this file was contributed to by a generative AI + */ + +package org.elasticsearch.xpack.inference.services.azureopenai; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.ChunkedNlpInferenceResults; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsModelTests; +import org.hamcrest.CoreMatchers; +import org.hamcrest.MatcherAssert; +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils.API_KEY_HEADER; +import static org.elasticsearch.xpack.inference.results.ChunkedTextEmbeddingResultsTests.asMapWithListsInsteadOfArrays; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectation; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.Utils.getInvalidModel; +import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettingsTests.getAzureOpenAiSecretSettingsMap; +import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsServiceSettingsTests.getPersistentAzureOpenAiServiceSettingsMap; +import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsServiceSettingsTests.getRequestAzureOpenAiServiceSettingsMap; +import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsTaskSettingsTests.getAzureOpenAiRequestTaskSettingsMap; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +public class AzureOpenAiServiceTests extends ESTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testParseRequestConfig_CreatesAnOpenAiEmbeddingsModel() throws IOException { + try (var service = createAzureOpenAiService()) { + ActionListener modelVerificationListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class)); + + var embeddingsModel = (AzureOpenAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().resourceName(), is("resource_name")); + assertThat(embeddingsModel.getServiceSettings().deploymentId(), is("deployment_id")); + assertThat(embeddingsModel.getServiceSettings().apiVersion(), is("api_version")); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertThat(embeddingsModel.getTaskSettings().user(), is("user")); + }, exception -> fail("Unexpected exception: " + exception)); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + getRequestAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", null, null), + getAzureOpenAiRequestTaskSettingsMap("user"), + getAzureOpenAiSecretSettingsMap("secret", null) + ), + Set.of(), + modelVerificationListener + ); + } + } + + public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException { + try (var service = createAzureOpenAiService()) { + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat(exception.getMessage(), is("The [azureopenai] service does not support task type [sparse_embedding]")); + } + ); + + service.parseRequestConfig( + "id", + TaskType.SPARSE_EMBEDDING, + getRequestConfigMap( + getRequestAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", null, null), + getAzureOpenAiRequestTaskSettingsMap("user"), + getAzureOpenAiSecretSettingsMap("secret", null) + ), + Set.of(), + modelVerificationListener + ); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException { + try (var service = createAzureOpenAiService()) { + var config = getRequestConfigMap( + getRequestAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", null, null), + getAzureOpenAiRequestTaskSettingsMap("user"), + getAzureOpenAiSecretSettingsMap("secret", null) + ); + config.put("extra_key", "value"); + + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat( + exception.getMessage(), + is("Model configuration contains settings [{extra_key=value}] unknown to the [azureopenai] service") + ); + } + ); + + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException { + try (var service = createAzureOpenAiService()) { + var serviceSettings = getRequestAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", null, null); + serviceSettings.put("extra_key", "value"); + + var config = getRequestConfigMap( + serviceSettings, + getAzureOpenAiRequestTaskSettingsMap("user"), + getAzureOpenAiSecretSettingsMap("secret", null) + ); + + ActionListener modelVerificationListener = ActionListener.wrap((model) -> { + fail("Expected exception, but got model: " + model); + }, e -> { + assertThat(e, instanceOf(ElasticsearchStatusException.class)); + assertThat( + e.getMessage(), + is("Model configuration contains settings [{extra_key=value}] unknown to the [azureopenai] service") + ); + }); + + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() throws IOException { + try (var service = createAzureOpenAiService()) { + var taskSettingsMap = getAzureOpenAiRequestTaskSettingsMap("user"); + taskSettingsMap.put("extra_key", "value"); + + var config = getRequestConfigMap( + getRequestAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", null, null), + taskSettingsMap, + getAzureOpenAiSecretSettingsMap("secret", null) + ); + + ActionListener modelVerificationListener = ActionListener.wrap((model) -> { + fail("Expected exception, but got model: " + model); + }, e -> { + assertThat(e, instanceOf(ElasticsearchStatusException.class)); + assertThat( + e.getMessage(), + is("Model configuration contains settings [{extra_key=value}] unknown to the [azureopenai] service") + ); + }); + + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap() throws IOException { + try (var service = createAzureOpenAiService()) { + var secretSettingsMap = getAzureOpenAiSecretSettingsMap("secret", null); + secretSettingsMap.put("extra_key", "value"); + + var config = getRequestConfigMap( + getRequestAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", null, null), + getAzureOpenAiRequestTaskSettingsMap("user"), + secretSettingsMap + ); + + ActionListener modelVerificationListener = ActionListener.wrap((model) -> { + fail("Expected exception, but got model: " + model); + }, e -> { + assertThat(e, instanceOf(ElasticsearchStatusException.class)); + assertThat( + e.getMessage(), + is("Model configuration contains settings [{extra_key=value}] unknown to the [azureopenai] service") + ); + }); + + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener); + } + } + + public void testParseRequestConfig_MovesModel() throws IOException { + try (var service = createAzureOpenAiService()) { + ActionListener modelVerificationListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class)); + + var embeddingsModel = (AzureOpenAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().resourceName(), is("resource_name")); + assertThat(embeddingsModel.getServiceSettings().deploymentId(), is("deployment_id")); + assertThat(embeddingsModel.getServiceSettings().apiVersion(), is("api_version")); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + assertThat(embeddingsModel.getTaskSettings().user(), is("user")); + }, exception -> fail("Unexpected exception: " + exception)); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + getRequestAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", null, null), + getAzureOpenAiRequestTaskSettingsMap("user"), + getAzureOpenAiSecretSettingsMap("secret", null) + ), + Set.of(), + modelVerificationListener + ); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAnAzureOpenAiEmbeddingsModel() throws IOException { + try (var service = createAzureOpenAiService()) { + var persistedConfig = getPersistedConfigMap( + getPersistentAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", 100, 512), + getAzureOpenAiRequestTaskSettingsMap("user"), + getAzureOpenAiSecretSettingsMap("secret", null) + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class)); + + var embeddingsModel = (AzureOpenAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().resourceName(), is("resource_name")); + assertThat(embeddingsModel.getServiceSettings().deploymentId(), is("deployment_id")); + assertThat(embeddingsModel.getServiceSettings().apiVersion(), is("api_version")); + assertThat(embeddingsModel.getServiceSettings().dimensions(), is(100)); + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); + assertThat(embeddingsModel.getTaskSettings().user(), is("user")); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidModel() throws IOException { + try (var service = createAzureOpenAiService()) { + var persistedConfig = getPersistedConfigMap( + getPersistentAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", null, null), + getAzureOpenAiRequestTaskSettingsMap("user"), + getAzureOpenAiSecretSettingsMap("secret", null) + ); + + var thrownException = expectThrows( + ElasticsearchStatusException.class, + () -> service.parsePersistedConfigWithSecrets( + "id", + TaskType.SPARSE_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ) + ); + + assertThat( + thrownException.getMessage(), + is("Failed to parse stored model [id] for [azureopenai] service, please delete and add the service again") + ); + } + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { + try (var service = createAzureOpenAiService()) { + var persistedConfig = getPersistedConfigMap( + getPersistentAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", 100, 512), + getAzureOpenAiRequestTaskSettingsMap("user"), + getAzureOpenAiSecretSettingsMap("secret", null) + ); + persistedConfig.config().put("extra_key", "value"); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class)); + + var embeddingsModel = (AzureOpenAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().resourceName(), is("resource_name")); + assertThat(embeddingsModel.getServiceSettings().deploymentId(), is("deployment_id")); + assertThat(embeddingsModel.getServiceSettings().apiVersion(), is("api_version")); + assertThat(embeddingsModel.getServiceSettings().dimensions(), is(100)); + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); + assertThat(embeddingsModel.getTaskSettings().user(), is("user")); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { + try (var service = createAzureOpenAiService()) { + var secretSettingsMap = getAzureOpenAiSecretSettingsMap("secret", null); + secretSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap( + getPersistentAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", 100, 512), + getAzureOpenAiRequestTaskSettingsMap("user"), + secretSettingsMap + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class)); + + var embeddingsModel = (AzureOpenAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().resourceName(), is("resource_name")); + assertThat(embeddingsModel.getServiceSettings().deploymentId(), is("deployment_id")); + assertThat(embeddingsModel.getServiceSettings().apiVersion(), is("api_version")); + assertThat(embeddingsModel.getServiceSettings().dimensions(), is(100)); + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); + assertThat(embeddingsModel.getTaskSettings().user(), is("user")); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSecrets() throws IOException { + try (var service = createAzureOpenAiService()) { + var persistedConfig = getPersistedConfigMap( + getPersistentAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", 100, 512), + getAzureOpenAiRequestTaskSettingsMap("user"), + getAzureOpenAiSecretSettingsMap("secret", null) + ); + persistedConfig.secrets.put("extra_key", "value"); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class)); + + var embeddingsModel = (AzureOpenAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().resourceName(), is("resource_name")); + assertThat(embeddingsModel.getServiceSettings().deploymentId(), is("deployment_id")); + assertThat(embeddingsModel.getServiceSettings().apiVersion(), is("api_version")); + assertThat(embeddingsModel.getServiceSettings().dimensions(), is(100)); + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); + assertThat(embeddingsModel.getTaskSettings().user(), is("user")); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { + try (var service = createAzureOpenAiService()) { + var serviceSettingsMap = getPersistentAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", 100, 512); + serviceSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap( + serviceSettingsMap, + getAzureOpenAiRequestTaskSettingsMap("user"), + getAzureOpenAiSecretSettingsMap("secret", null) + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class)); + + var embeddingsModel = (AzureOpenAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().resourceName(), is("resource_name")); + assertThat(embeddingsModel.getServiceSettings().deploymentId(), is("deployment_id")); + assertThat(embeddingsModel.getServiceSettings().apiVersion(), is("api_version")); + assertThat(embeddingsModel.getServiceSettings().dimensions(), is(100)); + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); + assertThat(embeddingsModel.getTaskSettings().user(), is("user")); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { + try (var service = createAzureOpenAiService()) { + var taskSettingsMap = getAzureOpenAiRequestTaskSettingsMap("user"); + taskSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap( + getPersistentAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", 100, 512), + taskSettingsMap, + getAzureOpenAiSecretSettingsMap("secret", null) + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class)); + + var embeddingsModel = (AzureOpenAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().resourceName(), is("resource_name")); + assertThat(embeddingsModel.getServiceSettings().deploymentId(), is("deployment_id")); + assertThat(embeddingsModel.getServiceSettings().apiVersion(), is("api_version")); + assertThat(embeddingsModel.getServiceSettings().dimensions(), is(100)); + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); + assertThat(embeddingsModel.getTaskSettings().user(), is("user")); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfig_CreatesAnAzureOpenAiEmbeddingsModel() throws IOException { + try (var service = createAzureOpenAiService()) { + var persistedConfig = getPersistedConfigMap( + getPersistentAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", null, null), + getAzureOpenAiRequestTaskSettingsMap("user") + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class)); + + var embeddingsModel = (AzureOpenAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().resourceName(), is("resource_name")); + assertThat(embeddingsModel.getServiceSettings().deploymentId(), is("deployment_id")); + assertThat(embeddingsModel.getServiceSettings().apiVersion(), is("api_version")); + assertThat(embeddingsModel.getTaskSettings().user(), is("user")); + assertNull(embeddingsModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() throws IOException { + try (var service = createAzureOpenAiService()) { + var persistedConfig = getPersistedConfigMap( + getPersistentAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", null, null), + getAzureOpenAiRequestTaskSettingsMap("user") + ); + + var thrownException = expectThrows( + ElasticsearchStatusException.class, + () -> service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config()) + ); + + assertThat( + thrownException.getMessage(), + is("Failed to parse stored model [id] for [azureopenai] service, please delete and add the service again") + ); + } + } + + public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { + try (var service = createAzureOpenAiService()) { + var persistedConfig = getPersistedConfigMap( + getPersistentAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", null, null), + getAzureOpenAiRequestTaskSettingsMap("user") + ); + persistedConfig.config().put("extra_key", "value"); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class)); + + var embeddingsModel = (AzureOpenAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().resourceName(), is("resource_name")); + assertThat(embeddingsModel.getServiceSettings().deploymentId(), is("deployment_id")); + assertThat(embeddingsModel.getServiceSettings().apiVersion(), is("api_version")); + assertThat(embeddingsModel.getTaskSettings().user(), is("user")); + assertNull(embeddingsModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { + try (var service = createAzureOpenAiService()) { + var serviceSettingsMap = getPersistentAzureOpenAiServiceSettingsMap( + "resource_name", + "deployment_id", + "api_version", + null, + null + ); + serviceSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap(serviceSettingsMap, getAzureOpenAiRequestTaskSettingsMap("user")); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class)); + + var embeddingsModel = (AzureOpenAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().resourceName(), is("resource_name")); + assertThat(embeddingsModel.getServiceSettings().deploymentId(), is("deployment_id")); + assertThat(embeddingsModel.getServiceSettings().apiVersion(), is("api_version")); + assertThat(embeddingsModel.getTaskSettings().user(), is("user")); + assertNull(embeddingsModel.getSecretSettings()); + } + } + + public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { + try (var service = createAzureOpenAiService()) { + var taskSettingsMap = getAzureOpenAiRequestTaskSettingsMap("user"); + taskSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap( + getPersistentAzureOpenAiServiceSettingsMap("resource_name", "deployment_id", "api_version", null, null), + taskSettingsMap + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + assertThat(model, instanceOf(AzureOpenAiEmbeddingsModel.class)); + + var embeddingsModel = (AzureOpenAiEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().resourceName(), is("resource_name")); + assertThat(embeddingsModel.getServiceSettings().deploymentId(), is("deployment_id")); + assertThat(embeddingsModel.getServiceSettings().apiVersion(), is("api_version")); + assertThat(embeddingsModel.getTaskSettings().user(), is("user")); + assertNull(embeddingsModel.getSecretSettings()); + } + } + + public void testInfer_ThrowsErrorWhenModelIsNotAzureOpenAiModel() throws IOException { + var sender = mock(Sender.class); + + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender(anyString())).thenReturn(sender); + + var mockModel = getInvalidModel("model_id", "service_name"); + + try (var service = new AzureOpenAiService(factory, createWithEmptySettings(threadPool))) { + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + mockModel, + null, + List.of(""), + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + thrownException.getMessage(), + is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") + ); + + verify(factory, times(1)).createSender(anyString()); + verify(sender, times(1)).start(); + } + + verify(sender, times(1)).close(); + verifyNoMoreInteractions(factory); + verifyNoMoreInteractions(sender); + } + + public void testInfer_SendsRequest() throws IOException, URISyntaxException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = AzureOpenAiEmbeddingsModelTests.createModel("resource", "deployment", "apiversion", "user", "apikey", null, "id"); + model.setUri(new URI(getUrl(webServer))); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + List.of("abc"), + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), Matchers.is(buildExpectation(List.of(List.of(0.0123F, -0.0123F))))); + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + assertThat(webServer.requests().get(0).getHeader(API_KEY_HEADER), equalTo("apikey")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), Matchers.is(2)); + assertThat(requestMap.get("input"), Matchers.is(List.of("abc"))); + assertThat(requestMap.get("user"), Matchers.is("user")); + } + } + + public void testCheckModelConfig_IncludesMaxTokens() throws IOException, URISyntaxException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = AzureOpenAiEmbeddingsModelTests.createModel( + "resource", + "deployment", + "apiversion", + null, + false, + 100, + null, + "user", + "apikey", + null, + "id" + ); + model.setUri(new URI(getUrl(webServer))); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.checkModelConfig(model, listener); + + var result = listener.actionGet(TIMEOUT); + assertThat( + result, + is( + AzureOpenAiEmbeddingsModelTests.createModel( + "resource", + "deployment", + "apiversion", + 2, + false, + 100, + SimilarityMeasure.DOT_PRODUCT, + "user", + "apikey", + null, + "id" + ) + ) + ); + + assertThat(webServer.requests(), hasSize(1)); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat(requestMap, Matchers.is(Map.of("input", List.of("how big"), "user", "user"))); + } + } + + public void testCheckModelConfig_HasSimilarity() throws IOException, URISyntaxException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = AzureOpenAiEmbeddingsModelTests.createModel( + "resource", + "deployment", + "apiversion", + null, + false, + null, + SimilarityMeasure.COSINE, + "user", + "apikey", + null, + "id" + ); + model.setUri(new URI(getUrl(webServer))); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.checkModelConfig(model, listener); + + var result = listener.actionGet(TIMEOUT); + assertThat( + result, + is( + AzureOpenAiEmbeddingsModelTests.createModel( + "resource", + "deployment", + "apiversion", + 2, + false, + null, + SimilarityMeasure.COSINE, + "user", + "apikey", + null, + "id" + ) + ) + ); + + assertThat(webServer.requests(), hasSize(1)); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat(requestMap, Matchers.is(Map.of("input", List.of("how big"), "user", "user"))); + } + } + + public void testCheckModelConfig_AddsDefaultSimilarityDotProduct() throws IOException, URISyntaxException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = AzureOpenAiEmbeddingsModelTests.createModel( + "resource", + "deployment", + "apiversion", + null, + false, + null, + null, + "user", + "apikey", + null, + "id" + ); + model.setUri(new URI(getUrl(webServer))); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.checkModelConfig(model, listener); + + var result = listener.actionGet(TIMEOUT); + assertThat( + result, + is( + AzureOpenAiEmbeddingsModelTests.createModel( + "resource", + "deployment", + "apiversion", + 2, + false, + null, + SimilarityMeasure.DOT_PRODUCT, + "user", + "apikey", + null, + "id" + ) + ) + ); + + assertThat(webServer.requests(), hasSize(1)); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat(requestMap, Matchers.is(Map.of("input", List.of("how big"), "user", "user"))); + } + } + + public void testCheckModelConfig_ThrowsIfEmbeddingSizeDoesNotMatchValueSetByUser() throws IOException, URISyntaxException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = AzureOpenAiEmbeddingsModelTests.createModel( + "resource", + "deployment", + "apiversion", + 3, + true, + 100, + null, + "user", + "apikey", + null, + "id" + ); + model.setUri(new URI(getUrl(webServer))); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.checkModelConfig(model, listener); + + var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + exception.getMessage(), + is( + "The retrieved embeddings size [2] does not match the size specified in the settings [3]. " + + "Please recreate the [id] configuration with the correct dimensions" + ) + ); + + assertThat(webServer.requests(), hasSize(1)); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat(requestMap, Matchers.is(Map.of("input", List.of("how big"), "user", "user", "dimensions", 3))); + } + } + + public void testCheckModelConfig_ReturnsNewModelReference_AndDoesNotSendDimensionsField_WhenNotSetByUser() throws IOException, + URISyntaxException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = AzureOpenAiEmbeddingsModelTests.createModel( + "resource", + "deployment", + "apiversion", + 100, + false, + 100, + null, + "user", + "apikey", + null, + "id" + ); + model.setUri(new URI(getUrl(webServer))); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.checkModelConfig(model, listener); + + var result = listener.actionGet(TIMEOUT); + assertThat( + result, + is( + AzureOpenAiEmbeddingsModelTests.createModel( + "resource", + "deployment", + "apiversion", + 2, + false, + 100, + SimilarityMeasure.DOT_PRODUCT, + "user", + "apikey", + null, + "id" + ) + ) + ); + + assertThat(webServer.requests(), hasSize(1)); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat(requestMap, Matchers.is(Map.of("input", List.of("how big"), "user", "user"))); + } + } + + public void testInfer_UnauthorisedResponse() throws IOException, URISyntaxException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "error": { + "message": "Incorrect API key provided:", + "type": "invalid_request_error", + "param": null, + "code": "invalid_api_key" + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson)); + + var model = AzureOpenAiEmbeddingsModelTests.createModel("resource", "deployment", "apiversion", "user", "apikey", null, "id"); + model.setUri(new URI(getUrl(webServer))); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + List.of("abc"), + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(error.getMessage(), containsString("Received an authentication error status code for request")); + assertThat(error.getMessage(), containsString("Error message: [Incorrect API key provided:]")); + assertThat(webServer.requests(), hasSize(1)); + } + } + + public void testChunkedInfer_CallsInfer_ConvertsFloatResponse() throws IOException, URISyntaxException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new AzureOpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = AzureOpenAiEmbeddingsModelTests.createModel("resource", "deployment", "apiversion", "user", "apikey", null, "id"); + model.setUri(new URI(getUrl(webServer))); + PlainActionFuture> listener = new PlainActionFuture<>(); + service.chunkedInfer( + model, + List.of("abc"), + new HashMap<>(), + InputType.INGEST, + new ChunkingOptions(null, null), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT).get(0); + assertThat(result, CoreMatchers.instanceOf(ChunkedTextEmbeddingResults.class)); + + assertThat( + asMapWithListsInsteadOfArrays((ChunkedTextEmbeddingResults) result), + Matchers.is( + Map.of( + ChunkedTextEmbeddingResults.FIELD_NAME, + List.of( + Map.of( + ChunkedNlpInferenceResults.TEXT, + "abc", + ChunkedNlpInferenceResults.INFERENCE, + List.of((double) 0.0123f, (double) -0.0123f) + ) + ) + ) + ) + ); + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + assertThat(webServer.requests().get(0).getHeader(API_KEY_HEADER), equalTo("apikey")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), Matchers.is(2)); + assertThat(requestMap.get("input"), Matchers.is(List.of("abc"))); + assertThat(requestMap.get("user"), Matchers.is("user")); + } + } + + private AzureOpenAiService createAzureOpenAiService() { + return new AzureOpenAiService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + } + + private Map getRequestConfigMap( + Map serviceSettings, + Map taskSettings, + Map secretSettings + ) { + var builtServiceSettings = new HashMap<>(); + builtServiceSettings.putAll(serviceSettings); + builtServiceSettings.putAll(secretSettings); + + return new HashMap<>( + Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings) + ); + } + + private PeristedConfig getPersistedConfigMap( + Map serviceSettings, + Map taskSettings, + Map secretSettings + ) { + + return new PeristedConfig( + new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, serviceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings)), + new HashMap<>(Map.of(ModelSecrets.SECRET_SETTINGS, secretSettings)) + ); + } + + private PeristedConfig getPersistedConfigMap(Map serviceSettings, Map taskSettings) { + + return new PeristedConfig( + new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, serviceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings)), + null + ); + } + + private record PeristedConfig(Map config, Map secrets) {} +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsModelTests.java new file mode 100644 index 0000000000000..f161cd0b823fe --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsModelTests.java @@ -0,0 +1,121 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureopenai.embeddings; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiSecretSettings; + +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsTaskSettingsTests.getAzureOpenAiRequestTaskSettingsMap; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; + +public class AzureOpenAiEmbeddingsModelTests extends ESTestCase { + + public void testOverrideWith_OverridesUser() { + var model = createModel("resource", "deployment", "apiversion", null, "api_key", null, "id"); + var requestTaskSettingsMap = getAzureOpenAiRequestTaskSettingsMap("user_override"); + + var overriddenModel = AzureOpenAiEmbeddingsModel.of(model, requestTaskSettingsMap); + + assertThat(overriddenModel, is(createModel("resource", "deployment", "apiversion", "user_override", "api_key", null, "id"))); + } + + public void testOverrideWith_EmptyMap() { + var model = createModel("resource", "deployment", "apiversion", null, "api_key", null, "id"); + + var requestTaskSettingsMap = Map.of(); + + var overriddenModel = AzureOpenAiEmbeddingsModel.of(model, requestTaskSettingsMap); + assertThat(overriddenModel, sameInstance(model)); + } + + public void testOverrideWith_NullMap() { + var model = createModel("resource", "deployment", "apiversion", null, "api_key", null, "id"); + + var overriddenModel = AzureOpenAiEmbeddingsModel.of(model, null); + assertThat(overriddenModel, sameInstance(model)); + } + + public void testCreateModel_FromUpdatedServiceSettings() { + var model = createModel("resource", "deployment", "apiversion", "user", "api_key", null, "id"); + var updatedSettings = new AzureOpenAiEmbeddingsServiceSettings( + "resource", + "deployment", + "override_apiversion", + null, + false, + null, + null + ); + + var overridenModel = new AzureOpenAiEmbeddingsModel(model, updatedSettings); + + assertThat(overridenModel, is(createModel("resource", "deployment", "override_apiversion", "user", "api_key", null, "id"))); + } + + public static AzureOpenAiEmbeddingsModel createModel( + String resourceName, + String deploymentId, + String apiVersion, + String user, + @Nullable String apiKey, + @Nullable String entraId, + String inferenceEntityId + ) { + var secureApiKey = apiKey != null ? new SecureString(apiKey.toCharArray()) : null; + var secureEntraId = entraId != null ? new SecureString(entraId.toCharArray()) : null; + return new AzureOpenAiEmbeddingsModel( + inferenceEntityId, + TaskType.TEXT_EMBEDDING, + "service", + new AzureOpenAiEmbeddingsServiceSettings(resourceName, deploymentId, apiVersion, null, false, null, null), + new AzureOpenAiEmbeddingsTaskSettings(user), + new AzureOpenAiSecretSettings(secureApiKey, secureEntraId) + ); + } + + public static AzureOpenAiEmbeddingsModel createModel( + String resourceName, + String deploymentId, + String apiVersion, + @Nullable Integer dimensions, + Boolean dimensionsSetByUser, + @Nullable Integer maxInputTokens, + @Nullable SimilarityMeasure similarity, + @Nullable String user, + @Nullable String apiKey, + @Nullable String entraId, + String inferenceEntityId + ) { + var secureApiKey = apiKey != null ? new SecureString(apiKey.toCharArray()) : null; + var secureEntraId = entraId != null ? new SecureString(entraId.toCharArray()) : null; + + return new AzureOpenAiEmbeddingsModel( + inferenceEntityId, + TaskType.TEXT_EMBEDDING, + "service", + new AzureOpenAiEmbeddingsServiceSettings( + resourceName, + deploymentId, + apiVersion, + dimensions, + dimensionsSetByUser, + maxInputTokens, + similarity + ), + new AzureOpenAiEmbeddingsTaskSettings(user), + new AzureOpenAiSecretSettings(secureApiKey, secureEntraId) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsRequestTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsRequestTaskSettingsTests.java new file mode 100644 index 0000000000000..3ff73e0f23656 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsRequestTaskSettingsTests.java @@ -0,0 +1,56 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureopenai.embeddings; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields; +import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsRequestTaskSettings; + +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class AzureOpenAiEmbeddingsRequestTaskSettingsTests extends ESTestCase { + public void testFromMap_ReturnsEmptySettings_WhenTheMapIsEmpty() { + var settings = OpenAiEmbeddingsRequestTaskSettings.fromMap(new HashMap<>(Map.of())); + assertThat(settings, is(OpenAiEmbeddingsRequestTaskSettings.EMPTY_SETTINGS)); + } + + public void testFromMap_ReturnsEmptySettings_WhenTheMapDoesNotContainTheFields() { + var settings = OpenAiEmbeddingsRequestTaskSettings.fromMap(new HashMap<>(Map.of("key", "model"))); + assertNull(settings.user()); + } + + public void testFromMap_ReturnsUser() { + var settings = OpenAiEmbeddingsRequestTaskSettings.fromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, "user"))); + assertThat(settings.user(), is("user")); + } + + public void testFromMap_WhenUserIsEmpty_ThrowsValidationException() { + var exception = expectThrows( + ValidationException.class, + () -> OpenAiEmbeddingsRequestTaskSettings.fromMap(new HashMap<>(Map.of(OpenAiServiceFields.USER, ""))) + ); + + assertThat(exception.getMessage(), containsString("[user] must be a non-empty string")); + } + + public static Map getRequestTaskSettingsMap(@Nullable String user) { + var map = new HashMap(); + + if (user != null) { + map.put(OpenAiServiceFields.USER, user); + } + + return map; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettingsTests.java new file mode 100644 index 0000000000000..be184956b2034 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettingsTests.java @@ -0,0 +1,389 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureopenai.embeddings; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields; +import org.hamcrest.CoreMatchers; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; +import static org.elasticsearch.xpack.inference.services.azureopenai.embeddings.AzureOpenAiEmbeddingsServiceSettings.DIMENSIONS_SET_BY_USER; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class AzureOpenAiEmbeddingsServiceSettingsTests extends AbstractWireSerializingTestCase { + + private static AzureOpenAiEmbeddingsServiceSettings createRandom() { + var resourceName = randomAlphaOfLength(8); + var deploymentId = randomAlphaOfLength(8); + var apiVersion = randomAlphaOfLength(8); + Integer dims = randomBoolean() ? 1536 : null; + Integer maxInputTokens = randomBoolean() ? null : randomIntBetween(128, 256); + return new AzureOpenAiEmbeddingsServiceSettings( + resourceName, + deploymentId, + apiVersion, + dims, + randomBoolean(), + maxInputTokens, + null + ); + } + + public void testFromMap_Request_CreatesSettingsCorrectly() { + var resourceName = "this-resource"; + var deploymentId = "this-deployment"; + var apiVersion = "2024-01-01"; + var dims = 1536; + var maxInputTokens = 512; + var serviceSettings = AzureOpenAiEmbeddingsServiceSettings.fromMap( + new HashMap<>( + Map.of( + AzureOpenAiServiceFields.RESOURCE_NAME, + resourceName, + AzureOpenAiServiceFields.DEPLOYMENT_ID, + deploymentId, + AzureOpenAiServiceFields.API_VERSION, + apiVersion, + ServiceFields.DIMENSIONS, + dims, + ServiceFields.MAX_INPUT_TOKENS, + maxInputTokens, + SIMILARITY, + SimilarityMeasure.COSINE.toString() + ) + ), + ConfigurationParseContext.REQUEST + ); + + assertThat( + serviceSettings, + is( + new AzureOpenAiEmbeddingsServiceSettings( + resourceName, + deploymentId, + apiVersion, + dims, + true, + maxInputTokens, + SimilarityMeasure.COSINE + ) + ) + ); + } + + public void testFromMap_Request_DimensionsSetByUser_IsFalse_WhenDimensionsAreNotPresent() { + var resourceName = "this-resource"; + var deploymentId = "this-deployment"; + var apiVersion = "2024-01-01"; + var maxInputTokens = 512; + var serviceSettings = AzureOpenAiEmbeddingsServiceSettings.fromMap( + new HashMap<>( + Map.of( + AzureOpenAiServiceFields.RESOURCE_NAME, + resourceName, + AzureOpenAiServiceFields.DEPLOYMENT_ID, + deploymentId, + AzureOpenAiServiceFields.API_VERSION, + apiVersion, + ServiceFields.MAX_INPUT_TOKENS, + maxInputTokens + ) + ), + ConfigurationParseContext.REQUEST + ); + + assertThat( + serviceSettings, + is(new AzureOpenAiEmbeddingsServiceSettings(resourceName, deploymentId, apiVersion, null, false, maxInputTokens, null)) + ); + } + + public void testFromMap_Request_DimensionsSetByUser_ShouldThrowWhenPresent() { + var resourceName = "this-resource"; + var deploymentId = "this-deployment"; + var apiVersion = "2024-01-01"; + var maxInputTokens = 512; + var thrownException = expectThrows( + ValidationException.class, + () -> AzureOpenAiEmbeddingsServiceSettings.fromMap( + new HashMap<>( + Map.of( + AzureOpenAiServiceFields.RESOURCE_NAME, + resourceName, + AzureOpenAiServiceFields.DEPLOYMENT_ID, + deploymentId, + AzureOpenAiServiceFields.API_VERSION, + apiVersion, + ServiceFields.MAX_INPUT_TOKENS, + maxInputTokens, + ServiceFields.DIMENSIONS, + 1024, + DIMENSIONS_SET_BY_USER, + false + ) + ), + ConfigurationParseContext.REQUEST + ) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + containsString( + Strings.format("Validation Failed: 1: [service_settings] does not allow the setting [%s];", DIMENSIONS_SET_BY_USER) + ) + ); + } + + public void testFromMap_Persistent_CreatesSettingsCorrectly() { + var resourceName = "this-resource"; + var deploymentId = "this-deployment"; + var apiVersion = "2024-01-01"; + var encodingFormat = "float"; + var dims = 1536; + var maxInputTokens = 512; + + var serviceSettings = AzureOpenAiEmbeddingsServiceSettings.fromMap( + new HashMap<>( + Map.of( + AzureOpenAiServiceFields.RESOURCE_NAME, + resourceName, + AzureOpenAiServiceFields.DEPLOYMENT_ID, + deploymentId, + AzureOpenAiServiceFields.API_VERSION, + apiVersion, + ServiceFields.DIMENSIONS, + dims, + DIMENSIONS_SET_BY_USER, + false, + ServiceFields.MAX_INPUT_TOKENS, + maxInputTokens, + SIMILARITY, + SimilarityMeasure.DOT_PRODUCT.toString() + ) + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat( + serviceSettings, + is( + new AzureOpenAiEmbeddingsServiceSettings( + resourceName, + deploymentId, + apiVersion, + dims, + false, + maxInputTokens, + SimilarityMeasure.DOT_PRODUCT + ) + ) + ); + } + + public void testFromMap_PersistentContext_DoesNotThrowException_WhenDimensionsIsNull() { + var resourceName = "this-resource"; + var deploymentId = "this-deployment"; + var apiVersion = "2024-01-01"; + + var settings = AzureOpenAiEmbeddingsServiceSettings.fromMap( + new HashMap<>( + Map.of( + AzureOpenAiServiceFields.RESOURCE_NAME, + resourceName, + AzureOpenAiServiceFields.DEPLOYMENT_ID, + deploymentId, + AzureOpenAiServiceFields.API_VERSION, + apiVersion, + DIMENSIONS_SET_BY_USER, + true + ) + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat(settings, is(new AzureOpenAiEmbeddingsServiceSettings(resourceName, deploymentId, apiVersion, null, true, null, null))); + } + + public void testFromMap_PersistentContext_DoesNotThrowException_WhenSimilarityIsPresent() { + var resourceName = "this-resource"; + var deploymentId = "this-deployment"; + var apiVersion = "2024-01-01"; + + var settings = AzureOpenAiEmbeddingsServiceSettings.fromMap( + new HashMap<>( + Map.of( + AzureOpenAiServiceFields.RESOURCE_NAME, + resourceName, + AzureOpenAiServiceFields.DEPLOYMENT_ID, + deploymentId, + AzureOpenAiServiceFields.API_VERSION, + apiVersion, + DIMENSIONS_SET_BY_USER, + true, + SIMILARITY, + SimilarityMeasure.COSINE.toString() + ) + ), + ConfigurationParseContext.PERSISTENT + ); + + assertThat( + settings, + is(new AzureOpenAiEmbeddingsServiceSettings(resourceName, deploymentId, apiVersion, null, true, null, SimilarityMeasure.COSINE)) + ); + } + + public void testFromMap_PersistentContext_ThrowsException_WhenDimensionsSetByUserIsNull() { + var resourceName = "this-resource"; + var deploymentId = "this-deployment"; + var apiVersion = "2024-01-01"; + + var exception = expectThrows( + ValidationException.class, + () -> AzureOpenAiEmbeddingsServiceSettings.fromMap( + new HashMap<>( + Map.of( + AzureOpenAiServiceFields.RESOURCE_NAME, + resourceName, + AzureOpenAiServiceFields.DEPLOYMENT_ID, + deploymentId, + AzureOpenAiServiceFields.API_VERSION, + apiVersion, + ServiceFields.DIMENSIONS, + 1 + ) + ), + ConfigurationParseContext.PERSISTENT + ) + ); + + assertThat( + exception.getMessage(), + containsString("Validation Failed: 1: [service_settings] does not contain the required setting [dimensions_set_by_user];") + ); + } + + public void testToXContent_WritesDimensionsSetByUserTrue() throws IOException { + var entity = new AzureOpenAiEmbeddingsServiceSettings("resource", "deployment", "apiVersion", null, true, null, null); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, CoreMatchers.is(""" + {"resource_name":"resource","deployment_id":"deployment","api_version":"apiVersion",""" + """ + "dimensions_set_by_user":true}""")); + } + + public void testToXContent_WritesAllValues() throws IOException { + var entity = new AzureOpenAiEmbeddingsServiceSettings("resource", "deployment", "apiVersion", 1024, false, 512, null); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, CoreMatchers.is(""" + {"resource_name":"resource","deployment_id":"deployment","api_version":"apiVersion",""" + """ + "dimensions":1024,"max_input_tokens":512,"dimensions_set_by_user":false}""")); + } + + public void testToFilteredXContent_WritesAllValues_ExceptDimensionsSetByUser() throws IOException { + var entity = new AzureOpenAiEmbeddingsServiceSettings("resource", "deployment", "apiVersion", 1024, false, 512, null); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + var filteredXContent = entity.getFilteredXContentObject(); + filteredXContent.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, CoreMatchers.is(""" + {"resource_name":"resource","deployment_id":"deployment","api_version":"apiVersion",""" + """ + "dimensions":1024,"max_input_tokens":512}""")); + } + + @Override + protected Writeable.Reader instanceReader() { + return AzureOpenAiEmbeddingsServiceSettings::new; + } + + @Override + protected AzureOpenAiEmbeddingsServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected AzureOpenAiEmbeddingsServiceSettings mutateInstance(AzureOpenAiEmbeddingsServiceSettings instance) throws IOException { + return createRandom(); + } + + public static Map getPersistentAzureOpenAiServiceSettingsMap( + String resourceName, + String deploymentId, + String apiVersion, + @Nullable Integer dimensions, + @Nullable Integer maxInputTokens + ) { + var map = new HashMap(); + + map.put(AzureOpenAiServiceFields.RESOURCE_NAME, resourceName); + map.put(AzureOpenAiServiceFields.DEPLOYMENT_ID, deploymentId); + map.put(AzureOpenAiServiceFields.API_VERSION, apiVersion); + + if (dimensions != null) { + map.put(ServiceFields.DIMENSIONS, dimensions); + map.put(DIMENSIONS_SET_BY_USER, true); + } else { + map.put(DIMENSIONS_SET_BY_USER, false); + } + + if (maxInputTokens != null) { + map.put(ServiceFields.MAX_INPUT_TOKENS, maxInputTokens); + } + + return map; + } + + public static Map getRequestAzureOpenAiServiceSettingsMap( + String resourceName, + String deploymentId, + String apiVersion, + @Nullable Integer dimensions, + @Nullable Integer maxInputTokens + ) { + var map = new HashMap(); + + map.put(AzureOpenAiServiceFields.RESOURCE_NAME, resourceName); + map.put(AzureOpenAiServiceFields.DEPLOYMENT_ID, deploymentId); + map.put(AzureOpenAiServiceFields.API_VERSION, apiVersion); + + if (dimensions != null) { + map.put(ServiceFields.DIMENSIONS, dimensions); + } + + if (maxInputTokens != null) { + map.put(ServiceFields.MAX_INPUT_TOKENS, maxInputTokens); + } + + return map; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsTaskSettingsTests.java new file mode 100644 index 0000000000000..cc2d8b9b67620 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsTaskSettingsTests.java @@ -0,0 +1,107 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.azureopenai.embeddings; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.hamcrest.Matchers.is; + +public class AzureOpenAiEmbeddingsTaskSettingsTests extends AbstractWireSerializingTestCase { + + public static AzureOpenAiEmbeddingsTaskSettings createRandomWithUser() { + return new AzureOpenAiEmbeddingsTaskSettings(randomAlphaOfLength(15)); + } + + /** + * The created settings can have the user set to null. + */ + public static AzureOpenAiEmbeddingsTaskSettings createRandom() { + var user = randomBoolean() ? randomAlphaOfLength(15) : null; + return new AzureOpenAiEmbeddingsTaskSettings(user); + } + + public void testFromMap_WithUser() { + assertEquals( + new AzureOpenAiEmbeddingsTaskSettings("user"), + AzureOpenAiEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, "user"))) + ); + } + + public void testFromMap_UserIsEmptyString() { + var thrownException = expectThrows( + ValidationException.class, + () -> AzureOpenAiEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, ""))) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + is(Strings.format("Validation Failed: 1: [task_settings] Invalid value empty string. [user] must be a non-empty string;")) + ); + } + + public void testFromMap_MissingUser_DoesNotThrowException() { + var taskSettings = AzureOpenAiEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of())); + assertNull(taskSettings.user()); + } + + public void testOverrideWith_KeepsOriginalValuesWithOverridesAreNull() { + var taskSettings = AzureOpenAiEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, "user"))); + + var overriddenTaskSettings = AzureOpenAiEmbeddingsTaskSettings.of( + taskSettings, + AzureOpenAiEmbeddingsRequestTaskSettings.EMPTY_SETTINGS + ); + MatcherAssert.assertThat(overriddenTaskSettings, is(taskSettings)); + } + + public void testOverrideWith_UsesOverriddenSettings() { + var taskSettings = AzureOpenAiEmbeddingsTaskSettings.fromMap(new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, "user"))); + + var requestTaskSettings = AzureOpenAiEmbeddingsRequestTaskSettings.fromMap( + new HashMap<>(Map.of(AzureOpenAiServiceFields.USER, "user2")) + ); + + var overriddenTaskSettings = AzureOpenAiEmbeddingsTaskSettings.of(taskSettings, requestTaskSettings); + MatcherAssert.assertThat(overriddenTaskSettings, is(new AzureOpenAiEmbeddingsTaskSettings("user2"))); + } + + @Override + protected Writeable.Reader instanceReader() { + return AzureOpenAiEmbeddingsTaskSettings::new; + } + + @Override + protected AzureOpenAiEmbeddingsTaskSettings createTestInstance() { + return createRandomWithUser(); + } + + @Override + protected AzureOpenAiEmbeddingsTaskSettings mutateInstance(AzureOpenAiEmbeddingsTaskSettings instance) throws IOException { + return createRandomWithUser(); + } + + public static Map getAzureOpenAiRequestTaskSettingsMap(@Nullable String user) { + var map = new HashMap(); + + if (user != null) { + map.put(AzureOpenAiServiceFields.USER, user); + } + + return map; + } +} From 9810b65d9603e142be43cc5cc8824006f087208b Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Fri, 12 Apr 2024 14:24:49 -0700 Subject: [PATCH 19/20] Use response executor when handling failure in field-caps (#107370) The transport service may call onFailure with a thread pool other than search_coordinator. This change adds a workaround to ensure that the merging of field-caps always occurs on the search_coordinator. --- docs/changelog/107370.yaml | 5 +++++ .../fieldcaps/CCSFieldCapabilitiesIT.java | 20 ++++++++++++++++++ .../TransportFieldCapabilitiesAction.java | 21 ++++++++++++++++++- .../RemoteClusterSecurityEsqlIT.java | 3 --- 4 files changed, 45 insertions(+), 4 deletions(-) create mode 100644 docs/changelog/107370.yaml diff --git a/docs/changelog/107370.yaml b/docs/changelog/107370.yaml new file mode 100644 index 0000000000000..e7bdeef68cffe --- /dev/null +++ b/docs/changelog/107370.yaml @@ -0,0 +1,5 @@ +pr: 107370 +summary: Fork when handling remote field-caps responses +area: Search +type: bug +issues: [] diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/fieldcaps/CCSFieldCapabilitiesIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/fieldcaps/CCSFieldCapabilitiesIT.java index afc62323ca544..08c4d2aab4bc9 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/fieldcaps/CCSFieldCapabilitiesIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/fieldcaps/CCSFieldCapabilitiesIT.java @@ -25,8 +25,10 @@ import java.util.List; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; +import static org.hamcrest.Matchers.arrayContaining; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; public class CCSFieldCapabilitiesIT extends AbstractMultiClustersTestCase { @@ -35,6 +37,11 @@ protected Collection remoteClusterAlias() { return List.of("remote_cluster"); } + @Override + protected boolean reuseClusters() { + return false; + } + @Override protected Collection> nodePlugins(String clusterAlias) { final List> plugins = new ArrayList<>(super.nodePlugins(clusterAlias)); @@ -105,4 +112,17 @@ public void testFailuresFromRemote() { assertEquals(IllegalArgumentException.class, ex.getClass()); assertEquals("I throw because I choose to.", ex.getMessage()); } + + public void testFailedToConnectToRemoteCluster() throws Exception { + String localIndex = "local_index"; + assertAcked(client(LOCAL_CLUSTER).admin().indices().prepareCreate(localIndex)); + client(LOCAL_CLUSTER).prepareIndex(localIndex).setId("1").setSource("foo", "bar").get(); + client(LOCAL_CLUSTER).admin().indices().prepareRefresh(localIndex).get(); + cluster("remote_cluster").close(); + FieldCapabilitiesResponse response = client().prepareFieldCaps("*", "remote_cluster:*").setFields("*").get(); + assertThat(response.getIndices(), arrayContaining(localIndex)); + List failures = response.getFailures(); + assertThat(failures, hasSize(1)); + assertThat(failures.get(0).getIndices(), arrayContaining("remote_cluster:*")); + } } diff --git a/server/src/main/java/org/elasticsearch/action/fieldcaps/TransportFieldCapabilitiesAction.java b/server/src/main/java/org/elasticsearch/action/fieldcaps/TransportFieldCapabilitiesAction.java index e6acaba8307f6..7a8ea12568006 100644 --- a/server/src/main/java/org/elasticsearch/action/fieldcaps/TransportFieldCapabilitiesAction.java +++ b/server/src/main/java/org/elasticsearch/action/fieldcaps/TransportFieldCapabilitiesAction.java @@ -15,6 +15,7 @@ import org.elasticsearch.action.ActionType; import org.elasticsearch.action.OriginalIndices; import org.elasticsearch.action.RemoteClusterActionType; +import org.elasticsearch.action.support.AbstractThreadedActionListener; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.ChannelActionListener; import org.elasticsearch.action.support.HandledTransportAction; @@ -252,7 +253,14 @@ private void doExecuteForked(Task task, FieldCapabilitiesRequest request, final remoteClusterClient.execute( TransportFieldCapabilitiesAction.REMOTE_TYPE, remoteRequest, - ActionListener.releaseAfter(remoteListener, refs.acquire()) + // The underlying transport service may call onFailure with a thread pool other than search_coordinator. + // This fork is a workaround to ensure that the merging of field-caps always occurs on the search_coordinator. + // TODO: remove this workaround after we fixed https://github.com/elastic/elasticsearch/issues/107439 + new ForkingOnFailureActionListener<>( + searchCoordinationExecutor, + true, + ActionListener.releaseAfter(remoteListener, refs.acquire()) + ) ); } } @@ -569,4 +577,15 @@ public void messageReceived(FieldCapabilitiesNodeRequest request, TransportChann }); } } + + private static class ForkingOnFailureActionListener extends AbstractThreadedActionListener { + ForkingOnFailureActionListener(Executor executor, boolean forceExecution, ActionListener delegate) { + super(executor, forceExecution, delegate); + } + + @Override + public void onResponse(Response response) { + delegate.onResponse(response); + } + } } diff --git a/x-pack/plugin/security/qa/multi-cluster/src/javaRestTest/java/org/elasticsearch/xpack/remotecluster/RemoteClusterSecurityEsqlIT.java b/x-pack/plugin/security/qa/multi-cluster/src/javaRestTest/java/org/elasticsearch/xpack/remotecluster/RemoteClusterSecurityEsqlIT.java index 4c7e96c26b7d6..2c393ea7ed1df 100644 --- a/x-pack/plugin/security/qa/multi-cluster/src/javaRestTest/java/org/elasticsearch/xpack/remotecluster/RemoteClusterSecurityEsqlIT.java +++ b/x-pack/plugin/security/qa/multi-cluster/src/javaRestTest/java/org/elasticsearch/xpack/remotecluster/RemoteClusterSecurityEsqlIT.java @@ -494,9 +494,6 @@ public void testCrossClusterQueryWithRemoteDLSAndFLS() throws Exception { assertThat(flatList, containsInAnyOrder("engineering")); } - @SuppressWarnings("unchecked") - @AwaitsFix(bugUrl = "this trips ThreadPool.assertCurrentThreadPool(ThreadPool.Names.SEARCH_COORDINATION)") - // comment out those assertions in EsqlIndexResolver and TransportFieldCapabilitiesAction to see this test pass public void testCrossClusterQueryAgainstInvalidRemote() throws Exception { configureRemoteCluster(); populateData(); From adb6aa91ee67e31463fdc16e557dd4238a6ef188 Mon Sep 17 00:00:00 2001 From: Athena Brown Date: Fri, 12 Apr 2024 16:36:11 -0600 Subject: [PATCH 20/20] Add extension point for file-based role validation (#107177) This commit adds an extension point for Security extensions to add a file-based role validator, which will be used to apply rules to the roles provided in roles.yml. If this validation fails, the system will treat it like the roles YAML file could not be read: if this happens at startup, the node will fail to start, while if the roles.yml file is changed at runtime and fails validation, the failure will be logged and the active roles not updated. There is no change to the default behavior of Elasticsearch. --- .../xpack/security/Security.java | 10 +- .../security/authz/FileRoleValidator.java | 29 +++++ .../security/authz/store/FileRolesStore.java | 107 +++++++++--------- .../authz/store/FileRolesStoreTests.java | 56 ++++++--- 4 files changed, 133 insertions(+), 69 deletions(-) create mode 100644 x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/FileRoleValidator.java diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java index f4457dcbbfaa9..837c58ab6542d 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/Security.java @@ -296,6 +296,7 @@ import org.elasticsearch.xpack.security.authz.AuthorizationDenialMessages; import org.elasticsearch.xpack.security.authz.AuthorizationService; import org.elasticsearch.xpack.security.authz.DlsFlsRequestCacheDifferentiator; +import org.elasticsearch.xpack.security.authz.FileRoleValidator; import org.elasticsearch.xpack.security.authz.ReservedRoleNameChecker; import org.elasticsearch.xpack.security.authz.SecuritySearchOperationListener; import org.elasticsearch.xpack.security.authz.accesscontrol.OptOutQueryCache; @@ -588,6 +589,7 @@ public class Security extends Plugin private final SetOnce> reloadableComponents = new SetOnce<>(); private final SetOnce authorizationDenialMessages = new SetOnce<>(); private final SetOnce reservedRoleNameCheckerFactory = new SetOnce<>(); + private final SetOnce fileRoleValidator = new SetOnce<>(); private final SetOnce secondaryAuthActions = new SetOnce<>(); public Security(Settings settings) { @@ -828,7 +830,6 @@ Collection createComponents( dlsBitsetCache.set(new DocumentSubsetBitsetCache(settings, threadPool)); final FieldPermissionsCache fieldPermissionsCache = new FieldPermissionsCache(settings); - this.fileRolesStore.set(new FileRolesStore(settings, environment, resourceWatcherService, getLicenseState(), xContentRegistry)); final NativeRolesStore nativeRolesStore = new NativeRolesStore( settings, client, @@ -859,6 +860,12 @@ Collection createComponents( if (reservedRoleNameCheckerFactory.get() == null) { reservedRoleNameCheckerFactory.set(new ReservedRoleNameChecker.Factory.Default()); } + if (fileRoleValidator.get() == null) { + fileRoleValidator.set(new FileRoleValidator.Default()); + } + this.fileRolesStore.set( + new FileRolesStore(settings, environment, resourceWatcherService, getLicenseState(), xContentRegistry, fileRoleValidator.get()) + ); final ReservedRoleNameChecker reservedRoleNameChecker = reservedRoleNameCheckerFactory.get().create(fileRolesStore.get()::exists); components.add(new PluginComponentBinding<>(ReservedRoleNameChecker.class, reservedRoleNameChecker)); @@ -2118,6 +2125,7 @@ public void loadExtensions(ExtensionLoader loader) { loadSingletonExtensionAndSetOnce(loader, hasPrivilegesRequestBuilderFactory, HasPrivilegesRequestBuilderFactory.class); loadSingletonExtensionAndSetOnce(loader, authorizationDenialMessages, AuthorizationDenialMessages.class); loadSingletonExtensionAndSetOnce(loader, reservedRoleNameCheckerFactory, ReservedRoleNameChecker.Factory.class); + loadSingletonExtensionAndSetOnce(loader, fileRoleValidator, FileRoleValidator.class); loadSingletonExtensionAndSetOnce(loader, secondaryAuthActions, SecondaryAuthActions.class); } diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/FileRoleValidator.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/FileRoleValidator.java new file mode 100644 index 0000000000000..9f4705d34b320 --- /dev/null +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/FileRoleValidator.java @@ -0,0 +1,29 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.security.authz; + +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.xpack.core.security.authz.RoleDescriptor; + +/** + * Provides a check which will be applied to roles in the file-based roles store. + */ +@FunctionalInterface +public interface FileRoleValidator { + ActionRequestValidationException validatePredefinedRole(RoleDescriptor roleDescriptor); + + /** + * The default file role validator used in stateful Elasticsearch, a no-op. + */ + class Default implements FileRoleValidator { + @Override + public ActionRequestValidationException validatePredefinedRole(RoleDescriptor roleDescriptor) { + return null; + } + } +} diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/store/FileRolesStore.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/store/FileRolesStore.java index d7c8f11c467f2..d769e44f2d38d 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/store/FileRolesStore.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/store/FileRolesStore.java @@ -12,6 +12,7 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.Maps; import org.elasticsearch.common.util.set.Sets; @@ -36,13 +37,13 @@ import org.elasticsearch.xpack.core.security.authz.support.DLSRoleQueryValidator; import org.elasticsearch.xpack.core.security.support.NoOpLogger; import org.elasticsearch.xpack.core.security.support.Validation; +import org.elasticsearch.xpack.security.authz.FileRoleValidator; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; import java.util.ArrayList; -import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -55,7 +56,9 @@ import java.util.stream.Collectors; import static java.util.Collections.emptyMap; +import static java.util.Collections.emptySet; import static java.util.Collections.unmodifiableMap; +import static java.util.Collections.unmodifiableSet; import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.core.security.SecurityField.DOCUMENT_LEVEL_SECURITY_FEATURE; @@ -67,6 +70,7 @@ public class FileRolesStore implements BiConsumer, ActionListener>> listeners = new ArrayList<>(); @@ -78,9 +82,10 @@ public FileRolesStore( Environment env, ResourceWatcherService watcherService, XPackLicenseState licenseState, - NamedXContentRegistry xContentRegistry + NamedXContentRegistry xContentRegistry, + FileRoleValidator roleValidator ) throws IOException { - this(settings, env, watcherService, null, licenseState, xContentRegistry); + this(settings, env, watcherService, null, roleValidator, licenseState, xContentRegistry); } FileRolesStore( @@ -88,11 +93,13 @@ public FileRolesStore( Environment env, ResourceWatcherService watcherService, Consumer> listener, + FileRoleValidator roleValidator, XPackLicenseState licenseState, NamedXContentRegistry xContentRegistry ) throws IOException { this.settings = settings; this.file = resolveFile(env); + this.roleValidator = roleValidator; if (listener != null) { listeners.add(listener); } @@ -101,7 +108,7 @@ public FileRolesStore( FileWatcher watcher = new FileWatcher(file.getParent()); watcher.addListener(new FileListener()); watcherService.add(watcher, ResourceWatcherService.Frequency.HIGH); - permissions = parseFile(file, logger, settings, licenseState, xContentRegistry); + permissions = parseFile(file, logger, settings, licenseState, xContentRegistry, roleValidator); } @Override @@ -176,27 +183,45 @@ public static Path resolveFile(Environment env) { } public static Set parseFileForRoleNames(Path path, Logger logger) { - // EMPTY is safe here because we never use namedObject as we are just parsing role names - return parseRoleDescriptors(path, logger, false, Settings.EMPTY, NamedXContentRegistry.EMPTY).keySet(); - } + if (logger == null) { + logger = NoOpLogger.INSTANCE; + } + + Map roles = new HashMap<>(); + logger.trace("attempting to read roles file located at [{}]", path.toAbsolutePath()); + if (Files.exists(path)) { + try { + List roleSegments = roleSegments(path); + for (String segment : roleSegments) { + RoleDescriptor rd = parseRoleDescriptor( + segment, + path, + logger, + false, + Settings.EMPTY, + NamedXContentRegistry.EMPTY, + new FileRoleValidator.Default() + ); + if (rd != null) { + roles.put(rd.getName(), rd); + } + } + } catch (IOException ioe) { + logger.error(() -> format("failed to read roles file [%s]. skipping all roles...", path.toAbsolutePath()), ioe); + return emptySet(); + } + } + return unmodifiableSet(roles.keySet()); - public static Map parseFile( - Path path, - Logger logger, - Settings settings, - XPackLicenseState licenseState, - NamedXContentRegistry xContentRegistry - ) { - return parseFile(path, logger, true, settings, licenseState, xContentRegistry); } public static Map parseFile( Path path, Logger logger, - boolean resolvePermission, Settings settings, XPackLicenseState licenseState, - NamedXContentRegistry xContentRegistry + NamedXContentRegistry xContentRegistry, + FileRoleValidator roleValidator ) { if (logger == null) { logger = NoOpLogger.INSTANCE; @@ -210,7 +235,7 @@ public static Map parseFile( final boolean isDlsLicensed = DOCUMENT_LEVEL_SECURITY_FEATURE.checkWithoutTracking(licenseState); for (String segment : roleSegments) { - RoleDescriptor descriptor = parseRoleDescriptor(segment, path, logger, resolvePermission, settings, xContentRegistry); + RoleDescriptor descriptor = parseRoleDescriptor(segment, path, logger, true, settings, xContentRegistry, roleValidator); if (descriptor != null) { if (ReservedRolesStore.isReserved(descriptor.getName())) { logger.warn( @@ -243,36 +268,6 @@ public static Map parseFile( return unmodifiableMap(roles); } - public static Map parseRoleDescriptors( - Path path, - Logger logger, - boolean resolvePermission, - Settings settings, - NamedXContentRegistry xContentRegistry - ) { - if (logger == null) { - logger = NoOpLogger.INSTANCE; - } - - Map roles = new HashMap<>(); - logger.trace("attempting to read roles file located at [{}]", path.toAbsolutePath()); - if (Files.exists(path)) { - try { - List roleSegments = roleSegments(path); - for (String segment : roleSegments) { - RoleDescriptor rd = parseRoleDescriptor(segment, path, logger, resolvePermission, settings, xContentRegistry); - if (rd != null) { - roles.put(rd.getName(), rd); - } - } - } catch (IOException ioe) { - logger.error(() -> format("failed to read roles file [%s]. skipping all roles...", path.toAbsolutePath()), ioe); - return emptyMap(); - } - } - return unmodifiableMap(roles); - } - @Nullable static RoleDescriptor parseRoleDescriptor( String segment, @@ -280,7 +275,8 @@ static RoleDescriptor parseRoleDescriptor( Logger logger, boolean resolvePermissions, Settings settings, - NamedXContentRegistry xContentRegistry + NamedXContentRegistry xContentRegistry, + FileRoleValidator roleValidator ) { String roleName = null; XContentParserConfiguration parserConfig = XContentParserConfiguration.EMPTY.withRegistry(xContentRegistry) @@ -311,7 +307,7 @@ static RoleDescriptor parseRoleDescriptor( // we pass true as last parameter because we do not want to reject files if field permissions // are given in 2.x syntax RoleDescriptor descriptor = RoleDescriptor.parse(roleName, parser, true, false); - return checkDescriptor(descriptor, path, logger, settings, xContentRegistry); + return checkDescriptor(descriptor, path, logger, settings, xContentRegistry, roleValidator); } else { logger.error("invalid role definition [{}] in roles file [{}]. skipping role...", roleName, path.toAbsolutePath()); return null; @@ -344,7 +340,8 @@ private static RoleDescriptor checkDescriptor( Path path, Logger logger, Settings settings, - NamedXContentRegistry xContentRegistry + NamedXContentRegistry xContentRegistry, + FileRoleValidator roleValidator ) { String roleName = descriptor.getName(); // first check if FLS/DLS is enabled on the role... @@ -374,6 +371,10 @@ private static RoleDescriptor checkDescriptor( } } } + ActionRequestValidationException ex = roleValidator.validatePredefinedRole(descriptor); + if (ex != null) { + throw ex; + } return descriptor; } @@ -417,7 +418,7 @@ public synchronized void onFileChanged(Path file) { if (file.equals(FileRolesStore.this.file)) { final Map previousPermissions = permissions; try { - permissions = parseFile(file, logger, settings, licenseState, xContentRegistry); + permissions = parseFile(file, logger, settings, licenseState, xContentRegistry, roleValidator); } catch (Exception e) { logger.error( () -> format("could not reload roles file [%s]. Current roles remain unmodified", file.toAbsolutePath()), @@ -431,7 +432,7 @@ public synchronized void onFileChanged(Path file) { .map(Map.Entry::getKey) .collect(Collectors.toSet()); final Set addedRoles = Sets.difference(permissions.keySet(), previousPermissions.keySet()); - final Set changedRoles = Collections.unmodifiableSet(Sets.union(changedOrMissingRoles, addedRoles)); + final Set changedRoles = unmodifiableSet(Sets.union(changedOrMissingRoles, addedRoles)); if (changedRoles.isEmpty() == false) { logger.info("updated roles (roles file [{}] {})", file.toAbsolutePath(), Files.exists(file) ? "changed" : "removed"); listeners.forEach(c -> c.accept(changedRoles)); diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authz/store/FileRolesStoreTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authz/store/FileRolesStoreTests.java index 0f9dd06983792..65f2919541e07 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authz/store/FileRolesStoreTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authz/store/FileRolesStoreTests.java @@ -39,6 +39,7 @@ import org.elasticsearch.xpack.core.security.authz.privilege.IndexPrivilege; import org.elasticsearch.xpack.core.security.authz.store.ReservedRolesStore; import org.elasticsearch.xpack.core.security.support.Automatons; +import org.elasticsearch.xpack.security.authz.FileRoleValidator; import org.junit.BeforeClass; import java.io.BufferedWriter; @@ -104,7 +105,8 @@ public void testParseFile() throws Exception { logger, Settings.builder().put(XPackSettings.DLS_FLS_ENABLED.getKey(), true).build(), TestUtils.newTestLicenseState(), - xContentRegistry() + xContentRegistry(), + new FileRoleValidator.Default() ); assertThat(roles, notNullValue()); assertThat(roles.size(), is(10)); @@ -295,7 +297,8 @@ public void testParseFileWithRemoteIndices() throws IllegalAccessException, IOEx logger, Settings.builder().put(XPackSettings.DLS_FLS_ENABLED.getKey(), true).build(), TestUtils.newTestLicenseState(), - xContentRegistry() + xContentRegistry(), + new FileRoleValidator.Default() ); assertThat(roles, notNullValue()); assertThat(roles.size(), is(2)); @@ -359,7 +362,8 @@ public void testParseFileWithFLSAndDLSDisabled() throws Exception { logger, Settings.builder().put(XPackSettings.DLS_FLS_ENABLED.getKey(), false).build(), TestUtils.newTestLicenseState(), - xContentRegistry() + xContentRegistry(), + new FileRoleValidator.Default() ); assertThat(roles, notNullValue()); assertThat(roles.size(), is(7)); @@ -410,7 +414,14 @@ public void testParseFileWithFLSAndDLSUnlicensed() throws Exception { events.clear(); MockLicenseState licenseState = mock(MockLicenseState.class); when(licenseState.isAllowed(DOCUMENT_LEVEL_SECURITY_FEATURE)).thenReturn(false); - Map roles = FileRolesStore.parseFile(path, logger, Settings.EMPTY, licenseState, xContentRegistry()); + Map roles = FileRolesStore.parseFile( + path, + logger, + Settings.EMPTY, + licenseState, + xContentRegistry(), + new FileRoleValidator.Default() + ); assertThat(roles, notNullValue()); assertThat(roles.size(), is(10)); assertNotNull(roles.get("role_fields")); @@ -445,7 +456,8 @@ public void testDefaultRolesFile() throws Exception { logger, Settings.EMPTY, TestUtils.newTestLicenseState(), - xContentRegistry() + xContentRegistry(), + new FileRoleValidator.Default() ); assertThat(roles, notNullValue()); assertThat(roles.size(), is(0)); @@ -474,7 +486,7 @@ public void testAutoReload() throws Exception { FileRolesStore store = new FileRolesStore(settings, env, watcherService, roleSet -> { modifiedRoles.addAll(roleSet); latch.countDown(); - }, TestUtils.newTestLicenseState(), xContentRegistry()); + }, new FileRoleValidator.Default(), TestUtils.newTestLicenseState(), xContentRegistry()); Set descriptors = store.roleDescriptors(Collections.singleton("role1")); assertThat(descriptors, notNullValue()); @@ -534,7 +546,7 @@ public void testAutoReload() throws Exception { if (roleSet.contains("dummy1")) { truncateLatch.countDown(); } - }, TestUtils.newTestLicenseState(), xContentRegistry()); + }, new FileRoleValidator.Default(), TestUtils.newTestLicenseState(), xContentRegistry()); final Set allRolesPreTruncate = store.getAllRoleNames(); assertTrue(allRolesPreTruncate.contains("role5")); @@ -563,7 +575,7 @@ public void testAutoReload() throws Exception { if (roleSet.contains("dummy2")) { modifyLatch.countDown(); } - }, TestUtils.newTestLicenseState(), xContentRegistry()); + }, new FileRoleValidator.Default(), TestUtils.newTestLicenseState(), xContentRegistry()); try (BufferedWriter writer = Files.newBufferedWriter(tmp, StandardCharsets.UTF_8, StandardOpenOption.TRUNCATE_EXISTING)) { writer.append("role5:").append(System.lineSeparator()); @@ -596,7 +608,8 @@ public void testThatEmptyFileDoesNotResultInLoop() throws Exception { logger, Settings.EMPTY, TestUtils.newTestLicenseState(), - xContentRegistry() + xContentRegistry(), + new FileRoleValidator.Default() ); assertThat(roles.keySet(), is(empty())); } @@ -611,7 +624,8 @@ public void testThatInvalidRoleDefinitions() throws Exception { logger, Settings.EMPTY, TestUtils.newTestLicenseState(), - xContentRegistry() + xContentRegistry(), + new FileRoleValidator.Default() ); assertThat(roles.size(), is(1)); assertThat(roles, hasKey("valid_role")); @@ -660,7 +674,8 @@ public void testReservedRoles() throws Exception { logger, Settings.EMPTY, TestUtils.newTestLicenseState(), - xContentRegistry() + xContentRegistry(), + new FileRoleValidator.Default() ); assertThat(roles, notNullValue()); assertThat(roles.size(), is(2)); @@ -696,7 +711,8 @@ public void testUsageStats() throws Exception { env, mock(ResourceWatcherService.class), TestUtils.newTestLicenseState(), - xContentRegistry() + xContentRegistry(), + new FileRoleValidator.Default() ); Map usageStats = store.usageStats(); @@ -723,14 +739,16 @@ public void testExists() throws Exception { env, mock(ResourceWatcherService.class), TestUtils.newTestLicenseState(), - xContentRegistry() + xContentRegistry(), + new FileRoleValidator.Default() ); Map roles = FileRolesStore.parseFile( path, logger, Settings.builder().put(XPackSettings.DLS_FLS_ENABLED.getKey(), true).build(), TestUtils.newTestLicenseState(), - xContentRegistry() + xContentRegistry(), + new FileRoleValidator.Default() ); assertThat(roles, notNullValue()); assertThat(roles.size(), is(10)); @@ -745,7 +763,15 @@ public void testBWCFieldPermissions() throws IOException { Path path = getDataPath("roles2xformat.yml"); byte[] bytes = Files.readAllBytes(path); String roleString = new String(bytes, Charset.defaultCharset()); - RoleDescriptor role = FileRolesStore.parseRoleDescriptor(roleString, path, logger, true, Settings.EMPTY, xContentRegistry()); + RoleDescriptor role = FileRolesStore.parseRoleDescriptor( + roleString, + path, + logger, + true, + Settings.EMPTY, + xContentRegistry(), + new FileRoleValidator.Default() + ); RoleDescriptor.IndicesPrivileges indicesPrivileges = role.getIndicesPrivileges()[0]; assertThat(indicesPrivileges.getGrantedFields(), arrayContaining("foo", "boo")); assertNull(indicesPrivileges.getDeniedFields());