Skip to content

Commit

Permalink
Merge pull request #4392 from inception-project/feature/4292-ollama-b…
Browse files Browse the repository at this point in the history
…ased-recommender

#4292 - ollama-based recommender
  • Loading branch information
reckart authored Dec 24, 2023
2 parents fa21655 + 43346f4 commit bdaeb95
Show file tree
Hide file tree
Showing 42 changed files with 349 additions and 110 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import de.tudarmstadt.ukp.inception.recommendation.api.evaluation.DataSplitter;
import de.tudarmstadt.ukp.inception.recommendation.api.evaluation.EvaluationResult;
import de.tudarmstadt.ukp.inception.recommendation.api.model.Recommender;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.PredictionContext;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationEngine;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationException;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommenderContext;
Expand Down Expand Up @@ -90,7 +91,7 @@ public void train(RecommenderContext aContext, List<CAS> aCasList)
}

@Override
public Range predict(RecommenderContext aContext, CAS aCas, int aBegin, int aEnd)
public Range predict(PredictionContext aContext, CAS aCas, int aBegin, int aEnd)
throws RecommendationException
{
Type predictedType = getPredictedType(aCas);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
import de.tudarmstadt.ukp.inception.kb.graph.KBHandle;
import de.tudarmstadt.ukp.inception.kb.model.KnowledgeBase;
import de.tudarmstadt.ukp.inception.recommendation.api.model.Recommender;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.PredictionContext;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommenderContext;
import de.tudarmstadt.ukp.inception.schema.api.feature.FeatureSupport;
import de.tudarmstadt.ukp.inception.schema.api.feature.FeatureSupportRegistry;
Expand Down Expand Up @@ -146,7 +147,7 @@ public void thatPredictionWorks() throws Exception
sut.train(context, Collections.singletonList(cas));
RecommenderTestHelper.addScoreFeature(cas, NamedEntity.class, "value");

sut.predict(context, cas);
sut.predict(new PredictionContext(context), cas);

List<NamedEntity> predictions = getPredictions(cas, NamedEntity.class);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import de.tudarmstadt.ukp.inception.recommendation.api.evaluation.EvaluationResult;
import de.tudarmstadt.ukp.inception.recommendation.api.evaluation.LabelPair;
import de.tudarmstadt.ukp.inception.recommendation.api.model.Recommender;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.PredictionContext;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationEngine;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationException;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommenderContext;
Expand Down Expand Up @@ -133,7 +134,7 @@ private DataMajorityModel trainModel(List<Annotation> aAnnotations)

// tag::predict1[]
@Override
public Range predict(RecommenderContext aContext, CAS aCas, int aBegin, int aEnd)
public Range predict(PredictionContext aContext, CAS aCas, int aBegin, int aEnd)
throws RecommendationException
{
DataMajorityModel model = aContext.get(KEY_MODEL).orElseThrow(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import de.tudarmstadt.ukp.inception.recommendation.api.evaluation.IncrementalSplitter;
import de.tudarmstadt.ukp.inception.recommendation.api.evaluation.PercentageBasedSplitter;
import de.tudarmstadt.ukp.inception.recommendation.api.model.Recommender;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.PredictionContext;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommenderContext;
import de.tudarmstadt.ukp.inception.support.test.recommendation.DkproTestHelper;
import de.tudarmstadt.ukp.inception.support.test.recommendation.RecommenderTestHelper;
Expand Down Expand Up @@ -104,7 +105,7 @@ public void thatPredictionWorks() throws Exception

sut.train(context, asList(cas));

sut.predict(context, cas);
sut.predict(new PredictionContext(context), cas);

Collection<NamedEntity> predictions = getPredictions(cas, NamedEntity.class);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@

import de.tudarmstadt.ukp.inception.recommendation.api.model.Recommender;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.NonTrainableRecommenderEngineImplBase;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.PredictionContext;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationException;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommenderContext;
import de.tudarmstadt.ukp.inception.recommendation.imls.elg.model.ElgAnnotation;
import de.tudarmstadt.ukp.inception.recommendation.imls.elg.model.ElgAnnotationsResponse;
import de.tudarmstadt.ukp.inception.recommendation.imls.elg.model.ElgServiceResponse;
Expand Down Expand Up @@ -67,7 +67,7 @@ public ElgRecommender(Recommender aRecommender, ElgRecommenderTraits aTraits,
}

@Override
public Range predict(RecommenderContext aContext, CAS aCas, int aBegin, int aEnd)
public Range predict(PredictionContext aContext, CAS aCas, int aBegin, int aEnd)
throws RecommendationException
{
ElgServiceResponse response;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
import de.tudarmstadt.ukp.inception.recommendation.api.evaluation.DataSplitter;
import de.tudarmstadt.ukp.inception.recommendation.api.evaluation.EvaluationResult;
import de.tudarmstadt.ukp.inception.recommendation.api.model.Recommender;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.PredictionContext;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationEngine;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationException;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommenderContext;
Expand Down Expand Up @@ -216,7 +217,7 @@ else if (response.statusCode() >= HTTP_BAD_REQUEST) {
}

@Override
public Range predict(RecommenderContext aContext, CAS aCas, int aBegin, int aEnd)
public Range predict(PredictionContext aContext, CAS aCas, int aBegin, int aEnd)
throws RecommendationException
{
var client = getClient();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
import de.tudarmstadt.ukp.inception.annotation.storage.CasMetadataUtils;
import de.tudarmstadt.ukp.inception.annotation.storage.CasStorageSession;
import de.tudarmstadt.ukp.inception.recommendation.api.model.Recommender;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.PredictionContext;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommenderContext;
import de.tudarmstadt.ukp.inception.recommendation.imls.external.v1.config.ExternalRecommenderPropertiesImpl;
import de.tudarmstadt.ukp.inception.recommendation.imls.external.v1.messages.PredictionRequest;
Expand Down Expand Up @@ -133,7 +134,7 @@ public void thatPredictingWorks() throws Exception

var cas = casses.get(0);
RecommenderTestHelper.addScoreFeature(cas, NamedEntity.class, "value");
sut.predict(context, cas);
sut.predict(new PredictionContext(context), cas);

var predictions = getPredictions(cas, NamedEntity.class);

Expand Down Expand Up @@ -180,7 +181,7 @@ public void thatPredictingSendsCorrectRequest() throws Exception

var cas = casses.get(0);
RecommenderTestHelper.addScoreFeature(cas, NamedEntity.class, "value");
sut.predict(context, cas);
sut.predict(new PredictionContext(context), cas);

var request = fromJsonString(PredictionRequest.class, requestBodies.get(1));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import org.xml.sax.SAXException;

import de.tudarmstadt.ukp.inception.recommendation.api.model.Recommender;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.PredictionContext;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationException;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommenderContext;
import de.tudarmstadt.ukp.inception.recommendation.imls.external.v1.messages.PredictionRequest;
Expand Down Expand Up @@ -93,7 +94,7 @@ public String predict(String aPredictionRequestJson)
}
}

recommendationEngine.predict(context, cas);
recommendationEngine.predict(new PredictionContext(context), cas);

return buildPredictionResponse(cas);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@

import de.tudarmstadt.ukp.inception.recommendation.api.model.Recommender;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.NonTrainableRecommenderEngineImplBase;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.PredictionContext;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationException;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommenderContext;
import de.tudarmstadt.ukp.inception.recommendation.imls.hf.client.HfInferenceClient;
import de.tudarmstadt.ukp.inception.recommendation.imls.hf.model.HfEntityGroup;
import de.tudarmstadt.ukp.inception.rendering.model.Range;
Expand All @@ -53,7 +53,7 @@ public HfRecommender(Recommender aRecommender, HfRecommenderTraits aTraits,
}

@Override
public Range predict(RecommenderContext aContext, CAS aCas, int aBegin, int aEnd)
public Range predict(PredictionContext aContext, CAS aCas, int aBegin, int aEnd)
throws RecommendationException
{
List<HfEntityGroup> response;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import de.tudarmstadt.ukp.inception.recommendation.api.evaluation.DataSplitter;
import de.tudarmstadt.ukp.inception.recommendation.api.evaluation.EvaluationResult;
import de.tudarmstadt.ukp.inception.recommendation.api.model.Recommender;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.PredictionContext;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationEngine;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationException;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommenderContext;
Expand Down Expand Up @@ -67,7 +68,7 @@ public void train(RecommenderContext aContext, List<CAS> aCasses)
}

@Override
public Range predict(RecommenderContext aContext, CAS aCas, int aBegin, int aEnd)
public Range predict(PredictionContext aContext, CAS aCas, int aBegin, int aEnd)
throws RecommendationException
{
// FIXME: Ignores begin/end - always fetches predictions for the entire CAS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import de.tudarmstadt.ukp.clarin.webanno.model.AnnotationLayer;
import de.tudarmstadt.ukp.dkpro.core.api.lexmorph.type.pos.POS;
import de.tudarmstadt.ukp.inception.recommendation.api.model.Recommender;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.PredictionContext;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommenderContext;
import de.tudarmstadt.ukp.inception.recommendation.imls.lapps.traits.LappsGridRecommenderTraits;
import okhttp3.mockwebserver.MockResponse;
Expand Down Expand Up @@ -83,7 +84,7 @@ public void thatPredictingPosWorks() throws Exception
RecommenderContext context = new RecommenderContext();
CAS cas = loadData();

sut.predict(context, cas);
sut.predict(new PredictionContext(context), cas);

Collection<POS> predictions = JCasUtil.select(cas.getJCas(), POS.class);

Expand Down
12 changes: 12 additions & 0 deletions inception/inception-imls-ollama/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@
<groupId>de.tudarmstadt.ukp.inception.app</groupId>
<artifactId>inception-api-render</artifactId>
</dependency>
<dependency>
<groupId>de.tudarmstadt.ukp.inception.app</groupId>
<artifactId>inception-api-annotation</artifactId>
</dependency>
<dependency>
<groupId>de.tudarmstadt.ukp.inception.app</groupId>
<artifactId>inception-layer-docmetadata</artifactId>
</dependency>
<dependency>
<groupId>de.tudarmstadt.ukp.inception.app</groupId>
<artifactId>inception-support</artifactId>
Expand Down Expand Up @@ -96,6 +104,10 @@
<groupId>org.apache.wicket</groupId>
<artifactId>wicket-spring</artifactId>
</dependency>
<dependency>
<groupId>com.googlecode.wicket-jquery-ui</groupId>
<artifactId>wicket-kendo-ui</artifactId>
</dependency>
<dependency>
<groupId>org.danekja</groupId>
<artifactId>jdk-serializable-functional</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,24 @@

import com.fasterxml.jackson.annotation.JsonProperty;

import de.tudarmstadt.ukp.clarin.webanno.model.AnnotationLayer;
import de.tudarmstadt.ukp.inception.annotation.layer.span.SpanLayerSupport;

public enum ExtractionMode
{
@JsonProperty("response-as-label")
RESPONSE_AS_LABEL, //

@JsonProperty("mentions-from-json")
MENTIONS_FROM_JSON
MENTIONS_FROM_JSON;

public boolean accepts(AnnotationLayer aLayer)
{
if (this == MENTIONS_FROM_JSON) {
// Mention extraction only makes sense for span layers
return SpanLayerSupport.TYPE.equals(aLayer.getType());
}

return true;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,32 +19,55 @@

import static java.util.Arrays.asList;

import java.util.Collections;

import org.apache.wicket.markup.html.form.DropDownChoice;
import org.apache.wicket.markup.html.form.EnumChoiceRenderer;
import org.apache.wicket.model.IModel;

import de.tudarmstadt.ukp.inception.recommendation.api.model.Recommender;

public class ExtractionModeSelect
extends DropDownChoice<ExtractionMode>
{
private static final long serialVersionUID = 1789605828488016006L;

public ExtractionModeSelect(String aId)
{
super(aId);
}
private IModel<Recommender> recommender;

public ExtractionModeSelect(String aId, IModel<ExtractionMode> aModel)
public ExtractionModeSelect(String aId, IModel<ExtractionMode> aModel,
IModel<Recommender> aRecommender)
{
super(aId);
setModel(aModel);
recommender = aRecommender;
}

@Override
protected void onInitialize()
{
super.onInitialize();

setChoiceRenderer(new EnumChoiceRenderer<>(this));
setChoices(asList(ExtractionMode.values()));
}

@Override
protected void onConfigure()
{
super.onConfigure();

if (!recommender.isPresent().getObject()) {
setChoices(Collections.emptyList());
return;
}

var validChoices = asList(ExtractionMode.values()).stream() //
.filter(e -> e.accepts(recommender.getObject().getLayer())) //
.toList();
setChoices(validChoices);

if (validChoices.size() == 1) {
setModelObject(validChoices.get(0));
}

setVisible(validChoices.size() > 1);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import java.io.IOException;
import java.lang.invoke.MethodHandles;
import java.nio.charset.Charset;
import java.util.List;

import org.apache.commons.lang3.exception.ExceptionUtils;
import org.apache.uima.cas.CAS;
Expand All @@ -37,8 +36,8 @@

import de.tudarmstadt.ukp.inception.recommendation.api.model.Recommender;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.NonTrainableRecommenderEngineImplBase;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.PredictionContext;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationException;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommenderContext;
import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.client.OllamaClient;
import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.client.OllamaGenerateRequest;
import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.prompt.PerAnnotationContextGenerator;
Expand All @@ -47,7 +46,6 @@
import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.prompt.PromptContext;
import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.prompt.PromptContextGenerator;
import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.response.MentionsFromJsonExtractor;
import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.response.MentionsSample;
import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.response.ResponseAsLabelExtractor;
import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.response.ResponseExtractor;
import de.tudarmstadt.ukp.inception.rendering.model.Range;
Expand Down Expand Up @@ -88,12 +86,11 @@ public String getString(String aFullName, Charset aEncoding,
}

@Override
public Range predict(RecommenderContext aContext, CAS aCas, int aBegin, int aEnd)
public Range predict(PredictionContext aContext, CAS aCas, int aBegin, int aEnd)
throws RecommendationException
{
var responseExtractor = getResponseExtractor();
List<MentionsSample> examples = responseExtractor.generate(this, aCas,
MAX_FEW_SHOT_EXAMPLES);
var examples = responseExtractor.generate(this, aCas, MAX_FEW_SHOT_EXAMPLES);

getPromptContextGenerator().generate(this, aCas, aBegin, aEnd).forEach(promptContext -> {
try {
Expand All @@ -106,6 +103,8 @@ public Range predict(RecommenderContext aContext, CAS aCas, int aBegin, int aEnd
responseExtractor.extract(this, aCas, promptContext, response);
}
catch (IOException e) {
aContext.error("Ollama [%s] failed to respond: %s", traits.getModel(),
ExceptionUtils.getRootCauseMessage(e));
LOG.error("Ollama [{}] failed to respond: {}", traits.getModel(),
ExceptionUtils.getRootCauseMessage(e));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
*/
package de.tudarmstadt.ukp.inception.recommendation.imls.ollama;

import static de.tudarmstadt.ukp.inception.support.WebAnnoConst.SPAN_TYPE;
import static org.apache.uima.cas.CAS.TYPE_NAME_STRING;

import java.lang.invoke.MethodHandles;
Expand All @@ -34,12 +33,14 @@

import de.tudarmstadt.ukp.clarin.webanno.model.AnnotationFeature;
import de.tudarmstadt.ukp.clarin.webanno.model.AnnotationLayer;
import de.tudarmstadt.ukp.inception.annotation.layer.span.SpanLayerSupport;
import de.tudarmstadt.ukp.inception.recommendation.api.model.Recommender;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationEngine;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationEngineFactoryImplBase;
import de.tudarmstadt.ukp.inception.recommendation.imls.ollama.client.OllamaClient;
import de.tudarmstadt.ukp.inception.support.io.WatchedResourceFile;
import de.tudarmstadt.ukp.inception.support.yaml.YamlUtil;
import de.tudarmstadt.ukp.inception.ui.core.docanno.layer.DocumentMetadataLayerSupport;

public class OllamaRecommenderFactory
extends RecommendationEngineFactoryImplBase<OllamaRecommenderTraits>
Expand Down Expand Up @@ -87,7 +88,8 @@ public RecommendationEngine build(Recommender aRecommender)
@Override
public boolean accepts(AnnotationLayer aLayer, AnnotationFeature aFeature)
{
return SPAN_TYPE.equals(aFeature.getLayer().getType())
return (SpanLayerSupport.TYPE.equals(aFeature.getLayer().getType())
|| DocumentMetadataLayerSupport.TYPE.equals(aFeature.getLayer().getType()))
&& TYPE_NAME_STRING.equals(aFeature.getType());
}

Expand Down
Loading

0 comments on commit bdaeb95

Please sign in to comment.