Skip to content

Commit

Permalink
SNOW-1812949 add get object and get bytes support for native arrow st…
Browse files Browse the repository at this point in the history
…ructured types (#1968)

Co-authored-by: Przemyslaw Motacki <[email protected]>
  • Loading branch information
sfc-gh-mkubik and sfc-gh-pmotacki authored Dec 1, 2024
1 parent c62c5e4 commit fad4f12
Show file tree
Hide file tree
Showing 16 changed files with 224 additions and 70 deletions.
54 changes: 33 additions & 21 deletions src/main/java/net/snowflake/client/core/SFArrowResultSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import java.util.stream.Stream;
import net.snowflake.client.core.arrow.ArrayConverter;
import net.snowflake.client.core.arrow.ArrowVectorConverter;
import net.snowflake.client.core.arrow.StructConverter;
import net.snowflake.client.core.arrow.StructObjectWrapper;
import net.snowflake.client.core.arrow.VarCharConverter;
import net.snowflake.client.core.arrow.VectorTypeConverter;
import net.snowflake.client.core.json.Converters;
Expand Down Expand Up @@ -576,16 +576,19 @@ public Object getObject(int columnIndex) throws SFException {
converter.setSessionTimeZone(sessionTimeZone);
Object obj = converter.toObject(index);
boolean isStructuredType = resultSetMetaData.isStructuredTypeColumn(columnIndex);
if (type == Types.STRUCT && isStructuredType) {
if (converter instanceof VarCharConverter) {
return createJsonSqlInput(columnIndex, obj);
} else if (converter instanceof StructConverter) {
return createArrowSqlInput(columnIndex, (Map<String, Object>) obj);
if (isVarcharConvertedStruct(type, isStructuredType, converter)) {
if (obj != null) {
return new StructObjectWrapper((String) obj, createJsonSqlInput(columnIndex, obj));
}
}
return obj;
}

private boolean isVarcharConvertedStruct(
int type, boolean isStructuredType, ArrowVectorConverter converter) {
return type == Types.STRUCT && isStructuredType && converter instanceof VarCharConverter;
}

private Object createJsonSqlInput(int columnIndex, Object obj) throws SFException {
try {
if (obj == null) {
Expand All @@ -605,15 +608,6 @@ private Object createJsonSqlInput(int columnIndex, Object obj) throws SFExceptio
}
}

private Object createArrowSqlInput(int columnIndex, Map<String, Object> input)
throws SFException {
if (input == null) {
return null;
}
return new ArrowSqlInput(
input, session, converters, resultSetMetaData.getColumnFields(columnIndex));
}

@Override
public Array getArray(int columnIndex) throws SFException {
ArrowVectorConverter converter = currentChunkIterator.getCurrentConverter(columnIndex - 1);
Expand All @@ -625,16 +619,19 @@ public Array getArray(int columnIndex) throws SFException {
}
if (converter instanceof VarCharConverter) {
return getJsonArray((String) obj, columnIndex);
} else if (converter instanceof ArrayConverter) {
return getArrowArray((List<Object>) obj, columnIndex);
} else if (converter instanceof VectorTypeConverter) {
return getArrowArray((List<Object>) obj, columnIndex);
} else if (converter instanceof ArrayConverter || converter instanceof VectorTypeConverter) {
StructObjectWrapper structObjectWrapper = (StructObjectWrapper) obj;
return getArrowArray(
structObjectWrapper.getJsonString(),
(List<Object>) structObjectWrapper.getObject(),
columnIndex);
} else {
throw new SFException(queryId, ErrorCode.INVALID_STRUCT_DATA);
}
}

private SfSqlArray getArrowArray(List<Object> elements, int columnIndex) throws SFException {
private SfSqlArray getArrowArray(String text, List<Object> elements, int columnIndex)
throws SFException {
try {
List<FieldMetadata> fieldMetadataList = resultSetMetaData.getColumnFields(columnIndex);
if (fieldMetadataList.size() != 1) {
Expand All @@ -651,62 +648,74 @@ private SfSqlArray getArrowArray(List<Object> elements, int columnIndex) throws
switch (columnType) {
case Types.INTEGER:
return new SfSqlArray(
text,
columnSubType,
mapAndConvert(elements, converters.integerConverter(columnType))
.toArray(Integer[]::new));
case Types.SMALLINT:
return new SfSqlArray(
text,
columnSubType,
mapAndConvert(elements, converters.smallIntConverter(columnType))
.toArray(Short[]::new));
case Types.TINYINT:
return new SfSqlArray(
text,
columnSubType,
mapAndConvert(elements, converters.tinyIntConverter(columnType))
.toArray(Byte[]::new));
case Types.BIGINT:
return new SfSqlArray(
text,
columnSubType,
mapAndConvert(elements, converters.bigIntConverter(columnType)).toArray(Long[]::new));
case Types.DECIMAL:
case Types.NUMERIC:
return new SfSqlArray(
text,
columnSubType,
mapAndConvert(elements, converters.bigDecimalConverter(columnType))
.toArray(BigDecimal[]::new));
case Types.CHAR:
case Types.VARCHAR:
case Types.LONGNVARCHAR:
return new SfSqlArray(
text,
columnSubType,
mapAndConvert(elements, converters.varcharConverter(columnType, columnSubType, scale))
.toArray(String[]::new));
case Types.BINARY:
return new SfSqlArray(
text,
columnSubType,
mapAndConvert(elements, converters.bytesConverter(columnType, scale))
.toArray(Byte[][]::new));
case Types.FLOAT:
case Types.REAL:
return new SfSqlArray(
text,
columnSubType,
mapAndConvert(elements, converters.floatConverter(columnType)).toArray(Float[]::new));
case Types.DOUBLE:
return new SfSqlArray(
text,
columnSubType,
mapAndConvert(elements, converters.doubleConverter(columnType))
.toArray(Double[]::new));
case Types.DATE:
return new SfSqlArray(
text,
columnSubType,
mapAndConvert(elements, converters.dateFromIntConverter(sessionTimeZone))
.toArray(Date[]::new));
case Types.TIME:
return new SfSqlArray(
text,
columnSubType,
mapAndConvert(elements, converters.timeFromIntConverter(scale)).toArray(Time[]::new));
case Types.TIMESTAMP:
return new SfSqlArray(
text,
columnSubType,
mapAndConvert(
elements,
Expand All @@ -715,13 +724,16 @@ private SfSqlArray getArrowArray(List<Object> elements, int columnIndex) throws
.toArray(Timestamp[]::new));
case Types.BOOLEAN:
return new SfSqlArray(
text,
columnSubType,
mapAndConvert(elements, converters.booleanConverter(columnType))
.toArray(Boolean[]::new));
case Types.STRUCT:
return new SfSqlArray(columnSubType, mapAndConvert(elements, e -> e).toArray(Map[]::new));
return new SfSqlArray(
text, columnSubType, mapAndConvert(elements, e -> e).toArray(Map[]::new));
case Types.ARRAY:
return new SfSqlArray(
text,
columnSubType,
mapAndConvert(elements, e -> ((List) e).stream().toArray(Map[]::new))
.toArray(Map[][]::new));
Expand Down
19 changes: 17 additions & 2 deletions src/main/java/net/snowflake/client/core/SFBaseResultSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ protected SQLInput createJsonSqlInputForColumn(
}

@SnowflakeJdbcInternalApi
protected SfSqlArray getJsonArray(String obj, int columnIndex) throws SFException {
protected SfSqlArray getJsonArray(String arrayString, int columnIndex) throws SFException {
try {
List<FieldMetadata> fieldMetadataList = resultSetMetaData.getColumnFields(columnIndex);
if (fieldMetadataList.size() != 1) {
Expand All @@ -288,73 +288,85 @@ protected SfSqlArray getJsonArray(String obj, int columnIndex) throws SFExceptio
int columnType = ColumnTypeHelper.getColumnType(columnSubType, session);
int scale = fieldMetadata.getScale();

ArrayNode arrayNode = (ArrayNode) OBJECT_MAPPER.readTree(obj);
ArrayNode arrayNode = (ArrayNode) OBJECT_MAPPER.readTree(arrayString);
Iterator<JsonNode> nodeElements = arrayNode.elements();

switch (columnType) {
case Types.INTEGER:
return new SfSqlArray(
arrayString,
columnSubType,
getStream(nodeElements, getConverters().integerConverter(columnType))
.toArray(Integer[]::new));
case Types.SMALLINT:
return new SfSqlArray(
arrayString,
columnSubType,
getStream(nodeElements, getConverters().smallIntConverter(columnType))
.toArray(Short[]::new));
case Types.TINYINT:
return new SfSqlArray(
arrayString,
columnSubType,
getStream(nodeElements, getConverters().tinyIntConverter(columnType))
.toArray(Byte[]::new));
case Types.BIGINT:
return new SfSqlArray(
arrayString,
columnSubType,
getStream(nodeElements, getConverters().bigIntConverter(columnType))
.toArray(Long[]::new));
case Types.DECIMAL:
case Types.NUMERIC:
return new SfSqlArray(
arrayString,
columnSubType,
convertToFixedArray(
getStream(nodeElements, getConverters().bigDecimalConverter(columnType))));
case Types.CHAR:
case Types.VARCHAR:
case Types.LONGNVARCHAR:
return new SfSqlArray(
arrayString,
columnSubType,
getStream(
nodeElements,
getConverters().varcharConverter(columnType, columnSubType, scale))
.toArray(String[]::new));
case Types.BINARY:
return new SfSqlArray(
arrayString,
columnSubType,
getStream(nodeElements, getConverters().bytesConverter(columnType, scale))
.toArray(Byte[][]::new));
case Types.FLOAT:
case Types.REAL:
return new SfSqlArray(
arrayString,
columnSubType,
getStream(nodeElements, getConverters().floatConverter(columnType))
.toArray(Float[]::new));
case Types.DOUBLE:
return new SfSqlArray(
arrayString,
columnSubType,
getStream(nodeElements, getConverters().doubleConverter(columnType))
.toArray(Double[]::new));
case Types.DATE:
return new SfSqlArray(
arrayString,
columnSubType,
getStream(nodeElements, getConverters().dateStringConverter(session))
.toArray(Date[]::new));
case Types.TIME:
return new SfSqlArray(
arrayString,
columnSubType,
getStream(nodeElements, getConverters().timeFromStringConverter(session))
.toArray(Time[]::new));
case Types.TIMESTAMP:
return new SfSqlArray(
arrayString,
columnSubType,
getStream(
nodeElements,
Expand All @@ -364,16 +376,19 @@ protected SfSqlArray getJsonArray(String obj, int columnIndex) throws SFExceptio
.toArray(Timestamp[]::new));
case Types.BOOLEAN:
return new SfSqlArray(
arrayString,
columnSubType,
getStream(nodeElements, getConverters().booleanConverter(columnType))
.toArray(Boolean[]::new));
case Types.STRUCT:
return new SfSqlArray(
arrayString,
columnSubType,
getStream(nodeElements, getConverters().structConverter(OBJECT_MAPPER))
.toArray(Map[]::new));
case Types.ARRAY:
return new SfSqlArray(
arrayString,
columnSubType,
getStream(nodeElements, getConverters().arrayConverter(OBJECT_MAPPER))
.toArray(Map[][]::new));
Expand Down
3 changes: 2 additions & 1 deletion src/main/java/net/snowflake/client/core/SFJsonResultSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import java.sql.Types;
import java.util.List;
import java.util.TimeZone;
import net.snowflake.client.core.arrow.StructObjectWrapper;
import net.snowflake.client.core.json.Converters;
import net.snowflake.client.jdbc.ErrorCode;
import net.snowflake.client.jdbc.FieldMetadata;
Expand Down Expand Up @@ -87,7 +88,7 @@ public Object getObject(int columnIndex) throws SFException {

case Types.STRUCT:
if (resultSetMetaData.isStructuredTypeColumn(columnIndex)) {
return getSqlInput((String) obj, columnIndex);
return new StructObjectWrapper((String) obj, getSqlInput((String) obj, columnIndex));
} else {
throw new SFException(ErrorCode.FEATURE_UNSUPPORTED, "data type: " + type);
}
Expand Down
25 changes: 24 additions & 1 deletion src/main/java/net/snowflake/client/core/SfSqlArray.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package net.snowflake.client.core;

import static net.snowflake.client.core.FieldSchemaCreator.buildBindingSchemaForType;
import static net.snowflake.client.core.FieldSchemaCreator.logger;

import com.fasterxml.jackson.core.JsonProcessingException;
import java.sql.Array;
Expand All @@ -16,14 +17,21 @@
@SnowflakeJdbcInternalApi
public class SfSqlArray implements Array {

private String text;
private int baseType;
private Object elements;
private String jsonStringFromElements;

public SfSqlArray(int baseType, Object elements) {
public SfSqlArray(String text, int baseType, Object elements) {
this.text = text;
this.baseType = baseType;
this.elements = elements;
}

public SfSqlArray(int baseType, Object elements) {
this(null, baseType, elements);
}

@Override
public String getBaseTypeName() throws SQLException {
return JDBCType.valueOf(baseType).getName();
Expand Down Expand Up @@ -81,7 +89,22 @@ public ResultSet getResultSet(long index, int count, Map<String, Class<?>> map)
@Override
public void free() throws SQLException {}

public String getText() {
if (text == null) {
logger.warn("Text field wasn't initialized. Should never happen.");
}
return text;
}

public String getJsonString() throws SQLException {
if (jsonStringFromElements == null) {
jsonStringFromElements = buildJsonStringFromElements(elements);
}

return jsonStringFromElements;
}

private static String buildJsonStringFromElements(Object elements) throws SQLException {
try {
return SnowflakeUtil.mapJson(elements);
} catch (JsonProcessingException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,12 @@ public ArrayConverter(ListVector valueVector, int vectorIndex, DataConversionCon

@Override
public Object toObject(int index) throws SFException {
return vector.getObject(index);
return isNull(index) ? null : new StructObjectWrapper(toString(index), vector.getObject(index));
}

@Override
public byte[] toBytes(int index) throws SFException {
return isNull(index) ? null : toString(index).getBytes();
}

@Override
Expand Down
Loading

0 comments on commit fad4f12

Please sign in to comment.