Использование sample_weights с fit_generator () - PullRequest
0 голосов
/ 17 ноября 2018

Внутри авторегрессионной непрерывной задачи, когда нули занимают слишком много места, можно рассматривать ситуацию как проблему с завышенным нулем (т.е. ZIB). Другими словами, вместо того, чтобы работать для подгонки f(x), мы хотим подгонять g(x)*f(x), где f(x) - это функция, которую мы хотим аппроксимировать, т.е. y, а g(x) - это функция, которая выводит значение в диапазоне от 0 и 1 в зависимости от того, является ли значение нулевым или ненулевым.

В настоящее время у меня есть две модели. Одна модель, которая дает мне g(x), а другая модель, которая подходит g(x)*f(x).

Первая модель дает мне набор весов. Здесь мне нужна твоя помощь. Я могу использовать sample_weights аргументы с model.fit(). Поскольку я работаю с огромным количеством данных, мне нужно работать с model.fit_generator(). Однако fit_generator() не имеет аргумента sample_weights.

Есть ли способ работы с sample_weights внутри fit_generator()? Иначе, как я могу соответствовать g(x)*f(x), зная, что у меня уже есть обученная модель для g(x)?

1 Ответ

0 голосов
/ 29 ноября 2018

Вы можете указать веса выборки в качестве третьего элемента кортежа, возвращаемого генератором. Из документации Keras fit_generator:

generator: Генератор или экземпляр объекта Sequence (keras.utils.Sequence) во избежание дублирования данных при использовании многопроцессорной обработки. Выход генератора должен быть либо

  • кортеж (inputs, targets)
  • кортеж (inputs, targets, sample_weights).

Обновление: Вот приблизительный эскиз генератора, который возвращает входные выборки и цели, а также веса выборок, полученные из модели g(x):

def gen(args):
    while True:
        for i in range(num_batches):
            # get the i-th batch data
            inputs = ...
            targets = ...

            # get the sample weights
            weights = g.predict(inputs)

            yield inputs, targets, weights


model.fit_generator(gen(args), steps_per_epoch=num_batches, ...)
...