fix(autobatching): detach completed states to stop InFlightAutoBatcher memory leak (UMA)#590
Draft
niklashoelter wants to merge 2 commits into
Draft
Conversation
…memory leak InFlightAutoBatcher accumulates every converged system in the caller's completed-states list for the whole run. A popped state preserves grad_fn, and some models return graph-carrying outputs - UMA's energy keeps requires_grad=True while its forces are already detached - so each completed state pins its swap's entire forward autograd graph. Across hundreds of in-flight swaps that is one retained graph per finished system (~tens of MB each), growing live (allocated, not cached) GPU memory monotonically until the device OOMs deep into a long optimization. empty_cache cannot reclaim it because the memory is allocated, not cached, which is why it surfaces as a late 'reserved but unallocated' fragmentation OOM. Detach any grad-carrying tensors on completed states as they leave next_batch, before they are accumulated. Completed states are only read for their values, never differentiated, so dropping the graph is safe. Live memory then stays flat across the whole run.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
Long optimization runs with
InFlightAutoBatcherleak GPU memory and eventually OOM deep into the run — not because of batch sizing, but because completed states retain an autograd graph.torch_sim.optimizeaccumulates every converged system (viaInFlightAutoBatcher.next_batch→ the caller's completed-states list) for the entire run so it can restore the original order at the end. A state popped out of the running batch preservesgrad_fn, and some models return graph-carrying outputs — notably UMA, whoseenergykeepsrequires_grad=Trueeven though its forces are already detached. Each retained completed state therefore pins the entire forward autograd graph of the swap it finished in. Across hundreds of in-flight swaps that is one live graph per finished system, sotorch.cuda.memory_allocatedclimbs monotonically (tens of MB per structure) until the device fills.Because the leak is allocated memory (not cached),
torch.cuda.empty_cache()cannot reclaim it, which is why it surfaces as a late "reserved but unallocated"-style OOM and why cache-clearing / allocator-config workarounds don't help.Reproduced with
uma-s-1p1over a ~4300-structure molecular library (FIRE, D3 on): live memory climbs from ~13 GB to 24 GB over ~356 swaps and OOMs; with this fix it stays flat at ~13 GB across the entire run (validated to full completion over all structures, twice).This is distinct from the LBFGS/BFGS history-buffer accumulation fixed in #559/#568 — it affects the plain optimization loop (e.g. FIRE) via the model's own energy tensor.
Fix
torch_sim/autobatching.pyonly: add a small_detach_state_graph()helper and call it on completed states as they leaveInFlightAutoBatcher.next_batch, before they are accumulated. Completed states are only ever read for their values (never differentiated), so dropping the graph is safe and changes no numerical results.Tests
tests/test_autobatching.py:test_detach_state_graph_drops_grad_but_keeps_values— asserts grad-carrying tensors are detached in place, non-grad tensors are left untouched, and values are unchanged. Full autobatching suite green; ruff + ty clean.