Объект DataParallel не имеет атрибута init_hidden - PullRequest
0 голосов
/ 21 мая 2018

То, что я хочу сделать, это использовать DataParallel в моем пользовательском классе RNN.

Похоже, я неправильно инициализировал hidden_0 ...

class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, n_layers=1):
    super(RNN, self).__init__()
    self.input_size = input_size
    self.hidden_size = hidden_size
    self.output_size = output_size
    self.n_layers = n_layers

    self.encoder = nn.Embedding(input_size, hidden_size)
    self.gru = nn.GRU(hidden_size, hidden_size, n_layers,batch_first = True)
    self.decoder = nn.Linear(hidden_size, output_size)
    self.init_hidden(batch_size)


def forward(self, input):
    input = self.encoder(input)
    output, self.hidden = self.gru(input,self.hidden)
    output = self.decoder(output.contiguous().view(-1,self.hidden_size))
    output = output.contiguous().view(batch_size,num_steps,N_CHARACTERS)
    #print (output.size())10,50,67

    return output

def init_hidden(self,batch_size):
    self.hidden = Variable(T.zeros(self.n_layers, batch_size, self.hidden_size).cuda())

И я вызываю сетьтаким образом:

decoder = T.nn.DataParallel(RNN(N_CHARACTERS, HIDDEN_SIZE, N_CHARACTERS), dim=1).cuda()

Затем начните тренировку:

for epoch in range(EPOCH_):
    hidden = decoder.init_hidden()

Но я получаю ошибку, и у меня нет идеального способа ее исправить ...

Объект 'DataParallel' не имеет атрибута 'init_hidden'

Спасибо за помощь!

Ответы [ 2 ]

0 голосов
/ 20 мая 2019

Обходной путь, который я сделал:

self.model = model 
# Since if the model is wrapped by the `DataParallel` class, you won't be able to access its attributes
# unless you write `model.module` which breaks the code compatibility. We use `model_attr_accessor` for attributes
# accessing only.
if isinstance(model, DataParallel):
    self.model_attr_accessor = model.module
else:
    self.model_attr_accessor = model

Это дает мне преимущество в том, что модель распределяется по моим графическим процессорам, когда я делаю self.model(input) (то есть, когда она обернута DataParallel);и когда мне нужно получить доступ к его атрибутам, я просто делаю self.model_attr_accessor.<<WHATEVER>>.Кроме того, этот дизайн дает мне более модульный способ доступа к атрибутам из нескольких функций без наличия if-statements во всех из них, чтобы проверить, обернуто ли оно DataParallel или нет.

С другой стороны, если вы написали model.module.<<WHATEVER>>, а модель не была обернута DataParallel, это приведет к ошибке, говорящей о том, что ваша модель не имеет атрибута module.


Однако, более компактной реализацией является создание настраиваемого DataParallel, например:

class _CustomDataParallel(nn.Module):
    def __init__(self, model):
        super(_CustomDataParallel, self).__init__()
        self.model = nn.DataParallel(model).cuda()
        print(type(self.model))

    def forward(self, *input):
        return self.model(*input)

    def __getattr__(self, name):
        try:
            return super().__getattr__(name)
        except AttributeError:
            return getattr(self.model.module, name)
0 голосов
/ 17 июля 2018

При использовании DataParallel ваш исходный модуль будет иметь атрибут module параллельного модуля:

for epoch in range(EPOCH_):
    hidden = decoder.module.init_hidden()
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...