Реализация минимального LSTMCell в Keras с использованием классов RNN и Layer - PullRequest
3 голосов
/ 12 февраля 2020

Я пытаюсь реализовать простой LSTMCell без «причудливых kwargs», по умолчанию реализованных в классе tf.keras.layers.LSTMCell, следуя схеме c модели, подобной this . На самом деле это не имеет прямой цели, я просто хотел бы попрактиковаться в реализации более сложной RNNCell, чем та, которая описана здесь в разделе «Примеры». Мой код следующий:

from keras import Input
from keras.layers import Layer, RNN
from keras.models import Model
import keras.backend as K

class CustomLSTMCell(Layer):

    def __init__(self, units, **kwargs):
        self.state_size = units
        super(CustomLSTMCell, self).__init__(**kwargs)

    def build(self, input_shape):

        self.forget_w = self.add_weight(shape=(self.state_size, self.state_size + input_shape[-1]),
                                    initializer='uniform',
                                    name='forget_w')
        self.forget_b = self.add_weight(shape=(self.state_size,),
                                    initializer='uniform',
                                    name='forget_b')

        self.input_w1 = self.add_weight(shape=(self.state_size, self.state_size + input_shape[-1]),
                                    initializer='uniform',
                                    name='input_w1')
        self.input_b1 = self.add_weight(shape=(self.state_size,),
                                    initializer='uniform',
                                    name='input_b1')
        self.input_w2 = self.add_weight(shape=(self.state_size, self.state_size + input_shape[-1]),
                                    initializer='uniform',
                                    name='input_w2')
        self.input_b2 = self.add_weight(shape=(self.state_size,),
                                    initializer='uniform',
                                    name='input_b2')

        self.output_w = self.add_weight(shape=(self.state_size, self.state_size + input_shape[-1]),
                                    initializer='uniform',
                                    name='output_w')
        self.output_b = self.add_weight(shape=(self.state_size,),
                                    initializer='uniform',
                                    name='output_b')

        self.built = True

    def merge_with_state(self, inputs):
        self.stateH = K.concatenate([self.stateH, inputs], axis=-1)

    def forget_gate(self):
        forget = K.dot(self.forget_w, self.stateH) + self.forget_b
        forget = K.sigmoid(forget)
        self.stateC = self.stateC * forget

    def input_gate(self):
        candidate = K.dot(self.input_w1, self.stateH) + self.input_b1
        candidate = K.tanh(candidate)

        amount = K.dot(self.input_w2, self.stateH) + self.input_b2
        amount = K.tanh(amount)

        self.stateC = self.stateC + amount * candidate

    def output_gate(self):
        self.stateH = K.dot(self.output_w, self.stateH) + self.output_b
        self.stateH = K.sigmoid(self.stateH)

        self.stateH = self.stateH * K.tanh(self.stateC)

    def call(self, inputs, states):

        self.stateH = states[0]
        self.stateC = states[1]

        self.merge_with_state(inputs)
        self.forget_gate()
        self.input_gate()
        self.output_gate()

        return self.stateH, [self.stateH, self.stateC]

# Testing
inp = Input(shape=(None, 3))
lstm = RNN(CustomLSTMCell(10))(inp)

model = Model(inputs=inp, outputs=lstm)
inp_value = [[[[1,2,3], [2,3,4], [3,4,5]]]]
pred = model.predict(inp_value)
print(pred)

Однако, когда я попытался проверить это, возникла исключительная ситуация со следующим сообщением:

IndexError: tuple index out of range

в функции call в строке где я установил значение для self.stateC. Здесь я подумал, что изначально аргумент states функции call является тензором, а не списком тензоров, поэтому я и получаю ошибку. Поэтому я добавил строку self.already_called = False к классам __init__ и следующий сегмент к функции call:

 if not self.already_called:
        self.stateH = K.ones(self.state_size)
        self.stateC = K.ones(self.state_size)
        self.already_called = True
    else:
        self.stateH = states[0]
        self.stateC = states[1]

, надеясь, что это устранит проблему. Это привело к другой ошибке в функции merge_with_state:

 ValueError: Shape must be rank 1 but is rank 2 for 'rnn_1/concat' (op: 'ConcatV2') with input shapes: [10], [?,3], [].

, которую я искренне не получаю, поскольку слой RNN должен только «показывать» тензоры CustomLSTMCell с shape (3), а не (None, 3), поскольку ось 0 является осью, она должна повторяться вдоль. На данный момент я был убежден, что я делаю что-то действительно неправильно и должен обратиться за помощью к сообществу. В основном мой вопрос: что не так с моим кодом и если «почти все», то как мне реализовать LSTMCell с нуля?

1 Ответ

0 голосов
/ 13 февраля 2020

Хорошо, похоже, мне удалось решить проблему. Оказывается, что всегда полезно прочитать документацию, в этом случае документы для класса RNN . Во-первых, атрибут already_called не нужен, поскольку проблема заключается в первой строке функции __init__: атрибут state_size должен представлять собой список целых чисел, а не только одно целое число, например: self.state_size = [units, units] (так как нам нужно два состояния для LSTM размером units, а не одно). Когда я исправил это, я получил другую ошибку: тензоры не совместимы по размеру в forget_gate для сложения. Это произошло потому, что RNN видит всю партию сразу, а не каждый элемент в пакете отдельно (таким образом, форма None на оси 0). Исправление для этого состоит в том, чтобы добавить дополнительное измерение к каждому тензору размера 1 на оси 0 следующим образом:

 self.forget_w = self.add_weight(shape=(1, self.state_size, self.state_size + input_shape[-1]),
                                initializer='uniform',
                                name='forget_w')

и вместо точечных произведений мне пришлось использовать K.batch_dot функция. Итак, весь рабочий код выглядит следующим образом:

 from keras import Input
 from keras.layers import Layer, RNN
 from keras.models import Model
 import keras.backend as K

 class CustomLSTMCell(Layer):

     def __init__(self, units, **kwargs):
         self.state_size = [units, units]
         super(CustomLSTMCell, self).__init__(**kwargs)

     def build(self, input_shape):

         self.forget_w = self.add_weight(shape=(1, self.state_size[0], self.state_size[0] + input_shape[-1]),
                                         initializer='uniform',
                                         name='forget_w')
         self.forget_b = self.add_weight(shape=(1, self.state_size[0]),
                                         initializer='uniform',
                                         name='forget_b')

         self.input_w1 = self.add_weight(shape=(1, self.state_size[0], self.state_size[0] + input_shape[-1]),
                                         initializer='uniform',
                                         name='input_w1')
         self.input_b1 = self.add_weight(shape=(1, self.state_size[0]),
                                         initializer='uniform',
                                         name='input_b1')
         self.input_w2 = self.add_weight(shape=(1, self.state_size[0], self.state_size[0] + input_shape[-1]),
                                         initializer='uniform',
                                         name='input_w2')
         self.input_b2 = self.add_weight(shape=(1, self.state_size[0],),
                                         initializer='uniform',
                                         name='input_b2')

         self.output_w = self.add_weight(shape=(1, self.state_size[0], self.state_size[0] + input_shape[-1]),
                                         initializer='uniform',
                                         name='output_w')
         self.output_b = self.add_weight(shape=(1, self.state_size[0],),
                                         initializer='uniform',
                                         name='output_b')

         self.built = True

     def merge_with_state(self, inputs):
         self.stateH = K.concatenate([self.stateH, inputs], axis=-1)

     def forget_gate(self):        
         forget = K.batch_dot(self.forget_w, self.stateH) + self.forget_b
         forget = K.sigmoid(forget)
         self.stateC = self.stateC * forget

     def input_gate(self):
         candidate = K.batch_dot(self.input_w1, self.stateH) + self.input_b1
         candidate = K.tanh(candidate)

         amount = K.batch_dot(self.input_w2, self.stateH) + self.input_b2
         amount = K.sigmoid(amount)

         self.stateC = self.stateC + amount * candidate

     def output_gate(self):
         self.stateH = K.batch_dot(self.output_w, self.stateH) + self.output_b
         self.stateH = K.sigmoid(self.stateH)

         self.stateH = self.stateH * K.tanh(self.stateC)

     def call(self, inputs, states):

         self.stateH = states[0]
         self.stateC = states[1]

         self.merge_with_state(inputs)
         self.forget_gate()
         self.input_gate()
         self.output_gate()

         return self.stateH, [self.stateH, self.stateC]

 inp = Input(shape=(None, 3))
 lstm = RNN(CustomLSTMCell(10))(inp)

 model = Model(inputs=inp, outputs=lstm)
 inp_value = [[[[1,2,3], [2,3,4], [3,4,5]]]]
 pred = model.predict(inp_value)
 print(pred)

Редактировать: В вопросе я допустил ошибку в отношении связанной модели и использовал функцию tanh в input_gate за amount вместо сигмовидной. Здесь я отредактировал это в коде, так что теперь это правильно.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...