Optional quantem-cuda kernel dispatch for K-Planes models and volume TV#243
Open
cedriclim1 wants to merge 10 commits into
Open
Optional quantem-cuda kernel dispatch for K-Planes models and volume TV#243cedriclim1 wants to merge 10 commits into
cedriclim1 wants to merge 10 commits into
Conversation
Add a has_quantem_cuda capability flag (mirroring has_torch/has_cupy) and a use_cuda_kernels config option (default true) as a kill-switch. ObjectPixelated.get_tv_loss now goes through tv_loss_vol_sq in tomography/utils.py, which dispatches to quantem-cuda's fused squared-anisotropic TV kernel (identical math, analytic backward, 10-14x faster fwd+bwd at 256^3-512^3) when the package is installed, the object is a CUDA fp32 tensor, and the kill-switch is on; otherwise it falls back to the existing pure-torch expression. quantem-cuda is not yet on PyPI, so the [cuda] extra is deferred to a one-line follow-up after the 0.1.0 publish (an unpublished package in optional-dependencies breaks uv lock for everyone). Until then, installing quantem-cuda from source enables the dispatch. Tests cover the fallback paths (CPU, no quantem-cuda, kill-switch) and kernel parity for values and gradients when it is installed.
interpolate_ms_features_tilted routes each multiscale level through the fused kernel (rotate + grid_sample + Hadamard product, analytic backward) when quantem-cuda is installed, tensors are CUDA fp32, and the use_cuda_kernels config flag is on; the torch path is unchanged and remains the fallback. Covered by GPU/CPU parity, gradient parity, per-scale dispatch-spy, and kill-switch tests.
3 tasks
Evaluate the four TV taps (base, +x, +y, +z) in one batched 4N-point forward instead of four serial 10k-point calls, which were kernel-launch bound (volume TV 1.62 -> 0.52 ms, full training step 7.16 -> 5.23 ms at the 200^3 benchmark config). Also gate the plane and volume TV terms on their own weights: tv_plane > 0 with tv_vol == 0 previously applied no plane TV at all, and tv_vol > 0 computed a zero-weighted plane TV. The batched call uses the unwrapped model, consistent with the base tap and _get_plane_tv_loss.
Each training step previously made two model calls on a tensor-decomp object: the main batch forward and a second 40k-point call for the volume-TV finite-difference taps. Both backwards accumulate gradients into the same plane parameters, so autograd runs the full gradient-accumulation traffic twice; profiling showed that traffic (copy_/add_/fill_, ~1.8 ms/step) exceeds the interpolation backward itself. The object model now exposes sample_tv_tap_coords(), the training loop concatenates the returned tap points onto the main batch for a single model call via forward_with_tv_taps() (mask and hard constraints apply to the main chunk only; tap densities stay raw, matching the previous TV semantics), and the tap densities reach get_volume_tv_loss through ReconstructionContext. The two-call path remains as a fallback for direct callers and non-tensor-decomp models. Full step 5.67 -> 4.79 ms (-15.5%) at the 200^3 benchmark config; loss and gradient parity verified against the two-call path, including taps crossing the [-1, 1] boundary.
1 task
Fold volume-TV tap evaluation into the main forward pass
quantem-cuda moved its kernels into per-module submodules mirroring quantem: the K-Planes interpolation now lives in quantem.cuda.core.ml (KPlanesTILTED is a quantem.core.ml model shared across applications) and the TV kernels in quantem.cuda.core. Update the dispatch imports and the test monkeypatch targets accordingly.
docs/notebooks/quantem_cuda_kernels.ipynb spells out exactly what the quantem-cuda TV kernels implement (tv_loss_sq_3d: unnormalized squared anisotropic sum, tv_vol parity; tv_loss_iso_3d: corner-restricted isotropic mean, eps inside the sqrt) and what they do not (no anisotropic-L1 kernel, no per-axis weights, iso needs D >= 2), with parity checks, a ptychography section (complex multislice objects, per-axis weighting recipe, honest op-level timings), the fused TILTED K-Planes interpolation contract, and the transparent-dispatch config surface. Executed on an RTX PRO 6000 so the outputs are visible in review. .gitignore gains an exemption for curated docs notebooks.
…kernel ObjectConstraints._calc_tv_loss now routes fp32 CUDA 3-D arrays through quantem.cuda.core.tv_loss_l1_3d (per-axis |diff| sums), composing the same anisotropic-L1 functional — per-axis (tv_weight_z, tv_weight_xy) weights and active-axis normalization included — with identical gradients (sign(0) = 0). Weighted size-1 axes fall through to torch so degenerate inputs behave exactly as before; the kill switch and capability flag are the same as the other dispatch points. ~4x fwd+bwd at multislice (16, 1024, 1024) sizes, neutral at single-slice sizes.
…ity gate docs/notebooks/quantem_cuda_kernels.ipynb now documents tv_loss_l1_3d (the ptychography functional) with a verified parity composition and updated timings/dispatch tables. docs/notebooks/ptycho_tv_kernel_quality.ipynb reruns the tutorial ducky reconstruction (simulated dataset, single-slice xy TV and multislice z+xy TV) with and without the kernel and gates on SSIM of the reconstructed phase against an fp-perturbation chaos floor: the pipeline is bit-exact deterministic, so a 1-ulp TV rounding difference diverges trajectories exactly like a 2e-7 relative defocus perturbation does. Both configs pass (SSIM 0.997 vs floor 0.996 single-slice; 0.992 vs 0.988 multislice), with final data-fidelity losses agreeing within the floor's scatter.
The repo gitignores *.ipynb on purpose — drop the docs/notebooks exemption and the two committed notebooks. They remain available as a PR attachment for review.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What problem this PR addresses
The sibling
quantem-cudapackage (electronmicroscopy/quantem-cuda#2) provides fused CUDA kernels, with hand-written analytic backwards, for two operations that pure torch handles inefficiently. Neither is tomography-specific:KPlanesTILTEDlives inquantem.core.ml.models.kplanesand is shared across applications (tomography object models are the current consumer). Itsinterpolate_ms_features_tilted— einsum rotate → gather →F.grid_sample→ 3-plane Hadamard product — dominates every K-Planes training step (75–82 % of the GPU step in the tomography profile) and materializes a(3T, C, B)intermediate of up to several GB per batch just to multiply three planes pointwise. The kernel lives inquantem.cuda.core.ml, mirroring the model's home, so any K-Planes application picks it up.quantem.cuda.core, shared kernels usable from any module — see the explicit inventory below).This PR wires quantem to use those kernels when available, with zero impact otherwise:
config.get("has_quantem_cuda")capability flag (try-import incore/config.py, mirroringhas_torch/has_cupy) and ause_cuda_kernelsconfig kill-switch (default true).interpolate_ms_features_tilted(core/ml/models/kplanes.py), used by everyKPlanesTILTEDforward in any application (data term, TV soft constraints, andcreate_volumedecoding alike).tv_loss_vol_sq()dispatch helper intomography/utils.py, used byObjectPixelated.get_tv_loss.ObjectConstraints._calc_tv_loss(diffractive_imaging/object_models.py) dispatches the ptychography TV object constraint to the fused L1 kernel — same anisotropic-L1 functional, per-axis(tv_weight_z, tv_weight_xy)weights, and active-axis normalization, with identical gradients (sign(0) = 0). Weighted size-1 axes fall through to torch so degenerate inputs behave exactly as before.Dispatch requires CUDA fp32 tensors + the package installed + the kill-switch on; the existing pure-torch code is unchanged and remains the fallback (CPU, other dtypes, kernels disabled, package absent).
Exactly which TV kernels exist (and which don't)
quantem.cuda.coreimplements three TV functionals, all fp32/CUDA/real-only, shape[..., D, H, W](ndim >= 3, leading dims flattened in), returning a differentiable scalar ((3,)for the per-axis L1 sums),torch.compile(fullgraph=True)-safe:tv_loss_sq_3d(volume)Σ(Δd)² + Σ(Δh)² + Σ(Δw)²(forward differences, full index range per axis)tv_volbefore itsweight / numelscalingtv_loss_iso_3d(volume, eps=1e-8)sqrt((Δd)² + (Δh)² + (Δw)² + eps)with all three differences anchored on the common corner set[0,D−2]×[0,H−2]×[0,W−2]epsinside the sqrt)tv_loss_l1_3d(volume)Σ|forward diff|— the exact functional of ptychography's_calc_tv_loss(3,)— weighting/normalization stay with the caller, so any per-axis scheme is exact;d|x|/dx = sign(x),sign(0) = 0Explicitly not implemented: per-axis weights inside a kernel (the L1 kernel's per-axis sums solve this in Python; a two-call recipe covers the squared variant), a 2-D-only variant (the anisotropic kernels on
[1, H, W]degrade gracefully to pure-2D TV;tv_loss_iso_3dneedsD >= 2— its corner set is empty on a single slice and it returns 0.0), complex input (take.angle()/.abs()first), and fp16/bf16/fp64.Notebooks (usage + reconstruction-level quality gate)
Both executed on an RTX PRO 6000 so outputs are reviewable; they are attached below as
notebooks.ziprather than committed (the repo gitignores*.ipynb).quantem_cuda_kernels.ipynbdocuments the direct-call API — the kernels are ordinary public torch functions, not just dispatch targets:tv_loss_l1_3d's per-axis sums into ptychography's_calc_tv_loss(weights, counts, active-axis normalization);tv_weight_z/tv_weight_xyconstraint now runs through the fused L1 kernel automatically; for different TV flavors, applytv_loss_sq_3d/tv_loss_iso_3dtoobj.angle()directly, with a verified two-call per-axis recipe for the squared variant. Op-level timings: the L1 kernel is ~4× the torch chain at(16, 1024, 1024)multislice and 256³ sizes (a few ms per iteration), neutral at single-slice sizes — a few-percent end-to-end win at large multislice sizes against an FFT-dominated iteration, roughly nothing for small single-slice objects;pts [B,3]in[−1,1]³,rotations [T,3,3],plane [3T,C,H,W]→[B,T·C];grid_sample(align_corners=True, padding_mode="border")parity, gradients to all three inputs) with a direct-call vs torch-chain comparison;ptycho_tv_kernel_quality.ipynbcloses the loop at the reconstruction level: it reruns the quantem-tutorialsptycho_iter_02_fullreconstruction on its simulated ducky dataset with TV constraints (single-slice xy TV, and multislice z+xy TV à laptycho_iter_04_multislice), with and without the kernel, and gates on SSIM of the reconstructed phase. The pipeline is bit-exact deterministic, so the fair floor is a torch run with the defocus perturbed by 2×10⁻⁷ relative — a 1-ulp TV rounding difference diverges a deterministic trajectory exactly like a 1-ulp physics perturbation. Both configs pass: SSIM(torch, kernel) 0.997 vs floor 0.996 (single-slice), 0.992 vs 0.988 (multislice), final data-fidelity losses within the floor's scatter.Results (RTX PRO 6000, KPlanesTILTED T=4/M=8/200³, batch 2048 × 300 samples/ray)
Packaging note: the
cudaoptional extra is intentionally not added yet — an unpublished package in[project.optional-dependencies]breaksuv lock(and thecheck-uv-lockgate) for everyone. Oncequantem-cudais published to PyPI this becomes a one-line follow-up; the soft-import dispatch in this PR works with any install of the package in the meantime.What should the reviewer(s) do
Draft until the quantem-cuda PR lands — feedback welcome on the dispatch points, the config surface, and the notebook. Tests:
tests/ml/test_kplanes_cuda_dispatch.py,tests/tomography/test_cuda_dispatch.py, andtests/diffractive_imaging/test_tv_cuda_dispatch.py(GPU/CPU parity incl. gradients, dispatch-spy, kill-switch; meaningful both with and without quantem-cuda installed). Full suite green in both environments.quantem_cuda_kernels.ipynb, attached below) has been updated.notebooks.zip