From 1eeb4e57e6ff4cd5e08cda9f76c4cd49eebe50e6 Mon Sep 17 00:00:00 2001 From: Johannes Oswald Date: Wed, 4 May 2022 14:20:56 +0200 Subject: [PATCH] Update utils.py Remove not supported Omniglot code --- utils/utils.py | 49 ++----------------------------------------------- 1 file changed, 2 insertions(+), 47 deletions(-) diff --git a/utils/utils.py b/utils/utils.py index 4a1f4d9..70d06ca 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -107,53 +107,8 @@ def load_data(args): input_channels=3 if args.dataset=="Omniglot": - dataset_transform = ClassSplitter(shuffle=True, - num_train_per_class=args.num_shots_train, - num_test_per_class=args.num_shots_test) - transform = Compose([Resize(28), ToTensor()]) - - meta_train_dataset = Omniglot("data", - transform=transform, - target_transform=Categorical(args.num_ways), - num_classes_per_task=args.num_ways, - meta_train=True, - class_augmentations=class_augmentations, - dataset_transform=dataset_transform, - download=True) - - meta_val_dataset = Omniglot("data", - transform=transform, - target_transform=Categorical(args.num_ways), - num_classes_per_task=args.num_ways, - meta_val=True, - class_augmentations=class_augmentations, - dataset_transform=dataset_transform) - meta_test_dataset = Omniglot("data", - transform=transform, - target_transform=Categorical(args.num_ways), - num_classes_per_task=args.num_ways, - meta_test=True, - dataset_transform=dataset_transform) - - meta_dataloader["train"] = BatchMetaDataLoader(meta_train_dataset, - batch_size=args.batch_size, - shuffle=True, - num_workers=args.num_workers, - pin_memory=True) - - meta_dataloader["val"] = BatchMetaDataLoader(meta_val_dataset, - batch_size=args.batch_size, - shuffle=True, - num_workers=args.num_workers, - pin_memory=True) - - meta_dataloader["test"]=BatchMetaDataLoader(meta_test_dataset, - batch_size=args.batch_size, - shuffle=True, - num_workers=args.num_workers, - pin_memory=True) - feature_size=args.hidden_size - input_channels=1 + print("Omniglot not supported.") + exit() if args.dataset=="TieredImagenet": dataset_transform = ClassSplitter(shuffle=True,