-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathblocks_extras.py
65 lines (51 loc) · 1.97 KB
/
blocks_extras.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""
This module contains classes that should be merged at some point inside blocks
"""
import copy
import numpy
import theano
from blocks.initialization import NdarrayInitialization, Orthogonal
# should be in blocks.initialization
class GlorotBengio(NdarrayInitialization):
"""Initialize parameters with Glorot-Bengio method.
Use the following gaussian parameters: mean=0 and std=sqrt(scale/Nin).
In some circles this method is also called Xavier weight initialization.
Parameters
----------
scale : float
1 for linear/tanh/sigmoid. 2 for RELU
normal : bool
Perform sampling from normal distribution. By defaut use uniform.
Notes
-----
For more information, see [GLOROT]_.
.. [GLOROT] Glorot et al. *Understanding the difficulty of training
deep feedforward neural networks*, International conference on
artificial intelligence and statistics, 249-256
http://jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf
"""
def __init__(self, scale=1, normal=False):
self._scale = float(scale)
self._normal = normal
def generate(self, rng, shape):
w = numpy.sqrt(self._scale/shape[-1])
if self._normal:
m = rng.normal(0., w, size=shape)
else:
m = rng.uniform(-w, w, size=shape)
return m.astype(theano.config.floatX)
class OrthogonalGlorot(GlorotBengio):
"""Initialize a random orthogonal matrix.
"""
def __init__(self, *args, **kwargs):
super(OrthogonalGlorot,self).__init__(*args, **kwargs)
self.orth = Orthogonal()
def generate(self, rng, shape):
if len(shape) == 1:
return super(OrthogonalGlorot,self).generate(rng,shape)
N = shape[0]
M = shape[1] // N
if M > 1 and len(shape) == 2 and shape[1] == M*N:
res = [self.orth.generate(rng,(N,N)) for i in range(M)]
return numpy.concatenate(res,axis=-1)
return self.orth.generate(rng, shape)