【强化学习】Deep Q Learning

Deep Q Learning

在前两篇文章中,我们发现RL模型的目标是基于观察空间 (observations) 和最大化奖励和 (maximumize sum rewards) 的。

如果我们能够拟合出一个函数 (function) 来解决上述问题,那就可以避免存储一个 (在Double Q-Learning中甚至是两个) 巨大的Q_table。

Tabular -> Function

  • Continous Observation: 函数能够让我们处理连续的观察空间,而表只能处理离散的。
  • Saving the space: 不用存储 len(state) * len(action) 大小的Q_table

在早期人们试过使用核函数或者线性函数等各种方法去拟合这个function,但后来深度神经网络出现后人们纷纷开始研究如何用DNN来拟合。

然而以上的拟合方式不免存在一个问题,我们期望得到一个DNN,使得DNN(state)->Q-value

可是强化学习中,最好的Q-value在开始时是不知道的 (这也是强化学习和机器学习不一样的地方:我们不知道能否训练到一个Q值,直到有人把它训练出来),这就导致我们在训练过程中没有目标函数。

Natural Deep Q Learning

所有的第一步必须从高维的感官输入中获得对环境的有效表示

深度Q网络(DQN)是一种将深度学习和Q学习相结合的强化学习方法。DQN由DeepMind于2015年提出,并在玩Atari视频游戏方面取得了显著的成功。DQN的核心原理是使用深度神经网络来近似Q函数,即在给定状态下采取某一动作的预期累积奖励。

DQN_25">DQN的关键创新

  1. 使用神经网络近似Q函数

    • 传统的Q学习使用表格(Q表)来存储每个状态-动作对的Q值。当状态空间很大或连续时,这变得不切实际。
    • DQN通过使用深度神经网络来近似Q函数,克服了这一限制。网络输入是状态,输出是该状态下所有可能动作的Q值。
  2. 经验回放

    • DQN引入了经验回放机制,即将代理的经验(状态、动作、奖励、新状态)存储在回放缓冲区中。

      image-20231114211049019
    • 训练时,从这个缓冲区中随机抽取小批量经验进行学习。这增加了数据的多样性,减少了样本之间的相关性,从而稳定了训练。

  3. 目标网络

    • DQN使用两个结构相同但参数不同的网络:一个是在线网络 (dqn_model),用于当前Q值的估计;另一个是目标网络 (target_model),用于计算目标Q值。
    • 目标网络的参数定期从在线网络复制过来,但不是每个训练步骤都更新。这减少了学习过程中的震荡,提高了稳定性。
    image-20231114211236348

训练过程

  • 在每个时间步,代理根据当前的Q值(通常结合探索策略,如ε-贪婪)选择一个动作,接收环境的反馈(新状态和奖励),并将这个转换存储在经验回放缓冲区中。
  • 训练神经网络时,从缓冲区中随机抽取一批经验,然后使用贝尔曼方程计算目标Q值和预测Q值,通过最小化这两者之间的差异来更新网络参数。

DQN_49">DQN解决月球着陆问题

导入环境

import time
from collections import defaultdict

import gymnasium as gym
import numpy as np
import random

from matplotlib import pyplot as plt, animation
from IPython.display import display, clear_output
env = gym.make("LunarLander-v2", continuous=False, render_mode='rgb_array')

定义经验池

class ExperienceBuffer:
    def __init__(self, size=0):
        self.states = []
        self.actions = []
        self.rewards = []
        self.states_next = []
        self.actions_next = []
        self.size = 0

    def clear(self):
        self.__init__()

    def append(self, s, a, r, s_n, a_n):
        self.states.append(s)
        self.actions.append(a)
        self.rewards.append(r)
        self.states_next.append(s_n)
        self.actions_next.append(a_n)
        self.size += 1

    def batch(self, batch_size=128):
        indices = np.random.choice(self.size, size=batch_size, replace=True)
        return  (
            np.array(self.states)[indices],
            np.array(self.actions)[indices],
            np.array(self.rewards)[indices],
            np.array(self.states_next)[indices],
            np.array(self.actions_next)[indices],
        )
import torch

from torch import nn
from torch.nn.functional import relu
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

DQN_115">定义DQN

class DQN(nn.Module):
    def __init__(self, state_size, action_size):
        super().__init__()
        self.state_size = state_size
        self.action_size = action_size
        self.hidden_size = 32
        self.linear_1 = nn.Linear(self.state_size, self.hidden_size)
        self.linear_2 = nn.Linear(self.hidden_size, self.action_size)

        nn.init.uniform_(self.linear_1.weight, a=-0.1, b=0.1)
        nn.init.uniform_(self.linear_2.weight, a=-0.1, b=0.1)

    def forward(self, state):
        if not isinstance(state, torch.Tensor):
            state = torch.tensor([state], dtype=torch.float)
        state = state.to(device)
        return self.linear_2(relu(self.linear_1(state)))

定义policy

def policy(model, state, eval=False):
    eps = 0.1

    if not eval and random.random() < eps:
        return random.randint(0, model.action_size - 1)
    else:
        q_values = model(torch.tensor([state], dtype=torch.float))
        action = torch.multinomial(F.softmax(q_values), num_samples=1)
        return int(action[0])

collect

dqn_model = DQN(state_size=8, action_size=4).to(device)
target_model = DQN(state_size=8, action_size=4).to(device)
from tqdm.notebook import tqdm
# 学习率
alpha = 0.9
# 折扣因子
gamma = 0.95
# 训练次数
episode = 1000
experience_buffer = ExperienceBuffer()

eval_iter = 100
eval_num = 100

# collect
def collect():
    for e in tqdm(range(episode)):
        state, info = env.reset()
        action = policy(dqn_model, state)

        sum_reward = 0

        while True:
            state_next, reward, terminated, truncated, info_next = env.step(action)
            action_next= policy(dqn_model, state_next)

            sum_reward += reward

            experience_buffer.append(
                state, action, reward, state_next, action_next
            )

            if terminated or truncated:
                break

            state = state_next
            info = info_next
            action = action_next

learning

## learning
from torch.optim import Adam

loss_fn = nn.MSELoss()
optimizer = Adam(lr=1e-5, params=dqn_model.parameters())

losses = []
target_fix_period = 5
epoch = 3

def train():
    for e in range(epoch):
        batch_size = 128
        for i in range(experience_buffer.size // batch_size):
            s, a, r, s_n, a_n = experience_buffer.batch(batch_size)

            s = torch.tensor(s, dtype=torch.float).to(device)
            s_n = torch.tensor(s_n, dtype=torch.float).to(device)
            r = torch.tensor(r, dtype=torch.float).to(device)
            a = torch.tensor(a, dtype=torch.long).to(device)
            a_n = torch.tensor(a_n, dtype=torch.long).to(device)

            y = r + target_model(s_n).gather(1, a_n.unsqueeze(1)).squeeze(1)
            y_hat = dqn_model(s).gather(1, a.unsqueeze(1)).squeeze(1)

            loss = loss_fn(y, y_hat)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if i % 500 == 0:
                print(f'i == {i}, loss = {loss} ')

            if i % target_fix_period == 0:
                target_model.load_state_dict(dqn_model.state_dict())

a_n:动作
s_n:状态

image-20231205221613164

image-20231205221643890

将状态 s_n 作为输入,target_model的输出是针对每个可能动作的 Q 值;如果 s_n 包含多个状态(比如一个批量),输出将是一个批量的 Q 值

image-20231205221710717

image-20231205221746045

image-20231205221827050

训练

for i in range(10):
    print(f'collect/train: {i}')
    experience_buffer.clear()
    collect()
    train()

结果

task_num = 10
frames = []

for _ in range(10):
    state, _ = env.reset()
    while True:
        action = policy(dqn_model, state, eval=True)
        state_next, reward, terminated, truncated, info_next = env.step(action)
        frames.append(env.render())

        if terminated or truncated:
            break

output


http://www.niftyadmin.cn/n/5276316.html

相关文章

Springboot管理系统数据权限过滤(三)——0业务入侵实现部门数据权限过滤

上一章&#xff0c;讲解了SQL拦截器的&#xff0c;做了一个简单的SQL改造。本章要实现&#xff1a; 仅对指定的service方法内执行的SQL进行处理&#xff1b;完成对部门权限的过滤&#xff1b; 以下简单用一个图说明了整个处理过程&#xff0c;红色框住的部分&#xff0c;就是…

Linux Docker本地部署WBO在线协作白板结合内网穿透远程访问

文章目录 前言1. 部署WBO白板2. 本地访问WBO白板3. Linux 安装cpolar4. 配置WBO公网访问地址5. 公网远程访问WBO白板6. 固定WBO白板公网地址 前言 WBO在线协作白板是一个自由和开源的在线协作白板&#xff0c;允许多个用户同时在一个虚拟的大型白板上画图。该白板对所有线上用…

element ui el-avatar 源码解析零基础逐行解析

avatar功能介绍 快捷配置头像的样式 avatar 的参数配置 属性说明参数size尺寸type string 类型 &#xff08;‘large’,‘medium’,‘small’&#xff09;number类型 validator 校验shape形状circle (原型) square&#xff08;方形&#xff09;icon传入的iconsrc传入的图片st…

开启服务,跨服务器传送文件

一、目的 将文件从一个服务器&#xff08;源服务器&#xff09;传到另一个服务器&#xff08;目标服务器&#xff09; 二、步骤 源服务器&#xff1a; 1、检查端口号&#xff08;8088是随机指定的&#xff0c;只要端口空闲即可&#xff09; lsof -i:8088 2、开启端口&…

org.springframework.boot.autoconfigure.AutoConfiguration.imports新版自动配置

文章目录 场景 场景 springboot2.7.0之后的版本 自动配置方式有了变化, 新版兼容旧版 旧版新版META-INF/spring.factoriesMETA-INF/spring/org.springframework.boot.autoconfigure.AutoConfiguration.importsConfigurationAutoConfiguration

软文营销未来的发展有哪些特点?媒介盒子分享

随着社交媒体和数字媒体的兴起&#xff0c;软文营销已经成为企业宣传的重要手段。然而&#xff0c;人们的消费观念和购物方式不断改变&#xff0c;软文营销需要根据市场和消费者的需求&#xff0c;不断改进和创新&#xff0c;媒介盒子接下来和大家聊聊软文营销未来的发展有哪些…

C语言—每日选择题—Day56

指针相关博客 打响指针的第一枪&#xff1a;指针家族-CSDN博客 深入理解&#xff1a;指针变量的解引用 与 加法运算-CSDN博客 第一题 1. 以下叙述中正确的是&#xff08;&#xff09; A&#xff1a;\0 表示字符 0 B&#xff1a;"a" 表示一个字符常量 C&#xff1a;表…

Java解决比特维位计数

Java解决比特维位计数 01 题目 给你一个整数 n &#xff0c;对于 0 < i < n 中的每个 i &#xff0c;计算其二进制表示中 1 的个数 &#xff0c;返回一个长度为 n 1 的数组 ans 作为答案。 示例 1&#xff1a; 输入&#xff1a;n 2 输出&#xff1a;[0,1,1] 解释&a…