Применение Kullback-Leibler (он же дивергенция kl) поэтапно в Pytorch - PullRequest
1 голос
/ 02 апреля 2019

У меня есть два тензора с именами x_t, x_k со следующими формами NxHxW и KxNxHxW соответственно, где K - это количество автоэнкодеров, использованных для восстановления x_t (если вы не знаете, чтоэто, предположим, что они K разные сети, стремящиеся предсказать x_t, это вероятно не имеет никакого отношения к вопросу в любом случае) N - размер партии, H высота матрицы, W ширина матрицы.

Я пытаюсь применить алгоритм Расхождение Кульбака-Лейблера к обоим тензорам (после трансляции x_t как x_k вдоль K th размерность), используя метод Пиорча nn.functional.kl_div.

Однако, похоже, он не работает , как я ожидал .Я рассчитываю вычислить kl_div между каждым наблюдением в x_t и x_k, получая тензор размером KxN (т. Е. kl_div каждого наблюдения для каждого K автоэнкодера).

Фактическим выходным значением является одиночное значение , если я использую аргумент reduction, и тот же тензорный размер (т. Е. KxNxHxW), если я его не использую.

Кто-нибудь пробовал что-то подобное?


Воспроизводимый пример:

import torch
import torch.nn.functional as F
#                  K   N   H  W
x_t = torch.randn(    10, 5, 5)
x_k = torch.randn( 3, 10, 5, 5)

x_broadcasted = x_t.expand_as(x_k)

loss = F.kl_div(x_t, x_k, reduction="none") # or "batchmean", or there are many options

1 Ответ

1 голос
/ 02 апреля 2019

Мне неясно, что именно представляет собой распределение вероятностей в вашей модели.При reduction='none', kl_div, при log(x_n) и y_n вычисляется kl_div = y_n * (log(y_n) - log(x_n)), который является «суммированной» частью фактической дивергенции Кульбака-Лейблера.Суммирование (или, другими словами, принимая ожидание) зависит от вас.Если ваша точка зрения такова, что H, W - это два измерения, по которым вы хотите рассчитывать, это просто:

loss = F.kl_div(x_t, x_k, reduction="none").sum(dim=(-1, -2))

, который имеет форму [K, N].Если выходные данные вашей сети следует интерпретировать по-разному, вам нужно лучше указать, какие измерения событий и какие примеры измерений вашего дистрибутива.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...