Как сбалансировать (передискретизировать) несбалансированные данные в PyTorch (с WeightedRandomSampler)? - PullRequest
0 голосов
/ 29 января 2019

У меня проблема с 2 классами, и мои данные сильно разбалансированы.У меня 232550 образцов из одного класса и 13498 из второго класса.Документы PyTorch и Интернет говорят мне использовать класс WeightedRandomSampler для моего DataLoader.

Я пытался использовать WeightedRandomSampler, но продолжаю получать ошибки.

    trainratio = np.bincount(trainset.labels)
    classcount = trainratio.tolist()
    train_weights = 1./torch.tensor(classcount, dtype=torch.float)
    train_sampleweights = train_weights[trainset.labels]
    train_sampler = WeightedRandomSampler(weights=train_sampleweights, 
    num_samples = len(train_sampleweights))
    trainloader = DataLoader(trainset, sampler=train_sampler, 
    shuffle=False)

Не могу понять, почему я получаю эту ошибку при инициализации класса WeightedRandomSampler?

Я пробовал другие подобные обходные пути, но пока все попытки приводят к некоторой ошибке.Как мне реализовать это, чтобы сбалансировать данные о моих поездах, валидации и тестах?

В настоящее время получаю эту ошибку:

train__sampleweights = train_weights [trainset.labels] ValueError: str слишком много измерений 'str'

1 Ответ

0 голосов
/ 29 января 2019

Проблема в типе trainset.labels. Чтобы исправить ошибку, можно конвертировать trainset.labels в float

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...