|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "markdown", |
| 5 | + "metadata": {}, |
| 6 | + "source": [ |
| 7 | + "### EM" |
| 8 | + ] |
| 9 | + }, |
| 10 | + { |
| 11 | + "cell_type": "code", |
| 12 | + "execution_count": 1, |
| 13 | + "metadata": {}, |
| 14 | + "outputs": [], |
| 15 | + "source": [ |
| 16 | + "# 导入numpy库 \n", |
| 17 | + "import numpy as np\n", |
| 18 | + "\n", |
| 19 | + "### EM算法过程函数定义\n", |
| 20 | + "def em(data, thetas, max_iter=30, eps=1e-3):\n", |
| 21 | + " '''\n", |
| 22 | + " 输入:\n", |
| 23 | + " data:观测数据\n", |
| 24 | + " thetas:初始化的估计参数值\n", |
| 25 | + " max_iter:最大迭代次数\n", |
| 26 | + " eps:收敛阈值\n", |
| 27 | + " 输出:\n", |
| 28 | + " thetas:估计参数\n", |
| 29 | + " '''\n", |
| 30 | + " # 初始化似然函数值\n", |
| 31 | + " ll_old = -np.infty\n", |
| 32 | + " for i in range(max_iter):\n", |
| 33 | + " ### E步:求隐变量分布\n", |
| 34 | + " # 对数似然\n", |
| 35 | + " log_like = np.array([np.sum(data * np.log(theta), axis=1) for theta in thetas])\n", |
| 36 | + " # 似然\n", |
| 37 | + " like = np.exp(log_like)\n", |
| 38 | + " # 求隐变量分布\n", |
| 39 | + " ws = like/like.sum(0)\n", |
| 40 | + " # 概率加权\n", |
| 41 | + " vs = np.array([w[:, None] * data for w in ws])\n", |
| 42 | + " ### M步:更新参数值\n", |
| 43 | + " thetas = np.array([v.sum(0)/v.sum() for v in vs])\n", |
| 44 | + " # 更新似然函数\n", |
| 45 | + " ll_new = np.sum([w*l for w, l in zip(ws, log_like)])\n", |
| 46 | + " print(\"Iteration: %d\" % (i+1))\n", |
| 47 | + " print(\"theta_B = %.2f, theta_C = %.2f, ll = %.2f\" \n", |
| 48 | + " % (thetas[0,0], thetas[1,0], ll_new))\n", |
| 49 | + " # 满足迭代条件即退出迭代\n", |
| 50 | + " if np.abs(ll_new - ll_old) < eps:\n", |
| 51 | + " break\n", |
| 52 | + " ll_old = ll_new\n", |
| 53 | + " return thetas" |
| 54 | + ] |
| 55 | + }, |
| 56 | + { |
| 57 | + "cell_type": "code", |
| 58 | + "execution_count": 2, |
| 59 | + "metadata": {}, |
| 60 | + "outputs": [ |
| 61 | + { |
| 62 | + "name": "stdout", |
| 63 | + "output_type": "stream", |
| 64 | + "text": [ |
| 65 | + "Iteration: 1\n", |
| 66 | + "theta_B = 0.71, theta_C = 0.58, ll = -32.69\n", |
| 67 | + "Iteration: 2\n", |
| 68 | + "theta_B = 0.75, theta_C = 0.57, ll = -31.26\n", |
| 69 | + "Iteration: 3\n", |
| 70 | + "theta_B = 0.77, theta_C = 0.55, ll = -30.76\n", |
| 71 | + "Iteration: 4\n", |
| 72 | + "theta_B = 0.78, theta_C = 0.53, ll = -30.33\n", |
| 73 | + "Iteration: 5\n", |
| 74 | + "theta_B = 0.79, theta_C = 0.53, ll = -30.07\n", |
| 75 | + "Iteration: 6\n", |
| 76 | + "theta_B = 0.79, theta_C = 0.52, ll = -29.95\n", |
| 77 | + "Iteration: 7\n", |
| 78 | + "theta_B = 0.80, theta_C = 0.52, ll = -29.90\n", |
| 79 | + "Iteration: 8\n", |
| 80 | + "theta_B = 0.80, theta_C = 0.52, ll = -29.88\n", |
| 81 | + "Iteration: 9\n", |
| 82 | + "theta_B = 0.80, theta_C = 0.52, ll = -29.87\n", |
| 83 | + "Iteration: 10\n", |
| 84 | + "theta_B = 0.80, theta_C = 0.52, ll = -29.87\n", |
| 85 | + "Iteration: 11\n", |
| 86 | + "theta_B = 0.80, theta_C = 0.52, ll = -29.87\n", |
| 87 | + "Iteration: 12\n", |
| 88 | + "theta_B = 0.80, theta_C = 0.52, ll = -29.87\n" |
| 89 | + ] |
| 90 | + } |
| 91 | + ], |
| 92 | + "source": [ |
| 93 | + "# 观测数据,5次独立试验,每次试验10次抛掷的正反次数\n", |
| 94 | + "# 比如第一次试验为5次正面5次反面\n", |
| 95 | + "observed_data = np.array([(5,5), (9,1), (8,2), (4,6), (7,3)])\n", |
| 96 | + "# 初始化参数值,即硬币B的正面概率为0.6,硬币C的正面概率为0.5\n", |
| 97 | + "thetas = np.array([[0.6, 0.4], [0.5, 0.5]])\n", |
| 98 | + "thetas = em(observed_data, thetas, max_iter=30, eps=1e-3)" |
| 99 | + ] |
| 100 | + }, |
| 101 | + { |
| 102 | + "cell_type": "code", |
| 103 | + "execution_count": 3, |
| 104 | + "metadata": {}, |
| 105 | + "outputs": [ |
| 106 | + { |
| 107 | + "data": { |
| 108 | + "text/plain": [ |
| 109 | + "array([[0.7967829 , 0.2032171 ],\n", |
| 110 | + " [0.51959543, 0.48040457]])" |
| 111 | + ] |
| 112 | + }, |
| 113 | + "execution_count": 3, |
| 114 | + "metadata": {}, |
| 115 | + "output_type": "execute_result" |
| 116 | + } |
| 117 | + ], |
| 118 | + "source": [ |
| 119 | + "thetas" |
| 120 | + ] |
| 121 | + }, |
| 122 | + { |
| 123 | + "cell_type": "code", |
| 124 | + "execution_count": null, |
| 125 | + "metadata": {}, |
| 126 | + "outputs": [], |
| 127 | + "source": [] |
| 128 | + } |
| 129 | + ], |
| 130 | + "metadata": { |
| 131 | + "kernelspec": { |
| 132 | + "display_name": "Python 3", |
| 133 | + "language": "python", |
| 134 | + "name": "python3" |
| 135 | + }, |
| 136 | + "language_info": { |
| 137 | + "codemirror_mode": { |
| 138 | + "name": "ipython", |
| 139 | + "version": 3 |
| 140 | + }, |
| 141 | + "file_extension": ".py", |
| 142 | + "mimetype": "text/x-python", |
| 143 | + "name": "python", |
| 144 | + "nbconvert_exporter": "python", |
| 145 | + "pygments_lexer": "ipython3", |
| 146 | + "version": "3.7.3" |
| 147 | + }, |
| 148 | + "toc": { |
| 149 | + "base_numbering": 1, |
| 150 | + "nav_menu": {}, |
| 151 | + "number_sections": true, |
| 152 | + "sideBar": true, |
| 153 | + "skip_h1_title": false, |
| 154 | + "title_cell": "Table of Contents", |
| 155 | + "title_sidebar": "Contents", |
| 156 | + "toc_cell": false, |
| 157 | + "toc_position": {}, |
| 158 | + "toc_section_display": true, |
| 159 | + "toc_window_display": false |
| 160 | + } |
| 161 | + }, |
| 162 | + "nbformat": 4, |
| 163 | + "nbformat_minor": 4 |
| 164 | +} |
0 commit comments