В настоящее время я слежу за курсом python от sentdex, на названном Python играет GTAV, однако, попытавшись продолжить, я понял, что многие модули и код, которые он использует, устарели, и после пробуя что-то новое, я начал использовать CNN, созданный с помощью Keras, и теперь я, кажется, застрял в попытке загрузить данные и обучить сеть.
import numpy as np
from alex_neural_net import alexnet
from keras import backend
backend.set_image_data_format('channels_last')
# save np.load
np_load_old = np.load
# modify the default parameters of np.load
np.load = lambda *a,**k: np_load_old(*a, allow_pickle=True, **k)
WIDTH = 80
HEIGHT = 60
LR = 1e-3
EPOCHS = 8
MODEL_NAME = 'pyslither-io-{}-{}-{}-epochs.model'.format(LR, 'alexnetv2', EPOCHS)
model = alexnet(WIDTH, HEIGHT, LR)
train_data = np.load('training_data_v2.npy')
train = train_data[:-500]
test = train_data[-500:]
X = np.array([i[0] for i, in train]).reshape(-1,WIDTH,HEIGHT,1)
Y = [i[1] for i in train]
test_x = np.array([i[0] for i in test]).reshape(-1,WIDTH,HEIGHT,1)
test_y = [i[1] for i in test]
#test y is the keys pressed array
model.fit({'input': X}, {'targets': Y}, nb_epoch=EPOCHS, validation_set=({'input': test_x},
{'targets': test_y}),
snapshot_step=500, run_id=MODEL_NAME)
# Save the model
model_json = model.to_json()
with open("weights/model.json", "w") as json_file:
json_file.write(model_json)
#model.save(MODEL_NAME)
# tensorboard --logdir=foo:E:\screengrab
# restore np.load for future normal usage
np.load = np_load_old
Пытаюсь обучить, но получаю эту ошибку:
Warning (from warnings module):
File "E:\screengrab\train_model.py", line 36
snapshot_step=500, run_id=MODEL_NAME)
UserWarning: The `nb_epoch` argument in `fit` has been renamed `epochs`.
Traceback (most recent call last):
File "E:\screengrab\train_model.py", line 36, in <module>
snapshot_step=500, run_id=MODEL_NAME)
File "E:\python64bit\lib\site-packages\keras\engine\training.py", line 1118, in fit
raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))
TypeError: Unrecognized keyword arguments: {'validation_set': ({'input': array([[[[255],
[255],
[255],
...,
[255],
[255],
[255]],
[[255],
[255],
[255],
...,
[236],
[236],
[236]],
[[236],
[236],
[236],
...,
[188],
[176],
[172]],
...,
[[255],
[255],
[255],
...,
[255],
[255],
[255]],
[[255],
[255],
[255],
...,
[255],
[255],
[255]],
[[255],
[255],
[255],
...,
[255],
[255],
[255]]],
[[[255],
[255],
[255],
...,
[255],
[255],
[255]],
[[255],
[255],
[255],
...,
[243],
[243],
[243]],
[[243],
[243],
[243],
...,
[243],
[187],
[168]],
...,
[[254],
[247],
[254],
...,
[254],
[254],
[254]],
[[253],
[254],
[241],
...,
[255],
[255],
[255]],
[[255],
[255],
[255],
...,
[254],
[254],
[254]]],
[[[255],
[255],
[255],
...,
[255],
[255],
[255]],
[[255],
[255],
[255],
...,
[243],
[243],
[243]],
[[243],
[243],
[243],
...,
[243],
[187],
[168]],
...,
[[254],
[247],
[254],
...,
[254],
[254],
[254]],
[[253],
[254],
[241],
...,
[255],
[255],
[255]],
[[255],
[255],
[255],
...,
[254],
[254],
[254]]],
...,
[[[255],
[255],
[255],
...,
[255],
[255],
[255]],
[[255],
[255],
[255],
...,
[236],
[236],
[236]],
[[236],
[236],
[236],
...,
[188],
[176],
[172]],
...,
[[255],
[255],
[255],
...,
[255],
[255],
[255]],
[[255],
[255],
[255],
...,
[220],
[255],
[255]],
[[255],
[255],
[255],
...,
[255],
[255],
[255]]],
[[[255],
[255],
[255],
...,
[255],
[255],
[255]],
[[255],
[255],
[255],
...,
[236],
[236],
[236]],
[[236],
[236],
[236],
...,
[188],
[176],
[172]],
...,
[[255],
[255],
[255],
...,
[255],
[255],
[255]],
[[255],
[255],
[255],
...,
[220],
[255],
[255]],
[[255],
[255],
[229],
...,
[255],
[255],
[255]]],
[[[255],
[255],
[255],
...,
[255],
[255],
[255]],
[[255],
[255],
[255],
...,
[236],
[236],
[236]],
[[236],
[236],
[236],
...,
[188],
[176],
[172]],
...,
[[255],
[255],
[255],
...,
[255],
[255],
[255]],
[[255],
[255],
[255],
...,
[255],
[255],
[255]],
[[255],
[255],
[255],
...,
[255],
[255],
[255]]]], dtype=uint8)}, {'targets': [[0, 1, 0], [1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 1, 0], [0, 1, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 0, 1], [0, 0, 1], [0, 0, 1], [1, 0, 0], [0, 1, 0], [0, 1, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [1, 0, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 0, 1], [0, 0, 1], [1, 0, 0], [1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 1, 0], [0, 0, 1], [1, 0, 0], [1, 0, 0], [0, 1, 0], [1, 0, 0], [0, 1, 0], [1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 1, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [1, 0, 0], [1, 0, 0], [1, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1], [1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [1, 0, 0], [0, 1, 0], [1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1], [1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1], [0, 1, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0], [0, 0, 1], [0, 0, 1], [0, 1, 0], [0, 1, 0],
[0, 0, 1], [0, 1, 0], [0, 0, 1], [0, 0, 1], [0, 1, 0], [1, 0, 0], [1, 0, 0], [0, 1, 0], [1, 0,
0], [0, 0,
1], [0, 0, 1], [0, 1, 0], [0, 1, 0], [0, 1, 0], [1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1], [0,
0, 1],
[0, 1, 0]]}), 'snapshot_step': 500, 'run_id': 'pyslither-io-0.001-alexnetv2-8-epochs.model'}
Изменить: получить новую ошибку
Traceback (most recent call last):
File "E:\python64bit\lib\site-packages\keras\engine\training_utils.py", line
80, in standardize_input_data
for x in names
File "E:\python64bit\lib\site-packages\keras\engine\training_utils.py", line
80, in <listcomp>
for x in names
KeyError: 'conv2d_1_input'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "E:\screengrab\train_model.py", line 36, in <module>
validation_steps=500)
File "E:\python64bit\lib\site-packages\keras\engine\training.py", line 1154,
in fit
batch_size=batch_size)
File "E:\python64bit\lib\site-packages\keras\engine\training.py", line 579,
in _standardize_user_data
exception_prefix='input')
File "E:\python64bit\lib\site-packages\keras\engine\training_utils.py", line
85, in standardize_input_data
'for each key in: ' + str(names))
ValueError: No data provided for "conv2d_1_input". Need data for each key in:
['conv2d_1_input']