Я пытаюсь представить мод / миксин для проблемы . В частности, я сосредоточен здесь на SpeechRecognitionProblem
. Я намерен изменить эту проблему, и поэтому я стараюсь сделать следующее:
class SpeechRecognitionProblemMod(speech_recognition.SpeechRecognitionProblem):
def hparams(self, defaults, model_hparams):
SpeechRecognitionProblem.hparams(self, defaults, model_hparams)
vocab_size = self.feature_encoders(model_hparams.data_dir)['targets'].vocab_size
p = defaults
p.vocab_size['targets'] = vocab_size
def feature_encoders(self, data_dir):
# ...
Так что этот мало что делает. Вызывает функцию hparams()
из базового класса, а затем изменяет некоторые значения.
Теперь уже есть некоторые готовые проблемы, например, Libri Speech:
@registry.register_problem()
class Librispeech(speech_recognition.SpeechRecognitionProblem):
# ..
Однако, чтобы применить мои модификации, я делаю это:
@registry.register_problem()
class LibrispeechMod(SpeechRecognitionProblemMod, Librispeech):
# ..
Это должно, если я не ошибаюсь, переписать все (с идентичными подписями) в Librispeech
и вместо этого вызвать функции SpeechRecognitionProblemMod
.
Поскольку я смог обучить модель с этим кодом, я предполагаю, что она работает так, как задумано.
Теперь возникает проблема my :
После тренировки хочу сериализовать модель. Это обычно работает. Тем не менее, это не так с моим модом, и я действительно знаю, почему:
В определенный момент вызывается hparams()
. Отладка до этого момента покажет мне следующее:
self # {LibrispeechMod}
self.hparams # <bound method SpeechRecognitionProblem.hparams of ..>
self.feature_encoders # <bound method SpeechRecognitionProblemMod.feature_encoders of ..>
self.hparams
должно быть <bound method SpeechRecognitionProblemMod.hparams of ..>
! Казалось бы, по какой-то причине hparams()
из SpeechRecognitionProblem
вызывается напрямую вместо SpeechRecognitionProblemMod
. Но обратите внимание , что это правильный тип для feature_encoders()
!
Дело в том, что я знаю, что это работает во время тренировок. Я вижу, что гиперпараметры (hparams) применяются соответственно просто потому, что имена узлов графа модели меняются в результате моих модификаций.
Есть одна специальность, которую я должен указать. tensor2tensor
позволяет динамически загружать t2t_usr_dir
, которые являются дополнительными модулями Python, которые загружаются с помощью import_usr_dir
. Я также использую эту функцию в моем сценарии сериализации:
if usr_dir:
logging.info('Loading user dir %s' % usr_dir)
import_usr_dir(usr_dir)
Это может быть единственным виновником, которого я вижу в данный момент, хотя я не смогу сказать, почему это может вызвать проблему.
Если кто-нибудь увидит что-то, чего нет у меня, я был бы рад получить подсказку, что я здесь делаю неправильно.
Так что за ошибка вы получаете?
Ради полноты, это результат неправильного вызова hparams()
метода:
NotFoundError (see above for traceback): Restoring from checkpoint failed.
Key transformer/symbol_modality_256_256/softmax/weights_0 not found in checkpoint
symbol_modality_256_256
неверно. Это должно быть symbol_modality_<vocab-size>_256
, где <vocab-size>
- это размер словаря, который устанавливается в SpeechRecognitionProblemMod.hparams
.