|
162 | 162 | }, |
163 | 163 | { |
164 | 164 | "cell_type": "code", |
165 | | - "execution_count": 22, |
| 165 | + "execution_count": 44, |
166 | 166 | "metadata": {}, |
167 | 167 | "outputs": [ |
168 | 168 | { |
169 | 169 | "name": "stdout", |
170 | 170 | "output_type": "stream", |
171 | 171 | "text": [ |
172 | | - "[1,1] \t:\t0.7128593117885544\n", |
173 | | - "[1,2] \t:\t0.7680398391451688\n", |
174 | | - "[1,3] \t:\t0.8178806550835265\n", |
175 | | - "[2,1] \t:\t0.6628583416987663\n", |
176 | | - "[2,3] \t:\t0.8746799974574001\n", |
| 172 | + "Cell \t:\tExpected Utility\n", |
| 173 | + "-------------------------------------\n", |
| 174 | + "[1,1] \t:\t0.7153193259024824\n", |
| 175 | + "[1,2] \t:\t0.7707398421463386\n", |
| 176 | + "[1,3] \t:\t0.8203828048081079\n", |
| 177 | + "[2,1] \t:\t0.6670047267920397\n", |
| 178 | + "[2,3] \t:\t0.8762199960076309\n", |
177 | 179 | "[3,1] \t:\tnull\n", |
178 | | - "[3,2] \t:\t0.6938189410949245\n", |
179 | | - "[3,3] \t:\t0.9241799994408929\n", |
| 180 | + "[3,2] \t:\t0.7344940650463005\n", |
| 181 | + "[3,3] \t:\t0.9266999990710265\n", |
180 | 182 | "[4,1] \t:\tnull\n", |
181 | 183 | "[4,2] \t:\t-1.0\n", |
182 | 184 | "[4,3] \t:\t1.0\n" |
|
188 | 190 | "null" |
189 | 191 | ] |
190 | 192 | }, |
191 | | - "execution_count": 22, |
| 193 | + "execution_count": 44, |
192 | 194 | "metadata": {}, |
193 | 195 | "output_type": "execute_result" |
194 | 196 | } |
|
228 | 230 | "cwe.addAgent(padpa);\n", |
229 | 231 | "padpa.reset();\n", |
230 | 232 | "cwe.executeTrials(2000);\n", |
231 | | - "\n", |
| 233 | + "System.out.println(\"Cell\" + \" \\t:\\t\" + \"Expected Utility\");\n", |
| 234 | + "System.out.println(\"-------------------------------------\");\n", |
232 | 235 | "Map<Cell<Double>, Double> U = padpa.getUtility();\n", |
233 | 236 | "for(int i = 1; i<=4; i++){\n", |
234 | 237 | " for(int j = 1; j<=3; j++){\n", |
|
254 | 257 | }, |
255 | 258 | { |
256 | 259 | "cell_type": "code", |
257 | | - "execution_count": 25, |
| 260 | + "execution_count": 42, |
258 | 261 | "metadata": {}, |
259 | | - "outputs": [ |
260 | | - { |
261 | | - "data": { |
262 | | - "text/plain": [ |
263 | | - "null" |
264 | | - ] |
265 | | - }, |
266 | | - "execution_count": 25, |
267 | | - "metadata": {}, |
268 | | - "output_type": "execute_result" |
269 | | - } |
270 | | - ], |
| 262 | + "outputs": [], |
271 | 263 | "source": [ |
272 | 264 | "import aima.core.environment.cellworld.*;\n", |
273 | 265 | "import aima.core.learning.reinforcement.agent.PassiveADPAgent;\n", |
|
318 | 310 | " }\n", |
319 | 311 | " }\n", |
320 | 312 | " runs.put(r, trials);\n", |
321 | | - "}" |
| 313 | + "}\n", |
| 314 | + "\n", |
| 315 | + "def T = [];\n", |
| 316 | + "def v4_3 = [];\n", |
| 317 | + "def v3_3 = [];\n", |
| 318 | + "def v1_3 = [];\n", |
| 319 | + "def v1_1 = [];\n", |
| 320 | + "def v3_2 = [];\n", |
| 321 | + "def v2_1 = [];\n", |
| 322 | + "double tmp = 0.0;\n", |
| 323 | + "for (int t = 0; t < (numTrialsPerRun / reportEveryN); t++) {\n", |
| 324 | + " T.add(t);\n", |
| 325 | + " Map<Cell<Double>, Double> u = runs.get(numRuns - 1).get(t);\n", |
| 326 | + " tmp = (u.containsKey(cw.getCellAt(4, 3)) ? u.get(cw.getCellAt(4, 3)) : 0.0);\n", |
| 327 | + " v4_3.add(tmp);\n", |
| 328 | + " tmp = (u.containsKey(cw.getCellAt(3, 3)) ? u.get(cw.getCellAt(3, 3)) : 0.0);\n", |
| 329 | + " v3_3.add(tmp);\n", |
| 330 | + " tmp = (u.containsKey(cw.getCellAt(1, 3)) ? u.get(cw.getCellAt(1, 3)) : 0.0);\n", |
| 331 | + " v1_3.add(tmp);\n", |
| 332 | + " tmp = (u.containsKey(cw.getCellAt(1, 1)) ? u.get(cw.getCellAt(1, 1)) : 0.0);\n", |
| 333 | + " v1_1.add(tmp);\n", |
| 334 | + " tmp = (u.containsKey(cw.getCellAt(3, 2)) ? u.get(cw.getCellAt(3, 2)) : 0.0);\n", |
| 335 | + " v3_2.add(tmp);\n", |
| 336 | + " tmp = (u.containsKey(cw.getCellAt(2, 1)) ? u.get(cw.getCellAt(2, 1)) : 0.0);\n", |
| 337 | + " v2_1.add(tmp);\n", |
| 338 | + "}\n", |
| 339 | + "\n", |
| 340 | + "def p1 = new Plot(title: \"Learning Curve\", yLabel: \"Utility estimates\", xLabel: \"Number of trails\");\n", |
| 341 | + "p1 << new Line(x: T, y: v4_3, displayName: \"v4_3\")\n", |
| 342 | + "p1 << new Line(x: T, y: v3_3, displayName: \"v3_3\")\n", |
| 343 | + "p1 << new Line(x: T, y: v1_3, displayName: \"v1_3\")\n", |
| 344 | + "p1 << new Line(x: T, y: v1_1, displayName: \"v1_1\")\n", |
| 345 | + "p1 << new Line(x: T, y: v3_2, displayName: \"v3_2\")\n", |
| 346 | + "p1 << new Line(x: T, y: v2_1, displayName: \"v2_1\")\n", |
| 347 | + "\n", |
| 348 | + "def trails = [];\n", |
| 349 | + "def rmseValues = [];\n", |
| 350 | + "for (int t = 0; t < rmseTrialsToReport; t++) {\n", |
| 351 | + " trails.add(t);\n", |
| 352 | + " double xSsquared = 0;\n", |
| 353 | + " for (int r = 0; r < numRuns; r++) {\n", |
| 354 | + " Map<Cell<Double>, Double> u = runs.get(r).get(t);\n", |
| 355 | + " Double val1_1 = u.get(cw.getCellAt(1, 1));\n", |
| 356 | + " xSsquared += Math.pow(0.705 - val1_1, 2);\n", |
| 357 | + " }\n", |
| 358 | + " double rmse = Math.sqrt(xSsquared / runs.size());\n", |
| 359 | + " rmseValues.add(rmse);\n", |
| 360 | + "}\n", |
| 361 | + "def p2 = new Plot(yLabel: \"RMS error in utility\", xLabel: \"Number of trails\");\n", |
| 362 | + "p2 << new Line(x: trails, y: rmseValues)\n", |
| 363 | + "OutputCell.HIDDEN" |
| 364 | + ] |
| 365 | + }, |
| 366 | + { |
| 367 | + "cell_type": "markdown", |
| 368 | + "metadata": {}, |
| 369 | + "source": [ |
| 370 | + "[![Utility estimates][1]][1]\n", |
| 371 | + "\n", |
| 372 | + "[1]: assets/reinforcement_learning/utility_estimates.png" |
| 373 | + ] |
| 374 | + }, |
| 375 | + { |
| 376 | + "cell_type": "markdown", |
| 377 | + "metadata": {}, |
| 378 | + "source": [ |
| 379 | + "[![RMS error in utility][1]][1]\n", |
| 380 | + "\n", |
| 381 | + "[1]: assets/reinforcement_learning/RMSerror.png" |
| 382 | + ] |
| 383 | + }, |
| 384 | + { |
| 385 | + "cell_type": "markdown", |
| 386 | + "metadata": {}, |
| 387 | + "source": [ |
| 388 | + "* The first figure shows the utility estimates for some of the states as a function of the number of trails. Notice the large changes occuring around the 63rd trial - this is the first time that the agent falls into the -1 terminal state at $[4,2]$. \n", |
| 389 | + "* The second plot shoes the root-mean-square error in the estimate for $U(1,1)$, averaged over 20 runs of 100 trails each." |
322 | 390 | ] |
323 | 391 | }, |
324 | 392 | { |
|
0 commit comments