-
Notifications
You must be signed in to change notification settings - Fork 0
/
4iii_dispatch.py
113 lines (91 loc) · 3.9 KB
/
4iii_dispatch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon May 30 01:48:39 2022
Chapter 4 (active vision memory, ch3 in latex) experiment dispatch script.
!!! DEPRECATED !!!
Experiment 2: evaluating different recurrent memory variants.
WW_LSTM vs WW_RNN vs LSTM vs RNN
Script trains all models on 3 timesteps. 4iii_eval.py evaluates them
appropriately.
Notes:
- More justice: varying where_dim
- Old WW stuff was fovonly, so it performed better.
- RAM_ch4 has been revised since, so this script might not run now.
Result: training memory with T=3 and evaluating at T=1,2,3 was a mistake, and
so the results from this script were not used in the dissertation in their
entirety.
@author: piotr
"""
import os
from functools import partial
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torchvision.transforms import Compose, ToTensor, Normalize
from torch.utils.data.sampler import RandomSampler
from model import RAM_ch4
from modules import crude_retina, classification_network_short, ResNetEncoder
from modules import lstm_, WW_LSTM, vanilla_rnn, WW_rnn
from modules import WW_module, WhereMix
from ch3_trainer import Trainer
from CUB_loader import CUBDataset, collate_pad, seed_worker
from utils import get_ymlconfig, set_seed
def main(config):
set_seed(config.seed, gpu=config.gpu)
transform = Compose([ToTensor(),
Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])])
# Build name based on architecture variant for current experiment
config.name = "ch4iii"
config.name += "-s{}".format(config.seed)
config.name += "VAR{}".format(int(config.vars.variant))
dataset = CUBDataset(transform = transform, shuffle=True)
generator = torch.Generator()
generator.manual_seed(config.seed) # separate seed for more control
loader = torch.utils.data.DataLoader(
dataset, batch_size = config.batch_size,
sampler = RandomSampler(dataset, generator = generator),
collate_fn = collate_pad, num_workers = config.num_workers,
pin_memory = config.gpu, worker_init_fn = seed_worker)
# Sensor and FE
retina = crude_retina(config.retina.foveal_size, config.retina.n_patches,
config.retina.scaling, config.gpu, clamp = False)
FE = ResNetEncoder(retina.out_shape)
# WW stuff
where_dim = 10
if config.vars.variant <= 1: #WW variants
WWfov = WW_module(FE.out_shape, where_dim)
WWper = WW_module(FE.out_shape, where_dim)
mem_in = WWfov.out_shape[:-1] + (WWfov.out_shape[-1]*2,)
mem_shape = mem_in
classifier_in = mem_shape.numel()
else: #non-WW variants
WWfov = nn.AdaptiveAvgPool2d((1,1))
WWper = nn.AdaptiveAvgPool2d((1,1))
mem_in = FE.out_shape[0]*2
mem_shape = FE.out_shape[0] #small, for closer param # to WW variants
classifier_in = mem_shape
# Memory
mem_variants = [partial(WW_LSTM, mem_in, FE.out_shape[0], where_dim*2,
gate_op=WhereMix, in_op=WhereMix),
partial(WW_rnn, mem_in, mem_shape),
partial(lstm_, mem_in, mem_shape),
partial(vanilla_rnn, mem_in, mem_shape)]
memory = mem_variants[config.vars.variant]()
# Classifier
classifier = classification_network_short(classifier_in, 200)
model = RAM_ch4(config.name, retina, FE, WWfov, WWper, memory, classifier,
gpu=True)
trainer = Trainer(config, loader, model)
trainer.train()
if __name__ == '__main__':
for seed in [1,9,919]:
for variant in range(4):
config = get_ymlconfig('./4iii_dispatch.yml')
config.seed = seed
# 0:WW_LSTM, 1:WW_RNN, 2:LSTM 3:RNN
config.vars.variant = variant
# config.training.resume = True
main(config)