Skip to content

Commit cd89c96

Browse files
authored
Update README.md
1 parent 9b0a17e commit cd89c96

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

README.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ For example, the figure above presents the processing time of a single mini-batc
2121
<br>
2222

2323
## Requirements
24-
- **GPU and CUDA are required**
24+
- **GPU and CUDA 8 are required**
2525
- [PyTorch](http://pytorch.org/)
2626
- [CuPy](https://cupy.chainer.org/)
2727
- [pynvrtc](https://github.com/NVIDIA/pynvrtc)
@@ -34,10 +34,11 @@ Install requirements via `pip install -r requirements.txt`. CuPy and pynvrtc nee
3434
The usage of SRU is similar to `nn.LSTM`.
3535
```python
3636
import torch
37+
from torch.autograd import Variable
3738
from cuda_functional import SRU, SRUCell
3839

3940
# input has length 20, batch size 32 and dimension 128
40-
x = torch.FloatTensor(20, 32, 128).cuda()
41+
x = Variable(torch.FloatTensor(20, 32, 128).cuda())
4142

4243
input_size, hidden_size = 128, 128
4344

@@ -48,6 +49,7 @@ rnn = SRU(input_size, hidden_size,
4849
use_tanh = 1, # use tanh or identity activation
4950
bidirectional = False # bidirectional RNN ?
5051
)
52+
rnn.cuda()
5153

5254
output, hidden = rnn(x) # forward pass
5355

0 commit comments

Comments
 (0)