diff --git a/java/flight/flight-core/pom.xml b/java/flight/flight-core/pom.xml index 3205cb222db95..70618f067b19f 100644 --- a/java/flight/flight-core/pom.xml +++ b/java/flight/flight-core/pom.xml @@ -28,6 +28,11 @@ + + io.dropwizard.metrics + metrics-core + 4.2.16 + org.apache.arrow arrow-format diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java index 03ce13c9780e3..dd4ea730224ad 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightStream.java @@ -28,6 +28,8 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.LinkedBlockingQueue; +import com.codahale.metrics.MetricRegistry; +import com.codahale.metrics.Timer; import org.apache.arrow.flight.ArrowMessage.HeaderType; import org.apache.arrow.flight.grpc.StatusUtils; import org.apache.arrow.memory.ArrowBuf; @@ -51,6 +53,8 @@ import io.grpc.stub.StreamObserver; +import static com.codahale.metrics.MetricRegistry.name; + /** * An adaptor between protobuf streams and flight data streams. */ @@ -85,6 +89,13 @@ public class FlightStream implements AutoCloseable { @VisibleForTesting volatile MetadataVersion metadataVersion = null; + public static final MetricRegistry metrics = new MetricRegistry(); + + private static final Timer fsNext = metrics.timer(name(FlightStream.class, "fsNext")); + private static final Timer fsTake = metrics.timer(name(FlightStream.class, "fsTake")); + private static final Timer fsRb = metrics.timer(name(FlightStream.class, "fsRb")); + private static final Timer fsDict = metrics.timer(name(FlightStream.class, "fsDict")); + /** * Constructs a new instance. * @@ -220,76 +231,87 @@ public void close() throws Exception { * @return Whether or not more data was found. */ public boolean next() { - try { - if (completed.isDone() && queue.isEmpty()) { - return false; - } + try (final Timer.Context context = fsNext.time()) { + try { + if (completed.isDone() && queue.isEmpty()) { + return false; + } - pending--; - requestOutstanding(); + pending--; + requestOutstanding(); - Object data = queue.take(); - if (DONE == data) { - queue.put(DONE); - // Other code ignores the value of this CompletableFuture, only whether it's completed (or has an exception) - completed.complete(null); - return false; - } else if (DONE_EX == data) { - queue.put(DONE_EX); - if (ex instanceof Exception) { - throw (Exception) ex; - } else { - throw new Exception(ex); + Object data; + try (final Timer.Context takeCtx = fsTake.time()) { +// System.out.format("Trying to take @ %dms\n", System.currentTimeMillis()); + data = queue.take(); +// System.out.format("Took @ %dms\n", System.currentTimeMillis()); } - } else { - try (ArrowMessage msg = ((ArrowMessage) data)) { - if (msg.getMessageType() == HeaderType.NONE) { - updateMetadata(msg); - // We received a message without data, so erase any leftover data - if (fulfilledRoot != null) { - fulfilledRoot.clear(); - } - } else if (msg.getMessageType() == HeaderType.RECORD_BATCH) { - checkMetadataVersion(msg); - // Ensure we have the root - root.get().clear(); - try (ArrowRecordBatch arb = msg.asRecordBatch()) { - loader.load(arb); - } - updateMetadata(msg); - } else if (msg.getMessageType() == HeaderType.DICTIONARY_BATCH) { - checkMetadataVersion(msg); - // Ensure we have the root - root.get().clear(); - try (ArrowDictionaryBatch arb = msg.asDictionaryBatch()) { - final long id = arb.getDictionaryId(); - if (dictionaries == null) { - throw new IllegalStateException("Dictionary ownership was claimed by the application."); + if (DONE == data) { + queue.put(DONE); + // Other code ignores the value of this CompletableFuture, only whether it's completed (or has an exception) + completed.complete(null); + return false; + } else if (DONE_EX == data) { + queue.put(DONE_EX); + if (ex instanceof Exception) { + throw (Exception) ex; + } else { + throw new Exception(ex); + } + } else { + try (ArrowMessage msg = ((ArrowMessage) data)) { + if (msg.getMessageType() == HeaderType.NONE) { + updateMetadata(msg); + // We received a message without data, so erase any leftover data + if (fulfilledRoot != null) { + fulfilledRoot.clear(); } - final Dictionary dictionary = dictionaries.lookup(id); - if (dictionary == null) { - throw new IllegalArgumentException("Dictionary not defined in schema: ID " + id); + } else if (msg.getMessageType() == HeaderType.RECORD_BATCH) { + try (final Timer.Context rbCtx = fsRb.time()) { + checkMetadataVersion(msg); + // Ensure we have the root + root.get().clear(); + try (ArrowRecordBatch arb = msg.asRecordBatch()) { + loader.load(arb); + } + updateMetadata(msg); } - - final FieldVector vector = dictionary.getVector(); - final VectorSchemaRoot dictionaryRoot = new VectorSchemaRoot(Collections.singletonList(vector.getField()), - Collections.singletonList(vector), 0); - final VectorLoader dictionaryLoader = new VectorLoader(dictionaryRoot); - dictionaryLoader.load(arb.getDictionary()); + } else if (msg.getMessageType() == HeaderType.DICTIONARY_BATCH) { + try (final Timer.Context dictCtx = fsRb.time()) { + checkMetadataVersion(msg); + // Ensure we have the root + root.get().clear(); + try (ArrowDictionaryBatch arb = msg.asDictionaryBatch()) { + final long id = arb.getDictionaryId(); + if (dictionaries == null) { + throw new IllegalStateException("Dictionary ownership was claimed by the application."); + } + final Dictionary dictionary = dictionaries.lookup(id); + if (dictionary == null) { + throw new IllegalArgumentException("Dictionary not defined in schema: ID " + id); + } + + final FieldVector vector = dictionary.getVector(); + final VectorSchemaRoot dictionaryRoot = new VectorSchemaRoot(Collections.singletonList(vector.getField()), + Collections.singletonList(vector), 0); + final VectorLoader dictionaryLoader = new VectorLoader(dictionaryRoot); + dictionaryLoader.load(arb.getDictionary()); + } + return next(); + } + } else { + throw new UnsupportedOperationException("Message type is unsupported: " + msg.getMessageType()); } - return next(); - } else { - throw new UnsupportedOperationException("Message type is unsupported: " + msg.getMessageType()); + return true; } - return true; } + } catch (RuntimeException e) { + throw e; + } catch (ExecutionException e) { + throw StatusUtils.fromThrowable(e.getCause()); + } catch (Exception e) { + throw new RuntimeException(e); } - } catch (RuntimeException e) { - throw e; - } catch (ExecutionException e) { - throw StatusUtils.fromThrowable(e.getCause()); - } catch (Exception e) { - throw new RuntimeException(e); } } @@ -383,6 +405,7 @@ private void enqueue(AutoCloseable message) { @Override public void onNext(ArrowMessage msg) { +// System.out.format("FlightStream.onNext() @ %dms\n", System.currentTimeMillis()); // Operations here have to be under a lock so that we don't add a message to the queue while in the middle of // close(). requestOutstanding(); diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java index 4c01cb6e5813c..1eb16dbe020a1 100644 --- a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcFlightStreamResultSet.java @@ -17,7 +17,9 @@ package org.apache.arrow.driver.jdbc; +import static com.codahale.metrics.MetricRegistry.name; import static org.apache.arrow.driver.jdbc.utils.FlightStreamQueue.createNewQueue; +import static org.apache.arrow.flight.FlightStream.metrics; import java.sql.ResultSet; import java.sql.ResultSetMetaData; @@ -26,10 +28,12 @@ import java.util.TimeZone; import java.util.concurrent.TimeUnit; +import com.codahale.metrics.Timer; import org.apache.arrow.driver.jdbc.utils.FlightStreamQueue; import org.apache.arrow.driver.jdbc.utils.VectorSchemaRootTransformer; import org.apache.arrow.flight.FlightInfo; import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.sql.FlightSqlClient; import org.apache.arrow.util.AutoCloseables; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.types.pojo.Schema; @@ -46,6 +50,8 @@ public final class ArrowFlightJdbcFlightStreamResultSet extends ArrowFlightJdbcVectorSchemaRootResultSet { + private static final Timer nextStream = metrics.timer(name(FlightSqlClient.class, "nextStream")); + private final ArrowFlightConnection connection; private FlightStream currentFlightStream; private FlightStreamQueue flightStreamQueue; @@ -55,6 +61,10 @@ public final class ArrowFlightJdbcFlightStreamResultSet private Schema schema; + private static final Timer rsExec = metrics.timer(name(FlightSqlClient.class, "rsExec")); + private static final Timer getStreams = metrics.timer(name(FlightSqlClient.class, "getStreams")); + private static final Timer getNextFlightStream = metrics.timer(name(FlightStreamQueue.class, "getNextFlightStream")); + ArrowFlightJdbcFlightStreamResultSet(final AvaticaStatement statement, final QueryState state, final Meta.Signature signature, @@ -112,31 +122,40 @@ private void loadNewQueue() { } private void loadNewFlightStream() throws SQLException { - if (currentFlightStream != null) { - AutoCloseables.closeNoChecked(currentFlightStream); + try(final Timer.Context context = nextStream.time()) { + if (currentFlightStream != null) { + AutoCloseables.closeNoChecked(currentFlightStream); + } + this.currentFlightStream = getNextFlightStream(true); } - this.currentFlightStream = getNextFlightStream(true); } @Override protected AvaticaResultSet execute() throws SQLException { - final FlightInfo flightInfo = ((ArrowFlightInfoStatement) statement).executeFlightInfoQuery(); - - if (flightInfo != null) { - schema = flightInfo.getSchema(); - execute(flightInfo); + try(final Timer.Context context = rsExec.time()) { +// System.out.format("Getting FlightInfo @ %dms\n", System.currentTimeMillis()); + final FlightInfo flightInfo = ((ArrowFlightInfoStatement) statement).executeFlightInfoQuery(); + + if (flightInfo != null) { + schema = flightInfo.getSchema(); +// System.out.format("Getting FlightDatas @ %dms\n", System.currentTimeMillis()); + execute(flightInfo); + } +// System.out.format("Got FlightDatas @ %dms\n", System.currentTimeMillis()); + return this; } - return this; } private void execute(final FlightInfo flightInfo) throws SQLException { - loadNewQueue(); - flightStreamQueue.enqueue(connection.getClientHandler().getStreams(flightInfo)); - loadNewFlightStream(); + try(final Timer.Context context = getStreams.time()) { + loadNewQueue(); + flightStreamQueue.enqueue(connection.getClientHandler().getStreams(flightInfo)); + loadNewFlightStream(); - // Ownership of the root will be passed onto the cursor. - if (currentFlightStream != null) { - executeForCurrentFlightStream(); + // Ownership of the root will be passed onto the cursor. + if (currentFlightStream != null) { + executeForCurrentFlightStream(); + } } } @@ -239,12 +258,14 @@ public synchronized void close() { } private FlightStream getNextFlightStream(final boolean isExecution) throws SQLException { - if (isExecution) { - final int statementTimeout = statement != null ? statement.getQueryTimeout() : 0; - return statementTimeout != 0 ? - flightStreamQueue.next(statementTimeout, TimeUnit.SECONDS) : flightStreamQueue.next(); - } else { - return flightStreamQueue.next(); + try (final Timer.Context context = getNextFlightStream.time()) { + if (isExecution) { + final int statementTimeout = statement != null ? statement.getQueryTimeout() : 0; + return statementTimeout != 0 ? + flightStreamQueue.next(statementTimeout, TimeUnit.SECONDS) : flightStreamQueue.next(); + } else { + return flightStreamQueue.next(); + } } } } diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java index f825e7d13cef5..29dc703a96c5f 100644 --- a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java @@ -58,7 +58,7 @@ static Signature newSignature(final String sql) { return new Signature( new ArrayList(), sql, - Collections.emptyList(), + new ArrayList(), Collections.emptyMap(), null, // unnecessary, as SQL requests use ArrowFlightJdbcCursor StatementType.SELECT diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatement.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatement.java index 80029f38f0958..a94faa49563e5 100644 --- a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatement.java +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatement.java @@ -17,19 +17,34 @@ package org.apache.arrow.driver.jdbc; +import java.math.BigDecimal; import java.sql.Connection; +import java.sql.JDBCType; import java.sql.PreparedStatement; import java.sql.SQLException; +import java.util.ArrayList; +import java.util.Date; +import java.util.List; +import com.google.flatbuffers.LongVector; import org.apache.arrow.driver.jdbc.client.ArrowFlightSqlClientHandler; import org.apache.arrow.driver.jdbc.utils.ConvertUtils; import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.*; import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.Text; +import org.apache.calcite.avatica.AvaticaParameter; import org.apache.calcite.avatica.AvaticaPreparedStatement; +import org.apache.calcite.avatica.Meta; import org.apache.calcite.avatica.Meta.Signature; import org.apache.calcite.avatica.Meta.StatementHandle; - +import org.apache.calcite.avatica.QueryState; +import org.apache.calcite.avatica.remote.TypedValue; +import org.joda.time.DateTime; +import org.joda.time.format.DateTimeFormat; /** * Arrow Flight JBCS's implementation {@link PreparedStatement}. @@ -72,8 +87,10 @@ static ArrowFlightPreparedStatement createNewPreparedStatement( final ArrowFlightSqlClientHandler.PreparedStatement prepare = connection.getClientHandler().prepare(signature.sql); final Schema resultSetSchema = prepare.getDataSetSchema(); + final Schema parameterSchema = prepare.getParameterSchema(); signature.columns.addAll(ConvertUtils.convertArrowFieldsToColumnMetaDataList(resultSetSchema.getFields())); + signature.parameters.addAll(ConvertUtils.convertArrowFieldsToAvaticaParameters(parameterSchema.getFields())); return new ArrowFlightPreparedStatement( connection, prepare, statementHandle, @@ -91,8 +108,130 @@ public synchronized void close() throws SQLException { super.close(); } + @Override + public long executeLargeUpdate() throws SQLException { + copyParameters(); + return preparedStatement.executeUpdate(); + } + @Override public FlightInfo executeFlightInfoQuery() throws SQLException { + copyParameters(); return preparedStatement.executeQuery(); } + + private void copyParameters() throws SQLException { + BufferAllocator allocator = new RootAllocator(50000000); + List fields = new ArrayList<>(); + List values = this.getParameterValues(); + for(int i = 0; i < this.getParameterCount(); i++) { + AvaticaParameter param = this.getParameter(i + 1); + switch (param.parameterType) { + case java.sql.Types.TINYINT: + case java.sql.Types.SMALLINT: + case java.sql.Types.INTEGER: + IntVector intVec = new IntVector(param.name, allocator); + intVec.setSafe(0, (int)values.get(i).value); + intVec.setValueCount(1); + fields.add(intVec); + break; + case java.sql.Types.BIGINT: + BigIntVector longVec = new BigIntVector(param.name, allocator); + Object lv = values.get(i).value; + if(lv instanceof Long){ + longVec.setSafe(0, (long)lv); + } else { + longVec.setSafe(0, (int)lv); + } + longVec.setValueCount(1); + fields.add(longVec); + break; + case java.sql.Types.BIT: + case java.sql.Types.BOOLEAN: + BitVector bitVec = new BitVector(param.name, allocator); + bitVec.setSafe(0, (int)values.get(i).value); + bitVec.setValueCount(1); + fields.add(bitVec); + break; + case java.sql.Types.FLOAT: + Float4Vector floatVec = new Float4Vector(param.name, allocator); + TypedValue tfVal = values.get(i); + float floatVal = (float)tfVal.value; + floatVec.setSafe(0, floatVal); + floatVec.setValueCount(1); + fields.add(floatVec); + break; + case java.sql.Types.DOUBLE: + Float8Vector doubleVec = new Float8Vector(param.name, allocator); + doubleVec.setSafe(0, (double)values.get(i).value); + doubleVec.setValueCount(1); + fields.add(doubleVec); + break; + case java.sql.Types.REAL: + case java.sql.Types.NUMERIC: + case java.sql.Types.DECIMAL: + DecimalVector decVec = new DecimalVector(param.name, allocator, param.precision, param.scale); + decVec.setSafe(0, (BigDecimal) values.get(i).value); + decVec.setValueCount(1); + fields.add(decVec); + break; + case java.sql.Types.CHAR: + case java.sql.Types.VARCHAR: + case java.sql.Types.NCHAR: + case java.sql.Types.NVARCHAR: + Text txt = new Text((String) values.get(i).value); + VarCharVector strVec = new VarCharVector(param.name, allocator); + strVec.setSafe(0, txt); + strVec.setValueCount(1); + fields.add(strVec); + break; + case java.sql.Types.LONGVARCHAR: + case java.sql.Types.LONGNVARCHAR: + LargeVarCharVector textVec = new LargeVarCharVector(param.name, allocator); + textVec.setSafe(0, new Text((String) values.get(i).value)); + textVec.setValueCount(1); + fields.add(textVec); + break; + case java.sql.Types.DATE: + case java.sql.Types.TIME: + case java.sql.Types.TIMESTAMP: + case java.sql.Types.TIME_WITH_TIMEZONE: + case java.sql.Types.TIMESTAMP_WITH_TIMEZONE: + DateMilliVector timeVec = new DateMilliVector(param.name, allocator); + TypedValue tmVal = values.get(i); + String dtStr = (String)tmVal.value; + + String pattern = "yyyy-MM-dd HH:mm:ss.SSS"; + DateTime dateTime = DateTime.parse(dtStr, DateTimeFormat.forPattern(pattern)); + + timeVec.setSafe(0, dateTime.getMillis()); + timeVec.setValueCount(1); + fields.add(timeVec); + break; + case java.sql.Types.BINARY: + case java.sql.Types.VARBINARY: + VarBinaryVector binVec = new VarBinaryVector(param.name, allocator); + binVec.setSafe(0, (byte[])values.get(i).value); + binVec.setValueCount(1); + fields.add(binVec); + break; + case java.sql.Types.BLOB: + case java.sql.Types.LONGVARBINARY: + LargeVarBinaryVector blobVec = new LargeVarBinaryVector(param.name, allocator); + blobVec.setSafe(0, (byte[])values.get(i).value); + blobVec.setValueCount(1); + fields.add(blobVec); + break; + case java.sql.Types.NULL: + NullVector nullVec = new NullVector(param.name); + nullVec.setValueCount(1); + fields.add(nullVec); + break; + default: + throw new SQLException("Unknown type: " + param.typeName); + } + } + VectorSchemaRoot parameters = new VectorSchemaRoot(fields); + this.preparedStatement.setParameters(parameters); + } } diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java index 7b059ab02f851..76b792cba4248 100644 --- a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/client/ArrowFlightSqlClientHandler.java @@ -27,6 +27,7 @@ import java.util.Set; import java.util.stream.Collectors; +import com.codahale.metrics.Timer; import org.apache.arrow.driver.jdbc.client.utils.ClientAuthenticationUtils; import org.apache.arrow.flight.CallOption; import org.apache.arrow.flight.FlightClient; @@ -48,11 +49,14 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.util.AutoCloseables; import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.types.pojo.Schema; import org.apache.calcite.avatica.Meta.StatementType; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import static com.codahale.metrics.MetricRegistry.name; + /** * A {@link FlightSqlClient} handler. */ @@ -96,10 +100,12 @@ private CallOption[] getOptions() { * @return a {@code FlightStream} of results. */ public List getStreams(final FlightInfo flightInfo) { - return flightInfo.getEndpoints().stream() - .map(FlightEndpoint::getTicket) - .map(ticket -> sqlClient.getStream(ticket, getOptions())) - .collect(Collectors.toList()); +// try(final Timer.Context context = handlerGetStreams.time()) { + return flightInfo.getEndpoints().stream() + .map(FlightEndpoint::getTicket) + .map(ticket -> sqlClient.getStream(ticket, getOptions())) + .collect(Collectors.toList()); +// } } /** @@ -155,6 +161,15 @@ public interface PreparedStatement extends AutoCloseable { */ Schema getDataSetSchema(); + /** + * Gets the {@link Schema} of the parameters for this {@link PreparedStatement}. + * + * @return {@link Schema}. + */ + Schema getParameterSchema(); + + void setParameters(VectorSchemaRoot parameters); + @Override void close(); } @@ -190,6 +205,16 @@ public Schema getDataSetSchema() { return preparedStatement.getResultSetSchema(); } + @Override + public Schema getParameterSchema() { + return preparedStatement.getParameterSchema(); + } + + @Override + public void setParameters(VectorSchemaRoot parameters) { + preparedStatement.setParameters(parameters); + } + @Override public void close() { try { diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/ConvertUtils.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/ConvertUtils.java index 324f991ef09e9..bbf1e4138b060 100644 --- a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/ConvertUtils.java +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/ConvertUtils.java @@ -17,6 +17,13 @@ package org.apache.arrow.driver.jdbc.utils; +import java.math.BigDecimal; +import java.sql.Date; +import java.sql.JDBCType; +import java.sql.Time; +import java.sql.Timestamp; +import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -25,6 +32,7 @@ import org.apache.arrow.flight.sql.FlightSqlColumnMetadata; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; +import org.apache.calcite.avatica.AvaticaParameter; import org.apache.calcite.avatica.ColumnMetaData; import org.apache.calcite.avatica.proto.Common; import org.apache.calcite.avatica.proto.Common.ColumnMetaData.Builder; @@ -37,6 +45,125 @@ public final class ConvertUtils { private ConvertUtils() { } + static boolean isSigned(ArrowType.ArrowTypeID type) { + switch (type) { + case Int: + case FloatingPoint: + case Decimal: + case Interval: + case Duration: + return true; + } + return false; + } + + static int toJdbcType(ArrowType type) { + switch (type.getTypeID()) { + case Null: + return JDBCType.NULL.getVendorTypeNumber(); + case Struct: + return JDBCType.STRUCT.getVendorTypeNumber(); + case List: + case LargeList: + case FixedSizeList: + return JDBCType.ARRAY.getVendorTypeNumber(); + case Int: + ArrowType.Int t = (ArrowType.Int)type; + switch (t.getBitWidth()) { + case 64: + return JDBCType.BIGINT.getVendorTypeNumber(); + default: + return JDBCType.INTEGER.getVendorTypeNumber(); + } + case FloatingPoint: + return JDBCType.FLOAT.getVendorTypeNumber(); + case Utf8: + case LargeUtf8: + return JDBCType.VARCHAR.getVendorTypeNumber(); + case Binary: + case LargeBinary: + return JDBCType.VARBINARY.getVendorTypeNumber(); + case FixedSizeBinary: + return JDBCType.BINARY.getVendorTypeNumber(); + case Bool: + return JDBCType.BOOLEAN.getVendorTypeNumber(); + case Decimal: + return JDBCType.DECIMAL.getVendorTypeNumber(); + case Date: + return JDBCType.DATE.getVendorTypeNumber(); + case Time: + return JDBCType.TIME.getVendorTypeNumber(); + case Timestamp: + case Interval: + case Duration: + return JDBCType.TIMESTAMP.getVendorTypeNumber(); + } + return JDBCType.OTHER.getVendorTypeNumber(); + } + + static String getClassName(ArrowType type) { + switch (type.getTypeID()) { + case List: + case LargeList: + case FixedSizeList: + return ArrayList.class.getCanonicalName(); + case Map: + return HashMap.class.getCanonicalName(); + case Int: + ArrowType.Int t = (ArrowType.Int)type; + switch (t.getBitWidth()) { + case 64: + return long.class.getCanonicalName(); + default: + return int.class.getCanonicalName(); + } + case FloatingPoint: + return float.class.getCanonicalName(); + case Utf8: + case LargeUtf8: + return String.class.getCanonicalName(); + case Binary: + case LargeBinary: + case FixedSizeBinary: + return byte[].class.getCanonicalName(); + case Bool: + return boolean.class.getCanonicalName(); + case Decimal: + return BigDecimal.class.getCanonicalName(); + case Date: + return Date.class.getCanonicalName(); + case Time: + return Time.class.getCanonicalName(); + case Timestamp: + case Interval: + case Duration: + return Timestamp.class.getCanonicalName(); + } + return null; + } + + /** + * Convert Fields To Avatica Parameters. + * + * @param fields list of {@link Field}. + * @return list of {@link AvaticaParameter}. + */ + public static List convertArrowFieldsToAvaticaParameters(final List fields) { + List list = new ArrayList<>(); + for(Field field : fields) { + final boolean signed = isSigned(field.getType().getTypeID()); + final int precision = 0; // Would have to know about the actual number + final int scale = 0; // According to https://www.postgresql.org/docs/current/datatype-numeric.html + final int type = toJdbcType(field.getType()); + final String typeName = field.getType().toString(); + final String clazz = getClassName(field.getType()); + final String name = field.getName(); + final AvaticaParameter param = new AvaticaParameter(signed, precision, scale, type, typeName, clazz, name); + list.add(param); + } + return list; + } + /** * Convert Fields To Column MetaData List functions. * diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/FlightStreamQueue.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/FlightStreamQueue.java index e1d770800e40c..815829c30a325 100644 --- a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/FlightStreamQueue.java +++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/utils/FlightStreamQueue.java @@ -17,8 +17,10 @@ package org.apache.arrow.driver.jdbc.utils; +import static com.codahale.metrics.MetricRegistry.name; import static java.lang.String.format; import static java.util.Collections.synchronizedSet; +import static org.apache.arrow.flight.FlightStream.metrics; import static org.apache.arrow.util.Preconditions.checkNotNull; import static org.apache.arrow.util.Preconditions.checkState; @@ -36,9 +38,11 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import com.codahale.metrics.Timer; import org.apache.arrow.flight.CallStatus; import org.apache.arrow.flight.FlightRuntimeException; import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.sql.FlightSqlClient; import org.apache.calcite.avatica.AvaticaConnection; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -62,6 +66,9 @@ public class FlightStreamQueue implements AutoCloseable { private final Set allStreams = synchronizedSet(new HashSet<>()); private final AtomicBoolean closed = new AtomicBoolean(); + private static final Timer next = metrics.timer(name(FlightStreamQueue.class, "next")); + private static final Timer take = metrics.timer(name(FlightStreamQueue.class, "take")); + /** * Instantiate a new FlightStreamQueue. */ @@ -97,20 +104,22 @@ interface FlightStreamSupplier { } private FlightStream next(final FlightStreamSupplier flightStreamSupplier) throws SQLException { - checkOpen(); - while (!futures.isEmpty()) { - final Future future = flightStreamSupplier.get(); - futures.remove(future); - try { - final FlightStream stream = future.get(); - if (stream.getRoot().getRowCount() > 0) { - return stream; + try (final Timer.Context context = next.time()) { + checkOpen(); + while (!futures.isEmpty()) { + final Future future = flightStreamSupplier.get(); + futures.remove(future); + try { + final FlightStream stream = future.get(); + if (stream.getRoot().getRowCount() > 0) { + return stream; + } + } catch (final ExecutionException | InterruptedException | CancellationException e) { + throw AvaticaConnection.HELPER.wrap(e.getMessage(), e); } - } catch (final ExecutionException | InterruptedException | CancellationException e) { - throw AvaticaConnection.HELPER.wrap(e.getMessage(), e); } + return null; } - return null; } /** @@ -145,7 +154,9 @@ public FlightStream next(final long timeoutValue, final TimeUnit timeoutUnit) public FlightStream next() throws SQLException { return next(() -> { try { - return completionService.take(); + try (final Timer.Context context = take.time()) { + return completionService.take(); + } } catch (final InterruptedException e) { throw AvaticaConnection.HELPER.wrap(e.getMessage(), e); } diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcDriverTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcDriverTest.java index 9b8fa96d2320e..106cc83e9428e 100644 --- a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcDriverTest.java +++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightJdbcDriverTest.java @@ -17,20 +17,13 @@ package org.apache.arrow.driver.jdbc; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; - -import java.sql.Connection; -import java.sql.Driver; -import java.sql.DriverManager; -import java.sql.SQLException; +import java.sql.*; import java.util.Collection; import java.util.Map; import java.util.Properties; +import java.util.concurrent.TimeUnit; +import com.codahale.metrics.ConsoleReporter; import org.apache.arrow.driver.jdbc.authentication.UserPasswordAuthentication; import org.apache.arrow.driver.jdbc.utils.ArrowFlightConnectionConfigImpl.ArrowFlightConnectionProperty; import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer; @@ -42,6 +35,9 @@ import org.junit.ClassRule; import org.junit.Test; +import static org.apache.arrow.flight.FlightStream.metrics; +import static org.junit.jupiter.api.Assertions.*; + /** * Tests for {@link ArrowFlightJdbcDriver}. */ @@ -133,6 +129,413 @@ public void testShouldConnectWhenProvidedWithValidUrl() throws Exception { } } + @Test + public void testQueryParameters() throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + Properties props = new Properties(); + props.setProperty("user", "admin"); + props.setProperty("password", "password"); + props.setProperty("useEncryption", "false"); + String conString = "jdbc:arrow-flight://127.0.0.1:50060"; + try (Connection con = driver.connect(conString, props); Statement stmt = con.createStatement()) { + try { + stmt.execute("create table person (id int, name varchar, primary key(id))"); + } catch (Exception ignored) {} + try(PreparedStatement ps = con.prepareStatement("select * from person where id=$1")) { + ParameterMetaData md = ps.getParameterMetaData(); + assertEquals(1, md.getParameterCount()); + assertEquals("Int", md.getParameterTypeName(1)); + ps.setInt(1, 1); + ResultSet rs = ps.executeQuery(); + assertFalse(rs.next()); // should be no records + } + } + } + + @Test + public void testWarehouse() throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + Properties props = new Properties(); + props.setProperty("user", "admin"); + props.setProperty("password", "password"); + props.setProperty("useEncryption", "false"); + String conString = "jdbc:arrow-flight://127.0.0.1:50060"; + try (Connection con = driver.connect(conString, props); Statement stmt = con.createStatement()) { + try { + stmt.execute("create table warehouse (\n" + + "w_id int,\n" + + "w_name string,\n" + + "w_street_1 string,\n" + + "w_street_2 string,\n" + + "w_city string,\n" + + "w_state string,\n" + + "w_zip string,\n" + + "w_tax float,\n" + + "w_ytd float,\n" + + "primary key (w_id)\n" + + ");\n"); + } catch (Exception ignored) {} + String sql = "SELECT w_street_1, w_street_2, w_city, w_state, w_zip, w_name FROM warehouse WHERE w_id = $1"; + try(PreparedStatement ps = con.prepareStatement(sql)) { + ParameterMetaData md = ps.getParameterMetaData(); + assertEquals(1, md.getParameterCount()); + assertEquals("Int", md.getParameterTypeName(1)); + ps.setInt(1, 1); + ResultSet rs = ps.executeQuery(); + assertNotNull(rs); + } + } + } + + @Test + public void testDistrict() throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + Properties props = new Properties(); + props.setProperty("user", "admin"); + props.setProperty("password", "password"); + props.setProperty("useEncryption", "false"); + String conString = "jdbc:arrow-flight://127.0.0.1:50060"; + try (Connection con = driver.connect(conString, props); Statement stmt = con.createStatement()) { + try { + stmt.execute("create table district (\n" + + "d_id int,\n" + + "d_w_id int,\n" + + "d_name string,\n" + + "d_street_1 string,\n" + + "d_street_2 string,\n" + + "d_city string,\n" + + "d_state string,\n" + + "d_zip string,\n" + + "d_tax float,\n" + + "d_ytd float,\n" + + "d_next_o_id int,\n" + + "primary key (d_w_id, d_id)\n" + + ");\n"); + } catch (Exception ignored) {} + String sql = "UPDATE district SET d_next_o_id = $1 + 1 WHERE d_id = $2 AND d_w_id = $3"; + try(PreparedStatement ps = con.prepareStatement(sql)) { + ps.setLong(1, 0); + ps.setInt(2, 2); + ps.setInt(3, 1); + ps.executeUpdate(); + } + } + } + + @Test + public void testDistrictForUpdate() throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + Properties props = new Properties(); + props.setProperty("user", "admin"); + props.setProperty("password", "password"); + props.setProperty("useEncryption", "false"); + String conString = "jdbc:arrow-flight://127.0.0.1:50060"; + try (Connection con = driver.connect(conString, props); Statement stmt = con.createStatement()) { + try { + stmt.execute("create table district (\n" + + "d_id int,\n" + + "d_w_id int,\n" + + "d_name string,\n" + + "d_street_1 string,\n" + + "d_street_2 string,\n" + + "d_city string,\n" + + "d_state string,\n" + + "d_zip string,\n" + + "d_tax float,\n" + + "d_ytd float,\n" + + "d_next_o_id int,\n" + + "primary key (d_w_id, d_id)\n" + + ");\n"); + } catch (Exception ignored) {} + String sql = "SELECT d_next_o_id, d_tax FROM district WHERE d_id = $1 AND d_w_id = $2 FOR UPDATE"; + try(PreparedStatement ps = con.prepareStatement(sql)) { + ps.setInt(1, 0); + ps.setInt(2, 2); + ps.executeUpdate(); + } + } + } + + @Test + public void testHistory() throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + Properties props = new Properties(); + props.setProperty("user", "admin"); + props.setProperty("password", "password"); + props.setProperty("useEncryption", "false"); + String conString = "jdbc:arrow-flight://127.0.0.1:50060"; + try (Connection con = driver.connect(conString, props); Statement stmt = con.createStatement()) { + try { + stmt.execute("create table history (\n" + + "h_c_id int,\n" + + "h_date timestamp,\n" + + "h_amount float,\n" + + "h_data varchar,\n" + + "PRIMARY KEY(h_c_id)\n" + + ");"); + } catch (Exception ignored) {} + String sql = "INSERT INTO history(h_c_id, h_date, h_amount, h_data) VALUES($1, $2, $3, $4)"; + try(PreparedStatement ps = con.prepareStatement(sql)) { + ps.setInt(1, 1); + ps.setString(2, "2023-02-03 12:31:00.16"); + ps.setFloat(3, 1.0f); + ps.setString(4, "test"); + ResultSet rs = ps.executeQuery(); + assertNotNull(rs); + } + } + } + + @Test + public void testRegion() throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + Properties props = new Properties(); + props.setProperty("user", "admin"); + props.setProperty("password", "password"); + props.setProperty("useEncryption", "false"); + String conString = "jdbc:arrow-flight://127.0.0.1:50060"; + try (Connection con = driver.connect(conString, props); Statement stmt = con.createStatement()) { + try { + stmt.execute("create table region (r_regionkey int, r_name varchar, r_comment varchar, primary key (r_regionkey))"); + } catch (Exception ignored) {} + String sql = "insert into region (r_regionkey, r_name, r_comment) values\n" + + "\t($1, $2, $3),\n" + + "\t($4, $5, $6),\n" + + "\t($7, $8, $9),\n" + + "\t($10, $11, $12)"; + try(PreparedStatement ps = con.prepareStatement(sql)) { + + ps.setInt(1, 1); + ps.setString(2, "Africa"); + ps.setString(3, "Africa"); + + ps.setInt(4, 1); + ps.setString(5, "Americas"); + ps.setString(6, "Americas"); + + ps.setInt(7, 1); + ps.setString(8, "Europe"); + ps.setString(9, "Europe"); + + ps.setInt(10, 1); + ps.setString(11, "Asia"); + ps.setString(12, "Asia"); + + int res = ps.executeUpdate(); + assertNotNull(res); + } + } + } + + @Test + public void testOrders() throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + Properties props = new Properties(); + props.setProperty("user", "admin"); + props.setProperty("password", "password"); + props.setProperty("useEncryption", "false"); + String conString = "jdbc:arrow-flight://127.0.0.1:50060"; + try (Connection con = driver.connect(conString, props); Statement stmt = con.createStatement()) { + try { + stmt.execute("create table orders (\n" + + "o_id int, \n" + + "o_d_id int, \n" + + "o_w_id int,\n" + + "o_c_id int,\n" + + "o_entry_d timestamp,\n" + + "o_carrier_id int,\n" + + "o_ol_cnt int, \n" + + "o_all_local int,\n" + + "PRIMARY KEY(o_w_id, o_d_id, o_id) \n" + + ")"); + } catch (Exception ignored) {} + String sql = "INSERT INTO orders (o_id, o_d_id, o_w_id, o_c_id, o_entry_d, o_ol_cnt, o_all_local) VALUES($1, $2, $3, $4, $5, $6, $7)"; + try(PreparedStatement ps = con.prepareStatement(sql)) { + ps.setInt(1, 1); + ps.setInt(2, 1); + ps.setInt(3, 1); + ps.setInt(4, 1); + ps.setString(5, "2023-02-03 12:31:00.16"); + ps.setInt(6, 1); + ps.setInt(7, 1); + ps.executeUpdate(); + } + } + } + + @Test + public void testSubqyeryParams() throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + Properties props = new Properties(); + props.setProperty("user", "admin"); + props.setProperty("password", "password"); + props.setProperty("useEncryption", "false"); + String conString = "jdbc:arrow-flight://127.0.0.1:50060"; + try (Connection con = driver.connect(conString, props); Statement stmt = con.createStatement()) { + try { + stmt.execute("create table orders (\n" + + "o_id int, \n" + + "o_d_id int, \n" + + "o_w_id int,\n" + + "o_c_id int,\n" + + "o_entry_d timestamp,\n" + + "o_carrier_id int,\n" + + "o_ol_cnt int, \n" + + "o_all_local int,\n" + + "PRIMARY KEY(o_w_id, o_d_id, o_id) \n" + + ")"); + } catch (Exception ignored) {} + String sql = "SELECT o_id, o_entry_d, COALESCE(o_carrier_id,0) FROM orders WHERE o_w_id = $1 AND o_d_id = $2 AND o_c_id = $3 AND o_id = (SELECT MAX(o_id) FROM orders WHERE o_w_id = $4 AND o_d_id = $5 AND o_c_id = $6)"; + try(PreparedStatement ps = con.prepareStatement(sql)) { + ps.setInt(1, 1); + ps.setInt(2, 1); + ps.setInt(3, 1); + ps.setInt(4, 1); + ps.setInt(5, 1); + ps.setInt(6, 1); + ps.executeQuery(); + } + } + } + + @Test + public void testTpcCQuery6() throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + Properties props = new Properties(); + props.setProperty("user", "admin"); + props.setProperty("password", "password"); + props.setProperty("useEncryption", "false"); + String conString = "jdbc:arrow-flight://127.0.0.1:50060"; + String sql = "SELECT\n" + + " 27 as s_quantity,\n" + + " 's_data' as s_data,\n" + + " 's_dist_01' as s_dist_01,\n" + + " 's_dist_02' as s_dist_02,\n" + + " 's_dist_03' as s_dist_03,\n" + + " 's_dist_04' as s_dist_04,\n" + + " 's_dist_05' as s_dist_05,\n" + + " 's_dist_06' as s_dist_06,\n" + + " 's_dist_07' as s_dist_07,\n" + + " 's_dist_08' as s_dist_08,\n" + + " 's_dist_09' as s_dist_09,\n" + + " 's_dist_10' as s_dist_10\n" + + ";\n"; + try (Connection con = driver.connect(conString, props)) { + try(PreparedStatement stmt = con.prepareStatement(sql)) { + long sum = 0; + long cnt = 100; + for(long i = 0; i < cnt; i++) { +// stmt.setInt(1, 1); +// stmt.setInt(2, 1); + long start = System.currentTimeMillis(); + try(ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + int quantity = rs.getInt(1); + System.out.format("quantity=%d\n", quantity); + assertFalse(rs.next()); + long end = System.currentTimeMillis(); + long delta = end - start; + sum += delta; + System.out.format("Selected single row in %dms\n\n", delta); + } + Thread.sleep(100); + } + System.out.format("Average time=%dms\n", sum / cnt); + ConsoleReporter reporter = ConsoleReporter.forRegistry(metrics) + .convertRatesTo(TimeUnit.SECONDS) + .convertDurationsTo(TimeUnit.MILLISECONDS) + .build(); + reporter.report(); + } + } + } + + @Test + public void testCustomer() throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + Properties props = new Properties(); + props.setProperty("user", "admin"); + props.setProperty("password", "password"); + props.setProperty("useEncryption", "false"); + String conString = "jdbc:arrow-flight://127.0.0.1:50060"; + try (Connection con = driver.connect(conString, props); Statement stmt = con.createStatement()) { + try { + stmt.execute("create table customer (\n" + + "c_id int,\n" + + "c_d_id int,\n" + + "c_w_id int,\n" + + "c_first string,\n" + + "c_middle string,\n" + + "c_last string,\n" + + "c_street_1 string,\n" + + "c_street_2 string,\n" + + "c_city string,\n" + + "c_state string,\n" + + "c_zip string,\n" + + "c_phone string,\n" + + "c_since datetime,\n" + + "c_credit string,\n" + + "c_credit_lim int,\n" + + "c_discount float,\n" + + "c_balance float,\n" + + "c_ytd_payment float,\n" + + "c_payment_cnt int,\n" + + "c_delivery_cnt int,\n" + + "c_data string,\n" + + "PRIMARY KEY(c_w_id, c_d_id, c_id)\n" + + ");"); + } catch (Exception ignored) {} + String sql = "SELECT count(c_id) FROM customer WHERE c_w_id = $1 AND c_d_id = $2 AND c_last = $3"; + try(PreparedStatement ps = con.prepareStatement(sql)) { + ParameterMetaData md = ps.getParameterMetaData(); + ps.setInt(1, 1); + ps.setInt(2, 1); + ps.setString(3, "Acme"); + ResultSet rs = ps.executeQuery(); + assertNotNull(rs); + } + } + } + + @Test + public void testDmlParameters() throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + Properties props = new Properties(); + props.setProperty("user", "admin"); + props.setProperty("password", "password"); + props.setProperty("useEncryption", "false"); + String conString = "jdbc:arrow-flight://127.0.0.1:50060"; + try (Connection con = driver.connect(conString, props); Statement stmt = con.createStatement()) { + try { + stmt.execute("create table person (id int, name varchar, primary key(id))"); + } catch (Exception ignored) {} + try(PreparedStatement ps = con.prepareStatement("insert into person (id, name) values ($1, $2)")) { + ParameterMetaData md = ps.getParameterMetaData(); + assertEquals(2, md.getParameterCount()); + assertEquals("Int", md.getParameterTypeName(1)); + assertEquals("Utf8", md.getParameterTypeName(2)); + ps.setInt(1, 1); + ps.setString(2, "Alan"); + assertEquals(-1, ps.executeUpdate()); + ResultSet rs = ps.getResultSet(); + assertNull(rs); + } + } + } + + @Test + public void testSetVariable() throws Exception { + final Driver driver = new ArrowFlightJdbcDriver(); + Properties props = new Properties(); + props.setProperty("user", "admin"); + props.setProperty("password", "password"); + props.setProperty("useEncryption", "false"); + String conString = "jdbc:arrow-flight://127.0.0.1:50060"; + try (Connection con = driver.connect(conString, props); Statement stmt = con.createStatement()) { + assertTrue(stmt.execute("SET UNIQUE_CHECKS=0")); + } + } + @Test public void testConnectWithInsensitiveCasePropertyKeys() throws Exception { // Get the Arrow Flight JDBC driver by providing a URL with insensitive case property keys. diff --git a/java/flight/flight-sql/pom.xml b/java/flight/flight-sql/pom.xml index ee218a2f2aa2e..fce7c82a7cb13 100644 --- a/java/flight/flight-sql/pom.xml +++ b/java/flight/flight-sql/pom.xml @@ -108,6 +108,11 @@ commons-cli 1.4 + + io.dropwizard.metrics + metrics-core + 4.2.16 + diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java index 922495a18e0c9..7ccd6c0e9fa41 100644 --- a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java @@ -17,6 +17,7 @@ package org.apache.arrow.flight.sql; +import static com.codahale.metrics.MetricRegistry.name; import static org.apache.arrow.flight.sql.impl.FlightSql.ActionBeginSavepointRequest; import static org.apache.arrow.flight.sql.impl.FlightSql.ActionBeginSavepointResult; import static org.apache.arrow.flight.sql.impl.FlightSql.ActionBeginTransactionRequest; @@ -56,6 +57,7 @@ import java.util.concurrent.ExecutionException; import java.util.stream.Collectors; +import com.codahale.metrics.Timer; import org.apache.arrow.flight.Action; import org.apache.arrow.flight.CallOption; import org.apache.arrow.flight.CallStatus; @@ -89,6 +91,19 @@ public class FlightSqlClient implements AutoCloseable { private final FlightClient client; +// private static final Timer stmtGetInfo = metrics.timer(name(FlightSqlClient.class, "stmtGetInfo")); +// private static final Timer stmtGetSchema = metrics.timer(name(FlightSqlClient.class, "stmtGetSchema")); +// private static final Timer stmtExecUpdate = metrics.timer(name(FlightSqlClient.class, "stmtExecUpdate")); +// private static final Timer beginTxn = metrics.timer(name(FlightSqlClient.class, "beginTxn")); +// private static final Timer commitTxn = metrics.timer(name(FlightSqlClient.class, "commitTxn")); +// private static final Timer rollbackTxn = metrics.timer(name(FlightSqlClient.class, "rollbackTxn")); +// private static final Timer psFetchSchema = metrics.timer(name(FlightSqlClient.class, "psFetchSchema")); +// private static final Timer psUpdate = metrics.timer(name(FlightSqlClient.class, "psUpdate")); +// private static final Timer psClose = metrics.timer(name(FlightSqlClient.class, "psClose")); +// private static final Timer getSchema = metrics.timer(name(FlightSqlClient.class, "getSchema")); +// private static final Timer getPsSchema = metrics.timer(name(FlightSqlClient.class, "getPsSchema")); +// private static final Timer getRsSchema = metrics.timer(name(FlightSqlClient.class, "getRsSchema")); + public FlightSqlClient(final FlightClient client) { this.client = Objects.requireNonNull(client, "Client cannot be null!"); } @@ -118,7 +133,9 @@ public FlightInfo execute(final String query, Transaction transaction, final Cal builder.setTransactionId(ByteString.copyFrom(transaction.getTransactionId())); } final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(builder.build()).toByteArray()); - return client.getInfo(descriptor, options); +// try(final Timer.Context context = stmtGetInfo.time()) { + return client.getInfo(descriptor, options); +// } } /** @@ -160,7 +177,9 @@ public SchemaResult getExecuteSchema(String query, Transaction transaction, Call builder.setTransactionId(ByteString.copyFrom(transaction.getTransactionId())); } final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(builder.build()).toByteArray()); - return client.getSchema(descriptor, options); +// try(final Timer.Context context = stmtGetSchema.time()) { + return client.getSchema(descriptor, options); +// } } /** @@ -211,27 +230,29 @@ public long executeUpdate(final String query, final CallOption... options) { * @return the number of rows affected. */ public long executeUpdate(final String query, Transaction transaction, final CallOption... options) { - final CommandStatementUpdate.Builder builder = CommandStatementUpdate.newBuilder().setQuery(query); - if (transaction != null) { - builder.setTransactionId(ByteString.copyFrom(transaction.getTransactionId())); - } +// try(final Timer.Context context = stmtExecUpdate.time()) { + final CommandStatementUpdate.Builder builder = CommandStatementUpdate.newBuilder().setQuery(query); + if (transaction != null) { + builder.setTransactionId(ByteString.copyFrom(transaction.getTransactionId())); + } - final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(builder.build()).toByteArray()); - try (final SyncPutListener putListener = new SyncPutListener()) { - final FlightClient.ClientStreamListener listener = - client.startPut(descriptor, VectorSchemaRoot.of(), putListener, options); - try (final PutResult result = putListener.read()) { - final DoPutUpdateResult doPutUpdateResult = DoPutUpdateResult.parseFrom( - result.getApplicationMetadata().nioBuffer()); - return doPutUpdateResult.getRecordCount(); - } finally { - listener.getResult(); + final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(builder.build()).toByteArray()); + try (final SyncPutListener putListener = new SyncPutListener()) { + final FlightClient.ClientStreamListener listener = + client.startPut(descriptor, VectorSchemaRoot.of(), putListener, options); + try (final PutResult result = putListener.read()) { + final DoPutUpdateResult doPutUpdateResult = DoPutUpdateResult.parseFrom( + result.getApplicationMetadata().nioBuffer()); + return doPutUpdateResult.getRecordCount(); + } finally { + listener.getResult(); + } + } catch (final InterruptedException | ExecutionException e) { + throw CallStatus.CANCELLED.withCause(e).toRuntimeException(); + } catch (final InvalidProtocolBufferException e) { + throw CallStatus.INTERNAL.withCause(e).toRuntimeException(); } - } catch (final InterruptedException | ExecutionException e) { - throw CallStatus.CANCELLED.withCause(e).toRuntimeException(); - } catch (final InvalidProtocolBufferException e) { - throw CallStatus.INTERNAL.withCause(e).toRuntimeException(); - } +// } } /** @@ -263,10 +284,10 @@ public long executeSubstraitUpdate(SubstraitPlan plan, Transaction transaction, final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(builder.build()).toByteArray()); try (final SyncPutListener putListener = new SyncPutListener()) { final FlightClient.ClientStreamListener listener = - client.startPut(descriptor, VectorSchemaRoot.of(), putListener, options); + client.startPut(descriptor, VectorSchemaRoot.of(), putListener, options); try (final PutResult result = putListener.read()) { final DoPutUpdateResult doPutUpdateResult = DoPutUpdateResult.parseFrom( - result.getApplicationMetadata().nioBuffer()); + result.getApplicationMetadata().nioBuffer()); return doPutUpdateResult.getRecordCount(); } finally { listener.getResult(); @@ -342,7 +363,9 @@ public SchemaResult getSchemasSchema(final CallOption... options) { * @param options RPC-layer hints for this call. */ public SchemaResult getSchema(FlightDescriptor descriptor, CallOption... options) { - return client.getSchema(descriptor, options); +// try(final Timer.Context context = getSchema.time()) { + return client.getSchema(descriptor, options); +// } } /** @@ -352,7 +375,9 @@ public SchemaResult getSchema(FlightDescriptor descriptor, CallOption... options * @param options RPC-layer hints for this call. */ public FlightStream getStream(Ticket ticket, CallOption... options) { - return client.getStream(ticket, options); +// try(final Timer.Context context = getStream.time()) { + return client.getStream(ticket, options); +// } } /** @@ -713,16 +738,18 @@ public PreparedStatement prepare(String query, CallOption... options) { * @return The representation of the prepared statement which exists on the server. */ public PreparedStatement prepare(String query, Transaction transaction, CallOption... options) { - ActionCreatePreparedStatementRequest.Builder builder = - ActionCreatePreparedStatementRequest.newBuilder().setQuery(query); - if (transaction != null) { - builder.setTransactionId(ByteString.copyFrom(transaction.getTransactionId())); - } - return new PreparedStatement(client, - new Action( - FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType(), - Any.pack(builder.build()).toByteArray()), - options); +// try(final Timer.Context context = prepare.time()) { + ActionCreatePreparedStatementRequest.Builder builder = + ActionCreatePreparedStatementRequest.newBuilder().setQuery(query); + if (transaction != null) { + builder.setTransactionId(ByteString.copyFrom(transaction.getTransactionId())); + } + return new PreparedStatement(client, + new Action( + FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType(), + Any.pack(builder.build()).toByteArray()), + options); +// } } /** @@ -760,18 +787,21 @@ public PreparedStatement prepare(SubstraitPlan plan, Transaction transaction, Ca /** Begin a transaction. */ public Transaction beginTransaction(CallOption... options) { - final Action action = new Action( - FlightSqlUtils.FLIGHT_SQL_BEGIN_TRANSACTION.getType(), - Any.pack(ActionBeginTransactionRequest.getDefaultInstance()).toByteArray()); - final Iterator preparedStatementResults = client.doAction(action, options); - final ActionBeginTransactionResult result = FlightSqlUtils.unpackAndParseOrThrow( - preparedStatementResults.next().getBody(), - ActionBeginTransactionResult.class); - preparedStatementResults.forEachRemaining((ignored) -> { }); - if (result.getTransactionId().isEmpty()) { - throw CallStatus.INTERNAL.withDescription("Server returned an empty transaction ID").toRuntimeException(); - } - return new Transaction(result.getTransactionId().toByteArray()); +// try(final Timer.Context context = beginTxn.time()) { + final Action action = new Action( + FlightSqlUtils.FLIGHT_SQL_BEGIN_TRANSACTION.getType(), + Any.pack(ActionBeginTransactionRequest.getDefaultInstance()).toByteArray()); + final Iterator preparedStatementResults = client.doAction(action, options); + final ActionBeginTransactionResult result = FlightSqlUtils.unpackAndParseOrThrow( + preparedStatementResults.next().getBody(), + ActionBeginTransactionResult.class); + preparedStatementResults.forEachRemaining((ignored) -> { + }); + if (result.getTransactionId().isEmpty()) { + throw CallStatus.INTERNAL.withDescription("Server returned an empty transaction ID").toRuntimeException(); + } + return new Transaction(result.getTransactionId().toByteArray()); +// } } /** Create a savepoint within a transaction. */ @@ -797,16 +827,19 @@ public Savepoint beginSavepoint(Transaction transaction, String name, CallOption /** Commit a transaction. */ public void commit(Transaction transaction, CallOption... options) { - Preconditions.checkArgument(transaction.getTransactionId().length != 0, "Transaction must be initialized"); - ActionEndTransactionRequest request = ActionEndTransactionRequest.newBuilder() - .setTransactionId(ByteString.copyFrom(transaction.getTransactionId())) - .setActionValue(ActionEndTransactionRequest.EndTransaction.END_TRANSACTION_COMMIT.getNumber()) - .build(); - final Action action = new Action( - FlightSqlUtils.FLIGHT_SQL_END_TRANSACTION.getType(), - Any.pack(request).toByteArray()); - final Iterator preparedStatementResults = client.doAction(action, options); - preparedStatementResults.forEachRemaining((ignored) -> { }); +// try(final Timer.Context context = commitTxn.time()) { + Preconditions.checkArgument(transaction.getTransactionId().length != 0, "Transaction must be initialized"); + ActionEndTransactionRequest request = ActionEndTransactionRequest.newBuilder() + .setTransactionId(ByteString.copyFrom(transaction.getTransactionId())) + .setActionValue(ActionEndTransactionRequest.EndTransaction.END_TRANSACTION_COMMIT.getNumber()) + .build(); + final Action action = new Action( + FlightSqlUtils.FLIGHT_SQL_END_TRANSACTION.getType(), + Any.pack(request).toByteArray()); + final Iterator preparedStatementResults = client.doAction(action, options); + preparedStatementResults.forEachRemaining((ignored) -> { + }); +// } } /** Release a savepoint. */ @@ -825,16 +858,19 @@ public void release(Savepoint savepoint, CallOption... options) { /** Rollback a transaction. */ public void rollback(Transaction transaction, CallOption... options) { - Preconditions.checkArgument(transaction.getTransactionId().length != 0, "Transaction must be initialized"); - ActionEndTransactionRequest request = ActionEndTransactionRequest.newBuilder() - .setTransactionId(ByteString.copyFrom(transaction.getTransactionId())) - .setActionValue(ActionEndTransactionRequest.EndTransaction.END_TRANSACTION_ROLLBACK.getNumber()) - .build(); - final Action action = new Action( - FlightSqlUtils.FLIGHT_SQL_END_TRANSACTION.getType(), - Any.pack(request).toByteArray()); - final Iterator preparedStatementResults = client.doAction(action, options); - preparedStatementResults.forEachRemaining((ignored) -> { }); +// try(final Timer.Context context = rollbackTxn.time()) { + Preconditions.checkArgument(transaction.getTransactionId().length != 0, "Transaction must be initialized"); + ActionEndTransactionRequest request = ActionEndTransactionRequest.newBuilder() + .setTransactionId(ByteString.copyFrom(transaction.getTransactionId())) + .setActionValue(ActionEndTransactionRequest.EndTransaction.END_TRANSACTION_ROLLBACK.getNumber()) + .build(); + final Action action = new Action( + FlightSqlUtils.FLIGHT_SQL_END_TRANSACTION.getType(), + Any.pack(request).toByteArray()); + final Iterator preparedStatementResults = client.doAction(action, options); + preparedStatementResults.forEachRemaining((ignored) -> { + }); +// } } /** Rollback to a savepoint. */ @@ -907,11 +943,13 @@ public static class PreparedStatement implements AutoCloseable { PreparedStatement(FlightClient client, Action action, CallOption... options) { this.client = client; - final Iterator preparedStatementResults = client.doAction(action, options); - preparedStatementResult = FlightSqlUtils.unpackAndParseOrThrow( - preparedStatementResults.next().getBody(), - ActionCreatePreparedStatementResult.class); - isClosed = false; +// try(final Timer.Context context = psInit.time()) { + final Iterator preparedStatementResults = client.doAction(action, options); + preparedStatementResult = FlightSqlUtils.unpackAndParseOrThrow( + preparedStatementResults.next().getBody(), + ActionCreatePreparedStatementResult.class); + isClosed = false; +// } } /** @@ -947,11 +985,13 @@ public void clearParameters() { * @return the Schema of the resultset. */ public Schema getResultSetSchema() { - if (resultSetSchema == null) { - final ByteString bytes = preparedStatementResult.getDatasetSchema(); - resultSetSchema = deserializeSchema(bytes); - } - return resultSetSchema; +// try(final Timer.Context context = getRsSchema.time()) { + if (resultSetSchema == null) { + final ByteString bytes = preparedStatementResult.getDatasetSchema(); + resultSetSchema = deserializeSchema(bytes); + } + return resultSetSchema; +// } } /** @@ -960,11 +1000,13 @@ public Schema getResultSetSchema() { * @return the Schema of the parameters. */ public Schema getParameterSchema() { - if (parameterSchema == null) { - final ByteString bytes = preparedStatementResult.getParameterSchema(); - parameterSchema = deserializeSchema(bytes); - } - return parameterSchema; +// try(final Timer.Context context = getPsSchema.time()) { + if (parameterSchema == null) { + final ByteString bytes = preparedStatementResult.getParameterSchema(); + parameterSchema = deserializeSchema(bytes); + } + return parameterSchema; +// } } /** @@ -973,12 +1015,14 @@ public Schema getParameterSchema() { public SchemaResult fetchSchema(CallOption... options) { checkOpen(); - final FlightDescriptor descriptor = FlightDescriptor - .command(Any.pack(CommandPreparedStatementQuery.newBuilder() - .setPreparedStatementHandle(preparedStatementResult.getPreparedStatementHandle()) - .build()) - .toByteArray()); - return client.getSchema(descriptor, options); +// try(final Timer.Context context = psFetchSchema.time()) { + final FlightDescriptor descriptor = FlightDescriptor + .command(Any.pack(CommandPreparedStatementQuery.newBuilder() + .setPreparedStatementHandle(preparedStatementResult.getPreparedStatementHandle()) + .build()) + .toByteArray()); + return client.getSchema(descriptor, options); +// } } private Schema deserializeSchema(final ByteString bytes) { @@ -1002,24 +1046,26 @@ private Schema deserializeSchema(final ByteString bytes) { public FlightInfo execute(final CallOption... options) { checkOpen(); - final FlightDescriptor descriptor = FlightDescriptor - .command(Any.pack(CommandPreparedStatementQuery.newBuilder() - .setPreparedStatementHandle(preparedStatementResult.getPreparedStatementHandle()) - .build()) - .toByteArray()); +// try(final Timer.Context context = psExec.time()) { + final FlightDescriptor descriptor = FlightDescriptor + .command(Any.pack(CommandPreparedStatementQuery.newBuilder() + .setPreparedStatementHandle(preparedStatementResult.getPreparedStatementHandle()) + .build()) + .toByteArray()); - if (parameterBindingRoot != null && parameterBindingRoot.getRowCount() > 0) { - final SyncPutListener putListener = new SyncPutListener(); + if (parameterBindingRoot != null && parameterBindingRoot.getRowCount() > 0) { + final SyncPutListener putListener = new SyncPutListener(); - FlightClient.ClientStreamListener listener = - client.startPut(descriptor, parameterBindingRoot, putListener, options); + FlightClient.ClientStreamListener listener = + client.startPut(descriptor, parameterBindingRoot, putListener, options); - listener.putNext(); - listener.completed(); - listener.getResult(); - } + listener.putNext(); + listener.completed(); + listener.getResult(); + } - return client.getInfo(descriptor, options); + return client.getInfo(descriptor, options); +// } } /** @@ -1039,29 +1085,31 @@ protected final void checkOpen() { */ public long executeUpdate(final CallOption... options) { checkOpen(); - final FlightDescriptor descriptor = FlightDescriptor - .command(Any.pack(CommandPreparedStatementUpdate.newBuilder() - .setPreparedStatementHandle(preparedStatementResult.getPreparedStatementHandle()) - .build()) - .toByteArray()); - setParameters(parameterBindingRoot == null ? VectorSchemaRoot.of() : parameterBindingRoot); - final SyncPutListener putListener = new SyncPutListener(); - final FlightClient.ClientStreamListener listener = - client.startPut(descriptor, parameterBindingRoot, putListener, options); - listener.putNext(); - listener.completed(); - try { - final PutResult read = putListener.read(); - try (final ArrowBuf metadata = read.getApplicationMetadata()) { - final DoPutUpdateResult doPutUpdateResult = - DoPutUpdateResult.parseFrom(metadata.nioBuffer()); - return doPutUpdateResult.getRecordCount(); +// try(final Timer.Context context = psUpdate.time()) { + final FlightDescriptor descriptor = FlightDescriptor + .command(Any.pack(CommandPreparedStatementUpdate.newBuilder() + .setPreparedStatementHandle(preparedStatementResult.getPreparedStatementHandle()) + .build()) + .toByteArray()); + setParameters(parameterBindingRoot == null ? VectorSchemaRoot.of() : parameterBindingRoot); + final SyncPutListener putListener = new SyncPutListener(); + final FlightClient.ClientStreamListener listener = + client.startPut(descriptor, parameterBindingRoot, putListener, options); + listener.putNext(); + listener.completed(); + try { + final PutResult read = putListener.read(); + try (final ArrowBuf metadata = read.getApplicationMetadata()) { + final DoPutUpdateResult doPutUpdateResult = + DoPutUpdateResult.parseFrom(metadata.nioBuffer()); + return doPutUpdateResult.getRecordCount(); + } + } catch (final InterruptedException | ExecutionException e) { + throw CallStatus.CANCELLED.withCause(e).toRuntimeException(); + } catch (final InvalidProtocolBufferException e) { + throw CallStatus.INVALID_ARGUMENT.withCause(e).toRuntimeException(); } - } catch (final InterruptedException | ExecutionException e) { - throw CallStatus.CANCELLED.withCause(e).toRuntimeException(); - } catch (final InvalidProtocolBufferException e) { - throw CallStatus.INVALID_ARGUMENT.withCause(e).toRuntimeException(); - } +// } } /** @@ -1074,18 +1122,20 @@ public void close(final CallOption... options) { return; } isClosed = true; - final Action action = new Action( - FlightSqlUtils.FLIGHT_SQL_CLOSE_PREPARED_STATEMENT.getType(), - Any.pack(ActionClosePreparedStatementRequest.newBuilder() - .setPreparedStatementHandle(preparedStatementResult.getPreparedStatementHandle()) - .build()) - .toByteArray()); - final Iterator closePreparedStatementResults = client.doAction(action, options); - closePreparedStatementResults.forEachRemaining(result -> { - }); - if (parameterBindingRoot != null) { - parameterBindingRoot.close(); - } +// try(final Timer.Context context = psClose.time()) { + final Action action = new Action( + FlightSqlUtils.FLIGHT_SQL_CLOSE_PREPARED_STATEMENT.getType(), + Any.pack(ActionClosePreparedStatementRequest.newBuilder() + .setPreparedStatementHandle(preparedStatementResult.getPreparedStatementHandle()) + .build()) + .toByteArray()); + final Iterator closePreparedStatementResults = client.doAction(action, options); + closePreparedStatementResults.forEachRemaining(result -> { + }); + if (parameterBindingRoot != null) { + parameterBindingRoot.close(); + } +// } } @Override diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlUtils.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlUtils.java index 532921a8ac6e7..5747f7f1e9360 100644 --- a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlUtils.java +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlUtils.java @@ -19,6 +19,7 @@ import java.util.List; +import com.codahale.metrics.MetricRegistry; import org.apache.arrow.flight.ActionType; import org.apache.arrow.flight.CallStatus; diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java index d2f73b63737a2..485438bef6e97 100644 --- a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/TestFlightSql.java @@ -37,26 +37,25 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; import java.util.stream.IntStream; +import org.apache.arrow.flight.auth2.BasicAuthCredentialWriter; +import org.apache.arrow.flight.auth2.ClientBearerHeaderHandler; +import org.apache.arrow.flight.auth2.ClientIncomingAuthHeaderMiddleware; +import org.apache.arrow.flight.client.ClientCookieMiddleware; +import org.apache.arrow.flight.grpc.CredentialCallOption; import org.apache.arrow.flight.sql.FlightSqlClient; import org.apache.arrow.flight.sql.FlightSqlClient.PreparedStatement; import org.apache.arrow.flight.sql.FlightSqlColumnMetadata; import org.apache.arrow.flight.sql.FlightSqlProducer; -import org.apache.arrow.flight.sql.example.FlightSqlExample; import org.apache.arrow.flight.sql.impl.FlightSql; import org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportedCaseSensitivity; import org.apache.arrow.flight.sql.util.TableRef; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; -import org.apache.arrow.vector.BitVector; -import org.apache.arrow.vector.FieldVector; -import org.apache.arrow.vector.IntVector; -import org.apache.arrow.vector.UInt1Vector; -import org.apache.arrow.vector.UInt4Vector; -import org.apache.arrow.vector.VarBinaryVector; -import org.apache.arrow.vector.VarCharVector; -import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.*; import org.apache.arrow.vector.complex.DenseUnionVector; import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.ipc.ReadChannel; @@ -93,20 +92,12 @@ public class TestFlightSql { private static final Map GET_SQL_INFO_EXPECTED_RESULTS_MAP = new LinkedHashMap<>(); private static final String LOCALHOST = "localhost"; private static BufferAllocator allocator; - private static FlightServer server; - private static FlightSqlClient sqlClient; @BeforeAll public static void setUp() throws Exception { allocator = new RootAllocator(Integer.MAX_VALUE); final Location serverLocation = Location.forGrpcInsecure(LOCALHOST, 0); - server = FlightServer.builder(allocator, serverLocation, new FlightSqlExample(serverLocation)) - .build() - .start(); - - final Location clientLocation = Location.forGrpcInsecure(LOCALHOST, server.getPort()); - sqlClient = new FlightSqlClient(FlightClient.builder(allocator, clientLocation).build()); GET_SQL_INFO_EXPECTED_RESULTS_MAP .put(Integer.toString(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_NAME_VALUE), "Apache Derby"); @@ -134,11 +125,6 @@ public static void setUp() throws Exception { Integer.toString(SqlSupportedCaseSensitivity.SQL_CASE_SENSITIVITY_CASE_INSENSITIVE_VALUE)); } - @AfterAll - public static void tearDown() throws Exception { - close(sqlClient, server, allocator); - } - private static List> getNonConformingResultsForGetSqlInfo(final List> results) { return getNonConformingResultsForGetSqlInfo(results, FlightSql.SqlInfo.FLIGHT_SQL_SERVER_NAME, @@ -173,822 +159,55 @@ private static List> getNonConformingResultsForGetSqlInfo( } @Test - public void testGetTablesSchema() { - final FlightInfo info = sqlClient.getTables(null, null, null, null, true); - MatcherAssert.assertThat(info.getSchema(), is(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA)); - } - - @Test - public void testGetTablesSchemaExcludeSchema() { - final FlightInfo info = sqlClient.getTables(null, null, null, null, false); - MatcherAssert.assertThat(info.getSchema(), is(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA_NO_SCHEMA)); - } - - @Test - public void testGetTablesResultNoSchema() throws Exception { - try (final FlightStream stream = - sqlClient.getStream( - sqlClient.getTables(null, null, null, null, false) - .getEndpoints().get(0).getTicket())) { - Assertions.assertAll( - () -> { - MatcherAssert.assertThat(stream.getSchema(), is(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA_NO_SCHEMA)); - }, - () -> { - final List> results = getResults(stream); - final List> expectedResults = ImmutableList.of( - // catalog_name | schema_name | table_name | table_type | table_schema - asList(null /* TODO No catalog yet */, "SYS", "SYSALIASES", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSCHECKS", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSCOLPERMS", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSCOLUMNS", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSCONGLOMERATES", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSCONSTRAINTS", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSDEPENDS", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSFILES", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSFOREIGNKEYS", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSKEYS", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSPERMS", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSROLES", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSROUTINEPERMS", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSSCHEMAS", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSSEQUENCES", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSSTATEMENTS", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSSTATISTICS", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSTABLEPERMS", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSTABLES", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSTRIGGERS", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSUSERS", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYS", "SYSVIEWS", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "SYSIBM", "SYSDUMMY1", "SYSTEM TABLE"), - asList(null /* TODO No catalog yet */, "APP", "FOREIGNTABLE", "TABLE"), - asList(null /* TODO No catalog yet */, "APP", "INTTABLE", "TABLE")); - MatcherAssert.assertThat(results, is(expectedResults)); - } - ); - } - } - - @Test - public void testGetTablesResultFilteredNoSchema() throws Exception { - try (final FlightStream stream = - sqlClient.getStream( - sqlClient.getTables(null, null, null, singletonList("TABLE"), false) - .getEndpoints().get(0).getTicket())) { - - Assertions.assertAll( - () -> MatcherAssert.assertThat(stream.getSchema(), is(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA_NO_SCHEMA)), - () -> { - final List> results = getResults(stream); - final List> expectedResults = ImmutableList.of( - // catalog_name | schema_name | table_name | table_type | table_schema - asList(null /* TODO No catalog yet */, "APP", "FOREIGNTABLE", "TABLE"), - asList(null /* TODO No catalog yet */, "APP", "INTTABLE", "TABLE")); - MatcherAssert.assertThat(results, is(expectedResults)); - } - ); - } - } - - @Test - public void testGetTablesResultFilteredWithSchema() throws Exception { - try (final FlightStream stream = - sqlClient.getStream( - sqlClient.getTables(null, null, null, singletonList("TABLE"), true) - .getEndpoints().get(0).getTicket())) { - Assertions.assertAll( - () -> MatcherAssert.assertThat(stream.getSchema(), is(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA)), - () -> { - MatcherAssert.assertThat(stream.getSchema(), is(FlightSqlProducer.Schemas.GET_TABLES_SCHEMA)); - final List> results = getResults(stream); - final List> expectedResults = ImmutableList.of( - // catalog_name | schema_name | table_name | table_type | table_schema - asList( - null /* TODO No catalog yet */, - "APP", - "FOREIGNTABLE", - "TABLE", - new Schema(asList( - new Field("ID", new FieldType(false, MinorType.INT.getType(), null, - new FlightSqlColumnMetadata.Builder() - .catalogName("") - .typeName("INTEGER") - .schemaName("APP") - .tableName("FOREIGNTABLE") - .precision(10) - .scale(0) - .isAutoIncrement(true) - .build().getMetadataMap()), null), - new Field("FOREIGNNAME", new FieldType(true, MinorType.VARCHAR.getType(), null, - new FlightSqlColumnMetadata.Builder() - .catalogName("") - .typeName("VARCHAR") - .schemaName("APP") - .tableName("FOREIGNTABLE") - .precision(100) - .scale(0) - .isAutoIncrement(false) - .build().getMetadataMap()), null), - new Field("VALUE", new FieldType(true, MinorType.INT.getType(), null, - new FlightSqlColumnMetadata.Builder() - .catalogName("") - .typeName("INTEGER") - .schemaName("APP") - .tableName("FOREIGNTABLE") - .precision(10) - .scale(0) - .isAutoIncrement(false) - .build().getMetadataMap()), null))).toJson()), - asList( - null /* TODO No catalog yet */, - "APP", - "INTTABLE", - "TABLE", - new Schema(asList( - new Field("ID", new FieldType(false, MinorType.INT.getType(), null, - new FlightSqlColumnMetadata.Builder() - .catalogName("") - .typeName("INTEGER") - .schemaName("APP") - .tableName("INTTABLE") - .precision(10) - .scale(0) - .isAutoIncrement(true) - .build().getMetadataMap()), null), - new Field("KEYNAME", new FieldType(true, MinorType.VARCHAR.getType(), null, - new FlightSqlColumnMetadata.Builder() - .catalogName("") - .typeName("VARCHAR") - .schemaName("APP") - .tableName("INTTABLE") - .precision(100) - .scale(0) - .isAutoIncrement(false) - .build().getMetadataMap()), null), - new Field("VALUE", new FieldType(true, MinorType.INT.getType(), null, - new FlightSqlColumnMetadata.Builder() - .catalogName("") - .typeName("INTEGER") - .schemaName("APP") - .tableName("INTTABLE") - .precision(10) - .scale(0) - .isAutoIncrement(false) - .build().getMetadataMap()), null), - new Field("FOREIGNID", new FieldType(true, MinorType.INT.getType(), null, - new FlightSqlColumnMetadata.Builder() - .catalogName("") - .typeName("INTEGER") - .schemaName("APP") - .tableName("INTTABLE") - .precision(10) - .scale(0) - .isAutoIncrement(false) - .build().getMetadataMap()), null))).toJson())); - MatcherAssert.assertThat(results, is(expectedResults)); - } - ); - } - } - - @Test - public void testSimplePreparedStatementSchema() throws Exception { - try (final PreparedStatement preparedStatement = sqlClient.prepare("SELECT * FROM intTable")) { - Assertions.assertAll( - () -> { - final Schema actualSchema = preparedStatement.getResultSetSchema(); - MatcherAssert.assertThat(actualSchema, is(SCHEMA_INT_TABLE)); - - }, - () -> { - final FlightInfo info = preparedStatement.execute(); - MatcherAssert.assertThat(info.getSchema(), is(SCHEMA_INT_TABLE)); - } - ); - } - } - - @Test - public void testSimplePreparedStatementResults() throws Exception { - try (final PreparedStatement preparedStatement = sqlClient.prepare("SELECT * FROM intTable"); - final FlightStream stream = sqlClient.getStream( - preparedStatement.execute().getEndpoints().get(0).getTicket())) { - Assertions.assertAll( - () -> MatcherAssert.assertThat(stream.getSchema(), is(SCHEMA_INT_TABLE)), - () -> MatcherAssert.assertThat(getResults(stream), is(EXPECTED_RESULTS_FOR_STAR_SELECT_QUERY)) - ); - } - } - - @Test - public void testSimplePreparedStatementResultsWithParameterBinding() throws Exception { - try (PreparedStatement prepare = sqlClient.prepare("SELECT * FROM intTable WHERE id = ?")) { - final Schema parameterSchema = prepare.getParameterSchema(); - try (final VectorSchemaRoot insertRoot = VectorSchemaRoot.create(parameterSchema, allocator)) { - insertRoot.allocateNew(); - - final IntVector valueVector = (IntVector) insertRoot.getVector(0); - valueVector.setSafe(0, 1); - insertRoot.setRowCount(1); - - prepare.setParameters(insertRoot); - FlightInfo flightInfo = prepare.execute(); - - FlightStream stream = sqlClient.getStream(flightInfo - .getEndpoints() - .get(0).getTicket()); - - Assertions.assertAll( - () -> MatcherAssert.assertThat(stream.getSchema(), is(SCHEMA_INT_TABLE)), - () -> MatcherAssert.assertThat(getResults(stream), is(EXPECTED_RESULTS_FOR_PARAMETER_BINDING)) - ); - } - } - } - - @Test - public void testSimplePreparedStatementUpdateResults() throws SQLException { - try (PreparedStatement prepare = sqlClient.prepare("INSERT INTO INTTABLE (keyName, value ) VALUES (?, ?)"); - PreparedStatement deletePrepare = sqlClient.prepare("DELETE FROM INTTABLE WHERE keyName = ?")) { - final Schema parameterSchema = prepare.getParameterSchema(); - try (final VectorSchemaRoot insertRoot = VectorSchemaRoot.create(parameterSchema, allocator)) { - final VarCharVector varCharVector = (VarCharVector) insertRoot.getVector(0); - final IntVector valueVector = (IntVector) insertRoot.getVector(1); - final int counter = 10; - insertRoot.allocateNew(); - - final IntStream range = IntStream.range(0, counter); - - range.forEach(i -> { - valueVector.setSafe(i, i * counter); - varCharVector.setSafe(i, new Text("value" + i)); - }); - - insertRoot.setRowCount(counter); - - prepare.setParameters(insertRoot); - final long updatedRows = prepare.executeUpdate(); - - final long deletedRows; - try (final VectorSchemaRoot deleteRoot = VectorSchemaRoot.of(varCharVector)) { - deletePrepare.setParameters(deleteRoot); - deletedRows = deletePrepare.executeUpdate(); - } - Assertions.assertAll( - () -> MatcherAssert.assertThat(updatedRows, is(10L)), - () -> MatcherAssert.assertThat(deletedRows, is(10L)) - ); - } - } - } - - @Test - public void testSimplePreparedStatementUpdateResultsWithoutParameters() throws SQLException { - try (PreparedStatement prepare = sqlClient - .prepare("INSERT INTO INTTABLE (keyName, value ) VALUES ('test', 1000)"); - PreparedStatement deletePrepare = sqlClient.prepare("DELETE FROM INTTABLE WHERE keyName = 'test'")) { - final long updatedRows = prepare.executeUpdate(); - - final long deletedRows = deletePrepare.executeUpdate(); - - Assertions.assertAll( - () -> MatcherAssert.assertThat(updatedRows, is(1L)), - () -> MatcherAssert.assertThat(deletedRows, is(1L)) - ); - } - } - - @Test - public void testSimplePreparedStatementClosesProperly() { - final PreparedStatement preparedStatement = sqlClient.prepare("SELECT * FROM intTable"); - Assertions.assertAll( - () -> { - MatcherAssert.assertThat(preparedStatement.isClosed(), is(false)); - }, - () -> { - preparedStatement.close(); - MatcherAssert.assertThat(preparedStatement.isClosed(), is(true)); - } - ); - } - - @Test - public void testGetCatalogsSchema() { - final FlightInfo info = sqlClient.getCatalogs(); - MatcherAssert.assertThat(info.getSchema(), is(FlightSqlProducer.Schemas.GET_CATALOGS_SCHEMA)); - } - - @Test - public void testGetCatalogsResults() throws Exception { - try (final FlightStream stream = - sqlClient.getStream(sqlClient.getCatalogs().getEndpoints().get(0).getTicket())) { - Assertions.assertAll( - () -> MatcherAssert.assertThat(stream.getSchema(), is(FlightSqlProducer.Schemas.GET_CATALOGS_SCHEMA)), - () -> { - List> catalogs = getResults(stream); - MatcherAssert.assertThat(catalogs, is(emptyList())); - } - ); - } - } - - @Test - public void testGetTableTypesSchema() { - final FlightInfo info = sqlClient.getTableTypes(); - MatcherAssert.assertThat(info.getSchema(), is(FlightSqlProducer.Schemas.GET_TABLE_TYPES_SCHEMA)); - } - - @Test - public void testGetTableTypesResult() throws Exception { - try (final FlightStream stream = - sqlClient.getStream(sqlClient.getTableTypes().getEndpoints().get(0).getTicket())) { - Assertions.assertAll( - () -> { - MatcherAssert.assertThat(stream.getSchema(), is(FlightSqlProducer.Schemas.GET_TABLE_TYPES_SCHEMA)); - }, - () -> { - final List> tableTypes = getResults(stream); - final List> expectedTableTypes = ImmutableList.of( - // table_type - singletonList("SYNONYM"), - singletonList("SYSTEM TABLE"), - singletonList("TABLE"), - singletonList("VIEW") - ); - MatcherAssert.assertThat(tableTypes, is(expectedTableTypes)); - } - ); - } - } - - @Test - public void testGetSchemasSchema() { - final FlightInfo info = sqlClient.getSchemas(null, null); - MatcherAssert.assertThat(info.getSchema(), is(FlightSqlProducer.Schemas.GET_SCHEMAS_SCHEMA)); - } - - @Test - public void testGetSchemasResult() throws Exception { - try (final FlightStream stream = - sqlClient.getStream(sqlClient.getSchemas(null, null).getEndpoints().get(0).getTicket())) { - Assertions.assertAll( - () -> { - MatcherAssert.assertThat(stream.getSchema(), is(FlightSqlProducer.Schemas.GET_SCHEMAS_SCHEMA)); - }, - () -> { - final List> schemas = getResults(stream); - final List> expectedSchemas = ImmutableList.of( - // catalog_name | schema_name - asList(null /* TODO Add catalog. */, "APP"), - asList(null /* TODO Add catalog. */, "NULLID"), - asList(null /* TODO Add catalog. */, "SQLJ"), - asList(null /* TODO Add catalog. */, "SYS"), - asList(null /* TODO Add catalog. */, "SYSCAT"), - asList(null /* TODO Add catalog. */, "SYSCS_DIAG"), - asList(null /* TODO Add catalog. */, "SYSCS_UTIL"), - asList(null /* TODO Add catalog. */, "SYSFUN"), - asList(null /* TODO Add catalog. */, "SYSIBM"), - asList(null /* TODO Add catalog. */, "SYSPROC"), - asList(null /* TODO Add catalog. */, "SYSSTAT")); - MatcherAssert.assertThat(schemas, is(expectedSchemas)); - } - ); - } - } - - @Test - public void testGetPrimaryKey() { - final FlightInfo flightInfo = sqlClient.getPrimaryKeys(TableRef.of(null, null, "INTTABLE")); - final FlightStream stream = sqlClient.getStream(flightInfo.getEndpoints().get(0).getTicket()); - - final List> results = getResults(stream); - - Assertions.assertAll( - () -> MatcherAssert.assertThat(results.size(), is(1)), - () -> { - final List result = results.get(0); - Assertions.assertAll( - () -> MatcherAssert.assertThat(result.get(0), is("")), - () -> MatcherAssert.assertThat(result.get(1), is("APP")), - () -> MatcherAssert.assertThat(result.get(2), is("INTTABLE")), - () -> MatcherAssert.assertThat(result.get(3), is("ID")), - () -> MatcherAssert.assertThat(result.get(4), is("1")), - () -> MatcherAssert.assertThat(result.get(5), notNullValue()) - ); - } - ); - } - - @Test - public void testGetSqlInfoSchema() { - final FlightInfo info = sqlClient.getSqlInfo(); - MatcherAssert.assertThat(info.getSchema(), is(FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA)); - } - - @Test - public void testGetSqlInfoResults() throws Exception { - final FlightInfo info = sqlClient.getSqlInfo(); - try (final FlightStream stream = sqlClient.getStream(info.getEndpoints().get(0).getTicket())) { - Assertions.assertAll( - () -> MatcherAssert.assertThat(stream.getSchema(), is(FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA)), - () -> MatcherAssert.assertThat(getNonConformingResultsForGetSqlInfo(getResults(stream)), is(emptyList())) - ); - } - } - - @Test - public void testGetSqlInfoResultsWithSingleArg() throws Exception { - final FlightSql.SqlInfo arg = FlightSql.SqlInfo.FLIGHT_SQL_SERVER_NAME; - final FlightInfo info = sqlClient.getSqlInfo(arg); - try (final FlightStream stream = sqlClient.getStream(info.getEndpoints().get(0).getTicket())) { - Assertions.assertAll( - () -> MatcherAssert.assertThat(stream.getSchema(), is(FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA)), - () -> MatcherAssert.assertThat(getNonConformingResultsForGetSqlInfo(getResults(stream), arg), is(emptyList())) - ); - } - } - - @Test - public void testGetSqlInfoResultsWithTwoArgs() throws Exception { - final FlightSql.SqlInfo[] args = { - FlightSql.SqlInfo.FLIGHT_SQL_SERVER_NAME, - FlightSql.SqlInfo.FLIGHT_SQL_SERVER_VERSION}; - final FlightInfo info = sqlClient.getSqlInfo(args); - try (final FlightStream stream = sqlClient.getStream(info.getEndpoints().get(0).getTicket())) { - Assertions.assertAll( - () -> MatcherAssert.assertThat( - stream.getSchema(), - is(FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA) - ), - () -> MatcherAssert.assertThat( - getNonConformingResultsForGetSqlInfo(getResults(stream), args), - is(emptyList()) - ) - ); - } - } - - @Test - public void testGetSqlInfoResultsWithThreeArgs() throws Exception { - final FlightSql.SqlInfo[] args = { - FlightSql.SqlInfo.FLIGHT_SQL_SERVER_NAME, - FlightSql.SqlInfo.FLIGHT_SQL_SERVER_VERSION, - FlightSql.SqlInfo.SQL_IDENTIFIER_QUOTE_CHAR}; - final FlightInfo info = sqlClient.getSqlInfo(args); - try (final FlightStream stream = sqlClient.getStream(info.getEndpoints().get(0).getTicket())) { - Assertions.assertAll( - () -> MatcherAssert.assertThat( - stream.getSchema(), - is(FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA) - ), - () -> MatcherAssert.assertThat( - getNonConformingResultsForGetSqlInfo(getResults(stream), args), - is(emptyList()) - ) - ); - } - } - - @Test - public void testGetCommandExportedKeys() { - final FlightStream stream = - sqlClient.getStream( - sqlClient.getExportedKeys(TableRef.of(null, null, "FOREIGNTABLE")) - .getEndpoints().get(0).getTicket()); - - final List> results = getResults(stream); - - final List> matchers = asList( - nullValue(String.class), // pk_catalog_name - is("APP"), // pk_schema_name - is("FOREIGNTABLE"), // pk_table_name - is("ID"), // pk_column_name - nullValue(String.class), // fk_catalog_name - is("APP"), // fk_schema_name - is("INTTABLE"), // fk_table_name - is("FOREIGNID"), // fk_column_name - is("1"), // key_sequence - containsString("SQL"), // fk_key_name - containsString("SQL"), // pk_key_name - is("3"), // update_rule - is("3")); // delete_rule - - final List assertions = new ArrayList<>(); - Assertions.assertEquals(1, results.size()); - for (int i = 0; i < matchers.size(); i++) { - final String actual = results.get(0).get(i); - final Matcher expected = matchers.get(i); - assertions.add(() -> MatcherAssert.assertThat(actual, expected)); - } - Assertions.assertAll(assertions); - } - - @Test - public void testGetCommandImportedKeys() { - final FlightStream stream = - sqlClient.getStream( - sqlClient.getImportedKeys(TableRef.of(null, null, "INTTABLE")) - .getEndpoints().get(0).getTicket()); - - final List> results = getResults(stream); - - final List> matchers = asList( - nullValue(String.class), // pk_catalog_name - is("APP"), // pk_schema_name - is("FOREIGNTABLE"), // pk_table_name - is("ID"), // pk_column_name - nullValue(String.class), // fk_catalog_name - is("APP"), // fk_schema_name - is("INTTABLE"), // fk_table_name - is("FOREIGNID"), // fk_column_name - is("1"), // key_sequence - containsString("SQL"), // fk_key_name - containsString("SQL"), // pk_key_name - is("3"), // update_rule - is("3")); // delete_rule - - Assertions.assertEquals(1, results.size()); - final List assertions = new ArrayList<>(); - for (int i = 0; i < matchers.size(); i++) { - final String actual = results.get(0).get(i); - final Matcher expected = matchers.get(i); - assertions.add(() -> MatcherAssert.assertThat(actual, expected)); - } - Assertions.assertAll(assertions); - } - - @Test - public void testGetTypeInfo() { - FlightInfo flightInfo = sqlClient.getXdbcTypeInfo(); - - FlightStream stream = sqlClient.getStream(flightInfo.getEndpoints().get(0).getTicket()); - - final List> results = getResults(stream); - - final List> matchers = ImmutableList.of( - asList("BIGINT", "-5", "19", null, null, emptyList().toString(), "1", "false", "2", "false", "false", "true", - "BIGINT", "0", "0", - null, null, "10", null), - asList("LONG VARCHAR FOR BIT DATA", "-4", "32700", "X'", "'", emptyList().toString(), "1", "false", "0", "true", - "false", "false", - "LONG VARCHAR FOR BIT DATA", null, null, null, null, null, null), - asList("VARCHAR () FOR BIT DATA", "-3", "32672", "X'", "'", singletonList("length").toString(), "1", "false", - "2", "true", "false", - "false", "VARCHAR () FOR BIT DATA", null, null, null, null, null, null), - asList("CHAR () FOR BIT DATA", "-2", "254", "X'", "'", singletonList("length").toString(), "1", "false", "2", - "true", "false", "false", - "CHAR () FOR BIT DATA", null, null, null, null, null, null), - asList("LONG VARCHAR", "-1", "32700", "'", "'", emptyList().toString(), "1", "true", "1", "true", "false", - "false", - "LONG VARCHAR", null, null, null, null, null, null), - asList("CHAR", "1", "254", "'", "'", singletonList("length").toString(), "1", "true", "3", "true", "false", - "false", "CHAR", null, null, - null, null, null, null), - asList("NUMERIC", "2", "31", null, null, Arrays.asList("precision", "scale").toString(), "1", "false", "2", - "false", "true", "false", - "NUMERIC", "0", "31", null, null, "10", null), - asList("DECIMAL", "3", "31", null, null, Arrays.asList("precision", "scale").toString(), "1", "false", "2", - "false", "true", "false", - "DECIMAL", "0", "31", null, null, "10", null), - asList("INTEGER", "4", "10", null, null, emptyList().toString(), "1", "false", "2", "false", "false", "true", - "INTEGER", "0", "0", - null, null, "10", null), - asList("SMALLINT", "5", "5", null, null, emptyList().toString(), "1", "false", "2", "false", "false", "true", - "SMALLINT", "0", - "0", null, null, "10", null), - asList("FLOAT", "6", "52", null, null, singletonList("precision").toString(), "1", "false", "2", "false", - "false", "false", "FLOAT", null, - null, null, null, "2", null), - asList("REAL", "7", "23", null, null, emptyList().toString(), "1", "false", "2", "false", "false", "false", - "REAL", null, null, - null, null, "2", null), - asList("DOUBLE", "8", "52", null, null, emptyList().toString(), "1", "false", "2", "false", "false", "false", - "DOUBLE", null, - null, null, null, "2", null), - asList("VARCHAR", "12", "32672", "'", "'", singletonList("length").toString(), "1", "true", "3", "true", - "false", "false", "VARCHAR", - null, null, null, null, null, null), - asList("BOOLEAN", "16", "1", null, null, emptyList().toString(), "1", "false", "2", "true", "false", "false", - "BOOLEAN", null, - null, null, null, null, null), - asList("DATE", "91", "10", "DATE'", "'", emptyList().toString(), "1", "false", "2", "true", "false", "false", - "DATE", "0", "0", - null, null, "10", null), - asList("TIME", "92", "8", "TIME'", "'", emptyList().toString(), "1", "false", "2", "true", "false", "false", - "TIME", "0", "0", - null, null, "10", null), - asList("TIMESTAMP", "93", "29", "TIMESTAMP'", "'", emptyList().toString(), "1", "false", "2", "true", "false", - "false", - "TIMESTAMP", "0", "9", null, null, "10", null), - asList("OBJECT", "2000", null, null, null, emptyList().toString(), "1", "false", "2", "true", "false", "false", - "OBJECT", null, - null, null, null, null, null), - asList("BLOB", "2004", "2147483647", null, null, singletonList("length").toString(), "1", "false", "0", null, - "false", null, "BLOB", null, - null, null, null, null, null), - asList("CLOB", "2005", "2147483647", "'", "'", singletonList("length").toString(), "1", "true", "1", null, - "false", null, "CLOB", null, - null, null, null, null, null), - asList("XML", "2009", null, null, null, emptyList().toString(), "1", "true", "0", "false", "false", "false", - "XML", null, null, - null, null, null, null)); - MatcherAssert.assertThat(results, is(matchers)); - } - - @Test - public void testGetTypeInfoWithFiltering() { - FlightInfo flightInfo = sqlClient.getXdbcTypeInfo(-5); - - FlightStream stream = sqlClient.getStream(flightInfo.getEndpoints().get(0).getTicket()); - - final List> results = getResults(stream); - - final List> matchers = ImmutableList.of( - asList("BIGINT", "-5", "19", null, null, emptyList().toString(), "1", "false", "2", "false", "false", "true", - "BIGINT", "0", "0", - null, null, "10", null)); - MatcherAssert.assertThat(results, is(matchers)); - } - - @Test - public void testGetCommandCrossReference() { - final FlightInfo flightInfo = sqlClient.getCrossReference(TableRef.of(null, null, - "FOREIGNTABLE"), TableRef.of(null, null, "INTTABLE")); - final FlightStream stream = sqlClient.getStream(flightInfo.getEndpoints().get(0).getTicket()); - - final List> results = getResults(stream); - - final List> matchers = asList( - nullValue(String.class), // pk_catalog_name - is("APP"), // pk_schema_name - is("FOREIGNTABLE"), // pk_table_name - is("ID"), // pk_column_name - nullValue(String.class), // fk_catalog_name - is("APP"), // fk_schema_name - is("INTTABLE"), // fk_table_name - is("FOREIGNID"), // fk_column_name - is("1"), // key_sequence - containsString("SQL"), // fk_key_name - containsString("SQL"), // pk_key_name - is("3"), // update_rule - is("3")); // delete_rule - - Assertions.assertEquals(1, results.size()); - final List assertions = new ArrayList<>(); - for (int i = 0; i < matchers.size(); i++) { - final String actual = results.get(0).get(i); - final Matcher expected = matchers.get(i); - assertions.add(() -> MatcherAssert.assertThat(actual, expected)); - } - Assertions.assertAll(assertions); - } - - @Test - public void testCreateStatementSchema() throws Exception { - final FlightInfo info = sqlClient.execute("SELECT * FROM intTable"); - MatcherAssert.assertThat(info.getSchema(), is(SCHEMA_INT_TABLE)); - - // Consume statement to close connection before cache eviction - try (FlightStream stream = sqlClient.getStream(info.getEndpoints().get(0).getTicket())) { - while (stream.next()) { - // Do nothing - } - } - } - - @Test - public void testCreateStatementResults() throws Exception { - try (final FlightStream stream = sqlClient - .getStream(sqlClient.execute("SELECT * FROM intTable").getEndpoints().get(0).getTicket())) { - Assertions.assertAll( - () -> { - MatcherAssert.assertThat(stream.getSchema(), is(SCHEMA_INT_TABLE)); - }, - () -> { - MatcherAssert.assertThat(getResults(stream), is(EXPECTED_RESULTS_FOR_STAR_SELECT_QUERY)); - } - ); - } - } - - List> getResults(FlightStream stream) { - final List> results = new ArrayList<>(); - while (stream.next()) { - try (final VectorSchemaRoot root = stream.getRoot()) { - final long rowCount = root.getRowCount(); - for (int i = 0; i < rowCount; ++i) { - results.add(new ArrayList<>()); - } - - root.getSchema().getFields().forEach(field -> { - try (final FieldVector fieldVector = root.getVector(field.getName())) { - if (fieldVector instanceof VarCharVector) { - final VarCharVector varcharVector = (VarCharVector) fieldVector; - for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { - final Text data = varcharVector.getObject(rowIndex); - results.get(rowIndex).add(isNull(data) ? null : data.toString()); - } - } else if (fieldVector instanceof IntVector) { - for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { - Object data = fieldVector.getObject(rowIndex); - results.get(rowIndex).add(isNull(data) ? null : Objects.toString(data)); - } - } else if (fieldVector instanceof VarBinaryVector) { - final VarBinaryVector varbinaryVector = (VarBinaryVector) fieldVector; - for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { - final byte[] data = varbinaryVector.getObject(rowIndex); - final String output; - try { - output = isNull(data) ? - null : - MessageSerializer.deserializeSchema( - new ReadChannel(Channels.newChannel(new ByteArrayInputStream(data)))).toJson(); - } catch (final IOException e) { - throw new RuntimeException("Failed to deserialize schema", e); - } - results.get(rowIndex).add(output); - } - } else if (fieldVector instanceof DenseUnionVector) { - final DenseUnionVector denseUnionVector = (DenseUnionVector) fieldVector; - for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { - final Object data = denseUnionVector.getObject(rowIndex); - results.get(rowIndex).add(isNull(data) ? null : Objects.toString(data)); - } - } else if (fieldVector instanceof ListVector) { - for (int i = 0; i < fieldVector.getValueCount(); i++) { - if (!fieldVector.isNull(i)) { - List elements = (List) ((ListVector) fieldVector).getObject(i); - List values = new ArrayList<>(); - - for (Text element : elements) { - values.add(element.toString()); - } - results.get(i).add(values.toString()); - } - } - - } else if (fieldVector instanceof UInt4Vector) { - final UInt4Vector uInt4Vector = (UInt4Vector) fieldVector; - for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { - final Object data = uInt4Vector.getObject(rowIndex); - results.get(rowIndex).add(isNull(data) ? null : Objects.toString(data)); - } - } else if (fieldVector instanceof UInt1Vector) { - final UInt1Vector uInt1Vector = (UInt1Vector) fieldVector; - for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { - final Object data = uInt1Vector.getObject(rowIndex); - results.get(rowIndex).add(isNull(data) ? null : Objects.toString(data)); - } - } else if (fieldVector instanceof BitVector) { - for (int rowIndex = 0; rowIndex < rowCount; rowIndex++) { - Object data = fieldVector.getObject(rowIndex); - results.get(rowIndex).add(isNull(data) ? null : Objects.toString(data)); - } - } else { - throw new UnsupportedOperationException("Not yet implemented"); + public void testStmt() throws Exception { + String username = "admin"; + String password = "password"; + final Location clientLocation = Location.forGrpcInsecure("127.0.0.1", 50060); + FlightClient.Builder flightBuilder = FlightClient.builder(allocator, clientLocation); + ClientIncomingAuthHeaderMiddleware.Factory authFactory = + new ClientIncomingAuthHeaderMiddleware.Factory(new ClientBearerHeaderHandler()); + flightBuilder.intercept(authFactory); + FlightClientMiddleware.Factory cookieFactory = new ClientCookieMiddleware.Factory(); + flightBuilder.intercept(cookieFactory); + FlightClient flightClient = flightBuilder.build(); + CredentialCallOption basic = new CredentialCallOption(new BasicAuthCredentialWriter(username, password)); + flightClient.handshake(basic); + + FlightSqlClient sqlClient = new FlightSqlClient(flightClient); + FlightSqlClient.PreparedStatement ps = sqlClient.prepare("select 'Hello, FlightSQL!' as salutation", authFactory.getCredentialCallOption()); + Schema resultSetSchema = ps.getResultSetSchema(); + Schema parameterSchema = ps.getParameterSchema(); + for(int i = 0; i < 10; i++) { + int rows = 0; + FlightInfo fi = ps.execute(authFactory.getCredentialCallOption()); // 3ms + Schema schema = fi.getSchema(); // 0ms + for (FlightEndpoint ep : fi.getEndpoints()) { + Ticket ticket = ep.getTicket(); // 0ms + FlightStream stream = sqlClient.getStream(ticket, authFactory.getCredentialCallOption()); // 1ms + + long start = System.currentTimeMillis(); + final VectorSchemaRoot root = stream.getRoot(); // 40ms + long end = System.currentTimeMillis(); + System.out.format("Got %d row in %dms\n", rows, end - start); + + try { + VarCharVector a = (VarCharVector) root.getVector(0); +// BigIntVector a = (BigIntVector) root.getVector(0); + while (stream.next()) { + rows = root.getRowCount(); + for (int row_idx = 0; row_idx < rows; row_idx++) { +// long res = a.get(row_idx); +// System.out.format("s_quantity=%d\n", res); + byte[] res = a.get(row_idx); + System.out.format("s_quantity=%d\n", res.length); } } - }); + } finally { + root.clear(); + } } } - - return results; + sqlClient.close(); } - @Test - public void testExecuteUpdate() { - Assertions.assertAll( - () -> { - long insertedCount = sqlClient.executeUpdate("INSERT INTO INTTABLE (keyName, value) VALUES " + - "('KEYNAME1', 1001), ('KEYNAME2', 1002), ('KEYNAME3', 1003)"); - MatcherAssert.assertThat(insertedCount, is(3L)); - - }, - () -> { - long updatedCount = sqlClient.executeUpdate("UPDATE INTTABLE SET keyName = 'KEYNAME1' " + - "WHERE keyName = 'KEYNAME2' OR keyName = 'KEYNAME3'"); - MatcherAssert.assertThat(updatedCount, is(2L)); - - }, - () -> { - long deletedCount = sqlClient.executeUpdate("DELETE FROM INTTABLE WHERE keyName = 'KEYNAME1'"); - MatcherAssert.assertThat(deletedCount, is(3L)); - } - ); - } - - @Test - public void testQueryWithNoResultsShouldNotHang() throws Exception { - try (final PreparedStatement preparedStatement = sqlClient.prepare("SELECT * FROM intTable WHERE 1 = 0"); - final FlightStream stream = sqlClient - .getStream(preparedStatement.execute().getEndpoints().get(0).getTicket())) { - Assertions.assertAll( - () -> MatcherAssert.assertThat(stream.getSchema(), is(SCHEMA_INT_TABLE)), - () -> { - final List> result = getResults(stream); - MatcherAssert.assertThat(result, is(emptyList())); - } - ); - } - } } diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java deleted file mode 100644 index fe1e1445afc6e..0000000000000 --- a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java +++ /dev/null @@ -1,1295 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.arrow.flight.sql.example; - -import static com.google.common.base.Strings.emptyToNull; -import static com.google.protobuf.Any.pack; -import static com.google.protobuf.ByteString.copyFrom; -import static java.lang.String.format; -import static java.nio.charset.StandardCharsets.UTF_8; -import static java.util.Collections.singletonList; -import static java.util.Objects.isNull; -import static java.util.UUID.randomUUID; -import static java.util.stream.IntStream.range; -import static org.apache.arrow.adapter.jdbc.JdbcToArrow.sqlToArrowVectorIterator; -import static org.apache.arrow.adapter.jdbc.JdbcToArrowUtils.jdbcToArrowSchema; -import static org.apache.arrow.flight.sql.impl.FlightSql.*; -import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetCrossReference; -import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetDbSchemas; -import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetExportedKeys; -import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetImportedKeys; -import static org.apache.arrow.flight.sql.impl.FlightSql.DoPutUpdateResult; -import static org.apache.arrow.flight.sql.impl.FlightSql.TicketStatementQuery; -import static org.apache.arrow.util.Preconditions.checkState; -import static org.slf4j.LoggerFactory.getLogger; - -import java.io.ByteArrayOutputStream; -import java.io.File; -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.channels.Channels; -import java.nio.file.Files; -import java.nio.file.NoSuchFileException; -import java.nio.file.Path; -import java.nio.file.Paths; -import java.sql.Connection; -import java.sql.DatabaseMetaData; -import java.sql.DriverManager; -import java.sql.PreparedStatement; -import java.sql.ResultSet; -import java.sql.ResultSetMetaData; -import java.sql.SQLException; -import java.sql.SQLSyntaxErrorException; -import java.sql.Statement; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Calendar; -import java.util.Comparator; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Map.Entry; -import java.util.Objects; -import java.util.Properties; -import java.util.Set; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.TimeUnit; -import java.util.function.BiConsumer; -import java.util.function.Consumer; -import java.util.function.Predicate; -import java.util.stream.Stream; - -import org.apache.arrow.adapter.jdbc.ArrowVectorIterator; -import org.apache.arrow.adapter.jdbc.JdbcFieldInfo; -import org.apache.arrow.adapter.jdbc.JdbcParameterBinder; -import org.apache.arrow.adapter.jdbc.JdbcToArrowUtils; -import org.apache.arrow.flight.CallStatus; -import org.apache.arrow.flight.Criteria; -import org.apache.arrow.flight.FlightDescriptor; -import org.apache.arrow.flight.FlightEndpoint; -import org.apache.arrow.flight.FlightInfo; -import org.apache.arrow.flight.FlightServer; -import org.apache.arrow.flight.FlightStream; -import org.apache.arrow.flight.Location; -import org.apache.arrow.flight.PutResult; -import org.apache.arrow.flight.Result; -import org.apache.arrow.flight.SchemaResult; -import org.apache.arrow.flight.Ticket; -import org.apache.arrow.flight.sql.FlightSqlColumnMetadata; -import org.apache.arrow.flight.sql.FlightSqlProducer; -import org.apache.arrow.flight.sql.SqlInfoBuilder; -import org.apache.arrow.flight.sql.impl.FlightSql.ActionClosePreparedStatementRequest; -import org.apache.arrow.flight.sql.impl.FlightSql.ActionCreatePreparedStatementRequest; -import org.apache.arrow.flight.sql.impl.FlightSql.ActionCreatePreparedStatementResult; -import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetCatalogs; -import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetPrimaryKeys; -import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetSqlInfo; -import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetTableTypes; -import org.apache.arrow.flight.sql.impl.FlightSql.CommandGetTables; -import org.apache.arrow.flight.sql.impl.FlightSql.CommandPreparedStatementQuery; -import org.apache.arrow.flight.sql.impl.FlightSql.CommandPreparedStatementUpdate; -import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementQuery; -import org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementUpdate; -import org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportedCaseSensitivity; -import org.apache.arrow.memory.ArrowBuf; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.memory.RootAllocator; -import org.apache.arrow.util.AutoCloseables; -import org.apache.arrow.util.Preconditions; -import org.apache.arrow.vector.BitVector; -import org.apache.arrow.vector.FieldVector; -import org.apache.arrow.vector.IntVector; -import org.apache.arrow.vector.UInt1Vector; -import org.apache.arrow.vector.VarBinaryVector; -import org.apache.arrow.vector.VarCharVector; -import org.apache.arrow.vector.VectorLoader; -import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.VectorUnloader; -import org.apache.arrow.vector.complex.ListVector; -import org.apache.arrow.vector.complex.impl.UnionListWriter; -import org.apache.arrow.vector.ipc.WriteChannel; -import org.apache.arrow.vector.ipc.message.MessageSerializer; -import org.apache.arrow.vector.types.Types.MinorType; -import org.apache.arrow.vector.types.pojo.ArrowType; -import org.apache.arrow.vector.types.pojo.Field; -import org.apache.arrow.vector.types.pojo.FieldType; -import org.apache.arrow.vector.types.pojo.Schema; -import org.apache.arrow.vector.util.Text; -import org.apache.commons.dbcp2.ConnectionFactory; -import org.apache.commons.dbcp2.DriverManagerConnectionFactory; -import org.apache.commons.dbcp2.PoolableConnection; -import org.apache.commons.dbcp2.PoolableConnectionFactory; -import org.apache.commons.dbcp2.PoolingDataSource; -import org.apache.commons.pool2.ObjectPool; -import org.apache.commons.pool2.impl.GenericObjectPool; -import org.slf4j.Logger; - -import com.google.common.cache.Cache; -import com.google.common.cache.CacheBuilder; -import com.google.common.cache.RemovalListener; -import com.google.common.cache.RemovalNotification; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import com.google.protobuf.ByteString; -import com.google.protobuf.Message; -import com.google.protobuf.ProtocolStringList; - -/** - * Example {@link FlightSqlProducer} implementation showing an Apache Derby backed Flight SQL server that generally - * supports all current features of Flight SQL. - */ -public class FlightSqlExample implements FlightSqlProducer, AutoCloseable { - private static final String DATABASE_URI = "jdbc:derby:target/derbyDB"; - private static final Logger LOGGER = getLogger(FlightSqlExample.class); - private static final Calendar DEFAULT_CALENDAR = JdbcToArrowUtils.getUtcCalendar(); - // ARROW-15315: Use ExecutorService to simulate an async scenario - private final ExecutorService executorService = Executors.newFixedThreadPool(10); - private final Location location; - private final PoolingDataSource dataSource; - private final BufferAllocator rootAllocator = new RootAllocator(); - private final Cache> preparedStatementLoadingCache; - private final Cache> statementLoadingCache; - private final SqlInfoBuilder sqlInfoBuilder; - - public static void main(String[] args) throws Exception { - Location location = Location.forGrpcInsecure("localhost", 55555); - final FlightSqlExample example = new FlightSqlExample(location); - Location listenLocation = Location.forGrpcInsecure("0.0.0.0", 55555); - try (final BufferAllocator allocator = new RootAllocator(); - final FlightServer server = FlightServer.builder(allocator, listenLocation, example).build()) { - server.start(); - server.awaitTermination(); - } - } - - public FlightSqlExample(final Location location) { - // TODO Constructor should not be doing work. - checkState( - removeDerbyDatabaseIfExists() && populateDerbyDatabase(), - "Failed to reset Derby database!"); - final ConnectionFactory connectionFactory = - new DriverManagerConnectionFactory(DATABASE_URI, new Properties()); - final PoolableConnectionFactory poolableConnectionFactory = - new PoolableConnectionFactory(connectionFactory, null); - final ObjectPool connectionPool = new GenericObjectPool<>(poolableConnectionFactory); - - poolableConnectionFactory.setPool(connectionPool); - // PoolingDataSource takes ownership of `connectionPool` - dataSource = new PoolingDataSource<>(connectionPool); - - preparedStatementLoadingCache = - CacheBuilder.newBuilder() - .maximumSize(100) - .expireAfterWrite(10, TimeUnit.MINUTES) - .removalListener(new StatementRemovalListener()) - .build(); - - statementLoadingCache = - CacheBuilder.newBuilder() - .maximumSize(100) - .expireAfterWrite(10, TimeUnit.MINUTES) - .removalListener(new StatementRemovalListener<>()) - .build(); - - this.location = location; - - sqlInfoBuilder = new SqlInfoBuilder(); - try (final Connection connection = dataSource.getConnection()) { - final DatabaseMetaData metaData = connection.getMetaData(); - - sqlInfoBuilder.withFlightSqlServerName(metaData.getDatabaseProductName()) - .withFlightSqlServerVersion(metaData.getDatabaseProductVersion()) - .withFlightSqlServerArrowVersion(metaData.getDriverVersion()) - .withFlightSqlServerReadOnly(metaData.isReadOnly()) - .withFlightSqlServerSql(true) - .withFlightSqlServerSubstrait(false) - .withFlightSqlServerTransaction(SqlSupportedTransaction.SQL_SUPPORTED_TRANSACTION_NONE) - .withSqlIdentifierQuoteChar(metaData.getIdentifierQuoteString()) - .withSqlDdlCatalog(metaData.supportsCatalogsInDataManipulation()) - .withSqlDdlSchema( metaData.supportsSchemasInDataManipulation()) - .withSqlDdlTable( metaData.allTablesAreSelectable()) - .withSqlIdentifierCase(metaData.storesMixedCaseIdentifiers() ? - SqlSupportedCaseSensitivity.SQL_CASE_SENSITIVITY_CASE_INSENSITIVE : - metaData.storesUpperCaseIdentifiers() ? - SqlSupportedCaseSensitivity.SQL_CASE_SENSITIVITY_UPPERCASE : - metaData.storesLowerCaseIdentifiers() ? - SqlSupportedCaseSensitivity.SQL_CASE_SENSITIVITY_LOWERCASE : - SqlSupportedCaseSensitivity.SQL_CASE_SENSITIVITY_UNKNOWN) - .withSqlQuotedIdentifierCase(metaData.storesMixedCaseQuotedIdentifiers() ? - SqlSupportedCaseSensitivity.SQL_CASE_SENSITIVITY_CASE_INSENSITIVE : - metaData.storesUpperCaseQuotedIdentifiers() ? - SqlSupportedCaseSensitivity.SQL_CASE_SENSITIVITY_UPPERCASE : - metaData.storesLowerCaseQuotedIdentifiers() ? - SqlSupportedCaseSensitivity.SQL_CASE_SENSITIVITY_LOWERCASE : - SqlSupportedCaseSensitivity.SQL_CASE_SENSITIVITY_UNKNOWN); - } catch (SQLException e) { - throw new RuntimeException(e); - } - - } - - private static boolean removeDerbyDatabaseIfExists() { - boolean wasSuccess; - final Path path = Paths.get("target" + File.separator + "derbyDB"); - - try (final Stream walk = Files.walk(path)) { - /* - * Iterate over all paths to delete, mapping each path to the outcome of its own - * deletion as a boolean representing whether or not each individual operation was - * successful; then reduce all booleans into a single answer, and store that into - * `wasSuccess`, which will later be returned by this method. - * If for whatever reason the resulting `Stream` is empty, throw an `IOException`; - * this not expected. - */ - wasSuccess = walk.sorted(Comparator.reverseOrder()).map(Path::toFile).map(File::delete) - .reduce(Boolean::logicalAnd).orElseThrow(IOException::new); - } catch (IOException e) { - /* - * The only acceptable scenario for an `IOException` to be thrown here is if - * an attempt to delete an non-existing file takes place -- which should be - * alright, since they would be deleted anyway. - */ - if (!(wasSuccess = e instanceof NoSuchFileException)) { - LOGGER.error(format("Failed attempt to clear DerbyDB: <%s>", e.getMessage()), e); - } - } - - return wasSuccess; - } - - private static boolean populateDerbyDatabase() { - try (final Connection connection = DriverManager.getConnection("jdbc:derby:target/derbyDB;create=true"); - Statement statement = connection.createStatement()) { - statement.execute("CREATE TABLE foreignTable (" + - "id INT not null primary key GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1), " + - "foreignName varchar(100), " + - "value int)"); - statement.execute("CREATE TABLE intTable (" + - "id INT not null primary key GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1), " + - "keyName varchar(100), " + - "value int, " + - "foreignId int references foreignTable(id))"); - statement.execute("INSERT INTO foreignTable (foreignName, value) VALUES ('keyOne', 1)"); - statement.execute("INSERT INTO foreignTable (foreignName, value) VALUES ('keyTwo', 0)"); - statement.execute("INSERT INTO foreignTable (foreignName, value) VALUES ('keyThree', -1)"); - statement.execute("INSERT INTO intTable (keyName, value, foreignId) VALUES ('one', 1, 1)"); - statement.execute("INSERT INTO intTable (keyName, value, foreignId) VALUES ('zero', 0, 1)"); - statement.execute("INSERT INTO intTable (keyName, value, foreignId) VALUES ('negative one', -1, 1)"); - } catch (final SQLException e) { - LOGGER.error(format("Failed attempt to populate DerbyDB: <%s>", e.getMessage()), e); - return false; - } - return true; - } - - private static ArrowType getArrowTypeFromJdbcType(final int jdbcDataType, final int precision, final int scale) { - final ArrowType type = - JdbcToArrowUtils.getArrowTypeFromJdbcType(new JdbcFieldInfo(jdbcDataType, precision, scale), DEFAULT_CALENDAR); - return isNull(type) ? ArrowType.Utf8.INSTANCE : type; - } - - private static void saveToVector(final Byte data, final UInt1Vector vector, final int index) { - vectorConsumer( - data, - vector, - fieldVector -> fieldVector.setNull(index), - (theData, fieldVector) -> fieldVector.setSafe(index, theData)); - } - - private static void saveToVector(final Byte data, final BitVector vector, final int index) { - vectorConsumer( - data, - vector, - fieldVector -> fieldVector.setNull(index), - (theData, fieldVector) -> fieldVector.setSafe(index, theData)); - } - - private static void saveToVector(final String data, final VarCharVector vector, final int index) { - preconditionCheckSaveToVector(vector, index); - vectorConsumer(data, vector, fieldVector -> fieldVector.setNull(index), - (theData, fieldVector) -> fieldVector.setSafe(index, new Text(theData))); - } - - private static void saveToVector(final Integer data, final IntVector vector, final int index) { - preconditionCheckSaveToVector(vector, index); - vectorConsumer(data, vector, fieldVector -> fieldVector.setNull(index), - (theData, fieldVector) -> fieldVector.setSafe(index, theData)); - } - - private static void saveToVector(final byte[] data, final VarBinaryVector vector, final int index) { - preconditionCheckSaveToVector(vector, index); - vectorConsumer(data, vector, fieldVector -> fieldVector.setNull(index), - (theData, fieldVector) -> fieldVector.setSafe(index, theData)); - } - - private static void preconditionCheckSaveToVector(final FieldVector vector, final int index) { - Objects.requireNonNull(vector, "vector cannot be null."); - checkState(index >= 0, "Index must be a positive number!"); - } - - private static void vectorConsumer(final T data, final V vector, - final Consumer consumerIfNullable, - final BiConsumer defaultConsumer) { - if (isNull(data)) { - consumerIfNullable.accept(vector); - return; - } - defaultConsumer.accept(data, vector); - } - - private static VectorSchemaRoot getSchemasRoot(final ResultSet data, final BufferAllocator allocator) - throws SQLException { - final VarCharVector catalogs = new VarCharVector("catalog_name", allocator); - final VarCharVector schemas = - new VarCharVector("db_schema_name", FieldType.notNullable(MinorType.VARCHAR.getType()), allocator); - final List vectors = ImmutableList.of(catalogs, schemas); - vectors.forEach(FieldVector::allocateNew); - final Map vectorToColumnName = ImmutableMap.of( - catalogs, "TABLE_CATALOG", - schemas, "TABLE_SCHEM"); - saveToVectors(vectorToColumnName, data); - final int rows = vectors.stream().map(FieldVector::getValueCount).findAny().orElseThrow(IllegalStateException::new); - vectors.forEach(vector -> vector.setValueCount(rows)); - return new VectorSchemaRoot(vectors); - } - - private static int saveToVectors(final Map vectorToColumnName, - final ResultSet data, boolean emptyToNull) - throws SQLException { - Predicate alwaysTrue = (resultSet) -> true; - return saveToVectors(vectorToColumnName, data, emptyToNull, alwaysTrue); - } - - private static int saveToVectors(final Map vectorToColumnName, - final ResultSet data, boolean emptyToNull, - Predicate resultSetPredicate) - throws SQLException { - Objects.requireNonNull(vectorToColumnName, "vectorToColumnName cannot be null."); - Objects.requireNonNull(data, "data cannot be null."); - final Set> entrySet = vectorToColumnName.entrySet(); - int rows = 0; - - while (data.next()) { - if (!resultSetPredicate.test(data)) { - continue; - } - for (final Entry vectorToColumn : entrySet) { - final T vector = vectorToColumn.getKey(); - final String columnName = vectorToColumn.getValue(); - if (vector instanceof VarCharVector) { - String thisData = data.getString(columnName); - saveToVector(emptyToNull ? emptyToNull(thisData) : thisData, (VarCharVector) vector, rows); - } else if (vector instanceof IntVector) { - final int intValue = data.getInt(columnName); - saveToVector(data.wasNull() ? null : intValue, (IntVector) vector, rows); - } else if (vector instanceof UInt1Vector) { - final byte byteValue = data.getByte(columnName); - saveToVector(data.wasNull() ? null : byteValue, (UInt1Vector) vector, rows); - } else if (vector instanceof BitVector) { - final byte byteValue = data.getByte(columnName); - saveToVector(data.wasNull() ? null : byteValue, (BitVector) vector, rows); - } else if (vector instanceof ListVector) { - String createParamsValues = data.getString(columnName); - - UnionListWriter writer = ((ListVector) vector).getWriter(); - - BufferAllocator allocator = vector.getAllocator(); - final ArrowBuf buf = allocator.buffer(1024); - - writer.setPosition(rows); - writer.startList(); - - if (createParamsValues != null) { - String[] split = createParamsValues.split(","); - - range(0, split.length) - .forEach(i -> { - byte[] bytes = split[i].getBytes(UTF_8); - Preconditions.checkState(bytes.length < 1024, - "The amount of bytes is greater than what the ArrowBuf supports"); - buf.setBytes(0, bytes); - writer.varChar().writeVarChar(0, bytes.length, buf); - }); - } - buf.close(); - writer.endList(); - } else { - throw CallStatus.INVALID_ARGUMENT.withDescription("Provided vector not supported").toRuntimeException(); - } - } - rows ++; - } - for (final Entry vectorToColumn : entrySet) { - vectorToColumn.getKey().setValueCount(rows); - } - - return rows; - } - - private static void saveToVectors(final Map vectorToColumnName, - final ResultSet data) - throws SQLException { - saveToVectors(vectorToColumnName, data, false); - } - - private static VectorSchemaRoot getTableTypesRoot(final ResultSet data, final BufferAllocator allocator) - throws SQLException { - return getRoot(data, allocator, "table_type", "TABLE_TYPE"); - } - - private static VectorSchemaRoot getCatalogsRoot(final ResultSet data, final BufferAllocator allocator) - throws SQLException { - return getRoot(data, allocator, "catalog_name", "TABLE_CATALOG"); - } - - private static VectorSchemaRoot getRoot(final ResultSet data, final BufferAllocator allocator, - final String fieldVectorName, final String columnName) - throws SQLException { - final VarCharVector dataVector = - new VarCharVector(fieldVectorName, FieldType.notNullable(MinorType.VARCHAR.getType()), allocator); - saveToVectors(ImmutableMap.of(dataVector, columnName), data); - final int rows = dataVector.getValueCount(); - dataVector.setValueCount(rows); - return new VectorSchemaRoot(singletonList(dataVector)); - } - - private static VectorSchemaRoot getTypeInfoRoot(CommandGetXdbcTypeInfo request, ResultSet typeInfo, - final BufferAllocator allocator) - throws SQLException { - Preconditions.checkNotNull(allocator, "BufferAllocator cannot be null."); - - VectorSchemaRoot root = VectorSchemaRoot.create(Schemas.GET_TYPE_INFO_SCHEMA, allocator); - - Map mapper = new HashMap<>(); - mapper.put(root.getVector("type_name"), "TYPE_NAME"); - mapper.put(root.getVector("data_type"), "DATA_TYPE"); - mapper.put(root.getVector("column_size"), "PRECISION"); - mapper.put(root.getVector("literal_prefix"), "LITERAL_PREFIX"); - mapper.put(root.getVector("literal_suffix"), "LITERAL_SUFFIX"); - mapper.put(root.getVector("create_params"), "CREATE_PARAMS"); - mapper.put(root.getVector("nullable"), "NULLABLE"); - mapper.put(root.getVector("case_sensitive"), "CASE_SENSITIVE"); - mapper.put(root.getVector("searchable"), "SEARCHABLE"); - mapper.put(root.getVector("unsigned_attribute"), "UNSIGNED_ATTRIBUTE"); - mapper.put(root.getVector("fixed_prec_scale"), "FIXED_PREC_SCALE"); - mapper.put(root.getVector("auto_increment"), "AUTO_INCREMENT"); - mapper.put(root.getVector("local_type_name"), "LOCAL_TYPE_NAME"); - mapper.put(root.getVector("minimum_scale"), "MINIMUM_SCALE"); - mapper.put(root.getVector("maximum_scale"), "MAXIMUM_SCALE"); - mapper.put(root.getVector("sql_data_type"), "SQL_DATA_TYPE"); - mapper.put(root.getVector("datetime_subcode"), "SQL_DATETIME_SUB"); - mapper.put(root.getVector("num_prec_radix"), "NUM_PREC_RADIX"); - - Predicate predicate; - if (request.hasDataType()) { - predicate = (resultSet) -> { - try { - return resultSet.getInt("DATA_TYPE") == request.getDataType(); - } catch (SQLException e) { - throw new RuntimeException(e); - } - }; - } else { - predicate = (resultSet -> true); - } - - int rows = saveToVectors(mapper, typeInfo, true, predicate); - - root.setRowCount(rows); - return root; - } - - private static VectorSchemaRoot getTablesRoot(final DatabaseMetaData databaseMetaData, - final BufferAllocator allocator, - final boolean includeSchema, - final String catalog, - final String schemaFilterPattern, - final String tableFilterPattern, - final String... tableTypes) - throws SQLException, IOException { - /* - * TODO Fix DerbyDB inconsistency if possible. - * During the early development of this prototype, an inconsistency has been found in the database - * used for this demonstration; as DerbyDB does not operate with the concept of catalogs, fetching - * the catalog name for a given table from `DatabaseMetadata#getColumns` and `DatabaseMetadata#getSchemas` - * returns null, as expected. However, the inconsistency lies in the fact that accessing the same - * information -- that is, the catalog name for a given table -- from `DatabaseMetadata#getSchemas` - * returns an empty String.The temporary workaround for this was making sure we convert the empty Strings - * to null using `com.google.common.base.Strings#emptyToNull`. - */ - Objects.requireNonNull(allocator, "BufferAllocator cannot be null."); - final VarCharVector catalogNameVector = new VarCharVector("catalog_name", allocator); - final VarCharVector schemaNameVector = new VarCharVector("db_schema_name", allocator); - final VarCharVector tableNameVector = - new VarCharVector("table_name", FieldType.notNullable(MinorType.VARCHAR.getType()), allocator); - final VarCharVector tableTypeVector = - new VarCharVector("table_type", FieldType.notNullable(MinorType.VARCHAR.getType()), allocator); - - final List vectors = new ArrayList<>(4); - vectors.add(catalogNameVector); - vectors.add(schemaNameVector); - vectors.add(tableNameVector); - vectors.add(tableTypeVector); - - vectors.forEach(FieldVector::allocateNew); - - final Map vectorToColumnName = ImmutableMap.of( - catalogNameVector, "TABLE_CAT", - schemaNameVector, "TABLE_SCHEM", - tableNameVector, "TABLE_NAME", - tableTypeVector, "TABLE_TYPE"); - - try (final ResultSet data = - Objects.requireNonNull( - databaseMetaData, - format("%s cannot be null.", databaseMetaData.getClass().getName())) - .getTables(catalog, schemaFilterPattern, tableFilterPattern, tableTypes)) { - - saveToVectors(vectorToColumnName, data, true); - final int rows = - vectors.stream().map(FieldVector::getValueCount).findAny().orElseThrow(IllegalStateException::new); - vectors.forEach(vector -> vector.setValueCount(rows)); - - if (includeSchema) { - final VarBinaryVector tableSchemaVector = - new VarBinaryVector("table_schema", FieldType.notNullable(MinorType.VARBINARY.getType()), allocator); - tableSchemaVector.allocateNew(rows); - - try (final ResultSet columnsData = - databaseMetaData.getColumns(catalog, schemaFilterPattern, tableFilterPattern, null)) { - final Map> tableToFields = new HashMap<>(); - - while (columnsData.next()) { - final String catalogName = columnsData.getString("TABLE_CAT"); - final String schemaName = columnsData.getString("TABLE_SCHEM"); - final String tableName = columnsData.getString("TABLE_NAME"); - final String typeName = columnsData.getString("TYPE_NAME"); - final String fieldName = columnsData.getString("COLUMN_NAME"); - final int dataType = columnsData.getInt("DATA_TYPE"); - final boolean isNullable = columnsData.getInt("NULLABLE") != DatabaseMetaData.columnNoNulls; - final int precision = columnsData.getInt("COLUMN_SIZE"); - final int scale = columnsData.getInt("DECIMAL_DIGITS"); - boolean isAutoIncrement = - Objects.equals(columnsData.getString("IS_AUTOINCREMENT"), "YES"); - - final List fields = tableToFields.computeIfAbsent(tableName, tableName_ -> new ArrayList<>()); - - final FlightSqlColumnMetadata columnMetadata = new FlightSqlColumnMetadata.Builder() - .catalogName(catalogName) - .schemaName(schemaName) - .tableName(tableName) - .typeName(typeName) - .precision(precision) - .scale(scale) - .isAutoIncrement(isAutoIncrement) - .build(); - - final Field field = - new Field( - fieldName, - new FieldType( - isNullable, - getArrowTypeFromJdbcType(dataType, precision, scale), - null, - columnMetadata.getMetadataMap()), - null); - fields.add(field); - } - - for (int index = 0; index < rows; index++) { - final String tableName = tableNameVector.getObject(index).toString(); - final Schema schema = new Schema(tableToFields.get(tableName)); - saveToVector( - copyFrom(serializeMetadata(schema)).toByteArray(), - tableSchemaVector, index); - } - } - - tableSchemaVector.setValueCount(rows); - vectors.add(tableSchemaVector); - } - } - - return new VectorSchemaRoot(vectors); - } - - private static ByteBuffer serializeMetadata(final Schema schema) { - final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); - try { - MessageSerializer.serialize(new WriteChannel(Channels.newChannel(outputStream)), schema); - - return ByteBuffer.wrap(outputStream.toByteArray()); - } catch (final IOException e) { - throw new RuntimeException("Failed to serialize schema", e); - } - } - - @Override - public void getStreamPreparedStatement(final CommandPreparedStatementQuery command, final CallContext context, - final ServerStreamListener listener) { - final ByteString handle = command.getPreparedStatementHandle(); - StatementContext statementContext = preparedStatementLoadingCache.getIfPresent(handle); - Objects.requireNonNull(statementContext); - final PreparedStatement statement = statementContext.getStatement(); - try (final ResultSet resultSet = statement.executeQuery()) { - final Schema schema = jdbcToArrowSchema(resultSet.getMetaData(), DEFAULT_CALENDAR); - try (final VectorSchemaRoot vectorSchemaRoot = VectorSchemaRoot.create(schema, rootAllocator)) { - final VectorLoader loader = new VectorLoader(vectorSchemaRoot); - listener.start(vectorSchemaRoot); - - final ArrowVectorIterator iterator = sqlToArrowVectorIterator(resultSet, rootAllocator); - while (iterator.hasNext()) { - final VectorSchemaRoot batch = iterator.next(); - if (batch.getRowCount() == 0) { - break; - } - final VectorUnloader unloader = new VectorUnloader(batch); - loader.load(unloader.getRecordBatch()); - listener.putNext(); - vectorSchemaRoot.clear(); - } - - listener.putNext(); - } - } catch (final SQLException | IOException e) { - LOGGER.error(format("Failed to getStreamPreparedStatement: <%s>.", e.getMessage()), e); - listener.error(CallStatus.INTERNAL.withDescription("Failed to prepare statement: " + e).toRuntimeException()); - } finally { - listener.completed(); - } - } - - @Override - public void closePreparedStatement(final ActionClosePreparedStatementRequest request, final CallContext context, - final StreamListener listener) { - // Running on another thread - executorService.submit(() -> { - try { - preparedStatementLoadingCache.invalidate(request.getPreparedStatementHandle()); - } catch (final Exception e) { - listener.onError(e); - return; - } - listener.onCompleted(); - }); - } - - @Override - public FlightInfo getFlightInfoStatement(final CommandStatementQuery request, final CallContext context, - final FlightDescriptor descriptor) { - ByteString handle = copyFrom(randomUUID().toString().getBytes(UTF_8)); - - try { - // Ownership of the connection will be passed to the context. Do NOT close! - final Connection connection = dataSource.getConnection(); - final Statement statement = connection.createStatement( - ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_READ_ONLY); - final String query = request.getQuery(); - final StatementContext statementContext = new StatementContext<>(statement, query); - - statementLoadingCache.put(handle, statementContext); - final ResultSet resultSet = statement.executeQuery(query); - - TicketStatementQuery ticket = TicketStatementQuery.newBuilder() - .setStatementHandle(handle) - .build(); - return getFlightInfoForSchema(ticket, descriptor, - jdbcToArrowSchema(resultSet.getMetaData(), DEFAULT_CALENDAR)); - } catch (final SQLException e) { - LOGGER.error( - format("There was a problem executing the prepared statement: <%s>.", e.getMessage()), - e); - throw CallStatus.INTERNAL.withCause(e).toRuntimeException(); - } - } - - @Override - public FlightInfo getFlightInfoPreparedStatement(final CommandPreparedStatementQuery command, - final CallContext context, - final FlightDescriptor descriptor) { - final ByteString preparedStatementHandle = command.getPreparedStatementHandle(); - StatementContext statementContext = - preparedStatementLoadingCache.getIfPresent(preparedStatementHandle); - try { - assert statementContext != null; - PreparedStatement statement = statementContext.getStatement(); - - ResultSetMetaData metaData = statement.getMetaData(); - return getFlightInfoForSchema(command, descriptor, - jdbcToArrowSchema(metaData, DEFAULT_CALENDAR)); - } catch (final SQLException e) { - LOGGER.error( - format("There was a problem executing the prepared statement: <%s>.", e.getMessage()), - e); - throw CallStatus.INTERNAL.withCause(e).toRuntimeException(); - } - } - - @Override - public SchemaResult getSchemaStatement(final CommandStatementQuery command, final CallContext context, - final FlightDescriptor descriptor) { - throw CallStatus.UNIMPLEMENTED.toRuntimeException(); - } - - @Override - public void close() throws Exception { - try { - preparedStatementLoadingCache.cleanUp(); - } catch (Throwable t) { - LOGGER.error(format("Failed to close resources: <%s>", t.getMessage()), t); - } - - AutoCloseables.close(dataSource, rootAllocator); - } - - @Override - public void listFlights(CallContext context, Criteria criteria, StreamListener listener) { - // TODO - build example implementation - throw CallStatus.UNIMPLEMENTED.toRuntimeException(); - } - - @Override - public void createPreparedStatement(final ActionCreatePreparedStatementRequest request, final CallContext context, - final StreamListener listener) { - // Running on another thread - executorService.submit(() -> { - try { - final ByteString preparedStatementHandle = copyFrom(randomUUID().toString().getBytes(UTF_8)); - // Ownership of the connection will be passed to the context. Do NOT close! - final Connection connection = dataSource.getConnection(); - final PreparedStatement preparedStatement = connection.prepareStatement(request.getQuery(), - ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_READ_ONLY); - final StatementContext preparedStatementContext = - new StatementContext<>(preparedStatement, request.getQuery()); - - preparedStatementLoadingCache.put(preparedStatementHandle, preparedStatementContext); - - final Schema parameterSchema = - jdbcToArrowSchema(preparedStatement.getParameterMetaData(), DEFAULT_CALENDAR); - - final ResultSetMetaData metaData = preparedStatement.getMetaData(); - final ByteString bytes = isNull(metaData) ? - ByteString.EMPTY : - ByteString.copyFrom( - serializeMetadata(jdbcToArrowSchema(metaData, DEFAULT_CALENDAR))); - final ActionCreatePreparedStatementResult result = ActionCreatePreparedStatementResult.newBuilder() - .setDatasetSchema(bytes) - .setParameterSchema(copyFrom(serializeMetadata(parameterSchema))) - .setPreparedStatementHandle(preparedStatementHandle) - .build(); - listener.onNext(new Result(pack(result).toByteArray())); - } catch (final SQLException e) { - listener.onError(CallStatus.INTERNAL - .withDescription("Failed to create prepared statement: " + e) - .toRuntimeException()); - return; - } catch (final Throwable t) { - listener.onError(CallStatus.INTERNAL.withDescription("Unknown error: " + t).toRuntimeException()); - return; - } - listener.onCompleted(); - }); - } - - @Override - public void doExchange(CallContext context, FlightStream reader, ServerStreamListener writer) { - // TODO - build example implementation - throw CallStatus.UNIMPLEMENTED.toRuntimeException(); - } - - @Override - public Runnable acceptPutStatement(CommandStatementUpdate command, - CallContext context, FlightStream flightStream, - StreamListener ackStream) { - final String query = command.getQuery(); - - return () -> { - try (final Connection connection = dataSource.getConnection(); - final Statement statement = connection.createStatement()) { - final int result = statement.executeUpdate(query); - - final DoPutUpdateResult build = - DoPutUpdateResult.newBuilder().setRecordCount(result).build(); - - try (final ArrowBuf buffer = rootAllocator.buffer(build.getSerializedSize())) { - buffer.writeBytes(build.toByteArray()); - ackStream.onNext(PutResult.metadata(buffer)); - ackStream.onCompleted(); - } - } catch (SQLSyntaxErrorException e) { - ackStream.onError(CallStatus.INVALID_ARGUMENT - .withDescription("Failed to execute statement (invalid syntax): " + e) - .toRuntimeException()); - } catch (SQLException e) { - ackStream.onError(CallStatus.INTERNAL - .withDescription("Failed to execute statement: " + e) - .toRuntimeException()); - } - }; - } - - @Override - public Runnable acceptPutPreparedStatementUpdate(CommandPreparedStatementUpdate command, CallContext context, - FlightStream flightStream, StreamListener ackStream) { - final StatementContext statement = - preparedStatementLoadingCache.getIfPresent(command.getPreparedStatementHandle()); - - return () -> { - if (statement == null) { - ackStream.onError(CallStatus.NOT_FOUND - .withDescription("Prepared statement does not exist") - .toRuntimeException()); - return; - } - try { - final PreparedStatement preparedStatement = statement.getStatement(); - - while (flightStream.next()) { - final VectorSchemaRoot root = flightStream.getRoot(); - - final int rowCount = root.getRowCount(); - final int recordCount; - - if (rowCount == 0) { - preparedStatement.execute(); - recordCount = preparedStatement.getUpdateCount(); - } else { - final JdbcParameterBinder binder = JdbcParameterBinder.builder(preparedStatement, root).bindAll().build(); - while (binder.next()) { - preparedStatement.addBatch(); - } - int[] recordCounts = preparedStatement.executeBatch(); - recordCount = Arrays.stream(recordCounts).sum(); - } - - final DoPutUpdateResult build = - DoPutUpdateResult.newBuilder().setRecordCount(recordCount).build(); - - try (final ArrowBuf buffer = rootAllocator.buffer(build.getSerializedSize())) { - buffer.writeBytes(build.toByteArray()); - ackStream.onNext(PutResult.metadata(buffer)); - } - } - } catch (SQLException e) { - ackStream.onError(CallStatus.INTERNAL.withDescription("Failed to execute update: " + e).toRuntimeException()); - return; - } - ackStream.onCompleted(); - }; - } - - @Override - public Runnable acceptPutPreparedStatementQuery(CommandPreparedStatementQuery command, CallContext context, - FlightStream flightStream, StreamListener ackStream) { - final StatementContext statementContext = - preparedStatementLoadingCache.getIfPresent(command.getPreparedStatementHandle()); - - return () -> { - assert statementContext != null; - PreparedStatement preparedStatement = statementContext.getStatement(); - - try { - while (flightStream.next()) { - final VectorSchemaRoot root = flightStream.getRoot(); - final JdbcParameterBinder binder = JdbcParameterBinder.builder(preparedStatement, root).bindAll().build(); - while (binder.next()) { - // Do not execute() - will be done in a getStream call - } - } - - } catch (SQLException e) { - ackStream.onError(CallStatus.INTERNAL - .withDescription("Failed to bind parameters: " + e.getMessage()) - .withCause(e) - .toRuntimeException()); - return; - } - ackStream.onCompleted(); - }; - } - - @Override - public FlightInfo getFlightInfoSqlInfo(final CommandGetSqlInfo request, final CallContext context, - final FlightDescriptor descriptor) { - return getFlightInfoForSchema(request, descriptor, Schemas.GET_SQL_INFO_SCHEMA); - } - - @Override - public void getStreamSqlInfo(final CommandGetSqlInfo command, final CallContext context, - final ServerStreamListener listener) { - this.sqlInfoBuilder.send(command.getInfoList(), listener); - } - - @Override - public FlightInfo getFlightInfoTypeInfo(CommandGetXdbcTypeInfo request, CallContext context, - FlightDescriptor descriptor) { - return getFlightInfoForSchema(request, descriptor, Schemas.GET_TYPE_INFO_SCHEMA); - } - - @Override - public void getStreamTypeInfo(CommandGetXdbcTypeInfo request, CallContext context, - ServerStreamListener listener) { - try (final Connection connection = dataSource.getConnection(); - final ResultSet typeInfo = connection.getMetaData().getTypeInfo(); - final VectorSchemaRoot vectorSchemaRoot = getTypeInfoRoot(request, typeInfo, rootAllocator)) { - listener.start(vectorSchemaRoot); - listener.putNext(); - } catch (SQLException e) { - LOGGER.error(format("Failed to getStreamCatalogs: <%s>.", e.getMessage()), e); - listener.error(e); - } finally { - listener.completed(); - } - } - - @Override - public FlightInfo getFlightInfoCatalogs(final CommandGetCatalogs request, final CallContext context, - final FlightDescriptor descriptor) { - return getFlightInfoForSchema(request, descriptor, Schemas.GET_CATALOGS_SCHEMA); - } - - @Override - public void getStreamCatalogs(final CallContext context, final ServerStreamListener listener) { - try (final Connection connection = dataSource.getConnection(); - final ResultSet catalogs = connection.getMetaData().getCatalogs(); - final VectorSchemaRoot vectorSchemaRoot = getCatalogsRoot(catalogs, rootAllocator)) { - listener.start(vectorSchemaRoot); - listener.putNext(); - } catch (SQLException e) { - LOGGER.error(format("Failed to getStreamCatalogs: <%s>.", e.getMessage()), e); - listener.error(e); - } finally { - listener.completed(); - } - } - - @Override - public FlightInfo getFlightInfoSchemas(final CommandGetDbSchemas request, final CallContext context, - final FlightDescriptor descriptor) { - return getFlightInfoForSchema(request, descriptor, Schemas.GET_SCHEMAS_SCHEMA); - } - - @Override - public void getStreamSchemas(final CommandGetDbSchemas command, final CallContext context, - final ServerStreamListener listener) { - final String catalog = command.hasCatalog() ? command.getCatalog() : null; - final String schemaFilterPattern = command.hasDbSchemaFilterPattern() ? command.getDbSchemaFilterPattern() : null; - try (final Connection connection = dataSource.getConnection(); - final ResultSet schemas = connection.getMetaData().getSchemas(catalog, schemaFilterPattern); - final VectorSchemaRoot vectorSchemaRoot = getSchemasRoot(schemas, rootAllocator)) { - listener.start(vectorSchemaRoot); - listener.putNext(); - } catch (SQLException e) { - LOGGER.error(format("Failed to getStreamSchemas: <%s>.", e.getMessage()), e); - listener.error(e); - } finally { - listener.completed(); - } - } - - @Override - public FlightInfo getFlightInfoTables(final CommandGetTables request, final CallContext context, - final FlightDescriptor descriptor) { - Schema schemaToUse = Schemas.GET_TABLES_SCHEMA; - if (!request.getIncludeSchema()) { - schemaToUse = Schemas.GET_TABLES_SCHEMA_NO_SCHEMA; - } - return getFlightInfoForSchema(request, descriptor, schemaToUse); - } - - @Override - public void getStreamTables(final CommandGetTables command, final CallContext context, - final ServerStreamListener listener) { - final String catalog = command.hasCatalog() ? command.getCatalog() : null; - final String schemaFilterPattern = - command.hasDbSchemaFilterPattern() ? command.getDbSchemaFilterPattern() : null; - final String tableFilterPattern = - command.hasTableNameFilterPattern() ? command.getTableNameFilterPattern() : null; - - final ProtocolStringList protocolStringList = command.getTableTypesList(); - final int protocolSize = protocolStringList.size(); - final String[] tableTypes = - protocolSize == 0 ? null : protocolStringList.toArray(new String[protocolSize]); - - try (final Connection connection = DriverManager.getConnection(DATABASE_URI); - final VectorSchemaRoot vectorSchemaRoot = getTablesRoot( - connection.getMetaData(), - rootAllocator, - command.getIncludeSchema(), - catalog, schemaFilterPattern, tableFilterPattern, tableTypes)) { - listener.start(vectorSchemaRoot); - listener.putNext(); - } catch (SQLException | IOException e) { - LOGGER.error(format("Failed to getStreamTables: <%s>.", e.getMessage()), e); - listener.error(e); - } finally { - listener.completed(); - } - } - - @Override - public FlightInfo getFlightInfoTableTypes(final CommandGetTableTypes request, final CallContext context, - final FlightDescriptor descriptor) { - return getFlightInfoForSchema(request, descriptor, Schemas.GET_TABLE_TYPES_SCHEMA); - } - - @Override - public void getStreamTableTypes(final CallContext context, final ServerStreamListener listener) { - try (final Connection connection = dataSource.getConnection(); - final ResultSet tableTypes = connection.getMetaData().getTableTypes(); - final VectorSchemaRoot vectorSchemaRoot = getTableTypesRoot(tableTypes, rootAllocator)) { - listener.start(vectorSchemaRoot); - listener.putNext(); - } catch (SQLException e) { - LOGGER.error(format("Failed to getStreamTableTypes: <%s>.", e.getMessage()), e); - listener.error(e); - } finally { - listener.completed(); - } - } - - @Override - public FlightInfo getFlightInfoPrimaryKeys(final CommandGetPrimaryKeys request, final CallContext context, - final FlightDescriptor descriptor) { - return getFlightInfoForSchema(request, descriptor, Schemas.GET_PRIMARY_KEYS_SCHEMA); - } - - @Override - public void getStreamPrimaryKeys(final CommandGetPrimaryKeys command, final CallContext context, - final ServerStreamListener listener) { - - final String catalog = command.hasCatalog() ? command.getCatalog() : null; - final String schema = command.hasDbSchema() ? command.getDbSchema() : null; - final String table = command.getTable(); - - try (Connection connection = DriverManager.getConnection(DATABASE_URI)) { - final ResultSet primaryKeys = connection.getMetaData().getPrimaryKeys(catalog, schema, table); - - final VarCharVector catalogNameVector = new VarCharVector("catalog_name", rootAllocator); - final VarCharVector schemaNameVector = new VarCharVector("db_schema_name", rootAllocator); - final VarCharVector tableNameVector = new VarCharVector("table_name", rootAllocator); - final VarCharVector columnNameVector = new VarCharVector("column_name", rootAllocator); - final IntVector keySequenceVector = new IntVector("key_sequence", rootAllocator); - final VarCharVector keyNameVector = new VarCharVector("key_name", rootAllocator); - - final List vectors = - new ArrayList<>( - ImmutableList.of( - catalogNameVector, schemaNameVector, tableNameVector, columnNameVector, keySequenceVector, - keyNameVector)); - vectors.forEach(FieldVector::allocateNew); - - int rows = 0; - for (; primaryKeys.next(); rows++) { - saveToVector(primaryKeys.getString("TABLE_CAT"), catalogNameVector, rows); - saveToVector(primaryKeys.getString("TABLE_SCHEM"), schemaNameVector, rows); - saveToVector(primaryKeys.getString("TABLE_NAME"), tableNameVector, rows); - saveToVector(primaryKeys.getString("COLUMN_NAME"), columnNameVector, rows); - final int key_seq = primaryKeys.getInt("KEY_SEQ"); - saveToVector(primaryKeys.wasNull() ? null : key_seq, keySequenceVector, rows); - saveToVector(primaryKeys.getString("PK_NAME"), keyNameVector, rows); - } - - try (final VectorSchemaRoot vectorSchemaRoot = new VectorSchemaRoot(vectors)) { - vectorSchemaRoot.setRowCount(rows); - - listener.start(vectorSchemaRoot); - listener.putNext(); - } - } catch (SQLException e) { - listener.error(e); - } finally { - listener.completed(); - } - } - - @Override - public FlightInfo getFlightInfoExportedKeys(final CommandGetExportedKeys request, final CallContext context, - final FlightDescriptor descriptor) { - return getFlightInfoForSchema(request, descriptor, Schemas.GET_EXPORTED_KEYS_SCHEMA); - } - - @Override - public void getStreamExportedKeys(final CommandGetExportedKeys command, final CallContext context, - final ServerStreamListener listener) { - String catalog = command.hasCatalog() ? command.getCatalog() : null; - String schema = command.hasDbSchema() ? command.getDbSchema() : null; - String table = command.getTable(); - - try (Connection connection = DriverManager.getConnection(DATABASE_URI); - ResultSet keys = connection.getMetaData().getExportedKeys(catalog, schema, table); - VectorSchemaRoot vectorSchemaRoot = createVectors(keys)) { - listener.start(vectorSchemaRoot); - listener.putNext(); - } catch (SQLException e) { - listener.error(e); - } finally { - listener.completed(); - } - } - - @Override - public FlightInfo getFlightInfoImportedKeys(final CommandGetImportedKeys request, final CallContext context, - final FlightDescriptor descriptor) { - return getFlightInfoForSchema(request, descriptor, Schemas.GET_IMPORTED_KEYS_SCHEMA); - } - - @Override - public void getStreamImportedKeys(final CommandGetImportedKeys command, final CallContext context, - final ServerStreamListener listener) { - String catalog = command.hasCatalog() ? command.getCatalog() : null; - String schema = command.hasDbSchema() ? command.getDbSchema() : null; - String table = command.getTable(); - - try (Connection connection = DriverManager.getConnection(DATABASE_URI); - ResultSet keys = connection.getMetaData().getImportedKeys(catalog, schema, table); - VectorSchemaRoot vectorSchemaRoot = createVectors(keys)) { - listener.start(vectorSchemaRoot); - listener.putNext(); - } catch (final SQLException e) { - listener.error(e); - } finally { - listener.completed(); - } - } - - @Override - public FlightInfo getFlightInfoCrossReference(CommandGetCrossReference request, CallContext context, - FlightDescriptor descriptor) { - return getFlightInfoForSchema(request, descriptor, Schemas.GET_CROSS_REFERENCE_SCHEMA); - } - - @Override - public void getStreamCrossReference(CommandGetCrossReference command, CallContext context, - ServerStreamListener listener) { - final String pkCatalog = command.hasPkCatalog() ? command.getPkCatalog() : null; - final String pkSchema = command.hasPkDbSchema() ? command.getPkDbSchema() : null; - final String fkCatalog = command.hasFkCatalog() ? command.getFkCatalog() : null; - final String fkSchema = command.hasFkDbSchema() ? command.getFkDbSchema() : null; - final String pkTable = command.getPkTable(); - final String fkTable = command.getFkTable(); - - try (Connection connection = DriverManager.getConnection(DATABASE_URI); - ResultSet keys = connection.getMetaData() - .getCrossReference(pkCatalog, pkSchema, pkTable, fkCatalog, fkSchema, fkTable); - VectorSchemaRoot vectorSchemaRoot = createVectors(keys)) { - listener.start(vectorSchemaRoot); - listener.putNext(); - } catch (final SQLException e) { - listener.error(e); - } finally { - listener.completed(); - } - } - - private VectorSchemaRoot createVectors(ResultSet keys) throws SQLException { - final VarCharVector pkCatalogNameVector = new VarCharVector("pk_catalog_name", rootAllocator); - final VarCharVector pkSchemaNameVector = new VarCharVector("pk_db_schema_name", rootAllocator); - final VarCharVector pkTableNameVector = new VarCharVector("pk_table_name", rootAllocator); - final VarCharVector pkColumnNameVector = new VarCharVector("pk_column_name", rootAllocator); - final VarCharVector fkCatalogNameVector = new VarCharVector("fk_catalog_name", rootAllocator); - final VarCharVector fkSchemaNameVector = new VarCharVector("fk_db_schema_name", rootAllocator); - final VarCharVector fkTableNameVector = new VarCharVector("fk_table_name", rootAllocator); - final VarCharVector fkColumnNameVector = new VarCharVector("fk_column_name", rootAllocator); - final IntVector keySequenceVector = new IntVector("key_sequence", rootAllocator); - final VarCharVector fkKeyNameVector = new VarCharVector("fk_key_name", rootAllocator); - final VarCharVector pkKeyNameVector = new VarCharVector("pk_key_name", rootAllocator); - final UInt1Vector updateRuleVector = new UInt1Vector("update_rule", rootAllocator); - final UInt1Vector deleteRuleVector = new UInt1Vector("delete_rule", rootAllocator); - - Map vectorToColumnName = new HashMap<>(); - vectorToColumnName.put(pkCatalogNameVector, "PKTABLE_CAT"); - vectorToColumnName.put(pkSchemaNameVector, "PKTABLE_SCHEM"); - vectorToColumnName.put(pkTableNameVector, "PKTABLE_NAME"); - vectorToColumnName.put(pkColumnNameVector, "PKCOLUMN_NAME"); - vectorToColumnName.put(fkCatalogNameVector, "FKTABLE_CAT"); - vectorToColumnName.put(fkSchemaNameVector, "FKTABLE_SCHEM"); - vectorToColumnName.put(fkTableNameVector, "FKTABLE_NAME"); - vectorToColumnName.put(fkColumnNameVector, "FKCOLUMN_NAME"); - vectorToColumnName.put(keySequenceVector, "KEY_SEQ"); - vectorToColumnName.put(updateRuleVector, "UPDATE_RULE"); - vectorToColumnName.put(deleteRuleVector, "DELETE_RULE"); - vectorToColumnName.put(fkKeyNameVector, "FK_NAME"); - vectorToColumnName.put(pkKeyNameVector, "PK_NAME"); - - final VectorSchemaRoot vectorSchemaRoot = VectorSchemaRoot.of( - pkCatalogNameVector, pkSchemaNameVector, pkTableNameVector, pkColumnNameVector, fkCatalogNameVector, - fkSchemaNameVector, fkTableNameVector, fkColumnNameVector, keySequenceVector, fkKeyNameVector, - pkKeyNameVector, updateRuleVector, deleteRuleVector); - - vectorSchemaRoot.allocateNew(); - final int rowCount = saveToVectors(vectorToColumnName, keys, true); - - vectorSchemaRoot.setRowCount(rowCount); - - return vectorSchemaRoot; - } - - @Override - public void getStreamStatement(final TicketStatementQuery ticketStatementQuery, final CallContext context, - final ServerStreamListener listener) { - final ByteString handle = ticketStatementQuery.getStatementHandle(); - final StatementContext statementContext = - Objects.requireNonNull(statementLoadingCache.getIfPresent(handle)); - try (final ResultSet resultSet = statementContext.getStatement().getResultSet()) { - final Schema schema = jdbcToArrowSchema(resultSet.getMetaData(), DEFAULT_CALENDAR); - try (VectorSchemaRoot vectorSchemaRoot = VectorSchemaRoot.create(schema, rootAllocator)) { - final VectorLoader loader = new VectorLoader(vectorSchemaRoot); - listener.start(vectorSchemaRoot); - - final ArrowVectorIterator iterator = sqlToArrowVectorIterator(resultSet, rootAllocator); - while (iterator.hasNext()) { - final VectorUnloader unloader = new VectorUnloader(iterator.next()); - loader.load(unloader.getRecordBatch()); - listener.putNext(); - vectorSchemaRoot.clear(); - } - - listener.putNext(); - } - } catch (SQLException | IOException e) { - LOGGER.error(format("Failed to getStreamPreparedStatement: <%s>.", e.getMessage()), e); - listener.error(e); - } finally { - listener.completed(); - statementLoadingCache.invalidate(handle); - } - } - - private FlightInfo getFlightInfoForSchema(final T request, final FlightDescriptor descriptor, - final Schema schema) { - final Ticket ticket = new Ticket(pack(request).toByteArray()); - // TODO Support multiple endpoints. - final List endpoints = singletonList(new FlightEndpoint(ticket, location)); - - return new FlightInfo(schema, descriptor, endpoints, -1, -1); - } - - private static class StatementRemovalListener - implements RemovalListener> { - @Override - public void onRemoval(final RemovalNotification> notification) { - try { - AutoCloseables.close(notification.getValue()); - } catch (final Exception e) { - // swallow - } - } - } -} diff --git a/java/vector/src/test/java/org/apache/arrow/vector/table/RowTest.java b/java/vector/src/test/java/org/apache/arrow/vector/table/RowTest.java deleted file mode 100644 index eb50e866b19f0..0000000000000 --- a/java/vector/src/test/java/org/apache/arrow/vector/table/RowTest.java +++ /dev/null @@ -1,856 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.arrow.vector.table; - -import static org.apache.arrow.vector.table.TestUtils.BIGINT_INT_MAP_VECTOR_NAME; -import static org.apache.arrow.vector.table.TestUtils.FIXEDBINARY_VECTOR_NAME_1; -import static org.apache.arrow.vector.table.TestUtils.INT_LIST_VECTOR_NAME; -import static org.apache.arrow.vector.table.TestUtils.INT_VECTOR_NAME_1; -import static org.apache.arrow.vector.table.TestUtils.STRUCT_VECTOR_NAME; -import static org.apache.arrow.vector.table.TestUtils.UNION_VECTOR_NAME; -import static org.apache.arrow.vector.table.TestUtils.VARBINARY_VECTOR_NAME_1; -import static org.apache.arrow.vector.table.TestUtils.VARCHAR_VECTOR_NAME_1; -import static org.apache.arrow.vector.table.TestUtils.fixedWidthVectors; -import static org.apache.arrow.vector.table.TestUtils.intPlusFixedBinaryColumns; -import static org.apache.arrow.vector.table.TestUtils.intPlusLargeVarBinaryColumns; -import static org.apache.arrow.vector.table.TestUtils.intPlusLargeVarcharColumns; -import static org.apache.arrow.vector.table.TestUtils.intPlusVarBinaryColumns; -import static org.apache.arrow.vector.table.TestUtils.intPlusVarcharColumns; -import static org.apache.arrow.vector.table.TestUtils.simpleDenseUnionVector; -import static org.apache.arrow.vector.table.TestUtils.simpleListVector; -import static org.apache.arrow.vector.table.TestUtils.simpleMapVector; -import static org.apache.arrow.vector.table.TestUtils.simpleStructVector; -import static org.apache.arrow.vector.table.TestUtils.simpleUnionVector; -import static org.apache.arrow.vector.table.TestUtils.timezoneTemporalVectors; -import static org.apache.arrow.vector.table.TestUtils.twoIntColumns; -import static org.junit.jupiter.api.Assertions.assertArrayEquals; -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertFalse; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertNull; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; - -import java.math.BigDecimal; -import java.nio.charset.StandardCharsets; -import java.time.Duration; -import java.time.LocalDateTime; -import java.time.Period; -import java.util.ArrayList; -import java.util.List; - -import org.apache.arrow.memory.ArrowBuf; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.memory.RootAllocator; -import org.apache.arrow.vector.BitVector; -import org.apache.arrow.vector.DecimalVector; -import org.apache.arrow.vector.DurationVector; -import org.apache.arrow.vector.FieldVector; -import org.apache.arrow.vector.IntervalDayVector; -import org.apache.arrow.vector.IntervalMonthDayNanoVector; -import org.apache.arrow.vector.IntervalYearVector; -import org.apache.arrow.vector.PeriodDuration; -import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.complex.DenseUnionVector; -import org.apache.arrow.vector.complex.ListVector; -import org.apache.arrow.vector.complex.MapVector; -import org.apache.arrow.vector.complex.StructVector; -import org.apache.arrow.vector.complex.UnionVector; -import org.apache.arrow.vector.holders.NullableBigIntHolder; -import org.apache.arrow.vector.holders.NullableBitHolder; -import org.apache.arrow.vector.holders.NullableDecimalHolder; -import org.apache.arrow.vector.holders.NullableDurationHolder; -import org.apache.arrow.vector.holders.NullableFloat4Holder; -import org.apache.arrow.vector.holders.NullableFloat8Holder; -import org.apache.arrow.vector.holders.NullableIntHolder; -import org.apache.arrow.vector.holders.NullableIntervalDayHolder; -import org.apache.arrow.vector.holders.NullableIntervalMonthDayNanoHolder; -import org.apache.arrow.vector.holders.NullableIntervalYearHolder; -import org.apache.arrow.vector.holders.NullableSmallIntHolder; -import org.apache.arrow.vector.holders.NullableTimeMicroHolder; -import org.apache.arrow.vector.holders.NullableTimeMilliHolder; -import org.apache.arrow.vector.holders.NullableTimeNanoHolder; -import org.apache.arrow.vector.holders.NullableTimeSecHolder; -import org.apache.arrow.vector.holders.NullableTimeStampMicroHolder; -import org.apache.arrow.vector.holders.NullableTimeStampMicroTZHolder; -import org.apache.arrow.vector.holders.NullableTimeStampMilliHolder; -import org.apache.arrow.vector.holders.NullableTimeStampMilliTZHolder; -import org.apache.arrow.vector.holders.NullableTimeStampNanoHolder; -import org.apache.arrow.vector.holders.NullableTimeStampNanoTZHolder; -import org.apache.arrow.vector.holders.NullableTimeStampSecHolder; -import org.apache.arrow.vector.holders.NullableTimeStampSecTZHolder; -import org.apache.arrow.vector.holders.NullableTinyIntHolder; -import org.apache.arrow.vector.holders.NullableUInt1Holder; -import org.apache.arrow.vector.holders.NullableUInt2Holder; -import org.apache.arrow.vector.holders.NullableUInt4Holder; -import org.apache.arrow.vector.holders.NullableUInt8Holder; -import org.apache.arrow.vector.types.IntervalUnit; -import org.apache.arrow.vector.types.TimeUnit; -import org.apache.arrow.vector.types.pojo.ArrowType; -import org.apache.arrow.vector.types.pojo.FieldType; -import org.apache.arrow.vector.types.pojo.TestExtensionType; -import org.apache.arrow.vector.util.JsonStringHashMap; -import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; - -class RowTest { - - private BufferAllocator allocator; - - @BeforeEach - public void init() { - allocator = new RootAllocator(Long.MAX_VALUE); - } - - @AfterEach - public void terminate() { - allocator.close(); - } - - @Test - void constructor() { - List vectorList = twoIntColumns(allocator); - try (Table t = new Table(vectorList)) { - Row c = t.immutableRow(); - assertEquals(StandardCharsets.UTF_8, c.getDefaultCharacterSet()); - } - } - - @Test - void at() { - List vectorList = twoIntColumns(allocator); - try (Table t = new Table(vectorList)) { - Row c = t.immutableRow(); - assertEquals(c.getRowNumber(), -1); - c.setPosition(1); - assertEquals(c.getRowNumber(), 1); - } - } - - @Test - void getIntByVectorIndex() { - List vectorList = twoIntColumns(allocator); - try (Table t = new Table(vectorList)) { - Row c = t.immutableRow(); - c.setPosition(1); - assertEquals(2, c.getInt(0)); - } - } - - @Test - void getIntByVectorName() { - List vectorList = twoIntColumns(allocator); - try (Table t = new Table(vectorList)) { - Row c = t.immutableRow(); - c.setPosition(1); - assertEquals(2, c.getInt(INT_VECTOR_NAME_1)); - } - } - - @Test - void testNameNotFound() { - List vectorList = twoIntColumns(allocator); - try (Table t = new Table(vectorList)) { - Row c = t.immutableRow(); - c.setPosition(1); - assertThrows(IllegalArgumentException.class, - () -> c.getVarCharObj("wrong name")); - } - } - - @Test - void testWrongType() { - List vectorList = twoIntColumns(allocator); - try (Table t = new Table(vectorList)) { - Row c = t.immutableRow(); - c.setPosition(1); - assertThrows(ClassCastException.class, - () -> c.getVarCharObj(INT_VECTOR_NAME_1)); - } - } - - @Test - void getDecimal() { - List vectors = new ArrayList<>(); - DecimalVector decimalVector = new DecimalVector("decimal_vector", allocator, 55, 10); - vectors.add(decimalVector); - decimalVector.setSafe(0, new BigDecimal("0.0543278923")); - decimalVector.setSafe(1, new BigDecimal("2.0543278923")); - decimalVector.setValueCount(2); - BigDecimal one = decimalVector.getObject(1); - - NullableDecimalHolder holder1 = new NullableDecimalHolder(); - NullableDecimalHolder holder2 = new NullableDecimalHolder(); - try (Table t = new Table(vectors)) { - Row c = t.immutableRow(); - c.setPosition(1); - assertEquals(one, c.getDecimalObj("decimal_vector")); - assertEquals(one, c.getDecimalObj(0)); - c.getDecimal(0, holder1); - c.getDecimal("decimal_vector", holder2); - assertEquals(holder1.buffer, holder2.buffer); - assertEquals(c.getDecimal(0).memoryAddress(), c.getDecimal("decimal_vector").memoryAddress()); - } - } - - @Test - void getDuration() { - List vectors = new ArrayList<>(); - TimeUnit unit = TimeUnit.SECOND; - final FieldType fieldType = FieldType.nullable(new ArrowType.Duration(unit)); - - DurationVector durationVector = new DurationVector("duration_vector", fieldType, allocator); - NullableDurationHolder holder1 = new NullableDurationHolder(); - NullableDurationHolder holder2 = new NullableDurationHolder(); - - holder1.value = 100; - holder1.unit = TimeUnit.SECOND; - holder1.isSet = 1; - holder2.value = 200; - holder2.unit = TimeUnit.SECOND; - holder2.isSet = 1; - - vectors.add(durationVector); - durationVector.setSafe(0, holder1); - durationVector.setSafe(1, holder2); - durationVector.setValueCount(2); - - Duration one = durationVector.getObject(1); - try (Table t = new Table(vectors)) { - Row c = t.immutableRow(); - c.setPosition(1); - assertEquals(one, c.getDurationObj("duration_vector")); - assertEquals(one, c.getDurationObj(0)); - c.getDuration(0, holder1); - c.getDuration("duration_vector", holder2); - assertEquals(holder1.value, holder2.value); - ArrowBuf durationBuf1 = c.getDuration(0); - ArrowBuf durationBuf2 = c.getDuration("duration_vector"); - assertEquals(durationBuf1.memoryAddress(), durationBuf2.memoryAddress()); - } - } - - @Test - void getIntervalDay() { - List vectors = new ArrayList<>(); - IntervalUnit unit = IntervalUnit.DAY_TIME; - final FieldType fieldType = FieldType.nullable(new ArrowType.Interval(unit)); - - IntervalDayVector intervalDayVector = new IntervalDayVector("intervalDay_vector", fieldType, allocator); - NullableIntervalDayHolder holder1 = new NullableIntervalDayHolder(); - NullableIntervalDayHolder holder2 = new NullableIntervalDayHolder(); - - holder1.days = 100; - holder1.milliseconds = 1000; - holder1.isSet = 1; - holder2.days = 200; - holder2.milliseconds = 2000; - holder2.isSet = 1; - - vectors.add(intervalDayVector); - intervalDayVector.setSafe(0, holder1); - intervalDayVector.setSafe(1, holder2); - intervalDayVector.setValueCount(2); - - Duration one = intervalDayVector.getObject(1); - try (Table t = new Table(vectors)) { - Row c = t.immutableRow(); - c.setPosition(1); - assertEquals(one, c.getIntervalDayObj("intervalDay_vector")); - assertEquals(one, c.getIntervalDayObj(0)); - c.getIntervalDay(0, holder1); - c.getIntervalDay("intervalDay_vector", holder2); - assertEquals(holder1.days, holder2.days); - assertEquals(holder1.milliseconds, holder2.milliseconds); - ArrowBuf intDayBuf1 = c.getIntervalDay(0); - ArrowBuf intDayBuf2 = c.getIntervalDay("intervalDay_vector"); - assertEquals(intDayBuf1.memoryAddress(), intDayBuf2.memoryAddress()); - } - } - - @Test - void getIntervalMonth() { - List vectors = new ArrayList<>(); - IntervalUnit unit = IntervalUnit.MONTH_DAY_NANO; - final FieldType fieldType = FieldType.nullable(new ArrowType.Interval(unit)); - - IntervalMonthDayNanoVector intervalMonthVector = - new IntervalMonthDayNanoVector("intervalMonth_vector", fieldType, allocator); - NullableIntervalMonthDayNanoHolder holder1 = new NullableIntervalMonthDayNanoHolder(); - NullableIntervalMonthDayNanoHolder holder2 = new NullableIntervalMonthDayNanoHolder(); - - holder1.days = 1; - holder1.months = 10; - holder1.isSet = 1; - holder2.days = 2; - holder2.months = 20; - holder2.isSet = 1; - - vectors.add(intervalMonthVector); - intervalMonthVector.setSafe(0, holder1); - intervalMonthVector.setSafe(1, holder2); - intervalMonthVector.setValueCount(2); - - PeriodDuration one = intervalMonthVector.getObject(1); - try (Table t = new Table(vectors)) { - Row c = t.immutableRow(); - c.setPosition(1); - assertEquals(one, c.getIntervalMonthDayNanoObj("intervalMonth_vector")); - assertEquals(one, c.getIntervalMonthDayNanoObj(0)); - c.getIntervalMonthDayNano(0, holder1); - c.getIntervalMonthDayNano("intervalMonth_vector", holder2); - assertEquals(holder1.days, holder2.days); - assertEquals(holder1.months, holder2.months); - ArrowBuf intMonthBuf1 = c.getIntervalMonthDayNano(0); - ArrowBuf intMonthBuf2 = c.getIntervalMonthDayNano("intervalMonth_vector"); - assertEquals(intMonthBuf1.memoryAddress(), intMonthBuf2.memoryAddress()); - } - } - - @Test - void getIntervalYear() { - List vectors = new ArrayList<>(); - IntervalUnit unit = IntervalUnit.YEAR_MONTH; - final FieldType fieldType = FieldType.nullable(new ArrowType.Interval(unit)); - - IntervalYearVector intervalYearVector = new IntervalYearVector("intervalYear_vector", fieldType, allocator); - NullableIntervalYearHolder holder1 = new NullableIntervalYearHolder(); - NullableIntervalYearHolder holder2 = new NullableIntervalYearHolder(); - - holder1.value = 1; - holder1.isSet = 1; - holder2.value = 2; - holder2.isSet = 1; - - vectors.add(intervalYearVector); - intervalYearVector.setSafe(0, holder1); - intervalYearVector.setSafe(1, holder2); - intervalYearVector.setValueCount(2); - - Period one = intervalYearVector.getObject(1); - try (Table t = new Table(vectors)) { - Row c = t.immutableRow(); - c.setPosition(1); - assertEquals(one, c.getIntervalYearObj("intervalYear_vector")); - assertEquals(one, c.getIntervalYearObj(0)); - c.getIntervalYear(0, holder1); - c.getIntervalYear("intervalYear_vector", holder2); - assertEquals(holder1.value, holder2.value); - int intYear1 = c.getIntervalYear(0); - int intYear2 = c.getIntervalYear("intervalYear_vector"); - assertEquals(2, intYear1); - assertEquals(intYear1, intYear2); - } - } - - @Test - void getBit() { - List vectors = new ArrayList<>(); - - BitVector bitVector = new BitVector("bit_vector", allocator); - NullableBitHolder holder1 = new NullableBitHolder(); - NullableBitHolder holder2 = new NullableBitHolder(); - - vectors.add(bitVector); - bitVector.setSafe(0, 0); - bitVector.setSafe(1, 1); - bitVector.setValueCount(2); - - int one = bitVector.get(1); - try (Table t = new Table(vectors)) { - Row c = t.immutableRow(); - c.setPosition(1); - assertEquals(one, c.getBit("bit_vector")); - assertEquals(one, c.getBit(0)); - c.getBit(0, holder1); - c.getBit("bit_vector", holder2); - assertEquals(holder1.value, holder2.value); - } - } - - @Test - void hasNext() { - List vectorList = twoIntColumns(allocator); - try (Table t = new Table(vectorList)) { - Row c = t.immutableRow(); - assertTrue(c.hasNext()); - c.setPosition(1); - assertFalse(c.hasNext()); - } - } - - @Test - void next() { - List vectorList = twoIntColumns(allocator); - try (Table t = new Table(vectorList)) { - Row c = t.immutableRow(); - c.setPosition(0); - c.next(); - assertEquals(1, c.getRowNumber()); - } - } - - @Test - void isNull() { - List vectorList = twoIntColumns(allocator); - try (Table t = new Table(vectorList)) { - Row c = t.immutableRow(); - c.setPosition(1); - assertFalse(c.isNull(0)); - } - } - - @Test - void isNullByFieldName() { - List vectorList = twoIntColumns(allocator); - try (Table t = new Table(vectorList)) { - Row c = t.immutableRow(); - c.setPosition(1); - assertFalse(c.isNull(INT_VECTOR_NAME_1)); - } - } - - @Test - void fixedWidthVectorTest() { - List vectorList = fixedWidthVectors(allocator, 2); - try (Table t = new Table(vectorList)) { - Row c = t.immutableRow(); - c.setPosition(1); - // integer tests using vector name and index - assertFalse(c.isNull("bigInt_vector")); - assertEquals(c.getInt("int_vector"), c.getInt(0)); - assertEquals(c.getBigInt("bigInt_vector"), c.getBigInt(1)); - assertEquals(c.getSmallInt("smallInt_vector"), c.getSmallInt(2)); - assertEquals(c.getTinyInt("tinyInt_vector"), c.getTinyInt(3)); - - // integer tests using Nullable Holders - NullableIntHolder int4Holder = new NullableIntHolder(); - NullableTinyIntHolder int1Holder = new NullableTinyIntHolder(); - NullableSmallIntHolder int2Holder = new NullableSmallIntHolder(); - NullableBigIntHolder int8Holder = new NullableBigIntHolder(); - c.getInt(0, int4Holder); - c.getBigInt(1, int8Holder); - c.getSmallInt(2, int2Holder); - c.getTinyInt(3, int1Holder); - assertEquals(c.getInt("int_vector"), int4Holder.value); - assertEquals(c.getBigInt("bigInt_vector"), int8Holder.value); - assertEquals(c.getSmallInt("smallInt_vector"), int2Holder.value); - assertEquals(c.getTinyInt("tinyInt_vector"), int1Holder.value); - - c.getInt("int_vector", int4Holder); - c.getBigInt("bigInt_vector", int8Holder); - c.getSmallInt("smallInt_vector", int2Holder); - c.getTinyInt("tinyInt_vector", int1Holder); - assertEquals(c.getInt("int_vector"), int4Holder.value); - assertEquals(c.getBigInt("bigInt_vector"), int8Holder.value); - assertEquals(c.getSmallInt("smallInt_vector"), int2Holder.value); - assertEquals(c.getTinyInt("tinyInt_vector"), int1Holder.value); - - // uint tests using vector name and index - assertEquals(c.getUInt1("uInt1_vector"), c.getUInt1(4)); - assertEquals(c.getUInt2("uInt2_vector"), c.getUInt2(5)); - assertEquals(c.getUInt4("uInt4_vector"), c.getUInt4(6)); - assertEquals(c.getUInt8("uInt8_vector"), c.getUInt8(7)); - - // UInt tests using Nullable Holders - NullableUInt4Holder uInt4Holder = new NullableUInt4Holder(); - NullableUInt1Holder uInt1Holder = new NullableUInt1Holder(); - NullableUInt2Holder uInt2Holder = new NullableUInt2Holder(); - NullableUInt8Holder uInt8Holder = new NullableUInt8Holder(); - // fill the holders using vector index and test - c.getUInt1(4, uInt1Holder); - c.getUInt2(5, uInt2Holder); - c.getUInt4(6, uInt4Holder); - c.getUInt8(7, uInt8Holder); - assertEquals(c.getUInt1("uInt1_vector"), uInt1Holder.value); - assertEquals(c.getUInt2("uInt2_vector"), uInt2Holder.value); - assertEquals(c.getUInt4("uInt4_vector"), uInt4Holder.value); - assertEquals(c.getUInt8("uInt8_vector"), uInt8Holder.value); - - // refill the holders using vector name and retest - c.getUInt1("uInt1_vector", uInt1Holder); - c.getUInt2("uInt2_vector", uInt2Holder); - c.getUInt4("uInt4_vector", uInt4Holder); - c.getUInt8("uInt8_vector", uInt8Holder); - assertEquals(c.getUInt1("uInt1_vector"), uInt1Holder.value); - assertEquals(c.getUInt2("uInt2_vector"), uInt2Holder.value); - assertEquals(c.getUInt4("uInt4_vector"), uInt4Holder.value); - assertEquals(c.getUInt8("uInt8_vector"), uInt8Holder.value); - - // tests floating point - assertEquals(c.getFloat4("float4_vector"), c.getFloat4(8)); - assertEquals(c.getFloat8("float8_vector"), c.getFloat8(9)); - - // floating point tests using Nullable Holders - NullableFloat4Holder float4Holder = new NullableFloat4Holder(); - NullableFloat8Holder float8Holder = new NullableFloat8Holder(); - // fill the holders using vector index and test - c.getFloat4(8, float4Holder); - c.getFloat8(9, float8Holder); - assertEquals(c.getFloat4("float4_vector"), float4Holder.value); - assertEquals(c.getFloat8("float8_vector"), float8Holder.value); - - // refill the holders using vector name and retest - c.getFloat4("float4_vector", float4Holder); - c.getFloat8("float8_vector", float8Holder); - assertEquals(c.getFloat4("float4_vector"), float4Holder.value); - assertEquals(c.getFloat8("float8_vector"), float8Holder.value); - - // test time values using vector name versus vector index - assertEquals(c.getTimeSec("timeSec_vector"), c.getTimeSec(10)); - assertEquals(c.getTimeMilli("timeMilli_vector"), c.getTimeMilli(11)); - assertEquals(c.getTimeMicro("timeMicro_vector"), c.getTimeMicro(12)); - assertEquals(c.getTimeNano("timeNano_vector"), c.getTimeNano(13)); - - // time tests using Nullable Holders - NullableTimeSecHolder timeSecHolder = new NullableTimeSecHolder(); - NullableTimeMilliHolder timeMilliHolder = new NullableTimeMilliHolder(); - NullableTimeMicroHolder timeMicroHolder = new NullableTimeMicroHolder(); - NullableTimeNanoHolder timeNanoHolder = new NullableTimeNanoHolder(); - // fill the holders using vector index and test - c.getTimeSec(10, timeSecHolder); - c.getTimeMilli(11, timeMilliHolder); - c.getTimeMicro(12, timeMicroHolder); - c.getTimeNano(13, timeNanoHolder); - assertEquals(c.getTimeSec("timeSec_vector"), timeSecHolder.value); - assertEquals(c.getTimeMilli("timeMilli_vector"), timeMilliHolder.value); - assertEquals(c.getTimeMicro("timeMicro_vector"), timeMicroHolder.value); - assertEquals(c.getTimeNano("timeNano_vector"), timeNanoHolder.value); - - LocalDateTime milliDT = c.getTimeMilliObj(11); - assertNotNull(milliDT); - assertEquals(milliDT, c.getTimeMilliObj("timeMilli_vector")); - - // refill the holders using vector name and retest - c.getTimeSec("timeSec_vector", timeSecHolder); - c.getTimeMilli("timeMilli_vector", timeMilliHolder); - c.getTimeMicro("timeMicro_vector", timeMicroHolder); - c.getTimeNano("timeNano_vector", timeNanoHolder); - assertEquals(c.getTimeSec("timeSec_vector"), timeSecHolder.value); - assertEquals(c.getTimeMilli("timeMilli_vector"), timeMilliHolder.value); - assertEquals(c.getTimeMicro("timeMicro_vector"), timeMicroHolder.value); - assertEquals(c.getTimeNano("timeNano_vector"), timeNanoHolder.value); - - assertEquals(c.getTimeStampSec("timeStampSec_vector"), c.getTimeStampSec(14)); - assertEquals(c.getTimeStampMilli("timeStampMilli_vector"), c.getTimeStampMilli(15)); - assertEquals(c.getTimeStampMicro("timeStampMicro_vector"), c.getTimeStampMicro(16)); - assertEquals(c.getTimeStampNano("timeStampNano_vector"), c.getTimeStampNano(17)); - - // time stamp tests using Nullable Holders - NullableTimeStampSecHolder timeStampSecHolder = new NullableTimeStampSecHolder(); - NullableTimeStampMilliHolder timeStampMilliHolder = new NullableTimeStampMilliHolder(); - NullableTimeStampMicroHolder timeStampMicroHolder = new NullableTimeStampMicroHolder(); - NullableTimeStampNanoHolder timeStampNanoHolder = new NullableTimeStampNanoHolder(); - // fill the holders using vector index and test - c.getTimeStampSec(14, timeStampSecHolder); - c.getTimeStampMilli(15, timeStampMilliHolder); - c.getTimeStampMicro(16, timeStampMicroHolder); - c.getTimeStampNano(17, timeStampNanoHolder); - assertEquals(c.getTimeStampSec("timeStampSec_vector"), timeStampSecHolder.value); - assertEquals(c.getTimeStampMilli("timeStampMilli_vector"), timeStampMilliHolder.value); - assertEquals(c.getTimeStampMicro("timeStampMicro_vector"), timeStampMicroHolder.value); - assertEquals(c.getTimeStampNano("timeStampNano_vector"), timeStampNanoHolder.value); - - LocalDateTime secDT = c.getTimeStampSecObj(14); - assertNotNull(secDT); - assertEquals(secDT, c.getTimeStampSecObj("timeStampSec_vector")); - - LocalDateTime milliDT1 = c.getTimeStampMilliObj(15); - assertNotNull(milliDT1); - assertEquals(milliDT1, c.getTimeStampMilliObj("timeStampMilli_vector")); - - LocalDateTime microDT = c.getTimeStampMicroObj(16); - assertNotNull(microDT); - assertEquals(microDT, c.getTimeStampMicroObj("timeStampMicro_vector")); - - LocalDateTime nanoDT = c.getTimeStampNanoObj(17); - assertNotNull(nanoDT); - assertEquals(nanoDT, c.getTimeStampNanoObj("timeStampNano_vector")); - - // refill the holders using vector name and retest - c.getTimeStampSec("timeStampSec_vector", timeStampSecHolder); - c.getTimeStampMilli("timeStampMilli_vector", timeStampMilliHolder); - c.getTimeStampMicro("timeStampMicro_vector", timeStampMicroHolder); - c.getTimeStampNano("timeStampNano_vector", timeStampNanoHolder); - assertEquals(c.getTimeStampSec("timeStampSec_vector"), timeStampSecHolder.value); - assertEquals(c.getTimeStampMilli("timeStampMilli_vector"), timeStampMilliHolder.value); - assertEquals(c.getTimeStampMicro("timeStampMicro_vector"), timeStampMicroHolder.value); - assertEquals(c.getTimeStampNano("timeStampNano_vector"), timeStampNanoHolder.value); - } - } - - @Test - void timestampsWithTimezones() { - List vectorList = timezoneTemporalVectors(allocator, 2); - try (Table t = new Table(vectorList)) { - Row c = t.immutableRow(); - c.setPosition(1); - - assertEquals(c.getTimeStampSecTZ("timeStampSecTZ_vector"), c.getTimeStampSecTZ(0)); - assertEquals(c.getTimeStampMilliTZ("timeStampMilliTZ_vector"), c.getTimeStampMilliTZ(1)); - assertEquals(c.getTimeStampMicroTZ("timeStampMicroTZ_vector"), c.getTimeStampMicroTZ(2)); - assertEquals(c.getTimeStampNanoTZ("timeStampNanoTZ_vector"), c.getTimeStampNanoTZ(3)); - - // time stamp tests using Nullable Holders - NullableTimeStampSecTZHolder timeStampSecHolder = new NullableTimeStampSecTZHolder(); - NullableTimeStampMilliTZHolder timeStampMilliHolder = new NullableTimeStampMilliTZHolder(); - NullableTimeStampMicroTZHolder timeStampMicroHolder = new NullableTimeStampMicroTZHolder(); - NullableTimeStampNanoTZHolder timeStampNanoHolder = new NullableTimeStampNanoTZHolder(); - - // fill the holders using vector index and test - c.getTimeStampSecTZ(0, timeStampSecHolder); - c.getTimeStampMilliTZ(1, timeStampMilliHolder); - c.getTimeStampMicroTZ(2, timeStampMicroHolder); - c.getTimeStampNanoTZ(3, timeStampNanoHolder); - - long tsSec = timeStampSecHolder.value; - long tsMil = timeStampMilliHolder.value; - long tsMic = timeStampMicroHolder.value; - long tsNan = timeStampNanoHolder.value; - - assertEquals(c.getTimeStampSecTZ("timeStampSecTZ_vector"), timeStampSecHolder.value); - assertEquals(c.getTimeStampMilliTZ("timeStampMilliTZ_vector"), timeStampMilliHolder.value); - assertEquals(c.getTimeStampMicroTZ("timeStampMicroTZ_vector"), timeStampMicroHolder.value); - assertEquals(c.getTimeStampNanoTZ("timeStampNanoTZ_vector"), timeStampNanoHolder.value); - - // fill the holders using vector index and test - c.getTimeStampSecTZ("timeStampSecTZ_vector", timeStampSecHolder); - c.getTimeStampMilliTZ("timeStampMilliTZ_vector", timeStampMilliHolder); - c.getTimeStampMicroTZ("timeStampMicroTZ_vector", timeStampMicroHolder); - c.getTimeStampNanoTZ("timeStampNanoTZ_vector", timeStampNanoHolder); - - assertEquals(tsSec, timeStampSecHolder.value); - assertEquals(tsMil, timeStampMilliHolder.value); - assertEquals(tsMic, timeStampMicroHolder.value); - assertEquals(tsNan, timeStampNanoHolder.value); - } - } - - @Test - void getVarChar() { - List vectorList = intPlusVarcharColumns(allocator); - try (Table t = new Table(vectorList)) { - Row c = t.immutableRow(); - c.setPosition(1); - assertEquals(c.getVarCharObj(1), "two"); - assertEquals(c.getVarCharObj(1), c.getVarCharObj(VARCHAR_VECTOR_NAME_1)); - assertArrayEquals("two".getBytes(), c.getVarChar(VARCHAR_VECTOR_NAME_1)); - assertArrayEquals("two".getBytes(), c.getVarChar(1)); - } - } - - @Test - void getVarBinary() { - List vectorList = intPlusVarBinaryColumns(allocator); - try (Table t = new Table(vectorList)) { - Row c = t.immutableRow(); - c.setPosition(1); - assertArrayEquals(c.getVarBinary(1), "two".getBytes()); - assertArrayEquals(c.getVarBinary(1), c.getVarBinary(VARBINARY_VECTOR_NAME_1)); - } - } - - @Test - void getLargeVarBinary() { - List vectorList = intPlusLargeVarBinaryColumns(allocator); - try (Table t = new Table(vectorList)) { - Row c = t.immutableRow(); - c.setPosition(1); - assertArrayEquals(c.getLargeVarBinary(1), "two".getBytes()); - assertArrayEquals(c.getLargeVarBinary(1), c.getLargeVarBinary(VARBINARY_VECTOR_NAME_1)); - } - } - - @Test - void getLargeVarChar() { - List vectorList = intPlusLargeVarcharColumns(allocator); - try (Table t = new Table(vectorList)) { - Row c = t.immutableRow(); - c.setPosition(1); - assertEquals(c.getLargeVarCharObj(1), "two"); - assertEquals(c.getLargeVarCharObj(1), c.getLargeVarCharObj(VARCHAR_VECTOR_NAME_1)); - assertArrayEquals("two".getBytes(), c.getLargeVarChar(VARCHAR_VECTOR_NAME_1)); - assertArrayEquals("two".getBytes(), c.getLargeVarChar(1)); - } - } - - @Test - void getFixedBinary() { - List vectorList = intPlusFixedBinaryColumns(allocator); - try (Table t = new Table(vectorList)) { - Row c = t.immutableRow(); - c.setPosition(1); - assertArrayEquals(c.getFixedSizeBinary(1), "two".getBytes()); - assertArrayEquals(c.getFixedSizeBinary(1), c.getFixedSizeBinary(FIXEDBINARY_VECTOR_NAME_1)); - } - } - - @Test - void testSimpleListVector1() { - try (ListVector listVector = simpleListVector(allocator); - VectorSchemaRoot vectorSchemaRoot = VectorSchemaRoot.of(listVector); - Table table = new Table(vectorSchemaRoot)) { - for (Row c : table) { - @SuppressWarnings("unchecked") - List list = (List) c.getList(INT_LIST_VECTOR_NAME); - assertEquals(10, list.size()); - } - } - } - - @Test - void testSimpleListVector2() { - try (ListVector listVector = simpleListVector(allocator); - VectorSchemaRoot vectorSchemaRoot = VectorSchemaRoot.of(listVector); - Table table = new Table(vectorSchemaRoot)) { - for (Row c : table) { - @SuppressWarnings("unchecked") - List list = (List) c.getList(0); - assertEquals(10, list.size()); - } - } - } - - @Test - void testSimpleStructVector1() { - try (StructVector structVector = simpleStructVector(allocator); - VectorSchemaRoot vectorSchemaRoot = VectorSchemaRoot.of(structVector); - Table table = new Table(vectorSchemaRoot)) { - for (Row c : table) { - @SuppressWarnings("unchecked") - JsonStringHashMap struct = - (JsonStringHashMap) c.getStruct(STRUCT_VECTOR_NAME); - @SuppressWarnings("unchecked") - JsonStringHashMap struct1 = - (JsonStringHashMap) c.getStruct(0); - int a = (int) struct.get("struct_int_child"); - double b = (double) struct.get("struct_flt_child"); - int a1 = (int) struct1.get("struct_int_child"); - double b1 = (double) struct1.get("struct_flt_child"); - assertNotNull(struct); - assertEquals(a, a1); - assertEquals(b, b1); - assertTrue(a >= 0); - assertTrue(b <= a, String.format("a = %s and b = %s", a, b)); - } - } - } - - @Test - void testSimpleUnionVector() { - try (UnionVector unionVector = simpleUnionVector(allocator); - VectorSchemaRoot vsr = VectorSchemaRoot.of(unionVector); - Table table = new Table(vsr)) { - Row c = table.immutableRow(); - c.setPosition(0); - Object object0 = c.getUnion(UNION_VECTOR_NAME); - Object object1 = c.getUnion(0); - assertEquals(object0, object1); - c.setPosition(1); - assertNull(c.getUnion(UNION_VECTOR_NAME)); - c.setPosition(2); - Object object2 = c.getUnion(UNION_VECTOR_NAME); - assertEquals(100, object0); - assertEquals(100, object2); - } - } - - @Test - void testSimpleDenseUnionVector() { - try (DenseUnionVector unionVector = simpleDenseUnionVector(allocator); - VectorSchemaRoot vsr = VectorSchemaRoot.of(unionVector); - Table table = new Table(vsr)) { - Row c = table.immutableRow(); - c.setPosition(0); - Object object0 = c.getDenseUnion(UNION_VECTOR_NAME); - Object object1 = c.getDenseUnion(0); - assertEquals(object0, object1); - c.setPosition(1); - assertNull(c.getDenseUnion(UNION_VECTOR_NAME)); - c.setPosition(2); - Object object2 = c.getDenseUnion(UNION_VECTOR_NAME); - assertEquals(100, object0); - assertEquals(100, object2); - } - } - - @Test - void testExtensionTypeVector() { - TestExtensionType.LocationVector vector = new TestExtensionType.LocationVector("location", allocator); - vector.allocateNew(); - vector.set(0, 34.073814f, -118.240784f); - vector.setValueCount(1); - - try (VectorSchemaRoot vsr = VectorSchemaRoot.of(vector); - Table table = new Table(vsr)) { - Row c = table.immutableRow(); - c.setPosition(0); - Object object0 = c.getExtensionType("location"); - Object object1 = c.getExtensionType(0); - assertEquals(object0, object1); - @SuppressWarnings("unchecked") - JsonStringHashMap struct0 = - (JsonStringHashMap) object0; - assertEquals(34.073814f, struct0.get("Latitude")); - } - } - - @Test - void testSimpleMapVector1() { - try (MapVector mapVector = simpleMapVector(allocator); - Table table = Table.of(mapVector)) { - - int i = 1; - for (Row c : table) { - @SuppressWarnings("unchecked") - List> list = - (List>) c.getMap(BIGINT_INT_MAP_VECTOR_NAME); - @SuppressWarnings("unchecked") - List> list1 = - (List>) c.getMap(0); - for (int j = 0; j < list1.size(); j++) { - assertEquals(list.get(j), list1.get(j)); - } - if (list != null && !list.isEmpty()) { - assertEquals(i, list.size()); - for (JsonStringHashMap sv : list) { - assertEquals(2, sv.size()); - Long o1 = (Long) sv.get("key"); - Integer o2 = (Integer) sv.get("value"); - assertEquals(o1, o2.longValue()); - } - } - i++; - } - } - } - - @Test - void resetPosition() { - try (ListVector listVector = simpleListVector(allocator); - VectorSchemaRoot vectorSchemaRoot = VectorSchemaRoot.of(listVector); - Table table = new Table(vectorSchemaRoot)) { - Row row = table.immutableRow(); - row.next(); - assertEquals(0, row.rowNumber); - row.resetPosition(); - assertEquals(-1, row.rowNumber); - } - } -}