Skip to content

Commit 3d40a68

Browse files
authored
Add WindMillEnergy dataset (#232)
* Add export and init * Add WindMillEnergy dataset implementation * Add WindMillEnergy dataset test * Add struct * Fixes * Fix check * Fix typo * Add docs * Improve docstring
1 parent 8255e6e commit 3d40a68

File tree

4 files changed

+126
-0
lines changed

4 files changed

+126
-0
lines changed

docs/src/datasets/graphs.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,4 +33,5 @@ TUDataset
3333
METRLA
3434
PEMSBAY
3535
TemporalBrains
36+
WindMillEnergy
3637
```

src/MLDatasets.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ include("datasets/graphs/pemsbay.jl")
137137
export PEMSBAY
138138
include("datasets/graphs/temporalbrains.jl")
139139
export TemporalBrains
140+
include("datasets/graphs/windmillenergy.jl")
141+
export WindMillEnergy
140142

141143
# Meshes
142144

@@ -159,6 +161,7 @@ function __init__()
159161
__init__metrla()
160162
__init__pemsbay()
161163
__init__temporalbrains()
164+
__init__windmillenergy()
162165

163166
# misc
164167
__init__iris()

src/datasets/graphs/windmillenergy.jl

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
function __init__windmillenergy()
2+
DEPNAME = "WindMillEnergy"
3+
LINK = "https://graphmining.ai/temporal_datasets/"
4+
register(ManualDataDep(DEPNAME,
5+
"""
6+
Dataset: $DEPNAME
7+
Website : $LINK
8+
"""))
9+
end
10+
11+
function windmillenergy_datadir(size::String, dir = nothing)
12+
dir = isnothing(dir) ? datadep"WindMillEnergy" : dir
13+
if size == "small" || size == "medium" || size == "large"
14+
LINK = "http://www-sop.inria.fr/members/Aurora.Rossi/data/windmill_output_$(size).json"
15+
else
16+
print("Please choose a valid size: small, medium or large")
17+
end
18+
if isfile(joinpath(dir, "windmill_output_$(size).json")) == false
19+
DataDeps.fetch_default(LINK, dir)
20+
end
21+
@assert isdir(dir)
22+
return dir
23+
end
24+
25+
function generate_task(data::AbstractArray, num_timesteps_in::Int, num_timesteps_out::Int)
26+
features = []
27+
targets = []
28+
for i in 1:(size(data,3)-num_timesteps_in-num_timesteps_out)
29+
push!(features, data[:,:,i:i+num_timesteps_in-1])
30+
push!(targets, data[:,:,i+num_timesteps_in:i+num_timesteps_in+num_timesteps_out-1])
31+
end
32+
return features, targets
33+
end
34+
35+
function create_windmillenergy_dataset(s::String, normalize::Bool, num_timesteps_in::Int, num_timesteps_out::Int, dir)
36+
name_file = joinpath(dir, "windmill_output_$(s).json")
37+
data = read_json(name_file)
38+
src = zeros(Int, length(data["edges"]))
39+
dst = zeros(Int, length(data["edges"]))
40+
for (i, edge) in enumerate(data["edges"])
41+
src[i] = edge[1] + 1
42+
dst[i] = edge[2] + 1
43+
end
44+
weights = Float32.(data["weights"])
45+
f = Float32.(stack(data["block"]))
46+
f = reshape(f, 1, size(f, 1), size(f, 2))
47+
48+
if normalize
49+
f = (f .- Statistics.mean(f, dims=(2))) ./ Statistics.std(f, dims=(2)) #Z-score normalization
50+
end
51+
52+
x, y = generate_task(f, num_timesteps_in, num_timesteps_out)
53+
54+
g = Graph(; edge_index = (src, dst),
55+
edge_data = weights,
56+
node_data = (features = x, targets = y))
57+
return g
58+
end
59+
60+
"""
61+
WindMillEnergy(; size, normalize=true, num_timesteps_in=8, num_timesteps_out=8, dir=nothing)
62+
63+
The WindMillEnergy dataset contains a collection hourly energy output of windmills from a European country for more than 2 years.
64+
65+
`WindMillEnergy` is a graph with nodes representing windmills. The edge weights represent the strength of the relationship between the windmills. The number of nodes is fixed and depends on the size of the dataset, 11 for `small`, 26 for `medium`, and 319 for `large`.
66+
67+
The node features and targets are the number of hourly energy output of the windmills. They are represented as an array of arrays of size `(1, num_nodes, num_timesteps_in)`. In both cases, two consecutive arrays are shifted by one-time step.
68+
69+
# Keyword Arguments
70+
71+
- `size::String`: The size of the dataset, can be `small`, `medium`, or `large`.
72+
- `normalize::Bool`: Whether to normalize the data using Z-score normalization. Default is `true`.
73+
- `num_timesteps_in::Int`: The number of time steps, in this case, the number of hours, for the input features. Default is `8`.
74+
- `num_timesteps_out::Int`: The number of time steps, in this case, the number of hours, for the target values. Default is `8`.
75+
- `dir::String`: The directory to save the dataset. Default is `nothing`.
76+
77+
# Examples
78+
79+
```julia-repl
80+
julia> using JSON3
81+
82+
julia> dataset = WindMillEnergy(;size= "small");
83+
84+
julia> dataset.graphs[1]
85+
Graph:
86+
num_nodes => 11
87+
num_edges => 121
88+
edge_index => ("121-element Vector{Int64}", "121-element Vector{Int64}")
89+
node_data => (features = "17456-element Vector{Any}", targets = "17456-element Vector{Any}")
90+
edge_data => 121-element Vector{Float32}
91+
92+
julia> size(dataset.graphs[1].node_data.features[1])
93+
(1, 11, 8)
94+
```
95+
"""
96+
struct WindMillEnergy <: AbstractDataset
97+
graphs::Vector{Graph}
98+
end
99+
100+
function WindMillEnergy(;size::String, normalize::Bool = true, num_timesteps_in::Int = 8 , num_timesteps_out::Int = 8, dir = nothing)
101+
create_default_dir("WindMillEnergy")
102+
dir = windmillenergy_datadir(size, dir)
103+
g = create_windmillenergy_dataset(size, normalize, num_timesteps_in, num_timesteps_out, dir)
104+
return WindMillEnergy([g])
105+
end
106+
107+
Base.length(d::WindMillEnergy) = length(d.graphs)
108+
Base.getindex(d::WindMillEnergy, ::Colon) = d.graphs[1]
109+
Base.getindex(d::WindMillEnergy, i) = getindex(d.graphs, i)

test/datasets/graphs_no_ci.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,3 +379,16 @@ end
379379
@test g.snapshots[1] isa MLDatasets.Graph
380380
@test length(g.snapshots[1].node_data) == 102
381381
end
382+
383+
@testset "WindMillEnergy" begin
384+
data = WindMillEnergy(size = "small")
385+
@test data isa AbstractDataset
386+
@test length(data) == 1
387+
g = data[1]
388+
@test g === data[:]
389+
@test g isa MLDatasets.Graph
390+
391+
@test g.num_nodes == 11
392+
@test g.num_edges == 121
393+
@test g.node_data.features[1][:,:,2:end] == g.node_data.features[2][:,:,1:end-1]
394+
end

0 commit comments

Comments
 (0)