Starred repositories
Step size adaptation for the No-U-Turn Sampler
Simulation based calibration and generation of synthetic data.
ML Collections is a library of Python Collections designed for ML use cases.
Home for "How To Scale Your Model", a short blog-style textbook about scaling LLMs on TPUs
A pedagogical implementation of Autograd
High accuracy RAG for answering questions from scientific documents with citations
Probabilistic Programming and Nested sampling in JAX
Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Bayesian Neural Field models for prediction in large-scale spatiotemporal datasets
Normalizing-flow enhanced sampling package for probabilistic inference in Jax
Hardware accelerated, batchable and differentiable optimizers in JAX.
tree is a library for working with nested data structures
State of the art inference for your bayesian models.
Application of the L2HMC algorithm to simulations in lattice QCD.
Optax is a gradient processing and optimization library for JAX.
Python-based research interface for blackbox and hyperparameter optimization, based on the internal Google Vizier Service.
Optimal transport tools implemented with the JAX framework, to solve large scale matching problems of any flavor.