diff --git a/build_feature_index.py b/build_feature_index.py index 60b5a79..f75bf6d 100644 --- a/build_feature_index.py +++ b/build_feature_index.py @@ -2,6 +2,7 @@ import numpy as np from generic_utils import load_image +from PIL import Image as pil_image from app_id_utils import freeze_app_ids, list_app_ids, app_id_to_image_filename from data_utils import get_label_database_filename @@ -14,7 +15,9 @@ ) -def build_feature_index(is_horizontal_banner=False, resolution=None): +def build_feature_index( + is_horizontal_banner=False, resolution=None, apply_flip=False, apply_mirror=False +): pooling = "avg" feature_filename = get_label_database_filename(pooling) @@ -44,6 +47,10 @@ def build_feature_index(is_horizontal_banner=False, resolution=None): image_filename = app_id_to_image_filename(app_id, is_horizontal_banner) image = load_image(image_filename, target_size=target_model_size) + if apply_flip: + image = pil_image.flip(image) + if apply_mirror: + image = pil_image.mirror(image) features = convert_image_to_features(image, model, preprocess=preprocess) Y_hat[counter, :] = features @@ -58,4 +65,9 @@ def build_feature_index(is_horizontal_banner=False, resolution=None): if __name__ == "__main__": - build_feature_index(is_horizontal_banner=False, resolution=None) + build_feature_index( + is_horizontal_banner=False, + resolution=None, + apply_flip=False, + apply_mirror=False, + )