У меня есть следующий фрагмент кода для сложного умножения в Pytorch
Вариация 3
def complex_mult(x, y):
a, b = x[:, :, :, :, :, 0], x[:, :, :, :, :, 1]
c, d = y[:, :, :, :, :, 0], y[:, :, :, :, :, 1]
out = torch.stack([a*c - b*d, a*d + b*c], dim=-1)
return out
Поскольку я изучал, почему я RuntimeError: CUDA out of memory
, я понял, что распределение памяти в зависимости от количества инструкций, которые я явно даю.
Я написал этот код тремя различными способами, два других, как показано ниже.
Вариация 2
def complex_mult(x, y):
a, b = x[:, :, :, :, :, 0], x[:, :, :, :, :, 1]
c, d = y[:, :, :, :, :, 0], y[:, :, :, :, :, 1]
real = a*c - b*d
imag = a*d + b*c
pair = [real, imag]
out = torch.stack(pair, dim=-1)
return out
Вариация 1
def complex_mult(x, y):
a, b = x[:, :, :, :, :, 0], x[:, :, :, :, :, 1]
c, d = y[:, :, :, :, :, 0], y[:, :, :, :, :, 1]
r1 = a*c
r2 = b*d
real = r1 - r2
i1 = a*d
i2 = b*c
imag = i1 + i2
pair = [real, imag]
out = torch.stack(pair, dim=-1)
return out
С Вариация 1 Gist , у нас есть:
- Все тензоры используют
float32
, что составляет 4 bytes
- Размер
_x
равен (128, 64, 32, 32, 2)
, таким образом 64 MB
, 32 MB
каждый компонент (real, imag)
Умножение каждого компонента (a*c
, b*d
, a*d
и b*c
) будет иметь 2048 MB
- Что в
64
раза больше 32 MB
каждого компонента x
- Итого
8 GB
Plus 2 GB
, если рассматривать каждый sum
результат как хранилище по своей собственной переменной ble
Штабелирование real
и imag
требует больше 4 GB
RuntimeError: CUDA out of memory. Tried to allocate 4.00 GiB (GPU 0; 14.73 GiB total capacity; 12.12 GiB already allocated; 1.82 GiB free; 12.12 GiB reserved in total by PyTorch)
Интуитивно я решил придерживаться Вариация 3 , думая
- "Без этих временных переменных (мест) держателей Torch будет знать, что они больше не будут использоваться, освобождая их память. Таким образом, используя только
4 GB
к концу моего исполнения. "
Но, что я не понимаю, и вот мой вопрос , is
Почему при Variation 3 я все еще получаю CUDA out of memory
после нескольких вызовов этого метода? Суть показывает, что при третьем вызове происходит сбой из-за недостатка память.
Что еще более интересно, это то, что, хотя можно думать
- "Это потому, что уже есть
4 GB
, который выделяется с помощью предыдущий результат. "
Однако, учитывая, что он был в состоянии выполнить дважды, я не думаю, что это проблема.