многослойная классификация тензорного потока - несовместимые формы: [7,5] и [7] - PullRequest
0 голосов
/ 04 апреля 2020

Я пытаюсь создать классификатор с несколькими метками и столкнулся с проблемой. У меня 5 классов, и я застреваю при попытке обучить сеть, я относительно новичок в машинном обучении, и это первый классификатор, который я создал.

Мой код:

```
def createModel(learn, act):

    model = models.Sequential()
    model.add(layers.Conv2D(32, (9,9), activation=act, input_shape=(512,512,1)))
    model.add(layers.AveragePooling2D((2,2)))
    model.add(layers.Conv2D(64, (9, 9), activation=act))
    model.add(layers.AveragePooling2D((2,2)))
    model.add(layers.Conv2D(64, (6, 6), activation=act))
    model.add(layers.MaxPooling2D((2,2)))
    model.add(layers.Conv2D(96, (6, 6), activation=act))
    model.add(layers.MaxPooling2D((2,2)))
    model.add(layers.Conv2D(128, (3, 3), activation=act))
    model.add(layers.MaxPooling2D((2,2)))
    model.add(layers.Conv2D(128, (3, 3), activation=act))
    model.add(layers.Flatten())
    model.add(layers.Dense(128, activation=act))
    model.add(layers.Dense(5, activation='sigmoid'))
    model.compile(optimizer=optimizers.Adam(learning_rate=learn), loss='binary_crossentropy', metrics=['accuracy'])

    return model


model = createModel(0.005, 'tanh')
History = model.fit(Xtrain, ytrain, epochs=300, validation_data=(Xtest, ytest), verbose=0)
```

Я использую свою собственную функцию разделения из-за того, что мой набор данных довольно странно отформатирован, поэтому я должен создать свои собственные метки с уже существующими данными, которые затем запускаются через горячий кодер. Создание меток типа s0:

```array([[[1., 0.],
        [1., 0.],
        [1., 0.],
        [0., 1.],
        [1., 0.]]```

Я использую 10 частей массивов изображений в качестве теста, который разделен на 70%, обучает на 30% тестов, но когда я начинаю обучать сеть, возникает следующая ошибка:>

```Incompatible shapes: [7,5] vs. [7]
     [[node Equal (defined at <ipython-input-54-eb6611e36e68>:3) ]] [Op:__inference_train_function_4978]```

Что это значит и как я могу это исправить?

1 Ответ

0 голосов
/ 04 апреля 2020

Я удалил горячий кодер, который вернул метки в предыдущую форму [0., 0., 0., 1., 0.], Что затем позволило мне обучить сеть.

...