Передача обучения на MNIST: ошибка неправильных меток - PullRequest
0 голосов
/ 30 мая 2020

Итак, я обучил перцептрон тензорному потоку на наборе данных MNIST, но только с цифрами от 0 до 4. Затем я создал новую модель со всеми теми же слоями и весами, но с новым выходным слоем также с 5 выходными узлами. Я хочу обучить эту новую модель классифицировать цифры от 5 до 9.

Я создал новые x_train и y_train только с цифрами от 5 до 9 и запустил

transfer_model.fit(x_train[train_filter],y_train[train_filter], epoch=5)

, где train_filter определяется как np.where(np.logical_and(x_train<=5,x_train>=9)).

На самом первом этапе обучения я получаю эту ошибку:

InvalidArgumentError: получено значение метки 9, которое является вне допустимого диапазона [0, 5). Значения ярлыков: 5 9 7 8 9 8 7 6 8 7 6 9 5 5 8 7 6 9 9 7 6 7 6 8 7 7 9 7 6 8 5 6

Это имеет смысл, потому что я изначально обучил сеть классифицировать в диапазоне [0,5), но теперь я хочу сделать диапазон [5,10). Я пропустил здесь шаг? Я не уверен, что мне не хватает ... Как определить, чему соответствует каждый выходной нейрон?

Вот сводка моей модели:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
flatten_7 (Flatten)          (None, 784)               0         
_________________________________________________________________
dense_49 (Dense)             (None, 100)               78500     
_________________________________________________________________
batch_normalization_10 (Batc (None, 100)               400       
_________________________________________________________________
dropout_5 (Dropout)          (None, 100)               0         
_________________________________________________________________
dense_50 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_11 (Batc (None, 100)               400       
_________________________________________________________________
dropout_6 (Dropout)          (None, 100)               0         
_________________________________________________________________
dense_51 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_12 (Batc (None, 100)               400       
_________________________________________________________________
dropout_7 (Dropout)          (None, 100)               0         
_________________________________________________________________
dense_52 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_13 (Batc (None, 100)               400       
_________________________________________________________________
dropout_8 (Dropout)          (None, 100)               0         
_________________________________________________________________
dense_53 (Dense)             (None, 100)               10100     
_________________________________________________________________
batch_normalization_14 (Batc (None, 100)               400       
_________________________________________________________________
dropout_9 (Dropout)          (None, 100)               0         
_________________________________________________________________
dense_55 (Dense)             (None, 5)                 505       
=================================================================
Total params: 121,405
Trainable params: 505
Non-trainable params: 120,900
_________________________________________________________________

Ответы [ 2 ]

1 голос
/ 30 мая 2020

Вам нужно сопоставить 5-9 с 0-4. Метки классов, вероятно, создаются с помощью одного горячего кодирования, у вас есть 5 уникальных меток, поэтому для его представления требуется только вектор длиной 5. Но поскольку метка 5-9, она будет вне диапазона. Вам не нужно настраивать модель, просто добавьте карту к выходным данным надписей.

0 голосов
/ 30 мая 2020

Поскольку вы используете numpy, вы можете попробовать следующее

import tensorflow as tf
import numpy as np

arr = np.array([5,6,7,8,9,8,7,6,5])
arr = tf.one_hot(arr,10,axis=0).numpy()
arr = arr[5:]

tf.argmax(arr).numpy() # returns array([0, 1, 2, 3, 4, 3, 2, 1, 0])

или используя tf.map_fn

arr = np.array([5,6,7,8,9,8,7,6,5])

tf.map_fn(lambda x : x-5, arr).numpy() # array([0, 1, 2, 3, 4, 3, 2, 1, 0])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...