Матрица смешения и точность теста в мультимарочной классификации изображений для кера - PullRequest
0 голосов
/ 24 февраля 2019

Я пытаюсь выучить классификацию по нескольким меткам, используя Keras.Набор данных - классификация роботизированных инструментов (для хирургии почек) с 11 классами.Я построил модель для задачи классификации с несколькими метками и смог оценить производительность модели.Тем не менее, у меня есть проблема, чтобы определить точность теста и путаницу.Может кто-нибудь, пожалуйста, поделитесь фрагментом кода, который работает.Заранее всем спасибо!

name_label_dictionary = {
1:  "instrument-shaft",
2:  "instrument-clasper",
3:  "instrument-wrist",
4:  "kidney-parenchyma",
5:  "covered-kidney",
6:  "thread",
7:  "clamps",
8:  "suturing-needle",
9:  "suction-instrument",
10:  "small-intestine",
11:  "ultrasound-probe"}
image_dataframe = pd.read_csv('data/train.csv')


encoder = MultiLabelBinarizer()
encoder.fit_transform([(1,),(2,),(3,),(4,),(5,),(6,),(7,),(8,),(9,),(10,),  (11,)])
samples = list(zip(image_dataframe['Id'],image_dataframe['Target']))
train_samples, validation_samples = train_test_split(samples,test_size=0.15)



def generator(data,batch_size=8):
images_path_length = len(data)
while 1:
    for off in range(0,images_path_length,batch_size):
        images_list = data[off:off+batch_size]
        rgb_arr=[]
        label=[]
        for j in images_list:
            img=[]
            img = np.array(Image.open("data/train/"+j[0]+"_resized.png"))/255
            rgb_arr.append(img)
            label.append(encoder.transform([tuple(map(int,j[1].split()))]))

        yield np.array(rgb_arr),np.array(label).reshape(len(label),11)


train_images_gen = generator(train_samples)
val_images_gen = generator(validation_samples)

Вот код Keras для обучения данных.

   model = InceptionV3(include_top = True, weights = None, classes=11)
   model.layers.pop()
   x = model.layers[-1].output
   x = Dense(11, activation='sigmoid', name='predictions')(x)
   train_model = Model(input=model.input,output=x)
   train_model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

   train_model.fit_generator(train_images_gen,
                      epochs=10,
                      verbose=1,
                      steps_per_epoch=len(train_samples)/8, 
                      validation_data=val_images_gen,
                      validation_steps=len(validation_samples)/8)

 score = train_model.evaluate_generator(val_images_gen, steps=50)
 print ('Validation Score: ', score[0])
 print ('Validation Accuracy: ',score[1])

В следующей части я создаю набор тестов и соответствующие ему метки, аналогичныепоезд и валидация установлены.

test_image_dataframe = pd.read_csv('data/test.csv')
test_samples = list(zip(test_image_dataframe['Id'], test_image_dataframe['Target']))

def test_generator(data,batch_size=8):
images_path_length = len(data)
while 1:
    for off in range(0,images_path_length,batch_size):
        images_list = data[off:off+batch_size]
        rgb_arr=[]
        label=[]
        for j in images_list:
            img=[]
            img = np.array(Image.open("data/test/"+j[0]+"_resized.png"))/255
            rgb_arr.append(img)
            label.append(encoder.transform([tuple(map(int,j[1].split()))]))

        yield np.array(rgb_arr),np.array(label).reshape(len(label),11)

и в конце предсказывают метки:

results = train_model.predict_generator(test_images_gen,steps=len(test_samples)/8,verbose=1)

predictions=[]
for i in results:
    label_predict=np.arange(11)[i >=0.2]
    predictions.append(' '.join(str(l) for l in label_predict))
print(predictions)

вывод прогнозов:

['0 1 2 4 5',
 '0 1 2 4',
 '0 1 2 4',
 '0 1 2 3 4',
 '0 1 2 3 4',
 '0 1 2 3 4',
 '0 1 2 3 4',
 '0 1 2 3 4',
 '0 1 2 3 4',
 '1 2 3 4',
 '0 1 2 3 4',
 '0 1 2 3 4',
 '0 1 2 3 4',

вывод результатов:

array([[9.96413887e-01, 1.00000000e+00, 9.80881095e-01, ...,
    8.43488611e-04, 9.98612493e-03, 1.06460014e-07],
   [9.88547325e-01, 9.99999881e-01, 9.43120480e-01, ...,
    4.49081184e-03, 7.23437071e-02, 1.84218152e-05],
   [9.99697685e-01, 1.00000000e+00, 9.99999642e-01, ...,
    1.14535843e-03, 1.99837261e-03, 6.23652944e-03],
   ...,
   [9.99984503e-01, 1.00000000e+00, 9.99968529e-01, ...,
    4.65406990e-03, 5.77826202e-01, 3.62675011e-01],
...