model.predict_generator и model.evaluate_generator возвращают совершенно разные точности - PullRequest
0 голосов
/ 26 июня 2019

Я обучил VGG как классификатор из 10 классов на 100 эпох, и это точность обучения / проверки.train/validation accuracy

Кроме того, я хотел протестировать модель на тестовом наборе, поэтому я оценил ее следующим образом:

test_datagen = ImageDataGenerator(
    rescale=1./255,
)

test_generator = test_datagen.flow_from_directory(
        '/content/drive/My Drive/Colab Notebooks/domat/solo-dataset/test/',
        target_size=(224, 224),
        batch_size=32,
        class_mode='categorical',
        shuffle=False
)

steps = 3616 // 32 

loss, accuracy = model_vgg_imagenet_dropout.evaluate_generator(test_generator,
                                             steps = steps,
                                             workers = 4,
                                             use_multiprocessing=True)

Когда я печатаюя получаю результаты (1.4021655139801776, 0.802820796460177), что похоже на то, что я ожидал.
Однако, когда я пытаюсь вручную оценить его с помощью model.predict_generator, я получаю только 13 % точности.
Ниже приведен код для его ручной оценки (генератор - это тот же объект):

predictions = model_vgg_imagenet_dropout.predict_generator(test_generator,
                                             steps = steps,
                                             workers = 4,
                                             use_multiprocessing=True)

y_pred = np.zeros(len(predictions))

for i, p in enumerate(predictions):
  max_index = np.argmax(p)
  y_pred[i] = max_index

# the y_pred array should contain the class index of each sample, as defined by test_generator.class_indices

y_true = test_generator.classes
from sklearn.metrics import accuracy_score
print(accuracy_score(y_true, y_pred))

Я не понимаю, где я делаю ошибку, мне кажется, это правильно.

Редактировать: когда я вручную наблюдаю результаты от model.predict_generator () и сопоставляю значения softmax с индексом класса, он буквально выводит как 3 или 4 класса большую часть времени.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...