Я новичок в tenorflow и не могу понять, правильно ли я сохраняю свою модель. моя функция поезда выглядит так:
def train_nn(self, sess, epochs, batch_size, get_batches_fn, train_op, cross_entropy_loss, input_image, label, keep_prob, learning_rate):
self.keep_prob_value = .5
self.learning_rate_value = 0.001
for epoch in range(epochs):
self.total_loss = 0.0
print("Epoch %d out of %d."%(epoch+1, epochs))
for X_batch, gt_batch in get_batches_fn(batch_size):
print("Importing a new batch for training.")
self.loss, _ = sess.run([cross_entropy_loss, train_op], feed_dict={input_image:X_batch, label:gt_batch, keep_prob:self.keep_prob_value, learning_rate:self.learning_rate_value})
self.total_loss += self.loss
print("Loss = %s"%self.total_loss)
print("Trying to save the graph.")
g = sess.graph
print("defined graph g")
gdef = g.as_graph_def()
print("graph set as default")
tf.io.train.write_graph(gdef, self.model_directory, os.path.join(self.model_directory, self.model_filename)+".pb", True)
print("Graph saved successfully.")
print("Initializing train saver.")
self.saver = tf.train.Saver()
print("saving weights")
self.saver.save(sess, os.path.join(self.model_directory, "weights"))
print("weights saved")
При выводе я получаю следующее:
Importing a new batch for training.
Loss = 290.54996478557587
Trying to save the graph.
defined graph g
graph set as default
Затем, через некоторое время выполнения, мой код перестает выполняться без дополнительного вывода.