В чем разница между unsqueez_ в pytorch и epxand_dim в keras и какой будет форма вывода после его использования? - PullRequest
1 голос
/ 29 апреля 2019

Я новичок в keras, и у меня есть код pytorch, который мне нужно заменить на keras, но я не мог понять какую-то его часть.особенно у меня проблемы с размером выходной фигуры.форма image - это (:, 3,32,32), а первое измерение image - это размер партии.теперь мой вопрос: что делает эта линия и какова выходная форма:

    image_yuv_ch = image[:, channel, :, :].unsqueeze_(1)

добавляет измерение в позиции 1?что такое выходная форма? :( размер фильтров был (64,8,8) и тогда у нас есть filters.unsqueez_(1), это означает, что новая форма filters равна (64,1,8,8)? что делает эта строка? image_conv = F.conv2d(image_yuv_ch, filters, stride=8) это то же самое, что и conv2d в керасе, какова форма выходного тензора из него? Я также не мог понять, что делает представление? Я знаю, что он пытается показать тензор в новой форме, но в приведенном ниже коде я не мог понять выводформа после каждой unsqueez_, permute или view. Подскажите, пожалуйста, какова выходная форма каждой строки? Заранее спасибо.

import torch.nn.functional as F
def apply_conv(self, image, filter_type: str):



        if filter_type == 'dct':
            filters = self.dct_conv_weights
        elif filter_type == 'idct':
            filters = self.idct_conv_weights
        else:
            raise('Unknown filter_type value.')

        image_conv_channels = []
        for channel in range(image.shape[1]):
            image_yuv_ch = image[:, channel, :, :].unsqueeze_(1)
            image_conv = F.conv2d(image_yuv_ch, filters, stride=8)
            image_conv = image_conv.permute(0, 2, 3, 1)
            image_conv = image_conv.view(image_conv.shape[0], image_conv.shape[1], image_conv.shape[2], 8, 8)
            image_conv = image_conv.permute(0, 1, 3, 2, 4)
            image_conv = image_conv.contiguous().view(image_conv.shape[0],
                                                  image_conv.shape[1]*image_conv.shape[2],
                                                  image_conv.shape[3]*image_conv.shape[4])

            image_conv.unsqueeze_(1)

            # image_conv = F.conv2d()
            image_conv_channels.append(image_conv)

        image_conv_stacked = torch.cat(image_conv_channels, dim=1)

        return image_conv_stacked

1 Ответ

1 голос
/ 30 апреля 2019

Похоже, вы пользователь Keras или пользователь Tensorflow и пытаетесь изучить Pytorch. Вы должны перейти на веб-сайт документа Pytorch, чтобы узнать больше о каждой операции.

  • unsqueeze - увеличить тусклость на 1 тензора. Подчеркивание в unsqueeze_() означает, что это функция in-place.
  • view() можно понимать как .reshape() в кератах.
  • permute() - переключение нескольких измерений тензора Например:
x = torch.randn(1,2,3) # shape [1,2,3]
x = torch.permute(2,0,1) # shape [3,1,2]

Чтобы узнать форму тензора после каждой операции, просто добавьте print(x.size()). Например:

image_conv = image_conv.permute(0, 2, 3, 1)
print(image_conv.size())

image_conv = image_conv.view(image_conv.shape[0], image_conv.shape[1], 
print(image_conv.size())

image_conv.shape[2], 8, 8)
print(image_conv.size())

image_conv = image_conv.permute(0, 1, 3, 2, 4)
print(image_conv.size())

Большая разница между Pytorch и Tensorflow (серверная часть Keras) заключается в том, что Pytorch будет генерировать динамический граф, а не статический граф как Tensorflow. Ваш способ определения модели не будет работать должным образом в Pytorch, поскольку веса conv не будут сохранены в model.parameters(), который нельзя оптимизировать во время обратного распространения.

Еще один комментарий, пожалуйста, проверьте ссылку , чтобы узнать, как определить правильную модель с помощью Pytorch:

import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)

    def forward(self, x):
       x = F.relu(self.conv1(x))
       return F.relu(self.conv2(x))

Код для комментария:


import torch

x = torch.randn(8, 3, 32, 32)
print(x.shape)
torch.Size([8, 3, 32, 32])
channel = 1
y = x[:, channel, :, :]
print(y.shape)
torch.Size([8, 32, 32])

y = y.unsqueeze_(1)
print(y.shape)
torch.Size([8, 1, 32, 32])

Надеюсь, это поможет вам и наслаждайтесь обучением!

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...