Как назначить матрицу данных в качестве метки для каждого входного изображения в моем наборе данных с помощью PyTorch? - PullRequest
0 голосов
/ 14 февраля 2019

Я хочу обучить сверточную нейронную сеть (CNN) в PyTorch для прогнозирования данных частотного спектра, связанных с входным изображением.Вместо того, чтобы назначать одну метку каждому изображению (собака, кошка, машина, самолет и т. Д.), Я хотел бы назначить матрицу меток (одна метка на частоту) для каждого изображения.В PyTorch, как мне назначить матрицу данных в качестве метки для каждого входного изображения в моем наборе данных?Я пытался сделать это с помощью ImageFolder.Спасибо!

1 Ответ

0 голосов
/ 16 февраля 2019

Набор данных ImageFolder ожидает определенную структуру, которая обрабатывает только одну метку класса и, следовательно, вероятно, не идеальна для ваших нужд.

Если вы посмотрите на DatasetFolder Class , который является родительским классомв ImageFolder обратите внимание на метод __getitem__, который берет индекс и возвращает образец и цель.Единственное, что вы можете сделать, это изменить атрибут samples загрузчика данных так, чтобы он возвращал ваш собственный путь и метку кортежа для каждого индекса.

В качестве альтернативы вы можете создать свой собственный класс набора данных, который наследуется от torch.utils.data.Datasetи имеет метод __getitem__ и __len__.

...