Fast low-rank update (LRU) of matrix determinants and pfaffians in JAX
In quantum physics and many other fields, it often happens that we have computed
Consider a special case that
Then
If lrux.det_lru
returns the ratio
import jax
# 64-bit recommended for numerical precision
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jax.random as jr
from lrux import det_lru
n = 10
A0 = jr.normal(jr.key(0), (n, n))
u = jr.normal(jr.key(1), (n,))
v = 5 # one-hot vector index
detA0 = jnp.linalg.det(A0)
Ainv = jnp.linalg.inv(A0)
ratio = det_lru(Ainv, u, v)
detA1_lru = detA0 * ratio
A1 = A0.at[v, :].add(u)
assert jnp.isclose(detA1_lru, jnp.linalg.det(A1))
Sometimes we need to keep computing
where the complexity is again
ratio, Ainv = det_lru(Ainv, u, v, return_update=True)
assert jnp.allclose(Ainv, jnp.linalg.inv(A1))
u_new = jr.normal(jr.key(2), (n,))
v_new = 6
ratio, Ainv = det_lru(Ainv, u_new, v_new, return_update=True)
detA2_lru = detA1_lru * ratio
A2 = A1.at[v_new, :].add(u_new)
assert jnp.isclose(detA2_lru, jnp.linalg.det(A2))
assert jnp.allclose(Ainv, jnp.linalg.inv(A2))
The main functions of lrux include det_lru
, det_lru_delayed
, pf_lru
, and pf_lru_delayed
. They provide:
- Row and column updates
- General rank-k updates
- Delayed updates
jit
andvmap
compatibility
As the pfaffian is not directly supported in JAX, we also provide backward-compatible functions pf
and slogpf
for pfaffian computations.
Requires Python 3.8+ and JAX 0.4.4+
pip install lrux