Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cell '20 | The Tolman-Eichenbaum machine: Unifying space and relational memory through generalization in the hippocampal formation. #16

Open
NorbertZheng opened this issue Jan 13, 2022 · 23 comments

Comments

@NorbertZheng
Copy link
Owner

Whittington J C R, Muller T H, Mark S, et al. The Tolman-Eichenbaum machine: Unifying space and relational memory through generalization in the hippocampal formation.

@NorbertZheng
Copy link
Owner Author

Mattar et al. 2018提出的prioritized-sweep可以较好地解释小鼠在rest时replay内容顺序的结构,但是人类在offline的replay却不能很好地被其预测,如Antonov et al. 2022Eldar et al. 2020中数据所示。这表明人类不是单纯地由expected-return所驱动,而online的sample是严重依赖policy的,导致TEM学出的hpc表征也是依赖policy的。但人类中除了online的expr,也可以在rest时自发地重组织形成simu的expr(如Liu et al. 2019),其内部具备丰富的sequence结构,使用其代替TEM中对Hopfield的reactivation是否有助于env-model的学习呢?

@NorbertZheng
Copy link
Owner Author

NorbertZheng commented Jan 13, 2022

这也是在expr中只学习p(s'|s)而不学习p(s'|s,a)的弊端,如果还有action会导致存储指数爆炸,主要也没有来自motor cortex的信息。如何平衡model修正与value-update之间的关系?毕竟replay应不仅仅和value相关,replay的是state-trans。如果dopamine驱动,reward对应的back replay,算是加上了一种state-trans偏好?
另外一点,我们就可以用cognitive-map这一对env的model进行RL任务,解决control问题,优化policy了。

@NorbertZheng
Copy link
Owner Author

另外replay的信息压缩是否和information bottleneck的rl-sampling(Zhu et al. 2020)相关?

@NorbertZheng
Copy link
Owner Author

NorbertZheng commented Jan 23, 2022

Tolman-Eichenbaum Machine

在这里,我们解析pytorch版本的TEM代码,主要解释model中的运算逻辑和架构。十分感谢Jcobb Bakermans提供的pytorch版本代码,比James Whittington的tensorflow版本TEM代码好看太多了。
image

@NorbertZheng
Copy link
Owner Author

NorbertZheng commented Jan 23, 2022

Algorithm

forward

def forward(self, walk, prev_iter = None, prev_M = None):
    # The previous iteration may contain walks without action.
    # These are new walks, for which some parameters need to be reset.
    steps = self.init_walks(prev_iter)
    # Forward pass: perform a TEM iteration for each set of [place, observation, action],
    # and produce inferred and generated variables for each step.
    for g, x, a in walk:
        # If there is no previous iteration at all: all walks
        # are new, initialise a whole new iteration object
        if steps is None:
            # Use an Iteration object to set initial values before any real iterations,
            # initialising M, x_inf as zero. Set actions to None blank
            # to indicate there was no previous action
            steps = [self.init_iteration(g, x, [None for _ in range(len(a))], prev_M)]
        # Perform TEM iteration using transition from previous iteration
        L, M, g_gen, p_gen, x_gen, x_logits, x_inf, g_inf, p_inf = self.iteration(
            x, g, steps[-1].a, steps[-1].M, steps[-1].x_inf, steps[-1].g_inf
        )
        # Store this iteration in iteration object in steps list
        steps.append(Iteration(g, x, a, L, M, g_gen, p_gen, x_gen, x_logits, x_inf, g_inf, p_inf))    
    # The first step is either a step from a previous walk or initialisiation rubbish, so remove it
    steps = steps[1:]
    # Return steps, which is a list of Iteration objects
    return steps

在这里,我们可以看到每次调用tem的时候,我们需要向其提供walkprev_iter(也就是steps)输入。其中,输入的数据结构如下:

  • walk:(20, batch_size(16)),其中20表示该walk的长度,batch_size表明相同结构的不同训练环境(也就是具备相同的struct,但具备不同的sensory dist)。值得注意的是这里的struct是确定形式的(比如square,可以变换大小,主要目标是让tem学到square下的状态转换基本模式),并不能是随意的graph,也就是需要graph中的sub-graph要尽可能的相同,才能让tem在限定的struct-set下抽取出通用的rule,目前还做不到任意graph抽取转移rule,这也是Tim BeherensBNU IDG/McGovern十周年纪念上的解释。
  • prev_iter:(20, batch_size(16)),之前迭代的steps。但一进来就直接掐断,只留下一个Iteration,要么是初始化得到的,要么是上次迭代的最后一个Iteration。

之后便开始,对walk进行迭代,每次得到的Iteration都被放进steps中,作为下一次调用forward时的prev_iter

@NorbertZheng
Copy link
Owner Author

NorbertZheng commented Jan 23, 2022

iteration

def iteration(self, x, locations, a_prev, M_prev, x_prev, g_prev):
    # First, do the transition step, as it will be necessary for both
    # the inference and generative part of the model
    gt_gen, gt_inf = self.gen_g(a_prev, g_prev, locations)
    # Run inference model: infer grounded location p_inf (hippocampus),
    # abstract location g_inf (entorhinal). Also keep filtered sensory observation (x_inf),
    # and retrieved grounded location p_inf_x
    x_inf, g_inf, p_inf_x, p_inf = self.inference(x, locations, M_prev, x_prev, gt_inf)
    # Run generative model: since generative model is only used
    # for training purposes, it will generate from *inferred* variables
    # instead of *generated* variables (as it would when used for generation)
    x_gen, x_logits, p_gen = self.generative(M_prev, p_inf, g_inf, gt_gen)
    # Update generative memory with generated and inferred grounded location.
    M = [self.hebbian(M_prev[0], torch.cat(p_inf,dim=1), torch.cat(p_gen,dim=1))]
    # If using memory for grounded location inference: append inference memory
    if self.hyper['use_p_inf']:
        # Inference memory is identical to generative memory
        # if using common memory, and updated separatedly if not
        M.append(M[0] if self.hyper['common_memory'] else\
            self.hebbian(
                M_prev[1], torch.cat(p_inf,dim=1), torch.cat(p_inf_x,dim=1),
                do_hierarchical_connections=False
            )
        )
    # Calculate loss of this step
    L = self.loss(gt_gen, p_gen, x_logits, x, g_inf, p_inf, p_inf_x, M_prev)
    # Return all iteration values
    return L, M, gt_gen, p_gen, x_gen, x_logits, x_inf, g_inf, p_inf

iterationforward函数中唯一调用的函数。其主要步骤如下:

  • gen_g:这也是gen_g函数唯一一次调用的地方。该函数首先将action转换为对应的grid cell投射矩阵D_a,然后将之前的抽象位置转移到当前的抽象位置。其中的变量依赖如下:
    gt_gen:
        (a_prev,g_prev)->gt_gen
    gt_inf:
        (a_prev,g_prev)->gt_inf
    
  • inference:tem中的inference部分。基于感觉输入x推断抽象位置g_inf和复合位置p_inf,并同时输出中间变量,如感觉输入还原的复合位置p_inf_x以及filtered的感觉观测x_inf。其中的变量依赖如下:
    x_inf:
        x->x_c
        (x_prev,x_c)->x_f(x_inf)
    p_inf_x:
        x_f->x_
        (x_,M_prev[inf])->p_x(p_inf_x)
    g_inf:
        (p_x,gt_inf,x)->g(g_inf)
    p_inf:
        g->g_
        (x_,g_)->p(p_inf)
    
  • generative:tem中的generative部分。只是为了training而运行,它从infered variables生成,而不是从generated variables,也就是将inference部分产生的抽象位置g_inf和复合位置p_inf以及之前gen_g部分生成的gt_gen为基础,生成出x_gen,进而和x对比产生error用于训练。其中的变量依赖如下:
    p_gen:
        (g_inf,M_prev[gen])->p_g_inf(p_gen)
    x_gen:
        p_inf[0]->x_p
        p_g_inf[0]->x_g
        (gt_gen,M_prev[gen])->p_g_gen[0]->x_gt
    x_logits:
        p_inf[0]->x_p_logits
        p_g_inf[0]->x_g_logits
        (gt_gen,M_prev[gen])->p_g_gen[0]->x_gt_logits
    
  • hebbian:使用当前iteration前三步得到的结果更新M。公式如下:
    image
    p_t对应的是p_inf,其是由g_inf作为MLP输入得到的,与之对应的是p_t_,其是由g_inf作为Memory输入得到的p_gen,因而这里更新的是M[gen]。另外一个更新公式如下(如果use_p_inf的话):
    image
    p_t对应的是p_inf,其是由g_inf作为MLP输入得到的,与之对应的是p_xt,这是在推断g_inf之前的变量p_inf_x,在p_inf之前,这里更新的是M[inf]。其中的变量依赖如下:
    M[gen]:
        (p_inf,p_gen)->M[gen]
    M[inf]:
        (p_inf,p_inf_x)->M[inf]
    
  • loss:最后是计算loss的函数,用于得到各类error(L_p_g, L_p_x, L_x_gen, L_x_g, L_x_p, L_g, L_reg_g, L_reg_p),用于后期training。其中的变量依赖如下:
    L_p:
        (p_inf,p_gen)->L_p_g
        (p_inf,p_inf_x)->L_p_x
    L_g:
        (g_inf,gt_gen)->L_g
    L_x:
        (x,x_gt_logits)->L_x_gen
        (x,x_g_logits)->L_x_g
        (x,x_p_logits)->L_x_p
    L_reg:
        g_inf->L_reg_g
        p_inf->L_reg_p
    

@NorbertZheng
Copy link
Owner Author

NorbertZheng commented Jan 23, 2022

generative

generative model可以被分解为如下形式,分别对应着generative process的三部分gen_ggen_pgen_x
image

def generative(self, M_prev, p_inf, g_inf, g_gen):
    # M_prev - (1/2, batch_size(16), sum(n_p)(400), sum(n_p)(400))
    # p_inf - (n_f(5), batch_size(16), n_x*n_g_subsampled(100,100,80,60,60))
    # g_inf - (n_f(5), batch_size(16), n_g(30, 30, 24, 18, 18))
    # g_gen - (n_f(5), batch_size(16), n_g(30, 30, 24, 18, 18))
    # Generate observation from inferred grounded location, using only
    # the highest frequency. Also keep non-softmaxed logits which are used in the loss later
    # x_p - (batch_size(16), n_x(45))
    # x_p_logits - (batch_size(16), n_x(45))
    x_p, x_p_logits = self.gen_x(p_inf[0])
    # Retrieve grounded location from memory by pattern completion on inferred abstract location
    # p_g_inf - (n_f(5), batch_size(16), n_x*n_g_subsampled(100,100,80,60,60))
    p_g_inf = self.gen_p(g_inf, M_prev[0]) # was p_mem_gen
    # And generate observation from the grounded location retrieved from inferred abstract location
    # x_g - (batch_size(16), n_x(45))
    # x_g_logits - (batch_size(16), n_x(45))
    x_g, x_g_logits = self.gen_x(p_g_inf[0])
    # Retreive grounded location from memory by pattern completion on abstract location by transitioning
    # p_g_gen - (n_f(5), batch_size(16), n_x*n_g_subsampled(100,100,80,60,60))
    p_g_gen = self.gen_p(g_gen, M_prev[0])
    # Generate observation from sampled grounded location
    # x_gt - (batch_size(16), n_x(45))
    # x_gt_logits - (batch_size(16), n_x(45))
    x_gt, x_gt_logits = self.gen_x(p_g_gen[0])
    # Return all generated observations and their corresponding logits
    return (x_p, x_g, x_gt), (x_p_logits, x_g_logits, x_gt_logits), p_g_inf

generative process的示意图如下:
image
generative process的更新公式如下:
image
generativeiteration函数中被调用。其主要模块如下:

  • gen_x:这里的做法十分粗暴,直接把p_inf的第一个频率模块拿出来作为其输入(batch_size(16),n_p[0](100)),然后直接传给f_x函数(如果do_sample的话会另行修正,但源码中没有修改,暂时作为一个placeholder,应当使用reparameterisation trick)。由于n_p[0](100)是通过n_g_subsampled[0](10)*n_x_c(10)计算得到的,这里会将p_inf[0]对每个sensory节点累加g_subsampled,得到(batch_size(16),n_x_c(10))的中间值,然后加权偏置得到(batch_size(16),n_x_c(10))x,然后使用可训练的f_c_star函数(MLP)解压得到(batch_size(16),n_x(45))logits,最后通过softmax将其归一化为概率分布probability。其对应示意图中右半部分,公式对应Observation Sample(但这里只是得到了prob dist,没有进行最后的sample得到one-hot向量),细节如下:
    image

  • gen_p:该函数接收两个输入,一个是(n_f(5), batch_size(16), n_g(30, 30, 24, 18, 18))g(正常的表征形式,没有被压缩),一个是(batch_size(16), sum(n_p)(400), sum(n_p)(400))M_prev[gen]。该函数中的步骤如下:

    • g2g_:将g输入,首先通过f_g函数暴力压缩为(n_f(5), batch_size(16), n_g_subsampled(10,10,8,6,6)),然后再暴力拉伸为(n_f(5), batch_size(16), n_p(100,100,80,60,60))g_,这便是g产生的p-index,即将在下面使用Hopfield进行迭代,得到收敛的p
    • attractor:将之前得到的g_(p_query)输入,先reshape为(batch_size(16), sum(n_p)(400)),然后通过f_p函数clamp到(-1,1)并使用leaky_relu进行激活。然后依据如下公式进行update:
      image
      最后通过reshape得到(n_f(5), batch_size(16), n_x*n_g_subsampled(100,100,80,60,60))p作为mu_p
    • f_sigma_p:将得到的mu_p作为输入,使用可训练的MLP转化为同样shape的sigma_p

    最终,p从N(μ=mu_p,σ=sigma_p)中sample得到(如果do_sample为False,则直接使用mu_p)。其对应示意图中左下部分,公式对应Retrieve Memory。

其中,(x_p,x_g,x_gt)分别是不同时期由不同p通过gen_x得到的,具体依赖关系如下:
image

@NorbertZheng
Copy link
Owner Author

NorbertZheng commented Jan 23, 2022

gen_g

这也是属于tem的generative部分。其对应示意图中左上部分,公式对应Transition Sample:
image
我们也很自然的将gen_g拆解为两部分f_mu_g_pathf_sigma_g_path

  • f_mu_g_path:接收两个输入a_prevg_prev。在这里,(batch_size(16),)a_prev先被转化为(batch_size(16), n_actions)的one-hot向量a。然后被可训练的MLP映射为(n_f(5), batch_size(16), [sum(n_g_from), n_g_to]([120,30],[90,30],[60,24],[36,18],[18,18]))D_a,其中受shiny影响,所有no_direc指定的batch都会被可训练的默认权重D_no_a所替代。之后,将D_a与对应的g_in相乘便可得到delta
    image
    但这里并没有结束,为保证stability,这并不是最终的g_step,真正的g_step是由g_prevdelta相加,并通过f_g_clamp函数将(-1,1)之外的值clamp后得到。
  • f_sigma_g_path:直接将g_prev输入到可训练的MLP中得到(n_f(5), batch_size(16), n_g(30, 30, 24, 18, 18))from_g。对于walk刚开始的env,其对应的sigma_g由可训练的参数logsig_g_init通过指数得到。

最后,g_inf便从N(μ=mu_g,σ=sigma_g)中sample得到(如果do_sample为False,则直接使用mu_g)。g_gen相比g_inf多一步,如果存在shiny,其中f_mu_g_pathD_a需要被D_no_a代替,重新得到的mu_g直接作为g_gen

@NorbertZheng
Copy link
Owner Author

NorbertZheng commented Jan 23, 2022

inference

inference model可以被分解为如下形式,分别对应着inference process的两部分infer_pinf_g
image
其中前项可以进一步分解,使得inference process中的infer_g可以复用generative process中的gen_g(也就是用来生成gt_inf):
image

def inference(self, x, locations, M_prev, x_prev, g_gen):
    # Compress sensory observation from one-hot
    # to two-hot (or alternatively, whatever an MLP makes of it)
    # x - (batch_size(16), n_x(45))
    # x_c - (batch_size(16), n_x_c(10))
    x_c = self.f_c(x)
    # Temporally filter sensory observation by mixing it with previous experience
    # x_prev - (n_f(5), batch_size(16), n_x_c(10))
    # x_f - (n_f(5), batch_size(16), n_x_c(10))
    x_f = self.x_prev2x(x_prev, x_c)
    # Prepare sensory experience for input to memory by normalisation and weighting
    # x_ - (n_f(5), batch_size(16), n_p(100,100,80,60,60))
    x_ = self.x2x_(x_f)
    # Retrieve grounded location from memory by
    # doing pattern completion on current sensory experience
    # p_x - (n_f(5), batch_size(16), n_p(100,100,80,60,60))
    p_x = self.attractor(x_, M_prev[1], retrieve_it_mask=self.hyper['p_retrieve_mask_inf'])\
        if self.hyper['use_p_inf'] else None
    # Infer abstract location by combining previous abstract location and
    # grounded location retrieved from memory by current sensory experience
    # g_gen - (n_f(5), batch_size(16), self.g_init[f](30, 30, 24, 18, 18))
    # x - (batch_size(16), n_x(45)), one-hot code, each element corresponds to a sensory
    # locations - (batch_size(16),), each element is dict, e.g. {'id': 24, 'shiny': None}
    # g - (n_f(5), batch_size(16), self.g_init[f](30, 30, 24, 18, 18)),
    # use x & g_gen(for inv_var_weight) to infer g
    g = self.inf_g(p_x, g_gen, x, locations)
    # Prepare abstract location for input to memory by downsampling and weighting
    # g_ - (n_f(5), batch_size(16), n_x*n_g_subsampled(100,100,80,60,60))
    g_ = self.g2g_(g)
    # Infer grounded location from sensory experience and inferred abstract location
    # p - (n_f(5), batch_size(16), n_x*n_g_subsampled(100,100,80,60,60))
    p = self.inf_p(x_, g_)
    # Return variables in order that they were created
    return x_f, g, p_x, p

inference process的示意图如下:
test_00
inference process的更新公式如下:
image
inferenceiteration函数中被调用。其主要模块如下:

  • inf_g:这一部分对应示意图中的上半部分。由上面的公式可知,其输入有(p_inf_x,gt_inf,x,locations)。我们先考虑p_inf_x的生成过程:

    • (batch_size(16), n_x(45))x通过f_x函数查表压缩得到(batch_size(16), n_x_c(10))x_c,这是更新公式中的Compress sensory observation。
    • 通过x_prev2x函数进行时序过滤得到(n_f(5), batch_size(16), n_x_c(10))x_f,这是更新公式中的Temporally filter sensorium,其中(n_f(5),)α是可训练参数。
    • 通过x2x_函数将(n_f(5), batch_size(16), n_x_c(10))x_f进行normalize,然后暴力重复并加权得到(n_f(5), batch_size(16), n_x_c*n_g_subsampled(100,100,80,60,60))x_,这是更新公式中的Sensory input to hippocampus,其中(n_f(5),)w_p是可训练参数。
    • 通过attractor函数,使用M_prev[inf]来进行迭代,得到渐近收敛的(n_f(5), batch_size(16), n_p(100,100,80,60,60))p_inf_x,这是更新公式中的Retrieve memory。

    如果不使用use_p_inf,则在计算g时只会使用gt_infmu_g=gt_inf[0],是shiny操作前的gt_gensigma_g=gt_inf[1],这和shiny操作前的gt_gen的σ相同),会和生成gt_gen的时候一样对N(μ=mu_g,σ=sigma_g)进行sample(如果do_sample为False,则直接使用mu_g)。但如果使用use_p_inf,这时我们会先对p_inf_x暴力sum降采样得到(n_f(5), batch_size(16), n_g_subsampled(10,10,8,6,6))g_downsampled,然后通过可训练MLP的f_mu_g_mem函数并clamp(-1,1)得到(n_f(5), batch_size(16), n_g(30,30,24,18,18))mu_g_mem。由于这里的p_inf_x实际上对应了本节第2个公式中的x_{<=t}和M_{t-1}混合后的结果,我们可以使用其作为gen_x函数的输入反推得到x_hat,进而计算与x的MSE,并整合g**2一块作为可训练MLP的f_sigma_g_mem函数输入sigma_g_mem,并加上超参数p2g_scale_offsetp2g_sig_val的乘积以减少推断p对早期训练的影响。之后通过inv_var_weight函数将(mu_g_path,mu_g_mem)和(sigma_g_path,sigma_g_mem)整合,考虑shiny对其的影响之后,直接在N(μ=mu_g,σ=sigma_g)中sample即可得到g_inf,这是更新公式中的Infer entorhinal。

  • inf_p:这一部分对应示意图中的下半部分。在此之前,将g_inf通过暴力truncated降采样,并repeat得到(n_f(5), batch_size(16), n_x*n_g_subsampled(100,100,80,60,60))g_,这是更新公式中的Entorhinal input to hippocampus。然后将g_x_进行element-wise mul,clamp(-1,1)后进行leaky-relu得到mu_p,注意这里没有使用attractor,但sigma_p由于不明晰paper中所具体指的函数被设置为0,p便通过从N(μ=mu_p,σ=sigma_p)中sample得到,这是更新公式中的Infer hippocampus。

@NorbertZheng
Copy link
Owner Author

NorbertZheng commented Jan 23, 2022

hebbian

这是Memory部分的学习算法,更新公式如下:
image

def hebbian(self, M_prev, p_inferred, p_generated, do_hierarchical_connections=True):
    # Create new ground memory for attractor network by setting weights to outer product of learned vectors
    # p_inferred corresponds to p in the paper, and p_generated corresponds to p^. 
    # The order of p + p^ and p - p^ is reversed since these are row vectors,
    # instead of column vectors in the paper.
    # M_new - (batch_size(16), sum(n_p)(400), sum(n_p)(400)),
    # calculated from (16,400,1) matmul (16,1,400), not element-wise mul
    M_new = torch.squeeze(torch.matmul(
        torch.unsqueeze(p_inferred + p_generated, 2),torch.unsqueeze(p_inferred - p_generated,1)
    ))
    # Multiply by connection vector, e.g. only keeping weights
    # from low to high frequencies for hierarchical retrieval.
    if do_hierarchical_connections:
        M_new = M_new * self.hyper['p_update_mask']
    # Store grounded location in attractor network memory with weights M by Hebbian learning of pattern
    M = torch.clamp(self.hyper['lambda'] * M_prev + self.hyper['eta'] * M_new, min=-1, max=1)
    return M

在使用更新公式更新之后,使用clamp截断(-1,1)之外的部分。

@NorbertZheng
Copy link
Owner Author

NorbertZheng commented Jan 23, 2022

loss

这是Cortex部分的学习算法,loss公式如下:
image

def loss(self, g_gen, p_gen, x_logits, x, g_inf, p_inf, p_inf_x, M_prev):
    # Calculate loss function, separately for each component
    # because you might want to reweight contributions later.
    # L_p_gen is squared error loss between inferred grounded location
    # and grounded location retrieved from inferred abstract location
    L_p_g = torch.sum(torch.stack(utils.squared_error(p_inf, p_gen), dim=0), dim=0)
    # L_p_inf is squared error loss between inferred grounded location
    # and grounded location retrieved from sensory experience
    L_p_x = torch.sum(torch.stack(utils.squared_error(p_inf, p_inf_x), dim=0), dim=0)\
        if self.hyper['use_p_inf'] else torch.zeros_like(L_p_g)
    # L_g is squared error loss between generated abstract location and inferred abstract location
    L_g = torch.sum(torch.stack(utils.squared_error(g_inf, g_gen), dim=0), dim=0)         
    # L_x is a cross-entropy loss between sensory experience and different model predictions.
    # First get true labels from sensory experience
    labels = torch.argmax(x, 1)
    # L_x_gen: losses generated by generative model from g_prev -> g -> p -> x
    L_x_gen = utils.cross_entropy(x_logits[2], labels)
    # L_x_g: Losses generated by generative model from g_inf -> p -> x
    L_x_g = utils.cross_entropy(x_logits[1], labels)
    # L_x_p: Losses generated by generative model from p_inf -> x
    L_x_p = utils.cross_entropy(x_logits[0], labels)
    # L_reg are regularisation losses, L_reg_g on L2 norm of g
    L_reg_g = torch.sum(torch.stack([torch.sum(g ** 2, dim=1) for g in g_inf], dim=0), dim=0)
    # And L_reg_p regularisation on L1 norm of p
    L_reg_p = torch.sum(torch.stack([torch.sum(torch.abs(p), dim=1) for p in p_inf], dim=0), dim=0)
    # Return total loss as list of losses, so you can possibly reweight them
    L = [L_p_g, L_p_x, L_x_gen, L_x_g, L_x_p, L_g, L_reg_g, L_reg_p]
    return L

通过该loss,使用BPTT+ADAM进行更新。

@NorbertZheng
Copy link
Owner Author

NorbertZheng commented Jan 23, 2022

Discussion

Successive Representation

TEM的学习过程是完全offline的,其并不存在与env进行交互动态更新policy的可能性,那么对于env的sample-policy就会影响其形成的表征,不论是grid-cell还是place-cell,而这有可能引发有关#11 SR的讨论。在TEM的simulation中,其使用喜欢在边界附近花费时间并接近物体的policy来模拟non-diffusive transitions。关于此,James Whittington在其phd-thesis中如下写道:

  • Because the transition statistics change, so do the optimal representations for predicting future location.
  • In order to make next-state predictions TEM learns predictive representations, with object vector cells predicting the next transition is towards the object.

这其实涉及到p(s'|s)的表征问题,很自然的,在non-diffusive transition下,p(s'|s)并不是等概率的,这导致agent学到的状态转移矩阵T就可以反应policy,具备形成SR的基础特征。而SR即在解释policy诱发的p(s'|s)偏好,由于TEM是在学习预测下一个state以及对应的sensory,而且这里只有对边界和物体的喜欢,因而只产生了object-vector cell,这是SR的一种形式。如果我们进一步增加non-diffusive transition的种类,我们可以观测到更多种类的SR。
image
当然,环境自身的transition结构,TEM也是学到了,毕竟在generative中的gen_g部分,有D_a*g_{t-1}表示(s,a)-pair,而其是用于生成g_{t}的,这里其实就表示了p(s'|s,a)。由于这里假定了MDP的全可观测性,并不存在POMDP的问题,那么其env动力学方程可由p(s',r|s,a)表示,同时由于没有提供reward信息,这里的env动力学方程可简化为p(s'|s,a),所以我们认为其已经学会了环境自身的transition结构,这一部分表示在μ=D_a*g_{t-1}。而σ=f_sigma_g(g_{t-1})表示对于μ的uncertainty,但由于环境是determined,所以这里的distribution既不表示stochasticity,也不表示uncertainty,理应归为0,关于stochasticity和uncertainty的讨论详见#13 。主要是x的sensory distribution变化,无法影响到理想的graph,g_{t}的值理应是确定的,不存在stochasticity和uncertainty导致的distribution。
另外,我认为一个模型能够表示state和在state-graph上的inference,并不会把其representation限制在diffusive transition上,因为这两者相互之间具备很大的自由度,与TEM能够形成SR并不矛盾。这一想法主要来自TEM论文中提到对于state的约束:

  • Each location in the map has a different g representation (so a unique memory can be built).
  • Arriving at the same location after different actions causes the same g representation (so the same memory can be retrieved) - a form of path integration for arbitrary graph structures.

@NorbertZheng
Copy link
Owner Author

NorbertZheng commented Jan 23, 2022

Hierarchies in the Map

在介绍Algorithm的时候,我们提到graph中的sub-graph要尽可能的相同,才能让tem在限定的struct-set下抽取出通用的rule,目前还做不到任意graph抽取转移rule。这里的主要原因在于:

  • When representing tasks that have self repeating structure (as ours do), it becomes efficient to hierarchically organise your cognitive map.

这也是HPC和ERC都具备不同freq的原因,我们可以通过不同的scale来表示world,这种hierarchy的组织形式是极其高效的(二进制独立编码),有助于我们以我们认为正常的方式抽取world中的rule。同时,Saxe et. al. 2019表示学习过程中会展现出一种non-linear的形式,先填满第一个奇异值(也就是最大的那个)对应的向量,然后再去修正后面的奇异值对应向量。考虑到hierarchy对应的其实也是特征值的问题,这暗示了人们理解事物的时候的一个基本原则,难度要循序渐进,先解决统一问题,再解决细节问题,我们才能摸到背后统一的rule。关于同时学经历具备不同subgraph的graph问题,我们不清楚学习到的grid-cell发放模式是否还具备hierarchy的结构,不知道TEM能否预测在这种实验范式下的grid-cell发放模式,注意这里是同时学习,而不是在一个环境中学完到另外一个新环境中产生remapping,学不会就另说了,TEM现在是真不太好说是否具备这样的功能。
另外,一个猜测,TEM在学习的过程中应该也是线形成large-scale的grid-cell(也就是特征值大的那个对应的特征向量),然后再逐渐补全剩余的部分,也就是表明了统一问题优先的原则。这就让我们考虑到了generalization的问题,不知道这里是否是和计算机那边DeepRL在PCG-env中考虑的generalization问题一致,Jiang et. al. 2021针对这一类问题从sample-efficiency的角度尝试解决不知道对human在offline时期replay来平衡generalization和reward-maximization有没有帮助,不然不好formalize人类执行任务中的generalization问题的形式。另外一点,就是人类日常生活中所解决的问题往往是POMDP,这也就导致了belief-state representation的出现,比如Gershman et. al. 2019提出mPFC等作为belief-state representation的潜在计算位置。Give the belief to TEM~
inf_g的时候,依据p_inf_x生成了sigma_g_input,这对于TEM刚进入新环境而言,其实算是一种不能确定true parameters,是一种epistemic uncertainty,这就变成了一个epistemic POMDP问题。而TEM很好地解决了这个问题,能够在经历一次新环境所有节点之后立刻推断出所有边,做到了epistemic POMDP问题的generalization。

# Not in paper, but this greatly improves zero-shot inference: provide
# the uncertainty function of the inferred abstract location with measures of memory quality
with torch.no_grad():
    # For the first measure, use the grounded location inferred from memory to generate an observation
    x_hat, x_hat_logits = self.gen_x(p_x[0])
    # Then calculate the error between the generated observation and the actual observation:
    # if the memory is working well, this error should be small
    err = utils.squared_error(x, x_hat)
# The second measure is the vector norm of the inferred abstract location; good memories should have
# similar vector norms. Concatenate the two measures as input for the abstract location uncertainty function
# sigma_g_input - (n_f(5), batch_size(16), n_measure(2))
sigma_g_input = [torch.cat(
    (torch.sum(g ** 2, dim=1, keepdim=True), torch.unsqueeze(err, dim=1)), dim=1
) for g in mu_g_mem]
...
# And get standard deviation/uncertainty of inferred abstract location by
# providing uncertainty function with memory quality measures
# sigma_g_mem - (n_f(5), batch_size(16), n_g(30,30,24,18,18))
sigma_g_mem = self.f_sigma_g_mem(sigma_g_input)

@NorbertZheng
Copy link
Owner Author

Questions

  • 有关人类replay有助于generalization的问题,可否与PCG-env中定义的generalization问题对应,我以后便从这方面去看计算机强化学习研究的文章。
  • 人类在日常生活中解决的问题往往是POMDP的,是否有必要给TEM加入belief-state representation,真正让其σ=f_sigma_g(g_{t-1})用起来,或者可以借鉴其它做POMDP的强化学习文章,PCG-env的generalization问题是否算作POMDP抽取env动力学方程的特例(也就是指实际生活中比这中simulation的环境更加复杂)?
  • Deep Learning中对于generalization指标的定义都有些许问题(说白了都没有generalization-error好用),就算explicit-regularization、flatten minima等都不能和generalization直接挂钩,但implicit-regularization一般能够从一定程度上提高generalization效果,但其无法用explicit-regularization的语言完全表示出来,详见Zhang et. al. 2021Sun et. al. 2021是为了表示Complementary Learning概念的优势,但为什么要留着一个early-stop,考虑到Arora et. al. 2019提到Deep Linear Network结合SGD具备隐式正则化的效果,近似限制rank(W),为什么不做一个Deep Linear Network版本的Complementary Learning系统?这方面拓展下去,是直接走ML-theory那边了么?
  • 在TEM上增加replay可以使其成为一个online的系统,加上有效的replay-sample policy有助于平衡generalization和reward-maximization,这好像是可以在一块的,但是Deep Linear Network和它们的关系是什么呢?有关Deep Linear Network的主要工作还是在ML-theory那边,好像和实验室在做的其他方向有些割裂。

@NorbertZheng
Copy link
Owner Author

Summary

  • Algorithm: Wake-Sleep Algorithm.
  • Architecture: g as MEC, p as HPC, x as LEC, combine x and g to get p.
  • Objective Function: prediction error(can be change to others objective function).

@NorbertZheng
Copy link
Owner Author

NorbertZheng commented Jan 24, 2022

Future Plan

  • Read the TEM code.
  • Add replay to p and observe the simulation result.
  • Modify objective function to get the optimal replay, which is similar to that in the second step.

@NorbertZheng
Copy link
Owner Author

NorbertZheng commented Mar 5, 2022

Questions

TEM

  • In the original TEM code, there are some hyper-parameters, e.g. logsig_offset, logsig_ratio, which are used to modify the logsig generated by MLP as follows:
    logsigmas = [self.p2g_logsig[i](x) for i, x in enumerate(logsig_input)]
    logsigma = tf.concat(logsigmas, axis=1) * self.par.logsig_ratio + self.par.logsig_offset
    Is that used to ensure the regularity of the latent space Ph.D. '17 | Variational Inference and Deep Learning: A New Synthesis. #21 generated by MLP? In the _loss part of the original model, there are some regularization items:
    # Calculate L_reg* losses.
    L_reg_g = tf.reduce_sum(tf.stack([tf.reduce_sum(g ** 2, axis=1)\
        for g in g_inf], axis=0), axis=0)
    L_reg_p = tf.reduce_sum(tf.stack([tf.reduce_sum(tf.abs(p), axis=1)\
        for p in p_inf], axis=0), axis=0)
    It seems we don't regularize the sigma term of distributions, so we need logsig_offset, logsig_ratio, and p2g_sig_val to ensure the regularity of the latent space?
  • In the _inf_g part of the original model, there are also some regularization items:
    # Then calculate the error between the generated observation and the actual observation:
    # if the memory is working well, this error should be small. `x_hat` as target.
    err = self._loss_mse(x, x_hat)
    # The second measure is the vector norm of the inferred abstract location;
    # good memories should have similar vector norms. Concatenate the two measures
    # as input for the abstract location uncertainty function.
    # logsig_g_in - (n_f[list], batch_size, 2)
    logsig_g_in = [tf.concat((tf.reduce_sum(g ** 2, axis=1, keepdims=True),
        tf.expand_dims(err, axis=1)), axis=1) for g in mu_g_mem]
    We use the regularization of mu_g, e.g. tf.reduce_sum(g ** 2, axis=1, keepdims=True), and err as the inputs to the MLP_logsig_g_mem. What is the meaning of the input here?
  • Why separate _hebbian and _final_hebbian? Want a batch-update of M?
  • Why always clamp value in range (-1,1)? Why use leaky_relu to modify the activation of p, is that more bio-plausible?
  • It seems p part of the model doesn't have inner dynamics itself. For example, we use x_ to retrieve memory, we will get one attractor. We ignore the previous state of memory, and the attractor can only reflect the dynamic of x instead of p. Specifically, We use x_prev to retrieve p_inf_x_prev, and update x_prev:
    # x - (n_f[list], batch_size, n_x_c)
    x = [(1 - alpha[f]) * x_prev[f] + alpha[f] * x_c\
        for f in range(self.model_params.neuron.n_f)]
    And then we use x_ to retrieve p_inf_x, so the transition from p_inf_x_prev to p_inf_x only reflect the recurrent dynamics of x, so does g-indexed retrieval. So the attractor space of p is determined by both x and g, or both LEC and MEC. Of course, this is perfectly suitable for the tasks we want to explain, but most time, we don't know which brain areas a new task is related to, so we should set up multiple brain areas to be responsible for different graphic structures, and use the matching degree of the task as the weight of the index. After all, the hopfield model can be regarded as a transformer (see Ramsauer et. al. 2020 for more detials), and that enables the hippocampus to have the ability to dynamically organize modules as Lewis 2021 mentioned. Can we cast the problem of exploration as inferencing a graph structure? And this step is before the so-called planning, which could only be executed after the map has been formalized in the hippocampus. During exploration, we keep using the sensory input x to modify the synaptic connection within hippocampus, and formalize a map in hippocampus just like the map in the sensory cortex, e.g. LEC. Hippocampus then uses this activation as the index to call other cortex, e.g. MEC. If the dynamics of hippocampus matches the dynamics of one cortex, then the activation of this cortex will then reinforce the dynamics in hippocampus, or we can say this cortex is attended, for the activation of this cortex has a high similarity with the activation of hippocampus, at the same time, the synaptic connection within hippocampus will also be modified. It seems that two (maybe more) maps in different cortex are trying to reach a consensus. In this case, the Levy flight can be seen as trying to find the most consistent map (maybe the composition of these maps), which is stored in other cortex, as soon as possible. The local movement corresponds to the local map (in one cortex), and the long jump corresponds the transition from one module to another. If the surrounding modules correspond to the same map (in the cortex), then this will reinforce these modules to be clustered in the activation of hippocampus. And now hippocampus formalize a large-scale map, whose components corresponds to clusters.
    When encounter a new graph which is completely different from the learned graph, then hippocampus formalize a new map by itself, which means this map can only match the dynamics of sensory cortex, cannot match other cortex. Now the hippocampus will formalize a fine-gained map, even context will also be reflected in the hippocampus, then transfer it to the cortex, maybe by replay in the random walk mode?
  • When the corresponding abstract map is not found, the map in sensory cortex will dominate the activation of hippocampus. And due to the setting of psychological experiment, two objects that are temporally adjacent may not be represented adjacently in the sensory cortex. In the original model, p doesn't have its own dynamics, without the position code from g (which has not been matched yet), it will take long time for p (also g) to learn the underlying sequence. This is not the case where g is learned and provides the position coding. Should we consider the inner dynamics of hippocampus itself? Maybe this will help to formalize the attractor space where the representations of objects that are temporally adjacent are represented adjacently in the hippocampus (see Schapiro et. al. 2016 for more details), and help to generate replay from hippocampus, not from cortex.
  • Hopfield can be used to model the CA1 of hippocampus, CANN is often used to model the CA3 of hippocampus, which is exactly the origin of theta sequence and replay, can we add CANN in TEM to support replay, or we can replace Hopfield with CANN? After all, the only difference between CANN and Hopfield is that CANN supports continuous attractor valley, which means that CANN is smoother than Hopfield, helping generate replay (or samples).

TEM-OVC

  • Not sure what some variables refer to
    variable my understanding
    r (purely?) reward tuning cell in hippocampus
    d (purely?) direction tuning cell in hippocampus
  • In the inf_l part of the ovc model, l is inferred from the activation of both g and ovc, without the activation of x, is this to match the bio-anatomical evidence?
  • In this ovc model, g and ovc are separated, does this correspond to the result of Obenhaus et. al. 2022? And this paper states that the majority of cell types were intermingled, but grid and object-vector cells exhibited little overlap. So, g mainly account for the generalization, and ovc mainly account for the maximization of reward?
  • It seems that once we encounter one shiny object, we will enter into sleep mode (after one replay, it will enter into awake mode and start from the end location of the replay, not the end location of previous awake mode), and generate replay, whose strategy is hard-coded and not a reversal of the path previously traveled. What is the role of replay here?

@NorbertZheng
Copy link
Owner Author

NorbertZheng commented Jan 28, 2023

Report

The following is a report about Tolman-Eichenbaum Machine (TEM), the corresponding ppt can be downloaded from here.

@NorbertZheng
Copy link
Owner Author

NorbertZheng commented Jan 28, 2023

TEM as Transformer/GNN

image
Here, we can see the architecture of the key component of the Transformer, e.g. self-attention block. Given a number of entities (without sequential order, but entities may contain position embeddings themselves), we use $W_{q}$, $W_{k}$, $W_{v}$ to calculate $q$, $k$, $v$ corresponding to each entity.

$$ Q=HW_{q} \quad K=HW_{k} \quad V=HW_{v}, $$

where $H=[X,E]$ contains the feature embeddings $X$ and the position embeddings $E$. After calculating $Q$, $K$, $V$, respectively. We use each item of $Q$ to query $K$, and get the probability of each value item. This process is like aggregating information from all nodes in the graph, but it is different from Graph Neural Network (GNN), e.g. it doesn't exploit the graph adjacency matrix explicitly, the probability is calculated from correlation.

$$ Prob_{l,i}=softmax(\frac{q_{l,i}K_{l}^{T}}{\sqrt{d_{k}}}). $$

Then we directly multiply $Prob_{l,i}$ with $V$ to get the updated entity $h_{l+1,i}$.

$$ h_{l+1,i}=Prob_{l,i}V_{l}. $$

image
Now, let's see how can we re-formulate TEM as such a self-attention block, e.g. what $W_{*}$, $H$ indeed correspond to. Let's look at the generative process of TEM. Obviously, the input entity embeddings contain feature embeddings $x$ (e.g. sensory observation or the neural activities of LEC) and position embeddings $g$ (e.g. the neural activities of MEC). Here, we use RNN to update the abstract location $g$, which models the neural activities of MEC. We use an action-specific weight matrix $W_{a}$ (which is the output of MLP with no bias in both layers, and the input is the one-hot version of $a$) to get the $\Delta g$.

  • We should note that when staying still, e.g. $a=0$, the one-hot version of $a$ is exactly all 0s. Due to that MLP has no bias in both layers, and we only have $tanh$ activation after the first layer (which outputs 0 if the input is 0), staying still (e.g. $a=0$) always gives all 0s $W_{a}$, which causes $\Delta g$ to be 0s. This is a strong inductive bias!

$$ \Delta g_{t-1}=W_{a}g_{t-1}. $$

Then we can update $g_{t-1}$ with $\Delta g_{t-1}$, clamp the value of updated value to get $g_{t}$.

$$ g_{t}=f_{g}(g_{t-1}+\Delta g_{t-1}). $$

Both $g_{t-1}$ and $g_{t}$ are the position embeddings of the corresponding entity. So, how could we get the corresponding query vector $\tilde{g}$? How to express the query matrix $W_{q}$? We first downsample $g$ then repeat the downsampled value to get the query vector $\tilde{g}$. Therefore, the query vector $\tilde{g}$ and query matrix $W_{q}$ can be formulated as follows:

$$
\tilde{g}{t}=W{repeat}f_{down}(g_{t}) \quad W_{q}=W_{repeat}f_{down}(\cdot).
$$

We can easily see that the query matrix $W_{q}$ of TEM is not like that of the self-attention block in Transformer, e.g. the query matrix $W_{q}$ of TEM is hand-coded, instead of learnable.

image
Now, let's focus on the memory retrieval part of the generative process. Firstly, we have to understand what the Hebbian memory $M_{t-1}$ (hippocampus) is exactly doing. The calculation equation of $p$ and the update equation of $M$ is as follows:

$$ p=flatten(x^{t}g) \quad M_{t}=\sum_{\tau=0}^{t}p_{\tau}^{T}p_{\tau}. $$

From the above equation, we can see that $M_{t}$ is binding every $g$ with every $x$. We should note that we ignore some computational details in the memory update process:

  • Memory decay when updating memory items. In TEM, every stored memory item will exponentially decay over time, which aligns with the original Hebbian learning rule.
  • Update memory items according to memory access rate. TEM doesn't directly store the current memory item without weighting. Instead, TEM uses past experience to calculate the memory access rate, and weights those frequently-accessed memory items lower. And this will avoid assigning high probability to those frequently-accessed memory items during the memory retrieval process. But this is not considered in the original Hebbian learning rule. The following is the implementation in TEM-baseline, there is another implementation in TEM-transformer.

$$
M_{t}=\lambda M_{t-1}+\eta (p_{t}-\hat{p}{t})(p{t}+\hat{p}_{t})^{T}.
$$

Now, we take one attractor step in TEM with no non-linearity as an example,

$$
\begin{aligned}
\tilde{x}{t}^{retrieved}&=sum(unflatten(p{t}^{retrieved}), 1)\
\tilde{p}{t}^{retrieved}&=q{t}M_{t-1}=q_{t}\sum_{\tau=0}^{t-1}p_{\tau}^{T}p_{\tau}=\sum_{\tau=0}^{t-1}[q_{t}\
\end{aligned}
$$

Due to that

$$
\begin{aligned}
&[q_{t}p_{\tau}^{T}]=\bar{\tilde{x}}{t} [\tilde{g}{t} \cdot \tilde{g}{\tau}],\
&where \quad \bar{\tilde{x}}=\sum
{i}(\tilde{x}{\tau}){i}.
\end{aligned}
$$

Then we have

$$
p_{t}^{retrieved}=\tilde{g}{t}\tilde{G}^{T}\Lambda{x}P.
$$

Finally, we get the following equation:

$$
\tilde{x}{t}^{retrieved}=(\alpha \tilde{g}{t}\tilde{G}^{T})\tilde{X} \quad c.f. \quad softmax(\frac{\tilde{g}{t}G^{T}}{\sqrt{d{k}}})\tilde{X}.
$$

Therefore, the key matrix $W_{k}$ can be formulated as follows:

$$ W_{k}=W_{q}=W_{repeat}f_{down}(\cdot). $$

We can also find that the key matrix $W_{k}$ of TEM is not like that of the self-attention block in Transformer, e.g. the key matrix $W_{k}$ of TEM is hand-coded, instead of learnable.

We can see that the generative process in TEM is doing exactly self-attention (but can only attend to past experience, not including future experience).

image
Now, let's focus on the inference process of TEM. The Hebbian memory in the inference process is doing exactly the same thing as the Hebbian memory in the generative process. They may use different components (either $g$ or $x$, here is $x$) to calculate the update weights. Here, all we have to know is that $M_{t}$ in the inference process of TEM also binds every $x$ with every $g$. Firstly, we use $\gamma$ s (different modules may have different $\gamma$) to filter the original $x$:

$$
\tilde{x}{t}=filter(x{t})=(1-\gamma)\tilde{x}{t-1}+\gamma x{t}.
$$

Similarly, one step attractor can be formulated as follows:

$$
\tilde{g}{t}^{retrieved}=(\alpha \tilde{x}{t}\tilde{X}^{T})\tilde{G} \quad c.f. \quad softmax(\frac{\tilde{x}{t}\tilde{X}^{T}}{\sqrt{d{k}}})\tilde{G}.
$$

We can easily find that the inference process of TEM is also doing self-attention, e.g.

$$ W_{q}=filter(\cdot) \quad W_{k}=filter(\cdot) \quad W_{v}=MLP(\cdot) $$

Different from the generative process of TEM, there are learnable parameters in $W_{q}$, $W_{k}$, $W_{v}$ (although most parts are hand-coded in $filter(\cdot)$, which is also part of $W_{v}$ in the generative process).

image
Now, we can conclude that

  • TEM is a causal Transformer with RNN position encodings.

@NorbertZheng
Copy link
Owner Author

NorbertZheng commented Jan 28, 2023

Relation to other models

image
In the Successive Representation (SR) model, they utilize the problem organization form of non-space tasks in reinforcement learning to tackle the rodents' navigation task. After all, the spatial navigation task is just a special case of the reinforcement learning task. We all know that in RL, the value of states can be expressed as follows:

$$
V(s)=\mathbb{E}{\pi}\left[\sum{t=0}^{\infty}\gamma^{t}R(s_{t})|s_{0}=s\right].
$$

But we do not care about the reward function $R(s)$ (after all, the standard place cells do not care about the reward), so we can decompose the value function $V(s)$ into two parts as follows:

$$
V(s)=\sum_{s'}M(s,s')R(s') \quad M(s,s')=\mathbb{E}{\pi}\left[\sum{t=0}^{\infty}\gamma^{t}\mathbb{I}(s_{t}=s')|s_{0}=s\right].
$$

$M(s,s')$ is a sort of function that predicts how likely we can get to the target location $s'$, given the current location $s$. Of course, with exponential decay over time. And SR model uses $M(s,s')$ to describe a place cell with the max-firing location at the target location $s'$. When the rodent is at location $s$, the firing rate of place cell $s'$ will predict the accumulated likelihood of getting to location $s'$ in the future.

Pretty simple? But there are some details in the equation we have to notice.

  • Reward function $R(s)$ is a sort of representation of the object, which could be independent of the abstract location, the walking policy, etc. Of course, the representation of one object (feature vector) may contain many aspects of that object, including value, sensory observation, etc., but we may only care about a subset of the attributes of the object. For example, in TEM the representation of one object only cares about the sensory observation, so the feature vector of the object reduces to sensory observation $x$, but in TEM-ovc-replay, the feature vector of the object concatenates sensory observation $x$ and reward value $r$. In the original RL equation, the representation of the object is only about the reward value $r$, after all, RL doesn't care about the exact description of the object, it only cares about the reinforcement learning task, which is a subset of all tasks. Here, in the SR model, we heuristically remove the representation of the object from the original value function, which means that we remove the ability of predicting the reward value from the original model.
  • $M(s,s')$ is calculated under specified walking policy $\pi$, which means that $M(s,s')$ is not the likelihood of getting from $s$ to $s'$ under diffusion, e.g. $M(s,s')$ is not a pure representation of that graph, but with some bias. But we still have to note that $M(s,s')$ still have the ability to predict the next location given the current location and action, e.g. $M(s,s')$ is a kind of representation that combines the space representation with something else (action representation, or we can call policy representation).

image
Now, we can see that claim clearly. $M(s,s')$ is a conjunctive representation of global transition structure $g$ (e.g. the space representation factor) and local transition structure $o$ (e.g. the action representation factor). As we can see, the action representation factor $o$ is associated with the abstract value $r$, but the action representation factor $o$ and the feature representation factor $f$ can still be separated given that we can change the [location, reward value] of the object, which leads to that the action representation factor exists independently (orthogonally) of the object representation. From that view, we can understand why $M(s,s')$ is policy-dependent. Besides, $M(s,s')$ is also limited in the specified task, e.g. neither the location nor the reward value of the object will be changed, which means that $M(s,s')$ directly learns that conjunctive representation, instead of composes two representation factors together. For example, when the walking policy is strict, e.g. we cannot choose any action that causes the agent leaving from the object, the agent will predict there is no way to walk away from the goal location, which has a great difference from the prediction of using the space representation factor. Therefore, the SR model cannot provide a generalizable graph representation (the space representation factor).

image
Now, we can agree with one point that the SR model is trying to decompose the conjunctive representation into multiple representation factors heuristically, instead of automatically. So how could we decompose the conjunctive representation (e.g. the task representation) into multiple representation factors (e.g. orthogonal task factors in the task (representation) space) automatically? Here is a work from James Whittington. He finds that adding some biological constraints (namely nonnegativity and energy efficiency in both activity and weights) to the loss function leads to disentangled representation, e.g. decomposing the task representation into multiple representation factors automatically.

$$
\begin{aligned}
\mathcal{L}&=\underbrace{\mathcal{L}{non-neg}+\mathcal{L}{activity}+\mathcal{L}{weight}}{Biological\ constraints}+\underbrace{\mathcal{L}{prediction}}{Functional constraints}\
&\mathcal{L}{non-neg}=\beta{non-neg}\sum_{i}max(-a_{i},0)\
&\mathcal{L}{activity}=\beta{activity}\sum_{l}||a_{l}||^{2}\
&\mathcal{L}{weight}=\beta{weight}\sum_{l}||W_{l}||^{2}\
&\mathcal{L}{prediction}=\beta{prediction}||\hat{y}-y||^{2}\
\end{aligned}
$$

In the first figure, the biological constraints are not added to the model, and we can see that each hidden neuron may respond to the change of multiple task factors. But if we add the biological constraints to the model, each hidden neuron will only respond to the change of one task factor, e.g. disentangled representation. And this means that we automatically decompose the task representation to orthogonal task factors, with each hidden neuron representing at most one task factor. Now, we can conclude that the SR model is just a subset of TEM-disentangled.

image
Now, we come to Clone-Structured Cognitive Graph (CSCG). This model can also learn a graph representation, but just like the SR model, the learned graph representation is also not generalizable. And due to that the walking policy is diffusion by default, the graph representation learned by CSCG is exactly the space representation (without combining with the action representation). Despite its non-generalizable graph representation, CSCG can learn such representation quickly. And in the beginning of entering environments with a brand-new structure, the hippocampus serves as a graph, instead of associative memory. Therefore, we can use CSCG to model the hippocampus, and then generate replays from CSCG to facilitate the structure abstraction process of TEM. Cheers~

image
Spatial Memory Pipeline (SMP) is another model built to explain the computational mechanisms of the hippocampal-entorhinal system. The SMP model has a similar architecture to TEM, e.g. VAE. But the SMP model is much more complex than TEM. SMP uses a memory bank in the machine learning field to model the hippocampus, which is not bio-plausible as the attractor network (after all, the attractor network is built by the computational neuroscience community). And its path integration part is also complex. Of course, such a complex model provides a much more powerful information processing ability than TEM. It can directly process raw visual sequences! And we can observe similar neural representation in the SMP model. However, due to the complexity of the model, the SMP model is not as good as TEM to explain the key principle of the hippocampal-entorhinal neural system.

@NorbertZheng
Copy link
Owner Author

More about disentangled representation

We use a discrete $16\times 16$ world, (so 256 locations; $n_{l}=256$) and optimize an independent representation, $z(x) \in \mathbb{R}^{n_{c}}, at each location. We now detail each component of the following loss

$$
\mathcal{L}=\underbrace{\mathcal{L}{nonneg}+\mathcal{L}{activity}+\mathcal{L}{weight}}{Biological\ constraints}+\underbrace{\mathcal{L}{location}+\mathcal{L}{actions}+\mathcal{L}{objects}}{Functional\ constraints}+\underbrace{\mathcal{L}{path\ integration}}{Structure\ constraints}.
$$

  • Biological losses. These are exactly the same as the above part, but we must also average over all locations, $x$

$$
\begin{aligned}
&\mathcal{L}{nonneg}=\frac{\beta{nonneg}}{n_{l}}\sum_{x}\sum_{i}max(-z_{i}(x),0)\
&\mathcal{L}{weight}=\beta{weight}\sum_{t}||W_{t}||^{2}\
&\mathcal{L}{activity}=\frac{\beta{activity}}{n_{l}}\sum_{x}||z(x)||^{2},
\end{aligned}
$$

where $z_{i}(x)$ is a neuron in representation $z(x)$, $t$ indexes the task (i.e. object, action, location, prediction), and the $\beta$ determines the regularization strength.

  • Location loss. The representation predicts location (a one-hot encoding describing each of the 256 locations) via a linear transformation, which is then fed into a softmax cross-entropy loss. In particular, the logits for each location, $x$, are $W_{l}z(x)$, where $W_{l} \in \mathbb{R}^{n_{l}\times n_{c}}$. If we denote each row of $W_{l}$ as $l_{x}$, noting that this row "corresponds" to location $x$ in the one-hot encoding, then the loss is as follows:

$$
\mathcal{L}{location}=-\frac{\beta{location}}{n_{l}}\sum_{x}ln\frac{e^{l_{x}\cdot z(x)}}{\sum_{x'}e^{l_{x'}\cdot z(x)}}.
$$

  • Object loss. An object is either present or not present at each location, so we use a sigmoid cross-entropy loss. In particular, the logits for each location $x$ is $W_{o}z(x)$, where $W_{o} \in \mathbb{R}^{1\times n_{c}}$. $\mathbb{I}(object\ at\ x)$ returns a $1$ if an object is present at location, $x$, and $0$ otherwise, then the loss is as follows:

$$
\mathcal{L}{object}=-{\beta{object}}{n_{l}}\sum_{x}\left[\mathbb{I}(object\ at\ x)ln\sigma(W_{o}z(x))+\mathbb{I}(object\ not\ at\ x)ln(1- \sigma(W_{o}z(x)))\right].
$$

  • Action loss. More than one action can be correct at a location, so we use a sigmoid cross-entropy loss for each of the $4$ actions. In particular, the logits at a location, $x$, are $W_{t}\cdot z(x)$, where $W_{t} \in \mathbb{R}^{n_{a}\times n_{c}}$, where $n_{a}$ is the number of actions ($4$ in our case - North, South, East, West). We denote each row of $W_{t}$ as $t_{j}$, noting this row "corresponds" to action $j$ in the action encoding. $\mathbb{I}(a=a(x))$ returns a $1$ if action, $a$, is a correct action at location $x$, and $0$ otherwise, then the loss is as follows:

$$
\mathcal{L}{action}=-\frac{\beta{action}}{n_{l}}\sum_{x}\sum_{a}\left[\mathbb{I}(a=a(x))ln\sigma(t_{a}\cdot z(x))+(\mathbb{I}(a\neq a(x)))ln(1-\sigma(t_{a}\cdot z(x)))\right].
$$

  • Structural loss. We use a squared error loss for the structural constraint, which asks for neighboring representations to be related to each other by an action matrix $W_{a}$ for each action $a$. This is just like a path integration loss. This loss is done for every location, $x$, and each of the $4$ actions, $a$.

$$
\mathcal{L}{path\ integration}=\frac{\beta{path\ integration}}{n_{l}}\sum_{x}\sum_{a}||z(x)-f(W_{a}z(x-d_{a}))||^{2},
$$

where $W_{a} \in \mathbb{R}^{n_{c}\times n_{c}} is a weight matrix that depends on action, $a$, (i.e. there are $4$ trainable weights matrices - one for each action), and $d_{a}$ means the displacement in the underlying space (the space of $x$), that the action $a$ corresponds to.

@NorbertZheng
Copy link
Owner Author

NorbertZheng commented Jan 28, 2023

Pattern forming dynamics. The overall loss can be optimized with respect to the weights. However, it can also be optimized directly with respect to $z$. This is particularly interesting for us, as it

  • allows our representation to be dynamic and change rapidly for a single task, and not just slowly via learning over many tasks.

This is a necessity for us as we need to represent objects which may move between tasks. Planning? To optimize both $z$ (task particularities) and weights (task generalities), we do so in two stages. First, we optimize with respect to $z$ to "infer" a representation for the current task. Second, we optimize with respect to the weights to learn parameters that are general across tasks.

When optimizing with respect to $z$ we only optimize two terms in the loss: $\mathcal{L}{object}$ and $\mathcal{L}{path\ integration}$. We optimize the first term so the systems has the ability to know where the objects are. We optimize the second term so that information can be propagated around (effectively via path integration).

The dynamics of the $\mathcal{L}_{object}$ are:

$$
\frac{d\mathcal{L}{object}}{dz(x)}=-W{objects}^{T}(\mathbb{I}(object\ at\ x)-\sigma(W_{o}z(x))).
$$

This says if you get the object prediction wrong, then update $z$ to better predict the object. We restrict this update to only take place where the object is, so it is just an object signal. This update is equivalent to a rodent observing that it is at an object.

The dynamics of the$\mathcal{L}_{path\ integration}$ are:

$$
\begin{aligned}
\frac{d\mathcal{L}{path\ integration}}{dz(x)}=&\sum{a}[-(z(x)-f(W_{a}z(x-d_{a}))\
&+W_{a}^{T}(z(x+d_{a})-f(W_{a}z(x)))\odot f'(W_{a}z(x))].
\end{aligned}
$$

The two terms in the above equation can be easily understood. The first says that the representation at each location, $z(x)$, should be updated according to what its neighbors think it should (this is the same update rule as path integration!). The second term says the representation at each location, $z(x)$, should be updated if did not predict its neighbors correctly. This equation tells representations to update based on their neighbors. This is just like a cellular automata, but instead of a discrete value being updated on the basis of its neighbors, it is a whole population vector whose elements are continuous. Indeed, just like cellular automata, it is also possible to initialize a single "cell" (location) of the cellular automata, and have that representations propagate throughout the space. In this case, it's just like path integration, but spreading through all space at once. We note, however, that in our simulations we initialize representations at all locations (for each walk).

We note that while we simulated this on a discrete grid, the same principles apply to continuous cases. In this case the sums over location/actions need to be replaced with integrals.

This is very general approach for understanding representations. The structural loss does not have to be related to the rules of path integration. It can be anything. It could be the rules of a graph. It could be rules of topology. It could have one set of rules at some locations and another set of rules at other locations.

  • The rules don't have to be neighboring representation telling each other what to be, it could also be long range rules too.

If there are structure or rules in the world or behavior, our constraints say that representations should understand that structures or rules. In mathematics this is known as a homeomorphism. In sum, understanding representations via constraints is very general.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant