Skip to content

Commit

Permalink
#34 - Improve and refactor Active Learning code
Browse files Browse the repository at this point in the history
- Moved some of the methods from ActiveLearningRecommender into a new ActiveLearningService(Impl) for easier access to other backend services.
  • Loading branch information
reckart committed Apr 11, 2018
1 parent 87c25db commit 17690ac
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 78 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright 2018
* Ubiquitous Knowledge Processing (UKP) Lab
* Technische Universität Darmstadt
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package de.tudarmstadt.ukp.inception.active.learning;

import java.util.List;

import org.apache.uima.jcas.JCas;

import de.tudarmstadt.ukp.clarin.webanno.api.annotation.model.AnnotatorState;
import de.tudarmstadt.ukp.clarin.webanno.model.AnnotationLayer;
import de.tudarmstadt.ukp.inception.recommendation.imls.core.dataobjects.AnnotationObject;
import de.tudarmstadt.ukp.inception.recommendation.model.Predictions;

public interface ActiveLearningService
{
List<List<AnnotationObject>> getRecommendationsForWholeProject(Predictions model,
AnnotationLayer aLayer);

List<List<AnnotationObject>> getRecommendationFromRecommendationModel(AnnotatorState aState,
AnnotationLayer aLayer);

List<AnnotationObject> getFlattenedRecommendationsFromRecommendationModel(JCas aJcas,
AnnotatorState aState, AnnotationLayer aSelectedLayer);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
/*
* Copyright 2018
* Ubiquitous Knowledge Processing (UKP) Lab
* Technische Universität Darmstadt
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package de.tudarmstadt.ukp.inception.active.learning;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.apache.uima.jcas.JCas;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import de.tudarmstadt.ukp.clarin.webanno.api.DocumentService;
import de.tudarmstadt.ukp.clarin.webanno.api.annotation.model.AnnotatorState;
import de.tudarmstadt.ukp.clarin.webanno.model.AnnotationLayer;
import de.tudarmstadt.ukp.inception.recommendation.imls.core.dataobjects.AnnotationObject;
import de.tudarmstadt.ukp.inception.recommendation.model.Predictions;
import de.tudarmstadt.ukp.inception.recommendation.service.RecommendationService;

@Component
public class ActiveLearningServiceImpl
implements ActiveLearningService
{
private final DocumentService documentService;
private final RecommendationService recommendationService;

@Autowired
public ActiveLearningServiceImpl(DocumentService aDocumentService,
RecommendationService aRecommendationService)
{
documentService = aDocumentService;
recommendationService = aRecommendationService;
}

@Override
public List<List<AnnotationObject>> getRecommendationFromRecommendationModel(
AnnotatorState aState, AnnotationLayer aLayer)
{
Predictions model = recommendationService.getPredictions(aState.getUser(),
aState.getProject());

if (model == null) {
return new ArrayList<>();
}

// getRecommendationsForThisDocument(model);
return getRecommendationsForWholeProject(model, aLayer);
}

// private List<List<AnnotationObject>> getRecommendationsForThisDocument(AnnotatorState aState,
// Predictions model, JCas aJcas, AnnotationLayer aLayer)
// {
// int windowBegin = 0;
// int windowEnd = aJcas.getDocumentText().length() - 1;
// // TODO #176 use the document Id once it it available in the CAS
// return model.getPredictions(aState.getDocument().getName(), aLayer, windowBegin,
// windowEnd, aJcas);
// }

@Override
public List<List<AnnotationObject>> getRecommendationsForWholeProject(Predictions model,
AnnotationLayer aLayer)
{
List<List<AnnotationObject>> result = new ArrayList<>();

Map<String, List<List<AnnotationObject>>> recommendationsMap = model
.getPredictionsForWholeProject(aLayer, documentService);

Set<String> documentNameSet = recommendationsMap.keySet();

for (String documentName : documentNameSet) {
result.addAll(recommendationsMap.get(documentName));
}

return result;
}

public List<AnnotationObject> getFlattenedRecommendationsFromRecommendationModel(JCas aJcas,
AnnotatorState aState, AnnotationLayer aSelectedLayer)
{
int windowBegin = 0;
int windowEnd = aJcas.getDocumentText().length() - 1;
Predictions model = recommendationService.getPredictions(aState.getUser(),
aState.getProject());
// TODO #176 use the document Id once it it available in the CAS
return model.getFlattenedPredictions(aState.getDocument().getName(), aSelectedLayer,
windowBegin, windowEnd, aJcas);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,21 @@
import java.util.Date;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

import org.apache.uima.cas.Type;
import org.apache.uima.cas.text.AnnotationFS;
import org.apache.uima.fit.util.CasUtil;
import org.apache.uima.jcas.JCas;

import de.tudarmstadt.ukp.clarin.webanno.api.DocumentService;
import de.tudarmstadt.ukp.clarin.webanno.api.annotation.model.AnnotatorState;
import de.tudarmstadt.ukp.clarin.webanno.model.AnnotationLayer;
import de.tudarmstadt.ukp.inception.active.learning.ActiveLearningService;
import de.tudarmstadt.ukp.inception.recommendation.imls.core.dataobjects.AnnotationObject;
import de.tudarmstadt.ukp.inception.recommendation.model.LearningRecord;
import de.tudarmstadt.ukp.inception.recommendation.model.LearningRecordUserAction;
import de.tudarmstadt.ukp.inception.recommendation.model.Predictions;
import de.tudarmstadt.ukp.inception.recommendation.service.LearningRecordService;
import de.tudarmstadt.ukp.inception.recommendation.service.RecommendationService;

public class ActiveLearningRecommender
implements Serializable
Expand All @@ -61,11 +57,11 @@ public ActiveLearningRecommender(AnnotatorState aState, AnnotationLayer aLayer)
}

public RecommendationDifference generateRecommendationWithLowestDifference(
LearningRecordService aRecordService, RecommendationService aRecommendationService,
DocumentService aDocumentService, Date learnSkippedRecommendationTime)
LearningRecordService aRecordService, ActiveLearningService aActiveLearningService,
Date learnSkippedRecommendationTime)
{
listOfRecommendationsForEachToken = getRecommendationFromRecommendationModel(
aDocumentService, aRecommendationService, annotatorState, selectedLayer);
listOfRecommendationsForEachToken = aActiveLearningService
.getRecommendationFromRecommendationModel(annotatorState, selectedLayer);

// remove recommendations with Null Annotation
listOfRecommendationsForEachToken.forEach(recommendationList ->
Expand All @@ -85,56 +81,14 @@ public RecommendationDifference generateRecommendationWithLowestDifference(
}

public boolean hasRecommendationWhichIsSkipped(LearningRecordService aRecordService,
DocumentService aDocumentService, RecommendationService aRecommendationService)
ActiveLearningService aActiveLearningService)
{
listOfRecommendationsForEachToken = getRecommendationFromRecommendationModel(
aDocumentService, aRecommendationService, annotatorState, selectedLayer);
listOfRecommendationsForEachToken = aActiveLearningService
.getRecommendationFromRecommendationModel(annotatorState, selectedLayer);
removeRejectedOrSkippedAnnotations(aRecordService, false, null);
return !listOfRecommendationsForEachToken.isEmpty();
}

private List<List<AnnotationObject>> getRecommendationFromRecommendationModel(
DocumentService aDocumentService, RecommendationService aRecommendationService,
AnnotatorState aState, AnnotationLayer aLayer)
{
Predictions model = aRecommendationService.getPredictions(aState.getUser(),
aState.getProject());

if (model == null) {
return new ArrayList<>();
}

// getRecommendationsForThisDocument(model);
return getRecommendationsForWholeProject(aDocumentService, model, aLayer);
}

private List<List<AnnotationObject>> getRecommendationsForThisDocument(Predictions model,
JCas aJcas, AnnotationLayer aLayer)
{
int windowBegin = 0;
int windowEnd = aJcas.getDocumentText().length() - 1;
// TODO #176 use the document Id once it it available in the CAS
return model.getPredictions(annotatorState.getDocument().getName(), aLayer, windowBegin,
windowEnd, aJcas);
}

private static List<List<AnnotationObject>> getRecommendationsForWholeProject(
DocumentService aDocumentService, Predictions model, AnnotationLayer aLayer)
{
List<List<AnnotationObject>> result = new ArrayList<>();

Map<String, List<List<AnnotationObject>>> recommendationsMap = model
.getPredictionsForWholeProject(aLayer, aDocumentService);

Set<String> documentNameSet = recommendationsMap.keySet();

for (String documentName : documentNameSet) {
result.addAll(recommendationsMap.get(documentName));
}

return result;
}

private static void removeRecommendationsWithNullAnnotation(
List<AnnotationObject> recommendationsList)
{
Expand Down Expand Up @@ -363,9 +317,9 @@ private static void sortDifferencesAscending(
}

public Optional<AnnotationObject> generateRecommendationWithLowestConfidence(
RecommendationService aRecommendationService, JCas aJcas)
ActiveLearningService aActiveLearningService, JCas aJcas)
{
recommendations = getFlattenedRecommendationsFromRecommendationModel(aRecommendationService,
recommendations = aActiveLearningService.getFlattenedRecommendationsFromRecommendationModel(
aJcas, annotatorState, selectedLayer);
removeRecommendationsWithNullAnnotation(recommendations);
removeExistingAnnotations(aJcas, selectedLayer, recommendations);
Expand All @@ -374,19 +328,6 @@ public Optional<AnnotationObject> generateRecommendationWithLowestConfidence(
return recommendations.stream().findFirst();
}

private static List<AnnotationObject> getFlattenedRecommendationsFromRecommendationModel(
RecommendationService aRecommendationService, JCas aJcas, AnnotatorState aState,
AnnotationLayer aSelectedLayer)
{
int windowBegin = 0;
int windowEnd = aJcas.getDocumentText().length() - 1;
Predictions model = aRecommendationService.getPredictions(aState.getUser(),
aState.getProject());
// TODO #176 use the document Id once it it available in the CAS
return model.getFlattenedPredictions(aState.getDocument().getName(), aSelectedLayer,
windowBegin, windowEnd, aJcas);
}

private static void removeExistingAnnotations(JCas aJcas,
AnnotationLayer aLayer, List<AnnotationObject> aRecommendations)
{
Expand Down Expand Up @@ -418,11 +359,11 @@ private static List<Integer> mapToBeginOffsets(
return existingAnnotationsSpanBegin;
}

public boolean checkRecommendationExist(DocumentService aDocumentService,
RecommendationService aRecommendationService, LearningRecord aRecord)
public boolean checkRecommendationExist(ActiveLearningService aActiveLearningService,
LearningRecord aRecord)
{
listOfRecommendationsForEachToken = getRecommendationFromRecommendationModel(
aDocumentService, aRecommendationService, annotatorState, selectedLayer);
listOfRecommendationsForEachToken = aActiveLearningService
.getRecommendationFromRecommendationModel(annotatorState, selectedLayer);
return containSuggestion(listOfRecommendationsForEachToken, aRecord);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
import de.tudarmstadt.ukp.clarin.webanno.support.spring.ApplicationEventPublisherHolder;
import de.tudarmstadt.ukp.clarin.webanno.ui.annotation.AnnotationPage;
import de.tudarmstadt.ukp.clarin.webanno.ui.annotation.sidebar.AnnotationSidebar_ImplBase;
import de.tudarmstadt.ukp.inception.active.learning.ActiveLearningService;
import de.tudarmstadt.ukp.inception.active.learning.event.ActiveLearningRecommendationEvent;
import de.tudarmstadt.ukp.inception.active.learning.event.ActiveLearningSessionCompletedEvent;
import de.tudarmstadt.ukp.inception.active.learning.event.ActiveLearningSessionStartedEvent;
Expand Down Expand Up @@ -113,6 +114,7 @@ public class ActiveLearningSidebar
private static final String ANNOTATION_MARKER = "VAnnotationMarker";
private static final String TEXT_MARKER = "VTextMarker";

private @SpringBean ActiveLearningService activeLearningService;
private @SpringBean AnnotationSchemaService annotationService;
private @SpringBean RecommendationService recommendationService;
private @SpringBean LearningRecordService learningRecordService;
Expand Down Expand Up @@ -254,7 +256,7 @@ private void showAndHighlightRecommendationAndJumpToRecommendationLocation(
else if (learnSkippedRecommendationTime == null) {
hasUnseenRecommendation = false;
hasSkippedRecommendation = activeLearningRecommender.hasRecommendationWhichIsSkipped(
learningRecordService, documentService, recommendationService);
learningRecordService, activeLearningService);
}
else {
hasUnseenRecommendation = false;
Expand Down Expand Up @@ -476,8 +478,7 @@ private void moveToNextRecommendation(AjaxRequestTarget aTarget)
annotationPage.actionRefreshDocument(aTarget);
currentDifference = activeLearningRecommender
.generateRecommendationWithLowestDifference(learningRecordService,
recommendationService, documentService,
learnSkippedRecommendationTime);
activeLearningService, learnSkippedRecommendationTime);
showAndHighlightRecommendationAndJumpToRecommendationLocation(aTarget);
}

Expand Down Expand Up @@ -535,8 +536,8 @@ private void jumpAndHighlightFromLearningHistory(AjaxRequestTarget aTarget,
highlightTextAndDisplayMessage(aTarget, record);
}
// if the suggestion still exists, highlight that suggestion.
else if (activeLearningRecommender.checkRecommendationExist(documentService,
recommendationService, record)) {
else if (activeLearningRecommender.checkRecommendationExist(activeLearningService,
record)) {
highlightRecommendation(aTarget, record.getOffsetCharacterBegin(),
record.getOffsetCharacterEnd(), record.getTokenText(), record.getAnnotation());
}
Expand Down

0 comments on commit 17690ac

Please sign in to comment.