Skip to content

Commit 926e89a

Browse files
committed
Add Q learning agent test
1 parent d69b005 commit 926e89a

File tree

3 files changed

+189
-0
lines changed

3 files changed

+189
-0
lines changed

notebooks/ReinforcementLearning.ipynb

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -786,6 +786,195 @@
786786
"pseudocode('Q Learning Agent')"
787787
]
788788
},
789+
{
790+
"cell_type": "markdown",
791+
"metadata": {},
792+
"source": [
793+
"Let's test the Q-learning agent on the 4\\*3 cell world discussed above."
794+
]
795+
},
796+
{
797+
"cell_type": "code",
798+
"execution_count": 11,
799+
"metadata": {},
800+
"outputs": [
801+
{
802+
"name": "stdout",
803+
"output_type": "stream",
804+
"text": [
805+
"Cell \t:\tExpected Utility\n",
806+
"----------------------------------\n",
807+
"[1,1] \t:\t0.7447970297243381\n",
808+
"[1,2] \t:\t0.7824342221178695\n",
809+
"[1,3] \t:\t0.8183443976230272\n",
810+
"[2,1] \t:\t0.6603248055209805\n",
811+
"[2,3] \t:\t0.881618254644549\n",
812+
"[3,1] \t:\t0.5780849497361085\n",
813+
"[3,2] \t:\t0.41519898633959246\n",
814+
"[3,3] \t:\t0.9503633060769898\n",
815+
"[4,1] \t:\t-0.048891852584282955\n",
816+
"[4,2] \t:\t-1.0\n",
817+
"[4,3] \t:\t1.0\n"
818+
]
819+
},
820+
{
821+
"data": {
822+
"text/plain": [
823+
"null"
824+
]
825+
},
826+
"execution_count": 11,
827+
"metadata": {},
828+
"output_type": "execute_result"
829+
}
830+
],
831+
"source": [
832+
"import aima.core.environment.cellworld.*;\n",
833+
"import aima.core.learning.reinforcement.agent.QLearningAgent;\n",
834+
"import aima.core.learning.reinforcement.example.CellWorldEnvironment;\n",
835+
"import aima.core.probability.example.MDPFactory;\n",
836+
"import aima.core.util.JavaRandomizer;\n",
837+
"\n",
838+
"import java.util.*;;\n",
839+
"\n",
840+
"CellWorld<Double> cw = CellWorldFactory.createCellWorldForFig17_1();;\n",
841+
"CellWorldEnvironment cwe = new CellWorldEnvironment(\n",
842+
" cw.getCellAt(1, 1),\n",
843+
" cw.getCells(),\n",
844+
" MDPFactory.createTransitionProbabilityFunctionForFigure17_1(cw),\n",
845+
" new JavaRandomizer());\n",
846+
"QLearningAgent<Cell<Double>, CellWorldAction> qla = new QLearningAgent<Cell<Double>, CellWorldAction>(MDPFactory.createActionsFunctionForFigure17_1(cw), CellWorldAction.None, 0.2, 1.0, 5, 2.0);\n",
847+
"cwe.addAgent(qla);\n",
848+
"qla.reset();\n",
849+
"cwe.executeTrials(100000);\n",
850+
"System.out.println(\"Cell\" + \" \\t:\\t\" + \"Expected Utility\");\n",
851+
"System.out.println(\"----------------------------------\");\n",
852+
"Map<Cell<Double>, Double> U = qla.getUtility();\n",
853+
"for(int i = 1; i<=4; i++){\n",
854+
" for(int j = 1; j<=3; j++){\n",
855+
" if(i==2 && j==2) continue; //Ignore wall\n",
856+
" System.out.println(\"[\" + i + \",\" + j + \"]\" + \" \\t:\\t\" + U.get(cw.getCellAt(i,j)));\n",
857+
" }\n",
858+
"}"
859+
]
860+
},
861+
{
862+
"cell_type": "markdown",
863+
"metadata": {},
864+
"source": [
865+
"The learning curves of the Q-Learning agent for the $4∗3$ cell world are shown below."
866+
]
867+
},
868+
{
869+
"cell_type": "code",
870+
"execution_count": 10,
871+
"metadata": {},
872+
"outputs": [],
873+
"source": [
874+
"import aima.core.environment.cellworld.*;\n",
875+
"import aima.core.learning.reinforcement.agent.QLearningAgent;\n",
876+
"import aima.core.learning.reinforcement.example.CellWorldEnvironment;\n",
877+
"import aima.core.probability.example.MDPFactory;\n",
878+
"import aima.core.util.JavaRandomizer;\n",
879+
"\n",
880+
"import java.util.*;\n",
881+
"\n",
882+
"int numRuns = 20;\n",
883+
"int numTrialsPerRun = 10000;\n",
884+
"int rmseTrialsToReport = 500;\n",
885+
"int reportEveryN = 20;\n",
886+
"\n",
887+
"CellWorld<Double> cw = CellWorldFactory.createCellWorldForFig17_1();;\n",
888+
"CellWorldEnvironment cwe = new CellWorldEnvironment(\n",
889+
" cw.getCellAt(1, 1),\n",
890+
" cw.getCells(),\n",
891+
" MDPFactory.createTransitionProbabilityFunctionForFigure17_1(cw),\n",
892+
" new JavaRandomizer());\n",
893+
"QLearningAgent<Cell<Double>, CellWorldAction> qla = new QLearningAgent<Cell<Double>, CellWorldAction>(MDPFactory.createActionsFunctionForFigure17_1(cw), CellWorldAction.None, 0.2, 1.0, 5, 2.0);\n",
894+
"cwe.addAgent(qla);\n",
895+
"Map<Integer, List<Map<Cell<Double>, Double>>> runs = new HashMap<Integer, List<Map<Cell<Double>, Double>>>();\n",
896+
"for (int r = 0; r < numRuns; r++) {\n",
897+
" qla.reset();\n",
898+
" List<Map<Cell<Double>, Double>> trials = new ArrayList<Map<Cell<Double>, Double>>();\n",
899+
" for (int t = 0; t < numTrialsPerRun; t++) {\n",
900+
" cwe.executeTrial();\n",
901+
" if (0 == t % reportEveryN) {\n",
902+
" Map<Cell<Double>, Double> u = qla.getUtility();\n",
903+
" trials.add(u);\n",
904+
" }\n",
905+
" }\n",
906+
" runs.put(r, trials);\n",
907+
"}\n",
908+
"\n",
909+
"def T = [];\n",
910+
"def v4_3 = [];\n",
911+
"def v3_3 = [];\n",
912+
"def v1_3 = [];\n",
913+
"def v1_1 = [];\n",
914+
"def v3_2 = [];\n",
915+
"def v2_1 = [];\n",
916+
"double tmp = 0.0;\n",
917+
"for (int t = 0; t < (numTrialsPerRun/reportEveryN); t++) {\n",
918+
" T.add(t);\n",
919+
" Map<Cell<Double>, Double> u = runs.get(numRuns - 1).get(t);\n",
920+
" tmp = (u.containsKey(cw.getCellAt(4, 3)) ? u.get(cw.getCellAt(4, 3)) : 0.0);\n",
921+
" v4_3.add(tmp);\n",
922+
" tmp = (u.containsKey(cw.getCellAt(3, 3)) ? u.get(cw.getCellAt(3, 3)) : 0.0);\n",
923+
" v3_3.add(tmp);\n",
924+
" tmp = (u.containsKey(cw.getCellAt(1, 3)) ? u.get(cw.getCellAt(1, 3)) : 0.0);\n",
925+
" v1_3.add(tmp);\n",
926+
" tmp = (u.containsKey(cw.getCellAt(1, 1)) ? u.get(cw.getCellAt(1, 1)) : 0.0);\n",
927+
" v1_1.add(tmp);\n",
928+
" tmp = (u.containsKey(cw.getCellAt(3, 2)) ? u.get(cw.getCellAt(3, 2)) : 0.0);\n",
929+
" v3_2.add(tmp);\n",
930+
" tmp = (u.containsKey(cw.getCellAt(2, 1)) ? u.get(cw.getCellAt(2, 1)) : 0.0);\n",
931+
" v2_1.add(tmp);\n",
932+
"}\n",
933+
"\n",
934+
"def p1 = new Plot(title: \"Learning Curve\", yLabel: \"Utility estimates\", xLabel: \"Number of trails\");\n",
935+
"p1 << new Line(x: T, y: v4_3, displayName: \"v4_3\")\n",
936+
"p1 << new Line(x: T, y: v3_3, displayName: \"v3_3\")\n",
937+
"p1 << new Line(x: T, y: v1_3, displayName: \"v1_3\")\n",
938+
"p1 << new Line(x: T, y: v1_1, displayName: \"v1_1\")\n",
939+
"p1 << new Line(x: T, y: v3_2, displayName: \"v3_2\")\n",
940+
"p1 << new Line(x: T, y: v2_1, displayName: \"v2_1\")\n",
941+
"\n",
942+
"def trails = [];\n",
943+
"def rmseValues = [];\n",
944+
"for (int t = 0; t < rmseTrialsToReport; t++) {\n",
945+
" trails.add(t);\n",
946+
" double xSsquared = 0;\n",
947+
" for (int r = 0; r < numRuns; r++) {\n",
948+
" Map<Cell<Double>, Double> u = runs.get(r).get(t);\n",
949+
" Double val1_1 = u.get(cw.getCellAt(1, 1));\n",
950+
" xSsquared += Math.pow(0.705 - val1_1, 2);\n",
951+
" }\n",
952+
" double rmse = Math.sqrt(xSsquared/runs.size());\n",
953+
" rmseValues.add(rmse);\n",
954+
"}\n",
955+
"def p2 = new Plot(yLabel: \"RMS error in utility\", xLabel: \"Number of trails\");\n",
956+
"p2 << new Line(x: trails, y: rmseValues)\n",
957+
"OutputCell.HIDDEN"
958+
]
959+
},
960+
{
961+
"cell_type": "markdown",
962+
"metadata": {},
963+
"source": [
964+
"[![Utility estimates][1]][1]\n",
965+
"\n",
966+
"[1]: assets/reinforcement_learning/q_utility_estimates.png"
967+
]
968+
},
969+
{
970+
"cell_type": "markdown",
971+
"metadata": {},
972+
"source": [
973+
"[![RMS error in utility][1]][1]\n",
974+
"\n",
975+
"[1]: assets/reinforcement_learning/q_RMSerror.png"
976+
]
977+
},
789978
{
790979
"cell_type": "code",
791980
"execution_count": null,
52.3 KB
Loading
133 KB
Loading

0 commit comments

Comments
 (0)