Изменение графика тензорного потока и возобновление обучения - PullRequest
3 голосов
/ 05 мая 2020

Я пытаюсь загрузить предварительно тренированные веса M Cnet модель и возобновить тренировку. Предоставленная здесь предварительно обученная модель обучается с параметрами K=4, T=7. Но мне нужна модель с параметрами K=4,T=1. Вместо того, чтобы начинать тренировку с нуля, я хочу загрузить веса из этой предварительно обученной модели. Но поскольку график изменился, я не могу загрузить предварительно обученную модель.

InvalidArgumentError (see above for traceback): Restoring from checkpoint failed. This is most likely due to a mismatch between the current graph and the graph from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:

Assign requires shapes of both tensors to match. lhs shape= [5,5,15,64] rhs shape= [5,5,33,64]
     [[node save/Assign_13 (defined at /media/nagabhushan/Data02/SNB/IISc/Research/04_Gaming_Video_Prediction/Workspace/VideoPrediction/Literature/01_MCnet/src/snb/mcnet.py:108) ]]

Можно ли загрузить предварительно обученную модель с новым графиком?

Что я пробовали :
Раньше я хотел перенести предварительно обученную модель со старой версии tensorflow на более новую. Я получил этот ответ в SO, который помог мне перенести модель. Идея состоит в том, чтобы создать новый график и загрузить переменные, существующие в новом графике, из сохраненного.

with tf.Session() as sess:
    _ = MCNET(image_size=[240, 320], batch_size=8, K=4, T=1, c_dim=3, checkpoint_dir=None, is_train=True)
    tf.global_variables_initializer().run(session=sess)

    ckpt_vars = tf.train.list_variables(model_path.as_posix())
    ass_ops = []
    for dst_var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
        for (ckpt_var, ckpt_shape) in ckpt_vars:
            if dst_var.name.split(":")[0] == ckpt_var and dst_var.shape == ckpt_shape:
                value = tf.train.load_variable(model_path.as_posix(), ckpt_var)
                ass_ops.append(tf.assign(dst_var, value))

    # Assign the variables
    sess.run(ass_ops)
    saver = tf.train.Saver()
    saver.save(sess, save_path.as_posix())

Я попробовал то же самое здесь, и это сработало, то есть я получил новую обученную модель для K=4,T=1. Но я не уверен, действительно ли это! Я имею в виду, будут ли веса иметь смысл? Это правильный способ сделать это?

Информация о модели :
M Cnet - это модель, используемая для прогнозирования видео, т.е. с учетом K прошедших кадров, она может предсказать следующие T кадров.

Любая помощь приветствуется

1 Ответ

4 голосов
/ 05 мая 2020
Модель

M Cnet имеет генератор и дискриминатор. Генератор основан на LSTM и, следовательно, нет проблем с загрузкой весов, варьируя количество временных шагов T. Однако дискриминатор, как они его закодировали, сверточный. Чтобы применить сверточные слои к видео, они объединяют кадры по размеру канала. С K=4,T=7 вы получите видео длиной 11 с 3 каналами. Когда вы объединяете их по размеру канала, вы получаете изображение с 33 каналами. Когда они определяют дискриминатор, они определяют первый уровень дискриминатора, имеющий 33 входных каналов, и, следовательно, веса имеют аналогичный размер. Но с K=4,T=1 длина видео составляет 5, а конечное изображение имеет 15 каналов, поэтому веса будут иметь 15 каналов. Это ошибка несоответствия, которую вы наблюдаете. Чтобы исправить это, вы можете выбрать веса только для первых 15 каналов ( из-за отсутствия лучшего способа, я могу придумать ). Код ниже:

with tf.Session() as sess:
    _ = MCNET(image_size=[240, 320], batch_size=8, K=4, T=1, c_dim=3, checkpoint_dir=None, is_train=True)
    tf.global_variables_initializer().run(session=sess)

    ckpt_vars = tf.train.list_variables(model_path.as_posix())
    ass_ops = []
    for dst_var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
        for (ckpt_var, ckpt_shape) in ckpt_vars:
            if dst_var.name.split(":")[0] == ckpt_var:
                if dst_var.shape == ckpt_shape:
                    value = tf.train.load_variable(model_path.as_posix(), ckpt_var)
                    ass_ops.append(tf.assign(dst_var, value))
                else:
                    value = tf.train.load_variable(model_path.as_posix(), ckpt_var)
                    if dst_var.shape[2] <= value.shape[2]:
                        adjusted_value = value[:, :, :dst_var.shape[2]]
                    else:
                        adjusted_value = numpy.random.random(dst_var.shape)
                        adjusted_value[:, :, :value.shape[2], ...] = value
                    ass_ops.append(tf.assign(dst_var, adjusted_value))

    # Assign the variables
    sess.run(ass_ops)
    saver = tf.train.Saver()
    saver.save(sess, save_path.as_posix())  
...