|
786 | 786 | "pseudocode('Q Learning Agent')" |
787 | 787 | ] |
788 | 788 | }, |
| 789 | + { |
| 790 | + "cell_type": "markdown", |
| 791 | + "metadata": {}, |
| 792 | + "source": [ |
| 793 | + "Let's test the Q-learning agent on the 4\\*3 cell world discussed above." |
| 794 | + ] |
| 795 | + }, |
| 796 | + { |
| 797 | + "cell_type": "code", |
| 798 | + "execution_count": 11, |
| 799 | + "metadata": {}, |
| 800 | + "outputs": [ |
| 801 | + { |
| 802 | + "name": "stdout", |
| 803 | + "output_type": "stream", |
| 804 | + "text": [ |
| 805 | + "Cell \t:\tExpected Utility\n", |
| 806 | + "----------------------------------\n", |
| 807 | + "[1,1] \t:\t0.7447970297243381\n", |
| 808 | + "[1,2] \t:\t0.7824342221178695\n", |
| 809 | + "[1,3] \t:\t0.8183443976230272\n", |
| 810 | + "[2,1] \t:\t0.6603248055209805\n", |
| 811 | + "[2,3] \t:\t0.881618254644549\n", |
| 812 | + "[3,1] \t:\t0.5780849497361085\n", |
| 813 | + "[3,2] \t:\t0.41519898633959246\n", |
| 814 | + "[3,3] \t:\t0.9503633060769898\n", |
| 815 | + "[4,1] \t:\t-0.048891852584282955\n", |
| 816 | + "[4,2] \t:\t-1.0\n", |
| 817 | + "[4,3] \t:\t1.0\n" |
| 818 | + ] |
| 819 | + }, |
| 820 | + { |
| 821 | + "data": { |
| 822 | + "text/plain": [ |
| 823 | + "null" |
| 824 | + ] |
| 825 | + }, |
| 826 | + "execution_count": 11, |
| 827 | + "metadata": {}, |
| 828 | + "output_type": "execute_result" |
| 829 | + } |
| 830 | + ], |
| 831 | + "source": [ |
| 832 | + "import aima.core.environment.cellworld.*;\n", |
| 833 | + "import aima.core.learning.reinforcement.agent.QLearningAgent;\n", |
| 834 | + "import aima.core.learning.reinforcement.example.CellWorldEnvironment;\n", |
| 835 | + "import aima.core.probability.example.MDPFactory;\n", |
| 836 | + "import aima.core.util.JavaRandomizer;\n", |
| 837 | + "\n", |
| 838 | + "import java.util.*;;\n", |
| 839 | + "\n", |
| 840 | + "CellWorld<Double> cw = CellWorldFactory.createCellWorldForFig17_1();;\n", |
| 841 | + "CellWorldEnvironment cwe = new CellWorldEnvironment(\n", |
| 842 | + " cw.getCellAt(1, 1),\n", |
| 843 | + " cw.getCells(),\n", |
| 844 | + " MDPFactory.createTransitionProbabilityFunctionForFigure17_1(cw),\n", |
| 845 | + " new JavaRandomizer());\n", |
| 846 | + "QLearningAgent<Cell<Double>, CellWorldAction> qla = new QLearningAgent<Cell<Double>, CellWorldAction>(MDPFactory.createActionsFunctionForFigure17_1(cw), CellWorldAction.None, 0.2, 1.0, 5, 2.0);\n", |
| 847 | + "cwe.addAgent(qla);\n", |
| 848 | + "qla.reset();\n", |
| 849 | + "cwe.executeTrials(100000);\n", |
| 850 | + "System.out.println(\"Cell\" + \" \\t:\\t\" + \"Expected Utility\");\n", |
| 851 | + "System.out.println(\"----------------------------------\");\n", |
| 852 | + "Map<Cell<Double>, Double> U = qla.getUtility();\n", |
| 853 | + "for(int i = 1; i<=4; i++){\n", |
| 854 | + " for(int j = 1; j<=3; j++){\n", |
| 855 | + " if(i==2 && j==2) continue; //Ignore wall\n", |
| 856 | + " System.out.println(\"[\" + i + \",\" + j + \"]\" + \" \\t:\\t\" + U.get(cw.getCellAt(i,j)));\n", |
| 857 | + " }\n", |
| 858 | + "}" |
| 859 | + ] |
| 860 | + }, |
| 861 | + { |
| 862 | + "cell_type": "markdown", |
| 863 | + "metadata": {}, |
| 864 | + "source": [ |
| 865 | + "The learning curves of the Q-Learning agent for the $4∗3$ cell world are shown below." |
| 866 | + ] |
| 867 | + }, |
| 868 | + { |
| 869 | + "cell_type": "code", |
| 870 | + "execution_count": 10, |
| 871 | + "metadata": {}, |
| 872 | + "outputs": [], |
| 873 | + "source": [ |
| 874 | + "import aima.core.environment.cellworld.*;\n", |
| 875 | + "import aima.core.learning.reinforcement.agent.QLearningAgent;\n", |
| 876 | + "import aima.core.learning.reinforcement.example.CellWorldEnvironment;\n", |
| 877 | + "import aima.core.probability.example.MDPFactory;\n", |
| 878 | + "import aima.core.util.JavaRandomizer;\n", |
| 879 | + "\n", |
| 880 | + "import java.util.*;\n", |
| 881 | + "\n", |
| 882 | + "int numRuns = 20;\n", |
| 883 | + "int numTrialsPerRun = 10000;\n", |
| 884 | + "int rmseTrialsToReport = 500;\n", |
| 885 | + "int reportEveryN = 20;\n", |
| 886 | + "\n", |
| 887 | + "CellWorld<Double> cw = CellWorldFactory.createCellWorldForFig17_1();;\n", |
| 888 | + "CellWorldEnvironment cwe = new CellWorldEnvironment(\n", |
| 889 | + " cw.getCellAt(1, 1),\n", |
| 890 | + " cw.getCells(),\n", |
| 891 | + " MDPFactory.createTransitionProbabilityFunctionForFigure17_1(cw),\n", |
| 892 | + " new JavaRandomizer());\n", |
| 893 | + "QLearningAgent<Cell<Double>, CellWorldAction> qla = new QLearningAgent<Cell<Double>, CellWorldAction>(MDPFactory.createActionsFunctionForFigure17_1(cw), CellWorldAction.None, 0.2, 1.0, 5, 2.0);\n", |
| 894 | + "cwe.addAgent(qla);\n", |
| 895 | + "Map<Integer, List<Map<Cell<Double>, Double>>> runs = new HashMap<Integer, List<Map<Cell<Double>, Double>>>();\n", |
| 896 | + "for (int r = 0; r < numRuns; r++) {\n", |
| 897 | + " qla.reset();\n", |
| 898 | + " List<Map<Cell<Double>, Double>> trials = new ArrayList<Map<Cell<Double>, Double>>();\n", |
| 899 | + " for (int t = 0; t < numTrialsPerRun; t++) {\n", |
| 900 | + " cwe.executeTrial();\n", |
| 901 | + " if (0 == t % reportEveryN) {\n", |
| 902 | + " Map<Cell<Double>, Double> u = qla.getUtility();\n", |
| 903 | + " trials.add(u);\n", |
| 904 | + " }\n", |
| 905 | + " }\n", |
| 906 | + " runs.put(r, trials);\n", |
| 907 | + "}\n", |
| 908 | + "\n", |
| 909 | + "def T = [];\n", |
| 910 | + "def v4_3 = [];\n", |
| 911 | + "def v3_3 = [];\n", |
| 912 | + "def v1_3 = [];\n", |
| 913 | + "def v1_1 = [];\n", |
| 914 | + "def v3_2 = [];\n", |
| 915 | + "def v2_1 = [];\n", |
| 916 | + "double tmp = 0.0;\n", |
| 917 | + "for (int t = 0; t < (numTrialsPerRun/reportEveryN); t++) {\n", |
| 918 | + " T.add(t);\n", |
| 919 | + " Map<Cell<Double>, Double> u = runs.get(numRuns - 1).get(t);\n", |
| 920 | + " tmp = (u.containsKey(cw.getCellAt(4, 3)) ? u.get(cw.getCellAt(4, 3)) : 0.0);\n", |
| 921 | + " v4_3.add(tmp);\n", |
| 922 | + " tmp = (u.containsKey(cw.getCellAt(3, 3)) ? u.get(cw.getCellAt(3, 3)) : 0.0);\n", |
| 923 | + " v3_3.add(tmp);\n", |
| 924 | + " tmp = (u.containsKey(cw.getCellAt(1, 3)) ? u.get(cw.getCellAt(1, 3)) : 0.0);\n", |
| 925 | + " v1_3.add(tmp);\n", |
| 926 | + " tmp = (u.containsKey(cw.getCellAt(1, 1)) ? u.get(cw.getCellAt(1, 1)) : 0.0);\n", |
| 927 | + " v1_1.add(tmp);\n", |
| 928 | + " tmp = (u.containsKey(cw.getCellAt(3, 2)) ? u.get(cw.getCellAt(3, 2)) : 0.0);\n", |
| 929 | + " v3_2.add(tmp);\n", |
| 930 | + " tmp = (u.containsKey(cw.getCellAt(2, 1)) ? u.get(cw.getCellAt(2, 1)) : 0.0);\n", |
| 931 | + " v2_1.add(tmp);\n", |
| 932 | + "}\n", |
| 933 | + "\n", |
| 934 | + "def p1 = new Plot(title: \"Learning Curve\", yLabel: \"Utility estimates\", xLabel: \"Number of trails\");\n", |
| 935 | + "p1 << new Line(x: T, y: v4_3, displayName: \"v4_3\")\n", |
| 936 | + "p1 << new Line(x: T, y: v3_3, displayName: \"v3_3\")\n", |
| 937 | + "p1 << new Line(x: T, y: v1_3, displayName: \"v1_3\")\n", |
| 938 | + "p1 << new Line(x: T, y: v1_1, displayName: \"v1_1\")\n", |
| 939 | + "p1 << new Line(x: T, y: v3_2, displayName: \"v3_2\")\n", |
| 940 | + "p1 << new Line(x: T, y: v2_1, displayName: \"v2_1\")\n", |
| 941 | + "\n", |
| 942 | + "def trails = [];\n", |
| 943 | + "def rmseValues = [];\n", |
| 944 | + "for (int t = 0; t < rmseTrialsToReport; t++) {\n", |
| 945 | + " trails.add(t);\n", |
| 946 | + " double xSsquared = 0;\n", |
| 947 | + " for (int r = 0; r < numRuns; r++) {\n", |
| 948 | + " Map<Cell<Double>, Double> u = runs.get(r).get(t);\n", |
| 949 | + " Double val1_1 = u.get(cw.getCellAt(1, 1));\n", |
| 950 | + " xSsquared += Math.pow(0.705 - val1_1, 2);\n", |
| 951 | + " }\n", |
| 952 | + " double rmse = Math.sqrt(xSsquared/runs.size());\n", |
| 953 | + " rmseValues.add(rmse);\n", |
| 954 | + "}\n", |
| 955 | + "def p2 = new Plot(yLabel: \"RMS error in utility\", xLabel: \"Number of trails\");\n", |
| 956 | + "p2 << new Line(x: trails, y: rmseValues)\n", |
| 957 | + "OutputCell.HIDDEN" |
| 958 | + ] |
| 959 | + }, |
| 960 | + { |
| 961 | + "cell_type": "markdown", |
| 962 | + "metadata": {}, |
| 963 | + "source": [ |
| 964 | + "[![Utility estimates][1]][1]\n", |
| 965 | + "\n", |
| 966 | + "[1]: assets/reinforcement_learning/q_utility_estimates.png" |
| 967 | + ] |
| 968 | + }, |
| 969 | + { |
| 970 | + "cell_type": "markdown", |
| 971 | + "metadata": {}, |
| 972 | + "source": [ |
| 973 | + "[![RMS error in utility][1]][1]\n", |
| 974 | + "\n", |
| 975 | + "[1]: assets/reinforcement_learning/q_RMSerror.png" |
| 976 | + ] |
| 977 | + }, |
789 | 978 | { |
790 | 979 | "cell_type": "code", |
791 | 980 | "execution_count": null, |
|
0 commit comments