Skip to content

feat: add ONNX Runtime inference backend with PyTorch fallback (#12)#32

Open
jshaofa-ui wants to merge 1 commit intoClimate-Vision:mainfrom
jshaofa-ui:feature/onnx-runtime-inference
Open

feat: add ONNX Runtime inference backend with PyTorch fallback (#12)#32
jshaofa-ui wants to merge 1 commit intoClimate-Vision:mainfrom
jshaofa-ui:feature/onnx-runtime-inference

Conversation

@jshaofa-ui
Copy link
Copy Markdown

[Good First Issue] Add ONNX Runtime inference path with PyTorch fallback

Resolves #12

Summary

Implements a complete ONNX Runtime inference backend for ClimateVision with automatic PyTorch fallback, enabling faster inference on CPU and edge devices while maintaining full compatibility with existing PyTorch models.

Changes

  • onnx_runtime.py (540 lines) - ONNX Runtime engine with session management, inference, and benchmarking
  • onnx_export.py (422 lines) - PyTorch to ONNX model export for U-Net and Siamese networks
  • init.py (67 lines) - Unified module API combining PyTorch and ONNX inference
  • test_onnx_runtime.py (873 lines) - 32 unit tests across 11 test classes
  • onnx-runtime-guide.md - Complete usage documentation

Core Features

  1. ONNXSession - Cached session manager with automatic CPU/CUDA provider selection
  2. run_onnx_inference() - Batch inference with latency tracking
  3. benchmark_onnx_model() - Full benchmarking (p50/p95/p99 latency + FPS)
  4. export_unet_to_onnx() / export_siamese_to_onnx() - Dynamic axis, configurable opset
  5. run_inference_with_fallback() - Automatic ONNX to PyTorch fallback
  6. validate_onnx_model() - Cross-validation with PyTorch output

Test Coverage

  • 11 test classes: ONNXSession caching, device selection, inference, benchmarking, export, validation, fallback, integration
  • 32 unit tests total
  • Graceful skip when torch/onnx not available

Technical Details

  • Zero breaking changes to existing inference pipeline
  • Automatic provider selection based on hardware availability
  • Session caching for repeated inference calls
  • Full numerical validation against PyTorch baseline

- ONNXSession: Cached session manager with auto CPU/CUDA provider selection
- run_onnx_inference: Batch inference with latency tracking
- benchmark_onnx_model: Full benchmarking (p50/p95/p99 + FPS)
- export_unet_to_onnx / export_siamese_to_onnx: Dynamic axis, configurable opset
- run_inference_with_fallback: ONNX to PyTorch automatic fallback
- validate_onnx_model: Cross-validation with PyTorch output
- 32 unit tests across 11 test classes
- Graceful skip when torch/onnx not available

Closes Climate-Vision#12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Good First Issue] Add ONNX Runtime inference path with PyTorch fallback

1 participant