Как использовать fit_generator с последовательными данными, которые разбиты на пакеты? - PullRequest
0 голосов
/ 30 мая 2018

Я пытаюсь написать генератор для моей модели Keras lstm.Использовать его с методом fit_generator.Мой первый вопрос: что должен вернуть мой генератор?Партия?Последовательность?Пример в документации Keras возвращает x, y для каждой записи данных, но что если мои данные являются последовательными?И я хочу разделить его на пакеты?

Вот метод python, который создает пакет для данного ввода

def get_batch(data, batch_num, batch_size, seq_length):
    i_start = batch_num*batch_size;
    batch_sequences = []
    batch_labels = []
    batch_chunk = data.iloc[i_start:(i_start+batch_size)+seq_length].values
    for i in range(0, batch_size):
        sequence = batch_chunk[(i_start+i):(i_start+i)+seq_length];
        label = data.iloc[(i_start+i)+seq_length].values;
        batch_labels.append(label)
        batch_sequences.append(sequence)
    return np.array(batch_sequences), np.array(batch_labels);

Вывод этого метода для ввода, подобного этому:

get_batch(data, batch_num=0, batch_size=2, seq_length=3):

Было бы:

x = [
      [[1],[2],[3]],
      [[2],[3],[4]]
    ]

Вот как я представляю свою модель:

model = Sequential()
model.add(LSTM(256, return_sequences=True, input_shape=(seq_length, num_features)))
model.add(Dropout(0.2))
model.add(LSTM(256))
model.add(Dense(num_classes, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam')

Мой вопрос: как я могу перевести свой метод в генератор?

1 Ответ

0 голосов
/ 30 мая 2018

Вот решение, которое использует Sequence , который действует как генератор в Keras:

class MySequence(Sequence):
  def __init__(self, num_batches):
    self.num_batches = num_batches

  def __len__(self):
    return self.num_batches # the length is the number of batches

  def __getitem__(self, batch_id):
    return get_batch(data, batch_id, self.batch_size, seq_length)

Я думаю, что это чище и не меняет вашу первоначальную функцию.Теперь вы передаете экземпляр MySequence model.fit_generator.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...