Как изменить весовой формат LSTMCell с тензорного потока на tf.keras - PullRequest
1 голос
/ 11 июня 2019

У меня есть какой-то старый код из тензорного потока, который я хочу заставить работать для тензорного потока 2 / tf.keras. Я хотел бы сохранить тот же вес LSTM, но не могу понять, как преобразовать формат.

У меня есть старые веса, сохраненные в файле контрольных точек, а также они сохранены в CSV-файлах.

Мой старый код выглядит примерно так:

input_placeholder = tf.placeholder(tf.float32, [None, None, input_units])
lstm_layers = [tf.nn.rnn_cell.LSTMCell(layer_size), tf.nn.rnn_cell.LSTMCell(layer_size)]
stacked = tf.contrib.rnn.MultiRNNCell(lstm_layers)
features, state = tf.nn.dynamic_rnn(stacked, input_placeholder, dtype=tf.float32)

А мой новый код выглядит примерно так:

input_placeholder = tf.placeholder(tf.float32, [None, None, input_units])
lstm_layers = [tf.keras.layers.LSTMCell(layer_size),tf.keras.layers.LSTMCell(layer_size)]
stacked = tf.keras.layers.StackedRNNCells(lstm_layers)
features = stacked(input_placeholder)
... #later in the code
features.set_weights(previous_weights)

Старый уклон, кажется, соответствует новому уклону. Старое ядро ​​выглядит как соединение ядра и рекуррентного ядра. Я могу загрузить предыдущие_весы в модель (явно проверил правильно загруженные веса), однако тесты, которые мне не удалось получить, дают тот же результат. Копаясь в исходном коде, ядра, похоже, имеют другой формат.

Можно ли рассчитать ядро ​​и recurrent_kernel (tf.keras), используя эти старые сохраненные веса ядра?

Ссылки, если они полезны:

https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/ops/rnn_cell_impl.py

https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/keras/layers/recurrent.py

1 Ответ

1 голос
/ 11 июня 2019

Вы можете разделить матрицу:

Если вы видите здесь , матрица ядра TF1 имеет форму (input_shape[-1], self.units).

Допустим, у вас есть 20 входов и 128 узлов в слое LSTM

input_units=20
layer_size = 128
input_placeholder = tf.placeholder(tf.float32, [None, None, input_units])
lstm_layers = [tf.nn.rnn_cell.LSTMCell(layer_size),     tf.nn.rnn_cell.LSTMCell(layer_size)]
stacked = tf.contrib.rnn.MultiRNNCell(lstm_layers)
output, state = tf.nn.dynamic_rnn(stacked, input_placeholder, dtype=tf.float32)

Ваши обучаемые параметры будут иметь следующие формы:

[<tf.Variable 'rnn/multi_rnn_cell/cell_0/lstm_cell/kernel:0' shape=(148, 512) dtype=float32_ref>,
 <tf.Variable 'rnn/multi_rnn_cell/cell_0/lstm_cell/bias:0' shape=(512,) dtype=float32_ref>,
 <tf.Variable 'rnn/multi_rnn_cell/cell_1/lstm_cell/kernel:0' shape=(256, 512) dtype=float32_ref>,
 <tf.Variable 'rnn/multi_rnn_cell/cell_1/lstm_cell/bias:0' shape=(512,) dtype=float32_ref>]

В TF 1.0, ядрои рекуррентное ядро ​​TF 2.0 сцеплено (см. здесь )

def build(self, input_shape):
    self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
                              initializer='uniform',
                              name='kernel')
   self.recurrent_kernel = self.add_weight(
    shape=(self.units, self.units),
    initializer='uniform',
    name='recurrent_kernel')
    self.built = True

В этой новой версии у вас теперь есть две разные весовые матрицы.

input_placeholder = tf.placeholder(tf.float32, [None, None, input_units])
lstm_layers = [tf.keras.layers.LSTMCell(layer_size),tf.keras.layers.LSTMCell(layer_size)]
stacked = tf.keras.layers.StackedRNNCells(lstm_layers)
output = tf.keras.layers.RNN(stacked, return_sequences=True, return_state=True, dtype=tf.float32)

Таким образом,ваши обучаемые параметры:

<tf.Variable 'rnn_1/while/stacked_rnn_cells_1/kernel:0' shape=(20, 512) dtype=float32>,
<tf.Variable 'rnn_1/while/stacked_rnn_cells_1/recurrent_kernel:0' shape=(128, 512) dtype=float32>,
<tf.Variable 'rnn_1/while/stacked_rnn_cells_1/bias:0' shape=(512,) dtype=float32>,
<tf.Variable 'rnn_1/while/stacked_rnn_cells_1/kernel_1:0' shape=(128, 512) dtype=float32>,
<tf.Variable 'rnn_1/while/stacked_rnn_cells_1/recurrent_kernel_1:0' shape=(128, 512) dtype=float32>,
<tf.Variable 'rnn_1/while/stacked_rnn_cells_1/bias_1:0' shape=(512,) dtype=float32>]
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...