“冻结”张量流中的一些变量/作用域:stop_gradient与传递变量以最小化
python
tensorflow
7
0

我正在尝试实现Adversarial NN ,它需要在交替训练小批处理中“冻结”图形的一个或另一部分。即有两个子网:G和D。

G( Z ) ->  Xz
D( X ) ->  Y

G损失函数取决于D[G(Z)], D[X]

首先,我需要在固定所有G参数的情况下训练D中的参数,然后在固定D的参数中训练G中的参数。在第一种情况下的损耗函数将在第二种情况下为负损耗函数,并且更新将必须应用于第一子网或第二子网的参数。

我看到tensorflow具有tf.stop_gradient函数。为了训练D(下游)子网,我可以使用此功能来阻止梯度流向

 Z -> [ G ] -> tf.stop_gradient(Xz) -> [ D ] -> Y

tf.stop_gradient注释非常简洁,没有任何内联示例(示例seq2seq.py太长且不易阅读),但看起来必须在创建图形时调用它。 这是否意味着如果我想以交替批方式阻止/取消阻止梯度流,则需要重新创建并重新初始化图模型?

而且似乎还无法通过tf.stop_gradient阻止通过G(上游)网络流动的梯度,对吗?

作为替代方案,我看到可以将变量列表传递给opt_op = opt.minimize(cost, <list of variables>) ,如果可以获取每个变量范围内的所有变量,这将是一个简单的解决方案子网。 可以为tf.scope获取<list of variables>吗?

参考资料:
Stack Overflow
收藏
评论
共 3 个回答
高赞 时间 活跃

@mrry的答案是完全正确的,而且可能比我要建议的更笼统。但是我认为更简单的方法是直接将python引用直接传递给var_list

W = tf.Variable(...)
C = tf.Variable(...)
Y_est = tf.matmul(W,C)
loss = tf.reduce_sum((data-Y_est)**2)
optimizer = tf.train.AdamOptimizer(0.001)

# You can pass the python object directly
train_W = optimizer.minimize(loss, var_list=[W])
train_C = optimizer.minimize(loss, var_list=[C])

我在这里有一个独立的示例: https : //gist.github.com/ahwillia/8cedc710352eb919b684d8848bc2df3a

收藏
评论

您可能要考虑的另一种选择是可以在变量上设置trainable = False。这意味着不会通过培训对其进行修改。

tf.Variable(my_weights, trainable=False)
收藏
评论

正如您在问题中提到的,实现此目标的最简单方法是使用对opt.minimize(cost, ...)单独调用来创建两个优化程序操作。默认情况下,优化器将使用tf.trainable_variables()所有变量。如果要将变量过滤到特定范围,则可以对tf.get_collection()使用可选的scope参数,如下所示:

optimizer = tf.train.AdagradOptimzer(0.01)

first_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                     "scope/prefix/for/first/vars")
first_train_op = optimizer.minimize(cost, var_list=first_train_vars)

second_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                      "scope/prefix/for/second/vars")                     
second_train_op = optimizer.minimize(cost, var_list=second_train_vars)
收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号