-
Notifications
You must be signed in to change notification settings - Fork 4
/
vis3.py
30 lines (26 loc) · 1023 Bytes
/
vis3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import matplotlib.pyplot as plt
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from matplotlib.cbook import get_sample_data
from utils import tile_raster_images
test_code = []
dataset_name = 'MNIST/2'
dataset = '../data/mnist.pkl.gz'
datafile = gzip.open(dataset, 'rb')
train_set, valid_set, test_set = cPickle.load(datafile)
datafile.close()
ax = plt.gca()
artists = []
zoom = 0.1
datax = test_set[0]
fig, ax = plt.subplots()
for i in range(len(test_set[0])):
channel = 1 - datax[i][np.newaxis]
img = Image.fromarray(tile_raster_images(X=(channel,channel,channel,None),img_shape=(28,28),tile_shape=(1,1)))
img.save('test.png')
img = plt.imread(get_sample_data('/home/llajan/RBM/from_scratch/test.png'))
im = OffsetImage(img, zoom=zoom)
ab = AnnotationBbox(im, (test_code[i][0], test_code[i][1]), xycoords='data', frameon=False)
artists.append(ax.add_artist(ab))
ax.update_datalim(np.column_stack([test_code[:,0],test_code[:,1]]))
ax.autoscale()
plt.savefig('clusters.eps', format='eps', dpi=1000)