-
Notifications
You must be signed in to change notification settings - Fork 0
/
input_pipeline.py
114 lines (93 loc) · 3.09 KB
/
input_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""
Data pipeline for the diffusion model.
"""
import tensorflow_datasets as tfds
import tensorflow as tf
import numpy as np
from typing import NamedTuple, Iterator
from tensorflow.keras.applications import resnet_v2
class Batch(NamedTuple):
images: np.ndarray
labels: np.ndarray
def load_dataset(name: str,
split: str,
shuffle: bool = False,
batch_size: int = 128) -> Iterator:
"""Loads the dataset.
Args:
name: Name of the dataset.
split: The split of the dataset to load.
shuffle: Whether to shuffle the dataset.
batch_size: The batch size.
(deprecated)
Returns:
An iterator over the dataset.
>>> dataset = load_dataset(name='mnist', split='train')
>>> print(next(dataset).images.shape)
(128, 28, 28, 1)
"""
ds = tfds.load(name,
split=split,
shuffle_files=shuffle,
as_supervised=True).cache().repeat()
def normalize(images, labels):
images = tf.cast(images, dtype='float32')
images = resnet_v2.preprocess_input(images)
return images, labels
ds = ds.batch(batch_size, drop_remainder=True)
ds = ds.map(normalize)
ds = ds.map(lambda x, y: Batch(x, y))
return iter(tfds.as_numpy(ds))
def create_split(name: str,
split: str,
shuffle: bool = False,
batch_size: int = 128,
cache: bool = False) -> tf.data.Dataset:
"""Creates an iterator over the dataset.
Args:
name: Name of the dataset.
split: The split of the dataset to load.
shuffle: Whether to shuffle the dataset.
batch_size: The batch size.
cache: Whether to cache the dataset.
Returns:
A batched dataset.
"""
ds = tfds.load(name,
split=split,
shuffle_files=shuffle,
as_supervised=True)
if cache:
ds = ds.cache()
ds = ds.repeat()
ds = ds.batch(batch_size, drop_remainder=True)
return ds
def preprocess_image_dataset(dataset, image_size: int,
dtype: str = 'float32') -> tf.data.Dataset:
"""
Normalizes images with the ResNet preprocessing function to [-1, 1]
and resizes them to the specified size.
"""
def resize(images, labels):
images = tf.image.resize(images, [image_size, image_size])
return images, labels
def normalize(images, labels):
images = tf.cast(images, dtype=dtype)
images = resnet_v2.preprocess_input(images)
return images, labels
dataset = dataset.map(resize)
dataset = dataset.map(normalize)
return dataset
def make_denoising_dataset(dataset):
def corrupt(images, labels):
noise = tf.random.normal(shape=images.shape, mean=0.0, stddev=1.0)
noised = images + noise
return noised, images
dataset = dataset.map(corrupt)
return dataset
def convert2iterator(ds):
ds = ds.map(lambda x, y: Batch(x, y))
return iter(tfds.as_numpy(ds))
def prepare_tf_data(xs):
# TODO: implement this function
pass