Я сделал очень простую нейронную сеть, которая предназначена для обучения с подкреплением. Тем не менее, я не могу ничего предсказать, так как при попытке предсказать получаю ошибку.
Ошибка в вопросе:
Ошибка при проверке входных данных: ожидается, что dens_203_input будет иметь форму (1202,), но получил массив с формой (1,)
Модель в вопросах:
def _build_compile_model(self):
model = Sequential()
model.add(Dense(300, activation='relu', input_dim=1202))
model.add(Dense(300, activation='relu'))
model.add(Dense(200, activation='relu'))
model.add(Dense(self._action_size, activation='softmax'))
model.compile(loss='mse', optimizer=self._optimizer)
return model
ошибка возникает при вызове model.predict (state), где state - это массив формы (1202, 1).
Полное сообщение об ошибке:
ValueError Traceback (most recent call last)
<ipython-input-148-06b7a01facef> in <module>
18 new_state, reward = env.step(action, new_demand_a, new_demand_b) # Take action, get new state and reward
19 new_state = np.reshape(new_state, [1202, -1])
---> 20 agent.update(old_state, new_state, action, reward) # Let the agent update internal
21 average_reward.append(reward) # Keep score
22 if i % 100 == 0 and i != 0: # Print out metadata every 100th iteration
<ipython-input-145-142ae54ce43f> in update(self, old_state, new_state, action, reward)
49 def update(self, old_state, new_state, action, reward):
50 print(old_state.shape)
---> 51 target = self.q_network.predict(old_state)
52 t = self.target_network.predict(new_state)
53 target[0][action] = reward + self.gamma * np.amax(t)
/opt/conda/envs/tensorflow2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py in predict(self, x, batch_size, verbose, steps, callbacks, max_queue_size, workers, use_multiprocessing)
1011 max_queue_size=max_queue_size,
1012 workers=workers,
-> 1013 use_multiprocessing=use_multiprocessing)
1014
1015 def reset_metrics(self):
/opt/conda/envs/tensorflow2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py in predict(self, model, x, batch_size, verbose, steps, callbacks, max_queue_size, workers, use_multiprocessing, **kwargs)
496 model, ModeKeys.PREDICT, x=x, batch_size=batch_size, verbose=verbose,
497 steps=steps, callbacks=callbacks, max_queue_size=max_queue_size,
--> 498 workers=workers, use_multiprocessing=use_multiprocessing, **kwargs)
499
500
/opt/conda/envs/tensorflow2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py in _model_iteration(self, model, mode, x, y, batch_size, verbose, sample_weight, steps, callbacks, max_queue_size, workers, use_multiprocessing, **kwargs)
424 max_queue_size=max_queue_size,
425 workers=workers,
--> 426 use_multiprocessing=use_multiprocessing)
427 total_samples = _get_total_number_of_samples(adapter)
428 use_sample = total_samples is not None
/opt/conda/envs/tensorflow2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py in _process_inputs(model, mode, x, y, batch_size, epochs, sample_weights, class_weights, shuffle, steps, distribution_strategy, max_queue_size, workers, use_multiprocessing)
644 standardize_function = None
645 x, y, sample_weights = standardize(
--> 646 x, y, sample_weight=sample_weights)
647 elif adapter_cls is data_adapter.ListsOfScalarsDataAdapter:
648 standardize_function = standardize
/opt/conda/envs/tensorflow2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, batch_size, check_steps, steps_name, steps, validation_split, shuffle, extract_tensors_from_dataset)
2381 is_dataset=is_dataset,
2382 class_weight=class_weight,
-> 2383 batch_size=batch_size)
2384
2385 def _standardize_tensors(self, x, y, sample_weight, run_eagerly, dict_inputs,
/opt/conda/envs/tensorflow2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py in _standardize_tensors(self, x, y, sample_weight, run_eagerly, dict_inputs, is_dataset, class_weight, batch_size)
2408 feed_input_shapes,
2409 check_batch_axis=False, # Don't enforce the batch size.
-> 2410 exception_prefix='input')
2411
2412 # Get typespecs for the input data and sanitize it if necessary.
/opt/conda/envs/tensorflow2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_utils.py in standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
580 ': expected ' + names[i] + ' to have shape ' +
581 str(shape) + ' but got array with shape ' +
--> 582 str(data_shape))
583 return data
584
ValueError: Error when checking input: expected dense_211_input to have shape (1202,) but got array with shape (1,)