Классификация Tensorflow 2.0 (Keras) с ограниченными классами - PullRequest
0 голосов
/ 29 мая 2020

Предпосылки проблемы

У меня есть основная c проблема классификации, классифицирующая каждую простую строку x в один из 20 классов.

Однако есть поворот. Каждая строка имеет связанный набор ограничений класса, выраженный в виде двоичной матрицы размером 20 x N. Идея состоит в том, чтобы обнулить логиты, которые не разрешены на последнем уровне, и не распространять ошибку, которая возникает для этих классов *. 1005 *

Вопрос

У меня есть рабочее решение для Tensorflow 1.0, это простое умножение на последнем слое. Однако я хочу переписать его в Tensorflow 2.0 и Keras.

Я предполагаю, что мне нужно будет передать матрицу ограничений класса в model.fit() вместе с входными данными. Как мне go сделать это?

Одна плохая идея решения

Одно простое решение - объединить входную матрицу и матрицу ограничений класса, и позволить нейронной сети изучить концепцию ограничения класса с нуля.

Однако это ненадежно и делает нейронную сеть излишне крупнее.

1 Ответ

0 голосов
/ 30 мая 2020

Но как это работает во время вывода? Знаете ли вы ограничения классов для новой строки во время вывода?

Если ответ «да»:

Я думаю, вам не следует использовать всю матрицу ограничений класса в качестве входных данных, а скорее вектор ограничения класса, использующий конкатенацию. Таким образом, вместо подачи row с формой (n,) вы подаете row_plus_class_restrictions с формой (n+20,).

row_feature_0
row_feature_1
...
row_feature_n
0
1
.
.
.
1

. Таким образом, вам также не нужно аннулировать любую ошибку, модель будет узнать, что он должен выводить на основе потери классификации.

Если ответ «нет»:

Тогда ваша модель не имеет особого смысла. Данные обучения представляют собой набор (row, class_restrictions, class_it_should_be) с размером (nb_row_features + 20 + 20), это правильно? Чему вы пытаетесь обучить - действительно практическому применению - какие данные находятся в ваших строках? Я не понимаю, чего бы вы хотели, если ответ отрицательный.

...