Модель SimpleTransformers с K-Fold - PullRequest
       79

Модель SimpleTransformers с K-Fold

0 голосов
/ 06 февраля 2020

Я столкнулся с проблемой, пытаясь собрать прогнозы для out_of_folds (чтобы получить прогнозы для всего обучающего набора)

Когда я вычисляю f1 metri c после прогноза в каждой итерации, я получаю другое число, затем встроенный f1-metri c.

Изображение результатов. Моя метри c идет после встроенной метрики * метрика 1033 *

enter image description here

Код:

def cros_val(features, n_folds = 5):

labels = features['target']

features = np.array(features)
k_fold = StratifiedKFold(n_splits=n_folds, shuffle=True)
results = []

out_of_fold = np.zeros(features.shape[0])

for train_index, val_index in k_fold.split(features, labels):

    train_df = features[train_index]
    val_df = features[val_index]

    train_df = pd.DataFrame(train_df)
    val_df = pd.DataFrame(val_df)

    model = ClassificationModel('bert', 'bert-base-uncased', args=custom_args) 
    model.train_model(train_df)

    result, model_outputs, _ = model.eval_model(val_df, f1=f1_score)

    val_proba = softmax(model_outputs, axis=1)

    out_of_fold[val_index] = val_proba[:, 1]

    print(f1_score(labels.to_numpy()[val_index], out_of_fold[val_index]>0.5))

    results.append(result['f1'])

    gc.enable()
    del train_df, val_df
    gc.collect()

print()
print(f1_score(out_of_fold>0.5, labels.to_numpy()))


return out_of_fold, results

Но несколько раз я получал одни и те же числа для 2/5 итераций для встроенных и моих метрик.

В то же время, когда я делаю это:

model = ClassificationModel('bert', 'bert-base-uncased', args=custom_args)
model.train_model(train_data.iloc[:6500])

result, model_outputs, _ = model.eval_model(train_data.iloc[6500:], f1=f1_score)
val_proba = softmax(model_outputs, axis=1)

{'m cc': 0,6249772665105072, 'tp': 409, 'tn': 496, 'fp': 90, 'fn': 118, 'f1': 0,7972709551656921, 'eval_loss': 0,42122048985745225}

f1_score(train_data['target'].to_numpy()[6500:], val_proba[:,1] > 0.5)

0.7980295566502462

Я получаю то, что хотел. Но в то же время эти цифры тоже немного отличаются. Возможно, что-то не так с индексами, но где ...

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...