h5模型转pb转tflite实现及报错解决
h5模型转pb转tflite实现及报错解决
这里我已Mask R-CNN训练出来的模型为例,我的h5文件不保存模型结构,我的代码有转为带结构的h5文件。话不多说,直接上代码。
def save_model(path):
"""
将训练的仅保存参数的h5文件转换为将整个model结构及参数保存的H5 model
:param path: h5 model path
:return:
"""
test_config = InferenceConfig()
model = MaskRCNN(config=test_config, mode="inference", name="strip", model_dir=root)
model.load_weights(path, by_name=True)
model.keras_model.save('maskrcnnstrip.h5')
def to_pb():
with K.get_session() as sess:
export_path = '../logs/saved_model'
builder = tf.saved_model.builder.SavedModelBuilder(export_path)
signature_inputs = {
'input_image': tf.saved_model.utils.build_tensor_info(model.input[0]),
'input_image_meta': tf.saved_model.utils.build_tensor_info(model.input[1]),
'input_anchors': tf.saved_model.utils.build_tensor_info(model.input[2]),
}
signature_outputs = {
'mrcnn_detection': tf.saved_model.utils.build_tensor_info(model.output[0]),
'mrcnn_class': tf.saved_model.utils.build_tensor_info(model.output[1]),
'mrcnn_bbox': tf.saved_model.utils.build_tensor_info(model.output[2]),
'mrcnn_mask': tf.saved_model.utils.build_tensor_info(model.output[3]),
'ROI': tf.saved_model.utils.build_tensor_info(model.output[4]),
'rpn_class': tf.saved_model.utils.build_tensor_info(model.output[5]),
'rpn_bbox': tf.saved_model.utils.build_tensor_info(model.output[6]),
}
classification_signature_def = tf.saved_model.signature_def_utils.build_signature_def(
inputs=signature_inputs,
outputs=signature_outputs,
method_name=tf.saved_model.PREDICT_METHOD_NAME)
builder.add_meta_graph_and_variables(
sess,
[tf.saved_model.SERVING],
signature_def_map={
'root': classification_signature_def
},
)
builder.save()
def pb_to_tflite(path):
"""
将 pb 模型转换为 tflite
:param path:
:return:
"""
converter = tf.lite.TFLiteConverter.from_saved_model(path, signature_key='root')
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
converter.allow_custom_ops = True
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
pb转tflite时报错:
ValueError: No ‘serving_default’ in the SavedModel’s SignatureDefs. Possible values are ‘XXX’.

这种错一般在 tf.lite.TFLiteConverter.from_saved_model ()中加入signature_key=‘XXX’,即可。