當前位置 主頁 > 服務器問題 > Linux/apache問題 > 最大化 縮小

    基于TensorFlow中自定義梯度的2種方式

    欄目:Linux/apache問題 時間:2020-02-06 08:32

    前言

    在深度學習中,有時候我們需要對某些節點的梯度進行一些定制,特別是該節點操作不可導(比如階梯除法如 ),如果實在需要對這個節點進行操作,而且希望其可以反向傳播,那么就需要對其進行自定義反向傳播時的梯度。在有些場景,如[2]中介紹到的梯度反轉(gradient inverse)中,就必須在某層節點對反向傳播的梯度進行反轉,也就是需要更改正常的梯度傳播過程,如下圖的 所示。

    在tensorflow中有若干可以實現定制梯度的方法,這里介紹兩種。

    1. 重寫梯度法

    重寫梯度法指的是通過tensorflow自帶的機制,將某個節點的梯度重寫(override),這種方法的適用性最廣。我們這里舉個例子[3].

    符號函數的前向傳播采用的是階躍函數y=sign(x) y = \rm{sign}(x)y=sign(x),如下圖所示,我們知道階躍函數不是連續可導的,因此我們在反向傳播時,將其替代為一個可以連續求導的函數y=Htanh(x) y = \rm{Htanh(x)}y=Htanh(x),于是梯度就是大于1和小于-1時為0,在-1和1之間時是1。

    使用重寫梯度的方法如下,主要是涉及到tf.RegisterGradient()和tf.get_default_graph().gradient_override_map(),前者注冊新的梯度,后者重寫圖中具有名字name='Sign'的操作節點的梯度,用在新注冊的QuantizeGrad替代。

    #使用修飾器,建立梯度反向傳播函數。其中op.input包含輸入值、輸出值,grad包含上層傳來的梯度
    @tf.RegisterGradient("QuantizeGrad")
    def sign_grad(op, grad):
     input = op.inputs[0] # 取出當前的輸入
     cond = (input>=-1)&(input<=1) # 大于1或者小于-1的值的位置
     zeros = tf.zeros_like(grad) # 定義出0矩陣用于掩膜
     return tf.where(cond, grad, zeros) 
     # 將大于1或者小于-1的上一層的梯度置為0
     
    #使用with上下文管理器覆蓋原始的sign梯度函數
    def binary(input):
     x = input
     with tf.get_default_graph().gradient_override_map({"Sign":'QuantizeGrad'}):
     #重寫梯度
      x = tf.sign(x)
     return x
     
    #使用
    x = binary(x)

    其中的def sign_grad(op, grad):是注冊新的梯度的套路,其中的op是當前操作的輸入值/張量等,而grad指的是從反向而言的上一層的梯度。

    通常來說,在tensorflow中自定義梯度,函數tf.identity()是很重要的,其API手冊如下:

    tf.identity(
     input,
     name=None
    )

    其會返回一個形狀和內容都和輸入完全一樣的輸出,但是你可以自定義其反向傳播時的梯度,因此在梯度反轉等操作中特別有用。

    這里再舉個反向梯度[2]的例子,也就是梯度為 而不是

    import tensorflow as tf
    x1 = tf.Variable(1)
    x2 = tf.Variable(3)
    x3 = tf.Variable(6)
    @tf.RegisterGradient('CustomGrad')
    def CustomGrad(op, grad):
    #  tf.Print(grad)
     return -grad
     
    g = tf.get_default_graph()
    oo = x1+x2
    with g.gradient_override_map({"Identity": "CustomGrad"}):
     output = tf.identity(oo)
    grad_1 = tf.gradients(output, oo)
    with tf.Session() as sess:
     sess.run(tf.global_variables_initializer())
     print(sess.run(grad_1))

    因為-grad,所以這里的梯度輸出是[-1]而不是[1]。有一個我們需要注意的是,在自定義函數def CustomGrad()中,返回的值得是一個張量,而不能返回一個參數,比如return 0,這樣會報錯,如:

    下一篇:沒有了
教我怎样炒股