Как узнать форму ввода модели Pytorch? - PullRequest
0 голосов
/ 10 октября 2019

У меня есть список из 100 матриц с формой (20,48), и я хочу передать эту матрицу в pytorch.

Это пример кода

import torch.nn.functional as F
import torch.nn as nn
import torch

sample = torch.randn(100,20,48)

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv1d(20, 48, kernel_size=2)
    def forward(self, x):
        return self.conv(x)

net = Net()

for i in net.state_dict().keys():
    print(i)

for i in list(net.parameters()):
    print(i.shape)

#output

conv.weight
conv.bias

torch.Size([48, 20, 2])
torch.Size([48])

Как проверить, что моя модель принимает ввод определенной формы? В моем случае, как я могу подтвердить, что мой входной слой модели принимает матрицу размера (bs, 20,48)?

...