Skip to content

Commit

Permalink
#1861 - Add viewport offsets to external recommender
Browse files Browse the repository at this point in the history
- Send viewport offsets to external recommender
- Bit of cleaning up
  • Loading branch information
reckart committed Aug 22, 2023
1 parent 4bc20e6 commit 9047b78
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 123 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
import static org.springframework.http.MediaType.APPLICATION_JSON_VALUE;

import java.io.IOException;
import java.io.InputStream;
import java.io.StringWriter;
import java.lang.invoke.MethodHandles;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
Expand All @@ -52,11 +52,8 @@
import org.springframework.http.HttpHeaders;
import org.xml.sax.SAXException;

import com.fasterxml.jackson.databind.ObjectMapper;

import de.tudarmstadt.ukp.clarin.webanno.api.annotation.util.WebAnnoCasUtil;
import de.tudarmstadt.ukp.clarin.webanno.api.type.CASMetadata;
import de.tudarmstadt.ukp.clarin.webanno.model.AnnotationLayer;
import de.tudarmstadt.ukp.clarin.webanno.support.JSONUtil;
import de.tudarmstadt.ukp.inception.recommendation.api.evaluation.DataSplitter;
import de.tudarmstadt.ukp.inception.recommendation.api.evaluation.EvaluationResult;
Expand All @@ -80,7 +77,7 @@ public class ExternalRecommender
{
public static final Key<Boolean> KEY_TRAINING_COMPLETE = new Key<>("training_complete");

private static final Logger LOG = LoggerFactory.getLogger(ExternalRecommender.class);
private static final Logger LOG = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());

private static final int HTTP_TOO_MANY_REQUESTS = 429;
private static final int HTTP_BAD_REQUEST = 400;
Expand All @@ -96,7 +93,9 @@ public ExternalRecommender(ExternalRecommenderProperties aProperties, Recommende

properties = aProperties;
traits = aTraits;
client = HttpClient.newBuilder().connectTimeout(properties.getConnectTimeout()).build();
client = HttpClient.newBuilder() //
.connectTimeout(properties.getConnectTimeout()) //
.build();
}

@Override
Expand All @@ -105,18 +104,18 @@ public boolean isReadyForPrediction(RecommenderContext aContext)
if (traits.isTrainable()) {
return aContext.get(KEY_TRAINING_COMPLETE).orElse(false);
}
else {
return true;
}

return true;
}

@Override
public void train(RecommenderContext aContext, List<CAS> aCasses) throws RecommendationException
{
TrainingRequest trainingRequest = new TrainingRequest();
var trainingRequest = new TrainingRequest();

// We assume that the type system for all CAS are the same
String typeSystem = serializeTypeSystem(aCasses.get(0));
var representativeCas = aCasses.get(0);
var typeSystem = serializeTypeSystem(representativeCas);
trainingRequest.setTypeSystem(typeSystem);

// Fill in metadata. We use the type system of the first CAS in the list
Expand All @@ -126,30 +125,31 @@ public void train(RecommenderContext aContext, List<CAS> aCasses) throws Recomme
// of the other CAS. This should happen really rarely, therefore this potential
// error is neglected.

trainingRequest.setMetadata(buildMetadata(aCasses.get(0)));
trainingRequest.setMetadata(
buildMetadata(representativeCas, Range.rangeCoveringDocument(representativeCas)));

List<Document> documents = new ArrayList<>();
var documents = new ArrayList<Document>();
for (CAS cas : aCasses) {
documents.add(buildDocument(cas));
}
trainingRequest.setDocuments(documents);

HttpRequest request = HttpRequest.newBuilder() //
var request = HttpRequest.newBuilder() //
.uri(URI.create(appendIfMissing(traits.getRemoteUrl(), "/")).resolve("train")) //
.header(HttpHeaders.CONTENT_TYPE, APPLICATION_JSON_VALUE) //
.timeout(properties.getReadTimeout())
.POST(BodyPublishers.ofString(toJson(trainingRequest), UTF_8)).build();

HttpResponse<String> response = sendRequest(request);
var response = sendRequest(request);
if (response.statusCode() == HTTP_TOO_MANY_REQUESTS) {
LOG.info("External recommender is already training");
}

// If the response indicates that the request was not successful,
// then it does not make sense to go on and try to decode the XMI
else if (response.statusCode() >= HTTP_BAD_REQUEST) {
String responseBody = getResponseBody(response);
String msg = format("Request was not successful: [%d] - [%s]", response.statusCode(),
var responseBody = getResponseBody(response);
var msg = format("Request was not successful: [%d] - [%s]", response.statusCode(),
responseBody);
throw new RecommendationException(msg);
}
Expand All @@ -161,35 +161,35 @@ else if (response.statusCode() >= HTTP_BAD_REQUEST) {
public Range predict(RecommenderContext aContext, CAS aCas, int aBegin, int aEnd)
throws RecommendationException
{
String typeSystem = serializeTypeSystem(aCas);
var typeSystem = serializeTypeSystem(aCas);

PredictionRequest predictionRequest = new PredictionRequest();
var predictionRequest = new PredictionRequest();
predictionRequest.setTypeSystem(typeSystem);
predictionRequest.setDocument(buildDocument(aCas));

// Fill in metadata
predictionRequest.setMetadata(buildMetadata(aCas));
predictionRequest.setMetadata(buildMetadata(aCas, new Range(aBegin, aEnd)));

HttpRequest request = HttpRequest.newBuilder() //
var request = HttpRequest.newBuilder() //
.uri(URI.create(appendIfMissing(traits.getRemoteUrl(), "/")).resolve("predict")) //
.header(HttpHeaders.CONTENT_TYPE, APPLICATION_JSON_VALUE) //
.timeout(properties.getReadTimeout()) //
.POST(BodyPublishers.ofString(toJson(predictionRequest), UTF_8)) //
.build();

HttpResponse<String> response = sendRequest(request);
var response = sendRequest(request);
// If the response indicates that the request was not successful,
// then it does not make sense to go on and try to decode the XMI
if (response.statusCode() >= HTTP_BAD_REQUEST) {
String responseBody = getResponseBody(response);
String msg = format("Request was not successful: [%d] - [%s]", response.statusCode(),
var responseBody = getResponseBody(response);
var msg = format("Request was not successful: [%d] - [%s]", response.statusCode(),
responseBody);
throw new RecommendationException(msg);
}

PredictionResponse predictionResponse = deserializePredictionResponse(response);
var predictionResponse = deserializePredictionResponse(response);

try (InputStream is = IOUtils.toInputStream(predictionResponse.getDocument(), UTF_8)) {
try (var is = IOUtils.toInputStream(predictionResponse.getDocument(), UTF_8)) {
XmiCasDeserializer.deserialize(is, WebAnnoCasUtil.getRealCas(aCas), true);
}
catch (SAXException | IOException e) {
Expand Down Expand Up @@ -226,7 +226,7 @@ private String serializeCas(CAS aCas) throws RecommendationException
try (var out = new StringWriter()) {
// Passing "null" as the type system to the XmiCasSerializer means that we want
// to serialize all types (i.e. no filtering for a specific target type system).
XmiCasSerializer xmiCasSerializer = new XmiCasSerializer(null);
var xmiCasSerializer = new XmiCasSerializer(null);
var contentHandler = new XMLSerializer(out, true).getContentHandler();
contentHandler = new IllegalXmlCharacterSanitizingContentHandler(contentHandler);
xmiCasSerializer.serialize(getRealCas(aCas), contentHandler, null, null, null);
Expand All @@ -239,10 +239,10 @@ private String serializeCas(CAS aCas) throws RecommendationException

private Document buildDocument(CAS aCas) throws RecommendationException
{
CASMetadata casMetadata = getCasMetadata(aCas);
String xmi = serializeCas(aCas);
long documentId = casMetadata.getSourceDocumentId();
String userId = casMetadata.getUsername();
var casMetadata = getCasMetadata(aCas);
var xmi = serializeCas(aCas);
var documentId = casMetadata.getSourceDocumentId();
var userId = casMetadata.getUsername();

return new Document(xmi, documentId, userId);
}
Expand All @@ -257,21 +257,20 @@ private CASMetadata getCasMetadata(CAS aCas) throws RecommendationException
}
}

private Metadata buildMetadata(CAS aCas) throws RecommendationException
private Metadata buildMetadata(CAS aCas, Range aRange) throws RecommendationException
{
CASMetadata casMetadata = getCasMetadata(aCas);
AnnotationLayer layer = recommender.getLayer();
var casMetadata = getCasMetadata(aCas);
var layer = recommender.getLayer();
return new Metadata(layer.getName(), recommender.getFeature().getName(),
casMetadata.getProjectId(), layer.getAnchoringMode().getId(),
layer.isCrossSentence());
layer.isCrossSentence(), aRange);
}

private PredictionResponse deserializePredictionResponse(HttpResponse<String> response)
throws RecommendationException
{
ObjectMapper objectMapper = new ObjectMapper();
try {
return objectMapper.readValue(response.body(), PredictionResponse.class);
return JSONUtil.fromJsonString(PredictionResponse.class, response.body());
}
catch (IOException e) {
throw new RecommendationException("Error while deserializing prediction response!", e);
Expand All @@ -298,14 +297,13 @@ private HttpResponse<String> sendRequest(HttpRequest aRequest) throws Recommenda
}
}

private String getResponseBody(HttpResponse<String> response) throws RecommendationException
private String getResponseBody(HttpResponse<String> response)
{
if (response.body() != null) {
return response.body();
}
else {
if (response.body() == null) {
return "";
}

return response.body();
}

@Override
Expand All @@ -317,7 +315,7 @@ public int estimateSampleCount(List<CAS> aCasses)
@Override
public EvaluationResult evaluate(List<CAS> aCasses, DataSplitter aDataSplitter)
{
EvaluationResult result = new EvaluationResult();
var result = new EvaluationResult();
result.setEvaluationSkipped(true);
result.setErrorMsg("ExternalRecommender does not support evaluation.");
return result;
Expand All @@ -326,14 +324,12 @@ public EvaluationResult evaluate(List<CAS> aCasses, DataSplitter aDataSplitter)
@Override
public TrainingCapability getTrainingCapability()
{
if (traits.isTrainable()) {
//
// return TRAINING_SUPPORTED;
// We need to get at least one training CAS because we need to extract the type system
return TRAINING_REQUIRED;
}
else {
if (!traits.isTrainable()) {
return TRAINING_NOT_SUPPORTED;
}

// return TRAINING_SUPPORTED;
// We need to get at least one training CAS because we need to extract the type system
return TRAINING_REQUIRED;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@

import com.fasterxml.jackson.annotation.JsonProperty;

import de.tudarmstadt.ukp.inception.rendering.model.Range;

public class Metadata
{

private final Range range;
private final String layer;
private final String feature;
private final long projectId;
Expand All @@ -32,13 +34,15 @@ public Metadata(@JsonProperty(value = "layer", required = true) String aLayer,
@JsonProperty(value = "feature", required = true) String aFeature,
@JsonProperty(value = "projectId", required = true) long aProjectId,
@JsonProperty(value = "anchoringMode", required = true) String aAnchoringMode,
@JsonProperty(value = "crossSentence", required = true) boolean aCrossSentence)
@JsonProperty(value = "crossSentence", required = true) boolean aCrossSentence,
@JsonProperty(value = "range", required = true) Range aRange)
{
layer = aLayer;
feature = aFeature;
projectId = aProjectId;
anchoringMode = aAnchoringMode;
crossSentence = aCrossSentence;
range = aRange;
}

public String getLayer()
Expand All @@ -65,4 +69,9 @@ public boolean isCrossSentence()
{
return crossSentence;
}

public Range getRange()
{
return range;
}
}
Loading

0 comments on commit 9047b78

Please sign in to comment.