как сгладить ввод в `nn.Sequential` в Pytorch - PullRequest
0 голосов
/ 28 декабря 2018

как сгладить ввод внутри nn.Sequential

Model = nn.Sequential(x.view(x.shape[0],-1),
                     nn.Linear(784,256),
                     nn.ReLU(),
                     nn.Linear(256,128),
                     nn.ReLU(),
                     nn.Linear(128,64),
                     nn.ReLU(),
                     nn.Linear(64,10),
                     nn.LogSoftmax(dim=1))

Ответы [ 2 ]

0 голосов
/ 26 июня 2019

Как определено flatten метод

torch.flatten(input, start_dim=0, end_dim=-1) → Tensor

- скорость, сравнимая с view(), но reshape еще быстрее.

import torch.nn as nn

class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

flatten = Flatten()

t = torch.Tensor(3,2,2).random_(0, 10)
print(t, t.shape)


#https://pytorch.org/docs/master/torch.html#torch.flatten
f = torch.flatten(t, start_dim=1, end_dim=-1)
print(f, f.shape)


#https://pytorch.org/docs/master/torch.html#torch.view
f = t.view(t.size(0), -1)
print(f, f.shape)


#https://pytorch.org/docs/master/torch.html#torch.reshape
f = t.reshape(t.size(0), -1)
print(f, f.shape)

Проверка скорости

# flatten 3.49 µs ± 146 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
# view 3.23 µs ± 228 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
# reshape 3.04 µs ± 93 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

Если бы мы использовали класс сверху

flatten = Flatten()
t = torch.Tensor(3,2,2).random_(0, 10)
%timeit f=flatten(t)


5.16 µs ± 122 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

Этот результат показывает, что создание класса будет более медленным подходом.Вот почему быстрее сплющивать тензоры вперед.Я думаю, что это главная причина, по которой они не продвинулись nn.Flatten.

Так что мое предложение было бы использовать внутри вперед для скорости.Как то так:

out = inp.reshape(inp.size(0), -1)
0 голосов
/ 28 декабря 2018

Вы можете создать новый модуль / класс, как показано ниже, и использовать его последовательно, как и другие модули (вызов Flatten()).

class Flatten(torch.nn.Module):
    def forward(self, x):
        batch_size = x.shape[0]
        return x.view(batch_size, -1)

Ссылка: https://discuss.pytorch.org/t/flatten-layer-of-pytorch-build-by-sequential-container/5983

...