Я обучаю модель tensorflow
, которая выводит 3 метки: прямую, левую, правую на основе ввода изображения.
Я использую набор данных tf.data и хочу дополнить данные, перевернувобраз.Проблема в том, что если я переворачиваю изображение, мне также нужно перевернуть метку (так как левая становится правой при переворачивании изображения).
Я реализовал функцию карты и попытался перевернуть метки, используя tf.собрать (было предложено здесь Как переворачивать помеченные направления по горизонтали в тензорном потоке )
augmented2 = trainDataset.map(
lambda x, y: flip(x, y), num_parallel_calls=10)
def flip(x: tf.Tensor, y) -> tf.Tensor:
y = tf.gather( y,[0, 2, 1])
x = flip_left_right(x)
return x, y
Проблема заключается в том, что во время обучения возникает ошибка:
тензор потока.python.framework.errors_impl.InvalidArgumentError: логиты и метки должны быть транслируемыми: logits_size = [32,3] label_size = [3,3]
Что я делаю не так?Является ли лучший / более простой способ перевернуть этикетки?