Skip to content

Commit

Permalink
Merge pull request #3211 from amoldavsky/assaf-java-wrapper-api-fix
Browse files Browse the repository at this point in the history
Java Wrapper JSON POST API regression fix
  • Loading branch information
ukclivecox authored and RafalSkolasinski committed May 20, 2021
1 parent 4ca1a71 commit 501732c
Show file tree
Hide file tree
Showing 10 changed files with 384 additions and 87 deletions.
2 changes: 1 addition & 1 deletion incubating/wrappers/java/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xs
<groupId>io.seldon.wrapper</groupId>
<artifactId>seldon-core-wrapper</artifactId>
<packaging>jar</packaging>
<version>0.4.1</version>
<version>0.4.2</version>
<name>Seldon Core Java Wrapper</name>
<url>http://maven.apache.org</url>
<description>Wrapper for seldon-core Java prediction models.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
package io.seldon.wrapper.api;

/**
* This is a model prediction API and container everything related to predictions.
*
* NOTE:
* This is NOT a RestFull API, since there is no resource or state (aka there is no prediction object or state)
* involved.
*/

import com.google.protobuf.InvalidProtocolBufferException;
import io.seldon.protos.PredictionProtos.Feedback;
import io.seldon.protos.PredictionProtos.SeldonMessage;
Expand All @@ -11,32 +19,79 @@
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.autoconfigure.condition.ConditionalOnExpression;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestMethod;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.bind.annotation.*;

@RestController
@ConditionalOnExpression("${seldon.api.model.enabled:false}")
public class ModelRestController {
private static Logger logger = LoggerFactory.getLogger(ModelRestController.class.getName());
public class ModelPredictionController {
private static Logger logger = LoggerFactory.getLogger(ModelPredictionController.class.getName());

@Autowired SeldonPredictionService predictionService;

/**
* Will accept a POST or a GET request with either a query parameter or a FORM parameter.
*
* Examples:
* GET -> /predict?json={ ... }
* curl -s \
* localhost:9000/predict?json={"data": {"names": ["a", "b"], "ndarray": [[1.0, 2.0]]}}' \
*
* POST FORM -> /predict
* curl -s -X POST \
* -d 'json={"data": {"names": ["a", "b"], "ndarray": [[1.0, 2.0]]}}' \
* localhost:9000/predict
*
* @param json
* @return
* @deprecated
*/
@Deprecated
@RequestMapping(
value = "/predict",
method = {RequestMethod.GET, RequestMethod.POST},
produces = "application/json; charset=utf-8")
public ResponseEntity<String> predictions(@RequestParam("json") String json) {
params = {"json"},
produces = MediaType.APPLICATION_JSON_UTF8_VALUE
)
public ResponseEntity<String> predictLegacy(@RequestParam("json") String json) {
return this.predict(json);
}

/**
* Will accept a POST with a proper JSON body.
*
* Examples:
* POST -> /predict
* curl -s -X POST \
* -d '{"data": {"names": ["a", "b"], "ndarray": [[1.0, 2.0]]}}' \
* localhost:9000/predict
*
* curl -s -X POST \
* -d '{"jsonData": {"foo": "bar"}' \
* localhost:9000/predict
*
* @param jsonStr
* @return
*/
@RequestMapping(
value = "/predict",
method = {RequestMethod.POST},
consumes = {
MediaType.APPLICATION_JSON_VALUE,
MediaType.APPLICATION_JSON_UTF8_VALUE
},
produces = MediaType.APPLICATION_JSON_VALUE
)
public ResponseEntity<String> predict(@RequestBody String jsonStr) {
SeldonMessage request;
try {
SeldonMessage.Builder builder = SeldonMessage.newBuilder();
ProtoBufUtils.updateMessageBuilderFromJson(builder, json);
ProtoBufUtils.updateMessageBuilderFromJson(builder, jsonStr);
request = builder.build();
} catch (InvalidProtocolBufferException e) {
logger.error("Bad request", e);
throw new APIException(ApiExceptionType.WRAPPER_INVALID_MESSAGE, json);
throw new APIException(ApiExceptionType.WRAPPER_INVALID_MESSAGE, jsonStr);
}

try {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
package io.seldon.wrapper.api;

import static io.seldon.wrapper.util.TestUtils.readFile;
import static org.hamcrest.Matchers.*;
import static org.junit.Assert.assertNotNull;
import static org.springframework.test.web.servlet.result.MockMvcResultHandlers.print;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath;
import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status;

import java.nio.charset.StandardCharsets;
import org.junit.Assert;
import org.junit.Before;
import org.junit.FixMethodOrder;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.MethodSorters;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.context.SpringBootTest.WebEnvironment;
import org.springframework.boot.web.server.LocalServerPort;
import org.springframework.http.MediaType;
import org.springframework.test.context.junit4.SpringRunner;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.MvcResult;
import org.springframework.test.web.servlet.request.MockMvcRequestBuilders;
import org.springframework.test.web.servlet.setup.MockMvcBuilders;
import org.springframework.web.context.WebApplicationContext;

@RunWith(SpringRunner.class)
@SpringBootTest(webEnvironment = WebEnvironment.RANDOM_PORT)
@FixMethodOrder(MethodSorters.NAME_ASCENDING)
// @AutoConfigureMockMvc
public class ModelPredictionControllerTest {

@Autowired private WebApplicationContext context;

@Autowired
ModelPredictionController modelPredictionController;

// @Autowired
private MockMvc mvc;

@Before
public void setup() {
mvc = MockMvcBuilders.webAppContextSetup(context).build();
}

@LocalServerPort private int port;

@Test
public void testPredictLegacyGetQuery() throws Exception {
final String predictJson = TestMessages.DEFAULT_DATA;
assertNotNull(predictJson);

MvcResult res =
mvc.perform(
MockMvcRequestBuilders.get("/predict")
.accept(MediaType.APPLICATION_JSON_UTF8)
.param("json", predictJson)
.contentType(MediaType.APPLICATION_JSON_UTF8))
.andReturn();

String response = res.getResponse().getContentAsString();
System.out.println(response);
Assert.assertEquals(200, res.getResponse().getStatus());
}

@Test
public void testPredictLegacyPostQuery() throws Exception {
final String predictJson = TestMessages.DEFAULT_DATA;
assertNotNull(predictJson);

MvcResult res =
mvc.perform(
MockMvcRequestBuilders.post("/predict")
.accept(MediaType.APPLICATION_JSON_UTF8)
.queryParam("json", predictJson)
.contentType(MediaType.APPLICATION_JSON_UTF8))
.andReturn();

String response = res.getResponse().getContentAsString();
System.out.println(response);
Assert.assertEquals(200, res.getResponse().getStatus());
}

@Test
public void testPredictLegacyPostForm() throws Exception {
final String predictJson = TestMessages.DEFAULT_DATA;
assertNotNull(predictJson);

MvcResult res =
mvc.perform(
MockMvcRequestBuilders.post("/predict")
.accept(MediaType.APPLICATION_JSON_UTF8)
.param("json", predictJson)
.contentType(MediaType.APPLICATION_FORM_URLENCODED))
.andReturn();

String response = res.getResponse().getContentAsString();
System.out.println(response);
Assert.assertEquals(200, res.getResponse().getStatus());
}

@Test
public void testPredictLegacyButNotPredict() throws Exception {
final String predictJson = TestMessages.DEFAULT_DATA;
assertNotNull(predictJson);

MvcResult res =
mvc.perform(
MockMvcRequestBuilders.post("/predict")
.accept(MediaType.APPLICATION_JSON_UTF8)
.param("json", predictJson)
.content(predictJson)
.contentType(MediaType.APPLICATION_FORM_URLENCODED))
.andReturn();

String response = res.getResponse().getContentAsString();
System.out.println(response);
Assert.assertEquals(200, res.getResponse().getStatus());

// if we get back a header of "application/json;charset=UTF-8" then we are hitting the legacy predict
Assert.assertEquals(res.getResponse().getContentType(), MediaType.APPLICATION_JSON_UTF8_VALUE);
}

@Test
public void testPredictButNotPredictLegacy() throws Exception {
final String predictJson = TestMessages.DEFAULT_DATA;
assertNotNull(predictJson);

MvcResult res =
mvc.perform(
MockMvcRequestBuilders.post("/predict")
.accept(MediaType.APPLICATION_JSON)
.content(predictJson)
.contentType(MediaType.APPLICATION_JSON_UTF8))
.andReturn();
String response = res.getResponse().getContentAsString();
System.out.println(response);
Assert.assertEquals(200, res.getResponse().getStatus());

// if we get back a header of "application/json" then we are hitting the legacy predict
Assert.assertEquals(res.getResponse().getContentType(), MediaType.APPLICATION_JSON_VALUE);
}

@Test
public void testPredict() throws Exception {
final String predictJson = TestMessages.DEFAULT_DATA;
assertNotNull(predictJson);

MvcResult res =
mvc.perform(
MockMvcRequestBuilders.post("/predict")
.accept(MediaType.APPLICATION_JSON)
.content(predictJson)
.contentType(MediaType.APPLICATION_JSON))
.andReturn();

String response = res.getResponse().getContentAsString();
System.out.println(response);
Assert.assertEquals(200, res.getResponse().getStatus());
}

@Test
public void testPredictWithUTF8Header() throws Exception {
final String predictJson = TestMessages.DEFAULT_DATA;
assertNotNull(predictJson);

MvcResult res =
mvc.perform(
MockMvcRequestBuilders.post("/predict")
.accept(MediaType.APPLICATION_JSON)
.content(predictJson)
.contentType(MediaType.APPLICATION_JSON_UTF8))
.andReturn();

String response = res.getResponse().getContentAsString();
System.out.println(response);
Assert.assertEquals(200, res.getResponse().getStatus());

// if we get back a header of "application/json" then we are hitting the legacy predict
Assert.assertEquals(res.getResponse().getContentType(), MediaType.APPLICATION_JSON_VALUE);
}

@Test
public void testPredictWithDefaultData() throws Exception {
final String predictJson = TestMessages.DEFAULT_DATA;
assertNotNull(predictJson);

MvcResult res =
mvc.perform(
MockMvcRequestBuilders.post("/predict")
.accept(MediaType.APPLICATION_JSON)
.content(predictJson)
.contentType(MediaType.APPLICATION_JSON_UTF8)
)
.andDo(print())
.andExpect(status().isOk())
.andExpect(jsonPath("$.data", is(notNullValue())))
.andReturn();

// if we get back a header of "application/json" then we are hitting the legacy predict
Assert.assertEquals(res.getResponse().getContentType(), MediaType.APPLICATION_JSON_VALUE);
}

@Test
public void testPredictWithJsonData_UTF8Header() throws Exception {
final String predictJson = TestMessages.JSON_DATA;
assertNotNull(predictJson);

MvcResult res =
mvc.perform(
MockMvcRequestBuilders.post("/predict")
.accept(MediaType.APPLICATION_JSON)
.content(predictJson)
.contentType(MediaType.APPLICATION_JSON_UTF8))
.andExpect(status().isOk())
.andExpect(jsonPath("$.jsonData", is(notNullValue())))
.andReturn();

String response = res.getResponse().getContentAsString();
System.out.println(response);

// if we get back a header of "application/json" then we are hitting the legacy predict
Assert.assertEquals(res.getResponse().getContentType(), MediaType.APPLICATION_JSON_VALUE);
}

@Test
public void testFeedback() throws Exception {
final String predictJson = TestMessages.DEFAULT_DATA;
assertNotNull(predictJson);

MvcResult res =
mvc.perform(
MockMvcRequestBuilders.get("/send-feedback")
.accept(MediaType.APPLICATION_JSON_UTF8)
.param("json", predictJson)
.contentType(MediaType.APPLICATION_JSON_UTF8))
.andReturn();

String response = res.getResponse().getContentAsString();
System.out.println(response);
Assert.assertEquals(200, res.getResponse().getStatus());
}
}
Loading

0 comments on commit 501732c

Please sign in to comment.