Skip to content

Make VI work on v4 #4582

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 97 commits into from
Feb 25, 2022
Merged
Changes from 1 commit
Commits
Show all changes
97 commits
Select commit Hold shift + click to select a range
e59ebba
resolve merge conflicts
ferrine Mar 31, 2021
8aa290f
start fixing things
ferrine Jun 6, 2021
f96c626
make a simple test pass
ferrine Jun 7, 2021
0af6dac
fix some more tests
ferrine Jun 11, 2021
7e60bcc
fix some more tests
ferrine Jun 11, 2021
e0fbb98
add scaling for VI
ferrine Jun 18, 2021
e515217
add shape check
ferrine Jun 18, 2021
6dfc18c
aet -> at
ferrine Jun 18, 2021
39e635b
use rvs_to_values from the model in opi.py
ferrine Jun 21, 2021
9f61021
refactor cloning routines (fix pymc references)
ferrine Jun 21, 2021
8909ac7
Run pre-commit and include VI tests in pytest workflow (rebase)
michaelosthege Jul 2, 2021
1076fa1
Run pre-commit and include VI tests in pytest workflow
michaelosthege Jul 2, 2021
7e73cd7
seems like Grouped inference not working
ferrine Jul 28, 2021
64ba837
spot an error in a simple test case
ferrine Aug 3, 2021
4b91bce
fix the test case with grouping
ferrine Aug 3, 2021
c81458a
fix sampling with changed shape
ferrine Aug 3, 2021
11ef0b6
remove not implemented error for local inference
ferrine Aug 3, 2021
98dd81d
support inferencedata
ferrine Aug 8, 2021
c08eea3
get rid of shape error for batched mvnormal
ferrine Aug 8, 2021
77443f5
do not support AEVB with an error message
ferrine Aug 8, 2021
215f92b
fix some meore tests
ferrine Sep 8, 2021
94a28e5
fix some more tests
ferrine Sep 16, 2021
509f7ba
fix full rank test
ferrine Sep 16, 2021
c0c8fb9
fix tests
ferrine Sep 16, 2021
7745ac6
test vi
ferrine Sep 16, 2021
3dafc10
fix conversion function
ferrine Sep 16, 2021
2752ebd
propagate model
ferrine Sep 16, 2021
ff5f8c8
fix
ferrine Sep 16, 2021
c154063
fix elbo
ferrine Sep 19, 2021
af9c24d
fix elbo full rank
ferrine Sep 19, 2021
a9d40ef
Fixing broken scaling with float32
ferrine Sep 23, 2021
54d2a43
ignore a nasty test
ferrine Sep 23, 2021
6d46a2f
xfail one test with float 32
ferrine Sep 26, 2021
2ce5a7d
fix pre commit
ferrine Sep 26, 2021
69b9486
fix import
ferrine Sep 26, 2021
1beec12
fix import.1
ferrine Sep 26, 2021
894d5ce
Update pymc/variational/opvi.py
ferrine Sep 27, 2021
8d2ec8b
fix docstrings
ferrine Sep 27, 2021
60e5653
Merge branch 'v4-4523' of github.com:pymc-devs/pymc3 into v4-4523
ferrine Sep 27, 2021
c03352e
fix error with nans
ferrine Oct 14, 2021
00c1d14
remove TODO comments
ferrine Oct 14, 2021
27b4261
Merge branch 'main' into v4-4523
ferrine Oct 14, 2021
694286a
print statements to logging
ferrine Oct 14, 2021
8dba7d5
revert bart test
ferrine Oct 14, 2021
6a2fc35
apply changes from main
ferrine Oct 15, 2021
3a5915a
fix pylint issues
ferrine Oct 15, 2021
f6d9b98
fix test bart
ferrine Oct 20, 2021
9a79e27
fix interence_data in init
ferrine Oct 20, 2021
deafa96
ignore pickling problems
ferrine Oct 26, 2021
0f45e73
fix aevb test
ferrine Oct 26, 2021
4957765
Merge branch 'main' into v4-4523
ferrine Oct 26, 2021
0ab2fba
Merge branch 'main' into v4-4523
ferrine Nov 2, 2021
b1b4938
Merge branch 'main' into v4-4523
ferrine Nov 7, 2021
8d48870
fix name error
ferrine Nov 7, 2021
6efd630
xfail test ramdom fn
ferrine Nov 7, 2021
b2e9c0f
mark xfail
ferrine Nov 7, 2021
a92aad8
refactor test
ferrine Nov 7, 2021
f253417
xfail fix
ferrine Nov 7, 2021
f09d33a
fix xfail syntax
ferrine Nov 8, 2021
19ea8c9
pytest
ferrine Nov 8, 2021
f14cbc1
test fixed
ferrine Nov 8, 2021
02fc30f
5090 fixed
ferrine Nov 8, 2021
baefac6
do not test local flows
ferrine Nov 15, 2021
bf38d33
Merge branch 'main' into v4-4523
ferrine Nov 16, 2021
beb75ba
change model.logpt not to return float
ferrine Nov 16, 2021
74e19fd
Merge branch 'main' into v4-4523
ferrine Nov 23, 2021
c2d24de
add a test for the replacenent in the graph
ferrine Nov 27, 2021
8fdf9a2
Merge branch 'main' into v4-4523
michaelosthege Dec 20, 2021
3943e0f
merge main into PR
ferrine Jan 16, 2022
13a970e
fix sample node functionality
ferrine Jan 16, 2022
994fba5
Fix test with var replacement
ferrine Jan 16, 2022
6090029
add uncommited changes
ferrine Jan 16, 2022
48041f5
resolve @ricardoV94's comment about initial point
ferrine Jan 23, 2022
cb0fee9
restore test_bart.py as in main branch
ferrine Jan 23, 2022
c5911ac
resolve duplicated _get_scaling function
ferrine Jan 23, 2022
a466ffc
Merge branch 'main' into v4-4523
ferrine Jan 23, 2022
78ca582
change job order
ferrine Jan 23, 2022
e4cbb33
use commit initial point in the test file
ferrine Jan 23, 2022
8fad157
use compute initial point in the opvi.py
ferrine Jan 23, 2022
7f281bd
remove unnessesary pattern broadcast
ferrine Jan 24, 2022
8e8f63e
mark test as xfail before aesara release
ferrine Jan 24, 2022
72a7556
Do not mark anything but just wait for the new release
ferrine Jan 24, 2022
57e8342
Merge branch 'main' into v4-4523
ferrine Jan 30, 2022
a6f54ac
use compute_initial_point
ferrine Feb 13, 2022
1ee5536
Merge branch 'main' into v4-4523
ferrine Feb 14, 2022
4fab824
Merge branch 'main' into v4-4523
ferrine Feb 20, 2022
b4a2f62
Update pymc/variational/opvi.py
ferrine Feb 20, 2022
f9d16a7
run upgraded pre-commit
ferrine Feb 20, 2022
bc712ef
Merge branch 'v4-4523' of github.com:pymc-devs/pymc3 into v4-4523
ferrine Feb 20, 2022
6a3ee61
move pipe back
ferrine Feb 20, 2022
cd2cda9
Update pymc/variational/opvi.py
ferrine Feb 23, 2022
670edb9
Update pymc/variational/opvi.py
ferrine Feb 23, 2022
01fb223
Update pymc/variational/opvi.py
ferrine Feb 23, 2022
32006cd
Add removed newline
ricardoV94 Feb 23, 2022
1cb1418
Use compile_pymc instead of aesara.function
ricardoV94 Feb 23, 2022
ceddb5c
Replace None by empty list in output
ricardoV94 Feb 23, 2022
ef5f91b
Apply suggestions from code review
ferrine Feb 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Use compile_pymc instead of aesara.function
  • Loading branch information
ricardoV94 committed Feb 23, 2022
commit 1cb1418b5d06d4ea7a1bcf4ff5411372c8dd50bb
10 changes: 5 additions & 5 deletions pymc/variational/opvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@

import pymc as pm

from pymc.aesaraf import at_rng, identity, rvs_to_value_vars
from pymc.aesaraf import at_rng, compile_pymc, identity, rvs_to_value_vars
from pymc.backends import NDArray
from pymc.blocking import DictToArrayBijection
from pymc.initial_point import make_initial_point_fn
Expand Down Expand Up @@ -363,9 +363,9 @@ def step_function(
total_grad_norm_constraint=total_grad_norm_constraint,
)
if score:
step_fn = aesara.function([], updates.loss, updates=updates, **fn_kwargs)
step_fn = compile_pymc([], updates.loss, updates=updates, **fn_kwargs)
else:
step_fn = aesara.function([], None, updates=updates, **fn_kwargs)
step_fn = compile_pymc([], None, updates=updates, **fn_kwargs)
return step_fn

@aesara.config.change_flags(compute_test_value="off")
Expand Down Expand Up @@ -394,7 +394,7 @@ def score_function(
if more_replacements is None:
more_replacements = {}
loss = self(sc_n_mc, more_replacements=more_replacements)
return aesara.function([], loss, **fn_kwargs)
return compile_pymc([], loss, **fn_kwargs)

@aesara.config.change_flags(compute_test_value="off")
def __call__(self, nmc, **kwargs):
Expand Down Expand Up @@ -1637,7 +1637,7 @@ def sample_dict_fn(self):
names = [self.model.rvs_to_values[v].name for v in self.model.free_RVs]
sampled = [self.rslice(name) for name in names]
sampled = self.set_size_and_deterministic(sampled, s, 0)
sample_fn = aesara.function([s], sampled)
sample_fn = compile_pymc([s], sampled)

def inner(draws=100):
_samples = sample_fn(draws)
Expand Down