Skip to content

Commit

Permalink
Issue #10933 - Fix AsyncIOServlet test issues (#10949)
Browse files Browse the repository at this point in the history
* Call ServletChannelState.asyncFailure from error listener. Fix #10933
* Separate invokers for read side and write side
* document async error issues
* updates from review
* updates from review
  • Loading branch information
gregw authored Dec 14, 2023
1 parent 2812023 commit f776d3e
Show file tree
Hide file tree
Showing 12 changed files with 169 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -212,10 +212,6 @@ static String getPathInContext(Request request)
/**
* {@inheritDoc}
* @param demandCallback the demand callback to invoke when there is a content chunk available.
* In addition to the invocation guarantees of {@link Content.Source#demand(Runnable)},
* this implementation serializes the invocation of the {@code Runnable} with
* invocations of any {@link Response#write(boolean, ByteBuffer, Callback)}
* {@code Callback} invocations.
* @see Content.Source#demand(Runnable)
*/
@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,6 @@ public interface Response extends Content.Sink
* has returned.</p>
* <p>Thus a {@code Callback} should not block waiting for a callback
* of a future call to this method.</p>
* <p>Furthermore, the invocation of the passed callback is serialized
* with invocations of the {@link Runnable} demand callback passed to
* {@link Request#demand(Runnable)}.</p>
*
* @param last whether the ByteBuffer is the last to write
* @param byteBuffer the ByteBuffer to write
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ private enum StreamSendState
private final AutoLock _lock = new AutoLock();
private final HandlerInvoker _handlerInvoker = new HandlerInvoker();
private final ConnectionMetaData _connectionMetaData;
private final SerializedInvoker _serializedInvoker;
private final SerializedInvoker _readInvoker;
private final SerializedInvoker _writeInvoker;
private final ResponseHttpFields _responseHeaders = new ResponseHttpFields();
private Thread _handling;
private boolean _handled;
Expand All @@ -122,7 +123,8 @@ public HttpChannelState(ConnectionMetaData connectionMetaData)
{
_connectionMetaData = connectionMetaData;
// The SerializedInvoker is used to prevent infinite recursion of callbacks calling methods calling callbacks etc.
_serializedInvoker = new HttpChannelSerializedInvoker();
_readInvoker = new HttpChannelSerializedInvoker();
_writeInvoker = new HttpChannelSerializedInvoker();
}

@Override
Expand Down Expand Up @@ -298,7 +300,7 @@ public Runnable onContentAvailable()
onContent = _onContentAvailable;
_onContentAvailable = null;
}
return _serializedInvoker.offer(onContent);
return _readInvoker.offer(onContent);
}

@Override
Expand Down Expand Up @@ -341,13 +343,13 @@ public Runnable onIdleTimeout(TimeoutException t)

// If there was a pending IO operation, deliver the idle timeout via them.
if (invokeOnContentAvailable != null || invokeWriteFailure != null)
return _serializedInvoker.offer(invokeOnContentAvailable, invokeWriteFailure);
return Invocable.combine(_readInvoker.offer(invokeOnContentAvailable), _writeInvoker.offer(invokeWriteFailure));

// Otherwise, if there are idle timeout listeners, ask them whether we should call onFailure.
Predicate<TimeoutException> onIdleTimeout = _onIdleTimeout;
if (onIdleTimeout != null)
{
return _serializedInvoker.offer(() ->
return () ->
{
if (onIdleTimeout.test(t))
{
Expand All @@ -356,7 +358,7 @@ public Runnable onIdleTimeout(TimeoutException t)
if (task != null)
task.run();
}
});
};
}
}

Expand Down Expand Up @@ -426,7 +428,7 @@ public Runnable onFailure(Throwable x)
};

// Serialize all the error actions.
task = _serializedInvoker.offer(invokeOnContentAvailable, invokeWriteFailure, invokeOnFailureListeners);
task = Invocable.combine(_readInvoker.offer(invokeOnContentAvailable), _writeInvoker.offer(invokeWriteFailure), invokeOnFailureListeners);
}
}

Expand Down Expand Up @@ -912,7 +914,7 @@ public void demand(Runnable demandCallback)

if (error)
{
httpChannelState._serializedInvoker.run(demandCallback);
httpChannelState._readInvoker.run(demandCallback);
}
else if (interimCallback == null)
{
Expand Down Expand Up @@ -1189,14 +1191,14 @@ else if (last && !(totalWritten == 0 && HttpMethod.HEAD.is(_request.getMethod())

if (writeFailure == NOTHING_TO_SEND)
{
httpChannelState._serializedInvoker.run(callback::succeeded);
httpChannelState._writeInvoker.run(callback::succeeded);
return;
}
// Have we failed in some way?
if (writeFailure != null)
{
Throwable failure = writeFailure;
httpChannelState._serializedInvoker.run(() -> callback.failed(failure));
httpChannelState._writeInvoker.run(() -> callback.failed(failure));
return;
}

Expand Down Expand Up @@ -1235,7 +1237,7 @@ public void succeeded()
httpChannel.lockedStreamSendCompleted(true);
}
if (callback != null)
httpChannel._serializedInvoker.run(callback::succeeded);
httpChannel._writeInvoker.run(callback::succeeded);
}

/**
Expand Down Expand Up @@ -1263,7 +1265,7 @@ public void failed(Throwable x)
httpChannel.lockedStreamSendCompleted(false);
}
if (callback != null)
httpChannel._serializedInvoker.run(() -> callback.failed(x));
httpChannel._writeInvoker.run(() -> callback.failed(x));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1220,15 +1220,15 @@ public boolean handle(Request request, Response response, Callback callback)
assertThat(chunk.getFailure(), sameInstance(failure));

CountDownLatch demand = new CountDownLatch(1);
// Demand callback serialized until after onFailure listeners.
// Demand callback not serialized until after onFailure listeners.
rq.demand(demand::countDown);
assertThat(demand.getCount(), is(1L));
assertThat(demand.getCount(), is(0L));

FuturePromise<Throwable> callback = new FuturePromise<>();
// Write callback serialized until after onFailure listeners.
// Write callback not serialized until after onFailure listeners.
handling.get().write(false, null, Callback.from(() ->
{}, callback::succeeded));
assertFalse(callback.isDone());
assertTrue(callback.isDone());

// Process onFailure task.
try (StacklessLogging ignore = new StacklessLogging(Response.class))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ public boolean handle(Request request, Response response, Callback callback)

int expectedStatus = succeedCallback ? HttpStatus.OK_200 : HttpStatus.INTERNAL_SERVER_ERROR_500;
assertEquals(expectedStatus, response.getStatus());
assertThat(failureLatch.await(1, TimeUnit.SECONDS), is(failIdleTimeout));
assertThat(failureLatch.await(idleTimeout + 500, TimeUnit.MILLISECONDS), is(failIdleTimeout && !succeedCallback));
}

@ParameterizedTest
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ public String toString()
* @param task the Runnable
* @return a new Task
*/
public static Task from(InvocationType type, Runnable task)
static Task from(InvocationType type, Runnable task)
{
return new ReadyTask(type, task);
}
Expand Down Expand Up @@ -202,4 +202,43 @@ default InvocationType getInvocationType()
{
return InvocationType.BLOCKING;
}

/**
* Combine {@link Runnable}s into a single {@link Runnable} that sequentially calls the others.
* @param runnables the {@link Runnable}s to combine
* @return the combined {@link Runnable} with a combined {@link InvocationType}.
*/
static Runnable combine(Runnable... runnables)
{
Runnable result = null;
for (Runnable runnable : runnables)
{
if (runnable == null)
continue;
if (result == null)
{
result = runnable;
}
else
{
Runnable first = result;
result = new Task()
{
@Override
public void run()
{
first.run();
runnable.run();
}

@Override
public InvocationType getInvocationType()
{
return combine(Invocable.getInvocationType(first), Invocable.getInvocationType(runnable));
}
};
}
}
return result;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
//
// ========================================================================
// Copyright (c) 1995 Mort Bay Consulting Pty Ltd and others.
//
// This program and the accompanying materials are made available under the
// terms of the Eclipse Public License v. 2.0 which is available at
// https://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
// which is available at https://www.apache.org/licenses/LICENSE-2.0.
//
// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
// ========================================================================
//

package org.eclipse.jetty.util.thread;

import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;

import org.junit.jupiter.api.Test;

import static org.eclipse.jetty.util.thread.Invocable.InvocationType.BLOCKING;
import static org.eclipse.jetty.util.thread.Invocable.InvocationType.EITHER;
import static org.eclipse.jetty.util.thread.Invocable.InvocationType.NON_BLOCKING;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.nullValue;
import static org.hamcrest.Matchers.sameInstance;

public class InvocableTest
{
@Test
public void testCombineType()
{
assertThat(Invocable.combine(null, null), is(BLOCKING));
assertThat(Invocable.combine(null, BLOCKING), is(BLOCKING));
assertThat(Invocable.combine(null, NON_BLOCKING), is(BLOCKING));
assertThat(Invocable.combine(null, EITHER), is(BLOCKING));

assertThat(Invocable.combine(BLOCKING, null), is(BLOCKING));
assertThat(Invocable.combine(BLOCKING, BLOCKING), is(BLOCKING));
assertThat(Invocable.combine(BLOCKING, NON_BLOCKING), is(BLOCKING));
assertThat(Invocable.combine(BLOCKING, EITHER), is(BLOCKING));

assertThat(Invocable.combine(NON_BLOCKING, null), is(BLOCKING));
assertThat(Invocable.combine(NON_BLOCKING, BLOCKING), is(BLOCKING));
assertThat(Invocable.combine(NON_BLOCKING, NON_BLOCKING), is(NON_BLOCKING));
assertThat(Invocable.combine(NON_BLOCKING, EITHER), is(NON_BLOCKING));

assertThat(Invocable.combine(EITHER, null), is(BLOCKING));
assertThat(Invocable.combine(EITHER, BLOCKING), is(BLOCKING));
assertThat(Invocable.combine(EITHER, NON_BLOCKING), is(NON_BLOCKING));
assertThat(Invocable.combine(EITHER, EITHER), is(EITHER));
}

@Test
public void testCombineRunnable()
{
Queue<String> history = new ConcurrentLinkedQueue<>();

assertThat(Invocable.combine(), nullValue());
assertThat(Invocable.combine((Runnable)null), nullValue());
assertThat(Invocable.combine(null, (Runnable)null), nullValue());

Runnable r1 = () -> history.add("R1");
Runnable r2 = () -> history.add("R2");
Runnable r3 = () -> history.add("R3");

assertThat(Invocable.combine(r1, null, null), sameInstance(r1));
assertThat(Invocable.combine(null, r2, null), sameInstance(r2));
assertThat(Invocable.combine(null, null, r3), sameInstance(r3));

Runnable r13 = Invocable.combine(r1, null, r3);
history.clear();
r13.run();
assertThat(history, contains("R1", "R3"));

Runnable r123 = Invocable.combine(r1, r2, r3);
history.clear();
r123.run();
assertThat(history, contains("R1", "R2", "R3"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,13 @@ public void handle()
// be dispatched to an error page, so we delegate this responsibility to the ErrorHandler.
reopen();
_state.errorHandling();

// TODO We currently directly call the errorHandler here, but this is not correct in the case of async errors,
// because since a failure has already occurred, the errorHandler is unable to write a response.
// Instead, we should fail the callback, so that it calls Response.writeError(...) with an ErrorResponse
// that ignores existing failures. However, the error handler needs to be able to call servlet pages,
// so it will need to do a new call to associate(req,res,callback) or similar, to make the servlet request and
// response wrap the error request and response. Have to think about what callback is passed.
errorHandler.handle(getServletContextRequest(), getServletContextResponse(), Callback.from(_state::errorHandlingComplete));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ public enum Action
private long _timeoutMs = DEFAULT_TIMEOUT;
private AsyncContextEvent _event;
private Thread _onTimeoutThread;
private boolean _failureListener;

protected ServletChannelState(ServletChannel servletChannel)
{
Expand Down Expand Up @@ -511,6 +512,11 @@ public void startAsync(AsyncContextEvent event)
if (_state != State.HANDLING || (_requestState != RequestState.BLOCKING && _requestState != RequestState.ERRORING))
throw new IllegalStateException(this.getStatusStringLocked());

if (!_failureListener)
{
_failureListener = true;
_servletChannel.getRequest().addFailureListener(this::asyncError);
}
_requestState = RequestState.ASYNC;
_event = event;
lastAsyncListeners = _asyncListeners;
Expand Down Expand Up @@ -1099,6 +1105,7 @@ protected void recycle()
_asyncWritePossible = false;
_timeoutMs = DEFAULT_TIMEOUT;
_event = null;
_failureListener = false;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;

@Disabled
public class AsyncIOServletTest extends AbstractTest
{
private static final ThreadLocal<RuntimeException> scope = new ThreadLocal<>();
Expand Down Expand Up @@ -1081,6 +1080,7 @@ public void onComplete(Result result)

@ParameterizedTest
@MethodSource("transportsNoFCGI")
@Disabled // TODO Cannot write response from onError as failure has occurred
public void testAsyncReadEarlyEOF(Transport transport) throws Exception
{
// SSLEngine receives the close alert from the client, and when
Expand Down Expand Up @@ -1197,8 +1197,8 @@ public void onError(Throwable x)
}

@ParameterizedTest
@MethodSource("transportsNoFCGI")
public void testAsyncEcho(Transport transport) throws Exception
@MethodSource("transports")
public void testAsyncReadEcho(Transport transport) throws Exception
{
// TODO: investigate why H3 does not work.
Assumptions.assumeTrue(transport != Transport.H3);
Expand All @@ -1208,8 +1208,6 @@ public void testAsyncEcho(Transport transport) throws Exception
@Override
protected void service(HttpServletRequest request, HttpServletResponse response) throws IOException
{
System.err.println("service " + request);

AsyncContext asyncContext = request.startAsync();
ServletInputStream input = request.getInputStream();
input.setReadListener(new ReadListener()
Expand All @@ -1222,7 +1220,6 @@ public void onDataAvailable() throws IOException
int b = input.read();
if (b >= 0)
{
// System.err.printf("0x%2x %s %n", b, Character.isISOControl(b)?"?":(""+(char)b));
response.getOutputStream().write(b);
}
else
Expand Down
Loading

0 comments on commit f776d3e

Please sign in to comment.