强化学习(9):TRPO、PPO以及DPPO算法

news/2024/5/18 21:47:34 标签: 强化学习, TRPO, PPO, DPPO

本文主要讲解有关 TRPO算法、PPO 算法、PPO2算法以及 DPPO 算法的相关内容。

PPO__4">一、PPO 算法

PPO(Proximal Policy Optimization) 是一种解决 PG 算法中学习率不好确定的问题的算法,因为如果学习率过大,则学出来的策略不易收敛, 反之,如果学习率太小,则会花费较长的时间。PPO 算法利用新策略和旧策略的比例,从而限制了新策略的更新幅度,让 PG 算法对于稍微大一点的学习率不那么敏感。

为了判定模型的更新什么时候停止,所以 PPO 在原目标函数的基础上添加了 KL 散度部分,用来表示两个分布之间的差别,差别越大值越大,惩罚也就越大。所以可以使两个分布尽可能的相似。PPO 算法的损失函数如下:
J P P O θ ′ ( θ ) = J θ ′ ( θ ) − β K L ( θ , θ ′ ) J_{PPO}^{\theta'}(\theta)=J^{\theta'}(\theta)-\beta KL(\theta,\theta') JPPOθ(θ)=Jθ(θ)βKL(θ,θ)

J θ ′ ( θ ) = E ( s t , a t ) ∼ π θ ′ [ p θ ( a t ∣ s t ) p θ ′ ( a t ∣ s t ) A θ ′ ( s t , a t ) ] J^{\theta'}(\theta)=E_{(s_t,a_t)\sim\pi_{\theta'}}[\frac{p_\theta(a_t|s_t)}{p_{\theta'}(a_t|s_t)}A^{\theta'}(s_t,a_t)] Jθ(θ)=E(st,at)πθ[pθ(atst)pθ(atst)Aθ(st,at)]

PPO 的前身是 TRPO(Trust Region Policy Optimization),TRPOPPO 之间的区别在于 TRPO 使用了 KL 散度作为约束条件,虽然损失函数是等价的,但是这种表示形式更难计算,所以较少使用。TRPO 损失函数如下:
J T R P O θ ′ ( θ ) = E ( s t , a t ) ∼ π θ ′ [ p θ ( a t ∣ s t ) p θ ′ ( a t ∣ s t ) A θ ′ ( s t , a t ) ] J_{TRPO}^{\theta'}(\theta)=E_{(s_t,a_t)\sim\pi_{\theta'}}[\frac{p_\theta(a_t|s_t)}{p_{\theta'}(a_t|s_t)}A^{\theta'}(s_t,a_t)] JTRPOθ(θ)=E(st,at)πθ[pθ(atst)pθ(atst)Aθ(st,at)]

K L ( θ , θ ′ ) < δ KL(\theta,\theta')<\delta KL(θ,θ)<δ

PPO 在训练时可以采用适应性的 KL 惩罚因子:当 KL 过大时,增大 β \beta β 的值来增加惩罚力度;当 kL 过小时,减小 β \beta β 值来降低惩罚力度。即:
i f K L ( θ , θ ′ ) > K L m a x , i n c r e a s e β if\quad KL(\theta,\theta')>KL_{max},\quad increase\quad\beta ifKL(θ,θ)>KLmax,increaseβ

i f K L ( θ , θ ′ ) < K L m i n , d e c r e a s e β if\quad KL(\theta,\theta')<KL_{min},\quad decrease\quad\beta ifKL(θ,θ)<KLmin,decreaseβ

PPO2__35">二、PPO2 算法

PPO2PPO 的基础上去除了 KL 散度损失函数,但是引入了 Clip 损失函数,当目标函数值低于 1 − ϵ 1-\epsilon 1ϵ 或大于 1 + ϵ 1+\epsilon 1+ϵ 时进行截断。其损失函数为:
J P P O 2 ( θ ) = ∑ ( s t , a t ) min ⁡ [ p θ ( a t ∣ s t ) p θ ′ ( a t ∣ s t ) A θ ′ ( s t , a t ) , c l i p ( p θ ( a t ∣ s t ) p θ ′ ( a t ∣ s t ) , 1 − ϵ , 1 + ϵ ) ⋅ A θ ′ ( s t , a t ) ] J_{PPO2}(\theta)=\sum_{(s_t,a_t)}\min[\frac{p_\theta(a_t|s_t)}{p_{\theta'}(a_t|s_t)}A^{\theta'}(s_t,a_t)\quad,\quad clip(\frac{p_\theta(a_t|s_t)}{p_{\theta'}(a_t|s_t)},1-\epsilon,1+\epsilon)\cdot A^{\theta'}(s_t,a_t)] JPPO2(θ)=(st,at)min[pθ(atst)pθ(atst)Aθ(st,at),clip(pθ(atst)pθ(atst),1ϵ,1+ϵ)Aθ(st,at)]
ppo2
上图中绿色虚线是原始的损失函数,蓝色虚线是 clip 函数,红色实线是实际上的损失函数,当优势函数 A 的值为正数或负数时,实际的损失函数有不同的情况。

PPO__44">三、DPPO 算法

DPPO(Distributed Proximal Policy Optimization)简单来说就是多线程并行版的 PPO。相应的代码是按照莫烦的教程来写的,使用了和 A3C 算法类似的网络结构。但是与 A3C 算法不同的是,A3C 算法是副网络与主网络有着相同的网络结构,并用副网络计算出来的梯度更新主网络的参数,更新完后再将主网络的参数同步给副网络。

而这里的 DPPO 代码是副网络不必拥有和主网络相同网络结构,每个副网络只需要有自己独立的环境就好了。副网络在不同的环境中收集数据,然后交给主网络来更新参数。原本应该是在主网络更新后同步网络参数给副网络,但是这样的时空开销都是比较大的,而该过程的实质其实就是用更新后的主网络来收集数据,所以一开始才说副网络不必拥有和主网络相同的网络结构。这一点和 A3C 算法对比一下,想想为什么会有这种区别。

自己在写代码的时候碰到了很多坑,下面来详细说一下:

坑1. 副网络需要收集的数据有三个,状态值 s、动作值 a 以及 Q-target,但是由于多线程的存在,如果把三者分三步保存在列表(队列也可)中,则三者的维度可能会有所差异,在喂入网络的时候就会出错。为了解决这个问题,必须让原来三步的工作具有原子性,即要么三步都做,要么三步都不做。用代码的角度来看,以下代码是不可行的:

global_s.append(batch_s)
global_a.append(batch_a)
global_q.append(batch_q_target)

而以下代码是可行的:

data.append(np.hstack((batch_s, batch_a, batch_q_target)))

坑2. 在主网络更新参数时,所有的副网络应该停止收集数据,因为这时副网络只会收集到过时的数据。所以要保证在主网络对应的线程运行时,其他所有的线程都停止。而这用锁 lock 是没办法解决的,因为锁是解决的多个线程争夺同一资源的情况,而上述情况显然不属于。代码中用的是事件 event 的方式,后面会详细介绍。

坑3. 莫烦的代码中在更新网络参数时使用了一个循环更新10次,当时我想这样和把学习率提高10倍有什么不同吗?你别说,还真不一样,因为 actor 网络的损失函数和新的策略相关的,当策略改变后网络的梯度也会发生改变。

坑4. 输入的数据的维度问题。这个问题在之前的代码中就遇到过了,当时没太在意。比如输入的数据是二维的,期望的输出也是二维的,但是由于中间某些操作会导致输出多了一个多余的维度,变为了三维。这不仅会导致数据的维度不匹配,好像还可能会影响到最终训练的结果。

多线程中的 Event

在 threading.Event() 中,可以用 flag=threading.Event() 来设置多个内部标识。其作用类似于操作系统中同步和互斥用到的标识,常用的函数有:

  • flag.set():将内部标识设为True,默认为False
  • flag.wait():如果内部标识为False,则线程阻塞,直到内部标识变为True
  • flag.clear():将内部标识设为False
  • flag.is_set():返回内部标识的值

在 DPPO 的代码中,将负责收集数据的副网络看作一类,将负责更新网络参数和采取动作的主网络看作一类,分别设置一个内部标识。一开始主网络的标识为 False,副网络的为 True,此时副网络收集数据;当数据收集到一定数量时,将副网络标识设为 False,主网络的设为 True,此时主网络更新;主网络更新完毕后,将自己的标识设为 False,将副网络的设为 True。不断重复以上步骤就可以让收集数据和更新网络两个步骤交替进行。


http://www.niftyadmin.cn/n/968491.html

相关文章

【论文笔记】U-Net模型-用于医学图像分割的神经网络模型

本文主要是对 U-Net 论文中主要内容的提炼&#xff0c;中间加入了一些自己的理解&#xff0c;有些地方自己不是很懂&#xff0c;所以直接采用了直译的方式。建议大家去阅读原文。 U-Net 的主要优点是可以用更少的训练样本达到更好的效果&#xff0c;并且速度很快&#xff0c;这…

UDP Client《——》UDP Server

#include <stdio.h>#include <winsock2.h>#pragma comment(lib, "WS2_32") // 链接到WS2_32.libclass CInitSock {public:CInitSock(BYTE minorVer 2, BYTE majorVer 2){// 初始化WS2_32.dllWSADATA wsaData;WORD sockVersion MAKEWORD(minorVer, ma…

【论文笔记】递归级联网络(Recursive Cascaded Networks)论文及VTN(Volume Tweening Network)

本文是递归级联网络和 VTN 网络论文&#xff0c;及其代码的一些解读。 一、递归级联网络 递归级联网络论文地址&#xff1a;递归级联网络论文 1. 前人工作 之前的工作尝试通过对一些现有网络进行堆叠来建模的&#xff0c;但是每一层网络的输入和任务各不相同&#xff0c;并且…

Secret的三种形式

Secret ConfigMap这个资源对象是Kubernetes当中非常重要的一个对象&#xff0c;一般情况下ConfigMap是用来存储一些非安全的配置信息&#xff0c;如果涉及到一些安全相关的数据的话用ConfigMap就非常不妥了&#xff0c;因为ConfigMap是名为存储的&#xff0c;我们说这个时候我们…

select模型Client——》Server

//// select.cpp文件#include <stdio.h>#include <winsock2.h>#pragma comment(lib, "WS2_32") // 链接到WS2_32.libclass CInitSock {public:CInitSock(BYTE minorVer 2, BYTE majorVer 2){// 初始化WS2_32.dllWSADATA wsaData;WORD sockVersion MA…

spring boot实现发送文本邮件、html邮件、带附件的邮件

源码url: https://github.com/zhzhair/mail-spring-boot.git 1.发送文本邮件&#xff1b; 2.发送html邮件&#xff1b; 3.发送带附件的邮件。转载于:https://www.cnblogs.com/zhzhair-coding/p/10962258.html

WSAAsyncSelect异步套接字模型Client——》Server

/// // WSAAsyncSelect.cpp文件#include <winsock2.h>#pragma comment(lib, "WS2_32") // 链接到WS2_32.lib#include <stdio.h>#define WM_SOCKET WM_USER 101 // 自定义消息class CInitSock {public:CInitSock(BYTE minorVer 2, BYTE majorVer 2){…

医学图像配准软件 ANTs(Advanced Normalization Tools)的安装和使用说明

本文是关于医学图像配准软件 ANTs&#xff08;Advanced Normalization Tools&#xff09;的安装和使用说明。 ANTs ANTs 是 Advanced Normalization Tools 的缩写&#xff0c;是基于 C 语言的一个医学图像处理的软件&#xff0c;速度比较快。 ANTs 支持 2D 和 3D 的图片&…