Я обучаю нейронную сеть с прямой связью (NN), поскольку обучающие данные слишком велики, я реализовал keras DataGenerator
для подачи обучающих данных, но потеря проверки не сходилась. Я подозревал, что моя реализация DataGenerator
работает неправильно, поэтому я попробовал DataGenerator
на другом протестированном NN с прямой связью, чтобы проверить, верна ли моя концепция DataGenerator
. Функция этого прямого NN заключается в преобразовании значений координат CIELCh в значения координат sRGB. Процесс обучения с использованием DataGenerator выглядит не так, как процесс обучения, когда я напрямую генерирую все данные перед обучением. И я не могу найти, где ошибка logi c. Пожалуйста, помогите мне. Спасибо. Я напрямую использую упакованный модуль keras
в tensorflow
, моя версия tenorflow - 1.13
Вот определение моего класса DataGenerator:
CIEL_MIN=0
CIEL_MAX=99
CIEC_MIN=0
CIEC_MAX=133
CIEH_MIN=0
CIEH_MAX=360
TRAIN_SINTERVAL = 3 # training sample interval
VAL_SINTERVAL = 1 # validation sample interval
DIM_IN = 3 # number of inputs
DIM_OUT = 3 # number of outputs
class CIELCh2sRGB_Generator(keras.utils.Sequence):
'Generates data for Keras'
def __init__(self, batch_size=32, validation=False, shuffle=True):
'Initialization'
self.batch_size = batch_size
self.validation = validation
self.shuffle = shuffle
self.lch_pool = {}
for l in range(CIEL_MIN, CIEL_MAX):
for c in range(CIEC_MIN, CIEC_MAX):
for h in range(CIEH_MIN, CIEH_MAX):
r,g,b = CIELCh_to_sRGB(l, c, h)
if r<0 or r>255 or g<0 or g>255 or b<0 or b>255:
continue
self.lch_pool[l, c, h] = (r, g, b)
if self.validation:
self.l_pts = np.arange(CIEL_MIN, CIEL_MAX, VAL_SINTERVAL)
self.c_pts = np.arange(CIEC_MIN, CIEC_MAX, VAL_SINTERVAL)
self.h_pts = np.arange(CIEH_MIN, CIEH_MAX, VAL_SINTERVAL)
else:
self.l_pts = np.arange(CIEL_MIN, CIEL_MAX, TRAIN_SINTERVAL)
self.c_pts = np.arange(CIEC_MIN, CIEC_MAX, TRAIN_SINTERVAL)
self.h_pts = np.arange(CIEH_MIN, CIEH_MAX, TRAIN_SINTERVAL)
if self.validation:
self.samples_per_epoch = (CIEL_MAX//VAL_SINTERVAL)*(CIEC_MAX//VAL_SINTERVAL)*(CIEH_MAX//VAL_SINTERVAL)
else:
self.samples_per_epoch = len(self.l_pts)*len(self.c_pts)*len(self.h_pts)
self.on_epoch_end()
def __len__(self):
'Denotes the number of batches per epoch'
return int(np.floor(self.samples_per_epoch / self.batch_size))
def __getitem__(self, index):
'Generate one batch of data'
X = np.empty((self.batch_size, DIM_IN))
Y = np.empty((self.batch_size, DIM_OUT), dtype=int)
i = 0
while i < self.batch_size:
L = self.l_pts[self.l_idx]
C = self.c_pts[self.c_idx]
h = self.h_pts[self.h_idx]
try:
r, g, b = self.lch_pool[L, C, h]
except KeyError:
self.__update_inner_index()
continue
X[i,] = [float(L) / float(CIEL_MAX+1),
float(C) / float(CIEC_MAX+1),
float(h) / float(CIEH_MAX)]
Y[i,] = [r1, g1, b1]
i += 1
self.__update_inner_index()
return X, Y
def on_epoch_end(self):
'Updates indexes after each epoch'
if (not self.validation) and (self.shuffle):
np.random.shuffle(self.l_pts)
np.random.shuffle(self.c_pts)
np.random.shuffle(self.h_pts)
self.l_idx = 0
self.c_idx = 0
self.h_idx = 0
def __update_inner_index(self):
self.h_idx += 1
if self.h_idx!=len(self.h_pts): return
else: self.h_idx = 0
self.c_idx += 1
if self.c_idx!=len(self.c_pts): return
else: self.c_idx = 0
self.l_idx += 1
if self.l_idx!=len(self.l_pts): return
else: self.l_idx = 0
return
Вот мой призыв скомпилировать и подогнать NN:
data_gen=CIELCh2sRGB_Generator()
val_gen=CIELCh2sRGB_Generator(validation=True)
model.compile(optimizer='rmsprop', loss='mse')
history=model.fit_generator(
data_gen,
validation_data=val_gen,
shuffle=True,
epochs=50)
Ниже приведен некоторый печатный текст о тренировочном процессе, самая сложная его часть - указанное значение потерь справа от индикатора выполнения будет увеличиваться, а не уменьшаться.
Epoch 1/50
148128/148128 [==============================] - 232s 2ms/step - loss: 5242.1912
5568/5568 [==============================] - 252s 45ms/step - loss: 3669.4281 - val_loss: 5242.1912
Epoch 2/50
148128/148128 [==============================] - 233s 2ms/step - loss: 5296.7092
5568/5568 [==============================] - 251s 45ms/step - loss: 2608.8566 - val_loss: 5296.7092
Epoch 3/50
148128/148128 [==============================] - 235s 2ms/step - loss: 8954.4010
5568/5568 [==============================] - 253s 46ms/step - loss: 2131.9791 - val_loss: 8954.4010
Epoch 4/50
148128/148128 [==============================] - 232s 2ms/step - loss: 10761.9622
5568/5568 [==============================] - 251s 45ms/step - loss: 1778.0234 - val_loss: 10761.9622
Epoch 5/50
117460/148128 [======================>.......] - ETA: 49s - loss: 6467.2221