|
279 | 279 | "metadata": {}, |
280 | 280 | "outputs": [], |
281 | 281 | "source": [ |
282 | | - "# TODO (exercise): code here" |
| 282 | + "# Let's checkout the reward space\n", |
| 283 | + "obs = env.reset()\n", |
| 284 | + "rewards = []\n", |
| 285 | + "done = False\n", |
| 286 | + "while not done:\n", |
| 287 | + " action = # TODO (exercise): code here\n", |
| 288 | + " obs, reward, done, info = env.step(action)\n", |
| 289 | + " rewards.append(reward)" |
283 | 290 | ] |
284 | 291 | }, |
285 | 292 | { |
|
608 | 615 | " DQN,\n", |
609 | 616 | " param_space=bandit_config_offline.to_dict(),\n", |
610 | 617 | " run_config=air.RunConfig(\n", |
611 | | - " local_dir=\"./results_notebook/offline_bandit/\",\n", |
612 | | - " stop={\"training_iteration\": 100},\n", |
| 618 | + " local_dir=\"./results_notebook/offline_bandits/\",\n", |
| 619 | + " stop={\"training_iteration\": 1000},\n", |
613 | 620 | " )\n", |
614 | 621 | ")\n", |
615 | 622 | "offline_bandit_results = bandit_tuner.fit()" |
|
629 | 636 | "outputs": [], |
630 | 637 | "source": [ |
631 | 638 | "print('Mean Bandit Episode reward:')\n", |
632 | | - "offline_bandit_results[0].metrics['evaluation']['episode_reward_mean']" |
| 639 | + "offline_bandits_results[0].metrics['evaluation']['episode_reward_mean']" |
633 | 640 | ] |
634 | 641 | }, |
635 | 642 | { |
|
683 | 690 | " param_space=dqn_config_offline.to_dict(),\n", |
684 | 691 | " run_config=air.RunConfig(\n", |
685 | 692 | " local_dir=\"./results_notebook/offline_rl/\",\n", |
686 | | - " stop={\"training_iteration\": 100},\n", |
| 693 | + " stop={\"training_iteration\": 30},\n", |
687 | 694 | " )\n", |
688 | 695 | ")\n", |
689 | | - "offline_dqn_results = dqn_tuner.fit()" |
| 696 | + "offline_rl_results = dqn_tuner.fit()" |
690 | 697 | ] |
691 | 698 | }, |
692 | 699 | { |
|
696 | 703 | "outputs": [], |
697 | 704 | "source": [ |
698 | 705 | "print('Mean DQN Episode reward:')\n", |
699 | | - "offline_dqn_results[0].metrics['evaluation']['episode_reward_mean']" |
| 706 | + "offline_rl_results[0].metrics['evaluation']['episode_reward_mean']" |
700 | 707 | ] |
701 | 708 | }, |
702 | 709 | { |
|
705 | 712 | "metadata": {}, |
706 | 713 | "outputs": [], |
707 | 714 | "source": [ |
| 715 | + "import pandas as pd\n", |
| 716 | + "\n", |
708 | 717 | "# plot the results and compare to baselines\n", |
709 | | - "offline_dqn_df = pd.read_csv(\"saved_runs/dqn_offline/random_data/progress.csv\")" |
| 718 | + "offline_rl_df = pd.read_csv(\"saved_runs/offline_rl/progress.csv\")\n", |
| 719 | + "offline_bandits_df = pd.read_csv(\"saved_runs/offline_bandits/progress.csv\")" |
710 | 720 | ] |
711 | 721 | }, |
712 | 722 | { |
|
716 | 726 | "outputs": [], |
717 | 727 | "source": [ |
718 | 728 | "\n", |
719 | | - "sns.lineplot(data=offline_dqn_df, x=\"training_iteration\", y=\"evaluation/episode_reward_mean\", label=\"Offline_DQN\")\n", |
720 | | - "sns.lineplot(data=dqn_df, x=\"training_iteration\", y=\"episode_reward_mean\", label=\"Online_DQN\")\n", |
721 | | - "sns.lineplot(data=bandit_df, x=\"training_iteration\", y=\"episode_reward_mean\", label=\"Bandits\")\n", |
| 729 | + "sns.lineplot(data=offline_rl_df, x=\"training_iteration\", y=\"evaluation/episode_reward_mean\", label=\"Offline_DQN\")\n", |
| 730 | + "sns.lineplot(data=offline_bandits_df, x=\"training_iteration\", y=\"evaluation/episode_reward_mean\", label=\"Offline_Bandits\")\n", |
722 | 731 | "plt.axhline(random_baseline, color=\"red\", linestyle='--', label=\"random baseline\")\n", |
723 | 732 | "plt.legend()\n", |
724 | 733 | "plt.title('Offline RL vs. Baselines training performance')" |
|
0 commit comments