Skip to content

Commit

Permalink
#4292 - ollama-based recommender
Browse files Browse the repository at this point in the history
- Steps towards providing context to prompt templates
  • Loading branch information
reckart committed Nov 27, 2023
1 parent 7b88375 commit 881ef5b
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ public Range predict(RecommenderContext aContext, CAS aCas, int aBegin, int aEnd
private Range predict(PromptBindingsGenerator aGenerator, RecommenderContext aContext, CAS aCas,
int aBegin, int aEnd)
{
aGenerator.generate(aCas, aBegin, aEnd).forEach(promptContext -> {
aGenerator.generate(this, aCas, aBegin, aEnd).forEach(promptContext -> {
try {
var prompt = jinjava.render(traits.getPrompt(), promptContext.getBindings());
var response = query(prompt);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import static de.tudarmstadt.ukp.inception.support.WebAnnoConst.SPAN_TYPE;
import static org.apache.uima.cas.CAS.TYPE_NAME_STRING;

import java.io.IOException;
import java.lang.invoke.MethodHandles;
import java.util.ArrayList;
import java.util.Collections;
Expand Down Expand Up @@ -121,7 +120,7 @@ private List<Preset> getPresets()
try {
return presets.get().get();
}
catch (IOException e) {
catch (Exception e) {
LOG.error("Unable to load presets", e);
return Collections.emptyList();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,30 +19,49 @@

import static de.tudarmstadt.ukp.clarin.webanno.api.annotation.util.WebAnnoCasUtil.selectOverlapping;

import java.util.LinkedHashMap;
import java.util.stream.Stream;

import org.apache.uima.cas.CAS;
import org.apache.uima.fit.util.CasUtil;
import org.apache.uima.fit.util.FSUtil;

import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Sentence;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationEngine;

public class PerAnnotationBindingsGenerator
implements PromptBindingsGenerator
{

@Override
public Stream<PromptContext> generate(CAS aCas, int aBegin, int aEnd)
public Stream<PromptContext> generate(RecommendationEngine aEngine, CAS aCas, int aBegin,
int aEnd)
{
var candidateType = CasUtil.getAnnotationType(aCas, Sentence.class);
return selectOverlapping(aCas, candidateType, aBegin, aEnd).stream().map(candidate -> {
var predictedFeature = aEngine.getPredictedFeature(aCas);

var candidates = selectOverlapping(aCas, candidateType, aBegin, aEnd);

var examples = new LinkedHashMap<String, String>();
for (var candidate : candidates) {
var text = candidate.getCoveredText();
var label = FSUtil.getFeature(candidate, predictedFeature, String.class);

examples.put(text, label);

if (examples.size() >= 10) {
break;
}
}

return candidates.stream().map(candidate -> {
var sentence = aCas.select(Sentence.class).covering(candidate) //
.map(Sentence::getCoveredText) //
.findFirst().orElse("");
var context = new PromptContext(candidate);
context.set(VAR_TEXT, candidate.getCoveredText());
context.set(VAR_SENTENCE, sentence);
context.set(VAR_EXAMPLES, examples);
return context;

});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@

import org.apache.uima.cas.CAS;

import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationEngine;

public class PerDocumentBindingsGenerator
implements PromptBindingsGenerator
{

@Override
public Stream<PromptContext> generate(CAS aCas, int aBegin, int aEnd)
public Stream<PromptContext> generate(RecommendationEngine aEngine, CAS aCas, int aBegin,
int aEnd)
{
var candidate = aCas.getDocumentAnnotation();
var context = new PromptContext(candidate);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@
import org.apache.uima.fit.util.CasUtil;

import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Sentence;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationEngine;

public class PerSentenceBindingsGenerator
implements PromptBindingsGenerator
{

@Override
public Stream<PromptContext> generate(CAS aCas, int aBegin, int aEnd)
public Stream<PromptContext> generate(RecommendationEngine aEngine, CAS aCas, int aBegin,
int aEnd)
{
var candidateType = CasUtil.getAnnotationType(aCas, Sentence.class);
return selectOverlapping(aCas, candidateType, aBegin, aEnd).stream().map(candidate -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,14 @@

import org.apache.uima.cas.CAS;

import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationEngine;

public interface PromptBindingsGenerator
{
static final String VAR_TEXT = "text";
static final String VAR_SENTENCE = "sentence";
static final String VAR_DOCUMENT = "document";
static final String VAR_EXAMPLES = "examples";

Stream<PromptContext> generate(CAS aCas, int aBegin, int aEnd);
Stream<PromptContext> generate(RecommendationEngine aEngine, CAS aCas, int aBegin, int aEnd);
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
public class PromptContext
{
private final AnnotationFS candidate;
private final Map<String, String> bindings;
private final Map<String, Object> bindings;

public PromptContext(AnnotationFS aCandidate)
{
Expand All @@ -38,12 +38,12 @@ public AnnotationFS getCandidate()
return candidate;
}

public void set(String aKey, String aValue)
public void set(String aKey, Object aValue)
{
bindings.put(aKey, aValue);
}

public Map<String, String> getBindings()
public Map<String, Object> getBindings()
{
return bindings;
}
Expand Down

0 comments on commit 881ef5b

Please sign in to comment.