Я сделал модель Keras для задачи классификации NLP Multi-class. Данные состоят из заголовков и тегов. Я обучил модель на названиях, чтобы предсказать теги. Я преобразовал теги в один горячий, используя sklearn.preprocessing LabelEncoder, OneHotEncoder.
Кодировка OneHot
def onehot(df):
values = array(df)
# integer encode
label_encoder = LabelEncoder()
integer_encoded = label_encoder.fit_transform(values)
# binary encode
onehot_encoder = OneHotEncoder(sparse=False)
integer_encoded = integer_encoded.reshape(len(integer_encoded), 1)
onehot_encoded = onehot_encoder.fit_transform(integer_encoded)
return label_encoder, onehot_encoded
Я использовал категорийный кроссцентроп и Адам для своей модели. Вот код для модели.
Keras CNN Модель
def ConvNet(embeddings, max_sequence_length, num_words, embedding_dim, labels_index):
embedding_layer = Embedding(num_words, embedding_dim, weights=[embeddings],
input_length = max_sequence_length,
trainable = False)
sequence_input = Input(shape = (max_sequence_length,), dtype='int32')
embedded_sequences = embedding_layer(sequence_input)
convs = []
filter_sizes = [2,3,4,5,6]
for filter_size in filter_sizes:
l_conv = Conv1D(filters=200, kernel_size=filter_size, activation='relu')(embedded_sequences)
l_pool = GlobalMaxPooling1D()(l_conv)
convs.append(l_pool)
l_merge = concatenate(convs, axis=1)
x = Dropout(0.1)(l_merge)
x = Dense(128, activation='relu')(x)
x = Dropout(0.2)(x)
preds = Dense(labels_index, activation='sigmoid')(x)
model = Model(sequence_input, preds)
model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['acc'])
model.summary()
return model
Прогноз
predictions = model.predict(test_cnn_data, batch_size=1024, verbose=1)
Я получаю массив, подобный этому из предсказаний
print(predictions[5, :])
Output:
array([8.8067267e-08, 5.1040554e-15, 1.9745098e-16, ..., 8.0959568e-17,
2.1070798e-17, 1.1202571e-18], dtype=float32)
Я понимаю, что это вероятности или показатель достоверности того, что следующее предложение принадлежит этому тегу.
Как преобразовать предсказанное массивы в теги, чтобы я мог сравнить его с тегами тестового набора данных для точности?