强化学习DQN 入门小游戏 最简单的Pytorch代码

news/2024/5/18 23:44:18 标签: 强化学习, 深度学习, pytorch

本文目的是用最简单的代码,展示DQN玩游戏的效果,不涉及深度学习原理讲解。

毕竟,入门如此艰难,唯一的动力不过是看个效果,装个biu……

安装OpenAI的游戏库gym

pip install gym

看一下运行效果

import gym

env = gym.make('CartPole-v1')
print('State shape:', env.observation_space.shape)
print('Number of actions:', env.action_space.n)
for _ in range(20):
    observation = env.reset()   # 初始状态
    for t in range(500):
        env.render()    # 显示图像
        action = env.action_space.sample()  # 随机选择一个动作
        observation, reward, done, info = env.step(action)  # 状态,回报,是否结束,信息
        print(observation, reward, done, info)
        if done:
            print("Episode finished after {} timesteps".format(t + 1))
            break
env.close()

在这里插入图片描述
这是个手推车平衡游戏,我们可以采取两个动作(action),向左推或向右推,保持杆子不倒的时间越长,分数越高。游戏的每个时刻都有一个状态(observation),这个游戏中状态是由4个数值描述的,我们其实不必了解这四个值的含义,规律反正是交给网络学习的,这是深度学习很爽的地方,写出AlphaGo的程序员不需要懂围棋规则。每采取一个动作,除了得到下一时刻的状态,最重要的是得到一个回报(reward),在这个游戏中,无论什么动作,回报都是+1,意思是只要活着,就给你加分,直到杆子倒下,游戏结束。

代码

import random
import gym
import torch
from torch import nn, optim

class QNet(nn.Sequential):
    def __init__(self):
        super(QNet, self).__init__(
            nn.Linear(4, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 2)
        )

class Game:
    def __init__(self, exp_pool_size, explore):
        self.env = gym.make('CartPole-v1')
        self.exp_pool = []
        self.exp_pool_size = exp_pool_size
        self.q_net = QNet()
        self.explore = explore
        self.loss_fn = nn.MSELoss()
        self.opt = optim.Adam(self.q_net.parameters())

    def __call__(self):
        is_render = False
        avg = 0
        while True:
            # 数据采样
            state = self.env.reset()
            R = 0
            while True:
                if is_render:
                    self.env.render()
                if len(self.exp_pool) >= self.exp_pool_size:
                    self.exp_pool.pop(0)
                    self.explore += 1e-7
                    if torch.rand(1) > self.explore:
                        action = self.env.action_space.sample()
                    else:
                        _state = torch.tensor(state, dtype=torch.float32)
                        Qs = self.q_net(_state[None, ...])
                        action = torch.argmax(Qs, 1)[0].item()
                else:
                    action = self.env.action_space.sample()

                next_state, reward, done, _ = self.env.step(action)
                R += reward
                self.exp_pool.append([state, reward, action, next_state, done])
                state = next_state

                if done:
                    avg = 0.95 * avg + 0.05 * R
                    print(avg, R)
                    if avg > 400:
                        is_render = True
                    break
            # 训练
            if len(self.exp_pool) >= self.exp_pool_size:
                exps = random.choices(self.exp_pool, k=100)
                _state = torch.tensor([exp[0] for exp in exps]).float()
                _reward = torch.tensor([[exp[1]] for exp in exps])
                _action = torch.tensor([[exp[2]] for exp in exps])
                _next_state = torch.tensor([exp[3] for exp in exps]).float()
                _done = torch.tensor([[int(exp[4])] for exp in exps])

                # 预测值
                _Qs = self.q_net(_state)
                _Q = torch.gather(_Qs, 1, _action)
                # 目标值
                _next_Qs = self.q_net(_next_state)
                _max_Q = torch.max(_next_Qs, dim=1, keepdim=True)[0]
                _target_Q = _reward + (1 - _done) * 0.9 * _max_Q

                loss = self.loss_fn(_Q, _target_Q.detach())
                self.opt.zero_grad()
                loss.backward()
                self.opt.step()


if __name__ == '__main__':
    g = Game(10000, 0.9)
    g()

代码很简练,慢慢读都能懂,我解释一下几个重点。

  1. 整个流程是先采样,将样本存入经验池self.exp_pool,当样本足够时,从经验池中随机选取100条样本进行训练,而后边更新经验池边训练。
  2. QNet就是我们学习的网络,状态到动作的映射,根据输入状态建议采取的动作。
  3. self.explore探索值,很容易可以看明白它是一个概率,控制动作是随机选取还是由网络推荐。探索可以让我们发现新鲜样本,但随着训练进行,我们见过的样本越来越多,应该逐渐减少探索,也就是降低随机动作的概率。
  4. R是每一局游戏的得分,因为这个值在训练中变化非常大,所以加了个avg滑动平均的操作,这样可以更清晰地看出训练效果。由于游戏限制,每局游戏最多到500分就会结束,所以我们设置当avg>400就开始显示图像。
  5. 其实强化学习DQN真正的精髓是在训练中目标值的确定,所以这块我们跳过~[666]

没错,就这点代码,训练2分钟,DQN就能学会玩这个平衡游戏啦!

在这里插入图片描述

再附带一个小车爬坡的小游戏

在这里插入图片描述
游戏名:MountainCar-v0。小车想到达最高峰,但其引擎强度不足以单程通过,所以要在两个山坡间反复横跳,积蓄力量,一鸣惊人~

这个游戏的区别是reward始终为-1,意思是只有到达终点才有奖励,其他打酱油行为都要扣分,拿到高分的办法就是尽快到达终点。状态有2个值:水平位置和速度。动作有3个值:左、右、不动。所以要把QNet的输入和输出维度改一下。

训练难点在于:游戏只持续200个动作,在随机选择动作的情况下,小车很难靠运气到达终点,也就很少有成功的经验,不容易学习。

next_state, reward, done, _ = self.env.step(action)
position, velocity = next_state
reward = (position + 0.5) ** 2
R += reward

于是我根据状态把reward给改了,position + 0.5的原因是,起始最低点的水平坐标值为-0.5,改完之后的reward含义就是:离最低点越远奖励越高,而且奖励还是成平方增长。

经过这通骚操作,就可以训练了,也只需要训练几分钟。avg的阈值我给的30,同学们看着给。

在这里插入图片描述


http://www.niftyadmin.cn/n/1228462.html

相关文章

阅读源码的好网站

http://lxr.free-electrons.com/source/?aarm 对函数、变量等可以很好的跟踪,非常方便!

init 进程和inittab 引导指令

init 进程和inittab 引导指令 init 进程是系统所有进程的起点,内核在完成内核引导以后,即在本线程(进程)空间内加载init 程序,它的进程号是1。init 程序读取/etc/inittab 文件作为其行为指针,根据initab 描…

YOLOv5 人脸口罩识别 免费提供数据集

本文分享快速使用YOLOv5训练自己的人脸口罩数据集。 第一步是搞数据,并把标注文件处理成YOLOv5格式,这其实是最费时的,此处省略1W字……我给同学们整了一个近8000张的人脸口罩数据集,拿走不蟹~ 下载YOLOv5-3.1版本和模型权重&am…

我与《深入浅出嵌入式底层软件开发》

本文系本站原创,欢迎转载! 转载请注明出处: http://blog.csdn.net/mr_raptor/article/details/6744480 硬件-宝剑,软件-剑法,体系结构-内功 看过太多的武侠与历史故事,拥有宝剑的,不一定登顶,熟读兵法的的&#xff0c…

YOLOv5 完美实现中文标签显示

首先,网上有好几个显示中文标签的教程了,我为什么还要写呢??哼,很显然,是觉得他们实现的不够完美嘛~ YOLOv5在标签显示上,是花了点心思的,标签字体的大小,会根据图片尺寸…

Adam优化器总结

根据李宏毅老师的课程,将梯度下降的常见训练优化思路和Adam的内容做简单的梳理。 梯度下降的基本方法是:朝着参数θ\thetaθ的当前梯度ggg的反方向,以η\etaη为步长迈出一步。 一个常见的问题就是梯度下降容易落入局部最小值,因…

如何从源码包安装软件?

如何从源码包安装软件? 从源码包安装软件最重要的就是仔细阅读README INSTALL等说明文件 它会告诉你怎样才能成功安装 通常从源码包安装软件的步骤是: tar jxvf gtk-2.4.13.tar.bz2 解开源码包 cd gtk-2.4.13/ 进入源码目录 ./configure 似乎在某些环境…

图像梯度 CAM 热力图 Pytorch代码

以Pytorch的VGG预训练模型为例,贴一下CAM(Class Activation Map)的核心代码。 img_path relephant.jpgimg Image.open(img_path).convert(RGB) transforms torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normaliz…