强化学习和torchrl

news/2024/5/18 23:44:07 标签: 人工智能, 强化学习

torchrl是一个基于pytorch的强化学习库,我发现根据torchrl的结构可以对强化学习知识点有更加深入的理解,下面将我的理解记录如下:

torchrl中将强化学习的过程分为了几个部分:

  • 环境,需要实现reset, step两个方法
  • replay buffer:用来存放collector采集的数据,训练时从replay buffer中采样
  • collector: 将actor和环境交互的过程抽象为collector, collector需要一个env对象,一个policy
  • loss function: 需要分别计算actor loss和critic loss,其实很多强化学习方法改进的部分就是这个地方
  • 训练过程trainer:分为外循环和内循环,外循环收集数据,而内循环更新参数

下面是使用torchrl实现了PPO算法:

import torch
from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor
from torch import nn

from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import TensorDictReplayBuffer, \
    LazyTensorStorage, SamplerWithoutReplacement
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import ProbabilisticActor, ValueOperator, TanhNormal
from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE

env = GymEnv("Pendulum-v1")
model = TensorDictModule(
    nn.Sequential(
        nn.Linear(3, 128), nn.Tanh(),
        nn.Linear(128, 128), nn.Tanh(),
        nn.Linear(128, 128), nn.Tanh(),
        nn.Linear(128, 2),
        NormalParamExtractor()
    ),
    in_keys=["observation"],
    out_keys=["loc", "scale"]
)
critic = ValueOperator(
    nn.Sequential(
        nn.Linear(3, 128), nn.Tanh(),
        nn.Linear(128, 128), nn.Tanh(),
        nn.Linear(128, 128), nn.Tanh(),
        nn.Linear(128, 1),
    ),
    in_keys=["observation"],
)
actor = ProbabilisticActor(
    model,
    in_keys=["loc", "scale"],
    distribution_class=TanhNormal,
    distribution_kwargs={"min": -1.0, "max": 1.0},
    return_log_prob=True
    )
buffer = TensorDictReplayBuffer(
    LazyTensorStorage(1000),
    SamplerWithoutReplacement()
    )
collector = SyncDataCollector(
    env,
    actor,
    frames_per_batch=1000,
    total_frames=1_000_000
    )
loss_fn = ClipPPOLoss(actor, critic, gamma=0.99)
optim = torch.optim.Adam(loss_fn.parameters(), lr=2e-4)
adv_fn = GAE(value_network=critic, gamma=0.99, lmbda=0.95, average_gae=True)
for data in collector:  # collect data
    for epoch in range(10):
        adv_fn(data)  # compute advantage
        buffer.extend(data.view(-1))
        for i in range(20):  # consume data
            sample = buffer.sample(50)  # mini-batch
            loss_vals = loss_fn(sample)
            loss_val = sum(
                value for key, value in loss_vals.items() if
                key.startswith("loss")
                )
            loss_val.backward()
            optim.step()
            optim.zero_grad()
    print(f"avg reward: {data['next', 'reward'].mean().item(): 4.4f}")

可以看出torchrl对强化学习的抽象做的比较好,另外也提供了一个叫做Trainer的类,可以不需要手动迭代了。
torchrl的另外一个特点是输入输出都是一个tensordict,这个tensordict可以看成一个字典,里面都是tensor,这样每个模块的输入输出都是一个字典,可以自己指定使用哪个key对应的数据,和传统的方法固定输入的个数比起来,增加了灵活性。

其实很多时候我们可能会有以下的困惑

  1. PPO和DDPG的区别?
  2. 可以同时使用PPO和DDPG吗?
  3. 什么时候应该使用2个value net,什么时候应该使用了2个action net?
  4. 哪些算法可以使用importance sampling?
  5. 哪些算法可以使用gae?
  6. 哪些算法可以使用replay buffer?
  7. TD0, TD1, TD lambda的区别是什么,哪些算法可以使用?
  8. on-policy和off-policy的区别是什么,如何判断是否是on-policy?

如果你明白这些问题,那么下面的内容可以跳过了

原理

其实很多强化学习算法的区别在于修改了整个Actor-Critic架构的某一部分,也就是上面提到的torchrl区分出来的这些。而很多算法修改的部分就是loss。loss分为actor网络的loss和crtic网络的loss。

critic 网络的loss:

critic网络是对一个s,a价值的估计,通常是网络的输出和目标s,a价值的均方误差,
l o s s ( s , a ) = ( c r i t i c ( s , a ) − t a r g e t _ v a l u e ( s , a ) ) 2 loss(s,a) = (critic(s,a) - target\_value(s, a))^2 loss(s,a)=(critic(s,a)target_value(s,a))2
这里就涉及到了如何对s, a的目标价值进行估计了,假如采样得到了一个序列: s1, a1, r1, s2, a2, r2, …, sn, an, rn,现在想算sn, an对应的目标价值。可以从rn开始往后取若干个值进行计算:

t a r g e t _ v a l u e ( s , a ) = r n + r n − 1 + . . . + r n − t + c r i t i c ( s n − t ) target\_value(s, a) = r_n + r_{n-1} + ... + r_{n-t} + critic(s_{n-t}) target_value(s,a)=rn+rn1+...+rnt+critic(snt)

这里面使用几个r是可以自己确定的,如果说只使用当前s, a对应的rn,那么就是TD0算法,而如果一直使用到了s1,那么就是TD1算法,如果取了若干个,那么就是TD lambda算法。

actor网络的loss

actor网络的loss分为两种,一种输出的是action的分布,一种是输出action的值。对于输出action分布的情况来说,目标函数中包含一个s,a的分布,要求s,a满足当前策略的分布。对于输出action的值的情况来说,目标函数中s,a的分布不做要求。

因此当策略更新之后,无法使用原来策略采集的数据更新当前策略,因为数据的s,a分布和当前策略是不同的。这个问题是输出action分布的方法才有的,而输出一个确定的action,是不存在这个问题的。简单来说DDPG输出的是一个确定策略,它是不存在这个问题的,因此DDPG可以用之前策略采集的数据。而例如PPO之类的,输出的是一个action的分布,因此它是受到这个问题影响的。这就区分了on-policy和off-policy的算法。

有些人认为,对于DQN来说,根本没有actor函数,直接通过critic选择策略,因此action的分布永远是固定的,也没有这个问题。上面这个看法是错误的,DQN的action分布是会改变的,选择某个动作的概率有时候是0,有时候是1,怎么能说概率分布不变呢。DQN是off-policy的原因是DQN的损失函数中不对s,a的分布做要求,因此s,a分布改变也没有关系。

这时使用了一个importance sampling解决了分布不同的问题,根据上面的推理可以得到,只有on-policy的算法需要使用importance sampling了。而replay buffer中数据的分布是不唯一的,off-policy算法天生可以使用replay buffer,而on-policy算法经过了importance sampling以后也可以使用replay buffer.

问题一: PPO和DDPG的区别?

PPO修改的是critic loss,对critic网络的loss进行了裁剪,主要有两种方法,对应了两个PPO loss,一个方法是裁剪loss 函数,另一个方法是在损失函数中加入KL散度进行调整,两种方法都是希望损失函数不要变化太大,从而更新太多引起模型不稳定。

DDPG修改的是actor loss,将随机动作变为确定动作。

问题二:可以同时使用PPO和DDPG吗?

可以,PPO裁剪的是critic的损失,而DDPG是修改为确定的动作,如果希望PPO输出的是一个确定的动作,那么就是PPO和DDPG结合了。结合之后的PPO变为了off policy的算法

问题三:什么时候应该使用2个value net,什么时候应该使用了2个action net?

这是为了防止value net或者action net更新太快导致模型不稳定。这个不是必须的,可以酌情使用。在torchrl的损失函数中专门有个参数:delay_actor和delay_value可以控制是否需要暂缓更新。

具体请看https://github.com/pytorch/rl/blob/bf264e0e24971fc05ec42b571de7b8df84043a51/torchrl/objectives/ddpg.py:

class DDPGLoss(LossModule):
    """The DDPG Loss class.

    Args:
        actor_network (TensorDictModule): a policy operator.
        value_network (TensorDictModule): a Q value operator.
        loss_function (str): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1".
        delay_actor (bool, optional): whether to separate the target actor networks from the actor networks used for
            data collection. Default is ``False``.
        delay_value (bool, optional): whether to separate the target value networks from the value networks used for
            data collection. Default is ``True``.
        separate_losses (bool, optional): if ``True``, shared parameters between
            policy and critic will only be trained on the policy loss.
            Defaults to ``False``, ie. gradients are propagated to shared
            parameters for both policy and critic losses.

问题四:哪些算法可以使用importance sampling?

只有on-policy算法需要,比如PPO, A2C之类的,对于DDPG,DQN是不需要的
换句话说,输出的是一个确定的策略,而不是一个分布,那么不需要,否则需要。

问题五:哪些算法可以使用gae?

算法输出是一个分布才可以使用gae,否则无法使用,因为无法计算状态的价值,只能获得状态动作对的价值。因此DDPG无法使用,而PPO, A2C是可以使用的

从torchrl的实现中也可以看出,DDPG是不支持gae的
https://github.com/pytorch/rl/blob/bf264e0e24971fc05ec42b571de7b8df84043a51/torchrl/objectives/ddpg.py

if value_type == ValueEstimators.TD1:
	self._value_estimator = TD1Estimator(value_network=self.actor_critic, **hp)
elif value_type == ValueEstimators.TD0:
    self._value_estimator = TD0Estimator(value_network=self.actor_critic, **hp)
elif value_type == ValueEstimators.GAE:
    raise NotImplementedError(
        f"Value type {value_type} it not implemented for loss {type(self)}."
    )
elif value_type == ValueEstimators.TDLambda:
    self._value_estimator = TDLambdaEstimator(
        value_network=self.actor_critic, **hp
    )
else:
    raise NotImplementedError(f"Unknown value type {value_type}")

问题六:哪些算法可以使用replay buffer?

输出确定策略的都能用,输出随机策略的,如果用了Importance sampling也能用。

问题七:TD0, TD1, TD lambda的区别是什么,哪些算法可以使用?

TD1, TD0, TD lambda都能用,而gae需要能算state value的方法才能用,一般来说只有输出动作分布的才能算state value,因此gae只能在输出随机分布的算法中使用,对于DDPG无法使用

问题八:on-policy和off-policy的区别是什么,如何判断是否是on-policy?

最准确的回答是:看actor或者critic的损失函数,如果损失函数中有对s,a的分布有要求,那么就是on-policy的,否则是off-policy的

一般来说,如果使用了输出随机动作,那么actor的损失函数大概率是对s,a分布有要求的,因此是on-policy的,如果使用了输出确定动作的,比如DDPG,那么actor损失函数大概率是对s,a分布无要求的,因此是off-policy的。

另外不要根据动作是否连续进行判断,因为有时候输出的是高斯分布的均值和方差,然后在这个高斯分布中采样,这种虽然获得的也是连续的动作空间,但是输出的仍然是一个分布,因此是一个on-policy的。


http://www.niftyadmin.cn/n/5075141.html

相关文章

根据二叉树创建字符串--力扣

🎈个人主页:🎈 :✨✨✨初阶牛✨✨✨ 🐻强烈推荐优质专栏: 🍔🍟🌯C的世界(持续更新中) 🐻推荐专栏1: 🍔🍟🌯C语言初阶 🐻推荐专栏2: 🍔…

计算机组成与设计的一些概念扫盲

一、术语 超标量架构 早期的单发射架构微处理器的流水线设计目标是做到平均每个时钟周期能执行一条指令,但这一目标不能满足提高处理器性能的要求。为了提高处理器的性能,处理器要具有每个时钟周期发射执行多条指令的能力。超标量体系结构可描述一种微处…

SLAM从入门到精通(基于传感器的闭环控制仿真)

【 声明:版权所有,欢迎转载,请勿用于商业用途。 联系信箱:feixiaoxing 163.com】 前面我们对底盘做了控制,对传感器数据也进行了读取,但是目前为止还没有做过一个完整的ros仿真程序。在这仿真中&#xff0c…

设计模式 - 创建型模式考点篇:工厂模式、建造者模式

目录 一、创建型模式 一句话概括 1.1、工厂模式 1.1.1、简单工厂模式(非 23 种经典设计模式) 概述 案例 1.1.2、静态工厂(扩展) 1.1.3、工厂方法模式 概念 案例 1.2、建造者模式 1.2.1、概念 1.2.2、案例 1.2.3、建…

Python —— UI自动化之八大元素定位

1、基础元素定位 1、id定位 使用html中标签的id元素去定位,在一般定位中优先选择,举例: from time import sleep from selenium import webdriver from selenium.webdriver.common.by import Bydriver webdriver.Firefox() driver.get(&q…

动画圆圈文字标志效果

效果展示 CSS 知识点 实现圆圈文字animation 属性回顾 实现思路 从效果的实现思路很简单,其实就是两个圆圈就可以实现。外层大圆(灰色)用于圆圈文字的展示,而内圆(藏青色)主要用于存放 Logo 图片。布局采…

浏览器自动化神器:Automa 轻松实现任务编排 | 开源日报 No.52

usememos/memos Stars: 13.8k License: MIT memos,一个轻量级的、自托管的备忘录中心。开源且永久免费。 开源且永久免费使用 Docker 可以在几秒钟内完成自我托管支持 Markdown 格式可定制和共享提供 RESTful API 用于自助服务 mamoe/mirai Stars: 12.6k Licen…

内网穿透方法汇总

内网穿透 1.ddns原理和使用条件 电信宽带:通过难过光猫拨号,得到的如果是私网的IP,可以通过10000号投诉,从而得到公网IP(浮动IP,每次拨号会分配一个IP,可以通过ddns实现通过域名绑定&#xff…