Как сбалансировать несбалансированные данные в PyTorch с WeightedRandomSampler? - PullRequest
0 голосов
/ 29 января 2019

У меня проблема с 2 классами, и мои данные несбалансированы.класс 0 имеет 232550 образцов, а класс 1 имеет 13498 образцов.Документы PyTorch и Интернет говорят мне использовать класс WeightedRandomSampler для моего DataLoader.

Я пытался использовать WeightedRandomSampler, но продолжаю получать ошибки.

    trainratio = np.bincount(trainset.labels) #trainset.labels is a list of 
    float [0,1,0,0,0,...] 
    classcount = trainratio.tolist()
    train_weights = 1./torch.tensor(classcount, dtype=torch.float)
    train_sampleweights = train_weights[trainset.labels]
    train_sampler = WeightedRandomSampler(weights=train_sampleweights, 
                                 num_samples=len(train_sampleweights))
    trainloader = DataLoader(trainset, sampler=train_sampler, 
                                       shuffle=False)

Некоторые размеры, которые я распечатал:

train_weights = tensor([4.3002e-06, 4.3002e-06, 4.3002e-06,  ..., 
4.3002e-06, 4.3002e-06, 4.3002e-06])

train_weights shape=  torch.Size([246048])

Я не понимаю, почему я 'получаю эту ошибку:

UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  self.weights = torch.tensor(weights, dtype=torch.double)

Я пробовал другие подобные обходные пути, но пока все попытки приводят к некоторой ошибке.Как мне реализовать это, чтобы сбалансировать мои данные о поездах, валидации и тестировании?

1 Ответ

0 голосов
/ 30 января 2019

Так что, очевидно, это внутреннее предупреждение, а не ошибка.По словам ребят из PyTorch, я могу продолжать кодирование и не беспокоиться о предупреждении.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...