在 dair-ai/emotion 上微调 Gemma 4 E4B-it
本笔记本基于 dair-ai/emotion 对 google/gemma-4-E4B-it 进行微调。
文章来源:如何微调 Gemma 4:基于人类情绪数据集的完整实战指南
Colab 来源:在 dair-ai/emotion 上微调 Gemma 4 E4B-it
1. 安装依赖
python
%%capture
!pip install -U transformers accelerate datasets trl peft bitsandbytes scikit-learn huggingface_hub2. 使用 HF_TOKEN 登录
python
from huggingface_hub import login
from google.colab import userdata
try:
hf_token = userdata.get('HF_TOKEN')
login(token=hf_token)
print("Logged in to Hugging Face successfully.")
except userdata.SecretNotFoundError:
raise ValueError("HF_TOKEN not found in Colab secrets. Please add it via the 🔑 panel.")
except Exception as e:
print(f"An error occurred during login: {e}")3. 加载数据集
python
from datasets import load_dataset, DatasetDict
TRAIN_LIMIT = 4000
VALIDATION_LIMIT = 400
TEST_LIMIT = 400
EVAL_LIMIT = 400
raw_dataset = load_dataset("dair-ai/emotion")
def maybe_limit(split, limit):
split = split.shuffle(seed=42)
if limit is None:
return split
return split.select(range(min(limit, len(split))))
dataset = DatasetDict({
"train": maybe_limit(raw_dataset["train"], TRAIN_LIMIT),
"validation": maybe_limit(raw_dataset["validation"], VALIDATION_LIMIT),
"test": maybe_limit(raw_dataset["test"], TEST_LIMIT),
})
datasetpython
label_names = dataset["train"].features["label"].names
label_namespython
dataset["train"][0]4. 创建系统提示词
python
SYSTEM_PROMPT = """You are an emotion classification assistant.
Read the user's text and answer with exactly one label.
Only choose from: sadness, joy, love, anger, fear, surprise.
Return only the label and nothing else."""5. 将数据集转换为 TRL 的 prompt-completion 对话格式
python
def to_prompt_completion(example):
text = example["text"]
label = label_names[example["label"]]
return {
"prompt": [
{
"role": "system",
"content": SYSTEM_PROMPT,
},
{
"role": "user",
"content": f"Classify the emotion of this text:\n\n{text}",
},
],
"completion": [
{
"role": "assistant",
"content": label,
}
],
}
sft_dataset = dataset.map(to_prompt_completion, remove_columns=dataset["train"].column_names)python
sft_dataset["train"][0]6. 加载处理器和基础模型,用于微调前评估
python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
MODEL_ID = "google/gemma-4-E4B-it"
MODEL_DTYPE = torch.bfloat16
USE_4BIT = True
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
processor = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
if processor.pad_token is None:
processor.pad_token = processor.eos_token
bnb_config = None
model_kwargs = {
"device_map": "auto",
}
if USE_4BIT:
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=MODEL_DTYPE,
)
model_kwargs["quantization_config"] = bnb_config
else:
model_kwargs["torch_dtype"] = MODEL_DTYPE
base_model = AutoModelForCausalLM.from_pretrained(MODEL_ID, **model_kwargs)
base_model.config.use_cache = False
base_model.config.pad_token_id = processor.pad_token_id
base_model.config.bos_token_id = processor.bos_token_id
base_model.config.eos_token_id = processor.eos_token_id
base_model.generation_config.pad_token_id = processor.pad_token_id
base_model.generation_config.bos_token_id = processor.bos_token_id
base_model.generation_config.eos_token_id = processor.eos_token_id
print(f"Base model loaded with 4-bit={USE_4BIT} and dtype={MODEL_DTYPE}.")7. 用于 Gemma 4 对话格式的推理辅助函数
python
import re
LABEL_PATTERN = re.compile(r"\b(sadness|joy|love|anger|fear|surprise)\b", re.IGNORECASE)
def extract_label(raw_text: str) -> str:
raw_text = raw_text.strip().lower()
match = LABEL_PATTERN.search(raw_text)
if match:
return match.group(1)
first_token = raw_text.split()[0].strip(".,!?:;\"'()[]{}") if raw_text.split() else ""
return first_token
def generate_label(model, processor, user_text, system_prompt, max_new_tokens=4):
messages = [
{
"role": "system",
"content": system_prompt,
},
{
"role": "user",
"content": f"Classify the emotion of this text:\n\n{user_text}",
},
]
device = next(model.parameters()).device
inputs = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
).to(device)
input_len = inputs["input_ids"].shape[-1]
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
pad_token_id=processor.pad_token_id,
eos_token_id=processor.eos_token_id,
)
raw_pred = processor.decode(outputs[0][input_len:], skip_special_tokens=True).strip()
return extract_label(raw_pred)python
def predict_emotion(user_text: str, model=None, proc=None) -> str:
model = model or base_model
proc = proc or processor
return generate_label(model, proc, user_text, SYSTEM_PROMPT)
predict_emotion("I feel so happy and excited today!")8. 评估辅助函数
python
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, f1_score
import pandas as pd
from tqdm.auto import tqdm
VALID_LABELS = set(label_names)
ALL_EVAL_LABELS = label_names + ["INVALID"]
def evaluate_model(model, processor, split="test", limit=EVAL_LIMIT):
y_true, y_pred, rows = [], [], []
raw_source = dataset[split]
if limit is not None:
raw_source = raw_source.select(range(min(limit, len(raw_source))))
model.eval()
for ex in tqdm(raw_source, desc=f"Evaluating {split}", leave=False):
true_label = label_names[ex["label"]]
raw_pred_label = generate_label(model, processor, ex["text"], SYSTEM_PROMPT)
pred_label = raw_pred_label if raw_pred_label in VALID_LABELS else "INVALID"
y_true.append(true_label)
y_pred.append(pred_label)
rows.append({
"text": ex["text"],
"true_label": true_label,
"pred_label": pred_label,
"raw_pred_label": raw_pred_label,
"correct": true_label == pred_label,
})
metrics = {
"accuracy": accuracy_score(y_true, y_pred),
"macro_f1": f1_score(y_true, y_pred, labels=label_names, average="macro", zero_division=0),
"invalid_predictions": sum(1 for p in y_pred if p == "INVALID"),
"evaluated_examples": len(y_true),
}
report = classification_report(
y_true,
y_pred,
labels=label_names,
output_dict=True,
zero_division=0,
)
df = pd.DataFrame(rows)
return metrics, report, df
def confusion_matrix_df(pred_df):
return pd.DataFrame(
confusion_matrix(pred_df["true_label"], pred_df["pred_label"], labels=ALL_EVAL_LABELS),
index=ALL_EVAL_LABELS,
columns=ALL_EVAL_LABELS,
)9. 微调前评估
python
pre_metrics, pre_report, pre_preds = evaluate_model(base_model, processor, "test")
pre_metricspython
pd.DataFrame(pre_report).transpose()python
confusion_matrix_df(pre_preds)10. 配置 LoRA 适配器
python
from peft import LoraConfig
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules="all-linear"
)11. 定义训练配置
python
from trl import SFTConfig, SFTTrainer
training_args = SFTConfig(
output_dir="./gemma4-emotion-lora",
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
gradient_accumulation_steps=2,
learning_rate=1e-4,
weight_decay=0.01,
lr_scheduler_type="linear",
warmup_steps=50,
num_train_epochs=1,
logging_steps=50,
eval_strategy="steps",
metric_for_best_model="eval_loss",
greater_is_better=False,
gradient_checkpointing=True,
bf16=True,
fp16=False,
tf32=False,
max_length=256,
packing=False,
completion_only_loss=True,
remove_unused_columns=False,
dataloader_num_workers=2,
optim="paged_adamw_8bit",
report_to="none",
)12. 训练模型
python
from peft import PeftModel
if isinstance(base_model, PeftModel):
base_model = base_model.unload()
base_model.config.use_cache = False
trainer = SFTTrainer(
model=base_model,
train_dataset=sft_dataset["train"],
eval_dataset=sft_dataset["validation"],
peft_config=lora_config,
args=training_args,
processing_class=processor,
)
trainable_params = 0
for param in trainer.model.parameters():
if param.requires_grad:
trainable_params += param.numel()
if trainable_params == 0:
raise RuntimeError("No trainable LoRA parameters were attached. Check target_modules before training.")
print(f"Trainable LoRA parameters: {trainable_params:,}")
train_result = trainer.train()
trainer.model.eval()
trainer.model.config.use_cache = True
train_result13. 保存适配器和处理器
python
trainer.model.save_pretrained("./gemma4-emotion-lora")
processor.save_pretrained("./gemma4-emotion-lora")
print("Saved adapter and processor to ./gemma4-emotion-lora")python
# Push adapter + processor to the Hub
repo_id = "kingabzpro/gemma4-emotion-lora"
trainer.model.push_to_hub(
repo_id,
private=False,
)
processor.push_to_hub(
repo_id,
private=False,
)
print(f"Pushed to https://huggingface.co/{repo_id}")14. 无需重新加载模型即可运行微调后评估
python
# Reuse the in-memory fine-tuned model to avoid a second base-model load.
# On smaller GPUs, reloading after training often causes fragmentation or OOMs.
ft_model = trainer.model
ft_model.eval()
ft_model.config.use_cache = True
post_metrics, post_report, post_preds = evaluate_model(ft_model, processor, "test")
post_metricspython
pd.DataFrame(post_report).transpose()python
confusion_matrix_df(post_preds)15. 对比微调前后效果
python
comparison_df = pd.DataFrame([
{"stage": "pre_finetuning", **pre_metrics},
{"stage": "post_finetuning", **post_metrics},
])
comparison_dfpython
merged_examples = pre_preds.copy()
merged_examples = merged_examples.rename(columns={"pred_label": "pre_pred", "correct": "pre_correct"})
merged_examples["post_pred"] = post_preds["pred_label"]
merged_examples["post_correct"] = post_preds["correct"]
changed_predictions = merged_examples[merged_examples["pre_pred"] != merged_examples["post_pred"]]
changed_predictions.head(20)16. 使用内存中的微调后模型进行预测
python
def predict_emotion_ft(user_text: str) -> str:
return generate_label(ft_model, processor, user_text, SYSTEM_PROMPT)
predict_emotion_ft("I feel completely heartbroken and alone.")python
predict_emotion_ft("This is the best day of my life!")17. 可选:保存对比结果
python
comparison_df.to_csv("gemma4_emotion_before_after_metrics.csv", index=False)
merged_examples.to_csv("gemma4_emotion_prediction_examples.csv", index=False)
print("Saved CSV outputs.")python