Skip to content

Commit

Permalink
explainability-addon to use spring web api
Browse files Browse the repository at this point in the history
  • Loading branch information
tiagodolphine committed Nov 16, 2020
1 parent 764d797 commit 93bfcb2
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@

<dependencies>
<dependency>
<groupId>org.jboss.resteasy</groupId>
<artifactId>resteasy-spring-boot-starter</artifactId>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.kie.kogito</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,20 @@

import java.util.List;

import javax.ws.rs.Consumes;
import javax.ws.rs.POST;
import javax.ws.rs.Path;
import javax.ws.rs.Produces;
import javax.ws.rs.core.MediaType;
import javax.ws.rs.core.Response;

import org.kie.kogito.Application;
import org.kie.kogito.explainability.model.PredictInput;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;

@Path("/predict")
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;

@RestController
@RequestMapping("/predict")
public class SpringBootExplainableResource {

private static final Logger LOGGER = LoggerFactory.getLogger(SpringBootExplainableResource.class);
Expand All @@ -44,15 +44,13 @@ public SpringBootExplainableResource(Application application) {
this.application = application;
}

@POST
@Consumes(MediaType.APPLICATION_JSON)
@Produces(MediaType.APPLICATION_JSON)
public Response predict(List<PredictInput> inputs) {
@PostMapping(consumes = MediaType.APPLICATION_JSON_VALUE, produces = MediaType.APPLICATION_JSON_VALUE)
public ResponseEntity predict(@RequestBody List<PredictInput> inputs) {
try {
return Response.ok(explainabilityService.processRequest(application, inputs)).build();
return ResponseEntity.ok(explainabilityService.processRequest(application, inputs));
} catch (Exception e) {
LOGGER.warn("An Exception occurred processing the predict request", e);
return Response.status(Response.Status.BAD_REQUEST).entity(e.getMessage()).build();
return ResponseEntity.badRequest().body(e.getMessage());
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,12 @@
import java.util.List;
import java.util.Map;

import javax.ws.rs.core.GenericType;
import javax.ws.rs.core.Response;

import org.junit.jupiter.api.Test;
import org.kie.kogito.explainability.model.ModelIdentifier;
import org.kie.kogito.explainability.model.PredictInput;
import org.kie.kogito.explainability.model.PredictOutput;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;

import static java.util.Arrays.asList;
import static java.util.Collections.emptyList;
Expand All @@ -49,8 +47,7 @@ class SpringBootExplainableResourceTest {
void explainServiceTest() {
List<PredictInput> inputs = singletonList(createInput(40));

List<PredictOutput> outputs = resource.predict(inputs).readEntity(new GenericType<List<PredictOutput>>() {
});
List<PredictOutput> outputs = (List<PredictOutput>) resource.predict(inputs).getBody();

assertNotNull(outputs);
assertEquals(1, outputs.size());
Expand All @@ -75,8 +72,7 @@ void explainServiceTest() {
void explainServiceTestMultipleInputs() {
List<PredictInput> inputs = asList(createInput(40), createInput(120));

List<PredictOutput> outputs = resource.predict(inputs).readEntity(new GenericType<List<PredictOutput>>() {
});
List<PredictOutput> outputs = (List<PredictOutput>) resource.predict(inputs).getBody();

assertNotNull(outputs);
assertEquals(2, outputs.size());
Expand All @@ -95,8 +91,7 @@ void explainServiceTestMultipleInputs() {

@Test
void explainServiceTestNoInputs() {
List<PredictOutput> outputs = resource.predict(emptyList()).readEntity(new GenericType<List<PredictOutput>>() {
});
List<PredictOutput> outputs = (List<PredictOutput>) resource.predict(emptyList()).getBody();

assertNotNull(outputs);
assertEquals(0, outputs.size());
Expand All @@ -107,10 +102,10 @@ void explainServiceFail() {
String unknownwResourceId = "unknown:model";
PredictInput input = createInput(10);
input.getModelIdentifier().setResourceId(unknownwResourceId);
Response responseEntity = resource.predict(singletonList(input));
ResponseEntity responseEntity = resource.predict(singletonList(input));

assertEquals(HttpStatus.BAD_REQUEST.value(), responseEntity.getStatus());
assertEquals("Model " + unknownwResourceId + " not found.", responseEntity.readEntity(String.class));
assertEquals(HttpStatus.BAD_REQUEST.value(), responseEntity.getStatusCodeValue());
assertEquals("Model " + unknownwResourceId + " not found.", responseEntity.getBody());
}

private PredictInput createInput(int speedLimit) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
server.address=0.0.0.0

#static content generated by swagger-maven-plugin
spring.mvc.servlet.path=/docs

resteasy.jaxrs.scan-packages=org.kie.**,\${groupId},http*

kogito.service.url=http://localhost:8080
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public DecisionRestResourceGenerator(DMNModel model, String appCanonicalName) {
this.packageName = CodegenStringUtil.escapeIdentifier(model.getNamespace());
this.decisionId = model.getDefinitions().getId();
this.decisionName = CodegenStringUtil.escapeIdentifier(model.getName());
this.nameURL = URLEncoder.encode(model.getName()).replace("+", "%20");
this.nameURL = URLEncoder.encode(model.getName()).replace("+", " ");
this.appCanonicalName = appCanonicalName;
String classPrefix = StringUtils.ucFirst(decisionName);
this.resourceClazzName = classPrefix + "Resource";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
import org.kie.kogito.codegen.GeneratorContext;
import org.kie.kogito.codegen.KogitoPackageSources;
import org.kie.kogito.codegen.DashboardGeneratedFileUtils;
import org.kie.kogito.codegen.di.CDIDependencyInjectionAnnotator;
import org.kie.kogito.codegen.di.DependencyInjectionAnnotator;
import org.kie.kogito.codegen.io.CollectedResource;
import org.kie.kogito.codegen.rules.config.NamedRuleUnitConfig;
Expand Down Expand Up @@ -317,7 +318,7 @@ private void generateRuleUnits( List<DroolsError> errors, List<org.kie.kogito.co
RuleUnitHelper ruleUnitHelper = new RuleUnitHelper();

//TODO: use template to support spring
if (annotator != null) {
if (annotator != null && annotator instanceof CDIDependencyInjectionAnnotator) {
generatedFiles.add( new org.kie.kogito.codegen.GeneratedFile( org.kie.kogito.codegen.GeneratedFile.Type.JSON_MAPPER,
packageName.replace('.', '/') + "/KogitoObjectMapper.java", annotator.objectMapperInjectorSource(packageName) ) );
}
Expand Down

0 comments on commit 93bfcb2

Please sign in to comment.