Skip to content

Commit

Permalink
#5176 - Ability to prevent string matcher from learning things
Browse files Browse the repository at this point in the history
- Added minimum term length setting for learning
- Added regular expression to exclude certain kinds of terms
  • Loading branch information
reckart committed Nov 21, 2024
1 parent 34ae130 commit ccb85d7
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.regex.Pattern;
import java.util.regex.PatternSyntaxException;
import java.util.stream.IntStream;
import java.util.stream.Stream;

Expand All @@ -65,6 +67,7 @@
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommenderContext.Key;
import de.tudarmstadt.ukp.inception.recommendation.imls.stringmatch.span.gazeteer.GazeteerService;
import de.tudarmstadt.ukp.inception.recommendation.imls.stringmatch.span.gazeteer.model.GazeteerEntry;
import de.tudarmstadt.ukp.inception.recommendation.imls.stringmatch.span.trie.KeySanitizerFactory;
import de.tudarmstadt.ukp.inception.recommendation.imls.stringmatch.span.trie.Trie;
import de.tudarmstadt.ukp.inception.recommendation.imls.stringmatch.span.trie.WhitespaceNormalizingSanitizer;
import de.tudarmstadt.ukp.inception.rendering.model.Range;
Expand All @@ -87,6 +90,11 @@ public class StringMatchingRecommender

private final GazeteerService gazeteerService;

private final KeySanitizerFactory keySanitizerFactory;

private Pattern excludePattern;
private String excludePatternError;

public StringMatchingRecommender(Recommender aRecommender,
StringMatchingRecommenderTraits aTraits)
{
Expand All @@ -100,6 +108,19 @@ public StringMatchingRecommender(Recommender aRecommender,

traits = aTraits;
gazeteerService = aGazeteerService;
keySanitizerFactory = WhitespaceNormalizingSanitizer.factory();

if (traits != null && traits.getExcludePattern() != null) {
try {
excludePattern = Pattern.compile(traits.getExcludePattern());
}
catch (PatternSyntaxException e) {
excludePatternError = e.getMessage();
}
}
else {
excludePattern = null;
}
}

@Override
Expand Down Expand Up @@ -136,7 +157,7 @@ public void pretrain(List<GazeteerEntry> aData, RecommenderContext aContext)

if (aData != null) {
for (var entry : aData) {
learn(dict, entry.text, entry.label);
learn(dict, entry.text, entry.label, true);
}

aContext.log(LogMessage.info(getRecommender().getName(),
Expand All @@ -148,12 +169,17 @@ public void pretrain(List<GazeteerEntry> aData, RecommenderContext aContext)

private <T> Trie<T> createTrie()
{
return new Trie<>(WhitespaceNormalizingSanitizer.factory());
return new Trie<>(keySanitizerFactory);
}

@Override
public void train(RecommenderContext aContext, List<CAS> aCasses) throws RecommendationException
{
if (excludePatternError != null) {
aContext.log(LogMessage.error(getRecommender().getName(),
"Ignoring bad exclude pattern: %s", excludePatternError));
}

// Pre-load the gazeteers into the model
if (gazeteerService != null) {
for (var gaz : gazeteerService.listGazeteers(recommender)) {
Expand Down Expand Up @@ -185,13 +211,13 @@ public void train(RecommenderContext aContext, List<CAS> aCasses) throws Recomme
var labels = FSUtil.getFeature(ann, predictedFeature, String[].class);
if (labels != null) {
for (var label : labels) {
learn(dict, ann.getCoveredText(), label);
learn(dict, ann.getCoveredText(), label, false);
}
}
}
else {
learn(dict, ann.getCoveredText(),
ann.getFeatureValueAsString(predictedFeature));
learn(dict, ann.getCoveredText(), ann.getFeatureValueAsString(predictedFeature),
false);
}
}
}
Expand Down Expand Up @@ -356,7 +382,7 @@ public EvaluationResult evaluate(List<CAS> aCasses, DataSplitter aDataSplitter)
Trie<DictEntry> dict = createTrie();
for (var sample : trainingSet) {
for (var span : sample.getSpans()) {
learn(dict, span.text(), span.label());
learn(dict, span.text(), span.label(), false);
}
}

Expand Down Expand Up @@ -391,12 +417,22 @@ public EvaluationResult evaluate(List<CAS> aCasses, DataSplitter aDataSplitter)
SAMPLE_UNIT.getSimpleName(), trainingSetSize, testSetSize, trainRatio, NO_LABEL));
}

private void learn(Trie<DictEntry> aDict, String aText, String aLabel)
private void learn(Trie<DictEntry> aDict, String aText, String aLabel, boolean aBypassLimits)
{
if (isBlank(aText)) {
return;
}

if (!aBypassLimits && traits != null) {
if (excludePattern != null && excludePattern.matcher(aText).matches()) {
return;
}

if (keySanitizerFactory.create().sanitize(aText).length() < traits.getMinLength()) {
return;
}
}

var label = isBlank(aLabel) ? BLANK_LABEL : aLabel;

var text = aText;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,16 @@
public class StringMatchingRecommenderTraits
implements Serializable
{
private static final long serialVersionUID = -7433406243352691789L;
private static final long serialVersionUID = -7329491581513178640L;

private boolean ignoreCase;

private String excludePattern;

private int minLength = 3;

// private int maxLength = 255;

public boolean isIgnoreCase()
{
return ignoreCase;
Expand All @@ -42,4 +48,35 @@ public void setIgnoreCase(boolean aIgnoreCase)
{
ignoreCase = aIgnoreCase;
}

public int getMinLength()
{
return minLength;
}

public void setMinLength(int aMinLength)
{
minLength = aMinLength;
}

public String getExcludePattern()
{
return excludePattern;
}

public void setExcludePattern(String aExcludePattern)
{
excludePattern = aExcludePattern;
}

// public int getMaxLength()
// {
// return maxLength;
// }
//
// public void setMaxLength(int aMaxLength)
// {
// maxLength = aMaxLength;
// }

}
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,39 @@
<form wicket:id="form">
<div class="row form-row" wicket:enclosure="ignoreCase">
<div class="offset-sm-3 col-sm-9">
<div class="form-check">
<div class="form-check form-switch">
<input wicket:id="ignoreCase" class="form-check-input" type="checkbox"/>
<label wicket:for="ignoreCase" class="form-check-label">
<wicket:label key="ignoreCase"/>
</label>
</div>
</div>
</div>
<div class="row form-row" wicket:enclosure="excludePattern">
<label wicket:for="excludePattern" class="col-sm-3 col-form-label">
<wicket:label key="excludePattern"/>
</label>
<div class="col-sm-9">
<div class="input-group">
<input wicket:id="excludePattern" class="form-control"/>
<span class="input-group-text">(.*)</span>
</div>
<div class="form-text">
Regular expression that controls what can be added to the dictionary.
</div>
</div>
</div>
<div class="row form-row" wicket:enclosure="minLength">
<label wicket:for="minLength" class="col-sm-3 col-form-label">
<wicket:label key="minLength"/>
</label>
<div class="col-sm-9">
<input wicket:id="minLength" type="number" class="form-control"/>
<div class="form-text">
Minimum length for dictionary entries. Shorter entries will not be added.
</div>
</div>
</div>
</form>
<wicket:remove>
<!--
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
import org.apache.wicket.markup.html.basic.Label;
import org.apache.wicket.markup.html.form.CheckBox;
import org.apache.wicket.markup.html.form.Form;
import org.apache.wicket.markup.html.form.NumberTextField;
import org.apache.wicket.markup.html.form.TextField;
import org.apache.wicket.markup.html.link.DownloadLink;
import org.apache.wicket.markup.html.list.ListItem;
import org.apache.wicket.markup.html.list.ListView;
Expand All @@ -51,7 +53,7 @@

import de.agilecoders.wicket.extensions.markup.html.bootstrap.form.fileinput.FileInputConfig;
import de.tudarmstadt.ukp.clarin.webanno.model.AnnotationFeature;
import de.tudarmstadt.ukp.inception.bootstrap.BootstrapFileInput;
import de.tudarmstadt.ukp.inception.bootstrap.BootstrapFileInputField;
import de.tudarmstadt.ukp.inception.recommendation.api.model.Recommender;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.DefaultTrainableRecommenderTraitsEditor;
import de.tudarmstadt.ukp.inception.recommendation.api.recommender.RecommendationEngineFactory;
Expand All @@ -76,7 +78,7 @@ public class StringMatchingRecommenderTraitsEditor
private final StringMatchingRecommenderTraits traits;

private GazeteerList gazeteers;
private BootstrapFileInput uploadField;
private BootstrapFileInputField uploadField;

public StringMatchingRecommenderTraitsEditor(String aId, IModel<Recommender> aRecommender)
{
Expand All @@ -87,7 +89,7 @@ public StringMatchingRecommenderTraitsEditor(String aId, IModel<Recommender> aRe
var form = new Form<StringMatchingRecommenderTraits>(MID_FORM,
CompoundPropertyModel.of(Model.of(traits)))
{
private static final long serialVersionUID = -3109239605742291123L;
private static final long serialVersionUID = -1L;

@Override
protected void onSubmit()
Expand All @@ -96,14 +98,21 @@ protected void onSubmit()
toolFactory.writeTraits(aRecommender.getObject(), traits);
}
};
add(form);

var ignoreCase = new CheckBox("ignoreCase");
ignoreCase.setOutputMarkupId(true);
ignoreCase.add(LambdaBehavior.visibleWhen(getModel() //
.map(Recommender::getFeature) //
.map(this::isStringBasedFeature)));
form.add(ignoreCase);
add(form);

form.add(new NumberTextField<>("minLength", Integer.class) //
.setMinimum(1) //
.setMaximum(500) //
.setStep(1));

form.add(new TextField<>("excludePattern", String.class));

gazeteers = new GazeteerList("gazeteers", LoadableDetachableModel.of(this::listGazeteers));
gazeteers.add(visibleWhen(getModel().map(Recommender::getId).isPresent()));
Expand All @@ -114,9 +123,9 @@ protected void onSubmit()
config.allowedFileExtensions(asList("txt"));
config.showPreview(false);
config.showUpload(true);
uploadField = new BootstrapFileInput("upload", new ListModel<>(), config)
uploadField = new BootstrapFileInputField("upload", new ListModel<>(), config)
{
private static final long serialVersionUID = -7072183979425490246L;
private static final long serialVersionUID = -1L;

@Override
protected void onSubmit(AjaxRequestTarget aTarget)
Expand Down Expand Up @@ -195,7 +204,7 @@ private List<Gazeteer> listGazeteers()
public class GazeteerList
extends WebMarkupContainer
{
private static final long serialVersionUID = -2049981253344229438L;
private static final long serialVersionUID = -1L;

private ListView<Gazeteer> gazeteerList;

Expand All @@ -207,7 +216,7 @@ public GazeteerList(String aId, IModel<? extends List<Gazeteer>> aChoices)

gazeteerList = new ListView<Gazeteer>("gazeteer", aChoices)
{
private static final long serialVersionUID = 2827701590781214260L;
private static final long serialVersionUID = -1L;

@Override
protected void populateItem(ListItem<Gazeteer> aItem)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@
# limitations under the License.
gazeteers=Gazeteers
ignoreCase=Case insensitive
excludePattern=Exclude pattern
minLength=Minimum length

0 comments on commit ccb85d7

Please sign in to comment.