Skip to content

Commit

Permalink
#4292 - ollama-based recommender
Browse files Browse the repository at this point in the history
- Steps towards few-shot predictions
  • Loading branch information
reckart committed Nov 28, 2023
1 parent 881ef5b commit 5162dae
Show file tree
Hide file tree
Showing 18 changed files with 455 additions and 94 deletions.
5 changes: 5 additions & 0 deletions inception/inception-imls-ollama/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -131,5 +131,10 @@
<artifactId>dkpro-core-api-ner-asl</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>de.tudarmstadt.ukp.inception.app</groupId>
<artifactId>inception-testing</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
</project>
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
*/
package de.tudarmstadt.ukp.inception.recommendation.imls.ollama;

import static de.tudarmstadt.ukp.inception.recommendation.imls.ollama.prompt.PromptContextGenerator.VAR_EXAMPLES;

import java.io.IOException;
import java.lang.invoke.MethodHandles;
import java.nio.charset.Charset;
import java.util.List;

import org.apache.commons.lang3.exception.ExceptionUtils;
import org.apache.uima.cas.CAS;
Expand All @@ -38,18 +41,22 @@
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommenderContext;
import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.client.OllamaClient;
import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.client.OllamaGenerateRequest;
import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.prompt.PerAnnotationBindingsGenerator;
import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.prompt.PerDocumentBindingsGenerator;
import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.prompt.PerSentenceBindingsGenerator;
import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.prompt.PromptBindingsGenerator;
import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.prompt.PerAnnotationContextGenerator;
import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.prompt.PerDocumentContextGenerator;
import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.prompt.PerSentenceContextGenerator;
import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.prompt.PromptContext;
import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.prompt.PromptContextGenerator;
import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.response.MentionsFromJsonExtractor;
import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.response.MentionsSample;
import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.response.ResponseAsLabelExtractor;
import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.response.ResponseExtractor;
import de.tudarmstadt.ukp.inception.rendering.model.Range;

public class OllamaRecommender
extends NonTrainableRecommenderEngineImplBase
{
private static final int MAX_FEW_SHOT_EXAMPLES = 10;

private final static Logger LOG = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());

private final OllamaRecommenderTraits traits;
Expand Down Expand Up @@ -84,28 +91,18 @@ public String getString(String aFullName, Charset aEncoding,
public Range predict(RecommenderContext aContext, CAS aCas, int aBegin, int aEnd)
throws RecommendationException
{
switch (traits.getPromptingMode()) {
case PER_ANNOTATION:
return predict(new PerAnnotationBindingsGenerator(), aContext, aCas, aBegin, aEnd);
case PER_SENTENCE:
return predict(new PerSentenceBindingsGenerator(), aContext, aCas, aBegin, aEnd);
case PER_DOCUMENT:
return predict(new PerDocumentBindingsGenerator(), aContext, aCas, aBegin, aEnd);
default:
throw new RecommendationException(
"Unsupported mode [" + traits.getPromptingMode() + "]");
}
}
var responseExtractor = getResponseExtractor();
List<MentionsSample> examples = responseExtractor.generate(this, aCas, MAX_FEW_SHOT_EXAMPLES);

private Range predict(PromptBindingsGenerator aGenerator, RecommenderContext aContext, CAS aCas,
int aBegin, int aEnd)
{
aGenerator.generate(this, aCas, aBegin, aEnd).forEach(promptContext -> {
getPromptContextGenerator().generate(this, aCas, aBegin, aEnd).forEach(promptContext -> {
try {
var prompt = jinjava.render(traits.getPrompt(), promptContext.getBindings());
var response = query(prompt);
var bindings = promptContext.getBindings();

bindings.put(VAR_EXAMPLES, examples);

var response = query(promptContext);

extractPredictions(aCas, promptContext, response);
responseExtractor.extract(this, aCas, promptContext, response);
}
catch (IOException e) {
LOG.error("Ollama [{}] failed to respond: {}", traits.getModel(),
Expand All @@ -116,30 +113,47 @@ private Range predict(PromptBindingsGenerator aGenerator, RecommenderContext aCo
return new Range(aBegin, aEnd);
}

private String query(String prompt) throws IOException
private String query(PromptContext aContext) throws IOException
{
var prompt = jinjava.render(traits.getPrompt(), aContext.getBindings());

LOG.trace("Querying ollama [{}]: [{}]", traits.getModel(), prompt);
var request = OllamaGenerateRequest.builder() //
.withModel(traits.getModel()) //
.withPrompt(prompt) //
.withFormat(traits.getFormat()) //
.withRaw(traits.isRaw()) //
.withStream(false) //
// FIXME: Make NUM_PREDICT accessible in UI
.withOption(OllamaGenerateRequest.NUM_PREDICT, 300) //
.build();
var response = client.generate(traits.getUrl(), request).trim();
LOG.trace("Ollama [{}] responds: [{}]", traits.getModel(), response);
return response;
}

private void extractPredictions(CAS aCas, PromptContext aContext, String aResponse)
private PromptContextGenerator getPromptContextGenerator()
{
switch (traits.getPromptingMode()) {
case PER_ANNOTATION:
return new PerAnnotationContextGenerator();
case PER_SENTENCE:
return new PerSentenceContextGenerator();
case PER_DOCUMENT:
return new PerDocumentContextGenerator();
default:
throw new IllegalArgumentException(
"Unsupported mode [" + traits.getPromptingMode() + "]");
}
}

private ResponseExtractor getResponseExtractor()
{
switch (traits.getExtractionMode()) {
case RESPONSE_AS_LABEL:
new ResponseAsLabelExtractor().extract(this, aCas, aContext, aResponse);
break;
return new ResponseAsLabelExtractor();
case MENTIONS_FROM_JSON:
new MentionsFromJsonExtractor().extract(this, aCas, aContext, aResponse);
break;
return new MentionsFromJsonExtractor();
default:
throw new IllegalArgumentException(
"Unsupported extraction mode [" + traits.getExtractionMode() + "]");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@
public class OllamaRecommenderTraits
implements Serializable
{
public static final String DEFAULT_OLLAMA_URL = "http://localhost:11434/";

private static final long serialVersionUID = -8760059914187478368L;

private String url = "http://localhost:11434/";
private String url = DEFAULT_OLLAMA_URL;

private String model;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ private OllamaGenerateRequest(Builder builder)
format = builder.format;
stream = builder.stream;
raw = builder.raw;
options = builder.options;
}

public OllamaGenerateResponseFormat getFormat()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,48 +19,29 @@

import static de.tudarmstadt.ukp.clarin.webanno.api.annotation.util.WebAnnoCasUtil.selectOverlapping;

import java.util.LinkedHashMap;
import java.util.stream.Stream;

import org.apache.uima.cas.CAS;
import org.apache.uima.fit.util.CasUtil;
import org.apache.uima.fit.util.FSUtil;

import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Sentence;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationEngine;

public class PerAnnotationBindingsGenerator
implements PromptBindingsGenerator
public class PerAnnotationContextGenerator
implements PromptContextGenerator
{
@Override
public Stream<PromptContext> generate(RecommendationEngine aEngine, CAS aCas, int aBegin,
int aEnd)
{
var candidateType = CasUtil.getAnnotationType(aCas, Sentence.class);
var predictedFeature = aEngine.getPredictedFeature(aCas);

var candidates = selectOverlapping(aCas, candidateType, aBegin, aEnd);

var examples = new LinkedHashMap<String, String>();
for (var candidate : candidates) {
var text = candidate.getCoveredText();
var label = FSUtil.getFeature(candidate, predictedFeature, String.class);

examples.put(text, label);

if (examples.size() >= 10) {
break;
}
}

var predictedType = aEngine.getPredictedType(aCas);
var candidates = selectOverlapping(aCas, predictedType, aBegin, aEnd);
return candidates.stream().map(candidate -> {
var sentence = aCas.select(Sentence.class).covering(candidate) //
.map(Sentence::getCoveredText) //
.findFirst().orElse("");
var context = new PromptContext(candidate);
context.set(VAR_TEXT, candidate.getCoveredText());
context.set(VAR_SENTENCE, sentence);
context.set(VAR_EXAMPLES, examples);
return context;
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@

import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationEngine;

public class PerDocumentBindingsGenerator
implements PromptBindingsGenerator
public class PerDocumentContextGenerator
implements PromptContextGenerator
{

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,18 @@
import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Sentence;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationEngine;

public class PerSentenceBindingsGenerator
implements PromptBindingsGenerator
public class PerSentenceContextGenerator
implements PromptContextGenerator
{
@Override
public Stream<PromptContext> generate(RecommendationEngine aEngine, CAS aCas, int aBegin,
int aEnd)
{
var candidateType = CasUtil.getAnnotationType(aCas, Sentence.class);
return selectOverlapping(aCas, candidateType, aBegin, aEnd).stream().map(candidate -> {
var sentenceType = CasUtil.getAnnotationType(aCas, Sentence.class);

var candidates = selectOverlapping(aCas, sentenceType, aBegin, aEnd);

return candidates.stream().map(candidate -> {
var context = new PromptContext(candidate);
context.set(VAR_TEXT, candidate.getCoveredText());
return context;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationEngine;

public interface PromptBindingsGenerator
public interface PromptContextGenerator
{
static final String VAR_TEXT = "text";
static final String VAR_SENTENCE = "sentence";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,21 @@
import java.io.IOException;
import java.lang.invoke.MethodHandles;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

import org.apache.commons.lang3.tuple.Pair;
import org.apache.uima.cas.CAS;
import org.apache.uima.cas.text.AnnotationFS;
import org.apache.uima.fit.util.FSUtil;
import org.apache.uima.jcas.tcas.Annotation;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.nimbusds.oauth2.sdk.util.StringUtils;

import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Sentence;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationEngine;
import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.prompt.PromptContext;
import de.tudarmstadt.ukp.inception.support.json.JSONUtil;
Expand All @@ -37,6 +44,53 @@ public class MentionsFromJsonExtractor
{
private final static Logger LOG = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());

@Override
public List<MentionsSample> generate(RecommendationEngine aEngine, CAS aCas, int aNum)
{
var examples = generateSamples(aEngine, aCas, aNum);

return new ArrayList<>(examples.values());
}

Map<String, MentionsSample> generateSamples(RecommendationEngine aEngine, CAS aCas, int aNum)
{
var predictedType = aEngine.getPredictedType(aCas);
var predictedFeature = aEngine.getPredictedFeature(aCas);

var examples = new LinkedHashMap<String, MentionsSample>();
for (var candidate : aCas.<Annotation> select(predictedType)) {
var sentence = aCas.select(Sentence.class).covering(candidate) //
.map(Sentence::getCoveredText) //
.findFirst().orElse("");

// Skip mentions for which we did not find a sentence
if (StringUtils.isBlank(sentence)) {
continue;
}

// Stop once we have sufficient samples
if (!examples.containsKey(sentence) && examples.size() > aNum) {
break;
}

var example = examples.computeIfAbsent(sentence, MentionsSample::new);
var text = candidate.getCoveredText();
var label = FSUtil.getFeature(candidate, predictedFeature, String.class);
example.addMention(text, label);
}
return examples;
}

private String toJson(Object aObject)
{
try {
return JSONUtil.toJsonString(aObject);
}
catch (IOException e) {
return null;
}
}

@Override
public void extract(RecommendationEngine aEngine, CAS aCas, PromptContext aContext,
String aResponse)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Licensed to the Technische Universität Darmstadt under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The Technische Universität Darmstadt
* licenses this file to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License.
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package de.tudarmstadt.ukp.inception.recommendation.imls.ollama.response;

import java.util.LinkedHashMap;
import java.util.Map;

public class MentionsSample
{
private final String text;
private final Map<String, String> labelledMentions = new LinkedHashMap<>();

public MentionsSample(String aText)
{
text = aText;
}

public void addMention(String aMention, String aLabel)
{
labelledMentions.put(aMention, aLabel);
}

public String getText()
{
return text;
}

public Map<String, String> getLabelledMentions()
{
return labelledMentions;
}
}
Loading

0 comments on commit 5162dae

Please sign in to comment.