Для задачи классификации аудио я нарезал некоторые аудиофайлы на куски фиксированного размера и сериализовал (см. data2serialized ниже) данные, метки, время начала и окончания и имя файла аудио в TFRecords сделать обучающие примеры.
Из сгенерированных TFRecords я делаю наборов данных для подачи tf.keras.models.Model.fit .
При выполнении прогноза мне нужно получить значение имя файла сериализованных данных, чтобы объединить результаты всех примеров из заданных аудиофайлов, но tf.keras.models.Model.predict принимает только входные данные, и я не вижу, как получить прогнозы + имена файлов в качестве выходных.
Я начал читать документацию для tf.estimator.Estimator но я все еще не вижу, как передать дополнительные данные, которые не являются ни входными данными, ни целями, через конвейер прогнозирования ...
Любое предложение?
def data2serialized(filename, start_time, end_time, data, labels):
feature = {
'filename': _bytes_feature([filename.encode()]),
'times': _float_feature([start_time, end_time]),
'data': _float_feature(data.flatten()),
'labels': _bytes_feature(["#".join(str(l) for l in labels).encode()]),
}
example = tf.train.Example(features=tf.train.Features(feature=feature))
return example.SerializeToString()
def serialized2data(serialized_data, feature_shape, class_list, nolabel_warning=True):
"""Generate features and labels.
Labels are indices of original label in class_list.
"""
features = {
'filename': tf.FixedLenFeature([], tf.string),
'times': tf.FixedLenFeature([2], tf.float32),
'data': tf.FixedLenFeature(feature_shape, tf.float32),
'labels': tf.FixedLenFeature([], tf.string),
}
example = tf.parse_single_example(serialized_data, features)
# reshape data to channels_first format
data = tf.reshape(example['data'], (1, feature_shape[0], feature_shape[1]))
# one-hot encode labels
labels = tf.strings.to_number(
tf.string_split([example['labels']], '#').values,
out_type=tf.int32
)
# get intersection of class_list and labels
labels = tf.squeeze(
tf.sparse.to_dense(
tf.sets.intersection(
tf.expand_dims(labels, axis=0),
tf.expand_dims(class_list, axis=0)
)
),
axis=0
)
# sort class_list and get indices of labels in class_list
class_list = tf.sort(class_list)
labels = tf.where(
tf.equal(
tf.expand_dims(labels, axis=1),
class_list)
)[:,1]
tf.cond(
tf.math.logical_and(nolabel_warning, tf.equal(tf.size(labels), 0)),
true_fn=lambda:myprint(tf.strings.format('File {} has no label', example['filename'])),
false_fn=lambda:1
)
one_hot = tf.cond(
tf.equal(tf.size(labels), 0),
true_fn=lambda: tf.zeros(tf.size(class_list)),
false_fn=lambda: tf.reduce_max(tf.one_hot(labels, tf.size(class_list)), 0)
)
return (data, one_hot)
def filelist2dataset(files, example_shape, class_list, training=True, batch_size=32, nolabel_warning=True):
files = tf.convert_to_tensor(files, dtype=dtypes.string)
files = tf.data.Dataset.from_tensor_slices(files)
# dataset = files.interleave(lambda x: tf.data.TFRecordDataset(x).prefetch(100), cycle_length=8)
dataset = files.interleave(lambda x: tf.data.TFRecordDataset(x), cycle_length=8)
dataset = dataset.map(lambda x: serialized2data(x, example_shape, class_list, nolabel_warning))
if training:
dataset = dataset.shuffle(10000)
dataset = dataset.repeat() # Repeat the input indefinitely.
dataset = dataset.batch(batch_size, drop_remainder=True)
return dataset