Skip to content

Commit d2210d4

Browse files
author
arkin-dev
committed
notebook(notebook): 新增 RLHF Demo(SFT/DPO/数据加载与模型加载单元)
1 parent a4c990a commit d2210d4

File tree

1 file changed

+332
-0
lines changed

1 file changed

+332
-0
lines changed

RLHF-demo/RLHF-Demo.ipynb

Lines changed: 332 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,332 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "5483e876",
6+
"metadata": {},
7+
"source": [
8+
"### 什么是 RLHF(Reinforcement Learning from Human Feedback)\n",
9+
"\n",
10+
"RLHF 是用“人类偏好”来对大语言模型进行对齐的一套训练范式:先让模型会做事,再让模型知道“什么更好”,最后用强化学习把“更好”的偏好真正优化进生成策略里。\n",
11+
"\n",
12+
"- **目标**:让模型更符合人类意图、更安全、更有用\n",
13+
"- **核心思想**:\n",
14+
" - 用监督微调(SFT)教会模型基本的指令跟随\n",
15+
" - 用偏好数据训练奖励模型(RM),学会打分“更好/更差”的回答\n",
16+
" - 用强化学习(PPO)在奖励信号下优化策略,权衡质量、稳定性与多样性\n",
17+
"- **关键组件**:指令数据、偏好数据(A/B 对比)、奖励模型、强化学习算法、KL 约束/参考策略\n",
18+
"- **典型产物**:\n",
19+
" - SFT 模型(会做事)\n",
20+
" - RM 奖励模型(会打分)\n",
21+
" - PPO 后的对齐模型(做得更好)\n",
22+
" - DPO (取缔RM+PPO)\n"
23+
]
24+
},
25+
{
26+
"cell_type": "markdown",
27+
"id": "59ff10f6",
28+
"metadata": {},
29+
"source": [
30+
"### 三、RLHF 的三阶段流程(工程化视角)\n",
31+
"\n",
32+
"| 阶段 | 名称 | 作用 | 技术 |\n",
33+
"|---|---|---|---|\n",
34+
"| 1️⃣ | SFT(监督微调) | 教模型执行指令 | CrossEntropyLoss |\n",
35+
"| 2️⃣ | Reward Model 训练 | 学会“什么样的回答更好” | Pairwise ranking (A > B) |\n",
36+
"| 3️⃣ | PPO 强化优化 | 用奖励信号优化生成策略 | PPO 算法(Policy Gradient) |\n",
37+
"\n",
38+
"#### 1️⃣ SFT(监督微调)\n",
39+
"- **输入**:指令-回答对(高质量、人类书写/筛选)\n",
40+
"- **目标**:让模型基本学会“按指令作答”\n",
41+
"- **训练**:最小化交叉熵损失(参考常用指令数据集)\n",
42+
"- **输出**:SFT 模型(作为后续 RM/PPO 的参考策略)\n",
43+
"\n",
44+
"#### 2️⃣ 奖励模型(RM)训练\n",
45+
"- **输入**:同一指令下成对回答(A、B),以及偏好标签(A > B)\n",
46+
"- **目标**:学习“偏好评分函数” r(x, y)\n",
47+
"- **训练**:Pairwise ranking(如 Bradley–Terry/Logistic loss)\n",
48+
"- **输出**:能对任意回答打分的奖励模型\n",
49+
"\n",
50+
"#### 3️⃣ PPO 强化优化\n",
51+
"- **输入**:SFT 模型作为初始策略 π_θ,奖励模型 r 作为奖励信号\n",
52+
"- **目标**:在 KL 约束下最大化期望奖励,提升对齐度与有用性\n",
53+
"- **训练**:PPO(剪切策略梯度),引入 KL 惩罚以保持与参考策略接近\n",
54+
"- **输出**:PPO 后的对齐模型(更符合人类偏好)\n",
55+
"\n",
56+
"> 实践要点:高质量偏好数据与稳定的 KL 控制是成功关键;监控长度偏置、模式坍缩与过拟合。\n",
57+
"\n",
58+
"#### DPO(Direct Preference Optimization)\n",
59+
"- **定位**:作为第 3 阶段(PPO)的常见替代方案,用偏好对直接优化策略。\n",
60+
"- **核心**:基于 `(x, y_pos, y_neg)` 提高 `y_pos` 概率、降低 `y_neg`,并以参考策略 `π_ref` 的对数概率差作隐式 KL 约束。\n",
61+
"- **直观目标**:最小化 `-log σ(β[(log πθ(y_pos|x) - log πθ(y_neg|x)) - (log πref(y_pos|x) - log πref(y_neg|x))])`\n",
62+
"- **优点**:流程简单、无奖励模型与 RL 回路、稳定易复现、吞吐高。\n",
63+
"- **局限**:依赖高质量偏好数据;极端分布迁移下可控性较弱。\n"
64+
]
65+
},
66+
{
67+
"cell_type": "markdown",
68+
"id": "777789fd",
69+
"metadata": {},
70+
"source": [
71+
"### 实验设置:模型与数据集选择\n",
72+
"\n",
73+
"- 模型:`Qwen2.5-1.5B-Instruct`(中文指令能力强,小参数、易于 LoRA/QLoRA)\n",
74+
"- SFT 数据:`BelleGroup/train_0.5M_CN`(中文指令-回答对,体量适中,可采样)\n",
75+
"- 偏好数据(用于 DPO/RM):`argilla/ultrafeedback-binarized-preferences`(成对偏好,易直接用于 DPO)\n",
76+
"\n",
77+
"下面先安装依赖并加载模型、抽样加载 SFT 数据(少量样本用于快速跑通)。\n"
78+
]
79+
},
80+
{
81+
"cell_type": "code",
82+
"execution_count": 1,
83+
"id": "8be2dae1",
84+
"metadata": {},
85+
"outputs": [
86+
{
87+
"name": "stdout",
88+
"output_type": "stream",
89+
"text": [
90+
"zsh:1: 4.44.0 not found\n",
91+
"Note: you may need to restart the kernel to use updated packages.\n"
92+
]
93+
}
94+
],
95+
"source": [
96+
"# 安装依赖(仅需首次)\n",
97+
"%pip -q install transformers>=4.44.0 accelerate datasets peft bitsandbytes trl>=0.9.6 sentencepiece\n",
98+
"\n"
99+
]
100+
},
101+
{
102+
"cell_type": "code",
103+
"execution_count": null,
104+
"id": "9a054388",
105+
"metadata": {},
106+
"outputs": [
107+
{
108+
"name": "stdout",
109+
"output_type": "stream",
110+
"text": [
111+
"[Info] CUDA 不可用,跳过 bitsandbytes 量化,改用 MPS/CPU.\n"
112+
]
113+
},
114+
{
115+
"data": {
116+
"application/vnd.jupyter.widget-view+json": {
117+
"model_id": "39701511a03045a9894de1a2d23a975b",
118+
"version_major": 2,
119+
"version_minor": 0
120+
},
121+
"text/plain": [
122+
"tokenizer_config.json: 0.00B [00:00, ?B/s]"
123+
]
124+
},
125+
"metadata": {},
126+
"output_type": "display_data"
127+
},
128+
{
129+
"data": {
130+
"application/vnd.jupyter.widget-view+json": {
131+
"model_id": "819e26a2ee33465cb0172f7968969ebf",
132+
"version_major": 2,
133+
"version_minor": 0
134+
},
135+
"text/plain": [
136+
"vocab.json: 0.00B [00:00, ?B/s]"
137+
]
138+
},
139+
"metadata": {},
140+
"output_type": "display_data"
141+
},
142+
{
143+
"data": {
144+
"application/vnd.jupyter.widget-view+json": {
145+
"model_id": "e59697db756a4b37bc7b25282bb87f48",
146+
"version_major": 2,
147+
"version_minor": 0
148+
},
149+
"text/plain": [
150+
"merges.txt: 0.00B [00:00, ?B/s]"
151+
]
152+
},
153+
"metadata": {},
154+
"output_type": "display_data"
155+
},
156+
{
157+
"data": {
158+
"application/vnd.jupyter.widget-view+json": {
159+
"model_id": "89992a0d96104244a1aa1d9e54d3a999",
160+
"version_major": 2,
161+
"version_minor": 0
162+
},
163+
"text/plain": [
164+
"tokenizer.json: 0.00B [00:00, ?B/s]"
165+
]
166+
},
167+
"metadata": {},
168+
"output_type": "display_data"
169+
},
170+
{
171+
"data": {
172+
"application/vnd.jupyter.widget-view+json": {
173+
"model_id": "8974f32f9d1c4528a9c22ba13d1e69f7",
174+
"version_major": 2,
175+
"version_minor": 0
176+
},
177+
"text/plain": [
178+
"config.json: 0%| | 0.00/660 [00:00<?, ?B/s]"
179+
]
180+
},
181+
"metadata": {},
182+
"output_type": "display_data"
183+
},
184+
{
185+
"data": {
186+
"application/vnd.jupyter.widget-view+json": {
187+
"model_id": "43f58ab1594d429b9b3e2e120af1f90e",
188+
"version_major": 2,
189+
"version_minor": 0
190+
},
191+
"text/plain": [
192+
"model.safetensors: 0%| | 0.00/3.09G [00:00<?, ?B/s]"
193+
]
194+
},
195+
"metadata": {},
196+
"output_type": "display_data"
197+
}
198+
],
199+
"source": [
200+
"import os\n",
201+
"import torch\n",
202+
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
203+
"\n",
204+
"model_name = \"Qwen/Qwen2.5-1.5B-Instruct\"\n",
205+
"\n",
206+
"use_cuda = torch.cuda.is_available()\n",
207+
"use_mps = torch.backends.mps.is_available()\n",
208+
"\n",
209+
"quant_config = None\n",
210+
"try:\n",
211+
" if use_cuda:\n",
212+
" from transformers import BitsAndBytesConfig # 仅在 CUDA 下尝试 4bit\n",
213+
" import importlib.metadata as im\n",
214+
" im.version(\"bitsandbytes\") # 检查安装\n",
215+
" quant_config = BitsAndBytesConfig(\n",
216+
" load_in_4bit=True,\n",
217+
" bnb_4bit_quant_type=\"nf4\",\n",
218+
" bnb_4bit_use_double_quant=True,\n",
219+
" bnb_4bit_compute_dtype=torch.bfloat16,\n",
220+
" )\n",
221+
" print(\"[Info] Using bitsandbytes 4-bit on CUDA.\")\n",
222+
" else:\n",
223+
" print(\"[Info] CUDA 不可用,跳过 bitsandbytes 量化,改用 MPS/CPU.\")\n",
224+
"except Exception as e:\n",
225+
" print(f\"[Warn] bitsandbytes 不可用或未安装:{e}. 将使用非量化加载。\")\n",
226+
"\n",
227+
"# 设备映射\n",
228+
"if use_cuda:\n",
229+
" device_map = \"auto\"\n",
230+
" dtype = torch.bfloat16\n",
231+
"elif use_mps:\n",
232+
" device_map = {\"\": \"mps\"}\n",
233+
" dtype = torch.float16\n",
234+
"else:\n",
235+
" device_map = {\"\": \"cpu\"}\n",
236+
" dtype = torch.float32\n",
237+
"\n",
238+
"# 加载 tokenizer / model(按可用性量化)\n",
239+
"tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, trust_remote_code=True)\n",
240+
"\n",
241+
"load_kwargs = dict(\n",
242+
" device_map=device_map,\n",
243+
" torch_dtype=dtype,\n",
244+
" trust_remote_code=True,\n",
245+
")\n",
246+
"if quant_config is not None:\n",
247+
" load_kwargs[\"quantization_config\"] = quant_config\n",
248+
"\n",
249+
"model = AutoModelForCausalLM.from_pretrained(model_name, **load_kwargs)\n",
250+
"\n",
251+
"print(f\"[Device] cuda={use_cuda}, mps={use_mps}, dtype={dtype}\")\n",
252+
"\n",
253+
"# 快速自检\n",
254+
"inputs = tokenizer(\"你好,简要介绍一下你自己。\", return_tensors=\"pt\")\n",
255+
"if use_mps:\n",
256+
" inputs = {k: v.to(\"mps\") for k, v in inputs.items()}\n",
257+
"else:\n",
258+
" inputs = {k: v.to(model.device) for k, v in inputs.items()}\n",
259+
"\n",
260+
"with torch.inference_mode():\n",
261+
" out = model.generate(**inputs, max_new_tokens=64, do_sample=False)\n",
262+
"print(tokenizer.decode(out[0], skip_special_tokens=True))\n",
263+
"\n"
264+
]
265+
},
266+
{
267+
"cell_type": "code",
268+
"execution_count": null,
269+
"id": "7715e506",
270+
"metadata": {},
271+
"outputs": [],
272+
"source": [
273+
"from datasets import load_dataset\n",
274+
"\n",
275+
"def _to_sft(example):\n",
276+
" instr = example.get(\"instruction\", \"\")\n",
277+
" inp = example.get(\"input\", \"\")\n",
278+
" output = example.get(\"output\", None)\n",
279+
" prompt = (instr + (\"\\n\" + inp if inp else \"\")).strip()\n",
280+
" return {\"prompt\": prompt, \"response\": output}\n",
281+
"\n",
282+
"# SFT:抽样加载 BELLE 中文指令数据\n",
283+
"sft_ds = load_dataset(\"BelleGroup/train_0.5M_CN\", split=\"train[:2000]\")\n",
284+
"sft_ds = sft_ds.map(_to_sft, remove_columns=sft_ds.column_names)\n",
285+
"print(\"SFT 样本示例:\", sft_ds[0])\n",
286+
"\n",
287+
"# 偏好数据:UltraFeedback(用于 DPO/RM)\n",
288+
"pref = load_dataset(\"argilla/ultrafeedback-binarized-preferences\", split=\"train[:5000]\")\n",
289+
"\n",
290+
"def _to_pref(ex):\n",
291+
" prompt = ex.get(\"prompt\") or ex.get(\"question\") or ex.get(\"instruction\")\n",
292+
" y_pos = ex.get(\"chosen\") or ex.get(\"better_response\")\n",
293+
" y_neg = ex.get(\"rejected\") or ex.get(\"worse_response\")\n",
294+
" return {\"prompt\": prompt, \"y_pos\": y_pos, \"y_neg\": y_neg}\n",
295+
"\n",
296+
"pref = pref.map(_to_pref)\n",
297+
"pref = pref.filter(lambda e: e[\"prompt\"] and e[\"y_pos\"] and e[\"y_neg\"]) # 保留完整样本\n",
298+
"print(\"偏好样本示例:\", {k: pref[0][k][:60] + \"...\" for k in [\"prompt\", \"y_pos\", \"y_neg\"]})\n",
299+
"\n"
300+
]
301+
},
302+
{
303+
"cell_type": "code",
304+
"execution_count": null,
305+
"id": "5f63aa6f",
306+
"metadata": {},
307+
"outputs": [],
308+
"source": []
309+
}
310+
],
311+
"metadata": {
312+
"kernelspec": {
313+
"display_name": "base",
314+
"language": "python",
315+
"name": "python3"
316+
},
317+
"language_info": {
318+
"codemirror_mode": {
319+
"name": "ipython",
320+
"version": 3
321+
},
322+
"file_extension": ".py",
323+
"mimetype": "text/x-python",
324+
"name": "python",
325+
"nbconvert_exporter": "python",
326+
"pygments_lexer": "ipython3",
327+
"version": "3.11.7"
328+
}
329+
},
330+
"nbformat": 4,
331+
"nbformat_minor": 5
332+
}

0 commit comments

Comments
 (0)