каждый, кого я пытаюсь обучить, модель pytorch
содержала настроенный слой без изучаемых параметров, и модель не выдает ошибку grad_fn.
import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
import torch.utils.data as Data
# custom layer
class Filter(nn.Module):
def __init__(self, filter_size=10, filter_step=1):
super(Filter, self).__init__()
self.filter_size = filter_size
self.filter_step = filter_step
def Output_filter(self, x):
frame_indx = 0
while 1:
if frame_indx + self.filter_size <= len(x):
if sum(x[frame_indx: frame_indx + self.filter_size]) >= self.filter_size / 2:
x[frame_indx: frame_indx + self.filter_size] = 1
else:
if sum(x[frame_indx:]) >= len(x[frame_indx:]) / 2:
x[frame_indx:] = 1
break
frame_indx = frame_indx + self.filter_step
return x
def forward(self, x):
x = F.relu(self.Output_filter(x))
return x
# model
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(7,2)
self.filter = Filter(3,1)
def forward(self, x):
x = self.fc1(x)
x = F.relu(x)
x = torch.argmax(x, 1).float()
x = self.filter(x)
return x
if __name__ == "__main__":
net = Net()
optimizer = torch.optim.Adam(net.parameters())
train_x = torch.rand((10,7),dtype=torch.float)
train_y = torch.tensor([1,1,1,1,0,1,0,0,1,0], dtype=torch.float)
train_dataset = Data.TensorDataset(train_x, train_y)
train_loader = Data.DataLoader(dataset=train_dataset, batch_size=10, shuffle=False, num_workers=2)
loss_func = nn.BCELoss()
net.train()
for step, (b_x, b_y) in enumerate(train_loader, 1):
pred_train = net(b_x)
loss_ = loss_func(pred_train, b_y)
optimizer.zero_grad()
loss_.backward()
optimizer.step()
Я предполагаю, что операция torch.argmax(x, 1)
НЕ различима,поэтому я добавляю require_grad_ (True), например, torch.argmax(x, 1).float().requires_grad_(True)
, но на этот раз модель выдает ' листовая переменная была перемещена во внутреннюю часть графика' ошибка.
В режиме отладки pycharm x.is_leaf
равен TRUE после операции torch.argmax
.У меня вопрос: есть ли любая другая дифференцируемая функция, которая может использоваться для замены torch.argmax
, или есть ли другие способы заставить мой код работать
?