Сохранение обученного алгоритма классификации с несколькими входами в Python - PullRequest
0 голосов
/ 14 июля 2020

Я разработал сценарий, который предсказывает возможные теги для некоторого текста на основе ранее помеченных вручную отзывов. Я использовал несколько онлайн-статей, чтобы помочь мне (а именно: https://towardsdatascience.com/multi-label-text-classification-with-scikit-learn-30714b7819c5).

Поскольку мне нужна вероятность для каждого тега, вот код, который я использовал:

NB_pipeline = Pipeline([
    ('clf', OneVsRestClassifier(MultinomialNB(alpha=0.3, fit_prior=True, class_prior=None))),
    ])

predictions_en = {}
for category in categories_en:
    NB_pipeline.fit(all_x_en, en_topics[category])
    proba_en = NB_pipeline.predict_proba(pred_x_en)
    predictions_en[category] = proba_en[-1][-1]

preds_en = pd.DataFrame(predictions_en.items())
preds_en = preds_en.sort_values(by=[1], ascending=False)
preds_en = preds_en.reset_index(drop=True)

Это очень хорошо работает для моих целей: он возвращает прогноз для каждого возможного тега. Но моя проблема в том, что он переобучает алгоритм каждый раз, когда я пытаюсь сделать прогноз. Я бы хотел обучить алгоритм в сценарии, сохранить обученный алгоритм, загрузить его в другой сценарий, в котором делается прогноз.

Я бы хотел сделать это в сценарии 1:

for category in categories_en:
    NB_pipeline.fit(all_x_en, en_topics[category])

И это в другом скрипте:

for category in categories_en:
    proba_en = NB_pipeline.predict_proba(pred_x_en)
    predictions_en[category] = proba_en[-1][-1]

Но я не могу заставить его работать. Он просто дает мне такое же предсказание, когда я пытаюсь его разделить.

1 Ответ

0 голосов
/ 14 июля 2020

Вы всегда можете использовать pickle для сериализации любого объекта python, включая ваш. Итак, самый простой и быстрый способ сохранить вашу модель - просто сериализовать ее в файл, скажем model.pickle. Это делается в первой части после обучения вашей модели. После этого все, что вам нужно сделать, это проверить, существует ли файл, и снова десериализовать его, используя pickle.

Это функция, которая сериализует python объектов в файлы:

import pickle

def serialize(obj, file):
    with open(file, 'wb') as f:
        pickle.dump(obj, f)

Это функция, которая десериализует python объекты из файлов:

import pickle

def deserialize(file):
    with open(file, 'rb') as f:
        return pickle.load(f)

После завершения обучения все, что вам нужно сделать, - это вызвать (если NB_pipeline является объектом вашей модели):

serialize(NB_pipeline, 'model.pickle')

И когда вам нужно его загрузить и использовать, просто позвоните:

NB_pipeline = deserialize('model.pickle')
...