Я студент и новичок в Python и PyTorch. У меня есть базовая c нейронная сеть, для которой я сталкиваюсь с упомянутой ошибкой RunTimeError. Код для воспроизведения ошибки следующий:
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
# Ensure Reproducibility
torch.manual_seed(0)
# Data Generation
x = torch.randn((100,1), requires_grad = True)
y = 1 + 2 * x + 0.3 * torch.randn(100,1)
# Shuffles the indices
idx = np.arange(100)
np.random.shuffle(idx)
# Uses first 80 random indices for train
train_idx = idx[:70]
# Uses the remaining indices for validation
val_idx = idx[70:]
# Generates train and validation sets
x_train, y_train = x[train_idx], y[train_idx]
x_val, y_val = x[val_idx], y[val_idx]
class OurFirstNeuralNetwork(nn.Module):
def __init__(self):
super(OurFirstNeuralNetwork, self).__init__()
# Here we "define" our Neural Network Architecture
self.fc1 = nn.Linear(1, 5)
self.non_linearity_fc1 = nn.ReLU()
self.fc2 = nn.Linear(5,1)
#self.non_linearity_fc2 = nn.ReLU()
def forward(self, x):
# The forward pass
# Here we define how activations "flow" between neurons. We've already discussed the "Sum" and "Transformation" steps of the forward pass.
sum_fc1 = self.fc1(x)
transformation_fc1 = self.non_linearity_fc1(sum_fc1)
sum_fc2 = self.fc2(transformation_fc1)
#transformation_fc2 = self.non_linearity_fc2(sum_fc2)
# The transformation_fc2 is also the output of our model which symbolises the end of our forward pass.
return sum_fc2
# Instantiate the model and train
model = OurFirstNeuralNetwork()
print(model)
print(model.state_dict())
n_epochs = 1000
loss_fn = nn.MSELoss(reduction='mean')
optimizer = optim.Adam(model.parameters())
for epoch in range(n_epochs):
model.train()
optimizer.zero_grad()
prediction = model(x_train)
loss = loss_fn(y_train, prediction)
print(epoch, loss)
loss.backward(retain_graph=True)
optimizer.step()
print(model.state_dict())
Все основано c и стандартно, и это прекрасно работает.
Однако, когда я вынимаю Аргумент "retain_graph = True", он вызывает RunTimeError. Прочитав различные форумы, я понимаю, что это связано с тем, что граф отбрасывается после первой итерации, но я видел много уроков и блогов, где loss.backward()
- это путь к go, тем более что он сохраняет память. Но я не могу концептуально gr asp, почему то же самое не работает для меня.
Любая помощь приветствуется, и мои извинения, если мой вопрос задан не в ожидаемом формате. Я открыт для обратной связи и обязуюсь включить больше деталей или перефразировать вопрос, чтобы всем было легче. Заранее спасибо!