Я тренирую модель глубокого обучения с использованием PyTorch.По неизвестным причинам память продолжает накапливаться, что приводит к сеансу, уничтоженному до 30 эпох и недостаточному оснащению.
Некоторые мысли здесь:
Интересно, вызвано ли это matplotlib
поэтому я добавил plt.close('all')
;не работает
Добавлено gc.collect()
;не работал
Интересно, вызвано ли это cv2.imwrite()
, но не знаю, как это проверить.Есть предложения?
Проблемы с PyTorch?
другие ...
model.train()
for epo in range(epoch):
for i, data in enumerate(trainloader, 0):
inputs = data
inputs = Variable(inputs)
optimizer.zero_grad()
top = model.upward(inputs + white(inputs))
outputs = model.downward(top, shortcut = True)
loss = criterion(inputs, outputs)
loss.backward()
optimizer.step()
# Print generated pictures every 100 iters
if i % 100 == 0:
inn = inputs[0].view(128, 128).detach().numpy() * 255
cv2.imwrite("/home/tk/Documents/recover/" + str(epo) + "_" + str(i) + ".png", inn)
out = outputs[0].view(128, 128).detach().numpy() * 255
cv2.imwrite("/home/tk/Documents/recover/" + str(epo) + "_" + str(i) + "_re.png", out)
# Print loss every 50 iters
if i % 50 == 0:
print ('[%d, %5d] loss: %.3f' % (epo, i, loss.item()))
gc.collect()
plt.close("all")
===================================================================
20181222 Обновление
Наборы данных и DalaLoader
class MSourceDataSet(Dataset):
def __init__(self, clean_dir):
for i in cleanfolder:
with open(clean_dir + '{}'.format(i)) as f:
clean_list.append(torch.Tensor(json.load(f)))
cleanblock = torch.cat(clean_list, 0)
self.spec = cleanblock
def __len__(self):
return self.spec.shape[0]
def __getitem__(self, index):
spec = self.spec[index]
return spec
trainset = MSourceDataSet(clean_dir)
trainloader = torch.utils.data.DataLoader(dataset = trainset,
batch_size = 4,
shuffle = True)
Модель действительно сложная и длинная.... плюс проблема накопления памяти раньше не возникала (при использовании той же модели), поэтому я не буду публиковать ее здесь ...