强化学习用 Sarsa 算法与 Q-learning 算法实现FrozenLake-v0

news/2024/5/18 22:27:24 标签: 强化学习, python

基础知识

关于Q-learning 和 Sarsa 算法, 详情参见博客 强化学习(Q-Learning,Sarsa)
Sarsa 算法框架为Sarsa算法
Q-learning 算法框架为
在这里插入图片描述

关于FrozenLake-v0环境介绍, 请参见https://copyfuture.com/blogs-details/20200320113725944awqrghbojzsr9ce
在这里插入图片描述
此图来自 强化学习FrozenLake求解

需要注意的细节

训练时

  • 采用 ϵ \epsilon ϵ 贪心算法;
# 贪婪动作选择,含嗓声干扰
a = np.argmax(Q_all[s, :] + np.random.randn(1, env.action_space.n) * (1. / (i + 1)))
  • 对 Q-learning 算法
# 更新Q表
# Q-learning
Q_all[s, a] = Q_all[s, a] + alpha * (r + gamma * np.max(Q_all[s1, :]) - Q_all[s, a])
  • 对 Sarsa 算法
# sarsa
# 更新Q表
a_ = np.argmax(Q_all[s1, :] + np.random.randn(1, env.action_space.n) * (1. / (i + 1)))
Q_all[s, a] = Q_all[s, a] + alpha * (r + gamma * Q_all[s1, a_] - Q_all[s, a])

测试时

  • 不采用 ϵ \epsilon ϵ 贪心算法;
a = np.argmax(Q_all[s, :])
  • 不更新Q表
# # 不更新Q表
# Q_all[s, a] = Q_all[s, a] + alpha * (r + gamma * np.max(Q_all[s1, :]) - Q_all[s, a])

寻找模型中最优的 α \alpha α, γ \gamma γ

我们计算一下不同参数下的学习率, 如下图所示
在这里插入图片描述
在这里插入图片描述
比较两种算法的准确率, 我们用Q-learning算法的准确率减掉Sarsa的准确率, 得到

从图中可以看到, 大于0的点均表明在此点对应的 α , γ \alpha,\gamma α,γ下, Q-learning 准确率高于Sarsa.

Python代码

import gym
import numpy as np
import random
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

# gym创建冰湖环境
env = gym.make('FrozenLake-v0')
env.render()  # 显示初始environment
# 初始化Q表格,矩阵维度为【S,A】,即状态数*动作数
Q_all = np.zeros([env.observation_space.n, env.action_space.n])
# 设置参数,
# 其中α\alpha 为学习速率(learning rate),γ\gamma为折扣因子(discount factor)
alpha = 0.8
gamma = 0.95
num_episodes = 2000
#
Alpha = np.arange(0.75, 1, 0.02)
Gamma = np.arange(0.1, 1, 0.05)
#Alpha = np.ones_like(Gamma)*0.97
# Training
correct_train = np.zeros([len(Alpha), len(Gamma)])
correct_test = np.zeros([len(Alpha), len(Gamma)])
for k in range(len(Alpha)):
    for p in range(len(Gamma)):
        alpha = Alpha[k]
        gamma = Gamma[p]

        # training
        rList = []
        for i in range(num_episodes):
            # 初始化环境,并开始观察

            s = env.reset()
            rAll = 0
            d = False
            j = 0
            # 最大步数
            while j < 99:
                j += 1
                # 贪婪动作选择,含嗓声干扰
                a = np.argmax(Q_all[s, :] + np.random.randn(1, env.action_space.n) * (1. / (i + 1)))

                # 从环境中得到新的状态和回报
                s1, r, d, _ = env.step(a)
                # 更新Q表
                # Q-learning
                Q_all[s, a] = Q_all[s, a] + alpha * (r + gamma * np.max(Q_all[s1, :]) - Q_all[s, a])
                # sarsa
                a_ = np.argmax(Q_all[s1, :] + np.random.randn(1, env.action_space.n) * (1. / (i + 1)))
                Q_all[s, a] = Q_all[s, a] + alpha * (r + gamma * Q_all[s1, a_] - Q_all[s, a])
                # 累加回报
                rAll += r
                # 更新状态
                s = s1
                # Game Over
                if d:
                    break
            rList.append(rAll)
        correct_train[k, p] = (sum(rList) / num_episodes)
        # test
        rList = []
        for i in range(num_episodes):
            # 初始化环境,并开始观察
            s = env.reset()
            rAll = 0
            d = False
            j = 0
            # 最大步数
            while j < 99:
                j += 1
                # 贪婪动作选择,含嗓声干扰
                a = np.argmax(Q_all[s, :])
                # 从环境中得到新的状态和回报
                s1, r, d, _ = env.step(a)
                # # 更新Q表
                # Q_all[s, a] = Q_all[s, a] + alpha * (r + gamma * np.max(Q_all[s1, :]) - Q_all[s, a])
                # 累加回报
                rAll += r
                # 更新状态
                s = s1
                # Game Over
                if d:
                    break
            rList.append(rAll)
        correct_test[k, p] = sum(rList) / num_episodes

# print("Score over time:" + str(sum(rList) / num_episodes))
# print("打印Q表:", Q_all)

# Test
plt.figure()
ax = plt.subplot(1, 1, 1)
h = plt.imshow(correct_train, interpolation='nearest', cmap='rainbow',
               extent=[0.75, 1, 0, 1],
               origin='lower', aspect='auto')
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
plt.colorbar(h, cax=cax)
plt.show()

参考文献

【1】https://blog.csdn.net/kyolxs/article/details/86693085
【2】 强化学习(Q-Learning,Sarsa)
【3】 强化学习FrozenLake求解
【4】https://copyfuture.com/blogs-details/20200320113725944awqrghbojzsr9ce


http://www.niftyadmin.cn/n/1394635.html

相关文章

半正定规划简介

本文主要内容来自 Vandenberghe, L., & Boyd, S. (1996). Semidefinite Programming. SIAM Review, 38(1), 49–95. 考虑如下优化问题 min⁡cTxs.t.F(x)≥0,(1)\begin{array}{ll} \min& c^Tx\\ s.t.&F(x)\geq 0, \end{array} \tag{1}mins.t.​cTxF(x)≥0,​(1) 其…

PhaseDNN: 采用相位技术与神经网络来求解高频波问题

大概从18年开始&#xff0c; 许志钦老师写了一系列有关DeepLearning 与Frequency有关的文章&#xff0c; 里面讲述了深度学习与傅里叶变换之间的关系&#xff0c; 具体可以参见许志钦老师的主页&#xff08;现在在交大工作&#xff09; https://ins.sjtu.edu.cn/people/xuzhiqi…

php采集新浪视频hlv格式源地址

为什么80%的码农都做不了架构师&#xff1f;>>> 比如&#xff1a; http://video.sina.com.cn/v/b/101268120-2809258130.html 这个页面的视频&#xff0c;可以采集到&#xff1a; http://edge.v.iask.com/101268120.hlv?KIDsina,viask&Expires1366473600&a…

Data driven governing equations approximations using DNN

本文内容源自 DATA DRIVEN GOVERNING EQUATIONS APPROXIMATION USING DEEP NEURAL NETWORKS。他的通讯作者为Ohio State University 的 Prof DONGBIN XIU&#xff0c;印象中好像是 JCP 的 Associate Editor。 我们提供了一个使用观察数据和深度神经网络来近似未知控制方程的数…

第一章练习题——加密算法(字母循环,取往后的第四个)

要求是&#xff1a;对字母进行加密运算&#xff0c;加密规则是所有的字母循环向后取后面第四个字母&#xff0c;即&#xff1a;A->E、b->f、v->z、w->a、Z->D 一开始还没能运行通过&#xff1a; #include<iostream> using namespace std; int main(){char…

用PCA,LDA,KNN对MNIST数据集分类(Python)

主成分分析 对于高维空间中xxx, 我们寻求线性变换&#xff0c; yWTx,wherex∈Rn,W∈Rmd,y∈Rd,d<m.yW^Tx,\quad where \; x\in \mathbb{R}^n,W\in \mathbb{R}^{m\times d},y\in\mathbb{R}^d,d<m.yWTx,wherex∈Rn,W∈Rmd,y∈Rd,d<m. 我们采用可重构性观点&#xff0c; …

5. Hive的三种去重方法

文章目录 Hive的三种去重方法1. distinct2. group by3. row_number()4. 三者的效率对比参考链接 Hive的三种去重方法 1. distinct -- 语法SELECT DISTINCT column1, column2, ... FROM table_name;注意事项&#xff1a; distinct 不能单独用于指定某一列&#xff0c;必须放在…

用全连接神经网络解决Letter Recognition分类任务(Python)

首先&#xff0c;我们要下载Letter Recognition数据。 Letter Recognition 是字符识别任务&#xff0c;有20000个数据&#xff0c;每个数据17维&#xff0c;其中有一维是给定标签&#xff08;26个英文字母&#xff09;。 我们首先下载Letter Recognition 数据集&#xff0c;见…