Итак, я делаю агентное моделирование, где у каждого агента есть сеть MLP, но у меня есть агент ограничения, который может иметь только один пример за раз, следовательно, агент может обучить только один пример
, поэтому, когда я пыталсятренировка с использованием функции соответствия keras Я получил цепочку ошибок
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(128,activation=tf.nn.tanh))
model.add(tf.keras.layers.Dense(128,activation=tf.nn.tanh))
model.add(tf.keras.layers.Dense(1,activation=tf.nn.tanh))
model.compile(optimizer='SGD',
loss='mean_squared_error',validation_split=0)
model.fit(x_train[0][np.newaxis,:,:],np.array([y_train[0]]),epochs=3,batch_size=1)
Я просто хотел тренировать один пример за раз, но получил следующую ошибку
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-81-688643bc66bc> in <module>()
----> 1 model.fit(x_train[0][np.newaxis,:,:],np.array([y_train[0]]),epochs=3,batch_size=1)
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)
1261 steps_name='steps_per_epoch',
1262 steps=steps_per_epoch,
-> 1263 validation_split=validation_split)
1264
1265 # Prepare validation data.
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/keras/engine/training.py in _standardize_user_data(self, x, y, sample_weight, class_weight, batch_size, check_steps, steps_name, steps, validation_split)
905 feed_output_shapes,
906 check_batch_axis=False, # Don't enforce the batch size.
--> 907 exception_prefix='target')
908
909 # Generate sample-wise weight values given the `sample_weight` and
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/keras/engine/training_utils.py in standardize_input_data(data, names, shapes, check_batch_axis, exception_prefix)
189 'Error when checking ' + exception_prefix + ': expected ' +
190 names[i] + ' to have shape ' + str(shape) +
--> 191 ' but got array with shape ' + str(data_shape))
192 return data
193
ValueError: Error when checking target: expected dense_5 to have shape (10,) but got array with shape (1,)