Chatglm-6B+Deepspeed+PTuningv2 多卡高效微调

21 天前 · 来自专栏 简单的算法笔记

前言


是目前中文自主开源LLM中很优秀的项目,官方给的文档及教程突出非常简便上手,使用过程体验很不错。

Chatglm官方给出的

以及代码路径 ChatGLM-6B/ptuning/ 脚本中 train.sh train_chat.sh 都是未使用 deepspeed 下进行 p-tuning v2 的微调。同时给出 ds_train_finetune.sh 是使用 deepspeed 但未使用 p-tuning v2 进行的 全部参数 微调。

网上目前微调文档教程基本参照官方使用指南进行补充,关于使用deepspeed微调都是全参数进行。然而一般人资源有限,无法随便进行全参数微调。本人在尝试过程中发现只需要基于官方给出的脚本进行简单修改就可以实现在deepspeed框架上进行ptuningv2 高效微调, 节省资源。

因此这里写一个笔记分享以下两种情况:

  • 未使用deepspeed的多卡ptuning微调。
  • 使用 deepspeed的多卡ptuning微调

未使用deepspeed的多卡ptuning微调

未使用deepspeed进行多卡ptuning微调方式很简单,只需要将 train.sh train_chat.sh 中参数 CUDA_VISIBLE_DEVICES=0 增加对应的显卡编号,例如使用 4 张显卡进行训练:

PRE_SEQ_LEN=128
LR=2e-2
CUDA_VISIBLE_DEVICES=0,1,2,3 python3 main.py \
--do_train \
--train_file AdvertiseGen/train.json \
--validation_file AdvertiseGen/dev.json \
--prompt_column content \
--response_column summary \
--overwrite_cache \
--model_name_or_path THUDM/chatglm-6b \
--output_dir output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \
--overwrite_output_dir \
--max_source_length 64 \
--max_target_length 64 \
--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 16 \
--predict_with_generate \
--max_steps 3000 \
--logging_steps 10 \
--save_steps 1000 \
--learning_rate $LR \
--pre_seq_len $PRE_SEQ_LEN \
--quantization_bit 4

使用 deepspeed的多卡ptuning微调:

使用 deepspeed的进行ptuning微调修改的脚本是 ds_train_ptuning.sh ,多卡的方式只需要通过参数 num_gpus 控制微调使用的显卡个数。

然而原脚本 ds_train_ptuning.sh 是对模型全参数进行微调,这里通过查看脚本中实际调用的训练文件 main.py 中可以看出,这里的逻辑是如果参数 pre_seq_len 不为空,则模型会采用 p-tuning v2 的方式,冻结全部的模型参数,只训练通过 pre_seq_len 进行长度设置的 soft promt 部分的参数。

main.py中决定微调方式的逻辑

那么不同脚本即可以通过参数 pre_seq_len 来决定训练过程中main.py使用什么方式(全参数、ptuningv2)进行微调。 因此只需要在 ds_train_ptuning.sh 中类似 train_chat.sh 增加一个参数 pre_seq_len 即可将全参数微调修改为ptuning-v2微调。

LR=1e-4
MASTER_PORT=$(shuf -n 1 -i 10000-65535)
deepspeed --num_gpus=4 --master_port $MASTER_PORT main.py \
--deepspeed deepspeed.json \
--do_train \
--train_file AdvertiseGen/train.json \
--test_file AdvertiseGen/dev.json \
--prompt_column content \
--response_column summary \
--overwrite_cache \
--model_name_or_path THUDM/chatglm-6b \
--output_dir ./output/adgen-chatglm-6b-ft-$LR \
--overwrite_output_dir \
--max_source_length 64 \
--max_target_length 64 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--predict_with_generate \