тензор потока: перегонка из модели на основе Resnet в модель на основе VGG - PullRequest
0 голосов
/ 30 ноября 2018

Модель, заданная https://github.com/iro-cp/FCRN-DepthPrediction, показывает хороший результат в прогнозировании глубины.Он в основном содержит слои понижающей дискретизации (на основе resnet50) и слои повышающей дискретизации (восходящая проекция).Чтобы настроить модель, я хочу использовать знания для обучения небольшой сети (VGG16 + up-projection).

Я импортирую сеть учителя и сеть ученика в одном сценарии.Я запускаю один сеанс и вычисляю потерю mse между мягкой целью учителя и мягкой целью студента и использую оптимизатор SGD tf для обратного распространения.Однако потеря не сходится.Я думаю, что это из-за моего использования тензорного потока.

Любой может сказать мне, достаточно ли только одного сеанса для одновременной работы сети преподавателей и студентов в тензорном потоке.Если да, то к какой проблеме это может привести.

      import vgg16 as student
      import resnet50 as teacher
      ...

      images, depths, invalid_depths = dataset.csv_inputs(TRAIN_FILE)
      keep_conv = tf.placeholder(tf.float32)

      studentnet = student.VGG16(images, keep_conv, True)
      logits_stu = studentnet.softout

      teachernet = teacher.Network({'data': images}, BATCH_SIZE, 1, False, trainable = False)
      logits_tea = teachernet.softout

      loss = op.loss(logits_tea, logits_stu)
      train_op = op.train(loss, global_step, BATCH_SIZE)

      sess = tf.Session(config=tf.ConfigProto(log_device_placement=LOG_DEVICE_PLACEMENT))
      sess.run(tf.global_variables_initializer())

      # define saver for studentnet
      saver_en = tf.train.Saver(studentnet.encoder_params)

      #load checkpoints for teachernet
      saver_tea = tf.train.Saver(teacher_params)
      saver_tea.restore(sess, teacher_data_path)

      coord = tf.train.Coordinator()
      threads = tf.train.start_queue_runners(sess=sess, coord=coord)
      for step in range(MAX_STEPS):
          for i in range(125):
              _, loss_value, tea_val, stu_val, images_val = sess.run([train_op, loss, logits_tea, logits_stu, images], feed_dict={keep_conv: 1})
              if i % 25 == 0:
                  print("%s: %d[epoch]: %d[iteration]: train loss %f" % (datetime.now(), step, i, loss_value))
                  assert not np.isnan(loss_value), 'Model diverged with loss = NaN'

          if step % 5 == 0 or (step * 1) == MAX_STEPS:
              en_checkpoint_path = EN_DIR + "/%d_model.ckpt" % (step)
              saver_en.save(sess, en_checkpoint_path)

      coord.request_stop()
      coord.join(threads)
      sess.close()
      ...

      def loss(labels, predictions):

          mse = tf.losses.mean_squared_error(labels, predictions, scope="mse_per_example")
          return mse

Это неполный код.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...