1717 >>> import torch, numpy as np
1818 >>> from tianshou.data import Batch
1919 >>> data = Batch(a=4, b=[5, 5], c='2312312', d=('a', -2, -3))
20- >>> # the list will automatically be converted to numpy array
20+ >>> # 注意,list会自动变成numpy
2121 >>> data.b
2222 array([5, 5])
2323 >>> data.b = np.array([3, 4, 5])
3838 act: tensor([0., 6.]),
3939 )
4040
41- 总之就是可以定义任何key-value放在Batch里面,然后可以做一些常规的操作比如+-\* 、cat/stack之类的。`Understand Batch </en/master/tutorials/batch.html >`_ 里面详细描述了Batch的各种用法,非常值得一看。
41+ 总之就是可以定义任何key-value放在Batch里面,然后可以做一些常规的操作比如+-\* 、cat/stack之类的。`Understand Batch </en/master/tutorials/batch.html >`_ 里面详细描述了Batch的各种用法,非常值得一看(虽然它是英文的但只要看代码也还行?) 。
4242
4343
4444Buffer
@@ -58,39 +58,52 @@ Buffer
5858::
5959
6060 >>> import pickle, numpy as np
61- >>> from tianshou.data import ReplayBuffer
61+ >>> from tianshou.data import Batch, ReplayBuffer
6262 >>> buf = ReplayBuffer(size=20)
6363 >>> for i in range(3):
64- ... buf.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
65- >>> buf.obs
66- # 因为设置了 size = 20,所以 len(buf.obs) == 20
67- array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
68- 0., 0., 0., 0.])
64+ ... buf.add(Batch(obs=i, act=i, rew=i, done=0, obs_next=i + 1, info={}))
65+
66+ >>> buf.obs # 因为设置了 size = 20,所以 len(buf.obs) == 20
67+ array([0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
6968 >>> # 但是里面只有3个合法的数据,因此 len(buf) == 3
7069 >>> len(buf)
7170 3
7271 >>> pickle.dump(buf, open('buf.pkl', 'wb')) # 把buffer所有数据保存到 "buf.pkl"
72+ >>> buf.save_hdf5('buf.hdf5') # 把buffer所有数据保存到 "buf.hdf5"
73+
7374 >>> buf2 = ReplayBuffer(size=10)
7475 >>> for i in range(15):
75- ... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
76+ ... done = i % 4 == 0
77+ ... buf2.add(Batch(obs=i, act=i, rew=i, done=done, obs_next=i + 1, info={}))
7678 >>> len(buf2)
7779 10
78- >>> buf2.obs
79- # 因为 buf2 的 size = 10,所以它只会存储最后10步的结果
80- array([10., 11., 12., 13., 14., 5., 6., 7., 8., 9.])
80+ >>> buf2.obs # 因为 buf2 的 size = 10,所以它只会存储最后10步的结果
81+ array([10, 11, 12, 13, 14, 5, 6, 7, 8, 9])
82+
83+ >>> buf.update(buf2) # 把 buf2 的数据挪到buf里面,同时保持相对时间顺序
84+ >>> buf.obs
85+ array([ 0, 1, 2, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 0, 0, 0,
86+ 0, 0, 0, 0])
8187
82- >>> # 把 buf2 的数据挪到buf里面,同时保持相对时间顺序
83- >>> buf.update(buf2)
84- array([ 0., 1., 2., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14.,
85- 0., 0., 0., 0., 0., 0., 0.])
88+ >>> indice = buf.sample_index(0) # 使用 batchsize=0 来获取buffer里面的全部数据
89+ >>> indice
90+ array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
91+ >>> buf.prev(indice) # 给定index,计算上一个transition所对应的index
92+ array([ 0, 0, 1, 2, 3, 4, 5, 7, 7, 8, 9, 11, 11])
93+ >>> buf.next(indice) # 给定index,计算下一个transition所对应的index
94+ array([ 1, 2, 3, 4, 5, 6, 6, 8, 9, 10, 10, 12, 12])
8695
8796 >>> # 从buffer里面拿一个随机的数据,batch_data就是buf[indice]
8897 >>> batch_data, indice = buf.sample(batch_size=4)
8998 >>> batch_data.obs == buf[indice].obs
9099 array([ True, True, True, True])
91100 >>> len(buf)
92101 13
93- >>> buf = pickle.load(open('buf.pkl', 'rb')) # load from "buf.pkl"
102+
103+ >>> buf = pickle.load(open('buf.pkl', 'rb')) # 从"buf.pkl"文件恢复出buffer
104+ >>> len(buf)
105+ 3
106+ >>> buf = ReplayBuffer.load_hdf5('buf.hdf5') # 从"buf.hdf5"导入完整的buffer
94107 >>> len(buf)
95108 3
96109
@@ -100,46 +113,67 @@ Buffer
100113 >>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True)
101114 >>> for i in range(16):
102115 ... done = i % 5 == 0
103- ... buf.add(obs={'id': i}, act=i, rew=i, done=done,
104- ... obs_next={'id': i + 1})
116+ ... ptr, ep_rew, ep_len, ep_idx = buf.add(
117+ ... Batch(obs={'id': i}, act=i, rew=i,
118+ ... done=done, obs_next={'id': i + 1}))
119+ ... print(i, ep_len, ep_rew)
120+ 0 [1] [0.]
121+ 1 [0] [0.]
122+ 2 [0] [0.]
123+ 3 [0] [0.]
124+ 4 [0] [0.]
125+ 5 [5] [15.]
126+ 6 [0] [0.]
127+ 7 [0] [0.]
128+ 8 [0] [0.]
129+ 9 [0] [0.]
130+ 10 [5] [40.]
131+ 11 [0] [0.]
132+ 12 [0] [0.]
133+ 13 [0] [0.]
134+ 14 [0] [0.]
135+ 15 [5] [65.]
105136 >>> print(buf) # 可以发现obs_next并不在里面存着
106137 ReplayBuffer(
107- act: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
108- done: array([0., 1., 0., 0., 0., 0., 1., 0., 0.]),
109- info: Batch(),
110138 obs: Batch(
111- id: array([ 9. , 10. , 11. , 12. , 13. , 14. , 15. , 7. , 8. ]),
139+ id: array([ 9, 10, 11, 12, 13, 14, 15, 7, 8]),
112140 ),
113- policy: Batch( ),
141+ act: array([ 9, 10, 11, 12, 13, 14, 15, 7, 8] ),
114142 rew: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
143+ done: array([False, True, False, False, False, False, True, False,
144+ False]),
115145 )
116146 >>> index = np.arange(len(buf))
117147 >>> print(buf.get(index, 'obs').id)
118- [[ 7. 7. 8. 9.]
119- [ 7. 8. 9. 10.]
120- [11. 11. 11. 11.]
121- [11. 11. 11. 12.]
122- [11. 11. 12. 13.]
123- [11. 12. 13. 14.]
124- [12. 13. 14. 15.]
125- [ 7. 7. 7. 7.]
126- [ 7. 7. 7. 8.]]
127- >>> # 也可以这样取出stacked过的obs
128- >>> np.allclose(buf.get(index, 'obs')['id'], buf[index].obs.id)
129- True
130- >>> # 可以通过 __getitem__ 来弄出obs_next(虽然并没存)
148+ [[ 7 7 8 9]
149+ [ 7 8 9 10]
150+ [11 11 11 11]
151+ [11 11 11 12]
152+ [11 11 12 13]
153+ [11 12 13 14]
154+ [12 13 14 15]
155+ [ 7 7 7 7]
156+ [ 7 7 7 8]]
157+ >>> # 也可以这样取出stacked过的obs(注意stack只对obs/obs_next/info/policy有效)
158+ >>> abs(buf.get(index, 'obs')['id'] - buf[index].obs.id).sum().sum()
159+ 0
160+ >>> # 可以通过 __getitem__ 来弄出obs_next(虽然并没存),但是[:]会按照时间顺序(而不是实际存储顺序)来取数据
161+ >>> # 比如下面这个就相当于 index == [7, 8, 0, 1, 2, 3, 4, 5, 6]
131162 >>> print(buf[:].obs_next.id)
132- [[ 7. 8. 9. 10.]
133- [ 7. 8. 9. 10.]
134- [11. 11. 11. 12.]
135- [11. 11. 12. 13.]
136- [11. 12. 13. 14.]
137- [12. 13. 14. 15.]
138- [12. 13. 14. 15.]
139- [ 7. 7. 7. 8.]
140- [ 7. 7. 8. 9.]]
163+ [[ 7 7 7 8]
164+ [ 7 7 8 9]
165+ [ 7 8 9 10]
166+ [ 7 8 9 10]
167+ [11 11 11 12]
168+ [11 11 12 13]
169+ [11 12 13 14]
170+ [12 13 14 15]
171+ [12 13 14 15]]
172+ >>> full_index = np.array([7, 8, 0, 1, 2, 3, 4, 5, 6])
173+ >>> np.allclose(buf[:].obs_next.id, buf[full_index].obs_next.id)
174+ True
141175
142- 天授还提供了其他类型的buffer比如 :class: `~tianshou.data.ListReplayBuffer ` (基于list), :class: `~tianshou.data.PrioritizedReplayBuffer ` (基于线段树 )。可以访问对应的文档来查看。
176+ 天授还提供了其他类型的buffer比如 :class: `~tianshou.data.PrioritizedReplayBuffer ` (基于线段树)、 :class: `~tianshou.data.VectorReplayBuffer ` (能够向其中添加不同episode的数据的同时维护时间顺序 )。可以访问对应的文档来查看。
143177
144178
145179Policy
@@ -175,7 +209,7 @@ Policy
175209| Testing state | False | False |
176210+-----------------------------------+-----------------+-----------------+
177211
178- ``policy.updating `` 实际情况下主要用于exploration,比如在各种Q-Learning算法中,在不同的state切换探索策略 。
212+ ``policy.updating `` 实际情况下主要用于exploration,比如在各种Q-Learning算法中,在不同的policy state切换探索策略 。
179213
180214
181215policy.forward
@@ -198,7 +232,7 @@ policy.forward
198232 act = policy(batch).act[0] # policy.forward 返回一个 batch,使用 ".act" 来取出里面action的数据
199233 obs, rew, done, info = env.step(act)
200234
201- 这边 ``Batch(obs=[obs]) `` 会自动为obs下面的所有数据创建第0维,让它为batch size=1,否则nn没法确定batch size。
235+ 这边 ``Batch(obs=[obs]) `` 会自动为obs下面的所有数据创建第0维,让它为batch size=1,否则神经网络没法确定batch size。
202236
203237
204238.. _process_fn :
@@ -257,14 +291,34 @@ policy.process_fn
257291Collector
258292---------
259293
260- :class: `~tianshou.data.Collector ` 负责policy与env的交互和数据存储。:meth: `~tianshou.data.Collector.collect ` 是collector的主要方法,它能够指定让policy和环境交互至少 ``n_step `` 个step或者 ``n_episode `` 个episode,并把该过程中产生的数据存储到buffer中。
294+ :class: `~tianshou.data.Collector ` 负责policy与env的交互和数据存储。:meth: `~tianshou.data.Collector.collect ` 是collector的主要方法,它能够指定让policy和环境交互给定数目 ``n_step `` 个step或者 ``n_episode `` 个episode,并把该过程中产生的数据存储到buffer中。
295+
296+ :ref: `pseudocode ` 给出了一个宏观层面的解释,其他collector的功能可参考对应文档。此处列出一些常用用法:
297+
298+ ::
299+
300+ policy = PGPolicy(...) # 或者其他policy都可以
301+ env = gym.make("CartPole-v0")
302+
303+ replay_buffer = ReplayBuffer(size=10000)
261304
262- 为啥这边强调 **至少 **?因为天授使用了一个buffer来处理这些事情(当然还有一种方法是每个env对应单独的一个buffer)。如果用一个buffer做的话,需要维护若干个cache buffer,然后必须等到episode结束才能将cache buffer里面的数据转移到main buffer中,否则不能保证其中的时间顺序。
305+ # 这里单个env对应ReplayBuffer
306+ collector = Collector(policy, env, buffer=replay_buffer)
263307
264- 这么做有优点也有缺点,缺点是老是有人提issue,得手动加一个 ``gym.wrappers.TimeLimit `` 在env上面(如果env的done一直是False的话);优点是delayed update能够带来一定的性能提升,以及会大幅简化其他部分代码(比如PER、nstep、GAE这种就很简单,还有buffer.sample也还算简单,如果n个buffer的话就得多写很多代码)。
308+ # 多个env的话得用VectorReplayBuffer,但是collector仍然适用
309+ vec_buffer = VectorReplayBuffer(total_size=10000, buffer_num=3)
310+ # buffer_num推荐和env数量相等
311+ envs = DummyVectorEnv([lambda: gym.make("CartPole-v0") for _ in range(3)])
312+ collector = Collector(policy, envs, buffer=vec_buffer)
265313
266- :ref: `pseudocode ` 给出了一个宏观层面的解释,其他collector的功能可参考对应文档。
314+ # 收集3个episode
315+ collector.collect(n_episode=3)
316+ # 收集至少俩step(这个会收集三个,因为有三个env,每次收集的次数得是3的倍数)
317+ collector.collect(n_step=2)
318+ # 边收集变直播,使用render参数就可以(render传入的是时间间隔,以秒为单位)
319+ collector.collect(n_episode=1, render=0.03)
267320
321+ 还有个:class: `~tianshou.data.AsyncCollector `,继承了:class: `~tianshou.data.Collector `,它支持异步的环境采样(比如环境很慢或者step时间差异很大)。不过AsyncCollector的collect的语义和上面Collector有所不同,由于异步的特性,它只能保证**至少** ``n_step `` 或者 ``n_episode `` 地收集数据。
268322
269323Trainer
270324-------
0 commit comments