Skip to content

Commit

Permalink
Inject ConnectionFactory in Jdbc based connectors
Browse files Browse the repository at this point in the history
  • Loading branch information
kokosing committed Jul 4, 2019
1 parent 2a632d3 commit 8082baf
Show file tree
Hide file tree
Showing 11 changed files with 146 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.google.inject.Binder;
import com.google.inject.Module;
import com.google.inject.Provides;
import com.google.inject.Singleton;
import org.h2.Driver;

import java.util.Map;
Expand All @@ -35,9 +36,16 @@ public void configure(Binder binder)
}

@Provides
public JdbcClient provideJdbcClient(BaseJdbcConfig config)
public JdbcClient provideJdbcClient(BaseJdbcConfig config, ConnectionFactory connectionFactory)
{
return new BaseJdbcClient(config, "\"", new DriverConnectionFactory(new Driver(), config));
return new BaseJdbcClient(config, "\"", connectionFactory);
}

@Provides
@Singleton
public ConnectionFactory getConnectionFactory(BaseJdbcConfig config)
{
return new DriverConnectionFactory(new Driver(), config);
}

public static Map<String, String> createProperties()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import com.fasterxml.jackson.core.JsonParser;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.collect.ImmutableSet;
import com.mysql.jdbc.Driver;
import com.mysql.jdbc.Statement;
import io.airlift.json.ObjectMapperProvider;
import io.airlift.slice.DynamicSliceOutput;
Expand All @@ -27,7 +26,6 @@
import io.prestosql.plugin.jdbc.BaseJdbcConfig;
import io.prestosql.plugin.jdbc.ColumnMapping;
import io.prestosql.plugin.jdbc.ConnectionFactory;
import io.prestosql.plugin.jdbc.DriverConnectionFactory;
import io.prestosql.plugin.jdbc.JdbcColumnHandle;
import io.prestosql.plugin.jdbc.JdbcIdentity;
import io.prestosql.plugin.jdbc.JdbcTableHandle;
Expand Down Expand Up @@ -55,7 +53,6 @@
import java.sql.SQLException;
import java.util.Collection;
import java.util.Optional;
import java.util.Properties;
import java.util.function.BiFunction;

import static com.fasterxml.jackson.core.JsonFactory.Feature.CANONICALIZE_FIELD_NAMES;
Expand All @@ -66,7 +63,6 @@
import static com.mysql.jdbc.SQLError.SQL_STATE_SYNTAX_ERROR;
import static io.airlift.slice.Slices.utf8Slice;
import static io.prestosql.plugin.jdbc.ColumnMapping.DISABLE_PUSHDOWN;
import static io.prestosql.plugin.jdbc.DriverConnectionFactory.basicConnectionProperties;
import static io.prestosql.plugin.jdbc.JdbcErrorCode.JDBC_ERROR;
import static io.prestosql.plugin.jdbc.StandardColumnMappings.realWriteFunction;
import static io.prestosql.plugin.jdbc.StandardColumnMappings.timestampWriteFunctionUsingSqlTimestamp;
Expand All @@ -91,38 +87,12 @@ public class MySqlClient
private final Type jsonType;

@Inject
public MySqlClient(BaseJdbcConfig config, MySqlConfig mySqlConfig, TypeManager typeManager)
throws SQLException
public MySqlClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, TypeManager typeManager)
{
super(config, "`", connectionFactory(config, mySqlConfig));
super(config, "`", connectionFactory);
this.jsonType = typeManager.getType(new TypeSignature(StandardTypes.JSON));
}

private static ConnectionFactory connectionFactory(BaseJdbcConfig config, MySqlConfig mySqlConfig)
throws SQLException
{
Properties connectionProperties = basicConnectionProperties(config);
connectionProperties.setProperty("useInformationSchema", "true");
connectionProperties.setProperty("nullCatalogMeansCurrent", "false");
connectionProperties.setProperty("useUnicode", "true");
connectionProperties.setProperty("characterEncoding", "utf8");
connectionProperties.setProperty("tinyInt1isBit", "false");
if (mySqlConfig.isAutoReconnect()) {
connectionProperties.setProperty("autoReconnect", String.valueOf(mySqlConfig.isAutoReconnect()));
connectionProperties.setProperty("maxReconnects", String.valueOf(mySqlConfig.getMaxReconnects()));
}
if (mySqlConfig.getConnectionTimeout() != null) {
connectionProperties.setProperty("connectTimeout", String.valueOf(mySqlConfig.getConnectionTimeout().toMillis()));
}

return new DriverConnectionFactory(
new Driver(),
config.getConnectionUrl(),
Optional.ofNullable(config.getUserCredentialName()),
Optional.ofNullable(config.getPasswordCredentialName()),
connectionProperties);
}

@Override
protected Collection<String> listSchemas(Connection connection)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,23 @@
package io.prestosql.plugin.mysql;

import com.google.inject.Binder;
import com.google.inject.Provides;
import com.google.inject.Scopes;
import com.google.inject.Singleton;
import com.mysql.jdbc.Driver;
import io.airlift.configuration.AbstractConfigurationAwareModule;
import io.prestosql.plugin.jdbc.BaseJdbcConfig;
import io.prestosql.plugin.jdbc.ConnectionFactory;
import io.prestosql.plugin.jdbc.DriverConnectionFactory;
import io.prestosql.plugin.jdbc.JdbcClient;

import java.sql.SQLException;
import java.util.Optional;
import java.util.Properties;

import static com.google.common.base.Preconditions.checkArgument;
import static io.airlift.configuration.ConfigBinder.configBinder;
import static io.prestosql.plugin.jdbc.DriverConnectionFactory.basicConnectionProperties;

public class MySqlClientModule
extends AbstractConfigurationAwareModule
Expand All @@ -49,4 +55,31 @@ private static void ensureCatalogIsEmpty(String connectionUrl)
throw new RuntimeException(e);
}
}

@Provides
@Singleton
public static ConnectionFactory createConnectionFactory(BaseJdbcConfig config, MySqlConfig mySqlConfig)
throws SQLException
{
Properties connectionProperties = basicConnectionProperties(config);
connectionProperties.setProperty("useInformationSchema", "true");
connectionProperties.setProperty("nullCatalogMeansCurrent", "false");
connectionProperties.setProperty("useUnicode", "true");
connectionProperties.setProperty("characterEncoding", "utf8");
connectionProperties.setProperty("tinyInt1isBit", "false");
if (mySqlConfig.isAutoReconnect()) {
connectionProperties.setProperty("autoReconnect", String.valueOf(mySqlConfig.isAutoReconnect()));
connectionProperties.setProperty("maxReconnects", String.valueOf(mySqlConfig.getMaxReconnects()));
}
if (mySqlConfig.getConnectionTimeout() != null) {
connectionProperties.setProperty("connectTimeout", String.valueOf(mySqlConfig.getConnectionTimeout().toMillis()));
}

return new DriverConnectionFactory(
new Driver(),
config.getConnectionUrl(),
Optional.ofNullable(config.getUserCredentialName()),
Optional.ofNullable(config.getPasswordCredentialName()),
connectionProperties);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import io.prestosql.plugin.jdbc.BlockReadFunction;
import io.prestosql.plugin.jdbc.BlockWriteFunction;
import io.prestosql.plugin.jdbc.ColumnMapping;
import io.prestosql.plugin.jdbc.DriverConnectionFactory;
import io.prestosql.plugin.jdbc.ConnectionFactory;
import io.prestosql.plugin.jdbc.JdbcColumnHandle;
import io.prestosql.plugin.jdbc.JdbcIdentity;
import io.prestosql.plugin.jdbc.JdbcOutputTableHandle;
Expand All @@ -32,7 +32,6 @@
import io.prestosql.spi.type.ArrayType;
import io.prestosql.spi.type.Type;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hbase.client.Scan;
import org.apache.hadoop.hbase.util.Bytes;
import org.apache.phoenix.compile.QueryPlan;
Expand All @@ -47,7 +46,6 @@
import org.apache.phoenix.iterate.TableResultIterator;
import org.apache.phoenix.jdbc.DelegatePreparedStatement;
import org.apache.phoenix.jdbc.PhoenixConnection;
import org.apache.phoenix.jdbc.PhoenixEmbeddedDriver.ConnectionInfo;
import org.apache.phoenix.jdbc.PhoenixPreparedStatement;
import org.apache.phoenix.jdbc.PhoenixResultSet;
import org.apache.phoenix.mapreduce.PhoenixInputSplit;
Expand All @@ -62,7 +60,6 @@
import java.io.IOException;
import java.sql.Array;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.JDBCType;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
Expand All @@ -71,9 +68,7 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Properties;
import java.util.function.BiFunction;

import static com.google.common.base.MoreObjects.firstNonNull;
Expand All @@ -84,6 +79,7 @@
import static io.prestosql.plugin.jdbc.StandardColumnMappings.timeWriteFunction;
import static io.prestosql.plugin.jdbc.StandardColumnMappings.varcharColumnMapping;
import static io.prestosql.plugin.phoenix.MetadataUtil.toPhoenixSchemaName;
import static io.prestosql.plugin.phoenix.PhoenixClientModule.getConnectionProperties;
import static io.prestosql.plugin.phoenix.PhoenixErrorCode.PHOENIX_METADATA_ERROR;
import static io.prestosql.plugin.phoenix.PhoenixErrorCode.PHOENIX_QUERY_ERROR;
import static io.prestosql.plugin.phoenix.PhoenixMetadata.DEFAULT_SCHEMA;
Expand All @@ -110,7 +106,6 @@
import static java.sql.Types.TIME_WITH_TIMEZONE;
import static java.sql.Types.VARCHAR;
import static java.util.Collections.nCopies;
import static java.util.Objects.requireNonNull;
import static org.apache.phoenix.coprocessor.BaseScannerRegionObserver.SKIP_REGION_BOUNDARY_CHECK;
import static org.apache.phoenix.util.SchemaUtil.ESCAPE_CHARACTER;

Expand All @@ -120,51 +115,16 @@ public class PhoenixClient
private final Configuration configuration;

@Inject
public PhoenixClient(PhoenixConfig config)
throws SQLException
{
this(config, getConnectionProperties(config));
}

private PhoenixClient(PhoenixConfig config, Properties connectionProperties)
public PhoenixClient(PhoenixConfig config, ConnectionFactory connectionFactory)
throws SQLException
{
super(
ESCAPE_CHARACTER,
new DriverConnectionFactory(
DriverManager.getDriver(config.getConnectionUrl()),
config.getConnectionUrl(),
Optional.empty(),
Optional.empty(),
connectionProperties),
connectionFactory,
config.isCaseInsensitiveNameMatching(),
config.getCaseInsensitiveNameMatchingCacheTtl());
this.configuration = new Configuration(false);
connectionProperties.forEach((k, v) -> configuration.set((String) k, (String) v));
}

private static Properties getConnectionProperties(PhoenixConfig config)
throws SQLException
{
requireNonNull(config, "config is null");
Configuration resourcesConfig = readConfig(config);
Properties connectionProperties = new Properties();
for (Entry<String, String> entry : resourcesConfig) {
connectionProperties.put(entry.getKey(), entry.getValue());
}

ConnectionInfo connectionInfo = ConnectionInfo.create(config.getConnectionUrl());
connectionProperties.putAll(connectionInfo.asProps().asMap());
return connectionProperties;
}

private static Configuration readConfig(PhoenixConfig config)
{
Configuration result = new Configuration(false);
for (String resourcePath : config.getResourceConfigFiles()) {
result.addResource(new Path(resourcePath));
}
return result;
getConnectionProperties(config).forEach((k, v) -> configuration.set((String) k, (String) v));
}

public PhoenixConnection getConnection(JdbcIdentity identity)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@
package io.prestosql.plugin.phoenix;

import com.google.inject.Binder;
import com.google.inject.Provides;
import com.google.inject.Scopes;
import com.google.inject.Singleton;
import io.airlift.configuration.AbstractConfigurationAwareModule;
import io.prestosql.plugin.jdbc.ConnectionFactory;
import io.prestosql.plugin.jdbc.DriverConnectionFactory;
import io.prestosql.plugin.jdbc.InternalBaseJdbc;
import io.prestosql.plugin.jdbc.JdbcClient;
import io.prestosql.plugin.jdbc.JdbcPageSinkProvider;
Expand All @@ -25,9 +29,16 @@
import io.prestosql.spi.connector.ConnectorRecordSetProvider;
import io.prestosql.spi.connector.ConnectorSplitManager;
import io.prestosql.spi.type.TypeManager;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.phoenix.jdbc.PhoenixDriver;
import org.apache.phoenix.jdbc.PhoenixEmbeddedDriver;

import java.sql.DriverManager;
import java.sql.SQLException;
import java.util.Map;
import java.util.Optional;
import java.util.Properties;

import static com.google.common.base.Preconditions.checkArgument;
import static io.prestosql.plugin.phoenix.PhoenixErrorCode.PHOENIX_CONFIG_ERROR;
Expand Down Expand Up @@ -73,4 +84,40 @@ private void checkConfiguration(String connectionUrl)
throw new PrestoException(PHOENIX_CONFIG_ERROR, e);
}
}

@Provides
@Singleton
public ConnectionFactory getConnectionFactory(PhoenixConfig config)
throws SQLException
{
return new DriverConnectionFactory(
DriverManager.getDriver(config.getConnectionUrl()),
config.getConnectionUrl(),
Optional.empty(),
Optional.empty(),
getConnectionProperties(config));
}

public static Properties getConnectionProperties(PhoenixConfig config)
throws SQLException
{
Configuration resourcesConfig = readConfig(config);
Properties connectionProperties = new Properties();
for (Map.Entry<String, String> entry : resourcesConfig) {
connectionProperties.put(entry.getKey(), entry.getValue());
}

PhoenixEmbeddedDriver.ConnectionInfo connectionInfo = PhoenixEmbeddedDriver.ConnectionInfo.create(config.getConnectionUrl());
connectionProperties.putAll(connectionInfo.asProps().asMap());
return connectionProperties;
}

private static Configuration readConfig(PhoenixConfig config)
{
Configuration result = new Configuration(false);
for (String resourcePath : config.getResourceConfigFiles()) {
result.addResource(new Path(resourcePath));
}
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import io.prestosql.plugin.jdbc.BlockReadFunction;
import io.prestosql.plugin.jdbc.BlockWriteFunction;
import io.prestosql.plugin.jdbc.ColumnMapping;
import io.prestosql.plugin.jdbc.DriverConnectionFactory;
import io.prestosql.plugin.jdbc.ConnectionFactory;
import io.prestosql.plugin.jdbc.JdbcColumnHandle;
import io.prestosql.plugin.jdbc.JdbcIdentity;
import io.prestosql.plugin.jdbc.JdbcTableHandle;
Expand All @@ -47,7 +47,6 @@
import io.prestosql.spi.type.TypeManager;
import io.prestosql.spi.type.TypeSignature;
import io.prestosql.spi.type.VarcharType;
import org.postgresql.Driver;
import org.postgresql.core.TypeInfo;
import org.postgresql.jdbc.PgConnection;
import org.postgresql.util.PGobject;
Expand Down Expand Up @@ -110,9 +109,13 @@ public class PostgreSqlClient
private final boolean supportArrays;

@Inject
public PostgreSqlClient(BaseJdbcConfig config, PostgreSqlConfig postgreSqlConfig, TypeManager typeManager)
public PostgreSqlClient(
BaseJdbcConfig config,
PostgreSqlConfig postgreSqlConfig,
ConnectionFactory connectionFactory,
TypeManager typeManager)
{
super(config, "\"", new DriverConnectionFactory(new Driver(), config));
super(config, "\"", connectionFactory);
this.jsonType = typeManager.getType(new TypeSignature(StandardTypes.JSON));
this.uuidType = typeManager.getType(new TypeSignature(StandardTypes.UUID));

Expand Down
Loading

0 comments on commit 8082baf

Please sign in to comment.