tf.Variable用法详解
https://tensorflow.google.cn/api_docs/python/tf/Variable
通过创建Variable类的实例向graph中添加变量。
Variable()需要初始值,一旦初始值确定,那么该变量的类型和形状都确定了。
更改值通过assign方法。
想要改变形状,需要使用assign+validate_shape=False。
想要运行graph,必须先显式初始化变量。
初始化变量3种方法:
①运行它的initializerop,
②从save file中恢复变量
③运行 assign操作,把初值赋给变量。
# Launch the graph in a session.
with tf.Session() as sess:
# Run the variable initializer.
sess.run(w.initializer)
# ...you now can run ops that use the value of 'w'...
但是最常用的初始化方式,是global_variables_initializer(),它可以添加一个op到graph里,这个op可以初始化所有变量。在launch图后,运行这个op。
# Add an Op to initialize global variables.
init_op = tf.global_variables_initializer()
# Launch the graph in a session.
with tf.Session() as sess:
# Run the Op that initializes global variables.
sess.run(init_op)
# ...you can now run any Op that uses variable values...
若想创建一个依赖于其他变量的变量,(If you need to create a variable with an initial value dependent on another variable),用那个变量的initialized_value()(Variable类有这个方法).
所有变量都自动收集到创建它们的图中。默认情况下,创建器把新变量添加到图集合GraphKeys.GLOBAL_VARIABLES中。
这个方便的global_variables()函数返回GraphKeys.GLOBAL_VARIABLES的内容。
在构建机器学习模型时,通常可以方便地区分包含可训练模型参数的变量和其他变量,例如用于计算训练步骤的全局步骤变量。所以,变量构造器支持trainable=<bool>参数。若为True,则新变量会同时被添加到图集合GraphKeys.TRAINABLE_VARIABLES.中。
这个方便的trainable_variables()函数返回GraphKeys.TRAINABLE_VARIABLES的内容
多种多样的Optimizer类把这个集合作为默认的优化变量列表。
从图上看懂tf.Variable()运行机制
代码1
import tensorflow as tf
a = tf.Variable(3, name='lxq')
b = tf.Variable(4, name='lxq2')
# a.assign_add(1)
# c = a+b
init = tf.global_variables_initializer() # 替换成这样就好
sess = tf.Session()
writer=tf.summary.FileWriter("logs", sess.graph) # 文件写在该.py文件同级logs文件夹下
sess.run(init)


代码2(从代码1中增加)
import tensorflow as tf
a = tf.Variable(3, name='lxq')
b = tf.Variable(4, name='lxq2')
a.assign_add(1)
c = a+b
init = tf.global_variables_initializer() # 替换成这样就好
sess = tf.Session()
writer=tf.summary.FileWriter("logs", sess.graph) # 文件写在该.py文件同级logs文件夹下
sess.run(init)
sess.run(c)


解析一下:
initial_value是tf.Variable()的第一个参数,小圆圈是常量的意思,实心箭头是数据流;
把initial_value通过了assign操作,椭圆表示操作节点(OpNode),该操作节点依赖于init操作;
同理,value是常量,通过数据流传递给AssignAdd操作节点。
lxq这个变量,参考两个量(黄色箭头是参考边,猜测lxq这个操作节点有判断选取哪一个值的功能),存在lxq中;
如果外部需要lxq这个变量,lxq存入的值(标量)会先传递给读取节点,然后数据流传向下一个操作节点add。
collection参数的使用
tf.Variable()有一个collection参数,表示创建的变量需要添加到哪个集合(列表形式,每个元素是一个集合名称)
不使用collection参数(默认)时,系统默认添加到tf.GraphKeys.GLOBAL_VARIABLES
因为
trainable参数默认为True,所以也会自动添加到tf.GraphKeys.TRAINABLE_VARIABLES中
tf.get_variable函数同tf.Variable()一样,默认添加到tf.GraphKeys.GLOBAL_VARIABLES,且trainable默认为True
使用collection参数时,参数包含哪些集合名,就添加到哪些集合中。
若collection参数中没有tf.GraphKeys.GLOBAL_VARIABLES时,则不添加!
如,创建一个变量a,使其除了默认添加的tf.GraphKeys.GLOBAL_VARIABLES集合外,同时添加到自定义集合thu中:
正确:h = tf.Variable(4.0, collections=['thu', tf.GraphKeys.GLOBAL_VARIABLES])
错误:h = tf.Variable(4.0, collections=['thu'])
下面代码可以验证:
import numpy as np
import tensorflow as tf
# 自己试验的时候,下面两组变量定义,以组为单位注释掉
# a = tf.get_variable(name='lxq1', initializer=tf.random_uniform([2, 3])) # 默认添加到tf.GraphKeys.GLOBAL_VARIABLES
# b = tf.get_variable(name='lxq2', initializer=tf.random_uniform([2, 3]), collections=[tf.GraphKeys.GLOBAL_VARIABLES]) # 指定添加到tf.GraphKeys.GLOBAL_VARIABLES,效果同上
# c = tf.get_variable(name='lxq3', initializer=tf.random_uniform([2, 3]), collections=['thu']) # 指定添加到'thu'集合,不添加到tf.GraphKeys.GLOBAL_VARIABLES
# d = tf.get_variable(name='lxq4', initializer=tf.random_uniform([2, 3]), collections=['pku', tf.GraphKeys.GLOBAL_VARIABLES]) # 同时添加到两个集合中
e = tf.Variable(1.0) # 默认添加到tf.GraphKeys.GLOBAL_VARIABLES
f = tf.Variable(2.0, collections=[tf.GraphKeys.GLOBAL_VARIABLES]) # 指定添加到tf.GraphKeys.GLOBAL_VARIABLES,效果同上
g = tf.Variable(3.0, collections=['thu']) # 指定添加到'thu'集合,不添加到tf.GraphKeys.GLOBAL_VARIABLES
h = tf.Variable(4.0, collections=['pku', tf.GraphKeys.GLOBAL_VARIABLES]) # 同时添加到两个集合中
with tf.Session() as ss:
ss.run(tf.global_variables_initializer())
print(len(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)))
print(len(tf.get_collection('thu')))
print(len(tf.get_collection('pku')))
3
1
1