Я использую Keras для обучения глубокой нейронной сети.Я использую функцию train_on_batch для обучения моей модели.У моей модели два выхода.Что я намерен сделать, это изменить потери для каждого из образцов на определенное значение для каждого образца.Так что из-за документации Keras здесь
мне нужно присвоить два разных веса для аргумента sample_weight .Вот как выглядит мой код, где в каждом пакете у меня есть четыре обучающих примера:
wights=[12,10,31,1];
mod_loss = mymodel.train_on_batch([X_train], [Y1, Y2],sample_weight=[wights,[1.0,1.0,1.0,1.0]])
Я использую sample_weight для взвешивания только первого вывода, а не второго вывода.когда я запускаю код, я получаю эту ошибку:
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/keras/engine/training.py", line 1211, in train_on_batch
class_weight=class_weight)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/keras/engine/training.py", line 801, in _standardize_user_data
feed_sample_weight_modes)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/keras/engine/training.py", line 799, in <listcomp>
for (ref, sw, cw, mode) in
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/keras/engine/training_utils.py", line 470, in standardize_weights
if sample_weight is not None and len(sample_weight.shape) != 1:
AttributeError: 'list' object has no attribute 'shape'
Это дало мне идею, если я изменю присвоенное значение на sample_weight на массив numpy, проблема будет решена.Поэтому я изменил код на этот:
wights=[12,10,31,1];
mod_loss = mymodel.train_on_batch([X_train], [Y1, Y2],sample_weight=numpy.array([wights,[1.0,1.0,1.0,1.0]]))
И у меня появилась эта ошибка:
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/keras/engine/training.py", line 1211, in train_on_batch
class_weight=class_weight)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/keras/engine/training.py", line 794, in _standardize_user_data
sample_weight, feed_output_names)
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/keras/engine/training_utils.py", line 200, in standardize_sample_weights
'sample_weight')
File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/keras/engine/training_utils.py", line 188, in standardize_sample_or_class_weights
str(x_weight))
TypeError: The model has multiple outputs, so `sample_weight` should be either a list or a dict. Provided `sample_weight` type not understood: [[12.0 10.0 31.0 1.0]
[ 1. 1. 1. 1. ]]
Я был немного смущен, я не уверен, что это ошибка внутриРеализация кераса или нет.Я едва мог найти любую работу или проблему, связанную с этим, в сети.Есть мысли?