h5模型转pb转tflite实现及报错解决

h5模型转pb转tflite实现及报错解决

    这里我已Mask R-CNN训练出来的模型为例,我的h5文件不保存模型结构,我的代码有转为带结构的h5文件。话不多说,直接上代码。

def save_model(path):
    """
    将训练的仅保存参数的h5文件转换为将整个model结构及参数保存的H5 model
    :param path: h5 model path
    :return:
    """
    test_config = InferenceConfig()
    model = MaskRCNN(config=test_config, mode="inference", name="strip", model_dir=root)
    model.load_weights(path, by_name=True)
    model.keras_model.save('maskrcnnstrip.h5')


def to_pb():
    with K.get_session() as sess:
        export_path = '../logs/saved_model'
        builder = tf.saved_model.builder.SavedModelBuilder(export_path)

        signature_inputs = {
            'input_image': tf.saved_model.utils.build_tensor_info(model.input[0]),
            'input_image_meta': tf.saved_model.utils.build_tensor_info(model.input[1]),
            'input_anchors': tf.saved_model.utils.build_tensor_info(model.input[2]),
        }

        signature_outputs = {
            'mrcnn_detection': tf.saved_model.utils.build_tensor_info(model.output[0]),
            'mrcnn_class': tf.saved_model.utils.build_tensor_info(model.output[1]),
            'mrcnn_bbox': tf.saved_model.utils.build_tensor_info(model.output[2]),
            'mrcnn_mask': tf.saved_model.utils.build_tensor_info(model.output[3]),
            'ROI': tf.saved_model.utils.build_tensor_info(model.output[4]),
            'rpn_class': tf.saved_model.utils.build_tensor_info(model.output[5]),
            'rpn_bbox': tf.saved_model.utils.build_tensor_info(model.output[6]),
        }

        classification_signature_def = tf.saved_model.signature_def_utils.build_signature_def(
            inputs=signature_inputs,
            outputs=signature_outputs,
            method_name=tf.saved_model.PREDICT_METHOD_NAME)

        builder.add_meta_graph_and_variables(
            sess,
            [tf.saved_model.SERVING],
            signature_def_map={
                'root': classification_signature_def
            },
        )

        builder.save()


def pb_to_tflite(path):
    """
    将 pb 模型转换为 tflite
    :param path:
    :return:
    """
    converter = tf.lite.TFLiteConverter.from_saved_model(path, signature_key='root')
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.target_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
    converter.allow_custom_ops = True
    tflite_model = converter.convert()
    open("converted_model.tflite", "wb").write(tflite_model)

pb转tflite时报错:
ValueError: No ‘serving_default’ in the SavedModel’s SignatureDefs. Possible values are ‘XXX’.
在这里插入图片描述

    这种错一般在 tf.lite.TFLiteConverter.from_saved_model ()中加入signature_key=‘XXX’,即可。