Проблема InfeedEnqueueTuple при попытке восстановить обновленную контрольную точку модели BERT с помощью облачного TPU - PullRequest
0 голосов
/ 16 ноября 2018

Буду признателен за любую помощь ниже, спасибо заранее. Я сделал копию блокнота Google Bert по тонкой настройке и обучил его на наборе данных SQUAD, используя Cloud TPU и Bucket. Предсказания для набора dev в порядке, поэтому я скачал файлы checkpoint, model.ckpt.meta, model.ckpt.index и model.ckpt.data локально и попытался восстановить с помощью кода:

sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
saver = tf.train.import_meta_graph(META_FILE) # META_FILE being path to .meta
saver.restore(sess, 'model.ckpt')

Однако я получил ошибку:

    op_def = op_dict[node.op]
KeyError: 'InfeedEnqueueTuple'

Я предполагаю, что это часть Cloud TPU Tools , и я должен продолжить работу с Cloud TPU, поэтому я попробовал следующее ( reference ):

# code from cells before includes
...
tf.contrib.cloud.configure_gcs(session, credentials=auth_info)
...
tpu_cluster_resolver = tf.contrib.cluster_resolver.TPUClusterResolver(TPU_ADDRESS)
run_config = tf.contrib.tpu.RunConfig(
    cluster=tpu_cluster_resolver,
    model_dir=OUTPUT_DIR,
    save_checkpoints_steps=SAVE_CHECKPOINTS_STEPS,
    tpu_config=tf.contrib.tpu.TPUConfig(
        iterations_per_loop=ITERATIONS_PER_LOOP,
        num_shards=NUM_TPU_CORES,
        per_host_input_for_training=tf.contrib.tpu.InputPipelineConfig.PER_HOST_V2))
...

Проблемная ячейка:

"""
# not valid checkpoint error. <bucket> placeholder for cloud bucket name
sess = tf.Session()
META_FILE = "gs://<bucket>/bert/models/bertsquad/model.ckpt-10949.meta"
CKPT_FILE = "gs://<bucket>/bert/models/bertsquad/model.ckpt"
saver = tf.train.import_meta_graph(META_FILE)
saver.restore(sess, CKPT_FILE)
"""

from google.cloud import storage
from tensorflow import MetaGraphDef

client = storage.Client(project="agent-helper-4a014")
bucket = client.get_bucket(<bucket>)
metafile = "bert/models/bertsquad/model.ckpt-10949.meta"
# using full path gs://<bucket>/bert/models/bertsquad doesn't work

blob = bucket.get_blob(metafile)
#blob = bucket.blob(metafile)
#model_graph = blob.download_to_filename("model.ckpt")
model_graph = blob.download_as_string()

mgd = MetaGraphDef()
mgd.ParseFromString(model_graph)

with tf.Session() as sess:
    saver = tf.train.import_meta_graph(mgd, clear_devices=True)
    init_checkpoint = saver.restore(sess, 'model.ckpt')

Это, в свою очередь, дало следующую ошибку:

InvalidArgumentError (see above for traceback): Restoring from checkpoint failed. This is most likely due to a mismatch between the current graph and the graph from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:

No OpKernel was registered to support Op 'InfeedEnqueueTuple' with these attrs.  Registered devices: [CPU,XLA_CPU], Registered kernels:
  <no registered kernels>

     [[node input_pipeline_task0/while/InfeedQueue/enqueue/0 (defined at <ipython-input-67-e4b52b7b5944>:21)  = InfeedEnqueueTuple[_class=["loc:@input_pipeline_task0/while/IteratorGetNext"], device_ordinal=0, dtypes=[DT_INT32, DT_INT32, DT_INT32, DT_INT32, DT_INT32, DT_INT32], shapes=[[2], [2,384], [2,384], [2,384], [2], [2]], _device="/job:worker/task:0/device:CPU:0"](input_pipeline_task0/while/IteratorGetNext, input_pipeline_task0/while/IteratorGetNext:1, input_pipeline_task0/while/IteratorGetNext:2, input_pipeline_task0/while/IteratorGetNext:3, input_pipeline_task0/while/IteratorGetNext:4, input_pipeline_task0/while/IteratorGetNext:5)]]

1 Ответ

0 голосов
/ 20 ноября 2018

Если ваш мотив - предсказание, просто укажите местоположение model_dir (должно быть поле GCS), где сохраняются контрольные точки и метафайл. Код больше не будет использоваться для обучения (поскольку контрольная точка сохраняется для количества этапов обучения, а график модели не изменяется). Он напрямую перейдет к прогнозу.

Но если ваш вариант использования действительно хочет сохранить контрольные точки и восстановить его только для логического вывода, выполните следующие действия:

  • Создайте модель сети для каждого слоя вручную, как для исходной модели, или используйте сохраненный файл .meta для воссоздания сети, используя функцию tf.train.import(), подобную этой:

saver = tf.train.import_meta_graph('<filename>.meta')

  • Теперь восстановите контрольные точки, используя: saver.restore(sess, 'model.ckpt')

ПРИМЕЧАНИЕ: График модели, на котором восстанавливаются контрольные точки, должен быть точно таким же, как и исходный график, для которого сохранены эти контрольные точки.

Надеюсь, это решит вашу проблему.

...