Ошибка класса DataLoader Pytorch - PullRequest
       51

Ошибка класса DataLoader Pytorch

0 голосов
/ 27 февраля 2019

Я начинающий пользователь pytorch и пытаюсь использовать dataloader.

На самом деле, я пытаюсь внедрить это в свою сеть, но загрузка занимает очень много времени.Итак, я отладил свою сеть, чтобы посмотреть, есть ли проблема в самой сети, но оказалось, что она имеет отношение к моему классу загрузчика данных.Вот код:

 from torch.utils.data import Dataset, DataLoader
 import numpy as np
 import pandas as pd

class DiabetesDataset(Dataset):

  def __init__(self, csv):
      self.xy = pd.read_csv(csv)

  def __len__(self):
      return len(self.xy)

  def __getitem__(self, index):
       self.x_data = torch.Tensor(xy.iloc[:, 0:-1].values)
       self.y_data = torch.Tensor(xy.iloc[:, [-1]].values)
       return self.x_data[index], self.y_data[index]

 dataset = DiabetesDataset("trial.csv")
 train_loader = DataLoader(dataset=dataset,
                      batch_size=1,
                      shuffle=True,
                      num_workers=2)`

 for a in train_loader:
    print(a)

Чтобы убедиться, что загрузчик данных вызывает всю задержку, я создал фиктивный CSV-файл с 2 столбцами 1 и 2, всего 10 выборок для каждого столбца.Затем я зациклился на объект train_loader, это было более 1 часа, и он все еще работает, учитывая, что размер выборки невелик, а размер партии установлен на 1.

Я не уверен в том, чтоошибка в моем коде, и это вызывает эту проблему.

Любые комментарии / комментарии приветствуются!

1 Ответ

0 голосов
/ 27 февраля 2019

В вашем коде есть некоторые ошибки - не могли бы вы проверить, работает ли это (это работает на моем компьютере с примером вашей игрушки):

from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
import torch


class DiabetesDataset(Dataset):

    def __init__(self, csv):
        self.xy = pd.read_csv(csv)

    def __len__(self):
        return len(self.xy)

    def __getitem__(self, index):
        x_data = torch.Tensor(self.xy.iloc[:, 0:-1].values)
        y_data = torch.Tensor(self.xy.iloc[:, [-1]].values)
        return x_data[index], y_data[index]


dataset = DiabetesDataset("trial.csv")


train_loader = DataLoader(
    dataset=dataset,
    batch_size=1,
    shuffle=True,
    num_workers=2)

if __name__ == '__main__':
    for a in train_loader:
        print(a)

Редактировать : Ваш кодне работает, потому что у вас отсутствует self в методе __getitem__ (self.xy.iloc ...) и потому что у вас нет if __name__ == '__main__ в конце вашего скрипта.В отношении второй ошибки см. RuntimeError в Windows, в которой выполняется многопроцессорная обработка Python

...