Skip to content

Commit

Permalink
Try to fix windows flaky test (#792)
Browse files Browse the repository at this point in the history
Change-Id: I9178a87fce442e9a99fb2181d5a2428ed7b194b4
  • Loading branch information
frankfliu authored Mar 26, 2021
1 parent 7f66ab1 commit ffbf983
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 75 deletions.
4 changes: 0 additions & 4 deletions serving/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@ jar {
exclude "META-INF/MANIFEST*"
}

test {
maxParallelForks = 1
}

application {
mainClassName = System.getProperty("main", "ai.djl.serving.ModelServer")
}
Expand Down
8 changes: 4 additions & 4 deletions serving/src/test/java/ai/djl/serving/ConfigManagerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.testng.Assert;
import org.testng.annotations.Test;

public class ConfigManagerTest {
public final class ConfigManagerTest {

@Test
public void testSsl()
private ConfigManagerTest() {}

public static void testSsl()
throws IOException, GeneralSecurityException, ParseException,
ReflectiveOperationException {
ConfigManager.init(parseArguments(new String[0]));
Expand Down
112 changes: 45 additions & 67 deletions serving/src/test/java/ai/djl/serving/ModelServerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@ public class ModelServerTest {
private ConfigManager configManager;
private ModelServer server;
private byte[] testImage;
CountDownLatch latch;
HttpResponseStatus httpStatus;
String result;
HttpHeaders headers;
volatile CountDownLatch latch;
volatile HttpResponseStatus httpStatus;
volatile String result;
volatile HttpHeaders headers;

static {
try {
Expand Down Expand Up @@ -134,7 +134,8 @@ public void afterSuite() {
@Test
public void test()
throws InterruptedException, HttpPostRequestEncoder.ErrorDataEncoderException,
IOException {
IOException, ParseException, GeneralSecurityException,
ReflectiveOperationException {
Assert.assertTrue(server.isRunning());

Channel channel = null;
Expand Down Expand Up @@ -185,11 +186,12 @@ public void test()
testRegisterModelNotFound();
testRegisterModelConflict();
testServiceUnavailable();

ConfigManagerTest.testSsl();
}

private void testRoot(Channel channel) throws InterruptedException {
result = null;
latch = new CountDownLatch(1);
reset();
HttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.OPTIONS, "/");
channel.writeAndFlush(req).sync();
latch.await();
Expand All @@ -198,8 +200,7 @@ private void testRoot(Channel channel) throws InterruptedException {
}

private void testPing(Channel channel) throws InterruptedException {
result = null;
latch = new CountDownLatch(1);
reset();
HttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/ping");
channel.writeAndFlush(req);
latch.await();
Expand All @@ -210,8 +211,7 @@ private void testPing(Channel channel) throws InterruptedException {
}

private void testPredictions(Channel channel) throws InterruptedException {
result = null;
latch = new CountDownLatch(1);
reset();
DefaultFullHttpRequest req =
new DefaultFullHttpRequest(
HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/mlp");
Expand All @@ -228,8 +228,7 @@ private void testPredictions(Channel channel) throws InterruptedException {
}

private void testInvocations(Channel channel) throws InterruptedException {
result = null;
latch = new CountDownLatch(1);
reset();
DefaultFullHttpRequest req =
new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/invocations");
req.content().writeBytes(testImage);
Expand All @@ -246,8 +245,7 @@ private void testInvocations(Channel channel) throws InterruptedException {
private void testInvocationsMultipart(Channel channel)
throws InterruptedException, HttpPostRequestEncoder.ErrorDataEncoderException,
IOException {
result = null;
latch = new CountDownLatch(1);
reset();
DefaultFullHttpRequest req =
new DefaultFullHttpRequest(
HttpVersion.HTTP_1_1, HttpMethod.POST, "/invocations?model_name=mlp");
Expand Down Expand Up @@ -275,9 +273,7 @@ private void testInvocationsMultipart(Channel channel)

private void testRegisterModelAsync(Channel channel)
throws InterruptedException, UnsupportedEncodingException {
result = null;
latch = new CountDownLatch(1);

reset();
String url = "https://resources.djl.ai/test-models/mlp.tar.gz";
HttpRequest req =
new DefaultFullHttpRequest(
Expand All @@ -296,9 +292,7 @@ private void testRegisterModelAsync(Channel channel)
for (int i = 0; i < 5; ++i) {
String token = "";
while (token != null) {
result = null;
latch = new CountDownLatch(1);

reset();
req =
new DefaultFullHttpRequest(
HttpVersion.HTTP_1_1,
Expand All @@ -324,8 +318,7 @@ private void testRegisterModelAsync(Channel channel)

private void testRegisterModel(Channel channel)
throws InterruptedException, UnsupportedEncodingException {
result = null;
latch = new CountDownLatch(1);
reset();

String url = "https://resources.djl.ai/test-models/mlp.tar.gz";
HttpRequest req =
Expand All @@ -342,8 +335,7 @@ private void testRegisterModel(Channel channel)
}

private void testScaleModel(Channel channel) throws InterruptedException {
result = null;
latch = new CountDownLatch(1);
reset();
HttpRequest req =
new DefaultFullHttpRequest(
HttpVersion.HTTP_1_1,
Expand All @@ -359,8 +351,7 @@ private void testScaleModel(Channel channel) throws InterruptedException {
}

private void testDescribeModel(Channel channel) throws InterruptedException {
result = null;
latch = new CountDownLatch(1);
reset();
HttpRequest req =
new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/models/mlp_2");
channel.writeAndFlush(req);
Expand All @@ -384,8 +375,7 @@ private void testDescribeModel(Channel channel) throws InterruptedException {
}

private void testUnregisterModel(Channel channel) throws InterruptedException {
result = null;
latch = new CountDownLatch(1);
reset();
HttpRequest req =
new DefaultFullHttpRequest(
HttpVersion.HTTP_1_1, HttpMethod.DELETE, "/models/mlp_1");
Expand All @@ -397,8 +387,7 @@ private void testUnregisterModel(Channel channel) throws InterruptedException {
}

private void testDescribeApi(Channel channel) throws InterruptedException {
result = null;
latch = new CountDownLatch(1);
reset();
HttpRequest req =
new DefaultFullHttpRequest(
HttpVersion.HTTP_1_1, HttpMethod.OPTIONS, "/predictions/mlp");
Expand All @@ -409,8 +398,7 @@ private void testDescribeApi(Channel channel) throws InterruptedException {
}

private void testPredictionsInvalidRequestSize(Channel channel) throws InterruptedException {
result = null;
latch = new CountDownLatch(1);
reset();
DefaultFullHttpRequest req =
new DefaultFullHttpRequest(
HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/mlp");
Expand All @@ -429,8 +417,7 @@ private void testInvalidRootRequest() throws InterruptedException {
Channel channel = connect(Connector.ConnectorType.INFERENCE);
Assert.assertNotNull(channel);

result = null;
latch = new CountDownLatch(1);
reset();
HttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/");
channel.writeAndFlush(req).sync();
latch.await();
Expand All @@ -447,8 +434,7 @@ private void testInvalidUri() throws InterruptedException {
Channel channel = connect(Connector.ConnectorType.INFERENCE);
Assert.assertNotNull(channel);

result = null;
latch = new CountDownLatch(1);
reset();
HttpRequest req =
new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/InvalidUrl");
channel.writeAndFlush(req).sync();
Expand All @@ -466,8 +452,7 @@ private void testInvalidDescribeModel() throws InterruptedException {
Channel channel = connect(Connector.ConnectorType.INFERENCE);
Assert.assertNotNull(channel);

result = null;
latch = new CountDownLatch(1);
reset();
HttpRequest req =
new DefaultFullHttpRequest(
HttpVersion.HTTP_1_1, HttpMethod.OPTIONS, "/predictions/InvalidModel");
Expand All @@ -486,8 +471,7 @@ private void testInvalidPredictionsUri() throws InterruptedException {
Channel channel = connect(Connector.ConnectorType.INFERENCE);
Assert.assertNotNull(channel);

result = null;
latch = new CountDownLatch(1);
reset();
HttpRequest req =
new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/predictions");
channel.writeAndFlush(req).sync();
Expand All @@ -505,8 +489,7 @@ private void testPredictionsModelNotFound() throws InterruptedException {
Channel channel = connect(Connector.ConnectorType.INFERENCE);
Assert.assertNotNull(channel);

result = null;
latch = new CountDownLatch(1);
reset();
HttpRequest req =
new DefaultFullHttpRequest(
HttpVersion.HTTP_1_1, HttpMethod.GET, "/predictions/InvalidModel");
Expand All @@ -525,8 +508,7 @@ private void testInvalidManagementUri() throws InterruptedException {
Channel channel = connect(Connector.ConnectorType.MANAGEMENT);
Assert.assertNotNull(channel);

result = null;
latch = new CountDownLatch(1);
reset();
HttpRequest req =
new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/InvalidUrl");
channel.writeAndFlush(req).sync();
Expand All @@ -544,8 +526,7 @@ private void testInvalidManagementMethod() throws InterruptedException {
Channel channel = connect(Connector.ConnectorType.MANAGEMENT);
Assert.assertNotNull(channel);

result = null;
latch = new CountDownLatch(1);
reset();
HttpRequest req =
new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.PUT, "/models");
channel.writeAndFlush(req).sync();
Expand All @@ -563,8 +544,7 @@ private void testInvalidPredictionsMethod() throws InterruptedException {
Channel channel = connect(Connector.ConnectorType.MANAGEMENT);
Assert.assertNotNull(channel);

result = null;
latch = new CountDownLatch(1);
reset();
HttpRequest req =
new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/models/noop");
channel.writeAndFlush(req).sync();
Expand All @@ -582,8 +562,7 @@ private void testDescribeModelNotFound() throws InterruptedException {
Channel channel = connect(Connector.ConnectorType.MANAGEMENT);
Assert.assertNotNull(channel);

result = null;
latch = new CountDownLatch(1);
reset();
HttpRequest req =
new DefaultFullHttpRequest(
HttpVersion.HTTP_1_1, HttpMethod.GET, "/models/InvalidModel");
Expand All @@ -602,8 +581,7 @@ private void testRegisterModelMissingUrl() throws InterruptedException {
Channel channel = connect(Connector.ConnectorType.MANAGEMENT);
Assert.assertNotNull(channel);

result = null;
latch = new CountDownLatch(1);
reset();
HttpRequest req =
new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/models");
channel.writeAndFlush(req).sync();
Expand All @@ -621,8 +599,7 @@ private void testRegisterModelNotFound() throws InterruptedException {
Channel channel = connect(Connector.ConnectorType.MANAGEMENT);
Assert.assertNotNull(channel);

result = null;
latch = new CountDownLatch(1);
reset();
HttpRequest req =
new DefaultFullHttpRequest(
HttpVersion.HTTP_1_1, HttpMethod.POST, "/models?url=InvalidUrl");
Expand All @@ -643,8 +620,7 @@ private void testRegisterModelConflict()
Channel channel = connect(Connector.ConnectorType.MANAGEMENT);
Assert.assertNotNull(channel);

result = null;
latch = new CountDownLatch(1);
reset();
String url = "https://resources.djl.ai/test-models/mlp.tar.gz";
DefaultFullHttpRequest req =
new DefaultFullHttpRequest(
Expand All @@ -666,8 +642,7 @@ private void testInvalidScaleModel() throws InterruptedException {
Channel channel = connect(Connector.ConnectorType.MANAGEMENT);
Assert.assertNotNull(channel);

result = null;
latch = new CountDownLatch(1);
reset();
DefaultFullHttpRequest req =
new DefaultFullHttpRequest(
HttpVersion.HTTP_1_1,
Expand All @@ -688,8 +663,7 @@ private void testScaleModelNotFound() throws InterruptedException {
Channel channel = connect(Connector.ConnectorType.MANAGEMENT);
Assert.assertNotNull(channel);

result = null;
latch = new CountDownLatch(1);
reset();
HttpRequest req =
new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.PUT, "/models/fake");
channel.writeAndFlush(req).sync();
Expand All @@ -707,8 +681,7 @@ private void testUnregisterModelNotFound() throws InterruptedException {
Channel channel = connect(Connector.ConnectorType.MANAGEMENT);
Assert.assertNotNull(channel);

result = null;
latch = new CountDownLatch(1);
reset();
HttpRequest req =
new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.DELETE, "/models/fake");
channel.writeAndFlush(req).sync();
Expand All @@ -726,8 +699,7 @@ private void testServiceUnavailable() throws InterruptedException {
Channel channel = connect(Connector.ConnectorType.MANAGEMENT);
Assert.assertNotNull(channel);

result = null;
latch = new CountDownLatch(1);
reset();
DefaultFullHttpRequest req =
new DefaultFullHttpRequest(
HttpVersion.HTTP_1_1,
Expand All @@ -736,8 +708,7 @@ private void testServiceUnavailable() throws InterruptedException {
channel.writeAndFlush(req);
latch.await();

result = null;
latch = new CountDownLatch(1);
reset();
req =
new DefaultFullHttpRequest(
HttpVersion.HTTP_1_1, HttpMethod.POST, "/predictions/mlp_2");
Expand Down Expand Up @@ -793,6 +764,13 @@ public void initChannel(Channel ch) {
throw new AssertionError("Failed connect to model server.");
}

private void reset() {
result = null;
httpStatus = null;
headers = null;
latch = new CountDownLatch(1);
}

@ChannelHandler.Sharable
private class TestHandler extends SimpleChannelInboundHandler<FullHttpResponse> {

Expand Down

0 comments on commit ffbf983

Please sign in to comment.