Проблема с использованием fit_generator во время обучения керасу - PullRequest
0 голосов
/ 10 июля 2019

Я работаю с очень большими наборами текстовых данных. Я думал об использовании model.fit_generator метода вместо простого model.fit, поэтому я попытался использовать этот генератор:

def TrainGenerator(inp, out):
  for i,o in zip(inp, out):
    yield i,o

Когда я пытаюсь использовать его во время тренировки, используя:

#inp_train, out_train are lists of sequences padded to 50 tokens
model.fit_generator(generator = TrainGenerator(inp_train,out_train),
                    steps_per_epoch = BATCH_SIZE * 100,
                    epochs = 20,
                    use_multiprocessing = True)

Я получаю:

ValueError: Error when checking input: expected embedding_input to have shape (50,) but got array with shape (1,)

Теперь я попытался использовать простой метод model.fit, и он отлично работает. Итак, я думаю, что моя проблема в генераторе, но, поскольку я новичок в использовании генераторов, я не знаю, как ее решить. Полная сводка модели:

Layer (type)                 Output Shape            
===========================================
Embedding (Embedding)      (None, 50, 400)           
___________________________________________
Bi_LSTM_1 (Bidirectional)  (None, 50, 1024)          
___________________________________________
Bi_LSTM_2 (Bidirectional)  (None, 50, 1024)          
___________________________________________
Output (Dense)             (None, 50, 153)           
===========================================

РЕДАКТИРОВАТЬ 1

Первый комментарий меня как-то вызвал. Я понял, что неправильно понял, как работают генераторы. Результатом моего генератора был список формы 50, а не список из N списков формы 50. Поэтому я покопался в документации keras и обнаружил это . Итак, я изменил свой способ работы, и этот класс работает как генератор:

class BatchGenerator(tf.keras.utils.Sequence):

    def __init__(self, x_set, y_set, batch_size):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size

    def __len__(self):
        return int(np.ceil(len(self.x) / float(self.

    def __getitem__(self, idx):
        batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]

        return batch_x, to_categorical(batch_y,num_labels)

Где to_categorical - функция:

def to_categorical(sequences, categories):
    cat_sequences = []
    for s in sequences:
        cats = []
        for item in s:
            cats.append(np.zeros(categories))
            cats[-1][item] = 1.0
        cat_sequences.append(cats)
    return np.array(cat_sequences)

Итак, что я заметил сейчас, это хороший прирост производительности моей сети, каждая эпоха длится уже половину. Это происходит потому, что у меня больше свободной оперативной памяти, так как теперь я не загружаю весь набор данных в память?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...