Я использую Google Bert для классификации текста отсюда, это мультиклассификатор: https://github.com/google-research/bert Я изменил его, чтобы использовать его для своих собственных требований, я удалил оценщик TPU и использую его для прогнозирования теги переполнения стека. но, к сожалению, я получаю очень плохие результаты от bert, я думаю, что мне, возможно, нужно изменить параметры
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import csv
import os
import modeling
import tokenization
import tensorflow as tf
import numpy as np
from sklearn.model_selection import train_test_split
from bert_model import *
DO_LOWER_CASE = False
BERT_INIT_CHKPNT = "./src_bert/cased_L-12_H-768_A-12/bert_model.ckpt"
BERT_CONFIG = './src_bert/cased_L-12_H-768_A-12/bert_config.json'
BERT_VOCAB = "./src_bert/cased_L-12_H-768_A-12/vocab.txt"
MAX_SEQ_LENGTH = 128
OUTPUT_DIR = "./working/output"
TRAIN_BATCH_SIZE = 32
PREDICT_BATCH_SIZE = EVAL_BATCH_SIZE = 8
LEARNING_RATE = 2e-5
NUM_TRAIN_EPOCHS = 1.0
# Warmup is a period of time where hte learning rate
# is small and gradually increases--usually helps training.
WARMUP_PROPORTION = 0.1
# Model configs
SAVE_CHECKPOINTS_STEPS = 10000
SAVE_SUMMARY_STEPS = 500
TRAIN_VAL_RATIO = 0.9
def run_bert(x_train_val, x_test, y_train_val, y_test, label_list):
tf.logging.set_verbosity(tf.logging.INFO)
os.makedirs(OUTPUT_DIR, exist_ok=True)
LEN = x_train_val.shape[0]
SIZE_TRAIN = int(TRAIN_VAL_RATIO * LEN)
x_val = x_train_val[SIZE_TRAIN:]
y_val = y_train_val[SIZE_TRAIN:]
x_train = x_train_val[:SIZE_TRAIN]
y_train = y_train_val[:SIZE_TRAIN]
tokenization.validate_case_matches_checkpoint(DO_LOWER_CASE,
BERT_INIT_CHKPNT)
bert_config = modeling.BertConfig.from_json_file(BERT_CONFIG)
if MAX_SEQ_LENGTH > bert_config.max_position_embeddings:
raise ValueError(
"Cannot use sequence length %d because the BERT model "
"was only trained up to sequence length %d" %
(MAX_SEQ_LENGTH, bert_config.max_position_embeddings))
tf.gfile.MakeDirs(OUTPUT_DIR)
tokenizer = tokenization.FullTokenizer(
vocab_file=BERT_VOCAB, do_lower_case=DO_LOWER_CASE)
run_config = tf.estimator.RunConfig(
model_dir=OUTPUT_DIR,
save_summary_steps=SAVE_SUMMARY_STEPS,
keep_checkpoint_max=1,
save_checkpoints_steps=SAVE_CHECKPOINTS_STEPS)
train_examples = None
num_train_steps = None
num_warmup_steps = None
# Train
train_examples = create_examples(x_train, y_train)
num_train_steps = int(
len(train_examples) / TRAIN_BATCH_SIZE * NUM_TRAIN_EPOCHS)
num_warmup_steps = int(num_train_steps * WARMUP_PROPORTION)
model_fn = model_fn_builder(
bert_config=bert_config,
num_labels=len(label_list),
init_checkpoint=BERT_INIT_CHKPNT,
learning_rate=LEARNING_RATE,
num_train_steps=num_train_steps,
num_warmup_steps=num_warmup_steps,
use_tpu=False,
use_one_hot_embeddings=False)
# If TPU is not available, this will fall back to normal Estimator on CPU
# or GPU.
estimator = tf.estimator.Estimator(
model_fn=model_fn,
config=run_config,
params={"batch_size": TRAIN_BATCH_SIZE})
train_file = os.path.join(OUTPUT_DIR, "train.tf_record")
file_based_convert_examples_to_features(
train_examples, label_list, MAX_SEQ_LENGTH, tokenizer, train_file)
train_input_fn = file_based_input_fn_builder(
input_file=train_file,
seq_length=MAX_SEQ_LENGTH,
is_training=True,
drop_remainder=True)
estimator.train(input_fn=train_input_fn, max_steps=num_train_steps)
#Evalution
eval_examples = create_examples(x_val, y_val)
num_actual_eval_examples = len(eval_examples)
eval_file = os.path.join(OUTPUT_DIR, "eval.tf_record")
file_based_convert_examples_to_features(
eval_examples, label_list, MAX_SEQ_LENGTH, tokenizer, eval_file)
tf.logging.info("***** Running evaluation *****")
# This tells the estimator to run through the entire set.
eval_steps = None
# However, if running eval on the TPU, you will need to specify the
# number of steps.
eval_drop_remainder = False
eval_input_fn = file_based_input_fn_builder(
input_file=eval_file,
seq_length=MAX_SEQ_LENGTH,
is_training=False,
drop_remainder=eval_drop_remainder)
result = estimator.evaluate(input_fn=eval_input_fn, steps=eval_steps)
output_eval_file = os.path.join(OUTPUT_DIR, "eval_results.txt")
with tf.gfile.GFile(output_eval_file, "w") as writer:
tf.logging.info("***** Eval results *****")
for key in sorted(result.keys()):
tf.logging.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key])))
#Testing
predict_examples = create_examples(x_test, y_test, False)
predict_file = os.path.join(OUTPUT_DIR, "predict.tf_record")
file_based_convert_examples_to_features(predict_examples, label_list,
MAX_SEQ_LENGTH, tokenizer,
predict_file)
tf.logging.info("***** Running prediction*****")
predict_drop_remainder = False
predict_input_fn = file_based_input_fn_builder(
input_file=predict_file,
seq_length=MAX_SEQ_LENGTH,
is_training=False,
drop_remainder=predict_drop_remainder)
result = estimator.predict(input_fn=predict_input_fn)
tf.logging.info("***** Predict results *****")
correct_labeling = []
accuracy = 0
for x, prediction, correct_label in zip(x_test,result,y_test):
probabilities = prediction["probabilities"]
max_index_label = np.argmax(probabilities)
tf.logging.info("{0} orig label {1} predicted {2}".format(x,correct_label,
label_list[max_index_label]))
if label_list[max_index_label]==correct_label:
accuracy+=1
correct_labeling.append(label_list[max_index_label]==correct_label)
print("test accuracy is {0}".format(100*(accuracy/len(x_test))))
return correct_labeling
if __name__ == "__main__":
input_csv = pd.read_csv('stack-overflow-data.csv')
input_csv = input_csv.dropna()
tweets = input_csv['post']
labels = []
print("loading DB")
map_index_to_label = list(set(input_csv['tags']))
map_label_to_index = {l:i for i,l in enumerate(map_index_to_label)}
for label in input_csv['tags']:
labels.append(map_label_to_index[label])
num_labels = len(map_index_to_label)
print("splitting")
X_train, X_test, y_train, y_test = train_test_split(tweets, labels, test_size=0.33,
random_state=42)
bert_test_res = run_bert(X_train, X_test, y_train, y_test, list(set(labels)))
Любая идея, какой параметр в вышеупомянутой функции run_bert () я должен настроить, чтобы увидеть некоторое улучшение?