Skip to content

Commit

Permalink
[Backport] Make sketch encoding configurable (#17086) (#17153)
Browse files Browse the repository at this point in the history
Makes sketch encoding in MSQ configurable by the user. This would allow a user to configure the sketch encoding method for a specific query.

The default is octet stream encoding.
  • Loading branch information
adarshsanjeev authored Sep 30, 2024
1 parent a16b75a commit e364d84
Show file tree
Hide file tree
Showing 10 changed files with 99 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,8 @@ private QueryDefinition initializeQueryDefAndState(final Closer closer)
this.workerSketchFetcher = new WorkerSketchFetcher(
netClient,
workerManager,
queryKernelConfig.isFaultTolerant()
queryKernelConfig.isFaultTolerant(),
MultiStageQueryContext.getSketchEncoding(querySpec.getQuery().context())
);
closer.register(workerSketchFetcher::close);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.apache.druid.msq.indexing.error.WorkerRpcFailedFault;
import org.apache.druid.msq.kernel.StageId;
import org.apache.druid.msq.kernel.WorkOrder;
import org.apache.druid.msq.rpc.SketchEncoding;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;

import javax.annotation.Nullable;
Expand Down Expand Up @@ -60,23 +61,25 @@ public ListenableFuture<Void> postWorkOrder(String workerTaskId, WorkOrder workO
@Override
public ListenableFuture<ClusterByStatisticsSnapshot> fetchClusterByStatisticsSnapshot(
String workerTaskId,
StageId stageId
StageId stageId,
SketchEncoding sketchEncoding
)
{
return wrap(workerTaskId, client, c -> c.fetchClusterByStatisticsSnapshot(workerTaskId, stageId));
return wrap(workerTaskId, client, c -> c.fetchClusterByStatisticsSnapshot(workerTaskId, stageId, sketchEncoding));
}

@Override
public ListenableFuture<ClusterByStatisticsSnapshot> fetchClusterByStatisticsSnapshotForTimeChunk(
String workerTaskId,
StageId stageId,
long timeChunk
long timeChunk,
SketchEncoding sketchEncoding
)
{
return wrap(
workerTaskId,
client,
c -> c.fetchClusterByStatisticsSnapshotForTimeChunk(workerTaskId, stageId, timeChunk)
c -> c.fetchClusterByStatisticsSnapshotForTimeChunk(workerTaskId, stageId, timeChunk, sketchEncoding)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.druid.msq.counters.CounterSnapshotsTree;
import org.apache.druid.msq.kernel.StageId;
import org.apache.druid.msq.kernel.WorkOrder;
import org.apache.druid.msq.rpc.SketchEncoding;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;

import java.io.Closeable;
Expand All @@ -47,7 +48,8 @@ public interface WorkerClient extends Closeable
*/
ListenableFuture<ClusterByStatisticsSnapshot> fetchClusterByStatisticsSnapshot(
String workerId,
StageId stageId
StageId stageId,
SketchEncoding sketchEncoding
);

/**
Expand All @@ -57,7 +59,8 @@ ListenableFuture<ClusterByStatisticsSnapshot> fetchClusterByStatisticsSnapshot(
ListenableFuture<ClusterByStatisticsSnapshot> fetchClusterByStatisticsSnapshotForTimeChunk(
String workerId,
StageId stageId,
long timeChunk
long timeChunk,
SketchEncoding sketchEncoding
);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.apache.druid.msq.indexing.error.WorkerRpcFailedFault;
import org.apache.druid.msq.kernel.StageId;
import org.apache.druid.msq.kernel.controller.ControllerQueryKernel;
import org.apache.druid.msq.rpc.SketchEncoding;
import org.apache.druid.msq.statistics.ClusterByStatisticsCollector;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import org.apache.druid.msq.statistics.CompleteKeyStatisticsInformation;
Expand All @@ -57,6 +58,7 @@ public class WorkerSketchFetcher implements AutoCloseable
private static final int DEFAULT_THREAD_COUNT = 4;

private final WorkerClient workerClient;
private final SketchEncoding sketchEncoding;
private final WorkerManager workerManager;

private final boolean retryEnabled;
Expand All @@ -68,10 +70,12 @@ public class WorkerSketchFetcher implements AutoCloseable
public WorkerSketchFetcher(
WorkerClient workerClient,
WorkerManager workerManager,
boolean retryEnabled
boolean retryEnabled,
SketchEncoding sketchEncoding
)
{
this.workerClient = workerClient;
this.sketchEncoding = sketchEncoding;
this.executorService = Execs.multiThreaded(DEFAULT_THREAD_COUNT, "SketchFetcherThreadPool-%d");
this.workerManager = workerManager;
this.retryEnabled = retryEnabled;
Expand All @@ -96,7 +100,7 @@ public void inMemoryFullSketchMerging(
executorService.submit(() -> {
fetchStatsFromWorker(
kernelActions,
() -> workerClient.fetchClusterByStatisticsSnapshot(taskId, stageId),
() -> workerClient.fetchClusterByStatisticsSnapshot(taskId, stageId, sketchEncoding),
taskId,
(kernel, snapshot) ->
kernel.mergeClusterByStatisticsCollectorForAllTimeChunks(stageId, workerNumber, snapshot),
Expand Down Expand Up @@ -252,7 +256,8 @@ public void sequentialTimeChunkMerging(
() -> workerClient.fetchClusterByStatisticsSnapshotForTimeChunk(
taskId,
new StageId(stageId.getQueryId(), stageId.getStageNumber()),
timeChunk
timeChunk,
sketchEncoding
),
taskId,
(kernel, snapshot) -> kernel.mergeClusterByStatisticsCollectorForTimeChunk(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,15 @@ public ListenableFuture<Void> postWorkOrder(String workerId, WorkOrder workOrder
@Override
public ListenableFuture<ClusterByStatisticsSnapshot> fetchClusterByStatisticsSnapshot(
String workerId,
StageId stageId
StageId stageId,
SketchEncoding sketchEncoding
)
{
String path = StringUtils.format(
"/keyStatistics/%s/%d?sketchEncoding=%s",
StringUtils.urlEncode(stageId.getQueryId()),
stageId.getStageNumber(),
WorkerResource.SketchEncoding.OCTET_STREAM
sketchEncoding
);

return getClient(workerId).asyncRequest(
Expand All @@ -110,15 +111,16 @@ public ListenableFuture<ClusterByStatisticsSnapshot> fetchClusterByStatisticsSna
public ListenableFuture<ClusterByStatisticsSnapshot> fetchClusterByStatisticsSnapshotForTimeChunk(
String workerId,
StageId stageId,
long timeChunk
long timeChunk,
SketchEncoding sketchEncoding
)
{
String path = StringUtils.format(
"/keyStatisticsForTimeChunk/%s/%d/%d?sketchEncoding=%s",
StringUtils.urlEncode(stageId.getQueryId()),
stageId.getStageNumber(),
timeChunk,
WorkerResource.SketchEncoding.OCTET_STREAM
sketchEncoding
);

return getClient(workerId).asyncRequest(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.druid.msq.rpc;


import org.apache.druid.msq.statistics.serde.ClusterByStatisticsSnapshotSerde;

/**
* Determines the encoding of key collectors returned by {@link WorkerResource#httpFetchKeyStatistics} and
* {@link WorkerResource#httpFetchKeyStatisticsWithSnapshot}.
*/
public enum SketchEncoding
{
/**
* The key collector is encoded as a byte stream with {@link ClusterByStatisticsSnapshotSerde}.
*/
OCTET_STREAM,
/**
* The key collector is encoded as json
*/
JSON
}
Original file line number Diff line number Diff line change
Expand Up @@ -373,19 +373,4 @@ public Response httpGetCounters(@Context final HttpServletRequest req)
return Response.status(Response.Status.OK).entity(worker.getCounters()).build();
}

/**
* Determines the encoding of key collectors returned by {@link #httpFetchKeyStatistics} and
* {@link #httpFetchKeyStatisticsWithSnapshot}.
*/
public enum SketchEncoding
{
/**
* The key collector is encoded as a byte stream with {@link ClusterByStatisticsSnapshotSerde}.
*/
OCTET_STREAM,
/**
* The key collector is encoded as json
*/
JSON
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.apache.druid.msq.indexing.error.MSQWarnings;
import org.apache.druid.msq.kernel.WorkerAssignmentStrategy;
import org.apache.druid.msq.rpc.ControllerResource;
import org.apache.druid.msq.rpc.SketchEncoding;
import org.apache.druid.msq.sql.MSQMode;
import org.apache.druid.query.QueryContext;
import org.apache.druid.query.QueryContexts;
Expand Down Expand Up @@ -138,6 +139,9 @@ public class MultiStageQueryContext
public static final String CTX_CLUSTER_STATISTICS_MERGE_MODE = "clusterStatisticsMergeMode";
public static final String DEFAULT_CLUSTER_STATISTICS_MERGE_MODE = ClusterStatisticsMergeMode.SEQUENTIAL.toString();

public static final String CTX_SKETCH_ENCODING_MODE = "sketchEncoding";
public static final String DEFAULT_CTX_SKETCH_ENCODING_MODE = SketchEncoding.OCTET_STREAM.toString();

public static final String CTX_ROWS_PER_SEGMENT = "rowsPerSegment";
public static final int DEFAULT_ROWS_PER_SEGMENT = 3000000;

Expand Down Expand Up @@ -265,6 +269,15 @@ public static ClusterStatisticsMergeMode getClusterStatisticsMergeMode(QueryCont
);
}

public static SketchEncoding getSketchEncoding(QueryContext queryContext)
{
return QueryContexts.getAsEnum(
CTX_SKETCH_ENCODING_MODE,
queryContext.getString(CTX_SKETCH_ENCODING_MODE, DEFAULT_CTX_SKETCH_ENCODING_MODE),
SketchEncoding.class
);
}

public static boolean isFinalizeAggregations(final QueryContext queryContext)
{
return queryContext.getBoolean(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.druid.msq.kernel.StageDefinition;
import org.apache.druid.msq.kernel.StageId;
import org.apache.druid.msq.kernel.controller.ControllerQueryKernel;
import org.apache.druid.msq.rpc.SketchEncoding;
import org.apache.druid.msq.statistics.ClusterByStatisticsSnapshot;
import org.apache.druid.msq.statistics.CompleteKeyStatisticsInformation;
import org.junit.After;
Expand Down Expand Up @@ -101,13 +102,13 @@ public void test_submitFetcherTask_parallelFetch() throws InterruptedException

final CountDownLatch latch = new CountDownLatch(TASK_IDS.size());

target = spy(new WorkerSketchFetcher(workerClient, workerManager, true));
target = spy(new WorkerSketchFetcher(workerClient, workerManager, true, SketchEncoding.OCTET_STREAM));

// When fetching snapshots, return a mock and add it to queue
doAnswer(invocation -> {
ClusterByStatisticsSnapshot snapshot = mock(ClusterByStatisticsSnapshot.class);
return Futures.immediateFuture(snapshot);
}).when(workerClient).fetchClusterByStatisticsSnapshot(any(), any());
}).when(workerClient).fetchClusterByStatisticsSnapshot(any(), any(), any());

target.inMemoryFullSketchMerging((kernelConsumer) -> {
kernelConsumer.accept(kernel);
Expand All @@ -124,13 +125,13 @@ public void test_submitFetcherTask_sequentialFetch() throws InterruptedException
doReturn(true).when(completeKeyStatisticsInformation).isComplete();
final CountDownLatch latch = new CountDownLatch(TASK_IDS.size());

target = spy(new WorkerSketchFetcher(workerClient, workerManager, true));
target = spy(new WorkerSketchFetcher(workerClient, workerManager, true, SketchEncoding.OCTET_STREAM));

// When fetching snapshots, return a mock and add it to queue
doAnswer(invocation -> {
ClusterByStatisticsSnapshot snapshot = mock(ClusterByStatisticsSnapshot.class);
return Futures.immediateFuture(snapshot);
}).when(workerClient).fetchClusterByStatisticsSnapshotForTimeChunk(any(), any(), anyLong());
}).when(workerClient).fetchClusterByStatisticsSnapshotForTimeChunk(any(), any(), anyLong(), any());

target.sequentialTimeChunkMerging(
(kernelConsumer) -> {
Expand All @@ -152,7 +153,7 @@ public void test_sequentialMerge_nonCompleteInformation()
{

doReturn(false).when(completeKeyStatisticsInformation).isComplete();
target = spy(new WorkerSketchFetcher(workerClient, workerManager, true));
target = spy(new WorkerSketchFetcher(workerClient, workerManager, true, SketchEncoding.OCTET_STREAM));
Assert.assertThrows(ISE.class, () -> target.sequentialTimeChunkMerging(
(ignore) -> {},
completeKeyStatisticsInformation,
Expand All @@ -167,7 +168,7 @@ public void test_inMemoryRetryEnabled_retryInvoked() throws InterruptedException
{
final CountDownLatch latch = new CountDownLatch(TASK_IDS.size());

target = spy(new WorkerSketchFetcher(workerClient, workerManager, true));
target = spy(new WorkerSketchFetcher(workerClient, workerManager, true, SketchEncoding.OCTET_STREAM));

workersWithFailedFetchParallel(ImmutableSet.of(TASK_1));

Expand Down Expand Up @@ -196,7 +197,7 @@ public void test_SequentialRetryEnabled_retryInvoked() throws InterruptedExcepti
doReturn(true).when(completeKeyStatisticsInformation).isComplete();
final CountDownLatch latch = new CountDownLatch(TASK_IDS.size());

target = spy(new WorkerSketchFetcher(workerClient, workerManager, true));
target = spy(new WorkerSketchFetcher(workerClient, workerManager, true, SketchEncoding.OCTET_STREAM));

workersWithFailedFetchSequential(ImmutableSet.of(TASK_1));
CountDownLatch retryLatch = new CountDownLatch(1);
Expand All @@ -223,7 +224,7 @@ public void test_SequentialRetryEnabled_retryInvoked() throws InterruptedExcepti
public void test_InMemoryRetryDisabled_multipleFailures() throws InterruptedException
{

target = spy(new WorkerSketchFetcher(workerClient, workerManager, false));
target = spy(new WorkerSketchFetcher(workerClient, workerManager, false, SketchEncoding.OCTET_STREAM));

workersWithFailedFetchParallel(ImmutableSet.of(TASK_1, TASK_0));

Expand Down Expand Up @@ -252,7 +253,7 @@ public void test_InMemoryRetryDisabled_multipleFailures() throws InterruptedExce
public void test_InMemoryRetryDisabled_singleFailure() throws InterruptedException
{

target = spy(new WorkerSketchFetcher(workerClient, workerManager, false));
target = spy(new WorkerSketchFetcher(workerClient, workerManager, false, SketchEncoding.OCTET_STREAM));

workersWithFailedFetchParallel(ImmutableSet.of(TASK_1));

Expand Down Expand Up @@ -283,7 +284,7 @@ public void test_SequentialRetryDisabled_multipleFailures() throws InterruptedEx
{

doReturn(true).when(completeKeyStatisticsInformation).isComplete();
target = spy(new WorkerSketchFetcher(workerClient, workerManager, false));
target = spy(new WorkerSketchFetcher(workerClient, workerManager, false, SketchEncoding.OCTET_STREAM));

workersWithFailedFetchSequential(ImmutableSet.of(TASK_1, TASK_0));

Expand Down Expand Up @@ -315,7 +316,7 @@ public void test_SequentialRetryDisabled_multipleFailures() throws InterruptedEx
public void test_SequentialRetryDisabled_singleFailure() throws InterruptedException
{
doReturn(true).when(completeKeyStatisticsInformation).isComplete();
target = spy(new WorkerSketchFetcher(workerClient, workerManager, false));
target = spy(new WorkerSketchFetcher(workerClient, workerManager, false, SketchEncoding.OCTET_STREAM));

workersWithFailedFetchSequential(ImmutableSet.of(TASK_1));

Expand Down Expand Up @@ -352,7 +353,7 @@ private void workersWithFailedFetchSequential(Set<String> failedTasks)
return Futures.immediateFailedFuture(new Exception("Task fetch failed :" + invocation.getArgument(0)));
}
return Futures.immediateFuture(snapshot);
}).when(workerClient).fetchClusterByStatisticsSnapshotForTimeChunk(any(), any(), anyLong());
}).when(workerClient).fetchClusterByStatisticsSnapshotForTimeChunk(any(), any(), anyLong(), any());
}

private void workersWithFailedFetchParallel(Set<String> failedTasks)
Expand All @@ -363,7 +364,7 @@ private void workersWithFailedFetchParallel(Set<String> failedTasks)
return Futures.immediateFailedFuture(new Exception("Task fetch failed :" + invocation.getArgument(0)));
}
return Futures.immediateFuture(snapshot);
}).when(workerClient).fetchClusterByStatisticsSnapshot(any(), any());
}).when(workerClient).fetchClusterByStatisticsSnapshot(any(), any(), any());
}

}
Loading

0 comments on commit e364d84

Please sign in to comment.