Обратное распространение после выбора подмножества элементов в последовательности - PullRequest
0 голосов
/ 26 марта 2020

У меня есть серия последовательностей размера (B, N, D), где B - количество последовательностей в пакете, N - количество элементов в последовательности, а D - размерность каждого элемента. Я хочу обратить внимание на элементы в каждой последовательности, чтобы найти два элемента в последовательности, для которых промежуточные элементы в среднем объединяются и распространяются на последующий слой моей нейронной сети. Конечно, я хотел бы узнать векторы внимания, поэтому мне интересно, возможно ли то, что я предлагаю, используя TF 2.0 таким образом, чтобы разрешить обратное распространение?

Минимальный рабочий пример:

B = 3  # batch size (number of sequences)
N = 10  # elements in sequence
D = 4  # dimension of each element

with tf.GradientTape() as tape:
    data = tf.Variable(initial_value=tf.random_normal_initializer(mean=0.0, stddev=0.05, seed=None)(shape=(B,N,D)))

    att1 = tf.Variable(initial_value=tf.random_normal_initializer(mean=0.0, stddev=0.05, seed=None)(shape=(D,)))
    att2 = tf.Variable(initial_value=tf.random_normal_initializer(mean=0.0, stddev=0.05, seed=None)(shape=(D,)))

    soft_start = tf.nn.softmax(tf.tensordot(data, att1, axes=[2, 0]), axis=1)
    start = tf.argmax(soft_start, axis=1)

    # To keep the example minimal, I ignore computing "soft_end" here.
    # In practice, I will mask the elements in "data" to the left of "start" and perform a similar softmax 
    # to compute "soft_end" using "att2" and only elements in "data" including and after the one identified 
    # by "start".
    end = tf.minimum(start + 3, N)

    result = []
    for i in range(B):
        result.append(tf.reduce_mean(tf.gather(data[i], tf.range(start[i], end[i])), axis=0))
    result = tf.stack(results)

    grads = tape.gradient(results, (att1, att2, data))
    print([g is not None for g in grads])  # [False, False, True]

После запуска этого кода я бы хотел, чтобы градиент result относительно att1 не был None, поэтому я могу обновить att1 с помощью обратного распространения.

...