Skip to content

Commit 63555c8

Browse files
Corallopre-commit-ci[bot]glenn-jocher
authored
Add option to quantize per-tensor (ultralytics#12516)
* Add option to quantize per-tensor * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Glenn Jocher <[email protected]>
1 parent f33d42d commit 63555c8

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

export.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,8 @@ def export_pb(keras_model, file, prefix=colorstr('TensorFlow GraphDef:')):
448448

449449

450450
@try_export
451-
def export_tflite(keras_model, im, file, int8, data, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')):
451+
def export_tflite(keras_model, im, file, int8, per_tensor, data, nms, agnostic_nms,
452+
prefix=colorstr('TensorFlow Lite:')):
452453
# YOLOv5 TensorFlow Lite export
453454
import tensorflow as tf
454455

@@ -469,6 +470,8 @@ def export_tflite(keras_model, im, file, int8, data, nms, agnostic_nms, prefix=c
469470
converter.inference_input_type = tf.uint8 # or tf.int8
470471
converter.inference_output_type = tf.uint8 # or tf.int8
471472
converter.experimental_new_quantizer = True
473+
if per_tensor:
474+
converter._experimental_disable_per_channel = True
472475
f = str(file).replace('.pt', '-int8.tflite')
473476
if nms or agnostic_nms:
474477
converter.target_spec.supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS)
@@ -713,6 +716,7 @@ def run(
713716
keras=False, # use Keras
714717
optimize=False, # TorchScript: optimize for mobile
715718
int8=False, # CoreML/TF INT8 quantization
719+
per_tensor=False, # TF per tensor quantization
716720
dynamic=False, # ONNX/TF/TensorRT: dynamic axes
717721
simplify=False, # ONNX: simplify model
718722
opset=12, # ONNX: opset version
@@ -798,7 +802,14 @@ def run(
798802
if pb or tfjs: # pb prerequisite to tfjs
799803
f[6], _ = export_pb(s_model, file)
800804
if tflite or edgetpu:
801-
f[7], _ = export_tflite(s_model, im, file, int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms)
805+
f[7], _ = export_tflite(s_model,
806+
im,
807+
file,
808+
int8 or edgetpu,
809+
per_tensor,
810+
data=data,
811+
nms=nms,
812+
agnostic_nms=agnostic_nms)
802813
if edgetpu:
803814
f[8], _ = export_edgetpu(file)
804815
add_tflite_metadata(f[8] or f[7], metadata, num_outputs=len(s_model.outputs))
@@ -837,6 +848,7 @@ def parse_opt(known=False):
837848
parser.add_argument('--keras', action='store_true', help='TF: use Keras')
838849
parser.add_argument('--optimize', action='store_true', help='TorchScript: optimize for mobile')
839850
parser.add_argument('--int8', action='store_true', help='CoreML/TF/OpenVINO INT8 quantization')
851+
parser.add_argument('--per-tensor', action='store_true', help='TF per-tensor quantization')
840852
parser.add_argument('--dynamic', action='store_true', help='ONNX/TF/TensorRT: dynamic axes')
841853
parser.add_argument('--simplify', action='store_true', help='ONNX: simplify model')
842854
parser.add_argument('--opset', type=int, default=17, help='ONNX: opset version')

0 commit comments

Comments
 (0)