一文读懂「RLHF,Reinforcement Learning from Human Feedback」基于人类反馈的进行强化学习

news/2024/5/18 21:47:45 标签: 人工智能, RLHF, 强化学习, GPT, 大语言模型, LLM

一、背景由来

过去几年里,以ChatGPT为代表的基于prompt范式的大型语言模型 (Large Language Model,LLM) 取得了巨大的成功。然而,对生成结果的评估是主观和依赖上下文的,这些结果难以用现有的基于规则的文本生成指标 (如 BLUE 和 ROUGE) 来衡量。除了评估指标,现有的模型通常以预测下一个单词的方式和简单的损失函数 (如交叉熵) 来建模,没有显式地引入人的偏好和主观意见。

一致性关注的是我们实际上希望模型做什么。它提出的问题是“目标函数是否符合我们的意图”,并且基于模型的目标和行为,在多大程度上与我们人类的价值观和和期望一致。举个简单的例子,假设我们要训练一个鸟类分类器,把鸟类分类为“麻雀”或“知更鸟”,并使用对数损失为训练目标,尽管我们的最终目标是很高的分类精度。该模型可能具有较低的对数损失,即模型的能力较强,但精度较差,这就是一个不一致的例子。模型可以优化培训目标,但与我们的最终目标不一致。

然而,在实际应用中,这些模型的目的是执行某种形式的有价值的认知工作,这些模型的训练方式与我们希望使用它们的方式之间存在着明显的分歧。尽管从数学上讲,机器计算的单词序列的统计分布可能是一种高效的选择,但实际上,我们会通过选择最适合给定情境的文本序列来生成语言,并使用我们的背景知识和常识来指导这一过程。当语言模型用于需要高度信任或可靠性的应用程序(如对话系统或智能个人助理)时,这可能会成为一个问题。

虽然在过去几年里,这些基于大量数据训练的模型变得极为复杂、强大,但当应用于实际人们生活生产时,它们往往无法发挥出潜力。大型语言模型中的一致性问题通常表现为:

  • 缺乏有效帮助:没有遵循用户的明确指示。

  • 虚构幻象:模型会虚构不存在或错误的事实。

  • 缺乏可解读性:人们很难理解模型是如何得出特定决策或预测结果的。

  • 训练内容偏见有害:经过有偏见、有害数据训练的语言模型,可能会在输出中重现这些数据,即使没有明确指示这样做。

但具体来说,一致性问题是从何处来的?

因此,训练阶段,如果直接用人的偏好(或者说人的反馈)来对模型整体的输出结果计算reward或loss,显然是要比传统的“给定上下文,预测下一个词”的损失函数合理的多。基于这个思想,便引出了本文要讨论的对象——RLHF(Reinforcement Learning from Human Feedback):即使用强化学习的方法,利用人类反馈信号直接优化语言模型。

RLHF_21">二、什么是RLHF

RLHF就是基于人类反馈(Human Feedback)对语言模型进行强化学习(Reinforcement Learning),和一般的fine-tune过程乃至prompt tuning自然也不同。

RLHF的训练过程可以分解为三个核心步骤:
在这里插入图片描述

  1. 多种策略产生样本并收集人类反馈
  2. 训练奖励模型
  3. 训练强化学习策略,微调 LM

二、原理介绍

在这里插入图片描述

Step 1:预训练语言模型+有标签数据微调(可选)

在这里插入图片描述

首先需要一个预训练语言模型,通过大量的语料去训练出基础模型,对于ChatGPT来说就是GPT-3。还有一个可选的Human Augmented Text,又叫Fine-tune。这里说直白点就是招人给问题(prompt)写示范回答(demonstration),然后给GPT-3上学习。Fine-tune又叫有标签数据微调,概念比较简单,就是给到标准答案让模型去学习,可能有同学好奇,这不是和一开始的例子一样吗?是的没错,但实际想要用人工去撰写答案的方式来训练模型,那成本是不可想象的,所以需要引入强化学习。后面会继续讲。

Step 2:训练奖励模型

在这里插入图片描述
我们需要一个模型来定量评判模型输出的回答在人类看来是否质量不错,即输入 [提示(prompt),模型生成的回答] ,奖励模型输出一个能表示回答质量的标量数字。

  1. 把大量的prompt(Open AI使用调用GPT-3用户的真实数据)输入给第一步得到的语言模型,对同一个问题,可以让一个模型生成多个回答,也可以让不同的微调(fine-tune)版本回答。

  2. 让标注人员对同一个问题的不同回答排序,有人可能会好奇为啥不打分?这是因为实验发现发现不同的标注员,打分的偏好会有很大的差异,而这种差异就会导致出现大量的噪声样本。排序的话能获得大大提升一致性。
    在这里插入图片描述
    在这里插入图片描述

  3. 这些不同的排序结果会通过某种归一化的方式变成定量的数据丢给模型训练,从而获得一个奖励模型。也就是一个裁判员。

Step 3:通过强化学习微调语言模型

在这里插入图片描述

基于强化学习(RL)去优化调整语言模型

  • policy是给GPT输入文本后输出结果的过程(输出文本的概率分布)

  • Action Space是词表所有token(可以简单理解为词语)在所有输出位置的排列组合

  • Observation space是可能输入的token序列,也就是Prompt

  • Reward Function则是基于上面第二步得到的奖励模型,配合一些策略层面的约束

  • 将初始语言模型的微调任务建模为强化学习问题,因此需要定义策略(policy)、动作空间(action space)和奖励函数(reward function)等基本要素

  • 具体怎么计算得到奖励Chat GPT是基于梯度下降,Open AI用的是 Proximal Policy Optimization (PPO) 算法。

☀️ 什么是 PPO?该算法的主要特点如下:

PPO 是一种用于在强化学习中训练代理 的算法。它被称为「on-policy」算法,因为它直接学习和更新当前策略,而不是像 DQN 的「off-policy」算法那样从过去的经验中学习。这意味着PPO正在根据代理人所采取的行动和所收到的奖励,不断的调整策略;

PPO 使用信任域优化方法来训练策略,这意味着它将策略的变化限制在与前一策略的一定范围内,以确保稳定性。这与其它策略梯度方法不同,后者有时会对策略进行大规模更新,从而破坏学习的稳定性;

PPO 使用价值函数,来估计给定状态或操作的预期返回。价值函数用于计算优势函数,它代表预期回报和当前回报之间的差异。然后,通过比较当前策略采取的操作与前一个策略本应采取的操作,使用优势函数更新策略。这使 PPO 可以根据所采取行动的预估值,对策略进行更智能的更新。

在这一步中,PPO 模型经由 SFT 模型初始化,且价值函数经由奖励模型初始化。该环境是一个「bandit environment」,它会产生随机显示提示, 并期望对提示做出响应。给出提示和响应后,它会产生奖励(由奖励模型决定)。SFT 模型会对每个 token 添加 KL 惩罚因子,以免奖励模型的过度优化。

三、缺点

局限性:在使语言模型与人类意图保持一致的过程中,用于调优的模型数据会受到各种复杂的主观因素的影响,主要包括:

  1. 生成 demo 数据的人工标注者的偏好;

  2. 设计研究并编写标签说明的研究人员

  3. 由开发人员编写或由OpenAI客户提供的可选的提示。

  4. 在模型评估中,奖励模型培训时所包含的标注者偏差(通过排名输出)

特别是ChatGPT作者指出的一个明显的事实,即参与培训过程的标注人员和研发人员,可能并不代表语言模型的所有潜在最终用户。此外,还有的一些其它缺点和需要解决的问题:

  • 缺乏对照研究

  • 比较数据缺乏基本事实

  • 人类的偏好并不一致

  • 奖励模型(RM)的即时稳定性测试

以及其它一些问题。

四、知识拓展

ChatGPT区别于其他模型的亮点以及没法复现的原因:

  1. 独特的基础模型:openai雇佣了上百号人产生了几万条打分数据,RLHF的本质类似于激发了基础模型的能力,基础模型已经很好了,只不过RLHF通过很少的计算量把模型能力激发出来了;
  2. 优秀的Reward模型:收集数据的时候,数据的质检很重要,这些数据需要真的反应你的想法,尽可能少的噪声;

另外一篇Deepmind的同类研究Sparrow,提供了每个阶段的实现细节:
在这里插入图片描述
在这里插入图片描述

四、资源

  1. 🔥 RLHF论文:Reinforcement Learning from Human Feedback
  2. 💡 一个讲的非常好的视频:清华博后带你走进ChatGPT——ChatGPTRLHF(3)
  3. 吴恩达RLHF课程:https://space.bilibili.com/253734135/channel/collectiondetail?sid=2028210&ctype=0
  4. Sparrow论文:https://arxiv.org/pdf/2209.14375.pdf
  5. ChatGPT 原来是这样工作的(上):https://www.jianshu.com/p/1987f73613ed

http://www.niftyadmin.cn/n/5332650.html

相关文章

基于YOLOv8深度学习的智能肺炎诊断系统【python源码+Pyqt5界面+数据集+训练代码】深度学习实战

《博主简介》 小伙伴们好,我是阿旭。专注于人工智能、AIGC、python、计算机视觉相关分享研究。 ✌更多学习资源,可关注公-仲-hao:【阿旭算法与机器学习】,共同学习交流~ 👍感谢小伙伴们点赞、关注! 《------往期经典推…

C++ //练习 1.7 编译一个包含不正确的嵌套注释的程序,观察编译器返回的错误信息。

C Primer(第5版) 练习 1.7 练习 1.7 编译一个包含不正确的嵌套注释的程序,观察编译器返回的错误信息。 环境:Linux Ubuntu(云服务器) 工具:vim 代码块 /**********************************…

C语言经典练习3——[NOIP2008]ISBN号码与圣诞树

前言 在学习C语言的过程中刷题是很重要的,俗话说眼看千遍不如手动一遍因为在真正动手去刷题的时候会暴露出更多你没有意识到的问题接下来我就为各位奉上两道我认为比较有代表性的题 1. [NOIP2008]ISBN号码 1.1 题目描述 每一本正式出版的图书都有一个ISBN号码与之对…

循环依赖和三级缓存

循环依赖: 是指一个或多个Bean实例之间存在直接或简介的依赖关系,这种依赖关系构成了环形调用(类与类之间的依赖关系形成了闭环)。 循环依赖的表现形式 eg1: 自己依赖自己的直接依赖 eg2: 两个对象之间的直接依赖 eg3: 多个对象之间的间接依赖 Spirng 框…

MyBatis XML 映射文件中的 SQL 语句可以分为动态语句和静态语句

静态查询&#xff1a; 静态查询是指在 SQL 语句中执行固定的查询操作&#xff0c;查询的条件和内容是预先确定的&#xff0c;不会随着用户输入或其他条件的改变而改变。以下是一个静态查询的示例&#xff1a; <select id"selectUsersByRole" resultType"co…

Spring MVC学习之——RequestMapping注解

RequestMapping注解 作用 用于建立请求URL和处理请求方法之间的对应关系。 属性 value&#xff1a;指定请求的实际地址&#xff0c;可以是一个字符串或者一个字符串列表。 value可以不写&#xff0c;直接在括号中写&#xff0c;默认就是value值 RequestMapping(value“/hel…

spark sql实践开发后端引擎

写在前面&#xff1a; 一转眼的时间&#xff0c;2024年了&#xff0c;翻看了一下博客首页&#xff0c;已有8年的码领&#xff0c;自从去年开启博客关注才能预览&#xff0c;至今已有1500个粉丝&#xff0c;比其他短视频平台的粉丝还要多&#xff0c;经年累月&#xff0c;8年一瞬…

android List,Set,Map区别和介绍

List 元素存放有序&#xff0c;元素可重复 1.LinkedList 链表&#xff0c;插入删除&#xff0c;非线性安全&#xff0c;插入和删除操作是双向链表操作&#xff0c;增加删除快&#xff0c;查找慢 add(E e)//添加元素 addFirst(E e)//向集合头部添加元素 addList(E e)//向集合…