Мне нужно эффективно объединить небольшое количество данных при обучении модели TensorFlow на TFRecord
с.Как я могу сделать этот поиск, используя информацию из проанализированного TFRecord
?
Подробнее:
Я обучаю сверточную сеть на большом наборе данных, используя TFRecords
.Каждый TFRecord
содержит необработанное изображение вместе с меткой цели и некоторые метаданные об изображении.Часть обучения состоит в том, что мне нужно стандартизировать изображение, используя mean
и std
, которые являются специфическими для группировки изображений.Чтобы сделать это в прошлом, я жестко закодировал mean
и std
в TFRecord
.Затем он используется примерно так в моем parse_example
, который используется для отображения на Dataset
в моем input_fn
, например так:
def parse_example(..):
# ...
parsed = tf.parse_single_example(value, keys_to_features)
image_raw = tf.decode_raw(parsed['image/raw'], tf.uint16)
image = tf.reshape(image_raw, image_shape)
image.set_shape(image_shape)
# pull hardcoded pixels mean and std from the parsed TFExample
mean = parsed['mean']
std = parsed['std']
image = (tf.cast(image, tf.float32) - mean) / std
# ...
return image, label
Хотя вышеприведенное работает и ускоряет обучениеявляется ограничением в том, что я часто хочу изменить то, что mean
и std
я использую.Вместо того, чтобы записывать mean
и std
в TFRecord
s, я бы предпочел просмотреть соответствующую сводную статистику во время обучения.Это означает, что когда я тренируюсь, у меня есть небольшой словарь Python, который я могу найти в соответствующей сводной статистике, используя информацию об изображении, которая анализируется из TFRecord
.Проблема, с которой я сталкиваюсь, заключается в том, что я не могу использовать этот словарь python в своем графе тензорного потока.Если я попытаюсь выполнить поиск напрямую, это не сработает, потому что у меня есть тензорные объекты вместо реальных примитивов.Это имеет смысл, поскольку input_fn
выполняет символические манипуляции при построении графа вычислений для TensorFlow (верно?).Как мне обойти это?
Одна вещь, которую я попробовал, состоит в том, чтобы создать таблицу поиска из словаря, например, так:
def create_channel_hashtable(keys, values, default_val=-1):
initializer = tf.contrib.lookup.KeyValueTensorInitializer(keys, values)
return tf.contrib.lookup.HashTable(initializer, default_val)
Хеш-таблицы можно создавать и использовать в parse_example
функция для поиска.Это все «работает», но это слишком сильно замедляет обучениеВозможно, стоит отметить, что это обучение проводится на TPU.С оригинальным подходом использования значений из TFRecord
s обучение очень быстрое и не ограничивается вводом-выводом, однако это меняется, когда используется поиск по хешу.Каков предлагаемый способ обработки этих случаев?Хотя переупаковка TFRecord
s выполнима, она кажется глупой, когда данные, которые нужно искать, малы и могут быть эффективными.