Skip to content

Commit

Permalink
Enable non-nested dictionary batches in Flight integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lidavidm committed Jun 26, 2019
1 parent f7631a2 commit b4dbc44
Show file tree
Hide file tree
Showing 16 changed files with 277 additions and 102 deletions.
3 changes: 2 additions & 1 deletion cpp/src/arrow/flight/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,8 @@ class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter {
}
RETURN_NOT_OK(Buffer::FromString(str_descr, &payload.descriptor));
first_payload_ = false;
} else if (stream_writer_->app_metadata_) {
} else if (ipc_payload.type == ipc::Message::RECORD_BATCH &&
stream_writer_->app_metadata_) {
payload.app_metadata = std::move(stream_writer_->app_metadata_);
}

Expand Down
4 changes: 3 additions & 1 deletion cpp/src/arrow/flight/test-integration-server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ class FlightIntegrationTestServer : public FlightServerBase {
RETURN_NOT_OK(reader->ReadWithMetadata(&chunk, &metadata));
if (chunk == nullptr) break;
retrieved_chunks.push_back(chunk);
RETURN_NOT_OK(writer->WriteMetadata(*metadata));
if (metadata) {
RETURN_NOT_OK(writer->WriteMetadata(*metadata));
}
}
std::shared_ptr<arrow::Table> retrieved_data;
RETURN_NOT_OK(arrow::Table::FromRecordBatches(reader->schema(), retrieved_chunks,
Expand Down
6 changes: 4 additions & 2 deletions cpp/src/arrow/flight/test-util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,10 @@ Status NumberingStream::GetSchemaPayload(FlightPayload* payload) {

Status NumberingStream::Next(FlightPayload* payload) {
RETURN_NOT_OK(stream_->Next(payload));
payload->app_metadata = Buffer::FromString(std::to_string(counter_));
counter_++;
if (payload && payload->ipc_message.type == ipc::Message::RECORD_BATCH) {
payload->app_metadata = Buffer::FromString(std::to_string(counter_));
counter_++;
}
return Status::OK();
}

Expand Down
2 changes: 1 addition & 1 deletion integration/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,7 +1152,7 @@ def _temp_path():
generate_interval_case(),
generate_map_case(),
generate_nested_case(),
generate_dictionary_case().skip_category(SKIP_FLIGHT),
generate_dictionary_case(),
generate_nested_dictionary_case().skip_category(SKIP_ARROW)
.skip_category(SKIP_FLIGHT),
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.apache.arrow.flight.impl.Flight.FlightDescriptor;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.ipc.message.MessageSerializer;
import org.apache.arrow.vector.types.pojo.Schema;
Expand Down Expand Up @@ -147,6 +148,21 @@ public ArrowMessage(ArrowRecordBatch batch, byte[] appMetadata) {
this.appMetadata = appMetadata;
}

public ArrowMessage(ArrowDictionaryBatch batch) {
FlatBufferBuilder builder = new FlatBufferBuilder();
int batchOffset = batch.writeTo(builder);
ByteBuffer serializedMessage = MessageSerializer
.serializeMessage(builder, MessageHeader.DictionaryBatch, batchOffset,
batch.computeBodyLength());
serializedMessage = serializedMessage.slice();
this.message = Message.getRootAsMessage(serializedMessage);
// asInputStream will free the buffers implicitly, so increment the reference count
batch.getDictionary().getBuffers().forEach(buf -> buf.getReferenceManager().retain());
this.bufs = ImmutableList.copyOf(batch.getDictionary().getBuffers());
this.descriptor = null;
this.appMetadata = null;
}

private ArrowMessage(FlightDescriptor descriptor, Message message, byte[] appMetadata, ArrowBuf buf) {
this.message = message;
this.descriptor = descriptor;
Expand Down Expand Up @@ -184,6 +200,13 @@ public ArrowRecordBatch asRecordBatch() throws IOException {
return batch;
}

public ArrowDictionaryBatch asDictionaryBatch() throws IOException {
Preconditions.checkArgument(bufs.size() == 1, "A batch can only be consumed if it contains a single ArrowBuf.");
Preconditions.checkArgument(getMessageType() == HeaderType.DICTIONARY_BATCH);
ArrowBuf underlying = bufs.get(0);
return MessageSerializer.deserializeDictionaryBatch(message, underlying);
}

public Iterable<ArrowBuf> getBufs() {
return Iterables.unmodifiableIterable(bufs);
}
Expand Down Expand Up @@ -277,7 +300,8 @@ private InputStream asInputStream(BufferAllocator allocator) {
return NO_BODY_MARSHALLER.stream(builder.build());
}

Preconditions.checkArgument(getMessageType() == HeaderType.RECORD_BATCH);
Preconditions.checkArgument(getMessageType() == HeaderType.RECORD_BATCH ||
getMessageType() == HeaderType.DICTIONARY_BATCH);
Preconditions.checkArgument(!bufs.isEmpty());
Preconditions.checkArgument(descriptor == null, "Descriptor should only be included in the schema message.");

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.arrow.flight;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.function.Consumer;
import java.util.stream.Stream;

import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.VectorUnloader;
import org.apache.arrow.vector.dictionary.Dictionary;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.DictionaryUtility;

/**
* Utilities to work with dictionaries in Flight.
*/
final class DictionaryUtils {

private DictionaryUtils() {
throw new UnsupportedOperationException("Do not instantiate this class.");
}

/**
* Generate all the necessary Flight messages to send a schema and associated dictionaries.
*/
static Schema generateSchemaMessages(final Schema originalSchema, final FlightDescriptor descriptor,
final DictionaryProvider provider, final Consumer<ArrowMessage> messageCallback) {
final List<Field> fields = new ArrayList<>(originalSchema.getFields().size());
final Set<Long> dictionaryIds = new HashSet<>();
for (final Field field : originalSchema.getFields()) {
fields.add(DictionaryUtility.toMessageFormat(field, provider, dictionaryIds));
}
final Schema schema = new Schema(fields, originalSchema.getCustomMetadata());
// Send the schema message
messageCallback.accept(new ArrowMessage(descriptor == null ? null : descriptor.toProtocol(), schema));
// Create and write dictionary batches
for (Long id : dictionaryIds) {
final Dictionary dictionary = provider.lookup(id);
final FieldVector vector = dictionary.getVector();
final int count = vector.getValueCount();
// Do NOT close this root, as it does not actually own the vector.
final VectorSchemaRoot dictRoot = new VectorSchemaRoot(
Collections.singletonList(vector.getField()),
Collections.singletonList(vector),
count);
final VectorUnloader unloader = new VectorUnloader(dictRoot);
try (final ArrowDictionaryBatch dictionaryBatch = new ArrowDictionaryBatch(
id, unloader.getRecordBatch())) {
messageCallback.accept(new ArrowMessage(dictionaryBatch));
}
}
return schema;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@

import java.io.InputStream;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
Expand All @@ -37,9 +42,17 @@
import org.apache.arrow.flight.impl.FlightServiceGrpc.FlightServiceBlockingStub;
import org.apache.arrow.flight.impl.FlightServiceGrpc.FlightServiceStub;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.VectorUnloader;
import org.apache.arrow.vector.dictionary.Dictionary;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.dictionary.DictionaryProvider.MapDictionaryProvider;
import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.DictionaryUtility;

import com.google.common.base.Preconditions;
import com.google.common.base.Throwables;
Expand Down Expand Up @@ -164,6 +177,19 @@ public void authenticate(ClientAuthHandler handler, CallOption... options) {
*/
public ClientStreamListener startPut(FlightDescriptor descriptor, VectorSchemaRoot root,
StreamListener<PutResult> metadataListener, CallOption... options) {
return startPut(descriptor, root, new MapDictionaryProvider(), metadataListener, options);
}

/**
* Create or append a descriptor with another stream.
* @param descriptor FlightDescriptor the descriptor for the data
* @param root VectorSchemaRoot the root containing data
* @param metadataListener A handler for metadata messages from the server.
* @param options RPC-layer hints for this call.
* @return ClientStreamListener an interface to control uploading data
*/
public ClientStreamListener startPut(FlightDescriptor descriptor, VectorSchemaRoot root, DictionaryProvider provider,
StreamListener<PutResult> metadataListener, CallOption... options) {
Preconditions.checkNotNull(descriptor);
Preconditions.checkNotNull(root);

Expand All @@ -173,8 +199,7 @@ public ClientStreamListener startPut(FlightDescriptor descriptor, VectorSchemaRo
ClientCalls.asyncBidiStreamingCall(
authInterceptor.interceptCall(doPutDescriptor, callOptions, channel), resultObserver);
// send the schema to start.
ArrowMessage message = new ArrowMessage(descriptor.toProtocol(), root.getSchema());
observer.onNext(message);
DictionaryUtils.generateSchemaMessages(root.getSchema(), descriptor, provider, observer::onNext);
return new PutObserver(new VectorUnloader(
root, true /* include # of nulls in vectors */, true /* must align buffers to be C++-compatible */),
observer, resultObserver.getFuture());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.arrow.flight;

import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.dictionary.DictionaryProvider;

/**
* API to Implement an Arrow Flight producer.
Expand Down Expand Up @@ -100,6 +101,13 @@ public interface ServerStreamListener {
*/
void start(VectorSchemaRoot root);

/**
* Start sending data, using the schema of the given {@link VectorSchemaRoot}.
*
* <p>This method must be called before all others.
*/
void start(VectorSchemaRoot root, DictionaryProvider dictionaries);

/**
* Send the current contents of the associated {@link VectorSchemaRoot}.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@

package org.apache.arrow.flight;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.function.BooleanSupplier;
Expand All @@ -32,8 +37,17 @@
import org.apache.arrow.flight.impl.Flight.HandshakeResponse;
import org.apache.arrow.flight.impl.FlightServiceGrpc.FlightServiceImplBase;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.VectorUnloader;
import org.apache.arrow.vector.dictionary.Dictionary;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.dictionary.DictionaryProvider.MapDictionaryProvider;
import org.apache.arrow.vector.ipc.message.ArrowDictionaryBatch;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.arrow.vector.util.DictionaryUtility;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -137,9 +151,14 @@ public boolean isCancelled() {

@Override
public void start(VectorSchemaRoot root) {
responseObserver.onNext(new ArrowMessage(null, root.getSchema()));
// [ARROW-4213] We must align buffers to be compatible with other languages.
start(root, new MapDictionaryProvider());
}

@Override
public void start(VectorSchemaRoot root, DictionaryProvider provider) {
unloader = new VectorUnloader(root, true, true);

DictionaryUtils.generateSchemaMessages(root.getSchema(), null, provider, responseObserver::onNext);
}

@Override
Expand Down Expand Up @@ -171,14 +190,17 @@ public StreamObserver<ArrowMessage> doPutCustom(final StreamObserver<Flight.PutR
responseObserver.disableAutoInboundFlowControl();
responseObserver.request(1);

FlightStream fs = new FlightStream(allocator, PENDING_REQUESTS, null, (count) -> responseObserver.request(count));
FlightStream fs = new FlightStream(allocator, PENDING_REQUESTS, null, responseObserver::request);
executors.submit(() -> {
try {
producer.acceptPut(makeContext(responseObserver), fs,
StreamPipe.wrap(responseObserver, PutResult::toProtocol)).run();
responseObserver.onCompleted();
} catch (Exception ex) {
responseObserver.onError(ex);
// The client may have terminated, so the exception here is effectively swallowed.
// Log the error as well so -something- makes it to the developer.
logger.error("Exception handling DoPut", ex);
}
});

Expand Down
Loading

0 comments on commit b4dbc44

Please sign in to comment.