Вывести скрытое состояние LSTM RNN в тензорном потоке. js - PullRequest
1 голос
/ 10 июля 2020

Я пытаюсь изменить модель, используемую в lstm-text-generation примере официального tf js -examples , чтобы вывести его скрытое состояние. Для этого я сначала изменил использование tf.sequential на tf.model в createModel из model.js.

Исходная строка 37-47 файла model.js

// (...)
const model = tf.sequential();
for (let i = 0; i < lstmLayerSizes.length; ++i) {
  const lstmLayerSize = lstmLayerSizes[i];
  model.add(tf.layers.lstm({
    units: lstmLayerSize,
    returnSequences: i < lstmLayerSizes.length - 1,
    inputShape: i === 0 ? [sampleLen, charSetSize] : undefined
  }));
}
model.add(
  tf.layers.dense({units: charSetSize, activation: 'softmax'}));
// (...)

Настроенные строки того же файла model.js.

// (...)
// This replaces `inputShape` of the first LSTM layer
const inputs = tf.input({ shape: [sampleLen, charSetSize] });

let outputs = inputs;
for (let i = 0; i < lstmLayerSizes.length; ++i) {
  const lstmLayerSize = lstmLayerSizes[i];
  const layer = tf.layers.lstm({
    units: lstmLayerSize,
    returnSequences: i < lstmLayerSizes.length - 1,
  });

  outputs = layer.apply(outputs);
}

outputs = tf.layers
  .dense({ units: charSetSize, activation: 'softmax' })
  .apply(outputs);

const model = tf.model({ inputs, outputs });
// (...)

Насколько я понял, мне нужно было бы добавить returnState: true к параметрам слоя lstm для достижения моей цели.

// (...)
// Inside of the for loop
const layer = tf.layers.lstm({
  units: lstmLayerSize,
  returnState: true, // <-- Here
  returnSequences: i < lstmLayerSizes.length - 1,
});
// (...)

Однако это не работает с последним слоем модели, так как он ожидает только один входной тензор, но установка returnState на true изменяет слои lstm на выходные 3 тензора [output, hiddenState, cellState].

Я не уверен, как достичь моей цели сохранить плотный слой для softmax activation и при этом печатать скрытое состояние.

Еще одна вещь, которая мне не ясна, заключается в том, что мне, вероятно, также потребуется настроить выходные данные метода .fit, поскольку он ожидает целевые данные в качестве второго аргумента y при обучении модели.

Любая помощь очень признателен.

...