Skip to content

perf: Speed up the small-model caller: call the model directly instead of Model.predict()#1089

Open
tfenne wants to merge 1 commit into
google:r1.10from
tfenne:sm_inference-call
Open

perf: Speed up the small-model caller: call the model directly instead of Model.predict()#1089
tfenne wants to merge 1 commit into
google:r1.10from
tfenne:sm_inference-call

Conversation

@tfenne

@tfenne tfenne commented Jun 23, 2026

Copy link
Copy Markdown

Summary

When --call_small_model_examples is enabled, make_examples runs a small Keras MLP once per region to pre-call easy candidates. That inference went through keras.Model.predict(), which rebuilds a data adapter, a CallbackList, and the prediction loop on every call. For the tiny per-region batches the small model sees, that per-call scaffolding — not the arithmetic — dominates its runtime.

This replaces the predict() call with a direct model call, which runs the same cached forward pass without the per-call machinery:

# deepvariant/small_model/inference.py
def classify(classifier, examples, batch_size):
  del batch_size  # each region's candidates run in a single forward pass
  return np.asarray(classifier(examples, training=False))

Why it's output-preserving

The small model is a plain Dense MLP ending in softmax, with no train/inference-divergent layers (no BatchNorm, no Dropout), so __call__(training=False) and predict() compute the identical forward pass over the same weights. The downstream gating (passes_confidence_threshold_accept_call_result) consumes the same probability rows, so every genotype decision is preserved.

A new test (ClassifyEquivalenceTest) builds a Dense+softmax MLP of the small model's shape and asserts classify() matches model.predict() — same shape, same dtype, identical arg-max, probabilities allclose to 1e-6 — across batch sizes from 1 to 300 (i.e. spanning predict's internal 128-row batching boundary).

Verified on the production models over every real chr20 example

As a direct check (not just the synthetic unit test), classify() was instrumented to run both predict() and __call__() on every batch, using the real production model.keras, across all of chr20:

model examples compared genotype-call (arg-max) changes bit-identical probabilities max prob diff max GQ (phred) diff
WGS 222,978 0 100.0000% 0 0
PacBio 299,227 0 99.9983% (5 rows differ) 1.19e-07 2.8e-04
ONT 482,232 0 99.9969% (15 rows differ) 1.19e-07 8.7e-05

Across >1,000,000 real candidate examples the two paths produce zero genotype-call differences. WGS is bit-for-bit identical; on long-read data ~20 rows total differ by a single float32 ULP (1.19e-07) — an artifact of XLA choosing a slightly different matmul tiling for the larger long-read batches — but the resulting GQ difference is below 3e-4 phred, far too small to move any integer GQ threshold. So the change is decision-identical on real data.

(The per-call counts also explain the platform-dependent speedup: WGS made 44,133 classify() calls for its 222,978 rows (~5 rows/call), versus ~2,560 calls for PacBio/ONT (~120–190 rows/call). The per-call predict() overhead is paid ~17× more often on WGS, which is why WGS gains ~55% while long-read — where each call already amortizes the overhead over a large batch — gains little.)

Measured impact (chr20, 16 shards, c8a.4xlarge / 16 cores)

make_examples wall time with --call_small_model_examples, 2 reps each:

dataset baseline (predict) new (call) speedup
WGS (Illumina) 231.8 / 231.7 s 104.7 / 105.0 s −54.8%
PacBio 142.4 / 142.3 s 132.7 / 133.5 s −6.6%
ONT 222.9 / 227.8 s 213.0 / 233.4 s ~−1–4% (within run noise)

The WGS gain is large because short-read per-region work is cheap, so the fixed per-region predict() overhead dominated make_examples. Long-read pileup work dominates PacBio/ONT, so the same overhead is a smaller fraction — never a regression.

Output is unchanged: examples_written and the number of small-model calls are identical between baseline and new on all three datasets (WGS 79759 / 143219, PacBio 125483 / 173744, ONT 134755 / 347477); the compressed CVO output differs only within the run-to-run non-determinism floor.

Notes

  • Because each region now runs in a single forward pass, --small_model_inference_batch_size no longer affects inference. Its help text is updated to mark it deprecated/ignored, but the flag is retained for command-line compatibility (no removal, no breakage).

Files

  • deepvariant/small_model/inference.pyclassify() uses __call__.
  • deepvariant/small_model/inference_test.py — stub __call__ instead of .predict; add the equivalence test.
  • deepvariant/make_examples_options.py — deprecate the now-inert batch-size flag's help text.

When --call_small_model_examples is set, make_examples runs a small Keras MLP once per region to pre-call easy candidates. That inference went through keras.Model.predict(), which rebuilds a data adapter, a CallbackList, and the prediction loop on every call. For the small per-region batches the model sees, that per-call scaffolding -- not the arithmetic -- dominates its runtime.

Calling the model directly (classifier(examples, training=False)) runs the same cached forward pass without that per-call machinery. The result is numerically identical: the model is a plain Dense MLP with no train/inference-divergent layers (no BatchNorm, no Dropout), so __call__(training=False) and predict() compute the same forward pass over the same weights, leaving the downstream confidence gating unchanged.

Because each region now runs in a single forward pass, --small_model_inference_batch_size no longer affects inference; its help text is updated to mark it deprecated/ignored but it is retained for command-line compatibility.

Measured on chr20 (16 shards, c8a.4xlarge): make_examples wall time with --call_small_model_examples drops from 231.8s to 104.7s for WGS (-55%), with smaller gains on long-read data whose per-region pileup work dominates (PacBio -7%, ONT roughly flat). examples_written and the small-model call counts are identical before and after on all three datasets.

A new ClassifyEquivalenceTest builds a Dense+softmax MLP of the small model's shape and asserts classify() matches Model.predict() (same shape, dtype, and arg-max; probabilities allclose to 1e-6) for batch sizes from 1 to 300.
@pichuan

pichuan commented Jun 24, 2026

Copy link
Copy Markdown
Collaborator

Hi @tfenne ,

Thanks for the PR!

Since I believe you're already familiar with our process, I'll go ahead and start the review.

As a reminder, because of the way our project is set up, we aren't able to merge GitHub PRs directly. If the changes look good, I will commit them, crediting your GitHub username and referencing this PR in the commit description.

Please let me know if you have any concerns with this approach.

-pichuan

@pichuan pichuan self-assigned this Jun 24, 2026
@pichuan pichuan self-requested a review June 25, 2026 04:48
@pichuan

pichuan commented Jun 25, 2026

Copy link
Copy Markdown
Collaborator

A quick update for you @tfenne :

I tested on our regular n2-standard-96 (96 shards) setup, but I did not see a speed improvement in that setting -- likely because with 96 shards, the per-region predict() overhead is already amortized enough to be negligible.

I confirmed the change is correctness-safe: hap.py metrics (F1, precision, recall) are identical across all configs.

I'm going to test on a different machine type (c3d-standard-16, 16 shards — closer to your c8a.4xlarge setup) and I'll report back to see if there's a speed improvement there.

@tfenne

tfenne commented Jun 25, 2026

Copy link
Copy Markdown
Author

Thanks @pichuan - I think the speedup I'm seeing should be constant with respect to the number of cores. IIRC the small model predict() is called once per region, and I think the number of regions is not different depending on the number of cores, but I will freely admit I don't know the codebase very well, and I could have misunderstood.

@tfenne

tfenne commented Jun 25, 2026

Copy link
Copy Markdown
Author

Oh one other thought. I've been doing most of my runs and benchmarking on AMD-based instances at AWS where vCPUs = physical CPU cores. I don't know as much about GCP, but I see that n2-standard-96 is intel-based (not AMD) and if Gemini is to be believed, it's a 48-core box with SMT. I wonder if that's what's driving the difference?

@tfenne

tfenne commented Jun 26, 2026

Copy link
Copy Markdown
Author

Well, it's not SMT vs. no-SMT. I tried running my same benchmark on matched c8a and c8i instances at AWS, and while the c8i instance is overall quite a bit slower, the r1.10 release vs. this branch both showed approximately 55% speedup (56% this time on AMD, 52% on intel).

I'll try and find time to replicate my findings on the whole genome (not just chr20) on a 96 core instance and post back here. The other place there is possibly room for error is if I didn't faithfully replicate what the wrapper script is doing to call make_examples. Here's what I'm doing:

python3 bazel-out/k8-opt/bin/deepvariant/make_examples.zip \
  --mode calling \
  --ref  /home/ubuntu/data/ref/GRCh38_no_alt_analysis_set.fasta \
  --reads /home/ubuntu/data/input/HG003.novaseq.pcr-free.35x.dedup.grch38_no_alt.chr20.bam \
  --examples '<out>/me.tfrecord@16.gz' \
  --checkpoint /home/ubuntu/models/wgs \
  --checkpoint_json /home/ubuntu/models/wgs/model.example_info.json \
  --call_small_model_examples \
  --regions chr20 \
  --task {}      # {} = 0..15, driven by: seq 0 15 | parallel -q --halt 2 --line-buffer

@pichuan

pichuan commented Jun 26, 2026

Copy link
Copy Markdown
Collaborator

For context, when I tested, I built a docker image and ran with docker, because that's how our standard release runtime metrics are evaluated.

@pgrosu

pgrosu commented Jun 26, 2026

Copy link
Copy Markdown

Hi Tim (@tfenne) and Pi-Chuan (@pichuan),

A very neat idea indeed, and maybe this might help. It's been a while since I've looked that the code in depth, but the general idea is this. The number of shards determine the number of regions to process, and are ideally uniformly distributed to each CPU by the operating system, given the task_id (by launching multiple instances of make_examples in parallel). These are the tensors generated by the single-threaded make_examples using the reference in the beginning of the tensor (for more details see my previous post here).

The number of instantiations calls of CallbackList will a increase little bit with every multiple of 10, but even at 100,000 repetitive calls, that resulted around 42.7 seconds on an 8-core (16 logical) machine. That can be driven further down given the number of cores and available memory (16->96 CPUs and 32->384 Gb RAM), but with caching that can stay around 4.59 seconds for even 10,000,000 instantiations (with similar argument settings). The code for this is at the end of this post.

There are some obvious caching and performance opportunities via a shared-memory model (see my previous post #650), opening the door to further modularize DeepVariant's pipeline steps to optimize for HPC/GPU-branching with the flexibility of plugging in models from other frameworks (i.e. PyTorch, etc). Currently, if you pause the execution within each step, that will allow you to explore and fine-tune optimization opportunities as guided via a profiler .

In any case, below is the code I mentioned above:

import time
start_time = time.time()

from functools import cache
import sys
from keras.src import callbacks as callbacks_module
import keras 

verbose=True
epochs=100

@cache
def callback_routine():

	callbacks = callbacks_module.CallbackList(
	                callbacks=None,
	                add_history=True,
	                add_progbar=verbose != 0,
	                verbose=verbose,
	                epochs=epochs,
	                steps=16,
	                model=keras.Model(),
	            )


def main():

	rounds = int( sys.argv[1] )

	for i in range(0, rounds):
			callback_routine()

	print( f"{(time.time() - start_time):0.3} seconds" )


if __name__ == "__main__":
    main()

Hope it helps,
Paul

@pichuan

pichuan commented Jun 26, 2026

Copy link
Copy Markdown
Collaborator

(Sorry, the long comment I posted was meant for #1086 . I'll move it there)

@pichuan

pichuan commented Jun 26, 2026

Copy link
Copy Markdown
Collaborator

Hi @tfenne ,

Here is an update on my testing. I tested with the same approach described in #1086 (comment)

Unfortunately, on c3d-standard-16 machines, when I test with your changes, the ONT_R104 and PACBIO data types are both crashing due to OOM. The same runs at HEAD (without your changes) on the same machine type complete successfully.

The remaining runs haven't all finished, so I don't know if anything else will run OOM as well.

I suspect it's the change in deepvariant/small_model/inference.py, where batch_size is deleted and all candidates for a region are processed in a single forward pass. With predict(batch_size=128), TensorFlow processes candidates in chunks and can free intermediate tensors between batches. With __call__(), all candidates are processed at once. ONT and PacBio tend to have denser regions with more candidates per region, which could explain why those data types hit OOM while others don't. --> This is just a guess. I didn't test out any further changes.

Based on this finding, I cannot recommend this PR to be reviewed or submitted internally at its current form. Let me know if you want to make any changes to fix the OOM, and I am happy to test again.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants