Skip to content

Commit aa55d80

Browse files
authored
Adding AQSOL Dataset (#240)
* AQSOL * aqsol working * remove redundant file * remove redundant file * Update graphs.jl * Update graphs.jl: Fixing tests * Update AQSOL.jl: docs * Update AQSOL.jl: better docstring * Update MLDatasets.jl * Update AQSOL.jl: fix edge_index * Update graphs.jl: update tests * Update graphs.jl: check only random graphs * Update graphs.md: add to docs
1 parent 6b7d256 commit aa55d80

File tree

5 files changed

+125
-2
lines changed

5 files changed

+125
-2
lines changed

docs/src/datasets/graphs.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ MLDatasets.HeteroGraph
2020
```
2121

2222
```@docs
23+
AQSOL
2324
ChickenPox
2425
CiteSeer
2526
Cora

src/MLDatasets.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ include("datasets/graphs/citeseer.jl")
119119
export CiteSeer
120120
include("datasets/graphs/karateclub.jl")
121121
export KarateClub
122+
include("datasets/graphs/AQSOL.jl")
123+
export AQSOL
122124
include("datasets/graphs/movielens.jl")
123125
export MovieLens
124126
include("datasets/graphs/ogbdataset.jl")
@@ -151,6 +153,7 @@ function __init__()
151153
# TODO automatically find and execute __init__xxx functions
152154

153155
# graph
156+
__init__aqsol()
154157
__init__chickenpox()
155158
__init__citeseer()
156159
__init__cora()
@@ -166,7 +169,6 @@ function __init__()
166169
__init__temporalbrains()
167170
__init__windmillenergy()
168171

169-
170172
# misc
171173
__init__iris()
172174
__init__mutagenesis()

src/datasets/graphs/AQSOL.jl

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
function __init__aqsol()
2+
DEPNAME = "AQSOL"
3+
LINK = "https://www.dropbox.com/s/lzu9lmukwov12kt/aqsol_graph_raw.zip?dl=1"
4+
register(DataDep(DEPNAME,
5+
"""
6+
Dataset: The AQSOL dataset.
7+
Website: http://arxiv.org/abs/2003.00982
8+
""",
9+
LINK,
10+
post_fetch_method = unpack))
11+
end
12+
13+
struct AQSOL <: AbstractDataset
14+
split::Symbol
15+
metadata::Dict{String,Any}
16+
graphs::Vector{Graph}
17+
end
18+
19+
"""
20+
AQSOL(; split=:train, dir=nothing)
21+
22+
The AQSOL (Aqueous Solubility) dataset from the paper
23+
[Graph Neural Network for Predicting Aqueous Solubility of Organic Molecules](http://arxiv.org/abs/2003.00982).
24+
25+
The dataset contains 9,882 graphs representing small organic molecules. Each graph represents a molecule, where nodes correspond to atoms and edges to bonds. The node features represent the atomic number, and the edge features represent the bond type. The target is the aqueous solubility of the molecule, measured in mol/L.
26+
27+
# Arguments
28+
29+
- `split`: Which split of the dataset to load. Can be one of `:train`, `:val`, or `:test`. Defaults to `:train`.
30+
- `dir`: Directory in which the dataset is in.
31+
32+
# Examples
33+
34+
```julia-repl
35+
julia> using MLDatasets
36+
37+
julia> data = AQSOL()
38+
dataset AQSOL:
39+
split => :train
40+
metadata => Dict{String, Any} with 1 entry
41+
graphs => 7985-element Vector{MLDatasets.Graph}
42+
43+
julia> length(data)
44+
7985
45+
46+
julia> g = data[1]
47+
Graph:
48+
num_nodes => 23
49+
num_edges => 42
50+
edge_index => ("42-element Vector{Int64}", "42-element Vector{Int64}")
51+
node_data => (features = "23-element Vector{Int64}",)
52+
edge_data => (features = "42-element Vector{Int64}",)
53+
54+
julia> g.num_nodes
55+
23
56+
57+
julia> g.node_data.features
58+
23-element Vector{Int64}:
59+
0
60+
1
61+
1
62+
63+
1
64+
1
65+
1
66+
67+
julia> g.edge_index
68+
([2, 3, 3, 4, 4, 5, 5, 6, 6, 7 … 18, 19, 19, 20, 20, 21, 20, 22, 20, 23], [3, 2, 4, 3, 5, 4, 6, 5, 7, 6 … 19, 18, 20, 19, 21, 20, 22, 20, 23, 20])
69+
```
70+
"""
71+
function AQSOL(;split=:train, dir=nothing)
72+
@assert split [:train, :val, :test]
73+
DEPNAME = "AQSOL"
74+
path = datafile(DEPNAME, "asqol_graph_raw/$(split).pickle", dir)
75+
graphs = Pickle.npyload(path)
76+
g = [create_aqsol_graph(g...) for g in graphs]
77+
metadata = Dict{String, Any}("n_observations" => length(g))
78+
return AQSOL(split, metadata, g)
79+
end
80+
81+
function create_aqsol_graph(x, edge_attr, edge_index, y)
82+
x = Int.(x)
83+
edge_attr = Int.(edge_attr)
84+
edge_index = Int.(edge_index .+ 1)
85+
86+
if size(edge_index, 2) == 0
87+
s, t = Int[], Int[]
88+
else
89+
s, t = edge_index[1, :], edge_index[2, :]
90+
end
91+
92+
return Graph(; num_nodes = length(x),
93+
edge_index = (s, t),
94+
node_data = (features = x,),
95+
edge_data = (features = edge_attr,))
96+
end
97+
98+
Base.length(d::AQSOL) = length(d.graphs)
99+
Base.getindex(d::AQSOL, ::Colon) = d.graphs
100+
Base.getindex(d::AQSOL, i) = getindex(d.graphs, i)

src/datasets/graphs/movielens.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -549,4 +549,4 @@ Base.length(data::MovieLens) = length(data.graphs)
549549
function Base.getindex(data::MovieLens, ::Colon)
550550
length(data.graphs) == 1 ? data.graphs[1] : data.graphs
551551
end
552-
Base.getindex(data::MovieLens, i) = getobs(data.graphs, i)
552+
Base.getindex(data::MovieLens, i) = getobs(data.graphs, i)

test/datasets/graphs.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,3 +278,23 @@ end
278278
@test size(g.edge_data.features) == (2, g.num_edges)
279279
@test size(g.edge_data.targets) == (g.num_edges,)
280280
end
281+
282+
@testset "AQSOL" begin
283+
split_counts = Dict(:train => 7985, :val => 998, :test => 999)
284+
for split in [:train, :val, :test]
285+
data = AQSOL(split=split)
286+
@test data isa AbstractDataset
287+
@test data.split == split
288+
@test length(data) == data.metadata["n_observations"]
289+
@test length(data.graphs) == split_counts[split]
290+
291+
i = rand(1:length(data))
292+
g = data[i]
293+
@test g isa MLDatasets.Graph
294+
s, t = g.edge_index
295+
@test all(1 .<= s .<= g.num_nodes)
296+
@test all(1 .<= t .<= g.num_nodes)
297+
@test length(s) == g.num_edges
298+
@test length(t) == g.num_edges
299+
end
300+
end

0 commit comments

Comments
 (0)