Skip to content

Commit

Permalink
fix: fixes ConcurrentMod exception when accessing steams metadata
Browse files Browse the repository at this point in the history
fixes: confluentinc#4639

Until the Streams bug https://issues.apache.org/jira/browse/KAFKA-9668 is fixed, ksql needs to protect itself from ConcurrentMod exceptions when accessing `KafkaSteams.allMetadata`.

This change accesses the internals of `KafkaStreams` to acquire a reference to the field that needs to be synchronised to protect against the concurrent modification.
  • Loading branch information
big-andy-coates committed Mar 5, 2020
1 parent c3cd132 commit eff3321
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 14 deletions.
96 changes: 83 additions & 13 deletions ksql-engine/src/main/java/io/confluent/ksql/util/QueryMetadata.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,27 @@

package io.confluent.ksql.util;

import static java.util.Objects.requireNonNull;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.confluent.ksql.internal.QueryStateListener;
import io.confluent.ksql.name.SourceName;
import io.confluent.ksql.schema.ksql.LogicalSchema;
import java.lang.Thread.UncaughtExceptionHandler;
import java.lang.reflect.Field;
import java.time.Duration;
import java.util.Collection;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Consumer;
import org.apache.kafka.streams.KafkaStreams;
import org.apache.kafka.streams.LagInfo;
import org.apache.kafka.streams.Topology;
import org.apache.kafka.streams.errors.StreamsException;
import org.apache.kafka.streams.processor.internals.StreamsMetadataState;
import org.apache.kafka.streams.state.StreamsMetadata;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -40,6 +44,8 @@ public abstract class QueryMetadata {

private static final Logger LOG = LoggerFactory.getLogger(QueryMetadata.class);

private static final Field STREAMS_INTERNAL_FIELD = getStreamsInternalField();

private final String statementString;
private final KafkaStreams kafkaStreams;
private final String executionPlan;
Expand All @@ -51,6 +57,7 @@ public abstract class QueryMetadata {
private final Set<SourceName> sourceNames;
private final LogicalSchema logicalSchema;
private final Long closeTimeout;
private final StreamsMetadataState streamsMetadataState;

private Optional<QueryStateListener> queryStateListener = Optional.empty();
private boolean everStarted = false;
Expand All @@ -70,34 +77,69 @@ public QueryMetadata(
final long closeTimeout
) {
// CHECKSTYLE_RULES.ON: ParameterNumberCheck
this.statementString = Objects.requireNonNull(statementString, "statementString");
this.kafkaStreams = Objects.requireNonNull(kafkaStreams, "kafkaStreams");
this.executionPlan = Objects.requireNonNull(executionPlan, "executionPlan");
this.queryApplicationId = Objects.requireNonNull(queryApplicationId, "queryApplicationId");
this.topology = Objects.requireNonNull(topology, "kafkaTopicClient");
this(
statementString,
kafkaStreams,
getStreamsMetadataState(kafkaStreams),
logicalSchema,
sourceNames,
executionPlan,
queryApplicationId,
topology,
streamsProperties,
overriddenProperties,
closeCallback,
closeTimeout
);
}

@VisibleForTesting
// CHECKSTYLE_RULES.OFF: ParameterNumberCheck
QueryMetadata(
final String statementString,
final KafkaStreams kafkaStreams,
final StreamsMetadataState streamsMetadataState,
final LogicalSchema logicalSchema,
final Set<SourceName> sourceNames,
final String executionPlan,
final String queryApplicationId,
final Topology topology,
final Map<String, Object> streamsProperties,
final Map<String, Object> overriddenProperties,
final Consumer<QueryMetadata> closeCallback,
final long closeTimeout
) {
// CHECKSTYLE_RULES.ON: ParameterNumberCheck
this.statementString = requireNonNull(statementString, "statementString");
this.kafkaStreams = requireNonNull(kafkaStreams, "kafkaStreams");
this.streamsMetadataState = requireNonNull(streamsMetadataState, "streamsMetadataState");
this.executionPlan = requireNonNull(executionPlan, "executionPlan");
this.queryApplicationId = requireNonNull(queryApplicationId, "queryApplicationId");
this.topology = requireNonNull(topology, "kafkaTopicClient");
this.streamsProperties =
ImmutableMap.copyOf(
Objects.requireNonNull(streamsProperties, "streamsPropeties"));
requireNonNull(streamsProperties, "streamsPropeties"));
this.overriddenProperties =
ImmutableMap.copyOf(
Objects.requireNonNull(overriddenProperties, "overriddenProperties"));
this.closeCallback = Objects.requireNonNull(closeCallback, "closeCallback");
this.sourceNames = Objects.requireNonNull(sourceNames, "sourceNames");
this.logicalSchema = Objects.requireNonNull(logicalSchema, "logicalSchema");
requireNonNull(overriddenProperties, "overriddenProperties"));
this.closeCallback = requireNonNull(closeCallback, "closeCallback");
this.sourceNames = requireNonNull(sourceNames, "sourceNames");
this.logicalSchema = requireNonNull(logicalSchema, "logicalSchema");
this.closeTimeout = closeTimeout;
}

protected QueryMetadata(final QueryMetadata other, final Consumer<QueryMetadata> closeCallback) {
this.statementString = other.statementString;
this.kafkaStreams = other.kafkaStreams;
this.streamsMetadataState = other.streamsMetadataState;
this.executionPlan = other.executionPlan;
this.queryApplicationId = other.queryApplicationId;
this.topology = other.topology;
this.streamsProperties = other.streamsProperties;
this.overriddenProperties = other.overriddenProperties;
this.sourceNames = other.sourceNames;
this.logicalSchema = other.logicalSchema;
this.closeCallback = Objects.requireNonNull(closeCallback, "closeCallback");
this.closeCallback = requireNonNull(closeCallback, "closeCallback");
this.closeTimeout = other.closeTimeout;
}

Expand Down Expand Up @@ -146,7 +188,10 @@ public Map<String, Map<Integer, LagInfo>> getAllLocalStorePartitionLags() {

public Collection<StreamsMetadata> getAllMetadata() {
try {
return ImmutableList.copyOf(kafkaStreams.allMetadata());
// Synchronized block need until https://issues.apache.org/jira/browse/KAFKA-9668 fixed.
synchronized (streamsMetadataState) {
return ImmutableList.copyOf(kafkaStreams.allMetadata());
}
} catch (IllegalStateException e) {
LOG.error(e.getMessage());
}
Expand Down Expand Up @@ -214,4 +259,29 @@ public void start() {
public String getTopologyDescription() {
return topology.describe().toString();
}

/*
Use reflection to get at StreamsMetadataState, which is needed to synchronize on if ksql is to
avoid the ConcurrentMod exception caused by this bug:
https://issues.apache.org/jira/browse/KAFKA-9668.
Yes, this is brittle. But it can be removed once the above bug is fixed.
*/
static StreamsMetadataState getStreamsMetadataState(final KafkaStreams kafkaStreams) {
try {
return (StreamsMetadataState) STREAMS_INTERNAL_FIELD.get(kafkaStreams);
} catch (final IllegalAccessException e) {
throw new IllegalStateException("Failed to access KafkaStreams.streamsMetadataState", e);
}
}

private static Field getStreamsInternalField() {
try {
final Field field = KafkaStreams.class.getDeclaredField("streamsMetadataState");
field.setAccessible(true);
return field;
} catch (final NoSuchFieldException e) {
throw new IllegalStateException("Failed to get KafkaStreams.streamsMetadataState", e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

package io.confluent.ksql.util;

import com.google.common.annotations.VisibleForTesting;
import io.confluent.ksql.name.SourceName;
import io.confluent.ksql.query.BlockingRowQueue;
import io.confluent.ksql.query.LimitHandler;
Expand All @@ -26,6 +27,7 @@
import java.util.function.Consumer;
import org.apache.kafka.streams.KafkaStreams;
import org.apache.kafka.streams.Topology;
import org.apache.kafka.streams.processor.internals.StreamsMetadataState;

/**
* Metadata of a transient query, e.g. {@code SELECT * FROM FOO;}.
Expand All @@ -50,9 +52,45 @@ public TransientQueryMetadata(
final Consumer<QueryMetadata> closeCallback,
final long closeTimeout) {
// CHECKSTYLE_RULES.ON: ParameterNumberCheck
this(
statementString,
kafkaStreams,
getStreamsMetadataState(kafkaStreams),
logicalSchema,
sourceNames,
executionPlan,
rowQueue,
queryApplicationId,
topology,
streamsProperties,
overriddenProperties,
closeCallback,
closeTimeout
);
}

// CHECKSTYLE_RULES.OFF: ParameterNumberCheck
@VisibleForTesting
TransientQueryMetadata(
final String statementString,
final KafkaStreams kafkaStreams,
final StreamsMetadataState streamsMetadataState,
final LogicalSchema logicalSchema,
final Set<SourceName> sourceNames,
final String executionPlan,
final BlockingRowQueue rowQueue,
final String queryApplicationId,
final Topology topology,
final Map<String, Object> streamsProperties,
final Map<String, Object> overriddenProperties,
final Consumer<QueryMetadata> closeCallback,
final long closeTimeout
) {
// CHECKSTYLE_RULES.ON: ParameterNumberCheck
super(
statementString,
kafkaStreams,
streamsMetadataState,
logicalSchema,
sourceNames,
executionPlan,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,37 @@
package io.confluent.ksql.util;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.confluent.ksql.internal.QueryStateListener;
import io.confluent.ksql.name.ColumnName;
import io.confluent.ksql.name.SourceName;
import io.confluent.ksql.schema.ksql.LogicalSchema;
import io.confluent.ksql.schema.ksql.types.SqlTypes;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
import org.apache.kafka.streams.KafkaStreams;
import org.apache.kafka.streams.KafkaStreams.State;
import org.apache.kafka.streams.Topology;
import org.apache.kafka.streams.processor.internals.StreamsMetadataState;
import org.apache.kafka.streams.state.StreamsMetadata;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand All @@ -53,14 +64,17 @@ public class QueryMetadataTest {
.valueColumn(ColumnName.of("f0"), SqlTypes.STRING)
.build();

private static final Set<SourceName> SOME_SOURCES = ImmutableSet.of(SourceName.of("s1"), SourceName.of("s2"));
private static final Set<SourceName> SOME_SOURCES = ImmutableSet
.of(SourceName.of("s1"), SourceName.of("s2"));
private static final Long closeTimeout = KsqlConfig.KSQL_SHUTDOWN_TIMEOUT_MS_DEFAULT;

@Mock
private Topology topoplogy;
@Mock
private KafkaStreams kafkaStreams;
@Mock
private StreamsMetadataState streamsMetadataState;
@Mock
private QueryStateListener listener;
@Mock
private Consumer<QueryMetadata> closeCallback;
Expand All @@ -73,6 +87,7 @@ public void setup() {
query = new QueryMetadata(
"foo",
kafkaStreams,
streamsMetadataState,
SOME_SCHEMA,
SOME_SOURCES,
"bar",
Expand Down Expand Up @@ -207,4 +222,52 @@ public void shouldReturnSources() {
public void shouldReturnSchema() {
assertThat(query.getLogicalSchema(), is(SOME_SCHEMA));
}

@Test
public void shouldGetAllMetadataAsImmutableCopy() {
assertThat(query.getAllMetadata(), is(instanceOf(ImmutableList.class)));
}

/*
Until https://issues.apache.org/jira/browse/KAFKA-9668 is fixed the `allMetadata` returns a ref
to internal mutable state. This state is mutated by other threads, leading to ConcurrentMod
exceptions. This test ensures ksqlDB has a workaround in place by ensuring any modification
on while
*/
@Test
public void shouldGetAllMetadataThreadSafe() {
final StreamsMetadataState streamsMetadataState = this.streamsMetadataState;

final List<StreamsMetadata> allMetadata = new ArrayList<>();
when(kafkaStreams.allMetadata()).thenReturn(allMetadata);

final ExecutorService executor = Executors.newSingleThreadExecutor();
final AtomicBoolean running = new AtomicBoolean(true);

try {
executor.submit(() -> {
while (running.get()) {
synchronized (streamsMetadataState) {
if (allMetadata.size() < 100) {
allMetadata.add(mock(StreamsMetadata.class));
} else {
allMetadata.clear();
}
}
}
});

for (int i = 0; i != 10_000; ++i) {
query.getAllMetadata();
}
} finally {
running.set(false);
executor.shutdownNow();
try {
executor.awaitTermination(1, TimeUnit.MINUTES);
} catch (final InterruptedException e) {
// Meh.
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.util.function.Consumer;
import org.apache.kafka.streams.KafkaStreams;
import org.apache.kafka.streams.Topology;
import org.apache.kafka.streams.processor.internals.StreamsMetadataState;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand All @@ -44,6 +45,8 @@ public class TransientQueryMetadataTest {
@Mock
private KafkaStreams kafkaStreams;
@Mock
private StreamsMetadataState streamsMetadataState;
@Mock
private LogicalSchema logicalSchema;
@Mock
private Set<SourceName> sourceNames;
Expand All @@ -64,6 +67,7 @@ public void setUp() {
query = new TransientQueryMetadata(
SQL,
kafkaStreams,
streamsMetadataState,
logicalSchema,
sourceNames,
EXECUTION_PLAN,
Expand Down

0 comments on commit eff3321

Please sign in to comment.