feat: add ONNX Runtime inference backend with PyTorch fallback (#12)#32
feat: add ONNX Runtime inference backend with PyTorch fallback (#12)#32jshaofa-ui wants to merge 1 commit intoClimate-Vision:mainfrom
Conversation
- ONNXSession: Cached session manager with auto CPU/CUDA provider selection - run_onnx_inference: Batch inference with latency tracking - benchmark_onnx_model: Full benchmarking (p50/p95/p99 + FPS) - export_unet_to_onnx / export_siamese_to_onnx: Dynamic axis, configurable opset - run_inference_with_fallback: ONNX to PyTorch automatic fallback - validate_onnx_model: Cross-validation with PyTorch output - 32 unit tests across 11 test classes - Graceful skip when torch/onnx not available Closes Climate-Vision#12
Goldokpa
left a comment
There was a problem hiding this comment.
Thanks for the substantial PR! Solid scope and good test coverage. Before this can move forward there are a few real bugs and a few things worth questioning:
Bugs that will break at runtime:
-
run_onnx_inferencewarm-up call is malformed. Inonnx_runtime.py:session.run(session.input_name, {session.input_name: image[:1]})
ONNXSession.runonly takesinput_data(one positional arg), and the underlying_session.runexpects(output_names, input_feed)whereoutput_namesisNoneor a list — not a string. This will throw on every call. Should just besession.run(image[:1]). -
all_outputs.extend(outputs)thenall_outputs[0].session.run()returns a list of arrays (one per output). For a single-output model with one batch,extendflattens to[array], but for multi-batch this stitches outputs from different batches as siblings rather than concatenating along the batch axis. The downstreamargmax/softmaxthen runs only on the first batch's logits and silently drops the rest. Usenp.concatenate([out[0] for out in batched_outputs], axis=0)or similar. -
Fallback path calls
pytorch_inference(image, analysis_type=...)butimageis a numpy array. Worth confirmingrun_inferenceinpipeline.pyaccepts that — if it expects a tensor or a file path, the fallback will always fail. -
_EXECUTION_PROVIDERSis malformed. The CUDA entry is("CUDAExecutionProvider", {"device_id": "0"})— butdevice_idshould be an int, not a string, per the ORT API. Alsoort.InferenceSession(providers=...)expects either a list of strings or a list of(name, options_dict)tuples; the code mixes both formats and theselected_providersfilter only checks the name, so it passes the malformed tuple through unchanged.
Things that look suspicious / worth challenging:
-
Test file references an absolute home path —
pip install -e /home/fa/projects/climatevision-workindocs/onnx-runtime-guide.md, plus the docstring ononnx_runtime.pyand the_DEFAULT_ONNX_DIR = parents[3](same parents-index issue I'd want verified — file is atsrc/climatevision/inference/onnx_runtime.py, soparents[3]should be the repo root, that one's actually correct here, but worth confirming foronnx_export.pytoo). The hardcoded developer path in the docs should be removed. -
873-line test file with 40+ test cases for a new module is unusual. Many tests follow the pattern of
try: import onnxruntime; except: pytest.skip()— meaning in CI withoutonnxruntimeinstalled (which is a new dependency this PR adds), almost every test silently skips. The test fortest_session_raises_without_onnxruntimeliterally haspass # Skip this test as it requires complex mockinginside it, so it asserts nothing. Worth pruning to focused tests that actually run, or pinningonnxruntimeas a test dep. -
Dependencies aren't added to
pyproject.toml/requirements.txt. The PR introducesonnx>=1.14.0andonnxruntime>=1.15.0but only mentions them in the markdown doc. Imports will fail in any environment that hasn't been manually prepared. -
run_onnx_inferencehas no path that returns adict[str, Any]despite the return type annotationONNXInferenceResult | dict[str, Any]. Either the dict branch is missing or the annotation is wrong. -
export_model_from_checkpointcallstorch.load(...)withoutweights_only=True. Recent PyTorch versions warn on this, and it's a security concern for untrusted checkpoints.
Happy to re-review once these are addressed. The overall architecture (session caching, fallback path, benchmark dataclass) is reasonable — the issues are mostly in the wiring.
[Good First Issue] Add ONNX Runtime inference path with PyTorch fallback
Resolves #12
Summary
Implements a complete ONNX Runtime inference backend for ClimateVision with automatic PyTorch fallback, enabling faster inference on CPU and edge devices while maintaining full compatibility with existing PyTorch models.
Changes
Core Features
Test Coverage
Technical Details