Полиномиальное наивное байесовское изменение Softmax - PullRequest
0 голосов
/ 09 декабря 2018

В scikit learn я делал многоклассовую классификацию, используя MultinomialNB для помеченных текстовых данных.При прогнозировании я использовал функцию «предиката_пробы» в многочленеNB

clf=MultinomialNB()
print(clf.fit(X_train,Y_train))
clf.predict_proba(X_test[0])

В результате я получил вектор значений вероятности для каждого класса, который суммировал до 1. Я знаю, что это из-за функции кросс-энтропии softmax.

массив ([[0,01245064, 0,02346781, 0,84694063, 0,03238112, 0,01833107, 0,03103464, 0,03539408]])

Мой вопрос здесь, при прогнозировании мне нужно иметь двоичную_кросс_энтропию, чтобыЯ получаю вектор значений вероятности для каждого класса от 0 до 1 независимо друг от друга.Итак, как мне изменить функцию при прогнозировании в scikit-learn?

1 Ответ

0 голосов
/ 09 декабря 2018

Вы можете получить логарифмическую правдоподобие для каждого класса, используя:

_joint_log_likelihood(self, X):
        """Compute the unnormalized posterior log probability of X
        I.e. ``log P(c) + log P(x|c)`` for all rows x of X, as an array-like of
        shape [n_classes, n_samples].
        Input is passed to _joint_log_likelihood as-is by predict,
        predict_proba and predict_log_proba.
        """ 

Наивный байесовский предикат_tlog_proba работает просто путем нормализации функции выше.

def predict_log_proba(self, X):
        """
        Return log-probability estimates for the test vector X.
        Parameters
        ----------
        X : array-like, shape = [n_samples, n_features]
        Returns
        -------
        C : array-like, shape = [n_samples, n_classes]
            Returns the log-probability of the samples for each class in
            the model. The columns correspond to the classes in sorted
            order, as they appear in the attribute `classes_`.
        """
        jll = self._joint_log_likelihood(X)
        # normalize by P(x) = P(f_1, ..., f_n)
        log_prob_x = logsumexp(jll, axis=1)
        return jll - np.atleast_2d(log_prob_x).T 
...