From b4dbc445e235260945daaf0b7369e49990bb91ed Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 21 May 2019 10:40:44 -0400 Subject: [PATCH] Enable non-nested dictionary batches in Flight integration tests --- cpp/src/arrow/flight/client.cc | 3 +- .../arrow/flight/test-integration-server.cc | 4 +- cpp/src/arrow/flight/test-util.cc | 6 +- integration/integration_test.py | 2 +- .../org/apache/arrow/flight/ArrowMessage.java | 26 +++++- .../apache/arrow/flight/DictionaryUtils.java | 80 +++++++++++++++++++ .../org/apache/arrow/flight/FlightClient.java | 29 ++++++- .../apache/arrow/flight/FlightProducer.java | 8 ++ .../apache/arrow/flight/FlightService.java | 28 ++++++- .../org/apache/arrow/flight/FlightStream.java | 64 +++++++++++++-- .../apache/arrow/flight/GenericOperation.java | 42 ---------- .../org/apache/arrow/flight/PutResult.java | 3 + .../arrow/flight/example/FlightHolder.java | 14 ++-- .../arrow/flight/example/InMemoryStore.java | 13 +-- .../apache/arrow/flight/example/Stream.java | 14 +++- .../integration/IntegrationTestClient.java | 43 +++++----- 16 files changed, 277 insertions(+), 102 deletions(-) create mode 100644 java/flight/src/main/java/org/apache/arrow/flight/DictionaryUtils.java delete mode 100644 java/flight/src/main/java/org/apache/arrow/flight/GenericOperation.java diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index 654e1e7c48750..070316718a1b0 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -308,7 +308,8 @@ class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter { } RETURN_NOT_OK(Buffer::FromString(str_descr, &payload.descriptor)); first_payload_ = false; - } else if (stream_writer_->app_metadata_) { + } else if (ipc_payload.type == ipc::Message::RECORD_BATCH && + stream_writer_->app_metadata_) { payload.app_metadata = std::move(stream_writer_->app_metadata_); } diff --git a/cpp/src/arrow/flight/test-integration-server.cc b/cpp/src/arrow/flight/test-integration-server.cc index 6d04588c190b7..9ceabbcafb209 100644 --- a/cpp/src/arrow/flight/test-integration-server.cc +++ b/cpp/src/arrow/flight/test-integration-server.cc @@ -106,7 +106,9 @@ class FlightIntegrationTestServer : public FlightServerBase { RETURN_NOT_OK(reader->ReadWithMetadata(&chunk, &metadata)); if (chunk == nullptr) break; retrieved_chunks.push_back(chunk); - RETURN_NOT_OK(writer->WriteMetadata(*metadata)); + if (metadata) { + RETURN_NOT_OK(writer->WriteMetadata(*metadata)); + } } std::shared_ptr retrieved_data; RETURN_NOT_OK(arrow::Table::FromRecordBatches(reader->schema(), retrieved_chunks, diff --git a/cpp/src/arrow/flight/test-util.cc b/cpp/src/arrow/flight/test-util.cc index c84b2eda81c82..4408801a97e25 100644 --- a/cpp/src/arrow/flight/test-util.cc +++ b/cpp/src/arrow/flight/test-util.cc @@ -271,8 +271,10 @@ Status NumberingStream::GetSchemaPayload(FlightPayload* payload) { Status NumberingStream::Next(FlightPayload* payload) { RETURN_NOT_OK(stream_->Next(payload)); - payload->app_metadata = Buffer::FromString(std::to_string(counter_)); - counter_++; + if (payload && payload->ipc_message.type == ipc::Message::RECORD_BATCH) { + payload->app_metadata = Buffer::FromString(std::to_string(counter_)); + counter_++; + } return Status::OK(); } diff --git a/integration/integration_test.py b/integration/integration_test.py index a4763c98c8c08..aca05747c72f7 100644 --- a/integration/integration_test.py +++ b/integration/integration_test.py @@ -1152,7 +1152,7 @@ def _temp_path(): generate_interval_case(), generate_map_case(), generate_nested_case(), - generate_dictionary_case().skip_category(SKIP_FLIGHT), + generate_dictionary_case(), generate_nested_dictionary_case().skip_category(SKIP_ARROW) .skip_category(SKIP_FLIGHT), ] diff --git a/java/flight/src/main/java/org/apache/arrow/flight/ArrowMessage.java b/java/flight/src/main/java/org/apache/arrow/flight/ArrowMessage.java index be63dcb4f1ac3..253f0d4c3a84e 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/ArrowMessage.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/ArrowMessage.java @@ -35,6 +35,7 @@ import org.apache.arrow.flight.impl.Flight.FlightDescriptor; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch; import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; import org.apache.arrow.vector.ipc.message.MessageSerializer; import org.apache.arrow.vector.types.pojo.Schema; @@ -147,6 +148,21 @@ public ArrowMessage(ArrowRecordBatch batch, byte[] appMetadata) { this.appMetadata = appMetadata; } + public ArrowMessage(ArrowDictionaryBatch batch) { + FlatBufferBuilder builder = new FlatBufferBuilder(); + int batchOffset = batch.writeTo(builder); + ByteBuffer serializedMessage = MessageSerializer + .serializeMessage(builder, MessageHeader.DictionaryBatch, batchOffset, + batch.computeBodyLength()); + serializedMessage = serializedMessage.slice(); + this.message = Message.getRootAsMessage(serializedMessage); + // asInputStream will free the buffers implicitly, so increment the reference count + batch.getDictionary().getBuffers().forEach(buf -> buf.getReferenceManager().retain()); + this.bufs = ImmutableList.copyOf(batch.getDictionary().getBuffers()); + this.descriptor = null; + this.appMetadata = null; + } + private ArrowMessage(FlightDescriptor descriptor, Message message, byte[] appMetadata, ArrowBuf buf) { this.message = message; this.descriptor = descriptor; @@ -184,6 +200,13 @@ public ArrowRecordBatch asRecordBatch() throws IOException { return batch; } + public ArrowDictionaryBatch asDictionaryBatch() throws IOException { + Preconditions.checkArgument(bufs.size() == 1, "A batch can only be consumed if it contains a single ArrowBuf."); + Preconditions.checkArgument(getMessageType() == HeaderType.DICTIONARY_BATCH); + ArrowBuf underlying = bufs.get(0); + return MessageSerializer.deserializeDictionaryBatch(message, underlying); + } + public Iterable getBufs() { return Iterables.unmodifiableIterable(bufs); } @@ -277,7 +300,8 @@ private InputStream asInputStream(BufferAllocator allocator) { return NO_BODY_MARSHALLER.stream(builder.build()); } - Preconditions.checkArgument(getMessageType() == HeaderType.RECORD_BATCH); + Preconditions.checkArgument(getMessageType() == HeaderType.RECORD_BATCH || + getMessageType() == HeaderType.DICTIONARY_BATCH); Preconditions.checkArgument(!bufs.isEmpty()); Preconditions.checkArgument(descriptor == null, "Descriptor should only be included in the schema message."); diff --git a/java/flight/src/main/java/org/apache/arrow/flight/DictionaryUtils.java b/java/flight/src/main/java/org/apache/arrow/flight/DictionaryUtils.java new file mode 100644 index 0000000000000..cae5c2e71f546 --- /dev/null +++ b/java/flight/src/main/java/org/apache/arrow/flight/DictionaryUtils.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.Callable; +import java.util.function.Consumer; +import java.util.stream.Stream; + +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.DictionaryUtility; + +/** + * Utilities to work with dictionaries in Flight. + */ +final class DictionaryUtils { + + private DictionaryUtils() { + throw new UnsupportedOperationException("Do not instantiate this class."); + } + + /** + * Generate all the necessary Flight messages to send a schema and associated dictionaries. + */ + static Schema generateSchemaMessages(final Schema originalSchema, final FlightDescriptor descriptor, + final DictionaryProvider provider, final Consumer messageCallback) { + final List fields = new ArrayList<>(originalSchema.getFields().size()); + final Set dictionaryIds = new HashSet<>(); + for (final Field field : originalSchema.getFields()) { + fields.add(DictionaryUtility.toMessageFormat(field, provider, dictionaryIds)); + } + final Schema schema = new Schema(fields, originalSchema.getCustomMetadata()); + // Send the schema message + messageCallback.accept(new ArrowMessage(descriptor == null ? null : descriptor.toProtocol(), schema)); + // Create and write dictionary batches + for (Long id : dictionaryIds) { + final Dictionary dictionary = provider.lookup(id); + final FieldVector vector = dictionary.getVector(); + final int count = vector.getValueCount(); + // Do NOT close this root, as it does not actually own the vector. + final VectorSchemaRoot dictRoot = new VectorSchemaRoot( + Collections.singletonList(vector.getField()), + Collections.singletonList(vector), + count); + final VectorUnloader unloader = new VectorUnloader(dictRoot); + try (final ArrowDictionaryBatch dictionaryBatch = new ArrowDictionaryBatch( + id, unloader.getRecordBatch())) { + messageCallback.accept(new ArrowMessage(dictionaryBatch)); + } + } + return schema; + } +} diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java index 73351e267cd38..41bb428885ede 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightClient.java @@ -19,7 +19,12 @@ import java.io.InputStream; import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; import java.util.Iterator; +import java.util.List; +import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; @@ -37,9 +42,17 @@ import org.apache.arrow.flight.impl.FlightServiceGrpc.FlightServiceBlockingStub; import org.apache.arrow.flight.impl.FlightServiceGrpc.FlightServiceStub; import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.dictionary.DictionaryProvider.MapDictionaryProvider; +import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch; import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.DictionaryUtility; import com.google.common.base.Preconditions; import com.google.common.base.Throwables; @@ -164,6 +177,19 @@ public void authenticate(ClientAuthHandler handler, CallOption... options) { */ public ClientStreamListener startPut(FlightDescriptor descriptor, VectorSchemaRoot root, StreamListener metadataListener, CallOption... options) { + return startPut(descriptor, root, new MapDictionaryProvider(), metadataListener, options); + } + + /** + * Create or append a descriptor with another stream. + * @param descriptor FlightDescriptor the descriptor for the data + * @param root VectorSchemaRoot the root containing data + * @param metadataListener A handler for metadata messages from the server. + * @param options RPC-layer hints for this call. + * @return ClientStreamListener an interface to control uploading data + */ + public ClientStreamListener startPut(FlightDescriptor descriptor, VectorSchemaRoot root, DictionaryProvider provider, + StreamListener metadataListener, CallOption... options) { Preconditions.checkNotNull(descriptor); Preconditions.checkNotNull(root); @@ -173,8 +199,7 @@ public ClientStreamListener startPut(FlightDescriptor descriptor, VectorSchemaRo ClientCalls.asyncBidiStreamingCall( authInterceptor.interceptCall(doPutDescriptor, callOptions, channel), resultObserver); // send the schema to start. - ArrowMessage message = new ArrowMessage(descriptor.toProtocol(), root.getSchema()); - observer.onNext(message); + DictionaryUtils.generateSchemaMessages(root.getSchema(), descriptor, provider, observer::onNext); return new PutObserver(new VectorUnloader( root, true /* include # of nulls in vectors */, true /* must align buffers to be C++-compatible */), observer, resultObserver.getFuture()); diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightProducer.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightProducer.java index 8b9df4b9fa8e9..316fbf97fb9cb 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/FlightProducer.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightProducer.java @@ -18,6 +18,7 @@ package org.apache.arrow.flight; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.dictionary.DictionaryProvider; /** * API to Implement an Arrow Flight producer. @@ -100,6 +101,13 @@ public interface ServerStreamListener { */ void start(VectorSchemaRoot root); + /** + * Start sending data, using the schema of the given {@link VectorSchemaRoot}. + * + *

This method must be called before all others. + */ + void start(VectorSchemaRoot root, DictionaryProvider dictionaries); + /** * Send the current contents of the associated {@link VectorSchemaRoot}. */ diff --git a/java/flight/src/main/java/org/apache/arrow/flight/FlightService.java b/java/flight/src/main/java/org/apache/arrow/flight/FlightService.java index dc055c1f0d758..8f8c2050d5af8 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/FlightService.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/FlightService.java @@ -17,6 +17,11 @@ package org.apache.arrow.flight; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.function.BooleanSupplier; @@ -32,8 +37,17 @@ import org.apache.arrow.flight.impl.Flight.HandshakeResponse; import org.apache.arrow.flight.impl.FlightServiceGrpc.FlightServiceImplBase; import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.VectorSchemaRoot; import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.dictionary.DictionaryProvider.MapDictionaryProvider; +import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.DictionaryUtility; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -137,9 +151,14 @@ public boolean isCancelled() { @Override public void start(VectorSchemaRoot root) { - responseObserver.onNext(new ArrowMessage(null, root.getSchema())); - // [ARROW-4213] We must align buffers to be compatible with other languages. + start(root, new MapDictionaryProvider()); + } + + @Override + public void start(VectorSchemaRoot root, DictionaryProvider provider) { unloader = new VectorUnloader(root, true, true); + + DictionaryUtils.generateSchemaMessages(root.getSchema(), null, provider, responseObserver::onNext); } @Override @@ -171,7 +190,7 @@ public StreamObserver doPutCustom(final StreamObserver responseObserver.request(count)); + FlightStream fs = new FlightStream(allocator, PENDING_REQUESTS, null, responseObserver::request); executors.submit(() -> { try { producer.acceptPut(makeContext(responseObserver), fs, @@ -179,6 +198,9 @@ public StreamObserver doPutCustom(final StreamObserver fields = new ArrayList<>(); + final Map dictionaryMap = new HashMap<>(); + for (final Field originalField : schema.getFields()) { + final Field updatedField = DictionaryUtility.toMemoryFormat(originalField, allocator, dictionaryMap); + fields.add(updatedField); + } + for (final Map.Entry entry : dictionaryMap.entrySet()) { + dictionaries.put(entry.getValue()); + } + schema = new Schema(fields, schema.getCustomMetadata()); fulfilledRoot = VectorSchemaRoot.create(schema, allocator); loader = new VectorLoader(fulfilledRoot); descriptor = msg.getDescriptor() != null ? new FlightDescriptor(msg.getDescriptor()) : null; root.set(fulfilledRoot); break; + } case RECORD_BATCH: queue.add(msg); break; - case NONE: case DICTIONARY_BATCH: + queue.add(msg); + break; + case NONE: case TENSOR: default: queue.add(DONE_EX); - ex = new UnsupportedOperationException("Unable to handle message of type: " + msg); + ex = new UnsupportedOperationException("Unable to handle message of type: " + msg.getMessageType()); } diff --git a/java/flight/src/main/java/org/apache/arrow/flight/GenericOperation.java b/java/flight/src/main/java/org/apache/arrow/flight/GenericOperation.java deleted file mode 100644 index 03a1e92af12c0..0000000000000 --- a/java/flight/src/main/java/org/apache/arrow/flight/GenericOperation.java +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.arrow.flight; - -/** - * Unused?. - */ -class GenericOperation { - - private final String type; - private final byte[] body; - - public GenericOperation(String type, byte[] body) { - super(); - this.type = type; - this.body = body == null ? new byte[0] : body; - } - - public String getType() { - return type; - } - - public byte[] getBody() { - return body; - } - -} diff --git a/java/flight/src/main/java/org/apache/arrow/flight/PutResult.java b/java/flight/src/main/java/org/apache/arrow/flight/PutResult.java index dffafbdc1e8f0..7cf615ebf69bd 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/PutResult.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/PutResult.java @@ -36,6 +36,9 @@ private PutResult(ByteBuffer metadata) { /** Create a PutResult with application-specific metadata. */ public static PutResult metadata(byte[] metadata) { + if (metadata == null) { + return empty(); + } return new PutResult(ByteBuffer.wrap(metadata)); } diff --git a/java/flight/src/main/java/org/apache/arrow/flight/example/FlightHolder.java b/java/flight/src/main/java/org/apache/arrow/flight/example/FlightHolder.java index 91ed04e7ffaba..cf3eb154ed7e1 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/example/FlightHolder.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/example/FlightHolder.java @@ -28,6 +28,7 @@ import org.apache.arrow.flight.Location; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.util.AutoCloseables; +import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.types.pojo.Schema; import com.google.common.base.Preconditions; @@ -43,19 +44,22 @@ public class FlightHolder implements AutoCloseable { private final FlightDescriptor descriptor; private final Schema schema; private final List streams = new CopyOnWriteArrayList<>(); + private final DictionaryProvider dictionaryProvider; /** * Creates a new instance. - * - * @param allocator The allocator to use for allocating buffers to store data. + * @param allocator The allocator to use for allocating buffers to store data. * @param descriptor The descriptor for the streams. * @param schema The schema for the stream. + * @param dictionaryProvider The dictionary provider for the stream. */ - public FlightHolder(BufferAllocator allocator, FlightDescriptor descriptor, Schema schema) { + public FlightHolder(BufferAllocator allocator, FlightDescriptor descriptor, Schema schema, + DictionaryProvider dictionaryProvider) { Preconditions.checkArgument(!descriptor.isCommand()); this.allocator = allocator.newChildAllocator(descriptor.toString(), 0, Long.MAX_VALUE); this.descriptor = descriptor; this.schema = schema; + this.dictionaryProvider = dictionaryProvider; } /** @@ -72,8 +76,8 @@ public Stream getStream(ExampleTicket ticket) { * Adds a new streams which clients can populate via the returned object. */ public Stream.StreamCreator addStream(Schema schema) { - Preconditions.checkArgument(schema.equals(schema), "Stream schema inconsistent with existing schema."); - return new Stream.StreamCreator(schema, allocator, t -> { + Preconditions.checkArgument(this.schema.equals(schema), "Stream schema inconsistent with existing schema."); + return new Stream.StreamCreator(schema, dictionaryProvider, allocator, t -> { synchronized (streams) { streams.add(t); } diff --git a/java/flight/src/main/java/org/apache/arrow/flight/example/InMemoryStore.java b/java/flight/src/main/java/org/apache/arrow/flight/example/InMemoryStore.java index 1d9eb889d743d..73d448e3ab1e1 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/example/InMemoryStore.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/example/InMemoryStore.java @@ -79,17 +79,6 @@ public Stream getStream(Ticket t) { return h.getStream(example); } - /** - * Create a new {@link Stream} with the given schema and descriptor. - */ - public StreamCreator putStream(final FlightDescriptor descriptor, final Schema schema) { - final FlightHolder h = holders.computeIfAbsent( - descriptor, - t -> new FlightHolder(allocator, t, schema)); - - return h.addStream(schema); - } - @Override public void listFlights(CallContext context, Criteria criteria, StreamListener listener) { @@ -123,7 +112,7 @@ public Runnable acceptPut(CallContext context, try (VectorSchemaRoot root = flightStream.getRoot()) { final FlightHolder h = holders.computeIfAbsent( flightStream.getDescriptor(), - t -> new FlightHolder(allocator, t, flightStream.getSchema())); + t -> new FlightHolder(allocator, t, flightStream.getSchema(), flightStream.getDictionaryProvider())); creator = h.addStream(flightStream.getSchema()); diff --git a/java/flight/src/main/java/org/apache/arrow/flight/example/Stream.java b/java/flight/src/main/java/org/apache/arrow/flight/example/Stream.java index b79525caab1fb..1139e8c592f7d 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/example/Stream.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/example/Stream.java @@ -29,6 +29,7 @@ import org.apache.arrow.util.AutoCloseables; import org.apache.arrow.vector.VectorLoader; import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; import org.apache.arrow.vector.types.pojo.Schema; @@ -41,6 +42,7 @@ public class Stream implements AutoCloseable, Iterable { private final String uuid = UUID.randomUUID().toString(); + private final DictionaryProvider dictionaryProvider; private final List batches; private final Schema schema; private final long recordCount; @@ -54,9 +56,11 @@ public class Stream implements AutoCloseable, Iterable { */ public Stream( final Schema schema, + final DictionaryProvider dictionaryProvider, List batches, long recordCount) { this.schema = schema; + this.dictionaryProvider = dictionaryProvider; this.batches = ImmutableList.copyOf(batches); this.recordCount = recordCount; } @@ -83,7 +87,7 @@ public String getUuid() { */ public void sendTo(BufferAllocator allocator, ServerStreamListener listener) { try (VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) { - listener.start(root); + listener.start(root, dictionaryProvider); final VectorLoader loader = new VectorLoader(root); int counter = 0; for (ArrowRecordBatch batch : batches) { @@ -121,18 +125,22 @@ public static class StreamCreator { private final List batches = new ArrayList<>(); private final Consumer committer; private long recordCount = 0; + private DictionaryProvider dictionaryProvider; /** * Creates a new instance. * * @param schema The schema for batches in the stream. + * @param dictionaryProvider The dictionary provider for the stream. * @param allocator The allocator used to copy data permanently into the stream. * @param committer A callback for when the the stream is ready to be finalized (no more batches). */ - public StreamCreator(Schema schema, BufferAllocator allocator, Consumer committer) { + public StreamCreator(Schema schema, DictionaryProvider dictionaryProvider, + BufferAllocator allocator, Consumer committer) { this.allocator = allocator; this.committer = committer; this.schema = schema; + this.dictionaryProvider = dictionaryProvider; } /** @@ -155,7 +163,7 @@ public void add(ArrowRecordBatch batch) { * Complete building the stream (no more batches can be added). */ public void complete() { - Stream stream = new Stream(schema, batches, recordCount); + Stream stream = new Stream(schema, dictionaryProvider, batches, recordCount); committer.accept(stream); } diff --git a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java index 47acbe9f6a46e..75a01f8633cc3 100644 --- a/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java +++ b/java/flight/src/main/java/org/apache/arrow/flight/example/integration/IntegrationTestClient.java @@ -96,27 +96,28 @@ private void run(String[] args) throws ParseException, IOException { jsonRoot = VectorSchemaRoot.create(root.getSchema(), allocator); VectorUnloader unloader = new VectorUnloader(root); VectorLoader jsonLoader = new VectorLoader(jsonRoot); - FlightClient.ClientStreamListener stream = client.startPut(descriptor, root, new StreamListener() { - int counter = 0; - - @Override - public void onNext(PutResult val) { - final String metadata = StandardCharsets.UTF_8.decode(val.getApplicationMetadata()).toString(); - if (!Integer.toString(counter).equals(metadata)) { - throw new RuntimeException( - String.format("Invalid ACK from server. Expected '%d' but got '%s'.", counter, metadata)); - } - counter++; - } - - @Override - public void onError(Throwable t) { - } - - @Override - public void onCompleted() { - } - }); + FlightClient.ClientStreamListener stream = client.startPut(descriptor, root, reader, + new StreamListener() { + int counter = 0; + + @Override + public void onNext(PutResult val) { + final String metadata = StandardCharsets.UTF_8.decode(val.getApplicationMetadata()).toString(); + if (!Integer.toString(counter).equals(metadata)) { + throw new RuntimeException( + String.format("Invalid ACK from server. Expected '%d' but got '%s'.", counter, metadata)); + } + counter++; + } + + @Override + public void onError(Throwable t) { + } + + @Override + public void onCompleted() { + } + }); int counter = 0; while (reader.read(root)) { stream.putNext(Integer.toString(counter).getBytes(StandardCharsets.UTF_8));