Мультиклассовые прогнозы с потерей категорической_кросентропии - PullRequest
0 голосов
/ 23 января 2019

Допустим, этот пример реализует простую двоичную классификацию.

X = массив ([[1,2,3], [2,3,4], [3,4,5]])

y = массив ([0], [1], [0])

...
model.compile(loss='binary_crossentropy', optimizer='adam')
model.fit(X, y, epochs=50, verbose=0)
# new instance where we do not know the answer
Xnew = array([[4, 5, 6]])
# make a prediction
ynew = model.predict(Xnew)
#show the inputs and predicted outputs
print("X=%s, Predicted=%s" % (Xnew[0], ynew[0]))
...

results
X=[4, 5, 6], Predicted=[0 or 1]

И этот реализует мультиклассовую классификацию.

X = массив ([[1,2,3], [2,3,4], [3,4,5]])

y = массив ([4], [5], [6])

...

model.compile(loss='categorical_crossentropy', optimizer='adam')
# fit model
model.fit(X, y, epochs=50, verbose=2)
model.reset_states()
# evaluate model on new data
yhat = model.predict((X))
...

results decoded
X=[4, 5, 6], Predicted=[4, 5, 6]

Как реализовать мультиклассовую классификацию с одним выходом, чтобы получить что-то вроде этого? (аналогично прогнозированию временных рядов)

X = массив ([[1,2,3], [2,3,4], [3,4,5]])

y = массив ([4], [5], [6])

 # new instance where we do not know the answer
 Xnew = array([[4, 5, 6]])
 yhat = model.predict_classes(Xnew)

декодированные результаты X = [4, 5, 6], прогнозируемый = [7]

1 Ответ

0 голосов
/ 23 января 2019

То, что вы ищете, это функция loss='sparse_categorical_crossentropy', которая будет предполагать, что целочисленные цели являются метками классов. Поэтому, если ваша модель имеет 7 выходов, и вы задаете цель 2, sparse_categorical_crossentropy преобразует 2 в [0,0,1,0,0,0,0] в качестве цели и применяет categorical_crossentropy как обычно.

В этом случае ваша функция активации выходного слоя должна быть softmax, а количество выходов должно быть равно количеству классов. Скорее всего что-то вроде Dense(num_classes, activation='softmax')

Если ваши целочисленные классы просто [4,5,6], то вам нужно переместить их в [0,1,2], чтобы выполнить условие max(Y_targets) < num_classes.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...