Skip to content

Commit

Permalink
Merge pull request #4417 from inception-project/feature/4399-Allow-Op…
Browse files Browse the repository at this point in the history
…enNLP-Multi-Token-Sequence-Classifier-to-work-for-cross-sentence-layers

#4399 - Allow open nlp multi token sequence classifier to work for cross sentence layers
  • Loading branch information
reckart authored Dec 31, 2023
2 parents 6375f93 + d4582b0 commit 9bdc272
Show file tree
Hide file tree
Showing 12 changed files with 377 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,14 @@ public void thatSlotFeatureInConditionWorks() throws Exception
.withFeature("links", asList(
buildFS(cas, "webanno.custom.ComplexLinkType")
.withFeature("target", buildAnnotation(cas, "webanno.custom.Span")
.on("ACME")
.withFeature("value", "PER")
.on("ACME") //
.withFeature("value", "PER") //
.buildAndAddToIndexes())
.buildWithoutAddingToIndexes(),
buildFS(cas, "webanno.custom.ComplexLinkType")
.withFeature("target", buildAnnotation(cas, "webanno.custom.Span")
.on("Foobar")
.withFeature("value", "LOC")
.on("Foobar") //
.withFeature("value", "LOC") //
.buildAndAddToIndexes())
.buildWithoutAddingToIndexes()))
.buildAndAddToIndexes();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ void testGenerateExamples()
{
String text = "John likes Mary.";
cas.setDocumentText(text);
buildAnnotation(cas, Sentence.class).on(text) //
buildAnnotation(cas, Sentence.class).onMatch(text) //
.buildAndAddToIndexes();
buildAnnotation(cas, NamedEntity.class).on("John") //
.withFeature(NamedEntity._FeatName_value, "PER") //
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Objects;

Expand Down Expand Up @@ -64,13 +66,17 @@ public class OpenNlpNerRecommender
extends RecommendationEngine
{
public static final Key<TokenNameFinderModel> KEY_MODEL = new Key<>("opennlp_ner_model");

private static final Logger LOG = LoggerFactory.getLogger(OpenNlpNerRecommender.class);

private static final String NO_NE_TAG = "O";

private static final Class<Sentence> SAMPLE_UNIT = Sentence.class;
private static final Class<Token> DATAPOINT_UNIT = Token.class;

private static final int DEFAULT_WINDOW_SIZE = 300;
private static final int MIN_WINDOW_SIZE = 30;

private static final int MIN_TRAINING_SET_SIZE = 2;
private static final int MIN_TEST_SET_SIZE = 2;

Expand All @@ -92,7 +98,7 @@ public boolean isReadyForPrediction(RecommenderContext aContext)
@Override
public void train(RecommenderContext aContext, List<CAS> aCasses) throws RecommendationException
{
var nameSamples = extractNameSamples(aCasses);
var nameSamples = extractSamples(aContext, aCasses);

if (nameSamples.size() < 2) {
aContext.log(LogMessage.warn(getRecommender().getName(),
Expand Down Expand Up @@ -140,9 +146,11 @@ public Range predict(PredictionContext aContext, CAS aCas, int aBegin, int aEnd)
var predictionCount = 0;

for (var unit : units) {
if (predictionCount >= traits.getPredictionLimit()) {
int predictionsLimit = traits.getPredictionLimit();
if (predictionsLimit > 0 && predictionCount >= predictionsLimit) {
break;
}

predictionCount++;

var tokenAnnotations = selectCovered(tokenType, unit);
Expand Down Expand Up @@ -178,14 +186,17 @@ public Range predict(PredictionContext aContext, CAS aCas, int aBegin, int aEnd)
@Override
public int estimateSampleCount(List<CAS> aCasses)
{
return extractNameSamples(aCasses).size();
return extractSamples(null, aCasses).size();
}

@Override
public EvaluationResult evaluate(List<CAS> aCasses, DataSplitter aDataSplitter)
throws RecommendationException
{
var data = extractNameSamples(aCasses);
// We use sentence-based samples here even if the layer allows cross-sentence annotations
// because with the overlapping sliding window, the evaluation would otherwise train on test
// data.
var data = extractSamplesFromSentences(aCasses);
var trainingSet = new ArrayList<NameSample>();
var testSet = new ArrayList<NameSample>();

Expand All @@ -209,7 +220,7 @@ public EvaluationResult evaluate(List<CAS> aCasses, DataSplitter aDataSplitter)
var trainRatio = (overallTrainingSize > 0) ? trainingSetSize / overallTrainingSize : 0.0;

if (trainingSetSize < MIN_TRAINING_SET_SIZE || testSetSize < MIN_TEST_SET_SIZE) {
String msg = String.format(
var msg = String.format(
"Not enough evaluation data: training set size [%d] (min. %d), test set size [%d] (min. %d) of total [%d] (min. %d)",
trainingSetSize, MIN_TRAINING_SET_SIZE, testSetSize, MIN_TEST_SET_SIZE,
data.size(), (MIN_TRAINING_SET_SIZE + MIN_TEST_SET_SIZE));
Expand Down Expand Up @@ -243,7 +254,7 @@ public EvaluationResult evaluate(List<CAS> aCasses, DataSplitter aDataSplitter)
var predictedNames = nameFinder.find(sampleTokens);
var goldNames = sample.getNames();

labelPairs.addAll(determineLabelsForASentence(sampleTokens, predictedNames, goldNames));
labelPairs.addAll(determineLabelsForSample(sampleTokens, predictedNames, goldNames));
}

return labelPairs.stream().collect(toEvaluationResult(DATAPOINT_UNIT.getSimpleName(),
Expand All @@ -254,7 +265,7 @@ public EvaluationResult evaluate(List<CAS> aCasses, DataSplitter aDataSplitter)
* Extract AnnotatedTokenPairs with info on predicted and gold label for each token of the given
* sentence.
*/
private List<LabelPair> determineLabelsForASentence(String[] sentence, Span[] predictedNames,
private List<LabelPair> determineLabelsForSample(String[] sentence, Span[] predictedNames,
Span[] goldNames)
{
int predictedNameIdx = 0;
Expand Down Expand Up @@ -307,7 +318,20 @@ private String determineLabel(Span aName, int aTokenIdx)
return label;
}

private List<NameSample> extractNameSamples(Iterable<CAS> aCasses)
private List<NameSample> extractSamples(RecommenderContext aContext, Iterable<CAS> aCasses)
{
if (getRecommender().getLayer().isCrossSentence()) {
if (aContext != null) {
aContext.log(LogMessage.info(getRecommender().getName(),
"Training using sliding-window since layer permits cross-sentence annotations."));
}
return extractSamplesUsingSlidingWindow(aCasses);
}

return extractSamplesFromSentences(aCasses);
}

private List<NameSample> extractSamplesFromSentences(Iterable<CAS> aCasses)
{
var nameSamples = new ArrayList<NameSample>();

Expand All @@ -317,7 +341,8 @@ private List<NameSample> extractNameSamples(Iterable<CAS> aCasses)

var firstSampleInCas = true;
for (var sampleUnit : cas.<Annotation> select(sampleUnitType)) {
if (nameSamples.size() >= traits.getTrainingSetSizeLimit()) {
int trainingSetSizeLimit = traits.getTrainingSetSizeLimit();
if (trainingSetSizeLimit > 0 && nameSamples.size() >= trainingSetSizeLimit) {
break nextCas;
}

Expand All @@ -328,7 +353,7 @@ private List<NameSample> extractNameSamples(Iterable<CAS> aCasses)
var tokens = cas.<Annotation> select(tokenType).coveredBy(sampleUnit).asList();
var tokenTexts = tokens.stream().map(AnnotationFS::getCoveredText)
.toArray(String[]::new);
var annotatedSpans = extractAnnotatedSpans(cas, sampleUnit, tokens);
var annotatedSpans = extractAnnotatedSpans(cas, tokens);
if (annotatedSpans.length == 0) {
continue;
}
Expand All @@ -342,18 +367,143 @@ private List<NameSample> extractNameSamples(Iterable<CAS> aCasses)
return nameSamples;
}

private Span[] extractAnnotatedSpans(CAS aCas, AnnotationFS aSampleUnit,
Collection<? extends AnnotationFS> aTokens)
private List<NameSample> extractSamplesUsingSlidingWindow(Iterable<CAS> aCasses)
{
var nameSamples = new ArrayList<NameSample>();

nextCas: for (var cas : aCasses) {
var windowSize = getWindowSize(cas);
var windowOverlap = windowSize / 2;

var firstSampleInCas = true;
var tokenIterator = cas.select(Token.class).iterator();
var tokens = makeSample(tokenIterator, new LinkedList<Token>(), windowSize,
windowOverlap);

while (!tokens.isEmpty()) {
int trainingSetSizeLimit = traits.getTrainingSetSizeLimit();
if (trainingSetSizeLimit > 0 && nameSamples.size() >= trainingSetSizeLimit) {
// Generated maximum number of samples
break nextCas;
}

var tokenTexts = tokens.stream() //
.map(AnnotationFS::getCoveredText) //
.toArray(String[]::new);
var annotatedSpans = extractAnnotatedSpans(cas, tokens);
if (annotatedSpans.length > 0) {
var nameSample = new NameSample(tokenTexts, annotatedSpans, firstSampleInCas);
nameSamples.add(nameSample);
firstSampleInCas = false;
}

tokens = makeSample(tokenIterator, tokens, windowSize, windowOverlap);
}
}

return nameSamples;
}

private int getWindowSize(CAS aCas)
{
int textLengh = aCas.getDocumentText().length();

int windowSize = traits.getWindowSize();
if (windowSize <= 0) {
windowSize = DEFAULT_WINDOW_SIZE;
}

// If the document is short try scaling down the window size to get a
// few more samples.
int minDesiredSamples = 10;
if (windowSize * minDesiredSamples > textLengh) {
windowSize = textLengh / minDesiredSamples;
}

// If the document is too short to accommodate the minimum training set size
// with the current window size, scale the window size down.
if (windowSize * MIN_TRAINING_SET_SIZE > textLengh) {
windowSize = textLengh / MIN_TRAINING_SET_SIZE;
}

if (windowSize < MIN_WINDOW_SIZE) {
windowSize = MIN_WINDOW_SIZE;
}

return windowSize;
}

private List<Token> makeSample(Iterator<Token> aFreshTokenIterator, List<Token> aTokens,
int aMaxLength, int aOverlap)
{
if (!aFreshTokenIterator.hasNext()) {
return Collections.emptyList();
}

var result = new LinkedList<Token>();

// Add tokens overlapping with previous sample
var size = 0;
if (aOverlap > 0) {
var overlapIterator = result.descendingIterator();
while (overlapIterator.hasNext()) {
var token = overlapIterator.next();
var tokenText = token.getCoveredText();

if (isBlank(tokenText)) {
continue;
}

size += tokenText.length();
if (size >= aOverlap && !result.isEmpty()) {
// Overlap size reached
break;
}
result.add(0, token);
}
}

// Add fresh tokens
var freshTokenAdded = false;
while (aFreshTokenIterator.hasNext()) {
var token = aFreshTokenIterator.next();
var tokenText = token.getCoveredText();

if (isBlank(tokenText)) {
continue;
}

size += tokenText.length();
if (size >= aMaxLength && freshTokenAdded) {
// Maximum sample size reached
break;
}

result.add(token);
freshTokenAdded = true;
}

if (!freshTokenAdded) {
return Collections.emptyList();
}

return result;
}

private Span[] extractAnnotatedSpans(CAS aCas, List<? extends AnnotationFS> aTokens)
{
if (aTokens.isEmpty()) {
return new Span[0];
}
// Create spans from target annotations

// Collect relevant annotations
var annotationType = getType(aCas, layerName);
var feature = annotationType.getFeatureByBaseName(featureName);
var annotations = selectCovered(annotationType, aSampleUnit);

var windowBegin = aTokens.get(0).getBegin();
var windowEnd = aTokens.get(aTokens.size() - 1).getEnd();
var annotations = aCas.<Annotation> select(annotationType) //
.coveredBy(windowBegin, windowEnd) //
.asList();
if (annotations.isEmpty()) {
return new Span[0];
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,16 @@
*/
package de.tudarmstadt.ukp.inception.recommendation.imls.opennlp.ner;

import static de.tudarmstadt.ukp.clarin.webanno.model.AnchoringMode.SENTENCES;
import static de.tudarmstadt.ukp.clarin.webanno.model.AnchoringMode.SINGLE_TOKEN;
import static de.tudarmstadt.ukp.clarin.webanno.model.AnchoringMode.TOKENS;
import static de.tudarmstadt.ukp.inception.support.WebAnnoConst.SPAN_TYPE;
import static java.util.Arrays.asList;

import org.apache.uima.cas.CAS;

import de.tudarmstadt.ukp.clarin.webanno.model.AnnotationFeature;
import de.tudarmstadt.ukp.clarin.webanno.model.AnnotationLayer;
import de.tudarmstadt.ukp.inception.annotation.layer.span.SpanLayerSupport;
import de.tudarmstadt.ukp.inception.recommendation.api.model.Recommender;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationEngine;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationEngineFactoryImplBase;
Expand Down Expand Up @@ -67,8 +68,8 @@ public boolean accepts(AnnotationLayer aLayer, AnnotationFeature aFeature)
return false;
}

return (asList(SINGLE_TOKEN, TOKENS).contains(aLayer.getAnchoringMode()))
&& !aLayer.isCrossSentence() && SPAN_TYPE.equals(aLayer.getType())
return (asList(SINGLE_TOKEN, TOKENS, SENTENCES).contains(aLayer.getAnchoringMode()))
&& SpanLayerSupport.TYPE.equals(aLayer.getType())
&& (CAS.TYPE_NAME_STRING.equals(aFeature.getType()) || aFeature.isVirtualFeature());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;

import opennlp.tools.util.TrainingParameters;

Expand All @@ -30,8 +32,9 @@ public class OpenNlpNerRecommenderTraits
{
private static final long serialVersionUID = 7717316701623340670L;

private int trainingSetSizeLimit = Integer.MAX_VALUE;
private int predictionLimit = Integer.MAX_VALUE;
private @JsonInclude(Include.NON_DEFAULT) int trainingSetSizeLimit = 0;
private @JsonInclude(Include.NON_DEFAULT) int predictionLimit = 0;
private @JsonInclude(Include.NON_DEFAULT) int windowSize = 0;

private int numThreads = 1;

Expand Down Expand Up @@ -65,6 +68,16 @@ public void setPredictionLimit(int aPredictionLimit)
predictionLimit = aPredictionLimit;
}

public void setWindowSize(int aWindowSize)
{
windowSize = aWindowSize;
}

public int getWindowSize()
{
return windowSize;
}

@JsonIgnore
public TrainingParameters getParameters()
{
Expand Down
Loading

0 comments on commit 9bdc272

Please sign in to comment.