-
Notifications
You must be signed in to change notification settings - Fork 9
/
iterators.py
42 lines (30 loc) · 1.15 KB
/
iterators.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#https://github.com/hvy/chainer-wasserstein-gan/blob/master/iterators.py
import numpy
from chainer.dataset import iterator
def to_tuple(x):
if hasattr(x, '__getitem__'):
return x
return x,
class UniformNoiseGenerator(object):
def __init__(self, low, high, size):
self.low = low
self.high = high
self.size = to_tuple(size)
def __call__(self, batch_size):
return numpy.random.uniform(self.low, self.high, (batch_size,) +
self.size).astype(numpy.float32)
class GaussianNoiseGenerator(object):
def __init__(self, loc, scale, size):
self.loc = loc
self.scale = scale
self.size = to_tuple(size)
def __call__(self, batch_size):
return numpy.random.normal(self.loc, self.scale, (batch_size,) +
self.size).astype(numpy.float32)
class RandomNoiseIterator(iterator.Iterator):
def __init__(self, noise_generator, batch_size):
self.noise_generator = noise_generator
self.batch_size = batch_size
def __next__(self):
batch = self.noise_generator(self.batch_size)
return batch