Skip to content

Commit

Permalink
Implement memory accounting for TaskDescriptor
Browse files Browse the repository at this point in the history
  • Loading branch information
arhimondr authored and losipiuk committed Feb 9, 2022
1 parent 3904620 commit 099ea40
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonValue;
import org.openjdk.jol.info.ClassLayout;

import java.util.Objects;

Expand All @@ -25,6 +26,8 @@

public class Lifespan
{
private static final int INSTANCE_SIZE = ClassLayout.parseClass(Lifespan.class).instanceSize();

private static final Lifespan TASK_WIDE = new Lifespan(false, 0);

private final boolean grouped;
Expand Down Expand Up @@ -93,4 +96,9 @@ public int hashCode()
{
return Objects.hash(grouped, groupId);
}

public long getRetainedSizeInBytes()
{
return INSTANCE_SIZE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,21 @@
import com.google.common.collect.ImmutableSet;
import io.trino.connector.CatalogName;
import io.trino.spi.HostAddress;
import org.openjdk.jol.info.ClassLayout;

import java.util.Objects;
import java.util.Optional;
import java.util.Set;

import static com.google.common.base.MoreObjects.toStringHelper;
import static io.airlift.slice.SizeOf.estimatedSizeOf;
import static io.airlift.slice.SizeOf.sizeOf;
import static java.util.Objects.requireNonNull;

public class NodeRequirements
{
private static final int INSTANCE_SIZE = ClassLayout.parseClass(NodeRequirements.class).instanceSize();

private final Optional<CatalogName> catalogName;
private final Set<HostAddress> addresses;

Expand Down Expand Up @@ -78,4 +83,11 @@ public String toString()
.add("addresses", addresses)
.toString();
}

public long getRetainedSizeInBytes()
{
return INSTANCE_SIZE
+ sizeOf(catalogName, CatalogName::getRetainedSizeInBytes)
+ estimatedSizeOf(addresses, HostAddress::getRetainedSizeInBytes);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Multimap;
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.log.Logger;
Expand Down Expand Up @@ -163,7 +164,7 @@ else if (partitioning.equals(SOURCE_DISTRIBUTION)) {
public static class SingleDistributionTaskSource
implements TaskSource
{
private final Multimap<PlanNodeId, ExchangeSourceHandle> exchangeSourceHandles;
private final ListMultimap<PlanNodeId, ExchangeSourceHandle> exchangeSourceHandles;

private boolean finished;

Expand All @@ -173,17 +174,17 @@ public static SingleDistributionTaskSource create(PlanFragment fragment, Multima
return new SingleDistributionTaskSource(getInputsForRemoteSources(fragment.getRemoteSourceNodes(), exchangeSourceHandles));
}

public SingleDistributionTaskSource(Multimap<PlanNodeId, ExchangeSourceHandle> exchangeSourceHandles)
public SingleDistributionTaskSource(ListMultimap<PlanNodeId, ExchangeSourceHandle> exchangeSourceHandles)
{
this.exchangeSourceHandles = ImmutableMultimap.copyOf(requireNonNull(exchangeSourceHandles, "exchangeSourceHandles is null"));
this.exchangeSourceHandles = ImmutableListMultimap.copyOf(requireNonNull(exchangeSourceHandles, "exchangeSourceHandles is null"));
}

@Override
public List<TaskDescriptor> getMoreTasks()
{
List<TaskDescriptor> result = ImmutableList.of(new TaskDescriptor(
0,
ImmutableMultimap.of(),
ImmutableListMultimap.of(),
exchangeSourceHandles,
new NodeRequirements(Optional.empty(), ImmutableSet.of())));
finished = true;
Expand Down Expand Up @@ -380,7 +381,7 @@ public List<TaskDescriptor> getMoreTasks()
return ImmutableList.of();
}

Map<Integer, Multimap<PlanNodeId, Split>> partitionToSplitsMap = new HashMap<>();
Map<Integer, ListMultimap<PlanNodeId, Split>> partitionToSplitsMap = new HashMap<>();
Map<Integer, HostAddress> partitionToNodeMap = new HashMap<>();
for (Map.Entry<PlanNodeId, SplitSource> entry : splitSources.entrySet()) {
SplitSource splitSource = entry.getValue();
Expand Down Expand Up @@ -426,8 +427,8 @@ public List<TaskDescriptor> getMoreTasks()
int taskPartitionId = 0;
ImmutableList.Builder<TaskDescriptor> result = ImmutableList.builder();
for (Integer partition : union(partitionToSplitsMap.keySet(), partitionToExchangeSourceHandlesMap.keySet())) {
Multimap<PlanNodeId, Split> splits = partitionToSplitsMap.getOrDefault(partition, ImmutableMultimap.of());
Multimap<PlanNodeId, ExchangeSourceHandle> exchangeSourceHandles = ImmutableListMultimap.<PlanNodeId, ExchangeSourceHandle>builder()
ListMultimap<PlanNodeId, Split> splits = partitionToSplitsMap.getOrDefault(partition, ImmutableListMultimap.of());
ListMultimap<PlanNodeId, ExchangeSourceHandle> exchangeSourceHandles = ImmutableListMultimap.<PlanNodeId, ExchangeSourceHandle>builder()
.putAll(partitionToExchangeSourceHandlesMap.getOrDefault(partition, ImmutableMultimap.of()))
.putAll(replicatedExchangeSourceHandles)
.build();
Expand Down Expand Up @@ -476,7 +477,7 @@ public static class SourceDistributionTaskSource
private final PlanNodeId partitionedSourceNodeId;
private final TableExecuteContextManager tableExecuteContextManager;
private final SplitSource splitSource;
private final Multimap<PlanNodeId, ExchangeSourceHandle> replicatedExchangeSourceHandles;
private final ListMultimap<PlanNodeId, ExchangeSourceHandle> replicatedExchangeSourceHandles;
private final int splitBatchSize;
private final LongConsumer getSplitTimeRecorder;
private final Optional<CatalogName> catalogRequirement;
Expand Down Expand Up @@ -528,7 +529,7 @@ public SourceDistributionTaskSource(
PlanNodeId partitionedSourceNodeId,
TableExecuteContextManager tableExecuteContextManager,
SplitSource splitSource,
Multimap<PlanNodeId, ExchangeSourceHandle> replicatedExchangeSourceHandles,
ListMultimap<PlanNodeId, ExchangeSourceHandle> replicatedExchangeSourceHandles,
int splitBatchSize,
LongConsumer getSplitTimeRecorder,
Optional<CatalogName> catalogRequirement,
Expand Down Expand Up @@ -674,7 +675,7 @@ public void close()
}
}

private static Multimap<PlanNodeId, ExchangeSourceHandle> getReplicatedExchangeSourceHandles(PlanFragment fragment, Multimap<PlanFragmentId, ExchangeSourceHandle> handles)
private static ListMultimap<PlanNodeId, ExchangeSourceHandle> getReplicatedExchangeSourceHandles(PlanFragment fragment, Multimap<PlanFragmentId, ExchangeSourceHandle> handles)
{
return getInputsForRemoteSources(
fragment.getRemoteSourceNodes().stream()
Expand All @@ -683,7 +684,7 @@ private static Multimap<PlanNodeId, ExchangeSourceHandle> getReplicatedExchangeS
handles);
}

private static Multimap<PlanNodeId, ExchangeSourceHandle> getPartitionedExchangeSourceHandles(PlanFragment fragment, Multimap<PlanFragmentId, ExchangeSourceHandle> handles)
private static ListMultimap<PlanNodeId, ExchangeSourceHandle> getPartitionedExchangeSourceHandles(PlanFragment fragment, Multimap<PlanFragmentId, ExchangeSourceHandle> handles)
{
return getInputsForRemoteSources(
fragment.getRemoteSourceNodes().stream()
Expand All @@ -703,7 +704,7 @@ private static Map<PlanFragmentId, PlanNodeId> getSourceFragmentToRemoteSourceNo
return result.build();
}

private static Multimap<PlanNodeId, ExchangeSourceHandle> getInputsForRemoteSources(
private static ListMultimap<PlanNodeId, ExchangeSourceHandle> getInputsForRemoteSources(
List<RemoteSourceNode> remoteSources,
Multimap<PlanFragmentId, ExchangeSourceHandle> exchangeSourceHandles)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,40 @@
*/
package io.trino.execution.scheduler;

import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.Multimap;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ListMultimap;
import io.trino.metadata.Split;
import io.trino.spi.exchange.ExchangeSourceHandle;
import io.trino.sql.planner.plan.PlanNodeId;
import org.openjdk.jol.info.ClassLayout;

import java.util.Objects;

import static com.google.common.base.MoreObjects.toStringHelper;
import static com.google.common.collect.Multimaps.asMap;
import static io.airlift.slice.SizeOf.estimatedSizeOf;
import static java.util.Objects.requireNonNull;

public class TaskDescriptor
{
private static final int INSTANCE_SIZE = ClassLayout.parseClass(TaskDescriptor.class).instanceSize();

private final int partitionId;
private final Multimap<PlanNodeId, Split> splits;
private final Multimap<PlanNodeId, ExchangeSourceHandle> exchangeSourceHandles;
private final ListMultimap<PlanNodeId, Split> splits;
private final ListMultimap<PlanNodeId, ExchangeSourceHandle> exchangeSourceHandles;
private final NodeRequirements nodeRequirements;

private transient volatile long retainedSizeInBytes;

public TaskDescriptor(
int partitionId,
Multimap<PlanNodeId, Split> splits,
Multimap<PlanNodeId, ExchangeSourceHandle> exchangeSourceHandles,
ListMultimap<PlanNodeId, Split> splits,
ListMultimap<PlanNodeId, ExchangeSourceHandle> exchangeSourceHandles,
NodeRequirements nodeRequirements)
{
this.partitionId = partitionId;
this.splits = ImmutableMultimap.copyOf(requireNonNull(splits, "splits is null"));
this.exchangeSourceHandles = ImmutableMultimap.copyOf(requireNonNull(exchangeSourceHandles, "exchangeSourceHandles is null"));
this.splits = ImmutableListMultimap.copyOf(requireNonNull(splits, "splits is null"));
this.exchangeSourceHandles = ImmutableListMultimap.copyOf(requireNonNull(exchangeSourceHandles, "exchangeSourceHandles is null"));
this.nodeRequirements = requireNonNull(nodeRequirements, "nodeRequirements is null");
}

Expand All @@ -48,12 +55,12 @@ public int getPartitionId()
return partitionId;
}

public Multimap<PlanNodeId, Split> getSplits()
public ListMultimap<PlanNodeId, Split> getSplits()
{
return splits;
}

public Multimap<PlanNodeId, ExchangeSourceHandle> getExchangeSourceHandles()
public ListMultimap<PlanNodeId, ExchangeSourceHandle> getExchangeSourceHandles()
{
return exchangeSourceHandles;
}
Expand Down Expand Up @@ -92,4 +99,17 @@ public String toString()
.add("nodeRequirements", nodeRequirements)
.toString();
}

public long getRetainedSizeInBytes()
{
long result = retainedSizeInBytes;
if (result == 0) {
result = INSTANCE_SIZE
+ estimatedSizeOf(asMap(splits), PlanNodeId::getRetainedSizeInBytes, splits -> estimatedSizeOf(splits, Split::getRetainedSizeInBytes))
+ estimatedSizeOf(asMap(exchangeSourceHandles), PlanNodeId::getRetainedSizeInBytes, handles -> estimatedSizeOf(handles, ExchangeSourceHandle::getRetainedSizeInBytes))
+ nodeRequirements.getRetainedSizeInBytes();
retainedSizeInBytes = result;
}
return result;
}
}
11 changes: 11 additions & 0 deletions core/trino-main/src/main/java/io/trino/metadata/Split.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.trino.spi.HostAddress;
import io.trino.spi.SplitWeight;
import io.trino.spi.connector.ConnectorSplit;
import org.openjdk.jol.info.ClassLayout;

import java.util.List;

Expand All @@ -28,6 +29,8 @@

public final class Split
{
private static final int INSTANCE_SIZE = ClassLayout.parseClass(Split.class).instanceSize();

private final CatalogName catalogName;
private final ConnectorSplit connectorSplit;
private final Lifespan lifespan;
Expand Down Expand Up @@ -90,4 +93,12 @@ public String toString()
.add("lifespan", lifespan)
.toString();
}

public long getRetainedSizeInBytes()
{
return INSTANCE_SIZE
+ catalogName.getRetainedSizeInBytes()
+ connectorSplit.getRetainedSizeInBytes()
+ lifespan.getRetainedSizeInBytes();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,18 @@

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonValue;
import org.openjdk.jol.info.ClassLayout;

import javax.annotation.concurrent.Immutable;

import static io.airlift.slice.SizeOf.estimatedSizeOf;
import static java.util.Objects.requireNonNull;

@Immutable
public class PlanNodeId
{
private static final int INSTANCE_SIZE = ClassLayout.parseClass(PlanNodeId.class).instanceSize();

private final String id;

@JsonCreator
Expand Down Expand Up @@ -58,4 +62,10 @@ public int hashCode()
{
return id.hashCode();
}

public long getRetainedSizeInBytes()
{
return INSTANCE_SIZE
+ estimatedSizeOf(id);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Multimap;
import io.airlift.units.DataSize;
import io.trino.connector.CatalogName;
Expand Down Expand Up @@ -73,7 +74,7 @@ public class TestStageTaskSourceFactory
@Test
public void testSingleDistributionTaskSource()
{
Multimap<PlanNodeId, ExchangeSourceHandle> sources = ImmutableListMultimap.<PlanNodeId, ExchangeSourceHandle>builder()
ListMultimap<PlanNodeId, ExchangeSourceHandle> sources = ImmutableListMultimap.<PlanNodeId, ExchangeSourceHandle>builder()
.put(PLAN_NODE_1, new TestingExchangeSourceHandle(0, 123))
.put(PLAN_NODE_2, new TestingExchangeSourceHandle(0, 321))
.put(PLAN_NODE_1, new TestingExchangeSourceHandle(0, 222))
Expand Down Expand Up @@ -489,7 +490,7 @@ public void testSourceDistributionTaskSource()

private static SourceDistributionTaskSource createSourceDistributionTaskSource(
List<Split> splits,
Multimap<PlanNodeId, ExchangeSourceHandle> replicatedSources,
ListMultimap<PlanNodeId, ExchangeSourceHandle> replicatedSources,
int splitBatchSize,
int splitsPerTask)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Multimap;
import io.trino.Session;
import io.trino.connector.CatalogName;
Expand Down Expand Up @@ -76,7 +77,7 @@ public TaskSource create(
getHandlesForRemoteSources(fragment.getRemoteSourceNodes(), exchangeSourceHandles));
}

private static Multimap<PlanNodeId, ExchangeSourceHandle> getHandlesForRemoteSources(
private static ListMultimap<PlanNodeId, ExchangeSourceHandle> getHandlesForRemoteSources(
List<RemoteSourceNode> remoteSources,
Multimap<PlanFragmentId, ExchangeSourceHandle> exchangeSourceHandles)
{
Expand All @@ -99,7 +100,7 @@ public static class TestingTaskSource
private final Iterator<Split> splits;
private final int tasksPerBatch;
private final PlanNodeId tableScanPlanNodeId;
private final Multimap<PlanNodeId, ExchangeSourceHandle> exchangeSourceHandles;
private final ListMultimap<PlanNodeId, ExchangeSourceHandle> exchangeSourceHandles;

private final AtomicInteger nextPartitionId = new AtomicInteger();

Expand All @@ -108,7 +109,7 @@ public TestingTaskSource(
List<Split> splits,
int tasksPerBatch,
PlanNodeId tableScanPlanNodeId,
Multimap<PlanNodeId, ExchangeSourceHandle> exchangeSourceHandles)
ListMultimap<PlanNodeId, ExchangeSourceHandle> exchangeSourceHandles)
{
this.catalogRequirement = requireNonNull(catalogRequirement, "catalogRequirement is null");
this.splits = ImmutableList.copyOf(requireNonNull(splits, "splits is null")).iterator();
Expand Down

0 comments on commit 099ea40

Please sign in to comment.