Python-DQN代码阅读(9)

目录

1.代码阅读

1.1 代码总括

1.2 代码分解

1.2.1 replay_memory.pop(0)

1.2.2 replay_memory.append(Transition(state, action, reward, next_state, done))

1.2.3 samples = random.sample(replay_memory, batch_size)

1.2.4 q_values_next = target_net.predict(sess, next_states_batch)

1.2.5 greedy_q = np.amax(q_values_next, axis=1)

1.2.6 targets_batch = reward_batch + np.invert(done_batch).astype(np.float32) * gamma * greedy_q

1.2.7 loss = q_net.update(sess, states_batch, action_batch, targets_batch)


1.代码阅读

1.1 代码总括

python">if (train_or_test == 'train'):
    # 如果回放记忆满了,弹出第一个元素
    if len(replay_memory) == replay_memory_size:
        replay_memory.pop(0)

    # 将转换(transition)保存到回放记忆中
    # 对于每一次生命损失(loss of life),将 done = True 记录到回放记忆中
    if (ale_lives == info_ale_lives):
        replay_memory.append(Transition(state, action, reward, next_state, done))
    else:
        replay_memory.append(Transition(state, action, reward, next_state, True))

    # 从回放记忆中随机采样一个小批次样本
    samples = random.sample(replay_memory, batch_size)
    states_batch, action_batch, reward_batch, next_states_batch, done_batch = map(np.array, zip(*samples))

    # 计算 Q 值和目标值
    q_values_next = target_net.predict(sess, next_states_batch)
    greedy_q = np.amax(q_values_next, axis=1)
    targets_batch = reward_batch + np.invert(done_batch).astype(np.float32) * gamma * greedy_q

    # 更新网络
    if (total_t % 4 == 0):
        states_batch = np.array(states_batch)
        loss = q_net.update(sess, states_batch, action_batch, targets_batch)

这段代码的功能是将转换(transition)数据保存到回放记忆(replay memory)中,然后从回放记忆中随机采样一个小批次样本,计算 Q 值和目标值,最后使用 Q 网络(q_net)更新网络参数。这是一种使用经验回放(experience replay)的方法,用于训练强化学习智能体,提高训练的稳定性和样本利用率。

1.2 代码分解

1.2.1 replay_memory.pop(0)

python">                if len(replay_memory) == replay_memory_size:
                    replay_memory.pop(0)

这段代码用于控制回放记忆池的大小。回放记忆池是在强化学习中用于存储Agent与环境交互过程中的经验样本(称为转换或者记忆),用于训练神经网络。

len(replay_memory) 表示当前回放记忆池中的样本数量,replay_memory_size 是设定的回放记忆池的最大容量。

这段代码中的条件 len(replay_memory) == replay_memory_size 检查当前回放记忆池的长度是否达到了最大容量。如果达到了最大容量,就执行 replay_memory.pop(0) 操作,从回放记忆池的最前面(即索引为0的位置)弹出第一个元素,以保持回放记忆池的大小不超过设定的最大容量。

这样做的目的通常是为了控制回放记忆池的大小,防止其无限增长,从而限制训练过程中的内存占用和计算资源消耗。当回放记忆池达到最大容量时,新的经验样本会替代最早的样本,从而保持回放记忆池的容量在一个固定的范围内。

1.2.2 replay_memory.append(Transition(state, action, reward, next_state, done))

python">if (ale_lives == info_ale_lives):
        replay_memory.append(Transition(state, action, reward, next_state, done))
    else:
        replay_memory.append(Transition(state, action, reward, next_state, True))

这段代码用于将当前的转换(Transition)添加到回放记忆池(replay_memory)中。

ale_livesinfo_ale_lives 是用于记录游戏中剩余生命值的变量,其值相等时表示游戏中的生命值没有发生变化。

如果 ale_livesinfo_ale_lives 相等,即当前的生命值没有发生变化,那么将当前的转换添加到回放记忆池中,并将 done 设置为 False,表示游戏未结束。

如果 ale_livesinfo_ale_lives 不相等,即发生了生命值的变化,那么将当前的转换添加到回放记忆池中,并将 done 设置为 True,表示游戏已经结束。

这样做的目的通常是为了将游戏中每次生命值的变化视为一个独立的转换,以便在训练过程中更好地处理游戏中的生命值变化情况。这可以帮助Agent更好地学习处理生命值变化对游戏进程和策略的影响。

(1)

python">replay_memory.append(Transition(state, action, reward, next_state, done))

这段代码将一个完整的转换(Transition)对象添加到回放记忆池(replay_memory)中

state 是当前状态的表示,可以是游戏画面、环境状态等; action 是Agent选择的动作; reward 是执行动作后获得的奖励; next_state 是执行动作后的下一个状态; done 是一个布尔值,表示当前转换是否是一个终止状态(例如游戏结束状态)。

通过将这些信息封装成一个转换对象(例如一个自定义的Transition类),可以将Agent在环境中的经验存储到回放记忆池中,以便在训练过程中进行经验回放,从而提高训练的效果。在训练过程中,Agent可以从回放记忆池中随机抽样一批转换,并用于更新其神经网络模型,从而进行优化和改进。

(2)

python">replay_memory.append(Transition(state, action, reward, next_state, True))

这段代码将一个完整的转换(Transition)对象添加到回放记忆池(replay_memory)中,并设置 done 参数为 True

state 是当前状态的表示,可以是游戏画面、环境状态等; action 是Agent选择的动作; reward 是执行动作后获得的奖励; next_state 是执行动作后的下一个状态; done 是一个布尔值,表示当前转换是否是一个终止状态(例如游戏结束状态)。

通过将这些信息封装成一个转换对象(例如一个自定义的Transition类),可以将Agent在环境中的经验存储到回放记忆池中,以便在训练过程中进行经验回放,从而提高训练的效果。当一个转换被设置为终止状态时,done 参数应该被设置为 True,以便在训练过程中正确处理终止状态的情况,例如更新目标Q值的计算等。

1.2.3 samples = random.sample(replay_memory, batch_size)

python">samples = random.sample(replay_memory, batch_size)
states_batch, action_batch, reward_batch, next_states_batch, done_batch = map(np.array, zip(*samples))

这段代码从回放记忆池(replay_memory)中随机采样得到 batch_size 个样本,并将这些样本解压缩成不同的变量。

replay_memory 是一个存储着多个转换(Transition)对象的列表,其中每个转换包含了一个状态转移过程中的信息,如上一个回答所述。

random.sample 函数用于从 replay_memory随机采样指定数量的样本,即 batch_size 个样本。这样的采样方式可以打破样本之间的时序关联性,从而减少样本之间的相关性,有助于提高训练的效果。

解压缩的过程中,zip(*samples) 将转换对象中对应的属性(如状态、动作、奖励、下一个状态、是否为终止状态)按照属性的维度进行组合返回一个包含多个元组的迭代器。然后通过 map(np.array, ...) 将每个元组中的属性转换为 NumPy 数组,得到 states_batchaction_batchreward_batchnext_states_batchdone_batch,它们分别表示状态、动作、奖励、下一个状态和是否为终止状态的批量数据。这些数据可以用于后续的训练操作,例如计算Q值和更新神经网络参数等。

1.2.4 q_values_next = target_net.predict(sess, next_states_batch)

python">q_values_next = target_net.predict(sess, next_states_batch)

这段代码通过调用 target_net 对象的 predict 方法,输入 sessnext_states_batch,获取下一个状态批量数据 next_states_batch 对应的 Q 值估计值。

target_net 是一个目标网络(Target Network),通常用于在训练过程中稳定目标估计。

强化学习的深度 Q 网络(DQN)算法中,使用两个神经网络,一个是主网络(Policy Network),用于选择动作和计算 Q 值,另一个就是目标网络,用于计算目标 Q 值。

predict 方法是用于进行预测的方法,接受输入数据 next_states_batch,并返回对应的预测结果,即下一个状态批量数据 next_states_batch 对应的 Q 值估计值 q_values_next这个 Q 值估计值可以作为训练过程中更新 Q 值的目标值,用于计算损失并进行反向传播更新网络参数。

1.2.5 greedy_q = np.amax(q_values_next, axis=1)

python">greedy_q = np.amax(q_values_next, axis=1)

这段代码使用 np.amax 函数计算 q_values_next 中每一行的最大值,即在每个状态下可选动作的最大 Q 值。

q_values_next 是通过目标网络 target_net 对下一个状态批量数据 next_states_batch 进行预测得到的 Q 值估计值。axis=1 参数表示在每一行中查找最大值。

计算出的 greedy_q 是一个一维数组,其中的每个元素表示在对应状态下的最大 Q 值,即选择最优动作的 Q 值。这些最大 Q 值将用于计算训练过程中的目标 Q 值,用于更新网络参数。

1.2.6 targets_batch = reward_batch + np.invert(done_batch).astype(np.float32) * gamma * greedy_q

python">targets_batch = reward_batch + np.invert(done_batch).astype(np.float32) * gamma * greedy_q

这段代码计算训练过程中的目标 Q 值,用于更新网络参数。

reward_batch 是从回放内存中取出的当前批次的奖励值,表示当前状态下选择的动作的即时奖励。

done_batch 是从回放内存中取出的当前批次的完成状态标志,表示当前状态是否为一个终止状态。done_batch 为 True 表示当前状态为终止状态,反之为 False。

np.invert(done_batch) 是对 done_batch 进行按位取反操作,将 True 转换为 False,将 False 转换为 True。

astype(np.float32) 是将 done_batch 数组中的数据类型转换为 float32 类型,以便后续的计算。

gamma强化学习中的折扣因子,用于控制未来奖励的重要性。在计算目标 Q 值时,乘以 gamma 可以降低未来奖励的权重。

greedy_q 是在前面的代码中计算得到的在下一个状态下选择最优动作的 Q 值估计值。

通过以上的计算,targets_batch 将得到当前状态下的目标 Q 值,用于更新网络参数。

具体而言,reward_batch 会被加到目标 Q 值中,如果当前状态为终止状态(done_batch 为 True),则目标 Q 值不再受未来奖励影响;如果当前状态不是终止状态(done_batch 为 False),则目标 Q 值会受到未来奖励的影响,乘以 gamma 并加上 greedy_q这样计算得到的 targets_batch 将作为训练过程中的目标 Q 值,用于更新网络参数。

1.2.7 loss = q_net.update(sess, states_batch, action_batch, targets_batch)

python">if (total_t % 4 == 0):
        states_batch = np.array(states_batch)
        loss = q_net.update(sess, states_batch, action_batch, targets_batch)

这段代码用于控制网络的更新频率,每隔4个时间步更新一次网络参数。

total_t 是一个计数器,用于记录训练过程中的总时间步数。

if (total_t % 4 == 0): 判断当前时间步是否是4的倍数,如果是则执行下面的代码块。

states_batch 是当前批次的状态值。包含了当前批次的状态值的列表。通过 np.array(states_batch) 将其转换为 NumPy 数组,便于后续在深度学习模型中进行处理。

action_batch 是当前批次的动作值。包含了当前批次的动作选择的列表,其中每个元素是一个整数,表示代理在当前状态下选择的动作。

targets_batch 是前面计算得到的目标 Q 值。包含了当前批次的目标 Q 值的列表,其中每个元素是一个浮点数,表示代理在当前状态下根据当前策略预测的 Q 值目标。

q_net.update(sess, states_batch, action_batch, targets_batch) 是调用 Q 网络的 update 方法来更新网络参数。具体的更新算法依赖于具体的深度 Q 网络实现,可能使用梯度下降、优化器等方法进行参数的更新。这里将当前批次的状态、动作和目标 Q 值传入网络的 update 方法,以实现网络参数的更新。

通过这段代码的控制,网络的参数更新频率被限制在每隔4个时间步更新一次,从而控制网络的学习速度,平衡训练速度和稳定性之间的关系。

python">   loss = q_net.update(sess, states_batch, action_batch, targets_batch)

q_net.update() 是一个用于更新 Q 网络权重的方法,其中 sess 是 TensorFlow 会话对象,states_batch 是输入的状态批次,action_batch 是动作选择批次,targets_batch 是目标 Q 值批次。

强化学习中,Q 网络的更新通常通过最小化损失函数来完成,损失函数度量了当前策略和目标 Q 值之间的差异。具体而言,对于每一个状态,Q 网络预测了每个动作的 Q 值,而目标 Q 值是通过贝尔曼方程计算得出的更新的目标是使预测的 Q 值与目标 Q 值尽可能接近。

q_net.update() 方法会计算损失函数,并使用优化算法(如梯度下降)来更新 Q 网络的权重,使其向着更优的策略逐步优化。损失函数的计算通常包括了预测的 Q 值和目标 Q 值之间的差异,以及其他的正则化项或优化目标。更新过程中使用的输入数据包括了当前状态批次、动作选择批次和目标 Q 值批次,用于计算损失函数和更新权重。


http://www.niftyadmin.cn/n/223611.html

相关文章

【Linux】页表的深入分析

上一篇文章介绍了线程的基本概念 而本篇文章我们来深入理解一下, CPU再调度我们以往理解的进程和如今的线程都会涉及到的一个内容: 页表 文章目录深入理解页表 *页表的实际组成*什么是page?深入理解页表 * 在介绍进程时, 博主没有深入介绍过页表. 只是简单说了 页…

Vue——组件注册

目录 全局注册​ 局部注册​ 组件名格式​ 一个 Vue 组件在使用前需要先被“注册”,这样 Vue 才能在渲染模板时找到其对应的实现。组件注册有两种方式:全局注册和局部注册。 全局注册​ 我们可以使用 Vue 应用实例的 app.component() 方法&#xff…

( “树” 之 DFS) 437. 路径总和 III ——【Leetcode每日一题】

437. 路径总和 III 给定一个二叉树的根节点 root ,和一个整数 targetSum ,求该二叉树里节点值之和等于 targetSum 的 路径 的数目。 路径 不需要从根节点开始,也不需要在叶子节点结束,但是路径方向必须是向下的(只能…

Flutter macOS 13.0环境配置

1、配置环境变量 export ANDROID_HOME“/Users/mac/Documents/Android/SDK” export PATH${PATH}:${ANDROID_HOME}/tools export PATH${PATH}:${ANDROID_HOME}/platform-tools export PUB_HOSTED_URLhttps://pub.flutter-io.cn export FLUTTER_STORAGE_BASE_URLhttps://storag…

【面试复盘】猿辅导计算机视觉算法工程师一面

来源:投稿 作者:LSC 编辑:学姐 总共时长45分钟,首先是自我介绍,然后开始面试。 1.旋转数组[1,2,3,4,5,6,7]--->[5,6,7,1,2,3,4],找出k是否存在一个旋转数组中,旋转数组一开始是有序的。 二分…

Linux / Centos Stream 9安装 Skywalking 9.4.0 记录

链路追踪框架 官网:http://skywalking.apache.org/ 下载:http://skywalking.apache.org/downloads/ Github:https://github.com/apache/skywalking 文档:https://skywalking.apache.org/docs/main/v9.4.0/readme/ 中文文档&#x…

Python程序异常处理

一、什么是异常 异常就是程序运行时发生错误的信号,在程序由于某些原因出现错误的时候,若程序没有处理它,则会抛出异常,程序也的运行也会随之终止; 程序异常带来的问题: 1.程序终止,无法运行…

在Android中监听网络连接的简单方法

在Android中监听网络连接的简单方法 要使用 Kotlin 监控 Android 中的互联网连接,您可以使用该类ConnectivityManager,这是一个允许您查询网络状态的系统服务。以下是如何使用它的示例: 将以下权限添加到您的 AndroidManifest.xml 文件&…