Неожиданные результаты со слоем CuDNNLSTM (вместо LSTM) - PullRequest
5 голосов
/ 28 июня 2019

Я опубликовал этот вопрос как выпуск в Keras 'Github, но подумал, что он может охватить более широкую аудиторию здесь.


Информация о системе

  • Написал ли я собственный код (в отличие от использования примера каталога): Минимальное изменение официального руководства по Keras
  • Платформа и распространение ОС (например, Linux Ubuntu 16.04): Ubuntu 18.04.2 LTS
  • Серверная часть TensorFlow (да / нет): да
  • Версия TensorFlow: 1.13.1
  • Версия Keras: 2.2.4
  • Версия Python: 3.6.5
  • Версия CUDA / cuDNN: 10,1
  • Модель и память графического процессора: Tesla K80 11G

Опишите текущее поведение
Я выполняюкод из учебника Seq2Seq .Единственное изменение, которое я сделал, - это поменять слои LSTM на CuDNNLSTM .Что происходит, так это то, что модель предсказывает фиксированный результат для любого ввода, который я ему даю.Когда я запускаю исходный код, я получаю разумные результаты.

Опишите ожидаемое поведение
См. Предыдущий раздел.

Код для воспроизведения проблемы
Взято из здесь .Просто замените LSTM на CuDNNLSTM.


Любые идеи приветствуются.

1 Ответ

3 голосов
/ 06 июля 2019

Итак, здесь есть две проблемы.
Использование CuDNNLSTM и parameter tuning.
По сути, сетевые перегрузки в вашем наборе данных приводят к тому, что выводом является только одно предложение для каждого ввода. Это не вина CuDNNLSTM и LSTM.

Во-первых,
CuDNN имеет немного отличную математику от обычной LSTM, что делает его совместимым с Cuda и работает быстрее. LSTM занимает 11 секунд для запуска на англ. Файле для того же кода, который вы использовали, а CuDNNLSTM - 1 секунда для каждой эпохи.

В CuDNNLSTM time_major параметр установлен на false. По этой причине сеть перегружается. Вы можете проверить это здесь .
Для небольших наборов данных, таких как анг-хин или анг-маратхи, ясно видно, что val-loss увеличивается после 30 эпох. Нет смысла управлять сетью больше, когда ваш network loss уменьшается, а val_loss увеличивается. Случай с LSTM такой же.

Здесь вам нужно param tuning для небольших наборов данных.

Вот несколько ссылок, которые могут помочь:

  1. Eng-Mar
  2. Руководство по переводу Pytorch
  3. Аналогичный вопрос 2 и Аналогичный вопрос 2
  4. NMT-keras
...