Загрузка модели и использование ее для обучения другой модели в Tensorflow - PullRequest
0 голосов
/ 14 ноября 2018

Я обучил модель как Color_Model в Tensorflow, и она отлично работает. Я хочу использовать эту обученную модель для обучения другой модели, как Motion_Model. На самом деле выход Color_Model входит в Motion_Model, помогает обучению Motion_Model. Но проблема в том, что я не знаю, как загрузить график Color_Model и настроить график Motion_Model так, чтобы Tensorflow знал, что они разделены. Я изменил имя весов в Motion_Model, чтобы у них не было конфликта имен.

Вот часть кода для загрузки и обучения:

with tf.Session() as sess:
  sess.run(init_op) 
  coord = tf.train.Coordinator()
  threads = tf.train.start_queue_runners(coord=coord)   
  ###Loaing the color model 
  new_saver = tf.train.import_meta_graph('./Color_Model/Deep_CNN_Color_Arch16.ckpt-44.meta')

  new_saver.restore(sess,tf.train.latest_checkpoint('./Color_Model/'))
  graph = tf.get_default_graph()
  X = graph.get_tensor_by_name("X:0")
  Y = graph.get_tensor_by_name("Y:0")
  phase = graph.get_tensor_by_name("phase:0")
  A7 = graph.get_tensor_by_name("Finalo:0")
  ##########################
  ###Training phase
  for step in range(1, iterations+1):
     ###Getting the training data batch
     img = sess.run([image])    
     X_temp = img[0][:,:,:,0:8]
     Y_temp = img[0][:,:,:,8:9]
     X_temp = X_temp.astype(np.float32)/255
     Y_temp = Y_temp.astype(np.float32)/255
     ###Getting the color model result
     output = sess.run([A7], feed_dict = {X: X_temp[:,:,:,5:8], Y: Y_temp, phase: False})
     ###Training the motion model       
     _, c, outputM = sess.run([optimizer, costM, MN_out], feed_dict = {XM: X_temp[:,:,:,0:5], YM: Y_temp, phaseM: True, ZM: output})

Как вы можете видеть, первый "sess.run" запускает Color_Model для получения выходных данных, а второй "sess.run" получает эти выходные данные и передает их в Motion_Model для обучения.

Но когда я запускаю этот код, я получаю следующую ошибку:

    Traceback (most recent call last):
    File "/home/hamidreza/venv/lib/python3.5/site- 
    packages/tensorflow/python/client/session.py", line 1292, in _do_call
    return fn(*args)
    File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1277, in _run_fn
options, feed_dict, fetch_list, target_list, run_metadata)
    File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1367, in _call_tf_sessionrun
run_metadata)
    tensorflow.python.framework.errors_impl.NotFoundError: Key WM1 not found in checkpoint
 [[{{node save/RestoreV2}} = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, ..., DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]

 During handling of the above exception, another exception occurred:

 Traceback (most recent call last):
 File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1538, in restore
{self.saver_def.filename_tensor_name: save_path})
 File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 887, in run
run_metadata_ptr)
 File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1110, in _run
feed_dict_tensor, options, run_metadata)
 File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1286, in _do_run
run_metadata)
 File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1308, in _do_call
raise type(e)(node_def, op, message)
 tensorflow.python.framework.errors_impl.NotFoundError: Key WM1 not found in checkpoint
 [[{{node save/RestoreV2}} = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, ..., DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]

 Caused by op 'save/RestoreV2', defined at:
 File "Detection_Model1.py", line 52, in <module>
 saver = tf.train.Saver()
 File "/home/hamidreza/venv/lib/python3.5/site- 
 packages/tensorflow/python/training/saver.py", line 1094, in __init__
 self.build()
 File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1106, in build
 self._build(self._filename, build_save=True, build_restore=True)
 File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1143, in _build
 build_save=build_save, build_restore=build_restore)
 File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 787, in _build_internal
 restore_sequentially, reshape)
 File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 406, in _AddRestoreOps
restore_sequentially)
 File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 854, in bulk_restore
return io_ops.restore_v2(filename_tensor, names, slices, dtypes)
 File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/ops/gen_io_ops.py", line 1466, in restore_v2
shape_and_slices=shape_and_slices, dtypes=dtypes, name=name)
 File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
op_def=op_def)
 File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/util/deprecation.py", line 488, in new_func
return func(*args, **kwargs)
 File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 3272, in create_op
op_def=op_def)
 File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1768, in __init__
self._traceback = tf_stack.extract_stack()


 NotFoundError (see above for traceback): Restoring from checkpoint failed. This is most likely due to a Variable name or other graph key that is missing from the checkpoint. 
 Please ensure that you have not altered the graph expected based on the checkpoint. Original error:

 Key WM1 not found in checkpoint
 [[{{node save/RestoreV2}} = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, ..., DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]

Я совершенно уверен, что это смешивает графики, потому что WM1 - это вес первого слоя в Motion_Model, и на самом деле ошибка говорит о том, что он не может найти его в контрольной точке, которая ссылается на Color_Model, я полагаю. Я очень признателен, если вы поможете мне с этой проблемой.

...