Я новичок в PyTorch, и мне было интересно, не могли бы вы объяснить мне некоторые ключевые различия между стандартной моделью.train () в PyTorch и функцией train () здесь.
Другой поездФункция () включена в официальный учебник PyTorch по классификации текста и была сбита с толку относительно того, сохраняются ли веса моделей в конце обучения.
https://pytorch.org/tutorials/intermediate/char_rnn_classification_tutorial.html
learning_rate = 0.005
criterion = nn.NLLLoss()
def train(category_tensor, line_tensor):
hidden = rnn.initHidden()
rnn.zero_grad()
for i in range(line_tensor.size()[0]):
output, hidden = rnn(line_tensor[i], hidden)
loss = criterion(output, category_tensor)
loss.backward()
# Add parameters' gradients to their values, multiplied by learning rate
for p in rnn.parameters():
p.data.add_(-learning_rate, p.grad.data)
return output, loss.item()
Это функция.Затем эта функция вызывается несколько раз в таком виде:
n_iters = 100000
print_every = 5000
plot_every = 1000
record_every = 500
# Keep track of losses for plotting
current_loss = 0
all_losses = []
predictions = []
true_vals = []
def timeSince(since):
now = time.time()
s = now - since
m = math.floor(s / 60)
s -= m * 60
return '%dm %ds' % (m, s)
start = time.time()
for iter in range(1, n_iters + 1):
category, line, category_tensor, line_tensor = randomTrainingExample()
output, loss = train(category_tensor, line_tensor)
current_loss += loss
if iter % print_every == 0:
guess, guess_i = categoryFromOutput(output)
correct = 'O' if guess == category else 'X (%s)' % category
print('%d %d%% (%s) %.4f %s / %s %s' % (iter, iter / n_iters * 100, timeSince(start), loss, line, guess, correct))
if iter % plot_every == 0:
all_losses.append(current_loss / plot_every)
current_loss = 0
if iter % record_every == 0:
guess, guess_i = categoryFromOutput(output)
predictions.append(guess)
true_vals.append(category)
Мне кажется, что веса моделей не сохраняются и не обновляются, а перезаписываются на каждой итерации при написании таким образом.Это правильно?Или модель, кажется, обучается правильно?
Кроме того, если бы я использовал стандартную функцию model.train (), каково главное преимущество и выполняет ли model.train () более или менее то же самое?функциональность как функция train () выше?