Skip to content

Commit 09607d0

Browse files
authored
ChickenPox graph dataset (#233)
* Add ChickenPox export * Add ChickenPox dataset export * Add ChickenPox test * Add ChickenPox dataset * Fix a alphabetical order * Add more info and example * Add more tests
1 parent 3d40a68 commit 09607d0

File tree

4 files changed

+136
-3
lines changed

4 files changed

+136
-3
lines changed

docs/src/datasets/graphs.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,19 @@ MLDatasets.HeteroGraph
2020
```
2121

2222
```@docs
23+
ChickenPox
2324
CiteSeer
2425
Cora
2526
KarateClub
27+
METRLA
2628
MovieLens
2729
OGBDataset
2830
OrganicMaterialsDB
31+
PEMSBAY
2932
PolBlogs
3033
PubMed
3134
Reddit
32-
TUDataset
33-
METRLA
34-
PEMSBAY
3535
TemporalBrains
36+
TUDataset
3637
WindMillEnergy
3738
```

src/MLDatasets.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,8 @@ include("graph.jl")
111111
include("datasets/graphs/planetoid.jl")
112112
include("datasets/graphs/traffic.jl")
113113
# export read_planetoid_data
114+
include("datasets/graphs/chickenpox.jl")
115+
export ChickenPox
114116
include("datasets/graphs/cora.jl")
115117
export Cora
116118
include("datasets/graphs/citeseer.jl")
@@ -149,6 +151,7 @@ function __init__()
149151
# TODO automatically find and execute __init__xxx functions
150152

151153
# graph
154+
__init__chickenpox()
152155
__init__citeseer()
153156
__init__cora()
154157
__init__movielens()
@@ -163,6 +166,7 @@ function __init__()
163166
__init__temporalbrains()
164167
__init__windmillenergy()
165168

169+
166170
# misc
167171
__init__iris()
168172
__init__mutagenesis()

src/datasets/graphs/chickenpox.jl

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
function __init__chickenpox()
2+
DEPNAME = "ChickenPox"
3+
LINK = "https://graphmining.ai/temporal_datasets/"
4+
register(ManualDataDep(DEPNAME,
5+
"""
6+
Dataset: $DEPNAME
7+
Website : $LINK
8+
"""))
9+
end
10+
11+
function chickenpox_datadir(dir = nothing)
12+
dir = isnothing(dir) ? datadep"ChickenPox" : dir
13+
LINK = "http://www-sop.inria.fr/members/Aurora.Rossi/data/chickenpox.json"
14+
if length(readdir((dir))) == 0
15+
DataDeps.fetch_default(LINK, dir)
16+
end
17+
@assert isdir(dir)
18+
return dir
19+
end
20+
21+
function generate_task(data::AbstractArray, num_timesteps_in::Int, num_timesteps_out::Int)
22+
features = []
23+
targets = []
24+
for i in 1:(size(data,3)-num_timesteps_in-num_timesteps_out)
25+
push!(features, data[:,:,i:i+num_timesteps_in-1])
26+
push!(targets, data[:,:,i+num_timesteps_in:i+num_timesteps_in+num_timesteps_out-1])
27+
end
28+
return features, targets
29+
end
30+
31+
function create_chickenpox_dataset( normalize::Bool, num_timesteps_in::Int, num_timesteps_out::Int, dir)
32+
name_file = joinpath(dir, "chickenpox.json")
33+
data = read_json(name_file)
34+
src = zeros(Int, length(data["edges"]))
35+
dst = zeros(Int, length(data["edges"]))
36+
for (i, edge) in enumerate(data["edges"])
37+
src[i] = edge[1] + 1
38+
dst[i] = edge[2] + 1
39+
end
40+
f = Float32.(stack(data["FX"]))
41+
f = reshape(f, 1, size(f, 1), size(f, 2))
42+
43+
metadata = Dict(key => value + 1 for (key, value) in data["node_ids"])
44+
45+
if normalize
46+
f = (f .- Statistics.mean(f, dims=(2))) ./ Statistics.std(f, dims=(2)) #Z-score normalization
47+
end
48+
49+
x, y = generate_task(f, num_timesteps_in, num_timesteps_out)
50+
51+
g = Graph(; edge_index = (src, dst),
52+
node_data = (features = x, targets = y))
53+
return g, metadata
54+
end
55+
56+
"""
57+
ChickenPox(; normalize= true, num_timesteps_in = 8 , num_timesteps_out = 8, dir = nothing)
58+
59+
The ChickenPox dataset contains county-level chickenpox cases in Hungary between 2004 and 2014.
60+
61+
`ChickenPox` is composed of a graph with nodes representing counties and edges representing the neighborhoods, and a metadata dictionary containing the correspondence between the node indices and the county names.
62+
63+
The node features are the number of weekly chickenpox cases in each county. They are represented as an array of arrays of size `(1, num_nodes, num_timesteps_in)`. The target values are the number of weekly chickenpox cases in each county. They are represented as an array of arrays of size `(1, num_nodes, num_timesteps_out)`. In both cases. two consecutive arrays are shifted by one-time step.
64+
65+
The dataset was taken from the [Pytorch Geometric Temporal repository](https://pytorch-geometric-temporal.readthedocs.io/en/latest/modules/dataset.html#torch_geometric_temporal.dataset.chickenpox.ChickenpoxDatasetLoader) and more information about the dataset can be found in the paper ["Chickenpox Cases in Hungary: a Benchmark Dataset for
66+
Spatiotemporal Signal Processing with Graph Neural Networks"](https://arxiv.org/pdf/2102.08100).
67+
68+
69+
# Keyword Arguments
70+
- `normalize::Bool`: Whether to normalize the data using Z-score normalization. Default is `true`.
71+
- `num_timesteps_in::Int`: The number of time steps, in this case, the number of weeks, for the input features. Default is `8`.
72+
- `num_timesteps_out::Int`: The number of time steps, in this case, the number of weeks, for the target values. Default is `8`.
73+
- `dir::String`: The directory to save the dataset. Default is `nothing`.
74+
75+
# Examples
76+
```julia-repl
77+
julia> using JSON3 # import JSON3
78+
79+
julia> dataset = ChickenPox()
80+
dataset ChickenPox:
81+
metadata => Dict{Symbol, Any} with 20 entries
82+
graphs => 1-element Vector{MLDatasets.Graph}
83+
84+
julia> dataset.graphs[1].num_nodes # 20 counties
85+
20
86+
87+
julia> size(dataset.graphs[1].node_data.features[1])
88+
(1, 20, 8)
89+
90+
julia> dataset.metadata[:BUDAPEST] # The node 5 correponds to Budapest county
91+
5
92+
```
93+
"""
94+
struct ChickenPox <: AbstractDataset
95+
metadata::Dict{Symbol, Any}
96+
graphs::Vector{Graph}
97+
end
98+
99+
function ChickenPox(; normalize::Bool = true, num_timesteps_in::Int = 8 , num_timesteps_out::Int = 8, dir = nothing)
100+
create_default_dir("ChickenPox")
101+
dir = chickenpox_datadir(dir)
102+
g, metadata = create_chickenpox_dataset(normalize, num_timesteps_in, num_timesteps_out, dir)
103+
return ChickenPox(metadata, [g])
104+
end
105+
106+
Base.length(d::ChickenPox) = length(d.graphs)
107+
Base.getindex(d::ChickenPox, ::Colon) = d.graphs[1]
108+
Base.getindex(d::ChickenPox, i) = getindex(d.graphs, i)

test/datasets/graphs_no_ci.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,3 +392,23 @@ end
392392
@test g.num_edges == 121
393393
@test g.node_data.features[1][:,:,2:end] == g.node_data.features[2][:,:,1:end-1]
394394
end
395+
396+
@testset "ChickenPox" begin
397+
data = ChickenPox()
398+
@test data isa AbstractDataset
399+
@test data.metadata isa Dict
400+
@test length(data) == 1
401+
g = data[1]
402+
@test g === data[:]
403+
@test g isa MLDatasets.Graph
404+
405+
@test g.num_nodes == 20
406+
@test g.num_edges == 102
407+
@test g.node_data.features[1][:,:,2:end] == g.node_data.features[2][:,:,1:end-1]
408+
409+
@test data.metadata[:BUDAPEST] == 5
410+
@test data.metadata[:BACS] == 1
411+
@test data.metadata[:ZALA] == 20
412+
@test data.metadata[:HEVES] == 10
413+
@test data.metadata isa Dict{Symbol, Any}
414+
end

0 commit comments

Comments
 (0)