Я сейчас загружаю в модель и 11 входных значений. Затем я отправляю эти 11 значений в тензор и пытаюсь предсказать результаты. Вот мой код:
# coding: utf-8
# In[5]:
import torch
import torchvision
from torchvision import transforms, datasets
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as utils
import numpy as np
data_np = np.loadtxt('input_preds.csv', delimiter=',')
train_ds = utils.TensorDataset(torch.tensor(data_np, dtype=torch.float32).view(-1,11))
trainset = torch.utils.data.DataLoader(train_ds, batch_size=1, shuffle=True)
# setting device on GPU if available, else CPU, replace .cuda() with .to(device)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
class Net(nn.Module):
def __init__(self):
super().__init__()
#self.bn = nn.BatchNorm2d(11)
self.fc1 = nn.Linear(11, 22)
self.fc2 = nn.Linear(22, 44)
self.fc3 = nn.Linear(44, 22)
self.fc4 = nn.Linear(22, 11)
def forward(self, x):
#x = x.view(-1, 11)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
x = self.fc4(x)
#return F.log_softmax(x, dim=1)
return x
model1 = torch.load('./1e-2')
model2 = torch.load('./1e-3')
for data in trainset:
X = data
X = X
output = model1(X).to(device)
print(output)
Тем не менее, я получаю эту ошибку
Traceback (most recent call last):
File "inference.py", line 53, in <module>
output = model1(X).to(device)
File "C:\Users\Happy\Miniconda3\envs\torch\lib\site-packages\torch\nn\modules\module.py", line 477, in __call__
result = self.forward(*input, **kwargs)
File "inference.py", line 40, in forward
x = F.relu(self.fc1(x))
File "C:\Users\Happy\Miniconda3\envs\torch\lib\site-packages\torch\nn\modules\module.py", line 477, in __call__
result = self.forward(*input, **kwargs)
File "C:\Users\Happy\Miniconda3\envs\torch\lib\site-packages\torch\nn\modules\linear.py", line 55, in forward
return F.linear(input, self.weight, self.bias)
File "C:\Users\Happy\Miniconda3\envs\torch\lib\site-packages\torch\nn\functional.py", line 1022, in linear
if input.dim() == 2 and bias is not None:
AttributeError: 'list' object has no attribute 'dim'
Я пытался преобразовать пакет в пустой массив, но это не помогло. Как мне решить эту ошибку? Спасибо за вашу помощь.