Tensorflow变量范围:如果变量存在则重用
python
tensorflow
6
0

我想要一段代码,如果不存在则在作用域内创建变量,如果已经存在则访问变量。我需要它是相同的代码,因为它将被多次调用。

但是,Tensorflow需要我指定是要创建还是重用该变量,如下所示:

with tf.variable_scope("foo"): #create the first time
    v = tf.get_variable("v", [1])

with tf.variable_scope("foo", reuse=True): #reuse the second time
    v = tf.get_variable("v", [1])

我怎样才能弄清楚是自动创建还是重用它?即,我希望以上两个代码块相同,并运行程序。

参考资料:
Stack Overflow
收藏
评论
共 3 个回答
高赞 时间 活跃

创建新变量且未声明形状或在变量创建过程中违反重用条件时, get_variable()引发ValueError 。因此,您可以尝试以下操作:

def get_scope_variable(scope_name, var, shape=None):
    with tf.variable_scope(scope_name) as scope:
        try:
            v = tf.get_variable(var, shape)
        except ValueError:
            scope.reuse_variables()
            v = tf.get_variable(var)
    return v

v1 = get_scope_variable('foo', 'v', [1])
v2 = get_scope_variable('foo', 'v')
assert v1 == v2

请注意以下内容也适用:

v1 = get_scope_variable('foo', 'v', [1])
v2 = get_scope_variable('foo', 'v', [1])
assert v1 == v2

更新。新的API现在支持自动重用:

def get_scope_variable(scope, var, shape=None):
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        v = tf.get_variable(var, shape)
    return v
收藏
评论

新的AUTO_REUSE选项可以解决问题。

tf.variable_scope API文档 :如果reuse=tf.AUTO_REUSE ,我们将创建变量(如果变量不存在),否则将其返回。

共享变量AUTO_REUSE的基本示例:

def foo():
  with tf.variable_scope("foo", reuse=tf.AUTO_REUSE):
    v = tf.get_variable("v", [1])
  return v

v1 = foo()  # Creates v.
v2 = foo()  # Gets the same, existing v.
assert v1 == v2
收藏
评论

尽管使用“ try ... except ...”子句是可行的,但我认为更优雅和可维护的方法是将变量初始化过程与“重用”过程分开。

def initialize_variable(scope_name, var_name, shape):
    with tf.variable_scope(scope_name) as scope:
        v = tf.get_variable(var_name, shape)
        scope.reuse_variable()

def get_scope_variable(scope_name, var_name):
    with tf.variable_scope(scope_name, reuse=True):
        v = tf.get_variable(var_name)
    return v

由于通常我们只需要初始化变量变量,但是要多次重复使用/共享变量,将两个进程分开会使代码更简洁。同样,通过这种方式,我们不需要每次都通过“ try”子句来检查变量是否已经创建。

收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号