diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/BinPackingNodeAllocatorService.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/BinPackingNodeAllocatorService.java index c40d27fbb19dc..151f94b5ed347 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/BinPackingNodeAllocatorService.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/BinPackingNodeAllocatorService.java @@ -512,7 +512,7 @@ public BinPackingSimulation( continue; } long nodeReservedMemory = preReservedMemory.getOrDefault(node.getNodeIdentifier(), 0L); - nodesRemainingMemory.put(node.getNodeIdentifier(), memoryPoolInfo.getMaxBytes() - nodeReservedMemory); + nodesRemainingMemory.put(node.getNodeIdentifier(), max(memoryPoolInfo.getMaxBytes() - nodeReservedMemory, 0L)); } nodesRemainingMemoryRuntimeAdjusted = new HashMap<>(); @@ -540,7 +540,7 @@ public BinPackingSimulation( // if globally reported memory usage of node is greater than computed one lets use that. // it can be greater if there are tasks executed on cluster which do not have task retries enabled. nodeUsedMemoryRuntimeAdjusted = max(nodeUsedMemoryRuntimeAdjusted, memoryPoolInfo.getReservedBytes()); - nodesRemainingMemoryRuntimeAdjusted.put(node.getNodeIdentifier(), memoryPoolInfo.getMaxBytes() - nodeUsedMemoryRuntimeAdjusted); + nodesRemainingMemoryRuntimeAdjusted.put(node.getNodeIdentifier(), max(memoryPoolInfo.getMaxBytes() - nodeUsedMemoryRuntimeAdjusted, 0L)); } } @@ -610,10 +610,10 @@ private void subtractFromRemainingMemory(String nodeIdentifier, long memoryLease { nodesRemainingMemoryRuntimeAdjusted.compute( nodeIdentifier, - (key, free) -> free - memoryLease); + (key, free) -> max(free - memoryLease, 0)); nodesRemainingMemory.compute( nodeIdentifier, - (key, free) -> free - memoryLease); + (key, free) -> max(free - memoryLease, 0)); } private boolean isNodeEmpty(String nodeIdentifier) diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/EventDrivenFaultTolerantQueryScheduler.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/EventDrivenFaultTolerantQueryScheduler.java index 85f00d4b28c42..98e9eab5651ae 100644 --- a/core/trino-main/src/main/java/io/trino/execution/scheduler/EventDrivenFaultTolerantQueryScheduler.java +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/EventDrivenFaultTolerantQueryScheduler.java @@ -37,6 +37,9 @@ import io.airlift.units.Duration; import io.opentelemetry.api.trace.Tracer; import io.trino.Session; +import io.trino.connector.informationschema.InformationSchemaTableHandle; +import io.trino.connector.system.GlobalSystemConnector; +import io.trino.connector.system.SystemTableHandle; import io.trino.exchange.SpoolingExchangeInput; import io.trino.execution.BasicStageStats; import io.trino.execution.ExecutionFailureInfo; @@ -85,12 +88,14 @@ import io.trino.sql.planner.PlanFragmentIdAllocator; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.SubPlan; +import io.trino.sql.planner.optimizations.PlanNodeSearcher; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.PlanFragmentId; import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.PlanNodeId; import io.trino.sql.planner.plan.RefreshMaterializedViewNode; import io.trino.sql.planner.plan.RemoteSourceNode; +import io.trino.sql.planner.plan.TableScanNode; import it.unimi.dsi.fastutil.ints.Int2ObjectMap; import it.unimi.dsi.fastutil.ints.Int2ObjectOpenHashMap; import it.unimi.dsi.fastutil.ints.IntOpenHashSet; @@ -1150,6 +1155,7 @@ private void createStageExecution(SubPlan subPlan, boolean rootFragment, Map node.getSourceFragmentIds().stream()) + .allMatch(sourceFragmentId -> stageExecutions.get(getStageId(sourceFragmentId)).isNoMemoryFragment())) { + return false; + } + + // If fragment source is not reading any external tables or only accesses information_schema assume it does not need significant amount of memory. + // Allow scheduling even if whole server memory is pre allocated. + List tableScanNodes = PlanNodeSearcher.searchFrom(fragment.getRoot()).whereIsInstanceOfAny(TableScanNode.class).findAll(); + return tableScanNodes.stream().allMatch(node -> isMetadataTableScan((TableScanNode) node)); + } + + private static boolean isMetadataTableScan(TableScanNode tableScanNode) + { + return (tableScanNode.getTable().getConnectorHandle() instanceof InformationSchemaTableHandle) || + (tableScanNode.getTable().getCatalogHandle().getCatalogName().equals(GlobalSystemConnector.NAME) && + (tableScanNode.getTable().getConnectorHandle() instanceof SystemTableHandle systemHandle) && + systemHandle.getSchemaName().equals("jdbc")); + } + private StageId getStageId(PlanFragmentId fragmentId) { return StageId.create(queryStateMachine.getQueryId(), fragmentId); @@ -1523,6 +1553,7 @@ private static class StageExecution private final EventDrivenTaskSource taskSource; private final FaultTolerantPartitioningScheme sinkPartitioningScheme; private final Exchange exchange; + private final boolean noMemoryFragment; private final PartitionMemoryEstimator partitionMemoryEstimator; private final int maxTaskExecutionAttempts; private final int schedulingPriority; @@ -1559,6 +1590,7 @@ private StageExecution( EventDrivenTaskSource taskSource, FaultTolerantPartitioningScheme sinkPartitioningScheme, Exchange exchange, + boolean noMemoryFragment, PartitionMemoryEstimator partitionMemoryEstimator, int maxTaskExecutionAttempts, int schedulingPriority, @@ -1576,6 +1608,7 @@ private StageExecution( this.taskSource = requireNonNull(taskSource, "taskSource is null"); this.sinkPartitioningScheme = requireNonNull(sinkPartitioningScheme, "sinkPartitioningScheme is null"); this.exchange = requireNonNull(exchange, "exchange is null"); + this.noMemoryFragment = noMemoryFragment; this.partitionMemoryEstimator = requireNonNull(partitionMemoryEstimator, "partitionMemoryEstimator is null"); this.maxTaskExecutionAttempts = maxTaskExecutionAttempts; this.schedulingPriority = schedulingPriority; @@ -1628,6 +1661,11 @@ public boolean isExchangeClosed() return exchangeClosed; } + public boolean isNoMemoryFragment() + { + return noMemoryFragment; + } + public void addPartition(int partitionId, NodeRequirements nodeRequirements) { if (getState().isDone()) { diff --git a/core/trino-main/src/main/java/io/trino/execution/scheduler/NoMemoryPartitionMemoryEstimator.java b/core/trino-main/src/main/java/io/trino/execution/scheduler/NoMemoryPartitionMemoryEstimator.java new file mode 100644 index 0000000000000..78b62761040d1 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/execution/scheduler/NoMemoryPartitionMemoryEstimator.java @@ -0,0 +1,39 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.execution.scheduler; + +import io.airlift.units.DataSize; +import io.trino.Session; +import io.trino.spi.ErrorCode; + +import java.util.Optional; + +public class NoMemoryPartitionMemoryEstimator + implements PartitionMemoryEstimator +{ + @Override + public MemoryRequirements getInitialMemoryRequirements(Session session, DataSize defaultMemoryLimit) + { + return new MemoryRequirements(DataSize.ofBytes(0)); + } + + @Override + public MemoryRequirements getNextRetryMemoryRequirements(Session session, MemoryRequirements previousMemoryRequirements, DataSize peakMemoryUsage, ErrorCode errorCode) + { + return new MemoryRequirements(DataSize.ofBytes(0)); + } + + @Override + public void registerPartitionFinished(Session session, MemoryRequirements previousMemoryRequirements, DataSize peakMemoryUsage, boolean success, Optional errorCode) {} +} diff --git a/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/TestDistributedFaultTolerantEngineOnlyQueries.java b/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/TestDistributedFaultTolerantEngineOnlyQueries.java index f970e177a4d74..f108ed8098589 100644 --- a/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/TestDistributedFaultTolerantEngineOnlyQueries.java +++ b/testing/trino-faulttolerant-tests/src/test/java/io/trino/faulttolerant/TestDistributedFaultTolerantEngineOnlyQueries.java @@ -14,10 +14,14 @@ package io.trino.faulttolerant; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.MoreCollectors; +import io.trino.Session; import io.trino.connector.MockConnectorFactory; import io.trino.connector.MockConnectorPlugin; +import io.trino.execution.QueryState; import io.trino.plugin.exchange.filesystem.FileSystemExchangePlugin; import io.trino.plugin.memory.MemoryQueryRunner; +import io.trino.server.BasicQueryInfo; import io.trino.testing.AbstractDistributedEngineOnlyQueries; import io.trino.testing.DistributedQueryRunner; import io.trino.testing.FaultTolerantExecutionConnectorTestHelper; @@ -25,8 +29,14 @@ import io.trino.tpch.TpchTable; import org.testng.annotations.Test; +import java.util.Optional; +import java.util.concurrent.ExecutorService; + import static io.airlift.testing.Closeables.closeAllSuppress; import static io.trino.testing.TestingNames.randomNameSuffix; +import static io.trino.testing.assertions.Assert.assertEventually; +import static java.util.concurrent.Executors.newCachedThreadPool; +import static org.assertj.core.api.Assertions.assertThat; public class TestDistributedFaultTolerantEngineOnlyQueries extends AbstractDistributedEngineOnlyQueries @@ -97,4 +107,66 @@ t2 AS ( assertUpdate("DROP TABLE " + tableName); } + + @Test(timeOut = 30_000) + public void testMetadataOnlyQueries() + throws InterruptedException + { + // enforce single task uses whole node + Session highTaskMemorySession = Session.builder(getSession()) + .setSystemProperty("fault_tolerant_execution_coordinator_task_memory", "500GB") + .setSystemProperty("fault_tolerant_execution_task_memory", "500GB") + .build(); + + ExecutorService backgroundExecutor = newCachedThreadPool(); + try { + String longQuery = "select count(*) long_query_count FROM lineitem l1 cross join lineitem l2 cross join lineitem l3 where l1.orderkey * l2.orderkey * l3.orderkey = 1"; + backgroundExecutor.submit(() -> { + query(highTaskMemorySession, longQuery); + }); + assertEventually(() -> queryIsInState(longQuery, QueryState.RUNNING)); + + assertThat(query("DESCRIBE lineitem")).succeeds(); + assertThat(query("SHOW TABLES")).succeeds(); + assertThat(query("SHOW TABLES LIKE 'line%'")).succeeds(); + assertThat(query("SHOW SCHEMAS")).succeeds(); + assertThat(query("SHOW SCHEMAS LIKE 'def%'")).succeeds(); + assertThat(query("SHOW CATALOGS")).succeeds(); + assertThat(query("SHOW CATALOGS LIKE 'mem%'")).succeeds(); + assertThat(query("SHOW FUNCTIONS")).succeeds(); + assertThat(query("SHOW FUNCTIONS LIKE 'split%'")).succeeds(); + assertThat(query("SHOW COLUMNS FROM lineitem")).succeeds(); + assertThat(query("SHOW SESSION")).succeeds(); + assertThat(query("SELECT count(*) FROM information_schema.tables")).succeeds(); + assertThat(query("SELECT * FROM system.jdbc.tables WHERE table_schem LIKE 'def%'")).succeeds(); + + // check non-metadata queries still wait for resources + String nonMetadataQuery = "select count(*) non_metadata_query_count from nation"; + backgroundExecutor.submit(() -> { + query(nonMetadataQuery); + }); + assertEventually(() -> queryIsInState(nonMetadataQuery, QueryState.STARTING)); + Thread.sleep(1000); // wait a bit longer and query should be still STARTING + assertThat(queryState(nonMetadataQuery).orElseThrow()).isEqualTo(QueryState.STARTING); + + // long query should be still running + assertThat(queryState(longQuery).orElseThrow()).isEqualTo(QueryState.RUNNING); + } + finally { + backgroundExecutor.shutdownNow(); + } + } + + private Optional queryState(String queryText) + { + return getDistributedQueryRunner().getCoordinator().getQueryManager().getQueries().stream() + .filter(query -> query.getQuery().equals(queryText)) + .collect(MoreCollectors.toOptional()) + .map(BasicQueryInfo::getState); + } + + private boolean queryIsInState(String queryText, QueryState queryState) + { + return queryState(queryText).map(state -> state == queryState).orElse(false); + } }