Многоканальная входная модель в TensorFlow 2 требует очень много времени для обучения, хотя большинство слоев не обучаются - PullRequest
0 голосов
/ 22 февраля 2020

Я замораживаю веса всех предварительно обученных моделей и создаю модель многоканального ввода. Модель работает на TensorFlow 2.0, и я использую 2 графических процессора. Обучающие выборки - 24912. Общее время, затрачиваемое каждой эпохой, составляет более одного часа, а обучаемые параметры - всего лишь 46 033 Может кто-нибудь предложить способ ускорить обучение.

model1 = load_model('a.h5')
model2 = load_model('b.h5')
model3 = load_model('c.h5')
model4 = load_model('d.h5')

for layer in model1.layers:
            layer.trainable = False

for layer in model2.layers:
            layer.trainable = False

for layer in model3.layers:
            layer.trainable = False

for layer in model4.layers:
            layer.trainable = False
merge = tf.reshape(concatenate([model1.output, model2.output, model3.output, model4.output]), (-1, 17, 4, 1))

conv1 = Conv2D(32, (6,2))(merge)
batch1 = BatchNormalization()(conv1)
prelu1 = PReLU()(batch1)

conv2 = Conv2D(32, (6,2))(prelu1)
batch2 = BatchNormalization()(conv2)
prelu2 = PReLU()(batch2)

conv3 = Conv2D(64, (7,2))(prelu2)
batch3 = BatchNormalization()(conv3)
prelu3 = PReLU()(batch3)

flat = Flatten()(prelu3)
dense1 = Dense(32, activation='tanh')(flat)
drop1 = Dropout(0.2)(dense1)
predictions = Dense(17, activation='softmax')(drop1)

final_model = Model(inputs=[model1.inputs, model2.inputs, model3.inputs, model4.inputs],outputs=predictions)

Модель выглядит так: Архитектура модели

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...