Я пытаюсь реализовать простой LSTMCell без «причудливых kwargs», по умолчанию реализованных в классе tf.keras.layers.LSTMCell, следуя схеме c модели, подобной this . На самом деле это не имеет прямой цели, я просто хотел бы попрактиковаться в реализации более сложной RNNCell, чем та, которая описана здесь в разделе «Примеры». Мой код следующий:
from keras import Input
from keras.layers import Layer, RNN
from keras.models import Model
import keras.backend as K
class CustomLSTMCell(Layer):
def __init__(self, units, **kwargs):
self.state_size = units
super(CustomLSTMCell, self).__init__(**kwargs)
def build(self, input_shape):
self.forget_w = self.add_weight(shape=(self.state_size, self.state_size + input_shape[-1]),
initializer='uniform',
name='forget_w')
self.forget_b = self.add_weight(shape=(self.state_size,),
initializer='uniform',
name='forget_b')
self.input_w1 = self.add_weight(shape=(self.state_size, self.state_size + input_shape[-1]),
initializer='uniform',
name='input_w1')
self.input_b1 = self.add_weight(shape=(self.state_size,),
initializer='uniform',
name='input_b1')
self.input_w2 = self.add_weight(shape=(self.state_size, self.state_size + input_shape[-1]),
initializer='uniform',
name='input_w2')
self.input_b2 = self.add_weight(shape=(self.state_size,),
initializer='uniform',
name='input_b2')
self.output_w = self.add_weight(shape=(self.state_size, self.state_size + input_shape[-1]),
initializer='uniform',
name='output_w')
self.output_b = self.add_weight(shape=(self.state_size,),
initializer='uniform',
name='output_b')
self.built = True
def merge_with_state(self, inputs):
self.stateH = K.concatenate([self.stateH, inputs], axis=-1)
def forget_gate(self):
forget = K.dot(self.forget_w, self.stateH) + self.forget_b
forget = K.sigmoid(forget)
self.stateC = self.stateC * forget
def input_gate(self):
candidate = K.dot(self.input_w1, self.stateH) + self.input_b1
candidate = K.tanh(candidate)
amount = K.dot(self.input_w2, self.stateH) + self.input_b2
amount = K.tanh(amount)
self.stateC = self.stateC + amount * candidate
def output_gate(self):
self.stateH = K.dot(self.output_w, self.stateH) + self.output_b
self.stateH = K.sigmoid(self.stateH)
self.stateH = self.stateH * K.tanh(self.stateC)
def call(self, inputs, states):
self.stateH = states[0]
self.stateC = states[1]
self.merge_with_state(inputs)
self.forget_gate()
self.input_gate()
self.output_gate()
return self.stateH, [self.stateH, self.stateC]
# Testing
inp = Input(shape=(None, 3))
lstm = RNN(CustomLSTMCell(10))(inp)
model = Model(inputs=inp, outputs=lstm)
inp_value = [[[[1,2,3], [2,3,4], [3,4,5]]]]
pred = model.predict(inp_value)
print(pred)
Однако, когда я попытался проверить это, возникла исключительная ситуация со следующим сообщением:
IndexError: tuple index out of range
в функции call
в строке где я установил значение для self.stateC
. Здесь я подумал, что изначально аргумент states
функции call
является тензором, а не списком тензоров, поэтому я и получаю ошибку. Поэтому я добавил строку self.already_called = False
к классам __init__
и следующий сегмент к функции call
:
if not self.already_called:
self.stateH = K.ones(self.state_size)
self.stateC = K.ones(self.state_size)
self.already_called = True
else:
self.stateH = states[0]
self.stateC = states[1]
, надеясь, что это устранит проблему. Это привело к другой ошибке в функции merge_with_state
:
ValueError: Shape must be rank 1 but is rank 2 for 'rnn_1/concat' (op: 'ConcatV2') with input shapes: [10], [?,3], [].
, которую я искренне не получаю, поскольку слой RNN должен только «показывать» тензоры CustomLSTMCell с shape (3), а не (None, 3), поскольку ось 0 является осью, она должна повторяться вдоль. На данный момент я был убежден, что я делаю что-то действительно неправильно и должен обратиться за помощью к сообществу. В основном мой вопрос: что не так с моим кодом и если «почти все», то как мне реализовать LSTMCell с нуля?