From 2d5fd3fa6054e77341c1b032b3b1ff635565aa08 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Fri, 10 Nov 2023 15:28:24 -0800 Subject: [PATCH] Clarify which models use which normalization type --- trapdata/ml/models/classification.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/trapdata/ml/models/classification.py b/trapdata/ml/models/classification.py index d213fef5..7e06d977 100644 --- a/trapdata/ml/models/classification.py +++ b/trapdata/ml/models/classification.py @@ -6,7 +6,7 @@ from trapdata.db.models.detections import save_classified_objects from trapdata.db.models.queue import DetectedObjectQueue, UnclassifiedObjectQueue -from .base import InferenceBaseClass, imagenet_normalization +from .base import InferenceBaseClass, imagenet_normalization, tensorflow_normalization class ClassificationIterableDatabaseDataset(torch.utils.data.IterableDataset): @@ -43,6 +43,7 @@ def transform(self, cropped_image): class EfficientNetClassifier(InferenceBaseClass): input_size = 300 + normalization = tensorflow_normalization def get_model(self): num_classes = len(self.category_map) @@ -108,6 +109,7 @@ def forward(self, x): class Resnet50Classifier(InferenceBaseClass): input_size = 300 + normalization = imagenet_normalization def get_model(self): num_classes = len(self.category_map) @@ -122,12 +124,11 @@ def get_model(self): return model def get_transforms(self): - mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] return torchvision.transforms.Compose( [ torchvision.transforms.Resize((self.input_size, self.input_size)), torchvision.transforms.ToTensor(), - torchvision.transforms.Normalize(mean, std), + self.normalization, ] ) @@ -275,6 +276,9 @@ class UKDenmarkMothSpeciesClassifierMixedResolution( name = "UK & Denmark Species Classifier" description = "Trained on April 3, 2023 using mix of low & med resolution images." + + normalization = imagenet_normalization + weights_path = ( "https://object-arbutus.cloud.computecanada.ca/ami-models/moths/classification/" "uk-denmark-moths-mixedres-20230403_140131_30.pth" @@ -288,6 +292,9 @@ class UKDenmarkMothSpeciesClassifierMixedResolution( class PanamaMothSpeciesClassifierMixedResolution(SpeciesClassifier, Resnet50Classifier): name = "Panama Species Classifier" description = "Trained on December 22, 2022 using a mix of low & med resolution images. 148 species." + + normalization = imagenet_normalization + weights_path = ( "https://object-arbutus.cloud.computecanada.ca/ami-models/moths/classification/" "panama_moth-model_v01_resnet50_2023-01-24-09-51.pt"