Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update ClusterByStatisticsCollectorImpl to use bytes instead of keys #12998

Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
*/
public class StageDefinition
{
private static final int PARTITION_STATS_MAX_KEYS = 2 << 15; // Avoid immediate downsample of single-bucket collectors
private static final int PARTITION_STATS_MAX_BYTES = 10_000_000; // Avoid immediate downsample of single-bucket collectors
private static final int PARTITION_STATS_MAX_BUCKETS = 5_000; // Limit for TooManyBuckets
private static final int MAX_PARTITIONS = 25_000; // Limit for TooManyPartitions

Expand Down Expand Up @@ -289,7 +289,7 @@ public ClusterByStatisticsCollector createResultKeyStatisticsCollector()
return ClusterByStatisticsCollectorImpl.create(
shuffleSpec.getClusterBy(),
signature,
PARTITION_STATS_MAX_KEYS,
PARTITION_STATS_MAX_BYTES,
PARTITION_STATS_MAX_BUCKETS,
shuffleSpec.doesAggregateByClusterKey(),
shuffleCheckHasMultipleValues
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,39 +56,37 @@ public class ClusterByStatisticsCollectorImpl implements ClusterByStatisticsColl

private final boolean[] hasMultipleValues;

// This can be reworked to accommodate maxSize instead of maxRetainedKeys to account for the skewness in the size of hte
// keys depending on the datasource
private final int maxRetainedKeys;
private final int maxRetainedBytes;
private final int maxBuckets;
private int totalRetainedKeys;
private double totalRetainedBytes;

private ClusterByStatisticsCollectorImpl(
final ClusterBy clusterBy,
final RowKeyReader keyReader,
final KeyCollectorFactory<?, ?> keyCollectorFactory,
final int maxRetainedKeys,
final int maxRetainedBytes,
final int maxBuckets,
final boolean checkHasMultipleValues
)
{
this.clusterBy = clusterBy;
this.keyReader = keyReader;
this.keyCollectorFactory = keyCollectorFactory;
this.maxRetainedKeys = maxRetainedKeys;
this.maxRetainedBytes = maxRetainedBytes;
this.buckets = new TreeMap<>(clusterBy.bucketComparator());
this.maxBuckets = maxBuckets;
this.checkHasMultipleValues = checkHasMultipleValues;
this.hasMultipleValues = checkHasMultipleValues ? new boolean[clusterBy.getColumns().size()] : null;

if (maxBuckets > maxRetainedKeys) {
throw new IAE("maxBuckets[%s] cannot be larger than maxRetainedKeys[%s]", maxBuckets, maxRetainedKeys);
if (maxBuckets > maxRetainedBytes) {
throw new IAE("maxBuckets[%s] cannot be larger than maxRetainedBytes[%s]", maxBuckets, maxRetainedBytes);
}
}

public static ClusterByStatisticsCollector create(
final ClusterBy clusterBy,
final RowSignature signature,
final int maxRetainedKeys,
final int maxRetainedBytes,
final int maxBuckets,
final boolean aggregate,
final boolean checkHasMultipleValues
Expand All @@ -101,7 +99,7 @@ public static ClusterByStatisticsCollector create(
clusterBy,
keyReader,
keyCollectorFactory,
maxRetainedKeys,
maxRetainedBytes,
maxBuckets,
checkHasMultipleValues
);
Expand All @@ -126,8 +124,8 @@ public ClusterByStatisticsCollector add(final RowKey key, final int weight)

bucketHolder.keyCollector.add(key, weight);

totalRetainedKeys += bucketHolder.updateRetainedKeys();
if (totalRetainedKeys > maxRetainedKeys) {
totalRetainedBytes += bucketHolder.updateRetainedBytes();
if (totalRetainedBytes > maxRetainedBytes) {
downSample();
}

Expand All @@ -147,15 +145,15 @@ public ClusterByStatisticsCollector addAll(final ClusterByStatisticsCollector ot
//noinspection rawtypes, unchecked
((KeyCollector) bucketHolder.keyCollector).addAll(otherBucketEntry.getValue().keyCollector);

totalRetainedKeys += bucketHolder.updateRetainedKeys();
if (totalRetainedKeys > maxRetainedKeys) {
totalRetainedBytes += bucketHolder.updateRetainedBytes();
if (totalRetainedBytes > maxRetainedBytes) {
downSample();
}
}

if (checkHasMultipleValues) {
for (int i = 0; i < clusterBy.getColumns().size(); i++) {
hasMultipleValues[i] |= that.hasMultipleValues[i];
hasMultipleValues[i] = hasMultipleValues[i] || that.hasMultipleValues[i];
}
}
} else {
Expand All @@ -178,8 +176,8 @@ public ClusterByStatisticsCollector addAll(final ClusterByStatisticsSnapshot sna
//noinspection rawtypes, unchecked
((KeyCollector) bucketHolder.keyCollector).addAll(otherKeyCollector);

totalRetainedKeys += bucketHolder.updateRetainedKeys();
if (totalRetainedKeys > maxRetainedKeys) {
totalRetainedBytes += bucketHolder.updateRetainedBytes();
if (totalRetainedBytes > maxRetainedBytes) {
downSample();
}
}
Expand Down Expand Up @@ -221,7 +219,7 @@ public boolean hasMultipleValues(final int keyPosition)
public ClusterByStatisticsCollector clear()
{
buckets.clear();
totalRetainedKeys = 0;
totalRetainedBytes = 0;
return this;
}

Expand Down Expand Up @@ -365,75 +363,75 @@ private BucketHolder getOrCreateBucketHolder(final RowKey bucketKey)
}

/**
* Reduce the number of retained keys by about half, if possible. May reduce by less than that, or keep the
* Reduce the number of retained bytes by about half, if possible. May reduce by less than that, or keep the
* number the same, if downsampling is not possible. (For example: downsampling is not possible if all buckets
* have been downsampled all the way to one key each.)
*/
private void downSample()
{
int newTotalRetainedKeys = totalRetainedKeys;
final int targetTotalRetainedKeys = totalRetainedKeys / 2;
double newTotalRetainedBytes = totalRetainedBytes;
final double targetTotalRetainedBytes = totalRetainedBytes / 2;

final List<BucketHolder> sortedHolders = new ArrayList<>(buckets.size());

// Only consider holders with more than one retained key. Holders with a single retained key cannot be downsampled.
for (final BucketHolder holder : buckets.values()) {
if (holder.retainedKeys > 1) {
if (holder.keyCollector.estimatedRetainedKeys() > 1) {
sortedHolders.add(holder);
}
}

// Downsample least-dense buckets first. (They're less likely to need high resolution.)
sortedHolders.sort(
Comparator.comparing((BucketHolder holder) ->
(double) holder.keyCollector.estimatedTotalWeight() / holder.retainedKeys)
(double) holder.keyCollector.estimatedTotalWeight() / holder.retainedBytes)
adarshsanjeev marked this conversation as resolved.
Show resolved Hide resolved
);

int i = 0;
while (i < sortedHolders.size() && newTotalRetainedKeys > targetTotalRetainedKeys) {
while (i < sortedHolders.size() && newTotalRetainedBytes > targetTotalRetainedBytes) {
final BucketHolder bucketHolder = sortedHolders.get(i);

// Ignore false return, because we wrap all collectors in DelegateOrMinKeyCollector and can be assured that
// it will downsample all the way to one if needed. Can't do better than that.
bucketHolder.keyCollector.downSample();
newTotalRetainedKeys += bucketHolder.updateRetainedKeys();
newTotalRetainedBytes += bucketHolder.updateRetainedBytes();

if (i == sortedHolders.size() - 1 || sortedHolders.get(i + 1).retainedKeys > bucketHolder.retainedKeys) {
if (i == sortedHolders.size() - 1 || sortedHolders.get(i + 1).retainedBytes > bucketHolder.retainedBytes) {
i++;
}
}

totalRetainedKeys = newTotalRetainedKeys;
totalRetainedBytes = newTotalRetainedBytes;
}

private void assertRetainedKeyCountsAreTrackedCorrectly()
adarshsanjeev marked this conversation as resolved.
Show resolved Hide resolved
{
// Check cached value of retainedKeys in each holder.
assert buckets.values()
.stream()
.allMatch(holder -> holder.retainedKeys == holder.keyCollector.estimatedRetainedKeys());
.allMatch(holder -> holder.retainedBytes == holder.keyCollector.estimatedRetainedBytes());

// Check cached value of totalRetainedKeys.
assert totalRetainedKeys ==
buckets.values().stream().mapToInt(holder -> holder.keyCollector.estimatedRetainedKeys()).sum();
// Check cached value of totalRetainedBytes.
assert totalRetainedBytes ==
buckets.values().stream().mapToDouble(holder -> holder.keyCollector.estimatedRetainedBytes()).sum();
}

private static class BucketHolder
{
private final KeyCollector<?> keyCollector;
private int retainedKeys;
private double retainedBytes;

public BucketHolder(final KeyCollector<?> keyCollector)
{
this.keyCollector = keyCollector;
this.retainedKeys = keyCollector.estimatedRetainedKeys();
this.retainedBytes = keyCollector.estimatedRetainedBytes();
}

public int updateRetainedKeys()
public double updateRetainedBytes()
{
final int newRetainedKeys = keyCollector.estimatedRetainedKeys();
final int difference = newRetainedKeys - retainedKeys;
retainedKeys = newRetainedKeys;
final double newRetainedKeys = keyCollector.estimatedRetainedBytes();
adarshsanjeev marked this conversation as resolved.
Show resolved Hide resolved
final double difference = newRetainedKeys - retainedBytes;
retainedBytes = newRetainedKeys;
return difference;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,16 @@ public int estimatedRetainedKeys()
}
}

@Override
public double estimatedRetainedBytes()
{
if (delegate != null) {
return delegate.estimatedRetainedBytes();
} else {
return minKey != null ? minKey.array().length : 0;
}
}

@Override
public boolean downSample()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@
*/
public class DistinctKeyCollector implements KeyCollector<DistinctKeyCollector>
{
static final int INITIAL_MAX_KEYS = 2 << 15 /* 65,536 */;
static final int SMALLEST_MAX_KEYS = 16;
static final int INITIAL_MAX_BYTES = 5_120_000;
adarshsanjeev marked this conversation as resolved.
Show resolved Hide resolved
static final int SMALLEST_MAX_BYTES = 5000;
adarshsanjeev marked this conversation as resolved.
Show resolved Hide resolved
private static final int MISSING_KEY_WEIGHT = 0;

private final Comparator<RowKey> comparator;
Expand All @@ -71,7 +71,8 @@ public class DistinctKeyCollector implements KeyCollector<DistinctKeyCollector>
* collector type, which is based on a more solid statistical foundation.
*/
private final Object2LongSortedMap<RowKey> retainedKeys;
private int maxKeys;
private int maxBytes;
private int retainedBytes;

/**
* Each key is retained with probability 2^(-spaceReductionFactor). This value is incremented on calls to
Expand All @@ -92,7 +93,7 @@ public class DistinctKeyCollector implements KeyCollector<DistinctKeyCollector>
this.comparator = Preconditions.checkNotNull(comparator, "comparator");
this.retainedKeys = Preconditions.checkNotNull(retainedKeys, "retainedKeys");
this.retainedKeys.defaultReturnValue(MISSING_KEY_WEIGHT);
this.maxKeys = INITIAL_MAX_KEYS;
this.maxBytes = INITIAL_MAX_BYTES;
this.spaceReductionFactor = spaceReductionFactor;
this.totalWeightUnadjusted = 0;

Expand Down Expand Up @@ -120,14 +121,16 @@ public void add(RowKey key, long weight)
if (isNewMin && !retainedKeys.isEmpty() && !isKeySelected(retainedKeys.firstKey())) {
// Old min should be kicked out.
totalWeightUnadjusted -= retainedKeys.removeLong(retainedKeys.firstKey());
retainedBytes -= retainedKeys.firstKey().array().length;
adarshsanjeev marked this conversation as resolved.
Show resolved Hide resolved
}

if (retainedKeys.putIfAbsent(key, weight) == MISSING_KEY_WEIGHT) {
// We did add this key. (Previous value was zero, meaning absent.)
totalWeightUnadjusted += weight;
retainedBytes += key.array().length;
}

while (retainedKeys.size() >= maxKeys) {
while (retainedBytes >= maxBytes) {
increaseSpaceReductionFactorIfPossible();
}
}
Expand Down Expand Up @@ -168,6 +171,12 @@ public int estimatedRetainedKeys()
return retainedKeys.size();
}

@Override
public double estimatedRetainedBytes()
{
return retainedBytes;
}

@Override
public RowKey minKey()
{
Expand All @@ -182,13 +191,13 @@ public boolean downSample()
return true;
}

if (maxKeys == SMALLEST_MAX_KEYS) {
if (maxBytes <= SMALLEST_MAX_BYTES) {
return false;
}

maxKeys /= 2;
maxBytes /= 2;

while (retainedKeys.size() >= maxKeys) {
while (retainedBytes >= maxBytes) {
if (!increaseSpaceReductionFactorIfPossible()) {
return false;
}
Expand Down Expand Up @@ -242,10 +251,10 @@ Map<RowKey, Long> getRetainedKeys()
return retainedKeys;
}

@JsonProperty("maxKeys")
int getMaxKeys()
@JsonProperty("maxBytes")
int getMaxBytes()
{
return maxKeys;
return maxBytes;
}

@JsonProperty("spaceReductionFactor")
Expand Down Expand Up @@ -296,6 +305,7 @@ private boolean increaseSpaceReductionFactorIfPossible()

if (!isKeySelected(key)) {
totalWeightUnadjusted -= entry.getLongValue();
retainedBytes -= entry.getKey().array().length;
iterator.remove();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ public interface KeyCollector<CollectorType extends KeyCollector<CollectorType>>
*/
int estimatedRetainedKeys();

/**
* Returns an estimate of the number of bytes currently retained by this collector. This may change over time as
* more keys are added.
*/
double estimatedRetainedBytes();

/**
* Downsample this collector, dropping about half of the keys that are currently retained. Returns true if
* the collector was downsampled, or if it is already retaining zero or one keys. Returns false if the collector is
Expand Down
Loading