From 0c6bfecb386f7e937ffcca8a1319d6cdaed11da7 Mon Sep 17 00:00:00 2001 From: Richard Eckart de Castilho Date: Sat, 25 Nov 2023 23:23:23 +0100 Subject: [PATCH] #4292 - ollama-based recommender - Refactored prompt generation and response extraction --- .../imls/ollama/OllamaRecommender.java | 260 +++--------------- .../imls/ollama/OllamaRecommenderTraits.java | 1 + .../recommendation/imls/ollama/Preset.java | 1 + .../imls/ollama/PromptingModeSelect.java | 2 + .../PerAnnotationBindingsGenerator.java | 48 ++++ .../prompt/PerDocumentBindingsGenerator.java | 36 +++ .../prompt/PerSentenceBindingsGenerator.java | 43 +++ .../prompt/PromptBindingsGenerator.java | 31 +++ .../imls/ollama/prompt/PromptContext.java | 50 ++++ .../ollama/{ => prompt}/PromptingMode.java | 2 +- .../response/MentionsFromJsonExtractor.java | 158 +++++++++++ .../response/ResponseAsLabelExtractor.java | 52 ++++ .../ollama/response/ResponseExtractor.java | 29 ++ .../imls/ollama/OllamaRecommenderTest.java | 1 + .../ollama/response/JsonExtractorTest.java | 100 +++++++ .../api/recommender/RecommendationEngine.java | 6 +- 16 files changed, 590 insertions(+), 230 deletions(-) create mode 100644 inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/prompt/PerAnnotationBindingsGenerator.java create mode 100644 inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/prompt/PerDocumentBindingsGenerator.java create mode 100644 inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/prompt/PerSentenceBindingsGenerator.java create mode 100644 inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/prompt/PromptBindingsGenerator.java create mode 100644 inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/prompt/PromptContext.java rename inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/{ => prompt}/PromptingMode.java (93%) create mode 100644 inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/response/MentionsFromJsonExtractor.java create mode 100644 inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/response/ResponseAsLabelExtractor.java create mode 100644 inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/response/ResponseExtractor.java create mode 100644 inception/inception-imls-ollama/src/test/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/response/JsonExtractorTest.java diff --git a/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/OllamaRecommender.java b/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/OllamaRecommender.java index 4bb3f2fe5f5..cb82d80a99a 100644 --- a/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/OllamaRecommender.java +++ b/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/OllamaRecommender.java @@ -17,20 +17,12 @@ */ package de.tudarmstadt.ukp.inception.recommendation.imls.ollama; -import static de.tudarmstadt.ukp.clarin.webanno.api.annotation.util.WebAnnoCasUtil.selectOverlapping; - import java.io.IOException; import java.lang.invoke.MethodHandles; import java.nio.charset.Charset; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; import org.apache.commons.lang3.exception.ExceptionUtils; -import org.apache.commons.lang3.tuple.Pair; import org.apache.uima.cas.CAS; -import org.apache.uima.cas.text.AnnotationFS; -import org.apache.uima.fit.util.CasUtil; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -40,23 +32,24 @@ import com.hubspot.jinjava.loader.ResourceLocator; import com.hubspot.jinjava.loader.ResourceNotFoundException; -import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Sentence; import de.tudarmstadt.ukp.inception.recommendation.api.model.Recommender; import de.tudarmstadt.ukp.inception.recommendation.api.recommender.NonTrainableRecommenderEngineImplBase; import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationException; import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommenderContext; import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.client.OllamaClient; import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.client.OllamaGenerateRequest; +import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.prompt.PerAnnotationBindingsGenerator; +import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.prompt.PerDocumentBindingsGenerator; +import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.prompt.PerSentenceBindingsGenerator; +import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.prompt.PromptBindingsGenerator; +import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.prompt.PromptContext; +import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.response.MentionsFromJsonExtractor; +import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.response.ResponseAsLabelExtractor; import de.tudarmstadt.ukp.inception.rendering.model.Range; -import de.tudarmstadt.ukp.inception.support.json.JSONUtil; public class OllamaRecommender extends NonTrainableRecommenderEngineImplBase { - private static final String VAR_TEXT = "text"; - private static final String VAR_SENTENCE = "sentence"; - private static final String VAR_DOCUMENT = "document"; - private final static Logger LOG = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); private final OllamaRecommenderTraits traits; @@ -93,239 +86,39 @@ public Range predict(RecommenderContext aContext, CAS aCas, int aBegin, int aEnd { switch (traits.getPromptingMode()) { case PER_ANNOTATION: - return predictPerAnnotation(aContext, aCas, aBegin, aEnd); + return predict(new PerAnnotationBindingsGenerator(), aContext, aCas, aBegin, aEnd); case PER_SENTENCE: - return predictPerSentence(aContext, aCas, aBegin, aEnd); + return predict(new PerSentenceBindingsGenerator(), aContext, aCas, aBegin, aEnd); case PER_DOCUMENT: - return predictPerDocument(aContext, aCas, aBegin, aEnd); + return predict(new PerDocumentBindingsGenerator(), aContext, aCas, aBegin, aEnd); default: throw new RecommendationException( "Unsupported mode [" + traits.getPromptingMode() + "]"); } } - private Range predictPerDocument(RecommenderContext aContext, CAS aCas, int aBegin, int aEnd) - { - var bindings = Map.of(VAR_TEXT, aCas.getDocumentText()); - var prompt = jinjava.render(traits.getPrompt(), bindings); - - try { - var candidate = aCas.getDocumentAnnotation(); - - var response = generate(prompt); - - extractPredictions(candidate, response); - } - catch (IOException e) { - LOG.error("Ollama [{}] failed to respond: {}", traits.getModel(), - ExceptionUtils.getRootCauseMessage(e)); - } - - return new Range(aBegin, aEnd); - } - - private void extractPredictions(AnnotationFS aCandidate, String aResponse) - { - switch (traits.getExtractionMode()) { - case RESPONSE_AS_LABEL: - predictResultAsLabel(aCandidate, aResponse); - break; - case MENTIONS_FROM_JSON: - var mentions = extractMentionFromJson(aCandidate, aResponse); - mentionsToPredictions(aCandidate, mentions); - break; - default: - throw new IllegalArgumentException( - "Unsupported extraction mode [" + traits.getExtractionMode() + "]"); - } - } - - private ArrayList> extractMentionFromJson(AnnotationFS aCandidate, - String aResponse) - { - var mentions = new ArrayList>(); - try { - // Ollama JSON mode always returns a JSON object - // See: - // https://github.com/jmorganca/ollama/commit/5cba29b9d666854706a194805c9d66518fe77545#diff-a604f7ba9b7f66dd7b59a9e884d3c82c96e5269fee85c906a7cca5f0c3eff7f8R30-R57 - var rootNode = JSONUtil.getObjectMapper().readTree(aResponse); - - var fieldIterator = rootNode.fields(); - while (fieldIterator.hasNext()) { - var fieldEntry = fieldIterator.next(); - if (fieldEntry.getValue().isArray()) { - for (var item : fieldEntry.getValue()) { - if (item.isTextual()) { - // Looks like this - // "Person": ["John"], - // "Location": ["diner", "Starbucks"] - mentions.add(Pair.of(item.asText(), fieldEntry.getKey())); - } - if (item.isObject()) { - // Looks like this - // "politicians": [ - // { "name": "President Livingston" }, - // { "name": "John" }, - // { "name": "Don Horny" } - // ] - var subFieldIterator = item.fields(); - while (subFieldIterator.hasNext()) { - var subEntry = subFieldIterator.next(); - if (subEntry.getValue().isTextual()) { - mentions.add(Pair.of(subEntry.getValue().asText(), - fieldEntry.getKey())); - } - // We assume that the first item is the most relevant one (the - // mention) so we do not get a bad mention in cases like this: - // { - // "name": "Don Horny", - // "affiliation": "Lord of Darkness" - // } - break; - } - } - } - } - - // Looks like this - // "John": {"type": "PERSON"}, - // "diner": {"type": "LOCATION"}, - // "Starbucks": {"type": "LOCATION"} - if (fieldEntry.getValue().isObject()) { - mentions.add(Pair.of(fieldEntry.getKey(), null)); - } - - // Looks like this - // "John": "politician", - // "President Livingston": "politician", - // "minister of foreign affairs": "politician", - // "Don Horny": "politician" - if (fieldEntry.getValue().isTextual()) { - mentions.add(Pair.of(fieldEntry.getKey(), fieldEntry.getValue().asText())); - } - } - } - catch (IOException e) { - LOG.error("Unable to extract mentions - not valid JSON: [" + aResponse + "]"); - } - return mentions; - } - - private void mentionsToPredictions(AnnotationFS aCandidate, List> mentions) - { - var cas = aCandidate.getCAS(); - var text = aCandidate.getCoveredText(); - var predictedType = getPredictedType(cas); - var predictedFeature = getPredictedFeature(cas); - var isPredictionFeature = getIsPredictionFeature(cas); - - for (var entry : mentions) { - var mention = entry.getKey(); - if (mention.isBlank()) { - LOG.debug("Blank mention ignored"); - continue; - } - - var label = entry.getValue(); - var lastIndex = 0; - var index = text.indexOf(mention, lastIndex); - var hitCount = 0; - while (index >= 0) { - int begin = aCandidate.getBegin() + index; - var prediction = cas.createAnnotation(predictedType, begin, - begin + mention.length()); - prediction.setBooleanValue(isPredictionFeature, true); - if (label != null) { - prediction.setStringValue(predictedFeature, label); - } - cas.addFsToIndexes(prediction); - LOG.debug("Prediction generated [{}] -> [{}]", mention, label); - hitCount++; - - lastIndex = index + mention.length(); - index = text.indexOf(mention, lastIndex); - - if (hitCount > text.length() / mention.length()) { - LOG.error( - "Mention detection seems to have entered into an endless loop - aborting"); - break; - } - } - - if (hitCount == 0) { - LOG.debug("Mention [{}] not found", mention); - } - } - } - - private void predictResultAsLabel(AnnotationFS aCandidate, String aResponse) + private Range predict(PromptBindingsGenerator aGenerator, RecommenderContext aContext, CAS aCas, + int aBegin, int aEnd) { - var aCas = aCandidate.getCAS(); - - var predictedType = getPredictedType(aCas); - var predictedFeature = getPredictedFeature(aCas); - var isPredictionFeature = getIsPredictionFeature(aCas); - - var prediction = aCas.createAnnotation(predictedType, aCandidate.getBegin(), - aCandidate.getEnd()); - prediction.setFeatureValueFromString(predictedFeature, aResponse); - prediction.setBooleanValue(isPredictionFeature, true); - aCas.addFsToIndexes(prediction); - - LOG.debug("Prediction generated [{}] -> [{}]", prediction.getCoveredText(), aResponse); - } - - private Range predictPerSentence(RecommenderContext aContext, CAS aCas, int aBegin, int aEnd) - { - var candidateType = CasUtil.getAnnotationType(aCas, Sentence.class); - - for (var candidate : selectOverlapping(aCas, candidateType, aBegin, aEnd)) { - var bindings = Map.of(VAR_TEXT, candidate.getCoveredText()); - var prompt = jinjava.render(traits.getPrompt(), bindings); - + aGenerator.generate(aCas, aBegin, aEnd).forEach(promptContext -> { try { - var response = generate(prompt); + var prompt = jinjava.render(traits.getPrompt(), promptContext.getBindings()); + var response = query(prompt); - extractPredictions(candidate, response); + extractPredictions(aCas, promptContext, response); } catch (IOException e) { LOG.error("Ollama [{}] failed to respond: {}", traits.getModel(), ExceptionUtils.getRootCauseMessage(e)); } - } - - return new Range(aBegin, aEnd); - } - - private Range predictPerAnnotation(RecommenderContext aContext, CAS aCas, int aBegin, int aEnd) - { - var predictedType = getPredictedType(aCas); - - for (var candidate : selectOverlapping(aCas, predictedType, aBegin, aEnd)) { - String sentence = aCas.select(Sentence.class).covering(candidate) - .map(Sentence::getCoveredText).findFirst().orElse(""); - var bindings = Map.of( // - VAR_TEXT, candidate.getCoveredText(), // - VAR_SENTENCE, sentence); - var prompt = jinjava.render(traits.getPrompt(), bindings); - - try { - var response = generate(prompt); - - extractPredictions(candidate, response); - } - catch (IOException e) { - LOG.error("Ollama [{}] failed to respond: {}", traits.getModel(), - ExceptionUtils.getRootCauseMessage(e)); - } - } + }); return new Range(aBegin, aEnd); } - private String generate(String prompt) throws IOException + private String query(String prompt) throws IOException { - LOG.trace("Asking ollama [{}]: [{}]", traits.getModel(), prompt); + LOG.trace("Querying ollama [{}]: [{}]", traits.getModel(), prompt); var request = OllamaGenerateRequest.builder() // .withModel(traits.getModel()) // .withPrompt(prompt) // @@ -337,4 +130,19 @@ private String generate(String prompt) throws IOException LOG.trace("Ollama [{}] responds: [{}]", traits.getModel(), response); return response; } + + private void extractPredictions(CAS aCas, PromptContext aContext, String aResponse) + { + switch (traits.getExtractionMode()) { + case RESPONSE_AS_LABEL: + new ResponseAsLabelExtractor().extract(this, aCas, aContext, aResponse); + break; + case MENTIONS_FROM_JSON: + new MentionsFromJsonExtractor().extract(this, aCas, aContext, aResponse); + break; + default: + throw new IllegalArgumentException( + "Unsupported extraction mode [" + traits.getExtractionMode() + "]"); + } + } } diff --git a/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/OllamaRecommenderTraits.java b/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/OllamaRecommenderTraits.java index 6d223c73c64..49db69809b0 100644 --- a/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/OllamaRecommenderTraits.java +++ b/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/OllamaRecommenderTraits.java @@ -27,6 +27,7 @@ import com.fasterxml.jackson.annotation.JsonInclude; import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.client.OllamaGenerateResponseFormat; +import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.prompt.PromptingMode; @JsonIgnoreProperties(ignoreUnknown = true) public class OllamaRecommenderTraits diff --git a/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/Preset.java b/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/Preset.java index 57b282eece5..57377ed3aef 100644 --- a/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/Preset.java +++ b/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/Preset.java @@ -22,6 +22,7 @@ import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.client.OllamaGenerateResponseFormat; +import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.prompt.PromptingMode; @JsonIgnoreProperties(ignoreUnknown = true) public class Preset diff --git a/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/PromptingModeSelect.java b/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/PromptingModeSelect.java index 81a41afd5ea..2bb7d56f8ea 100644 --- a/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/PromptingModeSelect.java +++ b/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/PromptingModeSelect.java @@ -23,6 +23,8 @@ import org.apache.wicket.markup.html.form.EnumChoiceRenderer; import org.apache.wicket.model.IModel; +import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.prompt.PromptingMode; + public class PromptingModeSelect extends DropDownChoice { diff --git a/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/prompt/PerAnnotationBindingsGenerator.java b/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/prompt/PerAnnotationBindingsGenerator.java new file mode 100644 index 00000000000..7c3a955a149 --- /dev/null +++ b/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/prompt/PerAnnotationBindingsGenerator.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Technische Universität Darmstadt under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The Technische Universität Darmstadt + * licenses this file to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package de.tudarmstadt.ukp.inception.recommendation.imls.ollama.prompt; + +import static de.tudarmstadt.ukp.clarin.webanno.api.annotation.util.WebAnnoCasUtil.selectOverlapping; + +import java.util.stream.Stream; + +import org.apache.uima.cas.CAS; +import org.apache.uima.fit.util.CasUtil; + +import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Sentence; + +public class PerAnnotationBindingsGenerator + implements PromptBindingsGenerator +{ + + @Override + public Stream generate(CAS aCas, int aBegin, int aEnd) + { + var candidateType = CasUtil.getAnnotationType(aCas, Sentence.class); + return selectOverlapping(aCas, candidateType, aBegin, aEnd).stream().map(candidate -> { + var sentence = aCas.select(Sentence.class).covering(candidate) // + .map(Sentence::getCoveredText) // + .findFirst().orElse(""); + var context = new PromptContext(candidate); + context.set(VAR_TEXT, candidate.getCoveredText()); + context.set(VAR_SENTENCE, sentence); + return context; + + }); + } +} diff --git a/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/prompt/PerDocumentBindingsGenerator.java b/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/prompt/PerDocumentBindingsGenerator.java new file mode 100644 index 00000000000..732237c239c --- /dev/null +++ b/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/prompt/PerDocumentBindingsGenerator.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Technische Universität Darmstadt under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The Technische Universität Darmstadt + * licenses this file to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package de.tudarmstadt.ukp.inception.recommendation.imls.ollama.prompt; + +import java.util.stream.Stream; + +import org.apache.uima.cas.CAS; + +public class PerDocumentBindingsGenerator + implements PromptBindingsGenerator +{ + + @Override + public Stream generate(CAS aCas, int aBegin, int aEnd) + { + var candidate = aCas.getDocumentAnnotation(); + var context = new PromptContext(candidate); + context.set(VAR_TEXT, aCas.getDocumentText()); + return Stream.of(context); + } +} diff --git a/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/prompt/PerSentenceBindingsGenerator.java b/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/prompt/PerSentenceBindingsGenerator.java new file mode 100644 index 00000000000..1e32be1ed02 --- /dev/null +++ b/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/prompt/PerSentenceBindingsGenerator.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Technische Universität Darmstadt under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The Technische Universität Darmstadt + * licenses this file to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package de.tudarmstadt.ukp.inception.recommendation.imls.ollama.prompt; + +import static de.tudarmstadt.ukp.clarin.webanno.api.annotation.util.WebAnnoCasUtil.selectOverlapping; + +import java.util.stream.Stream; + +import org.apache.uima.cas.CAS; +import org.apache.uima.fit.util.CasUtil; + +import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Sentence; + +public class PerSentenceBindingsGenerator + implements PromptBindingsGenerator +{ + + @Override + public Stream generate(CAS aCas, int aBegin, int aEnd) + { + var candidateType = CasUtil.getAnnotationType(aCas, Sentence.class); + return selectOverlapping(aCas, candidateType, aBegin, aEnd).stream().map(candidate -> { + var context = new PromptContext(candidate); + context.set(VAR_TEXT, candidate.getCoveredText()); + return context; + }); + } +} diff --git a/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/prompt/PromptBindingsGenerator.java b/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/prompt/PromptBindingsGenerator.java new file mode 100644 index 00000000000..705516c2f69 --- /dev/null +++ b/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/prompt/PromptBindingsGenerator.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Technische Universität Darmstadt under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The Technische Universität Darmstadt + * licenses this file to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package de.tudarmstadt.ukp.inception.recommendation.imls.ollama.prompt; + +import java.util.stream.Stream; + +import org.apache.uima.cas.CAS; + +public interface PromptBindingsGenerator +{ + static final String VAR_TEXT = "text"; + static final String VAR_SENTENCE = "sentence"; + static final String VAR_DOCUMENT = "document"; + + Stream generate(CAS aCas, int aBegin, int aEnd); +} diff --git a/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/prompt/PromptContext.java b/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/prompt/PromptContext.java new file mode 100644 index 00000000000..4e9a09211d2 --- /dev/null +++ b/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/prompt/PromptContext.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Technische Universität Darmstadt under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The Technische Universität Darmstadt + * licenses this file to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package de.tudarmstadt.ukp.inception.recommendation.imls.ollama.prompt; + +import java.util.HashMap; +import java.util.Map; + +import org.apache.uima.cas.text.AnnotationFS; + +public class PromptContext +{ + private final AnnotationFS candidate; + private final Map bindings; + + public PromptContext(AnnotationFS aCandidate) + { + candidate = aCandidate; + bindings = new HashMap<>(); + } + + public AnnotationFS getCandidate() + { + return candidate; + } + + public void set(String aKey, String aValue) + { + bindings.put(aKey, aValue); + } + + public Map getBindings() + { + return bindings; + } +} diff --git a/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/PromptingMode.java b/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/prompt/PromptingMode.java similarity index 93% rename from inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/PromptingMode.java rename to inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/prompt/PromptingMode.java index 9930ed990dd..56ec5a40e90 100644 --- a/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/PromptingMode.java +++ b/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/prompt/PromptingMode.java @@ -15,7 +15,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package de.tudarmstadt.ukp.inception.recommendation.imls.ollama; +package de.tudarmstadt.ukp.inception.recommendation.imls.ollama.prompt; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/response/MentionsFromJsonExtractor.java b/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/response/MentionsFromJsonExtractor.java new file mode 100644 index 00000000000..b93b2e68d60 --- /dev/null +++ b/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/response/MentionsFromJsonExtractor.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Technische Universität Darmstadt under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The Technische Universität Darmstadt + * licenses this file to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package de.tudarmstadt.ukp.inception.recommendation.imls.ollama.response; + +import java.io.IOException; +import java.lang.invoke.MethodHandles; +import java.util.ArrayList; +import java.util.List; + +import org.apache.commons.lang3.tuple.Pair; +import org.apache.uima.cas.CAS; +import org.apache.uima.cas.text.AnnotationFS; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationEngine; +import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.prompt.PromptContext; +import de.tudarmstadt.ukp.inception.support.json.JSONUtil; + +public class MentionsFromJsonExtractor + implements ResponseExtractor +{ + private final static Logger LOG = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); + + @Override + public void extract(RecommendationEngine aEngine, CAS aCas, PromptContext aContext, + String aResponse) + { + var mentions = extractMentionFromJson(aResponse); + mentionsToPredictions(aEngine, aCas, aContext.getCandidate(), mentions); + } + + List> extractMentionFromJson(String aResponse) + { + var mentions = new ArrayList>(); + try { + // Ollama JSON mode always returns a JSON object + // See: + // https://github.com/jmorganca/ollama/commit/5cba29b9d666854706a194805c9d66518fe77545#diff-a604f7ba9b7f66dd7b59a9e884d3c82c96e5269fee85c906a7cca5f0c3eff7f8R30-R57 + var rootNode = JSONUtil.getObjectMapper().readTree(aResponse); + + var fieldIterator = rootNode.fields(); + while (fieldIterator.hasNext()) { + var fieldEntry = fieldIterator.next(); + if (fieldEntry.getValue().isArray()) { + for (var item : fieldEntry.getValue()) { + if (item.isTextual()) { + // Looks like this + // "Person": ["John"], + // "Location": ["diner", "Starbucks"] + mentions.add(Pair.of(item.asText(), fieldEntry.getKey())); + } + if (item.isObject()) { + // Looks like this + // "politicians": [ + // { "name": "President Livingston" }, + // { "name": "John" }, + // { "name": "Don Horny" } + // ] + var subFieldIterator = item.fields(); + while (subFieldIterator.hasNext()) { + var subEntry = subFieldIterator.next(); + if (subEntry.getValue().isTextual()) { + mentions.add(Pair.of(subEntry.getValue().asText(), + fieldEntry.getKey())); + } + break; + } + } + } + } + + // Looks like this + // "John": {"type": "PERSON"}, + // "diner": {"type": "LOCATION"}, + // "Starbucks": {"type": "LOCATION"} + if (fieldEntry.getValue().isObject()) { + mentions.add(Pair.of(fieldEntry.getKey(), null)); + } + + // Looks like this + // "John": "politician", + // "President Livingston": "politician", + // "minister of foreign affairs": "politician", + // "Don Horny": "politician" + if (fieldEntry.getValue().isTextual()) { + mentions.add(Pair.of(fieldEntry.getKey(), fieldEntry.getValue().asText())); + } + } + } + catch (IOException e) { + LOG.error("Unable to extract mentions - not valid JSON: [" + aResponse + "]"); + } + return mentions; + } + + private void mentionsToPredictions(RecommendationEngine aEngine, CAS aCas, + AnnotationFS aCandidate, List> mentions) + { + var text = aCandidate.getCoveredText(); + var predictedType = aEngine.getPredictedType(aCas); + var predictedFeature = aEngine.getPredictedFeature(aCas); + var isPredictionFeature = aEngine.getIsPredictionFeature(aCas); + + for (var entry : mentions) { + var mention = entry.getKey(); + if (mention.isBlank()) { + LOG.debug("Blank mention ignored"); + continue; + } + + var label = entry.getValue(); + var lastIndex = 0; + var index = text.indexOf(mention, lastIndex); + var hitCount = 0; + while (index >= 0) { + int begin = aCandidate.getBegin() + index; + var prediction = aCas.createAnnotation(predictedType, begin, + begin + mention.length()); + prediction.setBooleanValue(isPredictionFeature, true); + if (label != null) { + prediction.setStringValue(predictedFeature, label); + } + aCas.addFsToIndexes(prediction); + LOG.debug("Prediction generated [{}] -> [{}]", mention, label); + hitCount++; + + lastIndex = index + mention.length(); + index = text.indexOf(mention, lastIndex); + + if (hitCount > text.length() / mention.length()) { + LOG.error( + "Mention detection seems to have entered into an endless loop - aborting"); + break; + } + } + + if (hitCount == 0) { + LOG.debug("Mention [{}] not found", mention); + } + } + } +} diff --git a/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/response/ResponseAsLabelExtractor.java b/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/response/ResponseAsLabelExtractor.java new file mode 100644 index 00000000000..7a49f572aab --- /dev/null +++ b/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/response/ResponseAsLabelExtractor.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Technische Universität Darmstadt under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The Technische Universität Darmstadt + * licenses this file to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package de.tudarmstadt.ukp.inception.recommendation.imls.ollama.response; + +import java.lang.invoke.MethodHandles; + +import org.apache.uima.cas.CAS; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationEngine; +import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.prompt.PromptContext; + +public class ResponseAsLabelExtractor + implements ResponseExtractor +{ + private final static Logger LOG = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass()); + + @Override + public void extract(RecommendationEngine aEngine, CAS aCas, PromptContext aContext, + String aResponse) + { + var candidate = aContext.getCandidate(); + + var predictedType = aEngine.getPredictedType(aCas); + var predictedFeature = aEngine.getPredictedFeature(aCas); + var isPredictionFeature = aEngine.getIsPredictionFeature(aCas); + + var prediction = aCas.createAnnotation(predictedType, candidate.getBegin(), + candidate.getEnd()); + prediction.setFeatureValueFromString(predictedFeature, aResponse); + prediction.setBooleanValue(isPredictionFeature, true); + aCas.addFsToIndexes(prediction); + + LOG.debug("Prediction generated [{}] -> [{}]", prediction.getCoveredText(), aResponse); + } +} diff --git a/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/response/ResponseExtractor.java b/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/response/ResponseExtractor.java new file mode 100644 index 00000000000..39fb5fa94bf --- /dev/null +++ b/inception/inception-imls-ollama/src/main/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/response/ResponseExtractor.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Technische Universität Darmstadt under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The Technische Universität Darmstadt + * licenses this file to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package de.tudarmstadt.ukp.inception.recommendation.imls.ollama.response; + +import org.apache.uima.cas.CAS; + +import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationEngine; +import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.prompt.PromptContext; + +public interface ResponseExtractor +{ + void extract(RecommendationEngine aEngine, CAS aCas, PromptContext aCandidate, + String aResponse); +} diff --git a/inception/inception-imls-ollama/src/test/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/OllamaRecommenderTest.java b/inception/inception-imls-ollama/src/test/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/OllamaRecommenderTest.java index b867e59b7d6..a8deeee3f24 100644 --- a/inception/inception-imls-ollama/src/test/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/OllamaRecommenderTest.java +++ b/inception/inception-imls-ollama/src/test/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/OllamaRecommenderTest.java @@ -41,6 +41,7 @@ import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommenderContext; import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.client.OllamaClientImpl; import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.client.OllamaGenerateResponseFormat; +import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.prompt.PromptingMode; @Disabled("Requires locally running ollama") class OllamaRecommenderTest diff --git a/inception/inception-imls-ollama/src/test/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/response/JsonExtractorTest.java b/inception/inception-imls-ollama/src/test/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/response/JsonExtractorTest.java new file mode 100644 index 00000000000..104ed96cc3b --- /dev/null +++ b/inception/inception-imls-ollama/src/test/java/de/tudarmstadt/ukp/inception/recommendation/imls/ollama/response/JsonExtractorTest.java @@ -0,0 +1,100 @@ +package de.tudarmstadt.ukp.inception.recommendation.imls.ollama.response; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.apache.commons.lang3.tuple.Pair; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +class JsonExtractorTest +{ + private MentionsFromJsonExtractor sut = new MentionsFromJsonExtractor(); + + @Test + void testExtractMentionFromJson_variant1() + { + var json = """ + { + "Person": ["John"], + "Location": ["diner", "Starbucks"] + } + """; + assertThat(sut.extractMentionFromJson(json)) // + .containsExactly( // + Pair.of("John", "Person"), // + Pair.of("diner", "Location"), // + Pair.of("Starbucks", "Location")); + } + + @Test + void testExtractMentionFromJson_variant2() + { + var json = """ + { + "politicians": [ + { "name": "President Livingston" }, + { "name": "John" }, + { "name": "Don Horny" } + ] + } + """; + assertThat(sut.extractMentionFromJson(json)) // + .containsExactly( // + Pair.of("President Livingston", "politicians"), // + Pair.of("John", "politicians"), // + Pair.of("Don Horny", "politicians")); + } + + @Test + void testExtractMentionFromJson_variant3() + { + var json = """ + { + "John": {"type": "PERSON"}, + "diner": {"type": "LOCATION"}, + "Starbucks": {"type": "LOCATION"} + } + """; + assertThat(sut.extractMentionFromJson(json)) // + .containsExactly( // + Pair.of("John", null), // + Pair.of("diner", null), // + Pair.of("Starbucks", null)); + } + + @Test + void testExtractMentionFromJson_variant4() + { + var json = """ + { + "John": "politician", + "President Livingston": "politician", + "minister of foreign affairs": "politician", + "Don Horny": "politician" + } + """; + assertThat(sut.extractMentionFromJson(json)) // + .containsExactly( // + Pair.of("John", "politician"), // + Pair.of("President Livingston", "politician"), // + Pair.of("minister of foreign affairs", "politician"), // + Pair.of("Don Horny", "politician")); + } + + @Disabled("Cannot really tell this one apart from variant 4") + @Test + void testExtractMentionFromJson_variant5() + { + // We assume that the first item is the most relevant one (the + // mention) so we do not get a bad mention in cases like this: + var json = """ + { + "name": "Don Horny", + "affiliation": "Lord of Darkness" + } + """; + assertThat(sut.extractMentionFromJson(json)) // + .containsExactly( // + Pair.of("Don Horny", null)); + } +} diff --git a/inception/inception-recommendation-api/src/main/java/de/tudarmstadt/ukp/inception/recommendation/api/recommender/RecommendationEngine.java b/inception/inception-recommendation-api/src/main/java/de/tudarmstadt/ukp/inception/recommendation/api/recommender/RecommendationEngine.java index fbcfa173ebf..d2494c4362a 100644 --- a/inception/inception-recommendation-api/src/main/java/de/tudarmstadt/ukp/inception/recommendation/api/recommender/RecommendationEngine.java +++ b/inception/inception-recommendation-api/src/main/java/de/tudarmstadt/ukp/inception/recommendation/api/recommender/RecommendationEngine.java @@ -191,12 +191,12 @@ public RecommenderContext newContext(RecommenderContext aContext) */ public abstract int estimateSampleCount(List aCasses); - protected Type getPredictedType(CAS aCas) + public Type getPredictedType(CAS aCas) { return getType(aCas, layerName); } - protected Feature getPredictedFeature(CAS aCas) + public Feature getPredictedFeature(CAS aCas) { return getPredictedType(aCas).getFeatureByBaseName(featureName); } @@ -219,7 +219,7 @@ protected Feature getModeFeature(CAS aCas) return getPredictedType(aCas).getFeatureByBaseName(scoreExplanationFeature); } - protected Feature getIsPredictionFeature(CAS aCas) + public Feature getIsPredictionFeature(CAS aCas) { return getPredictedType(aCas).getFeatureByBaseName(FEATURE_NAME_IS_PREDICTION); }