Skip to content

Commit ad42cb2

Browse files
committed
Add learning curves for passive-adp agent
1 parent 9625052 commit ad42cb2

File tree

3 files changed

+92
-24
lines changed

3 files changed

+92
-24
lines changed

notebooks/ReinforcementLearning.ipynb

Lines changed: 92 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -162,21 +162,23 @@
162162
},
163163
{
164164
"cell_type": "code",
165-
"execution_count": 22,
165+
"execution_count": 44,
166166
"metadata": {},
167167
"outputs": [
168168
{
169169
"name": "stdout",
170170
"output_type": "stream",
171171
"text": [
172-
"[1,1] \t:\t0.7128593117885544\n",
173-
"[1,2] \t:\t0.7680398391451688\n",
174-
"[1,3] \t:\t0.8178806550835265\n",
175-
"[2,1] \t:\t0.6628583416987663\n",
176-
"[2,3] \t:\t0.8746799974574001\n",
172+
"Cell \t:\tExpected Utility\n",
173+
"-------------------------------------\n",
174+
"[1,1] \t:\t0.7153193259024824\n",
175+
"[1,2] \t:\t0.7707398421463386\n",
176+
"[1,3] \t:\t0.8203828048081079\n",
177+
"[2,1] \t:\t0.6670047267920397\n",
178+
"[2,3] \t:\t0.8762199960076309\n",
177179
"[3,1] \t:\tnull\n",
178-
"[3,2] \t:\t0.6938189410949245\n",
179-
"[3,3] \t:\t0.9241799994408929\n",
180+
"[3,2] \t:\t0.7344940650463005\n",
181+
"[3,3] \t:\t0.9266999990710265\n",
180182
"[4,1] \t:\tnull\n",
181183
"[4,2] \t:\t-1.0\n",
182184
"[4,3] \t:\t1.0\n"
@@ -188,7 +190,7 @@
188190
"null"
189191
]
190192
},
191-
"execution_count": 22,
193+
"execution_count": 44,
192194
"metadata": {},
193195
"output_type": "execute_result"
194196
}
@@ -228,7 +230,8 @@
228230
"cwe.addAgent(padpa);\n",
229231
"padpa.reset();\n",
230232
"cwe.executeTrials(2000);\n",
231-
"\n",
233+
"System.out.println(\"Cell\" + \" \\t:\\t\" + \"Expected Utility\");\n",
234+
"System.out.println(\"-------------------------------------\");\n",
232235
"Map<Cell<Double>, Double> U = padpa.getUtility();\n",
233236
"for(int i = 1; i<=4; i++){\n",
234237
" for(int j = 1; j<=3; j++){\n",
@@ -254,20 +257,9 @@
254257
},
255258
{
256259
"cell_type": "code",
257-
"execution_count": 25,
260+
"execution_count": 42,
258261
"metadata": {},
259-
"outputs": [
260-
{
261-
"data": {
262-
"text/plain": [
263-
"null"
264-
]
265-
},
266-
"execution_count": 25,
267-
"metadata": {},
268-
"output_type": "execute_result"
269-
}
270-
],
262+
"outputs": [],
271263
"source": [
272264
"import aima.core.environment.cellworld.*;\n",
273265
"import aima.core.learning.reinforcement.agent.PassiveADPAgent;\n",
@@ -318,7 +310,83 @@
318310
" }\n",
319311
" }\n",
320312
" runs.put(r, trials);\n",
321-
"}"
313+
"}\n",
314+
"\n",
315+
"def T = [];\n",
316+
"def v4_3 = [];\n",
317+
"def v3_3 = [];\n",
318+
"def v1_3 = [];\n",
319+
"def v1_1 = [];\n",
320+
"def v3_2 = [];\n",
321+
"def v2_1 = [];\n",
322+
"double tmp = 0.0;\n",
323+
"for (int t = 0; t < (numTrialsPerRun / reportEveryN); t++) {\n",
324+
" T.add(t);\n",
325+
" Map<Cell<Double>, Double> u = runs.get(numRuns - 1).get(t);\n",
326+
" tmp = (u.containsKey(cw.getCellAt(4, 3)) ? u.get(cw.getCellAt(4, 3)) : 0.0);\n",
327+
" v4_3.add(tmp);\n",
328+
" tmp = (u.containsKey(cw.getCellAt(3, 3)) ? u.get(cw.getCellAt(3, 3)) : 0.0);\n",
329+
" v3_3.add(tmp);\n",
330+
" tmp = (u.containsKey(cw.getCellAt(1, 3)) ? u.get(cw.getCellAt(1, 3)) : 0.0);\n",
331+
" v1_3.add(tmp);\n",
332+
" tmp = (u.containsKey(cw.getCellAt(1, 1)) ? u.get(cw.getCellAt(1, 1)) : 0.0);\n",
333+
" v1_1.add(tmp);\n",
334+
" tmp = (u.containsKey(cw.getCellAt(3, 2)) ? u.get(cw.getCellAt(3, 2)) : 0.0);\n",
335+
" v3_2.add(tmp);\n",
336+
" tmp = (u.containsKey(cw.getCellAt(2, 1)) ? u.get(cw.getCellAt(2, 1)) : 0.0);\n",
337+
" v2_1.add(tmp);\n",
338+
"}\n",
339+
"\n",
340+
"def p1 = new Plot(title: \"Learning Curve\", yLabel: \"Utility estimates\", xLabel: \"Number of trails\");\n",
341+
"p1 << new Line(x: T, y: v4_3, displayName: \"v4_3\")\n",
342+
"p1 << new Line(x: T, y: v3_3, displayName: \"v3_3\")\n",
343+
"p1 << new Line(x: T, y: v1_3, displayName: \"v1_3\")\n",
344+
"p1 << new Line(x: T, y: v1_1, displayName: \"v1_1\")\n",
345+
"p1 << new Line(x: T, y: v3_2, displayName: \"v3_2\")\n",
346+
"p1 << new Line(x: T, y: v2_1, displayName: \"v2_1\")\n",
347+
"\n",
348+
"def trails = [];\n",
349+
"def rmseValues = [];\n",
350+
"for (int t = 0; t < rmseTrialsToReport; t++) {\n",
351+
" trails.add(t);\n",
352+
" double xSsquared = 0;\n",
353+
" for (int r = 0; r < numRuns; r++) {\n",
354+
" Map<Cell<Double>, Double> u = runs.get(r).get(t);\n",
355+
" Double val1_1 = u.get(cw.getCellAt(1, 1));\n",
356+
" xSsquared += Math.pow(0.705 - val1_1, 2);\n",
357+
" }\n",
358+
" double rmse = Math.sqrt(xSsquared / runs.size());\n",
359+
" rmseValues.add(rmse);\n",
360+
"}\n",
361+
"def p2 = new Plot(yLabel: \"RMS error in utility\", xLabel: \"Number of trails\");\n",
362+
"p2 << new Line(x: trails, y: rmseValues)\n",
363+
"OutputCell.HIDDEN"
364+
]
365+
},
366+
{
367+
"cell_type": "markdown",
368+
"metadata": {},
369+
"source": [
370+
"[![Utility estimates][1]][1]\n",
371+
"\n",
372+
"[1]: assets/reinforcement_learning/utility_estimates.png"
373+
]
374+
},
375+
{
376+
"cell_type": "markdown",
377+
"metadata": {},
378+
"source": [
379+
"[![RMS error in utility][1]][1]\n",
380+
"\n",
381+
"[1]: assets/reinforcement_learning/RMSerror.png"
382+
]
383+
},
384+
{
385+
"cell_type": "markdown",
386+
"metadata": {},
387+
"source": [
388+
"* The first figure shows the utility estimates for some of the states as a function of the number of trails. Notice the large changes occuring around the 63rd trial - this is the first time that the agent falls into the -1 terminal state at $[4,2]$. \n",
389+
"* The second plot shoes the root-mean-square error in the estimate for $U(1,1)$, averaged over 20 runs of 100 trails each."
322390
]
323391
},
324392
{
37 KB
Loading
56.4 KB
Loading

0 commit comments

Comments
 (0)