Я хочу построить набор данных тензорного потока из tfrecords.это мой код:
def make_dataset():
filenames = [train_tfrecords_dir + name for name in os.listdir(train_tfrecords_dir)]
dataset = tf.data.TFRecordDataset(filenames)
def parser(record):
keys_to_features = {
"mhot_label_raw": tf.FixedLenFeature((), tf.string, default_value=""),
"mel_spec_raw": tf.FixedLenFeature((), tf.string, default_value=""),
}
parsed = tf.parse_single_example(record, keys_to_features)
mel_spec1d = tf.decode_raw(parsed['mel_spec_raw'], tf.float64)
mhot_label = tf.decode_raw(parsed['mhot_label_raw'], tf.float64)
mel_spec = tf.reshape(mel_spec1d, [30, 65,85])
return {"mel_data": mel_spec}, mhot_label
dataset = dataset.map(parser)
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
return iterator
но это вызывает эту ошибку:
InvalidArgumentError: Input to DecodeRaw has length 165750 that is not a multiple of 8, the size of double
[[Node: DecodeRaw = DecodeRaw[little_endian=true, out_type=DT_DOUBLE](ParseSingleExample/Squeeze_mel_spec_raw)]]
[[Node: IteratorGetNext = IteratorGetNext[output_shapes=[[?,30,65,85], [?,?]], output_types=[DT_DOUBLE, DT_DOUBLE], _device="/job:localhost/replica:0/task:0/device:CPU:0"](OneShotIterator)]]
Как я могу это исправить?Я удалил tf.decode_raw , но он не работал