Не понимаю, почему я получаю ошибку типа - PullRequest
0 голосов
/ 08 мая 2020

В настоящее время я слежу за курсом python от sentdex, на названном Python играет GTAV, однако, попытавшись продолжить, я понял, что многие модули и код, которые он использует, устарели, и после пробуя что-то новое, я начал использовать CNN, созданный с помощью Keras, и теперь я, кажется, застрял в попытке загрузить данные и обучить сеть.

import numpy as np
from alex_neural_net import alexnet
from keras import backend
backend.set_image_data_format('channels_last')


# save np.load
np_load_old = np.load

# modify the default parameters of np.load
np.load = lambda *a,**k: np_load_old(*a, allow_pickle=True, **k)


WIDTH = 80
HEIGHT = 60
LR = 1e-3
EPOCHS = 8
MODEL_NAME = 'pyslither-io-{}-{}-{}-epochs.model'.format(LR, 'alexnetv2', EPOCHS)

model = alexnet(WIDTH, HEIGHT, LR)

train_data = np.load('training_data_v2.npy')


train = train_data[:-500]
test = train_data[-500:]

X = np.array([i[0] for i, in train]).reshape(-1,WIDTH,HEIGHT,1)
Y = [i[1] for i in train]

test_x = np.array([i[0] for i in test]).reshape(-1,WIDTH,HEIGHT,1)
test_y = [i[1] for i in test]
#test y is the keys pressed array

model.fit({'input': X}, {'targets': Y}, nb_epoch=EPOCHS, validation_set=({'input': test_x}, 
{'targets': test_y}), 
snapshot_step=500, run_id=MODEL_NAME)

# Save the model
model_json = model.to_json()
with open("weights/model.json", "w") as json_file:
json_file.write(model_json)

#model.save(MODEL_NAME)
# tensorboard --logdir=foo:E:\screengrab

# restore np.load for future normal usage
np.load = np_load_old

Пытаюсь обучить, но получаю эту ошибку:

     Warning (from warnings module):
     File "E:\screengrab\train_model.py", line 36
     snapshot_step=500, run_id=MODEL_NAME)
     UserWarning: The `nb_epoch` argument in `fit` has been renamed `epochs`.
     Traceback (most recent call last):
     File "E:\screengrab\train_model.py", line 36, in <module>
     snapshot_step=500, run_id=MODEL_NAME)
     File "E:\python64bit\lib\site-packages\keras\engine\training.py", line 1118, in fit
     raise TypeError('Unrecognized keyword arguments: ' + str(kwargs))
     TypeError: Unrecognized keyword arguments: {'validation_set': ({'input': array([[[[255],
     [255],
     [255],
     ...,
     [255],
     [255],
     [255]],

    [[255],
     [255],
     [255],
     ...,
     [236],
     [236],
     [236]],

    [[236],
     [236],
     [236],
     ...,
     [188],
     [176],
     [172]],

    ...,

    [[255],
     [255],
     [255],
     ...,
     [255],
     [255],
     [255]],

    [[255],
     [255],
     [255],
     ...,
     [255],
     [255],
     [255]],

    [[255],
     [255],
     [255],
     ...,
     [255],
     [255],
     [255]]],


   [[[255],
     [255],
     [255],
     ...,
     [255],
     [255],
     [255]],

    [[255],
     [255],
     [255],
     ...,
     [243],
     [243],
     [243]],

    [[243],
     [243],
     [243],
     ...,
     [243],
     [187],
     [168]],

    ...,

    [[254],
     [247],
     [254],
     ...,
     [254],
     [254],
     [254]],

    [[253],
     [254],
     [241],
     ...,
     [255],
     [255],
     [255]],

    [[255],
     [255],
     [255],
     ...,
     [254],
     [254],
     [254]]],


   [[[255],
     [255],
     [255],
     ...,
     [255],
     [255],
     [255]],

    [[255],
     [255],
     [255],
     ...,
     [243],
     [243],
     [243]],

    [[243],
     [243],
     [243],
     ...,
     [243],
     [187],
     [168]],

    ...,

    [[254],
     [247],
     [254],
     ...,
     [254],
     [254],
     [254]],

    [[253],
     [254],
     [241],
     ...,
     [255],
     [255],
     [255]],

    [[255],
     [255],
     [255],
     ...,
     [254],
     [254],
     [254]]],


   ...,


   [[[255],
     [255],
     [255],
     ...,
     [255],
     [255],
     [255]],

    [[255],
     [255],
     [255],
     ...,
     [236],
     [236],
     [236]],

    [[236],
     [236],
     [236],
     ...,
     [188],
     [176],
     [172]],

    ...,

    [[255],
     [255],
     [255],
     ...,
     [255],
     [255],
     [255]],

    [[255],
     [255],
     [255],
     ...,
     [220],
     [255],
     [255]],

    [[255],
     [255],
     [255],
     ...,
     [255],
     [255],
     [255]]],


   [[[255],
     [255],
     [255],
     ...,
     [255],
     [255],
     [255]],

    [[255],
     [255],
     [255],
     ...,
     [236],
     [236],
     [236]],

    [[236],
     [236],
     [236],
     ...,
     [188],
     [176],
     [172]],

    ...,

    [[255],
     [255],
     [255],
     ...,
     [255],
     [255],
     [255]],

    [[255],
     [255],
     [255],
     ...,
     [220],
     [255],
     [255]],

    [[255],
     [255],
     [229],
     ...,
     [255],
     [255],
     [255]]],


   [[[255],
     [255],
     [255],
     ...,
     [255],
     [255],
     [255]],

    [[255],
     [255],
     [255],
     ...,
     [236],
     [236],
     [236]],

    [[236],
     [236],
     [236],
     ...,
     [188],
     [176],
     [172]],

    ...,

    [[255],
     [255],
     [255],
     ...,
     [255],
     [255],
     [255]],

    [[255],
     [255],
     [255],
     ...,
     [255],
     [255],
     [255]],

    [[255],
     [255],
     [255],
     ...,
     [255],
     [255],
     [255]]]], dtype=uint8)}, {'targets': [[0, 1, 0], [1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 1, 0], [0, 1, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 0, 1], [0, 0, 1], [0, 0, 1], [1, 0, 0], [0, 1, 0], [0, 1, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [1, 0, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 0, 1], [0, 0, 1], [1, 0, 0], [1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 1, 0], [0, 0, 1], [1, 0, 0], [1, 0, 0], [0, 1, 0], [1, 0, 0], [0, 1, 0], [1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 1, 0], [0, 1, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 0, 0], [1, 0, 0], [1, 0, 0], [1, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1], [1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1], [0, 1, 0], [0, 0, 1], [1, 0, 0], [0, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1], [1, 0, 0], [0, 1, 0], [1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1], [1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1], [0, 1, 0], [1, 0, 0], [0, 0, 1], [1, 0, 0], [0, 0, 1], [0, 0, 1], [0, 1, 0], [0, 1, 0], 
    [0, 0, 1], [0, 1, 0], [0, 0, 1], [0, 0, 1], [0, 1, 0], [1, 0, 0], [1, 0, 0], [0, 1, 0], [1, 0, 
    0], [0, 0, 
    1], [0, 0, 1], [0, 1, 0], [0, 1, 0], [0, 1, 0], [1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1], [0, 
    0, 1], 
    [0, 1, 0]]}), 'snapshot_step': 500, 'run_id': 'pyslither-io-0.001-alexnetv2-8-epochs.model'}

Изменить: получить новую ошибку

Traceback (most recent call last):
File "E:\python64bit\lib\site-packages\keras\engine\training_utils.py", line 
80, in standardize_input_data
for x in names
File "E:\python64bit\lib\site-packages\keras\engine\training_utils.py", line 
80, in <listcomp>
for x in names
KeyError: 'conv2d_1_input'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "E:\screengrab\train_model.py", line 36, in <module>
validation_steps=500)
File "E:\python64bit\lib\site-packages\keras\engine\training.py", line 1154, 
in fit
batch_size=batch_size)
File "E:\python64bit\lib\site-packages\keras\engine\training.py", line 579, 
in _standardize_user_data
exception_prefix='input')
File "E:\python64bit\lib\site-packages\keras\engine\training_utils.py", line 
85, in standardize_input_data
'for each key in: ' + str(names))
ValueError: No data provided for "conv2d_1_input". Need data for each key in: 
['conv2d_1_input']

1 Ответ

0 голосов
/ 08 мая 2020

model.fit() имеет следующие аргументы:

Model.fit(
    x=None,
    y=None,
    batch_size=None,
    epochs=1,
    verbose=1,
    callbacks=None,
    validation_split=0.0,
    validation_data=None,
    shuffle=True,
    class_weight=None,
    sample_weight=None,
    initial_epoch=0,
    steps_per_epoch=None,
    validation_steps=None,
    validation_batch_size=None,
    validation_freq=1,
    max_queue_size=10,
    workers=1,
    use_multiprocessing=False,
)

validation_set среди них нет. Вероятно, вы имели в виду validation_data.

...