scikit-learn DummyClassifier ошибка - PullRequest
0 голосов
/ 26 июня 2018

У меня есть скрипт на Python, который обучает модель RandomForestClassifier (scikit-learn) на тренировочном наборе, а затем выдает отчет о классификации на тестовом наборе. Работает отлично, несмотря на отсутствие точности. Я попытался изменить RandomForestClassifier на DummyClassifier и снова запустить скрипт, но он продолжает выдавать эту ошибку AttributeError: 'list' object has no attribute 'argmax'. Эта ошибка происходит от dummy.py. Фиктивная модель обучается, но когда я вызываю функции predict() или score(), ошибка продолжает появляться. Я убедился, что все используемые x и y имеют тип numpy.ndarray. Я также пошел к файлу dummy.py, попытался преобразовать список в массив, но это начинает выбрасывать IndexErrors из-за пределов. Наконец-то также попробовал check_X_y, но та же ошибка все еще там. У кого-нибудь еще была эта проблема? Любые решения? Вот код:

    s= RobustScaler()
    x0 = [[ro for ir, ro in enumerate(rows) if ir in xlim] for
          rows in self.data1]

    x0 = s.fit_transform(x0)
    y0 = array([[row[ylim]] for row in self.data1]).reshape(-1,1)

    xt = [[ro for ir, ro in enumerate(rows) if ir in xlim] for
          rows in self.data2]

    xt = s.transform(xt)
    yt = array([[row[ylim]] for row in self.data2]).reshape(-1,1)

    print(len(self.data2), array(x0).shape, array(y0).shape)
    trainNum = int(trainPC * len(y0) / 100)
    print(trainNum, len(y0))
    x1, y0 = check_X_y(x0, y0)
    x2, yt = check_X_y(xt, yt)
    print(x1[trainNum:][:].shape,x2.shape,y0.shape,yt.shape)
    self.data1.clear()

    try:
        classOne = joblib.load(mod_name)
    except FileNotFoundError:

        # classOne = RandomForestClassifier(criterion='gini', n_estimators=40, class_weight='balanced', verbose=2,)
        classOne = DummyClassifier(strategy='stratified')
        print("fit start")
        classOne.fit(x1[:trainNum], array(y0[:trainNum][:]).reshape(-1,1))
        print("fit stop")

    try:
        print(classOne.best_params_)
    except AttributeError:
        pass
    try:
        print(classOne.feature_importances_)
    except AttributeError:
        pass
    gc.collect()
    yp = classOne.predict(x1[trainNum:])

Вот полная ошибка:

Traceback (most recent call last):
  File "C:/Users/User/PycharmProjects/NewCrawl/new2_predict.py", line 454, in <module>
    p.classify1(xl,97,80,'dummy')
  File "C:/Users/User/PycharmProjects/NewCrawl/new2_predict.py", line 167, in classify1
    yp2 = classOne.predict(x2)
  File "C:\Users\User\AppData\Roaming\Python\Python36\site-packages\sklearn\dummy.py", line 227, in predict
    k in range(self.n_outputs_)).T
  File "C:\Users\User\AppData\Roaming\Python\Python36\site-packages\numpy\core\shape_base.py", line 237, in vstack
    return _nx.concatenate([atleast_2d(_m) for _m in tup], 0)
  File "C:\Users\User\AppData\Roaming\Python\Python36\site-packages\numpy\core\shape_base.py", line 237, in <listcomp>
    return _nx.concatenate([atleast_2d(_m) for _m in tup], 0)
  File "C:\Users\User\AppData\Roaming\Python\Python36\site-packages\sklearn\dummy.py", line 227, in <genexpr>
    k in range(self.n_outputs_)).T
AttributeError: 'list' object has no attribute 'argmax'
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...