diff --git a/incubating/wrappers/java/pom.xml b/incubating/wrappers/java/pom.xml index 6b0bdcbbf8..cd47fac7b2 100644 --- a/incubating/wrappers/java/pom.xml +++ b/incubating/wrappers/java/pom.xml @@ -6,7 +6,7 @@ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xs io.seldon.wrapper seldon-core-wrapper jar - 0.4.1 + 0.4.2 Seldon Core Java Wrapper http://maven.apache.org Wrapper for seldon-core Java prediction models. diff --git a/incubating/wrappers/java/src/main/java/io/seldon/wrapper/api/ModelRestController.java b/incubating/wrappers/java/src/main/java/io/seldon/wrapper/api/ModelPredictionController.java similarity index 55% rename from incubating/wrappers/java/src/main/java/io/seldon/wrapper/api/ModelRestController.java rename to incubating/wrappers/java/src/main/java/io/seldon/wrapper/api/ModelPredictionController.java index 0db3625b0e..6609fa9bb5 100644 --- a/incubating/wrappers/java/src/main/java/io/seldon/wrapper/api/ModelRestController.java +++ b/incubating/wrappers/java/src/main/java/io/seldon/wrapper/api/ModelPredictionController.java @@ -1,5 +1,13 @@ package io.seldon.wrapper.api; +/** + * This is a model prediction API and container everything related to predictions. + * + * NOTE: + * This is NOT a RestFull API, since there is no resource or state (aka there is no prediction object or state) + * involved. + */ + import com.google.protobuf.InvalidProtocolBufferException; import io.seldon.protos.PredictionProtos.Feedback; import io.seldon.protos.PredictionProtos.SeldonMessage; @@ -11,32 +19,79 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.boot.autoconfigure.condition.ConditionalOnExpression; import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; -import org.springframework.web.bind.annotation.RequestMapping; -import org.springframework.web.bind.annotation.RequestMethod; -import org.springframework.web.bind.annotation.RequestParam; -import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.bind.annotation.*; @RestController @ConditionalOnExpression("${seldon.api.model.enabled:false}") -public class ModelRestController { - private static Logger logger = LoggerFactory.getLogger(ModelRestController.class.getName()); +public class ModelPredictionController { + private static Logger logger = LoggerFactory.getLogger(ModelPredictionController.class.getName()); @Autowired SeldonPredictionService predictionService; + /** + * Will accept a POST or a GET request with either a query parameter or a FORM parameter. + * + * Examples: + * GET -> /predict?json={ ... } + * curl -s \ + * localhost:9000/predict?json={"data": {"names": ["a", "b"], "ndarray": [[1.0, 2.0]]}}' \ + * + * POST FORM -> /predict + * curl -s -X POST \ + * -d 'json={"data": {"names": ["a", "b"], "ndarray": [[1.0, 2.0]]}}' \ + * localhost:9000/predict + * + * @param json + * @return + * @deprecated + */ + @Deprecated @RequestMapping( value = "/predict", method = {RequestMethod.GET, RequestMethod.POST}, - produces = "application/json; charset=utf-8") - public ResponseEntity predictions(@RequestParam("json") String json) { + params = {"json"}, + produces = MediaType.APPLICATION_JSON_UTF8_VALUE + ) + public ResponseEntity predictLegacy(@RequestParam("json") String json) { + return this.predict(json); + } + + /** + * Will accept a POST with a proper JSON body. + * + * Examples: + * POST -> /predict + * curl -s -X POST \ + * -d '{"data": {"names": ["a", "b"], "ndarray": [[1.0, 2.0]]}}' \ + * localhost:9000/predict + * + * curl -s -X POST \ + * -d '{"jsonData": {"foo": "bar"}' \ + * localhost:9000/predict + * + * @param jsonStr + * @return + */ + @RequestMapping( + value = "/predict", + method = {RequestMethod.POST}, + consumes = { + MediaType.APPLICATION_JSON_VALUE, + MediaType.APPLICATION_JSON_UTF8_VALUE + }, + produces = MediaType.APPLICATION_JSON_VALUE + ) + public ResponseEntity predict(@RequestBody String jsonStr) { SeldonMessage request; try { SeldonMessage.Builder builder = SeldonMessage.newBuilder(); - ProtoBufUtils.updateMessageBuilderFromJson(builder, json); + ProtoBufUtils.updateMessageBuilderFromJson(builder, jsonStr); request = builder.build(); } catch (InvalidProtocolBufferException e) { logger.error("Bad request", e); - throw new APIException(ApiExceptionType.WRAPPER_INVALID_MESSAGE, json); + throw new APIException(ApiExceptionType.WRAPPER_INVALID_MESSAGE, jsonStr); } try { diff --git a/incubating/wrappers/java/src/test/java/io/seldon/wrapper/api/ModelPredictionControllerTest.java b/incubating/wrappers/java/src/test/java/io/seldon/wrapper/api/ModelPredictionControllerTest.java new file mode 100644 index 0000000000..35c14907c9 --- /dev/null +++ b/incubating/wrappers/java/src/test/java/io/seldon/wrapper/api/ModelPredictionControllerTest.java @@ -0,0 +1,245 @@ +package io.seldon.wrapper.api; + +import static io.seldon.wrapper.util.TestUtils.readFile; +import static org.hamcrest.Matchers.*; +import static org.junit.Assert.assertNotNull; +import static org.springframework.test.web.servlet.result.MockMvcResultHandlers.print; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + +import java.nio.charset.StandardCharsets; +import org.junit.Assert; +import org.junit.Before; +import org.junit.FixMethodOrder; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.MethodSorters; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.boot.test.context.SpringBootTest.WebEnvironment; +import org.springframework.boot.web.server.LocalServerPort; +import org.springframework.http.MediaType; +import org.springframework.test.context.junit4.SpringRunner; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.MvcResult; +import org.springframework.test.web.servlet.request.MockMvcRequestBuilders; +import org.springframework.test.web.servlet.setup.MockMvcBuilders; +import org.springframework.web.context.WebApplicationContext; + +@RunWith(SpringRunner.class) +@SpringBootTest(webEnvironment = WebEnvironment.RANDOM_PORT) +@FixMethodOrder(MethodSorters.NAME_ASCENDING) +// @AutoConfigureMockMvc +public class ModelPredictionControllerTest { + + @Autowired private WebApplicationContext context; + + @Autowired + ModelPredictionController modelPredictionController; + + // @Autowired + private MockMvc mvc; + + @Before + public void setup() { + mvc = MockMvcBuilders.webAppContextSetup(context).build(); + } + + @LocalServerPort private int port; + + @Test + public void testPredictLegacyGetQuery() throws Exception { + final String predictJson = TestMessages.DEFAULT_DATA; + assertNotNull(predictJson); + + MvcResult res = + mvc.perform( + MockMvcRequestBuilders.get("/predict") + .accept(MediaType.APPLICATION_JSON_UTF8) + .param("json", predictJson) + .contentType(MediaType.APPLICATION_JSON_UTF8)) + .andReturn(); + + String response = res.getResponse().getContentAsString(); + System.out.println(response); + Assert.assertEquals(200, res.getResponse().getStatus()); + } + + @Test + public void testPredictLegacyPostQuery() throws Exception { + final String predictJson = TestMessages.DEFAULT_DATA; + assertNotNull(predictJson); + + MvcResult res = + mvc.perform( + MockMvcRequestBuilders.post("/predict") + .accept(MediaType.APPLICATION_JSON_UTF8) + .queryParam("json", predictJson) + .contentType(MediaType.APPLICATION_JSON_UTF8)) + .andReturn(); + + String response = res.getResponse().getContentAsString(); + System.out.println(response); + Assert.assertEquals(200, res.getResponse().getStatus()); + } + + @Test + public void testPredictLegacyPostForm() throws Exception { + final String predictJson = TestMessages.DEFAULT_DATA; + assertNotNull(predictJson); + + MvcResult res = + mvc.perform( + MockMvcRequestBuilders.post("/predict") + .accept(MediaType.APPLICATION_JSON_UTF8) + .param("json", predictJson) + .contentType(MediaType.APPLICATION_FORM_URLENCODED)) + .andReturn(); + + String response = res.getResponse().getContentAsString(); + System.out.println(response); + Assert.assertEquals(200, res.getResponse().getStatus()); + } + + @Test + public void testPredictLegacyButNotPredict() throws Exception { + final String predictJson = TestMessages.DEFAULT_DATA; + assertNotNull(predictJson); + + MvcResult res = + mvc.perform( + MockMvcRequestBuilders.post("/predict") + .accept(MediaType.APPLICATION_JSON_UTF8) + .param("json", predictJson) + .content(predictJson) + .contentType(MediaType.APPLICATION_FORM_URLENCODED)) + .andReturn(); + + String response = res.getResponse().getContentAsString(); + System.out.println(response); + Assert.assertEquals(200, res.getResponse().getStatus()); + + // if we get back a header of "application/json;charset=UTF-8" then we are hitting the legacy predict + Assert.assertEquals(res.getResponse().getContentType(), MediaType.APPLICATION_JSON_UTF8_VALUE); + } + + @Test + public void testPredictButNotPredictLegacy() throws Exception { + final String predictJson = TestMessages.DEFAULT_DATA; + assertNotNull(predictJson); + + MvcResult res = + mvc.perform( + MockMvcRequestBuilders.post("/predict") + .accept(MediaType.APPLICATION_JSON) + .content(predictJson) + .contentType(MediaType.APPLICATION_JSON_UTF8)) + .andReturn(); + String response = res.getResponse().getContentAsString(); + System.out.println(response); + Assert.assertEquals(200, res.getResponse().getStatus()); + + // if we get back a header of "application/json" then we are hitting the legacy predict + Assert.assertEquals(res.getResponse().getContentType(), MediaType.APPLICATION_JSON_VALUE); + } + + @Test + public void testPredict() throws Exception { + final String predictJson = TestMessages.DEFAULT_DATA; + assertNotNull(predictJson); + + MvcResult res = + mvc.perform( + MockMvcRequestBuilders.post("/predict") + .accept(MediaType.APPLICATION_JSON) + .content(predictJson) + .contentType(MediaType.APPLICATION_JSON)) + .andReturn(); + + String response = res.getResponse().getContentAsString(); + System.out.println(response); + Assert.assertEquals(200, res.getResponse().getStatus()); + } + + @Test + public void testPredictWithUTF8Header() throws Exception { + final String predictJson = TestMessages.DEFAULT_DATA; + assertNotNull(predictJson); + + MvcResult res = + mvc.perform( + MockMvcRequestBuilders.post("/predict") + .accept(MediaType.APPLICATION_JSON) + .content(predictJson) + .contentType(MediaType.APPLICATION_JSON_UTF8)) + .andReturn(); + + String response = res.getResponse().getContentAsString(); + System.out.println(response); + Assert.assertEquals(200, res.getResponse().getStatus()); + + // if we get back a header of "application/json" then we are hitting the legacy predict + Assert.assertEquals(res.getResponse().getContentType(), MediaType.APPLICATION_JSON_VALUE); + } + + @Test + public void testPredictWithDefaultData() throws Exception { + final String predictJson = TestMessages.DEFAULT_DATA; + assertNotNull(predictJson); + + MvcResult res = + mvc.perform( + MockMvcRequestBuilders.post("/predict") + .accept(MediaType.APPLICATION_JSON) + .content(predictJson) + .contentType(MediaType.APPLICATION_JSON_UTF8) + ) + .andDo(print()) + .andExpect(status().isOk()) + .andExpect(jsonPath("$.data", is(notNullValue()))) + .andReturn(); + + // if we get back a header of "application/json" then we are hitting the legacy predict + Assert.assertEquals(res.getResponse().getContentType(), MediaType.APPLICATION_JSON_VALUE); + } + + @Test + public void testPredictWithJsonData_UTF8Header() throws Exception { + final String predictJson = TestMessages.JSON_DATA; + assertNotNull(predictJson); + + MvcResult res = + mvc.perform( + MockMvcRequestBuilders.post("/predict") + .accept(MediaType.APPLICATION_JSON) + .content(predictJson) + .contentType(MediaType.APPLICATION_JSON_UTF8)) + .andExpect(status().isOk()) + .andExpect(jsonPath("$.jsonData", is(notNullValue()))) + .andReturn(); + + String response = res.getResponse().getContentAsString(); + System.out.println(response); + + // if we get back a header of "application/json" then we are hitting the legacy predict + Assert.assertEquals(res.getResponse().getContentType(), MediaType.APPLICATION_JSON_VALUE); + } + + @Test + public void testFeedback() throws Exception { + final String predictJson = TestMessages.DEFAULT_DATA; + assertNotNull(predictJson); + + MvcResult res = + mvc.perform( + MockMvcRequestBuilders.get("/send-feedback") + .accept(MediaType.APPLICATION_JSON_UTF8) + .param("json", predictJson) + .contentType(MediaType.APPLICATION_JSON_UTF8)) + .andReturn(); + + String response = res.getResponse().getContentAsString(); + System.out.println(response); + Assert.assertEquals(200, res.getResponse().getStatus()); + } +} diff --git a/incubating/wrappers/java/src/test/java/io/seldon/wrapper/api/RestControllerTest.java b/incubating/wrappers/java/src/test/java/io/seldon/wrapper/api/RestControllerTest.java deleted file mode 100644 index 1b7e4b064c..0000000000 --- a/incubating/wrappers/java/src/test/java/io/seldon/wrapper/api/RestControllerTest.java +++ /dev/null @@ -1,70 +0,0 @@ -package io.seldon.wrapper.api; - -import static io.seldon.wrapper.util.TestUtils.readFile; - -import java.nio.charset.StandardCharsets; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.boot.test.context.SpringBootTest; -import org.springframework.boot.test.context.SpringBootTest.WebEnvironment; -import org.springframework.boot.web.server.LocalServerPort; -import org.springframework.http.MediaType; -import org.springframework.test.context.junit4.SpringRunner; -import org.springframework.test.web.servlet.MockMvc; -import org.springframework.test.web.servlet.MvcResult; -import org.springframework.test.web.servlet.request.MockMvcRequestBuilders; -import org.springframework.test.web.servlet.setup.MockMvcBuilders; -import org.springframework.web.context.WebApplicationContext; - -@RunWith(SpringRunner.class) -@SpringBootTest(webEnvironment = WebEnvironment.RANDOM_PORT) -// @AutoConfigureMockMvc -public class RestControllerTest { - - @Autowired private WebApplicationContext context; - - @Autowired ModelRestController modelRestController; - - // @Autowired - private MockMvc mvc; - - @Before - public void setup() { - mvc = MockMvcBuilders.webAppContextSetup(context).build(); - } - - @LocalServerPort private int port; - - @Test - public void testPredict() throws Exception { - final String predictJson = readFile("src/test/resources/request.json", StandardCharsets.UTF_8); - MvcResult res = - mvc.perform( - MockMvcRequestBuilders.get("/predict") - .accept(MediaType.APPLICATION_JSON_UTF8) - .param("json", predictJson) - .contentType(MediaType.APPLICATION_JSON_UTF8)) - .andReturn(); - String response = res.getResponse().getContentAsString(); - System.out.println(response); - Assert.assertEquals(200, res.getResponse().getStatus()); - } - - @Test - public void testFeedback() throws Exception { - final String predictJson = readFile("src/test/resources/feedback.json", StandardCharsets.UTF_8); - MvcResult res = - mvc.perform( - MockMvcRequestBuilders.get("/send-feedback") - .accept(MediaType.APPLICATION_JSON_UTF8) - .param("json", predictJson) - .contentType(MediaType.APPLICATION_JSON_UTF8)) - .andReturn(); - String response = res.getResponse().getContentAsString(); - System.out.println(response); - Assert.assertEquals(200, res.getResponse().getStatus()); - } -} diff --git a/incubating/wrappers/java/src/test/java/io/seldon/wrapper/api/TestMessages.java b/incubating/wrappers/java/src/test/java/io/seldon/wrapper/api/TestMessages.java new file mode 100644 index 0000000000..ba5df1552d --- /dev/null +++ b/incubating/wrappers/java/src/test/java/io/seldon/wrapper/api/TestMessages.java @@ -0,0 +1,21 @@ +package io.seldon.wrapper.api; + +import static io.seldon.wrapper.util.TestUtils.readFileFromAbsolutePathOrResources; + +final public class TestMessages { + + /** + * All possible fields based on the SeldonMessage Proto: + * https://docs.seldon.io/projects/seldon-core/en/v1.6.0/reference/apis/prediction.html + */ + public static final String TF_DATA = readFile("requests/defaultData.json"); + public static final String DEFAULT_DATA = TF_DATA; + public static final String JSON_DATA = readFile("requests/jsonData.json"); + // TODO: add binData + // TODO: add strData + // TODO: add customData + + private static String readFile(String file) { + return readFileFromAbsolutePathOrResources(file); + } +} diff --git a/incubating/wrappers/java/src/test/java/io/seldon/wrapper/api/TestPredictionService.java b/incubating/wrappers/java/src/test/java/io/seldon/wrapper/api/TestPredictionService.java index 2877b085fb..2fad1b4847 100644 --- a/incubating/wrappers/java/src/test/java/io/seldon/wrapper/api/TestPredictionService.java +++ b/incubating/wrappers/java/src/test/java/io/seldon/wrapper/api/TestPredictionService.java @@ -3,14 +3,14 @@ import io.seldon.protos.PredictionProtos.DefaultData; import io.seldon.protos.PredictionProtos.SeldonMessage; import io.seldon.protos.PredictionProtos.Tensor; +import io.seldon.wrapper.pb.ProtoBufUtils; import org.springframework.stereotype.Component; @Component public class TestPredictionService implements SeldonPredictionService { @Override public SeldonMessage predict(SeldonMessage payload) { - return SeldonMessage.newBuilder() - .setData(DefaultData.newBuilder().setTensor(Tensor.newBuilder().addShape(1).addValues(1.0))) - .build(); + // echo payload back + return payload.toBuilder().build(); } } diff --git a/incubating/wrappers/java/src/test/java/io/seldon/wrapper/util/TestUtils.java b/incubating/wrappers/java/src/test/java/io/seldon/wrapper/util/TestUtils.java index 048761effb..87f25978aa 100644 --- a/incubating/wrappers/java/src/test/java/io/seldon/wrapper/util/TestUtils.java +++ b/incubating/wrappers/java/src/test/java/io/seldon/wrapper/util/TestUtils.java @@ -1,14 +1,51 @@ package io.seldon.wrapper.util; -import java.io.IOException; +import java.io.*; import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Paths; public class TestUtils { + private static final ClassLoader classLoader = TestUtils.class.getClassLoader(); + public static String readFile(String path, Charset encoding) throws IOException { byte[] encoded = Files.readAllBytes(Paths.get(path)); return new String(encoded, encoding); } + + /** + * Will load file from either an absolute path of a relative path from "target/test-classes" + * @param file file path (ex: "requests/jsonData.json", "/dev/null") + * @return + */ + public static String readFileFromAbsolutePathOrResources(String file) { + try { + InputStream is = getInputStreamFromAbsolutePathOrResources(file, classLoader); + byte[] bytes = is.readAllBytes(); + return new String(bytes, StandardCharsets.UTF_8); + } catch(Throwable t) { + System.out.println(t); + t.printStackTrace(); + // nothing + } + return null; + } + + public static InputStream getInputStreamFromAbsolutePathOrResources(String file, ClassLoader classLoader) { + InputStream is = null; + + // try loading assuming an absolute path + try { + is = new FileInputStream(file); + } catch ( FileNotFoundException fne ) { + // Nothing + } + if( is == null ) { + is = classLoader.getResourceAsStream(file); + } + + return is; + } } diff --git a/incubating/wrappers/java/src/test/resources/feedback.json b/incubating/wrappers/java/src/test/resources/feedback.json index eefde6f0af..134aa50c19 100644 --- a/incubating/wrappers/java/src/test/resources/feedback.json +++ b/incubating/wrappers/java/src/test/resources/feedback.json @@ -40,4 +40,4 @@ } }, "reward":1 -} \ No newline at end of file +} diff --git a/incubating/wrappers/java/src/test/resources/request.json b/incubating/wrappers/java/src/test/resources/requests/defaultData.json similarity index 96% rename from incubating/wrappers/java/src/test/resources/request.json rename to incubating/wrappers/java/src/test/resources/requests/defaultData.json index c02d8f332d..7d818e0e62 100644 --- a/incubating/wrappers/java/src/test/resources/request.json +++ b/incubating/wrappers/java/src/test/resources/requests/defaultData.json @@ -7,4 +7,4 @@ ] ] } -} \ No newline at end of file +} diff --git a/incubating/wrappers/java/src/test/resources/requests/jsonData.json b/incubating/wrappers/java/src/test/resources/requests/jsonData.json new file mode 100644 index 0000000000..73970991c1 --- /dev/null +++ b/incubating/wrappers/java/src/test/resources/requests/jsonData.json @@ -0,0 +1,9 @@ +{ + "jsonData": { + "data": { + "subject": "helpful message", + "body": "nothing strange, good, rewarding" + }, + "foo": "bar" + } +}