【强化学习】Q-Learning 案例分析

news/2024/5/18 21:44:08 标签: 强化学习, q-learning, 案例, 路径寻优

前期知识可查看:

案例介绍

寻路案例(强烈建议学习上述前期知识里的【强化学习】 Q-Learning 尤其是看懂前面的小案例

  • 红色为可移动的寻路个体
  • 黑色为惩罚位置【奖励= -1】
  • 黄色为目标位置【奖励= +1】
  • 其他区域为常规状态【奖励= 0】

寻路个体其实位置如图中所示的左上角,目标是移动到黄色位置,采用Q-Learning算法,能够让个体自主探索,最后找到最好的可以从起始点到终点位置的路径,同时绕过黑色区域

程序

案例的程序分为三个部分:

  • maze_env.py : 该案例环境部分,即:该图片以及这些颜色块的搭建,采用了Tkinter
  • RL_brain.py:该案例的Q-Learning的大脑,智能体的大脑部分,所有决策都在这部分
  • run_this.py:该案例的主要实施流程以及更新

1. maze_env.py

环境部分采用了Tkinter搭建这些图片以及这些颜色块,有兴趣的可以仔细分析代码,暂时不细说

import numpy as np
import time
import sys

if sys.version_info.major == 2:
    import Tkinter as tk
else:
    import tkinter as tk

UNIT = 40  # pixels
MAZE_H = 4  # grid height
MAZE_W = 4  # grid width


class Maze(tk.Tk, object):
    def __init__(self):
        super(Maze, self).__init__()
        self.action_space = ['u', 'd', 'l', 'r']
        self.n_actions = len(self.action_space)
        self.title('maze')
        self.geometry('{0}x{1}'.format(MAZE_H * UNIT, MAZE_H * UNIT))
        self._build_maze()

    def _build_maze(self):
        self.canvas = tk.Canvas(self, bg='white',
                                height=MAZE_H * UNIT,
                                width=MAZE_W * UNIT)

        # create grids
        for c in range(0, MAZE_W * UNIT, UNIT):
            x0, y0, x1, y1 = c, 0, c, MAZE_H * UNIT
            self.canvas.create_line(x0, y0, x1, y1)
        for r in range(0, MAZE_H * UNIT, UNIT):
            x0, y0, x1, y1 = 0, r, MAZE_W * UNIT, r
            self.canvas.create_line(x0, y0, x1, y1)

        # create origin
        origin = np.array([20, 20])

        # hell
        hell1_center = origin + np.array([UNIT * 2, UNIT])
        self.hell1 = self.canvas.create_rectangle(
            hell1_center[0] - 15, hell1_center[1] - 15,
            hell1_center[0] + 15, hell1_center[1] + 15,
            fill='black')
        # hell
        hell2_center = origin + np.array([UNIT, UNIT * 2])
        self.hell2 = self.canvas.create_rectangle(
            hell2_center[0] - 15, hell2_center[1] - 15,
            hell2_center[0] + 15, hell2_center[1] + 15,
            fill='black')

        # create oval
        oval_center = origin + UNIT * 2
        self.oval = self.canvas.create_oval(
            oval_center[0] - 15, oval_center[1] - 15,
            oval_center[0] + 15, oval_center[1] + 15,
            fill='yellow')

        # create red rect
        self.rect = self.canvas.create_rectangle(
            origin[0] - 15, origin[1] - 15,
            origin[0] + 15, origin[1] + 15,
            fill='red')

        # pack all
        self.canvas.pack()

    def reset(self):
        self.update()
        time.sleep(0.5)
        self.canvas.delete(self.rect)
        origin = np.array([20, 20])
        self.rect = self.canvas.create_rectangle(
            origin[0] - 15, origin[1] - 15,
            origin[0] + 15, origin[1] + 15,
            fill='red')
        # return observation
        return self.canvas.coords(self.rect)

    def step(self, action):
        s = self.canvas.coords(self.rect)
        base_action = np.array([0, 0])
        if action == 0:  # up
            if s[1] > UNIT:
                base_action[1] -= UNIT
        elif action == 1:  # down
            if s[1] < (MAZE_H - 1) * UNIT:
                base_action[1] += UNIT
        elif action == 2:  # right
            if s[0] < (MAZE_W - 1) * UNIT:
                base_action[0] += UNIT
        elif action == 3:  # left
            if s[0] > UNIT:
                base_action[0] -= UNIT

        self.canvas.move(self.rect, base_action[0], base_action[1])  # move agent

        s_ = self.canvas.coords(self.rect)  # next state

        # reward function
        if s_ == self.canvas.coords(self.oval):
            reward = 1
            done = True
            s_ = 'terminal'
        elif s_ in [self.canvas.coords(self.hell1), self.canvas.coords(self.hell2)]:
            reward = -1
            done = True
            s_ = 'terminal'
        else:
            reward = 0
            done = False

        return s_, reward, done

    def render(self):
        time.sleep(0.1)
        self.update()


def update():
    for t in range(10):
        s = env.reset()
        while True:
            env.render()
            a = 1
            s, r, done = env.step(a)
            if done:
                break


if __name__ == '__main__':
    env = Maze()
    env.after(100, update)
    env.mainloop()

2. RL_brain.py

该部分为Q-Learning的大脑部分,所有的巨册函数都在这儿

(1)参数初始化,包括算法用到的所有参数:行为、学习率、衰减率、决策率、以及q-table

(2)方法1:选择动作:随机数与决策率做对比,决策率为0.9,90%情况选择下一个反馈最大的奖励的行为,10%情况选择随机行为

(3)方法2:学习更新q-table:通过数据参数(该状态、该行为、该行为对该状态的奖励、下一个状态),计算该行为在该状态下的真实值与估计值,然后更新q-table里的预估值

(4)方法3:用来将新的状态作为索引添加在q-table里

import numpy as np
import pandas as pd


class QLearningTable:
    def __init__(self, actions, learning_rate=0.01, reward_decay=0.9, e_greedy=0.9):
        self.actions = actions  # a list
        self.lr = learning_rate  # 学习率
        self.gamma = reward_decay   # 衰减率
        self.epsilon = e_greedy     # 决策率
        self.q_table = pd.DataFrame(columns=self.actions, dtype=np.float64) # 初始化q-table

    # 选择动作
    def choose_action(self, observation):
        self.check_state_exist(observation)     # 检查该状态是否在q-table中存在,如不存在则添加
        # 行为选择
        if np.random.uniform() < self.epsilon:
            # 如果随机数小于0.9 则选择最优行为
            state_action = self.q_table.loc[observation, :]
            # 有一些行为可能存在相同的最大预期值,则在最大预期值行为里随机选择
            action = np.random.choice(state_action[state_action == np.max(state_action)].index)
        else:
            # 如果随机数大于0.9 则随机选择
            action = np.random.choice(self.actions)
        return action

    def learn(self, s, a, r, s_):
        self.check_state_exist(s_) # 检擦新的状态是否存在与q-table
        q_predict = self.q_table.loc[s, a] # 获取该行为在该状态下的估计的奖励预期值
        if s_ != 'terminal':    # 如果新的状态不是最终目的地
            q_target = r + self.gamma * self.q_table.loc[s_, :].max()  # 真实值=该行为对该状态的奖励+衰减率*下一个状态下行为的最大反馈奖励
        else: # 如果到达了最终目的地,没有下一个行为,因此不需要学习了,真实值=该行为对该状态的奖励
            q_target = r  # next state is terminal
        self.q_table.loc[s, a] += self.lr * (q_target - q_predict)  # 更新q-table

    # 检查状态是否存在,若不存在将作为索引添加在 q-table中,行为的值初始化为0
    def check_state_exist(self, state):
        if state not in self.q_table.index:
            # 如果该状态不在q-table的索引里存在,则将该状态添加到q-table中,
            # q-table是dataframe类型,字典的索引为状态,值的表头有四种【0,1,2,3】,分别代表前、后、左、右的行为
            self.q_table = self.q_table.append(
                pd.Series(
                    [0]*len(self.actions),
                    index=self.q_table.columns,
                    name=state,
                )
            )

3. run_this.py

该脚本就是算法实施的主要流程:

q4.png

(1)引入了环境(maze_env)和大脑(RL_brain)

(2)环境reset()给出了初始状态(Initialize s)

(3)循环开始(Repeat)

  • 刷新环境
  • 挑选动作(Choose a from s)
  • 从环境中获取该动作对该状态的反馈(下一个状态、该行为对该状态的奖励、是否到达终点)(observe r, s')
  • 开始学习,输入参数(该状态、该行为、该行为对该状态的奖励、下一个状态)——对比估计值和现实值,学习完之后将更新q-table(Q(s,a))
  • 更新个体状态到新的状态(s<--s')
  • 判断是否到达终点:跳出循环(until s is terminal)
from maze_env import Maze           # 环境模块
from RL_brain import QLearningTable # 大脑


def update(): # (Repeat)
    for episode in range(100):  # 100个回合

        observation = env.reset()   # 初始化观察值 (Initialize s)

        while True:
            env.render() # 刷新环境
            action = RL.choose_action(str(observation)) # 挑选动作 (Choose a from s)
            observation_, reward, done = env.step(action) # 在环境里施加动作,获取下一个状态、行为对于该状态的奖励、是否完成 (observe r, s')
            # 一个回合之后的学习,将该状态,该行为,该行为对于该状态的奖励,以及下一个状态   输入到学习的方法中 
            RL.learn(str(observation), action, reward, str(observation_)) # Q(s,a))
            observation = observation_ # 更新个体状态 (s<--s')
            if done:
                break # 如果反馈达到目的地,跳出循环 (until s is terminal)

    # end of game
    print('game over')
    env.destroy()


if __name__ == "__main__":
    env = Maze()
    RL = QLearningTable(actions=list(range(env.n_actions)))

    env.after(100, update)
    env.mainloop()

完成之后,在run_this.py里运行就可以看到学习探索路径的过程了

 

代码以及学习过程来源:莫烦Python教学(十分感谢莫烦大佬的教学视频)


http://www.niftyadmin.cn/n/1193388.html

相关文章

娱乐之神兽羊驼(原创)

重要的事情说三遍 : 可以转载&#xff01;可以转载&#xff01;可以转载&#xff01; 我看网上很难找到羊驼的代码&#xff0c;所以决定自己写一个玩&#xff0c;代码没什么科技含量&#xff01; 仅供娱乐哦&#xff01; 代码如下: 1 #include "stdafx.h"2 #include&…

leetcode:linked_list_cycle_II

一、 题目 给定一个链表&#xff0c;假设链表中有环则返回环的開始节点&#xff0c;否则返回NULL。要求不用额外的空间完毕。 二、 分析 在I中&#xff0c;我们推断环的存在&#xff0c;即用slow和fast两个指针&#xff0c;设定步长fast2;slow1;假设两个指针能够相遇则…

【强化学习】Sarsa

【强化学习】相关基本概念【强化学习】 Q-Learning【强化学习】 Q-Learning 案例分析【强化学习】 Sarsa【强化学习】 Sarsa&#xff08;lambda&#xff09;Sarsa概述 首先可以回顾一下之前说的Q-Learning算法&#xff0c;Sarsa算法与Q-Learning算法很相似&#xff0c; Q-Lear…

Activity与Fragment之间的通信

由于Fragment的生命周期完全依赖宿主Activity&#xff0c;所以当我们在使用Fragment时难免出现Activity和Fragment间的传值通信操作。 1、Activity向Fragment&#xff0c;通过声明的Fragment对象的setArguments(bundle)方法来实现Activity到Fragment的传递 1 Override2 …

HTML标题

HTML 标题 在 HTML 文档中&#xff0c;标题很重要。 HTML 标题 标题&#xff08;Heading&#xff09;是通过 <h1> - <h6> 标签进行定义的. <h1> 定义最大的标题。 <h6> 定义最小的标题。 实例 <!DOCTYPE html> <html> <head> <…

【强化学习】Sarsa(lambda)

【强化学习】相关基本概念【强化学习】 Q-Learning【强化学习】 Q-Learning 案例分析【强化学习】 Sarsa【强化学习】 Sarsa&#xff08;lambda&#xff09;Sarsa(λ) 1. Sarsa(λ) 是基于Sarsa算法的一种提速算法&#xff0c;为什么是提速呢&#xff1f; Sarsa算法&#xff…

我所经历的Android面试|掘金技术征文

概述时隔一个多月&#xff0c;我又回来了。这段时间有不少人问我最近在干嘛&#xff0c;面经什么时候写&#xff0c;怎么这么久没更文了等等等等。当然了&#xff0c;最近我一直在执行了一次我计划了近半年的跳槽。总得而言还不错。说下我的情况。我是去年九月份开始计划&#…

windows 系统上安装 tensorflow

首先保证已经安装有Anaconda和Python 1. 打开Anaconda prompt 前提环境 2. 在Anaconda Promt依次输入&#xff1a; python 版本查看可以先输入python --version conda config --add channels https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/conda config --set …