convLSTM2d с функциональным API - PullRequest
0 голосов
/ 31 марта 2019

У меня есть авто-кодер для сжатия изображений, где закодированный тензор имеет форму: (batch_size, 12, 64, 48).

batch_size - количество изображений, подаваемых в пакете, 12 - количество каналов этого последнего слоя кодера, который имеет ширину / высоту 64x48.

Я хочу ввести это вслой ConvLSTM2D, и я хотел бы, чтобы выход ConvLSTM2D имел то же измерение, что и вход ConvLSTM2D.

Цель состоит в том, чтобы увидеть восстановление изображения в видеопоследовательности, а не неупорядоченные изображения из набора данных.

Поместить ConvLSTM2d между кодером / декодером в архитектуре авто-кодера было сложно, особенно потому, что в большинстве примеров используется последовательный API, и я хочу использовать функциональный API в Keras.

Iпопытался изменить входные данные, но ошибка сохраняется

import tensorflow as tf
import tensorflow.keras.backend as K


def LSTM_layer(input):

    input = tf.keras.backend.expand_dims(input, axis=-1)
    lstm1 = tf.keras.layers.ConvLSTM2D(filters=12, kernel_size=(3, 3), strides=(1, 1), data_format="channels_first",
                                        input_shape=(None, 12, 64, 48), 
                                        padding='same', return_sequences=True)(input)

    return lstm1

def build_model(input_shape):

    #create an input with input_shape as the size
    input_ = tf.keras.Input(shape=input_shape, name="input_node")
    lstm_features = LSTM_layer(input_)

    model = tf.keras.Model(inputs=input_, outputs=[lstm_features])
    return model

def main():

    input_shape = (12, 64, 48) #this is the size of the tensor which is outputted by my encoder, with channels_first assumed
    model = build_model(input_shape)

if __name__ == '__main__':
    main()

К сожалению, это выдает эту ошибку:

Traceback (most recent call last):
  File "lstm.py", line 29, in <module>
    main()
  File "lstm.py", line 26, in main
    model = build_model(input_shape)
  File "lstm.py", line 20, in build_model
    model = tf.keras.Model(inputs=input_, outputs=[lstm_features])
  File "/home/hallab/.local/lib/python3.5/site-packages/tensorflow/python/keras/engine/training.py", line 121, in __init__
    super(Model, self).__init__(*args, **kwargs)
  File "/home/hallab/.local/lib/python3.5/site-packages/tensorflow/python/keras/engine/network.py", line 80, in __init__
    self._init_graph_network(*args, **kwargs)
  File "/home/hallab/.local/lib/python3.5/site-packages/tensorflow/python/training/checkpointable/base.py", line 474, in _method_wrapper
    method(self, *args, **kwargs)
  File "/home/hallab/.local/lib/python3.5/site-packages/tensorflow/python/keras/engine/network.py", line 224, in _init_graph_network
    '(thus holding past layer metadata). Found: ' + str(x))
ValueError: Output tensors to a Model must be the output of a TensorFlow `Layer` (thus holding past layer metadata). Found: Tensor("conv_lst_m2d/transpose_1:0", shape=(?, 12, 12, 48, 1), dtype=float32)

В большинстве сообщений об этой ошибке содержится указание обернуть операцию в лямбду ... ноя не реализую пользовательскую операцию здесь, это должен быть слой keras tf ... правильно?

Кроме того, в моей реализацииЯ хочу, чтобы выходной тензор из модуля LSTM был таким же, как и вход, могу ли я получить некоторую обратную связь об этом?

Спасибо.

1 Ответ

0 голосов
/ 01 апреля 2019

Вы можете использовать лямбду, чтобы обернуть форму вывода K.expand_dims, прежде чем вводить ее в следующий слой следующим образом:

import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras.layers import Lambda

def expand_dims(x):
    return K.expand_dims(x, 1)

def expand_dims_output_shape(input_shape):
    return (input_shape[0], 1, input_shape[1])

def LSTM_layer(input_):
    lstm1 = Lambda(expand_dims, expand_dims_output_shape)(input_)
    lstm1 = tf.keras.layers.ConvLSTM2D(filters=12, kernel_size=(3, 3), strides=(1, 1), data_format="channels_first",                             padding='same', return_sequences=False)(lstm1)
    return lstm1
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...