Skip to content

wrapping jax example code cell 22 returns error #625

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
AndreV84 opened this issue Jan 8, 2024 · 14 comments · Fixed by #630
Closed

wrapping jax example code cell 22 returns error #625

AndreV84 opened this issue Jan 8, 2024 · 14 comments · Fixed by #630
Labels
good first issue Good for newcomers

Comments

@AndreV84
Copy link

AndreV84 commented Jan 8, 2024

Describe the issue:

cell 22 of the file https://github.com/pymc-devs/pymc-examples/blob/main/examples/howto/wrapping_jax_function.ipynb throws error

reference issue pymc-devs/pymc#7088 (comment)

with pm.Model(rng_seeder=int(rng.integers(2**30))) as model:
    emission_signal = pm.Normal("emission_signal", 0, 1)
    emission_noise = pm.HalfNormal("emission_noise", 1)


    p_initial_state = pm.Dirichlet("p_initial_state", np.ones(3))
    logp_initial_state = pt.log(p_initial_state)

    p_transition = pm.Dirichlet("p_transition", np.ones(3), size=3)
    logp_transition = pt.log(p_transition)

    loglike = pm.Potential(
        "hmm_loglike",
        hmm_logp_op(
            emission_observed,
            emission_signal,
            emission_noise,
            logp_initial_state,
            logp_transition,
        ),
    )`

Reproduceable code example:

with pm.Model(rng_seeder=int(rng.integers(2**30))) as model:
    emission_signal = pm.Normal("emission_signal", 0, 1)
    emission_noise = pm.HalfNormal("emission_noise", 1)


    p_initial_state = pm.Dirichlet("p_initial_state", np.ones(3))
    logp_initial_state = pt.log(p_initial_state)

    p_transition = pm.Dirichlet("p_transition", np.ones(3), size=3)
    logp_transition = pt.log(p_transition)

    loglike = pm.Potential(
        "hmm_loglike",
        hmm_logp_op(
            emission_observed,
            emission_signal,
            emission_noise,
            logp_initial_state,
            logp_transition,
        ),
    )`

Error message:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[22], line 1
----> 1 with pm.Model(rng_seeder=int(rng.integers(2**30))) as model:
      2     emission_signal = pm.Normal("emission_signal", 0, 1)
      3     emission_noise = pm.HalfNormal("emission_noise", 1)

File ~/.local/lib/python3.8/site-packages/pymc/model.py:221, in ContextMeta.__call__(cls, *args, **kwargs)
    219 instance: "Model" = cls.__new__(cls, *args, **kwargs)
    220 with instance:  # appends context
--> 221     instance.__init__(*args, **kwargs)
    222 return instance

TypeError: __init__() got an unexpected keyword argument 'rng_seeder'`

PyMC version information:

5.6.1

Context for the issue:

since cell 22 ipython notebook example won't work

@ricardoV94
Copy link
Member

That's an old API, rng_seeder no longer exists. We can just remove that

@ricardoV94 ricardoV94 transferred this issue from pymc-devs/pymc Jan 8, 2024
@AndreV84
Copy link
Author

AndreV84 commented Jan 8, 2024

@ricardoV94 Thank you for your prompt response
Do you mean from the cell line I have to remove rng_seeder like that?
with pm.Model(rng_seeder=int(rng.integers(2**30))) as model:
so that the line will look like:
with pm.Model() as model: ?
is there an updated API example available for reference?

@ricardoV94
Copy link
Member

Yes that's it. Seeding is now (for a long while) always done by passing random_seed to sampling functions like pm.sample and so on.

No reference for that specific change, but let me know if you have any doubts.

@AndreV84
Copy link
Author

AndreV84 commented Jan 8, 2024

the next line fails after reducing the previous model definition

initial_point = model.compute_initial_point()
initial_point
AttributeError                            Traceback (most recent call last)
Cell In[24], line 1
----> 1 initial_point = model.compute_initial_point()
      2 initial_point

AttributeError: 'Model' object has no attribute 'compute_initial_point'`

@ricardoV94
Copy link
Member

Wow that notebook is really outdated. That's just called model.initial_point() now. Hopefully that's the last change :/

@AndreV84
Copy link
Author

AndreV84 commented Jan 8, 2024

maybe you know how to update this line too?

model_logp_jax_fn = model.compile_fn(model.logpt(sum=False), mode="JAX")
model_logp_jax_fn(initial_point)

@AndreV84
Copy link
Author

AndreV84 commented Jan 8, 2024

otherwise it is err

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[31], line 1
----> 1 model_logp_jax_fn = model.compile_fn(model.logpt(sum=False), mode="JAX")
      2 model_logp_jax_fn(initial_point)

AttributeError: 'Model' object has no attribute 'logpt'

@ricardoV94
Copy link
Member

model.logpt is now model.logp :)

@AndreV84
Copy link
Author

AndreV84 commented Jan 8, 2024

by now the entire notebook seems somewhat patched; thank you

for concern of the following kind I shall post at the discourse rather than at github?
from developer:
"there seems to be conflicts between JAX and PyMC pytnesor arguments , especially in the "dH" and "RKAMethod_jax" functions and I don't how to find a workaround"
"For the moment, we work with the 2 likelihood "Hubble(z)" and "SNIa-SCP""
"at each proposal of the parameters to estimate, we plug them in the computation of Planck Hi-CLASS code and computes the chi2 to see if we accept or not the point"
"
the ideal would be to include the Planck Likelihood ( with clik etc ...) in the summing of the chi2
"
/.local/lib/python3.8/site-packages/pytensor/tensor/__init__.py", line 56, in astensor_variable raise NotImplementedError(f"Cannot convert {x!r} to a tensor variable.") NotImplementedError: Cannot convert Array(6.00012, dtype=float64) to a tensor variable.

@ricardoV94
Copy link
Member

by now the entire notebook seems somewhat patched; thank you

Would you consider opening a PR to fix the NB for everyone?

for concern of the following kind I shall post at the discourse rather than at github?

Yup. You will also get much more visibility there (on average)

@AndreV84
Copy link
Author

AndreV84 commented Jan 8, 2024

I can share the resulting code
wrapping_jax_function_py.zip
wrapping_jax_function_ipynb.zip
not certain how to open PR to fix the NB; probably you could submit the PR request?

@AndreV84 AndreV84 closed this as completed Jan 8, 2024
@ricardoV94 ricardoV94 reopened this Jan 8, 2024
@ricardoV94
Copy link
Member

Keeping the issue open so we don't forget to fix it.

Thanks for sharing the code

@OriolAbril OriolAbril added the good first issue Good for newcomers label Jan 10, 2024
@HarshvirSandhu
Copy link
Contributor

Hello @ricardoV94
If this issue is still open, can I open a PR for fixing the notebook?

@ricardoV94
Copy link
Member

Definitely @HarshvirSandhu

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants