TensorFlow:通过名称获取变量
tensorflow
7
0

使用TensorFlow Python API时,我创建了一个变量(未在构造函数中指定其name ),并且其name属性的值为"Variable_23:0" 。当我尝试使用tf.get_variable("Variable23")选择此变量时,将tf.get_variable("Variable23")一个名为"Variable_23_1:0"的新变量。如何正确选择"Variable_23"而不是创建一个新的?

我要做的是按名称选择变量,然后重新初始化它,以便我可以调整权重。

参考资料:
Stack Overflow
收藏
评论
共 2 个回答
高赞 时间 活跃

通过名称获取变量的最简单方法是在tf.global_variables()集合中进行搜索:

var_23 = [v for v in tf.global_variables() if v.name == "Variable_23:0"][0]

这对于临时重用现有变量非常有效。 共享变量教程中介绍了一种更结构化的方法(当您要在模型的多个部分之间共享变量时)

收藏
评论

get_variable()函数创建一个新变量或返回由get_variable()先前创建的变量。它不会返回使用tf.Variable()创建的变量。这是一个简单的示例:

>>> with tf.variable_scope("foo"):
...   bar1 = tf.get_variable("bar", (2,3)) # create
... 
>>> with tf.variable_scope("foo", reuse=True):
...   bar2 = tf.get_variable("bar")  # reuse
... 

>>> with tf.variable_scope("", reuse=True): # root variable scope
...   bar3 = tf.get_variable("foo/bar") # reuse (equivalent to the above)
... 
>>> (bar1 is bar2) and (bar2 is bar3)
True

如果未使用tf.get_variable()创建变量,则有两种选择。首先,您可以使用tf.global_variables() (如@mrry所示):

>>> bar1 = tf.Variable(0.0, name="bar")
>>> bar2 = [var for var in tf.global_variables() if var.op.name=="bar"][0]
>>> bar1 is bar2
True

或者,您可以像这样使用tf.get_collection()

>>> bar1 = tf.Variable(0.0, name="bar")
>>> bar2 = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="bar")[0]
>>> bar1 is bar2
True

编辑

您也可以使用get_tensor_by_name()

>>> bar1 = tf.Variable(0.0, name="bar")
>>> graph = tf.get_default_graph()
>>> bar2 = graph.get_tensor_by_name("bar:0")
>>> bar1 is bar2
False, bar2 is a Tensor througn convert_to_tensor on bar1. but bar1 equal 
bar2 in value.

回想张量是操作的输出。它与操作名称相同,外加:0 。如果该操作具有多个输出,则它们与该操作具有相同的名称,外加:0:1:2等。

收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号