Я получил эту ошибку при переходе к моей функции классификации при выходе из модели Флобера (французская версия Берта). Как я могу преодолеть эту проблему.
моя функция взяла в качестве аргументов список id, texte (столбец предложений в строке) и связанные с ними метки (в данном случае три); Я тестирую 5 дифференцирующих классификаторов и хочу напечатать точность для каждой модели, а также отобразить точность для каждого делителя резюме.
def classify(int, texte, ylabels):
output_id = int
texte = texte
X_data = Flaubert(texte) # utilisation modèle Flaubert
model1 = LinearSVC()
model2 = MultinomialNB()
model3 = LogisticRegression()
model4 = RandomForestClassifier()
model5 = KNeighborsClassifier()
models = {'Model_SVC': model1, 'Model_NB': model2, 'Model_LR': model3, 'Model_RF': model4, 'Model_KN': model5}
cv_splitter = KFold(n_splits=10, shuffle=False, random_state=None)
for model_name, model in models.items():
y_pred = cross_val_predict(model, X_data, ylabels, cv=cv_splitter)
all_accuracies = cross_val_score(estimator=model, X_data, ylabels, cv=cv_splitter)
print(all_accuracies)
print(all_accuracies.mean()) #moyenne
print(all_accuracies.std()) #déterminer la variance
print(y_pred)
print("Model: {}".format(model_name))
print("Accuracy: {}".format(accuracy_score(ylabels, y_pred)))
функция flauber:
def Flaubert(texte):
# You could choose among ['flaubert-base-cased', 'flaubert-base-uncased', 'flaubert-large-cased']
modelname = 'flaubert-base-cased'
flaubert, log = FlaubertModel.from_pretrained(modelname, output_loading_info=True)
flaubert_tokenizer = FlaubertTokenizer.from_pretrained(modelname, do_lowercase=False) # do_lowercase=False if using the 'cased' model, otherwise it should be set to False
tokenized = texte.apply((lambda x: flaubert_tokenizer.encode(x, add_special_tokens=True)))
max_len = 0
for i in tokenized.values:
if len(i) > max_len:
max_len = len(i)
padded = np.array([i + [0] * (max_len - len(i)) for i in tokenized.values])
# Using model
token_ids = torch.tensor(padded)
#with torch.no_grad():
last_layer = flaubert(token_ids)[0]
return last_layer
last_layer выглядит следующим образом
трассировка стека:
Traceback (most recent call last):
File "test.py", line 49, in <module>
classify(data_id,data,ylabels)
File "/ho/ge/ke/classifieur.py", line 55, in classify
y_pred = cross_val_predict(model, X_data, ylabels, cv=cv_splitter)
File "/ho/ge/ke/anaconda3/envs/env/lib/python3.6/site-packages/sklearn/model_selection/_validation.py", line 755, in cross_val_predict
for train, test in cv.split(X, y, groups))
File "/ho/ge/ke/.local/lib/python3.6/site-packages/joblib/parallel.py", line 1004, in __call__
if self.dispatch_one_batch(iterator):
File "/ho/ge/ke/.local/lib/python3.6/site-packages/joblib/parallel.py", line 835, in dispatch_one_batch
self._dispatch(tasks)
File "/ho/ge/ke/.local/lib/python3.6/site-packages/joblib/parallel.py", line 754, in _dispatch
job = self._backend.apply_async(batch, callback=cb)
File "/ho/ge/ke/.local/lib/python3.6/site-packages/joblib/_parallel_backends.py", line 209, in apply_async
result = ImmediateResult(func)
File "/ho/ge/ke/.local/lib/python3.6/site-packages/joblib/_parallel_backends.py", line 590, in __init__
self.results = batch()
File "/ho/ge/ke/.local/lib/python3.6/site-packages/joblib/parallel.py", line 256, in __call__
for func, args, kwargs in self.items]
File "/ho/ge/ke/.local/lib/python3.6/site-packages/joblib/parallel.py", line 256, in <listcomp>
for func, args, kwargs in self.items]
File "/ho/ge/ke/anaconda3/envs/env/lib/python3.6/site-packages/sklearn/model_selection/_validation.py", line 842, in _fit_and_predict
estimator.fit(X_train, y_train, **fit_params)
File "/ho/ge/ke/anaconda3/envs/env/lib/python3.6/site-packages/sklearn/svm/_classes.py", line 235, in fit
accept_large_sparse=False)
File "/ho/ge/ke/anaconda3/envs/env/lib/python3.6/site-packages/sklearn/utils/validation.py", line 739, in check_X_y
estimator=estimator)
File "/ho/ge/ke/anaconda3/envs/env/lib/python3.6/site-packages/sklearn/utils/validation.py", line 558, in check_array
% (array.ndim, estimator_name))
ValueError: Found array with dim 3. Estimator expected <= 2.