Как сохранить глубокие графики супер ученика? - PullRequest
1 голос
/ 22 мая 2019

Я пытаюсь запустить пакет глубокого обучения (https://github.com/levyben/DeepSuperLearner), который дает 2 графика, но не показывает, как их сохранить - из этого кода можно что-нибудь добавить для сохранения графиков?

Вот мой код:

if __name__ == '__main__':
    MLP_learner = dcv.GridSearchCV(mlp, parameter_space, cv=inner_cv,iid=False, n_jobs=-1)
    GBM_learner = dcv.GridSearchCV(gbm, param, cv=inner_cv,iid=False, n_jobs=-1)
    LR_learner = dcv.GridSearchCV(logreg, LR_par, cv=inner_cv, iid=False, n_jobs=-1)
    RFC_learner = dcv.GridSearchCV(rfc, param_grid, cv=inner_cv,iid=False, n_jobs=-1)
    SVM_learner =  dcv.GridSearchCV(svm, tuned_parameters, cv=inner_cv, iid=False, n_jobs=-1)
    Keras_learner = GridSearchCV(estimator=keras, param_grid=kerasparams, cv=inner_cv,iid=False, n_jobs=-1)
    Base_learners = {'MultilayerPerceptron':MLP_learner, 'GradientBoostingMachine':GBM_learner, 
        'LogisticRegression':LR_learner,'RandomForest':RF_learner, 'SupportVectorMachine':SVM_learner, 'Keras':Keras_learner}

    X_train, X_test, Y_train, Y_test = train_test_split(X_res, y_res, test_size=0.2, random_state=0)

    DSL_learner = DeepSuperLearner(Base_learners)
    DSL_learner.fit(X_train, Y_train,max_iterations=1,sample_weight=None)
    DSL_learner.get_precision_recall(X_test, Y_test, show_graphs=True) 
    y_pred = DSL_learner.predict(X_test)
    y_pred = numpy.argmax(y_pred,axis=1) 
    print("Deep Super Learner Test Accuracy:", accuracy_score(y_pred, Y_test)*100, "%")

Я предполагаю, что это 'show_graphs = True', дающая мне графики, но мне нужно иметь возможность сохранить эти выходные данные, прочитав документацию в github, которую они не дают, чтобы добавить эту функцию в функцию .get_precision_recall (). Я пытался применить функцию savefig () в matplotlib, но пока безуспешно.

Я попытался добавить:

    plot = DSL_learner.get_precision_recall(X_test, Y_test, show_graphs=True)
    plt.show(plot)
    plot.savefig('DeepSuperLearner.png')

но это дает ошибку: AttributeError: 'tuple' object has no attribute 'savefig'

Вот изображение, которое я получаю (быстро показываю с меньшим количеством моделей для скорости) и которое я пытаюсь сохранить: graphs

Они также такие же, как в примере с github. Также, когда я бегу:

  plot = DSL_learner.get_precision_recall(X_test, Y_test, show_graphs=True)
  plt.show(plot)

графическое изображение выводится во второй раз. Я запускаю это в лаборатории jupyter (однако необходимо иметь возможность сохранять графики, чтобы запускать этот код в другом месте и по-прежнему получать графики)

...