Skip to content

fix(autobatching): detach completed states to stop InFlightAutoBatcher memory leak (UMA)#590

Draft
niklashoelter wants to merge 2 commits into
TorchSim:mainfrom
niklashoelter:fix/uma-inflight-energy-graph-leak
Draft

fix(autobatching): detach completed states to stop InFlightAutoBatcher memory leak (UMA)#590
niklashoelter wants to merge 2 commits into
TorchSim:mainfrom
niklashoelter:fix/uma-inflight-energy-graph-leak

Conversation

@niklashoelter

Copy link
Copy Markdown
Contributor

Problem

Long optimization runs with InFlightAutoBatcher leak GPU memory and eventually OOM deep into the run — not because of batch sizing, but because completed states retain an autograd graph.

torch_sim.optimize accumulates every converged system (via InFlightAutoBatcher.next_batch → the caller's completed-states list) for the entire run so it can restore the original order at the end. A state popped out of the running batch preserves grad_fn, and some models return graph-carrying outputs — notably UMA, whose energy keeps requires_grad=True even though its forces are already detached. Each retained completed state therefore pins the entire forward autograd graph of the swap it finished in. Across hundreds of in-flight swaps that is one live graph per finished system, so torch.cuda.memory_allocated climbs monotonically (tens of MB per structure) until the device fills.

Because the leak is allocated memory (not cached), torch.cuda.empty_cache() cannot reclaim it, which is why it surfaces as a late "reserved but unallocated"-style OOM and why cache-clearing / allocator-config workarounds don't help.

Reproduced with uma-s-1p1 over a ~4300-structure molecular library (FIRE, D3 on): live memory climbs from ~13 GB to 24 GB over ~356 swaps and OOMs; with this fix it stays flat at ~13 GB across the entire run (validated to full completion over all structures, twice).

This is distinct from the LBFGS/BFGS history-buffer accumulation fixed in #559/#568 — it affects the plain optimization loop (e.g. FIRE) via the model's own energy tensor.

Fix

torch_sim/autobatching.py only: add a small _detach_state_graph() helper and call it on completed states as they leave InFlightAutoBatcher.next_batch, before they are accumulated. Completed states are only ever read for their values (never differentiated), so dropping the graph is safe and changes no numerical results.

for completed_state in completed_states:
    _detach_state_graph(completed_state)

Tests

tests/test_autobatching.py: test_detach_state_graph_drops_grad_but_keeps_values — asserts grad-carrying tensors are detached in place, non-grad tensors are left untouched, and values are unchanged. Full autobatching suite green; ruff + ty clean.

…memory leak

InFlightAutoBatcher accumulates every converged system in the caller's
completed-states list for the whole run. A popped state preserves grad_fn,
and some models return graph-carrying outputs - UMA's energy keeps
requires_grad=True while its forces are already detached - so each completed
state pins its swap's entire forward autograd graph. Across hundreds of
in-flight swaps that is one retained graph per finished system (~tens of MB
each), growing live (allocated, not cached) GPU memory monotonically until
the device OOMs deep into a long optimization. empty_cache cannot reclaim it
because the memory is allocated, not cached, which is why it surfaces as a
late 'reserved but unallocated' fragmentation OOM.

Detach any grad-carrying tensors on completed states as they leave
next_batch, before they are accumulated. Completed states are only read for
their values, never differentiated, so dropping the graph is safe. Live
memory then stays flat across the whole run.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant