Как передать большие данные проверки в Nework при обучении? - PullRequest
0 голосов
/ 25 октября 2019

мои данные проверки имеют форму (4480, 48, 48, 1), они слишком велики, чтобы их можно было вставить в память при обучении, кроме того, я хочу наблюдать за средней потерей данных проверки с помощью тензорной доски. К настоящему времени я разделил большие данные проверки на количество маленьких кусочков, что (16,48,48,1), код моего тренировочного цикла находится здесь. Очевидно, что применение функции суммы к элементу loss_scalar является ошибкой, но как я могу суммироватьвсе раскололось? спасибо за помощь!

def train(self):
        #train_data
        hf = h5py.File("../checkpoint/data_train.h5", 'r')
        num = hf['train_input'].shape
        print(num)
        index = num[0] // self.batchsize
        with h5py.File("../checkpoint/data_training_test.h5", 'r') as hk:
            validation_input = np.array(hk['validation_input']) #(4480, 48, 48, 1)
            validation_label = np.array(hk['validation_label'])

        indexValid = len(validation_input) // self.batchsize

        #tensorboard
        merged = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(self.logs + +"_train", self.sess.graph)
        test_writer = tf.summary.FileWriter(self.logs + "_test", self.sess.graph)
        #Optimizer
        self.train_op = tf.train.AdamOptimizer().minimize(self.loss)
        init = tf.global_variables_initializer()
        start_time = time.time()
        #load model
        if self.resume():
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")

        print("Begin training...")
        count = 0
        self.sess.run(init)
        # logs
        logging.basicConfig(level=logging.DEBUG,
                            filename='./logs/traingraphs/train.log',
                            filemode='w',
                            format='%(asctime)s : %(message)s'
                            )
        for ep in tqdm(range(self.epoch)):
            for item in range(index):
                x = hf['train_input'][item * self.batchsize: (item + 1) * self.batchsize]
                y = hf['train_label'][item * self.batchsize: (item + 1) * self.batchsize]
                _,err = self.sess.run([self.train_op,self.loss],feed_dict={self.input: x,self.target: y})
            logging.info("Epoch:[%2d], step:[%2d], time:[%4.4f], loss:[%.8f]" \
                         % (ep, ep, time.time() - start_time, err))
            summary, _ = self.sess.run([merged, self.loss], feed_dict={self.input: x, self.target: y})
            train_writer.add_summary(summary, ep)
            self.save(ep)
            loss_scalar_list = []
            psnr_scalar_list = []
            psnr2_scalar_list = []
            for i in range(indexValid):
                vx = validation_input[i * self.batchsize: (i + 1) * self.batchsize] #(16,48,48,1)
                vy = validation_label[i * self.batchsize: (i + 1) * self.batchsize]
                loss_scalar, psnr_scalar,psnr2_scalar  = self.sess.run([self.loss_scalar, self.psnr_scalar,self.psnr2_scalar],
                                             feed_dict={self.input: vx, self.target: vy})
                loss_scalar_list.append(loss_scalar)
                psnr_scalar_list.append(psnr_scalar)
                psnr2_scalar_list.append(psnr2_scalar)
            loss_average = sum(loss_scalar_list) #Obviously, apply sum function on the loss_scalar_list is fault
            test_writer.add_summary(loss_average, ep)
        self.save(ep)
...