Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-33475: [Java] Add parameter binding for Prepared Statements in JDBC driver #38404

Merged
merged 23 commits into from
Nov 4, 2023
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,13 @@

package org.apache.arrow.driver.jdbc;

import static org.apache.arrow.driver.jdbc.utils.ConvertUtils.convertArrowFieldsToColumnMetaDataList;

import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.util.Properties;
import java.util.TimeZone;

import org.apache.arrow.driver.jdbc.client.ArrowFlightSqlClientHandler;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.calcite.avatica.AvaticaConnection;
import org.apache.calcite.avatica.AvaticaFactory;
import org.apache.calcite.avatica.AvaticaResultSetMetaData;
Expand Down Expand Up @@ -89,12 +86,6 @@ public ArrowFlightPreparedStatement newPreparedStatement(
ArrowFlightSqlClientHandler.PreparedStatement preparedStatement =
flightConnection.getMeta().getPreparedStatement(statementHandle);

if (preparedStatement == null) {
preparedStatement = flightConnection.getClientHandler().prepare(signature.sql);
}
final Schema resultSetSchema = preparedStatement.getDataSetSchema();
signature.columns.addAll(convertArrowFieldsToColumnMetaDataList(resultSetSchema.getFields()));

aiguofer marked this conversation as resolved.
Show resolved Hide resolved
return ArrowFlightPreparedStatement.newPreparedStatement(
flightConnection, preparedStatement, statementHandle,
signature, resultType, resultSetConcurrency, resultSetHoldability);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ static ArrowFlightJdbcFlightStreamResultSet fromFlightInfo(
final TimeZone timeZone = TimeZone.getDefault();
final QueryState state = new QueryState();

final Meta.Signature signature = ArrowFlightMetaImpl.newSignature(null);
final Meta.Signature signature = ArrowFlightMetaImpl.newSignature(null, null, null);

final AvaticaResultSetMetaData resultSetMetaData =
new AvaticaResultSetMetaData(null, null, signature);
Expand Down Expand Up @@ -153,11 +153,7 @@ private void executeForCurrentFlightStream() throws SQLException {
currentVectorSchemaRoot = originalRoot;
}

if (schema != null) {
execute(currentVectorSchemaRoot, schema);
} else {
execute(currentVectorSchemaRoot);
}
execute(currentVectorSchemaRoot, schema);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,18 @@

package org.apache.arrow.driver.jdbc;

import static java.util.Objects.isNull;

import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLException;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.TimeZone;

import org.apache.arrow.driver.jdbc.utils.ConvertUtils;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.calcite.avatica.AvaticaResultSet;
import org.apache.calcite.avatica.AvaticaResultSetMetaData;
Expand Down Expand Up @@ -74,7 +72,7 @@ public static ArrowFlightJdbcVectorSchemaRootResultSet fromVectorSchemaRoot(
final TimeZone timeZone = TimeZone.getDefault();
final QueryState state = new QueryState();

final Meta.Signature signature = ArrowFlightMetaImpl.newSignature(null);
final Meta.Signature signature = ArrowFlightMetaImpl.newSignature(null, null, null);

final AvaticaResultSetMetaData resultSetMetaData =
new AvaticaResultSetMetaData(null, null, signature);
Expand All @@ -93,17 +91,13 @@ protected AvaticaResultSet execute() throws SQLException {
}

void execute(final VectorSchemaRoot vectorSchemaRoot) {
final List<Field> fields = vectorSchemaRoot.getSchema().getFields();
final List<ColumnMetaData> columns = ConvertUtils.convertArrowFieldsToColumnMetaDataList(fields);
signature.columns.clear();
signature.columns.addAll(columns);

this.vectorSchemaRoot = vectorSchemaRoot;
execute(vectorSchemaRoot, vectorSchemaRoot.getSchema());
execute2(new ArrowFlightJdbcCursor(vectorSchemaRoot), this.signature.columns);
}

void execute(final VectorSchemaRoot vectorSchemaRoot, final Schema schema) {
final List<ColumnMetaData> columns = ConvertUtils.convertArrowFieldsToColumnMetaDataList(schema.getFields());
Schema currentSchema = schema == null ? vectorSchemaRoot.getSchema() : schema;
final List<ColumnMetaData> columns = ConvertUtils.convertArrowFieldsToColumnMetaDataList(currentSchema.getFields());
signature.columns.clear();
signature.columns.addAll(columns);

Expand Down Expand Up @@ -137,7 +131,7 @@ public void close() {
} catch (final Exception e) {
exceptions.add(e);
}
if (!isNull(statement)) {
if (!Objects.isNull(statement)) {
try {
super.close();
} catch (final Exception e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

package org.apache.arrow.driver.jdbc;

import static java.lang.String.format;

import java.sql.Connection;
import java.sql.SQLException;
import java.sql.SQLTimeoutException;
Expand All @@ -29,7 +27,10 @@
import java.util.concurrent.ConcurrentHashMap;

import org.apache.arrow.driver.jdbc.client.ArrowFlightSqlClientHandler.PreparedStatement;
import org.apache.arrow.driver.jdbc.utils.ConvertUtils;
import org.apache.arrow.driver.jdbc.utils.TypedValueBinder;
import org.apache.arrow.util.Preconditions;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.calcite.avatica.AvaticaConnection;
import org.apache.calcite.avatica.AvaticaParameter;
import org.apache.calcite.avatica.ColumnMetaData;
Expand All @@ -54,12 +55,20 @@ public ArrowFlightMetaImpl(final AvaticaConnection connection) {
setDefaultConnectionProperties();
}

static Signature newSignature(final String sql) {
/**
* Construct a signature.
*/
static Signature newSignature(final String sql, Schema resultSetSchema, Schema parameterSchema) {
List<ColumnMetaData> columnMetaData = resultSetSchema == null ?
new ArrayList<>() : ConvertUtils.convertArrowFieldsToColumnMetaDataList(resultSetSchema.getFields());
List<AvaticaParameter> parameters = parameterSchema == null ?
new ArrayList<>() : ConvertUtils.convertArrowFieldsToAvaticaParameters(parameterSchema.getFields());

return new Signature(
new ArrayList<ColumnMetaData>(),
columnMetaData,
sql,
Collections.<AvaticaParameter>emptyList(),
Collections.<String, Object>emptyMap(),
parameters,
Collections.emptyMap(),
null, // unnecessary, as SQL requests use ArrowFlightJdbcCursor
StatementType.SELECT
);
Expand All @@ -84,23 +93,29 @@ public void commit(final ConnectionHandle connectionHandle) {
public ExecuteResult execute(final StatementHandle statementHandle,
final List<TypedValue> typedValues, final long maxRowCount) {
Preconditions.checkArgument(connection.id.equals(statementHandle.connectionId),
"Connection IDs are not consistent");
"Connection IDs are not consistent");
final StatementHandleKey key = new StatementHandleKey(statementHandle);
PreparedStatement preparedStatement = statementHandlePreparedStatementMap.get(key);

if (preparedStatement == null) {
throw new IllegalStateException("Prepared statement not found: " + statementHandle);
}

final TypedValueBinder binder =
new TypedValueBinder(preparedStatement, ((ArrowFlightConnection) connection).getBufferAllocator());
binder.bind(typedValues);

if (statementHandle.signature == null) {
// Update query
final StatementHandleKey key = new StatementHandleKey(statementHandle);
PreparedStatement preparedStatement = statementHandlePreparedStatementMap.get(key);
if (preparedStatement == null) {
throw new IllegalStateException("Prepared statement not found: " + statementHandle);
}
long updatedCount = preparedStatement.executeUpdate();
return new ExecuteResult(Collections.singletonList(MetaResultSet.count(statementHandle.connectionId,
statementHandle.id, updatedCount)));
statementHandle.id, updatedCount)));
} else {
// TODO Why is maxRowCount ignored?
return new ExecuteResult(
Collections.singletonList(MetaResultSet.create(
statementHandle.connectionId, statementHandle.id,
true, statementHandle.signature, null)));
Collections.singletonList(MetaResultSet.create(
statementHandle.connectionId, statementHandle.id,
true, statementHandle.signature, null)));
}
}

Expand All @@ -126,18 +141,24 @@ public Frame fetch(final StatementHandle statementHandle, final long offset,
* the results.
*/
throw AvaticaConnection.HELPER.wrap(
format("%s does not use frames.", this),
String.format("%s does not use frames.", this),
AvaticaConnection.HELPER.unsupported());
}

private PreparedStatement prepareForHandle(final String query, StatementHandle handle) {
final PreparedStatement preparedStatement =
((ArrowFlightConnection) connection).getClientHandler().prepare(query);
handle.signature = newSignature(query, preparedStatement.getDataSetSchema(),
preparedStatement.getParameterSchema());
statementHandlePreparedStatementMap.put(new StatementHandleKey(handle), preparedStatement);
return preparedStatement;
}

@Override
public StatementHandle prepare(final ConnectionHandle connectionHandle,
final String query, final long maxRowCount) {
final StatementHandle handle = super.createStatement(connectionHandle);
handle.signature = newSignature(query);
final PreparedStatement preparedStatement =
((ArrowFlightConnection) connection).getClientHandler().prepare(query);
statementHandlePreparedStatementMap.put(new StatementHandleKey(handle), preparedStatement);
prepareForHandle(query, handle);
return handle;
}

Expand All @@ -157,20 +178,18 @@ public ExecuteResult prepareAndExecute(final StatementHandle handle,
final PrepareCallback callback)
throws NoSuchStatementException {
try {
final PreparedStatement preparedStatement =
((ArrowFlightConnection) connection).getClientHandler().prepare(query);
PreparedStatement preparedStatement = prepareForHandle(query, handle);
final StatementType statementType = preparedStatement.getType();
statementHandlePreparedStatementMap.put(new StatementHandleKey(handle), preparedStatement);
final Signature signature = newSignature(query);

final long updateCount =
statementType.equals(StatementType.UPDATE) ? preparedStatement.executeUpdate() : -1;
synchronized (callback.getMonitor()) {
callback.clear();
callback.assign(signature, null, updateCount);
callback.assign(handle.signature, null, updateCount);
}
callback.execute();
final MetaResultSet metaResultSet = MetaResultSet.create(handle.connectionId, handle.id,
false, signature, null);
false, handle.signature, null);
return new ExecuteResult(Collections.singletonList(metaResultSet));
} catch (SQLTimeoutException e) {
// So far AvaticaStatement(executeInternal) only handles NoSuchStatement and Runtime Exceptions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,12 @@

package org.apache.arrow.driver.jdbc;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;

import org.apache.arrow.driver.jdbc.client.ArrowFlightSqlClientHandler;
import org.apache.arrow.driver.jdbc.utils.ConvertUtils;
import org.apache.arrow.flight.FlightInfo;
import org.apache.arrow.util.Preconditions;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.calcite.avatica.AvaticaPreparedStatement;
import org.apache.calcite.avatica.Meta.Signature;
import org.apache.calcite.avatica.Meta.StatementHandle;
Expand All @@ -50,36 +47,6 @@ private ArrowFlightPreparedStatement(final ArrowFlightConnection connection,
this.preparedStatement = Preconditions.checkNotNull(preparedStatement);
}

/**
* Creates a new {@link ArrowFlightPreparedStatement} from the provided information.
*
* @param connection the {@link Connection} to use.
* @param statementHandle the {@link StatementHandle} to use.
* @param signature the {@link Signature} to use.
* @param resultSetType the ResultSet type.
* @param resultSetConcurrency the ResultSet concurrency.
* @param resultSetHoldability the ResultSet holdability.
* @return a new {@link PreparedStatement}.
* @throws SQLException on error.
*/
static ArrowFlightPreparedStatement createNewPreparedStatement(
final ArrowFlightConnection connection,
final StatementHandle statementHandle,
final Signature signature,
final int resultSetType,
final int resultSetConcurrency,
final int resultSetHoldability) throws SQLException {

final ArrowFlightSqlClientHandler.PreparedStatement prepare = connection.getClientHandler().prepare(signature.sql);
final Schema resultSetSchema = prepare.getDataSetSchema();

signature.columns.addAll(ConvertUtils.convertArrowFieldsToColumnMetaDataList(resultSetSchema.getFields()));

return new ArrowFlightPreparedStatement(
connection, prepare, statementHandle,
signature, resultSetType, resultSetConcurrency, resultSetHoldability);
}

aiguofer marked this conversation as resolved.
Show resolved Hide resolved
static ArrowFlightPreparedStatement newPreparedStatement(final ArrowFlightConnection connection,
final ArrowFlightSqlClientHandler.PreparedStatement preparedStmt,
final StatementHandle statementHandle,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.util.Preconditions;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.calcite.avatica.Meta.StatementType;
import org.slf4j.Logger;
Expand Down Expand Up @@ -155,6 +156,15 @@ public interface PreparedStatement extends AutoCloseable {
*/
Schema getDataSetSchema();

/**
* Gets the {@link Schema} of the parameters for this {@link PreparedStatement}.
*
* @return {@link Schema}.
*/
Schema getParameterSchema();

void setParameters(VectorSchemaRoot parameters);

@Override
void close();
}
Expand Down Expand Up @@ -190,6 +200,16 @@ public Schema getDataSetSchema() {
return preparedStatement.getResultSetSchema();
}

@Override
public Schema getParameterSchema() {
return preparedStatement.getParameterSchema();
}

@Override
public void setParameters(VectorSchemaRoot parameters) {
preparedStatement.setParameters(parameters);
}

@Override
public void close() {
try {
Expand Down
Loading
Loading