Перелистывание меток в Tensorflow с помощью функции карты tf.data - PullRequest
0 голосов
/ 19 сентября 2019

Я обучаю модель tensorflow, которая выводит 3 метки: прямую, левую, правую на основе ввода изображения.

Я использую набор данных tf.data и хочу дополнить данные, перевернувобраз.Проблема в том, что если я переворачиваю изображение, мне также нужно перевернуть метку (так как левая становится правой при переворачивании изображения).

Я реализовал функцию карты и попытался перевернуть метки, используя tf.собрать (было предложено здесь Как переворачивать помеченные направления по горизонтали в тензорном потоке )

augmented2 = trainDataset.map(
    lambda x, y: flip(x, y), num_parallel_calls=10)


def flip(x: tf.Tensor, y) -> tf.Tensor:
    y = tf.gather( y,[0, 2, 1])
    x = flip_left_right(x)
    return x, y

Проблема заключается в том, что во время обучения возникает ошибка:

тензор потока.python.framework.errors_impl.InvalidArgumentError: логиты и метки должны быть транслируемыми: logits_size = [32,3] label_size = [3,3]

Что я делаю не так?Является ли лучший / более простой способ перевернуть этикетки?

...