Ищем эффективную реализацию умножения матрицы Pytorch для предотвращения большого использования памяти - PullRequest
2 голосов
/ 07 июня 2019

torch.matmul в pytorch имеет функции вещания, которые могут потреблять слишком много памяти. Я ищу эффективные реализации для предотвращения чрезмерного использования памяти.

Например, , входной тензор имеет размер как adj.size()==[1,3000,3000] s.size()==torch.Size([235, 3000, 10]) s.transpose(1, 2).size()==torch.Size([235, 10, 3000]) Задача состоит в том, чтобы рассчитать

link_loss = adj - torch.matmul(s, s.transpose(1, 2)) #
link_loss = torch.norm(link_loss, p=2)

Исходный код находится в пакете расширения горелки torch_geometric. Он находится в определении функции dens_diff_pool . torch.matmul(s, s.transpose(1,2)) будет использовать слишком много памяти (на моем компьютере только 2 ГБ), что приведет к ошибке:

Traceback (последний последний вызов):

Файл "", строка 1, в torch.matmul (s, s.transpose (1, 2))

RuntimeError: $ Torch: недостаточно памяти: вы пытались выделить 7 ГБ. Купи новую оперативку! at .. \ aten \ src \ TH \ THGeneral.cpp: 201

Оригинальный код автора пакета содержит torch.matmul(s, s.transpose(1, 2)).size()==[235,3000,3000] размером более 7 ГБ.

Моя попытка состоит в том, что я пытался использовать for итерацию

batch_size=235
link_loss=torch.sqrt(torch.stack([torch.norm(adj - torch.matmul(s[i], s[i].transpose(0, 1)), p=2)**2 for i in range(batch_size)]).sum(dim=0))

Этот цикл for, как известно, работает медленнее, чем использование широковещательных или других встроенных функций pytorch. Вопрос : Есть ли более быстрая реализация, лучше, чем использовать [... for ...]. Я новичок в изучении pytorch. Спасибо.

...