Ярлыки для модели Keras, предсказывающей проблему мульти-классификации - PullRequest
0 голосов
/ 03 января 2019

Если у меня есть набор целей, также известных как [1,0,9,9,7,5,4,0,4,1], и я использую model.predict(X) Keras возвращает массив из 6 элементов для каждого из 10 образцов.Он возвращает 6 элементов, поскольку существует 6 возможных целей (0,1,4,5,7,9), а keras возвращает десятичное число / число с плавающей запятой (для каждой метки), представляющее вероятность того, что одна из этих целей является правильной.Например, для первого примера - где y = 1 Keras возвращает массив, который выглядит следующим образом: [.1, .4,.003,.001,.5,.003].

Я хочу знать, какое значение соответствует какой цели (относится ли .1 к 1, потому что оно первое в наборе данных, или к 0, потому что это наименьшее число, или к 9, потому что это последнее число и т. Д.). Как Керас заказывает свои прогнозы? Документация , похоже, не формулирует это;он только говорит

«Генерирует выходные прогнозы для входных выборок.»

Так что я не уверен, как сопоставить метки с результатами прогноза.

РЕДАКТИРОВАТЬ:

Вот моя модель и учебный код:

X_train, X_test, y_train, y_test = train_test_split(data, labels, test_size=0.25, random_state=42)

Y_train = to_categorical(y_train)
Y_test = to_categorical(y_test)

sequence_input = Input(shape=(MAX_SEQUENCE_LENGTH,), dtype='int32')
embedded_sequences = embedding_layer(sequence_input)
x = Conv1D(64, 5, activation='relu')(embedded_sequences)
x = MaxPooling1D(4)(x)
x = Conv1D(64, 5, activation='relu')(x)
x = MaxPooling1D(4)(x)
x = Conv1D(64, 5, activation='relu')(x)
x = MaxPooling1D(4)(x)  # global max pooling
x = Flatten()(x)
x = Dense(64, activation='relu')(x)
preds = Dense(labels_Index, activation='softmax')(x)

model = Model(sequence_input, preds)
model.fit(X_train, Y_train, epochs=10, verbose = 1) 

1 Ответ

0 голосов
/ 03 января 2019

Keras ничего не упорядочивает, все зависит от того, как классы данных, которые вы использовали для обучения модели, определены и закодированы в горячем виде.

Обычно можно восстановить целочисленную метку класса, взяв argmax массива вероятностей классов для каждой выборки.

В вашем примере 0,1 - это класс 0, 0,4 - это класс 1, 0,003 - это класс 2, 0,001 - это класс 3, 0,5 - это класс 4, а 0,003 - это класс 5 (всего 6 классов).

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...