Как сделать одну перекрестную проверку с тензорным потоком (Keras)? - PullRequest
0 голосов
/ 08 ноября 2019

У меня есть 20 предметов, и я хочу использовать перекрестную проверку без отрыва, когда я тренирую модель, реализованную с помощью Tensorflow. Я следую некоторым инструкциям и, наконец, вот мой псевдокод:

for train_index, test_index in loo.split(data):
print("TRAIN:", train_index, "TEST:", test_index)
train_X=np.concatenate(np.array([data[ii][0] for ii in train_index]))
train_y=np.concatenate(np.array([data[ii][1] for ii in train_index]))

test_X=np.concatenate(np.array([data[ii][0] for ii in test_index]))
test_y=np.concatenate(np.array([data[ii][1] for ii in test_index]))


train_X,train_y = shuffle(train_X, train_y)
test_X,test_y = shuffle(test_X, test_y)



#Compile the model
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

#keras.initializers.RandomNormal(mean=0.0, stddev=0.05, seed=None)

#Train the model
model.fit(train_X, train_y, batch_size=16, epochs=30,validation_split=.10)#,callbacks=[])

#test accuracy
test_loss, test_acc = model.evaluate(test_X,test_y)
print('\nTest accuracy:', test_acc)

, но результаты после первого предмета выглядят так:

Epoch 30/30
3590/3590 [==============================] - 4s 1ms/sample - loss: 0.5976 - 
**acc: 0.8872** - val_loss: 1.3873 - val_acc: 0.6591


255/255 [==============================] - 0s 774us/sample - loss: 1.8592 - 
acc: 0.4471

Test accuracy: 0.44705883

вторая итерация (Тема):

TRAIN: [ 0  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17] TEST: [1]

Train on 3582 samples, validate on 398 samples
Epoch 1/30
3582/3582 [==============================] - 5s 1ms/sample - loss: 0.7252 - 
**acc: 0.8238** - val_loss: 1.0627 - val_acc: 0.6859

Звучит, что модель использует предыдущие веса! Если мы посмотрим на первую точность второй итерации, она начинается с соотв: 0,8238!

Правильна ли моя реализация? или мне нужно больше шагов к начальному весу для каждого предмета?

1 Ответ

0 голосов
/ 08 ноября 2019

0,8238 - это данные тренировки, а не данные ваших тестов. Ваш метод fit () также имеет разделение проверки для данных тренировки.

Модель отлично работает из того, что я вижу. Ваша реализация верна.

...