Почему OOM происходит в моей модели init ()? - PullRequest
1 голос
/ 14 апреля 2019

Одна строка в моей модели, tr.nn.Linear(hw_flat * num_filters*8, num_fc), вызывает ошибку OOM при инициализации модели.Комментирование этого снимает проблему с памятью.

import torch as tr
from layers import Conv2dSame, Flatten

class Discriminator(tr.nn.Module):
    def __init__(self, cfg):
        super(Discriminator, self).__init__()
        num_filters = 64
        hw_flat = int(cfg.hr_resolution[0] / 2**4)**2
        num_fc = 1024

        self.model = tr.nn.Sequential(
            # Channels in, channels out, filter size, stride, padding
            Conv2dSame(cfg.num_channels, num_filters, 3),
            tr.nn.LeakyReLU(),
            Conv2dSame(num_filters, num_filters, 3, 2),
            tr.nn.BatchNorm2d(num_filters),
            tr.nn.LeakyReLU(),
            Conv2dSame(num_filters, num_filters*2, 3),
            tr.nn.BatchNorm2d(num_filters*2),
            tr.nn.LeakyReLU(),
            Conv2dSame(num_filters*2, num_filters*2, 3, 2),
            tr.nn.BatchNorm2d(num_filters*2),
            tr.nn.LeakyReLU(),
            Conv2dSame(num_filters*2, num_filters*4, 3),
            tr.nn.BatchNorm2d(num_filters*4),
            tr.nn.LeakyReLU(),
            Conv2dSame(num_filters*4, num_filters*4, 3, 2),
            tr.nn.BatchNorm2d(num_filters*4),
            tr.nn.LeakyReLU(),
            Conv2dSame(num_filters*4, num_filters*8, 3),
            tr.nn.BatchNorm2d(num_filters*8),
            tr.nn.LeakyReLU(),
            Conv2dSame(num_filters*8, num_filters*8, 3, 2),
            tr.nn.BatchNorm2d(num_filters*8),
            tr.nn.LeakyReLU(),
            Flatten(),
            tr.nn.Linear(hw_flat * num_filters*8, num_fc),
            tr.nn.LeakyReLU(),
            tr.nn.Linear(num_fc, 1),
            tr.nn.Sigmoid()
        )
        self.model.apply(self.init_weights)

    def forward(self, x_in):
        x_out = self.model(x_in)
        return x_out

    def init_weights(self, layer):
        if type(layer) in [tr.nn.Conv2d, tr.nn.Linear]:
            tr.nn.init.xavier_uniform_(layer.weight)

Это странно, так как hw_flat = 96 * 96 = 9216 и num_filters * 8 = 512, поэтому hw_flat * num_filters = 4718592, то есть число параметровв этом слое.Я подтвердил этот расчет, так как изменение слоя на tr.nn.Linear(4718592, num_fc) приводит к тому же результату.

Для меня это не имеет смысла, так как dtype = float32, поэтому ожидаемый размер этого будет 32 * 4718592 = 150 994 944 байт.,Это эквивалентно примерно 150 МБ.

Сообщение об ошибке:

Traceback (most recent call last):
  File "main.py", line 116, in <module>
    main()
  File "main.py", line 112, in main
    srgan = SRGAN(cfg)
  File "main.py", line 25, in __init__
    self.discriminator = Discriminator(cfg).to(device)
  File "/home/jpatts/Documents/ECE/ECE471-SRGAN/models.py", line 87, in __init__
    tr.nn.Linear(hw_flat * num_filters*8, num_fc),
  File "/home/jpatts/.local/lib/python3.6/site-packages/torch/nn/modules/linear.py", line 51, in __init__
    self.weight = Parameter(torch.Tensor(out_features, in_features))
RuntimeError: $ Torch: not enough memory: you tried to allocate 18GB. Buy new RAM! at /pytorch/aten/src/TH/THGeneral.cpp:201

Я использую только пакеты размером 1 (не то, что влияет на эту ошибку), с общей формой ввода всеть (1, 3, 1536, 1536) и форма после сглаживания слоя (1, 4718592).

Почему это происходит?

1 Ответ

1 голос
/ 14 апреля 2019

Ваш линейный слой довольно большой - на самом деле ему требуется как минимум 18 ГБ памяти.(Ваша оценка неверна по двум причинам: (1) float32 занимает 4 байта памяти, а не 32, и (2) вы не умножили на выходной размер.)

Из Часто задаваемые вопросы по документации PyTorch :

Не используйте слишком большие линейные слои.Линейный слой nn.Linear(m, n) использует O(n*m) память: то есть требования к весам для весов масштабируются в квадрате с количеством признаков.Таким способом очень легко унести вашу память (и помните, что вам понадобится, по крайней мере, вдвое больший вес, поскольку вам также нужно хранить градиенты).

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...