Skip to content

Commit

Permalink
fix: avoid thread blocking in ParallelSink (#4333)
Browse files Browse the repository at this point in the history
  • Loading branch information
ndr-brt authored Jul 4, 2024
1 parent d12b933 commit a5c736e
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,46 +52,61 @@
/**
* Provides core services for the Data Plane Framework.
*/
@Provides({ DataPlaneManager.class, DataTransferExecutorServiceContainer.class, TransferServiceRegistry.class })
@Provides({ DataPlaneManager.class, TransferServiceRegistry.class })
@Extension(value = DataPlaneFrameworkExtension.NAME)
public class DataPlaneFrameworkExtension implements ServiceExtension {

public static final String NAME = "Data Plane Framework";
private static final int DEFAULT_TRANSFER_THREADS = 20;

@Setting(value = "the iteration wait time in milliseconds in the data plane state machine. Default value " + DEFAULT_ITERATION_WAIT, type = "long")
@Setting(
value = "the iteration wait time in milliseconds in the data plane state machine.",
defaultValue = DEFAULT_ITERATION_WAIT + "",
type = "long")
private static final String DATAPLANE_MACHINE_ITERATION_WAIT_MILLIS = "edc.dataplane.state-machine.iteration-wait-millis";

@Setting(value = "the batch size in the data plane state machine. Default value " + DEFAULT_BATCH_SIZE, type = "int")
@Setting(
value = "the batch size in the data plane state machine.",
defaultValue = DEFAULT_BATCH_SIZE + "",
type = "int"
)
private static final String DATAPLANE_MACHINE_BATCH_SIZE = "edc.dataplane.state-machine.batch-size";

@Setting(value = "how many times a specific operation must be tried before terminating the dataplane with error", type = "int", defaultValue = DEFAULT_SEND_RETRY_LIMIT + "")
@Setting(
value = "how many times a specific operation must be tried before terminating the dataplane with error",
defaultValue = DEFAULT_SEND_RETRY_LIMIT + "",
type = "int"
)
private static final String DATAPLANE_SEND_RETRY_LIMIT = "edc.dataplane.send.retry.limit";

@Setting(value = "The base delay for the dataplane retry mechanism in millisecond", type = "long", defaultValue = DEFAULT_SEND_RETRY_BASE_DELAY + "")
@Setting(
value = "The base delay for the dataplane retry mechanism in millisecond",
defaultValue = DEFAULT_SEND_RETRY_BASE_DELAY + "",
type = "long"
)
private static final String DATAPLANE_SEND_RETRY_BASE_DELAY_MS = "edc.dataplane.send.retry.base-delay.ms";

@Setting
@Setting(
value = "Size of the transfer thread pool. It is advisable to set it bigger than the state machine batch size",
defaultValue = DEFAULT_TRANSFER_THREADS + "",
type = "int"
)
private static final String TRANSFER_THREADS = "edc.dataplane.transfer.threads";
private static final int DEFAULT_TRANSFER_THREADS = 10;

private DataPlaneManagerImpl dataPlaneManager;

@Inject
private TransferServiceSelectionStrategy transferServiceSelectionStrategy;

@Inject
private DataPlaneStore store;

@Inject
private TransferProcessApiClient transferProcessApiClient;

@Inject
private ExecutorInstrumentation executorInstrumentation;

@Inject
private Telemetry telemetry;

@Inject
private Clock clock;

@Inject
private PipelineService pipelineService;
@Inject
Expand All @@ -112,12 +127,6 @@ public String name() {
public void initialize(ServiceExtensionContext context) {
var monitor = context.getMonitor();

var numThreads = context.getSetting(TRANSFER_THREADS, DEFAULT_TRANSFER_THREADS);
var executorService = Executors.newFixedThreadPool(numThreads);
var executorContainer = new DataTransferExecutorServiceContainer(
executorInstrumentation.instrument(executorService, "Data plane transfers"));
context.registerService(DataTransferExecutorServiceContainer.class, executorContainer);

var transferServiceRegistry = new TransferServiceRegistryImpl(transferServiceSelectionStrategy);
transferServiceRegistry.registerTransferService(pipelineService);
context.registerService(TransferServiceRegistry.class, transferServiceRegistry);
Expand All @@ -131,6 +140,7 @@ public void initialize(ServiceExtensionContext context) {
.clock(clock)
.entityRetryProcessConfiguration(getEntityRetryProcessConfiguration(context))
.executorInstrumentation(executorInstrumentation)
.authorizationService(authorizationService)
.transferServiceRegistry(transferServiceRegistry)
.store(store)
.transferProcessClient(transferProcessApiClient)
Expand All @@ -154,6 +164,14 @@ public void shutdown() {
}
}

@Provider
public DataTransferExecutorServiceContainer dataTransferExecutorServiceContainer(ServiceExtensionContext context) {
var numThreads = context.getSetting(TRANSFER_THREADS, DEFAULT_TRANSFER_THREADS);
var executorService = Executors.newFixedThreadPool(numThreads);
return new DataTransferExecutorServiceContainer(
executorInstrumentation.instrument(executorService, "Data plane transfers"));
}

@Provider
public DataPlaneAuthorizationService authorizationService(ServiceExtensionContext context) {
if (authorizationService == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
import io.opentelemetry.instrumentation.annotations.WithSpan;
import org.eclipse.edc.connector.dataplane.spi.pipeline.DataSink;
import org.eclipse.edc.connector.dataplane.spi.pipeline.DataSource;
import org.eclipse.edc.connector.dataplane.spi.pipeline.StreamFailure;
import org.eclipse.edc.connector.dataplane.spi.pipeline.StreamResult;
import org.eclipse.edc.spi.EdcException;
import org.eclipse.edc.spi.monitor.Monitor;
import org.eclipse.edc.spi.result.AbstractResult;
import org.eclipse.edc.spi.telemetry.Telemetry;
import org.eclipse.edc.util.stream.PartitionIterator;
import org.jetbrains.annotations.NotNull;
Expand All @@ -30,10 +31,7 @@
import java.util.concurrent.ExecutorService;
import java.util.function.Supplier;

import static java.lang.String.format;
import static java.util.concurrent.CompletableFuture.completedFuture;
import static java.util.concurrent.CompletableFuture.supplyAsync;
import static org.eclipse.edc.connector.dataplane.spi.pipeline.StreamResult.failure;
import static org.eclipse.edc.util.async.AsyncUtils.asyncAllOf;

/**
Expand All @@ -49,28 +47,26 @@ public abstract class ParallelSink implements DataSink {
@WithSpan
@Override
public CompletableFuture<StreamResult<Object>> transfer(DataSource source) {
try {
var streamResult = source.openPartStream();
if (streamResult.failed()) {
return completedFuture(failure(streamResult.getFailure()));
}

try (var partStream = streamResult.getContent()) {
return PartitionIterator.streamOf(partStream, partitionSize)
.map(this::processPartsAsync)
.collect(asyncAllOf())
.thenApply(results -> results.stream()
.filter(AbstractResult::failed)
.findFirst()
.map(r -> StreamResult.<Object>error(String.join(",", r.getFailureMessages())))
.orElseGet(this::complete))
.exceptionally(throwable -> StreamResult.error("Unhandled exception raised when transferring data: " + throwable.getMessage()));
}
} catch (Exception e) {
var errorMessage = format("Error processing data transfer request - Request ID: %s", requestId);
monitor.severe(errorMessage, e);
return CompletableFuture.completedFuture(StreamResult.error(errorMessage));
}
return supplyAsync(() -> source.openPartStream().orElseThrow(StreamException::new), executorService)
.thenCompose(parts -> {
try (parts) {
return PartitionIterator.streamOf(parts, partitionSize)
.map(this::processPartsAsync)
.collect(asyncAllOf())
.thenApply(results -> results.stream()
.filter(StreamResult::failed)
.findFirst()
.map(r -> StreamResult.failure(r.getFailure()))
.orElseGet(this::complete));
}
})
.exceptionally(throwable -> {
if (throwable instanceof StreamException streamException) {
return StreamResult.failure(streamException.failure);
} else {
return StreamResult.error("Error processing data transfer request - Request ID: %s. Message: %s".formatted(requestId, throwable.getMessage()));
}
});
}

@NotNull
Expand All @@ -95,6 +91,16 @@ protected StreamResult<Object> complete() {
return StreamResult.success();
}

private static class StreamException extends EdcException {

private final StreamFailure failure;

StreamException(StreamFailure failure) {
super(failure.getFailureDetail());
this.failure = failure;
}
}

protected abstract static class Builder<B extends Builder<B, T>, T extends ParallelSink> {
protected T sink;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,65 +14,54 @@

package org.eclipse.edc.connector.dataplane.util.sink;

import org.assertj.core.api.Assertions;
import org.eclipse.edc.connector.dataplane.spi.pipeline.DataSource;
import org.eclipse.edc.connector.dataplane.spi.pipeline.InputStreamDataSource;
import org.eclipse.edc.connector.dataplane.spi.pipeline.StreamFailure;
import org.eclipse.edc.connector.dataplane.spi.pipeline.StreamResult;
import org.eclipse.edc.spi.monitor.Monitor;
import org.eclipse.edc.spi.telemetry.Telemetry;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import java.io.ByteArrayInputStream;
import java.time.Duration;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import java.util.stream.IntStream;

import static java.lang.String.format;
import static java.time.temporal.ChronoUnit.MILLIS;
import static java.util.UUID.randomUUID;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

class ParallelSinkTest {

private final Monitor monitor = mock(Monitor.class);
private final ExecutorService executor = Executors.newFixedThreadPool(2);
private final String dataSourceName = "test-datasource-name";
private final String dataSourceContent = "test-content";
private final Duration timeout = Duration.of(500, MILLIS);
private final String errorMessage = "test-errormessage";
private final InputStreamDataSource dataSource = new InputStreamDataSource(
dataSourceName,
new ByteArrayInputStream(dataSourceContent.getBytes()));
private final String dataFlowRequestId = randomUUID().toString();
FakeParallelSink fakeSink;

@BeforeEach
void setup() {
fakeSink = new FakeParallelSink();
fakeSink.monitor = monitor;
fakeSink.telemetry = new Telemetry(); // default noop implementation
fakeSink.executorService = executor;
fakeSink.requestId = dataFlowRequestId;
}
private final FakeParallelSink fakeSink = new FakeParallelSink.Builder().monitor(mock())
.executorService(Executors.newFixedThreadPool(2))
.requestId(dataFlowRequestId).build();

@Test
void transfer_succeeds() {
assertThat(fakeSink.transfer(dataSource)).succeedsWithin(500, TimeUnit.MILLISECONDS)
.satisfies(transferResult -> assertThat(transferResult.succeeded()).isTrue());
var dataSource = dataSource();

Assertions.assertThat(fakeSink.parts).containsExactly(dataSource);
var future = fakeSink.transfer(dataSource);

assertThat(future).succeedsWithin(timeout)
.satisfies(transferResult -> assertThat(transferResult.succeeded()).isTrue());
assertThat(fakeSink.parts).containsExactly(dataSource);
assertThat(fakeSink.complete).isEqualTo(1);
}

@Test
void transfer_whenCompleteFails_fails() {
var dataSource = dataSource();
fakeSink.completeResponse = StreamResult.error("General error");
assertThat(fakeSink.transfer(dataSource)).succeedsWithin(500, TimeUnit.MILLISECONDS)
.isEqualTo(fakeSink.completeResponse);

var future = fakeSink.transfer(dataSource);

assertThat(future).succeedsWithin(timeout).isEqualTo(fakeSink.completeResponse);
}

@Test
Expand All @@ -81,17 +70,23 @@ void transfer_whenExceptionOpeningPartStream_fails() {

when(dataSourceMock.openPartStream()).thenThrow(new RuntimeException(errorMessage));

assertThat(fakeSink.transfer(dataSourceMock)).succeedsWithin(500, TimeUnit.MILLISECONDS)
var future = fakeSink.transfer(dataSourceMock);

assertThat(future).succeedsWithin(timeout)
.satisfies(transferResult -> assertThat(transferResult.failed()).isTrue())
.satisfies(transferResult -> assertThat(transferResult.getFailureMessages()).containsExactly(format("Error processing data transfer request - Request ID: %s", dataFlowRequestId)));
.satisfies(transferResult -> assertThat(transferResult.getFailureDetail())
.contains("Error processing data transfer request").contains(dataFlowRequestId).contains(errorMessage));
assertThat(fakeSink.complete).isEqualTo(0);
}

@Test
void transfer_whenFailureDuringTransfer_fails() {
var dataSource = dataSource();
fakeSink.transferResultSupplier = () -> StreamResult.error(errorMessage);

assertThat(fakeSink.transfer(dataSource)).succeedsWithin(500, TimeUnit.MILLISECONDS)
var future = fakeSink.transfer(dataSource);

assertThat(future).succeedsWithin(timeout)
.satisfies(transferResult -> assertThat(transferResult.failed()).isTrue())
.satisfies(transferResult -> assertThat(transferResult.getFailure().getReason()).isEqualTo(StreamFailure.Reason.GENERAL_ERROR))
.satisfies(transferResult -> assertThat(transferResult.getFailureMessages()).containsExactly(errorMessage));
Expand All @@ -102,20 +97,40 @@ void transfer_whenFailureDuringTransfer_fails() {

@Test
void transfer_whenExceptionDuringTransfer_fails() {
var dataSource = dataSource();
fakeSink.transferResultSupplier = () -> {
throw new RuntimeException(errorMessage);
};

assertThat(fakeSink.transfer(dataSource)).succeedsWithin(500, TimeUnit.MILLISECONDS)
var future = fakeSink.transfer(dataSource);

assertThat(future).succeedsWithin(timeout)
.satisfies(transferResult -> assertThat(transferResult.failed()).isTrue())
.satisfies(transferResult -> assertThat(transferResult.getFailure().getReason()).isEqualTo(StreamFailure.Reason.GENERAL_ERROR))
.satisfies(transferResult -> assertThat(transferResult.getFailureMessages())
.containsExactly("Unhandled exception raised when transferring data: java.lang.RuntimeException: " + errorMessage));
.satisfies(transferResult -> assertThat(transferResult.getFailureDetail())
.contains("Error processing data transfer request").contains(dataFlowRequestId).contains(errorMessage));

assertThat(fakeSink.parts).containsExactly(dataSource);
assertThat(fakeSink.complete).isEqualTo(0);
}

@Test
void shouldNotBlock_whenDataSourceIsIndefinite() {
var infiniteStream = IntStream.iterate(0, i -> i + 1).mapToObj(i -> mock(DataSource.Part.class));
var dataSource = mock(DataSource.class);
when(dataSource.openPartStream()).thenReturn(StreamResult.success(infiniteStream));

var future = fakeSink.transfer(dataSource);

assertThat(future).isNotNull();
}

private InputStreamDataSource dataSource() {
return new InputStreamDataSource(
"test-datasource-name",
new ByteArrayInputStream("test-content".getBytes()));
}

private static class FakeParallelSink extends ParallelSink {

List<DataSource.Part> parts;
Expand All @@ -134,5 +149,17 @@ protected StreamResult<Object> complete() {
complete++;
return completeResponse;
}

public static class Builder extends ParallelSink.Builder<Builder, FakeParallelSink> {

protected Builder() {
super(new FakeParallelSink());
}

@Override
protected void validate() {

}
}
}
}

0 comments on commit a5c736e

Please sign in to comment.