мои данные проверки имеют форму (4480, 48, 48, 1), они слишком велики, чтобы их можно было вставить в память при обучении, кроме того, я хочу наблюдать за средней потерей данных проверки с помощью тензорной доски. К настоящему времени я разделил большие данные проверки на количество маленьких кусочков, что (16,48,48,1), код моего тренировочного цикла находится здесь. Очевидно, что применение функции суммы к элементу loss_scalar является ошибкой, но как я могу суммироватьвсе раскололось? спасибо за помощь!
def train(self):
#train_data
hf = h5py.File("../checkpoint/data_train.h5", 'r')
num = hf['train_input'].shape
print(num)
index = num[0] // self.batchsize
with h5py.File("../checkpoint/data_training_test.h5", 'r') as hk:
validation_input = np.array(hk['validation_input']) #(4480, 48, 48, 1)
validation_label = np.array(hk['validation_label'])
indexValid = len(validation_input) // self.batchsize
#tensorboard
merged = tf.summary.merge_all()
train_writer = tf.summary.FileWriter(self.logs + +"_train", self.sess.graph)
test_writer = tf.summary.FileWriter(self.logs + "_test", self.sess.graph)
#Optimizer
self.train_op = tf.train.AdamOptimizer().minimize(self.loss)
init = tf.global_variables_initializer()
start_time = time.time()
#load model
if self.resume():
print(" [*] Load SUCCESS")
else:
print(" [!] Load failed...")
print("Begin training...")
count = 0
self.sess.run(init)
# logs
logging.basicConfig(level=logging.DEBUG,
filename='./logs/traingraphs/train.log',
filemode='w',
format='%(asctime)s : %(message)s'
)
for ep in tqdm(range(self.epoch)):
for item in range(index):
x = hf['train_input'][item * self.batchsize: (item + 1) * self.batchsize]
y = hf['train_label'][item * self.batchsize: (item + 1) * self.batchsize]
_,err = self.sess.run([self.train_op,self.loss],feed_dict={self.input: x,self.target: y})
logging.info("Epoch:[%2d], step:[%2d], time:[%4.4f], loss:[%.8f]" \
% (ep, ep, time.time() - start_time, err))
summary, _ = self.sess.run([merged, self.loss], feed_dict={self.input: x, self.target: y})
train_writer.add_summary(summary, ep)
self.save(ep)
loss_scalar_list = []
psnr_scalar_list = []
psnr2_scalar_list = []
for i in range(indexValid):
vx = validation_input[i * self.batchsize: (i + 1) * self.batchsize] #(16,48,48,1)
vy = validation_label[i * self.batchsize: (i + 1) * self.batchsize]
loss_scalar, psnr_scalar,psnr2_scalar = self.sess.run([self.loss_scalar, self.psnr_scalar,self.psnr2_scalar],
feed_dict={self.input: vx, self.target: vy})
loss_scalar_list.append(loss_scalar)
psnr_scalar_list.append(psnr_scalar)
psnr2_scalar_list.append(psnr2_scalar)
loss_average = sum(loss_scalar_list) #Obviously, apply sum function on the loss_scalar_list is fault
test_writer.add_summary(loss_average, ep)
self.save(ep)