Итак, я обучил перцептрон тензорному потоку на наборе данных MNIST, но только с цифрами от 0 до 4. Затем я создал новую модель со всеми теми же слоями и весами, но с новым выходным слоем также с 5 выходными узлами. Я хочу обучить эту новую модель классифицировать цифры от 5 до 9.
Я создал новые x_train и y_train только с цифрами от 5 до 9 и запустил
transfer_model.fit(x_train[train_filter],y_train[train_filter], epoch=5)
, где train_filter определяется как np.where(np.logical_and(x_train<=5,x_train>=9))
.
На самом первом этапе обучения я получаю эту ошибку:
InvalidArgumentError: получено значение метки 9, которое является вне допустимого диапазона [0, 5). Значения ярлыков: 5 9 7 8 9 8 7 6 8 7 6 9 5 5 8 7 6 9 9 7 6 7 6 8 7 7 9 7 6 8 5 6
Это имеет смысл, потому что я изначально обучил сеть классифицировать в диапазоне [0,5), но теперь я хочу сделать диапазон [5,10). Я пропустил здесь шаг? Я не уверен, что мне не хватает ... Как определить, чему соответствует каждый выходной нейрон?
Вот сводка моей модели:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
flatten_7 (Flatten) (None, 784) 0
_________________________________________________________________
dense_49 (Dense) (None, 100) 78500
_________________________________________________________________
batch_normalization_10 (Batc (None, 100) 400
_________________________________________________________________
dropout_5 (Dropout) (None, 100) 0
_________________________________________________________________
dense_50 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_11 (Batc (None, 100) 400
_________________________________________________________________
dropout_6 (Dropout) (None, 100) 0
_________________________________________________________________
dense_51 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_12 (Batc (None, 100) 400
_________________________________________________________________
dropout_7 (Dropout) (None, 100) 0
_________________________________________________________________
dense_52 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_13 (Batc (None, 100) 400
_________________________________________________________________
dropout_8 (Dropout) (None, 100) 0
_________________________________________________________________
dense_53 (Dense) (None, 100) 10100
_________________________________________________________________
batch_normalization_14 (Batc (None, 100) 400
_________________________________________________________________
dropout_9 (Dropout) (None, 100) 0
_________________________________________________________________
dense_55 (Dense) (None, 5) 505
=================================================================
Total params: 121,405
Trainable params: 505
Non-trainable params: 120,900
_________________________________________________________________