Я начинающий пользователь pytorch и пытаюсь использовать dataloader.
На самом деле, я пытаюсь внедрить это в свою сеть, но загрузка занимает очень много времени.Итак, я отладил свою сеть, чтобы посмотреть, есть ли проблема в самой сети, но оказалось, что она имеет отношение к моему классу загрузчика данных.Вот код:
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
class DiabetesDataset(Dataset):
def __init__(self, csv):
self.xy = pd.read_csv(csv)
def __len__(self):
return len(self.xy)
def __getitem__(self, index):
self.x_data = torch.Tensor(xy.iloc[:, 0:-1].values)
self.y_data = torch.Tensor(xy.iloc[:, [-1]].values)
return self.x_data[index], self.y_data[index]
dataset = DiabetesDataset("trial.csv")
train_loader = DataLoader(dataset=dataset,
batch_size=1,
shuffle=True,
num_workers=2)`
for a in train_loader:
print(a)
Чтобы убедиться, что загрузчик данных вызывает всю задержку, я создал фиктивный CSV-файл с 2 столбцами 1 и 2, всего 10 выборок для каждого столбца.Затем я зациклился на объект train_loader, это было более 1 часа, и он все еще работает, учитывая, что размер выборки невелик, а размер партии установлен на 1.
Я не уверен в том, чтоошибка в моем коде, и это вызывает эту проблему.
Любые комментарии / комментарии приветствуются!