Функция TFRecord имеет неправильное значение - PullRequest
0 голосов
/ 09 марта 2019

Я пытаюсь обучить некоторые вложения и помещаю свой набор данных в форму записи. Когда я пишу один пример в файл примерно так:

tf_features = {
        'given': int64_feature(given),
        'context': bytes_feature(np.array(context).tostring())
}
writer.write(tf.train.Example(features=tf.train.Features(feature=tf_features)).SerializeToString())

, где int64_feature и bytes_feature определены как:

def bytes_feature(val):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[val]))

def int64_feature(val):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[val]))

и я распечатываю пример (данный, контекстный) пары, я получаю что-то вроде: (698, [686, 439, 464, 775]), что хорошо.

Однако, когда я пытаюсь прочитать из того же файла, вот так:

def parse_example(w, tf_example):
    feats_dict = {
        'given': tf.FixedLenFeature([], tf.int64, default_value=0),
        'context': tf.FixedLenFeature([], tf.string)
    }
    features = tf.parse_single_example(tf_example, feats_dict)
    context = tf.decode_raw(features['context'], tf.int64)

    context_feats = dict()
    ctx_idx = 0
    for i in range(w):
        if i == w//2: continue
        context_feats['context%d' % ctx_idx] = context[ctx_idx]
        ctx_idx += 1

    return context_feats, features['given']

dataset = tf.data.TFRecordDataset([fname]).map(partial(parse_example, 5))
iterator = dataset.make_one_shot_iterator()

with tf.Session() as sess:
    iter_features, iter_labels = iterator.get_next()
    features = sess.run(iter_features)
    labels = sess.run(iter_labels)
    print(features, labels)

Для той же пары контекста, что и раньше, я получаю (464, [686, 439, 464, 775]). Данная метка всегда является третьей от меток контекста.

Я часами смотрю на этот код, но я в тупике. Кто-нибудь знает, что происходит?

1 Ответ

0 голосов
/ 09 марта 2019

Я думаю, что понял, что происходит, и это довольно глупая ошибка. В следующих строках:

iter_features, iter_labels = iterator.get_next()
features = sess.run(iter_features)
labels = sess.run(iter_labels)

Я запускаю sess.run дважды, и из-за поведения итератора, когда я получал функции, он возвращал правильные функции, но когда я получал метки, он возвращал метки в следующем примере. .

Имеет смысл, чтобы полученная мной метка всегда была третьей в контексте из-за скользящего окна, используемого для получения пар заданного контекста.

Я изменил вышеупомянутые строки на:

iter_ex = iterator.get_next()
ex = sess.run(iter_ex)
print(ex)

И все работает как положено.

...