Я участвую в конкурсе Quora Insincere Questions Classification (несвоевременная подача), но есть странная ошибка, которую я не могу понять. Вот мой код (соответствующие части):
def loss(predict, observed):
a = predict*observed
b = predict+observed
return 2*(a/b)
train = pd.read_csv('../input/train.csv')
test = pd.read_csv('../input/test.csv')
train = train.iloc[0:5000, :]
test = test.iloc[0:1000, :]
qid = test['qid']
train = train.drop('qid', axis=1)
test = test.drop('qid', axis=1)
x_train, x_val, y_train, y_val = train_test_split(train['question_text'], train['target'])
count = CountVectorizer(stop_words='english', ngram_range=(1,1), min_df=1, #tokenizer=LemmaTokenizer()
)
tfidf = TfidfVectorizer(stop_words='english', ngram_range=(1,1), min_df=1, #tokenizer=LemmaTokenizer()
)
count.fit(list(x_train), list(x_val))
x_train_count = count.transform(x_train)
x_val_count = count.transform(x_val)
logistic = LogisticRegression()
logistic.fit(x_train_count, y_train)
predictions = logistic.predict_proba(x_val_count)
print("loss: %0.3f " %loss(predictions, y_val))
Когда я запускаю его, я получаю эту ошибку:
ValueError: operands could not be broadcast together with shapes (1250,2) (1250,)
Я знаю, почему я получил ошибку: это потому, что я не могу напрямую умножить два массива. Но вот некоторые измерения, которые не имеют смысла:
x_val_count.shape - (1250, 8411)
Я предполагаю, что это расширенный массив комментариев (1250 тестовых примеров) в числовой форме. Но начало печатного массива таково:
(0, 1057) 1
(0, 4920) 1
(0, 5563) 1
(1, 2894) 1
(1, 3403) 1
(2, 3311) 1
(3, 1386) 1
(3, 1646) 1
(4, 3207) 1
(4, 3330) 1
(4, 6111) 1
(5, 2346) 1
(5, 4148) 1
(5, 4441) 1
(5, 5223) 1
(5, 5316) 1
(5, 5378) 1
(5, 5565) 2
(5, 7571) 1
(6, 746) 2
(6, 983) 1
(6, 985) 1
(6, 3182) 1
(6, 3455) 1
(6, 4636) 1
Это выглядит так, как будто у него две колонки Почему это несоответствие?
predictions.shape - (1250, 2)
Я не знаю, почему у прогнозов есть два столбца. Почему не один?
Я надеюсь, что если я узнаю больше, я смогу решить проблему. Но кто-нибудь знает, как я могу это исправить?