是否可以使可训练变量不可训练?
tensorflow
8
0

我在范围内创建了一个可训练的变量 。后来,我输入了相同的作用域,将作用域设置为reuse_variables ,并使用get_variable检索相同的变量。但是,我不能将变量的可训练属性设置为False 。我的get_variable行是这样的:

weight_var = tf.get_variable('weights', trainable = False)

但是变量'weights'仍然在tf.trainable_variables的输出中。

我可以使用get_variable将共享变量的trainable标志设置为False吗?

我要这样做的原因是,我试图重用模型中从VGG net预训练的低级过滤器,并且希望像以前一样构建图,检索权重变量,并分配VGG过滤器值设置为权重变量,然后在接下来的训练步骤中将其固定。

参考资料:
Stack Overflow
收藏
评论
共 3 个回答
高赞 时间 活跃

看文档和代码后,我能找到一种方法,从删除变量TRAINABLE_VARIABLES

这是发生了什么:

  • 第一次tf.get_variable('weights', trainable=True) ,该变量将添加到TRAINABLE_VARIABLES列表中。
  • 第二次调用tf.get_variable('weights', trainable=False) ,您将获得相同的变量,但参数trainable=False无效,因为该变量已存在于TRAINABLE_VARIABLES列表中(并且无法从那里删除它

第一个解决方案

当调用minimize的优化(见的方法文档。 ),你可以通过一个var_list=[...]与变量参数要优化。

例如,如果要冻结除最后两层之外的VGG的所有层,可以在var_list传递最后两层的权重。

第二解决方案

您可以使用tf.train.Saver()保存变量并稍后将其还原(请参阅本教程 )。

  • 首先,使用所有可训练变量训练整个VGG模型。您可以通过调用saver.save(sess, "/path/to/dir/model.ckpt")将它们保存在检查点文件中。
  • 然后(在另一个文件中)使用不可训练的变量训练第二个版本。您加载先前存储在saver.restore(sess, "/path/to/dir/model.ckpt")的变量。

(可选)您可以决定仅将某些变量保存在检查点文件中。有关更多信息,请参阅文档

收藏
评论

当您只想训练或优化预训练网络的某些层时,这就是您需要知道的。

TensorFlow的minimize方法采用可选参数var_list ,这是通过反向传播调整的变量列表。

如果不指定var_list ,那么优化器可以调整图中的任何TF变量。当您在var_list指定某些变量时,TF var_list所有其他变量保持不变。

这是jonbruner及其合作者使用的脚本示例。

tvars = tf.trainable_variables()
g_vars = [var for var in tvars if 'g_' in var.name]
g_trainer = tf.train.AdamOptimizer(0.0001).minimize(g_loss, var_list=g_vars)

这将找到他们先前定义的所有变量名称中都带有“ g_”的变量,将它们放入列表中,然后对它们运行ADAM优化器。

您可以在Quora上找到相关的答案

收藏
评论

为了消除从训练的变量列表中的变量,可以先通过访问集合: trainable_collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)那里, trainable_collection包含可训练变量的集合的引用。如果从此列表中弹出元素,例如进行trainable_collection.pop(0) ,则将从可训练变量中删除相应的变量,因此将不训练该变量。

尽管此方法适用于pop ,但我仍在努力寻找正确使用带有正确参数的remove的方法,因此我们不依赖于变量的索引。

编辑:鉴于您在图中具有变量的名称(您可以通过检查图protobuf或使用Tensorboard来更轻松地获得它),可以使用它遍历可训练变量列表,然后删除可训练集合中的变量。例如:说我想和名称的变量"batch_normalization/gamma:0""batch_normalization/beta:0" 进行训练,但他们已经加入到TRAINABLE_VARIABLES集合。我能做的是:

#gets a reference to the list containing the trainable variables
trainable_collection = tf.get_collection_ref(tf.GraphKeys.TRAINABLE_VARIABLES)
variables_to_remove = list()
for vari in trainable_collection:
    #uses the attribute 'name' of the variable
    if vari.name=="batch_normalization/gamma:0" or vari.name=="batch_normalization/beta:0":
        variables_to_remove.append(vari)
for rem in variables_to_remove:
    trainable_collection.remove(rem)

`这样将成功从集合中删除这两个变量,并且不再对它们进行训练。

收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号