средняя функция факела неправильно - PullRequest
2 голосов
/ 17 апреля 2019

Не могу понять, как оценить это выражение: x.view(*(x.shape[:-2]),-1).mean(-1), х в форме (N, C, H, W) что такое астрик? а что значит (-1)? Заранее спасибо

1 Ответ

1 голос
/ 17 апреля 2019

что такое *?
Для .view() pytorch ожидает, что новая форма будет предоставлена ​​ отдельными int аргументами (представленными вдок как *shape).Звездочка (*) может использоваться в python для распаковки списка в его отдельные элементы, таким образом передавая view правильную форму входных аргументов, которую он ожидает.
Итак, в вашем случае x.shape - это (N, C, H, W), если вы передадите x.shape[:-2] без звездочки, вы получите x.view((N, C), -1) - это не , что view() ожидает.При распаковке (N, C) с использованием звездочки view получает view(N, C, -1) аргументов, как и ожидалось.В результате получается форма (N, C, H*W) (трехмерный тензор вместо 4).

Что такое mean(-1)?
Просто посмотрите документацию .mean(): первый аргумент является dim аргументом.То есть x.mean(-1) применяется mean вдоль последнего измерения.В вашем случае, поскольку keepdim=False по умолчанию, вы получите тензор размером (B, C), где каждый элемент соответствует среднему значению по обоим пространственным измерениям.
Это эквивалентно

x.mean(-1).mean(-1)
...