Ошибка при проверке ввода: ожидалось, что dens_203_input будет иметь форму (1202,), но получил массив с формой (1,) - PullRequest
1 голос
/ 19 марта 2020

Я сделал очень простую нейронную сеть, которая предназначена для обучения с подкреплением. Тем не менее, я не могу ничего предсказать, так как при попытке предсказать получаю ошибку.

Ошибка в вопросе:

Ошибка при проверке входных данных: ожидается, что dens_203_input будет иметь форму (1202,), но получил массив с формой (1,)

Модель в вопросах:

 def _build_compile_model(self):
    model = Sequential()
    model.add(Dense(300, activation='relu', input_dim=1202))
    model.add(Dense(300, activation='relu'))
    model.add(Dense(200, activation='relu'))
    model.add(Dense(self._action_size, activation='softmax'))

    model.compile(loss='mse', optimizer=self._optimizer)
    return model

ошибка возникает при вызове model.predict (state), где state - это массив формы (1202, 1).

Полное сообщение об ошибке:

ValueError                                Traceback (most recent call last)
<ipython-input-148-06b7a01facef> in <module>
     18     new_state, reward = env.step(action, new_demand_a, new_demand_b) # Take action, get new state and reward
     19     new_state = np.reshape(new_state, [1202, -1])
---> 20     agent.update(old_state, new_state, action, reward) # Let the agent update internal
     21     average_reward.append(reward) # Keep score
     22     if i % 100 == 0 and i != 0: # Print out metadata every 100th iteration

<ipython-input-145-142ae54ce43f> in update(self, old_state, new_state, action, reward)
     49     def update(self, old_state, new_state, action, reward):
     50         print(old_state.shape)
---> 51         target = self.q_network.predict(old_state)
     52         t = self.target_network.predict(new_state)
     53         target[0][action] = reward + self.gamma * np.amax(t)

/opt/conda/envs/tensorflow2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py in predict(self, x, batch_size, verbose, steps, callbacks, max_queue_size, workers, use_multiprocessing)
   1011         max_queue_size=max_queue_size,
   1012         workers=workers,
-> 1013         use_multiprocessing=use_multiprocessing)
   1014 
   1015   def reset_metrics(self):

/opt/conda/envs/tensorflow2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py in predict(self, model, x, batch_size, verbose, steps, callbacks, max_queue_size, workers, use_multiprocessing, **kwargs)
    496         model, ModeKeys.PREDICT, x=x, batch_size=batch_size, verbose=verbose,
    497         steps=steps, callbacks=callbacks, max_queue_size=max_queue_size,
--> 498         workers=workers, use_multiprocessing=use_multiprocessing, **kwargs)
    499 
    500 

/opt/conda/envs/tensorflow2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py in _model_iteration(self, model, mode, x, y, batch_size, verbose, sample_weight, steps, callbacks, max_queue_size, workers, use_multiprocessing, **kwargs)
    424           max_queue_size=max_queue_size,
    425           workers=workers,
--> 426           use_multiprocessing=use_multiprocessing)
    427       total_samples = _get_total_number_of_samples(adapter)
    428       use_sample = total_samples is not None

/opt/conda/envs/tensorflow2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_v2.py in _process_inputs(model, mode, x, y, batch_size, epochs, sample_weights, class_weights, shuffle, steps, distribution_strategy, max_queue_size, workers, use_multiprocessing)
    644     standardize_function = None
    645     x, y, sample_weights = standardize(
--> 646         x, y, sample_weight=sample_weights)
    647   elif adapter_cls is data_adapter.ListsOfScalarsDataAdapter:
    648     standardize_function = standardize

/opt/conda/envs/tensorflow2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, batch_size, check_steps, steps_name, steps, validation_split, shuffle, extract_tensors_from_dataset)
   2381         is_dataset=is_dataset,
   2382         class_weight=class_weight,
-> 2383         batch_size=batch_size)
   2384 
   2385   def _standardize_tensors(self, x, y, sample_weight, run_eagerly, dict_inputs,

/opt/conda/envs/tensorflow2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py in _standardize_tensors(self, x, y, sample_weight, run_eagerly, dict_inputs, is_dataset, class_weight, batch_size)
   2408           feed_input_shapes,
   2409           check_batch_axis=False,  # Don't enforce the batch size.
-> 2410           exception_prefix='input')
   2411 
   2412     # Get typespecs for the input data and sanitize it if necessary.

/opt/conda/envs/tensorflow2/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training_utils.py in standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
    580                              ': expected ' + names[i] + ' to have shape ' +
    581                              str(shape) + ' but got array with shape ' +
--> 582                              str(data_shape))
    583   return data
    584 

ValueError: Error when checking input: expected dense_211_input to have shape (1202,) but got array with shape (1,)

1 Ответ

1 голос
/ 23 марта 2020

Существует два подхода при подаче входных данных в вашей модели:

1-й вариант: использование input_shape

model.add(Dense(300, activation='relu', input_shape=(1202,1)))

Здесь форма ввода в 2D , но вы должны подать в вашу сеть 3D вход ( Rank 3 ), поскольку вам необходимо включить batch_size .

Пример ввода:

state = np.array(np.ones((BATCH_SIZE,1202,1)))
print("Input Rank: {}".format(tf.rank(state))) # Check for the Rank of Input

2-й вариант: Использование input_dim

model_dim.add(Dense(300, activation='relu', input_dim=1202))

Здесь форма ввода находится в 1D , но вы должны подать в вашу сеть 2D вход ( Rank 2 ), так как вам нужно включить batch_size .

Пример ввода:

state = np.array(np.ones((1,1202,)))
print("Input Rank: {}".format(tf.rank(state))) # Check for the Rank of Input
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...