Skip to content

BlackJAX

Composable Bayesian inference in JAX — on CPU, GPU and TPU


CI codecov PyPI arXiv Stars


BlackJAX sampling animation



What is BlackJAX?

BlackJAX is a library of samplers for JAX that runs on CPU, GPU, and TPU. It is not a probabilistic programming library — but it integrates seamlessly with any PPL that can expose a JAX-compatible log-probability density function.

It bridges the gap between one-liner frameworks and fully modular, composable inference toolkits. Use it as a black box, or crack it open and compose your own algorithms from low-level building blocks.


Algorithms

BlackJAX ships a broad collection of production-ready samplers:

Family Algorithms
MCMC NUTS, HMC, MALA, Random Walk MH, Elliptical Slice Sampling, Barker, ...
SMC Sequential Monte Carlo with adaptive tempering
Stochastic gradient SGLD, SGHMC, CSGLD, ...
Variational Stein Variational Gradient Descent
Orbit Generalized HMC / periodic orbit samplers

Quick Start

pip install blackjax
import jax
import jax.numpy as jnp
import blackjax

def logdensity_fn(x):
    return -0.5 * jnp.sum(x ** 2)  # standard normal

# Build a NUTS sampler
nuts = blackjax.nuts(logdensity_fn, step_size=1e-3, inverse_mass_matrix=jnp.ones(2))

# Initialize and run a single step
state = nuts.init(jnp.zeros(2))
rng_key = jax.random.key(0)
state, info = jax.jit(nuts.step)(rng_key, state)

See the documentation for full inference loops, multi-chain sampling, Stan warmup, and more.


Resources

Resource Description
📖 Documentation Full API reference, tutorials, and worked examples
📓 Sampling Book A cookbook of Bayesian inference with BlackJAX
📄 arXiv Paper Cabezas et al. (2024) — BlackJAX: Composable Bayesian inference in JAX
🐍 PyPI Latest release on the Python Package Index

Community

BlackJAX is built by researchers and engineers who care about sampling. We welcome contributions of all kinds — new algorithms, bug fixes, documentation, and discussions.


---

**What changed vs. the current 3-line README:**

| Before | After |
|---|---|
| No visual identity | Centered hero with the existing scatter.gif animation |
| No badges | CI, codecov, PyPI, arXiv, stars badges |
| No description | Clear pitch explaining what BlackJAX is and who it's for |
| No algorithm listing | Table covering MCMC, SMC, SG, variational, orbit families |
| No code | Minimal NUTS quick-start snippet |
| 2 bare links | Rich resource table + community section |

Pinned Loading

  1. blackjax blackjax Public

    BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.

    Python 1.1k 140

  2. sampling-book sampling-book Public

    Tutorials and sampling algorithm comparisons

    TeX 85 17

Repositories

Showing 3 of 3 repositories

Top languages

Loading…

Most used topics

Loading…