Skip to content

Commit 971abcc

Browse files
committed
Fix embeddings
1 parent 10adc7a commit 971abcc

File tree

2 files changed

+29
-34
lines changed

2 files changed

+29
-34
lines changed

examples/lstm/imdb-classifier/util.py

Lines changed: 27 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,38 +9,33 @@ def load_imdb():
99
subprocess.check_output(
1010
"curl -SL http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz | tar xz", shell=True)
1111

12-
if not os.path.exists('./aclImdb/cache.npz'):
13-
X_train = []
14-
y_train = []
15-
16-
path = './aclImdb/train/pos/'
17-
X_train.extend([open(path + f).read()
18-
for f in os.listdir(path) if f.endswith('.txt')])
19-
y_train.extend([1 for _ in range(12500)])
20-
21-
path = './aclImdb/train/neg/'
22-
X_train.extend([open(path + f).read()
23-
for f in os.listdir(path) if f.endswith('.txt')])
24-
y_train.extend([0 for _ in range(12500)])
25-
26-
X_test = []
27-
y_test = []
28-
29-
path = './aclImdb/test/pos/'
30-
X_test.extend([open(path + f).read()
31-
for f in os.listdir(path) if f.endswith('.txt')])
32-
y_test.extend([1 for _ in range(12500)])
33-
34-
path = './aclImdb/test/neg/'
35-
X_test.extend([open(path + f).read()
36-
for f in os.listdir(path) if f.endswith('.txt')])
37-
y_test.extend([0 for _ in range(12500)])
38-
39-
np.savez('./aclImdb/cache.npz', X_train=X_train, y_train=y_train,
40-
X_test=X_test, y_test=y_test)
41-
42-
cached = np.load('./aclImdb/cache.npz')
43-
return (cached['X_train'], cached['y_train']), (cached['X_test'], cached['y_test'])
12+
X_train = []
13+
y_train = []
14+
15+
path = './aclImdb/train/pos/'
16+
X_train.extend([open(path + f).read()
17+
for f in os.listdir(path) if f.endswith('.txt')])
18+
y_train.extend([1 for _ in range(12500)])
19+
20+
path = './aclImdb/train/neg/'
21+
X_train.extend([open(path + f).read()
22+
for f in os.listdir(path) if f.endswith('.txt')])
23+
y_train.extend([0 for _ in range(12500)])
24+
25+
X_test = []
26+
y_test = []
27+
28+
path = './aclImdb/test/pos/'
29+
X_test.extend([open(path + f).read()
30+
for f in os.listdir(path) if f.endswith('.txt')])
31+
y_test.extend([1 for _ in range(12500)])
32+
33+
path = './aclImdb/test/neg/'
34+
X_test.extend([open(path + f).read()
35+
for f in os.listdir(path) if f.endswith('.txt')])
36+
y_test.extend([0 for _ in range(12500)])
37+
38+
return (X_train, y_train), (X_test, y_test)
4439

4540

4641
if __name__ == '__main__':
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
[default]
2-
entity: oreilly
3-
project: imdb-sep10
2+
entity: qualcomm
3+
project: imdb-sep11
44
base_url: https://api.wandb.ai

0 commit comments

Comments
 (0)