Разница между параметром «weighted_metrics» модели model.com и параметром «class_weight» модели.fit_generator в кератах? - PullRequest
1 голос
/ 03 июля 2019

Во время обучения модели keras для классификации изображений (120 классов из набора данных DOG BREED IDENTIFICATION, KAGGLE) мне необходимо сбалансировать классы с использованием весов классов, которые я где-то читал, и в примерах я видел людей, использующих параметр fit_generator, class_weight.Но я нашел еще один параметр в model.compile, weighted_metrics, описание которого в документации: «Список метрик, которые будут оцениваться и взвешиваться по sample_weight или class_weight во время обучения и тестирования».Должен ли я использовать это?Пожалуйста, объясните назначение этого параметра на любом примере.

#Calculating Class weights
counter = Counter(train_generator.classes)
max_value = float(max(counter.values()))

CLASS_WEIGHTS = {classid: max_value / num_occurences
                 for classid, num_occurences in counter.items()}
# Model Compile
model.compile(optimizer=Adam(lr=LR),
              loss=categorical_crossentropy,
              metrics=[categorical_accuracy],
              weighted_metrics=None) # <--------------- This parameter

STEPS_PER_EPOCH = train_generator.n//train_generator.batch_size
VAL_STEPS = val_generator.n//val_generator.batch_size

model.fit_generator(train_generator,
                    steps_per_epoch=STEPS_PER_EPOCH,
                    epochs=EPOCHS,
                    callbacks=callback_list,
                    verbose=1,
                    class_weight=CLASS_WEIGHTS,
                    validation_data=val_generator,
                    validation_steps=VAL_STEPS) # USED CLASS_WEIGHTS HERE

1 Ответ

0 голосов
/ 03 июля 2019

Да, вы можете использовать их для своего несбалансированного набора данных.

weighted_metrics

- это список метрик, которые учитывают

class_weights

, которые вы передаете в fit_generator.

Итак, в вашем примере вы можете установить

weighted_metrics=['accuracy']

и

class_weight = {0 : 3, 1: 4}

Цель параметра weighted_metrics - предоставить список метрик, которые будут учитывать значения class_weights, которые вы передаете в fit_generator.

...