强化学习DQN之俄罗斯方块

news/2024/5/18 21:47:31 标签: 深度学习, 人工智能, 机器学习, 强化学习, DQN

强化学习DQN之俄罗斯方块

  • 强化学习DQN之俄罗斯方块
    • 算法流程
    • 文件目录结构
    • 模型结构
    • 游戏环境
    • 训练代码
    • 测试代码
    • 结果展示

DQN_2">强化学习DQN之俄罗斯方块

算法流程

本项目目的是训练一个基于深度强化学习的俄罗斯方块。具体来说,这个代码通过以下步骤实现训练:

  • 首先设置一些随机数种子,以便在后面的训练中能够重现结果。
  • 创建一个俄罗斯方块环境实例,这个环境是一个俄罗斯方块游戏,用于模拟AI与游戏的交互。
  • 创建一个DeepQNetwork模型,这个模型是基于深度学习强化学习模型,用于预测下一步的最佳行动。
  • 创建一个优化器(optimizer)和一个损失函数(criterion),用于训练模型。
  • 在每个训练时期(epoch)中,对于当前状态(state),计算所有可能的下一步状态(next_steps),根据一定的策略(exploration or exploitation)选择一个行动(action),并计算该行动带来的奖励(reward)和下一步是否为终止状态(done)。
  • 将当前状态、奖励、下一步状态和终止状态添加到回放内存(replay memory)中。
  • 如果当前状态为终止状态,则重置环境,并记录得分(final_score)、俄罗斯方块数量(final_tetrominoes)和消除的行数(final_cleared_lines)。
  • 从回放内存中随机选择一批样本(batch),并将其用于训练模型。具体来说,将状态批次(state_batch)、奖励批次(reward_batch)、下一步状态批次(next_state_batch)和是否为终止状态批次(done_batch)分别取出,并将其分别转换为张量(tensor)。然后计算每个样本的目标值(target)y_batch,并用它来计算损失值(loss),并将损失值的梯度反向传播(backpropagation)。最后,使用优化器来更新模型参数。
  • 输出当前训练时期的信息,并记录得分、俄罗斯方块数量和消除的行数到TensorBoard中。
  • 如果当前训练时期为某个特定数的倍数,将模型保存到硬盘中。
  • 重复上述步骤,直到达到指定的训练时期数。

文件目录结构

├── output.mp4
├── src
│   ├── deep_q_network.py                                    模型结构
│   └── tetris.py                                            游戏环境
├── tensorboard
│   └── events.out.tfevents.1676879249.aifs3-worker-2
├── test.py                                                  测试代码
├── trained_models                                           训练保存的模型
│   ├── tetris
│   ├── tetris_1000
│   ├── tetris_1500
│   ├── tetris_2000
│   └── tetris_500
└── train.py                                                 训练代码

模型结构

import torch.nn as nn

class DeepQNetwork(nn.Module):
    def __init__(self):
        super(DeepQNetwork, self).__init__()

        self.conv1 = nn.Sequential(nn.Linear(4, 64), nn.ReLU(inplace=True))
        self.conv2 = nn.Sequential(nn.Linear(64, 64), nn.ReLU(inplace=True))
        self.conv3 = nn.Sequential(nn.Linear(64, 1))

        self._create_weights()

    def _create_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)

        return x

游戏环境

import numpy as np
from PIL import Image
import cv2
from matplotlib import style
import torch
import random

style.use("ggplot")


class Tetris:
    piece_colors = [
        (0, 0, 0),
        (255, 255, 0),
        (147, 88, 254),
        (54, 175, 144),
        (255, 0, 0),
        (102, 217, 238),
        (254, 151, 32),
        (0, 0, 255)
    ]

    pieces = [
        [[1, 1],
         [1, 1]],

        [[0, 2, 0],
         [2, 2, 2]],

        [[0, 3, 3],
         [3, 3, 0]],

        [[4, 4, 0],
         [0, 4, 4]],

        [[5, 5, 5, 5]],

        [[0, 0, 6],
         [6, 6, 6]],

        [[7, 0, 0],
         [7, 7, 7]]
    ]

    def __init__(self, height=20, width=10, block_size=20):
        self.height = height
        self.width = width
        self.block_size = block_size
        self.extra_board = np.ones((self.height * self.block_size, self.width * int(self.block_size / 2), 3),
                                   dtype=np.uint8) * np.array([204, 204, 255], dtype=np.uint8)
        self.text_color = (200, 20, 220)
        self.reset()
    #---------------------------------------------------------------------------------------
    #  重置游戏
    #---------------------------------------------------------------------------------------
    def reset(self):
        self.board = [[0] * self.width for _ in range(self.height)]
        self.score = 0
        self.tetrominoes = 0
        self.cleared_lines = 0
        self.bag = list(range(len(self.pieces)))
        random.shuffle(self.bag)
        self.ind = self.bag.pop()
        self.piece = [row[:] for row in self.pieces[self.ind]]
        self.current_pos = {"x": self.width // 2 - len(self.piece[0]) // 2, "y": 0}
        self.gameover = False
        return self.get_state_properties(self.board)
    #---------------------------------------------------------------------------------------
    #  旋转方块
    #---------------------------------------------------------------------------------------
    def rotate(self, piece):
        num_rows_orig = num_cols_new = len(piece)
        num_rows_new = len(piece[0])
        rotated_array = []

        for i in range(num_rows_new):
            new_row = [0] * num_cols_new
            for j in range(num_cols_new):
                new_row[j] = piece[(num_rows_orig - 1) - j][i]
            rotated_array.append(new_row)
        return rotated_array
    #---------------------------------------------------------------------------------------
    #  获取当前游戏状态的一些属性
    #---------------------------------------------------------------------------------------
    def get_state_properties(self, board):
        lines_cleared, board = self.check_cleared_rows(board)
        holes = self.get_holes(board)
        bumpiness, height = self.get_bumpiness_and_height(board)

        return torch.FloatTensor([lines_cleared, holes, bumpiness, height])
    #---------------------------------------------------------------------------------------
    #  面板中空洞数量
    #---------------------------------------------------------------------------------------
    def get_holes(self, board):
        num_holes = 0
        for col in zip(*board):
            row = 0
            while row < self.height and col[row] == 0:
                row += 1
            num_holes += len([x for x in col[row + 1:] if x == 0])
        return num_holes
    #---------------------------------------------------------------------------------------
    #  计算游戏面板的凹凸度和亮度
    #---------------------------------------------------------------------------------------
    def get_bumpiness_and_height(self, board):
        board = np.array(board)
        mask = board != 0
        invert_heights = np.where(mask.any(axis=0), np.argmax(mask, axis=0), self.height)
        heights = self.height - invert_heights
        total_height = np.sum(heights)
        currs = heights[:-1]
        nexts = heights[1:]
        diffs = np.abs(currs - nexts)
        total_bumpiness = np.sum(diffs)
        return total_bumpiness, total_height
    #---------------------------------------------------------------------------------------
    #  获取下一个可能的状态
    #---------------------------------------------------------------------------------------
    def get_next_states(self):
        states = {}
        piece_id = self.ind
        curr_piece = [row[:] for row in self.piece]
        if piece_id == 0:  # O piece
            num_rotations = 1
        elif piece_id == 2 or piece_id == 3 or piece_id == 4:
            num_rotations = 2
        else:
            num_rotations = 4

        for i in range(num_rotations):
            valid_xs = self.width - len(curr_piece[0])
            for x in range(valid_xs + 1):
                piece = [row[:] for row in curr_piece]
                pos = {"x": x, "y": 0}
                while not self.check_collision(piece, pos):
                    pos["y"] += 1
                self.truncate(piece, pos)
                board = self.store(piece, pos)
                states[(x, i)] = self.get_state_properties(board)
            curr_piece = self.rotate(curr_piece)
        return states
    #---------------------------------------------------------------------------------------
    #  获取当前面板状态
    #---------------------------------------------------------------------------------------
    def get_current_board_state(self):
        board = [x[:] for x in self.board]
        for y in range(len(self.piece)):
            for x in range(len(self.piece[y])):
                board[y + self.current_pos["y"]][x + self.current_pos["x"]] = self.piece[y][x]
        return board
    #---------------------------------------------------------------------------------------
    #  添加新的方块
    #---------------------------------------------------------------------------------------
    def new_piece(self):
        if not len(self.bag):
            self.bag = list(range(len(self.pieces)))
            random.shuffle(self.bag)
        self.ind = self.bag.pop()
        self.piece = [row[:] for row in self.pieces[self.ind]]
        self.current_pos = {"x": self.width // 2 - len(self.piece[0]) // 2,
                            "y": 0
                            }
        if self.check_collision(self.piece, self.current_pos):
            self.gameover = True
    #---------------------------------------------------------------------------------------
    #  检查边界   输入:形状、位置
    #---------------------------------------------------------------------------------------
    def check_collision(self, piece, pos):
        future_y = pos["y"] + 1
        for y in range(len(piece)):
            for x in range(len(piece[y])):
                if future_y + y > self.height - 1 or self.board[future_y + y][pos["x"] + x] and piece[y][x]:
                    return True
        return False

    def truncate(self, piece, pos):
        gameover = False
        last_collision_row = -1
        for y in range(len(piece)):
            for x in range(len(piece[y])):
                if self.board[pos["y"] + y][pos["x"] + x] and piece[y][x]:
                    if y > last_collision_row:
                        last_collision_row = y

        if pos["y"] - (len(piece) - last_collision_row) < 0 and last_collision_row > -1:
            while last_collision_row >= 0 and len(piece) > 1:
                gameover = True
                last_collision_row = -1
                del piece[0]
                for y in range(len(piece)):
                    for x in range(len(piece[y])):
                        if self.board[pos["y"] + y][pos["x"] + x] and piece[y][x] and y > last_collision_row:
                            last_collision_row = y
        return gameover

    def store(self, piece, pos):
        board = [x[:] for x in self.board]
        for y in range(len(piece)):
            for x in range(len(piece[y])):
                if piece[y][x] and not board[y + pos["y"]][x + pos["x"]]:
                    board[y + pos["y"]][x + pos["x"]] = piece[y][x]
        return board

    def check_cleared_rows(self, board):
        to_delete = []
        for i, row in enumerate(board[::-1]):
            if 0 not in row:
                to_delete.append(len(board) - 1 - i)
        if len(to_delete) > 0:
            board = self.remove_row(board, to_delete)
        return len(to_delete), board

    def remove_row(self, board, indices):
        for i in indices[::-1]:
            del board[i]
            board = [[0 for _ in range(self.width)]] + board
        return board

    def step(self, action, render=True, video=None):
        x, num_rotations = action
        self.current_pos = {"x": x, "y": 0}
        for _ in range(num_rotations):
            self.piece = self.rotate(self.piece)

        while not self.check_collision(self.piece, self.current_pos):
            self.current_pos["y"] += 1
            if render:
                self.render(video)

        overflow = self.truncate(self.piece, self.current_pos)
        if overflow:
            self.gameover = True

        self.board = self.store(self.piece, self.current_pos)

        lines_cleared, self.board = self.check_cleared_rows(self.board)
        score = 1 + (lines_cleared ** 2) * self.width
        self.score += score
        self.tetrominoes += 1
        self.cleared_lines += lines_cleared
        if not self.gameover:
            self.new_piece()
        if self.gameover:
            self.score -= 2

        return score, self.gameover

    def render(self, video=None):
        if not self.gameover:
            img = [self.piece_colors[p] for row in self.get_current_board_state() for p in row]
        else:
            img = [self.piece_colors[p] for row in self.board for p in row]
        img = np.array(img).reshape((self.height, self.width, 3)).astype(np.uint8)
        img = img[..., ::-1]
        img = Image.fromarray(img, "RGB")

        img = img.resize((self.width * self.block_size, self.height * self.block_size))
        img = np.array(img)
        img[[i * self.block_size for i in range(self.height)], :, :] = 0
        img[:, [i * self.block_size for i in range(self.width)], :] = 0

        img = np.concatenate((img, self.extra_board), axis=1)


        cv2.putText(img, "Score:", (self.width * self.block_size + int(self.block_size / 2), self.block_size),
                    fontFace=cv2.FONT_HERSHEY_DUPLEX, fontScale=1.0, color=self.text_color)
        cv2.putText(img, str(self.score),
                    (self.width * self.block_size + int(self.block_size / 2), 2 * self.block_size),
                    fontFace=cv2.FONT_HERSHEY_DUPLEX, fontScale=1.0, color=self.text_color)

        cv2.putText(img, "Pieces:", (self.width * self.block_size + int(self.block_size / 2), 4 * self.block_size),
                    fontFace=cv2.FONT_HERSHEY_DUPLEX, fontScale=1.0, color=self.text_color)
        cv2.putText(img, str(self.tetrominoes),
                    (self.width * self.block_size + int(self.block_size / 2), 5 * self.block_size),
                    fontFace=cv2.FONT_HERSHEY_DUPLEX, fontScale=1.0, color=self.text_color)

        cv2.putText(img, "Lines:", (self.width * self.block_size + int(self.block_size / 2), 7 * self.block_size),
                    fontFace=cv2.FONT_HERSHEY_DUPLEX, fontScale=1.0, color=self.text_color)
        cv2.putText(img, str(self.cleared_lines),
                    (self.width * self.block_size + int(self.block_size / 2), 8 * self.block_size),
                    fontFace=cv2.FONT_HERSHEY_DUPLEX, fontScale=1.0, color=self.text_color)

        if video:
            video.write(img)

        cv2.imshow("Deep Q-Learning Tetris", img)
        cv2.waitKey(1)

训练代码

import argparse
import os
import shutil
from random import random, randint, sample

import numpy as np
import torch
import torch.nn as nn
from tensorboardX import SummaryWriter
import time
from src.deep_q_network import DeepQNetwork
from src.tetris import Tetris
from collections import deque


def get_args():
    parser = argparse.ArgumentParser(
        """Implementation of Deep Q Network to play Tetris""")
    parser.add_argument("--width", type=int, default=10, help="The common width for all images")
    parser.add_argument("--height", type=int, default=20, help="The common height for all images")
    parser.add_argument("--block_size", type=int, default=30, help="Size of a block")
    parser.add_argument("--batch_size", type=int, default=512, help="The number of images per batch")
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--gamma", type=float, default=0.99)
    parser.add_argument("--initial_epsilon", type=float, default=1)
    parser.add_argument("--final_epsilon", type=float, default=1e-3)
    parser.add_argument("--num_decay_epochs", type=float, default=2000)
    parser.add_argument("--num_epochs", type=int, default=3000)
    parser.add_argument("--save_interval", type=int, default=500)
    parser.add_argument("--replay_memory_size", type=int, default=30000,
                        help="Number of epoches between testing phases")
    parser.add_argument("--log_path", type=str, default="tensorboard")
    parser.add_argument("--saved_path", type=str, default="trained_models")

    args = parser.parse_args()
    return args


def train(opt):
    if torch.cuda.is_available():
        torch.cuda.manual_seed(123)
    else:
        torch.manual_seed(123)
    if os.path.isdir(opt.log_path):
        shutil.rmtree(opt.log_path)
    os.makedirs(opt.log_path)
    writer = SummaryWriter(opt.log_path)
    env = Tetris(width=opt.width, height=opt.height, block_size=opt.block_size)
    model = DeepQNetwork()
    optimizer = torch.optim.Adam(model.parameters(), lr=opt.lr)
    criterion = nn.MSELoss()

    state = env.reset()
    if torch.cuda.is_available():
        model.cuda()
        state = state.cuda()

    replay_memory = deque(maxlen=opt.replay_memory_size)
    epoch = 0
    t1 = time.time()
    total_time = 0
    best_score = 1000
    while epoch < opt.num_epochs:
        start_time = time.time()
        next_steps = env.get_next_states()
        # Exploration or exploitation
        epsilon = opt.final_epsilon + (max(opt.num_decay_epochs - epoch, 0) * (
                opt.initial_epsilon - opt.final_epsilon) / opt.num_decay_epochs)
        u = random()
        random_action = u <= epsilon
        next_actions, next_states = zip(*next_steps.items())
        next_states = torch.stack(next_states)
        if torch.cuda.is_available():
            next_states = next_states.cuda()
        model.eval()
        with torch.no_grad():
            predictions = model(next_states)[:, 0]
        model.train()
        if random_action:
            index = randint(0, len(next_steps) - 1)
        else:
            index = torch.argmax(predictions).item()

        next_state = next_states[index, :]
        action = next_actions[index]

        reward, done = env.step(action, render=True)

        if torch.cuda.is_available():
            next_state = next_state.cuda()
        replay_memory.append([state, reward, next_state, done])
        if done:
            final_score = env.score
            final_tetrominoes = env.tetrominoes
            final_cleared_lines = env.cleared_lines
            state = env.reset()
            if torch.cuda.is_available():
                state = state.cuda()
        else:
            state = next_state
            continue
        if len(replay_memory) < opt.replay_memory_size / 10:
            continue
        epoch += 1
        batch = sample(replay_memory, min(len(replay_memory), opt.batch_size))
        state_batch, reward_batch, next_state_batch, done_batch = zip(*batch)
        state_batch = torch.stack(tuple(state for state in state_batch))
        reward_batch = torch.from_numpy(np.array(reward_batch, dtype=np.float32)[:, None])
        next_state_batch = torch.stack(tuple(state for state in next_state_batch))

        if torch.cuda.is_available():
            state_batch = state_batch.cuda()
            reward_batch = reward_batch.cuda()
            next_state_batch = next_state_batch.cuda()

        print("state_batch",state_batch.shape)
        q_values = model(state_batch)
        model.eval()
        with torch.no_grad():
            next_prediction_batch = model(next_state_batch)
        model.train()

        y_batch = torch.cat(
            tuple(reward if done else reward + opt.gamma * prediction for reward, done, prediction in
                  zip(reward_batch, done_batch, next_prediction_batch)))[:, None]

        optimizer.zero_grad()
        loss = criterion(q_values, y_batch)
        loss.backward()
        optimizer.step()
        end_time = time.time()
        use_time = end_time-t1 -total_time
        total_time = end_time-t1
        print("Epoch: {}/{}, Action: {}, Score: {}, Tetrominoes {}, Cleared lines: {}, Used time: {}, total used time: {}".format(
            epoch,
            opt.num_epochs,
            action,
            final_score,
            final_tetrominoes,
            final_cleared_lines,
            use_time,
            total_time))
        writer.add_scalar('Train/Score', final_score, epoch - 1)
        writer.add_scalar('Train/Tetrominoes', final_tetrominoes, epoch - 1)
        writer.add_scalar('Train/Cleared lines', final_cleared_lines, epoch - 1)

        if epoch > 0 and epoch % opt.save_interval == 0:
            print("save interval model: {}".format(epoch))
            torch.save(model, "{}/tetris_{}".format(opt.saved_path, epoch))
        elif final_score>best_score:
            best_score = final_score
            print("save best model: {}".format(best_score))
            torch.save(model, "{}/tetris_{}".format(opt.saved_path, best_score))


if __name__ == "__main__":
    opt = get_args()
    train(opt)

测试代码

import argparse
import torch
import cv2
from src.tetris import Tetris


def get_args():
    parser = argparse.ArgumentParser(
        """Implementation of Deep Q Network to play Tetris""")

    parser.add_argument("--width", type=int, default=10, help="The common width for all images")
    parser.add_argument("--height", type=int, default=20, help="The common height for all images")
    parser.add_argument("--block_size", type=int, default=30, help="Size of a block")
    parser.add_argument("--fps", type=int, default=300, help="frames per second")
    parser.add_argument("--saved_path", type=str, default="trained_models")
    parser.add_argument("--output", type=str, default="output.mp4")

    args = parser.parse_args()
    return args


def test(opt):
    if torch.cuda.is_available():
        torch.cuda.manual_seed(123)
    else:
        torch.manual_seed(123)
    if torch.cuda.is_available():
        model = torch.load("{}/tetris_2000".format(opt.saved_path))
    else:
        model = torch.load("{}/tetris_2000".format(opt.saved_path), map_location=lambda storage, loc: storage)
    model.eval()
    env = Tetris(width=opt.width, height=opt.height, block_size=opt.block_size)
    env.reset()
    if torch.cuda.is_available():
        model.cuda()
    out = cv2.VideoWriter(opt.output, cv2.VideoWriter_fourcc(*"MJPG"), opt.fps,
                          (int(1.5*opt.width*opt.block_size), opt.height*opt.block_size))
    while True:
        next_steps = env.get_next_states()
        next_actions, next_states = zip(*next_steps.items())
        next_states = torch.stack(next_states)
        if torch.cuda.is_available():
            next_states = next_states.cuda()
        predictions = model(next_states)[:, 0]
        index = torch.argmax(predictions).item()
        action = next_actions[index]
        _, done = env.step(action, render=True, video=out)

        if done:
            out.release()
            break
        


if __name__ == "__main__":
    opt = get_args()
    test(opt)

结果展示

在这里插入图片描述


http://www.niftyadmin.cn/n/86801.html

相关文章

C++项目——高并发内存池(3)--central cache整体设计

1.central cache的介绍 1.1框架思想 1.1.1哈希映射 centralcache其实也是哈希桶结构的&#xff0c;并且central cache和thread cacha的哈希映射关系是一致的。目的为了&#xff0c;当thread cache某一个哈希桶下没有内存块时&#xff0c;可以利用之前编写的SizeClass::Index…

.NET6中使用GRPC详细描述

Supported languages | gRPC&#xff0c;官网。至于原理就不说了&#xff0c;可以百度原理之后&#xff0c;然后再结合代码&#xff0c;事半功倍&#xff0c;就能很好理解GRPC了。 目录 一、简单使用 二、实际应用 一、简单使用 1.使用vs2022创建一个grpc程序&#xff0c;…

开发一个简单易用的SDK的详细步骤(超详细,超适用)

文章目录开发一个简单易用的SDK的详细步骤创建starter步骤关键点总结开发一个简单易用的SDK的详细步骤 创建starter步骤 1.新建一个 spring boot 初始化项目 2.添加依赖&#xff0c;Lombok, Spring Configuration Processor Spring Configuration Processor 的作用是自动生…

code-breaking之javacon

JAVACON 题目 此题 来自P神 的code-breaking中的一道Java题&#xff0c;名为javacon&#xff0c;题目知识点为SpEL注入 题目下载地址&#xff1a;https://www.leavesongs.com/media/attachment/2018/11/23/challenge-0.0.1-SNAPSHOT.jar 运行环境 java -jar challenge-0.…

AcWing1027. 方格取数

AcWing1027. 方格取数设有 NN 的方格图&#xff0c;我们在其中的某些方格中填入正整数&#xff0c;而其它的方格中则放入数字0。如下图所示&#xff1a;某人从图中的左上角 A 出发&#xff0c;可以向下行走&#xff0c;也可以向右行走&#xff0c;直到到达右下角的 B 点。在走过…

国内知名插画培训机构有哪些

国内知名插画培训机构有哪些&#xff1f;给大家梳理了国内5家专业的插画师培训班&#xff0c;最新无大插画班排行榜&#xff0c;各有优势和特色&#xff01; 一&#xff1a;国内知名插画培训机构排名 1、轻微课&#xff08;五颗星&#xff09; 主打课程有日系插画、游戏原画、古…

Long型数据后端查询结果为null,返回前端显示-1,使用@JsonSerialize注解

使用场景 在开发中&#xff0c;我们将对象序列化为JSON传输给前端&#xff0c;有时候我们的某个或者某些字段需要特殊处理&#xff0c;比如我们有一个日期字段&#xff0c;我们希望当日期为NULL时给前端不返回NULL而返回为未完成等信息&#xff0c;或者我们有一些状态字段&…

18-组件样式

组件样式 组件中样式的冲突 组件中定义的样式&#xff0c;默认是全局都有效的&#xff0c;vue最终会把这些样式汇总到一起。 组件的样式最终汇总出来的顺序&#xff0c;取决于在父组件中引入的顺序。 例如&#xff1a;School中定义的样式&#xff1a; <template><…