在TensorFlow训练的模型中获取某些权重的值
tensorflow
6
0

我已经使用TensorFlow训练了ConvNet模型,并且希望在图层中获得特定的权重。例如,在torch7中,我只需访问model.modules[2].weights 。获得第2层的权重。如何在TensorFlow中做同样的事情?

参考资料:
Stack Overflow
收藏
评论
共 3 个回答
高赞 时间 活跃

因此,如果逐步执行此代码,您将首先获得已使用/可训练变量的列表。然后,您可以将它们排序在一个列表中,在其中将权重矩阵/列表排序为变量名,例如,如何处理该信息。

vars = tf.trainable_variables()
print(vars) #some infos about variables...
vars_vals = sess.run(vars)
for var, val in zip(vars, vars_vals):
    print("var: {}, value: {}".format(var.name, val)) #...or sort it in a list....
收藏
评论

2.0兼容答案 :如果我们使用Keras Sequential API构建模型,则可以使用以下代码获取模型的权重:

!pip install tensorflow==2.1

from tf.keras import Sequential

model = Sequential()

model.add(Conv2D(filters=conv1_fmaps, kernel_size=conv1_ksize,
                         strides=conv1_stride, padding=conv1_pad,
                         activation=tf.nn.relu, input_shape=(height, width, channels),
                    data_format='channels_last'))

model.add(MaxPool2D(pool_size = (2,2), strides= (2,2), padding="VALID"))

model.add(Dropout(0.25))

model.add(Flatten())

model.add(Dense(units = 32, activation = 'relu'))

model.add(Dense(units = 10, activation = 'softmax'))

model.summary()

print(model.trainable_variables) 

最后print(model.trainable_variables)语句print(model.trainable_variables)将返回模型的权重,如下所示:

    [<tf.Variable 'conv2d/kernel:0' shape=(3, 3, 1, 32) dtype=float32>,
 <tf.Variable 'conv2d/bias:0' shape=(32,) dtype=float32>, <tf.Variable 
'dense/kernel:0' shape=(6272, 32) dtype=float32>, <tf.Variable 'dense/bias:0' 
shape=(32,) dtype=float32>, <tf.Variable 'dense_1/kernel:0' shape=(32, 10) 
dtype=float32>, <tf.Variable 'dense_1/bias:0' shape=(10,) dtype=float32>]
收藏
评论

在TensorFlow中,训练后的权重由tf.Variable对象表示。如果您自己创建了tf.Variable称为v ,则可以通过调用sess.run(v) (其中sesstf.Session )来将其作为NumPy数组tf.Session

如果当前没有指向tf.Variable的指针,则可以通过调用tf.trainable_variables()获得当前图中可训练变量的列表。此函数返回当前图形中所有可训练tf.Variable对象的列表,并且您可以通过匹配v.name属性来选择所需的v.name 。例如:

# Desired variable is called "tower_2/filter:0".
var = [v for v in tf.trainable_variables() if v.name == "tower_2/filter:0"][0]
收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号