Skip to content

Commit 5b57f3e

Browse files
authored
Add PEMS-BAY dataset (#213)
* Add `PEMSBAY` in docs * Export and include traffic * Add common script for traffic datasets * Modify `METRLA` using traffic * Add `PEMSBAY` dataset * Add `PEMSBAY ` test
1 parent d92f66d commit 5b57f3e

File tree

6 files changed

+121
-43
lines changed

6 files changed

+121
-43
lines changed

docs/src/datasets/graphs.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,5 @@ PubMed
3131
Reddit
3232
TUDataset
3333
METRLA
34+
PEMSBAY
3435
```

src/MLDatasets.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ include("graph.jl")
109109
# export Graph
110110

111111
include("datasets/graphs/planetoid.jl")
112+
include("datasets/graphs/traffic.jl")
112113
# export read_planetoid_data
113114
include("datasets/graphs/cora.jl")
114115
export Cora
@@ -132,6 +133,8 @@ include("datasets/graphs/tudataset.jl")
132133
export TUDataset
133134
include("datasets/graphs/metrla.jl")
134135
export METRLA
136+
include("datasets/graphs/pemsbay.jl")
137+
export PEMSBAY
135138

136139
# Meshes
137140

@@ -152,6 +155,7 @@ function __init__()
152155
__init__reddit()
153156
__init__tudataset()
154157
__init__metrla()
158+
__init__pemsbay()
155159

156160
# misc
157161
__init__iris()

src/datasets/graphs/metrla.jl

Lines changed: 1 addition & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,7 @@ struct METRLA <: AbstractDataset
2626
end
2727

2828
function METRLA(;num_timesteps_in::Int = 12, num_timesteps_out::Int=12, dir = nothing)
29-
create_default_dir("METR-LA")
30-
d = metrla_datadir(dir)
31-
adj_matrix, node_values = read_metrla(d)
32-
33-
node_values = permutedims(node_values,(1,3,2))
34-
node_values = (node_values .- Statistics.mean(node_values, dims=(3,1))) ./ Statistics.std(node_values, dims=(3,1)) #Z-score normalization
35-
36-
s, t, w = adjmatrix2edgeindex(adj_matrix; weighted = true)
37-
38-
x, y = metrla_generate_task(node_values, num_timesteps_in, num_timesteps_out)
29+
s, t, w, x, y = processed_traffic("METR-LA", num_timesteps_in, num_timesteps_out, dir)
3930

4031
g = Graph(; num_nodes = 207,
4132
edge_index = (s, t),
@@ -47,39 +38,6 @@ function METRLA(;num_timesteps_in::Int = 12, num_timesteps_out::Int=12, dir = no
4738
return METRLA([g])
4839
end
4940

50-
function metrla_datadir(dir = nothing)
51-
dir = isnothing(dir) ? datadep"METR-LA" : dir
52-
dname = "METR-LA"
53-
LINK = "https://graphmining.ai/temporal_datasets/$dname.zip"
54-
if length(readdir((dir))) == 0
55-
DataDeps.fetch_default(LINK, dir)
56-
currdir = pwd()
57-
cd(dir) # Needed since `unpack` extracts in working dir
58-
DataDeps.unpack(joinpath(dir, "$dname.zip"))
59-
# conditions when unzipped folder is our required data dir
60-
cd(currdir)
61-
end
62-
@assert isdir(dir)
63-
return dir
64-
end
65-
66-
function read_metrla(d::String)
67-
adj_matrix = NPZ.npzread(joinpath(d, "adj_mat.npy"))
68-
node_features = NPZ.npzread(joinpath(d, "node_values.npy"))
69-
return adj_matrix, node_features
70-
end
71-
72-
function metrla_generate_task(node_values::AbstractArray, num_timesteps_in::Int, num_timesteps_out::Int)
73-
indices = [(i, i + num_timesteps_in + num_timesteps_out) for i in 1:(size(node_values,1) - num_timesteps_in - num_timesteps_out)]
74-
features = []
75-
targets = []
76-
for (i,j) in indices
77-
push!(features, node_values[i:i+num_timesteps_in-1,:,:])
78-
push!(targets, reshape(node_values[i+num_timesteps_in:j-1,1,:], (num_timesteps_out, 1, size(node_values, 3))))
79-
end
80-
return features, targets
81-
end
82-
8341
Base.length(d::METRLA) = length(d.graphs)
8442
Base.getindex(d::METRLA, ::Colon) = d.graphs[1]
8543
Base.getindex(d::METRLA, i) = getindex(d.graphs, i)

src/datasets/graphs/pemsbay.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
function __init__pemsbay()
2+
DEPNAME = "PEMS-BAY"
3+
LINK = "https://graphmining.ai/temporal_datasets/"
4+
register(ManualDataDep(DEPNAME,
5+
"""
6+
Dataset: $DEPNAME
7+
Website : $LINK
8+
"""))
9+
end
10+
11+
"""
12+
PEMSBAY(; num_timesteps_in::Int = 12, num_timesteps_out::Int=12, dir=nothing)
13+
14+
The PEMS-BAY dataset described in the [Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic Forecasting](https://arxiv.org/abs/1707.01926) paper.
15+
It is collected by California Transportation Agencies (Cal-
16+
Trans) Performance Measurement System (PeMS).
17+
18+
`PEMSBAY` is a graph with 325 nodes representing traffic sensors in the Bay Area.
19+
20+
The edge weights `w` are contained as a feature array in `edge_data` and represent the distance between the sensors.
21+
22+
The node features are the traffic speed and the time of the measurements collected by the sensors, divided into `num_timesteps_in` time steps.
23+
24+
The target values are the traffic speed and the time of the measurements collected by the sensors, divided into `num_timesteps_out` time steps.
25+
"""
26+
struct PEMSBAY <: AbstractDataset
27+
graphs::Vector{Graph}
28+
end
29+
30+
function PEMSBAY(;num_timesteps_in::Int = 12, num_timesteps_out::Int=12, dir = nothing)
31+
s, t, w, x, y = processed_traffic("PEMS-BAY", num_timesteps_in, num_timesteps_out, dir)
32+
33+
g = Graph(; num_nodes = 325,
34+
edge_index = (s, t),
35+
edge_data = w,
36+
node_data = (features = x, targets = y)
37+
38+
39+
)
40+
return PEMSBAY([g])
41+
end
42+
43+
Base.length(d::PEMSBAY) = length(d.graphs)
44+
Base.getindex(d::PEMSBAY, ::Colon) = d.graphs[1]
45+
Base.getindex(d::PEMSBAY, i) = getindex(d.graphs, i)

src/datasets/graphs/traffic.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
function traffic_datadir(dname ::String, dir = nothing)
2+
if dname == "PEMS-BAY"
3+
dir = isnothing(dir) ? datadep"PEMS-BAY" : dir
4+
elseif dname == "METR-LA"
5+
dir = isnothing(dir) ? datadep"METR-LA" : dir
6+
end
7+
LINK = "https://graphmining.ai/temporal_datasets/$dname.zip"
8+
if length(readdir((dir))) == 0
9+
DataDeps.fetch_default(LINK, dir)
10+
currdir = pwd()
11+
cd(dir) # Needed since `unpack` extracts in working dir
12+
DataDeps.unpack(joinpath(dir, "$dname.zip"))
13+
# conditions when unzipped folder is our required data dir
14+
cd(currdir)
15+
end
16+
@assert isdir(dir)
17+
return dir
18+
end
19+
20+
function read_traffic(d::String, dname::String)
21+
if dname == "PEMS-BAY"
22+
s="pems_"
23+
elseif dname == "METR-LA"
24+
s=""
25+
end
26+
27+
adj_matrix = NPZ.npzread(joinpath(d, "$(s)adj_mat.npy"))
28+
node_features = NPZ.npzread(joinpath(d, "$(s)node_values.npy"))
29+
30+
return adj_matrix, node_features
31+
end
32+
33+
function traffic_generate_task(node_values::AbstractArray, num_timesteps_in::Int, num_timesteps_out::Int)
34+
indices = [(i, i + num_timesteps_in + num_timesteps_out) for i in 1:(size(node_values,1) - num_timesteps_in - num_timesteps_out)]
35+
features = []
36+
targets = []
37+
for (i,j) in indices
38+
push!(features, node_values[i:i+num_timesteps_in-1,:,:])
39+
push!(targets, reshape(node_values[i+num_timesteps_in:j-1,1,:], (num_timesteps_out, 1, size(node_values, 3))))
40+
end
41+
return features, targets
42+
end
43+
44+
function processed_traffic(dname::String, num_timesteps_in::Int, num_timesteps_out::Int, dir = nothing)
45+
create_default_dir(dname)
46+
d = traffic_datadir(dname, dir)
47+
adj_matrix, node_values = read_traffic(d, dname)
48+
49+
node_values = permutedims(node_values,(1,3,2))
50+
node_values = (node_values .- Statistics.mean(node_values, dims=(3,1))) ./ Statistics.std(node_values, dims=(3,1)) #Z-score normalization
51+
52+
s, t, w = adjmatrix2edgeindex(adj_matrix; weighted = true)
53+
54+
x, y = traffic_generate_task(node_values, num_timesteps_in, num_timesteps_out)
55+
return s, t, w, x, y
56+
end

test/datasets/graphs_no_ci.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,4 +351,18 @@ end
351351
@test g.num_edges == 1722
352352
@test length(g.node_data.features) == 34248
353353
@test length(g.node_data.targets) == 34248
354+
end
355+
356+
@testset "PEMS-BAY" begin
357+
data = PEMSBAY()
358+
@test data isa AbstractDataset
359+
@test length(data) == 1
360+
g = data[1]
361+
@test g === data[:]
362+
@test g isa MLDatasets.Graph
363+
364+
@test g.num_nodes == 325
365+
@test g.num_edges == 2694
366+
@test length(g.node_data.features) == 52081
367+
@test length(g.node_data.targets) == 52081
354368
end

0 commit comments

Comments
 (0)