Вы можете разделить матрицу:
Если вы видите здесь , матрица ядра TF1 имеет форму (input_shape[-1], self.units)
.
Допустим, у вас есть 20 входов и 128 узлов в слое LSTM
input_units=20
layer_size = 128
input_placeholder = tf.placeholder(tf.float32, [None, None, input_units])
lstm_layers = [tf.nn.rnn_cell.LSTMCell(layer_size), tf.nn.rnn_cell.LSTMCell(layer_size)]
stacked = tf.contrib.rnn.MultiRNNCell(lstm_layers)
output, state = tf.nn.dynamic_rnn(stacked, input_placeholder, dtype=tf.float32)
Ваши обучаемые параметры будут иметь следующие формы:
[<tf.Variable 'rnn/multi_rnn_cell/cell_0/lstm_cell/kernel:0' shape=(148, 512) dtype=float32_ref>,
<tf.Variable 'rnn/multi_rnn_cell/cell_0/lstm_cell/bias:0' shape=(512,) dtype=float32_ref>,
<tf.Variable 'rnn/multi_rnn_cell/cell_1/lstm_cell/kernel:0' shape=(256, 512) dtype=float32_ref>,
<tf.Variable 'rnn/multi_rnn_cell/cell_1/lstm_cell/bias:0' shape=(512,) dtype=float32_ref>]
В TF 1.0, ядрои рекуррентное ядро TF 2.0 сцеплено (см. здесь )
def build(self, input_shape):
self.kernel = self.add_weight(shape=(input_shape[-1], self.units),
initializer='uniform',
name='kernel')
self.recurrent_kernel = self.add_weight(
shape=(self.units, self.units),
initializer='uniform',
name='recurrent_kernel')
self.built = True
В этой новой версии у вас теперь есть две разные весовые матрицы.
input_placeholder = tf.placeholder(tf.float32, [None, None, input_units])
lstm_layers = [tf.keras.layers.LSTMCell(layer_size),tf.keras.layers.LSTMCell(layer_size)]
stacked = tf.keras.layers.StackedRNNCells(lstm_layers)
output = tf.keras.layers.RNN(stacked, return_sequences=True, return_state=True, dtype=tf.float32)
Таким образом,ваши обучаемые параметры:
<tf.Variable 'rnn_1/while/stacked_rnn_cells_1/kernel:0' shape=(20, 512) dtype=float32>,
<tf.Variable 'rnn_1/while/stacked_rnn_cells_1/recurrent_kernel:0' shape=(128, 512) dtype=float32>,
<tf.Variable 'rnn_1/while/stacked_rnn_cells_1/bias:0' shape=(512,) dtype=float32>,
<tf.Variable 'rnn_1/while/stacked_rnn_cells_1/kernel_1:0' shape=(128, 512) dtype=float32>,
<tf.Variable 'rnn_1/while/stacked_rnn_cells_1/recurrent_kernel_1:0' shape=(128, 512) dtype=float32>,
<tf.Variable 'rnn_1/while/stacked_rnn_cells_1/bias_1:0' shape=(512,) dtype=float32>]