Сжатие факела и размер партии - PullRequest
1 голос
/ 10 марта 2020

Кто-нибудь здесь знает, учитывает ли функция torch.squeeze размер пакета (например, первое)? Из некоторого встроенного кода кажется, что это не так ... но, может быть, кто-то другой знает внутреннюю работу лучше, чем я.

Кстати, основная проблема в том, что у меня есть тензор формы (n_batch, channel, x, y, 1). Я хочу удалить последнее измерение с помощью простой функции, чтобы в итоге я получил форму (n_batch, channel, x, y).

Конечно, возможно изменение формы или даже выбор последней оси. Но я хочу встроить эту функциональность в слой, чтобы легко добавить ее к объекту ModuleList или Sequence.

1 Ответ

1 голос
/ 10 марта 2020

Нет! сжатие не учитывает размер пакета. Это потенциальный источник ошибки, если вы используете squeeze, когда размер пакета может быть 1. Эмпирическое правило заключается в том, что только классы и функции в torch.nn по умолчанию учитывают размеры пакета.

Это вызвало у меня головную боль в прошлое. Я рекомендую использовать reshape или только squeeze с необязательным аргументом входного измерения. В вашем случае вы можете использовать .squeeze(4), чтобы удалить только последнее измерение. Таким образом, ничего неожиданного не происходит. Сжатие без входного размера привело меня к неожиданным результатам, особенно когда

  1. форма ввода для модели может варьироваться
  2. размер партии может варьироваться
  3. nn.DataParallel используется (в этом случае размер партии для конкретного экземпляра может быть уменьшен до 1)
...