perf: Speed up the small-model caller: call the model directly instead of Model.predict()#1089
perf: Speed up the small-model caller: call the model directly instead of Model.predict()#1089tfenne wants to merge 1 commit into
Model.predict()#1089Conversation
When --call_small_model_examples is set, make_examples runs a small Keras MLP once per region to pre-call easy candidates. That inference went through keras.Model.predict(), which rebuilds a data adapter, a CallbackList, and the prediction loop on every call. For the small per-region batches the model sees, that per-call scaffolding -- not the arithmetic -- dominates its runtime. Calling the model directly (classifier(examples, training=False)) runs the same cached forward pass without that per-call machinery. The result is numerically identical: the model is a plain Dense MLP with no train/inference-divergent layers (no BatchNorm, no Dropout), so __call__(training=False) and predict() compute the same forward pass over the same weights, leaving the downstream confidence gating unchanged. Because each region now runs in a single forward pass, --small_model_inference_batch_size no longer affects inference; its help text is updated to mark it deprecated/ignored but it is retained for command-line compatibility. Measured on chr20 (16 shards, c8a.4xlarge): make_examples wall time with --call_small_model_examples drops from 231.8s to 104.7s for WGS (-55%), with smaller gains on long-read data whose per-region pileup work dominates (PacBio -7%, ONT roughly flat). examples_written and the small-model call counts are identical before and after on all three datasets. A new ClassifyEquivalenceTest builds a Dense+softmax MLP of the small model's shape and asserts classify() matches Model.predict() (same shape, dtype, and arg-max; probabilities allclose to 1e-6) for batch sizes from 1 to 300.
|
Hi @tfenne , Thanks for the PR! Since I believe you're already familiar with our process, I'll go ahead and start the review. As a reminder, because of the way our project is set up, we aren't able to merge GitHub PRs directly. If the changes look good, I will commit them, crediting your GitHub username and referencing this PR in the commit description. Please let me know if you have any concerns with this approach. -pichuan |
|
A quick update for you @tfenne : I tested on our regular I confirmed the change is correctness-safe: hap.py metrics (F1, precision, recall) are identical across all configs. I'm going to test on a different machine type ( |
|
Thanks @pichuan - I think the speedup I'm seeing should be constant with respect to the number of cores. IIRC the small model predict() is called once per region, and I think the number of regions is not different depending on the number of cores, but I will freely admit I don't know the codebase very well, and I could have misunderstood. |
|
Oh one other thought. I've been doing most of my runs and benchmarking on AMD-based instances at AWS where vCPUs = physical CPU cores. I don't know as much about GCP, but I see that |
|
Well, it's not SMT vs. no-SMT. I tried running my same benchmark on matched c8a and c8i instances at AWS, and while the c8i instance is overall quite a bit slower, the r1.10 release vs. this branch both showed approximately 55% speedup (56% this time on AMD, 52% on intel). I'll try and find time to replicate my findings on the whole genome (not just chr20) on a 96 core instance and post back here. The other place there is possibly room for error is if I didn't faithfully replicate what the wrapper script is doing to call make_examples. Here's what I'm doing: |
|
For context, when I tested, I built a docker image and ran with docker, because that's how our standard release runtime metrics are evaluated. |
|
Hi Tim (@tfenne) and Pi-Chuan (@pichuan), A very neat idea indeed, and maybe this might help. It's been a while since I've looked that the code in depth, but the general idea is this. The number of shards determine the number of regions to process, and are ideally uniformly distributed to each CPU by the operating system, given the The number of instantiations calls of There are some obvious caching and performance opportunities via a shared-memory model (see my previous post #650), opening the door to further modularize DeepVariant's pipeline steps to optimize for HPC/GPU-branching with the flexibility of plugging in models from other frameworks (i.e. PyTorch, etc). Currently, if you pause the execution within each step, that will allow you to explore and fine-tune optimization opportunities as guided via a profiler . In any case, below is the code I mentioned above: import time
start_time = time.time()
from functools import cache
import sys
from keras.src import callbacks as callbacks_module
import keras
verbose=True
epochs=100
@cache
def callback_routine():
callbacks = callbacks_module.CallbackList(
callbacks=None,
add_history=True,
add_progbar=verbose != 0,
verbose=verbose,
epochs=epochs,
steps=16,
model=keras.Model(),
)
def main():
rounds = int( sys.argv[1] )
for i in range(0, rounds):
callback_routine()
print( f"{(time.time() - start_time):0.3} seconds" )
if __name__ == "__main__":
main()Hope it helps, |
|
(Sorry, the long comment I posted was meant for #1086 . I'll move it there) |
|
Hi @tfenne , Here is an update on my testing. I tested with the same approach described in #1086 (comment) Unfortunately, on The remaining runs haven't all finished, so I don't know if anything else will run OOM as well. I suspect it's the change in Based on this finding, I cannot recommend this PR to be reviewed or submitted internally at its current form. Let me know if you want to make any changes to fix the OOM, and I am happy to test again. |
Summary
When
--call_small_model_examplesis enabled,make_examplesruns a small Keras MLP once per region to pre-call easy candidates. That inference went throughkeras.Model.predict(), which rebuilds a data adapter, aCallbackList, and the prediction loop on every call. For the tiny per-region batches the small model sees, that per-call scaffolding — not the arithmetic — dominates its runtime.This replaces the
predict()call with a direct model call, which runs the same cached forward pass without the per-call machinery:Why it's output-preserving
The small model is a plain
DenseMLP ending insoftmax, with no train/inference-divergent layers (no BatchNorm, no Dropout), so__call__(training=False)andpredict()compute the identical forward pass over the same weights. The downstream gating (passes_confidence_threshold→_accept_call_result) consumes the same probability rows, so every genotype decision is preserved.A new test (
ClassifyEquivalenceTest) builds aDense+softmaxMLP of the small model's shape and assertsclassify()matchesmodel.predict()— same shape, same dtype, identical arg-max, probabilitiesallcloseto 1e-6 — across batch sizes from 1 to 300 (i.e. spanningpredict's internal 128-row batching boundary).Verified on the production models over every real chr20 example
As a direct check (not just the synthetic unit test),
classify()was instrumented to run bothpredict()and__call__()on every batch, using the real productionmodel.keras, across all of chr20:Across >1,000,000 real candidate examples the two paths produce zero genotype-call differences. WGS is bit-for-bit identical; on long-read data ~20 rows total differ by a single float32 ULP (1.19e-07) — an artifact of XLA choosing a slightly different matmul tiling for the larger long-read batches — but the resulting GQ difference is below 3e-4 phred, far too small to move any integer GQ threshold. So the change is decision-identical on real data.
(The per-call counts also explain the platform-dependent speedup: WGS made 44,133
classify()calls for its 222,978 rows (~5 rows/call), versus ~2,560 calls for PacBio/ONT (~120–190 rows/call). The per-callpredict()overhead is paid ~17× more often on WGS, which is why WGS gains ~55% while long-read — where each call already amortizes the overhead over a large batch — gains little.)Measured impact (chr20, 16 shards, c8a.4xlarge / 16 cores)
make_exampleswall time with--call_small_model_examples, 2 reps each:The WGS gain is large because short-read per-region work is cheap, so the fixed per-region
predict()overhead dominatedmake_examples. Long-read pileup work dominates PacBio/ONT, so the same overhead is a smaller fraction — never a regression.Output is unchanged:
examples_writtenand the number of small-model calls are identical between baseline and new on all three datasets (WGS 79759 / 143219, PacBio 125483 / 173744, ONT 134755 / 347477); the compressed CVO output differs only within the run-to-run non-determinism floor.Notes
--small_model_inference_batch_sizeno longer affects inference. Its help text is updated to mark it deprecated/ignored, but the flag is retained for command-line compatibility (no removal, no breakage).Files
deepvariant/small_model/inference.py—classify()uses__call__.deepvariant/small_model/inference_test.py— stub__call__instead of.predict; add the equivalence test.deepvariant/make_examples_options.py— deprecate the now-inert batch-size flag's help text.