Память продолжает расти при тренировке FFNN pyTorch - PullRequest
0 голосов
/ 08 ноября 2018

У меня есть нейронная сеть с прямой связью, которая классифицирует набор данных MNIST. По какой-то причине объем памяти продолжает приближаться к 99% независимо от размера пакета. У меня нет ничего, что увеличивается в размере - каждая динамическая переменная переопределяется после первой эпохи, но даже после эпохи номер 70 память продолжает расти.

Я запускаю его на 8 ГБ памяти, четырехъядерном Intel i-5 (7-го поколения) 2,8 ГГц (Ubuntu 18.04)

batch_size =50 #Number of x's we pass through the net at each iteration
num_epochs =100 #Number of times the entire training data is through the    net

n_train = len(xtrain)
n_batch_train = n_train // batch_size
n_val =len(xval)
n_batch_val = n_val // batch_size

# loss/acc
train_acc,train_loss =[],[]
val_acc,val_loss =[],[]

test_acc,test_loss =[],[]


#Get parameters from the net

par=[]

for i in range(len(layers)-1):
    par=par+list(net.L[i].parameters())



#Optimizer

optimizer = optim.Adam(par,lr=0.001)

#interval of x

get_slice = lambda i,size: range(i*size,(i+1)*size)


for e in range(num_epochs):
  curr_loss =0
  net.train()
  for i in range(n_batch_train):

    x_interval = get_slice(i,batch_size)
    slze = get_slice(i,batch_size)
    #Batchnorm
    bn = nn.BatchNorm1d(num_features = num_features)
    x_batch = bn(Variable(torch.from_numpy(xtrain[slze])))

    out = (net(x_batch)).double()
    target_batch = Variable(torch.from_numpy(ytrain[slze]).double())
    L=criterion(out,target_batch)

    #Update gradients
    optimizer.zero_grad()
    L.backward()
    optimizer.step()

    #Store training accuracy and loss
    train_acc.append(accuracy(target_batch, out).data)
    train_loss.append(L.data.numpy())

#### Validate ####
  net.eval()
  for j in range(n_batch_val):
    slze = get_slice(j,batch_size)
    val_batch = Variable(torch.from_numpy(xval[slze]))
    val_out = (net(bn(val_batch))).double()
    target_batch = Variable(torch.from_numpy(yval[slze]).double())

    #Store val acc and loss
    val_acc.append(accuracy(target_batch,val_out).data)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...