强化学习的Sarsa与Q-Learning的Cliff-Walking对比实验

news/2024/5/19 1:12:16 标签: AI, Sarsa, Q-Learning, 强化学习, python

强化学习SarsaQ-Learning的Cliff-Walking对比实验

  • Cliff-Walking问题的描述
  • SarsaQ-Learning算法对比
  • 代码分享
  • 需要改进的地方
  • 引用和写在最后

Cliff-Walking问题的描述

在这里插入图片描述

悬崖行走:从S走到G,其中灰色部分是悬崖不可到达,求可行方案
建模中,掉下悬崖的奖励是-100,G的奖励是10,原地不动的奖励-1,到达非终点位置的奖励是0(与图中的示意图不一致,不过大差不差),分别使用同轨策略的Sarsa与离轨策略的Q-learning算法,经过20000幕进化迭代得出safe path,optimal path,最后根据Q值来得出最终的策略,以此来对上图进行复现

SarsaQLearning_7">SarsaQ-Learning算法对比

Sarsa算法
在这里插入图片描述
Q-Learning算法

在这里插入图片描述首先要介绍的是什么是ε-greedy,即ε-贪心算法,一般取定ε为一个较小的0-1之间的值(比如0.2)
在算法进行的时候,用计算机产生一个伪随机数,当随机数小于ε时采取任意等概率选择的原则,大于ε时则取最优的动作。

在介绍完两个算法和ε-贪心算法之后,一言概之就是,Sarsa对于当前状态s的a的选择是ε-贪心的,对于s’的a‘的选择也是ε-贪心的Q-Learning与sarsa一样,只是对于s’的a‘的选择是直接取最大的。

代码分享

python">import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches  # 图形类

np.random.seed(2022)


class Agent():
    terminal_state = np.arange(36, 48)  # 终止状态

    def __init__(self, board_rows, board_cols, actions_num, epsilon=0.2, gamma=0.9, alpha=0.1):
        self.board_rows = board_rows
        self.board_cols = board_cols
        self.states_num = board_rows * board_cols
        self.actions_num = actions_num
        self.epsilon = epsilon
        self.gamma = gamma
        self.alpha = alpha
        self.board = self.create_board()
        self.rewards = self.create_rewards()
        self.qtable = self.create_qtable()

    def create_board(self):  # 创建面板
        board = np.zeros((self.board_rows, self.board_cols))
        board[3][11] = 1
        board[3][1:11] = -1
        return board

    def create_rewards(self):  # 创建奖励表
        rewards = np.zeros((self.board_rows, self.board_cols))
        rewards[3][11] = 10
        rewards[3][1:11] = -100
        return rewards

    def create_qtable(self):  # 创建Q值
        qtable = np.zeros((self.states_num, self.actions_num))
        return qtable

    def change_axis_to_state(self, axis):  # 将坐标转化为状态
        return axis[0] * self.board_cols + axis[1]

    def change_state_to_axis(self, state):  # 将状态转化为坐标
        return state // self.board_cols, state % self.board_cols

    def choose_action(self, state):  # 选择动作并返回下一个状态
        if np.random.uniform(0, 1) <= self.epsilon:
            action = np.random.choice(self.actions_num)
        else:
            p = self.qtable[state, :]
            action = np.random.choice(np.where(p == p.max())[0])

        r, c = self.change_state_to_axis(state)
        new_r = r
        new_c = c

        flag = 0

        #状态未改变
        if action == 0:  # 上
            new_r = max(r - 1, 0)
            if new_r == r:
                flag = 1
        elif action == 1:  # 下
            new_r = min(r + 1, self.board_rows - 1)
            if new_r == r:
                flag = 1
        elif action == 2:  # 左
            new_c = max(c - 1, 0)
            if new_c == c:
                flag = 1
        elif action == 3:  # 右
            new_c = min(c + 1, self.board_cols - 1)
            if new_c == c:
                flag = 1

        r = new_r
        c = new_c
        if flag:
            reward = -1 + self.rewards[r,c]
        else:
            reward = self.rewards[r, c]

        next_state = self.change_axis_to_state((r, c))
        return action, next_state, reward


    def learn(self, s, r, a, s_,sarsa_or_q):
        # s状态,a动作,r即时奖励,s_演化的下一个动作
        q_old = self.qtable[s, a]
        # row,col = self.change_state_to_axis(s_)
        done = False
        if s_ in self.terminal_state:
            q_new = r
            done = True
        else:
            if sarsa_or_q == 0:
                if np.random.uniform(0.1) <= self.epsilon:
                    s_a = np.random.choice(self.actions_num)
                    q_new = r + self.gamma * self.qtable[s_, s_a]
                else:
                    q_new = r + self.gamma * max(self.qtable[s_, :])
            else:
                q_new = r + self.gamma * max(self.qtable[s_, :])
                # print(q_new)
        self.qtable[s, a] += self.alpha * (q_new - q_old)
        return done


    def initilize(self):
        start_pos = (3, 0)  # 从左下角出发
        self.cur_state = self.change_axis_to_state(start_pos)  # 当前状态
        return self.cur_state


    def show(self,sarsa_or_q):
        fig_size = (12, 8)
        fig, ax0 = plt.subplots(1, 1, figsize=fig_size)
        a_shift = [(0, 0.3), (0, -.4),(-.3, 0),(0.4, 0)]
        ax0.axis('off')  # 把横坐标关闭
        # 画网格线
        for i in range(self.board_cols + 1):  # 按列画线
            if i == 0 or i == self.board_cols:
                ax0.plot([i, i], [0, self.board_rows], color='black')
            else:
                ax0.plot([i, i], [0, self.board_rows], alpha=0.7,
                     color='grey', linestyle='dashed')

        for i in range(self.board_rows + 1):  # 按行画线
            if i == 0 or i == self.board_rows:
                ax0.plot([0, self.board_cols], [i, i], color='black')
            else:
                ax0.plot([0, self.board_cols], [i, i], alpha=0.7,
                         color='grey', linestyle='dashed')

        for i in range(self.board_rows):
            for j in range(self.board_cols):

                y = (self.board_rows - 1 - i)
                x = j

                if self.board[i, j] == -1:
                    rect = patches.Rectangle(
                        (x, y), 1, 1, edgecolor='none', facecolor='black', alpha=0.6)
                    ax0.add_patch(rect)
                elif self.board[i, j] == 1:
                    rect = patches.Rectangle(
                        (x, y), 1, 1, edgecolor='none', facecolor='red', alpha=0.6)
                    ax0.add_patch(rect)
                    ax0.text(x + 0.4, y + 0.5, "r = +10")

                else:
                    # qtable
                    s = self.change_axis_to_state((i, j))
                    qs = agent.qtable[s, :]
                    for a in range(len(qs)):
                        dx, dy = a_shift[a]
                        c = 'k'
                        q = qs[a]
                        if q > 0:
                            c = 'r'
                        elif q < 0:
                            c = 'g'
                        ax0.text(x + dx + 0.3, y + dy + 0.5,
                                 "{:.1f}".format(qs[a]), c=c)

        if sarsa_or_q == 0:
            ax0.set_title("Sarsa")
        else:
            ax0.set_title("Q-learning")
        if sarsa_or_q == 0:
            plt.savefig("Sarsa")
        else:
            plt.savefig("Q-Learning")
        plt.show(block=False)
        plt.pause(5)
        plt.close()

加上下面这一段,就可以使程序跑起来啦!

python">agent = Agent(4, 12, 4)
maxgen = 20000
gen = 1
sarsa_or_q = 0
while gen < maxgen:
    current_state = agent.initilize()
    while True:
        action, next_state, reward = agent.choose_action(current_state)
        done = agent.learn(current_state, reward, action, next_state,sarsa_or_q)
        current_state = next_state
        if done:
            break

    gen += 1

agent.show(sarsa_or_q)
print(agent.qtable)

设置sarsa_or_q分别为0和1可以查看采用不同方法计算得的结果示意图
根据Q值就可以得到最后的收敛策略
在这里插入图片描述
在这里插入图片描述

需要改进的地方

代码迭代的收敛太慢,笔者写的代码迭代了20000才收敛,这与课程中的100幕左右就收敛的结果是不一致的,算法的效率上还需要改进。值得补充的是,100幕左右收敛在迭代最大代数中并没有做到,所以在模拟仿真的时候,索性就选择了20000次,说不定提前就收敛了。
可以改进的地方:对模型进行建立,因为之前代码是无模型的,设立模型对策略进行引导会得到更好的结果,当然也有可能使问题陷入局部探索之中,这是继续深入学习需要讨论的。
与科研科研结合的地方:在研究方向上,如果要结合的话,需要学习多个个体在环境下同时学习时的处理方法
在这里插入图片描述

引用和写在最后

Cliff-Walking仿真的是Reinforcement Learning Course by David Silver中第五讲课中的例子
课程的地址给在这里
记录一下强化学习课程的学习暂时完结,完结撒花,哒哒!


http://www.niftyadmin.cn/n/23871.html

相关文章

论文投稿指南——中文核心期刊推荐(地球物理学)

【前言】 &#x1f680; 想发论文怎么办&#xff1f;手把手教你论文如何投稿&#xff01;那么&#xff0c;首先要搞懂投稿目标——论文期刊 &#x1f384; 在期刊论文的分布中&#xff0c;存在一种普遍现象&#xff1a;即对于某一特定的学科或专业来说&#xff0c;少数期刊所含…

[单调栈][st表]Max GEQ Sum Codeforces1691D

You are given an array aa of nn integers. You are asked to find out if the inequality max(ai,ai1,…,aj−1,aj)≥aiai1⋯aj−1ajmax(ai,ai1,…,aj−1,aj)≥aiai1⋯aj−1aj holds for all pairs of indices (i,j)(i,j), where 1≤i≤j≤n1≤i≤j≤n. Input Each test …

第21章 SQL RIGHT JOIN 关键字教程

RIGHT JOIN 关键字从右表&#xff08;table2&#xff09;return 所有的行&#xff0c;即使左表&#xff08;table1&#xff09;中没有匹配。如果左表中没有匹配&#xff0c;则结果为 NULL。 SQL RIGHT JOIN 语法 SELECT column_name(s)FROM table1RIGHT JOIN table2ON table1.…

Springboot扩展点之BeanDefinitionRegistryPostProcessor

前言通过这篇文章来大家分享一下&#xff0c;另外一个Springboot的扩展点BeanDefinitionRegistryPostProcessor&#xff0c;一般称这类扩展点为容器级后置处理器&#xff0c;另外一类是Bean级的后置处理器&#xff1b;容器级的后置处理器会在Spring容器初始化后、刷新前这个时间…

【算法刷题】哈希表题型及方法归纳

哈希表特点 常见的三种哈希结构&#xff1a; 1、数组&#xff1a;操作简单&#xff0c;方便快捷&#xff0c;但不适于进行一些更复杂的操作。 注&#xff1a;适用于用set或map的情景&#xff1a;&#xff08;1&#xff09;当数组大小受限&#xff1b;&#xff08;2&#xff0…

信息学奥赛一本通 1916:【01NOIP普及组】求先序排列 | 洛谷 P1030 [NOIP2001 普及组] 求先序排列

【题目链接】 ybt 1916&#xff1a;【01NOIP普及组】求先序排列 洛谷 P1030 [NOIP2001 普及组] 求先序排列 【题目考点】 1. 二叉树 【解题思路】 已知中序、后序遍历序列&#xff0c;构建二叉树&#xff0c;而后对该二叉树做先序遍历&#xff0c;得到先序遍历序列。 该题…

Keil MDK 配置详解与调试技术

工程配置介绍① 通用配置选项&#xff1b;② 操作系统选项&#xff1b;③ 勾选后可以减小镜像尺寸&#xff0c;加快运行速度&#xff1b;④ 浮点配置&#xff1b;⑤ 加载简要配置&#xff0c;分散加载情况需要配置&#xff1b;编译器输出选项&#xff1b;可执行…

rabbitmq+netcore6 【2】Work Queues:一个生产者两个消费者

文章目录1&#xff09;准备工作2&#xff09;新建消费者13&#xff09;新建消费者24&#xff09;生产者5&#xff09;知识点解读1、autoAck: true2、重复声明/前后不一致3、Message durability 消息持久化4、Fair Dispatch 公平调度5、综合以上知识点的代码&#xff1a;官网参考…