Py之trl:trl(一款采用强化学习训练Transformer语言模型和稳定扩散模型的全栈库)的简介、安装、使用方法之详细攻略

news/2024/5/19 0:53:18 标签: 强化学习, Transformer

Py之trl:trl(一款采用强化学习训练Transformer语言模型和稳定扩散模型的全栈库)的简介、安装、使用方法之详细攻略

目录

trl的简介

1、亮点

2、PPO是如何工作的:PPO对语言模型微调三步骤,Rollout→Evaluation→Optimization

trl的安装

trl的使用方法

1、基础用法

(1)、如何使用库中的SFTTrainer

(2)、如何使用库中的RewardTrainer

(3)、如何使用库中的PPOTrainer

2、进阶用法

LLMs之BELLE:源码解读(ppo_train.py文件)训练一个基于强化学习的自动对话生成模型—解析命令行参数→加载数据集(datasets库)→初始化模型分词器和PPOConfig配置参数(trl库)→模型训练(accelerate分布式训练+DeepSpeed推理加速,生成对话→计算奖励【评估生成质量】→执行PPO算法更新【改善生成文本的质量】)→模型保存之详细攻略

LLMs之BELLE:源码解读(dpo_train.py文件)训练一个基于强化学习的自动对话生成模型(DPO算法微调预训练语言模型)—解析命令行参数与初始化→加载数据集(json格式)→模型训练与评估之详细攻略


trl的简介

          TRL - Transformer Reinforcement Learning使用强化学习的全栈Transformer语言模型。trl 是一个全栈库,其中我们提供一组工具,用于通过强化学习训练Transformer语言模型和稳定扩散模型,从监督微调步骤(SFT)到奖励建模步骤(RM)再到近端策略优化(PPO)步骤。该库建立在Hugging Face 的 transformers 库之上。因此,可以通过 transformers 直接加载预训练语言模型。目前,大多数解码器架构和编码器-解码器架构都得到支持。请参阅文档或示例/文件夹,以查看示例代码片段以及如何运行这些工具。

GitHub地址:GitHub - huggingface/trl: Train transformer language models with reinforcement learning.

1、亮点

>> SFTTrainer:一个轻量级且友好的围绕transformer Trainer的包装器,可以在自定义数据集上轻松微调语言模型或适配器。

>> RewardTrainer: transformer Trainer的一个轻量级包装,可以轻松地微调人类偏好的语言模型(Reward Modeling)。

>> potrainer:用于语言模型的PPO训练器,它只需要(查询、响应、奖励)三元组来优化语言模型。

>> AutoModelForCausalLMWithValueHead & AutoModelForSeq2SeqLMWithValueHead:一个转换器模型,每个令牌有一个额外的标量输出,可以用作强化学习中的值函数。

>> 示例:使用BERT情感分类器训练GPT2生成积极的电影评论,仅使用适配器的完整RLHF,训练GPT-j减少毒性,Stack-Llama示例等。

2、PPO是如何工作的PPO对语言模型微调三步骤,Rollout→Evaluation→Optimization

通过PPO对语言模型进行微调大致包括三个步骤:

Rollout

Rollout(展开):语言模型基于查询生成响应或继续,查询可以是句子的开头。

Evaluation

Evaluation(评估):使用一个函数、模型、人类反馈或它们的组合来评估查询和响应。重要的是,此过程应为每个查询/响应对产生一个标量值

Optimization

Optimization(优化):这是最复杂的部分。在优化步骤中,使用查询/响应对来计算序列中token的对数概率。这是通过训练的模型和一个参考模型(通常是微调之前的预训练模型)来完成的。两个输出之间的KL-散度被用作附加奖励信号,以确保生成的响应不会偏离参考语言模型太远。然后,使用PPO训练主动语言模型。

这个过程在下面的示意图中说明。

trl的安装

pip install trl

trl的使用方法

1、基础用法

(1)、如何使用库中的SFTTrainer

以下是如何使用库中的SFTTrainer的基本示例。SFTTrainer是用于轻松微调语言模型或适配器的transformers Trainer的轻量包装器。

# imports
from datasets import load_dataset
from trl import SFTTrainer

# get dataset
dataset = load_dataset("imdb", split="train")

# get trainer
trainer = SFTTrainer(
    "facebook/opt-350m",
    train_dataset=dataset,
    dataset_text_field="text",
    max_seq_length=512,
)

# train
trainer.train()

(2)、如何使用库中的RewardTrainer

以下是如何使用库中的RewardTrainer的基本示例。RewardTrainer是用于轻松微调奖励模型或适配器的transformers Trainer的包装器。

# imports
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from trl import RewardTrainer

# load model and dataset - dataset needs to be in a specific format
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=1)
tokenizer = AutoTokenizer.from_pretrained("gpt2")

...

# load trainer
trainer = RewardTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
)

# train
trainer.train()

(3)、如何使用库中的PPOTrainer

以下是如何使用库中的PPOTrainer的基本示例。基于查询,语言模型创建响应,然后进行评估。评估可以是人工干预或另一个模型的输出。

# imports
import torch
from transformers import AutoTokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model
from trl.core import respond_to_batch

# get models
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
model_ref = create_reference_model(model)

tokenizer = AutoTokenizer.from_pretrained('gpt2')

# initialize trainer
ppo_config = PPOConfig(
    batch_size=1,
)

# encode a query
query_txt = "This morning I went to the "
query_tensor = tokenizer.encode(query_txt, return_tensors="pt")

# get model response
response_tensor  = respond_to_batch(model, query_tensor)

# create a ppo trainer
ppo_trainer = PPOTrainer(ppo_config, model, model_ref, tokenizer)

# define a reward for response
# (this could be any reward such as human feedback or output from another model)
reward = [torch.tensor(1.0)]

# train model for one step with ppo
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)

2、进阶用法

LLMs之BELLE:源码解读(ppo_train.py文件)训练一个基于强化学习的自动对话生成模型—解析命令行参数→加载数据集(datasets库)→初始化模型分词器和PPOConfig配置参数(trl库)→模型训练(accelerate分布式训练+DeepSpeed推理加速,生成对话→计算奖励【评估生成质量】→执行PPO算法更新【改善生成文本的质量】)→模型保存之详细攻略

https://yunyaniu.blog.csdn.net/article/details/133865725

LLMs之BELLE:源码解读(dpo_train.py文件)训练一个基于强化学习的自动对话生成模型(DPO算法微调预训练语言模型)—解析命令行参数与初始化→加载数据集(json格式)→模型训练与评估之详细攻略

https://yunyaniu.blog.csdn.net/article/details/133873621


http://www.niftyadmin.cn/n/5096152.html

相关文章

Mybatis <where>标签的小问题

我们都知道Mybatis的<where>标签是可以自动去掉子句第一个and或者or的&#xff0c;具体如下&#xff1a; <select id"selectSelective" resultType"com.study.entity.User">select * from t_user<where><if test"name ! null an…

Mini小主机All-in-one搭建教程4-安装Windows11系统

Mini小主机All-in-one搭建教程4-安装Windows11系统 硬件介绍 在狗东买的 极摩客M2 到手价是2799元 具体配置如下&#xff1a; 酷睿英特尔11代标压i7 11390H 64G1TB固态。 以下是 安装Windows11系统的教程。 安装Windows11系统 下载镜像包 首先下Windows系统的懒人镜像包&…

http代理有什么好处,怎么通过http代理服务安全上网呢?

通过http代理上网是一种常见的网络代理方式。http代理是指通过代理服务器进行网络连接&#xff0c;以实现隐藏自己的真实IP地址、保护个人隐私等目的。下面我们将介绍通过http代理上网的好处以及如何使用http代理服务来安全上网。 一、通过http代理上网的好处 1. 保护个人隐私 …

二维码智慧门牌管理系统:地址管理的现代革命

文章目录 前言一、标准地址的革新二、广泛的应用前景 前言 在科技不断发展和社会进步的背景下&#xff0c;高效、精准、智能的管理系统已经成为当今社会的迫切需求。传统的门牌管理系统在应对这一需求方面已显得力不从心&#xff0c;因此&#xff0c;二维码智慧门牌管理系统的…

“岗课赛证”融通的物联网综合实训室建设方案

一、概述 随着5G技术的普及应用和产业经济的革新发展,物联网产业所呈现的广阔前景带来了对创新型技术技能人才的迫切需求。高职院校物联网专业建设也因此转变为面向国家战略性新兴产业发展需求。当前,“岗位课程竞赛证书”融通的培育理念,是高职院校物联网人才培养和专业优化的…

【FreeRTOS】【STM32】06.1 FreeRTOS的使用1(对06的补充)

前后台系统(裸机) 裸机又称前后台系统&#xff0c;在一个while中不停循环处理各个task。 中断服务函数作为前台程序 大循环while(1)作为后台程序 多任务系统 通过任务调度的方式&#xff0c;执行各个任务&#xff0c;优先级高的先执行&#xff0c;执行完了释放CPU使用权&am…

LPWAN产业何时才能真正爆发?

导读&#xff1a; 虽然LPWAN目前还有重重困难&#xff0c;但是我们有充分的理由相信LPWAN即将爆发的趋势不变&#xff0c;当然&#xff0c;因为LPWAN是一个技术流派繁多的市场&#xff0c;除了LoRa、NB-IOT、eMTC还有RPMA、ZETA等等众多的技术流派&#xff0c;对于应用企业而言…

caxa复制尺寸错乱

复制粘贴&#xff0c;有些引出说明会错乱&#xff0c;可以利用快速选择&#xff0c;把引出说明打断