tf中线程与graph读取的关系
def import_graph_fun(pb_model_name):
output_graph_def = tf.GraphDef()
with open(pb_model_name, "rb") as f:
output_graph_def.ParseFromString(f.read())
tf.import_graph_def(output_graph_def, name="")
sess = tf.Session()
# other code
当我们训练好模型之后,将模型存储成pb格式,然后上面这段代码是读取该文件,一直以来都没有问题,直到有一天当我们需要在一个线程中运行多个模型的时候报错,报错的内容大概就是在预测阶段,不能从graph中找到对应节点。
之前我们都知道在tf框架中:1. 所有的graph都要在session中运行,并且一个session中只能运行一个graph,但是同样的graph可以在不同的session中运行; 2. 如果没有指定graph,框架会为我们生成默认的graph;3. 所有的op操作都会添加到模型的graph上。
其实进行到了这里,基本就能解开上面说的报错的原因:如果我们在一个线程中跑多个graph,我们就必须要有多个session,并且要给每个session绑定它对应的graph,之所以在一个线程中只跑一个图一直没出错是因为就一个图,默认这个图就是绑定了这个session的,不存在歧义。
其实我当时在写这里的时候有个疑问:这条语句graph_def.ParseFromString(f.read())是从pb文件中将序列化的graph解析出来,然后根据这条语句tf.import_graph_def(graph_def, name=“”)将这个graph_def导入,那么问题是这个导入是导入到了哪里?该接口并没有将graph return回来,它去了哪里,我们怎么拿到它?答案是:该接口直接将import出来的graph中所有的op添加到了它对应的上下文的graph中,想要获取它,就要先构造一个上下文环境,然后才能拿到这个graph,具体代码为:
def import_graph_fun(pb_model_name):
output_graph_def = tf.GraphDef()
with open(pb_model_name, "rb") as f:
output_graph_def.ParseFromString(f.read())
# 注意:这条语句非常重要,通过接口as_default()构建了一个上下文环境,此时tf.import_graph_def就是将op添加到了这个环境所对应的graph中,也就是g_
with tf.Graph().as_default() as g_:
tf.import_graph_def(output_graph_def, name="")
# 这里指定一下这个sess1绑定的是g_
sess1 = tf.Session(graph=g_)
# other code
如此,就可以work了。
总结:graph是跟线程相关的。