Поведение Pytorch'а gumbel_softmax () - PullRequest
0 голосов
/ 09 марта 2020

Я пытаюсь реализовать дифференцируемую альтернативу для argmax для выходного уровня сети классификации. Я был в восторге от nn.functional.gumbel_softmax, тем более что он предлагает параметр 'hard', но, похоже, он работает не так, как я ожидаю.

Код:

import torch.nn as nn
import torch

size = 5
Y = torch.rand(size)
target = torch.zeros_like(Y)
values, index = Y.max(0)
target[index] = 1

soft = nn.Softmax(dim=0)

out1 = soft(Y)
out2 = nn.functional.gumbel_softmax(logits=Y, tau=1, hard=False, eps=1e-10, dim=0)
out3 = nn.functional.gumbel_softmax(logits=Y, tau=1, hard=True, eps=1e-10, dim=0)
out4 = nn.functional.gumbel_softmax(logits=out1, tau=1, hard=False, eps=1e-10, dim=0)
out5 = nn.functional.gumbel_softmax(logits=out1, tau=1, hard=True, eps=1e-10, dim=0)

n_digits = 4


print("      Input:",Y)
print("     Target:",target)
print("       Soft:",out1)
print("Gumbel Soft:",out2)
print("Gumbel Hard:",out3)
print("Gumbel Soft:",out4)
print("Gumbel Hard:",out5)

Это привело к следующему выводу:

  Input: tensor([0.9196, 0.7742, 0.2492, 0.0309, 0.8299])
  Target: tensor([1., 0., 0., 0., 0.])
  Soft: tensor([0.2702, 0.2336, 0.1382, 0.1111, 0.2470])
  Gumbel Soft: tensor([0.1094, 0.0914, 0.7770, 0.0032, 0.0190])
  Gumbel Hard: tensor([0., 0., 0., 0., 1.])
  Gumbel Soft: tensor([0.0354, 0.0061, 0.0248, 0.8245, 0.1093])
  Gumbel Hard: tensor([0., 0., 1., 0., 0.])

В то время как нормальная функция softmax сохраняет связь между значениями, функция Гумбеля, кажется, просто повсюду, заставляя меня поверь, что я не понимаю какой-то важный аспект этого.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...