Я пытаюсь реализовать Stochasti c Усреднение веса (SWA) с тензорным потоком 2.0 в стиле кераса, поэтому мне нужно обновлять веса модели SWA каждый шаг. Я написал специальный Callback для этого, но я получаю предупреждение на каждом шагу. Вот некоторые подробности:
Мой пользовательский обратный вызов:
class CustomCallback(tf.keras.callbacks.Callback):
def __init__(self, valid_data, output_path, swa_alpha=0.99, eval_every=500, eval_batch=16, fold=None):
self.valid_inputs = valid_data[0]
self.valid_outputs = valid_data[1]
self.eval_batch = eval_batch
self.swa_alpha = swa_alpha
self.fold = fold
self.output_path = output_path
self.rho_value = -1 # record the best rho for report
self.eval_every = eval_every
def on_train_begin(self, logs={}):
self.swa_weights = self.model.get_weights()
def on_batch_end(self, batch, logs={}):
# update swa parameters
alpha = min(1 - 1 / (batch + 1), self.swa_alpha)
current_weights = self.model.get_weights()
for i, layer in enumerate(self.model.layers):
self.swa_weights[i] = alpha * self.swa_weights[i] + (1 - alpha) * current_weights[i]
# validation
if batch > 0 and batch % self.eval_every == 0:
# do validation
val_pred = self.model.predict(self.valid_inputs, batch_size=self.eval_batch)
rho_val = compute_spearmanr(self.valid_outputs, val_pred) # the metric
# set the swa parameters and do validation
self.model.set_weights(self.swa_weights)
swa_val_pred = self.model.predict(self.valid_inputs, batch_size=self.eval_batch)
swa_rho_val = compute_spearmanr(self.valid_outputs, swa_val_pred)
# reset the original parameters
self.model.set_weights(current_weights)
# check whether to save model and update best rho value
if rho_val > self.rho_value:
self.rho_value = rho_val
self.model.save_weights(f'{self.output_path}/fold-{fold}-best.h5')
del current_weights
gc.collect()
Вывод выглядит примерно так:
WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (11.428264). Check your callbacks.
WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (11.464315). Check your callbacks.
WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (11.502968). Check your callbacks.
WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (11.518413). Check your callbacks.
Я получаю предупреждение каждый шаг, что означает без запуска кода проверки код для обновления параметров SWA (self.model.get_weights()
и следующих for
l oop) достаточно медленный.
Я понимаю, что обновление параметров происходит очень медленно, поскольку model.get_weights()
и model.set_weights()
оба сделают глубокую копию параметров (новый список новых numpy ndarray в соответствии с моим экспериментом).
Я думаю, что в моей реализации SWA нет ничего плохого (пожалуйста, дайте мне знать если есть какие-либо ошибки), поэтому я просто хочу отключить предупреждение.
Что я пробовал:
- Добавление кода
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
и os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
для отключения WARNING. - Установка
verbose
в 2
и 0
in model.fit()
, т.е. model.fit(..., verbose=2, ...)
и model.fit(..., verbose=0, ...)
Оба не работают.
Есть идеи? Заранее спасибо за любую помощь!