Обратное распространение через несколько прямых проходов - PullRequest
1 голос
/ 03 августа 2020

В обычном бэкпропе мы один раз продвигаемся вперед, вычисляем градиенты, а затем применяем их для обновления весов. Но предположим, что мы wi sh для прямого продвижения дважды и обратного распространения через оба , и применяем градиенты только тогда (сначала пропустите).

Предположим следующее :

x = tf.Variable([2.])
w = tf.Variable([4.])

with tf.GradientTape(persistent=True) as tape:
    w.assign(w * x)
    y = w * w  # w^2 * x
print(tape.gradient(y, x))  # >>None

Из docs , tf.Variable - это объект с сохранением состояния , который блокирует градиенты , а его веса составляют tf.Variable s.

Примеры - дифференцируемое пристальное внимание (в отличие от RL) или просто передача скрытого состояния между слоями в последующих проходах вперед, как на диаграмме ниже. Ни TF, ни Keras не имеют поддержки API-уровня для градиентов с сохранением состояния, включая RNN s, которые сохраняют только тензор состояния с сохранением состояния; градиент не течет за пределы одной партии.

Как это можно сделать?

image

1 Ответ

1 голос
/ 03 августа 2020

Нам нужно будет тщательно применить tf.while_loop; from help(TensorArray):

Этот класс предназначен для использования с итерационными примитивами Dynami c, такими как while_loop и map_fn. Он поддерживает обратное распространение градиента через специальные зависимости потока управления.

Таким образом, мы стремимся записать al oop так, чтобы все выходные данные, через которые мы должны распространяться, записывались в TensorArray. Код, выполняющий это, и его высокоуровневое описание ниже. Внизу проверяющий пример.


Описание :

  • Код заимствован из K.rnn, переписан для простоты и релевантность
  • Для лучшего понимания я предлагаю проверить K.rnn, SimpleRNNCell.call и RNN.call.
  • model_rnn имеет несколько лишних проверок ради случая 3; свяжет более чистую версию
  • Идея следующая: мы проходим по сети сначала снизу вверх, затем слева направо и записываем все вперед перейти на одиночный TensorArray под одиночный tf.while_loop; это гарантирует, что TF кэширует тензорные операции на всем протяжении для обратного распространения.

from tensorflow.python.util import nest
from tensorflow.python.ops import array_ops, tensor_array_ops
from tensorflow.python.framework import ops


def model_rnn(model, inputs, states=None, swap_batch_timestep=True):
    def step_function(inputs, states):
        out = model([inputs, *states], training=True)
        output, new_states = (out if isinstance(out, (tuple, list)) else
                              (out, states))
        return output, new_states

    def _swap_batch_timestep(input_t):
        # (samples, timesteps, channels) -> (timesteps, samples, channels)
        # iterating dim0 to feed (samples, channels) slices expected by RNN
        axes = list(range(len(input_t.shape)))
        axes[0], axes[1] = 1, 0
        return array_ops.transpose(input_t, axes)

    if swap_batch_timestep:
        inputs = nest.map_structure(_swap_batch_timestep, inputs)

    if states is None:
        states = (tf.zeros(model.inputs[0].shape, dtype='float32'),)
    initial_states = states
    input_ta, output_ta, time, time_steps_t = _process_args(model, inputs)

    def _step(time, output_ta_t, *states):
        current_input = input_ta.read(time)
        output, new_states = step_function(current_input, tuple(states))

        flat_state = nest.flatten(states)
        flat_new_state = nest.flatten(new_states)
        for state, new_state in zip(flat_state, flat_new_state):
            if isinstance(new_state, ops.Tensor):
                new_state.set_shape(state.shape)

        output_ta_t = output_ta_t.write(time, output)
        new_states = nest.pack_sequence_as(initial_states, flat_new_state)
        return (time + 1, output_ta_t) + tuple(new_states)

    final_outputs = tf.while_loop(
        body=_step,
        loop_vars=(time, output_ta) + tuple(initial_states),
        cond=lambda time, *_: tf.math.less(time, time_steps_t))

    new_states = final_outputs[2:]
    output_ta = final_outputs[1]
    outputs = output_ta.stack()
    return outputs, new_states


def _process_args(model, inputs):
    time_steps_t = tf.constant(inputs.shape[0], dtype='int32')

    # assume single-input network (excluding states)
    input_ta = tensor_array_ops.TensorArray(
        dtype=inputs.dtype,
        size=time_steps_t,
        tensor_array_name='input_ta_0').unstack(inputs)

    # assume single-input network (excluding states)
    # if having states, infer info from non-state nodes
    output_ta = tensor_array_ops.TensorArray(
        dtype=model.outputs[0].dtype,
        size=time_steps_t,
        element_shape=model.outputs[0].shape,
        tensor_array_name='output_ta_0')

    time = tf.constant(0, dtype='int32', name='time')
    return input_ta, output_ta, time, time_steps_t

Примеры и проверка :

один и тот же ввод дважды, что позволяет проводить определенные сравнения с сохранением состояния и без него; результаты также сохраняются для разных входных данных.

  • Случай 0 : control; другие случаи должны соответствовать этому.
  • Случай 1 : сбой; градиенты не совпадают, хотя выходы и потери совпадают. Backprop терпит неудачу при загрузке половинной последовательности.
  • Случай 2 : градиенты соответствуют случаю 1. Может показаться, что мы использовали только один tf.while_loop, но SimpleRNN использует один из своих для 3 временных шага и записывает в TensorArray, который отбрасывается; это не пойдет. Обходной путь - реализовать SimpleRNN logi c самостоятельно.
  • Случай 3 : идеальное совпадение.

Обратите внимание, что не существует такой вещи, как RNN с отслеживанием состояния ячейка; Statefulness реализован в базовом классе RNN, и мы воссоздали его в model_rnn. То же самое и с любым другим слоем - подача одного шага среза за раз для каждого прямого прохода.

import random
import numpy as np
import tensorflow as tf

from tensorflow.keras.layers import Input, SimpleRNN, SimpleRNNCell
from tensorflow.keras.models import Model

def reset_seeds():
    random.seed(0)
    np.random.seed(1)
    tf.compat.v1.set_random_seed(2)  # graph-level seed
    tf.random.set_seed(3)  # global seed

def print_report(case, model, outs, loss, tape, idx=1):
    print("\nCASE #%s" % case)
    print("LOSS", loss)
    print("GRADS:\n", tape.gradient(loss, model.layers[idx].weights[0]))
    print("OUTS:\n", outs)


#%%# Make data ###############################################################
reset_seeds()
x0 = y0 = tf.constant(np.random.randn(2, 3, 4))
x0_2 = y0_2 = tf.concat([x0, x0], axis=1)
x00  = y00  = tf.stack([x0, x0], axis=0)

#%%# Case 0: Complete forward pass; control case #############################
reset_seeds()
ipt = Input(batch_shape=(2, 6, 4))
out = SimpleRNN(4, return_sequences=True)(ipt)
model0 = Model(ipt, out)
model0.compile('sgd', 'mse')
#%%#############################################################
with tf.GradientTape(persistent=True) as tape:
    outs = model0(x0_2, training=True)
    loss = model0.compiled_loss(y0_2, outs)
print_report(0, model0, outs, loss, tape)

#%%# Case 1: Two passes, stateful RNN, direct feeding ########################
reset_seeds()
ipt = Input(batch_shape=(2, 3, 4))
out = SimpleRNN(4, return_sequences=True, stateful=True)(ipt)
model1 = Model(ipt, out)
model1.compile('sgd', 'mse')
#%%#############################################################
with tf.GradientTape(persistent=True) as tape:
    outs0 = model1(x0, training=True)
    tape.watch(outs0)  # cannot even diff otherwise
    outs1 = model1(x0, training=True)
    tape.watch(outs1)
    outs = tf.concat([outs0, outs1], axis=1)
    tape.watch(outs)
    loss = model1.compiled_loss(y0_2, outs)
print_report(1, model1, outs, loss, tape)

#%%# Case 2: Two passes, stateful RNN, model_rnn #############################
reset_seeds()
ipt = Input(batch_shape=(2, 3, 4))
out = SimpleRNN(4, return_sequences=True, stateful=True)(ipt)
model2 = Model(ipt, out)
model2.compile('sgd', 'mse')
#%%#############################################################
with tf.GradientTape(persistent=True) as tape:
    outs, _ = model_rnn(model2, x00, swap_batch_timestep=False)
    outs = tf.concat(list(outs), axis=1)
    loss = model2.compiled_loss(y0_2, outs)
print_report(2, model2, outs, loss, tape)

#%%# Case 3: Single pass, stateless RNN, model_rnn ###########################
reset_seeds()
ipt  = Input(batch_shape=(2, 4))
sipt = Input(batch_shape=(2, 4))
out, state = SimpleRNNCell(4)(ipt, sipt)
model3 = Model([ipt, sipt], [out, state])
model3.compile('sgd', 'mse')
#%%#############################################################
with tf.GradientTape(persistent=True) as tape:
    outs, _ = model_rnn(model3, x0_2)
    outs = tf.transpose(outs, (1, 0, 2))
    loss = model3.compiled_loss(y0_2, outs)
print_report(3, model3, outs, loss, tape, idx=2)

Вертикальный поток : мы проверили горизонтальность , по времени - обратное распространение; как насчет вертикальной?

Для этого мы реализуем стековую RNN с отслеживанием состояния; результаты ниже. Все выходные данные на моей машине, здесь .

Настоящим мы проверили как вертикальный , так и горизонтальный обратное распространение с сохранением состояния. Это может быть использовано для реализации произвольно сложной логики прямой передачи c с правильной обратной связью. Прикладной пример здесь .

#%%# Case 4: Complete forward pass; control case ############################
reset_seeds()
ipt = Input(batch_shape=(2, 6, 4))
x   = SimpleRNN(4, return_sequences=True)(ipt)
out = SimpleRNN(4, return_sequences=True)(x)
model4 = Model(ipt, out)
model4.compile('sgd', 'mse')
#%%
with tf.GradientTape(persistent=True) as tape:
    outs = model4(x0_2, training=True)
    loss = model4.compiled_loss(y0_2, outs)
print("=" * 80)
print_report(4, model4, outs, loss, tape, idx=1)
print_report(4, model4, outs, loss, tape, idx=2)

#%%# Case 5: Two passes, stateless RNN; model_rnn ############################
reset_seeds()
ipt = Input(batch_shape=(2, 6, 4))
out = SimpleRNN(4, return_sequences=True)(ipt)
model5a = Model(ipt, out)
model5a.compile('sgd', 'mse')

ipt  = Input(batch_shape=(2, 4))
sipt = Input(batch_shape=(2, 4))
out, state = SimpleRNNCell(4)(ipt, sipt)
model5b = Model([ipt, sipt], [out, state])
model5b.compile('sgd', 'mse')
#%%
with tf.GradientTape(persistent=True) as tape:
    outs = model5a(x0_2, training=True)
    outs, _ = model_rnn(model5b, outs)
    outs = tf.transpose(outs, (1, 0, 2))
    loss = model5a.compiled_loss(y0_2, outs)
print_report(5, model5a, outs, loss, tape)
print_report(5, model5b, outs, loss, tape, idx=2)
...