Я пытаюсь реализовать дифференцируемую альтернативу для argmax для выходного уровня сети классификации. Я был в восторге от nn.functional.gumbel_softmax, тем более что он предлагает параметр 'hard', но, похоже, он работает не так, как я ожидаю.
Код:
import torch.nn as nn
import torch
size = 5
Y = torch.rand(size)
target = torch.zeros_like(Y)
values, index = Y.max(0)
target[index] = 1
soft = nn.Softmax(dim=0)
out1 = soft(Y)
out2 = nn.functional.gumbel_softmax(logits=Y, tau=1, hard=False, eps=1e-10, dim=0)
out3 = nn.functional.gumbel_softmax(logits=Y, tau=1, hard=True, eps=1e-10, dim=0)
out4 = nn.functional.gumbel_softmax(logits=out1, tau=1, hard=False, eps=1e-10, dim=0)
out5 = nn.functional.gumbel_softmax(logits=out1, tau=1, hard=True, eps=1e-10, dim=0)
n_digits = 4
print(" Input:",Y)
print(" Target:",target)
print(" Soft:",out1)
print("Gumbel Soft:",out2)
print("Gumbel Hard:",out3)
print("Gumbel Soft:",out4)
print("Gumbel Hard:",out5)
Это привело к следующему выводу:
Input: tensor([0.9196, 0.7742, 0.2492, 0.0309, 0.8299])
Target: tensor([1., 0., 0., 0., 0.])
Soft: tensor([0.2702, 0.2336, 0.1382, 0.1111, 0.2470])
Gumbel Soft: tensor([0.1094, 0.0914, 0.7770, 0.0032, 0.0190])
Gumbel Hard: tensor([0., 0., 0., 0., 1.])
Gumbel Soft: tensor([0.0354, 0.0061, 0.0248, 0.8245, 0.1093])
Gumbel Hard: tensor([0., 0., 1., 0., 0.])
В то время как нормальная функция softmax сохраняет связь между значениями, функция Гумбеля, кажется, просто повсюду, заставляя меня поверь, что я не понимаю какой-то важный аспект этого.