Это мой код для загрузки данных из tfrecord:
def read_tfrecord(tfrecord, epochs, batch_size):
dataset = tf.data.TFRecordDataset(tfrecord)
def parse(record):
features = {
"image": tf.io.FixedLenFeature([], tf.string),
"target": tf.io.FixedLenFeature([], tf.int64)
}
example = tf.io.parse_single_example(record, features)
image = decode_image(example["image"])
label = tf.cast(example["target"], tf.int32)
return image, label
dataset = dataset.map(parse)
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.prefetch(buffer_size=batch_size) #
dataset = dataset.batch(batch_size, drop_remainder=True)
dataset = dataset.repeat(epochs)
return dataset
x_train, y_train = read_tfrecord(tfrecord=train_files, epochs=EPOCHS, batch_size=BATCH_SIZE)
У меня следующая ошибка:
ValueError: too many values to unpack (expected 2)
Мой вопрос:
Как распаковать данные из набора данных?