Я хочу создать комбинацию из нескольких tf.keras.Sequential
моделей, чтобы только одна из подмоделей оценивалась в любой данный момент времени. Чтобы лучше объяснить, я создал следующую модель (код для модели находится в конце этого поста):

На этом графике модели sequential
, sequential_1
, sequential_2
и sequential_3
являются подмоделями на основе LSTM, а label_0
является простой пятой подмоделью. Последний слой arbiter
решает на основе значения во входных данных (извлеченных из in_arb
), какой из пяти параллельных путей будет фактически давать выходной сигнал сети. Другие четыре значения отбрасываются.
Конечно, вычисления в других четырех параллельных слоях (которые не влияют на результат) тратятся впустую. Поэтому мой вопрос: есть ли способ решить эту проблему внутри TensorFlow, например, с помощью некоторой условной маршрутизации графа вместо параллельного выполнения?
Пример кода для модели:
import tensorflow as tf
tf.compat.v1.enable_eager_execution()
config = tf.compat.v1.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.compat.v1.Session(config=config)
tf.compat.v1.keras.backend.set_session(sess)
batch_size = 1
def gen_lstm(base_label, num_features, num_units):
return tf.keras.Sequential([
tf.keras.layers.InputLayer(input_shape=(1, num_features), batch_size=batch_size,
name="input_{}".format(base_label)),
tf.keras.layers.LSTM(num_units,
batch_input_shape=(batch_size, 1, num_features),
return_sequences=False, stateful=True,
name="lstm_{}".format(base_label)),
tf.keras.layers.Dense(1, name="dense_{}".format(base_label)), # binary
tf.keras.layers.Activation('sigmoid', name="activ_{}".format(base_label)), # binary
])
models = {}
for l in [1, 3, 4, 5]:
global m
m = gen_lstm(l, 130, 88)
models[l] = m
in_all = tf.keras.layers.InputLayer(input_shape=(1, 132), batch_size=batch_size, name="input_all")
in_lstm = tf.keras.layers.Lambda(lambda x: tf.slice(x, [0, 0, 2], [-1, -1, -1]), name="in_lstm")(in_all.output)
def model_0(x):
return x[:, :, 1]
out_model_0 = tf.keras.layers.Lambda(model_0, name="label_0")(in_all.output)
out_concat = tf.keras.layers.Concatenate(axis=1, name="concat_infer")([out_model_0] + [m(in_lstm) for m in models.values()])
in_arb = tf.keras.layers.Lambda(lambda x: tf.reshape(tf.slice(x, [0, 0, 0], [-1, -1, 1]), (batch_size, 1)), name="in_arb")(in_all.output)
out_merged = tf.keras.layers.Concatenate(axis=1, name="concat_arb")([in_arb, out_concat])
def arbiter(x):
return tf.where(tf.equal(x[:, 0], tf.constant(0.0, dtype=tf.float32)), x[:, 1], tf.where(
tf.equal(x[:, 0], tf.constant(2.0, dtype=tf.float32)), x[:, 2] + tf.constant(2.0), tf.where(
tf.equal(x[:, 0], tf.constant(3.0, dtype=tf.float32)), x[:, 3] + tf.constant(3.0), tf.where(
tf.equal(x[:, 0], tf.constant(4.0, dtype=tf.float32)), x[:, 4] + tf.constant(4.0),
x[:, 5] + tf.constant(5.0)))))
out_merged = tf.keras.layers.Lambda(arbiter, name="arbiter")(out_merged)
lstm_model = tf.keras.Model([in_all.input], out_merged)
print(lstm_model.summary())
tf.keras.utils.plot_model(lstm_model, to_file="./temp.png", show_shapes=True)