【AI】实现中文文章摘要的AI模型
推荐超级课程:
@TOC
简介
最近逛github查看AI文章摘要的模型,发现了这个模型:
@misc{alpaca,
author={Ziang Leng, Qiyuan Chen and Cheng Li},
title = {Luotuo: An Instruction-following Chinese Language model, LoRA tuning on LLaMA},
year = {2023},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/LC1332/Luotuo-Chinese-LLM}},
}
虽然是2年前(2023年)的项目,但是,运行试了下,文章摘要效果还是不错的。
下面记述一下测试步骤。
更详细内容可以参考源项目地址: https://github.com/LC1332/Luotuo-Silk-Road
安装库
!git clone https://github.com/LC1332/luotuo-silk-road.git ./luotuo_silk_road
!wget https://github.com/LC1332/Luotuo-Chinese-LLM/raw/main/notebook/utils.py
!cd luotuo_silk_road/TuoLing
!pip install bitsandbytes transformers==4.27.1 peft==0.4.0 sentencepiece cpm_kernels mdtex2html protobuf torch
注意: 这里由于本身项目的requirement.txt里的peft的版本用的当时的比较老的版本,因此这里我指定到了0.4.0。
加载模型
import os
import torch
from utils import DeviceMap
from transformers import AutoModel, AutoTokenizer
torch.set_default_tensor_type(torch.cuda.HalfTensor)
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained(
"THUDM/chatglm-6b",
trust_remote_code=True,
device_map=DeviceMap("ChatGLM").get()
)
获取 PEFT 模型
from peft import get_peft_model, LoraConfig, TaskType
peft_path = "./luotuo_silk_road/TuoLing/output/luotuoC.pt"
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=True,
r=8,
lora_alpha=32,
lora_dropout=0.1,
)
model = get_peft_model(model, peft_config)
model.load_state_dict(torch.load(peft_path), strict=False)
torch.set_default_tensor_type(torch.cuda.FloatTensor)
测试模型
from luotuo_silk_road.TuoLing.cover_alpaca2jsonl import format_example
def evaluate(instruction, input=None):
with torch.no_grad():
feature = format_example(
{"instruction": "请用20个字以内帮我总结以下内容:", "output": "", "input": f"{instruction}"}
)
input_text = feature["context"]
input_ids = tokenizer.encode(input_text, return_tensors="pt")
out = model.generate(input_ids=input_ids, max_length=2048, temperature=0)
answer = tokenizer.decode(out[0])
print(answer)
evaluate(input("""请输入要进行总结的文章内容(长度为2048以内):"""))