tf.flags用法
TF中使用flags来定义解析命令行参数,用法类似于Python中的argparse。尤其在我们编写shell脚本训练代码的时候比较方便,比如某shell脚本:
python run_classifier.py \
--task_name=$TASK_NAME \
--do_train=true \
--do_eval=true \
--data_dir=$GLUE_DATA_DIR/$TASK_NAME \
--vocab_file=$ALBERT_CONFIG_DIR/vocab.txt \
--bert_config_file=$ALBERT_CONFIG_DIR/albert_config_tiny.json \
--init_checkpoint=$ALBERT_TINY_DIR/albert_model.ckpt \
--max_seq_length=128 \
--train_batch_size=64 \
--learning_rate=1e-4 \
--num_train_epochs=5.0 \
--output_dir=$CURRENT_DIR/${TASK_NAME}_output/
在运行run_classifier.py脚本的时候可以动态设置task_name, do_train等的值。
tf.flags只能设置int,string,float,bool四种类型的值,用法是:
1. 调用flags = tf.flags;
2. 开始赋值;
3. 运行tf.app.run()
举例说明:
import tensorflow as tf
#1、调用tf.flags: 第一个是参数名称,第二个参数是默认值,第三个是参数描述
tf.flags.DEFINE_string('str_name', 'def_v_1',"descrip1")
tf.flags.DEFINE_integer('int_name', 10,"descript2")
tf.flags.DEFINE_boolean('bool_name', False, "descript3")
# tf.flags.FLAGS是一个FlagValuesWrapper的实例化变量
FLAGS = tf.flags.FLAGS
#必须带参数,否则:'TypeError: main() takes no arguments (1 given)'; main的参数名随意定义,无要求
def main(_):
# 这里就是使用参数值的方法
print(FLAGS.str_name)
print(FLAGS.int_name)
print(FLAGS.bool_name)
if __name__ == '__main__':
# 如果使用flags.mark_flag_as_required表示强制指定命令行该参数的内容
flags.mark_flag_as_required("str_name")
tf.app.run() #2、执行main函数
参考博客: