Экстракт Cell State LSTM Keras - PullRequest
0 голосов
/ 05 июля 2018

Мне было интересно, можно ли будет извлечь последнее состояние ячейки LSTM в Керасе после тренировки модели. Например, в этой простой модели LSTM:

number_of_dimensions = 128
number_of_examples = 123456

input_ = Input(shape = (10,100,))
lstm, hidden, cell = CuDNNLSTM(units = number_of_dimensions, return_state=True)(input_)

dense = Dense(num_of_classes, activation='softmax')(lstm)

model = Model(inputs = input_, outputs = dense)
parallel_model = multi_gpu_model(model, gpus=2)
parallel_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['acc'])

# fit the model
parallel_model.fit(X1, onehot_encoded, epochs=100, verbose=1, batch_size = 128, validation_split = 0.2)

Я попытался напечатать 'cell', но результат был

tf.Tensor 'cu_dnnlstm_2/strided_slice_17:0' shape=(?, 128) dtype=float32 

Я хотел бы получить состояние ячейки в виде массива фигуры (number_of_examples, number_of_dimensions) или (123456, 128). Возможно ли сделать это керасом?

Спасибо!

Ответы [ 2 ]

0 голосов
/ 06 июля 2018

Опция, которая может вас заинтересовать, - это сохранение веса модели в файле hdf5:

model.save_weights('my_model_weights.h5')

(ref: https://keras.io/getting-started/faq/#how-can-i-save-a-keras-model)

Затем используйте средство просмотра HDF, такое как пакет Java HDFView: https://support.hdfgroup.org/products/java/hdfview/

Я считаю, что вы можете экспортировать данные в CSV для импорта в Numpy, например.

0 голосов
/ 06 июля 2018

Предполагая, что вы используете TensorFlow в качестве бэкэнда, вы можете специально запустить cell в рамках сеанса TensorFlow. Например:

from keras.layers import LSTM, Input, Dense
from keras.models import Model
import keras.backend as K
import numpy as np

number_of_dimensions = 128
number_of_examples = 123456

input_ = Input(shape=(10, 100,))
lstm, hidden, cell = LSTM(units=number_of_dimensions, return_state=True)(input_)
dense = Dense(10, activation='softmax')(lstm)
model = Model(inputs=input_, outputs=dense)

with K.get_session() as sess:
    x = np.zeros((number_of_examples, 10, 100))
    cell_state = sess.run(cell, feed_dict={input_: x})
    print(cell_state.shape)
...