Итак, я провел некоторое тестирование и выяснил следующее об этой проблеме.Поскольку я пытался повторно использовать мою созданную модель, мне пришлось использовать tf.global_variables_initializer ().При этом он переписал мой импортированный график, и все значения были случайными, что объясняет различные выходные данные сети.Это все еще оставило меня с проблемой, чтобы решить: как мне загрузить мою сеть?Обходной путь, который я сейчас использую, далеко не оптимален, но, по крайней мере, позволяет мне использовать мою сохраненную модель.Тензорный поток позволяет дать уникальные имена используемым функциям и тензорам.Таким образом я мог получить к ним доступ через график:
with tf.Session() as sess:
saver = tf.train.import_meta_graph('path to .meta')
saver.restore(sess, tf.train.latest_checkpoint('path to checkpoints'))
graph = tf.get_default_graph()
graph.get_tensor_by_name('name:0')
Используя этот метод, я мог получить доступ ко всем моим сохраненным значениям, но они были разделены!Это означает, что у меня был 1x вес и 1x смещение на одну использованную операцию, что привело к куче новых переменных.Если вы не знаете имен, используйте следующее:
print(graph.get_all_collection_keys())
Это печатает имена коллекций (наши переменные хранятся в коллекциях)
print(graph.get_collection('name'))
Это позволяет нам получить доступ к коллекции какпосмотрите, какие имена / ключи для наших переменных.
Это привело к другой проблеме.Я больше не мог использовать свою модель, так как инициализатор глобальных переменных переписал все.Таким образом, мне пришлось переопределить всю модель вручную с учетом веса и уклонов, которые я получил ранее.
К сожалению, это единственное, что я могу придумать.Если у кого-то есть идея получше, пожалуйста, дайте мне знать.
Все с ошибкой выглядело так:
imports...
placeholders for data...
def my_network(data):
## network definition with tf functions ##
return output
def train_my_net():
prediction = my_network(data)
cost function
optimizer
with tf.Session() as sess:
for i in how many epochs i want:
training routine
save
def use_my_net():
prediction = my_network(data)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver = tf.train.import_meta_graph('path to .meta')
saver.restore(sess, tf.train.latest_checkpoint('path to checkpoints'))
print(sess.run(prediction.eval(feed_dict={placeholder:data})))
graph = tf.get_default_graph()