-
Notifications
You must be signed in to change notification settings - Fork 0
TD3 in MADDPG
TD3 in Stable Baseline 3
TD3是一种DDPG算法的进阶实现,由藤本、斯科特、赫科·冯·霍夫和戴夫米格共同发表在 Addressing Function Approximation Error in Actor-Critic Methods(2018).。 对其论文的解读可以参考这篇知乎文档。
Stable Baseline 3中将其作为Off Policy Algorithm的一个子类实现,而主要的 Actor-Critic的神经网络,都被封装进了TD3 Policy类。尽管从算法意义上,critic 网络是基于Q的,因此不算policy的一部分,但是Stable Baseline 3应该是出于代码 重用的角度,仍将Q一并封装进入了policy。以下的所有policy,均指代码里这种对 actor-critic的整体抽象,而非单独actor所用来决定action的部分。
简而言之,从代码实现的角度,TD3作为一个Actor-Critic的算法,以共使用了
六个神经网络,即actor target
(下文用AT代指)、actor
(下文用A代指)、
两个critic target
(下文用CT1、CT2分别代指)和两个critic
(下文用C1、
C2分别代指)。此外,TD3 Policy还提供一个成为Q1的接口,这个接口内部
运行的就是C1神经网络。
这篇文章接下来将主要介绍TD3的training过程,和调用、封装TD3过程中的注意事项。 有关TD3的基类——Off Policy Algorithm和Base Algorithm的分析,可以参考右侧 Navigation bar里面的其他页面。
Training由TD3中的Train函数实现,被Learn函数调用。每一次Training会连续进行
gradient step
步,每一gradient step
会并行地处理batch size
个sample,
称为一个batch。
对于每一个batch,TD3都会用它来更新C1或C2网络。每隔delay
步,才会更新一次A网络。
同时会更新CT1、CT2、AT网络。
方便起见,Batch size以下简写为 B ,Action Space对应的shape以下简写为 Sa ,Observation Space对应的shape以下简写为 So
- 调用AT网络生成一个(B,Sa)的Tensor,称为next action;
- 将next action和observation合并成一个(B,Sa+So)的Tensor;
- 把这个合并的tensor输入CT1、CT2,这样,对于tensor中_B_个sample来说, 每一个都会获得两个Q值,取其中较小的一个,如果对应的sample会导致env的结束,把Q置0;
- 上面获得这个Q值,乘以
Gamma
之后加上reward,作为Q target - 把sample里的action和observation直接输入C1、C2,获得一个(B,2)的Tensor, 跟Q target取loss,反向传播更新C1、C2.
- 第1、2、3步是在
th.no_grad
中计算的,所以gradient不会反向传播到这里。 - 第5步中的action不是第1步的next action,而是直接来自env的action
从sample中拿到observation,喂给A网络,得到(B,Sa)的Tensor, 跟observation concatenate之后,把这个(B,Sa+So)的 Tensor喂入Q1接口(内部调用也就是C1),计算Loss,反向传递。
通过polyak update方法。
def polyak_update(params: Iterable[th.nn.Parameter], target_params: Iterable[th.nn.Parameter], tau: float) -> None:
with th.no_grad():
# zip does not raise an exception if length of parameters does not match.
for param, target_param in zip_strict(params, target_params):
target_param.data.mul_(1 - tau)
th.add(target_param.data, param.data, alpha=tau, out=target_param.data)
其中C1、C2以及A分别作为params,对应地,target param是CT1、CT2和AT。
在初始化TD3的时候,传入的应该是Policy这个类,内部会自动通过这个类初始化一个policy。
所以,也必须保证Policy这个类的constructor的前三个参数为:obs-space、action-space和 lr-scheduler。其余的参数会通过kwargs的形式传入。
其中lr-scheduler是一个Callable,是一个输入当前进行了百分之多少,输出learning rate的函数。 在TD3中,直接是一个constant function,即不论输入百分之几,输出都是同样的learning rate。
只要保证policy输出的action和env所接纳的action一致,包括dimension、范围、是否离散等。 如果不一致,需要wrapper来进行变幻和映射。
对于我们的情况可能未必需要
会在训练的不同环节调用,比如_on_step
方法会在每一个training step之后调用,如果它
返回false,就会立即终止训练。
原因是原生的action和critic里面比较好地实现了序列化、反序列化(通过pickle包), 因此可以中断训练、从中断恢复、保存训练结果以备后续使用等。
OffPolicyAlgorithm
默认会把policy作为被保存对象之一。因此policy应当支持
序列化和反序列化。如果封装原生的actor和critic,那么直接调用super._get_data
不需要额外实现Q1 Forward的NN!