Skip to content

Commit 207d77f

Browse files
authored
Bayesian models:naive bayes and bayesian networks.
1 parent b316a90 commit 207d77f

File tree

2 files changed

+448
-0
lines changed

2 files changed

+448
-0
lines changed
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"### bayesian network"
8+
]
9+
},
10+
{
11+
"cell_type": "code",
12+
"execution_count": 1,
13+
"metadata": {},
14+
"outputs": [],
15+
"source": [
16+
"# 导入pgmpy相关模块\n",
17+
"from pgmpy.factors.discrete import TabularCPD\n",
18+
"from pgmpy.models import BayesianModel\n",
19+
"letter_model = BayesianModel([('D', 'G'),\n",
20+
" ('I', 'G'),\n",
21+
" ('G', 'L'),\n",
22+
" ('I', 'S')])"
23+
]
24+
},
25+
{
26+
"cell_type": "code",
27+
"execution_count": 4,
28+
"metadata": {},
29+
"outputs": [],
30+
"source": [
31+
"# 学生成绩的条件概率分布\n",
32+
"grade_cpd = TabularCPD(\n",
33+
" variable='G', # 节点名称\n",
34+
" variable_card=3, # 节点取值个数\n",
35+
" values=[[0.3, 0.05, 0.9, 0.5], # 该节点的概率表\n",
36+
" [0.4, 0.25, 0.08, 0.3],\n",
37+
" [0.3, 0.7, 0.02, 0.2]],\n",
38+
" evidence=['I', 'D'], # 该节点的依赖节点\n",
39+
" evidence_card=[2, 2] # 依赖节点的取值个数\n",
40+
")\n",
41+
"# 考试难度的条件概率分布\n",
42+
"difficulty_cpd = TabularCPD(\n",
43+
" variable='D',\n",
44+
" variable_card=2,\n",
45+
" values=[[0.6], [0.4]]\n",
46+
")\n",
47+
"# 个人天赋的条件概率分布\n",
48+
"intel_cpd = TabularCPD(\n",
49+
" variable='I',\n",
50+
" variable_card=2,\n",
51+
" values=[[0.7], [0.3]]\n",
52+
")\n",
53+
"# 推荐信质量的条件概率分布\n",
54+
"letter_cpd = TabularCPD(\n",
55+
" variable='L',\n",
56+
" variable_card=2,\n",
57+
" values=[[0.1, 0.4, 0.99],\n",
58+
" [0.9, 0.6, 0.01]],\n",
59+
" evidence=['G'],\n",
60+
" evidence_card=[3]\n",
61+
")\n",
62+
"# SAT考试分数的条件概率分布\n",
63+
"sat_cpd = TabularCPD(\n",
64+
" variable='S',\n",
65+
" variable_card=2,\n",
66+
" values=[[0.95, 0.2],\n",
67+
" [0.05, 0.8]],\n",
68+
" evidence=['I'],\n",
69+
" evidence_card=[2]\n",
70+
")"
71+
]
72+
},
73+
{
74+
"cell_type": "code",
75+
"execution_count": 7,
76+
"metadata": {},
77+
"outputs": [
78+
{
79+
"name": "stderr",
80+
"output_type": "stream",
81+
"text": [
82+
"WARNING:root:Replacing existing CPD for G\n",
83+
"WARNING:root:Replacing existing CPD for D\n",
84+
"WARNING:root:Replacing existing CPD for I\n",
85+
"WARNING:root:Replacing existing CPD for L\n",
86+
"WARNING:root:Replacing existing CPD for S\n",
87+
"Finding Elimination Order: : 100%|██████████████████████████████████████████████████████| 2/2 [00:00<00:00, 668.95it/s]\n",
88+
"Eliminating: L: 100%|███████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 285.72it/s]"
89+
]
90+
},
91+
{
92+
"name": "stdout",
93+
"output_type": "stream",
94+
"text": [
95+
"+------+----------+\n",
96+
"| G | phi(G) |\n",
97+
"+======+==========+\n",
98+
"| G(0) | 0.9000 |\n",
99+
"+------+----------+\n",
100+
"| G(1) | 0.0800 |\n",
101+
"+------+----------+\n",
102+
"| G(2) | 0.0200 |\n",
103+
"+------+----------+\n"
104+
]
105+
},
106+
{
107+
"name": "stderr",
108+
"output_type": "stream",
109+
"text": [
110+
"\n"
111+
]
112+
}
113+
],
114+
"source": [
115+
"# 将各节点添加到模型中,构建贝叶斯网络\n",
116+
"letter_model.add_cpds(\n",
117+
" grade_cpd, \n",
118+
" difficulty_cpd,\n",
119+
" intel_cpd,\n",
120+
" letter_cpd,\n",
121+
" sat_cpd\n",
122+
")\n",
123+
"# 导入pgmpy贝叶斯推断模块\n",
124+
"from pgmpy.inference import VariableElimination\n",
125+
"# 贝叶斯网络推断\n",
126+
"letter_infer = VariableElimination(letter_model)\n",
127+
"# 天赋较好且考试不难的情况下推断该学生获得推荐信质量的好坏\n",
128+
"prob_G = letter_infer.query(\n",
129+
" variables=['G'],\n",
130+
" evidence={'I': 1, 'D': 0})\n",
131+
"print(prob_G)"
132+
]
133+
},
134+
{
135+
"cell_type": "code",
136+
"execution_count": null,
137+
"metadata": {},
138+
"outputs": [],
139+
"source": []
140+
}
141+
],
142+
"metadata": {
143+
"kernelspec": {
144+
"display_name": "Python 3",
145+
"language": "python",
146+
"name": "python3"
147+
},
148+
"language_info": {
149+
"codemirror_mode": {
150+
"name": "ipython",
151+
"version": 3
152+
},
153+
"file_extension": ".py",
154+
"mimetype": "text/x-python",
155+
"name": "python",
156+
"nbconvert_exporter": "python",
157+
"pygments_lexer": "ipython3",
158+
"version": "3.7.3"
159+
},
160+
"toc": {
161+
"base_numbering": 1,
162+
"nav_menu": {},
163+
"number_sections": true,
164+
"sideBar": true,
165+
"skip_h1_title": false,
166+
"title_cell": "Table of Contents",
167+
"title_sidebar": "Contents",
168+
"toc_cell": false,
169+
"toc_position": {},
170+
"toc_section_display": true,
171+
"toc_window_display": false
172+
}
173+
},
174+
"nbformat": 4,
175+
"nbformat_minor": 4
176+
}

0 commit comments

Comments
 (0)