Skip to content

Commit

Permalink
apacheGH-33475: [Java] Add parameter binding for Prepared Statements …
Browse files Browse the repository at this point in the history
…in JDBC driver (apache#38404)

This PR is a combination of apache#33961 and apache#14627. The goal is to support parametrized queries through the Arrow Flight SQL JDBC driver.

An Arrow Flight SQL server returns a Schema for the `PreparedStatement` parameters. The driver then converts the `Field` list associated with the Schema into a list of `AvaticaParameter`. When the user sets values for the parameters, Avatica generates a list of `TypedValue`, which we then bind to each parameter vector. This conversion between Arrow and Avatica is handled by implementations of a `AvaticaParameterConverter` interface for each Arrow type. This interface which provides 2 methods:
- createParameter: Create an `AvaticaParameter` from the given Arrow `Field`.
- bindParameter: Cast the given `TypedValue` and bind it to the `FieldVector` at the specified index.

This PR purposely leaves out a few features:
- We currently naively cast the `TypedValue` values assuming users set the type correctly. If this cast fails, we raise an exception letting the user know that the cast is not supported. This could be improved in subsequent PRs to do smarter conversions from other types.
- We currently don't provide conversions for complex types such as List, Map, Struct, Union, Interval, and Duration. The stubs are there so they can be implemented as needed.
- Tests for specific types have not been implemented. I'm not very familiar with a lot of these JDBC types so it's hard to implement rigorous tets.

* Closes: apache#33475
* Closes: apache#35536

Authored-by: Diego Fernandez <[email protected]>
Signed-off-by: David Li <[email protected]>
  • Loading branch information
aiguofer authored and loicalleyne committed Nov 13, 2023
1 parent 1da5d18 commit bd9a646
Show file tree
Hide file tree
Showing 34 changed files with 1,809 additions and 114 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,13 @@

package org.apache.arrow.driver.jdbc;

import static org.apache.arrow.driver.jdbc.utils.ConvertUtils.convertArrowFieldsToColumnMetaDataList;

import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.util.Properties;
import java.util.TimeZone;

import org.apache.arrow.driver.jdbc.client.ArrowFlightSqlClientHandler;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.calcite.avatica.AvaticaConnection;
import org.apache.calcite.avatica.AvaticaFactory;
import org.apache.calcite.avatica.AvaticaResultSetMetaData;
Expand Down Expand Up @@ -89,12 +86,6 @@ public ArrowFlightPreparedStatement newPreparedStatement(
ArrowFlightSqlClientHandler.PreparedStatement preparedStatement =
flightConnection.getMeta().getPreparedStatement(statementHandle);

if (preparedStatement == null) {
preparedStatement = flightConnection.getClientHandler().prepare(signature.sql);
}
final Schema resultSetSchema = preparedStatement.getDataSetSchema();
signature.columns.addAll(convertArrowFieldsToColumnMetaDataList(resultSetSchema.getFields()));

return ArrowFlightPreparedStatement.newPreparedStatement(
flightConnection, preparedStatement, statementHandle,
signature, resultType, resultSetConcurrency, resultSetHoldability);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ static ArrowFlightJdbcFlightStreamResultSet fromFlightInfo(
final TimeZone timeZone = TimeZone.getDefault();
final QueryState state = new QueryState();

final Meta.Signature signature = ArrowFlightMetaImpl.newSignature(null);
final Meta.Signature signature = ArrowFlightMetaImpl.newSignature(null, null, null);

final AvaticaResultSetMetaData resultSetMetaData =
new AvaticaResultSetMetaData(null, null, signature);
Expand Down Expand Up @@ -154,11 +154,7 @@ private void populateDataForCurrentFlightStream() throws SQLException {
currentVectorSchemaRoot = originalRoot;
}

if (schema != null) {
populateData(currentVectorSchemaRoot, schema);
} else {
populateData(currentVectorSchemaRoot);
}
populateData(currentVectorSchemaRoot, schema);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,18 @@

package org.apache.arrow.driver.jdbc;

import static java.util.Objects.isNull;

import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.TimeZone;

import org.apache.arrow.driver.jdbc.utils.ConvertUtils;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.calcite.avatica.AvaticaResultSet;
import org.apache.calcite.avatica.AvaticaResultSetMetaData;
Expand Down Expand Up @@ -74,7 +72,7 @@ public static ArrowFlightJdbcVectorSchemaRootResultSet fromVectorSchemaRoot(
final TimeZone timeZone = TimeZone.getDefault();
final QueryState state = new QueryState();

final Meta.Signature signature = ArrowFlightMetaImpl.newSignature(null);
final Meta.Signature signature = ArrowFlightMetaImpl.newSignature(null, null, null);

final AvaticaResultSetMetaData resultSetMetaData =
new AvaticaResultSetMetaData(null, null, signature);
Expand All @@ -93,17 +91,12 @@ protected AvaticaResultSet execute() throws SQLException {
}

void populateData(final VectorSchemaRoot vectorSchemaRoot) {
final List<Field> fields = vectorSchemaRoot.getSchema().getFields();
final List<ColumnMetaData> columns = ConvertUtils.convertArrowFieldsToColumnMetaDataList(fields);
signature.columns.clear();
signature.columns.addAll(columns);

this.vectorSchemaRoot = vectorSchemaRoot;
execute2(new ArrowFlightJdbcCursor(vectorSchemaRoot), this.signature.columns);
populateData(vectorSchemaRoot, null);
}

void populateData(final VectorSchemaRoot vectorSchemaRoot, final Schema schema) {
final List<ColumnMetaData> columns = ConvertUtils.convertArrowFieldsToColumnMetaDataList(schema.getFields());
Schema currentSchema = schema == null ? vectorSchemaRoot.getSchema() : schema;
final List<ColumnMetaData> columns = ConvertUtils.convertArrowFieldsToColumnMetaDataList(currentSchema.getFields());
signature.columns.clear();
signature.columns.addAll(columns);

Expand Down Expand Up @@ -137,7 +130,7 @@ public void close() {
} catch (final Exception e) {
exceptions.add(e);
}
if (!isNull(statement)) {
if (!Objects.isNull(statement)) {
try {
super.close();
} catch (final Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.arrow.driver.jdbc;

import static java.lang.String.format;

import java.sql.Connection;
import java.sql.SQLException;
import java.sql.SQLTimeoutException;
Expand All @@ -29,7 +27,10 @@
import java.util.concurrent.ConcurrentHashMap;

import org.apache.arrow.driver.jdbc.client.ArrowFlightSqlClientHandler.PreparedStatement;
import org.apache.arrow.driver.jdbc.utils.AvaticaParameterBinder;
import org.apache.arrow.driver.jdbc.utils.ConvertUtils;
import org.apache.arrow.util.Preconditions;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.calcite.avatica.AvaticaConnection;
import org.apache.calcite.avatica.AvaticaParameter;
import org.apache.calcite.avatica.ColumnMetaData;
Expand All @@ -54,12 +55,20 @@ public ArrowFlightMetaImpl(final AvaticaConnection connection) {
setDefaultConnectionProperties();
}

static Signature newSignature(final String sql) {
/**
* Construct a signature.
*/
static Signature newSignature(final String sql, Schema resultSetSchema, Schema parameterSchema) {
List<ColumnMetaData> columnMetaData = resultSetSchema == null ?
new ArrayList<>() : ConvertUtils.convertArrowFieldsToColumnMetaDataList(resultSetSchema.getFields());
List<AvaticaParameter> parameters = parameterSchema == null ?
new ArrayList<>() : ConvertUtils.convertArrowFieldsToAvaticaParameters(parameterSchema.getFields());

return new Signature(
new ArrayList<ColumnMetaData>(),
columnMetaData,
sql,
Collections.<AvaticaParameter>emptyList(),
Collections.<String, Object>emptyMap(),
parameters,
Collections.emptyMap(),
null, // unnecessary, as SQL requests use ArrowFlightJdbcCursor
StatementType.SELECT
);
Expand All @@ -84,23 +93,28 @@ public void commit(final ConnectionHandle connectionHandle) {
public ExecuteResult execute(final StatementHandle statementHandle,
final List<TypedValue> typedValues, final long maxRowCount) {
Preconditions.checkArgument(connection.id.equals(statementHandle.connectionId),
"Connection IDs are not consistent");
"Connection IDs are not consistent");
PreparedStatement preparedStatement = getPreparedStatement(statementHandle);

if (preparedStatement == null) {
throw new IllegalStateException("Prepared statement not found: " + statementHandle);
}


new AvaticaParameterBinder(preparedStatement, ((ArrowFlightConnection) connection).getBufferAllocator())
.bind(typedValues);

if (statementHandle.signature == null) {
// Update query
final StatementHandleKey key = new StatementHandleKey(statementHandle);
PreparedStatement preparedStatement = statementHandlePreparedStatementMap.get(key);
if (preparedStatement == null) {
throw new IllegalStateException("Prepared statement not found: " + statementHandle);
}
long updatedCount = preparedStatement.executeUpdate();
return new ExecuteResult(Collections.singletonList(MetaResultSet.count(statementHandle.connectionId,
statementHandle.id, updatedCount)));
statementHandle.id, updatedCount)));
} else {
// TODO Why is maxRowCount ignored?
return new ExecuteResult(
Collections.singletonList(MetaResultSet.create(
statementHandle.connectionId, statementHandle.id,
true, statementHandle.signature, null)));
Collections.singletonList(MetaResultSet.create(
statementHandle.connectionId, statementHandle.id,
true, statementHandle.signature, null)));
}
}

Expand All @@ -114,7 +128,23 @@ public ExecuteResult execute(final StatementHandle statementHandle,
public ExecuteBatchResult executeBatch(final StatementHandle statementHandle,
final List<List<TypedValue>> parameterValuesList)
throws IllegalStateException {
throw new IllegalStateException("executeBatch not implemented.");
Preconditions.checkArgument(connection.id.equals(statementHandle.connectionId),
"Connection IDs are not consistent");
PreparedStatement preparedStatement = getPreparedStatement(statementHandle);

if (preparedStatement == null) {
throw new IllegalStateException("Prepared statement not found: " + statementHandle);
}

final AvaticaParameterBinder binder = new AvaticaParameterBinder(preparedStatement,
((ArrowFlightConnection) connection).getBufferAllocator());
for (int i = 0; i < parameterValuesList.size(); i++) {
binder.bind(parameterValuesList.get(i), i);
}

// Update query
long[] updatedCounts = {preparedStatement.executeUpdate()};
return new ExecuteBatchResult(updatedCounts);
}

@Override
Expand All @@ -126,18 +156,24 @@ public Frame fetch(final StatementHandle statementHandle, final long offset,
* the results.
*/
throw AvaticaConnection.HELPER.wrap(
format("%s does not use frames.", this),
String.format("%s does not use frames.", this),
AvaticaConnection.HELPER.unsupported());
}

private PreparedStatement prepareForHandle(final String query, StatementHandle handle) {
final PreparedStatement preparedStatement =
((ArrowFlightConnection) connection).getClientHandler().prepare(query);
handle.signature = newSignature(query, preparedStatement.getDataSetSchema(),
preparedStatement.getParameterSchema());
statementHandlePreparedStatementMap.put(new StatementHandleKey(handle), preparedStatement);
return preparedStatement;
}

@Override
public StatementHandle prepare(final ConnectionHandle connectionHandle,
final String query, final long maxRowCount) {
final StatementHandle handle = super.createStatement(connectionHandle);
handle.signature = newSignature(query);
final PreparedStatement preparedStatement =
((ArrowFlightConnection) connection).getClientHandler().prepare(query);
statementHandlePreparedStatementMap.put(new StatementHandleKey(handle), preparedStatement);
prepareForHandle(query, handle);
return handle;
}

Expand All @@ -157,20 +193,18 @@ public ExecuteResult prepareAndExecute(final StatementHandle handle,
final PrepareCallback callback)
throws NoSuchStatementException {
try {
final PreparedStatement preparedStatement =
((ArrowFlightConnection) connection).getClientHandler().prepare(query);
PreparedStatement preparedStatement = prepareForHandle(query, handle);
final StatementType statementType = preparedStatement.getType();
statementHandlePreparedStatementMap.put(new StatementHandleKey(handle), preparedStatement);
final Signature signature = newSignature(query);

final long updateCount =
statementType.equals(StatementType.UPDATE) ? preparedStatement.executeUpdate() : -1;
synchronized (callback.getMonitor()) {
callback.clear();
callback.assign(signature, null, updateCount);
callback.assign(handle.signature, null, updateCount);
}
callback.execute();
final MetaResultSet metaResultSet = MetaResultSet.create(handle.connectionId, handle.id,
false, signature, null);
false, handle.signature, null);
return new ExecuteResult(Collections.singletonList(metaResultSet));
} catch (SQLTimeoutException e) {
// So far AvaticaStatement(executeInternal) only handles NoSuchStatement and Runtime Exceptions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,12 @@

package org.apache.arrow.driver.jdbc;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;

import org.apache.arrow.driver.jdbc.client.ArrowFlightSqlClientHandler;
import org.apache.arrow.driver.jdbc.utils.ConvertUtils;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.util.Preconditions;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.calcite.avatica.AvaticaPreparedStatement;
import org.apache.calcite.avatica.Meta.Signature;
import org.apache.calcite.avatica.Meta.StatementHandle;
Expand All @@ -50,36 +47,6 @@ private ArrowFlightPreparedStatement(final ArrowFlightConnection connection,
this.preparedStatement = Preconditions.checkNotNull(preparedStatement);
}

/**
* Creates a new {@link ArrowFlightPreparedStatement} from the provided information.
*
* @param connection the {@link Connection} to use.
* @param statementHandle the {@link StatementHandle} to use.
* @param signature the {@link Signature} to use.
* @param resultSetType the ResultSet type.
* @param resultSetConcurrency the ResultSet concurrency.
* @param resultSetHoldability the ResultSet holdability.
* @return a new {@link PreparedStatement}.
* @throws SQLException on error.
*/
static ArrowFlightPreparedStatement createNewPreparedStatement(
final ArrowFlightConnection connection,
final StatementHandle statementHandle,
final Signature signature,
final int resultSetType,
final int resultSetConcurrency,
final int resultSetHoldability) throws SQLException {

final ArrowFlightSqlClientHandler.PreparedStatement prepare = connection.getClientHandler().prepare(signature.sql);
final Schema resultSetSchema = prepare.getDataSetSchema();

signature.columns.addAll(ConvertUtils.convertArrowFieldsToColumnMetaDataList(resultSetSchema.getFields()));

return new ArrowFlightPreparedStatement(
connection, prepare, statementHandle,
signature, resultSetType, resultSetConcurrency, resultSetHoldability);
}

static ArrowFlightPreparedStatement newPreparedStatement(final ArrowFlightConnection connection,
final ArrowFlightSqlClientHandler.PreparedStatement preparedStmt,
final StatementHandle statementHandle,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.util.Preconditions;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.calcite.avatica.Meta.StatementType;
import org.slf4j.Logger;
Expand Down Expand Up @@ -206,6 +207,15 @@ public interface PreparedStatement extends AutoCloseable {
*/
Schema getDataSetSchema();

/**
* Gets the {@link Schema} of the parameters for this {@link PreparedStatement}.
*
* @return {@link Schema}.
*/
Schema getParameterSchema();

void setParameters(VectorSchemaRoot parameters);

@Override
void close();
}
Expand Down Expand Up @@ -241,6 +251,16 @@ public Schema getDataSetSchema() {
return preparedStatement.getResultSetSchema();
}

@Override
public Schema getParameterSchema() {
return preparedStatement.getParameterSchema();
}

@Override
public void setParameters(VectorSchemaRoot parameters) {
preparedStatement.setParameters(parameters);
}

@Override
public void close() {
try {
Expand Down
Loading

0 comments on commit bd9a646

Please sign in to comment.