Расчет дисперсии в Pytorch на пространственной оси - PullRequest
0 голосов
/ 17 февраля 2019

Я пытаюсь вычислить дисперсию в Pytorch, но не могу сделать это по нескольким осям.

Я сделал аналогичную вещь в Tensorflow, но не могу сделать это на Pytorch, так как функция torch.var принимает int в качестве измерения вместооси. Ниже кода - последний код канала, я ожидаю, что оси = [2,3]

Lambda(lambda x: tf.nn.moments(x, axes=[1, 2]))

Например, если input_dims = (5, 10, 25, 25), тогда output_dims должен быть(5,10,1,1).

1 Ответ

0 голосов
/ 17 февраля 2019

Одна вещь, которую вы можете сделать, это использовать tensor.view(), чтобы сгладить все измерения, для которых вы хотите рассчитать дисперсию, в одно измерение, прежде чем применять метод var():

torch.var(x.view(x.shape[0], x.shape[1], 1, -1,), dim=3, keepdim=True)

Я использовал keepdim=True, чтобы сохранить размер, для которого мы рассчитываем дисперсию, чтобы получить желаемую выходную форму.

...