Skip to content

TD3 in MADDPG

David Yunhao Liu edited this page Mar 7, 2021 · 2 revisions

TD3是一种DDPG算法的进阶实现,由藤本、斯科特、赫科·冯·霍夫和戴夫米格共同发表在 Addressing Function Approximation Error in Actor-Critic Methods(2018).。 对其论文的解读可以参考这篇知乎文档。

Stable Baseline 3中将其作为Off Policy Algorithm的一个子类实现,而主要的 Actor-Critic的神经网络,都被封装进了TD3 Policy类。尽管从算法意义上,critic 网络是基于Q的,因此不算policy的一部分,但是Stable Baseline 3应该是出于代码 重用的角度,仍将Q一并封装进入了policy。以下的所有policy,均指代码里这种 actor-critic的整体抽象,而非单独actor所用来决定action的部分。

简而言之,从代码实现的角度,TD3作为一个Actor-Critic的算法,以共使用了 六个神经网络,即actor target(下文用AT代指)、actor(下文用A代指)、 两个critic target(下文用CT1、CT2分别代指)和两个critic(下文用C1、 C2分别代指)。此外,TD3 Policy还提供一个成为Q1的接口,这个接口内部 运行的就是C1神经网络。

这篇文章接下来将主要介绍TD3的training过程,和调用、封装TD3过程中的注意事项。 有关TD3的基类——Off Policy Algorithm和Base Algorithm的分析,可以参考右侧 Navigation bar里面的其他页面。


1. Training in TD3

Training由TD3中的Train函数实现,被Learn函数调用。每一次Training会连续进行 gradient step步,每一gradient step会并行地处理batch size个sample, 称为一个batch。

对于每一个batch,TD3都会用它来更新C1或C2网络。每隔delay步,才会更新一次A网络。 同时会更新CT1、CT2、AT网络。

方便起见,Batch size以下简写为 B ,Action Space对应的shape以下简写为 Sa ,Observation Space对应的shape以下简写为 So

1.1 更新C1、C2网络

  1. 调用AT网络生成一个(BSa)的Tensor,称为next action;
  2. 将next action和observation合并成一个(BSa+So)的Tensor;
  3. 把这个合并的tensor输入CT1、CT2,这样,对于tensor中_B_个sample来说, 每一个都会获得两个Q值,取其中较小的一个,如果对应的sample会导致env的结束,把Q置0;
  4. 上面获得这个Q值,乘以Gamma之后加上reward,作为Q target
  5. 把sample里的action和observation直接输入C1、C2,获得一个(B,2)的Tensor, 跟Q target取loss,反向传播更新C1、C2.

注意

  • 第1、2、3步是在th.no_grad中计算的,所以gradient不会反向传播到这里。
  • 第5步中的action不是第1步的next action,而是直接来自env的action

1.2 更新A网络

从sample中拿到observation,喂给A网络,得到(BSa)的Tensor, 跟observation concatenate之后,把这个(BSa+So)的 Tensor喂入Q1接口(内部调用也就是C1),计算Loss,反向传递。

1.3 同步C1、C2到CT1、CT2,和A到AT

通过polyak update方法。

def polyak_update(params: Iterable[th.nn.Parameter], target_params: Iterable[th.nn.Parameter], tau: float) -> None:
    with th.no_grad():
        # zip does not raise an exception if length of parameters does not match.
        for param, target_param in zip_strict(params, target_params):
            target_param.data.mul_(1 - tau)
            th.add(target_param.data, param.data, alpha=tau, out=target_param.data)

其中C1、C2以及A分别作为params,对应地,target param是CT1、CT2和AT。

2. 使用中的注意点

2.1 创建时传入Policy的类而非Policy的实例

在初始化TD3的时候,传入的应该是Policy这个类,内部会自动通过这个类初始化一个policy。

所以,也必须保证Policy这个类的constructor的前三个参数为:obs-space、action-space和 lr-scheduler。其余的参数会通过kwargs的形式传入。

其中lr-scheduler是一个Callable,是一个输入当前进行了百分之多少,输出learning rate的函数。 在TD3中,直接是一个constant function,即不论输入百分之几,输出都是同样的learning rate。

2.2 保证predict的第一个输出的shape与env的step的输入一致

只要保证policy输出的action和env所接纳的action一致,包括dimension、范围、是否离散等。 如果不一致,需要wrapper来进行变幻和映射。

2.3 Callback

对于我们的情况可能未必需要

会在训练的不同环节调用,比如_on_step方法会在每一个training step之后调用,如果它 返回false,就会立即终止训练。

2.4 应当封装使用原生的ActorContinuesCritic类来实现自己的TD3Policy

原因是原生的action和critic里面比较好地实现了序列化、反序列化(通过pickle包), 因此可以中断训练、从中断恢复、保存训练结果以备后续使用等。

2.4.1 储存和重用网络参数

OffPolicyAlgorithm默认会把policy作为被保存对象之一。因此policy应当支持 序列化和反序列化。如果封装原生的actor和critic,那么直接调用super._get_data

2.4.2 Q1 Forward就是C1

不需要额外实现Q1 Forward的NN!