Похоже, вы пользователь Keras или пользователь Tensorflow и пытаетесь изучить Pytorch.
Вы должны перейти на веб-сайт документа Pytorch, чтобы узнать больше о каждой операции.
unsqueeze
- увеличить тусклость на 1 тензора. Подчеркивание в unsqueeze_()
означает, что это функция in-place
.
view()
можно понимать как .reshape()
в кератах.
permute()
- переключение нескольких измерений тензора Например:
x = torch.randn(1,2,3) # shape [1,2,3]
x = torch.permute(2,0,1) # shape [3,1,2]
Чтобы узнать форму тензора после каждой операции, просто добавьте print(x.size())
. Например:
image_conv = image_conv.permute(0, 2, 3, 1)
print(image_conv.size())
image_conv = image_conv.view(image_conv.shape[0], image_conv.shape[1],
print(image_conv.size())
image_conv.shape[2], 8, 8)
print(image_conv.size())
image_conv = image_conv.permute(0, 1, 3, 2, 4)
print(image_conv.size())
Большая разница между Pytorch и Tensorflow (серверная часть Keras) заключается в том, что Pytorch будет генерировать динамический граф, а не статический граф как Tensorflow. Ваш способ определения модели не будет работать должным образом в Pytorch, поскольку веса conv
не будут сохранены в model.parameters()
, который нельзя оптимизировать во время обратного распространения.
Еще один комментарий, пожалуйста, проверьте ссылку , чтобы узнать, как определить правильную модель с помощью Pytorch:
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
Код для комментария:
import torch
x = torch.randn(8, 3, 32, 32)
print(x.shape)
torch.Size([8, 3, 32, 32])
channel = 1
y = x[:, channel, :, :]
print(y.shape)
torch.Size([8, 32, 32])
y = y.unsqueeze_(1)
print(y.shape)
torch.Size([8, 1, 32, 32])
Надеюсь, это поможет вам и наслаждайтесь обучением!