针对多轮推理分类问题的软标签构造方法

news/2024/5/18 23:44:04 标签: 强化学习, 分类, 机器学习

Motivation

在非对称博弈中,我们常常要对对手的状态(如持有的手牌类型)进行推理。此类推理问题有两个特点:(1) 虽然存在正确结果,但正确结果往往无法经过一次推理得到。因为随着游戏的进行,才能获得足够的信息 (2) 虽然无法一次性获得正确的结果,但可以基于现有信息推理获得更正确的分布。更正确的分布会有益于我们在接下来的游戏中做出正确决策。

基于这两个特点,我们考虑一个简化的问题:我们想要知道对手持有的(唯一)一张手牌的类型。该问题可以被理解为多分类问题。考虑使用神经网络的情况,由于需要进行多轮推理,参照针对MDP的DQN模型结构,网络的输入将是我们先前推理得到的关于手牌所属类型的概率分布(当前状态),以及得到的新信息(使得状态转移的动作);网络输出基于新信息更新的概率分布(新状态)。训练这个网络需要使用从真实游戏中采集的数据,包括每轮的玩家行为,以及最重要的对手手牌真实值

对于多分类,真实值应当是一个ont-hot向量。它代表着当我们已经拥有足够信息时应当得到的正确推理结果。但当还没有足够信息时,模型在理论上是没有能力推理出这个结果的(除非过拟合)。考虑一般的推理过程,在没有任何信息时,先验分布应当是一个均匀分布或正态分布。随着新信息的获得,这个分布被逐步修正,直至收敛为一个one-hot向量。如果在一开始就想要将它直接变换为one-hot向量,这种dynamics将破坏该分布应有的均匀或正态性,这将得到不准确的新分布,不不利于下一步决策的进行。

Contribution

本文提出了一个基于experience replay的两步过程,以构造适应于当前推理进程的软标签。模型在训练时使用本方法构造的软标签(而非直接使用真实ont-hot标签)以防止训练时产生不合时宜的dynamics。两步过程简述如下:

1. 采集阶段:使用该轮游戏数据中的推理结果和对应真实值(one-hot形式),构造软标签(每步构造一个)

2. 训练阶段:损失函数中分别对软标签和one-hot标签计算loss,并取min作为损失函数值。这样做可以在模型生成尚未确定的结果时使用软标签优化,避免破坏分布的均匀或正态性。在生成相对确定结果时使用one-hot标签优化,避免软标签带来的不确定性。

Method

下面详细介绍构造软标签的过程。为了方便下面的推导,首先说明一下分布的表示:假设我们要推断对手持有的一张手牌,该手牌可能有四种类型,那么对应的概率分布为一个长度为4的向量。每个维度(下标)代表一个随机变量值,对应元素代表该随机变量值出现的概率。

我们考虑该手牌不同类型是否相互独立两种情况。“不相互独立”指的是相邻的随机变量值之间存在概率密度的连续性,如手牌分为1、2、3、4四种,且1的个数最多,此后逐个减小。那么先验分布应当是一个从X=1开始概率值逐渐减小的分布。相反,若手牌不同类型相互独立,即说明类型下标相邻手牌间没有实质联系(包括数量或者其它方面的连续性)。那么先验分布将是一个均匀分布。

我们使用前一轮训练得到的模型(在当前步)的推理结果和one-hot真实标签两个分布来构造该步对应的软标签。该步推理结果(分布)的均值和方差作为当前推理进程的度量,那么一种简单的思路是使构造的软标签均值方差与当前推理结果相同。下面分别介绍是否相互独立两种情况的不同软标签构造方法:

相互独立情况:构造均匀分布

我们需要保证真实值(对应的随机变量值)对应的概率值比其它位置更大,且其它位置概率值不能为0。由于各个位置相互独立,且没有其它信息,我们只能认为其它各个位置的概率值都相等。处理这种简单的形式,我们可以直接让新分布向量的均值方差与推理结果(分布向量)的均值方差相等。那么设真实值位置对应概率值为x,其它位置概率值为y,以均值方差相等为条件,可以直接得到方程:

(x+(n-1)y)/n=\mu

((x-\mu)^2+(n-1)(y-\mu)^2)/n=\sigma^2

其中 n 为分布向量长度。

不相互独立情况:构造正态分布

在各随机变量值不相互独立时,与真实值(对应的随机变量值)相近的位置较更远位置有更大的概率密度,针对这种情况应当构造以真实值位置为峰的正态分布。该过程分为两步:首先构造一个与推理结果接近的正态分布,然后再将其峰的位置与真实值位置对齐。

基于推理结果构造正态分布

正态分布的形态由方差决定,位置由均值决定。但需要注意的是,这里的均值方差与上一种情况使用的分布向量均值方差不同。因为这种情况的下标具有语义,所以我们必须将分布向量作为离散概率密度函数(分布列)来理解。同时还需要分别处理所求的分布均值与峰的位置(正态分布均值即为峰的位置,但这与我们期望的不一定相同,因此要对其进行平移来得到我们期望的峰位置)。

设推理结果分布为 D,其随机变量值为 x_i,对应概率密度为 y_i。其均值方差为:

\mu=(\sum x_iy_i)/\sum y_i

\sigma^2=(\sum y_i(x_i-\mu)^2)/\sum y_i

为了方便下一步进行对齐,我们尽可能把将要求出的正态分布峰(上式中 \mu)平移到 x=0 位置,但显然平移会影响 \mu 与 \sigma^2 的值。因此先以 argmax(D) 作为峰位置的估计值平移:

x'_i=x_i-argmax(D)

再对上式中 x_i 代入 x'_i ( y_i 不变)计算 \mu 与 \sigma^2 。

对齐峰位置构造软标签

使用 x'_i 计算的 \mu (峰位置)会比直接使用 x_i 处于一个相对更接近 x=0 的位置。这步我们需要构造对 x_i 的转换,使得真实值位置 t 与峰位置 \mu 完全对齐。即当 x_i=t 时,x'_{2i}=\mu。显然有:

 x'_{2i}=x_i-t+\mu

这样可以使用 x'_{2i} 代入分布 N(\mu,\sigma^2) 的累计密度函数 cdf(x) 计算软标签各个维度值:位置 X=x_i 的对应密度为 cdf((x_{i+1}+x_i)/2)-cdf((x_i+x_{i-1})/2)

至此,软标签构造完成。


http://www.niftyadmin.cn/n/1427337.html

相关文章

设计模式15-模板模式

一、模板模式 定义一个操作的算法的框架,而将一些步骤延迟到子类中。使得子类可以不改变一个算法的结构即可重定义该算法的某些特定步骤。 实现方案:将算法/逻辑框架放在抽象基类中,并定义好实现接口,在子类中实现细节接口。 二…

java xml dtd 不符合提示_java解析带有dtd验证的xml文件出错?

我的xml文档是&#xff1a;<?xmlversion "1.0"?>G:\java\configuration.dtd">dasdasdhelvetiaca36我的xml文档是&#xff1a;dasdasdhelvetiaca36njknk我的dtd文档是&#xff1a;我用的是DOM解析&#xff0c;我检查了好几遍xml和dtd没发现错误&#…

视图与包装器

1、视图与包装器 Demo Map<Integer,String> map new HashMap<>();map.put(1,"aaaaa");map.put(2,"bbbbb");map.put(3,"cccc");System.out.println(map);Set<Integer> keySet map.keySet();System.out.println("键"…

java多线程 sycophantic_java多线程3种方式

Java多线程实现方式主要有三种&#xff1a;继承Thread类、实现Runnable接口、使用ExecutorService、Callable、Future实现有返回结果的多线程。其中前两种方式线程执行完后都没有返回值&#xff0c;只有最后一种是带返回值的。1、继承Thread类实现多线程继承Thread类的方法尽管…

c++并发编程02-什么是I/O

相信对于程序员来说I/O操作是最为熟悉不过的了&#xff1a; 当我们使用C语言中的printf、C中的"<<"&#xff0c;Python中的print&#xff0c;Java中的System.out.println等时&#xff0c;这是I/O&#xff1b;当我们使用各种语言读写文件时&#xff0c;这也是I…

Java容器中算法

1、算法 本部分介绍在容器体系中的一些常规算法&#xff1a; ​ 这些算法常常封装到一些工具类中;l例如Collections&#xff0c;Arrays. 2、求最大值 //参数为接口&#xff1a;所有实现给接口的类的对象都可以作为参数传递进去 public static <T extends Object & Comp…

Policy Evaluation收敛性、炼丹与数学家

完美的学习算法 昨天和同学在群里讨论DRL里bad case的问题。突然有同学提出观点&#xff1a;“bad case其实并不存在&#xff0c;因为一些算法已经理论证明了具有唯一极值点&#xff0c;再加上一些平滑技巧指导优化器&#xff0c;就必然可以收敛。” 当听到这个观点时&#x…

java必知必会_Java必知必会-Spring Security

一、Spring Security介绍1、框架介绍Spring 是一个非常流行和成功的 Java 应用开发框架。Spring Security 基于 Spring 框架&#xff0c;提供了一套 Web 应用安全性的完整解决方案。一般来说&#xff0c;Web 应用的安全性包括用户认证(Authentication)和用户授权(Authorization…