Skip to content

Optional quantem-cuda kernel dispatch for K-Planes models and volume TV#243

Open
cedriclim1 wants to merge 10 commits into
electronmicroscopy:devfrom
cedriclim1:feat/quantem-cuda-extra
Open

Optional quantem-cuda kernel dispatch for K-Planes models and volume TV#243
cedriclim1 wants to merge 10 commits into
electronmicroscopy:devfrom
cedriclim1:feat/quantem-cuda-extra

Conversation

@cedriclim1

@cedriclim1 cedriclim1 commented Jun 10, 2026

Copy link
Copy Markdown
Collaborator

What problem this PR addresses

The sibling quantem-cuda package (electronmicroscopy/quantem-cuda#2) provides fused CUDA kernels, with hand-written analytic backwards, for two operations that pure torch handles inefficiently. Neither is tomography-specific:

  • Tilted K-Planes feature interpolation. KPlanesTILTED lives in quantem.core.ml.models.kplanes and is shared across applications (tomography object models are the current consumer). Its interpolate_ms_features_tilted — einsum rotate → gather → F.grid_sample → 3-plane Hadamard product — dominates every K-Planes training step (75–82 % of the GPU step in the tomography profile) and materializes a (3T, C, B) intermediate of up to several GB per batch just to multiply three planes pointwise. The kernel lives in quantem.cuda.core.ml, mirroring the model's home, so any K-Planes application picks it up.
  • Volume TV regularization (quantem.cuda.core, shared kernels usable from any module — see the explicit inventory below).

This PR wires quantem to use those kernels when available, with zero impact otherwise:

  • config.get("has_quantem_cuda") capability flag (try-import in core/config.py, mirroring has_torch/has_cupy) and a use_cuda_kernels config kill-switch (default true).
  • Per-multiscale-level dispatch inside interpolate_ms_features_tilted (core/ml/models/kplanes.py), used by every KPlanesTILTED forward in any application (data term, TV soft constraints, and create_volume decoding alike).
  • tv_loss_vol_sq() dispatch helper in tomography/utils.py, used by ObjectPixelated.get_tv_loss.
  • ObjectConstraints._calc_tv_loss (diffractive_imaging/object_models.py) dispatches the ptychography TV object constraint to the fused L1 kernel — same anisotropic-L1 functional, per-axis (tv_weight_z, tv_weight_xy) weights, and active-axis normalization, with identical gradients (sign(0) = 0). Weighted size-1 axes fall through to torch so degenerate inputs behave exactly as before.

Dispatch requires CUDA fp32 tensors + the package installed + the kill-switch on; the existing pure-torch code is unchanged and remains the fallback (CPU, other dtypes, kernels disabled, package absent).

Exactly which TV kernels exist (and which don't)

quantem.cuda.core implements three TV functionals, all fp32/CUDA/real-only, shape [..., D, H, W] (ndim >= 3, leading dims flattened in), returning a differentiable scalar ((3,) for the per-axis L1 sums), torch.compile(fullgraph=True)-safe:

Kernel Functional Reduction
tv_loss_sq_3d(volume) squared anisotropic: Σ(Δd)² + Σ(Δh)² + Σ(Δw)² (forward differences, full index range per axis) raw unnormalized sum — exact parity with our tv_vol before its weight / numel scaling
tv_loss_iso_3d(volume, eps=1e-8) isotropic, edge-preserving: sqrt((Δd)² + (Δh)² + (Δw)² + eps) with all three differences anchored on the common corner set [0,D−2]×[0,H−2]×[0,W−2] mean over corners (eps inside the sqrt)
tv_loss_l1_3d(volume) anisotropic L1: per-axis Σ|forward diff| — the exact functional of ptychography's _calc_tv_loss raw per-axis sums, shape (3,) — weighting/normalization stay with the caller, so any per-axis scheme is exact; d|x|/dx = sign(x), sign(0) = 0

Explicitly not implemented: per-axis weights inside a kernel (the L1 kernel's per-axis sums solve this in Python; a two-call recipe covers the squared variant), a 2-D-only variant (the anisotropic kernels on [1, H, W] degrade gracefully to pure-2D TV; tv_loss_iso_3d needs D >= 2 — its corner set is empty on a single slice and it returns 0.0), complex input (take .angle()/.abs() first), and fp16/bf16/fp64.

Notebooks (usage + reconstruction-level quality gate)

Both executed on an RTX PRO 6000 so outputs are reviewable; they are attached below as notebooks.zip rather than committed (the repo gitignores *.ipynb).

quantem_cuda_kernels.ipynb documents the direct-call API — the kernels are ordinary public torch functions, not just dispatch targets:

  • parity checks of all three TV kernels against pure-torch references, values and gradients, including the exact composition of tv_loss_l1_3d's per-axis sums into ptychography's _calc_tv_loss (weights, counts, active-axis normalization);
  • a ptychography section: the tv_weight_z/tv_weight_xy constraint now runs through the fused L1 kernel automatically; for different TV flavors, apply tv_loss_sq_3d/tv_loss_iso_3d to obj.angle() directly, with a verified two-call per-axis recipe for the squared variant. Op-level timings: the L1 kernel is ~4× the torch chain at (16, 1024, 1024) multislice and 256³ sizes (a few ms per iteration), neutral at single-slice sizes — a few-percent end-to-end win at large multislice sizes against an FFT-dominated iteration, roughly nothing for small single-slice objects;
  • the fused K-Planes interpolation contract (pts [B,3] in [−1,1]³, rotations [T,3,3], plane [3T,C,H,W][B,T·C]; grid_sample(align_corners=True, padding_mode="border") parity, gradients to all three inputs) with a direct-call vs torch-chain comparison;
  • the dispatch config surface and kill switch.

ptycho_tv_kernel_quality.ipynb closes the loop at the reconstruction level: it reruns the quantem-tutorials ptycho_iter_02_full reconstruction on its simulated ducky dataset with TV constraints (single-slice xy TV, and multislice z+xy TV à la ptycho_iter_04_multislice), with and without the kernel, and gates on SSIM of the reconstructed phase. The pipeline is bit-exact deterministic, so the fair floor is a torch run with the defocus perturbed by 2×10⁻⁷ relative — a 1-ulp TV rounding difference diverges a deterministic trajectory exactly like a 1-ulp physics perturbation. Both configs pass: SSIM(torch, kernel) 0.997 vs floor 0.996 (single-slice), 0.992 vs 0.988 (multislice), final data-fidelity losses within the floor's scatter.

Results (RTX PRO 6000, KPlanesTILTED T=4/M=8/200³, batch 2048 × 300 samples/ray)

  • Fused interpolation: 6.3× fwd+bwd over the torch chain (5.7× at T=8/M=32 — the win grows with model capacity).
  • TV kernels at tomography volumes: ~10× fwd+bwd at 256³, ~30× (sq) at 512³; the L1 kernel ~4× at multislice-ptychography and 256³ sizes.
  • End-to-end with device-resident batching (separate follow-up branch): 24.6 → 5.8 s/epoch, i.e. a 300-iteration 300³ reconstruction in ~29 min instead of ~2 h.
  • Quality gate on the atomic-resolution Au-nanoparticle phantom (clean dose, 300 iterations, seed 0, every interpolation call through the kernel): detection F1 0.858 vs 0.832 for the torch baseline (precision 1.000 both, recall 0.751 vs 0.713), PSNR 7.78 vs 7.98 dB — small mixed-sign deltas within run-to-run noise; no regression.

Packaging note: the cuda optional extra is intentionally not added yet — an unpublished package in [project.optional-dependencies] breaks uv lock (and the check-uv-lock gate) for everyone. Once quantem-cuda is published to PyPI this becomes a one-line follow-up; the soft-import dispatch in this PR works with any install of the package in the meantime.

What should the reviewer(s) do

Draft until the quantem-cuda PR lands — feedback welcome on the dispatch points, the config surface, and the notebook. Tests: tests/ml/test_kplanes_cuda_dispatch.py, tests/tomography/test_cuda_dispatch.py, and tests/diffractive_imaging/test_tv_cuda_dispatch.py (GPU/CPU parity incl. gradients, dispatch-spy, kill-switch; meaningful both with and without quantem-cuda installed). Full suite green in both environments.

  • This PR introduces a public-facing change (e.g., figures, CLI input/output, API).
    • For functional and algorithmic changes, tests are written or updated.
    • Documentation (usage notebook quantem_cuda_kernels.ipynb, attached below) has been updated.

notebooks.zip

Add a has_quantem_cuda capability flag (mirroring has_torch/has_cupy)
and a use_cuda_kernels config option (default true) as a kill-switch.

ObjectPixelated.get_tv_loss now goes through tv_loss_vol_sq in
tomography/utils.py, which dispatches to quantem-cuda's fused
squared-anisotropic TV kernel (identical math, analytic backward,
10-14x faster fwd+bwd at 256^3-512^3) when the package is installed,
the object is a CUDA fp32 tensor, and the kill-switch is on; otherwise
it falls back to the existing pure-torch expression.

quantem-cuda is not yet on PyPI, so the [cuda] extra is deferred to a
one-line follow-up after the 0.1.0 publish (an unpublished package in
optional-dependencies breaks uv lock for everyone). Until then,
installing quantem-cuda from source enables the dispatch.

Tests cover the fallback paths (CPU, no quantem-cuda, kill-switch) and
kernel parity for values and gradients when it is installed.
interpolate_ms_features_tilted routes each multiscale level through the
fused kernel (rotate + grid_sample + Hadamard product, analytic backward)
when quantem-cuda is installed, tensors are CUDA fp32, and the
use_cuda_kernels config flag is on; the torch path is unchanged and remains
the fallback. Covered by GPU/CPU parity, gradient parity, per-scale
dispatch-spy, and kill-switch tests.
Evaluate the four TV taps (base, +x, +y, +z) in one batched 4N-point
forward instead of four serial 10k-point calls, which were kernel-launch
bound (volume TV 1.62 -> 0.52 ms, full training step 7.16 -> 5.23 ms at
the 200^3 benchmark config). Also gate the plane and volume TV terms on
their own weights: tv_plane > 0 with tv_vol == 0 previously applied no
plane TV at all, and tv_vol > 0 computed a zero-weighted plane TV. The
batched call uses the unwrapped model, consistent with the base tap and
_get_plane_tv_loss.
Each training step previously made two model calls on a tensor-decomp
object: the main batch forward and a second 40k-point call for the
volume-TV finite-difference taps. Both backwards accumulate gradients
into the same plane parameters, so autograd runs the full
gradient-accumulation traffic twice; profiling showed that traffic
(copy_/add_/fill_, ~1.8 ms/step) exceeds the interpolation backward
itself.

The object model now exposes sample_tv_tap_coords(), the training loop
concatenates the returned tap points onto the main batch for a single
model call via forward_with_tv_taps() (mask and hard constraints apply
to the main chunk only; tap densities stay raw, matching the previous
TV semantics), and the tap densities reach get_volume_tv_loss through
ReconstructionContext. The two-call path remains as a fallback for
direct callers and non-tensor-decomp models.

Full step 5.67 -> 4.79 ms (-15.5%) at the 200^3 benchmark config; loss
and gradient parity verified against the two-call path, including taps
crossing the [-1, 1] boundary.
cedriclim1 and others added 3 commits June 10, 2026 11:44
Fold volume-TV tap evaluation into the main forward pass
quantem-cuda moved its kernels into per-module submodules mirroring quantem:
the K-Planes interpolation now lives in quantem.cuda.core.ml (KPlanesTILTED
is a quantem.core.ml model shared across applications) and the TV kernels in
quantem.cuda.core. Update the dispatch imports and the test monkeypatch
targets accordingly.
docs/notebooks/quantem_cuda_kernels.ipynb spells out exactly what the
quantem-cuda TV kernels implement (tv_loss_sq_3d: unnormalized squared
anisotropic sum, tv_vol parity; tv_loss_iso_3d: corner-restricted isotropic
mean, eps inside the sqrt) and what they do not (no anisotropic-L1 kernel,
no per-axis weights, iso needs D >= 2), with parity checks, a ptychography
section (complex multislice objects, per-axis weighting recipe, honest
op-level timings), the fused TILTED K-Planes interpolation contract, and
the transparent-dispatch config surface. Executed on an RTX PRO 6000 so the
outputs are visible in review. .gitignore gains an exemption for curated
docs notebooks.
@cedriclim1 cedriclim1 changed the title Optional quantem-cuda kernel dispatch for tomography hot spots Optional quantem-cuda kernel dispatch for K-Planes models and volume TV Jun 10, 2026
…kernel

ObjectConstraints._calc_tv_loss now routes fp32 CUDA 3-D arrays through
quantem.cuda.core.tv_loss_l1_3d (per-axis |diff| sums), composing the same
anisotropic-L1 functional — per-axis (tv_weight_z, tv_weight_xy) weights and
active-axis normalization included — with identical gradients (sign(0) = 0).
Weighted size-1 axes fall through to torch so degenerate inputs behave
exactly as before; the kill switch and capability flag are the same as the
other dispatch points. ~4x fwd+bwd at multislice (16, 1024, 1024) sizes,
neutral at single-slice sizes.
…ity gate

docs/notebooks/quantem_cuda_kernels.ipynb now documents tv_loss_l1_3d (the
ptychography functional) with a verified parity composition and updated
timings/dispatch tables.

docs/notebooks/ptycho_tv_kernel_quality.ipynb reruns the tutorial ducky
reconstruction (simulated dataset, single-slice xy TV and multislice z+xy
TV) with and without the kernel and gates on SSIM of the reconstructed
phase against an fp-perturbation chaos floor: the pipeline is bit-exact
deterministic, so a 1-ulp TV rounding difference diverges trajectories
exactly like a 2e-7 relative defocus perturbation does. Both configs pass
(SSIM 0.997 vs floor 0.996 single-slice; 0.992 vs 0.988 multislice), with
final data-fidelity losses agreeing within the floor's scatter.
The repo gitignores *.ipynb on purpose — drop the docs/notebooks
exemption and the two committed notebooks. They remain available as a
PR attachment for review.
@cedriclim1 cedriclim1 marked this pull request as ready for review June 11, 2026 00:46
@cedriclim1 cedriclim1 requested a review from arthurmccray June 11, 2026 00:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant