用强化学习神包trl轻松实现GPT2可控文本生成

news/2024/5/18 21:47:32 标签: 文本生成, gpt2, 迁移强化学习, 强化学习

来源:投稿 作者:Sally can wait
编辑:学姐

模型github: lvwerra/trl: Train transformer language models with reinforcement learning. (github.com)https://github.com/lvwerra/trl

这个项目是复现 ”Fine-Tuning Language Models from Human Preferences" by D. Ziegler et al一文的paper, code,因为觉得它非常好用,所以跟着跑通这个项目,并加上自己的理解介绍给大家。

理论基础

什么是可控文本生成

虽然GPT2已经能生成流畅的句子,但是在特定话题的控制和逻辑性上仍然和期望相去甚远。我们希望一个文本生成模型可以一贯地围绕一个话题进行续写,而不是漫无目的地续写下去,这就是可控文本生成的研究目标。

在特定的运用场景中,我们有时需要用文本生成的方式增广数据。这时候可控文本生成可以在保证标签不变的前提下产生出大量的“伪数据”。

而大模型如GPT3、chatGPT效果较好,但是并不开源,而且由于巨大的参数量,微调起来也是浩大的工程。所以大部分的可控文本生成研究还是围绕GPT2做文章。

强化学习和PPO

强化学习不同于监督学习。监督学习只是对给定的、封闭的训练-验证数据集做参数优化,再用优化后的参数指导模型做出正确的输出。而强化学习的特点表现在强化信号上,强化信号是对产生动作的好坏作一种评价 (通常为标量),因此模型在不断产出输出的同时也在不断获得针对该输出的反馈,用这个反馈来更新模型参数。只要反馈机制是合理的,那么强化学习就可以一直进行下去,而不会面临训练数据匮乏的问题。

PPO(近端策略优化,Proximal Policy Optimisation)强化学习目前最有效的一种算法。和先前的强化学习算法相比,PPO它在每一步迭代中都会尝试计算新的策略,这样可以让损失函数最小化,同时还能保证与上一步迭代的策略间的偏差相对较小。

PPO 里面有两项:一项是优化的奖励,另一项是一个约束。约束是为了防止模型被微调得过于离谱,失去了原有的语言模型做流畅的文字生成的能力。

How it works?

用PPO算法优化GPT2大致分以下三个步骤:

  1. 续写:GPT2先根据当前权重,续写给出的句子。

  2. 评估:GPT2续写的结果会经过一个分类层,或者也可以采用人工的打分,重要的是最终产生出一个数值型的分数。

  3. 优化:上一步对生成句子的打分会用于更新序列中token的对数概率。除此之外,还需要引入一个新的奖惩机制:KL散度。这需要用一个参考模型(通常是微调前的预训练模型,如GPT2-base)计算微调模型的输出和参考模型的输出之间的KL散度,把它和之前步骤的打分加在一起作为奖励函数,目的是确保生成的句子不会过多地偏离参考语言模型。然后使用PPO算法进一步训练语言模型。

实战:强化学习让GPT2产生正向IMDB影评

我们用强化学习的方法微调英文版 GPT2(small),让它基于 IMDB 数据集生成正面电影评论。该模型先是读取数据集中真实的影评,用GPT2续写。为了奖励情感倾向为正的续写,我们使用BERT影评分类器来分析生成的句子的情绪,把分类器的输出作为PPO训练的奖励。如果GPT2的续写经过分类器判别为正向情感,那么直接将分类器在正向情感上的置信度作为奖励加到ppo_trainer里面;反之,如果GPT2的续写经过分类器判别为负面情感,那么它在分类器输出层,正向情感对应的置信度会是负数或者很低,同样地将这个置信度加入ppo_trainer,可以提示模型减少对此输出的学习。

1.安装依赖包

conda install pytorch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0 cudatoolkit=11.1 -c pytorch -c conda-forge --yes

2.读取包

#torch==1.8, transformers==4.15.0
import torch
import wandb
import time
import os
from tqdm import tqdm
import numpy as np
import pandas as pd
tqdm.pandas()
from datasets import load_dataset
from transformers import AutoTokenizer, pipeline

from trl.gpt2 import GPT2HeadWithValueModel, respond_to_batch
from trl.ppo import PPOTrainer
from trl.core import build_bert_batch_from_txt, listify_batch

3.设置需要用到的超参数

config = {
    "model_name": "lvwerra/gpt2-imdb",
    "cls_model_name": "lvwerra/distilbert-imdb",
    "steps": 20000,
    "batch_size": 256,
    "forward_batch_size": 16,
    "ppo_epochs": 4,   
    "txt_in_min_len": 2,
    "txt_in_max_len": 8,
    "txt_out_min_len": 4,
    "txt_out_max_len": 16,
    "lr": 1.41e-5,
    "init_kl_coef":0.2,
    "target": 6,
    "horizon":10000,
    "gamma":1,
    "lam":0.95,
    "cliprange": .2,
    "cliprange_value":.2,
    "vf_coef":.1, 
}
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipe_device = 0 if torch.cuda.is_available() else -1

4.用wandb仪表盘监控训练过程中的各项指标、中间变量。首次使用需要注册一下。

wandb.init(name='run-42', project='gpt2-test', config=config, )

图:在训练过程中可以观察到训练过程中的中间变量。“query”和“response”分别表示IMDB的原始句子prompt(经过随机截断)和GPT2生成的续写,“reward”表示经过情感分类器之后的正向情感分值,越大表示情感越积极

5.加载IMDB数据集

ds = load_dataset('imdb', split='train')
ds = ds.rename_columns({'text': 'review', 'label': 'sentiment'})
ds = ds.filter(lambda x: len(x["review"])>200, batched=False)

6.load一个集成在transformers pipeline里的影评分类器(此处也可以替换成别的分类器,只要有打分就行)

sent_kwargs = {
    "return_all_scores": True,
    "function_to_apply": "none",
    "batch_size": config["forward_batch_size"]
}#指定分类器输出的格式

sentiment_pipe = pipeline("sentiment-analysis","lvwerra/distilbert-imdb", device=pipe_device)
text = 'this movie was really bad!!'
sentiment_pipe(text, **sent_kwargs)

#一条分类后的结果长这样:我们需要的是score
#[[{'label': 'NEGATIVE', 'score': 2.3350484371185303}, {'label': 'POSITIVE', 'score': -2.726576089859009}]]

这里注意必须要确保transformers版本是4.15.0,不同版本的Pipeline输出大有不同

7.加载预训练GPT2-small

gpt2_model = GPT2HeadWithValueModel.from_pretrained(config['model_name'])
gpt2_model_ref = GPT2HeadWithValueModel.from_pretrained(config['model_name'])

gpt2_tokenizer = AutoTokenizer.from_pretrained(config['model_name'])
gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token

wandb.watch(gpt2_model, log='all') #观察模型

gpt2_model.to(device);
gpt2_model_ref.to(device);

#设置文本生成的参数
gen_kwargs = {
    "min_length":-1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": gpt2_tokenizer.eos_token_id
}

8.长度控制+tokenize

class LengthSampler:
    def __init__(self, min_value, max_value):
        self.values = list(range(min_value, max_value))
    def __call__(self):
        return np.random.choice(self.values)
    
input_size = LengthSampler(config["txt_in_min_len"], config["txt_in_max_len"])
output_size = LengthSampler(config["txt_out_min_len"], config["txt_out_max_len"])
# 在tokenize之前,随机截断输入数据作为待续写的prompt,也随机确定续写长度,防止输入输出的长度过于单一

def tokenize(sample):
    sample["tokens"] = gpt2_tokenizer.encode(sample["review"])[:input_size()]
    sample["query"] = gpt2_tokenizer.decode(sample["tokens"])
    return sample

ds = ds.map(tokenize, batched=False)

def collater(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

dataloader = torch.utils.data.DataLoader(ds, batch_size=config['batch_size'], collate_fn=collater)

9.正式训练

ppo_trainer = PPOTrainer(gpt2_model, gpt2_model_ref, gpt2_tokenizer, **config)

total_ppo_epochs = int(np.ceil(config["steps"]/config['batch_size']))

for epoch, batch in tqdm(zip(range(total_ppo_epochs), iter(dataloader))):
    logs, timing = dict(), dict()
    t0 = time.time()
    query_tensors = [torch.tensor(t).long().to(device) for t in batch["tokens"]]
    
    #### Get response from gpt2
    t = time.time()
    response_tensors = []
    for i in range(config['batch_size']):
        gen_len = output_size()
        response = gpt2_model.generate(query_tensors[i].unsqueeze(dim=0),
                                       max_new_tokens=gen_len, **gen_kwargs)
        response_tensors.append(response.squeeze()[-gen_len:])
    batch['response'] = [gpt2_tokenizer.decode(r.squeeze()) for r in response_tensors]
    timing['time/get_response'] = time.time()-t

    #### Compute sentiment score
    t = time.time()
    texts = [q + r for q,r in zip(batch['query'], batch['response'])]
    pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
    #[[{'label': 'NEGATIVE', 'score': 0.27862095832824707}, {'label': 'POSITIVE', 'score': -0.5044471621513367}]]
    rewards = torch.tensor([output[1]["score"] for output in pipe_outputs]).to(device) #each output has negative score(output[0]) and positive score(output[1])
    #如果一个prompt目前是negative,它的positive score是-0.5,那么加到奖励里面,相当于让它少学这个
    timing['time/get_sentiment_preds'] = time.time()-t
    
    #### Run PPO step 
    t = time.time()
    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
    timing['time/optimization'] = time.time()-t
     
    #### Log everything
    timing['time/epoch'] = time.time()-t0
    table_rows = [list(r) for r in zip(batch['query'], batch['response'], rewards.cpu().tolist())]
    logs.update({'game_log': wandb.Table(columns=['query', 'response', 'reward'], rows=table_rows)})
    logs.update(timing)
    logs.update(stats)
    logs['env/reward_mean'] = torch.mean(rewards).cpu().numpy()
    logs['env/reward_std'] = torch.std(rewards).cpu().numpy()
    logs['env/reward_dist'] = rewards.cpu().numpy()
    wandb.log(logs)

在训练过程中观察仪表盘,发现reward是上升的,说明训练是有效的。

10.保存模型

gpt2_model.save_pretrained('gpt2-imdb-pos-v2', push_to_hub=True)
gpt2_tokenizer.save_pretrained('gpt2-imdb-pos-v2', push_to_hub=True)

11.模型评估

通过比较原始GPT产生的无限制的文本和微调后产生的受控的文本,我们发现微调过程明显地让模型产生出了正面情感倾向的影评。同样地,在合适的位置添加负号就可以重新训练出一个会产生负面情绪文本的GPT2。针对指定label产生伪数据,这在数据增强上具有很高的应用价值。

此外,本实验的奖励机制是情感倾向值,也可以把奖励机制换成任何你喜欢的评价指标,运用在其他话题的生成任务上,看看模型是否会按照这个方向来学习。

关注下方《学姐带你玩AI》🚀🚀🚀

回复“ACL

免费获取文本生成&机器学习顶会高分论文PDF

码字不易,欢迎大家点赞评论收藏!


http://www.niftyadmin.cn/n/128152.html

相关文章

【独家】华为OD机试 - 需要广播的服务器数量(C 语言解题)

最近更新的博客 华为od 2023 | 什么是华为od,od 薪资待遇,od机试题清单华为OD机试真题大全,用 Python 解华为机试题 | 机试宝典【华为OD机试】全流程解析+经验分享,题型分享,防作弊指南)华为od机试,独家整理 已参加机试人员的实战技巧文章目录 最近更新的博客使用说明本期…

2023年全国最新食品安全管理员精选真题及答案14

百分百题库提供食品安全管理员考试试题、食品安全员考试预测题、食品安全管理员考试真题、食品安全员证考试题库等,提供在线做题刷题,在线模拟考试,助你考试轻松过关。 131.食品生产企业在一年内()次因违反《中华人民共…

C语言之循环双链表

循环双链表&#xff1a;/***********循环双链表**********/ #include <stdio.h> #include <stdlib.h>typedef struct Dnode{ //循环双链表的数据结构 int data;struct Dnode *prior,*next; }LinkList;LinkList *LinkList_Create(void)…

Python3 JSON 数据解析 【Python学习连续,请关注】

Python3 JSON 数据解析JSON (JavaScript Object Notation) 是一种轻量级的数据交换格式。它基于 ECMAScript 的一个子集。Python3 中可以使用 json 模块来对 JSON 数据进行编解码&#xff0c;它包含了两个函数&#xff1a;json.dumps(): 对数据进行编码。json.loads(): 对数据进…

【信号与系统笔记】第一章 绪论

1.1信号传输系统 信息传输的任务 将带有信息的信号&#xff0c;通过某种系统由发送者传送给接收者。 通信系统的组成 转换器&#xff1a;把消息转换为电信号或者把电信号还原成消息信道&#xff1a;信号传输的通道&#xff0c;广义上来说。发射机和接收机也可以是信道的一部分…

MongoDB与MySQL有区别吗?用一个表格跟你说明

MongoDB MySQL 数据库模型 非关系型 关系型 存储方式 虚拟内存持久化 不同引擎有不同存储方式 查询语句 独特MongoDB查询方式 传统SQL语句 架构特点 可通过副本集和分片实现高可用 常见有单点、M-S、MHA、MMM、Cluster等架构方式 数据处理方式 基于内存&#xf…

PHP实现简单爬虫的方法

PHP实现简单爬虫的方法&#xff0c;php实现爬虫 本文实例讲述了PHP实现简单爬虫的方法。分享给大家供大家参考。具体如下&#xff1a; <?php /** * 爬虫程序 -- 原型 * * 从给定的url获取html内容 * * param string $url * return string */ function _getUrlContent($url…

数据模型(上):模型分类和模型组成

1.模型分类 ​ 数据模型是一种由符号、文本组成的集合,用以准确表达信息景观,达到有效交流、沟通的目的。数据建模者要求能与来自不同部门,具有不同技术背景,不同业务经验,不同技术水平的人员交流、沟通。数据建模者要了解每个人员的观点,并通过反馈证明理解无误,最终作…