diff --git a/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/binder/ColumnBinderArrowTypeVisitor.java b/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/binder/ColumnBinderArrowTypeVisitor.java index f790b6a541153..dc708724043d0 100644 --- a/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/binder/ColumnBinderArrowTypeVisitor.java +++ b/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/binder/ColumnBinderArrowTypeVisitor.java @@ -45,6 +45,7 @@ import org.apache.arrow.vector.VarBinaryVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.MapVector; import org.apache.arrow.vector.types.pojo.ArrowType; /** @@ -99,7 +100,7 @@ public ColumnBinder visit(ArrowType.Union type) { @Override public ColumnBinder visit(ArrowType.Map type) { - throw new UnsupportedOperationException("No column binder implemented for type " + type); + return new MapBinder((MapVector) vector); } @Override diff --git a/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/binder/MapBinder.java b/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/binder/MapBinder.java new file mode 100644 index 0000000000000..07391eb7cbfb4 --- /dev/null +++ b/java/adapter/jdbc/src/main/java/org/apache/arrow/adapter/jdbc/binder/MapBinder.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.adapter.jdbc.binder; + +import java.sql.PreparedStatement; +import java.sql.SQLException; +import java.sql.Types; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Objects; + +import org.apache.arrow.vector.complex.MapVector; +import org.apache.arrow.vector.complex.impl.UnionMapReader; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.util.JsonStringHashMap; + +/** + * A column binder for map of primitive values. + */ +public class MapBinder extends BaseColumnBinder { + + private UnionMapReader reader; + private final boolean isTextKey; + private final boolean isTextValue; + + public MapBinder(MapVector vector) { + this(vector, Types.VARCHAR); + } + + /** + * Init MapBinder and determine type of data vector. + * + * @param vector corresponding data vector from arrow buffer for binding + * @param jdbcType parameter jdbc type + */ + public MapBinder(MapVector vector, int jdbcType) { + super(vector, jdbcType); + reader = vector.getReader(); + List structField = Objects.requireNonNull(vector.getField()).getChildren(); + if (structField.size() != 1) { + throw new IllegalArgumentException("Expected Struct field metadata inside Map field"); + } + List keyValueFields = Objects.requireNonNull(structField.get(0)).getChildren(); + if (keyValueFields.size() != 2) { + throw new IllegalArgumentException("Expected two children fields " + + "inside nested Struct field in Map"); + } + ArrowType keyType = Objects.requireNonNull(keyValueFields.get(0)).getType(); + ArrowType valueType = Objects.requireNonNull(keyValueFields.get(1)).getType(); + isTextKey = ArrowType.Utf8.INSTANCE.equals(keyType); + isTextValue = ArrowType.Utf8.INSTANCE.equals(valueType); + } + + @Override + public void bind(PreparedStatement statement, + int parameterIndex, int rowIndex) throws SQLException { + reader.setPosition(rowIndex); + LinkedHashMap tags = new JsonStringHashMap<>(); + while (reader.next()) { + Object key = reader.key().readObject(); + Object value = reader.value().readObject(); + tags.put(isTextKey && key != null ? key.toString() : key, + isTextValue && value != null ? value.toString() : value); + } + switch (jdbcType) { + case Types.VARCHAR: + statement.setString(parameterIndex, tags.toString()); + break; + case Types.OTHER: + default: + statement.setObject(parameterIndex, tags); + } + } +} diff --git a/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/JdbcParameterBinderTest.java b/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/JdbcParameterBinderTest.java index fb4e6e5eb8896..15b9ab0386159 100644 --- a/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/JdbcParameterBinderTest.java +++ b/java/adapter/jdbc/src/test/java/org/apache/arrow/adapter/jdbc/JdbcParameterBinderTest.java @@ -30,6 +30,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Map; import java.util.function.BiConsumer; import org.apache.arrow.adapter.jdbc.binder.ColumnBinder; @@ -69,6 +70,7 @@ import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.MapVector; import org.apache.arrow.vector.types.DateUnit; import org.apache.arrow.vector.types.FloatingPointPrecision; import org.apache.arrow.vector.types.TimeUnit; @@ -76,6 +78,7 @@ import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.JsonStringHashMap; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -473,6 +476,81 @@ void listOfString() throws SQLException { testListType((ArrowType) new ArrowType.Utf8(), setValue, ListVector::setNull, values); } + @Test + void mapOfString() throws SQLException { + TriConsumer> setValue = (mapVector, index, values) -> { + org.apache.arrow.vector.complex.impl.UnionMapWriter mapWriter = mapVector.getWriter(); + mapWriter.setPosition(index); + mapWriter.startMap(); + values.entrySet().forEach(mapValue -> { + if (mapValue != null) { + byte[] keyBytes = mapValue.getKey().getBytes(StandardCharsets.UTF_8); + byte[] valueBytes = mapValue.getValue().getBytes(StandardCharsets.UTF_8); + try ( + ArrowBuf keyBuf = allocator.buffer(keyBytes.length); + ArrowBuf valueBuf = allocator.buffer(valueBytes.length); + ) { + mapWriter.startEntry(); + keyBuf.writeBytes(keyBytes); + valueBuf.writeBytes(valueBytes); + mapWriter.key().varChar().writeVarChar(0, keyBytes.length, keyBuf); + mapWriter.value().varChar().writeVarChar(0, valueBytes.length, valueBuf); + mapWriter.endEntry(); + } + } else { + mapWriter.writeNull(); + } + }); + mapWriter.endMap(); + }; + + JsonStringHashMap value1 = new JsonStringHashMap(); + value1.put("a", "b"); + value1.put("c", "d"); + JsonStringHashMap value2 = new JsonStringHashMap(); + value2.put("d", "e"); + value2.put("f", "g"); + value2.put("k", "l"); + JsonStringHashMap value3 = new JsonStringHashMap(); + value3.put("y", "z"); + value3.put("arrow", "cool"); + List> values = Arrays.asList(value1, value2, value3, Collections.emptyMap()); + testMapType(new ArrowType.Map(true), setValue, MapVector::setNull, values, new ArrowType.Utf8()); + } + + @Test + void mapOfInteger() throws SQLException { + TriConsumer> setValue = (mapVector, index, values) -> { + org.apache.arrow.vector.complex.impl.UnionMapWriter mapWriter = mapVector.getWriter(); + mapWriter.setPosition(index); + mapWriter.startMap(); + values.entrySet().forEach(mapValue -> { + if (mapValue != null) { + mapWriter.startEntry(); + mapWriter.key().integer().writeInt(mapValue.getKey()); + mapWriter.value().integer().writeInt(mapValue.getValue()); + mapWriter.endEntry(); + } else { + mapWriter.writeNull(); + } + }); + mapWriter.endMap(); + }; + + JsonStringHashMap value1 = new JsonStringHashMap(); + value1.put(1, 2); + value1.put(3, 4); + JsonStringHashMap value2 = new JsonStringHashMap(); + value2.put(5, 6); + value2.put(7, 8); + value2.put(9, 1024); + JsonStringHashMap value3 = new JsonStringHashMap(); + value3.put(Integer.MIN_VALUE, Integer.MAX_VALUE); + value3.put(0, 4096); + List> values = Arrays.asList(value1, value2, value3, Collections.emptyMap()); + testMapType(new ArrowType.Map(true), setValue, MapVector::setNull, values, new ArrowType.Int(32, true)); + } + @FunctionalInterface interface TriConsumer { void accept(T value1, U value2, V value3); @@ -672,4 +750,110 @@ void testListType(ArrowType arrowType, TriConsumer void testMapType(ArrowType arrowType, TriConsumer setValue, + BiConsumer setNull, List values, + ArrowType elementType) throws SQLException { + int jdbcType = Types.VARCHAR; + FieldType keyType = new FieldType(false, elementType, null, null); + FieldType mapType = new FieldType(false, ArrowType.Struct.INSTANCE, null, null); + Schema schema = new Schema(Collections.singletonList(new Field("field", FieldType.nullable(arrowType), + Collections.singletonList(new Field(MapVector.KEY_NAME, mapType, + Arrays.asList(new Field(MapVector.KEY_NAME, keyType, null), + new Field(MapVector.VALUE_NAME, keyType, null))))))); + try (final MockPreparedStatement statement = new MockPreparedStatement(); + final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + final JdbcParameterBinder binder = + JdbcParameterBinder.builder(statement, root).bindAll().build(); + assertThat(binder.next()).isFalse(); + + @SuppressWarnings("unchecked") + final V vector = (V) root.getVector(0); + final ColumnBinder columnBinder = ColumnBinder.forVector(vector); + assertThat(columnBinder.getJdbcType()).isEqualTo(jdbcType); + + setValue.accept(vector, 0, values.get(0)); + setValue.accept(vector, 1, values.get(1)); + setNull.accept(vector, 2); + root.setRowCount(3); + + assertThat(binder.next()).isTrue(); + assertThat(statement.getParamValue(1)).isEqualTo(values.get(0).toString()); + assertThat(binder.next()).isTrue(); + assertThat(statement.getParamValue(1)).isEqualTo(values.get(1).toString()); + assertThat(binder.next()).isTrue(); + assertThat(statement.getParamValue(1)).isNull(); + assertThat(statement.getParamType(1)).isEqualTo(jdbcType); + assertThat(binder.next()).isFalse(); + + binder.reset(); + + setNull.accept(vector, 0); + setValue.accept(vector, 1, values.get(3)); + setValue.accept(vector, 2, values.get(0)); + setValue.accept(vector, 3, values.get(2)); + setValue.accept(vector, 4, values.get(1)); + root.setRowCount(5); + + assertThat(binder.next()).isTrue(); + assertThat(statement.getParamValue(1)).isNull(); + assertThat(statement.getParamType(1)).isEqualTo(jdbcType); + assertThat(binder.next()).isTrue(); + assertThat(statement.getParamValue(1)).isEqualTo(values.get(3).toString()); + assertThat(binder.next()).isTrue(); + assertThat(statement.getParamValue(1)).isEqualTo(values.get(0).toString()); + assertThat(binder.next()).isTrue(); + assertThat(statement.getParamValue(1)).isEqualTo(values.get(2).toString()); + assertThat(binder.next()).isTrue(); + assertThat(statement.getParamValue(1)).isEqualTo(values.get(1).toString()); + assertThat(binder.next()).isFalse(); + } + + // Non-nullable (since some types have a specialized binder) + schema = new Schema(Collections.singletonList(new Field("field", FieldType.notNullable(arrowType), + Collections.singletonList(new Field(MapVector.KEY_NAME, mapType, + Arrays.asList(new Field(MapVector.KEY_NAME, keyType, null), + new Field(MapVector.VALUE_NAME, keyType, null))))))); + try (final MockPreparedStatement statement = new MockPreparedStatement(); + final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { + @SuppressWarnings("unchecked") + final V vector = (V) root.getVector(0); + + final JdbcParameterBinder binder = + JdbcParameterBinder.builder(statement, root).bind(1, + new org.apache.arrow.adapter.jdbc.binder.MapBinder((MapVector) vector, Types.OTHER)).build(); + assertThat(binder.next()).isFalse(); + + setValue.accept(vector, 0, values.get(0)); + setValue.accept(vector, 1, values.get(1)); + root.setRowCount(2); + + assertThat(binder.next()).isTrue(); + assertThat(statement.getParamValue(1)).isEqualTo(values.get(0)); + assertThat(binder.next()).isTrue(); + assertThat(statement.getParamValue(1)).isEqualTo(values.get(1)); + assertThat(binder.next()).isFalse(); + + binder.reset(); + + setValue.accept(vector, 0, values.get(0)); + setValue.accept(vector, 1, values.get(2)); + setValue.accept(vector, 2, values.get(0)); + setValue.accept(vector, 3, values.get(2)); + setValue.accept(vector, 4, values.get(1)); + root.setRowCount(5); + + assertThat(binder.next()).isTrue(); + assertThat(statement.getParamValue(1)).isEqualTo(values.get(0)); + assertThat(binder.next()).isTrue(); + assertThat(statement.getParamValue(1)).isEqualTo(values.get(2)); + assertThat(binder.next()).isTrue(); + assertThat(statement.getParamValue(1)).isEqualTo(values.get(0)); + assertThat(binder.next()).isTrue(); + assertThat(statement.getParamValue(1)).isEqualTo(values.get(2)); + assertThat(binder.next()).isTrue(); + assertThat(statement.getParamValue(1)).isEqualTo(values.get(1)); + assertThat(binder.next()).isFalse(); + } + } }