Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OMNIGLOT Dataset #46

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ The following dataset loaders are available:
- `ImageFolder <#imagefolder>`__
- `Imagenet-12 <#imagenet-12>`__
- `CIFAR10 and CIFAR100 <#cifar>`__

- OMNIGLOT
Datasets have the API: - ``__getitem__`` - ``__len__`` They all subclass
from ``torch.utils.data.Dataset`` Hence, they can all be multi-threaded
(python multiprocessing) using standard torch.utils.data.DataLoader.
Expand Down Expand Up @@ -187,6 +187,16 @@ here <https://github.com/facebook/fb.resnet.torch/blob/master/INSTALL.md#downloa
`Here is an
example <https://github.com/pytorch/examples/blob/27e2a46c1d1505324032b1d94fc6ce24d5b67e97/imagenet/main.py#L48-L62>`__.

OMNIGLOT
~~~~~~~~

`dset.OMNIGLOT(root_dir, [transform=None, target_transform=None])`

The dataset is composed of pairs: `(Image,Category idx)`. Each category corresponds to one character in one alphabet. Matching between classes indexes and real classes can be accessed through: `dataset.idx_classes`


From: `Lake, B. M., Salakhutdinov, R., and Tenenbaum, J. B. (2015). Human-level concept learning through probabilistic program induction. Science, 350(6266), 1332-1338.`

Models
======

Expand Down
3 changes: 2 additions & 1 deletion torchvision/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from .coco import CocoCaptions, CocoDetection
from .cifar import CIFAR10, CIFAR100
from .mnist import MNIST
from .omniglot import OMNIGLOT

__all__ = ('LSUN', 'LSUNClass',
'ImageFolder',
'CocoCaptions', 'CocoDetection',
'CIFAR10', 'CIFAR100',
'MNIST')
'MNIST','OMNIGLOT')
104 changes: 104 additions & 0 deletions torchvision/datasets/omniglot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from __future__ import print_function
import torch.utils.data as data
from PIL import Image
import os
import os.path
import errno
import torch
import json
import codecs
import numpy as np
from PIL import Image

class OMNIGLOT(data.Dataset):
urls = [
'https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip',
'https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip'
]
raw_folder = 'raw'
processed_folder = 'processed'
training_file = 'training.pt'
test_file = 'test.pt'

def __init__(self, root, train=True, transform=None, target_transform=None, download=False):
self.root = root
self.transform = transform
self.target_transform = target_transform
if download:
self.download()

if not self._check_exists():
raise RuntimeError('Dataset not found.'
+ ' You can use download=True to download it')

self.all_items=find_classes(os.path.join(self.root, self.processed_folder))

This comment was marked as off-topic.

self.idx_classes=index_classes(self.all_items)

def __getitem__(self, index):
filename=self.all_items[index][0]
path=self.all_items[index][2]+"/"+filename

This comment was marked as off-topic.

This comment was marked as off-topic.

img=Image.open(path).convert('RGB')
target=self.idx_classes[self.all_items[index][1]]
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)

return img,target

def __len__(self):
return len(self.all_items)

def _check_exists(self):

This comment was marked as off-topic.

return os.path.exists(os.path.join(self.root, self.processed_folder, "images_evaluation")) and \
os.path.exists(os.path.join(self.root, self.processed_folder, "images_background"))

def download(self):
from six.moves import urllib
import zipfile

if self._check_exists():
return

# download files
try:

This comment was marked as off-topic.

os.makedirs(os.path.join(self.root, self.raw_folder))
os.makedirs(os.path.join(self.root, self.processed_folder))
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise

for url in self.urls:
print('== Downloading ' + url)
data = urllib.request.urlopen(url)

This comment was marked as off-topic.

filename = url.rpartition('/')[2]
file_path = os.path.join(self.root, self.raw_folder, filename)
with open(file_path, 'wb') as f:
f.write(data.read())
file_processed = os.path.join(self.root, self.processed_folder)
print("== Unzip from "+file_path+" to "+file_processed)
zip_ref = zipfile.ZipFile(file_path, 'r')
zip_ref.extractall(file_processed)
zip_ref.close()
print("Download finished.")

def find_classes(root_dir):
retour=[]

This comment was marked as off-topic.

for (root,dirs,files) in os.walk(root_dir):
for f in files:
if (f.endswith("png")):
r=root.split('/')
lr=len(r)
retour.append((f,r[lr-2]+"/"+r[lr-1],root))

This comment was marked as off-topic.

print("Found %d items "%len(retour))
return retour

def index_classes(items):
idx={}
for i in items:
if (not i[1] in idx):

This comment was marked as off-topic.

idx[i[1]]=len(idx)
print("Found %d classes"% len(idx))
return idx