Тензор потока восстановить только некоторые переменные из контрольной точки - PullRequest
0 голосов
/ 02 мая 2018

После проверки контрольной точки (назовем ее моделью 1) я получил следующий список имен переменных (сокращен для простоты):

var_list = ["ex1_model/fc2/b",
"ex1_model/fc2/b/Adam",
"ex1_model/fc2/b/Adam_1",
"ex1_model/fc2/w",
"ex1_model/fc2/w/Adam"]

Предположим, у меня гораздо большая модель 2 и я хочу инициализировать ее части значениями из модели 1.

Получить переменные из имен, как описано здесь (потому что я не нашел простой способ сделать это):

def get_vars_by_name(names):
    return [v for v in tf.global_variables() if v.name in names]

Сборка модели 2 и заставка для восстановления:

logits = build_model(inputs)
saver = tf.train.Saver(var_list=get_vars_by_name(var_list))

В

saver.restore(sess, tf.train.latest_checkpoint(checkpoints_dir))

Я получаю сообщение об ошибке:

"ex1_model/fc2/w/Adam" [...] raise ValueError("No variables to save")

Пожалуйста, помогите мне найти ошибку, которую я сделал. Я также был бы благодарен за более простой способ, потому что это ужасно. Спасибо.

1 Ответ

0 голосов
/ 02 мая 2018

Простой способ решить эту проблему - угадать, должна ли переменная быть восстановлена ​​или нет.

def ignore_name(name):
    if name.endswith('/Adam') or name.endswith('/Adam_1'):
        return True
    return False

Вы должны быть в состоянии использовать эту идею напрямую через

def get_vars_by_name(names):
    return [v for v in tf.global_variables() if v.name in names and not ignore_name(v.name)]

Это даже позволяет обучать модель с использованием ADAM, а затем переключаться на SDG или наоборот.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...