From d99247699e83ffa41237b41e4c6cac887fef6644 Mon Sep 17 00:00:00 2001 From: Jakub Malek Date: Tue, 6 Aug 2024 11:11:25 +0200 Subject: [PATCH] #1360 Prevent ungraceful JdbcSourceTask stoppage with a maximum duration time for the poll operation --- .../source/JdbcSourceConnectorConfig.java | 78 ++++-- .../connect/jdbc/source/JdbcSourceTask.java | 105 ++++---- .../source/JdbcSourceTaskPollExecutor.java | 142 ++++++++++ .../source/JdbcSourceConnectorConfigTest.java | 61 ++++- .../JdbcSourceTaskPollExecutorTest.java | 243 ++++++++++++++++++ 5 files changed, 552 insertions(+), 77 deletions(-) create mode 100644 src/main/java/io/confluent/connect/jdbc/source/JdbcSourceTaskPollExecutor.java create mode 100644 src/test/java/io/confluent/connect/jdbc/source/JdbcSourceTaskPollExecutorTest.java diff --git a/src/main/java/io/confluent/connect/jdbc/source/JdbcSourceConnectorConfig.java b/src/main/java/io/confluent/connect/jdbc/source/JdbcSourceConnectorConfig.java index 17bebf0ba..d294287c2 100644 --- a/src/main/java/io/confluent/connect/jdbc/source/JdbcSourceConnectorConfig.java +++ b/src/main/java/io/confluent/connect/jdbc/source/JdbcSourceConnectorConfig.java @@ -26,16 +26,6 @@ import java.util.Map; import java.util.TimeZone; import java.util.concurrent.atomic.AtomicReference; - -import com.microsoft.sqlserver.jdbc.SQLServerConnection; -import io.confluent.connect.jdbc.dialect.DatabaseDialect; -import io.confluent.connect.jdbc.dialect.DatabaseDialects; -import io.confluent.connect.jdbc.util.DatabaseDialectRecommender; -import io.confluent.connect.jdbc.util.DateTimeUtils; -import io.confluent.connect.jdbc.util.EnumRecommender; -import io.confluent.connect.jdbc.util.QuoteMethod; -import io.confluent.connect.jdbc.util.TimeZoneValidator; - import java.util.function.BiFunction; import java.util.function.Function; import java.util.regex.Pattern; @@ -44,6 +34,7 @@ import org.apache.kafka.common.config.Config; import org.apache.kafka.common.config.ConfigDef; import org.apache.kafka.common.config.ConfigDef.Importance; +import org.apache.kafka.common.config.ConfigDef.Range; import org.apache.kafka.common.config.ConfigDef.Recommender; import org.apache.kafka.common.config.ConfigDef.Type; import org.apache.kafka.common.config.ConfigDef.Validator; @@ -56,10 +47,20 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.microsoft.sqlserver.jdbc.ISQLServerConnection; + +import io.confluent.connect.jdbc.dialect.DatabaseDialect; +import io.confluent.connect.jdbc.dialect.DatabaseDialects; +import io.confluent.connect.jdbc.util.DatabaseDialectRecommender; +import io.confluent.connect.jdbc.util.DateTimeUtils; +import io.confluent.connect.jdbc.util.EnumRecommender; +import io.confluent.connect.jdbc.util.QuoteMethod; +import io.confluent.connect.jdbc.util.TimeZoneValidator; + public class JdbcSourceConnectorConfig extends AbstractConfig { private static final Logger LOG = LoggerFactory.getLogger(JdbcSourceConnectorConfig.class); - private static Pattern INVALID_CHARS = Pattern.compile("[^a-zA-Z0-9._-]"); + private static final Pattern INVALID_CHARS = Pattern.compile("[^a-zA-Z0-9._-]"); public static final String CONNECTION_PREFIX = "connection."; @@ -101,6 +102,19 @@ public class JdbcSourceConnectorConfig extends AbstractConfig { public static final int POLL_INTERVAL_MS_DEFAULT = 5000; private static final String POLL_INTERVAL_MS_DISPLAY = "Poll Interval (ms)"; + public static final String POLL_MAX_WAIT_TIME_MS_CONFIG = "poll.max.wait.time.ms"; + public static final String POLL_MAX_WAIT_TIME_MS_DOC = "The maximum time in ms to wait by " + + "the worker task for the poll operation. This includes additional poll.interval.ms " + + "wait time applied in between subsequent poll calls. If the set maximum time is exceeded, " + + "the task will signal no-data to the worker. The polling operation however will not be " + + "interrupted until the task is stopped. Each time the worker is poll the records from the " + + "source task it will either wait for the result from the previously started polling " + + "operation or a new polling operation will be started. " + + "When the poll.max.wait.time.ms is set to zero, then the worker will wait indefinitely " + + "until the polling operation is finished."; + public static final int POLL_MAX_WAIT_TIME_MS_DEFAULT = 1_000; + private static final String POLL_MAX_DURATION_MS_DISPLAY = "Poll Max Wait Time (ms)"; + public static final String BATCH_MAX_ROWS_CONFIG = "batch.max.rows"; private static final String BATCH_MAX_ROWS_DOC = "Maximum number of rows to include in a single batch when polling for new data. This " @@ -314,7 +328,7 @@ public class JdbcSourceConnectorConfig extends AbstractConfig { public static final String QUERY_SUFFIX_CONFIG = "query.suffix"; public static final String QUERY_SUFFIX_DEFAULT = ""; - public static final String QUERY_SUFFIX_DOC = + public static final String QUERY_SUFFIX_DOC = "Suffix to append at the end of the generated query."; public static final String QUERY_SUFFIX_DISPLAY = "Query suffix"; @@ -401,18 +415,15 @@ public Config validateMultiConfigs(Config config) { } else { dialect = DatabaseDialects.findBestFor(this.getString(CONNECTION_URL_CONFIG), this); } - if (!dialect.name().equals( - DatabaseDialects.create( - SqlServerDatabaseDialectName, this - ).name() - ) - ) { - configValues - .get(JdbcSourceConnectorConfig.TRANSACTION_ISOLATION_MODE_CONFIG) - .addErrorMessage("Isolation mode of `" - + TransactionIsolationMode.SQL_SERVER_SNAPSHOT.name() - + "` can only be configured with a Sql Server Dialect" - ); + try (DatabaseDialect sqlServerDialect = DatabaseDialects.create( + SqlServerDatabaseDialectName, this)) { + if (!dialect.name().equals(sqlServerDialect.name())) { + configValues + .get(JdbcSourceConnectorConfig.TRANSACTION_ISOLATION_MODE_CONFIG) + .addErrorMessage("Isolation mode of `" + + TransactionIsolationMode.SQL_SERVER_SNAPSHOT.name() + + "` can only be configured with a Sql Server Dialect"); + } } } @@ -694,6 +705,17 @@ private static final void addConnectorOptions(ConfigDef config) { ++orderInGroup, Width.SHORT, POLL_INTERVAL_MS_DISPLAY + ).define( + POLL_MAX_WAIT_TIME_MS_CONFIG, + Type.INT, + POLL_MAX_WAIT_TIME_MS_DEFAULT, + Range.atLeast(0), + Importance.MEDIUM, + POLL_MAX_WAIT_TIME_MS_DOC, + CONNECTOR_GROUP, + ++orderInGroup, + Width.SHORT, + POLL_MAX_DURATION_MS_DISPLAY ).define( BATCH_MAX_ROWS_CONFIG, Type.INT, @@ -792,7 +814,7 @@ public JdbcSourceConnectorConfig(Map props) { } public String topicPrefix() { - return getString(JdbcSourceTaskConfig.TOPIC_PREFIX_CONFIG).trim(); + return getString(TOPIC_PREFIX_CONFIG).trim(); } /** @@ -914,7 +936,7 @@ public static NumericMapping get(JdbcSourceConnectorConfig config) { if (newMappingConfig != null) { return get(config.getString(JdbcSourceConnectorConfig.NUMERIC_MAPPING_CONFIG)); } - if (config.getBoolean(JdbcSourceTaskConfig.NUMERIC_PRECISION_MAPPING_CONFIG)) { + if (config.getBoolean(NUMERIC_PRECISION_MAPPING_CONFIG)) { return NumericMapping.PRECISION_ONLY; } return NumericMapping.NONE; @@ -993,7 +1015,7 @@ public static int get(TransactionIsolationMode mode) { case SERIALIZABLE: return Connection.TRANSACTION_SERIALIZABLE; case SQL_SERVER_SNAPSHOT: - return SQLServerConnection.TRANSACTION_SNAPSHOT; + return ISQLServerConnection.TRANSACTION_SNAPSHOT; default: return -1; } @@ -1010,7 +1032,7 @@ public NumericMapping numericMapping() { } public TimeZone timeZone() { - String dbTimeZone = getString(JdbcSourceTaskConfig.DB_TIMEZONE_CONFIG); + String dbTimeZone = getString(DB_TIMEZONE_CONFIG); return TimeZone.getTimeZone(ZoneId.of(dbTimeZone)); } diff --git a/src/main/java/io/confluent/connect/jdbc/source/JdbcSourceTask.java b/src/main/java/io/confluent/connect/jdbc/source/JdbcSourceTask.java index 413e1658d..3288503f6 100644 --- a/src/main/java/io/confluent/connect/jdbc/source/JdbcSourceTask.java +++ b/src/main/java/io/confluent/connect/jdbc/source/JdbcSourceTask.java @@ -15,19 +15,9 @@ package io.confluent.connect.jdbc.source; -import java.sql.SQLNonTransientException; -import java.util.TimeZone; -import org.apache.kafka.common.config.ConfigException; -import org.apache.kafka.common.utils.SystemTime; -import org.apache.kafka.common.utils.Time; -import org.apache.kafka.connect.errors.ConnectException; -import org.apache.kafka.connect.source.SourceRecord; -import org.apache.kafka.connect.source.SourceTask; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import java.sql.Connection; import java.sql.SQLException; +import java.sql.SQLNonTransientException; import java.util.ArrayList; import java.util.Arrays; import java.util.Calendar; @@ -39,19 +29,29 @@ import java.util.Map; import java.util.PriorityQueue; import java.util.Set; +import java.util.TimeZone; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Function; import java.util.stream.Collectors; +import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.utils.SystemTime; +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.source.SourceRecord; +import org.apache.kafka.connect.source.SourceTask; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import io.confluent.connect.jdbc.dialect.DatabaseDialect; import io.confluent.connect.jdbc.dialect.DatabaseDialects; +import io.confluent.connect.jdbc.source.JdbcSourceConnectorConfig.TransactionIsolationMode; import io.confluent.connect.jdbc.util.CachedConnectionProvider; import io.confluent.connect.jdbc.util.ColumnDefinition; import io.confluent.connect.jdbc.util.ColumnId; import io.confluent.connect.jdbc.util.TableId; import io.confluent.connect.jdbc.util.Version; -import io.confluent.connect.jdbc.source.JdbcSourceConnectorConfig.TransactionIsolationMode; /** * JdbcSourceTask is a Kafka Connect SourceTask implementation that reads from JDBC databases and @@ -66,6 +66,7 @@ public class JdbcSourceTask extends SourceTask { private Time time; private JdbcSourceTaskConfig config; private DatabaseDialect dialect; + private JdbcSourceTaskPollExecutor pollExecutor; //Visible for Testing CachedConnectionProvider cachedConnectionProvider; PriorityQueue tableQueue = new PriorityQueue<>(); @@ -98,7 +99,7 @@ public void start(Map properties) { List tables = config.getList(JdbcSourceTaskConfig.TABLES_CONFIG); Boolean tablesFetched = config.getBoolean(JdbcSourceTaskConfig.TABLES_FETCHED); - String query = config.getString(JdbcSourceTaskConfig.QUERY_CONFIG); + String query = config.getString(JdbcSourceConnectorConfig.QUERY_CONFIG); if ((tables.isEmpty() && query.isEmpty())) { // We are still waiting for the tables call to complete. @@ -155,13 +156,13 @@ public void start(Map properties) { List tablesOrQuery = queryMode == TableQuerier.QueryMode.QUERY ? Collections.singletonList(query) : tables; - String mode = config.getString(JdbcSourceTaskConfig.MODE_CONFIG); + String mode = config.getString(JdbcSourceConnectorConfig.MODE_CONFIG); //used only in table mode Map>> partitionsByTableFqn = new HashMap<>(); Map, Map> offsets = null; - if (mode.equals(JdbcSourceTaskConfig.MODE_INCREMENTING) - || mode.equals(JdbcSourceTaskConfig.MODE_TIMESTAMP) - || mode.equals(JdbcSourceTaskConfig.MODE_TIMESTAMP_INCREMENTING)) { + if (mode.equals(JdbcSourceConnectorConfig.MODE_INCREMENTING) + || mode.equals(JdbcSourceConnectorConfig.MODE_TIMESTAMP) + || mode.equals(JdbcSourceConnectorConfig.MODE_TIMESTAMP_INCREMENTING)) { List> partitions = new ArrayList<>(tables.size()); switch (queryMode) { case TABLE: @@ -187,15 +188,15 @@ public void start(Map properties) { } String incrementingColumn - = config.getString(JdbcSourceTaskConfig.INCREMENTING_COLUMN_NAME_CONFIG); + = config.getString(JdbcSourceConnectorConfig.INCREMENTING_COLUMN_NAME_CONFIG); List timestampColumns - = config.getList(JdbcSourceTaskConfig.TIMESTAMP_COLUMN_NAME_CONFIG); + = config.getList(JdbcSourceConnectorConfig.TIMESTAMP_COLUMN_NAME_CONFIG); Long timestampDelayInterval - = config.getLong(JdbcSourceTaskConfig.TIMESTAMP_DELAY_INTERVAL_MS_CONFIG); + = config.getLong(JdbcSourceConnectorConfig.TIMESTAMP_DELAY_INTERVAL_MS_CONFIG); boolean validateNonNulls - = config.getBoolean(JdbcSourceTaskConfig.VALIDATE_NON_NULL_CONFIG); + = config.getBoolean(JdbcSourceConnectorConfig.VALIDATE_NON_NULL_CONFIG); TimeZone timeZone = config.timeZone(); - String suffix = config.getString(JdbcSourceTaskConfig.QUERY_SUFFIX_CONFIG).trim(); + String suffix = config.getString(JdbcSourceConnectorConfig.QUERY_SUFFIX_CONFIG).trim(); if (queryMode.equals(TableQuerier.QueryMode.TABLE)) { validateColumnsExist(mode, incrementingColumn, timestampColumns, tables.get(0)); @@ -246,17 +247,17 @@ public void start(Map properties) { JdbcSourceConnectorConfig.TimestampGranularity timestampGranularity = JdbcSourceConnectorConfig.TimestampGranularity.get(config); - if (mode.equals(JdbcSourceTaskConfig.MODE_BULK)) { + if (mode.equals(JdbcSourceConnectorConfig.MODE_BULK)) { tableQueue.add( new BulkTableQuerier( - dialect, - queryMode, - tableOrQuery, - topicPrefix, + dialect, + queryMode, + tableOrQuery, + topicPrefix, suffix ) ); - } else if (mode.equals(JdbcSourceTaskConfig.MODE_INCREMENTING)) { + } else if (mode.equals(JdbcSourceConnectorConfig.MODE_INCREMENTING)) { tableQueue.add( new TimestampIncrementingTableQuerier( dialect, @@ -272,7 +273,7 @@ public void start(Map properties) { timestampGranularity ) ); - } else if (mode.equals(JdbcSourceTaskConfig.MODE_TIMESTAMP)) { + } else if (mode.equals(JdbcSourceConnectorConfig.MODE_TIMESTAMP)) { tableQueue.add( new TimestampTableQuerier( dialect, @@ -287,7 +288,7 @@ public void start(Map properties) { timestampGranularity ) ); - } else if (mode.endsWith(JdbcSourceTaskConfig.MODE_TIMESTAMP_INCREMENTING)) { + } else if (mode.endsWith(JdbcSourceConnectorConfig.MODE_TIMESTAMP_INCREMENTING)) { tableQueue.add( new TimestampIncrementingTableQuerier( dialect, @@ -305,12 +306,12 @@ public void start(Map properties) { ); } } + maxRetriesPerQuerier = config.getInt(JdbcSourceConnectorConfig.QUERY_RETRIES_CONFIG); + pollExecutor = new JdbcSourceTaskPollExecutor(time, config, this::doPoll); running.set(true); taskThreadId.set(Thread.currentThread().getId()); log.info("Started JDBC source task"); - - maxRetriesPerQuerier = config.getInt(JdbcSourceConnectorConfig.QUERY_RETRIES_CONFIG); } private void validateColumnsExist( @@ -324,16 +325,16 @@ private void validateColumnsExist( Set columnNames = defnsById.keySet().stream().map(ColumnId::name) .map(String::toLowerCase).collect(Collectors.toSet()); - if ((mode.equals(JdbcSourceTaskConfig.MODE_INCREMENTING) - || mode.equals(JdbcSourceTaskConfig.MODE_TIMESTAMP_INCREMENTING)) + if ((mode.equals(JdbcSourceConnectorConfig.MODE_INCREMENTING) + || mode.equals(JdbcSourceConnectorConfig.MODE_TIMESTAMP_INCREMENTING)) && !incrementingColumn.isEmpty() && !columnNames.contains(incrementingColumn.toLowerCase(Locale.getDefault()))) { throw new ConfigException("Incrementing column: " + incrementingColumn + " does not exist."); } - if ((mode.equals(JdbcSourceTaskConfig.MODE_TIMESTAMP) - || mode.equals(JdbcSourceTaskConfig.MODE_TIMESTAMP_INCREMENTING)) + if ((mode.equals(JdbcSourceConnectorConfig.MODE_TIMESTAMP) + || mode.equals(JdbcSourceConnectorConfig.MODE_TIMESTAMP_INCREMENTING)) && !timestampColumns.isEmpty()) { Set missingTsColumns = timestampColumns.stream() @@ -443,14 +444,18 @@ protected void closeResources() { @Override public List poll() throws InterruptedException { + return pollExecutor.poll(); + } + + private List doPoll() { log.trace("Polling for new data"); // If the call to get tables has not completed we will not do anything. // This is only valid in table mode. Boolean tablesFetched = config.getBoolean(JdbcSourceTaskConfig.TABLES_FETCHED); - String query = config.getString(JdbcSourceTaskConfig.QUERY_CONFIG); + String query = config.getString(JdbcSourceConnectorConfig.QUERY_CONFIG); if (query.isEmpty() && !tablesFetched) { - final long sleepMs = config.getInt(JdbcSourceTaskConfig.POLL_INTERVAL_MS_CONFIG); + final long sleepMs = config.getInt(JdbcSourceConnectorConfig.POLL_INTERVAL_MS_CONFIG); log.trace("Waiting for tables to be fetched from the database. No records will be polled. " + "Waiting {} ms to poll", sleepMs); time.sleep(sleepMs); @@ -458,19 +463,19 @@ public List poll() throws InterruptedException { } Map consecutiveEmptyResults = tableQueue.stream().collect( - Collectors.toMap(Function.identity(), (q) -> 0)); + Collectors.toMap(Function.identity(), q -> 0)); while (running.get()) { final TableQuerier querier = tableQueue.peek(); if (!querier.querying()) { // If not in the middle of an update, wait for next update time final long nextUpdate = querier.getLastUpdate() - + config.getInt(JdbcSourceTaskConfig.POLL_INTERVAL_MS_CONFIG); + + config.getInt(JdbcSourceConnectorConfig.POLL_INTERVAL_MS_CONFIG); final long now = time.milliseconds(); final long sleepMs = Math.min(nextUpdate - now, 100); if (sleepMs > 0) { - log.trace("Waiting {} ms to poll {} next", nextUpdate - now, querier.toString()); + log.trace("Waiting {} ms to poll {} next", nextUpdate - now, querier); time.sleep(sleepMs); continue; // Re-check stop flag before continuing } @@ -478,10 +483,10 @@ public List poll() throws InterruptedException { final List results = new ArrayList<>(); try { - log.debug("Checking for next block of results from {}", querier.toString()); + log.debug("Checking for next block of results from {}", querier); querier.maybeStartQuery(cachedConnectionProvider.getConnection()); - int batchMaxRows = config.getInt(JdbcSourceTaskConfig.BATCH_MAX_ROWS_CONFIG); + int batchMaxRows = config.getInt(JdbcSourceConnectorConfig.BATCH_MAX_ROWS_CONFIG); boolean hadNext = true; while (results.size() < batchMaxRows && (hadNext = querier.next())) { results.add(querier.extractRecord()); @@ -496,7 +501,7 @@ public List poll() throws InterruptedException { if (results.isEmpty()) { consecutiveEmptyResults.compute(querier, (k, v) -> v + 1); - log.trace("No updates for {}", querier.toString()); + log.trace("No updates for {}", querier); if (Collections.min(consecutiveEmptyResults.values()) >= CONSECUTIVE_EMPTY_RESULTS_BEFORE_RETURN) { @@ -554,11 +559,19 @@ private void shutdown() { if (querier != null) { resetAndRequeueHead(querier, true); } + closePollExecutor(); closeResources(); } + private void closePollExecutor() { + if (pollExecutor != null) { + pollExecutor.close(); + pollExecutor = null; + } + } + private void resetAndRequeueHead(TableQuerier expectedHead, boolean resetOffset) { - log.debug("Resetting querier {}", expectedHead.toString()); + log.debug("Resetting querier {}", expectedHead); TableQuerier removedQuerier = tableQueue.poll(); assert removedQuerier == expectedHead; expectedHead.reset(time.milliseconds(), resetOffset); @@ -588,7 +601,7 @@ private void validateNonNullable( String columnName = defn.id().name(); if (columnName.equalsIgnoreCase(incrementingColumn)) { incrementingOptional = defn.isOptional(); - } else if (lowercaseTsColumns.contains(columnName.toLowerCase(Locale.getDefault()))) { + } else if (lowercaseTsColumns.contains(columnName.toLowerCase(Locale.ROOT))) { if (!defn.isOptional()) { atLeastOneTimestampNotOptional = true; } diff --git a/src/main/java/io/confluent/connect/jdbc/source/JdbcSourceTaskPollExecutor.java b/src/main/java/io/confluent/connect/jdbc/source/JdbcSourceTaskPollExecutor.java new file mode 100644 index 000000000..e34c419e8 --- /dev/null +++ b/src/main/java/io/confluent/connect/jdbc/source/JdbcSourceTaskPollExecutor.java @@ -0,0 +1,142 @@ +/* + * Copyright 2018 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.connect.jdbc.source; + +import static io.confluent.connect.jdbc.source.JdbcSourceConnectorConfig.POLL_MAX_WAIT_TIME_MS_CONFIG; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.MILLISECONDS; + +import java.io.Closeable; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; + +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.source.SourceRecord; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Utility responsible for managing execution of the JDBC source task poll operation. + *

+ * When the poll.max.wait.time.ms is set to zero, the executor will simply execute the poll + * operation directly in the current thread. + * Otherwise, the poll operation will be executed in a new thread and the executor will + * wait up to the configured poll.max.wait.time.ms time for the started thread to finish. + * If the thread is not finished in time, the executor will return a null list of + * source records, signaling there is no data to the worker. + * In the next poll call, the executor will either try to wait again for the previously + * started thread to finish, or create a new one and apply the same wait logic. + *

+ */ +final class JdbcSourceTaskPollExecutor implements Closeable { + private static final Logger LOG = LoggerFactory.getLogger(JdbcSourceTaskPollExecutor.class); + private static final List NO_DATA = null; + + private final AtomicReference pollFuture = new AtomicReference<>(); + + private final Time time; + private final ExecutorService pollTaskExecutor; + private final Supplier> pollOperation; + private final int pollMaxWaitTimeMs; + + /** + * Creates the {@link JdbcSourceTaskPollExecutor}. + * + * @param time + * the component providing the current time measurement + * @param config + * the configuration of the JDBC source connector + * @param pollOperation + * the poll operation function + */ + JdbcSourceTaskPollExecutor(Time time, JdbcSourceConnectorConfig config, + Supplier> pollOperation) { + this.time = requireNonNull(time, "time must not be null"); + this.pollOperation = requireNonNull(pollOperation, "pollOperation must not be null"); + pollMaxWaitTimeMs = requireNonNull(config, "config must not be null") + .getInt(POLL_MAX_WAIT_TIME_MS_CONFIG); + pollTaskExecutor = Executors.newSingleThreadExecutor(task -> { + Thread thread = new Thread(task); + thread.setName(String.format("%s-poll-thread-%s", + config.getString("name"), config.getString("task.id"))); + return thread; + }); + } + + List poll() throws InterruptedException { + if (pollMaxWaitTimeMs <= 0) { + // waiting without timeout + return pollOperation.get(); + } + PollingFuture polling = getOrCreatePollingFuture(); + try { + List result = polling.future.get(pollMaxWaitTimeMs, MILLISECONDS); + pollFuture.compareAndSet(polling, null); + return result; + } catch (@SuppressWarnings("unused") TimeoutException e) { + LOG.info("Polling exceeded maximum duration of {}ms the total elapsed time is {}ms", + pollMaxWaitTimeMs, polling.elapsed(time)); + return NO_DATA; + } catch (ExecutionException e) { + if (e.getCause() instanceof RuntimeException) { + throw (RuntimeException) e.getCause(); + } + throw new ConnectException("Error while polling", e); + } + } + + private PollingFuture getOrCreatePollingFuture() { + return pollFuture.updateAndGet(polling -> polling != null + ? polling : new PollingFuture(time.milliseconds(), + pollTaskExecutor.submit(pollOperation::get))); + } + + @Override + public void close() { + cancelCurrentPolling(); + pollTaskExecutor.shutdown(); + } + + private void cancelCurrentPolling() { + pollFuture.updateAndGet(polling -> { + if (polling != null) { + polling.future.cancel(true); + } + return null; + }); + } + + private static class PollingFuture { + private final long startTimeMillis; + private final Future> future; + + private PollingFuture(long startTimeMillis, Future> future) { + this.startTimeMillis = startTimeMillis; + this.future = future; + } + + private long elapsed(Time time) { + return time.milliseconds() - startTimeMillis; + } + } +} diff --git a/src/test/java/io/confluent/connect/jdbc/source/JdbcSourceConnectorConfigTest.java b/src/test/java/io/confluent/connect/jdbc/source/JdbcSourceConnectorConfigTest.java index 7059cafcc..166e9655a 100644 --- a/src/test/java/io/confluent/connect/jdbc/source/JdbcSourceConnectorConfigTest.java +++ b/src/test/java/io/confluent/connect/jdbc/source/JdbcSourceConnectorConfigTest.java @@ -28,6 +28,7 @@ import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.modules.junit4.PowerMockRunner; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashMap; @@ -37,6 +38,7 @@ import io.confluent.connect.jdbc.source.JdbcSourceConnectorConfig.CachedRecommenderValues; import io.confluent.connect.jdbc.source.JdbcSourceConnectorConfig.CachingRecommender; +import static java.util.Collections.emptyList; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; @@ -94,7 +96,7 @@ public void testConnectionAttemptsAtLeastOne() { } @Test - public void testConfigTableNameRecommenderWithoutSchemaOrTableTypes() throws Exception { + public void testConfigTableNameRecommenderWithoutSchemaOrTableTypes() { props.put(JdbcSourceConnectorConfig.CONNECTION_URL_CONFIG, db.getUrl()); configDef = JdbcSourceConnectorConfig.baseConfigDef(); results = configDef.validate(props); @@ -104,7 +106,7 @@ public void testConfigTableNameRecommenderWithoutSchemaOrTableTypes() throws Exc } @Test - public void testConfigTableNameRecommenderWitSchemaAndWithoutTableTypes() throws Exception { + public void testConfigTableNameRecommenderWitSchemaAndWithoutTableTypes() { props.put(JdbcSourceConnectorConfig.CONNECTION_URL_CONFIG, db.getUrl()); props.put(JdbcSourceConnectorConfig.SCHEMA_PATTERN_CONFIG, "PRIVATE_SCHEMA"); configDef = JdbcSourceConnectorConfig.baseConfigDef(); @@ -115,7 +117,7 @@ public void testConfigTableNameRecommenderWitSchemaAndWithoutTableTypes() throws } @Test - public void testConfigTableNameRecommenderWithSchemaAndTableTypes() throws Exception { + public void testConfigTableNameRecommenderWithSchemaAndTableTypes() { props.put(JdbcSourceConnectorConfig.CONNECTION_URL_CONFIG, db.getUrl()); props.put(JdbcSourceConnectorConfig.SCHEMA_PATTERN_CONFIG, "PRIVATE_SCHEMA"); props.put(JdbcSourceConnectorConfig.TABLE_TYPE_CONFIG, "VIEW"); @@ -253,6 +255,59 @@ public void testTooLongTopicPrefix() { assertFalse(connectionAttemptsConfig.errorMessages().isEmpty()); } + @Test + public void testMaxPollingWaitTimeMs() { + // given + props.put(JdbcSourceConnectorConfig.POLL_MAX_WAIT_TIME_MS_CONFIG, "7777"); + + // when + Map validatedConfig = + JdbcSourceConnectorConfig.baseConfigDef().validateAll(props); + + // then + ConfigValue configValue = + validatedConfig.get(JdbcSourceConnectorConfig.POLL_MAX_WAIT_TIME_MS_CONFIG); + assertNotNull(configValue); + assertEquals(7777, configValue.value()); + assertEquals(emptyList(), configValue.errorMessages()); + assertEquals(emptyList(), configValue.recommendedValues()); + } + + @Test + public void testMaxPollingWaitTimeMsDefaultValue() { + // when + Map validatedConfig = + JdbcSourceConnectorConfig.baseConfigDef().validateAll(props); + + // then + ConfigValue configValue = + validatedConfig.get(JdbcSourceConnectorConfig.POLL_MAX_WAIT_TIME_MS_CONFIG); + assertNotNull(configValue); + assertEquals(1000, configValue.value()); + assertEquals(emptyList(), configValue.errorMessages()); + assertEquals(emptyList(), configValue.recommendedValues()); + } + + @Test + public void testMaxPollingWaitTimeMsInvalidValue() { + // given + props.put(JdbcSourceConnectorConfig.POLL_MAX_WAIT_TIME_MS_CONFIG, "-1"); + + // when + Map validatedConfig = + JdbcSourceConnectorConfig.baseConfigDef().validateAll(props); + + // then + ConfigValue configValue = + validatedConfig.get(JdbcSourceConnectorConfig.POLL_MAX_WAIT_TIME_MS_CONFIG); + assertNotNull(configValue); + assertEquals(-1, configValue.value()); + assertEquals(Arrays.asList("Invalid value -1 for configuration " + + "poll.max.wait.time.ms: Value must be at least 0"), + configValue.errorMessages()); + assertEquals(emptyList(), configValue.recommendedValues()); + } + @SuppressWarnings("unchecked") protected void assertContains(Collection actual, T... expected) { for (T e : expected) { diff --git a/src/test/java/io/confluent/connect/jdbc/source/JdbcSourceTaskPollExecutorTest.java b/src/test/java/io/confluent/connect/jdbc/source/JdbcSourceTaskPollExecutorTest.java new file mode 100644 index 000000000..f93305d6d --- /dev/null +++ b/src/test/java/io/confluent/connect/jdbc/source/JdbcSourceTaskPollExecutorTest.java @@ -0,0 +1,243 @@ +/* + * Copyright 2018 Confluent Inc. + * + * Licensed under the Confluent Community License (the "License"); you may not use + * this file except in compliance with the License. You may obtain a copy of the + * License at + * + * http://www.confluent.io/confluent-community-license + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ + +package io.confluent.connect.jdbc.source; + +import static io.confluent.connect.jdbc.source.JdbcSourceConnectorConfig.POLL_MAX_WAIT_TIME_MS_CONFIG; +import static java.util.Arrays.asList; +import static java.util.Collections.unmodifiableList; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.testcontainers.shaded.org.awaitility.Awaitility.await; + +import java.util.ArrayList; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import org.apache.kafka.common.utils.Time; +import org.apache.kafka.connect.errors.ConnectException; +import org.apache.kafka.connect.source.SourceRecord; +import org.junit.Before; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class JdbcSourceTaskPollExecutorTest { + private static final Logger LOG = LoggerFactory.getLogger(JdbcSourceConnectorConfigTest.class); + private static final long TEST_TIMEOUT_MS = 1_000L; + + private final Queue recordsToPoll = new LinkedBlockingQueue<>(); + private final Queue pollThreadNames = new LinkedBlockingQueue<>(); + private final AtomicInteger interruptedPollThreads = new AtomicInteger(); + private final AtomicReference pollInterceptor = new AtomicReference<>(); + private JdbcSourceConnectorConfig config; + + @Before + public void setup() { + config = mock(JdbcSourceConnectorConfig.class); + doReturn("test-connector").when(config).getString("name"); + doReturn("7").when(config).getString("task.id"); + + recordsToPoll.clear(); + pollThreadNames.clear(); + interruptedPollThreads.set(0); + pollInterceptor.set(null); + } + + @Test(timeout = TEST_TIMEOUT_MS) + public void shouldPollSourceRecordsInTheCurrentThreadWhenPollMaxWaitTimeIsZero() throws InterruptedException { + // given + List expected = asList(mock(SourceRecord.class), mock(SourceRecord.class)); + recordsToPoll.addAll(expected); + doReturn(0).when(config).getInt(POLL_MAX_WAIT_TIME_MS_CONFIG); + try (JdbcSourceTaskPollExecutor tested = createPollExecutor()) { + // when + List result = tested.poll(); + + // then + assertEquals(expected, result); + assertEquals(1, pollThreadNames.size()); + assertEquals(Thread.currentThread().getName(), pollThreadNames.peek()); + assertEquals(0, interruptedPollThreads.get()); + } + } + + @Test(timeout = TEST_TIMEOUT_MS) + public void shouldPollSourceRecordsInAnotherThreadWhenPollMaxWaitTimeIsAboveZero() throws InterruptedException { + // given + List expected = asList(mock(SourceRecord.class)); + recordsToPoll.addAll(expected); + doReturn(1).when(config).getInt(POLL_MAX_WAIT_TIME_MS_CONFIG); + try (JdbcSourceTaskPollExecutor tested = createPollExecutor()) { + // when + List result = tested.poll(); + + // then + assertEquals(expected, result); + assertEquals(1, pollThreadNames.size()); + assertEquals("test-connector-poll-thread-7", pollThreadNames.peek()); + assertNotEquals(Thread.currentThread().getName(), pollThreadNames.peek()); + assertEquals(0, interruptedPollThreads.get()); + } + } + + @Test(timeout = TEST_TIMEOUT_MS) + public void shouldReturnNullSourceRecordsWhenPollMaxWaitTimeIsExceeded() throws InterruptedException { + // given + List expected = asList(mock(SourceRecord.class)); + CountDownLatch pollLatch = new CountDownLatch(1); + recordsToPoll.addAll(expected); + blockPollingWith(pollLatch); + doReturn(100).when(config).getInt(POLL_MAX_WAIT_TIME_MS_CONFIG); + try (JdbcSourceTaskPollExecutor tested = createPollExecutor()) { + // when + List result1 = tested.poll(); + + // then + assertNull(result1); + assertEquals(1, pollThreadNames.size()); + assertNotEquals(Thread.currentThread().getName(), pollThreadNames.peek()); + + // when called again + List result2 = tested.poll(); + + // then still no result + assertNull(result2); + assertEquals(1, pollThreadNames.size()); + + // when unblocked + pollLatch.countDown(); + List result3 = tested.poll(); + + // then polling results should be finally returned + assertEquals(expected, result3); + assertEquals(1, pollThreadNames.size()); + assertEquals(0, interruptedPollThreads.get()); + } + } + + @Test(timeout = TEST_TIMEOUT_MS) + public void shouldReturnResultOfPreviouslyStartedAndFinishedPolling() throws InterruptedException { + // given + List expected = asList(mock(SourceRecord.class), mock(SourceRecord.class)); + List expectedNextResult = asList(mock(SourceRecord.class)); + CountDownLatch pollLatch = new CountDownLatch(1); + recordsToPoll.addAll(expected); + doReturn(100).when(config).getInt(POLL_MAX_WAIT_TIME_MS_CONFIG); + blockPollingWith(pollLatch); + try (JdbcSourceTaskPollExecutor tested = createPollExecutor()) { + assertNull(tested.poll()); + pollLatch.countDown(); + // waiting for the started polling to be finished + await().atMost(200, MILLISECONDS).until(() -> recordsToPoll.isEmpty()); + + // when + List result1 = tested.poll(); + + // then + assertEquals(expected, result1); + assertEquals(1, pollThreadNames.size()); + assertNotEquals(Thread.currentThread().getName(), pollThreadNames.peek()); + + // when new records are added and next polling is called + recordsToPoll.addAll(expectedNextResult); + List result2 = tested.poll(); + + // then the new results should be returned in a new thread + assertEquals(expectedNextResult, result2); + assertEquals(2, pollThreadNames.size()); + assertEquals(0, interruptedPollThreads.get()); + } + } + + @Test(timeout = TEST_TIMEOUT_MS) + public void shouldRethrowPollRuntimeException() { + ConnectException expectedException = new ConnectException("something bad happened:("); + doReturn(50).when(config).getInt(POLL_MAX_WAIT_TIME_MS_CONFIG); + pollInterceptor.set(() -> { + throw expectedException; + }); + try (JdbcSourceTaskPollExecutor tested = createPollExecutor()) { + // when + Exception caughtException = assertThrows(Exception.class, tested::poll); + + // then + assertEquals(expectedException, caughtException); + } + } + + @Test(timeout = TEST_TIMEOUT_MS) + public void shouldCancelCurrentPollingOnClose() throws InterruptedException { + // given + CountDownLatch pollLatch = new CountDownLatch(1); + recordsToPoll.offer(mock(SourceRecord.class)); + doReturn(50).when(config).getInt(POLL_MAX_WAIT_TIME_MS_CONFIG); + blockPollingWith(pollLatch); + try (JdbcSourceTaskPollExecutor tested = createPollExecutor()) { + assertNull(tested.poll()); + + // when + tested.close(); + + // then + await().atMost(200, MILLISECONDS).until(() -> recordsToPoll.isEmpty()); + assertEquals(1, interruptedPollThreads.get()); + } + } + + private JdbcSourceTaskPollExecutor createPollExecutor() { + return new JdbcSourceTaskPollExecutor(Time.SYSTEM, config, this::poll); + } + + private void blockPollingWith(CountDownLatch latch) { + pollInterceptor.set(() -> { + if (latch.getCount() > 0) { + try { + if (!latch.await(TEST_TIMEOUT_MS, MILLISECONDS)) { + throw new AssertionError("Timeout while waiting to unblock polling"); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + LOG.warn("Polling lock interrupted", e); + } + } + }); + } + + private List poll() { + Runnable pollHandler = pollInterceptor.get(); + pollThreadNames.offer(Thread.currentThread().getName()); + if (pollHandler != null) { + pollHandler.run(); + } + final List result = new ArrayList<>(); + while (!recordsToPoll.isEmpty()) { + result.add(recordsToPoll.poll()); + } + if (Thread.currentThread().isInterrupted()) { + interruptedPollThreads.incrementAndGet(); + } + return unmodifiableList(result); + } +}