ChatGLM-6B-PT
發(fā)布日期:2023/6/4 22:35:06 瀏覽量:
本倉庫實現(xiàn)了對于 ChatGLM-6B 模型基于 P-Tuning v2 的微調(diào)。P-Tuning v2 將需要微調(diào)的參數(shù)量減少到原來的 0.1%,再通過模型量化、Gradient Checkpoint 等方法,最低只需要 7GB 顯存即可運行。
下面以 ADGEN (廣告生成) 數(shù)據(jù)集為例介紹代碼的使用方法。
*Read this in English.
軟件依賴
運行微調(diào)需要4.27.1版本的transformers。除 ChatGLM-6B 的依賴之外,還需要安裝以下依賴
pip install rouge_chinese nltk jieba datasets
使用方法
下載數(shù)據(jù)集
ADGEN 數(shù)據(jù)集任務(wù)為根據(jù)輸入(content)生成一段廣告詞(summary)。
{ "content": "類型#上衣*版型#寬松*版型#顯瘦*圖案#線條*衣樣式#襯衫*衣袖型#泡泡袖*衣款式#抽繩", "summary": "這件襯衫的款式非常的寬松,利落的線條可以很好的隱藏身材上的小缺點,穿在身上有著很好的顯瘦效果。領(lǐng)口裝飾了一個可愛的抽繩,漂亮的繩結(jié)展現(xiàn)出了十足的個性,配合時尚的泡泡袖型,盡顯女性甜美可愛的氣息。" }
從 Google Drive 或者 Tsinghua Cloud 下載處理好的 ADGEN 數(shù)據(jù)集,將解壓后的 AdvertiseGen 目錄放到本目錄下。
訓(xùn)練
P-Tuning v2
運行以下指令進(jìn)行訓(xùn)練:
bash train.sh
train.sh 中的 PRE_SEQ_LEN 和 LR 分別是 soft prompt 長度和訓(xùn)練的學(xué)習(xí)率,可以進(jìn)行調(diào)節(jié)以取得最佳的效果。P-Tuning-v2 方法會凍結(jié)全部的模型參數(shù),可通過調(diào)整 quantization_bit 來被原始模型的量化等級,不加此選項則為 FP16 精度加載。
在默認(rèn)配置 quantization_bit=4、per_device_train_batch_size=1、gradient_accumulation_steps=16 下,INT4 的模型參數(shù)被凍結(jié),一次訓(xùn)練迭代會以 1 的批處理大小進(jìn)行 16 次累加的前后向傳播,等效為 16 的總批處理大小,此時最低只需 6.7G 顯存。若想在同等批處理大小下提升訓(xùn)練效率,可在二者乘積不變的情況下,加大 per_device_train_batch_size 的值,但也會帶來更多的顯存消耗,請根據(jù)實際情況酌情調(diào)整。
如果你想要從本地加載模型,可以將 train.sh 中的 THUDM/chatglm-6b 改為你本地的模型路徑。
Finetune
如果需要進(jìn)行全參數(shù)的 Finetune,需要安裝 Deepspeed,然后運行以下指令:
bash ds_train_finetune.sh
推理
在 P-tuning v2 訓(xùn)練時模型只保存 PrefixEncoder 部分的參數(shù),所以在推理時需要同時加載原 ChatGLM-6B 模型以及 PrefixEncoder 的權(quán)重,因此需要指定 evaluate.sh 中的參數(shù):
--model_name_or_path THUDM/chatglm-6b
--ptuning_checkpoint $CHECKPOINT_PATH
仍然兼容舊版全參保存的 Checkpoint,只需要跟之前一樣設(shè)定 model_name_or_path:
--model_name_or_path $CHECKPOINT_PATH
評測指標(biāo)為中文 Rouge score 和 BLEU-4。生成的結(jié)果保存在 ./output/adgen-chatglm-6b-pt-8-1e-2/generated_predictions.txt。
例子
示例1
- Input: 類型#上衣*材質(zhì)#牛仔布*顏色#白色*風(fēng)格#簡約*圖案#刺繡*衣樣式#外套*衣款式#破洞
- Label: 簡約而不簡單的牛仔外套,白色的衣身十分百搭。衣身多處有做舊破洞設(shè)計,打破單調(diào)乏味,增加一絲造型看點。衣身后背處有趣味刺繡裝飾,豐富層次感,彰顯別樣時尚。
- Output[微調(diào)前]: 這件上衣的材質(zhì)是牛仔布,顏色是白色,風(fēng)格是簡約,圖案是刺繡,衣樣式是外套,衣款式是破洞。
- Output[微調(diào)后]: 這是一款簡約的牛仔外套,破洞設(shè)計,將牛仔布破洞,帶來一種隨意與個性。破洞的牛仔外套,展現(xiàn)出時尚氣息,帶來一種休閑感。同時,刺繡圖案,讓整件外套更加立體。
示例2
- Input: 類型#裙*版型#顯瘦*風(fēng)格#文藝*風(fēng)格#簡約*圖案#印花*圖案#撞色*裙下擺#壓褶*裙長#連衣裙*裙領(lǐng)型#圓領(lǐng)
- Label: 文藝個性的印花連衣裙,藏青色底蘊,低調(diào)又大氣,撞色太陽花分布整個裙身,絢麗而美好,帶來時尚減齡的氣質(zhì)?;A(chǔ)款的舒適圓領(lǐng),簡約不失大方,勾勒精致臉龐。領(lǐng)后是一粒包布扣固定,穿脫十分方便。前片立體的打褶設(shè)計,搭配后片壓褶的做工,增添層次和空間感,顯瘦又有型。
- Output[微調(diào)前]: 類型#裙版型#顯瘦風(fēng)格#文藝風(fēng)格#簡約圖案#印花圖案#撞色裙下擺#壓褶裙長#連衣裙裙領(lǐng)型#圓領(lǐng) 1. 連衣裙:簡約風(fēng)格,裙長為膝蓋以上,裙領(lǐng)型為圓領(lǐng)。2. 裙下擺:壓褶設(shè)計,使裙擺呈現(xiàn)出流暢的褶皺效果。3. 裙領(lǐng)型:裙領(lǐng)型為圓領(lǐng),使穿上連衣裙后更加有型。4. 版型:采用顯瘦設(shè)計,讓連衣裙看起來更加苗條。5. 風(fēng)格:文藝風(fēng)格,讓連衣裙更加有內(nèi)涵和品味。6. 圖案:印花設(shè)計,在連衣裙上印有獨特的圖案。7. 撞色:采用撞色設(shè)計,讓連衣裙在色彩上更加鮮明、富有層次感。
- Output[微調(diào)后]: 這是一款文藝范的連衣裙,以印花為元素,采用簡約的印花,既能夠突出文藝氣質(zhì),又能夠展現(xiàn)簡約風(fēng)。在印花的同時又有領(lǐng)子和裙擺的壓褶設(shè)計,更加凸顯文藝氣質(zhì)。簡約而不會過于單調(diào),搭配出街,穿著十分舒適。
評估結(jié)果
| Finetune | P-tuning v2 | LoRA | |
|---|---|---|---|
| BLEU-4 | 8.01 | 8.10 | 7.62 |
| Rouge-1 | 31.23 | 31.12 | 30.60 |
| Rouge-2 | 7.36 | 7.11 | 6.96 |
| Rouge-l | 25.08 | 24.97 | 24.80 |
| Training Loss | 3.00 | 3.74 | 3.32 |
實驗設(shè)置
max_source_length=64 max_target_length=64 max_steps=3000
P-tuning v2
pre_seq_len=128 learning_rate=2e-2 quantization_bit=4 per_device_train_batch_size=16 gradient_accumulation_steps=1
Finetune
learning_rate=1e-4 fp16 num_gpus=4 per_device_train_batch_size=4 gradient_accumulation_steps=1
LoRA
實現(xiàn)采用的是 simple_thu_chatglm6b
learning_rate=5e-4 per_device_train_batch_size=16 gradient_accumulation_steps=1
模型部署
首先載入Tokenizer:
from transformers import AutoConfig, AutoModel, AutoTokenizer # 載入Tokenizer tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
- 如果需要加載的是新 Checkpoint(只包含 PrefixEncoder 參數(shù)):
config = AutoConfig.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True, pre_seq_len=128) model = AutoModel.from_pretrained("THUDM/chatglm-6b", config=config, trust_remote_code=True) prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin")) new_prefix_state_dict = {} for k, v in prefix_state_dict.items(): if k.startswith("transformer.prefix_encoder."): new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
注意你可能需要將 pre_seq_len 改成你訓(xùn)練時的實際值。如果你是從本地加載模型的話,需要將 THUDM/chatglm-6b 改成本地的模型路徑(注意不是checkpoint路徑)。
- 如果需要加載的是舊 Checkpoint(包含 ChatGLM-6B 以及 PrefixEncoder 參數(shù)),或者進(jìn)行的是全參數(shù)微調(diào),則直接加載整個 Checkpoint:
model = AutoModel.from_pretrained(CHECKPOINT_PATH, trust_remote_code=True)
之后根據(jù)需求可以進(jìn)行量化,也可以直接使用:
# Comment out the following line if you don’t use quantization model = model.quantize(4) model = model.half().cuda() model.transformer.prefix_encoder.float() model = model.eval() response, history = model.chat(tokenizer, "你好", history=[])
[23/04/19] 你也可以直接運行支持加載 P-Tuning v2 checkpoint 的 web demo
bash web_demo.sh
可能需要修改 web_demo.sh 的內(nèi)容以符合你實際的 checkpoint 情況。
使用自己的數(shù)據(jù)集
修改 train.sh 和 evaluate.sh 中的 train_file、validation_file和test_file為你自己的 JSON 格式數(shù)據(jù)集路徑,并將 prompt_column 和 response_column 改為 JSON 文件中輸入文本和輸出文本對應(yīng)的 KEY??赡苓€需要增大 max_source_length 和 max_target_length 來匹配你自己的數(shù)據(jù)集中的最大輸入輸出長度。
對話數(shù)據(jù)集
如需要使用多輪對話數(shù)據(jù)對模型進(jìn)行微調(diào),可以提供聊天歷史,例如以下是一個三輪對話的訓(xùn)練數(shù)據(jù):
{"prompt": "長城h3風(fēng)扇不轉(zhuǎn)。繼電器好的。保險絲好的傳感器新的風(fēng)扇也新的這是為什么。就是繼電器缺一個信號線", "response": "用電腦能讀數(shù)據(jù)流嗎?水溫多少", "history": []}
{"prompt": "95", "response": "上下水管溫差怎么樣啊?空氣是不是都排干凈了呢?", "history": [["長城h3風(fēng)扇不轉(zhuǎn)。繼電器好的。保險絲好的傳感器新的風(fēng)扇也新的這是為什么。就是繼電器缺一個信號線", "用電腦能讀數(shù)據(jù)流嗎?水溫多少"]]}
{"prompt": "是的。上下水管都好的", "response": "那就要檢查線路了,一般風(fēng)扇繼電器是由電腦控制吸合的,如果電路存在斷路,或者電腦壞了的話會出現(xiàn)繼電器不吸合的情況!", "history": [["長城h3風(fēng)扇不轉(zhuǎn)。繼電器好的。保險絲好的傳感器新的風(fēng)扇也新的這是為什么。就是繼電器缺一個信號線", "用電腦能讀數(shù)據(jù)流嗎?水溫多少"], ["95", "上下水管溫差怎么樣?。靠諝馐遣皇嵌寂鸥蓛袅四??"]]}
訓(xùn)練時需要指定 --history_column 為數(shù)據(jù)中聊天歷史的 key(在此例子中是 history),將自動把聊天歷史拼接。要注意超過輸入長度 max_source_length 的內(nèi)容會被截斷。
可以參考以下指令:
bash train_chat.sh
引用
@inproceedings{liu2022p,
title={P-tuning: Prompt tuning can be comparable to fine-tuning across scales and tasks},
author={Liu, Xiao and Ji, Kaixuan and Fu, Yicheng and Tam, Weng and Du, Zhengxiao and Yang, Zhilin and Tang, Jie},
booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers)},
pages={61--68},
year={2022}
}
馬上咨詢: 如果您有業(yè)務(wù)方面的問題或者需求,歡迎您咨詢!我們帶來的不僅僅是技術(shù),還有行業(yè)經(jīng)驗積累。
QQ: 39764417/308460098 Phone: 13 9800 1 9844 / 135 6887 9550 聯(lián)系人:石先生/雷先生