Передача обучаемых параметров для взвешенных потерь в Keras, при использовании генератора для входов - PullRequest
0 голосов
/ 25 февраля 2019

При создании пользовательской функции в Керасе я столкнулся с проблемой при применении обучаемых весов для другого типа потерь.Ниже приведен фрагмент кода, демонстрирующий мой подход к изучению весов - c_a, c_b, c_c для трех потерь одновременно.Я использовал генератор для получения входных данных, и это главное отличие от нескольких других вопросов, на которые ответил @Daniel Möller.

Я определил custom_loss функцию внутри основного метода train_test, который включает в себя часть генерации модели,Чтобы обучить этим параметрам, я определил входной слой weight_input = keras.Input((3,)) и добавил его как дополнительный вход для keras.Model, также поместил weight_input в качестве параметра для функции custom_loss.

Вот мой вопрос:

Я хотел бы определить три значения weight_inputs: c_a, c_b, c_c как [1,1,1] и хочу, чтобы они были изменены во время обученияна.Поэтому я предполагаю, что эти значения не нужно включать для каждой выборки данных (как и другие входные данные: visit_input, areas_input, visit_features_input в функции генератора) Но я до сих пор не понимаю, как передать эти три значения в Должен ли я добавить дополнительный вход вмои существующие генераторы?

В настоящее время результаты data.train_data_generator и data.test_data_generator включают visit_input, areas_input, visit_features_input.Тогда как мне передать эти веса в функцию customer_loss и сделать их обучаемыми?

from data import *
import keras

def train_test(self):

    def custom_loss(weight_inputs):
        def _custom_loss(y_true, y_pred):
            # Calculation part of loss_a, loss_b, loss_c are removed too.
            # Scalar value - c_a, c_b, c_c - should be an element of weight_inputs, respectively. 
            total_loss = self.c_a * loss_a + self.c_b * loss_b + self.c_c * loss_c 

            return total_loss

        return _custom_loss


    # ... Additional NN structure before logits 
    # (visit_input, areas_input, visit_features_input, weight_input were put in a proper way)

    # Final prediction layer    
    logits = keras.layers.Dense(365, activation=keras.activations.softmax)(concat)

    # weight parameter input for custom loss function
    weight_input = keras.Input((3,))

    # Define a model and compile
    self.model = keras.Model(inputs=[visit_input, areas_input, visit_features_input, weight_input], outputs=logits)
    self.model.compile(optimizer=keras.optimizers.Adam(0.001),
                       loss=custom_survival_loss(weight_input),
                       )

    # Train
    self.train_data = data.train_data_generator()
    self.test_data = data.test_data_generator()

    # Fit
    self.history = self.model.fit_generator(
        generator=self.train_data,
        steps_per_epoch=train_data_size//FLAGS.batch_size,
        epochs=FLAGS.train_epochs,
        callbacks=[TrackTestDataPerformanceCallback(data, self.test_data)]
    )

    self.result = self.model.predict_generator(
        generator=self.test_data,
        steps=1
    )

    # Evaluate function (The trained weights should be used in this method too)
    evaluate(data, self.result)

Генераторы определены ниже.(Перед добавлением weight_inputs part)

    def train_data_generator(self):
        def __gen__():
            while True:
                idxs = list(self.df_train.index)
                np.random.shuffle(idxs)
                for idx in idxs:
                    visit = self.train_visits.iloc[idx]
                    label = self.df_train.iloc[idx]
                    yield visit['visit_indices'], visit['area_indices'], \
                          [visit[ft] for ft in self.handcrafted_features], \
                          [label[ft] for ft in ['label', 'suppress_time']]

        gen = __gen__()

        while True:
            batch = [np.stack(x) for x in zip(*(next(gen) for _ in range(FLAGS.batch_size)))]
            yield [batch[0].reshape(-1, 1), batch[1], batch[2]], batch[-1]

Plus: Я бы хотел сохранить несколько функций суб-потерь в custom_loss, так как некоторые вычисления могут быть повторно использованы для различных потерь.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...