Skip to content

Commit fafcd8d

Browse files
committed
update to v0.4.1
1 parent b5c5325 commit fafcd8d

File tree

8 files changed

+216
-136
lines changed

8 files changed

+216
-136
lines changed

benchmark.rst

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
标准测试
2+
====
3+
4+
Mujoco Benchmark
5+
----------------
6+
7+
天授目前拥有市面上所有平台中最好的Mujoco结果(甚至比 `SpinningUp <https://spinningup.openai.com/en/latest/spinningup/bench.html>`_ 还要好!)
8+
9+
看这里:https://github.com/thu-ml/tianshou/tree/master/examples/mujoco
10+
11+
Atari Benchmark
12+
---------------
13+
14+
戳这里:https://github.com/thu-ml/tianshou/tree/master/examples/atari

cheatsheet.rst

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,19 +83,23 @@
8383

8484
本条目与 `Issue 42 <https://github.com/thu-ml/tianshou/issues/42>`_ 相关。
8585

86-
如果想收集训练log、预处理图像数据(比如Atari要resize到84x84x3)、根据环境信息修改奖励函数的值,可以在Collector中使用 ``preprocess_fn`` 接口,它会在数据存入Buffer之前被调用。
86+
如果想收集训练log、预处理图像数据(比如Atari要resize到84x84x3 -- 不过这个推荐直接wrapper做)、根据环境信息修改奖励函数的值,可以在Collector中使用 ``preprocess_fn`` 接口,它会在数据存入Buffer之前被调用。
87+
88+
``preprocess_fn`` 有两种输入接口:如果是env.reset()的话,它只会接收obs;如果是正常的env.step(),那么他会接收5个关键字 "obs_next"/"rew"/"done"/"info"/"policy"。返回一个字典或者Batch,里面包含着你想修改的东西。
8789

88-
``preprocess_fn`` 接收7个保留关键字(obs/act/rew/done/obs_next/info/policy),以数据组(Batch)的形式返回需要修改的部分,比如可以像下面这个例子一样:
8990
::
9091

9192
import numpy as np
9293
from collections import deque
94+
95+
9396
class MyProcessor:
9497
def __init__(self, size=100):
9598
self.episode_log = None
9699
self.main_log = deque(maxlen=size)
97100
self.main_log.append(0)
98101
self.baseline = 0
102+
99103
def preprocess_fn(**kwargs):
100104
"""把reward给归一化"""
101105
if 'rew' not in kwargs:
@@ -119,7 +123,7 @@
119123
::
120124

121125
test_processor = MyProcessor(size=100)
122-
collector = Collector(policy, env, buffer, test_processor.preprocess_fn)
126+
collector = Collector(policy, env, buffer, preprocess_fn=test_processor.preprocess_fn)
123127

124128
还有一些示例在 `test/base/test_collector.py <https://github.com/thu-ml/tianshou/blob/master/test/base/test_collector.py>`_ 中可以查看。
125129

@@ -249,6 +253,8 @@ RNN训练
249253
当然如果自定义的环境中,状态是一个自定义的类,也是可以的。不过天授只会把它的地址进行存储,就像下面这样(状态是nx.Graph):
250254
::
251255

256+
>>> # 这个例子可能现在不太能work,因为numpy升级了,以及nx.Graph重写了__getitem__,导致np.array([nx.Graph()])会出来空的数组……
257+
>>> # 不过正常的自定义class应该没啥问题
252258
>>> import networkx as nx
253259
>>> b = ReplayBuffer(size=3)
254260
>>> b.add(obs=nx.Graph(), act=0, rew=0, done=0)

concepts.rst

Lines changed: 107 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ Batch
1717
>>> import torch, numpy as np
1818
>>> from tianshou.data import Batch
1919
>>> data = Batch(a=4, b=[5, 5], c='2312312', d=('a', -2, -3))
20-
>>> # the list will automatically be converted to numpy array
20+
>>> # 注意,list会自动变成numpy
2121
>>> data.b
2222
array([5, 5])
2323
>>> data.b = np.array([3, 4, 5])
@@ -38,7 +38,7 @@ Batch
3838
act: tensor([0., 6.]),
3939
)
4040

41-
总之就是可以定义任何key-value放在Batch里面,然后可以做一些常规的操作比如+-\*、cat/stack之类的。`Understand Batch </en/master/tutorials/batch.html>`_ 里面详细描述了Batch的各种用法,非常值得一看。
41+
总之就是可以定义任何key-value放在Batch里面,然后可以做一些常规的操作比如+-\*、cat/stack之类的。`Understand Batch </en/master/tutorials/batch.html>`_ 里面详细描述了Batch的各种用法,非常值得一看(虽然它是英文的但只要看代码也还行?)
4242

4343

4444
Buffer
@@ -58,39 +58,52 @@ Buffer
5858
::
5959

6060
>>> import pickle, numpy as np
61-
>>> from tianshou.data import ReplayBuffer
61+
>>> from tianshou.data import Batch, ReplayBuffer
6262
>>> buf = ReplayBuffer(size=20)
6363
>>> for i in range(3):
64-
... buf.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
65-
>>> buf.obs
66-
# 因为设置了 size = 20,所以 len(buf.obs) == 20
67-
array([0., 1., 2., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
68-
0., 0., 0., 0.])
64+
... buf.add(Batch(obs=i, act=i, rew=i, done=0, obs_next=i + 1, info={}))
65+
66+
>>> buf.obs # 因为设置了 size = 20,所以 len(buf.obs) == 20
67+
array([0, 1, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
6968
>>> # 但是里面只有3个合法的数据,因此 len(buf) == 3
7069
>>> len(buf)
7170
3
7271
>>> pickle.dump(buf, open('buf.pkl', 'wb')) # 把buffer所有数据保存到 "buf.pkl"
72+
>>> buf.save_hdf5('buf.hdf5') # 把buffer所有数据保存到 "buf.hdf5"
73+
7374
>>> buf2 = ReplayBuffer(size=10)
7475
>>> for i in range(15):
75-
... buf2.add(obs=i, act=i, rew=i, done=i, obs_next=i + 1, info={})
76+
... done = i % 4 == 0
77+
... buf2.add(Batch(obs=i, act=i, rew=i, done=done, obs_next=i + 1, info={}))
7678
>>> len(buf2)
7779
10
78-
>>> buf2.obs
79-
# 因为 buf2 的 size = 10,所以它只会存储最后10步的结果
80-
array([10., 11., 12., 13., 14., 5., 6., 7., 8., 9.])
80+
>>> buf2.obs # 因为 buf2 的 size = 10,所以它只会存储最后10步的结果
81+
array([10, 11, 12, 13, 14, 5, 6, 7, 8, 9])
82+
83+
>>> buf.update(buf2) # 把 buf2 的数据挪到buf里面,同时保持相对时间顺序
84+
>>> buf.obs
85+
array([ 0, 1, 2, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 0, 0, 0,
86+
0, 0, 0, 0])
8187

82-
>>> # 把 buf2 的数据挪到buf里面,同时保持相对时间顺序
83-
>>> buf.update(buf2)
84-
array([ 0., 1., 2., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14.,
85-
0., 0., 0., 0., 0., 0., 0.])
88+
>>> indice = buf.sample_index(0) # 使用 batchsize=0 来获取buffer里面的全部数据
89+
>>> indice
90+
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12])
91+
>>> buf.prev(indice) # 给定index,计算上一个transition所对应的index
92+
array([ 0, 0, 1, 2, 3, 4, 5, 7, 7, 8, 9, 11, 11])
93+
>>> buf.next(indice) # 给定index,计算下一个transition所对应的index
94+
array([ 1, 2, 3, 4, 5, 6, 6, 8, 9, 10, 10, 12, 12])
8695

8796
>>> # 从buffer里面拿一个随机的数据,batch_data就是buf[indice]
8897
>>> batch_data, indice = buf.sample(batch_size=4)
8998
>>> batch_data.obs == buf[indice].obs
9099
array([ True, True, True, True])
91100
>>> len(buf)
92101
13
93-
>>> buf = pickle.load(open('buf.pkl', 'rb')) # load from "buf.pkl"
102+
103+
>>> buf = pickle.load(open('buf.pkl', 'rb')) # 从"buf.pkl"文件恢复出buffer
104+
>>> len(buf)
105+
3
106+
>>> buf = ReplayBuffer.load_hdf5('buf.hdf5') # 从"buf.hdf5"导入完整的buffer
94107
>>> len(buf)
95108
3
96109

@@ -100,46 +113,67 @@ Buffer
100113
>>> buf = ReplayBuffer(size=9, stack_num=4, ignore_obs_next=True)
101114
>>> for i in range(16):
102115
... done = i % 5 == 0
103-
... buf.add(obs={'id': i}, act=i, rew=i, done=done,
104-
... obs_next={'id': i + 1})
116+
... ptr, ep_rew, ep_len, ep_idx = buf.add(
117+
... Batch(obs={'id': i}, act=i, rew=i,
118+
... done=done, obs_next={'id': i + 1}))
119+
... print(i, ep_len, ep_rew)
120+
0 [1] [0.]
121+
1 [0] [0.]
122+
2 [0] [0.]
123+
3 [0] [0.]
124+
4 [0] [0.]
125+
5 [5] [15.]
126+
6 [0] [0.]
127+
7 [0] [0.]
128+
8 [0] [0.]
129+
9 [0] [0.]
130+
10 [5] [40.]
131+
11 [0] [0.]
132+
12 [0] [0.]
133+
13 [0] [0.]
134+
14 [0] [0.]
135+
15 [5] [65.]
105136
>>> print(buf) # 可以发现obs_next并不在里面存着
106137
ReplayBuffer(
107-
act: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
108-
done: array([0., 1., 0., 0., 0., 0., 1., 0., 0.]),
109-
info: Batch(),
110138
obs: Batch(
111-
id: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
139+
id: array([ 9, 10, 11, 12, 13, 14, 15, 7, 8]),
112140
),
113-
policy: Batch(),
141+
act: array([ 9, 10, 11, 12, 13, 14, 15, 7, 8]),
114142
rew: array([ 9., 10., 11., 12., 13., 14., 15., 7., 8.]),
143+
done: array([False, True, False, False, False, False, True, False,
144+
False]),
115145
)
116146
>>> index = np.arange(len(buf))
117147
>>> print(buf.get(index, 'obs').id)
118-
[[ 7. 7. 8. 9.]
119-
[ 7. 8. 9. 10.]
120-
[11. 11. 11. 11.]
121-
[11. 11. 11. 12.]
122-
[11. 11. 12. 13.]
123-
[11. 12. 13. 14.]
124-
[12. 13. 14. 15.]
125-
[ 7. 7. 7. 7.]
126-
[ 7. 7. 7. 8.]]
127-
>>> # 也可以这样取出stacked过的obs
128-
>>> np.allclose(buf.get(index, 'obs')['id'], buf[index].obs.id)
129-
True
130-
>>> # 可以通过 __getitem__ 来弄出obs_next(虽然并没存)
148+
[[ 7 7 8 9]
149+
[ 7 8 9 10]
150+
[11 11 11 11]
151+
[11 11 11 12]
152+
[11 11 12 13]
153+
[11 12 13 14]
154+
[12 13 14 15]
155+
[ 7 7 7 7]
156+
[ 7 7 7 8]]
157+
>>> # 也可以这样取出stacked过的obs(注意stack只对obs/obs_next/info/policy有效)
158+
>>> abs(buf.get(index, 'obs')['id'] - buf[index].obs.id).sum().sum()
159+
0
160+
>>> # 可以通过 __getitem__ 来弄出obs_next(虽然并没存),但是[:]会按照时间顺序(而不是实际存储顺序)来取数据
161+
>>> # 比如下面这个就相当于 index == [7, 8, 0, 1, 2, 3, 4, 5, 6]
131162
>>> print(buf[:].obs_next.id)
132-
[[ 7. 8. 9. 10.]
133-
[ 7. 8. 9. 10.]
134-
[11. 11. 11. 12.]
135-
[11. 11. 12. 13.]
136-
[11. 12. 13. 14.]
137-
[12. 13. 14. 15.]
138-
[12. 13. 14. 15.]
139-
[ 7. 7. 7. 8.]
140-
[ 7. 7. 8. 9.]]
163+
[[ 7 7 7 8]
164+
[ 7 7 8 9]
165+
[ 7 8 9 10]
166+
[ 7 8 9 10]
167+
[11 11 11 12]
168+
[11 11 12 13]
169+
[11 12 13 14]
170+
[12 13 14 15]
171+
[12 13 14 15]]
172+
>>> full_index = np.array([7, 8, 0, 1, 2, 3, 4, 5, 6])
173+
>>> np.allclose(buf[:].obs_next.id, buf[full_index].obs_next.id)
174+
True
141175

142-
天授还提供了其他类型的buffer比如 :class:`~tianshou.data.ListReplayBuffer` (基于list),:class:`~tianshou.data.PrioritizedReplayBuffer` (基于线段树)。可以访问对应的文档来查看。
176+
天授还提供了其他类型的buffer比如 :class:`~tianshou.data.PrioritizedReplayBuffer` (基于线段树)、:class:`~tianshou.data.VectorReplayBuffer` (能够向其中添加不同episode的数据的同时维护时间顺序)。可以访问对应的文档来查看。
143177

144178

145179
Policy
@@ -175,7 +209,7 @@ Policy
175209
| Testing state | False | False |
176210
+-----------------------------------+-----------------+-----------------+
177211

178-
``policy.updating`` 实际情况下主要用于exploration,比如在各种Q-Learning算法中,在不同的state切换探索策略
212+
``policy.updating`` 实际情况下主要用于exploration,比如在各种Q-Learning算法中,在不同的policy state切换探索策略
179213

180214

181215
policy.forward
@@ -198,7 +232,7 @@ policy.forward
198232
act = policy(batch).act[0] # policy.forward 返回一个 batch,使用 ".act" 来取出里面action的数据
199233
obs, rew, done, info = env.step(act)
200234

201-
这边 ``Batch(obs=[obs])`` 会自动为obs下面的所有数据创建第0维,让它为batch size=1,否则nn没法确定batch size。
235+
这边 ``Batch(obs=[obs])`` 会自动为obs下面的所有数据创建第0维,让它为batch size=1,否则神经网络没法确定batch size。
202236

203237

204238
.. _process_fn:
@@ -257,14 +291,34 @@ policy.process_fn
257291
Collector
258292
---------
259293

260-
:class:`~tianshou.data.Collector` 负责policy与env的交互和数据存储。:meth:`~tianshou.data.Collector.collect` 是collector的主要方法,它能够指定让policy和环境交互至少 ``n_step`` 个step或者 ``n_episode`` 个episode,并把该过程中产生的数据存储到buffer中。
294+
:class:`~tianshou.data.Collector` 负责policy与env的交互和数据存储。:meth:`~tianshou.data.Collector.collect` 是collector的主要方法,它能够指定让policy和环境交互给定数目 ``n_step`` 个step或者 ``n_episode`` 个episode,并把该过程中产生的数据存储到buffer中。
295+
296+
:ref:`pseudocode` 给出了一个宏观层面的解释,其他collector的功能可参考对应文档。此处列出一些常用用法:
297+
298+
::
299+
300+
policy = PGPolicy(...) # 或者其他policy都可以
301+
env = gym.make("CartPole-v0")
302+
303+
replay_buffer = ReplayBuffer(size=10000)
261304

262-
为啥这边强调 **至少**?因为天授使用了一个buffer来处理这些事情(当然还有一种方法是每个env对应单独的一个buffer)。如果用一个buffer做的话,需要维护若干个cache buffer,然后必须等到episode结束才能将cache buffer里面的数据转移到main buffer中,否则不能保证其中的时间顺序。
305+
# 这里单个env对应ReplayBuffer
306+
collector = Collector(policy, env, buffer=replay_buffer)
263307

264-
这么做有优点也有缺点,缺点是老是有人提issue,得手动加一个 ``gym.wrappers.TimeLimit`` 在env上面(如果env的done一直是False的话);优点是delayed update能够带来一定的性能提升,以及会大幅简化其他部分代码(比如PER、nstep、GAE这种就很简单,还有buffer.sample也还算简单,如果n个buffer的话就得多写很多代码)。
308+
# 多个env的话得用VectorReplayBuffer,但是collector仍然适用
309+
vec_buffer = VectorReplayBuffer(total_size=10000, buffer_num=3)
310+
# buffer_num推荐和env数量相等
311+
envs = DummyVectorEnv([lambda: gym.make("CartPole-v0") for _ in range(3)])
312+
collector = Collector(policy, envs, buffer=vec_buffer)
265313

266-
:ref:`pseudocode` 给出了一个宏观层面的解释,其他collector的功能可参考对应文档。
314+
# 收集3个episode
315+
collector.collect(n_episode=3)
316+
# 收集至少俩step(这个会收集三个,因为有三个env,每次收集的次数得是3的倍数)
317+
collector.collect(n_step=2)
318+
# 边收集变直播,使用render参数就可以(render传入的是时间间隔,以秒为单位)
319+
collector.collect(n_episode=1, render=0.03)
267320

321+
还有个:class:`~tianshou.data.AsyncCollector`,继承了:class:`~tianshou.data.Collector`,它支持异步的环境采样(比如环境很慢或者step时间差异很大)。不过AsyncCollector的collect的语义和上面Collector有所不同,由于异步的特性,它只能保证**至少** ``n_step`` 或者 ``n_episode`` 地收集数据。
268322

269323
Trainer
270324
-------

0 commit comments

Comments
 (0)