ML-Agents与自己的强化学习算法通讯——PPO篇

news/2024/5/19 0:11:15 标签: 算法, python, 人工智能, 强化学习

在上一篇文章ML-Agents与python的Low Level API通信中,我简要介绍了Python与Unity端的ML-Agents插件的通讯代码,如何正确运行一个能够进行强化学习训练的Unity环境,并获取到响应的信息,接下来将介绍如何利用自己的强化学习算法进行训练。

介绍

这里我利用了强化学习库PARL来训练ML-Agents的3DBall,使用的是PPO算法
关于PPO的具体代码细节可以参考我以前的文章强化学习PPO代码讲解,这里不再讲述PPO的代码细节(之所以选择PARL,是因为感觉其代码通俗易懂)
PARL主要将代码分为了几个部分,首先是Model脚本,主要用来编写Actor,Critic等神经网络模型。然后是Algorithm脚本,主要编写具体的算法细节,主要有sample,predict,learn函数。还有storage脚本,主要用来存放经验池(reply buffer)。还有Config脚本,存放训练使用的超参数。Agent脚本,用来对Algorithm脚本进行进一步封装,是与环境交互的接口。最后才是训练入口脚本,调用agent脚本和环境进行交互。

主要源码分析

对于PPO算法,我们可以将其分为两个阶段。第一个是收集数据阶段,一个是训练模型阶段。
和SAC,DDPG等off-policy算法类型,PPO也有经验池,但是PPO是on-policy算法,所以收集数据和训练不能同时进行,每一次训练过后,我们都需要把旧的数据丢弃,重新用训练后的模型采集训练数据。
因此,大致流程是这样的:

  1. 所有智能体采集n个step的数据,存放到经验池中。
  2. 采集完成后,计算各个step的advantage,logprob等数据,同样存放起来。
  3. 利用经验池的数据进行m次PPO的更新
  4. 清空经验池数据,重新采样
python">from mlagents_envs.environment import UnityEnvironment
import numpy as np
from mlagents_envs.environment import ActionTuple
import argparse
import numpy as np
from parl.utils import logger, summary

from storage import RolloutStorage
from parl.algorithms import PPO
from agent import PPOAgent
from genenal_model import GenenalModel_Continuous_Divide
from genenal_config import genenal_config_continuous
from mlagents_envs.side_channel.engine_configuration_channel import EngineConfigurationChannel

# 创建环境
channel = EngineConfigurationChannel()
env = UnityEnvironment(file_name="UnityEnvironment", seed=1, side_channels=[channel])
channel.set_configuration_parameters(time_scale = 3.0)
env.reset()
# 获取环境信息
behavior_names = list(env.behavior_specs.keys())
behavior_value = list(env.behavior_specs.values())
for i in range(len(behavior_names)):
    print(behavior_names[i])
    print("obs:",behavior_value[i].observation_specs[0], "   act:", behavior_value[0].action_spec)
discrete_actions = None
total_steps = 0
stepsNum = 0
obs_space = behavior_value[i].observation_specs[0]
act_space = behavior_value[i].action_spec.continuous_size
# 建立Actor Critic模型
model = GenenalModel_Continuous_Divide(obs_space, act_space, [256,128], [256,128])
config = genenal_config_continuous
config['batch_size'] = int(config['env_num'] * config['step_nums'])
config['num_updates'] = int(
    config['train_total_steps'] // config['batch_size'])
# 建立PPO算法
ppo = PPO(
        model,
        clip_param=config['clip_param'],
        entropy_coef=config['entropy_coef'],
        initial_lr=config['initial_lr'],
        continuous_action=config['continuous_action'])
agent = PPOAgent(ppo, config)
# 建立经验池
rollout = RolloutStorage(config['step_nums'], config['env_num'], obs_space,
                         act_space)

DecisionSteps, TerminalSteps = env.get_steps(behavior_names[0])
obs = DecisionSteps.obs[0]
agentsNum = len(DecisionSteps)
done = np.zeros(agentsNum, dtype='float32')
total_reward = np.zeros(agentsNum, dtype='float32')
this_action = np.zeros((agentsNum, act_space), dtype='float32')
next_obs = np.zeros((agentsNum, obs_space.shape[0]), dtype='float32')
for update in range(1, config['num_updates'] + 1):
    # 数据收集
    for step in range(0, config['step_nums']):
        value, action, logprob, _ = agent.sample(obs)
        agentsNumNow = len(DecisionSteps)
        if agentsNumNow == 0:
            action = np.random.rand(0, 2)
        else:
            action = action.reshape(agentsNumNow, act_space)
            this_action = action
        actions = ActionTuple(action, discrete_actions)
        env.set_actions(behavior_names[0], actions)
        env.step()
        DecisionSteps, TerminalSteps = env.get_steps(behavior_names[0])
        next_obs_Decision = DecisionSteps.obs[0]
        next_obs_Terminal = TerminalSteps.obs[0]
        if(len(next_obs_Terminal) != 0):
            next_obs = np.zeros((agentsNum, obs_space.shape[-1]))
            rewards = np.zeros(agentsNum, dtype=float)
            next_done = np.zeros(agentsNum, dtype=bool)
            j = 0
            for i in TerminalSteps.agent_id:
                next_obs[i] = next_obs_Terminal[j]
                rewards[i] = TerminalSteps.reward[j]
                next_done[i] = True
                j += 1
            rollout.append(obs, this_action, logprob, rewards, done, value.flatten())
            obs, done = next_obs, next_done
            total_reward += rewards

        if(len(next_obs_Decision) != 0):
            step += 1
            next_obs = next_obs_Decision
            rewards = DecisionSteps.reward
            next_done = np.zeros(agentsNum, dtype=bool)

            rollout.append(obs, this_action, logprob, rewards, done, value.flatten())
            obs, done = next_obs, next_done
            total_reward += rewards

        total_steps += 1
        stepsNum += 1
        if(stepsNum % 200 == 199):
            arv_reward = total_reward / 200
            print("total_steps:{0}".format(total_steps))
            print("arv_reward:", arv_reward)
            stepsNum = 0
            total_reward = 0
    # PPO训练模型
    value = agent.value(obs)
    rollout.compute_returns(value, done)    
    value_loss, action_loss, entropy_loss, lr = agent.learn(rollout)


env.close()


源码链接

https://github.com/tianjuehai/mlagents-ppo


http://www.niftyadmin.cn/n/235851.html

相关文章

生成与指定数组具有相同形状的全1数组np.ones_like()方法

【小白从小学Python、C、Java】 【计算机等级考试500强双证书】 【Python-数据分析】 生成与指定数组A形状相同的全1数组 np.ones_like() 选择题 关于以下代码说法错误的一项是? import numpy as np a np.array([[0,1],[2,3]]) print("【显示】a\n",a) print(&qu…

Java每日一练(20230418)

目录 1. N皇后 II 🌟🌟🌟 2. 字符串相乘 🌟🌟 3. 买卖股票的最佳时机 🌟 🌟 每日一练刷题专栏 🌟 Golang每日一练 专栏 Python每日一练 专栏 C/C每日一练 专栏 Java每日一…

matplotlib设置中文字体为微软雅黑

matplotlib无法设置任何中文字体怎么办? 如何在linux系统下让matplotlib显示中文? 下载微软雅黑字体,把它放在某个目录下。 链接: https://pan.baidu.com/s/1SCLYpH_MzY7vn0HA0wxxAw?pwdft2j 提取码:ft2j 在代码中加…

我发现了PMP通关密码!这14页纸直接背!

一周就能背完的PMP考试技巧只有14页纸 共分成了4大模块 完全不用担心看不懂 01关键词篇 第1章引论 1.看到“驱动变革”--选项中找“将来状态” 2.看到“依赖关系”--选项中找“项目集管理” 3.看到“价值最大化”--选项中找“项目组合管理” 4.看到“可行性研究”--选项中…

STM32 LCD-ADC-DMA实验

目录 1.1 STM32 DMA简介 1.2 STM32 DMA的操作 1.DMA的初始化 2. 初始化代码 3. 主函数代码 本文将向大家介绍 STM32 的 DMA。(如有错误,欢迎批评指正) 在本章中,我们将利用 STM32 的 DMA 来实现ADC1通道1内数据传送,并在 TFTLCD 模块上显…

java轻量级框架MiniDao的详解

MiniDao是一款基于Java语言开发的轻量级持久层框架,它的目标是简化数据库操作流程,提高开发效率,减少代码量。MiniDao采用简单的注解配置方式,可以很容易地与Spring等常用框架集成使用。 MiniDao的主要特点包括: 简单…

FluxMQ—物联网高性能MQTT网关

FluxMQ—物联网高性能MQTT网关 随着物联网技术的快速发展,人们越来越意识到实时、可靠、安全的数据传输对于智能化的生产与生活的重要性。因此,市场对于高性能的物联网数据传输解决方案有着强烈的需求。FluxMQ正是为满足这一需求而诞生的一款高性能、可…

PLATO: Pre-trained Dialogue Generation Model with Discrete Latent Variable论文学习

一、大纲内容 二、详细内容 Abstract ○ 对话生成模型可以用于闲聊、知识对话、对话问题生成 ○ 本文 ■ 构建了一个灵活的attention机制,充分的促进了单向和双向的语言生成模型 ■ 介绍了一个离散的潜变量,较好的解决了一问多答的问题 ■ 上述两个结构…