Skip to content

Commit b7fb2b1

Browse files
authored
【开源实习】Bartpho模型微调 (#2030)
1 parent 67ac64b commit b7fb2b1

File tree

3 files changed

+512
-0
lines changed

3 files changed

+512
-0
lines changed
+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# finetune bartpho
2+
3+
## dataset resources
4+
5+
HuyPham235711/BARTpho_Corrector
6+
7+
文本纠错任务
8+
9+
## mindnlp+ascend
10+
11+
| Epoch | train_loss | eval_loss | BLEU |
12+
| :---: | ---------- | --------- | ------- |
13+
| 1 | 0.4209 | 0.3085 | 86.5538 |
14+
| 2 | 0.3052 | 0.2757 | 87.7926 |
15+
| 3 | 0.2307 | 0.2531 | 88.8163 |
16+
| 4 | 0.1864 | 0.2303 | 89.1350 |
17+
| 5 | 0.1535 | 0.2048 | 90.4295 |
18+
19+
## pytorch+cuda:
20+
21+
| Epoch | train_loss | eval_loss | BLEU |
22+
| :---: | ---------- | --------- | ------- |
23+
| 1 | 0.9188 | 0.3580 | 86.3481 |
24+
| 2 | 0.4843 | 0.3022 | 86.9550 |
25+
| 3 | 0.3602 | 0.2728 | 87.5816 |
26+
| 4 | 0.2866 | 0.2623 | 88.8584 |
27+
| 5 | 0.2312 | 0.2659 | 89.8917 |
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
import os
2+
import numpy as np
3+
import random
4+
import mindspore as ms
5+
from mindspore import nn, ops, Tensor, set_seed
6+
from mindspore.dataset import GeneratorDataset
7+
from mindnlp.transformers import AutoModelForSeq2SeqLM, BartphoTokenizer
8+
from mindnlp.engine import Trainer, TrainingArguments, TrainerCallback
9+
from datasets import load_dataset
10+
11+
import evaluate
12+
13+
# 加载评估指标
14+
sacrebleu_metric = evaluate.load("sacrebleu")
15+
16+
# 定义模型和数据路径
17+
MODEL_NAME = "vinai/bartpho-syllable"
18+
MAX_LENGTH = 32 # 最大序列长度
19+
output_dir = './saved_model_weights' # 模型保存路径
20+
os.makedirs(output_dir, exist_ok=True)
21+
22+
23+
# 自定义训练回调函数来打印每个epoch的loss
24+
class LossLoggerCallback(TrainerCallback):
25+
def on_epoch_end(self, args, state, control, **kwargs):
26+
"""在每个epoch结束时调用"""
27+
# 获取当前训练信息
28+
epoch = state.epoch
29+
loss = state.log_history[-1].get('loss', 0.0) if state.log_history else 0.0
30+
31+
# 打印当前epoch的训练loss
32+
print(f"Epoch {epoch}: train_loss = {loss:.6f}")
33+
34+
# 如果有评估结果,也打印出来
35+
if 'eval_loss' in state.log_history[-1]:
36+
eval_loss = state.log_history[-1].get('eval_loss', 0.0)
37+
eval_metric = state.log_history[-1].get('eval_sacrebleu', 0.0)
38+
print(f"Epoch {epoch}: eval_loss = {eval_loss:.6f}, eval_sacrebleu = {eval_metric:.4f}")
39+
40+
41+
# 数据预处理函数
42+
def preprocess_function(examples):
43+
# 对输入和目标文本进行tokenize
44+
return tokenizer(
45+
examples["error"],
46+
text_target=examples["original"],
47+
max_length=MAX_LENGTH,
48+
truncation=True,
49+
padding="max_length"
50+
)
51+
52+
53+
# 计算评估指标
54+
def compute_metrics(eval_preds):
55+
preds, labels = eval_preds
56+
57+
# 如果模型返回的是元组,取第一个元素(预测logits)
58+
if isinstance(preds, tuple):
59+
preds = preds[0]
60+
61+
# 解码预测和标签
62+
decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
63+
64+
# 处理标签中的pad token
65+
labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
66+
decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
67+
68+
# 简单的后处理
69+
decoded_preds = [pred.strip() for pred in decoded_preds]
70+
decoded_labels = [[label.strip()] for label in decoded_labels] # sacrebleu需要列表的列表
71+
72+
# 计算BLEU分数
73+
result = sacrebleu_metric.compute(
74+
predictions=decoded_preds,
75+
references=decoded_labels
76+
)
77+
78+
return {
79+
"sacrebleu": round(result["score"], 4)
80+
}
81+
82+
83+
# 为MindSpore创建数据集
84+
def create_mindspore_dataset(data, batch_size=8):
85+
data_list = list(data)
86+
87+
def generator():
88+
for item in data_list:
89+
yield (
90+
Tensor(item["input_ids"], dtype=ms.int32),
91+
Tensor(item["attention_mask"], dtype=ms.int32),
92+
Tensor(item["labels"], dtype=ms.int32)
93+
)
94+
95+
return GeneratorDataset(
96+
generator,
97+
column_names=["input_ids", "attention_mask", "labels"]
98+
).batch(batch_size)
99+
100+
101+
# 对logits进行预处理,防止内存溢出
102+
def preprocess_logits_for_metrics(logits, labels):
103+
"""防止内存溢出"""
104+
pred_ids = ms.mint.argmax(logits[0], dim=-1)
105+
return pred_ids, labels
106+
107+
108+
# 主函数
109+
def main():
110+
global tokenizer # 使tokenizer在函数外可用
111+
112+
# 加载模型和tokenizer
113+
print("正在加载模型和tokenizer...")
114+
tokenizer = BartphoTokenizer.from_pretrained(MODEL_NAME)
115+
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
116+
117+
# 加载数据集
118+
print("正在加载数据集...")
119+
train_path = './train.csv'
120+
test_path = './test.csv'
121+
dataset = load_dataset("csv", data_files={"train": train_path, "test": test_path})
122+
123+
print(f"训练集样本数: {len(dataset['train'])}")
124+
print(f"测试集样本数: {len(dataset['test'])}")
125+
126+
# 数据预处理
127+
print("正在进行数据预处理...")
128+
tokenized_datasets = dataset.map(
129+
preprocess_function,
130+
batched=True,
131+
remove_columns=dataset["train"].column_names,
132+
)
133+
134+
# 创建MindSpore数据集
135+
print("正在创建MindSpore数据集...")
136+
train_dataset = create_mindspore_dataset(tokenized_datasets["train"], batch_size=8)
137+
eval_dataset = create_mindspore_dataset(tokenized_datasets["test"], batch_size=8)
138+
139+
# 定义训练参数
140+
training_args = TrainingArguments(
141+
output_dir="./results",
142+
evaluation_strategy="epoch",
143+
learning_rate=1e-5,
144+
per_device_train_batch_size=8,
145+
per_device_eval_batch_size=8,
146+
num_train_epochs=5,
147+
weight_decay=0.01,
148+
save_strategy="epoch",
149+
save_total_limit=2,
150+
)
151+
152+
# 初始化训练器
153+
print("初始化训练器...")
154+
trainer = Trainer(
155+
model=model,
156+
args=training_args,
157+
train_dataset=train_dataset,
158+
eval_dataset=eval_dataset,
159+
tokenizer=tokenizer,
160+
compute_metrics=compute_metrics,
161+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
162+
callbacks=[LossLoggerCallback()]
163+
)
164+
165+
# 开始训练
166+
print("开始训练...")
167+
trainer.train()
168+
# 保存模型
169+
print(f"训练完成,保存模型到 {output_dir}...")
170+
model.save_pretrained(output_dir)
171+
# 模型评估
172+
print("进行模型最终评估...")
173+
eval_results = trainer.evaluate()
174+
print(f"最终评估结果: {eval_results}")
175+
176+
177+
if __name__ == "__main__":
178+
main()

0 commit comments

Comments
 (0)