diff --git a/core/src/test/java/kafka/test/junit/RaftClusterInvocationContext.java b/core/src/test/java/kafka/test/junit/RaftClusterInvocationContext.java index df8990b0c92f2..9857d4c92cd39 100644 --- a/core/src/test/java/kafka/test/junit/RaftClusterInvocationContext.java +++ b/core/src/test/java/kafka/test/junit/RaftClusterInvocationContext.java @@ -48,7 +48,6 @@ import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -67,15 +66,11 @@ public class RaftClusterInvocationContext implements TestTemplateInvocationConte private final String baseDisplayName; private final ClusterConfig clusterConfig; - private final AtomicReference clusterReference; - private final AtomicReference zkReference; private final boolean isCombined; public RaftClusterInvocationContext(String baseDisplayName, ClusterConfig clusterConfig, boolean isCombined) { this.baseDisplayName = baseDisplayName; this.clusterConfig = clusterConfig; - this.clusterReference = new AtomicReference<>(); - this.zkReference = new AtomicReference<>(); this.isCombined = isCombined; } @@ -86,67 +81,43 @@ public String getDisplayName(int invocationIndex) { @Override public List getAdditionalExtensions() { - RaftClusterInstance clusterInstance = new RaftClusterInstance(clusterReference, zkReference, clusterConfig, isCombined); + RaftClusterInstance clusterInstance = new RaftClusterInstance(clusterConfig, isCombined); return Arrays.asList( - (BeforeTestExecutionCallback) context -> { - TestKitNodes nodes = new TestKitNodes.Builder(). - setBootstrapMetadataVersion(clusterConfig.metadataVersion()). - setCombined(isCombined). - setNumBrokerNodes(clusterConfig.numBrokers()). - setPerServerProperties(clusterConfig.perServerOverrideProperties()). - setNumDisksPerBroker(clusterConfig.numDisksPerBroker()). - setNumControllerNodes(clusterConfig.numControllers()).build(); - KafkaClusterTestKit.Builder builder = new KafkaClusterTestKit.Builder(nodes); - - if (Boolean.parseBoolean(clusterConfig.serverProperties().getOrDefault("zookeeper.metadata.migration.enable", "false"))) { - zkReference.set(new EmbeddedZookeeper()); - builder.setConfigProp("zookeeper.connect", String.format("localhost:%d", zkReference.get().port())); - } - // Copy properties into the TestKit builder - clusterConfig.serverProperties().forEach(builder::setConfigProp); - // KAFKA-12512 need to pass security protocol and listener name here - KafkaClusterTestKit cluster = builder.build(); - clusterReference.set(cluster); - cluster.format(); - if (clusterConfig.isAutoStart()) { - cluster.startup(); - kafka.utils.TestUtils.waitUntilTrue( - () -> cluster.brokers().get(0).brokerState() == BrokerState.RUNNING, - () -> "Broker never made it to RUNNING state.", - org.apache.kafka.test.TestUtils.DEFAULT_MAX_WAIT_MS, - 100L); - } - }, - (AfterTestExecutionCallback) context -> clusterInstance.stop(), - new ClusterInstanceParameterResolver(clusterInstance) + (BeforeTestExecutionCallback) context -> { + clusterInstance.format(); + if (clusterConfig.isAutoStart()) { + clusterInstance.start(); + } + }, + (AfterTestExecutionCallback) context -> clusterInstance.stop(), + new ClusterInstanceParameterResolver(clusterInstance) ); } public static class RaftClusterInstance implements ClusterInstance { - private final AtomicReference clusterReference; - private final AtomicReference zkReference; private final ClusterConfig clusterConfig; final AtomicBoolean started = new AtomicBoolean(false); final AtomicBoolean stopped = new AtomicBoolean(false); + final AtomicBoolean formated = new AtomicBoolean(false); private final ConcurrentLinkedQueue admins = new ConcurrentLinkedQueue<>(); + private EmbeddedZookeeper embeddedZookeeper; + private KafkaClusterTestKit clusterTestKit; private final boolean isCombined; - RaftClusterInstance(AtomicReference clusterReference, AtomicReference zkReference, ClusterConfig clusterConfig, boolean isCombined) { - this.clusterReference = clusterReference; - this.zkReference = zkReference; + RaftClusterInstance(ClusterConfig clusterConfig, boolean isCombined) { this.clusterConfig = clusterConfig; this.isCombined = isCombined; } @Override public String bootstrapServers() { - return clusterReference.get().bootstrapServers(); + return clusterTestKit.bootstrapServers(); } @Override public String bootstrapControllers() { - return clusterReference.get().bootstrapControllers(); + return clusterTestKit.bootstrapControllers(); } @Override @@ -193,25 +164,30 @@ public Set controllerIds() { @Override public KafkaClusterTestKit getUnderlying() { - return clusterReference.get(); + return clusterTestKit; } @Override public Admin createAdminClient(Properties configOverrides) { - Admin admin = Admin.create(clusterReference.get(). - newClientPropertiesBuilder(configOverrides).build()); + Admin admin = Admin.create(clusterTestKit.newClientPropertiesBuilder(configOverrides).build()); admins.add(admin); return admin; } @Override public void start() { - if (started.compareAndSet(false, true)) { - try { - clusterReference.get().startup(); - } catch (Exception e) { - throw new RuntimeException("Failed to start Raft server", e); + try { + format(); + if (started.compareAndSet(false, true)) { + clusterTestKit.startup(); + kafka.utils.TestUtils.waitUntilTrue( + () -> this.clusterTestKit.brokers().get(0).brokerState() == BrokerState.RUNNING, + () -> "Broker never made it to RUNNING state.", + org.apache.kafka.test.TestUtils.DEFAULT_MAX_WAIT_MS, + 100L); } + } catch (Exception e) { + throw new RuntimeException("Failed to start Raft server", e); } } @@ -220,9 +196,9 @@ public void stop() { if (stopped.compareAndSet(false, true)) { admins.forEach(admin -> Utils.closeQuietly(admin, "admin")); admins.clear(); - Utils.closeQuietly(clusterReference.get(), "cluster"); - if (zkReference.get() != null) { - Utils.closeQuietly(zkReference.get(), "zk"); + Utils.closeQuietly(clusterTestKit, "cluster"); + if (embeddedZookeeper != null) { + Utils.closeQuietly(embeddedZookeeper, "zk"); } } } @@ -240,27 +216,51 @@ public void startBroker(int brokerId) { @Override public void waitForReadyBrokers() throws InterruptedException { try { - clusterReference.get().waitForReadyBrokers(); + clusterTestKit.waitForReadyBrokers(); } catch (ExecutionException e) { throw new AssertionError("Failed while waiting for brokers to become ready", e); } } - private BrokerServer findBrokerOrThrow(int brokerId) { - return Optional.ofNullable(clusterReference.get().brokers().get(brokerId)) - .orElseThrow(() -> new IllegalArgumentException("Unknown brokerId " + brokerId)); - } @Override public Map brokers() { - return clusterReference.get().brokers().entrySet() + return clusterTestKit.brokers().entrySet() .stream() .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); } @Override public Map controllers() { - return Collections.unmodifiableMap(clusterReference.get().controllers()); + return Collections.unmodifiableMap(clusterTestKit.controllers()); + } + + public void format() throws Exception { + if (formated.compareAndSet(false, true)) { + TestKitNodes nodes = new TestKitNodes.Builder() + .setBootstrapMetadataVersion(clusterConfig.metadataVersion()) + .setCombined(isCombined) + .setNumBrokerNodes(clusterConfig.numBrokers()) + .setNumDisksPerBroker(clusterConfig.numDisksPerBroker()) + .setPerServerProperties(clusterConfig.perServerOverrideProperties()) + .setNumControllerNodes(clusterConfig.numControllers()).build(); + KafkaClusterTestKit.Builder builder = new KafkaClusterTestKit.Builder(nodes); + if (Boolean.parseBoolean(clusterConfig.serverProperties() + .getOrDefault("zookeeper.metadata.migration.enable", "false"))) { + this.embeddedZookeeper = new EmbeddedZookeeper(); + builder.setConfigProp("zookeeper.connect", String.format("localhost:%d", embeddedZookeeper.port())); + } + // Copy properties into the TestKit builder + clusterConfig.serverProperties().forEach(builder::setConfigProp); + // KAFKA-12512 need to pass security protocol and listener name here + this.clusterTestKit = builder.build(); + this.clusterTestKit.format(); + } + } + + private BrokerServer findBrokerOrThrow(int brokerId) { + return Optional.ofNullable(clusterTestKit.brokers().get(brokerId)) + .orElseThrow(() -> new IllegalArgumentException("Unknown brokerId " + brokerId)); } }