Я пытаюсь восстановить предварительно обученную модель, которая называется моделью суфлена, и она была обучена другими людьми.Однако, когда я пытался извлечь модель и восстановить график тензорного потока для обучения нового набора данных (набор данных 10000 изображений), я получил сообщение об ошибке ключа в моем терминале: Вот мой код:
meta_path = './model/model.ckpt-0.meta'
tf.reset_default_graph()
saver = tf.train.import_meta_graph(meta_path)
restored_graph = tf.get_default_graph()
for tensor in restored_graph.get_operations():
print (tensor.name)
global_step_tensor = restored_graph.get_tensor_by_name('Softmax/prediction:0')
image_input_node = restored_graph.get_tensor_by_name('TFRecordIterator/IteratorGetNext:0')
label_node = restored_graph.get_tensor_by_name('TFRecordIterator/OneShotIterator:0')
loss = restored_graph.get_tensor_by_name('sparse_softmax_cross_entropy/add:0')
# tf.contrib.quantize.create_training_graph(input_graph=restored_graph, quant_delay=2000000)
optimizer = tf.train.GradientDescentOptimizer(0.01)
train_op = optimizer.minimize(loss)
# tf.contrib.quantize.create_training_graph(quant_delay=2000000)
iterations = 100
# run the session
with tf.Session() as sess:
# restore the saved vairable
saver.restore(sess, './model/model.ckpt-0')
# sess.run(optimizer)
image_files = []
labels = []
with open("/mnt/ficussweden/hhzhang/00_Data/04_TF_Autobot/gender/data/train.tsv", 'r') as f:
line = f.readline()
while line:
array = line.rstrip('\n').split()
image_files.append(array[0])
labels.append(int(array[1]))
line = f.readline()
if len(image_files)>=10001:
break
# print(image_files[len(image_files)-1])
data = []
data_labels = []
for i in range(200):
print("process batch {}-th data \n".format(i))
batch = []
batch_label = []
for j in range(50):
img_path = image_files[j+i*50]
img = np.array(Image.open(img_path))
batch.append(img)
batch_label.append(labels[j+i*50])
batch = np.array(batch)
batch_label = np.array(batch_label)
data.append(batch)
data_labels.append(batch_label)
data = np.array(data)
data_labels = np.array(data_labels)
print(data.shape)
print(data_labels.shape)
for i in range(iterations):
for j in range(len(data)):
batch_data = data[j]
batch_label = data_labels[j]
res = sess.run(train_op, feed_dict = {image_input_node: batch_data, label_node: batch_label})
print(res)
Обратите внимание, что я прочитал все данные изображения в список пустых строк и выбросил их в свой график восстановления на основе входного узла.Но я получил следующее сообщение об ошибке:
Traceback (most recent call last):
File "restore_graph_train.py", line 127, in <module>
res = sess.run(train_op, feed_dict = {image_input_node: batch_data, label_node: batch_label})
File "/mnt/ficusspain/cqli/virtual_env/quantize_model/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 900, in run
run_metadata_ptr)
File "/mnt/ficusspain/cqli/virtual_env/quantize_model/lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1088, in _run
subfeed_dtype = subfeed_t.dtype.as_numpy_dtype
File "/mnt/ficusspain/cqli/virtual_env/quantize_model/lib/python3.5/site-packages/tensorflow/python/framework/dtypes.py", line 128, in as_numpy_dtype
return _TF_TO_NP[self._type_enum]
KeyError: 20