Keras, как написать параллельную модель, для мультиклассового прогнозирования - PullRequest
1 голос
/ 06 ноября 2019

У меня есть следующая модель, где keep_features = 900 или около того, y - горячее кодирование классов. Хотя я ищу архитектуру ниже (возможно ли это с помощью keras и как должна выглядеть идея обозначения, особенно параллельная часть и конкатенация)

model = Sequential()
model.add(Dense(keep_features, activation='relu'))
model.add(BatchNormalization())
model.add(Dense(256, activation='relu'))
model.add(BatchNormalization())
model.add(Dense(64, activation='relu'))
model.add(BatchNormalization())
model.add(Dense(3, activation='softmax'))
model.compile(loss=losses.categorical_crossentropy,optimizer='adam',metrics=['mae', 'acc'])

enter image description here

1 Ответ

1 голос
/ 06 ноября 2019

С главой «Модели с несколькими входами и несколькими выходами» здесь вы можете сделать что-то похожее на нужную модель:

K = tf.keras
input1 = K.layers.Input(keep_features_shape)

denseA1 = K.layers.Dense(256, activation='relu')(input1)
denseB1 = K.layers.Dense(256, activation='relu')(input1)
denseC1 = K.layers.Dense(256, activation='relu')(input1)

batchA1 = K.layers.BatchNormalization()(denseA1)
batchB1 = K.layers.BatchNormalization()(denseB1)
batchC1 = K.layers.BatchNormalization()(denseC1)

denseA2 = K.layers.Dense(64, activation='relu')(batchA1)
denseB2 = K.layers.Dense(64, activation='relu')(batchB1)
denseC2 = K.layers.Dense(64, activation='relu')(batchC1)

batchA2 = K.layers.BatchNormalization()(denseA2)
batchB2 = K.layers.BatchNormalization()(denseB2)
batchC2 = K.layers.BatchNormalization()(denseC2)

denseA3 = K.layers.Dense(32, activation='softmax')(batchA2) # individual layer
denseB3 = K.layers.Dense(16, activation='softmax')(batchB2) # individual layer
denseC3 = K.layers.Dense(8, activation='softmax')(batchC2) # individual layer

concat1 = K.layers.Concatenate(axis=-1)([denseA3, denseB3, denseC3])

model = K.Model(inputs=[input1], outputs=[concat1])

model.compile(loss = K.losses.categorical_crossentropy, optimizer='adam', metrics=['mae', 'acc'])

В результате: enter image description here enter image description here

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...