Почему моя модель бинарной классификации не обучается, даже если она не подходит? - PullRequest
0 голосов
/ 14 июля 2020

У меня есть следующая модель, использующая тензорный поток 2.2.0 с keras:

def get_model(input_shape):
  model = keras.Sequential()
  
  model.add(Conv2D(32, input_shape=input_shape, kernel_size=(3, 3)))
  model.add(MaxPooling2D(pool_size=(2, 2)))

  model.add(Conv2D(64, kernel_size=(3, 3)))
  model.add(MaxPooling2D(pool_size=(2, 2)))
  
  model.add(Conv2D(64, kernel_size=(3, 3)))
  model.add(MaxPooling2D(pool_size=(2, 2)))

  model.add(Flatten())

  model.add(Dense(64, activation='relu'))
  model.add(Dense(1, activation='sigmoid'))

  return model

Форма ввода (25, 25, 4) - 3-мерное изображение, 25x25 пикселей, с 4 каналами. Модель не учится - даже не переобучится! Я пытаюсь соответствовать, используя следующее заклинание:

model.compile(optimizer='sgd', metrics=['accuracy'], loss='binary_crossentropy')
model.fit(trainX, trainY, validation_split=0.2, epochs=10, batch_size=50)

Я также пытался изменить оптимизатор на sgd с теми же результатами и пробовал разные размеры пакетов (включая 1). Пример обучения для 10 эпох:

Epoch 1/10
763/763 [==============================] - 4s 5ms/step - loss: 0.6935 - accuracy: 0.5045 - val_loss: 0.6937 - val_accuracy: 0.5031
Epoch 2/10
763/763 [==============================] - 3s 4ms/step - loss: 0.6935 - accuracy: 0.5020 - val_loss: 0.6946 - val_accuracy: 0.4972
Epoch 3/10
763/763 [==============================] - 3s 4ms/step - loss: 0.6935 - accuracy: 0.5016 - val_loss: 0.6932 - val_accuracy: 0.4984
Epoch 4/10
763/763 [==============================] - 3s 4ms/step - loss: 0.6934 - accuracy: 0.5020 - val_loss: 0.6932 - val_accuracy: 0.4986
Epoch 5/10
763/763 [==============================] - 3s 4ms/step - loss: 0.6933 - accuracy: 0.5027 - val_loss: 0.6934 - val_accuracy: 0.4972
Epoch 6/10
763/763 [==============================] - 3s 4ms/step - loss: 0.6932 - accuracy: 0.5051 - val_loss: 0.6946 - val_accuracy: 0.5019
Epoch 7/10
763/763 [==============================] - 3s 4ms/step - loss: 0.6933 - accuracy: 0.5017 - val_loss: 0.6932 - val_accuracy: 0.4959
Epoch 8/10
763/763 [==============================] - 3s 4ms/step - loss: 0.6933 - accuracy: 0.5017 - val_loss: 0.6934 - val_accuracy: 0.5056
Epoch 9/10
763/763 [==============================] - 3s 4ms/step - loss: 0.6932 - accuracy: 0.5040 - val_loss: 0.6931 - val_accuracy: 0.5009
Epoch 10/10
763/763 [==============================] - 3s 4ms/step - loss: 0.6933 - accuracy: 0.5018 - val_loss: 0.6931 - val_accuracy: 0.5020
<tensorflow.python.keras.callbacks.History at 0x7f761a0856d8>

Как бы то ни было, данные почти наверняка не проблема - я пробовал другие методы машинного обучения, такие как случайные леса и повышение градиента, и они могут overfit просто отлично.

Я не упускаю здесь чего-то принципиального?

Изменить: установка активации сверточных слоев на relu не помогает. Приведенный ниже вывод имеет relu:


Epoch 1/10
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6936 - accuracy: 0.4990 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 2/10
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6933 - accuracy: 0.5026 - val_loss: 0.6931 - val_accuracy: 0.5043
Epoch 3/10
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6933 - accuracy: 0.5000 - val_loss: 0.6932 - val_accuracy: 0.4971
Epoch 4/10
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5004 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 5/10
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.4992 - val_loss: 0.6932 - val_accuracy: 0.5029
Epoch 6/10
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5031 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 7/10
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5006 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 8/10
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5001 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 9/10
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5029 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 10/10
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5012 - val_loss: 0.6931 - val_accuracy: 0.5029
<tensorflow.python.keras.callbacks.History at 0x7f29766804a8>

Я также попытался изменить метки на категориальные и использовать categorical_crossentropy, но безрезультатно.

Изменить 2: такое же поведение сохраняется за несколько эпох, с правильно установленной активацией.

Модель:

def get_model(input_shape):
  model = keras.Sequential()
  
  model.add(Conv2D(32, input_shape=input_shape, kernel_size=(3, 3), activation='relu'))
  model.add(MaxPooling2D(pool_size=(2, 2)))

  model.add(Conv2D(64, kernel_size=(3, 3), activation='relu'))
  model.add(MaxPooling2D(pool_size=(2, 2)))
  
  model.add(Flatten())

  model.add(Dense(64, activation='relu'))
  model.add(Dense(1, activation='sigmoid'))

  return model

Вывод:

Epoch 1/250
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6937 - accuracy: 0.4998 - val_loss: 0.6931 - val_accuracy: 0.5008
...
Epoch 243/250
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5006 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 244/250
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5007 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 245/250
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5014 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 246/250
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5035 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 247/250
1907/1907 [==============================] - 7s 4ms/step - loss: 0.6932 - accuracy: 0.5031 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 248/250
1907/1907 [==============================] - 7s 4ms/step - loss: 0.6932 - accuracy: 0.5026 - val_loss: 0.6932 - val_accuracy: 0.4971
Epoch 249/250
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5018 - val_loss: 0.6932 - val_accuracy: 0.4971
Epoch 250/250
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5007 - val_loss: 0.6931 - val_accuracy: 0.5029

Пример данных:

display(trainX[0])
display(trainX[0].shape)
---
array([[[-0.81307793, -0.80876915, -0.80270227, -0.81340067],
        [-0.81323822, -0.80901267, -0.80424022, -0.81004681],
        [-0.80974839, -0.80952621, -0.80894936, -0.81924987],
        [-0.81901061, -0.81892894, -0.8198063 , -0.82950191],
        [-0.82926863, -0.82535357, -0.81962295, -0.82940024],
        [-0.82911602, -0.82669005, -0.81815252, -0.82725751],
        [-0.82717653, -0.82594539, -0.81691338, -0.82605227],
        [-0.82584266, -0.82452835, -0.81556359, -0.82556375],
        [-0.82525089, -0.82266387, -0.8177839 , -0.82243512],
        [-0.82222369, -0.82112803, -0.82649334, -0.83150323]],

       [[-0.81323822, -0.80901267, -0.80424022, -0.81004681],
        [-0.81339844, -0.80925606, -0.80577279, -0.80666623],
        [-0.80990994, -0.8097693 , -0.81046532, -0.81594339],
        [-0.81916858, -0.81916656, -0.82128286, -0.82628101],
        [-0.8294225 , -0.82558735, -0.82110019, -0.82617847],
        [-0.82926995, -0.82692302, -0.81963519, -0.82401759],
        [-0.82733125, -0.82617881, -0.8184006 , -0.82280219],
        [-0.82599791, -0.82476263, -0.81705573, -0.82230957],
        [-0.82540639, -0.82289927, -0.81926792, -0.81915487],
        [-0.8223804 , -0.82136435, -0.82794484, -0.82829943]],

       [[-0.80974839, -0.80952621, -0.80894936, -0.81924987],
        [-0.80990994, -0.8097693 , -0.81046532, -0.81594339],
        [-0.80639256, -0.81028192, -0.81510641, -0.82501505],
        [-0.81572868, -0.81966765, -0.82580199, -0.83511534],
        [-0.82607158, -0.82608032, -0.82562142, -0.8350152 ],
        [-0.82591768, -0.82741428, -0.8241732 , -0.83290467],
        [-0.8239619 , -0.82667103, -0.82295269, -0.83171742],
        [-0.82261689, -0.82525665, -0.82162309, -0.83123616],
        [-0.82202021, -0.82339567, -0.82381013, -0.82815378],
        [-0.818968  , -0.82186268, -0.83238638, -0.83708633]],

       [[-0.81901061, -0.81892894, -0.8198063 , -0.82950191],
        [-0.81916858, -0.81916656, -0.82128286, -0.82628101],
        [-0.81572868, -0.81966765, -0.82580199, -0.83511534],
        [-0.82485699, -0.82883834, -0.8362085 , -0.84494163],
        [-0.83496124, -0.83509975, -0.83603291, -0.84484426],
        [-0.83481096, -0.83640177, -0.83462448, -0.84279185],
        [-0.83290099, -0.83567633, -0.83343734, -0.84163708],
        [-0.83158729, -0.83429571, -0.83214391, -0.84116895],
        [-0.83100444, -0.83247886, -0.83427135, -0.83817006],
        [-0.82802254, -0.830982  , -0.84260906, -0.84685789]],

       [[-0.82926863, -0.82535357, -0.81962295, -0.82940024],
        [-0.8294225 , -0.82558735, -0.82110019, -0.82617847],
        [-0.82607158, -0.82608032, -0.82562142, -0.8350152 ],
        [-0.83496124, -0.83509975, -0.83603291, -0.84484426],
        [-0.84479157, -0.84125479, -0.83585723, -0.84474687],
        [-0.84464544, -0.84253437, -0.83444812, -0.84269387],
        [-0.84278804, -0.84182145, -0.8332604 , -0.84153877],
        [-0.84151027, -0.84046456, -0.83196635, -0.84107051],
        [-0.84094331, -0.83867874, -0.83409482, -0.83807077],
        [-0.83804212, -0.83720728, -0.84243663, -0.84676108]],

       [[-0.82911602, -0.82669005, -0.81815252, -0.82725751],
        [-0.82926995, -0.82692302, -0.81963519, -0.82401759],
        [-0.82591768, -0.82741428, -0.8241732 , -0.83290467],
        [-0.83481096, -0.83640177, -0.83462448, -0.84279185],
        [-0.84464544, -0.84253437, -0.83444812, -0.84269387],
        [-0.84449925, -0.84380921, -0.83303354, -0.84062854],
        [-0.84264105, -0.84309893, -0.83184123, -0.83946655],
        [-0.84136274, -0.84174705, -0.8305422 , -0.83899551],
        [-0.84079554, -0.83996778, -0.83267886, -0.83597806],
        [-0.83789312, -0.83850169, -0.84105351, -0.84472027]],

       [[-0.82717653, -0.82594539, -0.81691338, -0.82605227],
        [-0.82733125, -0.82617881, -0.8184006 , -0.82280219],
        [-0.8239619 , -0.82667103, -0.82295269, -0.83171742],
        [-0.83290099, -0.83567633, -0.83343734, -0.84163708],
        [-0.84278804, -0.84182145, -0.8332604 , -0.84153877],
        [-0.84264105, -0.84309893, -0.83184123, -0.83946655],
        [-0.84077276, -0.84238718, -0.83064506, -0.83830071],
        [-0.83948755, -0.84103251, -0.82934186, -0.83782811],
        [-0.8389173 , -0.83924958, -0.83148541, -0.83480076],
        [-0.8359994 , -0.8377805 , -0.83988758, -0.84357199]],

       [[-0.82584266, -0.82452835, -0.81556359, -0.82556375],
        [-0.82599791, -0.82476263, -0.81705573, -0.82230957],
        [-0.82261689, -0.82525665, -0.82162309, -0.83123616],
        [-0.83158729, -0.83429571, -0.83214391, -0.84116895],
        [-0.84151027, -0.84046456, -0.83196635, -0.84107051],
        [-0.84136274, -0.84174705, -0.8305422 , -0.83899551],
        [-0.83948755, -0.84103251, -0.82934186, -0.83782811],
        [-0.83819763, -0.83967254, -0.82803413, -0.83735488],
        [-0.8376253 , -0.83788269, -0.83018513, -0.83432354],
        [-0.83469681, -0.83640794, -0.83861716, -0.84310649]],

       [[-0.82525089, -0.82266387, -0.8177839 , -0.82243512],
        [-0.82540639, -0.82289927, -0.81926792, -0.81915487],
        [-0.82202021, -0.82339567, -0.82381013, -0.82815378],
        [-0.83100444, -0.83247886, -0.83427135, -0.83817006],
        [-0.84094331, -0.83867874, -0.83409482, -0.83807077],
        [-0.84079554, -0.83996778, -0.83267886, -0.83597806],
        [-0.8389173 , -0.83924958, -0.83148541, -0.83480076],
        [-0.8376253 , -0.83788269, -0.83018513, -0.83432354],
        [-0.83705204, -0.83608379, -0.83232385, -0.83126675],
        [-0.83411887, -0.83460162, -0.8407067 , -0.84012427]],

       [[-0.82222369, -0.82112803, -0.82649334, -0.83150323],
        [-0.8223804 , -0.82136435, -0.82794484, -0.82829943],
        [-0.818968  , -0.82186268, -0.83238638, -0.83708633],
        [-0.82802254, -0.830982  , -0.84260906, -0.84685789],
        [-0.83804212, -0.83720728, -0.84243663, -0.84676108],
        [-0.83789312, -0.83850169, -0.84105351, -0.84472027],
        [-0.8359994 , -0.8377805 , -0.83988758, -0.84357199],
        [-0.83469681, -0.83640794, -0.83861716, -0.84310649],
        [-0.83411887, -0.83460162, -0.8407067 , -0.84012427],
        [-0.83116192, -0.83311339, -0.84889275, -0.84876322]]])
(10, 10, 4)

display(trainY[0:5])
display(trainY.shape)
---
array([0, 1, 0, 1, 0], dtype=int64)
(47666,)

Ответы [ 2 ]

3 голосов
/ 15 июля 2020

Модель не обучается, потому что сверточные слои имеют линейную функцию активации, которая по умолчанию равна None, если вы ее не укажете. Обычно функция активации, используемая со слоями conv, равна Relu, поэтому просто добавьте activation='relu' к вашим сверточным слоям

0 голосов
/ 15 июля 2020

Суть нейронных сетей в том, чтобы вызывать нелинейности. Вы не упомянули какие-либо функции активации в первых трех слоях свертки.

См. Здесь различные функции активации

tf.keras.layers.Conv2D(64, 3, activation = 'relu')
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...