Улучшение производительности MLP (для мультиклассовой классификации) с помощью меток Пуассона, отобранных с использованием Keras - PullRequest
0 голосов
/ 11 января 2019

Я пытаюсь использовать полностью подключенную нейронную сеть или многослойный персептрон для выполнения многоклассовой классификации: мои тренировочные данные (X) представляют собой разные строки ДНК одинаковой длины. Каждая из этих последовательностей имеет ассоциированное с ними значение с плавающей точкой (например, t_X), которое я использую для имитации меток (y) для моих данных следующим образом. y ~ np.random.poisson (постоянная * t_X) .

После обучения модели Keras (см. Ниже) я составил гистограмму прогнозируемых меток и тестовых меток, и проблема, с которой я сталкиваюсь, заключается в том, что моя модель неправильно классифицирует множество последовательностей, см. Изображение ниже.

Ссылка на гистограмму

Мои данные о тренировках выглядят следующим образом:

X , Y  
CTATTACCTGCCCACGGTAAAGGCGTTCTGG,    1
TTTCTGCCCGCGGCCTGGCAATTGATACCGC,    6
TTTTTACACGCCTTGCGTAAAGCGGCACGGC,    4
TTGCTGCCTGGCCGATGGTCTATGCCGCTGC,    7

Я одноразово кодирую свои Y и последовательности X превращаются в тензоры измерений: (размер пакета, длина последовательности, количество символов), эти числа примерно такие же, как 10 000 на 50 на 4

Моя модель keras выглядит так:

model = Sequential() 
model.add(Flatten())
model.add(Dense(100, activation='relu',input_shape=(50,4)))
model.add(Dropout(0.25))
model.add(Dense(50, activation='relu'))
model.add(Dropout(0.25))
model.add(Dense(len(one_hot_encoded_labels), activation='softmax'))

Я пробовал следующие различные функции потери

#model.compile(loss='mean_squared_error',optimizer=Adam(lr=0.00001), metrics=['accuracy'])
#model.compile(loss='mean_squared_error',optimizer=Adam(lr=0.0001), metrics=['mean_absolute_error',r_square])
#model.compile(loss='kullback_leibler_divergence',optimizer=Adam(lr=0.00001), metrics=['categorical_accuracy'])
#model.compile(loss=log_poisson_loss,optimizer=Adam(lr=0.0001), metrics=['categorical_accuracy'])
#model.compile(loss='categorical_crossentropy',optimizer=Adam(lr=0.0001), metrics=['categorical_accuracy'])
model.compile(loss='poisson',optimizer=Adam(lr=0.0001), metrics=['categorical_accuracy'])

Потеря ведет себя разумно; он падает и сглаживается с увеличением эпох. Я пробовал разные скорости обучения, разные оптимизаторы, разное количество нейронов в каждом слое, разное количество скрытых слоев и разные типы регуляризации.

Я думаю, что моя модель всегда помещает большинство предсказанных меток вокруг пика тестовых данных (см. Связанную гистограмму), но она не может классифицировать последовательности с меньшим количеством подсчетов в тестовом наборе. Это общая проблема?

Не обращаясь к другим архитектурам (таким как свертка или рекуррентность), кто-нибудь знает, как я могу улучшить производительность классификации для этой модели?

Файл обучающих данных

1 Ответ

0 голосов
/ 12 января 2019

Из ваших распределений гистограммы ясно, что у вас очень несбалансированный набор тестовых данных. Я предполагаю, что у вас такое же распределение данных о тренировках. Тогда это может быть причиной того, что NN работает плохо, потому что у него не так много данных, чтобы многие классы могли изучить особенности. Вы можете попробовать некоторые методы выборки, чтобы сравнить соотношение между каждым классом.

Вот ссылка , которая объясняет различные методы для такого набора данных дисбаланса.

Во-вторых, вы можете проверить производительность модели путем перекрестной проверки, где вы можете легко найти, является ли это приводимой или неснижаемой ошибкой. Если это неустранимая ошибка, вы не можете больше улучшаться (вам нужно попробовать другой метод для этой ситуации).

В-третьих, существует взаимосвязь между последовательностями. Простая сеть с прямой связью не может уловить такое отношение. Recurrent-network может захватывать такие зависимости в наборе данных. Вот простой пример для этого. Этот пример для двоичного класса, который может быть расширен до multi-class, как в вашем случае.

Для выбора loss-function это полностью зависит от проблемы. Вы можете проверить эту ссылку , в которой объясняется, когда и какая функция потерь может быть полезна.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...