tf.app.run()如何工作?
python
python-3.x
tensorflow
10
0

tf.app.run()如何翻译演示?

tensorflow/models/rnn/translate/translate.py ,有一个对tf.app.run()的调用。如何处理?

if __name__ == "__main__":
    tf.app.run() 
参考资料:
Stack Overflow
收藏
评论
共 5 个回答
高赞 时间 活跃
if __name__ == "__main__":

表示当前文件在shell下执行,而不是作为模块导入。

tf.app.run()

您可以通过文件app.py看到

def run(main=None, argv=None):
  """Runs the program with an optional 'main' function and 'argv' list."""
  f = flags.FLAGS

  # Extract the args from the optional `argv` list.
  args = argv[1:] if argv else None

  # Parse the known flags from that list, or from the command
  # line otherwise.
  # pylint: disable=protected-access
  flags_passthrough = f._parse_flags(args=args)
  # pylint: enable=protected-access

  main = main or sys.modules['__main__'].main

  # Call the main function, passing through any arguments
  # to the final program.
  sys.exit(main(sys.argv[:1] + flags_passthrough))

让我们逐行中断:

flags_passthrough = f._parse_flags(args=args)

这样可确保您通过命令行传递的参数有效,例如python my_model.py --data_dir='...' --max_iteration=10000实际上,此功能是基于python标准argparse模块实现的。

main = main or sys.modules['__main__'].main

=右边的第一个main是当前函数run(main=None, argv=None)的第一个参数。 sys.modules['__main__']表示当前正在运行的文件(例如, my_model.py )。

因此有两种情况:

  1. 您在my_model.py没有main功能,那么您必须调用tf.app.run(my_main_running_function)

  2. 您在my_model.py有一个main功能。 (通常是这种情况。)

最后一行:

sys.exit(main(sys.argv[:1] + flags_passthrough))

确保使用解析后的参数正确调用main(argv)my_main_running_function(argv)函数。

收藏
评论

这只是一个非常快速的包装程序,可以处理标志解析,然后分派到您自己的主程序。参见代码

收藏
评论

2.0兼容的答案 :如果您想在Tensorflow 2.0使用tf.app.run() ,我们应该使用以下命令,

tf.compat.v1.app.run() ,也可以使用tf_upgrade_v21.x代码转换为2.0

收藏
评论

tf.app没有什么特别的。这只是一个通用的入口点脚本

使用可选的“ main”功能和“ argv”列表运行程序。

它与神经网络无关,它只是调用主函数,并传递给它的任何参数。

收藏
评论

简而言之, tf.app.run()的工作是首先设置全局标志以供以后使用,例如:

from tensorflow.python.platform import flags
f = flags.FLAGS

然后使用一组参数运行自定义的main函数。

例如,在TensorFlow NMT代码库中,用于训练/推理的程序执行的第一个入口点就是从这一点开始(请参见下面的代码)

if __name__ == "__main__":
  nmt_parser = argparse.ArgumentParser()
  add_arguments(nmt_parser)
  FLAGS, unparsed = nmt_parser.parse_known_args()
  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

使用argparse解析参数后,使用tf.app.run()运行函数“ main”,其定义如下:

def main(unused_argv):
  default_hparams = create_hparams(FLAGS)
  train_fn = train.train
  inference_fn = inference.inference
  run_main(FLAGS, default_hparams, train_fn, inference_fn)

因此,在设置了供全局使用的标志之后, tf.app.run()只需运行您以argv作为参数传递给它的main功能。

PS:正如萨尔瓦多·达利(Salvador Dali)的回答所说,我猜这只是一个很好的软件工程实践,尽管我不确定TensorFlow是否会执行比使用常规CPython进行的main函数优化的运行。

收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题