torch.matmul
в pytorch имеет функции вещания, которые могут потреблять слишком много памяти. Я ищу эффективные реализации для предотвращения чрезмерного использования памяти.
Например, , входной тензор имеет размер как
adj.size()==[1,3000,3000]
s.size()==torch.Size([235, 3000, 10])
s.transpose(1, 2).size()==torch.Size([235, 10, 3000])
Задача состоит в том, чтобы рассчитать
link_loss = adj - torch.matmul(s, s.transpose(1, 2)) #
link_loss = torch.norm(link_loss, p=2)
Исходный код находится в пакете расширения горелки torch_geometric
. Он находится в определении функции dens_diff_pool .
torch.matmul(s, s.transpose(1,2))
будет использовать слишком много памяти (на моем компьютере только 2 ГБ), что приведет к ошибке:
Traceback (последний последний вызов):
Файл "", строка 1, в
torch.matmul (s, s.transpose (1, 2))
RuntimeError: $ Torch: недостаточно памяти: вы пытались выделить 7 ГБ.
Купи новую оперативку! at .. \ aten \ src \ TH \ THGeneral.cpp: 201
Оригинальный код автора пакета содержит torch.matmul(s, s.transpose(1, 2)).size()==[235,3000,3000]
размером более 7 ГБ.
Моя попытка состоит в том, что я пытался использовать for
итерацию
batch_size=235
link_loss=torch.sqrt(torch.stack([torch.norm(adj - torch.matmul(s[i], s[i].transpose(0, 1)), p=2)**2 for i in range(batch_size)]).sum(dim=0))
Этот цикл for
, как известно, работает медленнее, чем использование широковещательных или других встроенных функций pytorch.
Вопрос :
Есть ли более быстрая реализация, лучше, чем использовать [... for ...].
Я новичок в изучении pytorch. Спасибо.