Каков тип xs в NStepBiLSTM в Chainer? - PullRequest
0 голосов
/ 07 мая 2019

В руководстве по NStepBiLSTM говорится, что его функция forward ожидает, что xs придет в формате списка последовательностей с переменными переменными.Однако я получаю ошибку, которая подразумевает, что xs должен быть массивом np.Чего мне не хватает?

Я использую эту функцию, чтобы превратить входной массив в список переменных (массив) с формой массива (n, 1).

def cut_data(data, batchsize):
    q = data.shape[0] // batchsize
    data = data[:q*batchsize]
    data = data.reshape((batchsize, q))
    xs = []
    for i in range(q):
        a = data[:,i].reshape(batchsize,1)
        xs.append(Variable(a))
    return xs

Но когда я вызываюмой предсказатель с таким хз я получаю эту ошибку:

<ipython-input-27-cb554613ad71> in __call__(self, xs, ts)
     10         batchlen = len(xs)
     11         loss = F.sum(F.mean_squared_error(
---> 12         self.predictor(xs), ts, reduce='no')) / batch
     13 
     14         chainer.report({'loss': loss}, self)

<ipython-input-28-ce9434e91153> in __call__(self, x)
     12     def __call__(self, x):
     13         self.h, self.c, y = self.lstm(self.h,self.c,x)
---> 14         output = self.out(y)
     15         return output
     16 

~\Anaconda2\lib\site-packages\chainer\links\connection\linear.py in __call__(self, x)
    127             in_size = functools.reduce(operator.mul, x.shape[1:], 1)
    128             self._initialize_params(in_size)
--> 129         return linear.linear(x, self.W, self.b)

~\Anaconda2\lib\site-packages\chainer\functions\connection\linear.py in linear(x, W, b)
    165 
    166     """
--> 167     if x.ndim > 2:
    168         x = x.reshape(len(x), -1)
    169 

AttributeError: 'list' object has no attribute 'ndim'

Это моя простая сеть:

class LSTM_RNN(Chain):

    def __init__(self, n_hidden, n_input=1, n_out=1):
        super(LSTM_RNN, self).__init__()
        with self.init_scope():
            self.lstm = L.NStepBiLSTM(n_layers=n_hidden, in_size=n_input, out_size=n_out, dropout=0.5)
            self.out = L.Linear(n_hidden, n_out)
            self.h = None
            self.c = None

    def __call__(self, x):
        self.h, self.c, y = self.lstm(self.h,self.c,x)
        output = self.out(y)
        return output

    def reset_state(self):
        self.h = None
        self.c = None

1 Ответ

0 голосов
/ 13 мая 2019

Сообщение об ошибке

---> 14         output = self.out(y)

указывает, что ошибка вызвана методом self.out(), а не self.lstm()

Согласно официальной ссылке API ,L.NStepBiLSTM.run() возвращает кортеж hy, cy и ys, который ys является списком.

Ваш код

    def __call__(self, x):
        self.h, self.c, y = self.lstm(self.h,self.c,x)
        output = self.out(y)
        return output

указывает, что y(который упоминается как ys в официальном документе) непосредственно передается в self.out, что составляет L.Linear.__call__.Это вызывает несоответствие типов.

В общем, форма y в ys отличается друг от друга, поскольку x в xs может быть последовательностью различной длины.

Если вам нужна дополнительная помощь, пожалуйста, не стесняйтесь задавать вопросы!

...