tf.flags用法

TF中使用flags来定义解析命令行参数,用法类似于Python中的argparse。尤其在我们编写shell脚本训练代码的时候比较方便,比如某shell脚本:

python run_classifier.py \
  --task_name=$TASK_NAME \
  --do_train=true \
  --do_eval=true \
  --data_dir=$GLUE_DATA_DIR/$TASK_NAME \
  --vocab_file=$ALBERT_CONFIG_DIR/vocab.txt \
  --bert_config_file=$ALBERT_CONFIG_DIR/albert_config_tiny.json \
  --init_checkpoint=$ALBERT_TINY_DIR/albert_model.ckpt \
  --max_seq_length=128 \
  --train_batch_size=64 \
  --learning_rate=1e-4 \
  --num_train_epochs=5.0 \
  --output_dir=$CURRENT_DIR/${TASK_NAME}_output/

在运行run_classifier.py脚本的时候可以动态设置task_name, do_train等的值。

tf.flags只能设置int,string,float,bool四种类型的值,用法是:

1. 调用flags = tf.flags;

2. 开始赋值;

3. 运行tf.app.run()

举例说明:

import tensorflow as tf
 
#1、调用tf.flags:  第一个是参数名称,第二个参数是默认值,第三个是参数描述
tf.flags.DEFINE_string('str_name', 'def_v_1',"descrip1")
tf.flags.DEFINE_integer('int_name', 10,"descript2")
tf.flags.DEFINE_boolean('bool_name', False, "descript3")
 
# tf.flags.FLAGS是一个FlagValuesWrapper的实例化变量
FLAGS = tf.flags.FLAGS
 
#必须带参数,否则:'TypeError: main() takes no arguments (1 given)';  main的参数名随意定义,无要求
def main(_): 
  # 这里就是使用参数值的方法
  print(FLAGS.str_name)
  print(FLAGS.int_name)
  print(FLAGS.bool_name)
 
if __name__ == '__main__':
  # 如果使用flags.mark_flag_as_required表示强制指定命令行该参数的内容
  flags.mark_flag_as_required("str_name")
  tf.app.run() #2、执行main函数

 

参考博客:

https://blog.csdn.net/qq_41185868/article/details/82913886