Skip to content

Device-resident batch sampling for INR tomography reconstruction#244

Closed
cedriclim1 wants to merge 2 commits into
electronmicroscopy:devfrom
cedriclim1:feat/gpu-batch-sampler
Closed

Device-resident batch sampling for INR tomography reconstruction#244
cedriclim1 wants to merge 2 commits into
electronmicroscopy:devfrom
cedriclim1:feat/gpu-batch-sampler

Conversation

@cedriclim1

Copy link
Copy Markdown
Collaborator

What problem this PR addresses

Single-GPU INR tomography training is CPU-bound on the data path: each step the DataLoader makes batch_size Python __getitem__ calls (one per pixel), collates them, and copies the batch host→device — ~43 ms/batch at batch 8192, while the GPU has only ~20 ms of work. The GPU idles more than half of every step (2 % utilization measured on a 300³ KPlanesTILTED reconstruction).

A training batch here is just pixel indices plus lookups into the tilt stack, so this PR adds DeviceBatchSampler (tomography/dataset_models.py): the tilt stack and angles are made resident on the compute device once, and each batch is built with index arithmetic and two tensor lookups — same batch dicts, same shuffle/drop_last/val-split semantics as the DataLoader path. Tomography.reconstruct routes INR datasets through it via _setup_recon_dataloaders; non-INR datasets keep the DataLoader path.

Distributed runs are sharded with DistributedSampler semantics: every rank derives the same seeded epoch permutation (seed + epoch, CPU generator, reproducible) and takes an equal-size contiguous shard — equal so per-rank batch counts always match and DDP gradient sync cannot hang on a ragged tail. The training loop's existing sampler.set_epoch call drives reshuffling. The train/val split now uses a fixed-seed generator, making it identical across ranks (no train/val leakage between ranks) and stable across save/reload, so resumed runs keep validating on the same held-out pixels.

Also included: the validation loop now runs with bf16 autocast disabled, matching the training pass — the SO(3) pose solve (lu_factor) has no BFloat16 kernel, and a bf16 val loss is not comparable to the fp32 train loss it is checked against.

Results (RTX PRO 6000, KPlanesTILTED, 41×300² tilts, batch 2048, 300 samples/ray): 24.6 → 13.3 s/epoch on its own (1.85×); combined with the quantem-cuda kernel dispatch of #243, 24.6 → 5.8 s/epoch (4.24× — a 300-iteration reconstruction in ~29 min instead of ~2 h). GPU utilization during a full Au-phantom reconstruction went from 2 % to 82 %. Reconstruction quality is unchanged (same uniform sampling; verified end-to-end on the Au-phantom quality gate together with #243: detection F1 0.858 vs 0.832 baseline, precision 1.000 both, PSNR within 0.2 dB).

Independent of quantem-cuda — this PR is pure torch and stands alone.

What should the reviewer(s) do

Review the sampler semantics (shuffle / drop_last / val split / DDP sharding) and the reconstruct wiring. Tests in tests/tomography/test_device_batch_sampler.py: batch content identical to __getitem__ decode, exactly-once epoch coverage, disjoint equal DDP shards, shared reproducible epoch permutations, and a reconstruct smoke test; DDP additionally verified with a 2-GPU torchrun run (equal batch counts, identical reduced losses, no deadlock). Full suite green.

  • This PR introduces a public-facing change (e.g., figures, CLI input/output, API).
  • This PR affects internal functionality only (no user-facing change — same batches, faster; num_workers is ignored on the sampler path).
    • For functional and algorithmic changes, tests are written or updated.

The per-pixel DataLoader path costs ~43 ms/batch on CPU (8192 Python
__getitem__ calls + collate + H2D copy) while the GPU step itself takes
~20 ms, so single-GPU reconstructions idle the GPU half the time.
DeviceBatchSampler keeps the tilt stack resident on the device and builds
each batch with index arithmetic and two tensor lookups, yielding the same
batch dicts (and drop_last/val-split semantics) as the DataLoader path.
DDP runs keep the DataLoader + DistributedSampler path.

Also disable bf16 autocast in the validation loop to match the training
pass: the so3 pose solve (lu_factor) has no BFloat16 kernel, and a bf16
val loss is not comparable to the fp32 train loss.
Every rank derives the same epoch permutation from seed + epoch (CPU
generator, identical across ranks and reproducible) and takes an
equal-size contiguous shard, with the ragged tail dropped so per-rank
batch counts always match and gradient sync cannot hang. The training
loop's existing sampler.set_epoch call drives reshuffling, exactly like
DistributedSampler; single-process runs auto-advance the epoch instead.

The train/val split now uses a fixed-seed generator: identical across
ranks (no leakage between a rank's train shard and another's val shard)
and stable across save/reload, so resumed runs keep validating on the
same held-out pixels.

Verified with a 2-GPU torchrun run: equal batch counts on both ranks,
finite identical reduced losses, no deadlock.
@cedriclim1

Copy link
Copy Markdown
Collaborator Author

Closing this in favour of landing on the fork's consolidation branch (feat/tomography-inr-fixes) before submitting a single combined PR upstream.

@cedriclim1 cedriclim1 closed this Jun 10, 2026
@cedriclim1 cedriclim1 deleted the feat/gpu-batch-sampler branch June 10, 2026 22:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant