Я создаю чат-бота, обученного на Cornell Movie Dialogs Corpus с использованием NMT .
Я частично основываю свой код с https://github.com/bshao001/ChatLearner иhttps://github.com/chiphuyen/stanford-tensorflow-tutorials/tree/master/assignments/chatbot
Во время обучения я печатаю случайный выходной ответ, поданный на декодер из пакета, и соответствующий ответ, который, по прогнозам моей модели, будет наблюдать за процессом обучения.
Моя проблема: После всего лишь 4 итераций обучения модель учится выводить токен EOS (<\s>
) для каждого временного шага.Он всегда выводит это как ответ (определяемый с использованием argmax логитов), даже когда обучение продолжается.Время от времени модель редко выдает в качестве ответа серию периодов.
Я также печатаю 10 лучших логит-значений во время обучения (не только argmax), чтобы увидеть, может быть, где-то там правильное слово, но, похоже, оно предсказывает наиболее распространенные слова в словаре (например, я, вы, ?, .).Даже эти 10 лучших слов не сильно меняются во время тренировок.
Я удостоверился в правильности подсчета длины входной последовательности для кодера и декодера и добавил соответственно токены SOS (<s>
) и EOS (также используются для заполнения).Я также выполняю маскирование в расчете потерь.
Вот пример вывода:
Итерация обучения 1:
Decoder Input: <s> sure . sure . <\s> <\s> <\s> <\s> <\s> <\s> <\s>
<\s> <\s>
Predicted Answer: wildlife bakery mentality mentality administration
administration winston winston winston magazines magazines magazines
magazines
...
Учебная итерация 4:
Decoder Input: <s> i guess i had it coming . let us call it settled .
<\s> <\s> <\s> <\s> <\s>
Predicted Answer: <\s> <\s> <\s> <\s> <\s> <\s> <\s> <\s> <\s> <\s>
<\s> <\s> <\s> <\s> <\s> <\s> <\s> <\s>
После еще нескольких итераций он основывается только на прогнозировании EOS (и редко на некоторых периодах))
Я не уверен, что может быть причиной этой проблемы, и застрял на этом некоторое время.Любая помощь будет принята с благодарностью!
Обновление: Я позволил ей тренироваться более ста тысяч итераций, и она по-прежнему выводит только EOS (и случайные периоды).Потеря тренировки также не уменьшается после нескольких итераций (она остается на уровне 47 с начала)