Skip to content

Latest commit

 

History

History

tta_wrapper

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 

Test Time Augmentation Wrapper

This module wraps existing PyTorch model, performs inference on multiple augmented images and them merges the predictions into one.

Wrapper adds augmentation layers to your model like this:

            Input
              |           # input batch; shape B, H, W, C
         / / / \ \ \      # duplicate image for augmentation; shape N*B, H, W, C
        | | |   | | |     # apply augmentations (flips, rotation, shifts)
     your nn.Module model
        | | |   | | |     # reverse transformations (this part is skipped for classification)
         \ \ \ / / /      # merge predictions (mean, max, gmean)
              |           # output mask; shape B, H, W, C
            Output

Example

from pytorch_tools.tta_wrapper import TTA
# 2 x 3 x 3 = 18 augmentations per image!
tta_model = TTA(model, h_flip=True, h_shift=[5,-5], mul=[0.9, 1.1])
for batch in loader:
    prediction = tta_model(batch)