Skip to content

Commit

Permalink
ARROW-17431: [Java] MapBinder to bind Arrow Map type to DB column (ap…
Browse files Browse the repository at this point in the history
…ache#13941)

Typical real life Arrow datasets contain map of primitive type. This PR introduce MapBinder mapping of primitive types map entries

Authored-by: igor.suhorukov <[email protected]>
Signed-off-by: David Li <[email protected]>
  • Loading branch information
igor-suhorukov authored and zagto committed Oct 7, 2022
1 parent 31e3e39 commit bac735b
Show file tree
Hide file tree
Showing 3 changed files with 276 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import org.apache.arrow.vector.VarBinaryVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.complex.ListVector;
import org.apache.arrow.vector.complex.MapVector;
import org.apache.arrow.vector.types.pojo.ArrowType;

/**
Expand Down Expand Up @@ -99,7 +100,7 @@ public ColumnBinder visit(ArrowType.Union type) {

@Override
public ColumnBinder visit(ArrowType.Map type) {
throw new UnsupportedOperationException("No column binder implemented for type " + type);
return new MapBinder((MapVector) vector);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.arrow.adapter.jdbc.binder;

import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Types;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Objects;

import org.apache.arrow.vector.complex.MapVector;
import org.apache.arrow.vector.complex.impl.UnionMapReader;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.util.JsonStringHashMap;

/**
* A column binder for map of primitive values.
*/
public class MapBinder extends BaseColumnBinder<MapVector> {

private UnionMapReader reader;
private final boolean isTextKey;
private final boolean isTextValue;

public MapBinder(MapVector vector) {
this(vector, Types.VARCHAR);
}

/**
* Init MapBinder and determine type of data vector.
*
* @param vector corresponding data vector from arrow buffer for binding
* @param jdbcType parameter jdbc type
*/
public MapBinder(MapVector vector, int jdbcType) {
super(vector, jdbcType);
reader = vector.getReader();
List<Field> structField = Objects.requireNonNull(vector.getField()).getChildren();
if (structField.size() != 1) {
throw new IllegalArgumentException("Expected Struct field metadata inside Map field");
}
List<Field> keyValueFields = Objects.requireNonNull(structField.get(0)).getChildren();
if (keyValueFields.size() != 2) {
throw new IllegalArgumentException("Expected two children fields " +
"inside nested Struct field in Map");
}
ArrowType keyType = Objects.requireNonNull(keyValueFields.get(0)).getType();
ArrowType valueType = Objects.requireNonNull(keyValueFields.get(1)).getType();
isTextKey = ArrowType.Utf8.INSTANCE.equals(keyType);
isTextValue = ArrowType.Utf8.INSTANCE.equals(valueType);
}

@Override
public void bind(PreparedStatement statement,
int parameterIndex, int rowIndex) throws SQLException {
reader.setPosition(rowIndex);
LinkedHashMap<Object, Object> tags = new JsonStringHashMap<>();
while (reader.next()) {
Object key = reader.key().readObject();
Object value = reader.value().readObject();
tags.put(isTextKey && key != null ? key.toString() : key,
isTextValue && value != null ? value.toString() : value);
}
switch (jdbcType) {
case Types.VARCHAR:
statement.setString(parameterIndex, tags.toString());
break;
case Types.OTHER:
default:
statement.setObject(parameterIndex, tags);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;

import org.apache.arrow.adapter.jdbc.binder.ColumnBinder;
Expand Down Expand Up @@ -69,13 +70,15 @@
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.complex.ListVector;
import org.apache.arrow.vector.complex.MapVector;
import org.apache.arrow.vector.types.DateUnit;
import org.apache.arrow.vector.types.FloatingPointPrecision;
import org.apache.arrow.vector.types.TimeUnit;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.FieldType;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.JsonStringHashMap;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -473,6 +476,81 @@ void listOfString() throws SQLException {
testListType((ArrowType) new ArrowType.Utf8(), setValue, ListVector::setNull, values);
}

@Test
void mapOfString() throws SQLException {
TriConsumer<MapVector, Integer, Map<String, String>> setValue = (mapVector, index, values) -> {
org.apache.arrow.vector.complex.impl.UnionMapWriter mapWriter = mapVector.getWriter();
mapWriter.setPosition(index);
mapWriter.startMap();
values.entrySet().forEach(mapValue -> {
if (mapValue != null) {
byte[] keyBytes = mapValue.getKey().getBytes(StandardCharsets.UTF_8);
byte[] valueBytes = mapValue.getValue().getBytes(StandardCharsets.UTF_8);
try (
ArrowBuf keyBuf = allocator.buffer(keyBytes.length);
ArrowBuf valueBuf = allocator.buffer(valueBytes.length);
) {
mapWriter.startEntry();
keyBuf.writeBytes(keyBytes);
valueBuf.writeBytes(valueBytes);
mapWriter.key().varChar().writeVarChar(0, keyBytes.length, keyBuf);
mapWriter.value().varChar().writeVarChar(0, valueBytes.length, valueBuf);
mapWriter.endEntry();
}
} else {
mapWriter.writeNull();
}
});
mapWriter.endMap();
};

JsonStringHashMap<String, String> value1 = new JsonStringHashMap<String, String>();
value1.put("a", "b");
value1.put("c", "d");
JsonStringHashMap<String, String> value2 = new JsonStringHashMap<String, String>();
value2.put("d", "e");
value2.put("f", "g");
value2.put("k", "l");
JsonStringHashMap<String, String> value3 = new JsonStringHashMap<String, String>();
value3.put("y", "z");
value3.put("arrow", "cool");
List<Map<String, String>> values = Arrays.asList(value1, value2, value3, Collections.emptyMap());
testMapType(new ArrowType.Map(true), setValue, MapVector::setNull, values, new ArrowType.Utf8());
}

@Test
void mapOfInteger() throws SQLException {
TriConsumer<MapVector, Integer, Map<Integer, Integer>> setValue = (mapVector, index, values) -> {
org.apache.arrow.vector.complex.impl.UnionMapWriter mapWriter = mapVector.getWriter();
mapWriter.setPosition(index);
mapWriter.startMap();
values.entrySet().forEach(mapValue -> {
if (mapValue != null) {
mapWriter.startEntry();
mapWriter.key().integer().writeInt(mapValue.getKey());
mapWriter.value().integer().writeInt(mapValue.getValue());
mapWriter.endEntry();
} else {
mapWriter.writeNull();
}
});
mapWriter.endMap();
};

JsonStringHashMap<Integer, Integer> value1 = new JsonStringHashMap<Integer, Integer>();
value1.put(1, 2);
value1.put(3, 4);
JsonStringHashMap<Integer, Integer> value2 = new JsonStringHashMap<Integer, Integer>();
value2.put(5, 6);
value2.put(7, 8);
value2.put(9, 1024);
JsonStringHashMap<Integer, Integer> value3 = new JsonStringHashMap<Integer, Integer>();
value3.put(Integer.MIN_VALUE, Integer.MAX_VALUE);
value3.put(0, 4096);
List<Map<Integer, Integer>> values = Arrays.asList(value1, value2, value3, Collections.emptyMap());
testMapType(new ArrowType.Map(true), setValue, MapVector::setNull, values, new ArrowType.Int(32, true));
}

@FunctionalInterface
interface TriConsumer<T, U, V> {
void accept(T value1, U value2, V value3);
Expand Down Expand Up @@ -672,4 +750,110 @@ <T, V extends FieldVector> void testListType(ArrowType arrowType, TriConsumer<V,
assertThat(binder.next()).isFalse();
}
}

<T, V extends FieldVector> void testMapType(ArrowType arrowType, TriConsumer<V, Integer, T> setValue,
BiConsumer<V, Integer> setNull, List<T> values,
ArrowType elementType) throws SQLException {
int jdbcType = Types.VARCHAR;
FieldType keyType = new FieldType(false, elementType, null, null);
FieldType mapType = new FieldType(false, ArrowType.Struct.INSTANCE, null, null);
Schema schema = new Schema(Collections.singletonList(new Field("field", FieldType.nullable(arrowType),
Collections.singletonList(new Field(MapVector.KEY_NAME, mapType,
Arrays.asList(new Field(MapVector.KEY_NAME, keyType, null),
new Field(MapVector.VALUE_NAME, keyType, null)))))));
try (final MockPreparedStatement statement = new MockPreparedStatement();
final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
final JdbcParameterBinder binder =
JdbcParameterBinder.builder(statement, root).bindAll().build();
assertThat(binder.next()).isFalse();

@SuppressWarnings("unchecked")
final V vector = (V) root.getVector(0);
final ColumnBinder columnBinder = ColumnBinder.forVector(vector);
assertThat(columnBinder.getJdbcType()).isEqualTo(jdbcType);

setValue.accept(vector, 0, values.get(0));
setValue.accept(vector, 1, values.get(1));
setNull.accept(vector, 2);
root.setRowCount(3);

assertThat(binder.next()).isTrue();
assertThat(statement.getParamValue(1)).isEqualTo(values.get(0).toString());
assertThat(binder.next()).isTrue();
assertThat(statement.getParamValue(1)).isEqualTo(values.get(1).toString());
assertThat(binder.next()).isTrue();
assertThat(statement.getParamValue(1)).isNull();
assertThat(statement.getParamType(1)).isEqualTo(jdbcType);
assertThat(binder.next()).isFalse();

binder.reset();

setNull.accept(vector, 0);
setValue.accept(vector, 1, values.get(3));
setValue.accept(vector, 2, values.get(0));
setValue.accept(vector, 3, values.get(2));
setValue.accept(vector, 4, values.get(1));
root.setRowCount(5);

assertThat(binder.next()).isTrue();
assertThat(statement.getParamValue(1)).isNull();
assertThat(statement.getParamType(1)).isEqualTo(jdbcType);
assertThat(binder.next()).isTrue();
assertThat(statement.getParamValue(1)).isEqualTo(values.get(3).toString());
assertThat(binder.next()).isTrue();
assertThat(statement.getParamValue(1)).isEqualTo(values.get(0).toString());
assertThat(binder.next()).isTrue();
assertThat(statement.getParamValue(1)).isEqualTo(values.get(2).toString());
assertThat(binder.next()).isTrue();
assertThat(statement.getParamValue(1)).isEqualTo(values.get(1).toString());
assertThat(binder.next()).isFalse();
}

// Non-nullable (since some types have a specialized binder)
schema = new Schema(Collections.singletonList(new Field("field", FieldType.notNullable(arrowType),
Collections.singletonList(new Field(MapVector.KEY_NAME, mapType,
Arrays.asList(new Field(MapVector.KEY_NAME, keyType, null),
new Field(MapVector.VALUE_NAME, keyType, null)))))));
try (final MockPreparedStatement statement = new MockPreparedStatement();
final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
@SuppressWarnings("unchecked")
final V vector = (V) root.getVector(0);

final JdbcParameterBinder binder =
JdbcParameterBinder.builder(statement, root).bind(1,
new org.apache.arrow.adapter.jdbc.binder.MapBinder((MapVector) vector, Types.OTHER)).build();
assertThat(binder.next()).isFalse();

setValue.accept(vector, 0, values.get(0));
setValue.accept(vector, 1, values.get(1));
root.setRowCount(2);

assertThat(binder.next()).isTrue();
assertThat(statement.getParamValue(1)).isEqualTo(values.get(0));
assertThat(binder.next()).isTrue();
assertThat(statement.getParamValue(1)).isEqualTo(values.get(1));
assertThat(binder.next()).isFalse();

binder.reset();

setValue.accept(vector, 0, values.get(0));
setValue.accept(vector, 1, values.get(2));
setValue.accept(vector, 2, values.get(0));
setValue.accept(vector, 3, values.get(2));
setValue.accept(vector, 4, values.get(1));
root.setRowCount(5);

assertThat(binder.next()).isTrue();
assertThat(statement.getParamValue(1)).isEqualTo(values.get(0));
assertThat(binder.next()).isTrue();
assertThat(statement.getParamValue(1)).isEqualTo(values.get(2));
assertThat(binder.next()).isTrue();
assertThat(statement.getParamValue(1)).isEqualTo(values.get(0));
assertThat(binder.next()).isTrue();
assertThat(statement.getParamValue(1)).isEqualTo(values.get(2));
assertThat(binder.next()).isTrue();
assertThat(statement.getParamValue(1)).isEqualTo(values.get(1));
assertThat(binder.next()).isFalse();
}
}
}

0 comments on commit bac735b

Please sign in to comment.