Skip to content

Commit

Permalink
Add support to reprepare cached queries.
Browse files Browse the repository at this point in the history
We now reprepare cached queries that were invalidated due to e.g. schema changes.

[closes #382]

Signed-off-by: Mark Paluch <[email protected]>
  • Loading branch information
mp911de committed Oct 26, 2021
1 parent dd07bc4 commit 379f32e
Show file tree
Hide file tree
Showing 9 changed files with 385 additions and 85 deletions.
23 changes: 22 additions & 1 deletion src/main/java/io/r2dbc/postgresql/BoundedStatementCache.java
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,30 @@ public void put(Binding binding, String sql, String name) {

Map.Entry<CacheKey, String> lastAccessedStatement = getAndRemoveEldest();
ExceptionFactory factory = ExceptionFactory.withSql(lastAccessedStatement.getKey().sql);
String statementName = lastAccessedStatement.getValue();

close(lastAccessedStatement, factory, statementName);
}

@Override
public void evict(String name) {

synchronized (this.cache) {

List<CacheKey> toRemove = new ArrayList<>();
for (Map.Entry<CacheKey, String> entry : this.cache.entrySet()) {
if (entry.getKey().sql.equals(name)) {
toRemove.add(entry.getKey());
}
}

toRemove.forEach(this.cache::remove);
}
}

private void close(Map.Entry<CacheKey, String> lastAccessedStatement, ExceptionFactory factory, String statementName) {
ExtendedQueryMessageFlow
.closeStatement(this.client, lastAccessedStatement.getValue())
.closeStatement(this.client, statementName)
.handle(factory::handleErrorResponse)
.subscribe(it -> {
}, err -> LOGGER.warn(String.format("Cannot close statement %s (%s)", lastAccessedStatement.getValue(), lastAccessedStatement.getKey().sql), err));
Expand Down
4 changes: 4 additions & 0 deletions src/main/java/io/r2dbc/postgresql/DisabledStatementCache.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ public boolean requiresPrepare(Binding binding, String sql) {
public void put(Binding binding, String sql, String name) {
}

@Override
public void evict(String sql) {
}

@Override
public String toString() {
return "DisabledStatementCache";
Expand Down
244 changes: 192 additions & 52 deletions src/main/java/io/r2dbc/postgresql/ExtendedFlowDelegate.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import io.netty.buffer.ByteBuf;
import io.netty.util.ReferenceCountUtil;
import io.netty.util.ReferenceCounted;
import io.r2dbc.postgresql.api.ErrorDetails;
import io.r2dbc.postgresql.client.Binding;
import io.r2dbc.postgresql.client.Client;
import io.r2dbc.postgresql.client.ExtendedQueryMessageFlow;
Expand All @@ -45,14 +46,18 @@
import io.r2dbc.postgresql.util.Operators;
import reactor.core.publisher.Flux;
import reactor.core.publisher.FluxSink;
import reactor.core.publisher.Mono;
import reactor.core.publisher.SynchronousSink;
import reactor.core.publisher.UnicastProcessor;
import reactor.util.annotation.Nullable;
import reactor.util.concurrent.Queues;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiConsumer;
import java.util.function.Predicate;

import static io.r2dbc.postgresql.message.frontend.Execute.NO_LIMIT;
Expand Down Expand Up @@ -87,92 +92,81 @@ public static Flux<BackendMessage> runQuery(ConnectionResources resources, Excep
StatementCache cache = resources.getStatementCache();
Client client = resources.getClient();

String name = cache.getName(binding, query);
String portal = resources.getPortalNameSupplier().get();
boolean prepareRequired = cache.requiresPrepare(binding, query);

List<FrontendMessage.DirectEncoder> messagesToSend = new ArrayList<>(6);

if (prepareRequired) {
messagesToSend.add(new Parse(name, binding.getParameterTypes(), query));
}

Bind bind = new Bind(portal, binding.getParameterFormats(), values, ExtendedQueryMessageFlow.resultFormat(resources.getConfiguration().isForceBinary()), name);

messagesToSend.add(bind);
messagesToSend.add(new Describe(portal, PORTAL));

Flux<BackendMessage> exchange;
boolean compatibilityMode = resources.getConfiguration().isCompatibilityMode();
boolean implicitTransactions = resources.getClient().getTransactionStatus() == TransactionStatus.IDLE;

ExtendedFlowOperator operator = new ExtendedFlowOperator(query, binding, cache, values, portal, resources.getConfiguration().isForceBinary());

if (compatibilityMode) {

if (fetchSize == NO_LIMIT || implicitTransactions) {
exchange = fetchAll(messagesToSend, client, portal);
exchange = fetchAll(operator, client, portal);
} else {
exchange = fetchCursoredWithSync(messagesToSend, client, portal, fetchSize);
exchange = fetchCursoredWithSync(operator, client, portal, fetchSize);
}
} else {

if (fetchSize == NO_LIMIT) {
exchange = fetchAll(messagesToSend, client, portal);
exchange = fetchAll(operator, client, portal);
} else {
exchange = fetchCursoredWithFlush(messagesToSend, client, portal, fetchSize);
exchange = fetchCursoredWithFlush(operator, client, portal, fetchSize);
}
}

if (prepareRequired) {

exchange = exchange.doOnNext(message -> {
exchange = exchange.doOnNext(message -> {

if (message == ParseComplete.INSTANCE) {
cache.put(binding, query, name);
}
});
}
if (message == ParseComplete.INSTANCE) {
operator.hydrateStatementCache();
}
});

return exchange.doOnSubscribe(it -> QueryLogger.logQuery(client.getContext(), query)).doOnDiscard(ReferenceCounted.class, ReferenceCountUtil::release).filter(RESULT_FRAME_FILTER).handle(factory::handleErrorResponse);
}

/**
* Execute the query and indicate to fetch all rows with the {@link Execute} message.
*
* @param messagesToSend the initial bind flow
* @param client client to use
* @param portal the portal
* @param operator the flow operator
* @param client client to use
* @param portal the portal
* @return the resulting message stream
*/
private static Flux<BackendMessage> fetchAll(List<FrontendMessage.DirectEncoder> messagesToSend, Client client, String portal) {
private static Flux<BackendMessage> fetchAll(ExtendedFlowOperator operator, Client client, String portal) {

messagesToSend.add(new Execute(portal, NO_LIMIT));
messagesToSend.add(new Close(portal, PORTAL));
messagesToSend.add(Sync.INSTANCE);
UnicastProcessor<FrontendMessage> requestsProcessor = UnicastProcessor.create(Queues.<FrontendMessage>small().get());
FluxSink<FrontendMessage> requestsSink = requestsProcessor.sink();
MessageFactory factory = () -> operator.getMessages(Arrays.asList(new Execute(portal, NO_LIMIT), new Close(portal, PORTAL), Sync.INSTANCE));

return client.exchange(Mono.just(new CompositeFrontendMessage(messagesToSend)))
return client.exchange(operator.takeUntil(), Flux.<FrontendMessage>just(new CompositeFrontendMessage(factory.createMessages())).concatWith(requestsProcessor))
.handle(handleReprepare(requestsSink, operator, factory))
.doFinally(ignore -> operator.close(requestsSink))
.as(Operators::discardOnCancel);
}

/**
* Execute a chunked query and indicate to fetch rows in chunks with the {@link Execute} message.
*
* @param messagesToSend the messages to send
* @param client client to use
* @param portal the portal
* @param fetchSize fetch size per roundtrip
* @param operator the flow operator
* @param client client to use
* @param portal the portal
* @param fetchSize fetch size per roundtrip
* @return the resulting message stream
*/
private static Flux<BackendMessage> fetchCursoredWithSync(List<FrontendMessage.DirectEncoder> messagesToSend, Client client, String portal, int fetchSize) {
private static Flux<BackendMessage> fetchCursoredWithSync(ExtendedFlowOperator operator, Client client, String portal, int fetchSize) {

UnicastProcessor<FrontendMessage> requestsProcessor = UnicastProcessor.create(Queues.<FrontendMessage>small().get());
FluxSink<FrontendMessage> requestsSink = requestsProcessor.sink();
AtomicBoolean isCanceled = new AtomicBoolean(false);
AtomicBoolean done = new AtomicBoolean(false);

messagesToSend.add(new Execute(portal, fetchSize));
messagesToSend.add(Sync.INSTANCE);
MessageFactory factory = () -> operator.getMessages(Arrays.asList(new Execute(portal, fetchSize), Sync.INSTANCE));
Predicate<BackendMessage> takeUntil = operator.takeUntil();

return client.exchange(it -> done.get() && it instanceof ReadyForQuery, Flux.<FrontendMessage>just(new CompositeFrontendMessage(messagesToSend)).concatWith(requestsProcessor))
return client.exchange(it -> done.get() && takeUntil.test(it), Flux.<FrontendMessage>just(new CompositeFrontendMessage(factory.createMessages())).concatWith(requestsProcessor))
.handle(handleReprepare(requestsSink, operator, factory))
.handle((BackendMessage message, SynchronousSink<BackendMessage> sink) -> {

if (message instanceof CommandComplete) {
Expand Down Expand Up @@ -211,30 +205,30 @@ private static Flux<BackendMessage> fetchCursoredWithSync(List<FrontendMessage.D
} else {
sink.next(message);
}
}).doFinally(ignore -> requestsSink.complete())
}).doFinally(ignore -> operator.close(requestsSink))
.as(flux -> Operators.discardOnCancel(flux, () -> isCanceled.set(true)));
}

/**
* Execute a contiguous query and indicate to fetch rows in chunks with the {@link Execute} message. Uses {@link Flush}-based synchronization that creates a cursor. Note that flushing keeps the
* cursor open even with implicit transactions and this method may not work with newer pgpool implementations.
*
* @param messagesToSend the messages to send
* @param client client to use
* @param portal the portal
* @param fetchSize fetch size per roundtrip
* @param operator the flow operator
* @param client client to use
* @param portal the portal
* @param fetchSize fetch size per roundtrip
* @return the resulting message stream
*/
private static Flux<BackendMessage> fetchCursoredWithFlush(List<FrontendMessage.DirectEncoder> messagesToSend, Client client, String portal, int fetchSize) {
private static Flux<BackendMessage> fetchCursoredWithFlush(ExtendedFlowOperator operator, Client client, String portal, int fetchSize) {

UnicastProcessor<FrontendMessage> requestsProcessor = UnicastProcessor.create(Queues.<FrontendMessage>small().get());
FluxSink<FrontendMessage> requestsSink = requestsProcessor.sink();
AtomicBoolean isCanceled = new AtomicBoolean(false);

messagesToSend.add(new Execute(portal, fetchSize));
messagesToSend.add(Flush.INSTANCE);
MessageFactory factory = () -> operator.getMessages(Arrays.asList(new Execute(portal, fetchSize), Flush.INSTANCE));

return client.exchange(Flux.<FrontendMessage>just(new CompositeFrontendMessage(messagesToSend)).concatWith(requestsProcessor))
return client.exchange(operator.takeUntil(), Flux.<FrontendMessage>just(new CompositeFrontendMessage(factory.createMessages())).concatWith(requestsProcessor))
.handle(handleReprepare(requestsSink, operator, factory))
.handle((BackendMessage message, SynchronousSink<BackendMessage> sink) -> {

if (message instanceof CommandComplete) {
Expand All @@ -258,8 +252,154 @@ private static Flux<BackendMessage> fetchCursoredWithFlush(List<FrontendMessage.
} else {
sink.next(message);
}
}).doFinally(ignore -> requestsSink.complete())
}).doFinally(ignore -> operator.close(requestsSink))
.as(flux -> Operators.discardOnCancel(flux, () -> isCanceled.set(true)));
}

private static BiConsumer<BackendMessage, SynchronousSink<BackendMessage>> handleReprepare(FluxSink<FrontendMessage> requests, ExtendedFlowOperator operator, MessageFactory messageFactory) {

AtomicBoolean reprepared = new AtomicBoolean();

return (message, sink) -> {

if (message instanceof ErrorResponse && requiresReprepare((ErrorResponse) message) && reprepared.compareAndSet(false, true)) {

operator.evictCachedStatement();

List<FrontendMessage.DirectEncoder> messages = messageFactory.createMessages();
if (!messages.contains(Sync.INSTANCE)) {
messages.add(0, Sync.INSTANCE);
}
requests.next(new CompositeFrontendMessage(messages));
} else {
sink.next(message);
}
};
}

private static boolean requiresReprepare(ErrorResponse errorResponse) {

ErrorDetails details = new ErrorDetails(errorResponse.getFields());
String code = details.getCode();

// "prepared statement \"S_2\" does not exist"
// INVALID_SQL_STATEMENT_NAME
if ("26000".equals(code)) {
return true;
}
// NOT_IMPLEMENTED

if (!"0A000".equals(code)) {
return false;
}

String routine = details.getRoutine().orElse(null);
// "cached plan must not change result type"
return "RevalidateCachedQuery".equals(routine) // 9.2+
|| "RevalidateCachedPlan".equals(routine); // <= 9.1
}

interface MessageFactory {

List<FrontendMessage.DirectEncoder> createMessages();

}

/**
* Operator to encapsulate common activity around the extended flow. Subclasses {@link AtomicInteger} to capture the number of ReadyForQuery frames.
*/
static class ExtendedFlowOperator extends AtomicInteger {

private final String sql;

private final Binding binding;

@Nullable
private volatile String name;

private final StatementCache cache;

private final List<ByteBuf> values;

private final String portal;

private final boolean forceBinary;

public ExtendedFlowOperator(String sql, Binding binding, StatementCache cache, List<ByteBuf> values, String portal, boolean forceBinary) {
this.sql = sql;
this.binding = binding;
this.cache = cache;
this.values = values;
this.portal = portal;
this.forceBinary = forceBinary;
set(1);
}

public void close(FluxSink<FrontendMessage> requests) {
requests.complete();
this.values.forEach(ReferenceCountUtil::release);
}

public void evictCachedStatement() {

incrementAndGet();

synchronized (this) {
this.name = null;
}
this.cache.evict(this.sql);
}

public void hydrateStatementCache() {
this.cache.put(this.binding, this.sql, getStatementName());
}

public Predicate<BackendMessage> takeUntil() {
return m -> {

if (m instanceof ReadyForQuery) {
return decrementAndGet() <= 0;
}

return false;
};
}

private boolean isPrepareRequired() {
return this.cache.requiresPrepare(this.binding, this.sql);
}

public String getStatementName() {
synchronized (this) {

if (this.name == null) {
this.name = this.cache.getName(this.binding, this.sql);
}
return this.name;
}
}

public List<FrontendMessage.DirectEncoder> getMessages(Collection<FrontendMessage.DirectEncoder> append) {
List<FrontendMessage.DirectEncoder> messagesToSend = new ArrayList<>(6);

if (isPrepareRequired()) {
messagesToSend.add(new Parse(getStatementName(), this.binding.getParameterTypes(), this.sql));
}

for (ByteBuf value : this.values) {
value.readerIndex(0);
value.touch("ExtendedFlowOperator").retain();
}

Bind bind = new Bind(this.portal, this.binding.getParameterFormats(), this.values, ExtendedQueryMessageFlow.resultFormat(this.forceBinary), getStatementName());

messagesToSend.add(bind);
messagesToSend.add(new Describe(this.portal, PORTAL));
messagesToSend.addAll(append);

return messagesToSend;
}

}

}
Loading

0 comments on commit 379f32e

Please sign in to comment.