Невозможно выполнить функцию обратного вызова, созданную с использованием tenorflow - PullRequest
1 голос
/ 26 апреля 2020

В рамках учебного пособия по TF 2.0 я опробовал функцию обратного вызова в TensorFlow, которая позволяет модели останавливать обучение при достижении заданного значения c точности или потери. Пример, приведенный в этом Colab , работает нормально. Я попытался запустить подобный пример локально, используя pycharm (с tf gpu conda env), но функция обратного вызова вообще не выполняется и работает до последней эпохи. Нет никаких ошибок, и коды выглядят одинаково.

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten
from matplotlib import pyplot as plt
from tensorflow.keras.callbacks import Callback


class MyCallback(Callback):
    def on_epochs_end(self, epoch, logs={}):
        if(logs.get('accuracy') > 0.9):
            print("\n Training stopping now. accuracy reached 90 !")
            self.model.stop_training = True


callback = MyCallback()

# Input data
(training_data, training_labels), (testing_data, testing_labels) = fashion_mnist.load_data()
training_data = training_data / 255.0
testing_data = testing_data / 255.0
plt.imshow(training_data[0], cmap='gray')

# Network
model = Sequential([
    Flatten(input_shape=(28, 28)),
    Dense(units=128, activation='relu'),
    Dense(units=10, activation='softmax')])

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.fit(training_data, training_labels, epochs=25, callbacks=[callback])

Я имел в виду различные примеры некоторых решений и натолкнулся на утверждения типа
- activation='relu'
- activation=tf.nn.relu
- activation=tf.keras.activation.relu

Какой из них лучше использовать? Ошибка вызвана неправильным импортом?

Если бы кто-нибудь мог дать некоторые подсказки, это было бы полезно.

1 Ответ

1 голос
/ 26 апреля 2020

Ошибка из-за опечатки в вашем классе обратного вызова. В определении функции on_epoch_end вы опечатали как on_epochs_end. Кроме этого все правильно.

class MyCallback(Callback):
 #def on_epochs_end(self, epoch, logs={}): # should be epoch (not epochs)
  def on_epoch_end(self, epoch, logs={}):
    if(logs.get('accuracy') > 0.9):
      print("\n Training stopping now. accuracy reached 90 !")
      self.model.stop_training = True

Полный код здесь для вашей справки.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...