Я реализую MLP с Keras
и пользовательской функцией потерь. Я замечаю, что model.compile()
занимает очень много времени: кажется, это не конец. Потеря, которую я передал функции compile()
, является обычной. Я также использую другую функцию, которая используется в функции потерь.
Это моя пользовательская потеря:
def get_top_one_probability(vector):
return (K.exp(vector) / K.sum(K.exp(vector)))
def custom_loss(groups_id_count, tf_session):
def listnet_loss(real_labels, predicted_labels):
losses = tf.Variable([[0.0]], tf.float32)
for group in groups_id_count:
start_range = 0
end_range = (start_range + group[1])
batch_real_labels = real_labels[start_range:end_range]
batch_predicted_labels = predicted_labels[start_range:end_range]
loss = -K.sum(get_top_one_probability(batch_real_labels)) * tf.math.log(get_top_one_probability(batch_predicted_labels))
losses = tf.concat([losses, loss], axis=0)
start_range = end_range
return K.mean(losses)
return listnet_loss
И это код MLP:
mlp = keras.models.Sequential()
# add input layer
mlp.add(
keras.layers.Dense(
units=training_dataset.shape[1],
input_shape = (training_dataset.shape[1], ),
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
activation='tanh')
)
# add hidden layer
mlp.add(
keras.layers.Dense(
units=training_dataset.shape[1] + 10,
input_shape = (training_dataset.shape[1] + 10,),
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
activation='relu')
)
# add output layer
mlp.add(
keras.layers.Dense(
units=1,
input_shape = (1, ),
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
activation='softmax')
)
# define SGD optimizer
sgd_optimizer = keras.optimizers.SGD(
lr=0.01, decay=0.01, momentum=0.9, nesterov=True
)
# compile model
print('Compiling model...\n')
mlp.compile(
optimizer=sgd_optimizer,
loss=custom_loss(groups_id_count, tf.compat.v1.Session())
)
mlp.summary() # print model settings
# Training
with tf.device('/GPU:0'):
print('Start training')
mlp.fit(training_dataset, training_dataset_labels, epochs=50, verbose=2, batch_size=training_dataset.shape[0], workers=10)
Почему функция compile()
занимает очень много времени? Заранее спасибо