Как использовать обученные контрольные точки модели BERT для прогнозирования? - PullRequest
2 голосов
/ 28 июня 2019

Я обучил BERT с помощью SQUAD 2.0 и получил model.ckpt.data, model.ckpt.meta.model.ckpt.index (F1 балл: 81) в выходном каталоге вместе с Foretions.json и т. д., используя BERT-master / run_squad.py

python run_squad.py \
  --vocab_file=$BERT_LARGE_DIR/vocab.txt \
  --bert_config_file=$BERT_LARGE_DIR/bert_config.json \
  --init_checkpoint=$BERT_LARGE_DIR/bert_model.ckpt \
  --do_train=True \
  --train_file=$SQUAD_DIR/train-v2.0.json \
  --do_predict=True \
  --predict_file=$SQUAD_DIR/dev-v2.0.json \
  --train_batch_size=24 \
  --learning_rate=3e-5 \
  --num_train_epochs=2.0 \
  --max_seq_length=384 \
  --doc_stride=128 \
  --output_dir=gs://some_bucket/squad_large/ \
  --use_tpu=True \
  --tpu_name=$TPU_NAME \
  --version_2_with_negative=True

Я пыталсяскопируйте файл model.ckpt.meta, model.ckpt.index, model.ckpt.data в каталог $ BERT_LARGE_DIR и измените флаги run_squad.py следующим образом, чтобы только предсказать ответ, а не обучать с использованием набора данных:

python run_squad.py \
  --vocab_file=$BERT_LARGE_DIR/vocab.txt \
  --bert_config_file=$BERT_LARGE_DIR/bert_config.json \
  --init_checkpoint=$BERT_LARGE_DIR/model.ckpt \
  --do_train=False \
  --train_file=$SQUAD_DIR/train-v2.0.json \
  --do_predict=True \
  --predict_file=$SQUAD_DIR/dev-v2.0.json \
  --train_batch_size=24 \
  --learning_rate=3e-5 \
  --num_train_epochs=2.0 \
  --max_seq_length=384 \
  --doc_stride=128 \
  --output_dir=gs://some_bucket/squad_large/ \
  --use_tpu=True \
  --tpu_name=$TPU_NAME \
  --version_2_with_negative=True

выдает ведро каталог / model.ckpt не существует ошибки.

Как использовать контрольные точки, сгенерированные после обучения, и использовать их для прогнозирования?

Ответы [ 2 ]

0 голосов
/ 29 июня 2019

Обычно обученные контрольные точки создаются в каталоге, указанном параметром --output_dir во время обучения.(В вашем случае это gs: // some_bucket / squad_large /).Каждый контрольно-пропускной пункт будет иметь номер.Вы должны определить наибольшее число;пример: model.ckpt-12345.Теперь установите параметр --init_checkpoint в вашей оценке / прогнозе, используя выходной каталог и последнюю сохраненную контрольную точку (модель с наибольшим числом).(В вашем случае это будет что-то вроде --init_checkpoint=gs://some_bucket/squad_large/model.ckpt-<highest number>)

0 голосов
/ 28 июня 2019

Во втором коде ФЛАГ init_checkpoint Я думаю, что это должно быть:

--init_checkpoint=$BERT_LARGE_DIR/bert_model.ckpt

как в приведенном выше, а не --init_checkpoint=$BERT_LARGE_DIR/model.ckpt.

Если проблема не устранена, используете ли вы multi_cased_L-12_H-768_A-12 предварительно обученные модели?

...