Skip to content

Commit

Permalink
Fix Parallel Execution of CLIP Feature (#334)
Browse files Browse the repository at this point in the history
  • Loading branch information
silvanheller authored Oct 17, 2022
1 parent 2c0e1c7 commit 183dfd1
Show file tree
Hide file tree
Showing 12 changed files with 64 additions and 76 deletions.
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ allprojects {
group = 'org.vitrivr'

/* Our current version, on dev branch this should always be release+1-SNAPSHOT */
version = '3.12.2'
version = '3.12.4'

apply plugin: 'java-library'
apply plugin: 'maven-publish'
Expand Down
14 changes: 7 additions & 7 deletions cineast-api/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ plugins {
id 'application'
}


application {
getMainClass().set('org.vitrivr.cineast.api.Main')
applicationDefaultJvmArgs = ["-Xms1G", "-Xmx2G"]
Expand Down Expand Up @@ -86,11 +85,6 @@ signing {
sign publishing.publications.mavenJava
}

configurations.all {
// Check for updates every build
resolutionStrategy.cacheChangingModulesFor 0, 'seconds'
}

jar {
manifest {
attributes 'Main-Class': 'org.vitrivr.cineast.api.Main'
Expand All @@ -103,8 +97,14 @@ shadowJar {
mergeServiceFiles()
}

configurations.all {
// Check for updates every build
resolutionStrategy.cacheChangingModulesFor 0, 'seconds'
}

dependencies {
implementation project(':cineast-runtime')
api project(':cineast-runtime')

implementation("io.javalin:javalin-bundle:$version_javalin") {
exclude group: 'ch.qos.logback', module: 'logback-classic'
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,10 @@ public void execute(Session session, QueryConfig qconf, TemporalQuery message, S
.collect(Collectors.toList());

if (results.isEmpty()) {
LOGGER.warn("No results found for category {} and qt {} in stage with id {}. Full component: {}", category, qt.type(), lambdaFinalContainerIdx, stage);
LOGGER.info("No results found for category {} and qt {} in stage {} for query {}. Full component: {}", category, qt.type(), lambdaFinalContainerIdx, uuid, stage);
}
if (cache.get(stageIndex).containsKey(category)) {
LOGGER.error("Category {} was used twice in stage {}. This erases the results of the previous category... ", category, stageIndex);
LOGGER.error("Category {} was used twice in stage {} for query {}. This erases the results of the previous category... ", category, stageIndex, uuid);
}

cache.get(stageIndex).put(category, results);
Expand Down
3 changes: 3 additions & 0 deletions cineast-core/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ import org.gradle.nativeplatform.platform.internal.DefaultNativePlatform
repositories {
maven {
url "https://oss.sonatype.org/content/repositories/snapshots/"
mavenContent {
snapshotsOnly()
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import org.tensorflow.Tensor;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.buffer.DataBuffers;
import org.tensorflow.ndarray.buffer.FloatDataBuffer;
import org.tensorflow.types.TFloat16;
import org.vitrivr.cineast.core.config.QueryConfig;
import org.vitrivr.cineast.core.config.ReadableQueryConfig;
Expand Down Expand Up @@ -50,9 +49,7 @@ public CLIPImage() {
}

private static float[] prepareImage(BufferedImage img) {
return ImagePreprocessingHelper.imageToCHWArray(
ImagePreprocessingHelper.squaredScaleCenterCrop(img, IMAGE_SIZE),
MEAN, STD);
return ImagePreprocessingHelper.imageToCHWArray(ImagePreprocessingHelper.squaredScaleCenterCrop(img, IMAGE_SIZE), MEAN, STD);
}

@Override
Expand Down Expand Up @@ -105,13 +102,10 @@ private float[] embedImage(BufferedImage img) {
try (TFloat16 imageTensor = TFloat16.tensorOf(Shape.of(1, 3, IMAGE_SIZE, IMAGE_SIZE), DataBuffers.of(rgb))) {
HashMap<String, Tensor> inputMap = new HashMap<>();
inputMap.put(EMBEDDING_INPUT, imageTensor);

Map<String, Tensor> resultMap = model.call(inputMap);

try (TFloat16 encoding = (TFloat16) resultMap.get(EMBEDDING_OUTPUT)) {

float[] embeddingArray = new float[EMBEDDING_SIZE];
FloatDataBuffer floatBuffer = DataBuffers.of(embeddingArray);
var embeddingArray = new float[EMBEDDING_SIZE];
var floatBuffer = DataBuffers.of(embeddingArray);
encoding.read(floatBuffer);

return embeddingArray;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import static org.vitrivr.cineast.core.util.CineastConstants.FEATURE_COLUMN_QUALIFIER;
import static org.vitrivr.cineast.core.util.CineastConstants.GENERIC_ID_COLUMN_QUALIFIER;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -51,19 +52,21 @@ public class CLIPText implements Retriever {

private static final CorrespondenceFunction CORRESPONDENCE = CorrespondenceFunction.linear(1);

private static SavedModelBundle model;

private DBSelector selector;
private ClipTokenizer ct = new ClipTokenizer();
private final ClipTokenizer ct = new ClipTokenizer();

private static SavedModelBundle bundle;

public CLIPText() {
loadModel();
ensureModelLoaded();
}

private synchronized static void loadModel() {
if (model == null) {
model = SavedModelBundle.load(RESOURCE_PATH + EMBEDDING_MODEL);
public static synchronized void ensureModelLoaded() {
if (bundle != null) {
return;
}
LOGGER.debug("Loading CLIP Model");
bundle = SavedModelBundle.load(RESOURCE_PATH + EMBEDDING_MODEL);
}

@Override
Expand All @@ -84,40 +87,41 @@ public void init(DBSelectorSupplier selectorSupply) {

@Override
public List<ScoreElement> getSimilar(SegmentContainer sc, ReadableQueryConfig qc) {

String text = sc.getText();

if (text == null || text.isBlank()) {
return Collections.emptyList();
}
LOGGER.debug("Querying for: \"{}\"", text);

return getSimilar(new FloatArrayTypeProvider(embedText(text)), qc);
try {
return getSimilar(new FloatArrayTypeProvider(embedText(text)), qc);
} catch (Exception e) {
LOGGER.error("error during CLIPText execution", e);
return new ArrayList<>();
}
}

public float[] embedText(String text) {

long[] tokens = ct.clipTokenize(text);

LongNdArray arr = NdArrays.ofLongs(Shape.of(1, tokens.length));
for (int i = 0; i < tokens.length; i++) {
arr.setLong(tokens[i], 0, i);
}

try (TInt64 textTensor = TInt64.tensorOf(arr)) {

HashMap<String, Tensor> inputMap = new HashMap<>();
inputMap.put(EMBEDDING_INPUT, textTensor);
return exec(inputMap);
}
}

Map<String, Tensor> resultMap = model.call(inputMap);

try (TFloat16 embedding = (TFloat16) resultMap.get(EMBEDDING_OUTPUT)) {

float[] embeddingArray = new float[EMBEDDING_SIZE];
FloatDataBuffer floatBuffer = DataBuffers.of(embeddingArray);
embedding.read(floatBuffer);
return embeddingArray;
private static float[] exec(HashMap<String, Tensor> inputMap) {
var resultMap = bundle.call(inputMap);

}
try (TFloat16 embedding = (TFloat16) resultMap.get(EMBEDDING_OUTPUT)) {
var embeddingArray = new float[EMBEDDING_SIZE];
var floatBuffer = DataBuffers.of(embeddingArray);
embedding.read(floatBuffer);
return embeddingArray;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,13 @@
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.buffer.DataBuffers;
import org.tensorflow.ndarray.buffer.FloatDataBuffer;
import org.tensorflow.types.TFloat32;
import org.vitrivr.cineast.core.config.QueryConfig;
import org.vitrivr.cineast.core.config.ReadableQueryConfig;
Expand Down Expand Up @@ -79,14 +77,11 @@ public static float[] encodeImage(BufferedImage image) {
HashMap<String, Tensor> inputMap = new HashMap<>();
inputMap.put(INPUT, imageTensor);

Map<String, Tensor> resultMap = model.call(inputMap);

var resultMap = model.call(inputMap);
try (TFloat32 encoding = (TFloat32) resultMap.get(OUTPUT)) {

float[] embeddingArray = new float[ENCODING_SIZE];
FloatDataBuffer floatBuffer = DataBuffers.of(embeddingArray);
var embeddingArray = new float[ENCODING_SIZE];
var floatBuffer = DataBuffers.of(embeddingArray);
encoding.read(floatBuffer);

return embeddingArray;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import org.tensorflow.ndarray.NdArrays;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.buffer.DataBuffers;
import org.tensorflow.ndarray.buffer.FloatDataBuffer;
import org.tensorflow.types.TFloat32;
import org.tensorflow.types.TString;
import org.vitrivr.cineast.core.config.QueryConfig;
Expand Down Expand Up @@ -90,9 +89,7 @@ public void processSegment(SegmentContainer sc) {

// Case: segment contains video frames
if (!sc.getVideoFrames().isEmpty() && sc.getVideoFrames().get(0) != VideoFrame.EMPTY_VIDEO_FRAME) {
List<MultiImage> frames = sc.getVideoFrames().stream()
.map(VideoFrame::getImage)
.collect(Collectors.toList());
List<MultiImage> frames = sc.getVideoFrames().stream().map(VideoFrame::getImage).collect(Collectors.toList());

float[] embeddingArray = embedVideo(frames);
this.persist(sc.getId(), new FloatVectorImpl(embeddingArray));
Expand Down Expand Up @@ -179,19 +176,15 @@ private float[] embedText(String text) {

HashMap<String, Tensor> inputMap = new HashMap<>();
inputMap.put(TEXT_EMBEDDING_INPUT, textTensor);

Map<String, Tensor> resultMap = textEmbedding.call(inputMap);

try (TFloat32 intermediaryEmbedding = (TFloat32) resultMap.get(TEXT_EMBEDDING_OUTPUT)) {

inputMap.clear();
inputMap.put(TEXT_CO_EMBEDDING_INPUT, intermediaryEmbedding);

resultMap = textCoEmbedding.call(inputMap);
try (TFloat32 embedding = (TFloat32) resultMap.get(TEXT_CO_EMBEDDING_OUTPUT)) {

float[] embeddingArray = new float[EMBEDDING_SIZE];
FloatDataBuffer floatBuffer = DataBuffers.of(embeddingArray);
var embeddingArray = new float[EMBEDDING_SIZE];
var floatBuffer = DataBuffers.of(embeddingArray);
// Beware TensorFlow allows tensor writing to buffers through the function read rather than write
embedding.read(floatBuffer);

Expand All @@ -211,17 +204,13 @@ private float[] embedImage(BufferedImage image) {
inputMap.put(InceptionResnetV2.INPUT, imageTensor);

Map<String, Tensor> resultMap = visualEmbedding.call(inputMap);

try (TFloat32 intermediaryEmbedding = (TFloat32) resultMap.get(InceptionResnetV2.OUTPUT)) {

inputMap.clear();
inputMap.put(VISUAL_CO_EMBEDDING_INPUT, intermediaryEmbedding);

resultMap = visualCoEmbedding.call(inputMap);
try (TFloat32 embedding = (TFloat32) resultMap.get(VISUAL_CO_EMBEDDING_OUTPUT)) {

float[] embeddingArray = new float[EMBEDDING_SIZE];
FloatDataBuffer floatBuffer = DataBuffers.of(embeddingArray);
var embeddingArray = new float[EMBEDDING_SIZE];
var floatBuffer = DataBuffers.of(embeddingArray);
// Beware TensorFlow allows tensor writing to buffers through the function read rather than write
embedding.read(floatBuffer);

Expand All @@ -240,12 +229,10 @@ private float[] embedVideo(List<MultiImage> frames) {
HashMap<String, Tensor> inputMap = new HashMap<>();

inputMap.put(VISUAL_CO_EMBEDDING_INPUT, encoding);

Map<String, Tensor> resultMap = visualCoEmbedding.call(inputMap);
try (TFloat32 embedding = (TFloat32) resultMap.get(VISUAL_CO_EMBEDDING_OUTPUT)) {

float[] embeddingArray = new float[EMBEDDING_SIZE];
FloatDataBuffer floatBuffer = DataBuffers.of(embeddingArray);
var embeddingArray = new float[EMBEDDING_SIZE];
var floatBuffer = DataBuffers.of(embeddingArray);
// Beware TensorFlow allows tensor writing to buffers through the function read rather than write
embedding.read(floatBuffer);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ public class ClipTokenizer {
for (int i : byte_encoder.keySet()) {
byte_decoder.put(byte_encoder.get(i), i);
}
init();
}

private HashMap<String, String> cache = new HashMap<>();
private Pattern pat = Pattern.compile("<\\|startoftext\\|>|<\\|endoftext\\|>|'s|'t|'re|'ve|'m|'ll|'d|[\\p{L}]+|[\\p{N}]|[^\\s\\p{L}\\p{N}]+", Pattern.CASE_INSENSITIVE);
public ClipTokenizer() {
init();
cache.put("<|startoftext|>", "<|startoftext|>");
cache.put("<|endoftext|>", "<|endoftext|>");
}
Expand Down
3 changes: 0 additions & 3 deletions cineast-runtime/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,6 @@ shadowJar {
mergeServiceFiles()
}

/**
* See build.gradle in cineast-core on why we cannot use implementation
*/
dependencies {

api project(':cineast-core')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ public static void retrieveAndLog(List<Retriever> retrievers, ContinuousRetrieva
System.out.println("Only printing the first " + limit + " results, change with --limit parameter");
DBSelector selector = Config.sharedConfig().getDatabase().getSelectorSupplier().get();
retrievers.forEach(retriever -> {

AtomicBoolean entityExists = new AtomicBoolean(true);
retriever.getTableNames().forEach(table -> {
if (!selector.existsEntity(table)) {
Expand All @@ -104,7 +103,7 @@ public static void retrieveAndLog(List<Retriever> retrievers, ContinuousRetrieva
}
System.out.println("Retrieving for " + retriever.getClass().getSimpleName());
long start = System.currentTimeMillis();
List<SegmentScoreElement> results = retrieval.retrieveByRetriever(qc, retriever, new ConstrainedQueryConfig().setMaxResults(limit));
List<SegmentScoreElement> results = retrieval.retrieveByRetriever(qc, retriever, new ConstrainedQueryConfig().setMaxResults(limit).setResultsPerModule(limit));
long stop = System.currentTimeMillis();
System.out.println("Results for " + retriever.getClass().getSimpleName() + ":, retrieved in " + (stop - start) + "ms");

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,21 @@ public class TextRetrievalCommand extends AbstractCineastCommand {
@Option(name = {"--detail"}, title = "detailed results", description = "also list detailed results for retrieved segments.")
private Boolean printDetail = false;

@Option(name = {"--category"}, title = "category", description = "specific category of retrievers")
private String category;

@Override
public void execute() {
final ContinuousRetrievalLogic retrieval = new ContinuousRetrievalLogic(Config.sharedConfig().getDatabase());
System.out.println("Querying for text " + text);
TextQueryTermContainer qc = new TextQueryTermContainer(text);
List<Retriever> retrievers = new ArrayList<>();
if (category != null) {
Config.sharedConfig().getRetriever().getRetrieversByCategory(category).forEach((ObjectDoubleProcedure<? super Retriever>) (retriever, weight) -> {
CliUtils.retrieveAndLog(Lists.newArrayList(retriever), retrieval, limit, printDetail, qc);
});
return;
}
Config.sharedConfig().getRetriever().getRetrieversByCategory("ocr").forEach((ObjectDoubleProcedure<? super Retriever>) (retriever, weight) -> {
CliUtils.retrieveAndLog(Lists.newArrayList(retriever), retrieval, limit, printDetail, qc);
});
Expand Down

0 comments on commit 183dfd1

Please sign in to comment.