CNN не учится правильно - PullRequest
       79

CNN не учится правильно

0 голосов
/ 06 апреля 2020

У меня есть небольшой набор данных из 500 изображений растений, и я должен предсказать число для одного изображения в диапазоне [1, 10]. Между числами существует порядковая связь (10> 9> ...> 1). Эта проблема похожа на оценку возраста на основе одной фотографии.

Я попытался регрессировать с помощью Resnet18, Resnet34 и VGG16. Ни один из них не дал очень хорошего результата.

Интересный момент заключается в том, что когда я построил тепловую карту для нескольких изображений, она показала, что модель выбирает неправильные точки, чтобы предсказать ответ. Это похоже на то, что если бы я предполагал прогнозировать возраст на основе фотографии лица, то cnn придавал больше значения фону, чем фактическому лицу.

Я пробовал и другие подходы, такие как классификация и обучение ранжированию, но то же самое происходит, когда я делаю карту тепла. В этих подходах лучшая точность, которую я получаю, составляет 30% при использовании классификации и 35% при использовании обучения для ранжирования.

В подходах регрессии и классификации я использовал реализацию Fastai с предварительной подготовкой. Подход обучения к рангу я использовал это: https://github.com/Raschka-research-group/coral-cnn. Я немного изменил, чтобы иметь возможность использовать и предварительно обученную модель.

Еще один важный момент - это то, что набор данных не сбалансирован. 80% набора данных соответствуют классам с 6 по 10.

У кого-нибудь есть какие-либо советы по его улучшению или другой подход, который я мог бы попробовать?

РЕДАКТИРОВАТЬ: мое увеличение данных выглядит следующим образом:

transforms.Compose([
                  transforms.Resize(256), transforms.CenterCrop(224),
                  transforms.RandomHorizontalFlip(p=0.5),
                  transforms.ColorJitter(brightness=0.15), 
                  transforms.ToTensor(),
                  transforms.Normalize([0.485, 0.456, 0.406], [0.299, 0.224, 0.225])
])

1 Ответ

0 голосов
/ 06 апреля 2020

Вы можете попробовать увеличить набор данных, чтобы получить больше данных (например, случайное кадрирование, вращение и т. Д. c) и убедиться, что вы нормализуете свои данные. Для решения проблемы дисбаланса классов вы можете попробовать использовать PyTorch WeightedRandomSampler:

#Let there be 9 samples in class 0 and 1 sample in class 1 respectively
class_counts = [9.0, 1.0]
num_samples = sum(class_counts)
labels = [0, 0,..., 0, 1] #corresponding labels of samples

class_weights = [num_samples/class_counts[i] for i in range(len(class_counts))]
weights = [class_weights[labels[i]] for i in range(int(num_samples))]
sampler = WeightedRandomSampler(torch.DoubleTensor(weights), int(num_samples))

. Вы сможете легко применить это к вашему случаю с 10 классами, надеюсь, это решит вашу проблему!

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...