Skip to content

Enhance safety, error handling, and numerical stability in tensor ops#25

Merged
JGalego merged 15 commits intomainfrom
develop
Mar 25, 2026
Merged

Enhance safety, error handling, and numerical stability in tensor ops#25
JGalego merged 15 commits intomainfrom
develop

Conversation

@JGalego
Copy link
Copy Markdown
Owner

@JGalego JGalego commented Mar 25, 2026

No description provided.

Critical fixes and improvements from comprehensive code review:

Safety Improvements:
- Add SAFETY documentation to all unsafe get_unchecked calls in Conv operator
- Document bounds checking guarantees and invariant maintenance
- Ensure all unsafe operations have verified preconditions

Error Handling:
- Change Tensor::relu() and Tensor::sigmoid() to return Result<Tensor>
- Add validation for non-finite values (NaN, Inf) in activation functions
- Provide clear error messages for invalid inputs
- Update 20+ call sites across codebase to handle Result types

Numerical Stability:
- Implement numerically stable sigmoid with input clamping [-500, 500]
- Use different computation paths for positive/negative values
- Prevent exp() overflow/underflow in extreme value ranges
- Maintain accuracy and symmetry properties across full input domain

Test Improvements:
- Fix panic in formal.rs random_tensor helper (support 1D/N-D tensors)
- Update integration tests with realistic input ranges
- Fix formal verification tests to handle Result types
- Update examples (tensor_ops, simple_model) with proper error handling
- Remove undefined behavior tests (infinity inputs)

Verification:
- All 232 unit tests pass
- All 10 integration tests pass
- All 25 doctests pass
- Zero clippy warnings
- Clean compilation

Breaking Changes:
- Tensor::relu() and Tensor::sigmoid() now return Result<Tensor>
  Migration: Change `.relu()` to `.relu()?` and `.sigmoid()` to `.sigmoid()?`

Files Modified:
- src/operators.rs: Safety docs, error handling
- src/tensor.rs: Result types, numerical stability
- src/formal.rs: Test helper robustness
- tests/*.rs: Result handling
- examples/*.rs: Error propagation
Pin image crate to =0.25.5 to maintain Rust 1.85.0 MSRV compatibility.

The image 0.25.6+ versions require Rust 1.87.0 due to their dependency
on zune-jpeg@0.5.8. Since image and imageproc are only used in examples
(not the core library), pinning to the last compatible version is the
most pragmatic solution.

This resolves the CI error:
  error: rustc 1.85.1 is not supported by the following package:
    zune-jpeg@0.5.8 requires rustc 1.87.0

Dependency changes:
- image: "0.25" → "=0.25.5" (exact version pin)
- Removes zune-jpeg@0.5.8 from dependency tree
- Maintains all functionality (examples continue to work)

Alternative considered: Updating MSRV to 1.87, but decided to maintain
broader compatibility for users on stable Rust 1.85.
src/tensor.rs
- div: add explicit division-by-zero guard before performing the op
- sqrt: add non-negativity check; sqrt of negative is undefined
- pow: switch from exact shape match to broadcast_tensors, consistent
  with add/mul

src/operators.rs
- conv_op: remove unsafe get_unchecked indexing; use ndarray indexing
  instead. Fix bias application to cover all batches (previously only
  batch 0 received the bias). Simplify 4D bias shape handling.
- doc: change "verified with Why3" to "specified in formal/operators.mlw"
  to accurately reflect what the .mlw files currently provide

src/graph.rs
- validate: collect all tensor names (inputs + initializers + node
  outputs) before checking node inputs, so validation is independent
  of node listing order and forward references don't spuriously fail

src/runtime.rs
- add_tensor: subtract freed memory when overwriting an existing tensor
  so overwrites don't inflate the reported memory usage total

src/formal.rs
- softmax debug_assert and test: check per-row sums (last-axis softmax
  produces one distribution per row, not one for the whole tensor)
- concat_with_contracts: call Tensor::concat instead of returning
  unsupported error
- random_tensor: replace DefaultHasher (not cross-version stable) with
  a 64-bit LCG for reproducible property-test seeds
- to_why3_tensor: fix array literal format to match Why3 syntax
- generate_proof_obligations: replace fake goal strings with actual
  theory-qualified names matching the .mlw files

formal/*.mlw
- tensors.mlw: replace recursive shape_product (unterminating in Why3
  theories) with an axiomatised declaration; add same_shape equivalence
  lemmas and shape_product_zero
- operators.mlw: strengthen matmul_spec, sigmoid_spec, softmax_spec,
  conv_spec, concat_spec (add axis param); add 9 proved element-level
  lemmas (commutativity, identity, relu properties, sigmoid bounds)

Cargo.toml
- move image/imageproc to dev-dependencies (only used by examples);
  deduplicate env_logger (was in both [dependencies] and [dev-dependencies])
- Update runnx dependency version from 0.2.0 to 0.2.1 in README and quick-start guide
- Fix release notes link to point to CHANGELOG anchor for v0.2.1
- Remove duplicate content in docs/guides/quick-start.md (old partial version was prepended to the full file)
- Fix FORMAL_VERIFICATION.md: replace non-existent simple_specs.mlw with actual files (tensors.mlw, operators.mlw)
- Fix FORMAL_VERIFICATION.md: CVC4 → CVC5
Demonstrates building a binary classifier with sigmoid(X @ W + b) using
the graph API with MatMul, Add, and Sigmoid nodes, including manual
verification and a save/reload round-trip.
- Tensor.data is now Arc<ArrayD<f32>>, making clone O(1) instead of
  copying the full buffer; eliminates the per-node tensor copy hot path
- Runtime groups nodes into independent waves via topological_levels()
  and executes each wave with rayon under the new `parallel` feature
- Add serde "rc" feature to support Arc serialization
- Add rayon as an optional dependency behind a new `parallel` feature flag
- Add Graph::topological_levels() which groups nodes into independent
  execution waves using tensor availability levels
- Refactor Runtime::execute() to use wave-based scheduling: gather inputs
  sequentially, run operators in parallel per wave via rayon par_iter,
  then store outputs sequentially; falls back to iter when feature is off
- Remove the now-unused execute_node() helper
- Add wide = "0.7" dependency for portable 256-bit SIMD (f32x8)
- Add src/simd.rs with relu, sigmoid, exp, sqrt helpers that process 8 floats per iteration; fall back to mapv_inplace for non-contiguous arrays
- Wire SIMD helpers into Tensor::relu, sigmoid, exp, sqrt; sigmoid clamp tightened from [-500,500] to [-88,88] (sufficient for f32 exp range)
- Add Python reference script examples/yolov8_detect_and_draw.py mirroring the Rust example using ultralytics
Introduce three selectable Conv implementations behind feature flags:

- `naive-conv`: original 6-level nested loop, kept as a readable reference for comparing against the optimised paths
- default (im2col): rearranges input patches into a column matrix then calls a single ndarray dot (matrixmultiply GEMM); converts Conv into one cache-friendly matrix multiply instead of scattered memory reads
- `blas`: same im2col transform but delegates the GEMM to OpenBLAS sgemm via the cblas crate for maximum throughput on tuned systems

Add cblas and blas-src as optional dependencies; document the libopenblas-dev system requirement in the blas back-end doc comment.
Add 53 targeted tests covering previously uncovered branches in operators.rs: reshape error paths, concat negative axis and single-input early return, slice fallback and tensor inputs, split axis validation and zero-sized chunks, gather index clamping and negative indices, unsqueeze/ edge cases, pad wrong-length and non-constant mode, batch_norm non-4D and multi-batch, softmax axis-out-of-bounds and multi-dim path, resize fallback variants, conv channel mismatch, and maxpool padding.
Replace a no-op test (result discarded, no assertion) and a bare is_ok() check with meaningful assertions: verify the output shape and computed value for batch_norm with custom epsilon, and assert the correct output shape for the empty-axes unsqueeze path.
@JGalego JGalego merged commit d48f534 into main Mar 25, 2026
33 of 35 checks passed
@codecov
Copy link
Copy Markdown

codecov bot commented Mar 25, 2026

Codecov Report

❌ Patch coverage is 90.30837% with 44 lines in your changes missing coverage. Please review.
✅ Project coverage is 95.09%. Comparing base (bf0bc6b) to head (95147d9).
⚠️ Report is 18 commits behind head on main.

Files with missing lines Patch % Lines
src/tensor.rs 85.39% 26 Missing ⚠️
src/runtime.rs 82.75% 10 Missing ⚠️
src/simd.rs 88.00% 6 Missing ⚠️
src/graph.rs 97.91% 1 Missing ⚠️
src/operators.rs 99.16% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main      #25      +/-   ##
==========================================
+ Coverage   89.41%   95.09%   +5.67%     
==========================================
  Files           8        9       +1     
  Lines        5151     5262     +111     
==========================================
+ Hits         4606     5004     +398     
+ Misses        545      258     -287     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant