Где мой код keras DataGenerator go неправильный? - PullRequest
0 голосов
/ 07 августа 2020

Я обучаю нейронную сеть с прямой связью (NN), поскольку обучающие данные слишком велики, я реализовал keras DataGenerator для подачи обучающих данных, но потеря проверки не сходилась. Я подозревал, что моя реализация DataGenerator работает неправильно, поэтому я попробовал DataGenerator на другом протестированном NN с прямой связью, чтобы проверить, верна ли моя концепция DataGenerator. Функция этого прямого NN заключается в преобразовании значений координат CIELCh в значения координат sRGB. Процесс обучения с использованием DataGenerator выглядит не так, как процесс обучения, когда я напрямую генерирую все данные перед обучением. И я не могу найти, где ошибка logi c. Пожалуйста, помогите мне. Спасибо. Я напрямую использую упакованный модуль keras в tensorflow, моя версия tenorflow - 1.13 Вот определение моего класса DataGenerator:

CIEL_MIN=0
CIEL_MAX=99
CIEC_MIN=0
CIEC_MAX=133
CIEH_MIN=0
CIEH_MAX=360

TRAIN_SINTERVAL = 3 # training sample interval
VAL_SINTERVAL = 1 # validation sample interval
DIM_IN = 3 # number of inputs
DIM_OUT = 3 # number of outputs
class CIELCh2sRGB_Generator(keras.utils.Sequence):
    'Generates data for Keras'
    def __init__(self, batch_size=32, validation=False, shuffle=True):
        'Initialization'
        self.batch_size = batch_size
        self.validation = validation
        self.shuffle = shuffle
        
        self.lch_pool = {}
        for l in range(CIEL_MIN, CIEL_MAX):
            for c in range(CIEC_MIN, CIEC_MAX):
                for h in range(CIEH_MIN, CIEH_MAX):
                    r,g,b = CIELCh_to_sRGB(l, c, h)
                    if r<0 or r>255 or g<0 or g>255 or b<0 or b>255:
                        continue
                    self.lch_pool[l, c, h] = (r, g, b)
        
        if self.validation:
            self.l_pts = np.arange(CIEL_MIN, CIEL_MAX, VAL_SINTERVAL)
            self.c_pts = np.arange(CIEC_MIN, CIEC_MAX, VAL_SINTERVAL)
            self.h_pts = np.arange(CIEH_MIN, CIEH_MAX, VAL_SINTERVAL)
        else:
            self.l_pts = np.arange(CIEL_MIN, CIEL_MAX, TRAIN_SINTERVAL)
            self.c_pts = np.arange(CIEC_MIN, CIEC_MAX, TRAIN_SINTERVAL)
            self.h_pts = np.arange(CIEH_MIN, CIEH_MAX, TRAIN_SINTERVAL)

        if self.validation:
            self.samples_per_epoch = (CIEL_MAX//VAL_SINTERVAL)*(CIEC_MAX//VAL_SINTERVAL)*(CIEH_MAX//VAL_SINTERVAL)
        else:
            self.samples_per_epoch = len(self.l_pts)*len(self.c_pts)*len(self.h_pts)
            
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(self.samples_per_epoch / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        X = np.empty((self.batch_size, DIM_IN))
        Y = np.empty((self.batch_size, DIM_OUT), dtype=int)
        
        i = 0
        while i < self.batch_size:
            L = self.l_pts[self.l_idx]
            C = self.c_pts[self.c_idx]
            h = self.h_pts[self.h_idx]
            
            try:
                r, g, b = self.lch_pool[L, C, h]
            except KeyError:
                self.__update_inner_index()
                continue
            
            X[i,] = [float(L) / float(CIEL_MAX+1),
                     float(C) / float(CIEC_MAX+1),
                     float(h) / float(CIEH_MAX)]
            Y[i,] = [r1, g1, b1]
            
            i += 1
            self.__update_inner_index()

        return X, Y

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        if (not self.validation) and (self.shuffle):
            np.random.shuffle(self.l_pts)
            np.random.shuffle(self.c_pts)
            np.random.shuffle(self.h_pts)
        self.l_idx = 0
        self.c_idx = 0
        self.h_idx = 0

        
    def __update_inner_index(self):
        
        self.h_idx += 1
        if self.h_idx!=len(self.h_pts): return
        else: self.h_idx = 0
        
        self.c_idx += 1
        if self.c_idx!=len(self.c_pts): return
        else: self.c_idx = 0
        
        self.l_idx += 1
        if self.l_idx!=len(self.l_pts): return
        else: self.l_idx = 0
        
        return

Вот мой призыв скомпилировать и подогнать NN:

data_gen=CIELCh2sRGB_Generator()
val_gen=CIELCh2sRGB_Generator(validation=True)
model.compile(optimizer='rmsprop', loss='mse')
history=model.fit_generator(
    data_gen,
    validation_data=val_gen,
    shuffle=True,
    epochs=50)

Ниже приведен некоторый печатный текст о тренировочном процессе, самая сложная его часть - указанное значение потерь справа от индикатора выполнения будет увеличиваться, а не уменьшаться.

Epoch 1/50
148128/148128 [==============================] - 232s 2ms/step - loss: 5242.1912
5568/5568 [==============================] - 252s 45ms/step - loss: 3669.4281 - val_loss: 5242.1912
Epoch 2/50
148128/148128 [==============================] - 233s 2ms/step - loss: 5296.7092
5568/5568 [==============================] - 251s 45ms/step - loss: 2608.8566 - val_loss: 5296.7092
Epoch 3/50
148128/148128 [==============================] - 235s 2ms/step - loss: 8954.4010
5568/5568 [==============================] - 253s 46ms/step - loss: 2131.9791 - val_loss: 8954.4010
Epoch 4/50
148128/148128 [==============================] - 232s 2ms/step - loss: 10761.9622
5568/5568 [==============================] - 251s 45ms/step - loss: 1778.0234 - val_loss: 10761.9622
Epoch 5/50
117460/148128 [======================>.......] - ETA: 49s - loss: 6467.2221
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...