Skip to content

Commit f705071

Browse files
authored
Remove redundancies in pytorch2onnx (open-mmlab#160)
* rm redundancies * re-add some packages
1 parent e1f4f51 commit f705071

File tree

1 file changed

+1
-6
lines changed

1 file changed

+1
-6
lines changed

tools/pytorch2onnx.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
import numpy as np
66
import onnxruntime as rt
77
import torch
8-
from torch import nn
98
import torch._C
109
import torch.serialization
1110
from mmcv.onnx import register_extra_symbolics
1211
from mmcv.runner import load_checkpoint
12+
from torch import nn
1313

1414
from mmseg.models import build_segmentor
1515

@@ -186,11 +186,6 @@ def parse_args():
186186
# convert SyncBN to BN
187187
segmentor = _convert_batchnorm(segmentor)
188188

189-
if isinstance(segmentor.decode_head, nn.ModuleList):
190-
num_classes = segmentor.decode_head[-1].num_classes
191-
else:
192-
num_classes = segmentor.decode_head.num_classes
193-
194189
if args.checkpoint:
195190
load_checkpoint(segmentor, args.checkpoint, map_location='cpu')
196191

0 commit comments

Comments
 (0)