diff --git a/README.rst b/README.rst index 3da007a5d2e..9193e8f6827 100644 --- a/README.rst +++ b/README.rst @@ -46,7 +46,7 @@ The following dataset loaders are available: - `ImageFolder <#imagefolder>`__ - `Imagenet-12 <#imagenet-12>`__ - `CIFAR10 and CIFAR100 <#cifar>`__ - +- OMNIGLOT Datasets have the API: - ``__getitem__`` - ``__len__`` They all subclass from ``torch.utils.data.Dataset`` Hence, they can all be multi-threaded (python multiprocessing) using standard torch.utils.data.DataLoader. @@ -187,6 +187,16 @@ here `__. +OMNIGLOT +~~~~~~~~ + + `dset.OMNIGLOT(root_dir, [transform=None, target_transform=None])` + +The dataset is composed of pairs: ``(Filename,Category idx)``. Each caty"egory corresponds to one character in one alphabet. Matching between classes indexes and real classes can be accessed through: `dataset.idx_classes` +The dataset can be used with ``transform=transforms.FilenameToPILImage`` to obtain pairs of (PIL Image,Category_idx) + +From: `Lake, B. M., Salakhutdinov, R., and Tenenbaum, J. B. (2015). Human-level concept learning through probabilistic program induction. Science, 350(6266), 1332-1338.` + Models ====== diff --git a/test/test_omniglot.py b/test/test_omniglot.py new file mode 100644 index 00000000000..82e2e6a4bf7 --- /dev/null +++ b/test/test_omniglot.py @@ -0,0 +1,13 @@ +import torch +import torchvision.datasets as dset +import torchvision.transforms as transforms + +print('Omniglot') +a = dset.OMNIGLOT("../data", download=True,transform=transforms.Compose([transforms.FilenameToPILImage(),transforms.ToTensor()])) + +print(a.idx_classes) +print(a[3]) +# print('\n\nCifar 100') +# a = dset.CIFAR100(root="abc/def/ghi", download=True) + +# print(a[3]) diff --git a/torchvision/datasets/__init__.py b/torchvision/datasets/__init__.py index e9c4b0e7184..56622b679ca 100644 --- a/torchvision/datasets/__init__.py +++ b/torchvision/datasets/__init__.py @@ -3,9 +3,10 @@ from .coco import CocoCaptions, CocoDetection from .cifar import CIFAR10, CIFAR100 from .mnist import MNIST +from .omniglot import OMNIGLOT __all__ = ('LSUN', 'LSUNClass', 'ImageFolder', 'CocoCaptions', 'CocoDetection', 'CIFAR10', 'CIFAR100', - 'MNIST') + 'MNIST','OMNIGLOT') diff --git a/torchvision/datasets/omniglot.py b/torchvision/datasets/omniglot.py new file mode 100644 index 00000000000..e6999cea832 --- /dev/null +++ b/torchvision/datasets/omniglot.py @@ -0,0 +1,115 @@ +from __future__ import print_function +import torch.utils.data as data +from PIL import Image +import os +import os.path +import errno +import torch +import json +import codecs +import numpy as np +from PIL import Image + +class OMNIGLOT(data.Dataset): + urls = [ + 'https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip', + 'https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip' + ] + raw_folder = 'raw' + processed_folder = 'processed' + training_file = 'training.pt' + test_file = 'test.pt' + + ''' + The items are (filename,category). The index of all the categories can be found in self.idx_classes + + Args: + + - root: the directory where the dataset will be stored + - transform: how to transform the input + - target_transform: how to transform the target + - download: need to download the dataset + ''' + def __init__(self, root, transform=None, target_transform=None, download=False): + self.root = root + self.transform = transform + self.target_transform = target_transform + + if download: + self.download() + + if not self._check_exists(): + raise RuntimeError('Dataset not found.' + + ' You can use download=True to download it') + + self.all_items=find_classes(os.path.join(self.root, self.processed_folder)) + self.idx_classes=index_classes(self.all_items) + + def __getitem__(self, index): + filename=self.all_items[index][0] + img=str.join('/',[self.all_items[index][2],filename]) + + target=self.idx_classes[self.all_items[index][1]] + if self.transform is not None: + img = self.transform(img) + if self.target_transform is not None: + target = self.target_transform(target) + + return img,target + + def __len__(self): + return len(self.all_items) + + def _check_exists(self): + return os.path.exists(os.path.join(self.root, self.processed_folder, "images_evaluation")) and \ + os.path.exists(os.path.join(self.root, self.processed_folder, "images_background")) + + def download(self): + from six.moves import urllib + import zipfile + + if self._check_exists(): + return + + # download files + try: + os.makedirs(os.path.join(self.root, self.raw_folder)) + os.makedirs(os.path.join(self.root, self.processed_folder)) + except OSError as e: + if e.errno == errno.EEXIST: + pass + else: + raise + + for url in self.urls: + print('== Downloading ' + url) + data = urllib.request.urlopen(url) + filename = url.rpartition('/')[2] + file_path = os.path.join(self.root, self.raw_folder, filename) + with open(file_path, 'wb') as f: + f.write(data.read()) + file_processed = os.path.join(self.root, self.processed_folder) + print("== Unzip from "+file_path+" to "+file_processed) + zip_ref = zipfile.ZipFile(file_path, 'r') + zip_ref.extractall(file_processed) + zip_ref.close() + print("Download finished.") + +def find_classes(root_dir): + retour=[] + for (root,dirs,files) in os.walk(root_dir): + for f in files: + if (f.endswith("png")): + r=root.split('/') + lr=len(r) + retour.append((f,r[lr-2]+"/"+r[lr-1],root)) + print("== Found %d items "%len(retour)) + return retour + +def index_classes(items): + idx={} + for i in items: + if (not i[1] in idx): + idx[i[1]]=len(idx) + print("== Found %d classes"% len(idx)) + return idx \ No newline at end of file diff --git a/torchvision/transforms.py b/torchvision/transforms.py index 68ce23a1b1a..62af39903c9 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -27,6 +27,13 @@ def __call__(self, img): img = t(img) return img +class FilenameToPILImage(object): + """ + Load a PIL RGB Image from a filename. + """ + def __call__(self,filename): + img=Image.open(filename).convert('RGB') + return img class ToTensor(object): """Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range