Skip to content

Commit

Permalink
fix review9: fix/ debug ModelCacheTest
Browse files Browse the repository at this point in the history
  • Loading branch information
EvgeniiMunin committed Oct 9, 2024
1 parent e6f2045 commit b51e6c2
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 116 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.prebid.server.hooks.modules.greenbids.real.time.data.model.predictor.FilterService;
import org.prebid.server.hooks.modules.greenbids.real.time.data.model.predictor.ModelCache;
import org.prebid.server.hooks.modules.greenbids.real.time.data.model.predictor.OnnxModelRunner;
import org.prebid.server.hooks.modules.greenbids.real.time.data.model.predictor.OnnxModelRunnerFactory;
import org.prebid.server.hooks.modules.greenbids.real.time.data.model.predictor.OnnxModelRunnerWithThresholds;
import org.prebid.server.hooks.modules.greenbids.real.time.data.model.predictor.ThresholdCache;
import org.prebid.server.hooks.modules.greenbids.real.time.data.model.result.GreenbidsInvocationService;
Expand Down Expand Up @@ -77,7 +78,17 @@ Storage storage(GreenbidsRealTimeDataProperties properties) {
}

@Bean
ModelCache modelCache(GreenbidsRealTimeDataProperties properties, Vertx vertx, Storage storage) {
OnnxModelRunnerFactory onnxModelRunnerFactory() {
return new OnnxModelRunnerFactory();
}

@Bean
ModelCache modelCache(
GreenbidsRealTimeDataProperties properties,
Vertx vertx,
Storage storage,
OnnxModelRunnerFactory onnxModelRunnerFactory) {

final Cache<String, OnnxModelRunner> modelCacheWithExpiration = Caffeine.newBuilder()
.expireAfterWrite(properties.getCacheExpirationMinutes(), TimeUnit.MINUTES)
.build();
Expand All @@ -87,7 +98,8 @@ ModelCache modelCache(GreenbidsRealTimeDataProperties properties, Vertx vertx, S
properties.getGcsBucketName(),
modelCacheWithExpiration,
properties.getOnnxModelCacheKeyPrefix(),
vertx);
vertx,
onnxModelRunnerFactory);
}

@Bean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.prebid.server.log.Logger;
import org.prebid.server.log.LoggerFactory;

import java.util.Arrays;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.atomic.AtomicBoolean;
Expand All @@ -32,18 +33,22 @@ public class ModelCache {

private final Vertx vertx;

private final OnnxModelRunnerFactory onnxModelRunnerFactory;

public ModelCache(
Storage storage,
String gcsBucketName,
Cache<String, OnnxModelRunner> cache,
String onnxModelCacheKeyPrefix,
Vertx vertx) {
Vertx vertx,
OnnxModelRunnerFactory onnxModelRunnerFactory) {
this.gcsBucketName = Objects.requireNonNull(gcsBucketName);
this.cache = Objects.requireNonNull(cache);
this.storage = Objects.requireNonNull(storage);
this.onnxModelCacheKeyPrefix = Objects.requireNonNull(onnxModelCacheKeyPrefix);
this.isFetching = new AtomicBoolean(false);
this.vertx = Objects.requireNonNull(vertx);
this.onnxModelRunnerFactory = Objects.requireNonNull(onnxModelRunnerFactory);
}

public Future<OnnxModelRunner> get(String onnxModelPath, String pbuid) {
Expand All @@ -61,8 +66,7 @@ public Future<OnnxModelRunner> get(String onnxModelPath, String pbuid) {

if (isFetching.compareAndSet(false, true)) {
try {
fetchAndCacheModelRunner(onnxModelPath, cacheKey);
return Future.failedFuture("ModelRunner is fetched. Skip current request");
return fetchAndCacheModelRunner(onnxModelPath, cacheKey);
} finally {
isFetching.set(false);
}
Expand All @@ -71,14 +75,14 @@ public Future<OnnxModelRunner> get(String onnxModelPath, String pbuid) {
return Future.failedFuture("ModelRunner fetching in progress. Skip current request");
}

private void fetchAndCacheModelRunner(String onnxModelPath, String cacheKey) {
private Future<OnnxModelRunner> fetchAndCacheModelRunner(String onnxModelPath, String cacheKey) {
System.out.println(
"fetchAndCacheModelRunner: \n" +
" onnxModelPath: " + onnxModelPath + "\n" +
" cacheKey: " + cacheKey
);

vertx.executeBlocking(() -> getBlob(onnxModelPath))
return vertx.executeBlocking(() -> getBlob(onnxModelPath))
.map(this::loadModelRunner)
.onSuccess(onnxModelRunner -> cache.put(cacheKey, onnxModelRunner))
.onFailure(error -> logger.error("Failed to fetch ONNX model"));
Expand Down Expand Up @@ -108,8 +112,16 @@ private Blob getBlob(String onnxModelPath) {
private OnnxModelRunner loadModelRunner(Blob blob) {
try {
final byte[] onnxModelBytes = blob.getContent();
return new OnnxModelRunner(onnxModelBytes);

System.out.println(
"loadModelRunner: \n" +
" blob: " + blob + "\n" +
" onnxModelBytes: " + Arrays.toString(onnxModelBytes) + "\n"
);

return onnxModelRunnerFactory.create(onnxModelBytes);
} catch (OrtException e) {
System.out.println("OrtException trigger PreBidException");
throw new PreBidException("Failed to convert blob to ONNX model", e);
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package org.prebid.server.hooks.modules.greenbids.real.time.data.model.predictor;

import ai.onnxruntime.OrtException;

public class OnnxModelRunnerFactory {
public OnnxModelRunner create(byte[] bytes) throws OrtException {
return new OnnxModelRunner(bytes);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.prebid.server.hooks.modules.greenbids.real.time.data.model.predictor.FilterService;
import org.prebid.server.hooks.modules.greenbids.real.time.data.model.predictor.ModelCache;
import org.prebid.server.hooks.modules.greenbids.real.time.data.model.predictor.OnnxModelRunner;
import org.prebid.server.hooks.modules.greenbids.real.time.data.model.predictor.OnnxModelRunnerFactory;
import org.prebid.server.hooks.modules.greenbids.real.time.data.model.predictor.OnnxModelRunnerWithThresholds;
import org.prebid.server.hooks.modules.greenbids.real.time.data.model.predictor.ThresholdCache;
import org.prebid.server.hooks.modules.greenbids.real.time.data.model.result.AnalyticsResult;
Expand Down Expand Up @@ -90,12 +91,14 @@ public void setUp() throws IOException {
final File database = new File("src/test/resources/GeoLite2-Country.mmdb");
final DatabaseReader dbReader = new DatabaseReader.Builder(database).build();
final FilterService filterService = new FilterService();
final OnnxModelRunnerFactory onnxModelRunnerFactory = new OnnxModelRunnerFactory();
final ModelCache modelCache = new ModelCache(
storage,
"test_bucket",
modelCacheWithExpiration,
"onnxModelRunner_",
Vertx.vertx());
Vertx.vertx(),
onnxModelRunnerFactory);
final ThresholdCache thresholdCache = new ThresholdCache(
storage,
"test_bucket",
Expand Down
Loading

0 comments on commit b51e6c2

Please sign in to comment.