Tensorflow获取范围内的所有变量
tensorflow
5
0

我在这样的一定范围内创建了一些变量:

with tf.variable_scope("my_scope"):
  createSomeVariables()
  ...

然后,我想获取“ my_scope”中所有变量的列表,以便将其传递给优化器。什么是正确的方法?

参考资料:
Stack Overflow
收藏
评论
共 2 个回答
高赞 时间 活跃

用户正确指出您需要tf.get_collection() 。我将给出一个简单的示例如何执行此操作:

import tensorflow as tf

with tf.name_scope('some_scope1'):
    a = tf.Variable(1, 'a')
    b = tf.Variable(2, 'b')
    c = tf.Variable(3, 'c')

with tf.name_scope('some_scope2'):
    d = tf.Variable(4, 'd')
    e = tf.Variable(5, 'e')
    f = tf.Variable(6, 'f')

h = tf.Variable(8, 'h')

for i in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='some_scope'):
    print i   # i.name if you want just a name

请注意,您可以提供任何graphKeys,并且scope是正则表达式:

scope:(可选。)如果提供,则将结果列表过滤为仅包含其名称属性使用re.match匹配的项目。如果提供了范围,则没有名称属性的项目将永远不会返回,而choice或re.match表示没有特殊标记的范围将按前缀过滤。

因此,如果您传递“ some_scope”,则将获得6个变量。

收藏
评论

我认为你想要tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES,scope ='my_scope') 。这将获得作用域中的所有变量。

要传递给优化器,您不需要所有变量,而只需要可训练的变量。它们也保存在默认集合tf.GraphKeys.TRAINABLE_VARIABLES

收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号