Skip to content

Commit

Permalink
#5003 - Allow exporting models from the OpenNLP recommenders
Browse files Browse the repository at this point in the history
- Added export option for the OpenNLP models
  • Loading branch information
reckart committed Aug 17, 2024
1 parent cf93467 commit b58f8da
Show file tree
Hide file tree
Showing 14 changed files with 181 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -95,23 +95,31 @@ public boolean equals(final Object other)
return false;
}
CasKey castOther = (CasKey) other;
return new EqualsBuilder().append(projectId, castOther.projectId)
.append(documentId, castOther.documentId).append(userId, castOther.userId)
return new EqualsBuilder() //
.append(projectId, castOther.projectId) //
.append(documentId, castOther.documentId) //
.append(userId, castOther.userId) //
.isEquals();
}

@Override
public int hashCode()
{
return new HashCodeBuilder().append(projectId).append(documentId).append(userId)
return new HashCodeBuilder() //
.append(projectId) //
.append(documentId) //
.append(userId) //
.toHashCode();
}

@Override
public String toString()
{
return new ToStringBuilder(this, ToStringStyle.NO_CLASS_NAME_STYLE).append("p", projectId)
.append("d", documentId).append("u", userId).toString();
return new ToStringBuilder(this, ToStringStyle.NO_CLASS_NAME_STYLE) //
.append("p", projectId) //
.append("d", documentId) //
.append("u", userId) //
.toString();
}

public static CasKey matchingAllFromProject(Project aProject)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,16 @@ public OpenNlpDoccatRecommenderTraitsEditor createTraitsEditor(String aId,
{
return new OpenNlpDoccatRecommenderTraitsEditor(aId, aModel);
}

@Override
public boolean isModelExportSupported()
{
return true;
}

@Override
public String getExportModelName(Recommender aRecommender)
{
return "model.bin";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import static org.apache.uima.fit.util.CasUtil.selectCovered;

import java.io.IOException;
import java.io.OutputStream;
import java.lang.invoke.MethodHandles;
import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -312,4 +313,13 @@ private DoccatModel train(List<DocumentSample> aSamples, TrainingParameters aPar
"Exception during training the OpenNLP Document Categorizer model", e);
}
}

@Override
public void exportModel(RecommenderContext aContext, OutputStream aOutput) throws IOException
{
var model = aContext.get(KEY_MODEL)
.orElseThrow(() -> new IOException("No model trained yet."));

model.serialize(aOutput);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import static org.apache.uima.fit.util.CasUtil.getType;

import java.io.IOException;
import java.io.OutputStream;
import java.lang.invoke.MethodHandles;
import java.util.ArrayList;
import java.util.LinkedHashMap;
Expand Down Expand Up @@ -560,4 +561,13 @@ private TokenNameFinderModel train(List<NameSample> aNameSamples,
throw new RecommendationException("Error while training OpenNLP pos", e);
}
}

@Override
public void exportModel(RecommenderContext aContext, OutputStream aOutput) throws IOException
{
var model = aContext.get(KEY_MODEL)
.orElseThrow(() -> new IOException("No model trained yet."));

model.serialize(aOutput);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public String getId()
@Override
public RecommendationEngine build(Recommender aRecommender)
{
OpenNlpNerRecommenderTraits traits = new OpenNlpNerRecommenderTraits();
var traits = new OpenNlpNerRecommenderTraits();
return new OpenNlpNerRecommender(aRecommender, traits);
}

Expand Down Expand Up @@ -79,4 +79,16 @@ public OpenNlpNerRecommenderTraits createTraits()
{
return new OpenNlpNerRecommenderTraits();
}

@Override
public boolean isModelExportSupported()
{
return true;
}

@Override
public String getExportModelName(Recommender aRecommender)
{
return "model.bin";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,16 @@ public boolean isMultipleRecommendationProvider()
{
return true;
}

@Override
public boolean isModelExportSupported()
{
return true;
}

@Override
public String getExportModelName(Recommender aRecommender)
{
return "model.bin";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
package de.tudarmstadt.ukp.inception.recommendation.imls.stringmatch.span.gazeteer;

import static de.tudarmstadt.ukp.inception.project.api.ProjectService.withProjectLogger;
import static java.lang.String.join;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.apache.commons.lang3.StringUtils.trimToNull;

Expand All @@ -30,12 +31,11 @@
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.lang.invoke.MethodHandles;
import java.util.ArrayList;
import java.util.List;

import org.apache.commons.io.IOUtils;
import org.apache.commons.io.LineIterator;
import org.apache.commons.lang3.Validate;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -58,7 +58,7 @@
public class GazeteerServiceImpl
implements GazeteerService
{
private final Logger log = LoggerFactory.getLogger(getClass());
private static final Logger LOG = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());

private final EntityManager entityManager;

Expand All @@ -76,13 +76,14 @@ public GazeteerServiceImpl(RepositoryProperties aRepositoryProperties,
@Transactional
public List<Gazeteer> listGazeteers(Recommender aRecommender)
{
String query = String.join("\n", //
var query = join("\n", //
"FROM Gazeteer", //
"WHERE recommender = :recommender ", //
"ORDER BY name ASC");

return entityManager.createQuery(query, Gazeteer.class)
.setParameter("recommender", aRecommender).getResultList();
return entityManager.createQuery(query, Gazeteer.class) //
.setParameter("recommender", aRecommender) //
.getResultList();
}

@Override
Expand All @@ -93,14 +94,14 @@ public void createOrUpdateGazeteer(Gazeteer aGazeteer)
if (aGazeteer.getId() == null) {
entityManager.persist(aGazeteer);

log.info("Created gazeteer [{}] for recommender {} in project {}",
LOG.info("Created gazeteer [{}] for recommender {} in project {}",
aGazeteer.getName(), aGazeteer.getRecommender(),
aGazeteer.getRecommender().getProject());
}
else {
entityManager.merge(aGazeteer);

log.info("Updated gazeteer [{}] for recommender {} in project {}",
LOG.info("Updated gazeteer [{}] for recommender {} in project {}",
aGazeteer.getName(), aGazeteer.getRecommender(),
aGazeteer.getRecommender().getProject());
}
Expand All @@ -111,51 +112,55 @@ public void createOrUpdateGazeteer(Gazeteer aGazeteer)
@Transactional
public void importGazeteerFile(Gazeteer aGazeteer, InputStream aStream) throws IOException
{
File gazFile = getGazeteerFile(aGazeteer);
var gazFile = getGazeteerFile(aGazeteer);

if (!gazFile.getParentFile().exists()) {
gazFile.getParentFile().mkdirs();
}

try (OutputStream os = new FileOutputStream(gazFile)) {
try (var os = new FileOutputStream(gazFile)) {
IOUtils.copyLarge(aStream, os);
}
}

@Override
public File getGazeteerFile(Gazeteer aGazeteer) throws IOException
{
return repositoryProperties.getPath().toPath().resolve("project")
.resolve(String.valueOf(aGazeteer.getRecommender().getProject().getId()))
.resolve("gazeteer").resolve(aGazeteer.getId() + ".txt").toFile();
return repositoryProperties.getPath().toPath() //
.resolve("project") //
.resolve(String.valueOf(aGazeteer.getRecommender().getProject().getId())) //
.resolve("gazeteer") //
.resolve(aGazeteer.getId() + ".txt") //
.toFile();
}

@Override
@Transactional
public void deleteGazeteers(Gazeteer aGazeteer) throws IOException
{
try (var logCtx = withProjectLogger(aGazeteer.getRecommender().getProject())) {
entityManager.remove(
entityManager.contains(aGazeteer) ? aGazeteer : entityManager.merge(aGazeteer));
entityManager.remove(entityManager.contains(aGazeteer) //
? aGazeteer //
: entityManager.merge(aGazeteer));

File gaz = getGazeteerFile(aGazeteer);
var gaz = getGazeteerFile(aGazeteer);
if (gaz.exists()) {
gaz.delete();
}

log.info("Removed gazeteer [{}] for recommender {} in project {}", aGazeteer.getName(),
LOG.info("Removed gazeteer [{}] for recommender {} in project {}", aGazeteer.getName(),
aGazeteer.getRecommender(), aGazeteer.getRecommender().getProject());
}
}

@Override
public List<GazeteerEntry> readGazeteerFile(Gazeteer aGaz) throws IOException
{
File file = getGazeteerFile(aGaz);
var file = getGazeteerFile(aGaz);

List<GazeteerEntry> data = new ArrayList<>();
var data = new ArrayList<GazeteerEntry>();

try (InputStream is = new FileInputStream(file)) {
try (var is = new FileInputStream(file)) {
parseGazeteer(aGaz, is, data);
}

Expand All @@ -166,20 +171,20 @@ public void parseGazeteer(Gazeteer aGaz, InputStream aStream, List<GazeteerEntry
throws IOException
{
int lineNumber = 0;
LineIterator i = IOUtils.lineIterator(aStream, UTF_8);
var i = IOUtils.lineIterator(aStream, UTF_8);
while (i.hasNext()) {
lineNumber++;
String line = i.nextLine().trim();
var line = i.nextLine().trim();

if (line.isEmpty() || line.startsWith("#")) {
// Ignore comment lines and empty lines
continue;
}

String[] fields = line.split("\t");
var fields = line.split("\t");
if (fields.length >= 2) {
String text = trimToNull(fields[0]);
String label = trimToNull(fields[1]);
var text = trimToNull(fields[0]);
var label = trimToNull(fields[1]);
if (label != null && text != null) {
aTarget.add(new GazeteerEntry(text, label));
}
Expand All @@ -198,11 +203,10 @@ public boolean existsGazeteer(Recommender aRecommender, String aName)
Validate.notNull(aRecommender, "Recommender must be specified");
Validate.notNull(aName, "Gazeteer name must be specified");

String query = "SELECT COUNT(*) " + //
var query = "SELECT COUNT(*) " + //
"FROM Gazeteer " + //
"WHERE recommender = :recommender AND name = :name";

long count = entityManager.createQuery(query, Long.class)
var count = entityManager.createQuery(query, Long.class)
.setParameter("recommender", aRecommender).setParameter("name", aName)
.getSingleResult();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ public class GazeteerEntry

public GazeteerEntry(String aText, String aLabel)
{
super();
text = aText;
label = aLabel;
}
Expand All @@ -44,19 +43,27 @@ public boolean equals(final Object other)
return false;
}
GazeteerEntry castOther = (GazeteerEntry) other;
return new EqualsBuilder().append(text, castOther.text).append(label, castOther.label)
return new EqualsBuilder() //
.append(text, castOther.text) //
.append(label, castOther.label) //
.isEquals();
}

@Override
public int hashCode()
{
return new HashCodeBuilder().append(text).append(label).toHashCode();
return new HashCodeBuilder() //
.append(text) //
.append(label) //
.toHashCode();
}

@Override
public String toString()
{
return new ToStringBuilder(this).append("text", text).append("label", label).toString();
return new ToStringBuilder(this) //
.append("text", text) //
.append("label", label) //
.toString();
}
}
Loading

0 comments on commit b58f8da

Please sign in to comment.