Пользовательский изменяющий слой возвращает тензор с формой (?,?,?) - PullRequest
0 голосов
/ 13 мая 2019

Мне нужен слой, который преобразует 4-мерный тензор из сверточного слоя с формой (нет, 3, 3, 2048) в трехмерный тензор с формой (нет, 9, 2048) для подачи в LSTM, где 9 - эторазмер временного шага.

Когда я использую сам слой, он работает, но когда я использую его в последовательной модели, следующий слой получает (?,?,?) как input_shape из выводамоего пользовательского слоя.

Ниже вы можете найти мой код:

class Conv2LSTM(Layer):

    '''The :class:`Conv2LSTM` is a custom layer that reshapes the input tensor collapsing the width and height dimensions to a single dimension that represents the sequence accepted by the LSTM.
    '''

    def __init__(self, **kwargs):
        super(Conv2LSTM, self).__init__(**kwargs)

    def build(self, input_shape):
        self.input_spec = [InputSpec(shape=input_shape)]
        super(Conv2LSTM, self).build(input_shape)

    def call(self, x, mask=None):

        '''Overrides the :class:`keras.engine.topology.Layers` method. It collapses the second and third dimension of the tensor into a single dimension.

        :param x: input tensor
        :param mask: tensor mask
        :return: re-ordered tensor
        '''

        return K.reshape(x, (K.shape(x)[0],) + (K.shape(x)[1]*K.shape(x)[2], K.shape(x)[3]))

    def get_config(self):
        base_config = super(Conv2LSTM, self).get_config()
        return dict(list(base_config.items()))

    def compute_output_shape(self, input_shape):
        return (input_shape[0],) + (input_shape[1]*input_shape[2], input_shape[3])

Как такое возможно, что если я напечатаю фигуру внутри слоя, это будет правильно, если я создам модель сэтот единственный слой работает, но в сочетании с последовательным слоем возвращает форму NoneType?

1 Ответ

0 голосов
/ 13 мая 2019

По некоторым причинам, хотя в каждом ответе из других аналогичных вопросов предлагается использовать K.shape(x) для извлечения формы входного тензора, в этом случае ошибка была вызвана этим.

Достаточно было заменить K.shape(x)[i] на x.shape[i].value.

Новая реализация call:

return K.reshape(x, (-1, x.shape[1].value * x.shape[2].value, x.shape[3].value))

...