Перезапись методов с помощью миксин-паттерна не работает должным образом - PullRequest
0 голосов
/ 02 ноября 2018

Я пытаюсь представить мод / миксин для проблемы . В частности, я сосредоточен здесь на SpeechRecognitionProblem. Я намерен изменить эту проблему, и поэтому я стараюсь сделать следующее:

class SpeechRecognitionProblemMod(speech_recognition.SpeechRecognitionProblem):

    def hparams(self, defaults, model_hparams):
        SpeechRecognitionProblem.hparams(self, defaults, model_hparams)
        vocab_size = self.feature_encoders(model_hparams.data_dir)['targets'].vocab_size
        p = defaults
        p.vocab_size['targets'] = vocab_size

    def feature_encoders(self, data_dir): 
        # ...

Так что этот мало что делает. Вызывает функцию hparams() из базового класса, а затем изменяет некоторые значения.

Теперь уже есть некоторые готовые проблемы, например, Libri Speech:

@registry.register_problem()
class Librispeech(speech_recognition.SpeechRecognitionProblem):
    # ..

Однако, чтобы применить мои модификации, я делаю это:

@registry.register_problem()
class LibrispeechMod(SpeechRecognitionProblemMod, Librispeech):
    # ..

Это должно, если я не ошибаюсь, переписать все (с идентичными подписями) в Librispeech и вместо этого вызвать функции SpeechRecognitionProblemMod.

Поскольку я смог обучить модель с этим кодом, я предполагаю, что она работает так, как задумано.

Теперь возникает проблема my :

После тренировки хочу сериализовать модель. Это обычно работает. Тем не менее, это не так с моим модом, и я действительно знаю, почему:

В определенный момент вызывается hparams(). Отладка до этого момента покажет мне следующее:

self                  # {LibrispeechMod}
self.hparams          # <bound method SpeechRecognitionProblem.hparams of ..>
self.feature_encoders # <bound method SpeechRecognitionProblemMod.feature_encoders of ..>

self.hparams должно быть <bound method SpeechRecognitionProblemMod.hparams of ..>! Казалось бы, по какой-то причине hparams() из SpeechRecognitionProblem вызывается напрямую вместо SpeechRecognitionProblemMod. Но обратите внимание , что это правильный тип для feature_encoders()!

Дело в том, что я знаю, что это работает во время тренировок. Я вижу, что гиперпараметры (hparams) применяются соответственно просто потому, что имена узлов графа модели меняются в результате моих модификаций.

Есть одна специальность, которую я должен указать. tensor2tensor позволяет динамически загружать t2t_usr_dir, которые являются дополнительными модулями Python, которые загружаются с помощью import_usr_dir. Я также использую эту функцию в моем сценарии сериализации:

if usr_dir:
    logging.info('Loading user dir %s' % usr_dir)
    import_usr_dir(usr_dir)

Это может быть единственным виновником, которого я вижу в данный момент, хотя я не смогу сказать, почему это может вызвать проблему.

Если кто-нибудь увидит что-то, чего нет у меня, я был бы рад получить подсказку, что я здесь делаю неправильно.


Так что за ошибка вы получаете?

Ради полноты, это результат неправильного вызова hparams() метода:

NotFoundError (see above for traceback): Restoring from checkpoint failed.
Key transformer/symbol_modality_256_256/softmax/weights_0 not found in checkpoint

symbol_modality_256_256 неверно. Это должно быть symbol_modality_<vocab-size>_256, где <vocab-size> - это размер словаря, который устанавливается в SpeechRecognitionProblemMod.hparams.

1 Ответ

0 голосов
/ 02 ноября 2018

Итак, это странное поведение произошло из-за того, что я выполнял удаленную отладку и что исходные файлы usr_dir не были правильно синхронизированы. Все работает как задумано, но исходные файлы там, где не совпадают.

Дело закрыто.

...