Я написал пользовательский обратный вызов для автоматического замораживания слоев в моей сети с несколькими входами / выходами в случае, если соответствующая потеря падает ниже определенного порога.
Как это работает: после каждой эпохи он проверяет, потеря липервый вес в self.weights ниже значения self.loss_thresholds. Если это так, он замораживает соответствующие слои и делает нижележащие слои обучаемыми. После этого он устанавливает флаг stop_training моделей, чтобы выйти из функции подгонки и перекомпилировать модель.
from keras.callbacks import Callback
class WeightLossCallback(Callback):
def __init__(self, loss_weights):
self.loss_weights = loss_weights
self.weights = ["akVol","biltarres","riskap1","sraj","zbj"]
self.loss_thresholds = {"akVol": 0.12, "biltarres": 0.12, "riskap1": 0.14, "sraj": 1.5, "zbj": 14.}
self.learning_rate = {"akVol": 0.05, "biltarres": 0.05, "riskap1": 0.05, "sraj": 0.01, "zbj": 0.05}
self.stopTraining=False
def on_train_begin(self, logs={}):
K.set_value(self.model.optimizer.lr, self.learning_rate[self.weights[0]])
def on_epoch_end(self, epoch, logs={}):
if epoch < 10:
return
def freezeAttAndLayer(weight):
print("Setting weight " + weight + " to zero.")
self.weights.remove(weight)
self.loss_weights[weight] = 0.0
self.stopTraining=True
print("Step 2: Freezing corresponding layers.")
for layer in self.model.layers:
if weight in layer.name:
print("Freezing layer " + layer.name + ".")
layer.trainable=False
def unfreezeAttAndLayer(weight):
print("Step 3: Unfreezing corresponding layers.")
self.loss_weights[nextWeight] = 1000.
for layer in self.model.layers:
if weight in layer.name:
print("Unfreezing layer " + layer.name + ".")
layer.trainable=True
# in case of last weight in self.weights list
if len(self.weights) < 2:
lastWeight = self.weights[0]
if logs[lastWeight+"_loss"] >= self.loss_thresholds[lastWeight]:
return
else:
freezeAttAndLayer(lastWeight)
self.model.stop_training=True
return
currentWeight = self.weights[0]
nextWeight = self.weights[1]
print("Step 1: Checking weight thresholds for " + currentWeight + " and " + nextWeight + "...")
if logs[currentWeight+"_loss"] < self.loss_thresholds[currentWeight]:
freezeAttAndLayer(currentWeight)
unfreezeAttAndLayer(nextWeight)
if self.stopTraining:
self.stopTraining=False
self.model.stop_training=True
print(self.loss_weights)
print(self.weights)
Функция подбора работает следующим образом:
from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
ModelCpt = ModelCheckpoint("C:/Users/pan11811/Desktop/ModelCheckpoint/test.h5", monitor="loss",save_best_only=True, save_weights_only=False)
WeightLossCpt = WeightLossCallback(loss_weights)
epochCount=0
while(i<21):
ReduceLRCpt = ReduceLROnPlateau(patience=35, min_delta=0.1, factor=0.6, monitor=WeightLossCpt.weights[0]+"_loss", verbose=1, min_lr=0.001)
model.fit(inputDic, outputDic, epochs=2000+epochCount,batch_size=4000, callbacks=[ReduceLRCpt,ModelCpt,WeightLossCpt], verbose=1)
model.compile(optimizer=optimizers.Adamax(lr=0.025),loss=losses.mean_squared_error,loss_weights=Weig htLossCpt.loss_weights)
epochCount+=1
if len(WeightLossCpt.weights) == 0:
print("Training completed.")
break
print("#######################################")
print("RECOMPILED_MODEL")
print("#######################################")
Вот обучениежурнал для замораживания сети sraj:
Epoch 968/2003
64538/64538 [==============================] - 0s 5us/step - loss: 1489.0058 - tbaRenormalized_loss: 11.6664 - tbasum_loss: 180.4811 - vtstkj_loss: 0.0000e+00 - sraj_loss: 1.4885 - riskap1_loss: 0.1149 - biltarres_loss: 0.1144 - akVol_loss: 0.1210 - zbj_loss: 60316658.7545 - zb_loss: 18562112.8038 - sra_output_loss: 0.4971 - vtstk_loss: 0.0000e+00
Step 1: Checking weight thresholds for sraj and zbj...
Setting weight sraj to zero.
Step 2: Freezing corresponding layers.
Freezing layer input_sraj.
Freezing layer sraj_0.
Freezing layer sraj_1.
Freezing layer sraj_2.
Freezing layer sraj_3.
Freezing layer sraj_4.
Freezing layer sraj_5.
Freezing layer sraj.
Step 3: Unfreezing corresponding layers.
Unfreezing layer zbj.
{'tbaRenormalized': 0.0, 'tbasum': 0.0, 'sraj': 0.0, 'riskap1': 0.0, 'zb': 0.0, 'biltarres': 0.0, 'akVol': 0.0, 'vtstk': 0.0, 'zbj': 1000.0}
['zbj']
#######################################
RECOMPILED_MODEL
#######################################
Epoch 1/2004
64538/64538 [==============================] - 0s 6us/step - loss: 18151638696.1287 - tbaRenormalized_loss: 11.6664 - tbasum_loss: 180.4811 - vtstkj_loss: 0.0000e+00 - sraj_loss: 2.5473 - riskap1_loss: 0.1149 - biltarres_loss: 0.1144 - akVol_loss: 0.1210 - zbj_loss: 18151638.7103 - zb_loss: 5467256.6113 - sra_output_loss: 0.9667 - vtstk_loss: 0.0000e+00
Проблема здесь заключается в том, что прямо перед замораживанием слоев потери sraj ниже порогового значения в 1.4885, но после замораживания и перекомпиляции сети потери возрастают до 2,5473, но остаютсяконстанта.
Итак, мой вопрос: почему этот скачок происходит, когда я четко заморозил все слои?
Я ценю любую помощь!