From b40751ba46fe39e0b0636ce376434bbf598b66be Mon Sep 17 00:00:00 2001 From: Gabriele Cardosi Date: Wed, 6 Oct 2021 13:51:12 +0200 Subject: [PATCH] [DROOLS-6627] Verify Scorecard reasoncode when result is null (#3875) * [DROOLS-6625] Managing missing "required" input data * [DROOLS-6625] Managing best-effort conversion of input data * [DROOLS-6625] Managing invalid values - TODO: integration tests * [DROOLS-6625] Managing invalid values * [DROOLS-6625] Managing missing values * [DROOLS-6625] Validate input data * [DROOLS-6635] Move testing sources in specific files * [DROOLS-6625] Fix merge with base branch * [DROOLS-6625] Fix as per PR suggestion * [DROOLS-6635] Fix merge * [DROOLS-6627] Verify Scorecard reasoncode when result is null * [DROOLS-6635] Fix merge with 7.x * [DROOLS-6627] Fix merge * [DROOLS-6627] Fixed as per PR request --- .../commons/model/KiePMMLDroolsModel.java | 2 ++ .../tests/SimpleScorecardCategoricalTest.java | 20 ++++++++++++++---- .../model/KiePMMLScorecardModel.java | 11 +++++----- .../tests/SimpleScorecardCategoricalTest.java | 21 +++++++++++++++---- 4 files changed, 41 insertions(+), 13 deletions(-) diff --git a/kie-pmml-trusty/kie-pmml-models/kie-pmml-models-drools/kie-pmml-models-drools-common/src/main/java/org/kie/pmml/models/drools/commons/model/KiePMMLDroolsModel.java b/kie-pmml-trusty/kie-pmml-models/kie-pmml-models-drools/kie-pmml-models-drools-common/src/main/java/org/kie/pmml/models/drools/commons/model/KiePMMLDroolsModel.java index 225b80206af..045ee2b7647 100644 --- a/kie-pmml-trusty/kie-pmml-models/kie-pmml-models-drools/kie-pmml-models-drools-common/src/main/java/org/kie/pmml/models/drools/commons/model/KiePMMLDroolsModel.java +++ b/kie-pmml-trusty/kie-pmml-models/kie-pmml-models-drools/kie-pmml-models-drools-common/src/main/java/org/kie/pmml/models/drools/commons/model/KiePMMLDroolsModel.java @@ -49,6 +49,7 @@ public abstract class KiePMMLDroolsModel extends KiePMMLModel implements IsDrool private static final Logger logger = LoggerFactory.getLogger(KiePMMLDroolsModel.class); private static final AgendaEventListener agendaEventListener = getAgendaEventListener(logger); + private static final long serialVersionUID = 5471400949048174357L; /** * Map between the original field name and the generated type. @@ -75,6 +76,7 @@ public Object evaluate(final Object knowledgeBase, Map requestDa String fullClassName = this.getClass().getName(); String packageName = fullClassName.contains(".") ? fullClassName.substring(0, fullClassName.lastIndexOf('.')) : ""; + outputFieldsMap.clear(); KiePMMLSessionUtils.Builder builder = KiePMMLSessionUtils.builder((KieBase) knowledgeBase, name, packageName, toReturn) .withObjectsInSession(requestData, fieldTypeMap) diff --git a/kie-pmml-trusty/kie-pmml-models/kie-pmml-models-drools/kie-pmml-models-drools-scorecard/kie-pmml-models-drools-scorecard-tests/src/test/java/org/kie/pmml/models/drools/scorecard/tests/SimpleScorecardCategoricalTest.java b/kie-pmml-trusty/kie-pmml-models/kie-pmml-models-drools/kie-pmml-models-drools-scorecard/kie-pmml-models-drools-scorecard-tests/src/test/java/org/kie/pmml/models/drools/scorecard/tests/SimpleScorecardCategoricalTest.java index 309a32e5ea4..a2a9e6a586f 100644 --- a/kie-pmml-trusty/kie-pmml-models/kie-pmml-models-drools/kie-pmml-models-drools-scorecard/kie-pmml-models-drools-scorecard-tests/src/test/java/org/kie/pmml/models/drools/scorecard/tests/SimpleScorecardCategoricalTest.java +++ b/kie-pmml-trusty/kie-pmml-models/kie-pmml-models-drools/kie-pmml-models-drools-scorecard/kie-pmml-models-drools-scorecard-tests/src/test/java/org/kie/pmml/models/drools/scorecard/tests/SimpleScorecardCategoricalTest.java @@ -16,7 +16,6 @@ package org.kie.pmml.models.drools.scorecard.tests; -import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.HashMap; @@ -35,6 +34,8 @@ import org.kie.pmml.api.runtime.PMMLRuntime; import org.kie.pmml.models.tests.AbstractPMMLTest; +import static org.junit.Assert.assertFalse; + @RunWith(Parameterized.class) public class SimpleScorecardCategoricalTest extends AbstractPMMLTest { @@ -43,7 +44,7 @@ public class SimpleScorecardCategoricalTest extends AbstractPMMLTest { private static final String TARGET_FIELD = "Score"; private static final String REASON_CODE1_FIELD = "Reason Code 1"; private static final String REASON_CODE2_FIELD = "Reason Code 2"; - private static final String[] CATEGORY = new String[] { "classA", "classB", "classC", "classD", "classE", "NA" }; + private static final String[] CATEGORY = new String[]{"classA", "classB", "classC", "classD", "classE", "NA"}; private static PMMLRuntime pmmlRuntime; private String input1; @@ -52,7 +53,8 @@ public class SimpleScorecardCategoricalTest extends AbstractPMMLTest { private String reasonCode1; private String reasonCode2; - public SimpleScorecardCategoricalTest(String input1, String input2, double score, String reasonCode1, String reasonCode2) { + public SimpleScorecardCategoricalTest(String input1, String input2, double score, String reasonCode1, + String reasonCode2) { this.input1 = input1; this.input2 = input2; this.score = score; @@ -60,7 +62,7 @@ public SimpleScorecardCategoricalTest(String input1, String input2, double score this.reasonCode2 = reasonCode2; } - @BeforeClass + @BeforeClass public static void setupClass() { pmmlRuntime = getPMMLRuntime(FILE_NAME); } @@ -93,6 +95,16 @@ public void testSimpleScorecardCategoricalVerifyNoException() { getSamples().stream().map(sample -> evaluate(pmmlRuntime, sample, MODEL_NAME)).forEach(Assert::assertNotNull); } + @Test + public void testSimpleScorecardCategoricalVerifyNoReasonCodeWithoutScore() { + getSamples().stream().map(sample -> evaluate(pmmlRuntime, sample, MODEL_NAME)) + .filter(pmml4Result -> pmml4Result.getResultVariables().get(TARGET_FIELD) == null) + .forEach(pmml4Result -> { + assertFalse(pmml4Result.getResultVariables().containsKey(REASON_CODE1_FIELD)); + assertFalse(pmml4Result.getResultVariables().containsKey(REASON_CODE2_FIELD)); + }); + } + private List> getSamples() { return IntStream.range(0, 10).boxed().map(i -> new HashMap() {{ put("input1", CATEGORY[i % CATEGORY.length]); diff --git a/kie-pmml-trusty/kie-pmml-models/kie-pmml-models-scorecard/kie-pmml-models-scorecard-model/src/main/java/org/kie/pmml/models/scorecard/model/KiePMMLScorecardModel.java b/kie-pmml-trusty/kie-pmml-models/kie-pmml-models-scorecard/kie-pmml-models-scorecard-model/src/main/java/org/kie/pmml/models/scorecard/model/KiePMMLScorecardModel.java index 480a3756dba..cde86c6ddf7 100644 --- a/kie-pmml-trusty/kie-pmml-models/kie-pmml-models-scorecard/kie-pmml-models-scorecard-model/src/main/java/org/kie/pmml/models/scorecard/model/KiePMMLScorecardModel.java +++ b/kie-pmml-trusty/kie-pmml-models/kie-pmml-models-scorecard/kie-pmml-models-scorecard-model/src/main/java/org/kie/pmml/models/scorecard/model/KiePMMLScorecardModel.java @@ -66,12 +66,13 @@ public Object evaluate(final Object knowledgeBase, final Map req if (localTransformations != null) { derivedFields.addAll(localTransformations.getDerivedFields()); } + outputFieldsMap.clear(); return characteristics.evaluate(defineFunctions, derivedFields, kiePMMLOutputFields, requestData, - outputFieldsMap, - initialScore, - reasonCodeAlgorithm, - useReasonCodes, - baselineScore).orElse(null); + outputFieldsMap, + initialScore, + reasonCodeAlgorithm, + useReasonCodes, + baselineScore).orElse(null); } @Override diff --git a/kie-pmml-trusty/kie-pmml-models/kie-pmml-models-scorecard/kie-pmml-models-scorecard-tests/src/test/java/org/kie/pmml/models/scorecard/tests/SimpleScorecardCategoricalTest.java b/kie-pmml-trusty/kie-pmml-models/kie-pmml-models-scorecard/kie-pmml-models-scorecard-tests/src/test/java/org/kie/pmml/models/scorecard/tests/SimpleScorecardCategoricalTest.java index 410e3d99a82..12a41e22405 100644 --- a/kie-pmml-trusty/kie-pmml-models/kie-pmml-models-scorecard/kie-pmml-models-scorecard-tests/src/test/java/org/kie/pmml/models/scorecard/tests/SimpleScorecardCategoricalTest.java +++ b/kie-pmml-trusty/kie-pmml-models/kie-pmml-models-scorecard/kie-pmml-models-scorecard-tests/src/test/java/org/kie/pmml/models/scorecard/tests/SimpleScorecardCategoricalTest.java @@ -34,6 +34,8 @@ import org.kie.pmml.api.runtime.PMMLRuntime; import org.kie.pmml.models.tests.AbstractPMMLTest; +import static org.junit.Assert.assertFalse; + @RunWith(Parameterized.class) public class SimpleScorecardCategoricalTest extends AbstractPMMLTest { @@ -42,7 +44,7 @@ public class SimpleScorecardCategoricalTest extends AbstractPMMLTest { private static final String TARGET_FIELD = "Score"; private static final String REASON_CODE1_FIELD = "Reason Code 1"; private static final String REASON_CODE2_FIELD = "Reason Code 2"; - private static final String[] CATEGORY = new String[] { "classA", "classB", "classC", "classD", "classE", "NA" }; + private static final String[] CATEGORY = new String[]{"classA", "classB", "classC", "classD", "classE", "NA"}; private static PMMLRuntime pmmlRuntime; private String input1; @@ -51,7 +53,8 @@ public class SimpleScorecardCategoricalTest extends AbstractPMMLTest { private String reasonCode1; private String reasonCode2; - public SimpleScorecardCategoricalTest(String input1, String input2, double score, String reasonCode1, String reasonCode2) { + public SimpleScorecardCategoricalTest(String input1, String input2, double score, String reasonCode1, + String reasonCode2) { this.input1 = input1; this.input2 = input2; this.score = score; @@ -89,7 +92,18 @@ public void testSimpleScorecardCategorical() { @Test public void testSimpleScorecardCategoricalVerifyNoException() { - getSamples().stream().map(sample -> evaluate(pmmlRuntime, sample, MODEL_NAME)).forEach(Assert::assertNotNull); + getSamples().stream().map(sample -> evaluate(pmmlRuntime, sample, MODEL_NAME)) + .forEach(Assert::assertNotNull); + } + + @Test + public void testSimpleScorecardCategoricalVerifyNoReasonCodeWithoutScore() { + getSamples().stream().map(sample -> evaluate(pmmlRuntime, sample, MODEL_NAME)) + .filter(pmml4Result -> pmml4Result.getResultVariables().get(TARGET_FIELD) == null) + .forEach(pmml4Result -> { + assertFalse(pmml4Result.getResultVariables().containsKey(REASON_CODE1_FIELD)); + assertFalse(pmml4Result.getResultVariables().containsKey(REASON_CODE2_FIELD)); + }); } private List> getSamples() { @@ -98,5 +112,4 @@ private List> getSamples() { put("input2", CATEGORY[Math.abs(CATEGORY.length - i) % CATEGORY.length]); }}).collect(Collectors.toList()); } - }