Короче говоря, я обучил модель бинарной классификации с равными данными в каждом классе, чтобы не было дисбаланса класса. Модель обучена на 10 000 изображений с соответствующими метками и проверена на 6 000 изображений с соответствующими метками.
Результатом является модель с точностью 0,995, что должно означать, что реализация модели будет классифицировать правильные классы 0,995 времени. (Модель НЕ выбирает класс А все время и является правильной в 0.995 случаев, потому что нет дисбаланса класса)
Однако это не так. Кроме того, данные были перетасованы, поэтому модель также не угадывает класс A для первых 5000 изображений, а затем угадывает класс B для остальных, чтобы получить точность 0,995.
Полный код, вопрос и вещи, которые я принял к сведению, есть на моем github:
https://github.com/Nickclickflick/tutorials
Не стесняйтесь загружать и использовать модель, чтобы увидеть результаты робота flappy bird.
Изменить 1: 8 000 от общего количества изображений являются оригинальными, а остальные 8 000 дополнены, как описано ниже
В следующем фрагменте кода показано увеличение исходных изображений
datagen = ImageDataGenerator(featurewise_center=True, samplewise_center=True,
featurewise_std_normalization=True, samplewise_std_normalization=True,
zca_whitening=True, zca_epsilon=1e-06)
Редактировать 2: Следующий код был использован для генерации исходного набора данных (это доступно на github)
import numpy as np
from grabscreen import grab_screen
import cv2
import time
from getkeys import key_check
import os
jump = [1,0]
do_nothing = [0,1]
starting_value = 1
while True:
file_name = 'E:/flappy/tmp_data/training_data-{}.npy'.format(starting_value)
if os.path.isfile(file_name):
print('File exists, moving along',starting_value)
starting_value += 1
else:
print('File does not exist, starting fresh!',starting_value)
break
def keys_to_output(keys):
output = [0,0]
if ' ' in keys:
output = jump
else:
output = do_nothing
return output
def main(file_name, starting_value):
file_name = file_name
starting_value = starting_value
training_data = []
# countdown
for i in list(range(6))[::-1]:
print(i+1)
time.sleep(1)
paused = False
print('STARTING!!!')
while True:
if not paused:
screen = grab_screen(region=(0,200,600,1000))
last_time = time.time()
# resize to something a bit more acceptable for a CNN
screen = cv2.resize(screen, (150,250))
# run a color convert:
screen = cv2.cvtColor(screen, cv2.COLOR_BGR2RGB)
keys = key_check()
output = keys_to_output(keys)
training_data.append([screen,output])
if len(training_data) % 10 == 0:
print(len(training_data))
if len(training_data) == 100:
np.save(file_name,training_data)
print('SAVED')
training_data = []
starting_value += 1
file_name = 'E:/flappy/tmp_data/training_data-{}.npy'.format(starting_value)
keys = key_check()
# pause script
if 'T' in keys:
if paused:
paused = False
print('unpaused!')
time.sleep(1)
else:
print('Pausing!')
paused = True
time.sleep(1)
main(file_name, starting_value)