Мне нужно написать сеть Keras (tf версия: 1.15.0), которая получает выходные данные другой сети и применяет значение выходных данных к изображению в качестве функции преобразования. В качестве примера рассмотрим, что предыдущая сеть генерирует значение r=0.4
, затем в этой сети я хочу применить вертикальный переворот на основе значения r
. Вот код:
def vertical_flip():
def VF(img, r):
img_vf = tf.image.flip_up_down(img)
out = K.switch(tf.math.greater_equal(r, 0.5), img_vf, img)
return img_vf
return VF
После применения преобразования полученное изображение будет передано в следующую сеть. Все эти три сети затем объединяются в одну сеть и будут обучаться по методу train_on_batch
. Код выше не имеет обучаемых / не обучаемых параметров, но находится внутри сети, которую необходимо обучить. При попытке обучить сеть, я получаю эту ошибку:
ValueError: Variable <tf.Variable 'conv2d_1/kernel:0' shape=(3, 3, 3, 64) dtype=float32> has `None` for gradient. Please make sure that all of your ops have a gradient defined (i.e. are differentiable). Common ops without gradient: K.argmax, K.round, K.eval.
Хотя упомянутая переменная находится внутри первой сети, похоже, что K.switch
выдает ошибку. Мой вопрос состоит в том, как обучить мою общую сеть, учитывая тот факт, что мне нужно реализовать функцию переключения?