Я тренирую нейронную сеть с Керасом. Из-за размера набора данных мне нужно использовать генератор и метод fit_generator (). Я следую этому уроку:
https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly
Тем не менее, я подготовил небольшой пример для проверки выборок, подаваемых в сеть в каждую эпоху, и кажется, что это число превышает количество выборок.
class DataGenerator(keras.utils.Sequence):
'Generates data for Keras'
def __init__(self, files, batch_size=2, dim=(160, 160), n_channels=3,
n_classes=2, shuffle=False):
'Initialization'
self.dim = dim
self.files = files
self.batch_size = batch_size
self.n_channels = n_channels
self.n_classes = n_classes
self.shuffle = shuffle
self.on_epoch_end()
def __len__(self):
'Denotes the number of batches per epoch'
print ("Number of batches per epoch")
print(int(np.floor(len(self.files) / self.batch_size)))
return int(np.floor(len(self.files) / self.batch_size))
def __getitem__(self, index):
'Generate one batch of data'
# Generate indexes of the batch
indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
# Find list of IDs
files_temp = [self.files[k] for k in indexes]
# Generate data
X, y = self.__data_generation(files_temp)
return X, y
def on_epoch_end(self):
'Updates indexes after each epoch'
self.indexes = np.arange(len(self.files))
if self.shuffle == True:
np.random.shuffle(self.indexes)
def __data_generation(self, files_temp):
'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
# Initialization
X = np.empty((self.batch_size, *self.dim, self.n_channels))
y = np.empty((self.batch_size), dtype=int)
# Generate data
for i, ID in enumerate(files_temp):
# Store sample
X[i,] = read_image(ID)
# Store class
y[i] = get_label(ID)
return X, keras.utils.to_categorical(y, num_classes=self.n_classes)
...
params = {'dim': (160, 160),
'batch_size': 2,
'n_classes': 2,
'n_channels': 3,
'shuffle': True}
gen_train = DataGenerator(files, **params)
model.fit_generator(gen_train, steps_per_epoch=ceil(num_samples_train)/batch_size, validation_data=None,
epochs = 1, verbose=1,
callbacks = [tensorboard])
Где read_image
и get_label
- мои методы получения данных. Эти методы включают print () для загружаемого изображения, и я получаю больше, чем ожидаю. Например:
num_samples = 10
batch_size = 2
Шагов за эпоху будет равен 5, и это то, что показывает индикатор выполнения keras, но я получаю больше изображений (что я знаю из-за печати внутри метода).
Я попытался отладить и обнаружил, что функция __getitem__
вызывается более 5 раз! Первые пять раз будут иметь индексы от 0 до 4 (как и ожидалось), но затем я получу повторный индекс и загрузим больше данных.
Есть идеи, почему это происходит? Я отладил файл data_utils.py в керасе, но не могу найти точное место, где индекс передается в __getitem__
. Кажется, все внутри getitem работает нормально.