Как создать комбинированную модель tf.keras с условной оценкой подмоделей - PullRequest
0 голосов
/ 15 октября 2019

Я хочу создать комбинацию из нескольких tf.keras.Sequential моделей, чтобы только одна из подмоделей оценивалась в любой данный момент времени. Чтобы лучше объяснить, я создал следующую модель (код для модели находится в конце этого поста):

Graph of combined model

На этом графике модели sequential, sequential_1, sequential_2 и sequential_3 являются подмоделями на основе LSTM, а label_0 является простой пятой подмоделью. Последний слой arbiter решает на основе значения во входных данных (извлеченных из in_arb), какой из пяти параллельных путей будет фактически давать выходной сигнал сети. Другие четыре значения отбрасываются.

Конечно, вычисления в других четырех параллельных слоях (которые не влияют на результат) тратятся впустую. Поэтому мой вопрос: есть ли способ решить эту проблему внутри TensorFlow, например, с помощью некоторой условной маршрутизации графа вместо параллельного выполнения?

Пример кода для модели:

import tensorflow as tf
tf.compat.v1.enable_eager_execution()

config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.compat.v1.Session(config=config)
tf.compat.v1.keras.backend.set_session(sess)

batch_size = 1

def gen_lstm(base_label, num_features, num_units):
    return tf.keras.Sequential([
        tf.keras.layers.InputLayer(input_shape=(1, num_features), batch_size=batch_size,
                                   name="input_{}".format(base_label)),
        tf.keras.layers.LSTM(num_units,
                             batch_input_shape=(batch_size, 1, num_features),
                             return_sequences=False, stateful=True,
                             name="lstm_{}".format(base_label)),
        tf.keras.layers.Dense(1, name="dense_{}".format(base_label)), # binary
        tf.keras.layers.Activation('sigmoid', name="activ_{}".format(base_label)), # binary
    ])

models = {}
for l in [1, 3, 4, 5]:
    global m
    m = gen_lstm(l, 130, 88)
    models[l] = m

in_all = tf.keras.layers.InputLayer(input_shape=(1, 132), batch_size=batch_size, name="input_all")
in_lstm = tf.keras.layers.Lambda(lambda x: tf.slice(x, [0, 0, 2], [-1, -1, -1]), name="in_lstm")(in_all.output)

def model_0(x):
    return x[:, :, 1]

out_model_0 = tf.keras.layers.Lambda(model_0, name="label_0")(in_all.output)

out_concat = tf.keras.layers.Concatenate(axis=1, name="concat_infer")([out_model_0] + [m(in_lstm) for m in models.values()])

in_arb = tf.keras.layers.Lambda(lambda x: tf.reshape(tf.slice(x, [0, 0, 0], [-1, -1, 1]), (batch_size, 1)), name="in_arb")(in_all.output)
out_merged = tf.keras.layers.Concatenate(axis=1, name="concat_arb")([in_arb, out_concat])

def arbiter(x):
    return tf.where(tf.equal(x[:, 0], tf.constant(0.0, dtype=tf.float32)), x[:, 1], tf.where(
        tf.equal(x[:, 0], tf.constant(2.0, dtype=tf.float32)), x[:, 2] + tf.constant(2.0), tf.where(
            tf.equal(x[:, 0], tf.constant(3.0, dtype=tf.float32)), x[:, 3] + tf.constant(3.0), tf.where(
                tf.equal(x[:, 0], tf.constant(4.0, dtype=tf.float32)), x[:, 4] + tf.constant(4.0),
                x[:, 5] + tf.constant(5.0)))))

out_merged = tf.keras.layers.Lambda(arbiter, name="arbiter")(out_merged)

lstm_model = tf.keras.Model([in_all.input], out_merged)

print(lstm_model.summary())
tf.keras.utils.plot_model(lstm_model, to_file="./temp.png", show_shapes=True)
...