@@ -3,6 +3,8 @@ using HDF5, LinearAlgebra, MPI, PDMats, Random, StableRNGs, Statistics, Test, YA
33
44include (joinpath (@__DIR__ , " models" , " llw2d.jl" ))
55include (joinpath (@__DIR__ , " models" , " lorenz63.jl" ))
6+ include (joinpath (@__DIR__ , " models" , " lineargaussian.jl" ))
7+ include (joinpath (@__DIR__ , " kalman.jl" ))
68
79using . LLW2d
810using . Lorenz63
@@ -226,10 +228,13 @@ function run_unit_tests_for_generic_model_interface(model, seed)
226228end
227229
228230@testset (
229- " Generic model interface unit tests - $model_module "
230- ) for model_module in (LLW2d, Lorenz63)
231+ " Generic model interface unit tests - $(parentmodule (typeof (model))) "
232+ ) for model in (
233+ LLW2d. init (Dict ()),
234+ Lorenz63. init (Dict ()),
235+ LinearGaussian. init (LinearGaussian. stochastically_driven_dsho_model_parameters ())
236+ )
231237 seed = 1234
232- model = model_module. init (Dict ())
233238 run_unit_tests_for_generic_model_interface (model, seed)
234239end
235240
@@ -508,10 +513,14 @@ function run_tests_for_optimal_proposal_model_interface(
508513end
509514
510515@testset (
511- " Optimal proposal model interface unit tests - $(model_module) "
512- ) for model_module in (LLW2d, Lorenz63)
516+ " Optimal proposal model interface unit tests - $(parentmodule (typeof (model))) "
517+ ) for model in (
518+ # Use sigma != 1. to test if covariance is being scaled by sigma correctly
519+ LLW2d. init (Dict (" llw2d" => Dict (" sigma" => [0.5 , 1.5 , 1.5 ]))),
520+ Lorenz63. init (Dict ()),
521+ LinearGaussian. init (LinearGaussian. stochastically_driven_dsho_model_parameters ())
522+ )
513523 seed = 1234
514- model = model_module. init (Dict ())
515524 # Number of samples to use in convergence tests of Monte Carlo estimates
516525 estimate_n_samples = [10 , 100 , 1000 ]
517526 # Constant factor used in Monte Carlo estimate convergence tests. Set based on some
523532 )
524533end
525534
535+ function run_tests_for_convergence_of_filter_estimates_against_kalman_filter (
536+ filter_type,
537+ init_model,
538+ model_parameters_dict,
539+ seed,
540+ n_time_step,
541+ n_particles,
542+ mean_rmse_bound_constant,
543+ log_var_rmse_bound_constant,
544+ )
545+ rng = Random. TaskLocalRNG ()
546+ Random. seed! (rng, seed)
547+ model = init_model (model_parameters_dict)
548+ observation_seq = ParticleDA. simulate_observations_from_model (
549+ model, n_time_step; rng= rng
550+ )
551+ true_state_mean_seq, true_state_var_seq = Kalman. run_kalman_filter (
552+ model, observation_seq
553+ )
554+ for n_particle in n_particles
555+ output_filename = tempname ()
556+ filter_parameters = ParticleDA. FilterParameters (
557+ nprt= n_particle, verbose= true , output_filename= output_filename
558+ )
559+ states, statistics = ParticleDA. run_particle_filter (
560+ init_model,
561+ filter_parameters,
562+ model_parameters_dict,
563+ observation_seq,
564+ filter_type,
565+ ParticleDA. MeanAndVarSummaryStat;
566+ rng= rng
567+ )
568+ state_mean_seq = Matrix {ParticleDA.get_state_eltype(model)} (
569+ undef, ParticleDA. get_state_dimension (model), n_time_step
570+ )
571+ state_var_seq = Matrix {ParticleDA.get_state_eltype(model)} (
572+ undef, ParticleDA. get_state_dimension (model), n_time_step
573+ )
574+ weights_seq = Matrix {Float64} (undef, n_particle, n_time_step)
575+ h5open (output_filename, " r" ) do file
576+ for t in 1 : n_time_step
577+ key = ParticleDA. time_index_to_hdf5_key (t)
578+ state_mean_seq[:, t] = read (file[" state_avg" ][key])
579+ state_var_seq[:, t] = read (file[" state_var" ][key])
580+ weights_seq[:, t] = read (file[" weights" ][key])
581+ end
582+ end
583+ mean_rmse = sqrt (
584+ mean (x -> x.^ 2 , state_mean_seq .- true_state_mean_seq)
585+ )
586+ log_var_rmse = sqrt (
587+ mean (x -> x.^ 2 , log .(state_var_seq) .- log .(true_state_var_seq))
588+ )
589+ # Monte Carlo estimates of mean and log variance should have O(sqrt(n_particle))
590+ # convergence to true values
591+ @test mean_rmse < mean_rmse_bound_constant / sqrt (n_particle)
592+ @test log_var_rmse < log_var_rmse_bound_constant / sqrt (n_particle)
593+ end
594+ end
595+
596+ @testset (
597+ " Filter estimate validation against Kalman filter - $(filter_type) "
598+ ) for filter_type in (BootstrapFilter, OptimalFilter)
599+ seed = 1234
600+ n_time_step = 100
601+ n_particles = [30 , 100 , 300 , 1000 ]
602+ # Constant factora used in Monte Carlo estimate convergence tests. Set based on some
603+ # trial and error to keep tests relatively sensitive while avoiding too high
604+ # probability of false failures
605+ mean_rmse_bound_constant = 1.
606+ log_var_rmse_bound_constant = 5.
607+ run_tests_for_convergence_of_filter_estimates_against_kalman_filter (
608+ filter_type,
609+ LinearGaussian. init,
610+ LinearGaussian. stochastically_driven_dsho_model_parameters (),
611+ seed,
612+ n_time_step,
613+ n_particles,
614+ mean_rmse_bound_constant,
615+ log_var_rmse_bound_constant,
616+ )
617+ end
618+
526619@testset " Summary statistics unit tests" begin
527620 MPI. Init ()
528621 seed = 5678
0 commit comments