Я обучил модель как Color_Model в Tensorflow, и она отлично работает. Я хочу использовать эту обученную модель для обучения другой модели, как Motion_Model. На самом деле выход Color_Model входит в Motion_Model, помогает обучению Motion_Model. Но проблема в том, что я не знаю, как загрузить график Color_Model и настроить график Motion_Model так, чтобы Tensorflow знал, что они разделены. Я изменил имя весов в Motion_Model, чтобы у них не было конфликта имен.
Вот часть кода для загрузки и обучения:
with tf.Session() as sess:
sess.run(init_op)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
###Loaing the color model
new_saver = tf.train.import_meta_graph('./Color_Model/Deep_CNN_Color_Arch16.ckpt-44.meta')
new_saver.restore(sess,tf.train.latest_checkpoint('./Color_Model/'))
graph = tf.get_default_graph()
X = graph.get_tensor_by_name("X:0")
Y = graph.get_tensor_by_name("Y:0")
phase = graph.get_tensor_by_name("phase:0")
A7 = graph.get_tensor_by_name("Finalo:0")
##########################
###Training phase
for step in range(1, iterations+1):
###Getting the training data batch
img = sess.run([image])
X_temp = img[0][:,:,:,0:8]
Y_temp = img[0][:,:,:,8:9]
X_temp = X_temp.astype(np.float32)/255
Y_temp = Y_temp.astype(np.float32)/255
###Getting the color model result
output = sess.run([A7], feed_dict = {X: X_temp[:,:,:,5:8], Y: Y_temp, phase: False})
###Training the motion model
_, c, outputM = sess.run([optimizer, costM, MN_out], feed_dict = {XM: X_temp[:,:,:,0:5], YM: Y_temp, phaseM: True, ZM: output})
Как вы можете видеть, первый "sess.run" запускает Color_Model для получения выходных данных, а второй "sess.run" получает эти выходные данные и передает их в Motion_Model для обучения.
Но когда я запускаю этот код, я получаю следующую ошибку:
Traceback (most recent call last):
File "/home/hamidreza/venv/lib/python3.5/site-
packages/tensorflow/python/client/session.py", line 1292, in _do_call
return fn(*args)
File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1277, in _run_fn
options, feed_dict, fetch_list, target_list, run_metadata)
File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1367, in _call_tf_sessionrun
run_metadata)
tensorflow.python.framework.errors_impl.NotFoundError: Key WM1 not found in checkpoint
[[{{node save/RestoreV2}} = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, ..., DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1538, in restore
{self.saver_def.filename_tensor_name: save_path})
File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 887, in run
run_metadata_ptr)
File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1110, in _run
feed_dict_tensor, options, run_metadata)
File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1286, in _do_run
run_metadata)
File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1308, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.NotFoundError: Key WM1 not found in checkpoint
[[{{node save/RestoreV2}} = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, ..., DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]
Caused by op 'save/RestoreV2', defined at:
File "Detection_Model1.py", line 52, in <module>
saver = tf.train.Saver()
File "/home/hamidreza/venv/lib/python3.5/site-
packages/tensorflow/python/training/saver.py", line 1094, in __init__
self.build()
File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1106, in build
self._build(self._filename, build_save=True, build_restore=True)
File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1143, in _build
build_save=build_save, build_restore=build_restore)
File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 787, in _build_internal
restore_sequentially, reshape)
File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 406, in _AddRestoreOps
restore_sequentially)
File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 854, in bulk_restore
return io_ops.restore_v2(filename_tensor, names, slices, dtypes)
File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/ops/gen_io_ops.py", line 1466, in restore_v2
shape_and_slices=shape_and_slices, dtypes=dtypes, name=name)
File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py", line 787, in _apply_op_helper
op_def=op_def)
File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/util/deprecation.py", line 488, in new_func
return func(*args, **kwargs)
File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 3272, in create_op
op_def=op_def)
File "/home/hamidreza/venv/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1768, in __init__
self._traceback = tf_stack.extract_stack()
NotFoundError (see above for traceback): Restoring from checkpoint failed. This is most likely due to a Variable name or other graph key that is missing from the checkpoint.
Please ensure that you have not altered the graph expected based on the checkpoint. Original error:
Key WM1 not found in checkpoint
[[{{node save/RestoreV2}} = RestoreV2[dtypes=[DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, ..., DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2/tensor_names, save/RestoreV2/shape_and_slices)]]
Я совершенно уверен, что это смешивает графики, потому что WM1 - это вес первого слоя в Motion_Model, и на самом деле ошибка говорит о том, что он не может найти его в контрольной точке, которая ссылается на Color_Model, я полагаю.
Я очень признателен, если вы поможете мне с этой проблемой.