Skip to content

Commit

Permalink
Ensure module instances are not shared by multiple JDBC catalogs
Browse files Browse the repository at this point in the history
When a JdbcPlugin is instantiated, creating an instance of a module at
that time causes the instance to be shared by all connectors provided by
the plugin. This is problematic when the module extends
AbstractConfigurationAwareModule, as it holds a reference to the
ConfigurationFactory, which is set dynamically during bootstrap. If
catalogs are loaded concurrently, this can lead to situations where a
connector accesses the configuration of another connector.
  • Loading branch information
piotrrzysko committed Nov 7, 2024
1 parent a68ed81 commit 33b6886
Show file tree
Hide file tree
Showing 25 changed files with 35 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import io.trino.spi.type.TypeManager;

import java.util.Map;
import java.util.function.Supplier;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Strings.isNullOrEmpty;
Expand All @@ -36,13 +37,13 @@ public class JdbcConnectorFactory
implements ConnectorFactory
{
private final String name;
private final Module module;
private final Supplier<Module> moduleSupplier;

public JdbcConnectorFactory(String name, Module module)
public JdbcConnectorFactory(String name, Supplier<Module> moduleSupplier)
{
checkArgument(!isNullOrEmpty(name), "name is null or empty");
this.name = name;
this.module = module;
this.moduleSupplier = requireNonNull(moduleSupplier, "moduleSupplier is null");
}

@Override
Expand All @@ -55,7 +56,6 @@ public String getName()
public Connector create(String catalogName, Map<String, String> requiredConfig, ConnectorContext context)
{
requireNonNull(requiredConfig, "requiredConfig is null");
requireNonNull(module, "module is null");
checkStrictSpiVersionMatch(context, this);

Bootstrap app = new Bootstrap(
Expand All @@ -65,7 +65,7 @@ public Connector create(String catalogName, Map<String, String> requiredConfig,
binder -> binder.bind(OpenTelemetry.class).toInstance(context.getOpenTelemetry()),
binder -> binder.bind(CatalogName.class).toInstance(new CatalogName(catalogName)),
new JdbcModule(),
module);
requireNonNull(moduleSupplier.get(), "moduleSupplier.get() is null"));

Injector injector = app
.doNotInitializeLogging()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import io.trino.spi.Plugin;
import io.trino.spi.connector.ConnectorFactory;

import java.util.function.Supplier;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Strings.isNullOrEmpty;
import static io.airlift.configuration.ConfigurationAwareModule.combine;
Expand All @@ -28,23 +30,23 @@ public class JdbcPlugin
implements Plugin
{
private final String name;
private final Module module;
private final Supplier<Module> moduleSupplier;

public JdbcPlugin(String name, Module module)
public JdbcPlugin(String name, Supplier<Module> moduleSupplier)
{
checkArgument(!isNullOrEmpty(name), "name is null or empty");
this.name = name;
this.module = requireNonNull(module, "module is null");
this.moduleSupplier = requireNonNull(moduleSupplier, "moduleSupplier is null");
}

@Override
public Iterable<ConnectorFactory> getConnectorFactories()
{
return ImmutableList.of(new JdbcConnectorFactory(
name,
combine(
() -> combine(
new CredentialProviderModule(),
new ExtraCredentialsBasedIdentityCacheMappingModule(),
module)));
moduleSupplier.get())));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ public static QueryRunner createH2QueryRunner(

createSchema(properties, "tpch");

queryRunner.installPlugin(new JdbcPlugin("base_jdbc", module));
queryRunner.installPlugin(new JdbcPlugin("base_jdbc", () -> module));
queryRunner.createCatalog("jdbc", "base_jdbc", properties);

copyTpchTables(queryRunner, "tpch", TINY_SCHEMA_NAME, tables);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public void testRuleBasedIdentifierCanBeUsedTogetherWithCacheBased()

private static ConnectorFactory getConnectorFactory()
{
Plugin plugin = new JdbcPlugin("jdbc", new TestingH2JdbcModule());
Plugin plugin = new JdbcPlugin("jdbc", TestingH2JdbcModule::new);
return getOnlyElement(plugin.getConnectorFactories());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public class TestJmxStats
public void testJmxStatsExposure()
throws Exception
{
Plugin plugin = new JdbcPlugin("base_jdbc", new TestingH2JdbcModule());
Plugin plugin = new JdbcPlugin("base_jdbc", TestingH2JdbcModule::new);
ConnectorFactory factory = getOnlyElement(plugin.getConnectorFactories());
factory.create(
"test",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ public class ClickHousePlugin
{
public ClickHousePlugin()
{
super("clickhouse", new ClickHouseClientModule());
super("clickhouse", ClickHouseClientModule::new);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ public class DruidJdbcPlugin
{
public DruidJdbcPlugin()
{
super("druid", new DruidJdbcClientModule());
super("druid", DruidJdbcClientModule::new);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ public class ExamplePlugin
{
public ExamplePlugin()
{
super("example_jdbc", new ExampleClientModule());
super("example_jdbc", ExampleClientModule::new);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ public class ExasolPlugin
{
public ExasolPlugin()
{
super("exasol", new ExasolClientModule());
super("exasol", ExasolClientModule::new);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ public class IgnitePlugin
{
public IgnitePlugin()
{
super("ignite", new IgniteClientModule());
super("ignite", IgniteClientModule::new);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ public class MariaDbPlugin
{
public MariaDbPlugin()
{
super("mariadb", new MariaDbClientModule());
super("mariadb", MariaDbClientModule::new);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ public class MySqlPlugin
{
public MySqlPlugin()
{
super("mysql", new MySqlClientModule());
super("mysql", MySqlClientModule::new);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ public class OraclePlugin
{
public OraclePlugin()
{
super("oracle", new OracleClientModule());
super("oracle", OracleClientModule::new);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@ public class PostgreSqlPlugin
{
public PostgreSqlPlugin()
{
super("postgresql", combine(new PostgreSqlClientModule(), new PostgreSqlConnectionFactoryModule()));
super("postgresql", () -> combine(new PostgreSqlClientModule(), new PostgreSqlConnectionFactoryModule()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ protected QueryRunner createQueryRunner()
.setAdditionalSetup(runner -> {
runner.installPlugin(new JdbcPlugin(
"counting_postgresql",
combine(new PostgreSqlClientModule(), new TestingPostgreSqlModule(connectionFactory))));
() -> combine(new PostgreSqlClientModule(), new TestingPostgreSqlModule(connectionFactory))));
runner.createCatalog("counting_postgresql", "counting_postgresql", ImmutableMap.of(
"connection-url", postgreSqlServer.getJdbcUrl(),
"connection-user", postgreSqlServer.getUser(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ protected QueryRunner createQueryRunner()
.setAdditionalSetup(runner -> {
runner.installPlugin(new JdbcPlugin(
"counting_postgresql",
combine(new PostgreSqlClientModule(), new TestingPostgreSqlModule(connectionFactory))));
() -> combine(new PostgreSqlClientModule(), new TestingPostgreSqlModule(connectionFactory))));
runner.createCatalog("counting_postgresql", "counting_postgresql", ImmutableMap.of(
"connection-url", postgreSqlServer.getJdbcUrl(),
"connection-user", postgreSqlServer.getUser(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ public class RedshiftPlugin
{
public RedshiftPlugin()
{
super("redshift", new RedshiftClientModule());
super("redshift", RedshiftClientModule::new);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@ public class SingleStorePlugin
{
public SingleStorePlugin()
{
super("singlestore", new SingleStoreClientModule());
super("singlestore", SingleStoreClientModule::new);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ public class SnowflakePlugin
{
public SnowflakePlugin()
{
super("snowflake", new SnowflakeClientModule());
super("snowflake", SnowflakeClientModule::new);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@ public class SqlServerPlugin
{
public SqlServerPlugin()
{
super("sqlserver", combine(new SqlServerClientModule(), new SqlServerConnectionFactoryModule()));
super("sqlserver", () -> combine(new SqlServerClientModule(), new SqlServerConnectionFactoryModule()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ protected QueryRunner createQueryRunner()
.setAdditionalSetup(runner -> {
runner.installPlugin(new JdbcPlugin(
"counting_sqlserver",
combine(new SqlServerClientModule(), new TestingSqlServerModule(connectionFactory))));
() -> combine(new SqlServerClientModule(), new TestingSqlServerModule(connectionFactory))));
runner.createCatalog("counting_sqlserver", "counting_sqlserver", ImmutableMap.of(
"connection-url", sqlServer.getJdbcUrl(),
"connection-user", sqlServer.getUsername(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ protected QueryRunner createQueryRunner()
.setAdditionalSetup(runner -> {
runner.installPlugin(new JdbcPlugin(
"counting_sqlserver",
combine(new SqlServerClientModule(), new TestingSqlServerModule(connectionFactory))));
() -> combine(new SqlServerClientModule(), new TestingSqlServerModule(connectionFactory))));
runner.createCatalog("counting_sqlserver", "counting_sqlserver", ImmutableMap.of(
"connection-url", sqlServer.getJdbcUrl(),
"connection-user", sqlServer.getUsername(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ public class VerticaPlugin
{
public VerticaPlugin()
{
super("vertica", new VerticaClientModule());
super("vertica", VerticaClientModule::new);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ public ScalarFunctionImplementation getScalarFunctionImplementation(FunctionId f
}))
.build()));
queryRunner.createCatalog("mock", "mock");
queryRunner.installPlugin(new JdbcPlugin("base_jdbc", new TestingH2JdbcModule()));
queryRunner.installPlugin(new JdbcPlugin("base_jdbc", TestingH2JdbcModule::new));
queryRunner.createCatalog("jdbc", "base_jdbc", TestingH2JdbcModule.createProperties());
for (String tableName : ImmutableList.of("orders", "nation", "region", "lineitem")) {
queryRunner.execute(format("CREATE TABLE %1$s AS SELECT * FROM tpch.tiny.%1$s WITH NO DATA", tableName));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ protected void configureCatalog(QueryRunner queryRunner)
queryRunner.installPlugin(new TpchPlugin());
queryRunner.createCatalog("tpch", "tpch", ImmutableMap.of());

queryRunner.installPlugin(new JdbcPlugin("base_jdbc", new TestingH2JdbcModule()));
queryRunner.installPlugin(new JdbcPlugin("base_jdbc", TestingH2JdbcModule::new));
Map<String, String> jdbcConfigurationProperties = TestingH2JdbcModule.createProperties();
queryRunner.createCatalog("jdbc", "base_jdbc", jdbcConfigurationProperties);

Expand Down

0 comments on commit 33b6886

Please sign in to comment.