tf.estimator.Estimator类的用法

https://blog.csdn.net/liushuikong/article/details/79223407

https://www.cnblogs.com/arkenstone/p/8448208.html

https://zhuanlan.zhihu.com/p/41473323

https://yinguobing.com/facial-landmark-localization-by-deep-learning-save-model-application/#fn3

https://www.cnblogs.com/YouXiangLiThon/p/7435825.html

https://guillaumegenthial.github.io/serving-tensorflow-estimator.html

1、estimator

estimator类是机器学习模型的抽象,estimator允许开发者自定义任意的结构模型、损失函数、优化函数以及如何对这个模型进行训练、导出、评估等内容,同时屏蔽了与底层硬件设备、分布式网络数据传输等相关细节

tf.eatimator.Estimator(
    model_fn=model_fn,
    params=params,
    config=run_config
)

一个estimator,需要传入模型函数,参数和配置

参数应该是模型超参数的一个集合,可以是一个字典

配置用于指定模型如何运行训练和评估,以及在哪里存储结果,该对象会把相关信息高速estimator

模型函数一个python函数,它根据给定的输入构建模型

2、estimator类主要有三个方法:train、evaluate、predict,分别表示模型的训练、评估和预测,三个方法都接受一个用户自定义的输入函数input_fn,执行input_fn获取输入数据,estimator的三个方法都会调用model_fn执行具体操作,不同mode传入,返回的也不同

def input_fn(dataset):
    ***
    return feature,label

类内方法,model_fn,train、evaluate、predict参考http://www.cnblogs.com/zongfa/p/10149483.html

3、

def my_model(
    features,     #this is batch_features from input_fn
    labels,       #this is batch_labels from input_fn
    mode,         #an instance of tf.estimator.ModeKeys
    params        #configuration
)

这是固定格式,利用estimator进行train、eval、predict,下面是train方法。train_input_fn()传入所需的features,labels

http://www.cnblogs.com/wdmx/p/10010433.html

classifier.train(input_fn=lambda: train_input_fn(FILE_TRAIN, True, 500))

一个比较好的例子https://guillaumegenthial.github.io/serving-tensorflow-estimator.html     

                               https://github.com/tensorflow/models/blob/master/samples/outreach/blogs/blog_custom_estimators.py                   

4、保存模型    

 https://www.cnblogs.com/arkenstone/p/8448208.html

def serving_input_receiver_fn():
    """
    Build serving inputs
    """
    inputs = tf.placeholder(dtype=tf.string, name="input_image")
    feature_config = {'image/encoded': tf.FixedLenFeature(shape=[], dtype=tf.string)}
    tf_example = tf.parse_example(inputs, feature_config)
    patch_images = tf.map_fn(_preprocess_image, tf_example["image/encoded"], dtype=tf.float32)
    patch_images = tf.squeeze(patch_images, axis=[0])
    receive_tensors = {'example': inputs}
    features = {"input": patch_images}
    return tf.estimator.export.ServingInputReceiver(features, receive_tensors)


def save_serving_model():
    session_config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
    config = tf.estimator.RunConfig(model_dir=MODEL_WEIGHTS_PATH, session_config=session_config)  # session_config is used for configuration of session
    model = get_estimator_model(config=config)
    model.export_savedmodel(export_dir_base=SERVING_MODEL_SAVE_PATH, serving_input_receiver_fn=serving_input_receiver_fn)

 

将训练好的ckpt模型freeze成pb模型文件,estimator提供export_savedmodel函数

export_savedmodel(export_dir_base,serving_inpit_receiver_fn)

该方法首先建立一个图,以获得输入法特征Tensors,,然后调用estimator的model_fn(),以基于这些特征的模型曲线图,开始session,将最新的还原到其中,在export_bir_base下(保存目录)创建一个带时间戳的导出目录,并将savedmodel写入包含mateGraphDefault,保存单个文件

当自定义model_fn时,必须要填充export_output元素,https://www.tensorflow.org/api_docs/python/tf/estimator/EstimatorSpec?hl=zh-cn,这是{name:output}描述在投放期间要导出进而使用的输出签名和命令