Создание дополненных обучающих данных с помощью вращения тензорного потока - PullRequest
0 голосов
/ 03 мая 2018

Недавно начали с tensorflow и cnn, и я надеюсь натренировать простую сеть для поворота объектов вверх.

У меня есть набор данных 1k изображений, ориентированных вверх и использующих tensorflow.contrib.image.rotate Я бы хотел повернуть их со случайными углами. Что-то в пределах RotNet , но с tensorflow вместо keras.

Идея заключается в создании N повернутых обучающих примеров из каждого набора данных 1k изображений. Каждое из изображений имеет форму 30x30x1 (черно-белое).

with tf.Session() as sess:
    for curr in range(oriented_data.shape[0]):
        curr_image = loaded_oriented_data[curr]
        for i in range(augment_each_image):
            rotation_angle = np.random.randint(360)
            rotated_image = tfci.rotate(curr_image, np.float(rotation_angle) * math.pi/180.)
            training_data[curr + i] = sess.run(rotated_image)
            labels[curr + i] = rotation_angle

Теперь проблема в том, что выполнение строки sess.run(rotated_image) занимает очень много времени. например, создание только 5 примеров для каждого из 1К выполняется более 30 минут (на процессоре).
Если я просто удалю эту строку, изображения будут сгенерированы за минуту.

Я полагаю, что есть способ хранить и работать с данными в качестве тензоров вместо преобразования их обратно в ndarrays, как я делал до сих пор, или, возможно, есть более быстрая функция для оценки тензоров?

1 Ответ

0 голосов
/ 04 мая 2018

Проблема в том, что вы создаете оператор поворота для каждого изображения в augment_each_image, создавая потенциально очень большую сеть.

Решение состоит в том, чтобы создать одну операцию поворота, которую вы последовательно применяете к своим изображениям. Что-то в этом роде:

im_ph = tf.placeholder(...)
ang_ph = tf.placeholder(...)
rot_op = tfci.rotate(im_ph, ang_ph)

with tf.Session() as sess:
  for curr in range(oriented_data.shape[0]):
    curr_image = loaded_oriented_data[curr]
      for i in range(augment_each_image):
        rotation_angle = np.random.randint(360)
        rotated_image = sess.run(rot_op, {im_ph: curr_image, ang_ph: np.float(rotation_angle) * math.pi/180.})
        training_data[curr + i] = rotated_image
        labels[curr + i] = rotation_angle
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...