Я пытаюсь использовать путаницу, чтобы определить, с какими проблемами столкнулась моя нейронная сеть при обучении на наборе данных 20newsgroup. Я смог запустить сеть и получил приличную точность, но Матрица замешательства, которую я создал для анализа результатов, не имеет большого смысла без дополнительного контекста. Теоретически, метод маркировки матрицы путаницы / изменения ее вывода в контексте слов, обрабатываемых сетью, который, вероятно, помог бы мне разобраться в этом.
Подготовка данных:
import sklearn
import numpy as np
from sklearn.datasets import fetch_20newsgroups
newsgroups_train = fetch_20newsgroups(subset='train')
newsgroups_test = fetch_20newsgroups(subset='test')
X_train, y_train = newsgroups_train['data'], newsgroups_train['target']
X_test, y_test = newsgroups_test['data'], newsgroups_test['target']
y_train = np.array(y_train)
y_test = np.array(y_test)
%tensorflow_version 2.x
import tensorflow as tf
encoder = tf.keras.preprocessing.text.Tokenizer()
encoder.fit_on_texts(X_train)
sequences_trainx = encoder.texts_to_sequences(X_train)
sequences_trainx
sequences_testx = encoder.texts_to_sequences(X_test)
sequences_testx
X_train_matrix=tf.keras.preprocessing.sequence.pad_sequences(sequences_trainx, padding='post')
X_test_matrix= tf.keras.preprocessing.sequence.pad_sequences(sequences_testx, padding='post' )
Нейронный Net Модель:
model = keras.Sequential([
keras.layers.Embedding(134143, 64),
keras.layers.GlobalAveragePooling1D(),
keras.layers.Dense(20)
])
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.1),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
history = model.fit(X_train_matrix, y_train,
epochs=10,
validation_data = (X_test_matrix, y_test),
batch_size = 100
)
Матрица путаницы:
from sklearn.metrics import confusion_matrix
y_pred = model.predict(X_test_matrix)
y_pred = y_pred.argmax(axis=-1)
conf = confusion_matrix(y_test, y_pred)
conf
Вывод:
array([[199, 0, 0, 0, 12, 0, 0, 0, 0, 0, 15, 1, 4,
7, 1, 59, 19, 1, 0, 1],
[ 0, 133, 1, 1, 168, 4, 1, 0, 0, 0, 31, 2, 45,
0, 1, 2, 0, 0, 0, 0],
[ 0, 12, 80, 2, 253, 0, 0, 0, 0, 0, 24, 0, 13,
1, 0, 0, 9, 0, 0, 0],
[ 0, 3, 3, 26, 313, 0, 0, 0, 0, 0, 11, 0, 36,
0, 0, 0, 0, 0, 0, 0],
[ 0, 0, 0, 0, 361, 0, 0, 0, 0, 0, 12, 0, 12,
0, 0, 0, 0, 0, 0, 0],
[ 0, 16, 5, 1, 217, 96, 0, 0, 0, 0, 16, 0, 43,
1, 0, 0, 0, 0, 0, 0],
[ 0, 0, 0, 1, 125, 0, 25, 11, 0, 0, 179, 0, 49,
0, 0, 0, 0, 0, 0, 0],
[ 0, 0, 0, 0, 38, 0, 0, 258, 0, 0, 64, 0, 35,
0, 0, 0, 1, 0, 0, 0],
[ 0, 0, 0, 0, 18, 0, 0, 150, 84, 0, 55, 0, 77,
0, 0, 1, 13, 0, 0, 0],
[ 1, 0, 0, 0, 2, 0, 0, 0, 0, 192, 201, 0, 0,
0, 0, 1, 0, 0, 0, 0],
[ 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 395, 0, 0,
0, 0, 1, 1, 0, 0, 0],
[ 1, 1, 0, 0, 22, 0, 0, 0, 0, 0, 21, 232, 81,
0, 1, 1, 36, 0, 0, 0],
[ 0, 0, 0, 0, 58, 0, 0, 2, 0, 0, 21, 1, 311,
0, 0, 0, 0, 0, 0, 0],
[ 1, 0, 0, 0, 29, 0, 0, 1, 0, 0, 94, 0, 86,
163, 1, 7, 14, 0, 0, 0],
[ 0, 4, 0, 0, 19, 0, 0, 0, 0, 0, 27, 0, 69,
3, 261, 3, 8, 0, 0, 0],
[ 2, 0, 0, 0, 6, 0, 0, 0, 0, 0, 37, 0, 5,
1, 0, 346, 0, 0, 0, 1],
[ 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 39, 1, 0,
0, 0, 1, 322, 0, 0, 0],
[ 10, 0, 0, 0, 3, 0, 0, 0, 0, 0, 65, 0, 1,
0, 0, 19, 31, 246, 1, 0],
[ 10, 0, 0, 0, 2, 0, 0, 2, 0, 0, 55, 1, 7,
2, 6, 7, 160, 0, 58, 0],
[ 53, 1, 0, 0, 6, 0, 0, 0, 0, 0, 35, 0, 4,
2, 2, 89, 47, 1, 1, 10]])