Ошибка типа:> не поддерживается между экземплярами NoneType и float - PullRequest
1 голос
/ 18 января 2020

У меня есть этот код, и он вызывает ошибку в python 3, и такое сравнение может работать на python 2, как я могу его изменить

import tensorflow as tf 
def train_set():
    class MyCallBacks(tf.keras.callbacks.Callback):
        def on_epoch_end(self,epoch,logs={}):
            if(logs.get('acc')>0.95):
                print('the training will stop !')
                self.model.stop_training=True
    callbacks=MyCallBacks()
    mnist_dataset=tf.keras.datasets.mnist 
    (x_train,y_train),(x_test,y_test)=mnist_dataset.load_data()
    x_train=x_train/255.0
    x_test=x_test/255.0
    classifier=tf.keras.Sequential([
                                    tf.keras.layers.Flatten(input_shape=(28,28)),
                                    tf.keras.layers.Dense(512,activation=tf.nn.relu),
                                    tf.keras.layers.Dense(10,activation=tf.nn.softmax)
                                    ])
    classifier.compile(
                        optimizer='sgd',
                        loss='sparse_categorical_crossentropy',
                        metrics=['accuracy']
                       )    
    history=classifier.fit(x_train,y_train,epochs=20,callbacks=[callbacks])
    return history.epoch,history.history['acc'][-1]
train_set()

Ответы [ 3 ]

2 голосов
/ 01 апреля 2020

Tensorflow 2.0

DESIRED_ACCURACY = 0.979

class myCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epochs, logs={}) :
        if(logs.get('acc') is not None and logs.get('acc') >= DESIRED_ACCURACY) :
            print('\nReached 99.9% accuracy so cancelling training!')
            self.model.stop_training = True

callbacks = myCallback()
2 голосов
/ 18 января 2020

похоже, что ваша ошибка похожа на Исключение с обратным вызовом в Keras - Tensorflow 2.0 - Python попробуйте заменить logs.get('acc') на logs.get('accuracy')

1 голос
/ 18 января 2020

Это работает в Python2, потому что в Python2 вы можете сравнить None с float, но это невозможно в Python3.

Эта строка

logs.get('acc')

возвращает None и возникает ваша проблема.

Быстрое решение состоит в замене условия на

if logs.get('acc') is not None and logs.get('acc') > 0.95:

Если logs.get('acc') равно None, то указанное выше условие будет коротким -схема и вторая часть, logs.get('acc') > 0.95, не будет оценена, поэтому она не вызовет упомянутую ошибку.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...