Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix potential hang in VertxInputStream #12568

Merged
merged 1 commit into from
Oct 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package io.quarkus.resteasy.test;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
import java.util.Timer;
import java.util.TimerTask;
import java.util.concurrent.LinkedBlockingDeque;

import javax.annotation.PreDestroy;
import javax.enterprise.event.Observes;
import javax.ws.rs.POST;
import javax.ws.rs.Path;

import io.vertx.core.Handler;
import io.vertx.ext.web.Router;
import io.vertx.ext.web.RoutingContext;

@Path("/in")
public class InputStreamResource {

Timer timer = new Timer();

public static final LinkedBlockingDeque<Throwable> THROWABLES = new LinkedBlockingDeque<>();

@PreDestroy
void stop() {
timer.cancel();
}

@POST
public String read(InputStream inputStream) throws IOException {
try {
byte[] buf = new byte[1024];
int r;
ByteArrayOutputStream out = new ByteArrayOutputStream();
while ((r = inputStream.read(buf)) > 0) {
out.write(buf, 0, r);
}
return new String(out.toByteArray(), StandardCharsets.UTF_8);
} catch (IOException e) {
THROWABLES.add(e);
throw e;
}
}

public void delayFilter(@Observes Router router) {
router.route().order(Integer.MIN_VALUE).handler(new Handler<RoutingContext>() {
@Override
public void handle(RoutingContext event) {
timer.schedule(new TimerTask() {
@Override
public void run() {
event.next();
}
}, 1000);
}
});
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package io.quarkus.resteasy.test;

import java.io.IOException;
import java.net.Socket;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.TimeUnit;

import org.hamcrest.Matchers;
import org.jboss.shrinkwrap.api.ShrinkWrap;
import org.jboss.shrinkwrap.api.spec.JavaArchive;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.test.QuarkusUnitTest;
import io.quarkus.test.common.http.TestHTTPResource;
import io.restassured.RestAssured;

public class VertxIOHangTestCase {

@RegisterExtension
static QuarkusUnitTest runner = new QuarkusUnitTest()
.setArchiveProducer(() -> ShrinkWrap.create(JavaArchive.class)
.addClasses(InputStreamResource.class));

@TestHTTPResource
URI uri;

@Test
public void testDelayFilter() {
// makes sure that everything works as normal
RestAssured.given().body("hello world").post("/in").then().body(Matchers.is("hello world"));
}

@Test
public void testDelayFilterConnectionKilled() throws Exception {
// makes sure that everything works as normal
try (Socket s = new Socket(uri.getHost(), uri.getPort())) {
s.getOutputStream().write(
"POST /in HTTP/1.1\r\nHost:localhost\r\nContent-Length: 100\r\n\r\n".getBytes(StandardCharsets.UTF_8));
s.getOutputStream().flush();
}
Throwable exception = InputStreamResource.THROWABLES.poll(3, TimeUnit.SECONDS);
Assertions.assertTrue(exception instanceof IOException);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import java.io.IOException;
import java.io.InputStream;
import java.io.InterruptedIOException;
import java.nio.channels.ClosedChannelException;
import java.util.ArrayDeque;
import java.util.Deque;

Expand All @@ -15,6 +16,7 @@
import io.vertx.core.http.HttpHeaders;
import io.vertx.core.http.HttpServerRequest;
import io.vertx.core.http.HttpServerResponse;
import io.vertx.core.net.impl.ConnectionBase;
import io.vertx.ext.web.RoutingContext;

public class VertxInputStream extends InputStream {
Expand All @@ -27,7 +29,6 @@ public class VertxInputStream extends InputStream {
private final long limit;

public VertxInputStream(RoutingContext request, long timeout) throws IOException {

this.exchange = new VertxBlockingInput(request.request(), timeout);
Long limitObj = request.get(VertxHttpRecorder.MAX_REQUEST_SIZE_KEY);
if (limitObj == null) {
Expand Down Expand Up @@ -153,46 +154,51 @@ public static class VertxBlockingInput implements Handler<Buffer> {
public VertxBlockingInput(HttpServerRequest request, long timeout) throws IOException {
this.request = request;
this.timeout = timeout;
if (!request.isEnded()) {
request.pause();
request.handler(this);
request.endHandler(new Handler<Void>() {
@Override
public void handle(Void event) {
synchronized (request.connection()) {
eof = true;
if (waiting) {
request.connection().notify();
final ConnectionBase connection = (ConnectionBase) request.connection();
synchronized (connection) {
if (!connection.channel().isOpen()) {
readException = new ClosedChannelException();
} else if (!request.isEnded()) {
request.pause();
request.handler(this);
request.endHandler(new Handler<Void>() {
@Override
public void handle(Void event) {
synchronized (connection) {
eof = true;
if (waiting) {
connection.notify();
}
}
}
}
});
request.exceptionHandler(new Handler<Throwable>() {
@Override
public void handle(Throwable event) {
synchronized (request.connection()) {
readException = new IOException(event);
if (input1 != null) {
input1.getByteBuf().release();
input1 = null;
}
if (inputOverflow != null) {
Buffer d = inputOverflow.poll();
while (d != null) {
d.getByteBuf().release();
d = inputOverflow.poll();
});
request.exceptionHandler(new Handler<Throwable>() {
@Override
public void handle(Throwable event) {
synchronized (connection) {
readException = new IOException(event);
if (input1 != null) {
input1.getByteBuf().release();
input1 = null;
}
if (inputOverflow != null) {
Buffer d = inputOverflow.poll();
while (d != null) {
d.getByteBuf().release();
d = inputOverflow.poll();
}
}
if (waiting) {
connection.notify();
}
}
if (waiting) {
request.connection().notify();
}
}
}

});
request.fetch(1);
} else {
eof = true;
});
request.fetch(1);
} else {
eof = true;
}
}
}

Expand Down