diff --git a/incubating/wrappers/java/pom.xml b/incubating/wrappers/java/pom.xml
index 6b0bdcbbf8..cd47fac7b2 100644
--- a/incubating/wrappers/java/pom.xml
+++ b/incubating/wrappers/java/pom.xml
@@ -6,7 +6,7 @@ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xs
io.seldon.wrapper
seldon-core-wrapper
jar
- 0.4.1
+ 0.4.2
Seldon Core Java Wrapper
http://maven.apache.org
Wrapper for seldon-core Java prediction models.
diff --git a/incubating/wrappers/java/src/main/java/io/seldon/wrapper/api/ModelRestController.java b/incubating/wrappers/java/src/main/java/io/seldon/wrapper/api/ModelPredictionController.java
similarity index 55%
rename from incubating/wrappers/java/src/main/java/io/seldon/wrapper/api/ModelRestController.java
rename to incubating/wrappers/java/src/main/java/io/seldon/wrapper/api/ModelPredictionController.java
index 0db3625b0e..6609fa9bb5 100644
--- a/incubating/wrappers/java/src/main/java/io/seldon/wrapper/api/ModelRestController.java
+++ b/incubating/wrappers/java/src/main/java/io/seldon/wrapper/api/ModelPredictionController.java
@@ -1,5 +1,13 @@
package io.seldon.wrapper.api;
+/**
+ * This is a model prediction API and container everything related to predictions.
+ *
+ * NOTE:
+ * This is NOT a RestFull API, since there is no resource or state (aka there is no prediction object or state)
+ * involved.
+ */
+
import com.google.protobuf.InvalidProtocolBufferException;
import io.seldon.protos.PredictionProtos.Feedback;
import io.seldon.protos.PredictionProtos.SeldonMessage;
@@ -11,32 +19,79 @@
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.condition.ConditionalOnExpression;
import org.springframework.http.HttpStatus;
+import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
-import org.springframework.web.bind.annotation.RequestMapping;
-import org.springframework.web.bind.annotation.RequestMethod;
-import org.springframework.web.bind.annotation.RequestParam;
-import org.springframework.web.bind.annotation.RestController;
+import org.springframework.web.bind.annotation.*;
@RestController
@ConditionalOnExpression("${seldon.api.model.enabled:false}")
-public class ModelRestController {
- private static Logger logger = LoggerFactory.getLogger(ModelRestController.class.getName());
+public class ModelPredictionController {
+ private static Logger logger = LoggerFactory.getLogger(ModelPredictionController.class.getName());
@Autowired SeldonPredictionService predictionService;
+ /**
+ * Will accept a POST or a GET request with either a query parameter or a FORM parameter.
+ *
+ * Examples:
+ * GET -> /predict?json={ ... }
+ * curl -s \
+ * localhost:9000/predict?json={"data": {"names": ["a", "b"], "ndarray": [[1.0, 2.0]]}}' \
+ *
+ * POST FORM -> /predict
+ * curl -s -X POST \
+ * -d 'json={"data": {"names": ["a", "b"], "ndarray": [[1.0, 2.0]]}}' \
+ * localhost:9000/predict
+ *
+ * @param json
+ * @return
+ * @deprecated
+ */
+ @Deprecated
@RequestMapping(
value = "/predict",
method = {RequestMethod.GET, RequestMethod.POST},
- produces = "application/json; charset=utf-8")
- public ResponseEntity predictions(@RequestParam("json") String json) {
+ params = {"json"},
+ produces = MediaType.APPLICATION_JSON_UTF8_VALUE
+ )
+ public ResponseEntity predictLegacy(@RequestParam("json") String json) {
+ return this.predict(json);
+ }
+
+ /**
+ * Will accept a POST with a proper JSON body.
+ *
+ * Examples:
+ * POST -> /predict
+ * curl -s -X POST \
+ * -d '{"data": {"names": ["a", "b"], "ndarray": [[1.0, 2.0]]}}' \
+ * localhost:9000/predict
+ *
+ * curl -s -X POST \
+ * -d '{"jsonData": {"foo": "bar"}' \
+ * localhost:9000/predict
+ *
+ * @param jsonStr
+ * @return
+ */
+ @RequestMapping(
+ value = "/predict",
+ method = {RequestMethod.POST},
+ consumes = {
+ MediaType.APPLICATION_JSON_VALUE,
+ MediaType.APPLICATION_JSON_UTF8_VALUE
+ },
+ produces = MediaType.APPLICATION_JSON_VALUE
+ )
+ public ResponseEntity predict(@RequestBody String jsonStr) {
SeldonMessage request;
try {
SeldonMessage.Builder builder = SeldonMessage.newBuilder();
- ProtoBufUtils.updateMessageBuilderFromJson(builder, json);
+ ProtoBufUtils.updateMessageBuilderFromJson(builder, jsonStr);
request = builder.build();
} catch (InvalidProtocolBufferException e) {
logger.error("Bad request", e);
- throw new APIException(ApiExceptionType.WRAPPER_INVALID_MESSAGE, json);
+ throw new APIException(ApiExceptionType.WRAPPER_INVALID_MESSAGE, jsonStr);
}
try {
diff --git a/incubating/wrappers/java/src/test/java/io/seldon/wrapper/api/ModelPredictionControllerTest.java b/incubating/wrappers/java/src/test/java/io/seldon/wrapper/api/ModelPredictionControllerTest.java
new file mode 100644
index 0000000000..35c14907c9
--- /dev/null
+++ b/incubating/wrappers/java/src/test/java/io/seldon/wrapper/api/ModelPredictionControllerTest.java
@@ -0,0 +1,245 @@
+package io.seldon.wrapper.api;
+
+import static io.seldon.wrapper.util.TestUtils.readFile;
+import static org.hamcrest.Matchers.*;
+import static org.junit.Assert.assertNotNull;
+import static org.springframework.test.web.servlet.result.MockMvcResultHandlers.print;
+import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath;
+import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;
+
+import java.nio.charset.StandardCharsets;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.FixMethodOrder;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.MethodSorters;
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.boot.test.context.SpringBootTest;
+import org.springframework.boot.test.context.SpringBootTest.WebEnvironment;
+import org.springframework.boot.web.server.LocalServerPort;
+import org.springframework.http.MediaType;
+import org.springframework.test.context.junit4.SpringRunner;
+import org.springframework.test.web.servlet.MockMvc;
+import org.springframework.test.web.servlet.MvcResult;
+import org.springframework.test.web.servlet.request.MockMvcRequestBuilders;
+import org.springframework.test.web.servlet.setup.MockMvcBuilders;
+import org.springframework.web.context.WebApplicationContext;
+
+@RunWith(SpringRunner.class)
+@SpringBootTest(webEnvironment = WebEnvironment.RANDOM_PORT)
+@FixMethodOrder(MethodSorters.NAME_ASCENDING)
+// @AutoConfigureMockMvc
+public class ModelPredictionControllerTest {
+
+ @Autowired private WebApplicationContext context;
+
+ @Autowired
+ ModelPredictionController modelPredictionController;
+
+ // @Autowired
+ private MockMvc mvc;
+
+ @Before
+ public void setup() {
+ mvc = MockMvcBuilders.webAppContextSetup(context).build();
+ }
+
+ @LocalServerPort private int port;
+
+ @Test
+ public void testPredictLegacyGetQuery() throws Exception {
+ final String predictJson = TestMessages.DEFAULT_DATA;
+ assertNotNull(predictJson);
+
+ MvcResult res =
+ mvc.perform(
+ MockMvcRequestBuilders.get("/predict")
+ .accept(MediaType.APPLICATION_JSON_UTF8)
+ .param("json", predictJson)
+ .contentType(MediaType.APPLICATION_JSON_UTF8))
+ .andReturn();
+
+ String response = res.getResponse().getContentAsString();
+ System.out.println(response);
+ Assert.assertEquals(200, res.getResponse().getStatus());
+ }
+
+ @Test
+ public void testPredictLegacyPostQuery() throws Exception {
+ final String predictJson = TestMessages.DEFAULT_DATA;
+ assertNotNull(predictJson);
+
+ MvcResult res =
+ mvc.perform(
+ MockMvcRequestBuilders.post("/predict")
+ .accept(MediaType.APPLICATION_JSON_UTF8)
+ .queryParam("json", predictJson)
+ .contentType(MediaType.APPLICATION_JSON_UTF8))
+ .andReturn();
+
+ String response = res.getResponse().getContentAsString();
+ System.out.println(response);
+ Assert.assertEquals(200, res.getResponse().getStatus());
+ }
+
+ @Test
+ public void testPredictLegacyPostForm() throws Exception {
+ final String predictJson = TestMessages.DEFAULT_DATA;
+ assertNotNull(predictJson);
+
+ MvcResult res =
+ mvc.perform(
+ MockMvcRequestBuilders.post("/predict")
+ .accept(MediaType.APPLICATION_JSON_UTF8)
+ .param("json", predictJson)
+ .contentType(MediaType.APPLICATION_FORM_URLENCODED))
+ .andReturn();
+
+ String response = res.getResponse().getContentAsString();
+ System.out.println(response);
+ Assert.assertEquals(200, res.getResponse().getStatus());
+ }
+
+ @Test
+ public void testPredictLegacyButNotPredict() throws Exception {
+ final String predictJson = TestMessages.DEFAULT_DATA;
+ assertNotNull(predictJson);
+
+ MvcResult res =
+ mvc.perform(
+ MockMvcRequestBuilders.post("/predict")
+ .accept(MediaType.APPLICATION_JSON_UTF8)
+ .param("json", predictJson)
+ .content(predictJson)
+ .contentType(MediaType.APPLICATION_FORM_URLENCODED))
+ .andReturn();
+
+ String response = res.getResponse().getContentAsString();
+ System.out.println(response);
+ Assert.assertEquals(200, res.getResponse().getStatus());
+
+ // if we get back a header of "application/json;charset=UTF-8" then we are hitting the legacy predict
+ Assert.assertEquals(res.getResponse().getContentType(), MediaType.APPLICATION_JSON_UTF8_VALUE);
+ }
+
+ @Test
+ public void testPredictButNotPredictLegacy() throws Exception {
+ final String predictJson = TestMessages.DEFAULT_DATA;
+ assertNotNull(predictJson);
+
+ MvcResult res =
+ mvc.perform(
+ MockMvcRequestBuilders.post("/predict")
+ .accept(MediaType.APPLICATION_JSON)
+ .content(predictJson)
+ .contentType(MediaType.APPLICATION_JSON_UTF8))
+ .andReturn();
+ String response = res.getResponse().getContentAsString();
+ System.out.println(response);
+ Assert.assertEquals(200, res.getResponse().getStatus());
+
+ // if we get back a header of "application/json" then we are hitting the legacy predict
+ Assert.assertEquals(res.getResponse().getContentType(), MediaType.APPLICATION_JSON_VALUE);
+ }
+
+ @Test
+ public void testPredict() throws Exception {
+ final String predictJson = TestMessages.DEFAULT_DATA;
+ assertNotNull(predictJson);
+
+ MvcResult res =
+ mvc.perform(
+ MockMvcRequestBuilders.post("/predict")
+ .accept(MediaType.APPLICATION_JSON)
+ .content(predictJson)
+ .contentType(MediaType.APPLICATION_JSON))
+ .andReturn();
+
+ String response = res.getResponse().getContentAsString();
+ System.out.println(response);
+ Assert.assertEquals(200, res.getResponse().getStatus());
+ }
+
+ @Test
+ public void testPredictWithUTF8Header() throws Exception {
+ final String predictJson = TestMessages.DEFAULT_DATA;
+ assertNotNull(predictJson);
+
+ MvcResult res =
+ mvc.perform(
+ MockMvcRequestBuilders.post("/predict")
+ .accept(MediaType.APPLICATION_JSON)
+ .content(predictJson)
+ .contentType(MediaType.APPLICATION_JSON_UTF8))
+ .andReturn();
+
+ String response = res.getResponse().getContentAsString();
+ System.out.println(response);
+ Assert.assertEquals(200, res.getResponse().getStatus());
+
+ // if we get back a header of "application/json" then we are hitting the legacy predict
+ Assert.assertEquals(res.getResponse().getContentType(), MediaType.APPLICATION_JSON_VALUE);
+ }
+
+ @Test
+ public void testPredictWithDefaultData() throws Exception {
+ final String predictJson = TestMessages.DEFAULT_DATA;
+ assertNotNull(predictJson);
+
+ MvcResult res =
+ mvc.perform(
+ MockMvcRequestBuilders.post("/predict")
+ .accept(MediaType.APPLICATION_JSON)
+ .content(predictJson)
+ .contentType(MediaType.APPLICATION_JSON_UTF8)
+ )
+ .andDo(print())
+ .andExpect(status().isOk())
+ .andExpect(jsonPath("$.data", is(notNullValue())))
+ .andReturn();
+
+ // if we get back a header of "application/json" then we are hitting the legacy predict
+ Assert.assertEquals(res.getResponse().getContentType(), MediaType.APPLICATION_JSON_VALUE);
+ }
+
+ @Test
+ public void testPredictWithJsonData_UTF8Header() throws Exception {
+ final String predictJson = TestMessages.JSON_DATA;
+ assertNotNull(predictJson);
+
+ MvcResult res =
+ mvc.perform(
+ MockMvcRequestBuilders.post("/predict")
+ .accept(MediaType.APPLICATION_JSON)
+ .content(predictJson)
+ .contentType(MediaType.APPLICATION_JSON_UTF8))
+ .andExpect(status().isOk())
+ .andExpect(jsonPath("$.jsonData", is(notNullValue())))
+ .andReturn();
+
+ String response = res.getResponse().getContentAsString();
+ System.out.println(response);
+
+ // if we get back a header of "application/json" then we are hitting the legacy predict
+ Assert.assertEquals(res.getResponse().getContentType(), MediaType.APPLICATION_JSON_VALUE);
+ }
+
+ @Test
+ public void testFeedback() throws Exception {
+ final String predictJson = TestMessages.DEFAULT_DATA;
+ assertNotNull(predictJson);
+
+ MvcResult res =
+ mvc.perform(
+ MockMvcRequestBuilders.get("/send-feedback")
+ .accept(MediaType.APPLICATION_JSON_UTF8)
+ .param("json", predictJson)
+ .contentType(MediaType.APPLICATION_JSON_UTF8))
+ .andReturn();
+
+ String response = res.getResponse().getContentAsString();
+ System.out.println(response);
+ Assert.assertEquals(200, res.getResponse().getStatus());
+ }
+}
diff --git a/incubating/wrappers/java/src/test/java/io/seldon/wrapper/api/RestControllerTest.java b/incubating/wrappers/java/src/test/java/io/seldon/wrapper/api/RestControllerTest.java
deleted file mode 100644
index 1b7e4b064c..0000000000
--- a/incubating/wrappers/java/src/test/java/io/seldon/wrapper/api/RestControllerTest.java
+++ /dev/null
@@ -1,70 +0,0 @@
-package io.seldon.wrapper.api;
-
-import static io.seldon.wrapper.util.TestUtils.readFile;
-
-import java.nio.charset.StandardCharsets;
-import org.junit.Assert;
-import org.junit.Before;
-import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.springframework.beans.factory.annotation.Autowired;
-import org.springframework.boot.test.context.SpringBootTest;
-import org.springframework.boot.test.context.SpringBootTest.WebEnvironment;
-import org.springframework.boot.web.server.LocalServerPort;
-import org.springframework.http.MediaType;
-import org.springframework.test.context.junit4.SpringRunner;
-import org.springframework.test.web.servlet.MockMvc;
-import org.springframework.test.web.servlet.MvcResult;
-import org.springframework.test.web.servlet.request.MockMvcRequestBuilders;
-import org.springframework.test.web.servlet.setup.MockMvcBuilders;
-import org.springframework.web.context.WebApplicationContext;
-
-@RunWith(SpringRunner.class)
-@SpringBootTest(webEnvironment = WebEnvironment.RANDOM_PORT)
-// @AutoConfigureMockMvc
-public class RestControllerTest {
-
- @Autowired private WebApplicationContext context;
-
- @Autowired ModelRestController modelRestController;
-
- // @Autowired
- private MockMvc mvc;
-
- @Before
- public void setup() {
- mvc = MockMvcBuilders.webAppContextSetup(context).build();
- }
-
- @LocalServerPort private int port;
-
- @Test
- public void testPredict() throws Exception {
- final String predictJson = readFile("src/test/resources/request.json", StandardCharsets.UTF_8);
- MvcResult res =
- mvc.perform(
- MockMvcRequestBuilders.get("/predict")
- .accept(MediaType.APPLICATION_JSON_UTF8)
- .param("json", predictJson)
- .contentType(MediaType.APPLICATION_JSON_UTF8))
- .andReturn();
- String response = res.getResponse().getContentAsString();
- System.out.println(response);
- Assert.assertEquals(200, res.getResponse().getStatus());
- }
-
- @Test
- public void testFeedback() throws Exception {
- final String predictJson = readFile("src/test/resources/feedback.json", StandardCharsets.UTF_8);
- MvcResult res =
- mvc.perform(
- MockMvcRequestBuilders.get("/send-feedback")
- .accept(MediaType.APPLICATION_JSON_UTF8)
- .param("json", predictJson)
- .contentType(MediaType.APPLICATION_JSON_UTF8))
- .andReturn();
- String response = res.getResponse().getContentAsString();
- System.out.println(response);
- Assert.assertEquals(200, res.getResponse().getStatus());
- }
-}
diff --git a/incubating/wrappers/java/src/test/java/io/seldon/wrapper/api/TestMessages.java b/incubating/wrappers/java/src/test/java/io/seldon/wrapper/api/TestMessages.java
new file mode 100644
index 0000000000..ba5df1552d
--- /dev/null
+++ b/incubating/wrappers/java/src/test/java/io/seldon/wrapper/api/TestMessages.java
@@ -0,0 +1,21 @@
+package io.seldon.wrapper.api;
+
+import static io.seldon.wrapper.util.TestUtils.readFileFromAbsolutePathOrResources;
+
+final public class TestMessages {
+
+ /**
+ * All possible fields based on the SeldonMessage Proto:
+ * https://docs.seldon.io/projects/seldon-core/en/v1.6.0/reference/apis/prediction.html
+ */
+ public static final String TF_DATA = readFile("requests/defaultData.json");
+ public static final String DEFAULT_DATA = TF_DATA;
+ public static final String JSON_DATA = readFile("requests/jsonData.json");
+ // TODO: add binData
+ // TODO: add strData
+ // TODO: add customData
+
+ private static String readFile(String file) {
+ return readFileFromAbsolutePathOrResources(file);
+ }
+}
diff --git a/incubating/wrappers/java/src/test/java/io/seldon/wrapper/api/TestPredictionService.java b/incubating/wrappers/java/src/test/java/io/seldon/wrapper/api/TestPredictionService.java
index 2877b085fb..2fad1b4847 100644
--- a/incubating/wrappers/java/src/test/java/io/seldon/wrapper/api/TestPredictionService.java
+++ b/incubating/wrappers/java/src/test/java/io/seldon/wrapper/api/TestPredictionService.java
@@ -3,14 +3,14 @@
import io.seldon.protos.PredictionProtos.DefaultData;
import io.seldon.protos.PredictionProtos.SeldonMessage;
import io.seldon.protos.PredictionProtos.Tensor;
+import io.seldon.wrapper.pb.ProtoBufUtils;
import org.springframework.stereotype.Component;
@Component
public class TestPredictionService implements SeldonPredictionService {
@Override
public SeldonMessage predict(SeldonMessage payload) {
- return SeldonMessage.newBuilder()
- .setData(DefaultData.newBuilder().setTensor(Tensor.newBuilder().addShape(1).addValues(1.0)))
- .build();
+ // echo payload back
+ return payload.toBuilder().build();
}
}
diff --git a/incubating/wrappers/java/src/test/java/io/seldon/wrapper/util/TestUtils.java b/incubating/wrappers/java/src/test/java/io/seldon/wrapper/util/TestUtils.java
index 048761effb..87f25978aa 100644
--- a/incubating/wrappers/java/src/test/java/io/seldon/wrapper/util/TestUtils.java
+++ b/incubating/wrappers/java/src/test/java/io/seldon/wrapper/util/TestUtils.java
@@ -1,14 +1,51 @@
package io.seldon.wrapper.util;
-import java.io.IOException;
+import java.io.*;
import java.nio.charset.Charset;
+import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Paths;
public class TestUtils {
+ private static final ClassLoader classLoader = TestUtils.class.getClassLoader();
+
public static String readFile(String path, Charset encoding) throws IOException {
byte[] encoded = Files.readAllBytes(Paths.get(path));
return new String(encoded, encoding);
}
+
+ /**
+ * Will load file from either an absolute path of a relative path from "target/test-classes"
+ * @param file file path (ex: "requests/jsonData.json", "/dev/null")
+ * @return
+ */
+ public static String readFileFromAbsolutePathOrResources(String file) {
+ try {
+ InputStream is = getInputStreamFromAbsolutePathOrResources(file, classLoader);
+ byte[] bytes = is.readAllBytes();
+ return new String(bytes, StandardCharsets.UTF_8);
+ } catch(Throwable t) {
+ System.out.println(t);
+ t.printStackTrace();
+ // nothing
+ }
+ return null;
+ }
+
+ public static InputStream getInputStreamFromAbsolutePathOrResources(String file, ClassLoader classLoader) {
+ InputStream is = null;
+
+ // try loading assuming an absolute path
+ try {
+ is = new FileInputStream(file);
+ } catch ( FileNotFoundException fne ) {
+ // Nothing
+ }
+ if( is == null ) {
+ is = classLoader.getResourceAsStream(file);
+ }
+
+ return is;
+ }
}
diff --git a/incubating/wrappers/java/src/test/resources/feedback.json b/incubating/wrappers/java/src/test/resources/feedback.json
index eefde6f0af..134aa50c19 100644
--- a/incubating/wrappers/java/src/test/resources/feedback.json
+++ b/incubating/wrappers/java/src/test/resources/feedback.json
@@ -40,4 +40,4 @@
}
},
"reward":1
-}
\ No newline at end of file
+}
diff --git a/incubating/wrappers/java/src/test/resources/request.json b/incubating/wrappers/java/src/test/resources/requests/defaultData.json
similarity index 96%
rename from incubating/wrappers/java/src/test/resources/request.json
rename to incubating/wrappers/java/src/test/resources/requests/defaultData.json
index c02d8f332d..7d818e0e62 100644
--- a/incubating/wrappers/java/src/test/resources/request.json
+++ b/incubating/wrappers/java/src/test/resources/requests/defaultData.json
@@ -7,4 +7,4 @@
]
]
}
-}
\ No newline at end of file
+}
diff --git a/incubating/wrappers/java/src/test/resources/requests/jsonData.json b/incubating/wrappers/java/src/test/resources/requests/jsonData.json
new file mode 100644
index 0000000000..73970991c1
--- /dev/null
+++ b/incubating/wrappers/java/src/test/resources/requests/jsonData.json
@@ -0,0 +1,9 @@
+{
+ "jsonData": {
+ "data": {
+ "subject": "helpful message",
+ "body": "nothing strange, good, rewarding"
+ },
+ "foo": "bar"
+ }
+}