如何计算张量流模型中可训练参数的总数?
neural-network
tensorflow
5
0

是否有函数调用或其他方法来计算张量流模型中的参数总数?

通过参数,我的意思是:可训练变量的N个暗矢量具有N个参数, NxM矩阵具有N*M参数,依此类推。所以本质上,我想对张量流会话中所有可训练变量的形状尺寸的乘积求和。

参考资料:
Stack Overflow
收藏
评论
共 6 个回答
高赞 时间 活跃

如果您正在考虑自己计算参数数量,那么这两个现有答案很好。如果您的问题更像是“是否有一种简便的方法来分析我的TensorFlow模型?”,我强烈建议您研究tfprof 。它可以分析您的模型,包括计算参数数量。

收藏
评论

我有一个更短的版本,使用numpy的一种解决方案:

np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()])
收藏
评论

如果希望避免使用numpy(许多项目可以忽略它),则:

all_trainable_vars = tf.reduce_sum([tf.reduce_prod(v.shape) for v in tf.trainable_variables()])

这是Julius Kunze先前回答的TF翻译。

与任何TF操作一样,它需要运行会话来评估:

print(sess.run(all_trainable_vars))
收藏
评论

我将介绍我的等效但较短的实现:

def count_params():
    "print number of trainable variables"
    size = lambda v: reduce(lambda x, y: x*y, v.get_shape().as_list())
    n = sum(size(v) for v in tf.trainable_variables())
    print "Model size: %dK" % (n/1000,)
收藏
评论

不知道给出的答案是否真的有效(我发现您需要将dim对象转换为int才能起作用)。这是一个可行的方法,您可以复制粘贴函数并调用它们(也添加了一些注释):

def count_number_trainable_params():
    '''
    Counts the number of trainable variables.
    '''
    tot_nb_params = 0
    for trainable_variable in tf.trainable_variables():
        shape = trainable_variable.get_shape() # e.g [D,F] or [W,H,C]
        current_nb_params = get_nb_params_shape(shape)
        tot_nb_params = tot_nb_params + current_nb_params
    return tot_nb_params

def get_nb_params_shape(shape):
    '''
    Computes the total number of params for a given shap.
    Works for any number of shapes etc [D,F] or [W,H,C] computes D*F and W*H*C.
    '''
    nb_params = 1
    for dim in shape:
        nb_params = nb_params*int(dim)
    return nb_params 
收藏
评论

tf.trainable_variables()循环遍历每个变量的形状。

total_parameters = 0
for variable in tf.trainable_variables():
    # shape is an array of tf.Dimension
    shape = variable.get_shape()
    print(shape)
    print(len(shape))
    variable_parameters = 1
    for dim in shape:
        print(dim)
        variable_parameters *= dim.value
    print(variable_parameters)
    total_parameters += variable_parameters
print(total_parameters)

更新:由于这个答案,我写了一篇文章来澄清Tensorflow中的动态/静态形状: https ://pgaleone.eu/tensorflow/2018/07/28/understanding-tensorflow-tensors-shape-static-dynamic/

收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号