Сохранение градиентов при перестановке данных в тензор с помощью pytorch - PullRequest
0 голосов
/ 29 января 2020

У меня есть схема, где я храню матрицу с нулями на диагонали как вектор. Позже я хочу оптимизировать этот вектор, поэтому мне нужно отслеживать градиент. Моя задача состоит в том, чтобы изменить форму между ними.

Я хочу - по причинам, определяемым доменом c - сохранить порядок данных в матрице, чтобы транспонированные элементы матрицы W располагались рядом друг с другом в векторная форма.

Размер матрицы W может быть изменен, поэтому я начну с нумерации элементов в левой верхней части матрицы и продолжу в сторону.

У меня есть придумать два способа сделать это. См. Фрагмент кода.

import torch
import torch.sparse

w = torch.tensor([10,11,12,13,14,15],requires_grad=True,dtype=torch.float)
i = torch.LongTensor([
    [0, 1,0],
    [1, 0,1], 
    [0, 2,2],
    [2, 0,3],
    [1, 2,4],
    [2, 1,5],
])
v = torch.FloatTensor([1,      1,      1 ,1,1,1   ])
reshaper = torch.sparse.FloatTensor(i.t(), v, torch.Size([3,3,6])).to_dense()
W_mat_with_reshaper = reshaper @ w
W_mat_directly = torch.tensor([
  [0,    w[0],  w[2],],
  [w[1],    0,  w[4],],
  [w[3], w[5],     0,],
])
print(W_mat_with_reshaper)
print(W_mat_directly)

, и это дает вывод


tensor([[ 0., 10., 12.],
        [11.,  0., 14.],
        [13., 15.,  0.]], grad_fn=<UnsafeViewBackward>)
tensor([[ 0., 10., 12.],
        [11.,  0., 14.],
        [13., 15.,  0.]])

Как видите, прямой способ преобразования вектора в матрицу не имеет функции grad, но умножить-с-восстановитель-тензор. Создание тензора изменения формы кажется трудным, но, с другой стороны, ручная запись матрицы также невозможна.

Есть ли способ сделать произвольные изменения в pytorch, которые сохраняют градиент?

1 Ответ

1 голос
/ 29 января 2020

Вместо построения W_mat_directly из элементов из w, попробуйте присвоить w в W:

W_mat_directly = torch.zeros((3, 3), dtype=w.dtype)
W_mat_directly[(0, 0, 1, 1, 2, 2), (1, 2, 0, 2, 0, 1)] = w

Вы получите

tensor([[ 0., 10., 11.],
        [12.,  0., 13.],
        [14., 15.,  0.]], grad_fn=<IndexPutBackward>)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...