определение измерения смещения в кодировании нейронной сети - PullRequest
0 голосов
/ 21 февраля 2020

На следующем рисунке, показывающем код для определения размера смещения b1, мне интересно, почему первое измерение смещения b1 не является размером пакета? Значит ли это, что предполагается, что это смещение применяется ко всем партиям?

Если я указываю измерение смещения b1 равным (batch_size, 256), значит ли это, что я применяю другой b1 к другому пакету? Но теоретически это все еще должно работать правильно? Кроме того, в чем разница между тензорными (256), (256,) и (256,1) ...?

Рисунок: определение размера nn

1 Ответ

1 голос
/ 21 февраля 2020

Веса и смещения вашего уровня нейронной сети , а не , указанные в терминах размера пакета.

например: w1 = torch.randn(784,256): это 2D-матрица, которую вы собираетесь использовать для матрицы умножить. 784 - это размер вашего входного изображения без учета размера пакета. (Я предполагаю, что это для mnist? Похоже, что вы сглаживаете 2d-изображения в 1d-вектор, поэтому 28 * 28 = 784).

256 - это выходное измерение вашего вывода (сколько логитов вы используете)

Аналогично, b1 = torch.randn(256): Это одномерный вектор, который вы просто добавляете в логиты.

256 - это размер логитов

Pytorch автоматически передает (повторяет) их по размеру пакета для всех ваших операций, поэтому не имеет значения, какой был размер пакета.

Например. например, при добавлении b1 автоматически повторяется по первому измерению, поэтому фактическая форма для добавления равна (batch_size, 256).

По соглашению, pytorch "выравнивает" измерения справа налево.

  1. , если они отсутствуют, затем повторяет тензор по отсутствующему измерению
  2. Если любое измерение равно 1, оно повторяет тензор по этому измерению, чтобы соответствовать другому операнду. Например (скопировано из документации по трансляции)
>>> x=torch.empty(5,7,3)
>>> y=torch.empty(5,7,3)
# same shapes are always broadcastable (i.e. the above rules always hold)

>>> x=torch.empty((0,))
>>> y=torch.empty(2,2)
# x and y are not broadcastable, because x does not have at least 1 dimension

# can line up trailing dimensions
>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty(  3,1,1)
# x and y are broadcastable.
# 1st trailing dimension: both have size 1
# 2nd trailing dimension: y has size 1
# 3rd trailing dimension: x size == y size
# 4th trailing dimension: y dimension doesn't exist

# but:
>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty(  3,1,1)
# x and y are not broadcastable, because in the 3rd trailing dimension 2 != 3

Это действительно удобно, потому что это означает, что вам не нужно переопределять свой нейронный net каждый раз, когда вы хотите использовать другой batch_size

Вот ссылка, если вы хотите узнать больше о вещании в pytorch


Также в чем разница между тензором (256), (256, ) и (256,1)

первые два абсолютно одинаковы; python обычно позволяет использовать запятые в выражениях кортежей. Вы создаете одномерный вектор из 256 элементов.

Последний отличается; вы создаете 2D-тензор, где первое измерение равно 256, а второе - 1. Базовые данные одинаковы, и это не имеет значения, если вы согласны с тем, что вы используете, но если вы смешиваете их, это часто приводит к нежелательному поведению:

Например:

a = torch.randn(256)
b = torch.randn(256)
c = a + b
c.shape
>>> torch.Size([256])

Просто: они просто добавляют поэлементно.

Но обратите внимание, что происходит, когда один из них shape (-1,1):

b = b.view(-1,1) # -1 here means torch will infer the shape of this dimension based on the known size of the tensor and all other specified dimensions
b.shape
>>> torch.Size([256, 1])

c = a + b

Теперь из-за правил вещания

  • a повторяется по первому измерению, поэтому оно имеет то же число измерений, что и b, поэтому он автоматически интерпретирует a как тензор (256,256)
  • b повторяется, поэтому его последнее измерение (1) теперь соответствует измерению (256)

, поэтому:

c.shape
>>> torch.Size([256, 256])

Подсказка: правила вещания трудно запомнить, и они часто являются источником ошибок. Если вы сомневаетесь в тензорных формах, стоит запускать ваш код в интерпретаторе построчно с фиктивными данными и просто проверять, какова форма каждого тензора, например print(torch.mm(input,w1).shape)

...