Интерпретация результатов НЛП с использованием матрицы путаницы - PullRequest
1 голос
/ 06 марта 2020

Я пытаюсь использовать путаницу, чтобы определить, с какими проблемами столкнулась моя нейронная сеть при обучении на наборе данных 20newsgroup. Я смог запустить сеть и получил приличную точность, но Матрица замешательства, которую я создал для анализа результатов, не имеет большого смысла без дополнительного контекста. Теоретически, метод маркировки матрицы путаницы / изменения ее вывода в контексте слов, обрабатываемых сетью, который, вероятно, помог бы мне разобраться в этом.

Подготовка данных:

import sklearn
import numpy as np
from sklearn.datasets import fetch_20newsgroups
newsgroups_train = fetch_20newsgroups(subset='train')
newsgroups_test = fetch_20newsgroups(subset='test')

X_train, y_train = newsgroups_train['data'], newsgroups_train['target']
X_test, y_test = newsgroups_test['data'], newsgroups_test['target']

y_train = np.array(y_train)
y_test = np.array(y_test)

%tensorflow_version 2.x
import tensorflow as tf
encoder = tf.keras.preprocessing.text.Tokenizer()

encoder.fit_on_texts(X_train)
sequences_trainx = encoder.texts_to_sequences(X_train)
sequences_trainx 


sequences_testx = encoder.texts_to_sequences(X_test)
sequences_testx 

X_train_matrix=tf.keras.preprocessing.sequence.pad_sequences(sequences_trainx, padding='post')
X_test_matrix= tf.keras.preprocessing.sequence.pad_sequences(sequences_testx, padding='post'  )

Нейронный Net Модель:

model = keras.Sequential([
  keras.layers.Embedding(134143, 64),
  keras.layers.GlobalAveragePooling1D(),
  keras.layers.Dense(20)
])

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.1),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

history = model.fit(X_train_matrix, y_train,
                    epochs=10,
                    validation_data = (X_test_matrix, y_test),
                    batch_size = 100
                    )

Матрица путаницы:

from sklearn.metrics import confusion_matrix

y_pred = model.predict(X_test_matrix)
y_pred = y_pred.argmax(axis=-1)
conf =  confusion_matrix(y_test, y_pred)
conf

Вывод:

array([[199,   0,   0,   0,  12,   0,   0,   0,   0,   0,  15,   1,   4,
          7,   1,  59,  19,   1,   0,   1],
       [  0, 133,   1,   1, 168,   4,   1,   0,   0,   0,  31,   2,  45,
          0,   1,   2,   0,   0,   0,   0],
       [  0,  12,  80,   2, 253,   0,   0,   0,   0,   0,  24,   0,  13,
          1,   0,   0,   9,   0,   0,   0],
       [  0,   3,   3,  26, 313,   0,   0,   0,   0,   0,  11,   0,  36,
          0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0, 361,   0,   0,   0,   0,   0,  12,   0,  12,
          0,   0,   0,   0,   0,   0,   0],
       [  0,  16,   5,   1, 217,  96,   0,   0,   0,   0,  16,   0,  43,
          1,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   1, 125,   0,  25,  11,   0,   0, 179,   0,  49,
          0,   0,   0,   0,   0,   0,   0],
       [  0,   0,   0,   0,  38,   0,   0, 258,   0,   0,  64,   0,  35,
          0,   0,   0,   1,   0,   0,   0],
       [  0,   0,   0,   0,  18,   0,   0, 150,  84,   0,  55,   0,  77,
          0,   0,   1,  13,   0,   0,   0],
       [  1,   0,   0,   0,   2,   0,   0,   0,   0, 192, 201,   0,   0,
          0,   0,   1,   0,   0,   0,   0],
       [  0,   0,   0,   0,   2,   0,   0,   0,   0,   0, 395,   0,   0,
          0,   0,   1,   1,   0,   0,   0],
       [  1,   1,   0,   0,  22,   0,   0,   0,   0,   0,  21, 232,  81,
          0,   1,   1,  36,   0,   0,   0],
       [  0,   0,   0,   0,  58,   0,   0,   2,   0,   0,  21,   1, 311,
          0,   0,   0,   0,   0,   0,   0],
       [  1,   0,   0,   0,  29,   0,   0,   1,   0,   0,  94,   0,  86,
        163,   1,   7,  14,   0,   0,   0],
       [  0,   4,   0,   0,  19,   0,   0,   0,   0,   0,  27,   0,  69,
          3, 261,   3,   8,   0,   0,   0],       
[  2,   0,   0,   0,   6,   0,   0,   0,   0,   0,  37,   0,   5,
              1,   0, 346,   0,   0,   0,   1],
           [  0,   0,   0,   0,   1,   0,   0,   0,   0,   0,  39,   1,   0,
              0,   0,   1, 322,   0,   0,   0],
           [ 10,   0,   0,   0,   3,   0,   0,   0,   0,   0,  65,   0,   1,
              0,   0,  19,  31, 246,   1,   0],
           [ 10,   0,   0,   0,   2,   0,   0,   2,   0,   0,  55,   1,   7,
              2,   6,   7, 160,   0,  58,   0],
           [ 53,   1,   0,   0,   6,   0,   0,   0,   0,   0,  35,   0,   4,
              2,   2,  89,  47,   1,   1,  10]])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...