Вы на правильном пути. Вы можете сделать это двумя способами: построить одну модель, разделенную на функциональные части, или построить несколько моделей и связать их вместе. Функционально это одно и то же, поскольку слой является моделью; модель может быть слоем.
Я собираюсь использовать несколько простых Dense
слоев для представления от model_a
до model_c
. В качестве предостережения, есть ошибка, с которой вы можете столкнуться, если будете использовать последовательный API, поэтому я продемонстрирую с функциональным API, но уверяю вас, это так же просто (модель определяется только после слои, а не раньше).
Поскольку одна модель разделена на функции:
import tensorflow as tf
def model_a(x):
return tf.keras.layers.Dense(56,name='model_a_whatever')(x) # returning output of layer
def model_b(x):
x = tf.keras.layers.Dense(56,name='model_b_whatever')(x)
return tf.keras.layers.Dense(56,name='model_b_more_whatever')(x)
def model_c(x):
x = tf.keras.layers.Dense(56,name='model_c_whatever')(x)
x = tf.keras.layers.Dense(56,name='model_c_more_whatever')(x)
return tf.keras.layers.Dense(56,name='model_c_even_more_whatever')(x)
# in through the input layer
main_input = tf.keras.layers.Input(shape=(12,34,56),name='INPUT')
# now through functions containing different models' layers
left = model_a(main_input)
right = model_b(main_input)
# concatenate their outputs
concatenated = tf.keras.layers.Concatenate(axis=-1)([left,right])
# now through function containing layers of model c
left = model_c(concatenated)
# and the juke right to a fully connected layer
right = tf.keras.layers.Dense(56,name='FC')(concatenated)
# then add the outputs and apply softmax activation
added = tf.keras.layers.Add(name='add')([left,right])
outputs = tf.keras.layers.Activation('softmax',name='Softmax')(added)
# now define the model
model = tf.keras.models.Model(main_input,outputs) # Model(input layer, final output))
print(model.summary())
tf.keras.utils.plot_model(model, to_file='just_a_model.png')
Диаграмма будет выглядеть более загроможденной, чем ваша, поскольку все слои будут видны:
![Diagram of the model by this method](https://i.stack.imgur.com/irrzJ.png)
Столько моделей объединено:
# as separate models linked together
def build_model_a():
a = tf.keras.layers.Input(shape=(12,34,56),name='model_a_input')
b = tf.keras.layers.Dense(56,name='model_a_whatever')(a) # whatever layers
return tf.keras.models.Model(a,b,name='MODEL_A') # returning model, not just layer output
def build_model_b():
a = tf.keras.layers.Input(shape=(12,34,56),name='model_b_input')
b = tf.keras.layers.Dense(56,name='model_b_whatever')(a)
b = tf.keras.layers.Dense(56,name='model_b_more_whatever')(b)
return tf.keras.models.Model(a,b,name='MODEL_B')
def build_model_c():
a = tf.keras.layers.Input(shape=(12,34,112),name='model_c_input') # axis 2 is doubled because concatenation.
b = tf.keras.layers.Dense(56,name='model_c_whatever')(a)
b = tf.keras.layers.Dense(56,name='model_c_more_whatever')(b)
b = tf.keras.layers.Dense(56,name='model_c_even_more_whatever')(b)
return tf.keras.models.Model(a,b,name='MODEL_C')
# define the main input
main_input = tf.keras.layers.Input(shape=(12,34,56),name='INPUT')
# build the models
model_a = build_model_a()
model_b = build_model_b()
model_c = build_model_c()
# pass input through models a and b
a = model_a(main_input)
b = model_b(main_input)
# concatenate their outputs
ab = tf.keras.layers.Concatenate(axis=-1,name='Concatenate')([a,b])
# pass through model c and fully-connected layer
c = model_c(ab)
d = tf.keras.layers.Dense(56,name='FC')(ab)
# add their outputs and apply softmax activation
add = tf.keras.layers.Add(name="add")([c,d])
outputs = tf.keras.layers.Activation('softmax',name='Softmax')(add)
model = tf.keras.models.Model(main_input,outputs)
print(model.summary())
tf.keras.utils.plot_model(model, to_file='multi_model.png')
Хотя это функционально та же сеть, что и в первом случае, диаграмма теперь соответствует ваш:
![Diagram of model built by second method](https://i.stack.imgur.com/e7XkV.png)
Любой метод будет работать. Как видите, первый метод - это просто очистка кода; для наглядности помещаем отдельные конвейеры данных в функции. Если вы хотите усложнить задачу, например, иметь разные функции потерь для подмоделей и т. Д., Тогда второй метод может упростить процесс. Ошибка, о которой я упоминал, возникает, только если вы используете последовательный API со вторым методом.