Многомерная мультиклассовая классификация с использованием CNN - PullRequest
0 голосов
/ 28 октября 2019

Фон

У меня есть набор данных тысяч траекторий и статистика ~ 120 классов объектов. В общей сложности я в настоящее время использую около 15 функций (координаты x и y, а также другие статистические данные об объекте) с около 800 выборками на последовательность. Все эти данные помечены.

Моей первоначальной целью было использование ConvLSTM для классификации, однако я изо всех сил пытаюсь даже реализовать базовую линию. Моя базовая архитектура предназначена для стандартного CNN, который свертывается вдоль временной области. Используя подмножество двух очень разных классов и учитывая только координаты x, y и очень описательную статистику, сеть работает хорошо. Тем не менее, по мере того, как я добавляю больше классов и больше статистики (которые, как я знаю, необходимы, чтобы различать эти другие, менее похожие классы), точность (то есть точность топ-1, но то же самое относится и к точности топ-5, которую я также измеряю)классификация ухудшается до степени случайного предположения.

Я ищу общие советы относительно ошибок в моей архитектуре, гиперпараметры, а также общие советы по поиску потенциальных проблем в моем подходе.

Данные игрушки

Вот пример некоторых функций для двух классов игрушек. Верхний левый график показывает координаты x, y. Например, верхняя средняя особенность очень различает эти два класса. Некоторые из других функций менее полезны для различения этих двух классов, но они полезны для различения других классов, поэтому для краткости я включил некоторые из них. Для базовой архитектуры я изначально использовал только верхнюю левую и верхнюю среднюю функцию.

Базовая архитектура

Мои первоначальные идеи основывались на следующей настройке:

+------------------------------------------------+--------------------+---------+
|                  Layer (type)                  |    Output Shape    | Param # |
+------------------------------------------------+--------------------+---------+
| input (InputLayer)                             | (None, 3, 800, 1)  |       0 |
| poolconv1conv (Conv2D)                         | (None, 1, 399, 32) |     288 |
| poolconv1avgpool (AveragePo (None, 1, 200, 32) | 0                  |         |
| poolconv1bn (BatchNormaliza (None, 1, 200, 32) | 96                 |         |
| poolconv1 (Activation)                         | (None, 1, 200, 32) |       0 |
| poolconv1do (Dropout)                          | (None, 1, 200, 32) |       0 |
| poolconv2conv (Conv2D)                         | (None, 1, 99, 32)  |    3072 |
| poolconv2avgpool (AveragePo (None, 1, 50, 32)  | 0                  |         |
| poolconv2bn (BatchNormaliza (None, 1, 50, 32)  | 96                 |         |
| poolconv2 (Activation)                         | (None, 1, 50, 32)  |       0 |
| poolconv2do (Dropout)                          | (None, 1, 50, 32)  |       0 |
| poolconv3conv (Conv2D)                         | (None, 1, 24, 64)  |    6144 |
| poolconv3avgpool (AveragePo (None, 1, 12, 64)  | 0                  |         |
| poolconv3bn (BatchNormaliza (None, 1, 12, 64)  | 192                |         |
| poolconv3 (Activation)                         | (None, 1, 12, 64)  |       0 |
| poolconv3do (Dropout)                          | (None, 1, 12, 64)  |       0 |
| poolconv4conv (Conv2D)                         | (None, 1, 5, 64)   |   12288 |
| poolconv4avgpool (AveragePo (None, 1, 3, 64)   | 0                  |         |
| poolconv4bn (BatchNormaliza (None, 1, 3, 64)   | 192                |         |
| poolconv4 (Activation)                         | (None, 1, 3, 64)   |       0 |
| poolconv4do (Dropout)                          | (None, 1, 3, 64)   |       0 |
| flatten2 (Flatten)                             | (None, 192)        |       0 |
| headfc0 (Dense)                                | (None, 1024)       |  197632 |
| headbn0 (BatchNormalization (None, 1024)       | 4096               |         |
| headdo0 (Dropout)                              | (None, 1024)       |       0 |
| headfc1 (Dense)                                | (None, 512)        |  524800 |
| headbn1 (BatchNormalization (None, 512)        | 2048               |         |
| headdo1 (Dropout)                              | (None, 512)        |       0 |
| headfc2 (Dense)                                | (None, 128)        |   65664 |
| headbn2 (BatchNormalization (None, 128)        | 512                |         |
| headdo2 (Dropout)                              | (None, 128)        |       0 |
| headout (Dense)                                | (None, 2)          |    1032 |
+------------------------------------------------+--------------------+---------+

Это может показаться преувеличенным, но интуиция заключалась в том, что с добавлением дополнительной статистики мою poolconv1conv форму просто нужно будет отрегулировать в соответствии с количеством функций, и она должна адаптироваться.

poolconv1conv предназначен для извлечения значимых взаимодействий объектов, а также поведения отдельных объектов в небольшом масштабе времени. В качестве альтернативы я попытался разделить это на два слоя Conv2D размером 1x3, а затем 3x1 (или nx1 с n = количеством объектов), чтобы зафиксировать эти два поведения по отдельности. На самом деле не улучшается.

headout использует активацию softmax, и я тренируюсь, используя categorical_crossentropy убыток для тренировки, так как это многоклассовая классификационная задача

Игрушкарезультаты

Используя описанный выше пример с классификацией двух классов, с 500 последовательностями для каждого класса и разделением 60/20/20 на обучение / валидацию / тестирование, подход работает нормально, так как графики для потери и точность изображены. Производительность валидации очень высока, но это, вероятно, из-за того, что я выбрал довольно высокий уровень обучения на ранней стадии, я полагаю? Может быть, это уже признак того, что что-то странное происходит?

Увеличение

Если я добавлю больше классов в микс, это уже станет менее многообещающим. При добавлении двух дополнительных классов, которые все еще достаточно различны (как можно видеть здесь ), не только потеря тренировки застаивается около 1, точность также резко снижается и требует гораздо большего обучения для достижения> 0,7.

Возможно, данные больше не являются такими дискриминационными для этих классов. Так как функции в нижнем левом графике хороши для различения 2 из 4 классов (предварительные знания с моей стороны), я решил добавить их. Теперь мой ввод 5х800, а мой poolconv1pool имеет kernel_size = (5,3). Предполагается, что это должно увеличить различимость, и поэтому точность должна увеличиваться быстрее, чем без этой функции как части ввода.

Я использую те же данные обучения и разделение обучения / проверки / теста, что и раньше, чтобы сделать их сопоставимыми. Результаты показывают небольшую производительность, но не до того уровня, на который я надеялся. Это послужило в некоторой степени регуляризатором точности проверки, но, кроме этого, никаких существенных улучшений не было.

На данный момент я не знаю, как действовать дальше. Я хотел бы иметь возможность получить более значимую базовую линию, используя CNN, прежде чем перейти к LSTM / ConvLSTM, но если бы мне пришлось увеличить количество классов до своего полного размера и включить все статистические функции, я сомневаюсь в производительности.

Редактировать: С тех пор я также запускаю модель, используя те же разбиения, что и выше, с большинством моих функций. это немного помогает в общей точности и скорости обучения ( потери показаны здесь ), но не так сильно, как хотелось бы, учитывая, что он все еще работает только в 4 классах. Вот результаты: красная линия - модель, которая увеличила количество объектов до 5, а цвет бирюзового цвета - модель с 10 объектами.

Так что я бы приветствовал любые указатели для улучшения этой настройки! В основном это касается:

  • Как увеличить скорость обучения от эпохи к эпохе. Под этим я подразумеваю не изменение скорости обучения, а, например, выбор архитектуры.
  • Советы по настройке сети некоторыми изящными способами. Я открыт и приветствую любые предложения, даже самые свежие.
  • Проблемы в моей текущей реализации, которые я мог пропустить.
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...