Несбалансированные данные: глубокая сеть прямой связи со сбалансированным class_weight не учится за пределами доминирующего класса - PullRequest
0 голосов
/ 02 февраля 2019

У меня есть данные с 5 выходными классами.Учебные данные не имеют следующих образцов для этих 5 классов: [706326, 32211, 2856, 3050, 901]

Я использую следующий код keras (tf.keras):

class_weights = class_weight.compute_class_weight('balanced',
                                                 np.unique(y_train),
                                                 y_train)

model = tf.keras.Sequential([
    tf.keras.layers.Dense(50, input_shape=(dataX.shape[1],)),
    tf.keras.layers.Dropout(rate = 0.5),
    tf.keras.layers.Dense(50, activation=tf.nn.relu),
    tf.keras.layers.Dropout(rate = 0.5),
    tf.keras.layers.Dense(50, activation=tf.nn.relu),
    tf.keras.layers.Dropout(rate = 0.5),
    tf.keras.layers.Dense(50, activation=tf.nn.relu),
    tf.keras.layers.Dropout(rate = 0.5),
    tf.keras.layers.Dense(5, activation=tf.nn.softmax) ])
     adam = tf.keras.optimizers.Adam(lr=0.5)

model.compile(optimizer=adam, 
              loss='sparse_categorical_crossentropy',
              metrics=[metrics.sparse_categorical_accuracy])    
     model.fit(X_train,y_train, epochs=5, batch_size=32, class_weight=class_weights)

y_pred = np.argmax(model.predict(X_test), axis=1)

Я использую sparse_categorical_crossentropy, которая принимает категории в виде целых чисел (не нужно преобразовывать их в одноразовое кодирование), но я также пробовал categoryorical_crossentropy и все еще та же проблема.

Я, конечно, пробовал разныескорость обучения, размер партии, количество эпох, оптимизатор и глубина / длина сети.Но он всегда застрял с точностью ~ 0,94, что, по сути, я бы получил, если бы все время предсказывал первый класс.

Не уверен, чего здесь не хватает.Любая ошибка с моей стороны, или какая-то ошибка с class_weight в Keras?Или я должен использовать какую-то другую специализированную глубокую сеть?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...