From 9b4410469a1ea83557e0ff6f37a8fb915a1e5fd1 Mon Sep 17 00:00:00 2001 From: lry89757 Date: Sun, 23 Oct 2022 23:59:07 +0800 Subject: [PATCH] Add the onnxsim support --- tnn_runtime.md | 3 ++- tools/convert2onnx.py | 29 +++++++++++++++++++++++++++-- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/tnn_runtime.md b/tnn_runtime.md index 077adb4..c6e0c8e 100644 --- a/tnn_runtime.md +++ b/tnn_runtime.md @@ -4,8 +4,9 @@ The doc introduces how to convert models into TNN. **step 1**, convert your model to ONNX, run: -``` +```shell python3 tools/convert2onnx.py --input-img --shape 512 512 --checkpoint +# if you want to use onnxsim, you could invoke it just by adding `--onnxsim` ``` **step 2**, clone the TNN: diff --git a/tools/convert2onnx.py b/tools/convert2onnx.py index 366b340..46f962e 100644 --- a/tools/convert2onnx.py +++ b/tools/convert2onnx.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse +import os import numpy as np @@ -534,7 +535,7 @@ def pytorch2onnx(model, mm_inputs, show=False, output_file='tmp.onnx', - num_classes=150): + onnxsim=False): model.cpu().eval() imgs = mm_inputs.pop('imgs') @@ -549,6 +550,27 @@ def pytorch2onnx(model, keep_initializers_as_inputs=False, verbose=show) print(f'Successfully exported ONNX model: {output_file}') + + if onnxsim: + import onnx + + from onnxsim import simplify + + print(f'Simplifying the {output_file}...') + + # use onnxsimplify to reduce reduent model. + onnx_model = onnx.load(output_file) + + model_simp, check = simplify(onnx_model) + + assert check, "Simplified ONNX model could not be validated" + + stem, suffix = os.path.splitext(output_file) + outsim_file = stem + '-sim' + suffix + + onnx.save(model_simp, outsim_file) + + print(f'Successfully simplify the model: {outsim_file}') def parse_args(): @@ -561,6 +583,7 @@ def parse_args(): '--show', action='store_true', help='show onnx graph and segmentation results') + parser.add_argument("--onnxsim", action="store_true", help="use onnxsim or not") parser.add_argument('--output-file', type=str, default='tmp.onnx') parser.add_argument('--opset-version', type=int, default=11) parser.add_argument( @@ -613,4 +636,6 @@ def parse_args(): mm_inputs = _demo_mm_inputs(input_shape, num_classes=150) # convert model to onnx file - pytorch2onnx(segmentor, mm_inputs, show=args.show, output_file=args.output_file) + pytorch2onnx(segmentor, mm_inputs, show=args.show, output_file=args.output_file, onnxsim=args.onnxsim) + +