Как объединить два набора данных Tensorflow? - PullRequest
0 голосов
/ 26 апреля 2020

Я пытаюсь загрузить, а затем увеличить некоторый набор данных изображений (160 x 160 x 3), в котором изображения хранятся в папке, а имя папки для меня является меткой. Для создания копии данных применяются несколько преобразований, и они должны быть concatenated (or stacked may be), чтобы объединить данные и сохранить их на диск.

Ниже приведен простейший воспроизводимый фрагмент, который я могу написать, и я не может append/concatenate/stack два набора данных.

def some_transformation(image, label):
    # do something like rotation, clipping, noise add etc.
    return image, label

userA = tf.data.Dataset.from_tensors(tf.constant(np.zeros((3, 160, 160, 3))))
userA_label = tf.data.Dataset.from_tensors(tf.constant(np.zeros((3, 2))))
userA_with_labels = tf.data.Dataset.zip((userA, userA_label))
transformed_userA_w_label = userA_with_labels.map(some_transformation)

userB = tf.data.Dataset.from_tensors(tf.constant(np.zeros((3, 160, 160, 3))))
userB_label = tf.data.Dataset.from_tensors(tf.constant(np.zeros((3, 2))))
userB_with_labels = tf.data.Dataset.zip((userB, userB_label))
transformed_userB_w_label = userB_with_labels.map(some_transformation)

print('User A - {}'.format(transformed_userA_w_label))
print('User B - {}'.format(transformed_userB_w_label))
transformed_userA_w_label.concatenate(transformed_userB_w_label)

Вывод операторов печати выглядит следующим образом:

User A - <MapDataset shapes: ((3, 160, 160, 3), (3, 2)), types: (tf.float64, tf.float64)>
User B - <MapDataset shapes: ((3, 160, 160, 3), (3, 2)), types: (tf.float64, tf.float64)>
Output ds - <ConcatenateDataset shapes: ((3, 160, 160, 3), (3, 2)), types: (tf.float64, tf.float64)>

Ожидается: 6 изображение

Output ds - <ConcatenateDataset shapes: ((6, 160, 160, 3), (6, 2)), types: (tf.float64, tf.float64)>

1 Ответ

1 голос
/ 27 апреля 2020

Ключевой проблемой здесь является использование tf.data.Dataset.from_tensors против tf.data.Dataset.from_tensor_slices.

  • tf.data.Dataset.from_tensors([t1,t2,t3]) - Создает набор данных, где каждый элемент списка задан как точка данных
  • tf.data.Dataset.from_tensor_slices(t) - Создает набор данных, где один элемент - это один элемент, проиндексированный в первая ось

Из имеющихся у вас данных (т.е. 3 изображения размером 160x160x3, т.е. 3x160x160x3) вам необходимо использовать второй метод. В противном случае все ваши 3 изображения воспринимаются как одна точка данных (что, вероятно, не то, что вы хотите).

Если не считать второй проблемы, вывод, который вы показываете,

User A - <MapDataset shapes: ((3, 160, 160, 3), (3, 2)), types: (tf.float64, tf.float64)>
User B - <MapDataset shapes: ((3, 160, 160, 3), (3, 2)), types: (tf.float64, tf.float64)>
Output ds - <ConcatenateDataset shapes: ((3, 160, 160, 3), (3, 2)), types: (tf.float64, tf.float64)>

Это просто показывая, как выглядит один элемент. Таким образом, вы не увидите 6 так, как хотели бы, даже если у вас был правильный код. Чтобы увидеть количество элементов, вы должны выполнить итерацию набора данных. В вашем случае вы увидите 2 (так как этот набор данных рассматривает все 3 изображения как одну точку данных).

Итак, чтобы исправить свой код, сделайте это,

def some_transformation(image, label):
    # do something like rotation, clipping, noise add etc.
    return image, label

userA = tf.data.Dataset.from_tensor_slices(tf.constant(np.zeros((3, 160, 160, 3))))
userA_label = tf.data.Dataset.from_tensors(tf.constant(np.zeros((3, 2))))
userA_with_labels = tf.data.Dataset.zip((userA, userA_label))
transformed_userA_w_label = userA_with_labels.map(some_transformation)

userB = tf.data.Dataset.from_tensor_slices(tf.constant(np.zeros((3, 160, 160, 3))))
userB_label = tf.data.Dataset.from_tensors(tf.constant(np.zeros((3, 2))))
userB_with_labels = tf.data.Dataset.zip((userB, userB_label))
transformed_userB_w_label = userB_with_labels.map(some_transformation)

print('User A - {}'.format(transformed_userA_w_label))
print('User B - {}'.format(transformed_userB_w_label))
concat_ds = transformed_userA_w_label.concatenate(transformed_userB_w_label)
print(concat_ds)

for i,ii in enumerate(concat_ds):
  print(i)

Вы увидите значение i, напечатанное 6 раз. Что вам нужно.

...