Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tnn_runtime.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ The doc introduces how to convert models into TNN.

**step 1**, convert your model to ONNX, run:

```
```shell
python3 tools/convert2onnx.py <config-file> --input-img <img-dir> --shape 512 512 --checkpoint <model-ckpt>
# if you want to use onnxsim, you could invoke it just by adding `--onnxsim`
```

**step 2**, clone the TNN:
Expand Down
29 changes: 27 additions & 2 deletions tools/convert2onnx.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os

import numpy as np

Expand Down Expand Up @@ -534,7 +535,7 @@ def pytorch2onnx(model,
mm_inputs,
show=False,
output_file='tmp.onnx',
num_classes=150):
onnxsim=False):
model.cpu().eval()
imgs = mm_inputs.pop('imgs')

Expand All @@ -549,6 +550,27 @@ def pytorch2onnx(model,
keep_initializers_as_inputs=False,
verbose=show)
print(f'Successfully exported ONNX model: {output_file}')

if onnxsim:
import onnx

from onnxsim import simplify

print(f'Simplifying the {output_file}...')

# use onnxsimplify to reduce reduent model.
onnx_model = onnx.load(output_file)

model_simp, check = simplify(onnx_model)

assert check, "Simplified ONNX model could not be validated"

stem, suffix = os.path.splitext(output_file)
outsim_file = stem + '-sim' + suffix

onnx.save(model_simp, outsim_file)

print(f'Successfully simplify the model: {outsim_file}')


def parse_args():
Expand All @@ -561,6 +583,7 @@ def parse_args():
'--show',
action='store_true',
help='show onnx graph and segmentation results')
parser.add_argument("--onnxsim", action="store_true", help="use onnxsim or not")
parser.add_argument('--output-file', type=str, default='tmp.onnx')
parser.add_argument('--opset-version', type=int, default=11)
parser.add_argument(
Expand Down Expand Up @@ -613,4 +636,6 @@ def parse_args():
mm_inputs = _demo_mm_inputs(input_shape, num_classes=150)

# convert model to onnx file
pytorch2onnx(segmentor, mm_inputs, show=args.show, output_file=args.output_file)
pytorch2onnx(segmentor, mm_inputs, show=args.show, output_file=args.output_file, onnxsim=args.onnxsim)