如何在Tensorflow中仅使用Python制作自定义激活功能?
deep-learning
neural-network
python
tensorflow
6
0

假设您需要创建一个激活功能,而仅使用预定义的张量流构建块是不可能的,那该怎么办?

因此在Tensorflow中可以创建自己的激活功能。但这很复杂,您必须用C ++编写它并重新编译整个tensorflow [1] [2]

有没有更简单的方法?

参考资料:
Stack Overflow
收藏
评论
共 2 个回答
高赞 时间 活跃

为什么不简单地使用tensorflow中已经可用的功能来构建新功能呢?

对于答案中spiky函数,可能如下所示

def spiky(x):
    r = tf.floormod(x, tf.constant(1))
    cond = tf.less_equal(r, tf.constant(0.5))
    return tf.where(cond, r, tf.constant(0))

我认为这实际上要容易得多(甚至不需要计算任何梯度),除非您想做真正的奇特的事情,否则我几乎无法想象tensorflow不会为构建高度复杂的激活函数提供构建模块。

收藏
评论

就在这里!

信用:很难找到信息并使之正常工作,但这是从此处此处找到的原理和代码复制的示例。

要求:在我们开始之前,有两个要求才能成功。首先,您需要能够将激活作为函数写入numpy数组。其次,您必须能够将该函数的派生形式编写为Tensorflow中的函数(更简单),或者在最坏的情况下将其编写为numpy数组中的函数。

编写激活功能:

因此,让我们举一个我们想使用激活函数的函数为例:

def spiky(x):
    r = x % 1
    if r <= 0.5:
        return r
    else:
        return 0

外观如下: 尖峰激活

第一步是使其成为numpy函数,这很容易:

import numpy as np
np_spiky = np.vectorize(spiky)

现在我们应该写它的导数。

激活梯度:在我们的情况下很容易,如果x mod 1 <0.5则为1,否则为0。所以:

def d_spiky(x):
    r = x % 1
    if r <= 0.5:
        return 1
    else:
        return 0
np_d_spiky = np.vectorize(d_spiky)

现在是使用TensorFlow函数的困难部分。

将一个numpy fct转换为一个tensorflow fct:我们首先将np_d_spiky变成一个tensorflow函数。张量流tf.py_func(func, inp, Tout, stateful=stateful, name=name) [doc]中有一个函数,它将任何numpy函数转换为张量流函数,因此我们可以使用它:

import tensorflow as tf
from tensorflow.python.framework import ops

np_d_spiky_32 = lambda x: np_d_spiky(x).astype(np.float32)


def tf_d_spiky(x,name=None):
    with tf.name_scope(name, "d_spiky", [x]) as name:
        y = tf.py_func(np_d_spiky_32,
                        [x],
                        [tf.float32],
                        name=name,
                        stateful=False)
        return y[0]

tf.py_func作用于张量列表(并返回张量列表),这就是为什么我们有[x] (并返回y[0] )的原因。 stateful选项是告诉tensorflow函数是否总是为相同的输入提供相同的输出(有状态= False),在这种情况下,tensorflow可以只是tensorflow图,这是我们的情况,在大多数情况下可能是这种情况。此时要注意的一件事是numpy使用float64但tensorflow使用float32因此您需要先将函数转换为使用float32然后才能将其转换为tensorflow函数,否则tensorflow会抱怨。这就是为什么我们需要首先制作np_d_spiky_32

渐变呢?仅执行上述操作的问题是,即使我们现在有了tf_d_spiky的tensorflow版本的np_d_spiky ,但如果我们np_d_spiky ,我们也不能将其用作激活函数,因为tensorflow不知道如何计算该梯度功能。

破解以获取渐变:如上述tf.RegisterGradient所述,有一种破解方法是使用tf.RegisterGradient [doc]tf.Graph.gradient_override_map [doc]定义函数的渐变。从harpone复制代码,我们可以修改tf.py_func函数以使其同时定义渐变:

def py_func(func, inp, Tout, stateful=True, name=None, grad=None):

    # Need to generate a unique name to avoid duplicates:
    rnd_name = 'PyFuncGrad' + str(np.random.randint(0, 1E+8))

    tf.RegisterGradient(rnd_name)(grad)  # see _MySquareGrad for grad example
    g = tf.get_default_graph()
    with g.gradient_override_map({"PyFunc": rnd_name}):
        return tf.py_func(func, inp, Tout, stateful=stateful, name=name)

现在我们差不多完成了,唯一的事情是我们需要传递给上述py_func函数的grad函数需要采用一种特殊的形式。它需要接受一个操作,以及操作之前的先前梯度,并在操作之后向后传播这些梯度。

梯度函数:因此,对于我们尖刻的激活函数,我们将采用以下方法:

def spikygrad(op, grad):
    x = op.inputs[0]

    n_gr = tf_d_spiky(x)
    return grad * n_gr  

激活函数只有一个输入,这就是x = op.inputs[0] 。如果操作有很多输入,我们将需要返回一个元组,每个输入一个梯度。例如,如果该操作是ab相对于所述梯度a+1且相对于b-1 ,所以我们将具有return +1*grad,-1*grad 。请注意,我们需要返回输入的tensorflow函数,这就是为什么需要tf_d_spikynp_d_spiky不能起作用的原因,因为它无法作用于tensorflow张量。另外,我们可以使用张量流函数编写导数:

def spikygrad2(op, grad):
    x = op.inputs[0]
    r = tf.mod(x,1)
    n_gr = tf.to_float(tf.less_equal(r, 0.5))
    return grad * n_gr  

将所有内容组合在一起:现在我们拥有了所有片段,我们可以将它们全部组合在一起:

np_spiky_32 = lambda x: np_spiky(x).astype(np.float32)

def tf_spiky(x, name=None):

    with tf.name_scope(name, "spiky", [x]) as name:
        y = py_func(np_spiky_32,
                        [x],
                        [tf.float32],
                        name=name,
                        grad=spikygrad)  # <-- here's the call to the gradient
        return y[0]

现在我们完成了。我们可以对其进行测试。

测试:

with tf.Session() as sess:

    x = tf.constant([0.2,0.7,1.2,1.7])
    y = tf_spiky(x)
    tf.initialize_all_variables().run()

    print(x.eval(), y.eval(), tf.gradients(y, [x])[0].eval())

[0.2 0.69999999 1.20000005 1.70000005] [0.2 0. 0.20000005 0.] [1. 0. 1. 0.]

成功!

收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号