-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cell '20 | The Tolman-Eichenbaum machine: Unifying space and relational memory through generalization in the hippocampal formation. #16
Comments
相关文献
|
Mattar et al. 2018提出的prioritized-sweep可以较好地解释小鼠在rest时replay内容顺序的结构,但是人类在offline的replay却不能很好地被其预测,如Antonov et al. 2022和Eldar et al. 2020中数据所示。这表明人类不是单纯地由expected-return所驱动,而online的sample是严重依赖policy的,导致TEM学出的hpc表征也是依赖policy的。但人类中除了online的expr,也可以在rest时自发地重组织形成simu的expr(如Liu et al. 2019),其内部具备丰富的sequence结构,使用其代替TEM中对Hopfield的reactivation是否有助于env-model的学习呢? |
这也是在expr中只学习p(s'|s)而不学习p(s'|s,a)的弊端,如果还有action会导致存储指数爆炸,主要也没有来自motor cortex的信息。如何平衡model修正与value-update之间的关系?毕竟replay应不仅仅和value相关,replay的是state-trans。如果dopamine驱动,reward对应的back replay,算是加上了一种state-trans偏好? |
另外replay的信息压缩是否和information bottleneck的rl-sampling(Zhu et al. 2020)相关? |
Tolman-Eichenbaum Machine在这里,我们解析pytorch版本的TEM代码,主要解释model中的运算逻辑和架构。十分感谢Jcobb Bakermans提供的pytorch版本代码,比James Whittington的tensorflow版本TEM代码好看太多了。 |
Algorithmforwarddef forward(self, walk, prev_iter = None, prev_M = None):
# The previous iteration may contain walks without action.
# These are new walks, for which some parameters need to be reset.
steps = self.init_walks(prev_iter)
# Forward pass: perform a TEM iteration for each set of [place, observation, action],
# and produce inferred and generated variables for each step.
for g, x, a in walk:
# If there is no previous iteration at all: all walks
# are new, initialise a whole new iteration object
if steps is None:
# Use an Iteration object to set initial values before any real iterations,
# initialising M, x_inf as zero. Set actions to None blank
# to indicate there was no previous action
steps = [self.init_iteration(g, x, [None for _ in range(len(a))], prev_M)]
# Perform TEM iteration using transition from previous iteration
L, M, g_gen, p_gen, x_gen, x_logits, x_inf, g_inf, p_inf = self.iteration(
x, g, steps[-1].a, steps[-1].M, steps[-1].x_inf, steps[-1].g_inf
)
# Store this iteration in iteration object in steps list
steps.append(Iteration(g, x, a, L, M, g_gen, p_gen, x_gen, x_logits, x_inf, g_inf, p_inf))
# The first step is either a step from a previous walk or initialisiation rubbish, so remove it
steps = steps[1:]
# Return steps, which is a list of Iteration objects
return steps 在这里,我们可以看到每次调用tem的时候,我们需要向其提供
之后便开始,对 |
generativegenerative model可以被分解为如下形式,分别对应着generative process的三部分 def generative(self, M_prev, p_inf, g_inf, g_gen):
# M_prev - (1/2, batch_size(16), sum(n_p)(400), sum(n_p)(400))
# p_inf - (n_f(5), batch_size(16), n_x*n_g_subsampled(100,100,80,60,60))
# g_inf - (n_f(5), batch_size(16), n_g(30, 30, 24, 18, 18))
# g_gen - (n_f(5), batch_size(16), n_g(30, 30, 24, 18, 18))
# Generate observation from inferred grounded location, using only
# the highest frequency. Also keep non-softmaxed logits which are used in the loss later
# x_p - (batch_size(16), n_x(45))
# x_p_logits - (batch_size(16), n_x(45))
x_p, x_p_logits = self.gen_x(p_inf[0])
# Retrieve grounded location from memory by pattern completion on inferred abstract location
# p_g_inf - (n_f(5), batch_size(16), n_x*n_g_subsampled(100,100,80,60,60))
p_g_inf = self.gen_p(g_inf, M_prev[0]) # was p_mem_gen
# And generate observation from the grounded location retrieved from inferred abstract location
# x_g - (batch_size(16), n_x(45))
# x_g_logits - (batch_size(16), n_x(45))
x_g, x_g_logits = self.gen_x(p_g_inf[0])
# Retreive grounded location from memory by pattern completion on abstract location by transitioning
# p_g_gen - (n_f(5), batch_size(16), n_x*n_g_subsampled(100,100,80,60,60))
p_g_gen = self.gen_p(g_gen, M_prev[0])
# Generate observation from sampled grounded location
# x_gt - (batch_size(16), n_x(45))
# x_gt_logits - (batch_size(16), n_x(45))
x_gt, x_gt_logits = self.gen_x(p_g_gen[0])
# Return all generated observations and their corresponding logits
return (x_p, x_g, x_gt), (x_p_logits, x_g_logits, x_gt_logits), p_g_inf generative process的示意图如下:
|
hebbiandef hebbian(self, M_prev, p_inferred, p_generated, do_hierarchical_connections=True):
# Create new ground memory for attractor network by setting weights to outer product of learned vectors
# p_inferred corresponds to p in the paper, and p_generated corresponds to p^.
# The order of p + p^ and p - p^ is reversed since these are row vectors,
# instead of column vectors in the paper.
# M_new - (batch_size(16), sum(n_p)(400), sum(n_p)(400)),
# calculated from (16,400,1) matmul (16,1,400), not element-wise mul
M_new = torch.squeeze(torch.matmul(
torch.unsqueeze(p_inferred + p_generated, 2),torch.unsqueeze(p_inferred - p_generated,1)
))
# Multiply by connection vector, e.g. only keeping weights
# from low to high frequencies for hierarchical retrieval.
if do_hierarchical_connections:
M_new = M_new * self.hyper['p_update_mask']
# Store grounded location in attractor network memory with weights M by Hebbian learning of pattern
M = torch.clamp(self.hyper['lambda'] * M_prev + self.hyper['eta'] * M_new, min=-1, max=1)
return M 在使用更新公式更新之后,使用clamp截断(-1,1)之外的部分。 |
lossdef loss(self, g_gen, p_gen, x_logits, x, g_inf, p_inf, p_inf_x, M_prev):
# Calculate loss function, separately for each component
# because you might want to reweight contributions later.
# L_p_gen is squared error loss between inferred grounded location
# and grounded location retrieved from inferred abstract location
L_p_g = torch.sum(torch.stack(utils.squared_error(p_inf, p_gen), dim=0), dim=0)
# L_p_inf is squared error loss between inferred grounded location
# and grounded location retrieved from sensory experience
L_p_x = torch.sum(torch.stack(utils.squared_error(p_inf, p_inf_x), dim=0), dim=0)\
if self.hyper['use_p_inf'] else torch.zeros_like(L_p_g)
# L_g is squared error loss between generated abstract location and inferred abstract location
L_g = torch.sum(torch.stack(utils.squared_error(g_inf, g_gen), dim=0), dim=0)
# L_x is a cross-entropy loss between sensory experience and different model predictions.
# First get true labels from sensory experience
labels = torch.argmax(x, 1)
# L_x_gen: losses generated by generative model from g_prev -> g -> p -> x
L_x_gen = utils.cross_entropy(x_logits[2], labels)
# L_x_g: Losses generated by generative model from g_inf -> p -> x
L_x_g = utils.cross_entropy(x_logits[1], labels)
# L_x_p: Losses generated by generative model from p_inf -> x
L_x_p = utils.cross_entropy(x_logits[0], labels)
# L_reg are regularisation losses, L_reg_g on L2 norm of g
L_reg_g = torch.sum(torch.stack([torch.sum(g ** 2, dim=1) for g in g_inf], dim=0), dim=0)
# And L_reg_p regularisation on L1 norm of p
L_reg_p = torch.sum(torch.stack([torch.sum(torch.abs(p), dim=1) for p in p_inf], dim=0), dim=0)
# Return total loss as list of losses, so you can possibly reweight them
L = [L_p_g, L_p_x, L_x_gen, L_x_g, L_x_p, L_g, L_reg_g, L_reg_p]
return L 通过该loss,使用BPTT+ADAM进行更新。 |
DiscussionSuccessive RepresentationTEM的学习过程是完全offline的,其并不存在与env进行交互动态更新policy的可能性,那么对于env的sample-policy就会影响其形成的表征,不论是grid-cell还是place-cell,而这有可能引发有关#11 SR的讨论。在TEM的simulation中,其使用喜欢在边界附近花费时间并接近物体的policy来模拟non-diffusive transitions。关于此,James Whittington在其phd-thesis中如下写道:
这其实涉及到
|
Hierarchies in the Map在介绍Algorithm的时候,我们提到graph中的sub-graph要尽可能的相同,才能让tem在限定的struct-set下抽取出通用的rule,目前还做不到任意graph抽取转移rule。这里的主要原因在于:
这也是HPC和ERC都具备不同freq的原因,我们可以通过不同的scale来表示world,这种hierarchy的组织形式是极其高效的(二进制独立编码),有助于我们以我们认为正常的方式抽取world中的rule。同时,Saxe et. al. 2019表示学习过程中会展现出一种non-linear的形式,先填满第一个奇异值(也就是最大的那个)对应的向量,然后再去修正后面的奇异值对应向量。考虑到hierarchy对应的其实也是特征值的问题,这暗示了人们理解事物的时候的一个基本原则,难度要循序渐进,先解决统一问题,再解决细节问题,我们才能摸到背后统一的rule。关于同时学经历具备不同subgraph的graph问题,我们不清楚学习到的grid-cell发放模式是否还具备hierarchy的结构,不知道TEM能否预测在这种实验范式下的grid-cell发放模式,注意这里是同时学习,而不是在一个环境中学完到另外一个新环境中产生remapping,学不会就另说了,TEM现在是真不太好说是否具备这样的功能。 # Not in paper, but this greatly improves zero-shot inference: provide
# the uncertainty function of the inferred abstract location with measures of memory quality
with torch.no_grad():
# For the first measure, use the grounded location inferred from memory to generate an observation
x_hat, x_hat_logits = self.gen_x(p_x[0])
# Then calculate the error between the generated observation and the actual observation:
# if the memory is working well, this error should be small
err = utils.squared_error(x, x_hat)
# The second measure is the vector norm of the inferred abstract location; good memories should have
# similar vector norms. Concatenate the two measures as input for the abstract location uncertainty function
# sigma_g_input - (n_f(5), batch_size(16), n_measure(2))
sigma_g_input = [torch.cat(
(torch.sum(g ** 2, dim=1, keepdim=True), torch.unsqueeze(err, dim=1)), dim=1
) for g in mu_g_mem]
...
# And get standard deviation/uncertainty of inferred abstract location by
# providing uncertainty function with memory quality measures
# sigma_g_mem - (n_f(5), batch_size(16), n_g(30,30,24,18,18))
sigma_g_mem = self.f_sigma_g_mem(sigma_g_input) |
Questions
|
Summary
|
Future Plan
|
QuestionsTEM
TEM-OVC
|
ReportThe following is a report about Tolman-Eichenbaum Machine (TEM), the corresponding ppt can be downloaded from here. |
More about disentangled representationWe use a discrete $$
$$ where
$$
$$
$$
$$ where $W_{a} \in \mathbb{R}^{n_{c}\times n_{c}} is a weight matrix that depends on action, |
Pattern forming dynamics. The overall loss can be optimized with respect to the weights. However, it can also be optimized directly with respect to
This is a necessity for us as we need to represent objects which may move between tasks. Planning? To optimize both When optimizing with respect to The dynamics of the $$ This says if you get the object prediction wrong, then update The dynamics of the$\mathcal{L}_{path\ integration}$ are: $$ The two terms in the above equation can be easily understood. The first says that the representation at each location, We note that while we simulated this on a discrete grid, the same principles apply to continuous cases. In this case the sums over location/actions need to be replaced with integrals. This is very general approach for understanding representations. The structural loss does not have to be related to the rules of path integration. It can be anything. It could be the rules of a graph. It could be rules of topology. It could have one set of rules at some locations and another set of rules at other locations.
If there are structure or rules in the world or behavior, our constraints say that representations should understand that structures or rules. In mathematics this is known as a homeomorphism. In sum, understanding representations via constraints is very general. |
Whittington J C R, Muller T H, Mark S, et al. The Tolman-Eichenbaum machine: Unifying space and relational memory through generalization in the hippocampal formation.
The text was updated successfully, but these errors were encountered: