Skip to content

Commit a0e7ea7

Browse files
authored
Merge pull request #244 from Team-RADDISH/mmg/grf-update-and-cov-fix
Update `GaussianRandomFields` version and add additional tests for linear Gaussian models
2 parents 87c0324 + aabb7c5 commit a0e7ea7

File tree

5 files changed

+179
-92
lines changed

5 files changed

+179
-92
lines changed

extra/linear_gaussian_validation.jl

Lines changed: 5 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
### A Pluto.jl notebook ###
2-
# v0.19.18
2+
# v0.19.22
33

44
using Markdown
55
using InteractiveUtils
@@ -195,35 +195,6 @@ function plot_filter_estimate_rmse_vs_n_particles(
195195
)
196196
end
197197

198-
# ╔═╡ 159ed63c-5dac-4f9b-a0cc-a5c13b6978e0
199-
function diagonal_linear_gaussian_model_parameters(
200-
state_dimension=3,
201-
state_transition_coefficient=0.8,
202-
observation_coefficient=1.0,
203-
initial_state_std=1.0,
204-
state_noise_std=0.6,
205-
observation_noise_std=0.5,
206-
)
207-
return Dict(
208-
:state_transition_matrix => ScalMat(
209-
state_dimension, state_transition_coefficient
210-
),
211-
:observation_matrix => ScalMat(
212-
state_dimension, observation_coefficient
213-
),
214-
:initial_state_mean => Zeros(state_dimension),
215-
:initial_state_covar => ScalMat(
216-
state_dimension, initial_state_std^2
217-
),
218-
:state_noise_covar => ScalMat(
219-
state_dimension, state_noise_std^2
220-
),
221-
:observation_noise_covar => ScalMat(
222-
state_dimension, observation_noise_std^2
223-
),
224-
)
225-
end
226-
227198
# ╔═╡ 89dae12b-0010-4ea1-ae69-490137196662
228199
let
229200
n_time_step = 200
@@ -235,7 +206,7 @@ let
235206
n_particle,
236207
filter_type,
237208
LinearGaussian.init,
238-
diagonal_linear_gaussian_model_parameters(),
209+
LinearGaussian.diagonal_linear_gaussian_model_parameters(),
239210
seed
240211
)
241212
end
@@ -249,59 +220,12 @@ let
249220
n_time_step,
250221
n_particles,
251222
LinearGaussian.init,
252-
diagonal_linear_gaussian_model_parameters(),
223+
LinearGaussian.diagonal_linear_gaussian_model_parameters(),
253224
seed
254225
)
255-
# savefig(figure, "diagonal_linear_gaussian_model_estimate_rmse_vs_n_particles.pdf")
256226
figure
257227
end
258228

259-
# ╔═╡ db091a48-589f-4393-8951-aadc351588ff
260-
function stochastically_driven_dsho_model_parameters(
261-
δ=0.2,
262-
ω=1.,
263-
Q=2.,
264-
σ=0.5,
265-
)
266-
β = sqrt(Q^2 - 1 / 4)
267-
return Dict(
268-
:state_transition_matrix => exp(-ω * δ / 2Q) * [
269-
[
270-
cos* β * δ / Q) + sin* β * δ / Q) / 2β,
271-
Q * sin* β * δ / Q) /* β)
272-
]';
273-
[
274-
-Q * ω * sin* δ * β / Q) / β,
275-
cos* δ * β / Q) - sin* δ * β / Q) / 2β
276-
]'
277-
],
278-
:observation_matrix => ScalMat(2, 1.),
279-
:initial_state_mean => Zeros(2),
280-
:initial_state_covar => ScalMat(2, 1.),
281-
:state_noise_covar => PDMat(
282-
Q * exp(-ω * δ / Q) * [
283-
[
284-
(
285-
(cos(2ω * δ * β / Q) - 1)
286-
- 2β * sin(2ω * δ * β / Q)
287-
+ 4β^2 * (exp* δ / Q) - 1)
288-
) / (8ω^3 * β^2),
289-
Q * sin* δ * β / Q)^2 / (2ω^2 * β^2)
290-
]';
291-
[
292-
Q * sin* δ * β / Q)^2 / (2ω^2 * β^2),
293-
(
294-
(cos(2ω * δ * β / Q) - 1)
295-
+ 2β * sin(2ω * δ * β / Q)
296-
+ 4β^2 * (exp* δ / Q) - 1)
297-
) / (8ω * β^2),
298-
]'
299-
]
300-
),
301-
:observation_noise_covar => ScalMat(2, σ^2)
302-
)
303-
end
304-
305229
# ╔═╡ 64a289be-75ce-42e2-9e43-8e0286f70a35
306230
let
307231
n_time_step = 200
@@ -313,7 +237,7 @@ let
313237
n_particle,
314238
filter_type,
315239
LinearGaussian.init,
316-
stochastically_driven_dsho_model_parameters(),
240+
LinearGaussian.stochastically_driven_dsho_model_parameters(),
317241
seed
318242
)
319243
end
@@ -328,10 +252,9 @@ let
328252
n_time_step,
329253
n_particles,
330254
LinearGaussian.init,
331-
stochastically_driven_dsho_model_parameters(),
255+
LinearGaussian.stochastically_driven_dsho_model_parameters(),
332256
seed
333257
)
334-
# savefig(figure, "dsho_linear_gaussian_model_estimate_rmse_vs_n_particles.pdf")
335258
figure
336259
end
337260

@@ -341,9 +264,7 @@ end
341264
# ╠═4d2656ca-eacb-4d2b-91cb-bc82fdb49520
342265
# ╠═a64762bb-3a9f-4b1c-83db-f1a366f282eb
343266
# ╠═2ad564f3-48a2-4c2a-8d7d-384a84f7d6d2
344-
# ╠═159ed63c-5dac-4f9b-a0cc-a5c13b6978e0
345267
# ╠═89dae12b-0010-4ea1-ae69-490137196662
346268
# ╠═3e0abdfc-8668-431c-8ad3-61802e21d34e
347-
# ╠═db091a48-589f-4393-8951-aadc351588ff
348269
# ╠═64a289be-75ce-42e2-9e43-8e0286f70a35
349270
# ╠═b396f776-885b-437a-94c3-693f318d7ed2

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ YAML = "ddb6d928-2868-570f-bddf-ab3f9cf99eb6"
1818
[compat]
1919
Distributions = "0.22, 0.23, 0.24, 0.25"
2020
FillArrays = "0.13"
21-
GaussianRandomFields = "2.1.1"
21+
GaussianRandomFields = "2.2.1"
2222
HDF5 = "0.14, 0.15, 0.16"
2323
MPI = "0.20.8"
2424
OrdinaryDiffEq = "6.40"

test/models/lineargaussian.jl

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,79 @@ struct LinearGaussianModel{S <: Real, T <: Real}
3434
observation_noise_distribution::MvNormal{T}
3535
end
3636

37+
function diagonal_linear_gaussian_model_parameters(
38+
state_dimension=3,
39+
state_transition_coefficient=0.8,
40+
observation_coefficient=1.0,
41+
initial_state_std=1.0,
42+
state_noise_std=0.6,
43+
observation_noise_std=0.5,
44+
)
45+
return Dict(
46+
:state_transition_matrix => ScalMat(
47+
state_dimension, state_transition_coefficient
48+
),
49+
:observation_matrix => ScalMat(
50+
state_dimension, observation_coefficient
51+
),
52+
:initial_state_mean => Zeros(state_dimension),
53+
:initial_state_covar => ScalMat(
54+
state_dimension, initial_state_std^2
55+
),
56+
:state_noise_covar => ScalMat(
57+
state_dimension, state_noise_std^2
58+
),
59+
:observation_noise_covar => ScalMat(
60+
state_dimension, observation_noise_std^2
61+
),
62+
)
63+
end
64+
65+
function stochastically_driven_dsho_model_parameters(
66+
δ=0.2,
67+
ω=1.,
68+
Q=2.,
69+
σ=0.5,
70+
)
71+
β = sqrt(Q^2 - 1 / 4)
72+
return Dict(
73+
:state_transition_matrix => exp(-ω * δ / 2Q) * [
74+
[
75+
cos* β * δ / Q) + sin* β * δ / Q) / 2β,
76+
Q * sin* β * δ / Q) /* β)
77+
]';
78+
[
79+
-Q * ω * sin* δ * β / Q) / β,
80+
cos* δ * β / Q) - sin* δ * β / Q) / 2β
81+
]'
82+
],
83+
:observation_matrix => ScalMat(2, 1.),
84+
:initial_state_mean => Zeros(2),
85+
:initial_state_covar => ScalMat(2, 1.),
86+
:state_noise_covar => PDMat(
87+
Q * exp(-ω * δ / Q) * [
88+
[
89+
(
90+
(cos(2ω * δ * β / Q) - 1)
91+
- 2β * sin(2ω * δ * β / Q)
92+
+ 4β^2 * (exp* δ / Q) - 1)
93+
) / (8ω^3 * β^2),
94+
Q * sin* δ * β / Q)^2 / (2ω^2 * β^2)
95+
]';
96+
[
97+
Q * sin* δ * β / Q)^2 / (2ω^2 * β^2),
98+
(
99+
(cos(2ω * δ * β / Q) - 1)
100+
+ 2β * sin(2ω * δ * β / Q)
101+
+ 4β^2 * (exp* δ / Q) - 1)
102+
) / (8ω * β^2),
103+
]'
104+
]
105+
),
106+
:observation_noise_covar => ScalMat(2, σ^2)
107+
)
108+
end
109+
37110
function init(parameters_dict::Dict)
38111
parameters = LinearGaussianModelParameters(; parameters_dict...)
39112
(observation_dimension, state_dimension) = size(

test/models/llw2d.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ function get_covariance_gaussian_random_fields(
396396
model_parameters, (x_index_2, y_index_2)
397397
)
398398
covariance_structure = gaussian_random_fields[var_index_1].grf.cov.cov
399-
return covariance_structure.σ^2 * apply(
399+
return apply(
400400
covariance_structure, abs.(grid_point_1 .- grid_point_2)
401401
)
402402
else

test/runtests.jl

Lines changed: 99 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ using HDF5, LinearAlgebra, MPI, PDMats, Random, StableRNGs, Statistics, Test, YA
33

44
include(joinpath(@__DIR__, "models", "llw2d.jl"))
55
include(joinpath(@__DIR__, "models", "lorenz63.jl"))
6+
include(joinpath(@__DIR__, "models", "lineargaussian.jl"))
7+
include(joinpath(@__DIR__, "kalman.jl"))
68

79
using .LLW2d
810
using .Lorenz63
@@ -226,10 +228,13 @@ function run_unit_tests_for_generic_model_interface(model, seed)
226228
end
227229

228230
@testset (
229-
"Generic model interface unit tests - $model_module"
230-
) for model_module in (LLW2d, Lorenz63)
231+
"Generic model interface unit tests - $(parentmodule(typeof(model)))"
232+
) for model in (
233+
LLW2d.init(Dict()),
234+
Lorenz63.init(Dict()),
235+
LinearGaussian.init(LinearGaussian.stochastically_driven_dsho_model_parameters())
236+
)
231237
seed = 1234
232-
model = model_module.init(Dict())
233238
run_unit_tests_for_generic_model_interface(model, seed)
234239
end
235240

@@ -508,10 +513,14 @@ function run_tests_for_optimal_proposal_model_interface(
508513
end
509514

510515
@testset (
511-
"Optimal proposal model interface unit tests - $(model_module)"
512-
) for model_module in (LLW2d, Lorenz63)
516+
"Optimal proposal model interface unit tests - $(parentmodule(typeof(model)))"
517+
) for model in (
518+
# Use sigma != 1. to test if covariance is being scaled by sigma correctly
519+
LLW2d.init(Dict("llw2d" => Dict("sigma" => [0.5, 1.5, 1.5]))),
520+
Lorenz63.init(Dict()),
521+
LinearGaussian.init(LinearGaussian.stochastically_driven_dsho_model_parameters())
522+
)
513523
seed = 1234
514-
model = model_module.init(Dict())
515524
# Number of samples to use in convergence tests of Monte Carlo estimates
516525
estimate_n_samples = [10, 100, 1000]
517526
# Constant factor used in Monte Carlo estimate convergence tests. Set based on some
@@ -523,6 +532,90 @@ end
523532
)
524533
end
525534

535+
function run_tests_for_convergence_of_filter_estimates_against_kalman_filter(
536+
filter_type,
537+
init_model,
538+
model_parameters_dict,
539+
seed,
540+
n_time_step,
541+
n_particles,
542+
mean_rmse_bound_constant,
543+
log_var_rmse_bound_constant,
544+
)
545+
rng = Random.TaskLocalRNG()
546+
Random.seed!(rng, seed)
547+
model = init_model(model_parameters_dict)
548+
observation_seq = ParticleDA.simulate_observations_from_model(
549+
model, n_time_step; rng=rng
550+
)
551+
true_state_mean_seq, true_state_var_seq = Kalman.run_kalman_filter(
552+
model, observation_seq
553+
)
554+
for n_particle in n_particles
555+
output_filename = tempname()
556+
filter_parameters = ParticleDA.FilterParameters(
557+
nprt=n_particle, verbose=true, output_filename=output_filename
558+
)
559+
states, statistics = ParticleDA.run_particle_filter(
560+
init_model,
561+
filter_parameters,
562+
model_parameters_dict,
563+
observation_seq,
564+
filter_type,
565+
ParticleDA.MeanAndVarSummaryStat;
566+
rng=rng
567+
)
568+
state_mean_seq = Matrix{ParticleDA.get_state_eltype(model)}(
569+
undef, ParticleDA.get_state_dimension(model), n_time_step
570+
)
571+
state_var_seq = Matrix{ParticleDA.get_state_eltype(model)}(
572+
undef, ParticleDA.get_state_dimension(model), n_time_step
573+
)
574+
weights_seq = Matrix{Float64}(undef, n_particle, n_time_step)
575+
h5open(output_filename, "r") do file
576+
for t in 1:n_time_step
577+
key = ParticleDA.time_index_to_hdf5_key(t)
578+
state_mean_seq[:, t] = read(file["state_avg"][key])
579+
state_var_seq[:, t] = read(file["state_var"][key])
580+
weights_seq[:, t] = read(file["weights"][key])
581+
end
582+
end
583+
mean_rmse = sqrt(
584+
mean(x -> x.^2, state_mean_seq .- true_state_mean_seq)
585+
)
586+
log_var_rmse = sqrt(
587+
mean(x -> x.^2, log.(state_var_seq) .- log.(true_state_var_seq))
588+
)
589+
# Monte Carlo estimates of mean and log variance should have O(sqrt(n_particle))
590+
# convergence to true values
591+
@test mean_rmse < mean_rmse_bound_constant / sqrt(n_particle)
592+
@test log_var_rmse < log_var_rmse_bound_constant / sqrt(n_particle)
593+
end
594+
end
595+
596+
@testset (
597+
"Filter estimate validation against Kalman filter - $(filter_type)"
598+
) for filter_type in (BootstrapFilter, OptimalFilter)
599+
seed = 1234
600+
n_time_step = 100
601+
n_particles = [30, 100, 300, 1000]
602+
# Constant factora used in Monte Carlo estimate convergence tests. Set based on some
603+
# trial and error to keep tests relatively sensitive while avoiding too high
604+
# probability of false failures
605+
mean_rmse_bound_constant = 1.
606+
log_var_rmse_bound_constant = 5.
607+
run_tests_for_convergence_of_filter_estimates_against_kalman_filter(
608+
filter_type,
609+
LinearGaussian.init,
610+
LinearGaussian.stochastically_driven_dsho_model_parameters(),
611+
seed,
612+
n_time_step,
613+
n_particles,
614+
mean_rmse_bound_constant,
615+
log_var_rmse_bound_constant,
616+
)
617+
end
618+
526619
@testset "Summary statistics unit tests" begin
527620
MPI.Init()
528621
seed = 5678

0 commit comments

Comments
 (0)