Предположим, у меня есть пакет индексов
batch_idx = tf.constant([0, 1, 0, 1, 0])
и список функций
list_fn = [fn1, fn2]
Теперь на графике tf я хочу выбрать fn1
, если idx равен 0
и fn2
в противном случае. Это должно иметь
y[i] = list_fn[batch_idx[i]](x[i])
для каждого i=0,1,...,4
. Примечание. x
и y
имеют тот же размер партии, что и batch_idx
.
Я знаю, как это сделать, если размер партии равен 1 с использованием tf.cond
. Но не знаю, как для правильной партии.
Обратите внимание, что в моей ситуации batch_idx на самом деле динамический c, так как он либо заполнитель, либо из tf.data
.