diff --git a/examples/tensorrt_qat/models.py b/examples/tensorrt_qat/models.py index 5c4e562..adfac9d 100644 --- a/examples/tensorrt_qat/models.py +++ b/examples/tensorrt_qat/models.py @@ -47,7 +47,7 @@ def Lenet(data): pool2 = flow.nn.max_pool2d( conv2, ksize=2, strides=2, padding="VALID", name="pool2", data_format="NCHW" ) - reshape = flow.reshape(pool2, [pool2.shape[0], -1]) + reshape = flow.reshape(pool2, [pool2.shape[0], -1], name="reshape") hidden = flow.layers.dense( reshape, 512, @@ -57,6 +57,7 @@ def Lenet(data): ) return flow.layers.dense(hidden, 10, kernel_initializer=initializer, name="dense2") + def _get_regularizer(model_name): # all decay return flow.regularizers.l2(0.00004) @@ -337,7 +338,7 @@ def build_network( name="pool5", ) fc = flow.layers.dense( - flow.reshape(pool, (pool.shape[0], -1)), + flow.reshape(pool, (pool.shape[0], -1), name="reshape"), units=class_num, use_bias=False, kernel_initializer=_get_initializer("dense_weight"), @@ -392,7 +393,7 @@ def get_lenet_job_function( func_config.qat.symmetric(True) func_config.qat.per_channel_weight_quantization(False) func_config.qat.moving_min_max_stop_update_after_iters(1000) - func_config.qat.target_backend("tensorrt7") + func_config.qat.target_backend("tensorrt") if func_type == "train": @flow.global_function(type="train", function_config=func_config) @@ -438,7 +439,7 @@ def get_mobilenet_job_function( func_config.qat.symmetric(True) func_config.qat.per_channel_weight_quantization(False) func_config.qat.moving_min_max_stop_update_after_iters(1000) - func_config.qat.target_backend("tensorrt7") + func_config.qat.target_backend("tensorrt") if func_type == "train": @flow.global_function(type="train", function_config=func_config) diff --git a/oneflow_onnx/oneflow2onnx/handlers/quantize.py b/oneflow_onnx/oneflow2onnx/handlers/quantize.py index 519b425..7a861e3 100644 --- a/oneflow_onnx/oneflow2onnx/handlers/quantize.py +++ b/oneflow_onnx/oneflow2onnx/handlers/quantize.py @@ -90,8 +90,14 @@ def get_min_or_max_value(get_min: bool, pre_func: Optional[Callable] = None): raise ValueError("invalid quantization formula: " + formula) ctx.RemoveNode(node.name) - ctx.MakeConst(node.output_tensor_names[0], scale) - ctx.MakeConst(node.output_tensor_names[1], zero_point) + ctx.MakeConst( + node.output_tensor_names[0], + scale.squeeze() if formula == "cambricon" or per_layer else scale, + ) + ctx.MakeConst( + node.output_tensor_names[1], + zero_point.squeeze() if formula == "cambricon" or per_layer else zero_point, + ) @classmethod def Version_10(cls, ctx: Graph, node: Node, **kwargs): @@ -144,8 +150,8 @@ def Version_10(cls, ctx: Graph, node: Node, **kwargs): raise ValueError("invalid quantization formula: " + formula) ctx.RemoveNode(node.name) - ctx.MakeConst(node.output_tensor_names[0], scale.flatten()) - ctx.MakeConst(node.output_tensor_names[1], zero_point) + ctx.MakeConst(node.output_tensor_names[0], scale.squeeze()) + ctx.MakeConst(node.output_tensor_names[1], zero_point.squeeze()) @flow_op( @@ -167,7 +173,7 @@ def _Convert(cls, ctx: Graph, node: Node, opset: int, **kwargs): ) if opset < 13: scale_shape = ctx.get_shape(node.input_tensor_names[1]) - if not (len(scale_shape) == 1 and scale_shape[0] == 1): + if not len(scale_shape) == 0: raise RuntimeError("per-channel mode is not supported in version 10") else: