Использовать дискриминатор GAN для классификации по одному классу в MATLAB - PullRequest
1 голос
/ 04 февраля 2020

Я пытаюсь внедрить нейронную сеть для классификации различных дефектов для проверки качества. Я хочу использовать одноклассную классификацию. Чтобы выполнить sh этого, я хочу обучить генеративные состязательные сети и использовать дискриминатор для классификации.

Я использовал пример подсолнечника для реализации моей первой GAN. (https://de.mathworks.com/help/deeplearning/examples/train-generative-adversarial-network.html)

В этом примере есть строка, которая «классифицирует» сгенерированные выходные данные с помощью сети дискриминатора:

dlYPredGenerated = forward(dlnetDiscriminator, dlXGenerated); 

Теперь Я ожидал, что вывод будет состоять из 2 меток: «Оригинал» или «Подделка». Вместо этого я получаю длинный список чисел:

(:,:,1,1) =
    5.9427
(:,:,1,2) =
    7.5930
(:,:,1,3) =
    9.3393
etc.

Я думаю, что это значения потерь для сети дискриминатора.

Я хотел бы знать, как я могу использовать получившуюся сеть дискриминаторов для классификации входных изображений. Проблема состоит в том, что сеть дискриминатора не имеет полностью связанных уровней или классификационного уровня в конце структуры уровня. Похоже, что архитектура дискриминатора отличается от архитектуры «нормальной» сверточной нейронной сети.

Может кто-нибудь помочь мне с этой задачей?

TL; DR: Я хочу использовать пример подсолнечника matlab (https://de.mathworks.com/help/deeplearning/examples/train-generative-adversarial-network.html) для обучающих GAN и извлечь дискриминатор для функционирования в качестве сети классификации.

1 Ответ

0 голосов
/ 04 апреля 2020

Из примера matlab sunflower,

dlYPredGenerated = forward(dlnetDiscriminator, dlXGenerated);

выдает выходные данные последнего подключенного слоя F C без активации (не потери). Вот почему он поставляется с

probGenerated = sigmoid(dlYPredGenerated);

Следовательно, probGenerated - это реальный результат, который вы хотите видеть как фальшивую или реальную вероятность. Кстати, выход имеет 4 измерения, так как он имеет метку fmt SSCB (Spatial-пространственный-канал-пакет), и потеря составляет от

[lossGenerator, lossDiscriminator] = ganLoss(probReal,probGenerated);

...