Я пишу две модели в кератах (с бэкэндом тензорного потока), которые требуют перевода изображений несколько раз.Теперь, после некоторых поисков, я нашел метод, который работает.В основном все сводится к использованию tf.map_fn для отображения tf.contrib.image.translate на пакеты изображений, как показано в следующем коде:
import tensorflow as tf
import tensorflow.keras as keras
def sample(inputs):
""" Creates a random direction for every image in the batch."""
batch_size = tf.shape(inputs)[0]
eps = tf.random.normal((batch_size, 2))
return eps
def translate_single_set(args):
"""
Translates a single fixed set of images. Used as callable in tf.map_fn.
args should be [images, translations]
"""
assert isinstance(args, list)
assert len(args) == 2
return tf.contrib.image.translate(args[0], args[1])
translator_inputs = keras.layers.Input(shape=(28, 28))
shift = keras.layers.Lambda(sample)(translator_inputs)
translated = keras.layers.Lambda(
lambda args: tf.map_fn(translate_single_set, args, dtype=tf.float32)
)([translator_inputs, shift])
translator = keras.Model(inputs=translator_inputs, outputs=translated)
# use this on mnist
from tensorflow.keras.datasets import mnist
(original_train, _), (original_test, _) = mnist.load_data()
original_train = original_train.astype('float32')/255
original_test = original_test.astype('float32')/255
binary_train = np.round(original_train)
binary_test = np.round(original_test)
# Shuffle the data
np.random.shuffle(binary_train)
np.random.shuffle(binary_test)
translator.predict(binary_train, batch_size=100)
Проблема в том, что это очень медленно.При использовании только процессора, выполнение последней строки занимает несколько секунд, а в gpu - десятки секунд.Поскольку у моих моделей много тренировочных весов, я действительно хочу использовать графический процессор.Однако моя первая модель становится в десять-двадцать раз медленнее при переводе изображений по сравнению с непереводом (на GPU, на процессоре она становится примерно в пять-семь раз медленнее).Вторая модель еще не закончена, но требует в 25 раз больше переводов, чем первая.
Существует ли способ более эффективного перевода (пакетов) изображений?И правильно ли я считаю, что основная горлышко бутылки заключается в использовании tf.map_fn?(Я не знаю как это проверить)