Skip to content

Proposal: Enhance model architecture robustness for ONNX/TensorRT export and Multi-IO support #124

@Bahdmanbabzo

Description

@Bahdmanbabzo

Is your feature request related to a problem? Please describe.

When attempting to export the SynthSeg unet architecture to portable formats like ONNX or TensorRT, several architectural rigidities in ext/neuron/models.py cause the conversion to fail. Specifically:

  • Shape Access Inconsistency: The code occasionally assumes a specific Keras/TF version’s behavior for TensorShape (e.g., expecting .as_list()), leading to AttributeError during graph tracing.

  • Multi-IO Fragility: The current unet and conv_dec logic assumes single-tensor inputs. If an input_model with multiple inputs/outputs is provided, the build fails with a TypeError.

  • Batch Dimension Interference: Inconsistent indexing can accidentally pull the batch size into spatial calculations, breaking decoder reconstruction for 3D volumes.

Describe the solution you'd like

I have developed a set of backward-compatible refinements for the ext.neuron.models module that resolve these issues:

  • Sequence Handling: Added isinstance(..., (list, tuple)) checks to ensure the decoder safely selects the primary tensor in multi-IO models.

  • Version-Agnostic Shape Retrieval: Implemented a fallback mechanism (hasattr(shape, 'as_list')) to handle shape extraction across different TensorFlow/Keras environments.

  • Standardized Slicing: Consistently usingshape[1:]to isolate spatial and channel dimensions, ensuring the batch dimension remains excluded during layer initialization.

It looks like this

# Handle potential multi-input models
input_tensor = input_model.input
if isinstance(input_tensor, (list, tuple)):
    input_tensor = input_tensor[0]

# Handle potential multi-output models
last_tensor = input_model.output
if isinstance(last_tensor, (list, tuple)):
    last_tensor = last_tensor[0]

# Version-agnostic shape retrieval for stable dimension slicing
if hasattr(last_tensor.shape, 'as_list'):
    input_shape = last_tensor.shape.as_list()[1:]
else:
    input_shape = list(last_tensor.shape)[1:]

Describe alternatives you've considered

Using keras.backend.int_shape is an alternative, but the proposed explicit checks offer higher stability across custom Keras-wrapped tensors and environments where the backend may not be fully initialized during a conversion trace.

Verification

I have uploaded the generated ONNX models, conversion script and the code I used to run inference (using the fp32 variant) on my huggingface. It returns an accurate result on NIfti files.

Additional context

The primary goal of these changes is to enable the export of SynthSeg weights to FP32 or INT8 ONNX models. This facilitates high-performance, browser-based inference (via ONNX Runtime Web), which is critical for making these tools accessible in resource-limited research settings where maintaining a full Python/TensorFlow environment is impractical.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions