Skip to content

Commit

Permalink
Switch from executors to structured task execution
Browse files Browse the repository at this point in the history
  • Loading branch information
bryanlb committed Aug 28, 2024
1 parent 20582f7 commit 66b2267
Showing 1 changed file with 66 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,36 +15,36 @@
import com.slack.astra.proto.metadata.Metadata;
import com.slack.astra.proto.service.AstraSearch;
import com.slack.astra.server.AstraQueryServiceBase;
import org.apache.commons.lang3.NotImplementedException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.time.Duration;
import java.time.temporal.ChronoUnit;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.StructuredTaskScope;
import java.util.concurrent.TimeoutException;
import org.apache.commons.lang3.NotImplementedException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AstraCacheQueryService extends AstraQueryServiceBase {
private static final Logger LOG = LoggerFactory.getLogger(AstraCacheQueryService.class);
private static final int queryConcurrency =
Integer.parseInt(
System.getProperty(
"astra.concurrent.query",
String.valueOf(Runtime.getRuntime().availableProcessors() - 1)));
// private static final int queryConcurrency =
// Integer.parseInt(
// System.getProperty(
// "astra.concurrent.query",
// String.valueOf(Runtime.getRuntime().availableProcessors() - 1)));

private static final int cacheSize = Integer.parseInt(
System.getProperty(
"astra.concurrent.cache", String.valueOf(50)));
"astra.concurrent.cache", String.valueOf(100)));

private final ExecutorService executorService =
Executors.newFixedThreadPool(
queryConcurrency,
new ThreadFactoryBuilder().setNameFormat("cache-query-service-%d").build());
// private final ExecutorService executorService =
// Executors.newFixedThreadPool(
// queryConcurrency,
// new ThreadFactoryBuilder().setNameFormat("cache-query-service-%d").build());

private final BlobStore blobStore;
private final Duration queryTimeout;
Expand Down Expand Up @@ -93,63 +93,66 @@ public AstraCacheQueryService(
searchMetadataStore.createSync(searchMetadata);
}

@SuppressWarnings("preview")
@Override
public AstraSearch.SearchResult doSearch(AstraSearch.SearchRequest request) {

// todo - timeout?

SearchQuery query = SearchResultUtils.fromSearchRequest(request);

// at this point we already have a list of chunks that (may) be cached - just do the query
// already

List<CompletableFuture<SearchResult<LogMessage>>> queryFutures = new ArrayList<>();
query.chunkIds.forEach(
chunkId -> {
queryFutures.add(
CompletableFuture.supplyAsync(
() -> {
try {
// todo - searchStartTime/searchEndtime instead of query start/end time?
return searcherCache
.get(chunkId)
.search(
query.dataset,
query.queryStr,
query.startTimeEpochMs,
query.endTimeEpochMs,
query.howMany,
query.aggBuilder,
query.queryBuilder,
query.sourceFieldFilter);
} catch (ExecutionException e) {
throw new RuntimeException(e);
}
},
executorService));
});

try {
CompletableFuture.allOf(queryFutures.toArray(CompletableFuture[]::new))
.get(queryTimeout.get(ChronoUnit.SECONDS), TimeUnit.SECONDS);
} catch (InterruptedException | ExecutionException e) {
try (var scope = new StructuredTaskScope<SearchResult<LogMessage>>()) {
List<StructuredTaskScope.Subtask<SearchResult<LogMessage>>> searchSubtasks =
query.chunkIds.stream().map(chunkId -> scope.fork(() -> {
try {
// todo - searchStartTime/searchEndtime instead of query start/end time?
return searcherCache
.get(chunkId)
.search(
query.dataset,
query.queryStr,
query.startTimeEpochMs,
query.endTimeEpochMs,
query.howMany,
query.aggBuilder,
query.queryBuilder,
query.sourceFieldFilter);
} catch (ExecutionException e) {
throw new RuntimeException(e);
}
})).toList();

try {
scope.joinUntil(Instant.now().plusSeconds(queryTimeout.toSeconds()));
} catch (TimeoutException timeoutException) {
LOG.warn("Query timeout", timeoutException);
scope.shutdown();
scope.join();
}

List<SearchResult<LogMessage>> response = new ArrayList<>(searchSubtasks.size());
for (StructuredTaskScope.Subtask<SearchResult<LogMessage>> searchResult : searchSubtasks) {
try {
if (searchResult.state().equals(StructuredTaskScope.Subtask.State.SUCCESS)) {
response.add(searchResult.get() == null ? SearchResult.error() : searchResult.get());
} else {
response.add(SearchResult.error());
}
} catch (Exception e) {
LOG.error("Error fetching search result", e);
response.add(SearchResult.error());
}
}
SearchResult<LogMessage> aggregatedResults =
((SearchResultAggregator<LogMessage>) new SearchResultAggregatorImpl<>(query))
.aggregate(response, false);
return SearchResultUtils.toSearchResultProto(aggregatedResults);
}
} catch (Exception e) {
LOG.error("Search failed with ", e);
throw new RuntimeException(e);
} catch (TimeoutException e) {
LOG.warn(
"Query timeout - {} cancelled, {} total",
queryFutures.stream().map(CompletableFuture::isCancelled).toList().size(),
queryFutures.size());
}

List<SearchResult<LogMessage>> searchResults = new ArrayList<>();
for (CompletableFuture<SearchResult<LogMessage>> queryFuture : queryFutures) {
searchResults.add(queryFuture.getNow(new SearchResult<>()));
}

SearchResult<LogMessage> aggregatedResults =
((SearchResultAggregator<LogMessage>) new SearchResultAggregatorImpl<>(query))
.aggregate(searchResults, false);
return SearchResultUtils.toSearchResultProto(aggregatedResults);
}

@Override
Expand Down

0 comments on commit 66b2267

Please sign in to comment.