Я пытаюсь предсказать идеальный ход в игре, используя последовательную сеть keras. Сеть довольно проста с (как мне кажется) входной формой (3216,). Код для определения сети прилагается ниже.
self.model = keras.Sequential()
self.model.add(Dense(2000, activation='relu', input_dim=(2 * 7 + 2 + Cfg.tiles_x * Cfg.tiles_y)))
self.model.add(Dense(1000, activation='relu'))
self.model.add(Dense(100, activation='relu'))
self.model.add(Dense(9)) self.model.compile(loss=keras.losses.categorical_crossentropy,
optimizer='adam',
metrics=['accuracy'])
Входные данные - это текущее состояние игры (игровое поле и инвентарь игрока). Он сохраняется в виде списка до тех пор, пока не будет передан в функцию прогнозирования, после чего он преобразуется в массив numpy с помощью np.asarray (list).
Вызов функции:
a1_action = agent1.get_move(np.asarray(a1_input))
Где a1_input - список игровых состояний длиной 3216. Agent1 - это объект класса, хранящего сеть и функцию get_move.
Функция, которая возвращает прогноз:
def get_move(self, info):
# This is where I print from in the next section
prediction = self.model.predict(info)
if max(prediction) > AI_Cfg.actionThreshold:
return Agent.moves[prediction.index(max(prediction))]
else:
return ""
Теперь ошибка I get при передаче информации напрямую находится в строке model.predict ():
ValueError: Ошибка при проверке ввода: ожидается, что density_input будет иметь форму (3216,), но получил массив с формой (1,)
Я также пробовал предложения из здесь , но это преобразует мою форму ввода (3216,) в (1, 3216) и дает ошибку:
ValueError: попытка преобразовать значение (75) неподдерживаемого типа (класс 'numpy .int32') в тензор.
Я также пробовал переустановить numpy.
Если есть дополнительная информация, я буду рад y сделать это. Спасибо!
Полезная печать над строкой прогноза:
info.shape -> (3216,)
info -> [75 35 0 ... 0 0 0]
РЕДАКТИРОВАТЬ
После изменения кода, чтобы попытаться исправить проблему, я получаю только указанную ниже ошибку.
Попытка преобразовать значение (75) с неподдерживаемым типом (класс 'numpy .int32') в тензор.
Интересно, что это та же ошибка Я получаю, когда (пропуская первый набор прогнозов и) пытаюсь обучить модель на данных.