Skip to content

Commit 6745966

Browse files
authored
Update self_attn_pool.py
Fix numpy initialization bug
1 parent b5e1df0 commit 6745966

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

chapter8/self_attn_pool.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def __init__(self, data_root="data", train_size=0.8):
8585
self.test_label = graph_labels[self.test_index]
8686

8787
def split_data(self, train_size):
88-
unique_indicator = np.asarray(set(self.graph_indicator))
88+
unique_indicator = np.asarray(list(set(self.graph_indicator)))
8989
train_index, test_index = train_test_split(unique_indicator,
9090
train_size=train_size,
9191
random_state=1234)

0 commit comments

Comments
 (0)