强化学习代码实战入门 | 井字棋Tic-Tac-Toe代码详解

news/2024/5/19 0:11:12 标签: 人工智能, python, 代码解读, 强化学习

这是一个易理解的 demo,300行左右,可以作为RL的入门代码,辅助基础公式的理解

这个是我自己的学习笔记。三连留下邮箱,可以直接发送完整的代码标注文件~

如有错误,麻烦指出!我已经蛮久没写博了,上一篇RL博客也快一年半了,很久没做这一块了。硕士刚入学,兜兜转转还是回到了RL。欢迎交流~

井字棋规则:感觉就是三子棋。3 * 3棋盘,先连成3子胜利。

代码概览

🌱 类的定义:定义了State、Judger、Player、HumanPlayer四个类,分别代表棋局状态、下棋(裁判)、AI棋手、人类棋手

🌱 状态s:每个棋局是一种状态,使用hash标识唯一的状态。共3^9种状态

🌱 训练过程:首先让2个AI棋手对战,以逐渐完善策略(价值状态函数)。AI棋手训练完后,让AI棋手和人类棋手对战

🌱 训练AI棋手时:初始时,设置胜局状态value为1,输局状态value为0,其余为0.5。然后backup更新,即利用 V(s) ⬅ V(s) + α[V(s')-V(s)] 不断修正value,直到逐渐收敛。α是步长

代码讲解

main主函数

先来看下主函数:

python">if __name__ == '__main__':
    train(int(1e5))   # 1e5是浮点型。是epoch  # 2AI对战,完善value function
    compete(int(1e3))   # 2AI对战训练完后,再对战自测胜率 
    play()   # 人类和AI对战

总之就是先AI对战,再人机大战

State状态类

引入包并定义3*3棋局:

python">import numpy as np
import pickle

BOARD_ROWS = 3
BOARD_COLS = 3
BOARD_SIZE = BOARD_ROWS * BOARD_COLS

再进行State状态类的定义。State类包括函数:

🌻 __init__()初始化、

🌻 hash()计算每个状态的哈希值以索引、

🌻 is_end()检查棋局是否结束、

🌻 next_station()函数将棋手标志放至下一个下棋位置上、

🌻 print()打印当前3*3棋局 

详细注释见代码:

python"># 每一个state是棋盘的整个状态,共3^9个状态
class State:
    def __init__(self):
        # 1 symbol: 先行player
        # -1 symbol:后行player
        # 0 symbol:empty position
        self.data = np.zeros((BOARD_ROWS, BOARD_COLS))  # 代表board
        self.winner = None
        self.hash_val = None    # 使用hash标识每个状态
        self.end = None

    # 计算每个状态的哈希值(规则随机)
    def hash(self):
        if self.hash_val is None:
            self.hash_val = 0
            for i in self.data.reshape(BOARD_COLS * BOARD_ROWS):
                if i == -1:
                    i = 2
                self.hash_val = self.hash_val * 3 + i
        return int(self.hash_val)
    
    # 检查游戏是否分出胜负,或是平局
    def is_end(self):
        if self.end is not None:
            return self.end
        
        results = []
        # check row
        for i in range(0, BOARD_ROWS):
            results.append(np.sum(self.data[i, :]))
        # check columns
        for i in range(0, BOARD_COLS):
            results.append(np.sum(self.data[i, :]))
        # check diagnoals
        results.append(0)
        for i in range(0, BOARD_ROWS):
            results[-1] += self.data[i, i]
        results.append(0)
        for i in range(0, BOARD_COLS):
            results[-1] += self.data[i, BOARD_ROWS -1 - i]

        for result in results:
            if result == 3:
                self.end = True
                self.winner = 1
                return self.end
            if result == -3:
                self.end = True
                self.winner = -1
                return self.end
            
        # check tie
        sum = np.sum(np.abs(self.data))
        if sum == BOARD_COLS * BOARD_ROWS:
            self.end = True
            self.winner = 0  # 平局
            return self.end
        
        # 非胜负/平局,继续游戏
        self.end = False
        return self.end

    # 下一个状态
    # 将棋手标志放置board位置(i, j)
    def next_station(self, i, j, symbol):
        new_state = State()
        new_state.data = np.copy(self.data)
        new_state.data[i, j] = symbol
        return new_state
    
    # 打印棋局
    def print(self):
        for i in range(0, BOARD_ROWS):
            print('-----------------')
            out = '| '
            for j in range(0, BOARD_COLS):
                if self.data[i, j] == 1:
                    token = '*'
                if self.data[i, j] == -1:
                    token = 'x'
                if self.data[i, j] == 0:
                    token = '0'
                out += token + ' | '
            print(out)
        print('-----------------')
python"># 检索当前状态下,所有下一个可能动作带来的状态变换
def get_all_states_impl(current_state, current_symbol, all_states):
    for i in range(0, BOARD_ROWS):
        for j in range(0, BOARD_COLS):
            if current_state.data[i][j] == 0:   # 检索目前所有空格子
                newState = current_state.next_state(i, j, current_symbol)
                newHash = newState.hash()
                if newHash not in all_states.keys():
                    isEnd = newState.is_end()
                    all_states[newHash] = (newState, isEnd)
                    if not isEnd:   # 如果棋手1下完还没结束棋局,棋手2下
                        get_all_states_impl(newState, -current_symbol, all_states)

def get_all_states():
    current_symbol = 1
    current_state = State()
    all_states = dict()
    all_states[current_state.hash()] = (current_state, current_state.is_end())
    get_all_states_impl(current_state, current_symbol, all_states)
    return all_states

# all_states字典:key是某状态对应的唯一哈希值,value是(state,isEnd)
all_states = get_all_states()

Judger裁判类

Judger类是裁判,其实就是两个棋手轮流下棋。包括函数:

🌻 __init__() 初始化

🌻 reset() 重置

🌻 alternate() 轮流选择下棋手

🌻 play() 双方下棋

详细注释见代码:

python">class Judger:
    def __init__(self, player1, player2):
        self.p1 = player1
        self.p2 = player2
        self.current_player = None
        self.p1_symbol = 1
        self.p2_symbol = -1
        self.p1.set_symbol(self.p1_symbol)
        self.p2.set_symbol(self.p2_symbol)
    
    def reset(self):
        self.p1.reset()
        self.p2.reset()

    def alternate(self):
        while True:
            yield self.p1
            yield self.p2
    
    # play函数用于双方轮流下棋(这个play函数是两个AI棋手下),
    # act函数用于为当前player选择value最高的下棋位置
    # @print:if True, print each board during the game
    def play(self, print = False):
        alternator = self.alternate()
        self.reset()
        current_state = State()
        self.p1.set_state(current_state)
        self.p1.set_state(current_state)
        while True: # 一直到棋局结束,return了才结束循环
            player = next(alternator)   # 双方轮流下棋
            if print:
                current_state.print()
                [i, j, symbol] = player.act()    # 为棋手选择下一步最佳落子
            next_state_hash = current_state.next_state(i, j, symbol)
            current_state, is_end = all_states[next_state_hash]
            self.p1.set_state(current_state)
            self.p2.set_state(current_state)
            if is_end:
                if print:
                    current_state.print()
                return current_state.winner

Player棋手(AI)类

Player是AI棋手类。包括函数:

🌻 __init__() 初始化

🌻 reset() 重置

🌻 set_state() 设置状态及是否explore

🌻 set_symbol() 状态价值初始化赋值

🌻 backup() 反向更新状态价值:V(s) ⬅ V(s) + α[V(s')-V(s)]

🌻 act() 当前state下,选择最优action

🌻 save_policy() 保存策略(就是estimations价值)

🌻 load_policy() 加载策略

详细注释见代码:

python"># AI player
# 关于value function(state的函数)解释:https://face2ai.com/RL-RSAB-1-5-An-Extended-Example/
class Player:
    # @step_size: the step size to update estimation
    # (好像就是value function),back up更新里的α,详见上方链接解释
    # @epsilon: the probability to explore
    def __init__(self, step_size = 0.1, epsilon = 0.1):
        self.estimations = dict()
        self.step_size = step_size
        self.epsilon = epsilon
        self.states = []
        self.greedy = []

    def reset(self):
        self.states = []
        self.greedy = []

    def set_state(self, state):
        self.states.append(state)
        self.greedy.append(True)    # 应该是exploit而完全不explore

    # 这个函数就是给状态state赋value的(estimation字典,key为hash(对应某个状态),值为value),
    # 在棋局结束状态下,如果这个状态赢,赋1;若平局,赋0.5,输则赋0
    # 正在进行时的棋局状态一律赋0.5
    def set_symbol(self, symbol):
        self.symbol = symbol
        for hash_val in all_states.keys():
            (state, is_end) = all_states[hash_val]
            if is_end:
                if state.winner == self.symbol:
                    self.estimations[hash_val] = 1.0
                elif state.winner == 0:
                    self.estimations[hash_val] = 0.5
                else:
                    self.estimations[hash_val] = 0
            else:
                self.estimations[hash_val] = 0.5
    
    # 反向更新价值状态函数(value estimation)
    def backup(self):
        self.states = [state.hash() for state in self.states]

        # 反向更新:V(s) ⬅ V(s) + α[V(s')-V(s)]
        # 可参考链接:https://face2ai.com/RL-RSAB-1-5-An-Extended-Example/
        for i in reversed(range(len(self.states) - 1)):
            state = self.states[i]
            td_error = self.greedy[i] * (self.estimations[self.states[i + 1]] - self.estimations[state])
            self.estimations[state] += self.step_size * td_error

    # 当前state下,选择下一步的最优action
    # act函数的返回结果为下棋位置和棋手标志(i, j, symbol)
    def act(self):
        state = self.states[-1]
        next_states = []
        next_positions = []
        
        # 找出目前state下所有空位
        for i in range(BOARD_ROWS):
            for j in range(BOARD_COLS):
                if state.data[i, j] == 0:
                    next_positions.append([i,j])
                    next_states.append(state.next_state(i, j, self.symbol).hash())
        
        # explore
        if np.random.rand() < self.epsilon:    # 随机生成(0,1)之间数 
            action = next_positions[np.random.randint(len(next_positions))]
            action.append(self.symbol)
            self.greedy[-1] = False
            return action

        # 否则exploit
        values = []
        for hash, pos in zip(next_states, next_positions):
            values.append((self.estimations[hash], pos))
        np.random.shuffle(values)
        values.sort(key = lambda x:x[0], reverse = True)    # 按照state的value值大小倒序排
        action = values[0][1]    # 选择value值最大的action的位置
        action.append(self.symbol)  # 为这个动作加上棋手标志
        return action
    
    def save_policy(self):
        # bin是二进制格式的文件
        with open('policy_%s.bin' % ('first' if self.symbol == 1 else 'second'), 'wb') as f:
            pickle.dump(self.estimations, f)    # 对象存储
        
    def load_policy(self):
        with open('policy_%s.bin' % ('first' if self.symbol == 1 else 'second'), 'rb') as f:
            self.estimations = pickle.load(f)

HumanPlayer棋手(人类)类

HumanPlayer类就是人类棋手。包括函数:

🌻 __init__() 初始化

🌻 set_state() 设置状态

🌻 set_symbol() 设置棋手标志

🌻 act() 人类棋手通过键盘下棋

详细注释见代码:

python"># human interface
# input a number to put a chessman
# | q | w | e |
# | a | s | d |
# | z | x | c |
class HumanPlayer:
    def __init__(self, **kwargs):
        self.symbol = None
        self.keys = ['q', 'w', 'e', 'a', 's', 'd', 'z', 'x', 'c']
        self.state = None
        return

    def reset(self):
        return
    
    def set_state(self, state):
        self.state = state

    def set_symbol(self, symbol):
        self.symbol = symbol
        return
    
    def backup(self, _):
        return

    def act(self):
        self.state.print()
        key = input("Input your position:")  # 将这句话显示在屏幕上,接收用户输入的值,赋给key
        # 默认用户的输入是键盘最左边三行三列字母
        data = self.keys.index(key)  # 索引
        i = data // int(BOARD_COLS)
        j = data % BOARD_COLS
        return (i, j, self.symbol)

train函数

train()其实就是让两个AI对战,不断完善value function,直至逐渐收敛。是在训练AI棋手

python"># 2个AI间训练
# 两个AI player打,不断完善value function(即estimations)
# 训练结束后, Epoch 10000, player 1 win 0.08, player 2 win 0.03
def train(epochs):
    player1 = Player(epsilon = 0.01)
    player2 = Player(epsilon = 0.01)
    judger = Judger(player1, player2)
    player1_win = 0.0
    player2_win = 0.0
    for i in range(1, epochs + 1):
        winner = judger.play(print = False)   # 开始下棋
        if winner == 1:
            player1_win += 1
        if winner == -1:
            player2_win += 1
        print('Epoch %d, player1 win %.02f, player2 win %.02f' % (i, player1_win / i, player2_win / i))
        player1.backup()
        player2.backup()
        judger.reset()
    player1.save_policy()
    player2.save_policy()

compete函数

相当于train后的test。这个游戏规则太简单了,测试时两个AI都是平局

python"># 2个AI间测试
# 训练结束后的测试,AI之间就没有输赢了,全是平局:
# 1000 turns, player 1 win 0.00, player 2 win 0.00
# 计算下turns次棋,两个AI棋手的分别胜率
def compete(turns):
    player1 = Player(epsilon = 0)
    player2 = Player(epsilon = 0)
    judger = Judger(player1, player2)
    player1.load_policy()
    player2.load_policy()
    player1_win = 0.0
    player2_win = 0.0
    for i in range(0, turns):
        winner  = judger.play()
        if winner == 1:
            player1_vin += 1
        if winner == -1:
            player2_win += 1
        judger.reset()
    print('%d turns, player1 win %.02f, player1 win %.02f' % (turns, player1_win / turns, player2_win / turns))

play函数

人类和AI下棋

python"># 人类和AI下棋
def play():
    while True:
        player1 = HumanPlayer()
        player2 = Player(epsilon = 0)
        judger = Judger(player1, player2)
        player2.load_policy()   # AI棋手使用之前两个AI对战储存的policy(即value function,因为上面设置的epsilon为0,直接贪婪地选造成value最大状态的action)
        winner = judger.play()
        if winner == player2.symbol:
            print("You lose!")
        elif winner == player1.symbol:
            print("You win!")
        else:
            print("It is a tie!")


http://www.niftyadmin.cn/n/5018520.html

相关文章

Ubantu终端常用命令、快捷键和基本操作

目录 前言 一、常用命令 二、常用快捷键 三、快捷键自定义设置 总结 前言 Ubantu终端常用命令和快捷键用于进行系统管理、文件操作、软件安装等常见使用场景。使用它们可以提高工作效率&#xff0c;简化操作流程&#xff0c;并进行更多的自定义配置和控制。同时&#xff0c…

大模型参数高效微调PEFT的理解和应用

简介 近年的大型语言模型&#xff08;也被称作基础模型&#xff09;&#xff0c;大多是采用大量资料数据和庞大模型参数训练的结果&#xff0c;比如常见的ChatGPT3有175B的模型参数量。随着Large Language Model(LLM)的横空出世&#xff0c;网络模型对常见问题的解答有了很强的…

详细解析如何用“双指针“解题(面试必备,小白一看就会系类)

一、前言 大家在平时的训练和交流中肯定多少都会听过或者见过用"双指针"去快速的解题&#xff0c;那么大家有没有想过&#xff0c;为什么要用"双指针"呢&#xff1f;这里的"双指针"和我们平时了解的指针一样吗&#xff1f; 其实&#xff0c;这里…

【校招VIP】java语言考点之异常

考点介绍&#xff1a; 导致程序的正常流程被中断的事件&#xff0c;叫做异常。异常是程序中的一些错误&#xff0c;但并不是所有的错误都是异常&#xff0c;并且错误有时候是可以避免的。异常发生的原因有很多&#xff0c;通常包含以下几大类: 1.用户输入了非法数据。2.要打开的…

mp4视频太大怎么发送?这样压缩视频就对了

随着科技的发展&#xff0c;视频格式多种多样&#xff0c;其中mp4格式因为其通用性而广受欢迎。然而&#xff0c;有时候我们会遇到一个问题&#xff1a;mp4视频文件太大&#xff0c;导致发送变得困难。那么&#xff0c;如何解决这个问题呢&#xff1f;下面就给大家分享几个实用…

java 多线程乐观锁与悲观锁

乐观锁与悲观锁 悲观锁:一上来就加锁&#xff0c;没有安全感&#xff0c;每次只能一个线程进入访问完毕后再解锁。是线程安全的&#xff0c;但是性能较差&#xff01; 乐观锁:一开不上锁&#xff0c;认为是没有问题的&#xff0c;大家一起跑&#xff0c;等要出线程安全问题的…

ChatGPT很好,但别想着用来写留学申请文书!

大家必须承认一件事&#xff0c;大多数申请者和 ChatGPT 相比&#xff0c;ChatGPT 产出的文章质量更高—— ChatGPT语言更精准 ChatGPT文章结构更严谨 ChatGPT行文更流畅 …… 但是为什么仍然不建议大家利用人工智能来撰写申请文书呢&#xff1f; 文书至关重要——比大…

自学Python05-学会Python中的函数定义

亲爱的同学们&#xff0c;今天我们将开始学习 Python 中的函数。函数就像一个魔法盒子&#xff0c;可以让我们在程序中执行一段代码&#xff0c;并且可以反复使用。这样&#xff0c;我们的程序就可以变得更加简洁和易于理解。现在&#xff0c;让我们一起来学习如何使用函数吧&a…