打开APP
userphoto
未登录

开通VIP,畅享免费电子书等14项超值服

开通VIP
Py之trl:trl(一款采用强化学习训练Transformer语言模型和稳定扩散模型的全栈库)的简介、安装、使用方法之详细攻略
Py之trl:trl(一款采用强化学习训练Transformer语言模型和稳定扩散模型的全栈库)的简介、安装、使用方法之详细攻略
trl的简介
TRL - Transformer Reinforcement Learning使用强化学习的全栈Transformer语言模型。trl 是一个全栈库,其中我们提供一组工具,用于通过强化学习训练Transformer语言模型和稳定扩散模型,从监督微调步骤(SFT)到奖励建模步骤(RM)再到近端策略优化(PPO)步骤。该库建立在Hugging Face 的 transformers 库之上。因此,可以通过 transformers 直接加载预训练语言模型。目前,大多数解码器架构和编码器-解码器架构都得到支持。请参阅文档或示例/文件夹,以查看示例代码片段以及如何运行这些工具。
GitHub地址:GitHub - huggingface/trl: Train transformer language models with reinforcement learning.
1、亮点
>> SFTTrainer:一个轻量级且友好的围绕transformer Trainer的包装器,可以在自定义数据集上轻松微调语言模型或适配器。
>> RewardTrainer: transformer Trainer的一个轻量级包装,可以轻松地微调人类偏好的语言模型(Reward Modeling)。
>> potrainer:用于语言模型的PPO训练器,它只需要(查询、响应、奖励)三元组来优化语言模型。
>> AutoModelForCausalLMWithValueHead & AutoModelForSeq2SeqLMWithValueHead:一个转换器模型,每个令牌有一个额外的标量输出,可以用作强化学习中的值函数。
>> 示例:使用BERT情感分类器训练GPT2生成积极的电影评论,仅使用适配器的完整RLHF,训练GPT-j减少毒性,Stack-Llama示例等。
2、PPO是如何工作的:PPO对语言模型微调三步骤,Rollout→Evaluation→Optimization
通过PPO对语言模型进行微调大致包括三个步骤:
Rollout
Rollout(展开):语言模型基于查询生成响应或继续,查询可以是句子的开头。
Evaluation
Evaluation(评估):使用一个函数、模型、人类反馈或它们的组合来评估查询和响应。重要的是,此过程应为每个查询/响应对产生一个标量值。
Optimization
Optimization(优化):这是最复杂的部分。在优化步骤中,使用查询/响应对来计算序列中token的对数概率。这是通过训练的模型和一个参考模型(通常是微调之前的预训练模型)来完成的。两个输出之间的KL-散度被用作附加奖励信号,以确保生成的响应不会偏离参考语言模型太远。然后,使用PPO训练主动语言模型。
这个过程在下面的示意图中说明。
trl的安装
pip install trl trl的使用方法
1、基础用法
(1)、如何使用库中的SFTTrainer
以下是如何使用库中的SFTTrainer的基本示例。SFTTrainer是用于轻松微调语言模型或适配器的transformers Trainer的轻量包装器。
# importsfrom datasets import load_datasetfrom trl import SFTTrainer# get datasetdataset = load_dataset("imdb", split="train")# get trainertrainer = SFTTrainer( "facebook/opt-350m", train_dataset=dataset, dataset_text_field="text", max_seq_length=512,)# traintrainer.train() (2)、如何使用库中的RewardTrainer
以下是如何使用库中的RewardTrainer的基本示例。RewardTrainer是用于轻松微调奖励模型或适配器的transformers Trainer的包装器。
# importsfrom transformers import AutoModelForSequenceClassification, AutoTokenizerfrom trl import RewardTrainer# load model and dataset - dataset needs to be in a specific formatmodel = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=1)tokenizer = AutoTokenizer.from_pretrained("gpt2")...# load trainertrainer = RewardTrainer( model=model, tokenizer=tokenizer, train_dataset=dataset,)# traintrainer.train() (3)、如何使用库中的PPOTrainer
以下是如何使用库中的PPOTrainer的基本示例。基于查询,语言模型创建响应,然后进行评估。评估可以是人工干预或另一个模型的输出。
# importsimport torchfrom transformers import AutoTokenizerfrom trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_modelfrom trl.core import respond_to_batch# get modelsmodel = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')model_ref = create_reference_model(model)tokenizer = AutoTokenizer.from_pretrained('gpt2')# initialize trainerppo_config = PPOConfig( batch_size=1,)# encode a queryquery_txt = "This morning I went to the "query_tensor = tokenizer.encode(query_txt, return_tensors="pt")# get model responseresponse_tensor = respond_to_batch(model, query_tensor)# create a ppo trainerppo_trainer = PPOTrainer(ppo_config, model, model_ref, tokenizer)# define a reward for response# (this could be any reward such as human feedback or output from another model)reward = [torch.tensor(1.0)]# train model for one step with ppotrain_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward) 2、进阶用法
LLMs之BELLE:源码解读(ppo_train.py文件)训练一个基于强化学习的自动对话生成模型—解析命令行参数→加载数据集(datasets库)→初始化模型分词器和PPOConfig配置参数(trl库)→模型训练(accelerate分布式训练+DeepSpeed推理加速,生成对话→计算奖励【评估生成质量】→执行PPO算法更新【改善生成文本的质量】)→模型保存之详细攻略
https://yunyaniu.blog.csdn.net/article/details/133865725
LLMs之BELLE:源码解读(dpo_train.py文件)训练一个基于强化学习的自动对话生成模型(DPO算法微调预训练语言模型)—解析命令行参数与初始化→加载数据集(json格式)→模型训练与评估之详细攻略
https://yunyaniu.blog.csdn.net/article/details/133873621
本站仅提供存储服务,所有内容均由用户发布,如发现有害或侵权内容,请点击举报
打开APP,阅读全文并永久保存 查看更多类似文章
猜你喜欢
类似文章
【热】打开小程序,算一算2024你的财运
使用QLoRA对Llama 2进行微调的详细笔记
社区供稿 | RLHF 实践中的框架使用与一些坑 (TRL, LMFlow)
Trapper: Transformer模型都在此!
☀️机器学习入门☀️(二) KNN分类算法 | 附加小练习
Python遇见机器学习 ---- 线性回归算法 Linear Regression
一口气发布1008种机器翻译模型,GitHub最火NLP项目大更新
更多类似文章 >>
生活服务
热点新闻
分享 收藏 导长图 关注 下载文章
绑定账号成功
后续可登录账号畅享VIP特权!
如果VIP功能使用有故障,
可点击这里联系客服!

联系客服