-
Notifications
You must be signed in to change notification settings - Fork 150
Proposal: Enhance model architecture robustness for ONNX/TensorRT export and Multi-IO support #124
Description
Is your feature request related to a problem? Please describe.
When attempting to export the SynthSeg unet architecture to portable formats like ONNX or TensorRT, several architectural rigidities in ext/neuron/models.py cause the conversion to fail. Specifically:
-
Shape Access Inconsistency: The code occasionally assumes a specific Keras/TF version’s behavior for TensorShape (e.g., expecting
.as_list()), leading toAttributeErrorduring graph tracing. -
Multi-IO Fragility: The current
unetandconv_declogic assumes single-tensor inputs. If an input_model with multiple inputs/outputs is provided, the build fails with aTypeError. -
Batch Dimension Interference: Inconsistent indexing can accidentally pull the batch size into spatial calculations, breaking decoder reconstruction for 3D volumes.
Describe the solution you'd like
I have developed a set of backward-compatible refinements for the ext.neuron.models module that resolve these issues:
-
Sequence Handling: Added
isinstance(..., (list, tuple))checks to ensure the decoder safely selects the primary tensor in multi-IO models. -
Version-Agnostic Shape Retrieval: Implemented a fallback mechanism (
hasattr(shape, 'as_list')) to handle shape extraction across different TensorFlow/Keras environments. -
Standardized Slicing: Consistently using
shape[1:]to isolate spatial and channel dimensions, ensuring the batch dimension remains excluded during layer initialization.
It looks like this
# Handle potential multi-input models
input_tensor = input_model.input
if isinstance(input_tensor, (list, tuple)):
input_tensor = input_tensor[0]
# Handle potential multi-output models
last_tensor = input_model.output
if isinstance(last_tensor, (list, tuple)):
last_tensor = last_tensor[0]
# Version-agnostic shape retrieval for stable dimension slicing
if hasattr(last_tensor.shape, 'as_list'):
input_shape = last_tensor.shape.as_list()[1:]
else:
input_shape = list(last_tensor.shape)[1:]
Describe alternatives you've considered
Using keras.backend.int_shape is an alternative, but the proposed explicit checks offer higher stability across custom Keras-wrapped tensors and environments where the backend may not be fully initialized during a conversion trace.
Verification
I have uploaded the generated ONNX models, conversion script and the code I used to run inference (using the fp32 variant) on my huggingface. It returns an accurate result on NIfti files.
Additional context
The primary goal of these changes is to enable the export of SynthSeg weights to FP32 or INT8 ONNX models. This facilitates high-performance, browser-based inference (via ONNX Runtime Web), which is critical for making these tools accessible in resource-limited research settings where maintaining a full Python/TensorFlow environment is impractical.