как загрузить данные и как сделать увеличение данных с помощью Pytorch - PullRequest
0 голосов
/ 05 марта 2020

Я новичок ie в Pytorch, у меня проблемы с классификацией изображений, но я не понимаю, как загрузить изображение из каталога загрузки, пожалуйста, помогите мне, как загрузить данные изображения и как увеличить.

здесь мои данные выглядят так:

train=pd.read_csv('dataset/train.csv')
test=pd.read_csv('dataset/test.csv') 
train.head()
Image   Class
0   image7042.jpg   Food
1   image3327.jpg   misc
2   image10335.jpg  Attire
3   image8019.jpg   Food
4   image2128.jpg   Attire

здесь моя папка изображений:

file_path='dataset/Train Images'

1 Ответ

0 голосов
/ 05 марта 2020

Вы можете использовать torchvision для этого. Предполагая, что у вас есть все обучающие / тестовые изображения, разделенные на две папки с именами train и test, вот пример кода для загрузки и перебора изображений:

import torchvision
from torchvision import datasets, transforms

def load_dataset(data_path):
    dataset = torchvision.datasets.ImageFolder(
        root=data_path,
        transform=transforms.Compose([torchvision.transforms.ToTensor()])
    )
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=1,
        num_workers=0,
        shuffle=True
    )
    return data_loader

train_loader = load_dataset(f'{base_dir}/train')
test_loader = load_dataset(f'{base_dir}/test')

for batch_idx, (data, _) in enumerate(train_loader):
   # Train model

...

for batch_idx, (data, _) in enumerate(test_loader):
   # Evaluate model

Вы можете увеличить batch_size если вы хотите обучать вашу модель партиями, добавьте преобразователи к аргументу transform, чтобы увеличить изображения и многое другое.

Ознакомьтесь с документацией: https://pytorch.org/docs/stable/torchvision/index.html

...