Как отключить предупреждение «tenorflow: метод (on_train_batch_end) медленный по сравнению с пакетным обновлением (). Проверьте ваши обратные вызовы» - PullRequest
3 голосов
/ 15 января 2020

Я пытаюсь реализовать Stochasti c Усреднение веса (SWA) с тензорным потоком 2.0 в стиле кераса, поэтому мне нужно обновлять веса модели SWA каждый шаг. Я написал специальный Callback для этого, но я получаю предупреждение на каждом шагу. Вот некоторые подробности:

Мой пользовательский обратный вызов:


class CustomCallback(tf.keras.callbacks.Callback):
    def __init__(self, valid_data, output_path, swa_alpha=0.99, eval_every=500, eval_batch=16, fold=None):
        self.valid_inputs = valid_data[0]
        self.valid_outputs = valid_data[1]
        self.eval_batch = eval_batch
        self.swa_alpha = swa_alpha
        self.fold = fold
        self.output_path = output_path
        self.rho_value = -1  # record the best rho for report
        self.eval_every = eval_every

    def on_train_begin(self, logs={}):
        self.swa_weights = self.model.get_weights()

    def on_batch_end(self, batch, logs={}):

        # update swa parameters
        alpha = min(1 - 1 / (batch + 1), self.swa_alpha)
        current_weights = self.model.get_weights()
        for i, layer in enumerate(self.model.layers):
            self.swa_weights[i] = alpha * self.swa_weights[i] + (1 - alpha) * current_weights[i]

        # validation
        if batch > 0 and batch % self.eval_every == 0:
            # do validation
            val_pred = self.model.predict(self.valid_inputs, batch_size=self.eval_batch)
            rho_val = compute_spearmanr(self.valid_outputs, val_pred)  # the metric

            # set the swa parameters and do validation
            self.model.set_weights(self.swa_weights)
            swa_val_pred = self.model.predict(self.valid_inputs, batch_size=self.eval_batch)
            swa_rho_val = compute_spearmanr(self.valid_outputs, swa_val_pred)

            # reset the original parameters
            self.model.set_weights(current_weights)

            # check whether to save model and update best rho value
            if rho_val > self.rho_value:
                self.rho_value = rho_val
                self.model.save_weights(f'{self.output_path}/fold-{fold}-best.h5')

        del current_weights
        gc.collect()

Вывод выглядит примерно так:

WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (11.428264). Check your callbacks.
WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (11.464315). Check your callbacks.
WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (11.502968). Check your callbacks.
WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (11.518413). Check your callbacks.

Я получаю предупреждение каждый шаг, что означает без запуска кода проверки код для обновления параметров SWA (self.model.get_weights() и следующих for l oop) достаточно медленный.

Я понимаю, что обновление параметров происходит очень медленно, поскольку model.get_weights() и model.set_weights() оба сделают глубокую копию параметров (новый список новых numpy ndarray в соответствии с моим экспериментом).

Я думаю, что в моей реализации SWA нет ничего плохого (пожалуйста, дайте мне знать если есть какие-либо ошибки), поэтому я просто хочу отключить предупреждение.

Что я пробовал:

  1. Добавление кода os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" и os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" для отключения WARNING.
  2. Установка verbose в 2 и 0 in model.fit(), т.е. model.fit(..., verbose=2, ...) и model.fit(..., verbose=0, ...)

Оба не работают.

Есть идеи? Заранее спасибо за любую помощь!

1 Ответ

0 голосов
/ 16 января 2020

Это не очень удовлетворительный ответ, но не работает TF_CPP_MIN_LOG_LEVEL - это известная проблема: TF_CPP_MIN_LOG_LEVEL не работает с TF2.0 dev20190820 .

Мне удалось воспроизвести вашу проблему на tensorflow==2.1.0-rc1 с примером игрушки здесь:

import os
import time
os.environ['TF_CPP_MIN_LOG_LEVEL'] = "2"
import tensorflow as tf
tf.get_logger().setLevel("WARNING")
tf.autograph.set_verbosity(2)

print(tf.__version__)

mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dropout(0.2),
  tf.keras.layers.Dense(10, activation='softmax')
])

class CustomCallback(tf.keras.callbacks.Callback):

  def on_train_batch_end(self, batch, logs=None):
    time.sleep(3)

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(x_train, y_train, epochs=1, callbacks=[CustomCallback()])
2.1.0
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11493376/11490434 [==============================] - 31s 3us/step
Train on 60000 samples
WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (3.002797). Check your callbacks.
   32/60000 [..............................] - ETA: 1:57:38 - loss: 2.4674 - accuracy: 0.0938WARNING:tensorflow:Method (on_train_batch_end) is slow compared to the batch update (3.002938). Check your callbacks.
...

Ни одно из стандартных предложений (os.environ['TF_CPP_MIN_LOG_LEVEL'], tf.get_logger().setLevel("WARNING") или tf.autograph.set_verbosity(2)) не работает, и я подозреваю, что вам придется подождать пока вышеупомянутая проблема не будет решена.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...