Ключевой проблемой здесь является использование tf.data.Dataset.from_tensors
против tf.data.Dataset.from_tensor_slices
.
tf.data.Dataset.from_tensors([t1,t2,t3])
- Создает набор данных, где каждый элемент списка задан как точка данных tf.data.Dataset.from_tensor_slices(t)
- Создает набор данных, где один элемент - это один элемент, проиндексированный в первая ось
Из имеющихся у вас данных (т.е. 3 изображения размером 160x160x3, т.е. 3x160x160x3
) вам необходимо использовать второй метод. В противном случае все ваши 3 изображения воспринимаются как одна точка данных (что, вероятно, не то, что вы хотите).
Если не считать второй проблемы, вывод, который вы показываете,
User A - <MapDataset shapes: ((3, 160, 160, 3), (3, 2)), types: (tf.float64, tf.float64)>
User B - <MapDataset shapes: ((3, 160, 160, 3), (3, 2)), types: (tf.float64, tf.float64)>
Output ds - <ConcatenateDataset shapes: ((3, 160, 160, 3), (3, 2)), types: (tf.float64, tf.float64)>
Это просто показывая, как выглядит один элемент. Таким образом, вы не увидите 6
так, как хотели бы, даже если у вас был правильный код. Чтобы увидеть количество элементов, вы должны выполнить итерацию набора данных. В вашем случае вы увидите 2
(так как этот набор данных рассматривает все 3 изображения как одну точку данных).
Итак, чтобы исправить свой код, сделайте это,
def some_transformation(image, label):
# do something like rotation, clipping, noise add etc.
return image, label
userA = tf.data.Dataset.from_tensor_slices(tf.constant(np.zeros((3, 160, 160, 3))))
userA_label = tf.data.Dataset.from_tensors(tf.constant(np.zeros((3, 2))))
userA_with_labels = tf.data.Dataset.zip((userA, userA_label))
transformed_userA_w_label = userA_with_labels.map(some_transformation)
userB = tf.data.Dataset.from_tensor_slices(tf.constant(np.zeros((3, 160, 160, 3))))
userB_label = tf.data.Dataset.from_tensors(tf.constant(np.zeros((3, 2))))
userB_with_labels = tf.data.Dataset.zip((userB, userB_label))
transformed_userB_w_label = userB_with_labels.map(some_transformation)
print('User A - {}'.format(transformed_userA_w_label))
print('User B - {}'.format(transformed_userB_w_label))
concat_ds = transformed_userA_w_label.concatenate(transformed_userB_w_label)
print(concat_ds)
for i,ii in enumerate(concat_ds):
print(i)
Вы увидите значение i
, напечатанное 6 раз. Что вам нужно.