Возможно, просто установите input.requires_grad = True
для каждой входной партии, в которую вы вводите, а затем после loss.backward()
вы увидите, что input.grad
содержит ожидаемый градиент.Другими словами, если ваши входные данные для модели (которую вы называете features
в вашем коде) имеют некоторый M x N x ...
тензор, features.grad
будет тензором той же формы, где каждый элемент grad
содержит градиентпо отношению к соответствующему элементу features
.В моих комментариях ниже я использую i
в качестве обобщенного индекса - если ваш parameters
имеет, например, 3 измерения, замените его на features.grad[i, j, k]
и т. Д.
Что касается получаемой ошибки: PyTorchОперации строят дерево, представляющее математическую операцию, которую они описывают, которая затем используется для дифференцирования.Например, c = a + b
создаст дерево, в котором a
и b
являются листовыми узлами, а c
не является листом (так как это результат других выражений).Ваша модель - это выражение, а его входы и параметры - это листья, тогда как все промежуточные и конечные результаты не являются листьями.Вы можете думать о листьях как о «константах» или «параметрах» и обо всех других переменных как о функциях тех.Это сообщение говорит вам, что вы можете установить только requires_grad
листовых переменных.
Ваша проблема в том, что на первой итерации features
является случайным (или, в противном случае, вы инициализируете) и, следовательно, действительным листом.После вашей первой итерации features
больше не является листом, поскольку он становится выражением, рассчитанным на основе предыдущих.В псевдокоде у вас есть
f_1 = initial_value # valid leaf
f_2 = f_1 + your_grad_stuff # not a leaf: f_2 is a function of f_1
, чтобы справиться с этим, вам нужно использовать detach
, который разрывает связи в дереве и заставляет автограда обрабатывать тензор так, как если бы онбыл постоянным, независимо от того, как он был создан.В частности, расчеты градиента не будут распространяться через detach
.Поэтому вам нужно что-то вроде
features = features.detach() - 0.01 * features.grad
Примечание: возможно, вам нужно посыпать еще пару detach
здесь и там, что трудно сказать, не видя весь ваш код и не зная точную цель.