Как определено flatten
метод
torch.flatten(input, start_dim=0, end_dim=-1) → Tensor
- скорость, сравнимая с view()
, но reshape
еще быстрее.
import torch.nn as nn
class Flatten(nn.Module):
def forward(self, input):
return input.view(input.size(0), -1)
flatten = Flatten()
t = torch.Tensor(3,2,2).random_(0, 10)
print(t, t.shape)
#https://pytorch.org/docs/master/torch.html#torch.flatten
f = torch.flatten(t, start_dim=1, end_dim=-1)
print(f, f.shape)
#https://pytorch.org/docs/master/torch.html#torch.view
f = t.view(t.size(0), -1)
print(f, f.shape)
#https://pytorch.org/docs/master/torch.html#torch.reshape
f = t.reshape(t.size(0), -1)
print(f, f.shape)
Проверка скорости
# flatten 3.49 µs ± 146 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
# view 3.23 µs ± 228 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
# reshape 3.04 µs ± 93 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
Если бы мы использовали класс сверху
flatten = Flatten()
t = torch.Tensor(3,2,2).random_(0, 10)
%timeit f=flatten(t)
5.16 µs ± 122 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
Этот результат показывает, что создание класса будет более медленным подходом.Вот почему быстрее сплющивать тензоры вперед.Я думаю, что это главная причина, по которой они не продвинулись nn.Flatten
.
Так что мое предложение было бы использовать внутри вперед для скорости.Как то так:
out = inp.reshape(inp.size(0), -1)