diff --git a/README.md b/README.md index be19cb06f9..9b840d3476 100644 --- a/README.md +++ b/README.md @@ -206,13 +206,13 @@ from flash.image import ImageEmbedder download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/") # 2. Create an ImageEmbedder with resnet50 trained on imagenet. -embedder = ImageEmbedder(backbone="resnet50", embedding_dim=128) +embedder = ImageEmbedder(backbone="resnet50") # 3. Generate an embedding from an image path. embeddings = embedder.predict("data/hymenoptera_data/predict/153783656_85f9c3ac70.jpg") # 4. Print embeddings shape -print(embeddings.shape) +print(embeddings[0].shape) ``` diff --git a/flash_examples/integrations/fiftyone/image_embedding.py b/flash_examples/integrations/fiftyone/image_embedding.py index b9d1651ceb..019bd9cffe 100644 --- a/flash_examples/integrations/fiftyone/image_embedding.py +++ b/flash_examples/integrations/fiftyone/image_embedding.py @@ -28,7 +28,7 @@ ) # 3 Load model -embedder = ImageEmbedder(backbone="resnet101", embedding_dim=128) +embedder = ImageEmbedder(backbone="resnet101") # 4 Generate embeddings filepaths = dataset.values("filepath")