Device-resident batch sampling for INR tomography reconstruction#244
Closed
cedriclim1 wants to merge 2 commits into
Closed
Device-resident batch sampling for INR tomography reconstruction#244cedriclim1 wants to merge 2 commits into
cedriclim1 wants to merge 2 commits into
Conversation
The per-pixel DataLoader path costs ~43 ms/batch on CPU (8192 Python __getitem__ calls + collate + H2D copy) while the GPU step itself takes ~20 ms, so single-GPU reconstructions idle the GPU half the time. DeviceBatchSampler keeps the tilt stack resident on the device and builds each batch with index arithmetic and two tensor lookups, yielding the same batch dicts (and drop_last/val-split semantics) as the DataLoader path. DDP runs keep the DataLoader + DistributedSampler path. Also disable bf16 autocast in the validation loop to match the training pass: the so3 pose solve (lu_factor) has no BFloat16 kernel, and a bf16 val loss is not comparable to the fp32 train loss.
Every rank derives the same epoch permutation from seed + epoch (CPU generator, identical across ranks and reproducible) and takes an equal-size contiguous shard, with the ragged tail dropped so per-rank batch counts always match and gradient sync cannot hang. The training loop's existing sampler.set_epoch call drives reshuffling, exactly like DistributedSampler; single-process runs auto-advance the epoch instead. The train/val split now uses a fixed-seed generator: identical across ranks (no leakage between a rank's train shard and another's val shard) and stable across save/reload, so resumed runs keep validating on the same held-out pixels. Verified with a 2-GPU torchrun run: equal batch counts on both ranks, finite identical reduced losses, no deadlock.
Collaborator
Author
|
Closing this in favour of landing on the fork's consolidation branch (feat/tomography-inr-fixes) before submitting a single combined PR upstream. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What problem this PR addresses
Single-GPU INR tomography training is CPU-bound on the data path: each step the DataLoader makes
batch_sizePython__getitem__calls (one per pixel), collates them, and copies the batch host→device — ~43 ms/batch at batch 8192, while the GPU has only ~20 ms of work. The GPU idles more than half of every step (2 % utilization measured on a 300³ KPlanesTILTED reconstruction).A training batch here is just pixel indices plus lookups into the tilt stack, so this PR adds
DeviceBatchSampler(tomography/dataset_models.py): the tilt stack and angles are made resident on the compute device once, and each batch is built with index arithmetic and two tensor lookups — same batch dicts, same shuffle/drop_last/val-split semantics as the DataLoader path.Tomography.reconstructroutes INR datasets through it via_setup_recon_dataloaders; non-INR datasets keep the DataLoader path.Distributed runs are sharded with
DistributedSamplersemantics: every rank derives the same seeded epoch permutation (seed + epoch, CPU generator, reproducible) and takes an equal-size contiguous shard — equal so per-rank batch counts always match and DDP gradient sync cannot hang on a ragged tail. The training loop's existingsampler.set_epochcall drives reshuffling. The train/val split now uses a fixed-seed generator, making it identical across ranks (no train/val leakage between ranks) and stable across save/reload, so resumed runs keep validating on the same held-out pixels.Also included: the validation loop now runs with bf16 autocast disabled, matching the training pass — the SO(3) pose solve (
lu_factor) has no BFloat16 kernel, and a bf16 val loss is not comparable to the fp32 train loss it is checked against.Results (RTX PRO 6000, KPlanesTILTED, 41×300² tilts, batch 2048, 300 samples/ray): 24.6 → 13.3 s/epoch on its own (1.85×); combined with the quantem-cuda kernel dispatch of #243, 24.6 → 5.8 s/epoch (4.24× — a 300-iteration reconstruction in ~29 min instead of ~2 h). GPU utilization during a full Au-phantom reconstruction went from 2 % to 82 %. Reconstruction quality is unchanged (same uniform sampling; verified end-to-end on the Au-phantom quality gate together with #243: detection F1 0.858 vs 0.832 baseline, precision 1.000 both, PSNR within 0.2 dB).
Independent of quantem-cuda — this PR is pure torch and stands alone.
What should the reviewer(s) do
Review the sampler semantics (shuffle / drop_last / val split / DDP sharding) and the
reconstructwiring. Tests intests/tomography/test_device_batch_sampler.py: batch content identical to__getitem__decode, exactly-once epoch coverage, disjoint equal DDP shards, shared reproducible epoch permutations, and a reconstruct smoke test; DDP additionally verified with a 2-GPUtorchrunrun (equal batch counts, identical reduced losses, no deadlock). Full suite green.num_workersis ignored on the sampler path).