Симуляция Pytorch не сходится на выпуклой функции потерь, если не инициализирована с 0 - PullRequest
0 голосов
/ 25 октября 2019

Мой код работает, когда веса инициализируются с 0. Когда я инициализирую их согласно некоторому начальному числу, они не сходятся. Это должно быть ошибкой, так как функция потерь выпуклая.

Я отфильтровал две метки из MNIST (0 и 1), а затем я обучил модель логистической регрессии с использованием Pytorch. Поскольку я использую только 200 тренировочных образцов (и 784 параметра), модель должна быстро сходиться с точностью 100% на тренировочном наборе. Это не тот случай, когда веса инициализируются некоторым начальным числом.

У меня была некоторая проблема, чтобы поделиться своим кодом в stackoverflow, поэтому вот ссылка на код: https://drive.google.com/file/d/1ELe8TIWrXMiXgsB63B0Ss43GPr719rGc/view?usp=sharing

Ответы [ 2 ]

2 голосов
/ 25 октября 2019

Ваши данные не масштабируются и не нормализуются. Если вы посмотрите на переменную images в цикле тренировки, то это значение между 0 и 255. Это, по всей вероятности, повредит вашему тренировочному процессу.

Существуют более чистые способы подбора выборки набора данных, как вы хотите, но без внесения изменений. большая часть вашего кода, используя это определение загрузки данных

import torchvision.transforms as transforms

#Load Dataset
preprocessing = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.1307,), (0.3081,))
                                   ])
train_dataset = dsets.MNIST(root='./data', train=True, transform=preprocessing, download=True)

#Filter samples by label (to get binary classification) and by number of training samples
Binary_filter=torch.add(train_dataset.targets==1, train_dataset.targets==0)
train_dataset.data, train_dataset.targets = train_dataset.data[Binary_filter],train_dataset.targets[Binary_filter]

TrainSet_filter=torch.cat((torch.ones(num_of_training_samples)
                         ,torch.zeros(len(train_dataset.targets)-num_of_training_samples)),0).bool()
train_dataset.data, train_dataset.targets = train_dataset.data[TrainSet_filter], train_dataset.targets[TrainSet_filter]

#Make Dataset Iterable
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

У меня ~ 100% точности примерно за 5-10 эпох.

0 голосов
/ 26 октября 2019

Ваша функция потерь (BCE) является выпуклой только по отношению к выходам глубокой сети, а не по весам.

Вы определенно не можете предполагать, что любой локальный минимум также является глобальным минимумом. .

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...