У меня есть следующая модель, использующая тензорный поток 2.2.0 с keras:
def get_model(input_shape):
model = keras.Sequential()
model.add(Conv2D(32, input_shape=input_shape, kernel_size=(3, 3)))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(64, kernel_size=(3, 3)))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(64, kernel_size=(3, 3)))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(64, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
return model
Форма ввода (25, 25, 4)
- 3-мерное изображение, 25x25 пикселей, с 4 каналами. Модель не учится - даже не переобучится! Я пытаюсь соответствовать, используя следующее заклинание:
model.compile(optimizer='sgd', metrics=['accuracy'], loss='binary_crossentropy')
model.fit(trainX, trainY, validation_split=0.2, epochs=10, batch_size=50)
Я также пытался изменить оптимизатор на sgd
с теми же результатами и пробовал разные размеры пакетов (включая 1). Пример обучения для 10 эпох:
Epoch 1/10
763/763 [==============================] - 4s 5ms/step - loss: 0.6935 - accuracy: 0.5045 - val_loss: 0.6937 - val_accuracy: 0.5031
Epoch 2/10
763/763 [==============================] - 3s 4ms/step - loss: 0.6935 - accuracy: 0.5020 - val_loss: 0.6946 - val_accuracy: 0.4972
Epoch 3/10
763/763 [==============================] - 3s 4ms/step - loss: 0.6935 - accuracy: 0.5016 - val_loss: 0.6932 - val_accuracy: 0.4984
Epoch 4/10
763/763 [==============================] - 3s 4ms/step - loss: 0.6934 - accuracy: 0.5020 - val_loss: 0.6932 - val_accuracy: 0.4986
Epoch 5/10
763/763 [==============================] - 3s 4ms/step - loss: 0.6933 - accuracy: 0.5027 - val_loss: 0.6934 - val_accuracy: 0.4972
Epoch 6/10
763/763 [==============================] - 3s 4ms/step - loss: 0.6932 - accuracy: 0.5051 - val_loss: 0.6946 - val_accuracy: 0.5019
Epoch 7/10
763/763 [==============================] - 3s 4ms/step - loss: 0.6933 - accuracy: 0.5017 - val_loss: 0.6932 - val_accuracy: 0.4959
Epoch 8/10
763/763 [==============================] - 3s 4ms/step - loss: 0.6933 - accuracy: 0.5017 - val_loss: 0.6934 - val_accuracy: 0.5056
Epoch 9/10
763/763 [==============================] - 3s 4ms/step - loss: 0.6932 - accuracy: 0.5040 - val_loss: 0.6931 - val_accuracy: 0.5009
Epoch 10/10
763/763 [==============================] - 3s 4ms/step - loss: 0.6933 - accuracy: 0.5018 - val_loss: 0.6931 - val_accuracy: 0.5020
<tensorflow.python.keras.callbacks.History at 0x7f761a0856d8>
Как бы то ни было, данные почти наверняка не проблема - я пробовал другие методы машинного обучения, такие как случайные леса и повышение градиента, и они могут overfit просто отлично.
Я не упускаю здесь чего-то принципиального?
Изменить: установка активации сверточных слоев на relu
не помогает. Приведенный ниже вывод имеет relu
:
Epoch 1/10
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6936 - accuracy: 0.4990 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 2/10
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6933 - accuracy: 0.5026 - val_loss: 0.6931 - val_accuracy: 0.5043
Epoch 3/10
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6933 - accuracy: 0.5000 - val_loss: 0.6932 - val_accuracy: 0.4971
Epoch 4/10
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5004 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 5/10
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.4992 - val_loss: 0.6932 - val_accuracy: 0.5029
Epoch 6/10
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5031 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 7/10
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5006 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 8/10
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5001 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 9/10
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5029 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 10/10
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5012 - val_loss: 0.6931 - val_accuracy: 0.5029
<tensorflow.python.keras.callbacks.History at 0x7f29766804a8>
Я также попытался изменить метки на категориальные и использовать categorical_crossentropy
, но безрезультатно.
Изменить 2: такое же поведение сохраняется за несколько эпох, с правильно установленной активацией.
Модель:
def get_model(input_shape):
model = keras.Sequential()
model.add(Conv2D(32, input_shape=input_shape, kernel_size=(3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(64, kernel_size=(3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(64, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
return model
Вывод:
Epoch 1/250
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6937 - accuracy: 0.4998 - val_loss: 0.6931 - val_accuracy: 0.5008
...
Epoch 243/250
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5006 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 244/250
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5007 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 245/250
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5014 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 246/250
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5035 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 247/250
1907/1907 [==============================] - 7s 4ms/step - loss: 0.6932 - accuracy: 0.5031 - val_loss: 0.6931 - val_accuracy: 0.5029
Epoch 248/250
1907/1907 [==============================] - 7s 4ms/step - loss: 0.6932 - accuracy: 0.5026 - val_loss: 0.6932 - val_accuracy: 0.4971
Epoch 249/250
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5018 - val_loss: 0.6932 - val_accuracy: 0.4971
Epoch 250/250
1907/1907 [==============================] - 6s 3ms/step - loss: 0.6932 - accuracy: 0.5007 - val_loss: 0.6931 - val_accuracy: 0.5029
Пример данных:
display(trainX[0])
display(trainX[0].shape)
---
array([[[-0.81307793, -0.80876915, -0.80270227, -0.81340067],
[-0.81323822, -0.80901267, -0.80424022, -0.81004681],
[-0.80974839, -0.80952621, -0.80894936, -0.81924987],
[-0.81901061, -0.81892894, -0.8198063 , -0.82950191],
[-0.82926863, -0.82535357, -0.81962295, -0.82940024],
[-0.82911602, -0.82669005, -0.81815252, -0.82725751],
[-0.82717653, -0.82594539, -0.81691338, -0.82605227],
[-0.82584266, -0.82452835, -0.81556359, -0.82556375],
[-0.82525089, -0.82266387, -0.8177839 , -0.82243512],
[-0.82222369, -0.82112803, -0.82649334, -0.83150323]],
[[-0.81323822, -0.80901267, -0.80424022, -0.81004681],
[-0.81339844, -0.80925606, -0.80577279, -0.80666623],
[-0.80990994, -0.8097693 , -0.81046532, -0.81594339],
[-0.81916858, -0.81916656, -0.82128286, -0.82628101],
[-0.8294225 , -0.82558735, -0.82110019, -0.82617847],
[-0.82926995, -0.82692302, -0.81963519, -0.82401759],
[-0.82733125, -0.82617881, -0.8184006 , -0.82280219],
[-0.82599791, -0.82476263, -0.81705573, -0.82230957],
[-0.82540639, -0.82289927, -0.81926792, -0.81915487],
[-0.8223804 , -0.82136435, -0.82794484, -0.82829943]],
[[-0.80974839, -0.80952621, -0.80894936, -0.81924987],
[-0.80990994, -0.8097693 , -0.81046532, -0.81594339],
[-0.80639256, -0.81028192, -0.81510641, -0.82501505],
[-0.81572868, -0.81966765, -0.82580199, -0.83511534],
[-0.82607158, -0.82608032, -0.82562142, -0.8350152 ],
[-0.82591768, -0.82741428, -0.8241732 , -0.83290467],
[-0.8239619 , -0.82667103, -0.82295269, -0.83171742],
[-0.82261689, -0.82525665, -0.82162309, -0.83123616],
[-0.82202021, -0.82339567, -0.82381013, -0.82815378],
[-0.818968 , -0.82186268, -0.83238638, -0.83708633]],
[[-0.81901061, -0.81892894, -0.8198063 , -0.82950191],
[-0.81916858, -0.81916656, -0.82128286, -0.82628101],
[-0.81572868, -0.81966765, -0.82580199, -0.83511534],
[-0.82485699, -0.82883834, -0.8362085 , -0.84494163],
[-0.83496124, -0.83509975, -0.83603291, -0.84484426],
[-0.83481096, -0.83640177, -0.83462448, -0.84279185],
[-0.83290099, -0.83567633, -0.83343734, -0.84163708],
[-0.83158729, -0.83429571, -0.83214391, -0.84116895],
[-0.83100444, -0.83247886, -0.83427135, -0.83817006],
[-0.82802254, -0.830982 , -0.84260906, -0.84685789]],
[[-0.82926863, -0.82535357, -0.81962295, -0.82940024],
[-0.8294225 , -0.82558735, -0.82110019, -0.82617847],
[-0.82607158, -0.82608032, -0.82562142, -0.8350152 ],
[-0.83496124, -0.83509975, -0.83603291, -0.84484426],
[-0.84479157, -0.84125479, -0.83585723, -0.84474687],
[-0.84464544, -0.84253437, -0.83444812, -0.84269387],
[-0.84278804, -0.84182145, -0.8332604 , -0.84153877],
[-0.84151027, -0.84046456, -0.83196635, -0.84107051],
[-0.84094331, -0.83867874, -0.83409482, -0.83807077],
[-0.83804212, -0.83720728, -0.84243663, -0.84676108]],
[[-0.82911602, -0.82669005, -0.81815252, -0.82725751],
[-0.82926995, -0.82692302, -0.81963519, -0.82401759],
[-0.82591768, -0.82741428, -0.8241732 , -0.83290467],
[-0.83481096, -0.83640177, -0.83462448, -0.84279185],
[-0.84464544, -0.84253437, -0.83444812, -0.84269387],
[-0.84449925, -0.84380921, -0.83303354, -0.84062854],
[-0.84264105, -0.84309893, -0.83184123, -0.83946655],
[-0.84136274, -0.84174705, -0.8305422 , -0.83899551],
[-0.84079554, -0.83996778, -0.83267886, -0.83597806],
[-0.83789312, -0.83850169, -0.84105351, -0.84472027]],
[[-0.82717653, -0.82594539, -0.81691338, -0.82605227],
[-0.82733125, -0.82617881, -0.8184006 , -0.82280219],
[-0.8239619 , -0.82667103, -0.82295269, -0.83171742],
[-0.83290099, -0.83567633, -0.83343734, -0.84163708],
[-0.84278804, -0.84182145, -0.8332604 , -0.84153877],
[-0.84264105, -0.84309893, -0.83184123, -0.83946655],
[-0.84077276, -0.84238718, -0.83064506, -0.83830071],
[-0.83948755, -0.84103251, -0.82934186, -0.83782811],
[-0.8389173 , -0.83924958, -0.83148541, -0.83480076],
[-0.8359994 , -0.8377805 , -0.83988758, -0.84357199]],
[[-0.82584266, -0.82452835, -0.81556359, -0.82556375],
[-0.82599791, -0.82476263, -0.81705573, -0.82230957],
[-0.82261689, -0.82525665, -0.82162309, -0.83123616],
[-0.83158729, -0.83429571, -0.83214391, -0.84116895],
[-0.84151027, -0.84046456, -0.83196635, -0.84107051],
[-0.84136274, -0.84174705, -0.8305422 , -0.83899551],
[-0.83948755, -0.84103251, -0.82934186, -0.83782811],
[-0.83819763, -0.83967254, -0.82803413, -0.83735488],
[-0.8376253 , -0.83788269, -0.83018513, -0.83432354],
[-0.83469681, -0.83640794, -0.83861716, -0.84310649]],
[[-0.82525089, -0.82266387, -0.8177839 , -0.82243512],
[-0.82540639, -0.82289927, -0.81926792, -0.81915487],
[-0.82202021, -0.82339567, -0.82381013, -0.82815378],
[-0.83100444, -0.83247886, -0.83427135, -0.83817006],
[-0.84094331, -0.83867874, -0.83409482, -0.83807077],
[-0.84079554, -0.83996778, -0.83267886, -0.83597806],
[-0.8389173 , -0.83924958, -0.83148541, -0.83480076],
[-0.8376253 , -0.83788269, -0.83018513, -0.83432354],
[-0.83705204, -0.83608379, -0.83232385, -0.83126675],
[-0.83411887, -0.83460162, -0.8407067 , -0.84012427]],
[[-0.82222369, -0.82112803, -0.82649334, -0.83150323],
[-0.8223804 , -0.82136435, -0.82794484, -0.82829943],
[-0.818968 , -0.82186268, -0.83238638, -0.83708633],
[-0.82802254, -0.830982 , -0.84260906, -0.84685789],
[-0.83804212, -0.83720728, -0.84243663, -0.84676108],
[-0.83789312, -0.83850169, -0.84105351, -0.84472027],
[-0.8359994 , -0.8377805 , -0.83988758, -0.84357199],
[-0.83469681, -0.83640794, -0.83861716, -0.84310649],
[-0.83411887, -0.83460162, -0.8407067 , -0.84012427],
[-0.83116192, -0.83311339, -0.84889275, -0.84876322]]])
(10, 10, 4)
display(trainY[0:5])
display(trainY.shape)
---
array([0, 1, 0, 1, 0], dtype=int64)
(47666,)