强化学习(五)-Deterministic Policy Gradient (DPG) 算法及公式推导

news/2024/5/18 21:47:37 标签: 算法, 强化学习, RF, DPG, DDPG, actor-critic

针对连续动作空间,策略函数没法预测出每个动作选择的概率。因此使用确定性策略梯度方法。

0 概览

  • 1 actor输出确定动作
  • 2 模型目标:
    actor目标:使critic值最大
    critic目标: 使TD error最大
  • 3 改进:
    使用两个target 网络减少TD error自举估计。

1 actor 和 critic 网络

  • 确定性策略网络
    actor: a= π ( s ; θ ) \pi(s;\theta) π(s;θ) 输出为确定性的动作a
  • 动作价值网络
    critic Q=q(s,a;w) ,用于评估动作a的好坏

2 critic网络训练

  • 观察一组数据 ( s t , a t , r t , s t + 1 ) (s_t,a_t,r_t,s_{t+1}) (st,at,rt,st+1)
    即在状态 s t s_t st时,执行动作 a t a_t at,得到奖励 r t r_t rt,和下一状态 s t + 1 s_{t+1} st+1
  • a t 时刻 Q 值 : q t = q ( s t , a t , w ) a_t时刻Q值: q_t=q(s_t,a_t,w) at时刻Q:qt=q(st,at,w)
  • a t + 1 时刻 Q 值 : q t + 1 = q ( s t + 1 , a t + 1 , w ) a_{t+1}时刻Q值: q_{t+1}=q(s_{t+1},a_{t+1},w) at+1时刻Q:qt+1=q(st+1,at+1,w) ,其中 a t + 1 = π ( s t + 1 ; θ ) a_{t+1}=\pi(s_{t+1};\theta) at+1=π(st+1;θ)
    即TD Target = r t + γ ∗ q t + 1 r_t+\gamma * q_{t+1} rt+γqt+1
  • 目标:使t时刻的TD error最小
    TD error: δ t = q t − ( r t + γ ∗ q t + 1 ) \delta_t=q_t-(r_t+\gamma * q_{t+1}) δt=qt(rt+γqt+1)
    w = w − α ∗ δ t ∗ ∂ q ( s t , a t ; w ) ∂ w w=w-\alpha *\delta_t* \frac{\partial q(s_t,a_t;w)}{\partial w} w=wαδtwq(st,at;w)

3 actor 网络训练

actor 网络目标是时critic值最大,所以要借助critic网络,将actor值带入critic网络,使critic最大。

  • a= π ( s ; θ ) \pi(s;\theta) π(s;θ) ,带入q(s,a;w)中 得到 q(s, π ( s ; θ ) \pi(s;\theta) π(s;θ) ;w)
    即使 q(s, π ( s ; θ ) \pi(s;\theta) π(s;θ) ;w) 最大
    θ \theta θ求导:
    g = ∂ q ( s , π ( s ; θ ) ; w ) ∂ θ = ∂ a ∂ θ ∗ ∂ q ( s , a ; w ) ∂ a g=\frac{\partial q(s,\pi(s;\theta);w)}{\partial \theta}=\frac{\partial a }{\partial \theta} *\frac{\partial q(s,a;w) }{\partial a} g=θq(s,π(s;θ);w)=θaaq(s,a;w)
  • 参数更新
    θ = θ + β ∗ g \theta=\theta + \beta* g θ=θ+βg

4 训练改进

4.1 主网络actor和critic更新

critic 网络更新时,在计算TD error时,使用了自举,会导致数据过高估计或者过低估计。
关键在于 t + 1 t+1 t+1时刻的 a t + 1 和 q t + 1 怎么生成 a_{t+1}和q_{t+1}怎么生成 at+1qt+1怎么生成
和其他方法一样,可以使用两个actor和两个critic网络,减少自举带来的估计。

  • t+1 时的 a t + 1 a_{t+1} at+1使用另一个target 策略网络actor生成
    a t + 1 = π ( s t + 1 ; θ ˉ ) a_{t+1}=\pi(s_{t+1};\bar\theta) at+1=π(st+1;θˉ)
  • 同样t+1时 q t + 1 q_{t+1} qt+1使用另一个target critic网络生成
    q t + 1 = q ( s t + 1 , a t + 1 ; w ˉ ) q_{t+1}=q(s_{t+1},a_{t+1};\bar w) qt+1=q(st+1,at+1;wˉ)

actor 参数更新方式不变。
critic更新方式变化,使用了target网络产生的 a t + 1 和 q t + 1 a_{t+1}和q_{t+1} at+1qt+1
在这里插入图片描述

4.2 target网络actor和critic更新

target 网络初始时来自主网络,后期更新时,部分来自主网络,部分来自自己。
w ˉ = τ ∗ w + ( 1 − τ ) ∗ w ˉ \bar w= \tau *w +(1-\tau) * \bar w wˉ=τw+(1τ)wˉ
θ ˉ = τ ∗ θ + ( 1 − τ ) ∗ θ ˉ \bar \theta= \tau *\theta +(1-\tau) * \bar \theta θˉ=τθ+(1τ)θˉ

5 其他改进措施

  • 添加经验回放, Experience replay buffer
  • 多步TD target
  • target networks

http://www.niftyadmin.cn/n/5282649.html

相关文章

如何预防最新的Faust变种.faust 勒索病毒感染您的计算机?

引言: 在当今数字化的时代,勒索病毒已经成为网络安全的一大威胁。.faust 勒索病毒是其中的一种,它以高效的加密算法和勒索手段而闻名。本文将深入介绍.faust 勒索病毒的特征、威胁,以及针对其攻击的恢复方法和预防策略。数据的重…

从零学算法5

5.给你一个字符串 s,找到 s 中最长的回文子串。 如果字符串的反序与原始字符串相同,则该字符串称为回文字符串。 示例 1: 输入:s “babad” 输出:“bab” 解释:“aba” 同样是符合题意的答案。 示例 2&…

P4 音频知识点——PCM音频原始数据

目录 前言 01 PCM音频原始数据 1.1 频率 1.2 振幅: 1.3 比特率 1.4 采样 1.5 量化 1.6 编码 02. PCM数据有以下重要的参数: 采样率: 采集深度 通道数 ​​​​​​​ PCM比特率 ​​​​​​​ PCM文件大小计算: ​…

【Linux基础开发工具】gcc/g++使用make/Makefile

目录 前言 gcc/g的使用 1. 语言的发展 1.1 语言和编译器自举的过程 1.2 程序翻译的过程: 2. 动静态库的理解 Linux项目自动化构建工具-make/makefile 1. 快速上手使用 2. makefile/make执行顺序的理解 前言 了解完vim编辑器的使用,接下来就可以尝…

Mac OS 13+,Apple Silicon,删除OBS虚拟摄像头(virtual camera),

原文链接: https://www.reddit.com/r/MacOS/comments/142cv OBS为了捕获摄像头视频,将虚拟摄像头插件内置为系统插件了.如下 直接删除没有权限的,要删除他,在mac os 13以后,需要关闭先关闭苹果系统的完整性保护(SIP) Apple 芯片(M1,....)的恢复模式分为两种,回退恢复模式,和…

实用干货:公司规定所有接口都用 POST请求,为什么?

大家好,我是大澈! 本文约1000字,整篇阅读大约需要2分钟。 感谢关注微信公众号:“程序员大澈”,免费领取"面试礼包"一份,然后免费加入问答群,从此让解决问题的你不再孤单&#xff01…

用python对航空公司客户价值进行聚类分析

1.实验目的 1.会用Python创建KMeans聚类分析模型; 2.使用KMeans模型对航空公司客户价值进行聚类分析; 3.会对聚类结果进行分析 2.实验设备 Jupyter notebook 3.实验原理 4.实验内容 使用sklearn.cluester的KMeans类对航空公司客户数据进行聚类分析&…

Web前端复习

一、随堂练习 1.小题 margin vanish:border和inline-block都可以形成bfc二维数组转置:res[i] [];函数的不同声明定义: 有变量名字的函数,即便后面声明了同样的,以函数表达式为主;定义,运行。再…