diff --git a/extra/modules/greenbids-real-time-data/src/main/java/org/prebid/server/hooks/modules/greenbids/real/time/data/config/GreenbidsRealTimeDataConfiguration.java b/extra/modules/greenbids-real-time-data/src/main/java/org/prebid/server/hooks/modules/greenbids/real/time/data/config/GreenbidsRealTimeDataConfiguration.java index a19c81a76c7..7223698356f 100644 --- a/extra/modules/greenbids-real-time-data/src/main/java/org/prebid/server/hooks/modules/greenbids/real/time/data/config/GreenbidsRealTimeDataConfiguration.java +++ b/extra/modules/greenbids-real-time-data/src/main/java/org/prebid/server/hooks/modules/greenbids/real/time/data/config/GreenbidsRealTimeDataConfiguration.java @@ -13,6 +13,7 @@ import org.prebid.server.hooks.modules.greenbids.real.time.data.model.predictor.FilterService; import org.prebid.server.hooks.modules.greenbids.real.time.data.model.predictor.ModelCache; import org.prebid.server.hooks.modules.greenbids.real.time.data.model.predictor.OnnxModelRunner; +import org.prebid.server.hooks.modules.greenbids.real.time.data.model.predictor.OnnxModelRunnerFactory; import org.prebid.server.hooks.modules.greenbids.real.time.data.model.predictor.OnnxModelRunnerWithThresholds; import org.prebid.server.hooks.modules.greenbids.real.time.data.model.predictor.ThresholdCache; import org.prebid.server.hooks.modules.greenbids.real.time.data.model.result.GreenbidsInvocationService; @@ -77,7 +78,17 @@ Storage storage(GreenbidsRealTimeDataProperties properties) { } @Bean - ModelCache modelCache(GreenbidsRealTimeDataProperties properties, Vertx vertx, Storage storage) { + OnnxModelRunnerFactory onnxModelRunnerFactory() { + return new OnnxModelRunnerFactory(); + } + + @Bean + ModelCache modelCache( + GreenbidsRealTimeDataProperties properties, + Vertx vertx, + Storage storage, + OnnxModelRunnerFactory onnxModelRunnerFactory) { + final Cache modelCacheWithExpiration = Caffeine.newBuilder() .expireAfterWrite(properties.getCacheExpirationMinutes(), TimeUnit.MINUTES) .build(); @@ -87,7 +98,8 @@ ModelCache modelCache(GreenbidsRealTimeDataProperties properties, Vertx vertx, S properties.getGcsBucketName(), modelCacheWithExpiration, properties.getOnnxModelCacheKeyPrefix(), - vertx); + vertx, + onnxModelRunnerFactory); } @Bean diff --git a/extra/modules/greenbids-real-time-data/src/main/java/org/prebid/server/hooks/modules/greenbids/real/time/data/model/predictor/ModelCache.java b/extra/modules/greenbids-real-time-data/src/main/java/org/prebid/server/hooks/modules/greenbids/real/time/data/model/predictor/ModelCache.java index 4f9bba3f26e..d061b71f735 100644 --- a/extra/modules/greenbids-real-time-data/src/main/java/org/prebid/server/hooks/modules/greenbids/real/time/data/model/predictor/ModelCache.java +++ b/extra/modules/greenbids-real-time-data/src/main/java/org/prebid/server/hooks/modules/greenbids/real/time/data/model/predictor/ModelCache.java @@ -12,6 +12,7 @@ import org.prebid.server.log.Logger; import org.prebid.server.log.LoggerFactory; +import java.util.Arrays; import java.util.Objects; import java.util.Optional; import java.util.concurrent.atomic.AtomicBoolean; @@ -32,18 +33,22 @@ public class ModelCache { private final Vertx vertx; + private final OnnxModelRunnerFactory onnxModelRunnerFactory; + public ModelCache( Storage storage, String gcsBucketName, Cache cache, String onnxModelCacheKeyPrefix, - Vertx vertx) { + Vertx vertx, + OnnxModelRunnerFactory onnxModelRunnerFactory) { this.gcsBucketName = Objects.requireNonNull(gcsBucketName); this.cache = Objects.requireNonNull(cache); this.storage = Objects.requireNonNull(storage); this.onnxModelCacheKeyPrefix = Objects.requireNonNull(onnxModelCacheKeyPrefix); this.isFetching = new AtomicBoolean(false); this.vertx = Objects.requireNonNull(vertx); + this.onnxModelRunnerFactory = Objects.requireNonNull(onnxModelRunnerFactory); } public Future get(String onnxModelPath, String pbuid) { @@ -61,8 +66,7 @@ public Future get(String onnxModelPath, String pbuid) { if (isFetching.compareAndSet(false, true)) { try { - fetchAndCacheModelRunner(onnxModelPath, cacheKey); - return Future.failedFuture("ModelRunner is fetched. Skip current request"); + return fetchAndCacheModelRunner(onnxModelPath, cacheKey); } finally { isFetching.set(false); } @@ -71,14 +75,14 @@ public Future get(String onnxModelPath, String pbuid) { return Future.failedFuture("ModelRunner fetching in progress. Skip current request"); } - private void fetchAndCacheModelRunner(String onnxModelPath, String cacheKey) { + private Future fetchAndCacheModelRunner(String onnxModelPath, String cacheKey) { System.out.println( "fetchAndCacheModelRunner: \n" + " onnxModelPath: " + onnxModelPath + "\n" + " cacheKey: " + cacheKey ); - vertx.executeBlocking(() -> getBlob(onnxModelPath)) + return vertx.executeBlocking(() -> getBlob(onnxModelPath)) .map(this::loadModelRunner) .onSuccess(onnxModelRunner -> cache.put(cacheKey, onnxModelRunner)) .onFailure(error -> logger.error("Failed to fetch ONNX model")); @@ -108,8 +112,16 @@ private Blob getBlob(String onnxModelPath) { private OnnxModelRunner loadModelRunner(Blob blob) { try { final byte[] onnxModelBytes = blob.getContent(); - return new OnnxModelRunner(onnxModelBytes); + + System.out.println( + "loadModelRunner: \n" + + " blob: " + blob + "\n" + + " onnxModelBytes: " + Arrays.toString(onnxModelBytes) + "\n" + ); + + return onnxModelRunnerFactory.create(onnxModelBytes); } catch (OrtException e) { + System.out.println("OrtException trigger PreBidException"); throw new PreBidException("Failed to convert blob to ONNX model", e); } } diff --git a/extra/modules/greenbids-real-time-data/src/main/java/org/prebid/server/hooks/modules/greenbids/real/time/data/model/predictor/OnnxModelRunnerFactory.java b/extra/modules/greenbids-real-time-data/src/main/java/org/prebid/server/hooks/modules/greenbids/real/time/data/model/predictor/OnnxModelRunnerFactory.java new file mode 100644 index 00000000000..c6c0009f336 --- /dev/null +++ b/extra/modules/greenbids-real-time-data/src/main/java/org/prebid/server/hooks/modules/greenbids/real/time/data/model/predictor/OnnxModelRunnerFactory.java @@ -0,0 +1,9 @@ +package org.prebid.server.hooks.modules.greenbids.real.time.data.model.predictor; + +import ai.onnxruntime.OrtException; + +public class OnnxModelRunnerFactory { + public OnnxModelRunner create(byte[] bytes) throws OrtException { + return new OnnxModelRunner(bytes); + } +} diff --git a/extra/modules/greenbids-real-time-data/src/test/java/org/prebid/server/hooks/modules/greenbids/real/time/data/v1/GreenbidsRealTimeDataProcessedAuctionRequestHookTest.java b/extra/modules/greenbids-real-time-data/src/test/java/org/prebid/server/hooks/modules/greenbids/real/time/data/v1/GreenbidsRealTimeDataProcessedAuctionRequestHookTest.java index ffc00f2b4c7..874bf936d4a 100644 --- a/extra/modules/greenbids-real-time-data/src/test/java/org/prebid/server/hooks/modules/greenbids/real/time/data/v1/GreenbidsRealTimeDataProcessedAuctionRequestHookTest.java +++ b/extra/modules/greenbids-real-time-data/src/test/java/org/prebid/server/hooks/modules/greenbids/real/time/data/v1/GreenbidsRealTimeDataProcessedAuctionRequestHookTest.java @@ -28,6 +28,7 @@ import org.prebid.server.hooks.modules.greenbids.real.time.data.model.predictor.FilterService; import org.prebid.server.hooks.modules.greenbids.real.time.data.model.predictor.ModelCache; import org.prebid.server.hooks.modules.greenbids.real.time.data.model.predictor.OnnxModelRunner; +import org.prebid.server.hooks.modules.greenbids.real.time.data.model.predictor.OnnxModelRunnerFactory; import org.prebid.server.hooks.modules.greenbids.real.time.data.model.predictor.OnnxModelRunnerWithThresholds; import org.prebid.server.hooks.modules.greenbids.real.time.data.model.predictor.ThresholdCache; import org.prebid.server.hooks.modules.greenbids.real.time.data.model.result.AnalyticsResult; @@ -90,12 +91,14 @@ public void setUp() throws IOException { final File database = new File("src/test/resources/GeoLite2-Country.mmdb"); final DatabaseReader dbReader = new DatabaseReader.Builder(database).build(); final FilterService filterService = new FilterService(); + final OnnxModelRunnerFactory onnxModelRunnerFactory = new OnnxModelRunnerFactory(); final ModelCache modelCache = new ModelCache( storage, "test_bucket", modelCacheWithExpiration, "onnxModelRunner_", - Vertx.vertx()); + Vertx.vertx(), + onnxModelRunnerFactory); final ThresholdCache thresholdCache = new ThresholdCache( storage, "test_bucket", diff --git a/extra/modules/greenbids-real-time-data/src/test/java/org/prebid/server/hooks/modules/greenbids/real/time/data/v1/ModelCacheTest.java b/extra/modules/greenbids-real-time-data/src/test/java/org/prebid/server/hooks/modules/greenbids/real/time/data/v1/ModelCacheTest.java index cb000065e61..b14ba4a3b81 100644 --- a/extra/modules/greenbids-real-time-data/src/test/java/org/prebid/server/hooks/modules/greenbids/real/time/data/v1/ModelCacheTest.java +++ b/extra/modules/greenbids-real-time-data/src/test/java/org/prebid/server/hooks/modules/greenbids/real/time/data/v1/ModelCacheTest.java @@ -18,16 +18,14 @@ import org.prebid.server.exception.PreBidException; import org.prebid.server.hooks.modules.greenbids.real.time.data.model.predictor.ModelCache; import org.prebid.server.hooks.modules.greenbids.real.time.data.model.predictor.OnnxModelRunner; +import org.prebid.server.hooks.modules.greenbids.real.time.data.model.predictor.OnnxModelRunnerFactory; import java.lang.reflect.Field; -import java.util.concurrent.Callable; import java.util.concurrent.atomic.AtomicBoolean; import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.Assert.assertThrows; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; @@ -46,7 +44,6 @@ public class ModelCacheTest { @Mock private Storage storage; - @Mock private Vertx vertx; @Mock @@ -58,18 +55,17 @@ public class ModelCacheTest { @Mock private OnnxModelRunner onnxModelRunner; + @Mock + private OnnxModelRunnerFactory onnxModelRunnerFactory; + @Mock private ModelCache target; @BeforeEach public void setUp() { - //cache = Mockito.mock(Cache.class); - //storage = Mockito.mock(Storage.class); - //vertx = Mockito.mock(Vertx.class); - //blob = Mockito.mock(Blob.class); - //onnxModelRunner = Mockito.mock(OnnxModelRunner.class); - - target = new ModelCache(storage, GCS_BUCKET_NAME, cache, MODEL_CACHE_KEY_PREFIX, vertx); + vertx = Vertx.vertx(); + target = new ModelCache( + storage, GCS_BUCKET_NAME, cache, MODEL_CACHE_KEY_PREFIX, vertx, onnxModelRunnerFactory); } @Test @@ -92,30 +88,13 @@ public void getShouldSkipFetchingWhenFetchingInProgress() throws NoSuchFieldExce // given String cacheKey = MODEL_CACHE_KEY_PREFIX + PBUUID; - //ModelCache targetWithFetching = new ModelCache( - // storage, GCS_BUCKET_NAME, cache, MODEL_CACHE_KEY_PREFIX, vertx); - // isFetching.set(true); - // Create a spy of the ModelCache class - ModelCache spyModelCache = spy(new ModelCache(storage, GCS_BUCKET_NAME, cache, MODEL_CACHE_KEY_PREFIX, vertx)); + ModelCache spyModelCache = spy(new ModelCache( + storage, GCS_BUCKET_NAME, cache, MODEL_CACHE_KEY_PREFIX, vertx, onnxModelRunnerFactory)); // Mock the cache to simulate that the model is not present when(cache.getIfPresent(eq(cacheKey))).thenReturn(null); - // Mock the vertx executeBlocking to simulate fetching process - /* - when(vertx.executeBlocking(any(Callable.class))).thenAnswer(invocation -> { - Callable callable = invocation.getArgument(0); - try { - // Simulate the callable being executed successfully - Object result = callable.call(); - return Future.succeededFuture(result); - } catch (Exception e) { - return Future.failedFuture(e); - } - }); - */ - // Spy the isFetching AtomicBoolean behavior AtomicBoolean mockFetchingState = mock(AtomicBoolean.class); @@ -123,17 +102,6 @@ public void getShouldSkipFetchingWhenFetchingInProgress() throws NoSuchFieldExce when(mockFetchingState.compareAndSet(false, true)).thenReturn(false); when(mockFetchingState.compareAndSet(false, true)).thenReturn(false); - // Use reflection to set the private field 'isFetching' in the spy - - /* - ModelCache modelCacheWithMockedFetching = new ModelCache( - storage, GCS_BUCKET_NAME, cache, MODEL_CACHE_KEY_PREFIX, vertx) { - protected AtomicBoolean getIsFetching() { - return mockFetchingState; - } - }; - */ - // Use reflection to set the private field 'isFetching' in the spy accessible Field isFetchingField = ModelCache.class.getDeclaredField("isFetching"); isFetchingField.setAccessible(true); @@ -151,42 +119,37 @@ protected AtomicBoolean getIsFetching() { assertThat(firstCall.failed()).isTrue(); assertThat(firstCall.cause().getMessage()).isEqualTo( "ModelRunner fetching in progress. Skip current request"); - //verify(vertx).executeBlocking(any(Callable.class)); } @Test - public void getShouldFetchModelWhenNotInCache() { + public void getShouldFetchModelWhenNotInCache() throws OrtException { // given final String cacheKey = MODEL_CACHE_KEY_PREFIX + PBUUID; - // final String onnxModelPath = "models_pbuid=" + PBUUID + ".onnx"; + final byte[] bytes = new byte[]{1, 2, 3}; when(cache.getIfPresent(eq(cacheKey))).thenReturn(null); when(storage.get(GCS_BUCKET_NAME)).thenReturn(bucket); when(bucket.get(ONNX_MODEL_PATH)).thenReturn(blob); - // when(vertx.executeBlocking(any(Callable.class))) - // .thenReturn(Future.succeededFuture(blob)); - - when(vertx.executeBlocking(any(Callable.class))).thenAnswer(invocation -> { - Callable callable = invocation.getArgument(0); - try { - Object result = callable.call(); - return Future.succeededFuture(result); - } catch (Exception e) { - return Future.failedFuture(e); - } - }); + lenient().when(blob.getContent()).thenReturn(bytes); + lenient().when(onnxModelRunnerFactory.create(bytes)).thenReturn(onnxModelRunner); // when Future future = target.get(ONNX_MODEL_PATH, PBUUID); - System.out.println( - "future.cause().getMessage(): " + future.cause().getMessage() + "\n" + - "future.succeeded(): " + future.succeeded() - ); // then - assertThat(future.failed()).isTrue(); - assertThat(future.cause().getMessage()).isEqualTo("ModelRunner is fetched. Skip current request"); - verify(vertx).executeBlocking(any(Callable.class)); + future.onComplete(ar -> { + + System.out.println( + "future.onComplete: \n" + + " ar: " + ar + "\n" + + " ar.result(): " + ar.result() + "\n" + + " cache: " + cache + ); + + assertThat(ar.succeeded()).isTrue(); + assertThat(ar.result()).isEqualTo(onnxModelRunner); + verify(cache).put(eq(cacheKey), eq(onnxModelRunner)); + }); } @Test @@ -200,69 +163,40 @@ public void getShouldThrowExceptionWhenStorageFails() { // Simulate an error when accessing the storage bucket when(storage.get(GCS_BUCKET_NAME)).thenThrow(new StorageException(500, "Storage Error")); - // Mock vertx.executeBlocking to simulate the behavior of exception being thrown in getBlob - when(vertx.executeBlocking(any(Callable.class))).thenAnswer(invocation -> { - Callable callable = invocation.getArgument(0); - try { - // The callable should throw an exception when called - Object result = callable.call(); - return Future.succeededFuture(result); - } catch (Exception e) { - // Return a failed future when an exception occurs - return Future.failedFuture(e); - } - }); - // when Future future = target.get(ONNX_MODEL_PATH, PBUUID); - /* - Exception exception = assertThrows(PreBidException.class, () -> { - // Call the method and expect the PreBidException to be thrown - Future future = target.get(ONNX_MODEL_PATH, PBUUID); - // Force the future to resolve, triggering the exception - future.result(); - }); - //Future future = target.get(ONNX_MODEL_PATH, PBUUID); - */ // then future.onComplete(ar -> { System.out.println( "future.onComplete: \n" + + " ar: " + ar + "\n" + " ar.failed(): " + ar.failed() + "\n" + " ar.cause(): " + ar.cause() + "\n" + - " ar.cause().getMessage(): " + ar.cause().getMessage() + " ar.cause().getMessage(): " + ar.cause().getMessage() + "\n" + + " ar.result(): " + ar.result() ); assertThat(ar.cause()).isInstanceOf(PreBidException.class); assertThat(ar.cause().getMessage()).contains("Error accessing GCS artefact for model"); }); - - //assertThat(exception.getMessage()).contains("Error accessing GCS artefact for model"); } - /* @Test - public void getShouldThrowExceptionIfOnnxModelFails() { + public void getShouldThrowExceptionWhenOnnxModelFails() throws OrtException { // given final String cacheKey = MODEL_CACHE_KEY_PREFIX + PBUUID; + final byte[] bytes = new byte[]{1, 2, 3}; + + // Mock that the model is not in cache when(cache.getIfPresent(eq(cacheKey))).thenReturn(null); - when(storage.get(GCS_BUCKET_NAME)).thenReturn(bucket); + + // Simulate an error when accessing the storage bucket + when(storage.get(GCS_BUCKET_NAME)).thenReturn(bucket);; when(bucket.get(ONNX_MODEL_PATH)).thenReturn(blob); - when(blob.getContent()).thenThrow(new PreBidException("Failed to convert blob to ONNX model")); - - when(vertx.executeBlocking(any(Callable.class))).thenAnswer(invocation -> { - Callable callable = invocation.getArgument(0); - try { - // The callable should throw an exception when called - Object result = callable.call(); - return Future.succeededFuture(result); - } catch (Exception e) { - // Return a failed future when an exception occurs - return Future.failedFuture(e); - } - }); + lenient().when(blob.getContent()).thenReturn(bytes); + lenient().when(onnxModelRunnerFactory.create(bytes)).thenThrow(new OrtException("Failed to convert blob to ONNX model")); // when Future future = target.get(ONNX_MODEL_PATH, PBUUID); @@ -272,14 +206,54 @@ public void getShouldThrowExceptionIfOnnxModelFails() { System.out.println( "future.onComplete: \n" + + " ar: " + ar + "\n" + " ar.failed(): " + ar.failed() + "\n" + " ar.cause(): " + ar.cause() + "\n" + - " ar.cause().getMessage(): " + ar.cause().getMessage() + " ar.cause().getMessage(): " + ar.cause().getMessage() + "\n" + + " ar.result(): " + ar.result() ); + + assertThat(ar.failed()).isTrue(); assertThat(ar.cause()).isInstanceOf(PreBidException.class); assertThat(ar.cause().getMessage()).contains("Failed to convert blob to ONNX model"); }); } - */ + + @Test + public void getShouldThrowExceptionWhenBucketNotFound() { + // given + final String cacheKey = MODEL_CACHE_KEY_PREFIX + PBUUID; + final byte[] bytes = new byte[]{1, 2, 3}; + + // Mock that the model is not in cache + when(cache.getIfPresent(eq(cacheKey))).thenReturn(null); + + // Simulate an error when accessing the storage bucket + when(storage.get(GCS_BUCKET_NAME)).thenReturn(bucket);; + when(bucket.get(ONNX_MODEL_PATH)).thenReturn(blob); + lenient().when(blob.getContent()).thenThrow(new PreBidException("Bucket not found")); + + // when + Future future = target.get(ONNX_MODEL_PATH, PBUUID); + + // then + future.onComplete(ar -> { + + System.out.println( + "future.onComplete: \n" + + " ar: " + ar + "\n" + + " ar.failed(): " + ar.failed() + "\n" + + " ar.cause(): " + ar.cause() + "\n" + + " ar.cause().getMessage(): " + ar.cause().getMessage() + "\n" + + " ar.result(): " + ar.result() + ); + + + assertThat(ar.failed()).isTrue(); + assertThat(ar.cause()).isInstanceOf(PreBidException.class); + assertThat(ar.cause().getMessage()).contains("Bucket not found"); + }); + } + }