TensorFlow: извлечение данных с заданной функцией из набора данных NSynth - PullRequest
0 голосов
/ 14 ноября 2018

У меня есть набор данных файлов TFRecord сериализованных буферов протокола TensorFlow с одним примером Прото на заметку, загруженный из https://magenta.tensorflow.org/datasets/nsynth. Я использую набор тестов, который составляет приблизительно 1 Гб, на случай, если кто-то захочет скачайте его, чтобы проверить код ниже. Каждый пример содержит много функций: высота, инструмент ...

Код, который читает в этих данных:

import tensorflow as tf
import numpy as np

sess = tf.InteractiveSession()

# Reading input data
dataset = tf.data.TFRecordDataset('../data/nsynth-test.tfrecord')

# Convert features into tensors
features = {
"pitch": tf.FixedLenFeature([1], dtype=tf.int64),
"audio": tf.FixedLenFeature([64000], dtype=tf.float32),
"instrument_family": tf.FixedLenFeature([1], dtype=tf.int64)}

parse_function = lambda example_proto: tf.parse_single_example(example_proto,features)
dataset = dataset.map(parse_function)

# Consuming TFRecord data.
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.batch(batch_size=3)
dataset = dataset.repeat()
iterator = dataset.make_one_shot_iterator()
batch = iterator.get_next()
sess.run(batch)

Теперь высота звука варьируется от 21 до 108. Но я хочу рассмотреть только данные определенного шага, например, pitch = 51. Как извлечь это подмножество "pitch = 51" из всего набора данных? Или, в качестве альтернативы, что мне сделать, чтобы мой итератор проходил только через это подмножество?

1 Ответ

0 голосов
/ 27 ноября 2018

То, что у вас есть, выглядит довольно хорошо, все, что вам не хватает, это функция фильтра.

Например, если вы хотите извлечь только pitch = 51, вы должны добавить после своей карты функцию

dataset = dataset.filter(lambda example: tf.equal(example["pitch"][0], 51))
...